├── assets └── framework.png ├── ovdetr ├── datasets │ ├── torchvision_datasets │ │ ├── __init__.py │ │ ├── lvis.py │ │ └── coco.py │ ├── __init__.py │ ├── data_prefetcher.py │ ├── samplers.py │ ├── lvis.py │ ├── coco.py │ ├── transforms.py │ └── coco_eval.py ├── models │ ├── ops │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ ├── make.sh │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── src │ │ │ ├── vision.cpp │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_attn_cuda.cu │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.h │ │ │ │ └── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn.h │ │ ├── setup.py │ │ └── test.py │ ├── __init__.py │ ├── post_process.py │ ├── position_encoding.py │ ├── backbone.py │ ├── matcher.py │ ├── segmentation.py │ └── model.py ├── util │ ├── __init__.py │ ├── pos_embed.py │ ├── coco_categories.py │ ├── box_ops.py │ ├── clip_utils.py │ └── misc.py ├── scripts │ └── save_clip_features.py └── engine_ov.py ├── dataset_prepare.md ├── .gitignore ├── run_scripts.md └── README.md /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuhangzang/OV-DETR/HEAD/assets/framework.png -------------------------------------------------------------------------------- /ovdetr/datasets/torchvision_datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | from .coco import CocoDetection 7 | from .lvis import LvisDetection 8 | -------------------------------------------------------------------------------- /ovdetr/models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /ovdetr/models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /ovdetr/models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | -------------------------------------------------------------------------------- /ovdetr/util/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # OV DETR 3 | # Copyright (c) S-LAB, Nanyang Technological University. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # Modified from DETR (https://github.com/facebookresearch/detr) 9 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 10 | # ------------------------------------------------------------------------ 11 | -------------------------------------------------------------------------------- /dataset_prepare.md: -------------------------------------------------------------------------------- 1 | ## Open-Vocabulary COCO 2 | 3 | 1. Download the [COCO](https://cocodataset.org/#home) dataset. 4 | 5 | 2. Create the annotation jsons for the open-vocabulary setting. We use the scripts provide by [OVR-CNN](https://github.com/alirezazareian/ovr-cnn/blob/master/ipynb/003.ipynb). Then add the object proposals that may cover the novel classes. You can download our pre-generated json file in [Google Drive](https://drive.google.com/file/d/1O_RU6k_s3UI74RFcpxAyIQhHmnmbRdIe/view?usp=sharing). 6 | 7 | 3. Extract the CLIP image features. You can download our pre-generated file in [Google Drive](https://drive.google.com/file/d/1nZJcr0Rl1Osy6qxbNPd1eIgZiZ0Warc6/view?usp=sharing), or use the [script](./ovdetr/scripts/save_clip_features.py) to extract it by yourself. 8 | 9 | ## Open-Vocabulary LVIS 10 | 11 | Under preparation. 12 | -------------------------------------------------------------------------------- /ovdetr/models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # output dir 2 | output 3 | instant_test_output 4 | inference_test_output 5 | 6 | # dataset 7 | data 8 | logs 9 | exps 10 | 11 | *.png 12 | *.json 13 | *.diff 14 | *.jpg 15 | !/projects/DensePose/doc/images/*.jpg 16 | 17 | # compilation and distribution 18 | __pycache__ 19 | _ext 20 | *.pyc 21 | *.pyd 22 | *.so 23 | *.dll 24 | *.egg-info/ 25 | build/ 26 | dist/ 27 | wheels/ 28 | 29 | # pytorch/python/numpy formats 30 | *.pth 31 | *.pkl 32 | *.npy 33 | *.ts 34 | model_ts*.txt 35 | 36 | # ipython/jupyter notebooks 37 | *.ipynb 38 | **/.ipynb_checkpoints/ 39 | .ipynb_checkpoints 40 | */.ipynb_checkpoints/* 41 | 42 | # Editor temporaries 43 | *.swn 44 | *.swo 45 | *.swp 46 | *~ 47 | 48 | # editor settings 49 | .idea 50 | .vscode 51 | _darcs 52 | 53 | # project dirs 54 | /detectron2/model_zoo/configs 55 | /datasets/* 56 | !/datasets/*.* 57 | /projects/*/datasets 58 | /models 59 | /snippet 60 | -------------------------------------------------------------------------------- /ovdetr/models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /ovdetr/models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /ovdetr/models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /ovdetr/util/pos_embed.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def gen_sineembed_for_position(pos_tensor): 7 | # n_query, bs, _ = pos_tensor.size() 8 | # sineembed_tensor = torch.zeros(n_query, bs, 256) 9 | scale = 2 * math.pi 10 | dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) 11 | dim_t = 10000 ** (2 * (dim_t // 2) / 128) 12 | x_embed = pos_tensor[:, :, 0] * scale 13 | y_embed = pos_tensor[:, :, 1] * scale 14 | pos_x = x_embed[:, :, None] / dim_t 15 | pos_y = y_embed[:, :, None] / dim_t 16 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 17 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 18 | if pos_tensor.size(-1) == 2: 19 | pos = torch.cat((pos_y, pos_x), dim=2) 20 | elif pos_tensor.size(-1) == 4: 21 | w_embed = pos_tensor[:, :, 2] * scale 22 | pos_w = w_embed[:, :, None] / dim_t 23 | pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2) 24 | 25 | h_embed = pos_tensor[:, :, 3] * scale 26 | pos_h = h_embed[:, :, None] / dim_t 27 | pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2) 28 | 29 | pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) 30 | else: 31 | raise ValueError("Unknown pos_tensor shape(-1):{}".format(pos_tensor.size(-1))) 32 | return pos 33 | -------------------------------------------------------------------------------- /ovdetr/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | import torch.utils.data 10 | import torchvision 11 | 12 | from .coco import build as build_coco 13 | from .lvis import build as build_lvis 14 | from .torchvision_datasets import CocoDetection, LvisDetection 15 | 16 | 17 | def get_coco_api_from_dataset(dataset): 18 | for _ in range(10): 19 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 20 | # break 21 | if isinstance(dataset, torch.utils.data.Subset): 22 | dataset = dataset.dataset 23 | if isinstance(dataset, CocoDetection): 24 | return dataset.coco 25 | elif isinstance(dataset, LvisDetection): 26 | return dataset.lvis 27 | 28 | 29 | def build_dataset(image_set, args): 30 | if args.dataset_file == "coco": 31 | return build_coco(image_set, args) 32 | elif args.dataset_file == "coco_panoptic": 33 | # to avoid making panopticapi required for coco 34 | from .coco_panoptic import build as build_coco_panoptic 35 | 36 | return build_coco_panoptic(image_set, args) 37 | elif args.dataset_file == "lvis": 38 | return build_lvis(image_set, args) 39 | else: 40 | raise ValueError(f"dataset {args.dataset_file} not supported") 41 | -------------------------------------------------------------------------------- /ovdetr/scripts/save_clip_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from clip import clip 4 | from PIL import Image, ImageOps 5 | from tqdm import tqdm 6 | import json 7 | from collections import defaultdict 8 | 9 | 10 | device = "cuda" if torch.cuda.is_available() else "cpu" 11 | model, preprocess = clip.load("ViT-B/32", device=device) 12 | for _, param in model.named_parameters(): 13 | param.requires_grad = False 14 | 15 | # Json and COCO dataset dir path 16 | json_path = 'xxx/instances_train2017_seen_2_proposal.json' 17 | file_dir = "xxx/coco/train2017/" 18 | save_path = "xxx/coco/zero-shot/clip_feat.pkl" 19 | 20 | with open(json_path, "r") as f: 21 | data = json.load(f) 22 | 23 | img2ann_gt = defaultdict(list) 24 | for temp in data['annotations']: 25 | img2ann_gt[temp['image_id']].append(temp) 26 | 27 | dic = {} 28 | for image_id in tqdm(img2ann_gt.keys()): 29 | file_name = file_dir + f"{image_id}".zfill(12) + ".jpg" 30 | image = Image.open(file_name).convert("RGB") 31 | 32 | for value in img2ann_gt[image_id]: 33 | ind = value['id'] 34 | bbox = copy.deepcopy(value['bbox']) 35 | if (bbox[1] < 16) or (bbox[2] < 16): 36 | continue 37 | bbox[2] += bbox[0] 38 | bbox[3] += bbox[1] 39 | roi = preprocess(image.crop(bbox)).to(device).unsqueeze(0) 40 | roi_features = model.encode_image(roi) 41 | 42 | category_id = value['category_id'] 43 | 44 | if category_id in dic.keys(): 45 | dic[category_id].append(roi_features) 46 | else: 47 | dic[category_id] = [roi_features] 48 | 49 | 50 | for key in dic.keys(): 51 | dic[key] = torch.cat(dic[key], 0) 52 | 53 | torch.save(dic, save_path) 54 | -------------------------------------------------------------------------------- /ovdetr/models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /ovdetr/util/coco_categories.py: -------------------------------------------------------------------------------- 1 | COCO_CATEGORIES = { 2 | 1: "person", 3 | 2: "bicycle", 4 | 3: "car", 5 | 4: "motorcycle", 6 | 5: "airplane", 7 | 6: "bus", 8 | 7: "train", 9 | 8: "truck", 10 | 9: "boat", 11 | 10: "traffic light", 12 | 11: "fire hydrant", 13 | 12: "street sign", 14 | 13: "stop sign", 15 | 14: "parking meter", 16 | 15: "bench", 17 | 16: "bird", 18 | 17: "cat", 19 | 18: "dog", 20 | 19: "horse", 21 | 20: "sheep", 22 | 21: "cow", 23 | 22: "elephant", 24 | 23: "bear", 25 | 24: "zebra", 26 | 25: "giraffe", 27 | 26: "hat", 28 | 27: "backpack", 29 | 28: "umbrella", 30 | 29: "shoe", 31 | 30: "eye glasses", 32 | 31: "handbag", 33 | 32: "tie", 34 | 33: "suitcase", 35 | 34: "frisbee", 36 | 35: "skis", 37 | 36: "snowboard", 38 | 37: "sports ball", 39 | 38: "kite", 40 | 39: "baseball bat", 41 | 40: "baseball glove", 42 | 41: "skateboard", 43 | 42: "surfboard", 44 | 43: "tennis racket", 45 | 44: "bottle", 46 | 45: "plate", 47 | 46: "wine glass", 48 | 47: "cup", 49 | 48: "fork", 50 | 49: "knife", 51 | 50: "spoon", 52 | 51: "bowl", 53 | 52: "banana", 54 | 53: "apple", 55 | 54: "sandwich", 56 | 55: "orange", 57 | 56: "broccoli", 58 | 57: "carrot", 59 | 58: "hot dog", 60 | 59: "pizza", 61 | 60: "donut", 62 | 61: "cake", 63 | 62: "chair", 64 | 63: "couch", 65 | 64: "potted plant", 66 | 65: "bed", 67 | 66: "mirror", 68 | 67: "dining table", 69 | 68: "window", 70 | 69: "desk", 71 | 70: "toilet", 72 | 71: "door", 73 | 72: "tv", 74 | 73: "laptop", 75 | 74: "computer mouse", 76 | 75: "remote", 77 | 76: "keyboard", 78 | 77: "cell phone", 79 | 78: "microwave", 80 | 79: "oven", 81 | 80: "toaster", 82 | 81: "sink", 83 | 82: "refrigerator", 84 | 83: "blender", 85 | 84: "book", 86 | 85: "clock", 87 | 86: "vase", 88 | 87: "scissors", 89 | 88: "teddy bear", 90 | 89: "hair drier", 91 | 90: "toothbrush", 92 | 91: "hair brush", 93 | } 94 | -------------------------------------------------------------------------------- /run_scripts.md: -------------------------------------------------------------------------------- 1 | We use the same scripts as [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). 2 | If you are not familiar with DETR series papers, you are recommend to first read the documents of [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR) and [DETR](https://github.com/facebookresearch/detr). 3 | 4 | --- 5 | For example, to train the model on single node and 8 GPUs: 6 | 7 | ``` 8 | python -m torch.distributed.launch \ 9 | --nproc_per_node=8 \ 10 | --use_env main.py \ 11 | --dataset_file coco \ 12 | --coco_path xxxx/COCO/ \ 13 | --output_dir ./output/ \ 14 | --num_queries 300 \ 15 | --with_box_refine \ 16 | --two_stage \ 17 | --label_map \ 18 | --max_len 15 \ 19 | --prob 0.75 \ 20 | --clip_feat_path xxxx/clip_feat_coco.pkl \ 21 | ``` 22 | 23 | The meaning of these arguments: 24 | * `dataset_file`: `coco` or `lvis`. 25 | * `coco_path` / `--lvis_path`: the dataset directory path. 26 | * `label_map`: mapping the default categories ids (e.g, 0-91 for COCO) to a contiguous array (e.g, 0-64 for the open-vocabulary COCO setting). 27 | * `output_dir`: path to the log and checkpoint files. 28 | * `max_len`: the symbol `R` in the paper to control the repeat times for object queries. You are recommended to use large value to reduce the convergence time. 29 | * `prob`: the probability of selecting CLIP text or image features for conditional matching. `prob=1.0` refers to merely use the CLIP text embeddings. 30 | * `clip_feat_path`: path to the pre-computed file of CLIP image features. 31 | 32 | --- 33 | To evaluate the model, you need add two arguments, `--eval` and `--resume` (same as Deformable DETR): 34 | ``` 35 | python -m torch.distributed.launch \ 36 | --nproc_per_node=8 \ 37 | --use_env main.py \ 38 | --dataset_file coco \ 39 | --coco_path xxxx/COCO/ \ 40 | --output_dir ./output/ \ 41 | --num_queries 300 \ 42 | --with_box_refine \ 43 | --two_stage \ 44 | --label_map \ 45 | --eval \ 46 | --resume xxx/checkpoint.pth \ 47 | ``` 48 | 49 | * `resume`: path to the checkpoint file. 50 | 51 | --- 52 | To train on the instance segmentation task, you need add two arguments, `--masks` and `--frozen_weights` (same as DETR): 53 | ``` 54 | python -m torch.distributed.launch \ 55 | --nproc_per_node=8 \ 56 | --use_env main.py \ 57 | --dataset_file coco \ 58 | --coco_path xxxx/COCO/ \ 59 | --output_dir ./output/ \ 60 | --num_queries 300 \ 61 | --with_box_refine \ 62 | --two_stage \ 63 | --label_map \ 64 | --masks \ 65 | --frozen_weights xxx/checkpoint.pth \ 66 | ``` 67 | 68 | * `frozen_weights`: path to the pretrained model. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Open-Vocabulary DETR with Conditional Matching

2 | 3 |

4 | arXiv | 5 | Project Page | 6 | Code 7 |

8 | 9 | This repository contains the implementation of the following paper: 10 | > **Open-Vocabulary DETR with Conditional Matching**
11 | > Yuhang Zang, Wei Li, Kaiyang Zhou, Chen Huang, Chen Change Loy
12 | > European Conference on Computer Vision (**ECCV**), 2022
13 | 14 |

15 | 16 |

17 | 18 | ## Installation 19 | 20 | We use the same environment as [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). 21 | You are also required to install the following packages: 22 | 23 | - [CLIP](https://github.com/openai/CLIP) 24 | - [cocoapi](https://github.com/cocodataset/cocoapi) 25 | - [lvis-api](https://github.com/lvis-dataset/lvis-api) 26 | 27 | We test our models under ```python=3.8, pytorch=1.11.0, cuda=10.1```, 8 Nvidia V100 32GB GPUs. 28 | 29 | ## Data 30 | Please refer to [dataset_prepare.md](./dataset_prepare.md). 31 | 32 | ## Running the Model 33 | Please refer to [run_scripts.md](./run_scripts.md). 34 | 35 | ## Model Zoo 36 | - Open-vocabulary COCO (AP50 metric) 37 | 38 | | Base | Novel| All | Model | 39 | |------|------|-----|-------| 40 | | 61.0 | 29.4 | 52.7|[Google Drive](https://drive.google.com/file/d/1_iypFgVsLQwXVrT5zDtKeFaxOcC_A3uO/view?usp=sharing)| 41 | 42 | ## Citation 43 | If you find our work useful for your research, please consider citing the paper: 44 | ``` 45 | @InProceedings{zang2022open, 46 |  author = {Zang, Yuhang and Li, Wei and Zhou, Kaiyang and Huang, Chen and Loy, Chen Change}, 47 |  title = {Open-Vocabulary DETR with Conditional Matching}, 48 |  journal = {European Conference on Computer Vision}, 49 |  year = {2022} 50 | } 51 | ``` 52 | 53 | ## License 54 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. 55 | 56 | ## Acknowledgement 57 | We would like to thanks [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR), [CLIP](https://github.com/openai/CLIP) and [ViLD](https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild) for their open-source projects. 58 | 59 | ## Contact 60 | Please contact [Yuhang Zang](mailto:zang0012@ntu.edu.sg) if you have any questions. 61 | -------------------------------------------------------------------------------- /ovdetr/models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import glob 10 | import os 11 | 12 | import torch 13 | from setuptools import find_packages, setup 14 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 15 | 16 | requirements = ["torch", "torchvision"] 17 | 18 | 19 | def get_extensions(): 20 | this_dir = os.path.dirname(os.path.abspath(__file__)) 21 | extensions_dir = os.path.join(this_dir, "src") 22 | 23 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 24 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 25 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 26 | 27 | sources = main_file + source_cpu 28 | extension = CppExtension 29 | extra_compile_args = {"cxx": []} 30 | define_macros = [] 31 | 32 | if torch.cuda.is_available() and CUDA_HOME is not None: 33 | extension = CUDAExtension 34 | sources += source_cuda 35 | define_macros += [("WITH_CUDA", None)] 36 | extra_compile_args["nvcc"] = [ 37 | "-DCUDA_HAS_FP16=1", 38 | "-D__CUDA_NO_HALF_OPERATORS__", 39 | "-D__CUDA_NO_HALF_CONVERSIONS__", 40 | "-D__CUDA_NO_HALF2_OPERATORS__", 41 | ] 42 | else: 43 | raise NotImplementedError("Cuda is not availabel") 44 | 45 | sources = [os.path.join(extensions_dir, s) for s in sources] 46 | include_dirs = [extensions_dir] 47 | ext_modules = [ 48 | extension( 49 | "MultiScaleDeformableAttention", 50 | sources, 51 | include_dirs=include_dirs, 52 | define_macros=define_macros, 53 | extra_compile_args=extra_compile_args, 54 | ) 55 | ] 56 | return ext_modules 57 | 58 | 59 | setup( 60 | name="MultiScaleDeformableAttention", 61 | version="1.0", 62 | author="Weijie Su", 63 | url="https://github.com/fundamentalvision/Deformable-DETR", 64 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 65 | packages=find_packages( 66 | exclude=( 67 | "configs", 68 | "tests", 69 | ) 70 | ), 71 | ext_modules=get_extensions(), 72 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 73 | ) 74 | -------------------------------------------------------------------------------- /ovdetr/datasets/torchvision_datasets/lvis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from io import BytesIO 4 | 5 | import tqdm 6 | from PIL import Image 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | class LvisDetection(VisionDataset): 11 | """`LVIS Dataset. 12 | Args: 13 | root (string): Root directory where images are downloaded to. 14 | annFile (string): Path to json annotation file. 15 | transform (callable, optional): A function/transform that takes in an PIL image 16 | and returns a transformed version. E.g, ``transforms.ToTensor`` 17 | target_transform (callable, optional): A function/transform that takes in the 18 | target and transforms it. 19 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 20 | and returns a transformed version. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | root, 26 | annFile, 27 | transform=None, 28 | target_transform=None, 29 | transforms=None, 30 | cache_mode=False, 31 | local_rank=0, 32 | local_size=1, 33 | ): 34 | super(LvisDetection, self).__init__(root, transforms, transform, target_transform) 35 | from lvis import LVIS 36 | 37 | self.lvis = LVIS(annFile) 38 | self.ids = list(sorted(self.lvis.imgs.keys())) 39 | self.cache_mode = cache_mode 40 | self.local_rank = local_rank 41 | self.local_size = local_size 42 | if cache_mode: 43 | self.cache = {} 44 | self.cache_images() 45 | 46 | def cache_images(self): 47 | self.cache = {} 48 | for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): 49 | if index % self.local_size != self.local_rank: 50 | continue 51 | path = self.lvis.load_imgs(img_id)[0]["file_name"] 52 | with open(os.path.join(self.root, path), "rb") as f: 53 | self.cache[path] = f.read() 54 | 55 | def get_image(self, path): 56 | if self.cache_mode: 57 | if path not in self.cache.keys(): 58 | with open(os.path.join(self.root, path), "rb") as f: 59 | self.cache[path] = f.read() 60 | return Image.open(BytesIO(self.cache[path])).convert("RGB") 61 | return Image.open(os.path.join(self.root, path)).convert("RGB") 62 | 63 | def __getitem__(self, index): 64 | lvis = self.lvis 65 | img_id = self.ids[index] 66 | ann_ids = lvis.get_ann_ids(img_ids=[img_id]) 67 | target = lvis.load_anns(ann_ids) 68 | 69 | split_folder, file_name = lvis.load_imgs([img_id])[0]["coco_url"].split("/")[-2:] 70 | path = os.path.join(split_folder, file_name) 71 | 72 | img = self.get_image(path) 73 | if self.transforms is not None: 74 | img, target = self.transforms(img, target) 75 | 76 | return img, target 77 | 78 | def __len__(self): 79 | return len(self.ids) 80 | -------------------------------------------------------------------------------- /ovdetr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from util.clip_utils import build_text_embedding_coco, build_text_embedding_lvis 4 | 5 | from .backbone import build_backbone 6 | from .deformable_transformer import build_deforamble_transformer 7 | from .matcher import build_matcher 8 | from .model import OVDETR 9 | from .post_process import OVPostProcess, PostProcess, PostProcessSegm 10 | from .segmentation import DETRsegm 11 | from .set_criterion import OVSetCriterion 12 | 13 | 14 | def build_model(args): 15 | if args.dataset_file == "coco": 16 | num_classes = 91 17 | elif args.dataset_file == "lvis": 18 | num_classes = 1204 19 | else: 20 | raise NotImplementedError 21 | 22 | device = torch.device(args.device) 23 | 24 | backbone = build_backbone(args) 25 | 26 | transformer = build_deforamble_transformer(args) 27 | 28 | if args.dataset_file == "coco": 29 | zeroshot_w = build_text_embedding_coco() 30 | elif args.dataset_file == "lvis": 31 | zeroshot_w = build_text_embedding_lvis() 32 | else: 33 | raise NotImplementedError 34 | model = OVDETR( 35 | backbone, 36 | transformer, 37 | num_classes=num_classes, 38 | num_queries=args.num_queries, 39 | num_feature_levels=args.num_feature_levels, 40 | aux_loss=args.aux_loss, 41 | with_box_refine=args.with_box_refine, 42 | two_stage=args.two_stage, 43 | cls_out_channels=1, 44 | dataset_file=args.dataset_file, 45 | zeroshot_w=zeroshot_w, 46 | max_len=args.max_len, 47 | clip_feat_path=args.clip_feat_path, 48 | prob=args.prob, 49 | ) 50 | 51 | if args.masks: 52 | model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) 53 | 54 | matcher = build_matcher(args) 55 | weight_dict = {"loss_ce": args.cls_loss_coef, "loss_bbox": args.bbox_loss_coef} 56 | weight_dict["loss_giou"] = args.giou_loss_coef 57 | weight_dict["loss_embed"] = args.feature_loss_coef 58 | if args.masks: 59 | weight_dict["loss_mask"] = args.mask_loss_coef 60 | weight_dict["loss_dice"] = args.dice_loss_coef 61 | # TODO this is a hack 62 | if args.aux_loss: 63 | aux_weight_dict = {} 64 | for i in range(args.dec_layers - 1): 65 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 66 | aux_weight_dict.update({k + "_enc": v for k, v in weight_dict.items()}) 67 | weight_dict.update(aux_weight_dict) 68 | 69 | losses = ["labels", "boxes", "embed"] 70 | if args.masks: 71 | losses = ["labels", "boxes", "masks"] 72 | 73 | criterion = OVSetCriterion( 74 | num_classes, 75 | matcher, 76 | weight_dict, 77 | losses, 78 | focal_alpha=args.focal_alpha, 79 | ) 80 | postprocessors = {"bbox": OVPostProcess(num_queries=args.num_queries)} 81 | criterion.to(device) 82 | 83 | if args.masks: 84 | postprocessors["segm"] = PostProcessSegm() 85 | 86 | return model, criterion, postprocessors 87 | -------------------------------------------------------------------------------- /ovdetr/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | """ 10 | Utilities for bounding box manipulation and GIoU. 11 | """ 12 | import torch 13 | from torchvision.ops.boxes import box_area 14 | 15 | 16 | def box_cxcywh_to_xyxy(x): 17 | x_c, y_c, w, h = x.unbind(-1) 18 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | def box_xyxy_to_cxcywh(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 25 | return torch.stack(b, dim=-1) 26 | 27 | 28 | # modified from torchvision to also return the union 29 | def box_iou(boxes1, boxes2): 30 | area1 = box_area(boxes1) 31 | area2 = box_area(boxes2) 32 | 33 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 34 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 35 | 36 | wh = (rb - lt).clamp(min=0) # [N,M,2] 37 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 38 | 39 | union = area1[:, None] + area2 - inter 40 | 41 | iou = inter / union 42 | return iou, union 43 | 44 | 45 | def generalized_box_iou(boxes1, boxes2): 46 | """ 47 | Generalized IoU from https://giou.stanford.edu/ 48 | 49 | The boxes should be in [x0, y0, x1, y1] format 50 | 51 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 52 | and M = len(boxes2) 53 | """ 54 | # degenerate boxes gives inf / nan results 55 | # so do an early check 56 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 57 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 58 | iou, union = box_iou(boxes1, boxes2) 59 | 60 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 61 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 62 | 63 | wh = (rb - lt).clamp(min=0) # [N,M,2] 64 | area = wh[:, :, 0] * wh[:, :, 1] 65 | 66 | return iou - (area - union) / area 67 | 68 | 69 | def masks_to_boxes(masks): 70 | """Compute the bounding boxes around the provided masks 71 | 72 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 73 | 74 | Returns a [N, 4] tensors, with the boxes in xyxy format 75 | """ 76 | if masks.numel() == 0: 77 | return torch.zeros((0, 4), device=masks.device) 78 | 79 | h, w = masks.shape[-2:] 80 | 81 | y = torch.arange(0, h, dtype=torch.float) 82 | x = torch.arange(0, w, dtype=torch.float) 83 | y, x = torch.meshgrid(y, x) 84 | 85 | x_mask = masks * x.unsqueeze(0) 86 | x_max = x_mask.flatten(1).max(-1)[0] 87 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 88 | 89 | y_mask = masks * y.unsqueeze(0) 90 | y_max = y_mask.flatten(1).max(-1)[0] 91 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 92 | 93 | return torch.stack([x_min, y_min, x_max, y_max], 1) 94 | -------------------------------------------------------------------------------- /ovdetr/datasets/data_prefetcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import torch 7 | 8 | 9 | def to_cuda(samples, targets, device): 10 | samples = samples.to(device, non_blocking=True) 11 | targets = [{k: v.to(device, non_blocking=True) for k, v in t.items()} for t in targets] 12 | return samples, targets 13 | 14 | 15 | class data_prefetcher: 16 | def __init__(self, loader, device, prefetch=True): 17 | self.loader = iter(loader) 18 | self.prefetch = prefetch 19 | self.device = device 20 | if prefetch: 21 | self.stream = torch.cuda.Stream() 22 | self.preload() 23 | 24 | def preload(self): 25 | try: 26 | self.next_samples, self.next_targets = next(self.loader) 27 | except StopIteration: 28 | self.next_samples = None 29 | self.next_targets = None 30 | return 31 | # if record_stream() doesn't work, another option is to make sure device inputs are created 32 | # on the main stream. 33 | # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda') 34 | # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda') 35 | # Need to make sure the memory allocated for next_* is not still in use by the main stream 36 | # at the time we start copying to next_*: 37 | # self.stream.wait_stream(torch.cuda.current_stream()) 38 | with torch.cuda.stream(self.stream): 39 | self.next_samples, self.next_targets = to_cuda( 40 | self.next_samples, self.next_targets, self.device 41 | ) 42 | # more code for the alternative if record_stream() doesn't work: 43 | # copy_ will record the use of the pinned source tensor in this side stream. 44 | # self.next_input_gpu.copy_(self.next_input, non_blocking=True) 45 | # self.next_target_gpu.copy_(self.next_target, non_blocking=True) 46 | # self.next_input = self.next_input_gpu 47 | # self.next_target = self.next_target_gpu 48 | 49 | # With Amp, it isn't necessary to manually convert data to half. 50 | # if args.fp16: 51 | # self.next_input = self.next_input.half() 52 | # else: 53 | 54 | def next(self): 55 | if self.prefetch: 56 | torch.cuda.current_stream().wait_stream(self.stream) 57 | samples = self.next_samples 58 | targets = self.next_targets 59 | if samples is not None: 60 | samples.record_stream(torch.cuda.current_stream()) 61 | if targets is not None: 62 | for t in targets: 63 | for k, v in t.items(): 64 | v.record_stream(torch.cuda.current_stream()) 65 | self.preload() 66 | else: 67 | try: 68 | samples, targets = next(self.loader) 69 | samples, targets = to_cuda(samples, targets, self.device) 70 | except StopIteration: 71 | samples = None 72 | targets = None 73 | return samples, targets 74 | -------------------------------------------------------------------------------- /ovdetr/datasets/torchvision_datasets/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from torchvision 6 | # ------------------------------------------------------------------------ 7 | 8 | """ 9 | Copy-Paste from torchvision, but add utility of caching images on memory 10 | """ 11 | import os 12 | import os.path 13 | from io import BytesIO 14 | 15 | import tqdm 16 | from PIL import Image 17 | from torchvision.datasets.vision import VisionDataset 18 | 19 | 20 | class CocoDetection(VisionDataset): 21 | """`MS Coco Detection `_ Dataset. 22 | Args: 23 | root (string): Root directory where images are downloaded to. 24 | annFile (string): Path to json annotation file. 25 | transform (callable, optional): A function/transform that takes in an PIL image 26 | and returns a transformed version. E.g, ``transforms.ToTensor`` 27 | target_transform (callable, optional): A function/transform that takes in the 28 | target and transforms it. 29 | transforms (callable, optional): A function/transform that takes input sample and its target as entry 30 | and returns a transformed version. 31 | """ 32 | 33 | def __init__( 34 | self, 35 | root, 36 | annFile, 37 | transform=None, 38 | target_transform=None, 39 | transforms=None, 40 | cache_mode=False, 41 | local_rank=0, 42 | local_size=1, 43 | ): 44 | super(CocoDetection, self).__init__(root, transforms, transform, target_transform) 45 | from pycocotools.coco import COCO 46 | 47 | self.coco = COCO(annFile) 48 | self.ids = list(sorted(self.coco.imgs.keys())) 49 | self.cache_mode = cache_mode 50 | self.local_rank = local_rank 51 | self.local_size = local_size 52 | if cache_mode: 53 | self.cache = {} 54 | self.cache_images() 55 | 56 | def cache_images(self): 57 | self.cache = {} 58 | for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): 59 | if index % self.local_size != self.local_rank: 60 | continue 61 | path = self.coco.loadImgs(img_id)[0]["file_name"] 62 | with open(os.path.join(self.root, path), "rb") as f: 63 | self.cache[path] = f.read() 64 | 65 | def get_image(self, path): 66 | if self.cache_mode: 67 | if path not in self.cache.keys(): 68 | with open(os.path.join(self.root, path), "rb") as f: 69 | self.cache[path] = f.read() 70 | return Image.open(BytesIO(self.cache[path])).convert("RGB") 71 | return Image.open(os.path.join(self.root, path)).convert("RGB") 72 | 73 | def __getitem__(self, index): 74 | """ 75 | Args: 76 | index (int): Index 77 | Returns: 78 | tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. 79 | """ 80 | coco = self.coco 81 | img_id = self.ids[index] 82 | ann_ids = coco.getAnnIds(imgIds=img_id) 83 | target = coco.loadAnns(ann_ids) 84 | 85 | path = coco.loadImgs(img_id)[0]["file_name"] 86 | 87 | img = self.get_image(path) 88 | if self.transforms is not None: 89 | img, target = self.transforms(img, target) 90 | 91 | return img, target 92 | 93 | def __len__(self): 94 | return len(self.ids) 95 | -------------------------------------------------------------------------------- /ovdetr/models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import, division, print_function 10 | 11 | import MultiScaleDeformableAttention as MSDA 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Function 15 | from torch.autograd.function import once_differentiable 16 | 17 | 18 | class MSDeformAttnFunction(Function): 19 | @staticmethod 20 | def forward( 21 | ctx, 22 | value, 23 | value_spatial_shapes, 24 | value_level_start_index, 25 | sampling_locations, 26 | attention_weights, 27 | im2col_step, 28 | ): 29 | ctx.im2col_step = im2col_step 30 | output = MSDA.ms_deform_attn_forward( 31 | value, 32 | value_spatial_shapes, 33 | value_level_start_index, 34 | sampling_locations, 35 | attention_weights, 36 | ctx.im2col_step, 37 | ) 38 | ctx.save_for_backward( 39 | value, 40 | value_spatial_shapes, 41 | value_level_start_index, 42 | sampling_locations, 43 | attention_weights, 44 | ) 45 | return output 46 | 47 | @staticmethod 48 | @once_differentiable 49 | def backward(ctx, grad_output): 50 | ( 51 | value, 52 | value_spatial_shapes, 53 | value_level_start_index, 54 | sampling_locations, 55 | attention_weights, 56 | ) = ctx.saved_tensors 57 | grad_value, grad_sampling_loc, grad_attn_weight = MSDA.ms_deform_attn_backward( 58 | value, 59 | value_spatial_shapes, 60 | value_level_start_index, 61 | sampling_locations, 62 | attention_weights, 63 | grad_output, 64 | ctx.im2col_step, 65 | ) 66 | 67 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 68 | 69 | 70 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 71 | # for debug and test only, 72 | # need to use cuda version instead 73 | N_, S_, M_, D_ = value.shape 74 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 75 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 76 | sampling_grids = 2 * sampling_locations - 1 77 | sampling_value_list = [] 78 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 79 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 80 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_ * M_, D_, H_, W_) 81 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 82 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 83 | # N_*M_, D_, Lq_, P_ 84 | sampling_value_l_ = F.grid_sample( 85 | value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False 86 | ) 87 | sampling_value_list.append(sampling_value_l_) 88 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 89 | attention_weights = attention_weights.transpose(1, 2).reshape(N_ * M_, 1, Lq_, L_ * P_) 90 | output = ( 91 | (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights) 92 | .sum(-1) 93 | .view(N_, M_ * D_, Lq_) 94 | ) 95 | return output.transpose(1, 2).contiguous() 96 | -------------------------------------------------------------------------------- /ovdetr/models/post_process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from util import box_ops 6 | 7 | 8 | class PostProcess(nn.Module): 9 | """This module converts the model's output into the format expected by the coco api""" 10 | 11 | @torch.no_grad() 12 | def forward(self, outputs, target_sizes): 13 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] 14 | 15 | assert len(out_logits) == len(target_sizes) 16 | assert target_sizes.shape[1] == 2 17 | 18 | prob = out_logits.sigmoid() 19 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) 20 | scores = topk_values 21 | topk_boxes = topk_indexes // out_logits.shape[2] 22 | labels = topk_indexes % out_logits.shape[2] 23 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) 24 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) 25 | 26 | # and from relative [0, 1] to absolute [0, height] coordinates 27 | img_h, img_w = target_sizes.unbind(1) 28 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 29 | boxes = boxes * scale_fct[:, None, :] 30 | 31 | results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] 32 | 33 | return results 34 | 35 | 36 | class PostProcessSegm(nn.Module): 37 | def __init__(self, threshold=0.5): 38 | super().__init__() 39 | self.threshold = threshold 40 | 41 | @torch.no_grad() 42 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 43 | assert len(orig_target_sizes) == len(max_target_sizes) 44 | max_h, max_w = max_target_sizes.max(0)[0].tolist() 45 | outputs_masks = outputs["pred_masks"] 46 | outputs_masks = F.interpolate( 47 | outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False 48 | ) 49 | outputs_masks = outputs_masks.sigmoid() > self.threshold 50 | 51 | for i, (cur_mask, t, tt) in enumerate( 52 | zip(outputs_masks, max_target_sizes, orig_target_sizes) 53 | ): 54 | img_h, img_w = t[0], t[1] 55 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) 56 | results[i]["masks"] = F.interpolate( 57 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 58 | ).byte() 59 | 60 | return results 61 | 62 | 63 | class OVPostProcess(nn.Module): 64 | """This module converts the model's output into the format expected by the coco api""" 65 | 66 | def __init__(self, num_queries=300): 67 | super().__init__() 68 | self.num_queries = num_queries 69 | 70 | @torch.no_grad() 71 | def forward(self, outputs, target_sizes): 72 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] 73 | select_id = outputs["select_id"] 74 | if type(select_id) == int: 75 | select_id = [select_id] 76 | 77 | assert len(out_logits) == len(target_sizes) 78 | assert target_sizes.shape[1] == 2 79 | 80 | prob = out_logits.sigmoid() 81 | 82 | scores, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 300, dim=1) 83 | topk_boxes = topk_indexes // out_logits.shape[2] 84 | 85 | labels = torch.zeros_like(prob).flatten(1) 86 | num_queries = self.num_queries 87 | for ind, c in enumerate(select_id): 88 | labels[:, ind * num_queries : (ind + 1) * num_queries] = c 89 | labels = torch.gather(labels, 1, topk_boxes) 90 | 91 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) 92 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) 93 | 94 | # and from relative [0, 1] to absolute [0, height] coordinates 95 | img_h, img_w = target_sizes.unbind(1) 96 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 97 | boxes = boxes * scale_fct[:, None, :] 98 | 99 | results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] 100 | 101 | return results, topk_indexes 102 | -------------------------------------------------------------------------------- /ovdetr/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Various positional encodings for the transformer. 12 | """ 13 | import math 14 | 15 | import torch 16 | from torch import nn 17 | 18 | from util.misc import NestedTensor 19 | 20 | 21 | class PositionEmbeddingSine(nn.Module): 22 | """ 23 | This is a more standard version of the position embedding, very similar to the one 24 | used by the Attention is all you need paper, generalized to work on images. 25 | """ 26 | 27 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 28 | super().__init__() 29 | self.num_pos_feats = num_pos_feats 30 | self.temperature = temperature 31 | self.normalize = normalize 32 | if scale is not None and normalize is False: 33 | raise ValueError("normalize should be True if scale is passed") 34 | if scale is None: 35 | scale = 2 * math.pi 36 | self.scale = scale 37 | 38 | def forward(self, tensor_list: NestedTensor): 39 | x = tensor_list.tensors 40 | mask = tensor_list.mask 41 | assert mask is not None 42 | not_mask = ~mask 43 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 44 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 45 | if self.normalize: 46 | eps = 1e-6 47 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 48 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 49 | 50 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 51 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 52 | 53 | pos_x = x_embed[:, :, :, None] / dim_t 54 | pos_y = y_embed[:, :, :, None] / dim_t 55 | pos_x = torch.stack( 56 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 57 | ).flatten(3) 58 | pos_y = torch.stack( 59 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 60 | ).flatten(3) 61 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 62 | return pos 63 | 64 | 65 | class PositionEmbeddingLearned(nn.Module): 66 | """ 67 | Absolute pos embedding, learned. 68 | """ 69 | 70 | def __init__(self, num_pos_feats=256): 71 | super().__init__() 72 | self.row_embed = nn.Embedding(50, num_pos_feats) 73 | self.col_embed = nn.Embedding(50, num_pos_feats) 74 | self.reset_parameters() 75 | 76 | def reset_parameters(self): 77 | nn.init.uniform_(self.row_embed.weight) 78 | nn.init.uniform_(self.col_embed.weight) 79 | 80 | def forward(self, tensor_list: NestedTensor): 81 | x = tensor_list.tensors 82 | h, w = x.shape[-2:] 83 | i = torch.arange(w, device=x.device) 84 | j = torch.arange(h, device=x.device) 85 | x_emb = self.col_embed(i) 86 | y_emb = self.row_embed(j) 87 | pos = ( 88 | torch.cat( 89 | [ 90 | x_emb.unsqueeze(0).repeat(h, 1, 1), 91 | y_emb.unsqueeze(1).repeat(1, w, 1), 92 | ], 93 | dim=-1, 94 | ) 95 | .permute(2, 0, 1) 96 | .unsqueeze(0) 97 | .repeat(x.shape[0], 1, 1, 1) 98 | ) 99 | return pos 100 | 101 | 102 | def build_position_encoding(args): 103 | N_steps = args.hidden_dim // 2 104 | if args.position_embedding in ("v2", "sine"): 105 | # TODO find a better way of exposing other arguments 106 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 107 | elif args.position_embedding in ("v3", "learned"): 108 | position_embedding = PositionEmbeddingLearned(N_steps) 109 | else: 110 | raise ValueError(f"not supported {args.position_embedding}") 111 | 112 | return position_embedding 113 | -------------------------------------------------------------------------------- /ovdetr/models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import, division, print_function 10 | 11 | import torch 12 | from functions.ms_deform_attn_func import (MSDeformAttnFunction, 13 | ms_deform_attn_core_pytorch) 14 | from torch.autograd import gradcheck 15 | 16 | N, M, D = 1, 2, 2 17 | Lq, L, P = 2, 2, 2 18 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 19 | level_start_index = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) 20 | S = sum([(H * W).item() for H, W in shapes]) 21 | 22 | 23 | torch.manual_seed(3) 24 | 25 | 26 | @torch.no_grad() 27 | def check_forward_equal_with_pytorch_double(): 28 | value = torch.rand(N, S, M, D).cuda() * 0.01 29 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 30 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 31 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 32 | im2col_step = 2 33 | output_pytorch = ( 34 | ms_deform_attn_core_pytorch( 35 | value.double(), shapes, sampling_locations.double(), attention_weights.double() 36 | ) 37 | .detach() 38 | .cpu() 39 | ) 40 | output_cuda = ( 41 | MSDeformAttnFunction.apply( 42 | value.double(), 43 | shapes, 44 | level_start_index, 45 | sampling_locations.double(), 46 | attention_weights.double(), 47 | im2col_step, 48 | ) 49 | .detach() 50 | .cpu() 51 | ) 52 | fwdok = torch.allclose(output_cuda, output_pytorch) 53 | max_abs_err = (output_cuda - output_pytorch).abs().max() 54 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 55 | 56 | print(f"* {fwdok} max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}") 57 | 58 | 59 | @torch.no_grad() 60 | def check_forward_equal_with_pytorch_float(): 61 | value = torch.rand(N, S, M, D).cuda() * 0.01 62 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 63 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 64 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 65 | im2col_step = 2 66 | output_pytorch = ( 67 | ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights) 68 | .detach() 69 | .cpu() 70 | ) 71 | output_cuda = ( 72 | MSDeformAttnFunction.apply( 73 | value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step 74 | ) 75 | .detach() 76 | .cpu() 77 | ) 78 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 79 | max_abs_err = (output_cuda - output_pytorch).abs().max() 80 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 81 | 82 | print( 83 | f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}" 84 | ) 85 | 86 | 87 | def check_gradient_numerical( 88 | channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True 89 | ): 90 | 91 | value = torch.rand(N, S, M, channels).cuda() * 0.01 92 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 93 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 94 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 95 | im2col_step = 2 96 | func = MSDeformAttnFunction.apply 97 | 98 | value.requires_grad = grad_value 99 | sampling_locations.requires_grad = grad_sampling_loc 100 | attention_weights.requires_grad = grad_attn_weight 101 | 102 | gradok = gradcheck( 103 | func, 104 | ( 105 | value.double(), 106 | shapes, 107 | level_start_index, 108 | sampling_locations.double(), 109 | attention_weights.double(), 110 | im2col_step, 111 | ), 112 | ) 113 | 114 | print(f"* {gradok} check_gradient_numerical(D={channels})") 115 | 116 | 117 | if __name__ == "__main__": 118 | check_forward_equal_with_pytorch_double() 119 | check_forward_equal_with_pytorch_float() 120 | 121 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 122 | check_gradient_numerical(channels, True, True, True) 123 | -------------------------------------------------------------------------------- /ovdetr/models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | """ 10 | Backbone modules. 11 | """ 12 | from typing import Dict, List 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | import torchvision 17 | from torch import nn 18 | from torchvision.models._utils import IntermediateLayerGetter 19 | 20 | from util.misc import NestedTensor, is_main_process 21 | 22 | from .position_encoding import build_position_encoding 23 | 24 | 25 | class FrozenBatchNorm2d(torch.nn.Module): 26 | """ 27 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 28 | 29 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 30 | without which any other models than torchvision.models.resnet[18,34,50,101] 31 | produce nans. 32 | """ 33 | 34 | def __init__(self, n, eps=1e-5): 35 | super(FrozenBatchNorm2d, self).__init__() 36 | self.register_buffer("weight", torch.ones(n)) 37 | self.register_buffer("bias", torch.zeros(n)) 38 | self.register_buffer("running_mean", torch.zeros(n)) 39 | self.register_buffer("running_var", torch.ones(n)) 40 | self.eps = eps 41 | 42 | def _load_from_state_dict( 43 | self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 44 | ): 45 | num_batches_tracked_key = prefix + "num_batches_tracked" 46 | if num_batches_tracked_key in state_dict: 47 | del state_dict[num_batches_tracked_key] 48 | 49 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 50 | state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs 51 | ) 52 | 53 | def forward(self, x): 54 | # move reshapes to the beginning 55 | # to make it fuser-friendly 56 | w = self.weight.reshape(1, -1, 1, 1) 57 | b = self.bias.reshape(1, -1, 1, 1) 58 | rv = self.running_var.reshape(1, -1, 1, 1) 59 | rm = self.running_mean.reshape(1, -1, 1, 1) 60 | eps = self.eps 61 | scale = w * (rv + eps).rsqrt() 62 | bias = b - rm * scale 63 | return x * scale + bias 64 | 65 | 66 | class BackboneBase(nn.Module): 67 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 68 | super().__init__() 69 | for name, parameter in backbone.named_parameters(): 70 | if ( 71 | not train_backbone 72 | or "layer2" not in name 73 | and "layer3" not in name 74 | and "layer4" not in name 75 | ): 76 | parameter.requires_grad_(False) 77 | if return_interm_layers: 78 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 79 | self.strides = [8, 16, 32] 80 | self.num_channels = [512, 1024, 2048] 81 | else: 82 | return_layers = {"layer4": "0"} 83 | self.strides = [32] 84 | self.num_channels = [2048] 85 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 86 | 87 | def forward(self, tensor_list: NestedTensor): 88 | xs = self.body(tensor_list.tensors) 89 | out: Dict[str, NestedTensor] = {} 90 | for name, x in xs.items(): 91 | m = tensor_list.mask 92 | assert m is not None 93 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 94 | out[name] = NestedTensor(x, mask) 95 | return out 96 | 97 | 98 | class Backbone(BackboneBase): 99 | """ResNet backbone with frozen BatchNorm.""" 100 | 101 | def __init__( 102 | self, 103 | name: str, 104 | train_backbone: bool, 105 | return_interm_layers: bool, 106 | dilation: bool, 107 | ): 108 | norm_layer = FrozenBatchNorm2d 109 | backbone = getattr(torchvision.models, name)( 110 | replace_stride_with_dilation=[False, False, dilation], 111 | pretrained=is_main_process(), 112 | norm_layer=norm_layer, 113 | ) 114 | assert name not in ("resnet18", "resnet34"), "number of channels are hard coded" 115 | super().__init__(backbone, train_backbone, return_interm_layers) 116 | if dilation: 117 | self.strides[-1] = self.strides[-1] // 2 118 | 119 | 120 | class Joiner(nn.Sequential): 121 | def __init__(self, backbone, position_embedding): 122 | super().__init__(backbone, position_embedding) 123 | self.strides = backbone.strides 124 | self.num_channels = backbone.num_channels 125 | 126 | def forward(self, tensor_list: NestedTensor): 127 | xs = self[0](tensor_list) 128 | out: List[NestedTensor] = [] 129 | pos = [] 130 | for name, x in sorted(xs.items()): 131 | out.append(x) 132 | 133 | # position encoding 134 | for x in out: 135 | pos.append(self[1](x).to(x.tensors.dtype)) 136 | 137 | return out, pos 138 | 139 | 140 | def build_backbone(args): 141 | position_embedding = build_position_encoding(args) 142 | train_backbone = args.lr_backbone > 0 143 | return_interm_layers = args.masks or (args.num_feature_levels > 1) 144 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 145 | model = Joiner(backbone, position_embedding) 146 | return model 147 | -------------------------------------------------------------------------------- /ovdetr/datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from codes in torch.utils.data.distributed 6 | # ------------------------------------------------------------------------ 7 | 8 | import math 9 | import os 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class DistributedSampler(Sampler): 17 | """Sampler that restricts data loading to a subset of the dataset. 18 | It is especially useful in conjunction with 19 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 20 | process can pass a DistributedSampler instance as a DataLoader sampler, 21 | and load a subset of the original dataset that is exclusive to it. 22 | .. note:: 23 | Dataset is assumed to be of constant size. 24 | Arguments: 25 | dataset: Dataset used for sampling. 26 | num_replicas (optional): Number of processes participating in 27 | distributed training. 28 | rank (optional): Rank of the current process within num_replicas. 29 | """ 30 | 31 | def __init__( 32 | self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True 33 | ): 34 | if num_replicas is None: 35 | if not dist.is_available(): 36 | raise RuntimeError("Requires distributed package to be available") 37 | num_replicas = dist.get_world_size() 38 | if rank is None: 39 | if not dist.is_available(): 40 | raise RuntimeError("Requires distributed package to be available") 41 | rank = dist.get_rank() 42 | self.dataset = dataset 43 | self.num_replicas = num_replicas 44 | self.rank = rank 45 | self.epoch = 0 46 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 47 | self.total_size = self.num_samples * self.num_replicas 48 | self.shuffle = shuffle 49 | 50 | def __iter__(self): 51 | if self.shuffle: 52 | # deterministically shuffle based on epoch 53 | g = torch.Generator() 54 | g.manual_seed(self.epoch) 55 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 56 | else: 57 | indices = torch.arange(len(self.dataset)).tolist() 58 | 59 | # add extra samples to make it evenly divisible 60 | indices += indices[: (self.total_size - len(indices))] 61 | assert len(indices) == self.total_size 62 | 63 | # subsample 64 | offset = self.num_samples * self.rank 65 | indices = indices[offset : offset + self.num_samples] 66 | assert len(indices) == self.num_samples 67 | 68 | return iter(indices) 69 | 70 | def __len__(self): 71 | return self.num_samples 72 | 73 | def set_epoch(self, epoch): 74 | self.epoch = epoch 75 | 76 | 77 | class NodeDistributedSampler(Sampler): 78 | """Sampler that restricts data loading to a subset of the dataset. 79 | It is especially useful in conjunction with 80 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 81 | process can pass a DistributedSampler instance as a DataLoader sampler, 82 | and load a subset of the original dataset that is exclusive to it. 83 | .. note:: 84 | Dataset is assumed to be of constant size. 85 | Arguments: 86 | dataset: Dataset used for sampling. 87 | num_replicas (optional): Number of processes participating in 88 | distributed training. 89 | rank (optional): Rank of the current process within num_replicas. 90 | """ 91 | 92 | def __init__( 93 | self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True 94 | ): 95 | if num_replicas is None: 96 | if not dist.is_available(): 97 | raise RuntimeError("Requires distributed package to be available") 98 | num_replicas = dist.get_world_size() 99 | if rank is None: 100 | if not dist.is_available(): 101 | raise RuntimeError("Requires distributed package to be available") 102 | rank = dist.get_rank() 103 | if local_rank is None: 104 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 105 | if local_size is None: 106 | local_size = int(os.environ.get("LOCAL_SIZE", 1)) 107 | self.dataset = dataset 108 | self.shuffle = shuffle 109 | self.num_replicas = num_replicas 110 | self.num_parts = local_size 111 | self.rank = rank 112 | self.local_rank = local_rank 113 | self.epoch = 0 114 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 115 | self.total_size = self.num_samples * self.num_replicas 116 | 117 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts 118 | 119 | def __iter__(self): 120 | if self.shuffle: 121 | # deterministically shuffle based on epoch 122 | g = torch.Generator() 123 | g.manual_seed(self.epoch) 124 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 125 | else: 126 | indices = torch.arange(len(self.dataset)).tolist() 127 | indices = [i for i in indices if i % self.num_parts == self.local_rank] 128 | 129 | # add extra samples to make it evenly divisible 130 | indices += indices[: (self.total_size_parts - len(indices))] 131 | assert len(indices) == self.total_size_parts 132 | 133 | # subsample 134 | indices = indices[ 135 | self.rank 136 | // self.num_parts : self.total_size_parts : self.num_replicas 137 | // self.num_parts 138 | ] 139 | assert len(indices) == self.num_samples 140 | 141 | return iter(indices) 142 | 143 | def __len__(self): 144 | return self.num_samples 145 | 146 | def set_epoch(self, epoch): 147 | self.epoch = epoch 148 | -------------------------------------------------------------------------------- /ovdetr/models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import, division, print_function 10 | 11 | import math 12 | import warnings 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn 17 | from torch.nn.init import constant_, xavier_uniform_ 18 | 19 | from ..functions import MSDeformAttnFunction 20 | 21 | 22 | def _is_power_of_2(n): 23 | if (not isinstance(n, int)) or (n < 0): 24 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 25 | return (n & (n - 1) == 0) and n != 0 26 | 27 | 28 | class MSDeformAttn(nn.Module): 29 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 30 | """ 31 | Multi-Scale Deformable Attention Module 32 | :param d_model hidden dimension 33 | :param n_levels number of feature levels 34 | :param n_heads number of attention heads 35 | :param n_points number of sampling points per attention head per feature level 36 | """ 37 | super().__init__() 38 | if d_model % n_heads != 0: 39 | raise ValueError( 40 | "d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads) 41 | ) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn( 46 | "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 47 | "which is more efficient in our CUDA implementation." 48 | ) 49 | 50 | self.im2col_step = 64 51 | 52 | self.d_model = d_model 53 | self.n_levels = n_levels 54 | self.n_heads = n_heads 55 | self.n_points = n_points 56 | 57 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 58 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 59 | self.value_proj = nn.Linear(d_model, d_model) 60 | self.output_proj = nn.Linear(d_model, d_model) 61 | 62 | self._reset_parameters() 63 | 64 | def _reset_parameters(self): 65 | constant_(self.sampling_offsets.weight.data, 0.0) 66 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 67 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 68 | grid_init = ( 69 | (grid_init / grid_init.abs().max(-1, keepdim=True)[0]) 70 | .view(self.n_heads, 1, 1, 2) 71 | .repeat(1, self.n_levels, self.n_points, 1) 72 | ) 73 | for i in range(self.n_points): 74 | grid_init[:, :, i, :] *= i + 1 75 | with torch.no_grad(): 76 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 77 | constant_(self.attention_weights.weight.data, 0.0) 78 | constant_(self.attention_weights.bias.data, 0.0) 79 | xavier_uniform_(self.value_proj.weight.data) 80 | constant_(self.value_proj.bias.data, 0.0) 81 | xavier_uniform_(self.output_proj.weight.data) 82 | constant_(self.output_proj.bias.data, 0.0) 83 | 84 | def forward( 85 | self, 86 | query, 87 | reference_points, 88 | input_flatten, 89 | input_spatial_shapes, 90 | input_level_start_index, 91 | input_padding_mask=None, 92 | ): 93 | N, Len_q, _ = query.shape 94 | N, Len_in, _ = input_flatten.shape 95 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 96 | 97 | value = self.value_proj(input_flatten) 98 | if input_padding_mask is not None: 99 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 100 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 101 | sampling_offsets = self.sampling_offsets(query).view( 102 | N, Len_q, self.n_heads, self.n_levels, self.n_points, 2 103 | ) 104 | attention_weights = self.attention_weights(query).view( 105 | N, Len_q, self.n_heads, self.n_levels * self.n_points 106 | ) 107 | attention_weights = F.softmax(attention_weights, -1).view( 108 | N, Len_q, self.n_heads, self.n_levels, self.n_points 109 | ) 110 | # N, Len_q, n_heads, n_levels, n_points, 2 111 | if reference_points.shape[-1] == 2: 112 | offset_normalizer = torch.stack( 113 | [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1 114 | ) 115 | sampling_locations = ( 116 | reference_points[:, :, None, :, None, :] 117 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 118 | ) 119 | elif reference_points.shape[-1] == 4: 120 | sampling_locations = ( 121 | reference_points[:, :, None, :, None, :2] 122 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 123 | ) 124 | else: 125 | raise ValueError( 126 | "Last dim of reference_points must be 2 or 4, but get {} instead.".format( 127 | reference_points.shape[-1] 128 | ) 129 | ) 130 | output = MSDeformAttnFunction.apply( 131 | value, 132 | input_spatial_shapes, 133 | input_level_start_index, 134 | sampling_locations, 135 | attention_weights, 136 | self.im2col_step, 137 | ) 138 | output = self.output_proj(output) 139 | return output 140 | -------------------------------------------------------------------------------- /ovdetr/datasets/lvis.py: -------------------------------------------------------------------------------- 1 | """ 2 | LVIS dataset which returns image_id for evaluation. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 5 | """ 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | from pycocotools import mask as coco_mask 11 | 12 | import datasets.transforms as T 13 | 14 | from .torchvision_datasets import LvisDetection as TvLvisDetection 15 | 16 | 17 | class LvisDetection(TvLvisDetection): 18 | def __init__(self, img_folder, ann_file, transforms, return_masks, label_map): 19 | super(LvisDetection, self).__init__(img_folder, ann_file) 20 | self._transforms = transforms 21 | self.cat_ids = self.lvis.get_cat_ids() 22 | self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} 23 | self.prepare = ConvertCocoPolysToMask(return_masks, self.cat2label, label_map) 24 | 25 | def __getitem__(self, idx): 26 | img, target = super(LvisDetection, self).__getitem__(idx) 27 | image_id = self.ids[idx] 28 | target = {"image_id": image_id, "annotations": target} 29 | img, target = self.prepare(img, target) 30 | if self._transforms is not None: 31 | img, target = self._transforms(img, target) 32 | if len(target["labels"]) == 0: 33 | return self[(idx + 1) % len(self)] 34 | else: 35 | return img, target 36 | 37 | 38 | def convert_coco_poly_to_mask(segmentations, height, width): 39 | masks = [] 40 | for polygons in segmentations: 41 | rles = coco_mask.frPyObjects(polygons, height, width) 42 | mask = coco_mask.decode(rles) 43 | if len(mask.shape) < 3: 44 | mask = mask[..., None] 45 | mask = torch.as_tensor(mask, dtype=torch.uint8) 46 | mask = mask.any(dim=2) 47 | masks.append(mask) 48 | if masks: 49 | masks = torch.stack(masks, dim=0) 50 | else: 51 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 52 | return masks 53 | 54 | 55 | class ConvertCocoPolysToMask(object): 56 | def __init__(self, return_masks=False, cat2label=None, label_map=False): 57 | self.return_masks = return_masks 58 | self.cat2label = cat2label 59 | self.label_map = label_map 60 | 61 | def __call__(self, image, target): 62 | w, h = image.size 63 | 64 | image_id = target["image_id"] 65 | image_id = torch.tensor([image_id]) 66 | 67 | anno = target["annotations"] 68 | 69 | anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] 70 | 71 | boxes = [obj["bbox"] for obj in anno] 72 | # guard against no boxes via resizing 73 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 74 | boxes[:, 2:] += boxes[:, :2] 75 | boxes[:, 0::2].clamp_(min=0, max=w) 76 | boxes[:, 1::2].clamp_(min=0, max=h) 77 | 78 | if self.label_map: 79 | classes = [self.cat2label[obj["category_id"]] for obj in anno] 80 | else: 81 | classes = [obj["category_id"] for obj in anno] 82 | classes = torch.tensor(classes, dtype=torch.int64) 83 | 84 | if self.return_masks: 85 | segmentations = [obj["segmentation"] for obj in anno] 86 | masks = convert_coco_poly_to_mask(segmentations, h, w) 87 | 88 | keypoints = None 89 | if anno and "keypoints" in anno[0]: 90 | keypoints = [obj["keypoints"] for obj in anno] 91 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 92 | num_keypoints = keypoints.shape[0] 93 | if num_keypoints: 94 | keypoints = keypoints.view(num_keypoints, -1, 3) 95 | 96 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 97 | boxes = boxes[keep] 98 | classes = classes[keep] 99 | if self.return_masks: 100 | masks = masks[keep] 101 | if keypoints is not None: 102 | keypoints = keypoints[keep] 103 | 104 | target = {} 105 | target["boxes"] = boxes 106 | target["labels"] = classes 107 | if self.return_masks: 108 | target["masks"] = masks 109 | target["image_id"] = image_id 110 | if keypoints is not None: 111 | target["keypoints"] = keypoints 112 | 113 | # for conversion to coco api 114 | area = torch.tensor([obj["area"] for obj in anno]) 115 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 116 | target["area"] = area[keep] 117 | target["iscrowd"] = iscrowd[keep] 118 | 119 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 120 | target["size"] = torch.as_tensor([int(h), int(w)]) 121 | 122 | return image, target 123 | 124 | 125 | def make_coco_transforms(image_set): 126 | 127 | normalize = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 128 | 129 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 130 | 131 | if image_set == "train": 132 | return T.Compose( 133 | [ 134 | T.RandomHorizontalFlip(), 135 | T.RandomSelect( 136 | T.RandomResize(scales, max_size=1333), 137 | T.Compose( 138 | [ 139 | T.RandomResize([400, 500, 600]), 140 | T.RandomSizeCrop(384, 600), 141 | T.RandomResize(scales, max_size=1333), 142 | ] 143 | ), 144 | ), 145 | normalize, 146 | ] 147 | ) 148 | 149 | if image_set == "val": 150 | return T.Compose( 151 | [ 152 | T.RandomResize([800], max_size=1333), 153 | normalize, 154 | ] 155 | ) 156 | 157 | raise ValueError(f"unknown {image_set}") 158 | 159 | 160 | def build(image_set, args): 161 | root = Path(args.lvis_path) 162 | assert root.exists(), f"provided LVIS path {root} does not exist" 163 | PATHS = { 164 | "train": (root, root / "lvis_v1_train_norare.json"), 165 | "val": (root, root / "lvis_v1_val.json"), 166 | } 167 | 168 | img_folder, ann_file = PATHS[image_set] 169 | dataset = LvisDetection( 170 | img_folder, 171 | ann_file, 172 | transforms=make_coco_transforms(image_set), 173 | return_masks=args.masks, 174 | label_map=args.label_map, 175 | ) 176 | return dataset 177 | -------------------------------------------------------------------------------- /ovdetr/models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /ovdetr/util/clip_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from [ViLD](https://github.com/tensorflow/tpu/tree/master/models/official/detection/projects/vild) 2 | 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | from clip import clip 8 | 9 | from .coco_categories import COCO_CATEGORIES 10 | from .lvis_v1_categories import LVIS_CATEGORIES 11 | 12 | 13 | def article(name): 14 | return "an" if name[0] in "aeiou" else "a" 15 | 16 | 17 | def processed_name(name, rm_dot=False): 18 | # _ for lvis 19 | # / for obj365 20 | res = name.replace("_", " ").replace("/", " or ").lower() 21 | if rm_dot: 22 | res = res.rstrip(".") 23 | return res 24 | 25 | 26 | single_template = ["a photo of a {}."] 27 | 28 | multiple_templates = [ 29 | "There is {article} {} in the scene.", 30 | "There is the {} in the scene.", 31 | "a photo of {article} {} in the scene.", 32 | "a photo of the {} in the scene.", 33 | "a photo of one {} in the scene.", 34 | "itap of {article} {}.", 35 | "itap of my {}.", # itap: I took a picture of 36 | "itap of the {}.", 37 | "a photo of {article} {}.", 38 | "a photo of my {}.", 39 | "a photo of the {}.", 40 | "a photo of one {}.", 41 | "a photo of many {}.", 42 | "a good photo of {article} {}.", 43 | "a good photo of the {}.", 44 | "a bad photo of {article} {}.", 45 | "a bad photo of the {}.", 46 | "a photo of a nice {}.", 47 | "a photo of the nice {}.", 48 | "a photo of a cool {}.", 49 | "a photo of the cool {}.", 50 | "a photo of a weird {}.", 51 | "a photo of the weird {}.", 52 | "a photo of a small {}.", 53 | "a photo of the small {}.", 54 | "a photo of a large {}.", 55 | "a photo of the large {}.", 56 | "a photo of a clean {}.", 57 | "a photo of the clean {}.", 58 | "a photo of a dirty {}.", 59 | "a photo of the dirty {}.", 60 | "a bright photo of {article} {}.", 61 | "a bright photo of the {}.", 62 | "a dark photo of {article} {}.", 63 | "a dark photo of the {}.", 64 | "a photo of a hard to see {}.", 65 | "a photo of the hard to see {}.", 66 | "a low resolution photo of {article} {}.", 67 | "a low resolution photo of the {}.", 68 | "a cropped photo of {article} {}.", 69 | "a cropped photo of the {}.", 70 | "a close-up photo of {article} {}.", 71 | "a close-up photo of the {}.", 72 | "a jpeg corrupted photo of {article} {}.", 73 | "a jpeg corrupted photo of the {}.", 74 | "a blurry photo of {article} {}.", 75 | "a blurry photo of the {}.", 76 | "a pixelated photo of {article} {}.", 77 | "a pixelated photo of the {}.", 78 | "a black and white photo of the {}.", 79 | "a black and white photo of {article} {}.", 80 | "a plastic {}.", 81 | "the plastic {}.", 82 | "a toy {}.", 83 | "the toy {}.", 84 | "a plushie {}.", 85 | "the plushie {}.", 86 | "a cartoon {}.", 87 | "the cartoon {}.", 88 | "an embroidered {}.", 89 | "the embroidered {}.", 90 | "a painting of the {}.", 91 | "a painting of a {}.", 92 | ] 93 | 94 | 95 | def load_clip_to_cpu(visual_backbone): 96 | backbone_name = visual_backbone 97 | url = clip._MODELS[backbone_name] 98 | model_path = clip._download(url, os.path.expanduser("~/.cache/clip")) 99 | 100 | try: 101 | # loading JIT archive 102 | model = torch.jit.load(model_path, map_location="cpu").eval() 103 | state_dict = None 104 | 105 | except RuntimeError: 106 | state_dict = torch.load(model_path, map_location="cpu") 107 | 108 | model = clip.build_model(state_dict or model.state_dict()) 109 | 110 | return model 111 | 112 | 113 | class TextEncoder(nn.Module): 114 | def __init__(self, clip_model): 115 | super().__init__() 116 | self.transformer = clip_model.transformer 117 | self.positional_embedding = clip_model.positional_embedding 118 | self.ln_final = clip_model.ln_final 119 | self.text_projection = clip_model.text_projection 120 | self.dtype = clip_model.dtype 121 | self.token_embedding = clip_model.token_embedding 122 | 123 | def forward(self, text): 124 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 125 | 126 | x = x + self.positional_embedding.type(self.dtype) 127 | x = x.permute(1, 0, 2) # NLD -> LND 128 | x = self.transformer(x) 129 | x = x.permute(1, 0, 2) # LND -> NLD 130 | x = self.ln_final(x).type(self.dtype) 131 | 132 | # x.shape = [batch_size, n_ctx, transformer.width] 133 | # take features from the eot embedding (eot_token is the highest number in each sequence) 134 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 135 | 136 | return x 137 | 138 | 139 | def build_text_embedding_coco(): 140 | categories = COCO_CATEGORIES 141 | run_on_gpu = torch.cuda.is_available() 142 | 143 | clip_model = load_clip_to_cpu("ViT-B/32") 144 | text_model = TextEncoder(clip_model) 145 | if run_on_gpu: 146 | text_model = text_model.cuda() 147 | 148 | for _, param in text_model.named_parameters(): 149 | param.requires_grad = False 150 | templates = multiple_templates 151 | with torch.no_grad(): 152 | zeroshot_weights = [] 153 | for _, category in categories.items(): 154 | texts = [ 155 | template.format(processed_name(category, rm_dot=True), article=article(category)) 156 | for template in templates 157 | ] 158 | texts = [ 159 | "This is " + text if text.startswith("a") or text.startswith("the") else text 160 | for text in texts 161 | ] 162 | texts = clip.tokenize(texts) # tokenize 163 | if run_on_gpu: 164 | texts = texts.cuda() 165 | text_embeddings = text_model(texts) 166 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 167 | text_embedding = text_embeddings.mean(dim=0) 168 | text_embedding /= text_embedding.norm() 169 | zeroshot_weights.append(text_embedding) 170 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1) 171 | if run_on_gpu: 172 | zeroshot_weights = zeroshot_weights.cuda() 173 | zeroshot_weights = zeroshot_weights.t() 174 | all_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 38, 41, 42, 44, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 59, 60, 61, 62, 63, 65, 70, 72, 73, 74, 75, 76, 78, 79, 80, 81, 82, 84, 85, 86, 87, 90] # noqa 175 | all_ids = [i - 1 for i in all_ids] 176 | return zeroshot_weights[all_ids] 177 | 178 | 179 | def build_text_embedding_lvis(): 180 | categories = LVIS_CATEGORIES 181 | model, _ = clip.load("ViT-B/32") 182 | templates = multiple_templates 183 | 184 | run_on_gpu = torch.cuda.is_available() 185 | 186 | with torch.no_grad(): 187 | all_text_embeddings = [] 188 | for category in categories: 189 | texts = [ 190 | template.format( 191 | processed_name(category["name"], rm_dot=True), article=article(category["name"]) 192 | ) 193 | for template in templates 194 | ] 195 | texts = [ 196 | "This is " + text if text.startswith("a") or text.startswith("the") else text 197 | for text in texts 198 | ] 199 | texts = clip.tokenize(texts) # tokenize 200 | if run_on_gpu: 201 | texts = texts.cuda() 202 | model = model.cuda() 203 | text_embeddings = model.encode_text(texts) 204 | text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True) 205 | text_embedding = text_embeddings.mean(dim=0) 206 | text_embedding /= text_embedding.norm() 207 | all_text_embeddings.append(text_embedding) 208 | all_text_embeddings = torch.stack(all_text_embeddings, dim=1) 209 | if run_on_gpu: 210 | all_text_embeddings = all_text_embeddings.cuda() 211 | 212 | all_text_embeddings = all_text_embeddings.t() 213 | return all_text_embeddings 214 | -------------------------------------------------------------------------------- /ovdetr/models/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Modules to compute the matching cost and solve the corresponding LSAP. 12 | """ 13 | import torch 14 | from scipy.optimize import linear_sum_assignment 15 | from torch import nn 16 | 17 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 18 | 19 | 20 | class HungarianMatcher(nn.Module): 21 | """This class computes an assignment between the targets and the predictions of the network 22 | 23 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 24 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 25 | while the others are un-matched (and thus treated as non-objects). 26 | """ 27 | 28 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 29 | """Creates the matcher 30 | 31 | Params: 32 | cost_class: This is the relative weight of the classification error in the matching cost 33 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 34 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 35 | """ 36 | super().__init__() 37 | self.cost_class = cost_class 38 | self.cost_bbox = cost_bbox 39 | self.cost_giou = cost_giou 40 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 41 | 42 | def forward(self, outputs, targets): 43 | """Performs the matching 44 | 45 | Params: 46 | outputs: This is a dict that contains at least these entries: 47 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 48 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 49 | 50 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 51 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 52 | objects in the target) containing the class labels 53 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 54 | 55 | Returns: 56 | A list of size batch_size, containing tuples of (index_i, index_j) where: 57 | - index_i is the indices of the selected predictions (in order) 58 | - index_j is the indices of the corresponding selected targets (in order) 59 | For each batch element, it holds: 60 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 61 | """ 62 | with torch.no_grad(): 63 | bs, num_queries = outputs["pred_logits"].shape[:2] 64 | 65 | # We flatten to compute the cost matrices in a batch 66 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 67 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 68 | 69 | # Also concat the target labels and boxes 70 | tgt_ids = torch.cat([v["labels"] for v in targets]) 71 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 72 | 73 | # Compute the classification cost. 74 | alpha = 0.25 75 | gamma = 2.0 76 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 77 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 78 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 79 | 80 | # Compute the L1 cost between boxes 81 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 82 | 83 | # Compute the giou cost betwen boxes 84 | cost_giou = -generalized_box_iou( 85 | box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) 86 | ) 87 | 88 | # Final cost matrix 89 | C = ( 90 | self.cost_bbox * cost_bbox 91 | + self.cost_class * cost_class 92 | + self.cost_giou * cost_giou 93 | ) 94 | C = C.view(bs, num_queries, -1).cpu() 95 | 96 | sizes = [len(v["boxes"]) for v in targets] 97 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 98 | return [ 99 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) 100 | for i, j in indices 101 | ] 102 | 103 | 104 | class OVHungarianMatcher(HungarianMatcher): 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | 108 | @torch.no_grad() 109 | def forward(self, outputs, targets, select_id): 110 | # We flatten to compute the cost matrices in a batch 111 | num_patch = len(select_id) 112 | bs, num_queries = outputs["pred_logits"].shape[:2] 113 | num_queries = num_queries // num_patch 114 | out_prob_all = outputs["pred_logits"].view(bs, num_patch, num_queries, -1) 115 | out_bbox_all = outputs["pred_boxes"].view(bs, num_patch, num_queries, -1) 116 | 117 | # Also concat the target labels and boxes 118 | tgt_ids_all = torch.cat([v["labels"] for v in targets]) 119 | tgt_bbox_all = torch.cat([v["boxes"] for v in targets]) 120 | 121 | alpha = 0.25 122 | gamma = 2.0 123 | 124 | ans = [[[], []] for _ in range(bs)] 125 | 126 | for index, label in enumerate(select_id): 127 | out_prob = out_prob_all[:, index, :, :].flatten(0, 1).sigmoid() 128 | out_bbox = out_bbox_all[:, index, :, :].flatten(0, 1) 129 | 130 | mask = (tgt_ids_all == label).nonzero().squeeze(1) 131 | tgt_bbox = tgt_bbox_all[mask] 132 | 133 | # Compute the classification cost. 134 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 135 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 136 | cost_class = pos_cost_class[:, 0:1] - neg_cost_class[:, 0:1] 137 | 138 | # Compute the L1 cost between boxes 139 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 140 | 141 | # Compute the giou cost betwen boxes 142 | cost_giou = -generalized_box_iou( 143 | box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox) 144 | ) 145 | 146 | # Final cost matrix 147 | C = ( 148 | self.cost_bbox * cost_bbox 149 | + self.cost_class * cost_class 150 | + self.cost_giou * cost_giou 151 | ) 152 | C = C.view(bs, num_queries, -1).cpu() 153 | 154 | sizes = [len(v["labels"][v["labels"] == label]) for ind, v in enumerate(targets)] 155 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 156 | 157 | for ind in range(bs): 158 | x, y = indices[ind] 159 | if len(x) == 0: 160 | continue 161 | x += index * num_queries 162 | ans[ind][0] += x.tolist() 163 | y_label = (targets[ind]["labels"] == label).nonzero().squeeze(1).data.cpu().numpy() 164 | y_label = y_label[y].tolist() 165 | ans[ind][1] += y_label 166 | 167 | return [ 168 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) 169 | for i, j in ans 170 | ] 171 | 172 | 173 | def build_matcher(args): 174 | return OVHungarianMatcher( 175 | cost_class=args.set_cost_class, 176 | cost_bbox=args.set_cost_bbox, 177 | cost_giou=args.set_cost_giou, 178 | ), HungarianMatcher( 179 | cost_class=args.set_cost_class, 180 | cost_bbox=args.set_cost_bbox, 181 | cost_giou=args.set_cost_giou, 182 | ) 183 | -------------------------------------------------------------------------------- /ovdetr/datasets/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | """ 10 | COCO dataset which returns image_id for evaluation. 11 | 12 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 13 | """ 14 | from pathlib import Path 15 | 16 | import torch 17 | import torch.utils.data 18 | from pycocotools import mask as coco_mask 19 | 20 | import datasets.transforms as T 21 | from util.misc import get_local_rank, get_local_size 22 | 23 | from .torchvision_datasets import CocoDetection as TvCocoDetection 24 | 25 | 26 | class CocoDetection(TvCocoDetection): 27 | SEEN_CLASSES = ( 28 | "toilet", 29 | "bicycle", 30 | "apple", 31 | "train", 32 | "laptop", 33 | "carrot", 34 | "motorcycle", 35 | "oven", 36 | "chair", 37 | "mouse", 38 | "boat", 39 | "kite", 40 | "sheep", 41 | "horse", 42 | "sandwich", 43 | "clock", 44 | "tv", 45 | "backpack", 46 | "toaster", 47 | "bowl", 48 | "microwave", 49 | "bench", 50 | "book", 51 | "orange", 52 | "bird", 53 | "pizza", 54 | "fork", 55 | "frisbee", 56 | "bear", 57 | "vase", 58 | "toothbrush", 59 | "spoon", 60 | "giraffe", 61 | "handbag", 62 | "broccoli", 63 | "refrigerator", 64 | "remote", 65 | "surfboard", 66 | "car", 67 | "bed", 68 | "banana", 69 | "donut", 70 | "skis", 71 | "person", 72 | "truck", 73 | "bottle", 74 | "suitcase", 75 | "zebra", 76 | ) 77 | UNSEEN_CLASSES = ( 78 | "umbrella", 79 | "cow", 80 | "cup", 81 | "bus", 82 | "keyboard", 83 | "skateboard", 84 | "dog", 85 | "couch", 86 | "tie", 87 | "snowboard", 88 | "sink", 89 | "elephant", 90 | "cake", 91 | "scissors", 92 | "airplane", 93 | "cat", 94 | "knife", 95 | ) 96 | 97 | def __init__( 98 | self, 99 | img_folder, 100 | ann_file, 101 | transforms, 102 | return_masks, 103 | cache_mode=False, 104 | local_rank=0, 105 | local_size=1, 106 | label_map=False, 107 | ): 108 | super(CocoDetection, self).__init__( 109 | img_folder, 110 | ann_file, 111 | cache_mode=cache_mode, 112 | local_rank=local_rank, 113 | local_size=local_size, 114 | ) 115 | self._transforms = transforms 116 | self.cat_ids = self.coco.getCatIds(self.SEEN_CLASSES + self.UNSEEN_CLASSES) 117 | self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)} 118 | self.cat_ids_unseen = self.coco.getCatIds(self.UNSEEN_CLASSES) 119 | self.prepare = ConvertCocoPolysToMask( 120 | return_masks, self.cat2label, label_map, self.cat_ids_unseen 121 | ) 122 | 123 | def __getitem__(self, idx): 124 | img, target = super(CocoDetection, self).__getitem__(idx) 125 | image_id = self.ids[idx] 126 | target = {"image_id": image_id, "annotations": target} 127 | img, target = self.prepare(img, target) 128 | if self._transforms is not None: 129 | img, target = self._transforms(img, target) 130 | if len(target["labels"]) == 0: 131 | return self[(idx + 1) % len(self)] 132 | else: 133 | return img, target 134 | return img, target 135 | 136 | 137 | def convert_coco_poly_to_mask(segmentations, height, width): 138 | masks = [] 139 | for polygons in segmentations: 140 | rles = coco_mask.frPyObjects(polygons, height, width) 141 | mask = coco_mask.decode(rles) 142 | if len(mask.shape) < 3: 143 | mask = mask[..., None] 144 | mask = torch.as_tensor(mask, dtype=torch.uint8) 145 | mask = mask.any(dim=2) 146 | masks.append(mask) 147 | if masks: 148 | masks = torch.stack(masks, dim=0) 149 | else: 150 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 151 | return masks 152 | 153 | 154 | class ConvertCocoPolysToMask(object): 155 | def __init__(self, return_masks=False, cat2label=None, label_map=False, cat_ids_unseen=None): 156 | self.return_masks = return_masks 157 | self.cat2label = cat2label 158 | self.label_map = label_map 159 | self.cat_ids_unseen = cat_ids_unseen 160 | 161 | def __call__(self, image, target): 162 | w, h = image.size 163 | 164 | image_id = target["image_id"] 165 | image_id = torch.tensor([image_id]) 166 | 167 | anno = target["annotations"] 168 | 169 | anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] 170 | 171 | boxes = [obj["bbox"] for obj in anno] 172 | # guard against no boxes via resizing 173 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 174 | boxes[:, 2:] += boxes[:, :2] 175 | boxes[:, 0::2].clamp_(min=0, max=w) 176 | boxes[:, 1::2].clamp_(min=0, max=h) 177 | 178 | if self.label_map: 179 | classes = [ 180 | self.cat2label[obj["category_id"]] 181 | if obj["category_id"] >= 0 182 | else obj["category_id"] 183 | for obj in anno 184 | ] 185 | else: 186 | classes = [obj["category_id"] for obj in anno] 187 | classes = torch.tensor(classes, dtype=torch.int64) 188 | 189 | if self.return_masks: 190 | segmentations = [obj["segmentation"] for obj in anno] 191 | masks = convert_coco_poly_to_mask(segmentations, h, w) 192 | 193 | keypoints = None 194 | if anno and "keypoints" in anno[0]: 195 | keypoints = [obj["keypoints"] for obj in anno] 196 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 197 | num_keypoints = keypoints.shape[0] 198 | if num_keypoints: 199 | keypoints = keypoints.view(num_keypoints, -1, 3) 200 | 201 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 202 | boxes = boxes[keep] 203 | classes = classes[keep] 204 | if self.return_masks: 205 | masks = masks[keep] 206 | if keypoints is not None: 207 | keypoints = keypoints[keep] 208 | 209 | target = {} 210 | target["boxes"] = boxes 211 | target["labels"] = classes 212 | if self.return_masks: 213 | target["masks"] = masks 214 | target["image_id"] = image_id 215 | if keypoints is not None: 216 | target["keypoints"] = keypoints 217 | 218 | # for conversion to coco api 219 | area = torch.tensor([obj["area"] for obj in anno]) 220 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 221 | target["area"] = area[keep] 222 | target["iscrowd"] = iscrowd[keep] 223 | 224 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 225 | target["size"] = torch.as_tensor([int(h), int(w)]) 226 | 227 | return image, target 228 | 229 | 230 | def make_coco_transforms(image_set): 231 | 232 | normalize = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 233 | 234 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 235 | 236 | if image_set == "train": 237 | return T.Compose( 238 | [ 239 | T.RandomHorizontalFlip(), 240 | T.RandomSelect( 241 | T.RandomResize(scales, max_size=1333), 242 | T.Compose( 243 | [ 244 | T.RandomResize([400, 500, 600]), 245 | T.RandomSizeCrop(384, 600), 246 | T.RandomResize(scales, max_size=1333), 247 | ] 248 | ), 249 | ), 250 | normalize, 251 | ] 252 | ) 253 | 254 | if image_set == "val": 255 | return T.Compose( 256 | [ 257 | T.RandomResize([800], max_size=1333), 258 | normalize, 259 | ] 260 | ) 261 | 262 | raise ValueError(f"unknown {image_set}") 263 | 264 | 265 | def build(image_set, args): 266 | root = Path(args.coco_path) 267 | assert root.exists(), f"provided COCO path {root} does not exist" 268 | mode = "instances" 269 | PATHS = { 270 | "train": ( 271 | root / "train2017", 272 | root / "zero-shot" / f"{mode}_train2017_seen_2_proposal.json", 273 | ), 274 | "val": (root / "val2017", root / "zero-shot" / f"{mode}_val2017_all_2.json"), 275 | } 276 | 277 | img_folder, ann_file = PATHS[image_set] 278 | dataset = CocoDetection( 279 | img_folder, 280 | ann_file, 281 | transforms=make_coco_transforms(image_set), 282 | return_masks=args.masks, 283 | cache_mode=args.cache_mode, 284 | local_rank=get_local_rank(), 285 | local_size=get_local_size(), 286 | label_map=args.label_map, 287 | ) 288 | return dataset 289 | -------------------------------------------------------------------------------- /ovdetr/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | """ 10 | Transforms and data augmentation for both image + bbox. 11 | """ 12 | import random 13 | 14 | import PIL 15 | import torch 16 | import torchvision.transforms as T 17 | import torchvision.transforms.functional as F 18 | 19 | from util.box_ops import box_xyxy_to_cxcywh 20 | from util.misc import interpolate 21 | 22 | 23 | def crop(image, target, region): 24 | cropped_image = F.crop(image, *region) 25 | 26 | target = target.copy() 27 | i, j, h, w = region 28 | 29 | # should we do something wrt the original size? 30 | target["size"] = torch.tensor([h, w]) 31 | 32 | fields = ["labels", "area", "iscrowd"] 33 | 34 | if "clip_image" in target: 35 | fields.append("clip_image") 36 | 37 | if "boxes" in target: 38 | boxes = target["boxes"] 39 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 40 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 41 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 42 | cropped_boxes = cropped_boxes.clamp(min=0) 43 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 44 | target["boxes"] = cropped_boxes.reshape(-1, 4) 45 | target["area"] = area 46 | fields.append("boxes") 47 | 48 | if "masks" in target: 49 | # FIXME should we update the area here if there are no boxes? 50 | target["masks"] = target["masks"][:, i : i + h, j : j + w] 51 | fields.append("masks") 52 | 53 | # remove elements for which the boxes or masks that have zero area 54 | if "boxes" in target or "masks" in target: 55 | # favor boxes selection when defining which elements to keep 56 | # this is compatible with previous implementation 57 | if "boxes" in target: 58 | cropped_boxes = target["boxes"].reshape(-1, 2, 2) 59 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 60 | else: 61 | keep = target["masks"].flatten(1).any(1) 62 | 63 | for field in fields: 64 | target[field] = target[field][keep] 65 | 66 | return cropped_image, target 67 | 68 | 69 | def hflip(image, target): 70 | flipped_image = F.hflip(image) 71 | 72 | w, h = image.size 73 | 74 | target = target.copy() 75 | if "boxes" in target: 76 | boxes = target["boxes"] 77 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor( 78 | [w, 0, w, 0] 79 | ) 80 | target["boxes"] = boxes 81 | 82 | if "masks" in target: 83 | target["masks"] = target["masks"].flip(-1) 84 | 85 | return flipped_image, target 86 | 87 | 88 | def resize(image, target, size, max_size=None): 89 | # size can be min_size (scalar) or (w, h) tuple 90 | 91 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 92 | w, h = image_size 93 | if max_size is not None: 94 | min_original_size = float(min((w, h))) 95 | max_original_size = float(max((w, h))) 96 | if max_original_size / min_original_size * size > max_size: 97 | size = int(round(max_size * min_original_size / max_original_size)) 98 | 99 | if (w <= h and w == size) or (h <= w and h == size): 100 | return (h, w) 101 | 102 | if w < h: 103 | ow = size 104 | oh = int(size * h / w) 105 | else: 106 | oh = size 107 | ow = int(size * w / h) 108 | 109 | return (oh, ow) 110 | 111 | def get_size(image_size, size, max_size=None): 112 | if isinstance(size, (list, tuple)): 113 | return size[::-1] 114 | else: 115 | return get_size_with_aspect_ratio(image_size, size, max_size) 116 | 117 | size = get_size(image.size, size, max_size) 118 | rescaled_image = F.resize(image, size) 119 | 120 | if target is None: 121 | return rescaled_image, None 122 | 123 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 124 | ratio_width, ratio_height = ratios 125 | 126 | target = target.copy() 127 | if "boxes" in target: 128 | boxes = target["boxes"] 129 | scaled_boxes = boxes * torch.as_tensor( 130 | [ratio_width, ratio_height, ratio_width, ratio_height] 131 | ) 132 | target["boxes"] = scaled_boxes 133 | 134 | if "area" in target: 135 | area = target["area"] 136 | scaled_area = area * (ratio_width * ratio_height) 137 | target["area"] = scaled_area 138 | 139 | h, w = size 140 | target["size"] = torch.tensor([h, w]) 141 | 142 | if "masks" in target: 143 | target["masks"] = ( 144 | interpolate(target["masks"][:, None].float(), size, mode="nearest")[:, 0] > 0.5 145 | ) 146 | 147 | return rescaled_image, target 148 | 149 | 150 | def pad(image, target, padding): 151 | # assumes that we only pad on the bottom right corners 152 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 153 | if target is None: 154 | return padded_image, None 155 | target = target.copy() 156 | # should we do something wrt the original size? 157 | target["size"] = torch.tensor(padded_image[::-1]) 158 | if "masks" in target: 159 | target["masks"] = torch.nn.functional.pad(target["masks"], (0, padding[0], 0, padding[1])) 160 | return padded_image, target 161 | 162 | 163 | class RandomCrop(object): 164 | def __init__(self, size): 165 | self.size = size 166 | 167 | def __call__(self, img, target): 168 | region = T.RandomCrop.get_params(img, self.size) 169 | return crop(img, target, region) 170 | 171 | 172 | class RandomSizeCrop(object): 173 | def __init__(self, min_size: int, max_size: int): 174 | self.min_size = min_size 175 | self.max_size = max_size 176 | 177 | def __call__(self, img: PIL.Image.Image, target: dict): 178 | w = random.randint(self.min_size, min(img.width, self.max_size)) 179 | h = random.randint(self.min_size, min(img.height, self.max_size)) 180 | region = T.RandomCrop.get_params(img, [h, w]) 181 | return crop(img, target, region) 182 | 183 | 184 | class CenterCrop(object): 185 | def __init__(self, size): 186 | self.size = size 187 | 188 | def __call__(self, img, target): 189 | image_width, image_height = img.size 190 | crop_height, crop_width = self.size 191 | crop_top = int(round((image_height - crop_height) / 2.0)) 192 | crop_left = int(round((image_width - crop_width) / 2.0)) 193 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 194 | 195 | 196 | class RandomHorizontalFlip(object): 197 | def __init__(self, p=0.5): 198 | self.p = p 199 | 200 | def __call__(self, img, target): 201 | if random.random() < self.p: 202 | return hflip(img, target) 203 | return img, target 204 | 205 | 206 | class RandomResize(object): 207 | def __init__(self, sizes, max_size=None): 208 | assert isinstance(sizes, (list, tuple)) 209 | self.sizes = sizes 210 | self.max_size = max_size 211 | 212 | def __call__(self, img, target=None): 213 | size = random.choice(self.sizes) 214 | return resize(img, target, size, self.max_size) 215 | 216 | 217 | class RandomPad(object): 218 | def __init__(self, max_pad): 219 | self.max_pad = max_pad 220 | 221 | def __call__(self, img, target): 222 | pad_x = random.randint(0, self.max_pad) 223 | pad_y = random.randint(0, self.max_pad) 224 | return pad(img, target, (pad_x, pad_y)) 225 | 226 | 227 | class RandomSelect(object): 228 | """ 229 | Randomly selects between transforms1 and transforms2, 230 | with probability p for transforms1 and (1 - p) for transforms2 231 | """ 232 | 233 | def __init__(self, transforms1, transforms2, p=0.5): 234 | self.transforms1 = transforms1 235 | self.transforms2 = transforms2 236 | self.p = p 237 | 238 | def __call__(self, img, target): 239 | if random.random() < self.p: 240 | return self.transforms1(img, target) 241 | return self.transforms2(img, target) 242 | 243 | 244 | class ToTensor(object): 245 | def __call__(self, img, target): 246 | return F.to_tensor(img), target 247 | 248 | 249 | class RandomErasing(object): 250 | def __init__(self, *args, **kwargs): 251 | self.eraser = T.RandomErasing(*args, **kwargs) 252 | 253 | def __call__(self, img, target): 254 | return self.eraser(img), target 255 | 256 | 257 | class Normalize(object): 258 | def __init__(self, mean, std): 259 | self.mean = mean 260 | self.std = std 261 | 262 | def __call__(self, image, target=None): 263 | image = F.normalize(image, mean=self.mean, std=self.std) 264 | if target is None: 265 | return image, None 266 | target = target.copy() 267 | h, w = image.shape[-2:] 268 | if "boxes" in target: 269 | boxes = target["boxes"] 270 | boxes = box_xyxy_to_cxcywh(boxes) 271 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 272 | target["boxes"] = boxes 273 | return image, target 274 | 275 | 276 | class Compose(object): 277 | def __init__(self, transforms): 278 | self.transforms = transforms 279 | 280 | def __call__(self, image, target): 281 | for t in self.transforms: 282 | image, target = t(image, target) 283 | return image, target 284 | 285 | def __repr__(self): 286 | format_string = self.__class__.__name__ + "(" 287 | for t in self.transforms: 288 | format_string += "\n" 289 | format_string += " {0}".format(t) 290 | format_string += "\n)" 291 | return format_string 292 | -------------------------------------------------------------------------------- /ovdetr/engine_ov.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # OV DETR 3 | # Copyright (c) S-LAB, Nanyang Technological University. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # Modified from DETR (https://github.com/facebookresearch/detr) 9 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 10 | # ------------------------------------------------------------------------ 11 | 12 | 13 | import math 14 | import sys 15 | from typing import Iterable 16 | 17 | import numpy as np 18 | import pycocotools.mask as mask_util 19 | import torch 20 | from torch.cuda.amp import GradScaler, autocast 21 | 22 | import util.misc as utils 23 | from datasets.coco_eval import CocoEvaluator, convert_to_xywh 24 | from datasets.data_prefetcher import data_prefetcher 25 | 26 | 27 | def train_one_epoch( 28 | model: torch.nn.Module, 29 | criterion: torch.nn.Module, 30 | data_loader: Iterable, 31 | optimizer: torch.optim.Optimizer, 32 | device: torch.device, 33 | epoch: int, 34 | max_norm: float = 0, 35 | masks: bool = False, 36 | amp: bool = False, 37 | ): 38 | model.train() 39 | criterion.train() 40 | metric_logger = utils.MetricLogger(delimiter=" ") 41 | metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) 42 | metric_logger.add_meter("grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")) 43 | header = "Epoch: [{}]".format(epoch) 44 | print_freq = 10 45 | 46 | scaler = GradScaler() 47 | 48 | prefetcher = data_prefetcher(data_loader, device, prefetch=True) 49 | samples, targets = prefetcher.next() 50 | 51 | # for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 52 | for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header): 53 | with autocast(enabled=amp): 54 | if not masks: 55 | outputs = model(samples, targets) 56 | else: 57 | outputs = model(samples, targets, criterion) 58 | loss_dict = criterion(outputs, targets) 59 | weight_dict = criterion.weight_dict 60 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 61 | 62 | # reduce losses over all GPUs for logging purposes 63 | loss_dict_reduced = utils.reduce_dict(loss_dict) 64 | loss_dict_reduced_unscaled = {f"{k}_unscaled": v for k, v in loss_dict_reduced.items()} 65 | loss_dict_reduced_scaled = { 66 | k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict 67 | } 68 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 69 | 70 | loss_value = losses_reduced_scaled.item() 71 | 72 | if not math.isfinite(loss_value): 73 | print("Loss is {}, stopping training".format(loss_value)) 74 | print(loss_dict_reduced) 75 | sys.exit(1) 76 | 77 | if amp: 78 | optimizer.zero_grad() 79 | scaler.scale(losses).backward() 80 | if max_norm > 0: 81 | scaler.unscale_(optimizer) 82 | grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 83 | else: 84 | grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) 85 | scaler.step(optimizer) 86 | scaler.update() 87 | else: 88 | optimizer.zero_grad() 89 | losses.backward() 90 | if max_norm > 0: 91 | grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 92 | else: 93 | grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) 94 | optimizer.step() 95 | 96 | metric_logger.update( 97 | loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled 98 | ) 99 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 100 | metric_logger.update(grad_norm=grad_total_norm) 101 | 102 | samples, targets = prefetcher.next() 103 | # gather the stats from all processes 104 | metric_logger.synchronize_between_processes() 105 | print("Averaged stats:", metric_logger) 106 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 107 | 108 | 109 | @torch.no_grad() 110 | def evaluate( 111 | model, criterion, postprocessors, data_loader, base_ds, device, output_dir, label_map, amp 112 | ): 113 | model.eval() 114 | criterion.eval() 115 | 116 | metric_logger = utils.MetricLogger(delimiter=" ") 117 | header = "Test:" 118 | 119 | iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) 120 | coco_evaluator = CocoEvaluator(base_ds, iou_types, data_loader.dataset.cat2label, label_map) 121 | 122 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 123 | samples = samples.to(device) 124 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 125 | 126 | with autocast(enabled=amp): 127 | outputs = model(samples) 128 | 129 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 130 | results, topk_boxes = postprocessors["bbox"](outputs, orig_target_sizes) 131 | if "segm" in postprocessors.keys(): 132 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 133 | outputs_masks = outputs["pred_masks"].squeeze(2) 134 | 135 | bs = len(topk_boxes) 136 | outputs_masks_new = [[] for _ in range(bs)] 137 | for b in range(bs): 138 | for index in topk_boxes[b]: 139 | outputs_masks_new[b].append(outputs_masks[b : b + 1, index : index + 1, :, :]) 140 | for b in range(bs): 141 | outputs_masks_new[b] = torch.cat(outputs_masks_new[b], 1) 142 | outputs["pred_masks"] = torch.cat(outputs_masks_new, 0) 143 | 144 | results = postprocessors["segm"](results, outputs, orig_target_sizes, target_sizes) 145 | res = {target["image_id"].item(): output for target, output in zip(targets, results)} 146 | 147 | if coco_evaluator is not None: 148 | coco_evaluator.update(res) 149 | 150 | # gather the stats from all processes 151 | metric_logger.synchronize_between_processes() 152 | print("Averaged stats:", metric_logger) 153 | if coco_evaluator is not None: 154 | coco_evaluator.synchronize_between_processes() 155 | 156 | # accumulate predictions from all images 157 | if coco_evaluator is not None: 158 | coco_evaluator.accumulate() 159 | coco_evaluator.summarize() 160 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 161 | if coco_evaluator is not None: 162 | if "bbox" in postprocessors.keys(): 163 | stats["coco_eval_bbox"] = coco_evaluator.coco_eval["bbox"].stats.tolist() 164 | if "segm" in postprocessors.keys(): 165 | stats["coco_eval_masks"] = coco_evaluator.coco_eval["segm"].stats.tolist() 166 | return stats, coco_evaluator 167 | 168 | 169 | @torch.no_grad() 170 | def lvis_evaluate( 171 | model, criterion, postprocessors, data_loader, base_ds, device, output_dir, label_map, amp 172 | ): 173 | model.eval() 174 | criterion.eval() 175 | 176 | metric_logger = utils.MetricLogger(delimiter=" ") 177 | header = "Test:" 178 | 179 | iou_types = tuple(k for k in ("segm", "bbox") if k in postprocessors.keys()) 180 | lvis_results = [] 181 | 182 | cat2label = data_loader.dataset.cat2label 183 | label2cat = {v: k for k, v in cat2label.items()} 184 | 185 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 186 | samples = samples.to(device) 187 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 188 | 189 | with autocast(enabled=amp): 190 | outputs = model(samples) 191 | 192 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 193 | results, topk_boxes = postprocessors["bbox"](outputs, orig_target_sizes) 194 | if "segm" in postprocessors.keys(): 195 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 196 | outputs_masks = outputs["pred_masks"].squeeze(2) 197 | 198 | bs = len(topk_boxes) 199 | outputs_masks_new = [[] for _ in range(bs)] 200 | for b in range(bs): 201 | for index in topk_boxes[b]: 202 | outputs_masks_new[b].append(outputs_masks[b : b + 1, index : index + 1, :, :]) 203 | for b in range(bs): 204 | outputs_masks_new[b] = torch.cat(outputs_masks_new[b], 1) 205 | outputs["pred_masks"] = torch.cat(outputs_masks_new, 0) 206 | 207 | results = postprocessors["segm"](results, outputs, orig_target_sizes, target_sizes) 208 | 209 | for target, output in zip(targets, results): 210 | image_id = target["image_id"].item() 211 | 212 | if "masks" in output.keys(): 213 | masks = output["masks"].data.cpu().numpy() 214 | masks = masks > 0.5 215 | rles = [ 216 | mask_util.encode( 217 | np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F") 218 | )[0] 219 | for mask in masks 220 | ] 221 | for rle in rles: 222 | rle["counts"] = rle["counts"].decode("utf-8") 223 | 224 | boxes = convert_to_xywh(output["boxes"]) 225 | for ind in range(len(output["scores"])): 226 | temp = { 227 | "image_id": image_id, 228 | "score": output["scores"][ind].item(), 229 | "category_id": output["labels"][ind].item(), 230 | "bbox": boxes[ind].tolist(), 231 | } 232 | if label_map: 233 | temp["category_id"] = label2cat[temp["category_id"]] 234 | if "masks" in output.keys(): 235 | temp["segmentation"] = rles[ind] 236 | 237 | lvis_results.append(temp) 238 | 239 | rank = torch.distributed.get_rank() 240 | torch.save(lvis_results, output_dir + f"/pred_{rank}.pth") 241 | 242 | # gather the stats from all processes 243 | metric_logger.synchronize_between_processes() 244 | torch.distributed.barrier() 245 | if rank == 0: 246 | world_size = torch.distributed.get_world_size() 247 | for i in range(1, world_size): 248 | temp = torch.load(output_dir + f"/pred_{i}.pth") 249 | lvis_results += temp 250 | 251 | from lvis import LVISEval, LVISResults 252 | 253 | lvis_results = LVISResults(base_ds, lvis_results, max_dets=300) 254 | for iou_type in iou_types: 255 | lvis_eval = LVISEval(base_ds, lvis_results, iou_type) 256 | lvis_eval.run() 257 | lvis_eval.print_results() 258 | torch.distributed.barrier() 259 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 260 | return stats, None 261 | -------------------------------------------------------------------------------- /ovdetr/datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | """ 10 | COCO evaluator that works in distributed mode. 11 | 12 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 13 | The difference is that there is less copy-pasting from pycocotools 14 | in the end of the file, as python3 can suppress prints with contextlib 15 | """ 16 | import contextlib 17 | import copy 18 | import os 19 | 20 | import numpy as np 21 | import pycocotools.mask as mask_util 22 | import torch 23 | from pycocotools.coco import COCO 24 | from pycocotools.cocoeval import COCOeval 25 | 26 | from util.misc import all_gather 27 | 28 | 29 | class CocoEvaluator(object): 30 | def __init__(self, coco_gt, iou_types, cat2label, label_map=False): 31 | assert isinstance(iou_types, (list, tuple)) 32 | coco_gt = copy.deepcopy(coco_gt) 33 | self.coco_gt = coco_gt 34 | 35 | self.iou_types = iou_types 36 | self.coco_eval = {} 37 | for iou_type in iou_types: 38 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 39 | self.coco_eval[iou_type].useCats = True 40 | 41 | self.img_ids = [] 42 | self.eval_imgs = {k: [] for k in iou_types} 43 | 44 | self.label2cat = {v: k for k, v in cat2label.items()} 45 | self.label_map = label_map 46 | if label_map: 47 | self.unseen_list = [4, 5, 11, 12, 15, 16, 21, 23, 27, 29, 32, 34, 45, 47, 54, 58, 63] 48 | else: 49 | self.unseen_list = [4, 5, 16, 17, 20, 21, 27, 31, 35, 40, 46, 48, 60, 62, 75, 80, 86] 50 | 51 | def update(self, predictions): 52 | img_ids = list(np.unique(list(predictions.keys()))) 53 | self.img_ids.extend(img_ids) 54 | 55 | for iou_type in self.iou_types: 56 | results = self.prepare(predictions, iou_type) 57 | 58 | # suppress pycocotools prints 59 | with open(os.devnull, "w") as devnull: 60 | with contextlib.redirect_stdout(devnull): 61 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 62 | coco_eval = self.coco_eval[iou_type] 63 | 64 | coco_eval.cocoDt = coco_dt 65 | coco_eval.params.imgIds = list(img_ids) 66 | coco_eval.params.useCats = True 67 | img_ids, eval_imgs = evaluate(coco_eval) 68 | 69 | self.eval_imgs[iou_type].append(eval_imgs) 70 | 71 | def synchronize_between_processes(self): 72 | for iou_type in self.iou_types: 73 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 74 | create_common_coco_eval( 75 | self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type] 76 | ) 77 | 78 | def accumulate(self): 79 | for coco_eval in self.coco_eval.values(): 80 | coco_eval.accumulate() 81 | 82 | def summarize(self): 83 | for iou_type, coco_eval in self.coco_eval.items(): 84 | print("IoU metric: {}".format(iou_type)) 85 | coco_eval.summarize() 86 | 87 | precisions = self.coco_eval[iou_type].eval["precision"] 88 | 89 | results_seen = [] 90 | results_unseen = [] 91 | for idx in range(precisions.shape[-3]): 92 | # area range index 0: all area ranges 93 | # max dets index -1: typically 100 per image 94 | precision = precisions[0, :, idx, 0, -1] 95 | precision = precision[precision > -1] 96 | if precision.size: 97 | ap = np.mean(precision) 98 | # print(f"AP {idx}: {ap}") 99 | if idx not in self.unseen_list: 100 | results_seen.append(float(ap * 100)) 101 | else: 102 | results_unseen.append(float(ap * 100)) 103 | print(f"{iou_type} AP seen: {np.mean(results_seen)}") 104 | print(f"{iou_type} AP unseen: {np.mean(results_unseen)}") 105 | 106 | def prepare(self, predictions, iou_type): 107 | if iou_type == "bbox": 108 | return self.prepare_for_coco_detection(predictions) 109 | elif iou_type == "segm": 110 | return self.prepare_for_coco_segmentation(predictions) 111 | elif iou_type == "keypoints": 112 | return self.prepare_for_coco_keypoint(predictions) 113 | else: 114 | raise ValueError("Unknown iou type {}".format(iou_type)) 115 | 116 | def prepare_for_coco_detection(self, predictions): 117 | coco_results = [] 118 | for original_id, prediction in predictions.items(): 119 | if len(prediction) == 0: 120 | continue 121 | 122 | boxes = prediction["boxes"] 123 | boxes = convert_to_xywh(boxes).tolist() 124 | scores = prediction["scores"].tolist() 125 | labels = prediction["labels"].tolist() 126 | 127 | coco_results.extend( 128 | [ 129 | { 130 | "image_id": original_id, 131 | "category_id": self.label2cat[labels[k]] if self.label_map else labels[k], 132 | "bbox": box, 133 | "score": scores[k], 134 | } 135 | for k, box in enumerate(boxes) 136 | ] 137 | ) 138 | return coco_results 139 | 140 | def prepare_for_coco_segmentation(self, predictions): 141 | coco_results = [] 142 | for original_id, prediction in predictions.items(): 143 | if len(prediction) == 0: 144 | continue 145 | 146 | scores = prediction["scores"] 147 | labels = prediction["labels"] 148 | masks = prediction["masks"].cpu() 149 | 150 | masks = masks > 0.5 151 | 152 | scores = prediction["scores"].tolist() 153 | labels = prediction["labels"].tolist() 154 | 155 | rles = [ 156 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 157 | for mask in masks 158 | ] 159 | for rle in rles: 160 | rle["counts"] = rle["counts"].decode("utf-8") 161 | 162 | coco_results.extend( 163 | [ 164 | { 165 | "image_id": original_id, 166 | "category_id": self.label2cat[labels[k]] if self.label_map else labels[k], 167 | "segmentation": rle, 168 | "score": scores[k], 169 | } 170 | for k, rle in enumerate(rles) 171 | ] 172 | ) 173 | return coco_results 174 | 175 | def prepare_for_coco_keypoint(self, predictions): 176 | coco_results = [] 177 | for original_id, prediction in predictions.items(): 178 | if len(prediction) == 0: 179 | continue 180 | 181 | boxes = prediction["boxes"] 182 | boxes = convert_to_xywh(boxes).tolist() 183 | scores = prediction["scores"].tolist() 184 | labels = prediction["labels"].tolist() 185 | keypoints = prediction["keypoints"] 186 | keypoints = keypoints.flatten(start_dim=1).tolist() 187 | 188 | coco_results.extend( 189 | [ 190 | { 191 | "image_id": original_id, 192 | "category_id": self.label2cat[labels[k]] if self.label_map else labels[k], 193 | "keypoints": keypoint, 194 | "score": scores[k], 195 | } 196 | for k, keypoint in enumerate(keypoints) 197 | ] 198 | ) 199 | return coco_results 200 | 201 | 202 | def convert_to_xywh(boxes): 203 | xmin, ymin, xmax, ymax = boxes.unbind(1) 204 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 205 | 206 | 207 | def merge(img_ids, eval_imgs): 208 | all_img_ids = all_gather(img_ids) 209 | all_eval_imgs = all_gather(eval_imgs) 210 | 211 | merged_img_ids = [] 212 | for p in all_img_ids: 213 | merged_img_ids.extend(p) 214 | 215 | merged_eval_imgs = [] 216 | for p in all_eval_imgs: 217 | merged_eval_imgs.append(p) 218 | 219 | merged_img_ids = np.array(merged_img_ids) 220 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 221 | 222 | # keep only unique (and in sorted order) images 223 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 224 | merged_eval_imgs = merged_eval_imgs[..., idx] 225 | 226 | return merged_img_ids, merged_eval_imgs 227 | 228 | 229 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 230 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 231 | img_ids = list(img_ids) 232 | eval_imgs = list(eval_imgs.flatten()) 233 | 234 | coco_eval.evalImgs = eval_imgs 235 | coco_eval.params.imgIds = img_ids 236 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 237 | 238 | 239 | ################################################################# 240 | # From pycocotools, just removed the prints and fixed 241 | # a Python3 bug about unicode not defined 242 | ################################################################# 243 | 244 | 245 | def evaluate(self): 246 | """ 247 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 248 | :return: None 249 | """ 250 | # tic = time.time() 251 | # print('Running per image evaluation...') 252 | p = self.params 253 | # add backward compatibility if useSegm is specified in params 254 | if p.useSegm is not None: 255 | p.iouType = "segm" if p.useSegm == 1 else "bbox" 256 | print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) 257 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 258 | p.imgIds = list(np.unique(p.imgIds)) 259 | if p.useCats: 260 | p.catIds = list(np.unique(p.catIds)) 261 | p.maxDets = sorted(p.maxDets) 262 | self.params = p 263 | 264 | self._prepare() 265 | # loop through images, area range, max detection number 266 | catIds = p.catIds if p.useCats else [-1] 267 | 268 | if p.iouType == "segm" or p.iouType == "bbox": 269 | computeIoU = self.computeIoU 270 | elif p.iouType == "keypoints": 271 | computeIoU = self.computeOks 272 | self.ious = {(imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds} 273 | 274 | evaluateImg = self.evaluateImg 275 | maxDet = p.maxDets[-1] 276 | evalImgs = [ 277 | evaluateImg(imgId, catId, areaRng, maxDet) 278 | for catId in catIds 279 | for areaRng in p.areaRng 280 | for imgId in p.imgIds 281 | ] 282 | # this is NOT in the pycocotools code, but could be done outside 283 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 284 | self._paramsEval = copy.deepcopy(self.params) 285 | # toc = time.time() 286 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 287 | return p.imgIds, evalImgs 288 | 289 | 290 | ################################################################# 291 | # end of straight copy from pycocotools, just removing the prints 292 | ################################################################# 293 | -------------------------------------------------------------------------------- /ovdetr/models/segmentation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | This file provides the definition of the convolutional heads used to predict masks, as well as the losses 12 | """ 13 | from typing import List, Optional 14 | 15 | import torch 16 | import torch.nn as nn 17 | import torch.nn.functional as F 18 | from torch import Tensor 19 | 20 | from util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list 21 | 22 | 23 | class DETRsegm(nn.Module): 24 | def __init__(self, detr, freeze_detr=False): 25 | super().__init__() 26 | self.detr = detr 27 | 28 | if freeze_detr: 29 | for p in self.parameters(): 30 | p.requires_grad_(False) 31 | 32 | hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead 33 | self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) 34 | self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 512], hidden_dim) 35 | 36 | def forward(self, samples: NestedTensor, targets=None, criterion=None): 37 | if self.training: 38 | return self.forward_train(samples, targets, criterion) 39 | else: 40 | return self.forward_test(samples) 41 | 42 | def forward_train(self, samples: NestedTensor, targets=None, criterion=None): 43 | with torch.no_grad(): 44 | if not isinstance(samples, NestedTensor): 45 | samples = nested_tensor_from_tensor_list(samples) 46 | features, pos = self.detr.backbone(samples) 47 | 48 | srcs = [] 49 | masks = [] 50 | for l, feat in enumerate(features): 51 | src, mask = feat.decompose() 52 | srcs.append(self.detr.input_proj[l](src)) 53 | masks.append(mask) 54 | assert mask is not None 55 | if self.detr.num_feature_levels > len(srcs): 56 | _len_srcs = len(srcs) 57 | for l in range(_len_srcs, self.detr.num_feature_levels): 58 | if l == _len_srcs: 59 | src = self.detr.input_proj[l](features[-1].tensors) 60 | else: 61 | src = self.detr.input_proj[l](srcs[-1]) 62 | m = samples.mask 63 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 64 | pos_l = self.detr.backbone[1](NestedTensor(src, mask)).to(src.dtype) 65 | srcs.append(src) 66 | masks.append(mask) 67 | pos.append(pos_l) 68 | 69 | max_len = 20 70 | uniq_labels = torch.cat([t["labels"] for t in targets]) 71 | uniq_labels = torch.unique(uniq_labels).to("cpu") 72 | uniq_labels = uniq_labels[torch.randperm(len(uniq_labels))][:max_len] 73 | select_id = uniq_labels.tolist() 74 | 75 | clip_query = self.detr.zeroshot_w[:, select_id].t() 76 | clip_query = self.detr.patch2query(clip_query) 77 | 78 | query_embeds = None 79 | if not self.detr.two_stage: 80 | query_embeds = self.detr.query_embed.weight 81 | ( 82 | hs, 83 | init_reference, 84 | inter_references, 85 | enc_outputs_class, 86 | enc_outputs_coord_unact, 87 | _, 88 | ), memory = self.detr.transformer(srcs, masks, pos, query_embeds, text_query=clip_query) 89 | 90 | for lvl in [hs.shape[0] - 1]: 91 | if lvl == 0: 92 | reference = init_reference 93 | else: 94 | reference = inter_references[lvl - 1] 95 | reference = inverse_sigmoid(reference) 96 | outputs_class = self.detr.get_outputs_class(self.detr.class_embed[lvl], hs[lvl]) 97 | tmp = self.detr.bbox_embed[lvl](hs[lvl]) 98 | if reference.shape[-1] == 4: 99 | tmp += reference 100 | else: 101 | assert reference.shape[-1] == 2 102 | tmp[..., :2] += reference 103 | outputs_coord = tmp.sigmoid() 104 | out = {"pred_logits": outputs_class, "pred_boxes": outputs_coord} 105 | 106 | # FIXME h_boxes takes the last one computed, keep this in mind 107 | indices = criterion.matcher(out, targets, select_id) 108 | src_idx = criterion._get_src_permutation_idx(indices) 109 | hs_select = hs[-1][src_idx[0], src_idx[1], :] 110 | 111 | bbox_mask = self.bbox_attention( 112 | hs_select[ 113 | None, 114 | ], 115 | memory[1], 116 | mask=masks[1], 117 | ) 118 | 119 | seg_masks = self.mask_head( 120 | srcs[1], bbox_mask, [features[1].tensors, features[0].tensors, features[0].tensors] 121 | ) 122 | bs = features[-1].tensors.shape[0] 123 | outputs_seg_masks = seg_masks.view( 124 | bs, len(src_idx[0]), seg_masks.shape[-2], seg_masks.shape[-1] 125 | ) 126 | out["pred_masks"] = outputs_seg_masks 127 | out["select_id"] = select_id 128 | 129 | return out 130 | 131 | def forward_test(self, samples: NestedTensor, targets=None, criterion=None): 132 | if not isinstance(samples, NestedTensor): 133 | samples = nested_tensor_from_tensor_list(samples) 134 | features, pos = self.detr.backbone(samples) 135 | 136 | srcs = [] 137 | masks = [] 138 | for l, feat in enumerate(features): 139 | src, mask = feat.decompose() 140 | srcs.append(self.detr.input_proj[l](src)) 141 | masks.append(mask) 142 | assert mask is not None 143 | if self.detr.num_feature_levels > len(srcs): 144 | _len_srcs = len(srcs) 145 | for l in range(_len_srcs, self.detr.num_feature_levels): 146 | if l == _len_srcs: 147 | src = self.detr.input_proj[l](features[-1].tensors) 148 | else: 149 | src = self.detr.input_proj[l](srcs[-1]) 150 | m = samples.mask 151 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 152 | pos_l = self.detr.backbone[1](NestedTensor(src, mask)).to(src.dtype) 153 | srcs.append(src) 154 | masks.append(mask) 155 | pos.append(pos_l) 156 | 157 | select_id = list(range(self.detr.zeroshot_w_val.shape[-1])) 158 | query_embeds = None 159 | if not self.detr.two_stage: 160 | query_embeds = self.detr.query_embed.weight 161 | 162 | outputs_class_list = [] 163 | num_patch = 5 164 | bs = features[-1].tensors.shape[0] 165 | cache = None 166 | for c in range(len(select_id) // num_patch + 1): 167 | clip_query = self.detr.zeroshot_w_val[:, c * num_patch : (c + 1) * num_patch].t() 168 | clip_query = self.detr.patch2query(clip_query) 169 | ( 170 | hs, 171 | init_reference, 172 | inter_references, 173 | enc_outputs_class, 174 | enc_outputs_coord_unact, 175 | cache, 176 | ), memory = self.detr.transformer( 177 | srcs, masks, pos, query_embeds, text_query=clip_query, cache=cache 178 | ) 179 | 180 | outputs_classes = [] 181 | outputs_coords = [] 182 | for lvl in range(hs.shape[0]): 183 | if lvl == 0: 184 | reference = init_reference 185 | else: 186 | reference = inter_references[lvl - 1] 187 | reference = inverse_sigmoid(reference) 188 | outputs_class = self.detr.get_outputs_class(self.detr.class_embed[lvl], hs[lvl]) 189 | tmp = self.detr.bbox_embed[lvl](hs[lvl]) 190 | if reference.shape[-1] == 4: 191 | tmp += reference 192 | else: 193 | assert reference.shape[-1] == 2 194 | tmp[..., :2] += reference 195 | outputs_coord = tmp.sigmoid() 196 | outputs_classes.append(outputs_class) 197 | outputs_coords.append(outputs_coord) 198 | outputs_class = torch.stack(outputs_classes) 199 | outputs_class_list.append(outputs_class) 200 | 201 | outputs_class = torch.cat(outputs_class_list, -2) 202 | prob = outputs_class[-1].sigmoid() 203 | scores, topk_indexes = torch.topk(prob.view(outputs_class[-1].shape[0], -1), 100, dim=1) 204 | labels = torch.zeros_like(prob, dtype=torch.int16).flatten(1) 205 | num_queries = self.detr.num_queries 206 | for ind, c in enumerate(select_id): 207 | labels[:, ind * num_queries : (ind + 1) * num_queries] = c 208 | labels = torch.gather(labels, 1, topk_indexes) 209 | select_id = torch.unique(labels).tolist() 210 | 211 | outputs_class_list = [] 212 | outputs_coord_list = [] 213 | outputs_seg_masks_list = [] 214 | bs = features[-1].tensors.shape[0] 215 | cache = None 216 | for c in range(len(select_id) // num_patch + 1): 217 | select_c = select_id[c * num_patch : (c + 1) * num_patch] 218 | clip_query = self.detr.zeroshot_w_val[:, select_c].t() 219 | clip_query = self.detr.patch2query(clip_query) 220 | ( 221 | hs, 222 | init_reference, 223 | inter_references, 224 | enc_outputs_class, 225 | enc_outputs_coord_unact, 226 | cache, 227 | ), memory = self.detr.transformer( 228 | srcs, masks, pos, query_embeds, text_query=clip_query, cache=cache 229 | ) 230 | outputs_classes = [] 231 | outputs_coords = [] 232 | for lvl in range(hs.shape[0]): 233 | if lvl == 0: 234 | reference = init_reference 235 | else: 236 | reference = inter_references[lvl - 1] 237 | reference = inverse_sigmoid(reference) 238 | outputs_class = self.detr.get_outputs_class(self.detr.class_embed[lvl], hs[lvl]) 239 | tmp = self.detr.bbox_embed[lvl](hs[lvl]) 240 | if reference.shape[-1] == 4: 241 | tmp += reference 242 | else: 243 | assert reference.shape[-1] == 2 244 | tmp[..., :2] += reference 245 | outputs_coord = tmp.sigmoid() 246 | outputs_classes.append(outputs_class) 247 | outputs_coords.append(outputs_coord) 248 | outputs_class = torch.stack(outputs_classes) 249 | outputs_coord = torch.stack(outputs_coords) 250 | outputs_class_list.append(outputs_class) 251 | outputs_coord_list.append(outputs_coord) 252 | 253 | bbox_mask = self.bbox_attention(hs[-1], memory[1], mask=masks[1]) 254 | seg_masks = self.mask_head( 255 | srcs[1], bbox_mask, [features[1].tensors, features[0].tensors, features[0].tensors] 256 | ) 257 | outputs_seg_masks = seg_masks.view( 258 | bs, self.detr.num_queries * len(select_c), seg_masks.shape[-2], seg_masks.shape[-1] 259 | ) 260 | outputs_seg_masks_list.append(outputs_seg_masks) 261 | 262 | outputs_class = torch.cat(outputs_class_list, -2) 263 | outputs_coord = torch.cat(outputs_coord_list, -2) 264 | outputs_seg_masks = torch.cat(outputs_seg_masks_list, 1) 265 | 266 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 267 | out["pred_masks"] = outputs_seg_masks 268 | out["select_id"] = select_id 269 | 270 | del outputs_class_list, outputs_coord_list, outputs_seg_masks_list 271 | 272 | return out 273 | 274 | 275 | def _expand(tensor, length: int): 276 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) 277 | 278 | 279 | class MaskHeadSmallConv(nn.Module): 280 | """ 281 | Simple convolutional head, using group norm. 282 | Upsampling is done using a FPN approach 283 | """ 284 | 285 | def __init__(self, dim, fpn_dims, context_dim): 286 | super().__init__() 287 | 288 | inter_dims = [ 289 | dim, 290 | context_dim // 2, 291 | context_dim // 4, 292 | context_dim // 8, 293 | context_dim // 16, 294 | context_dim // 64, 295 | ] 296 | self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) 297 | self.gn1 = torch.nn.GroupNorm(8, dim) 298 | self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) 299 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) 300 | self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) 301 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) 302 | self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) 303 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) 304 | self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) 305 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) 306 | self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) 307 | 308 | self.dim = dim 309 | 310 | self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) 311 | self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) 312 | self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) 313 | 314 | for m in self.modules(): 315 | if isinstance(m, nn.Conv2d): 316 | nn.init.kaiming_uniform_(m.weight, a=1) 317 | nn.init.constant_(m.bias, 0) 318 | 319 | def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): 320 | x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) 321 | 322 | x = self.lay1(x) 323 | x = self.gn1(x) 324 | x = F.relu(x) 325 | x = self.lay2(x) 326 | x = self.gn2(x) 327 | x = F.relu(x) 328 | 329 | cur_fpn = self.adapter1(fpns[0]) 330 | if cur_fpn.size(0) != x.size(0): 331 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 332 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 333 | x = self.lay3(x) 334 | x = self.gn3(x) 335 | x = F.relu(x) 336 | 337 | cur_fpn = self.adapter2(fpns[1]) 338 | if cur_fpn.size(0) != x.size(0): 339 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 340 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 341 | x = self.lay4(x) 342 | x = self.gn4(x) 343 | x = F.relu(x) 344 | 345 | cur_fpn = self.adapter3(fpns[2]) 346 | if cur_fpn.size(0) != x.size(0): 347 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 348 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 349 | x = self.lay5(x) 350 | x = self.gn5(x) 351 | x = F.relu(x) 352 | 353 | x = self.out_lay(x) 354 | return x 355 | 356 | 357 | class MHAttentionMap(nn.Module): 358 | """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" 359 | 360 | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): 361 | super().__init__() 362 | self.num_heads = num_heads 363 | self.hidden_dim = hidden_dim 364 | self.dropout = nn.Dropout(dropout) 365 | 366 | self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 367 | self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 368 | 369 | nn.init.zeros_(self.k_linear.bias) 370 | nn.init.zeros_(self.q_linear.bias) 371 | nn.init.xavier_uniform_(self.k_linear.weight) 372 | nn.init.xavier_uniform_(self.q_linear.weight) 373 | self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 374 | 375 | def forward(self, q, k, mask: Optional[Tensor] = None): 376 | q = self.q_linear(q) 377 | k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) 378 | qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) 379 | kh = k.view( 380 | k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1] 381 | ) 382 | weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) 383 | 384 | if mask is not None: 385 | weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) 386 | weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) 387 | weights = self.dropout(weights) 388 | return weights 389 | 390 | 391 | def dice_loss(inputs, targets, num_boxes): 392 | """ 393 | Compute the DICE loss, similar to generalized IOU for masks 394 | Args: 395 | inputs: A float tensor of arbitrary shape. 396 | The predictions for each example. 397 | targets: A float tensor with the same shape as inputs. Stores the binary 398 | classification label for each element in inputs 399 | (0 for the negative class and 1 for the positive class). 400 | """ 401 | inputs = inputs.sigmoid() 402 | inputs = inputs.flatten(1) 403 | numerator = 2 * (inputs * targets).sum(1) 404 | denominator = inputs.sum(-1) + targets.sum(-1) 405 | loss = 1 - (numerator + 1) / (denominator + 1) 406 | return loss.sum() / num_boxes 407 | 408 | 409 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 410 | """ 411 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 412 | Args: 413 | inputs: A float tensor of arbitrary shape. 414 | The predictions for each example. 415 | targets: A float tensor with the same shape as inputs. Stores the binary 416 | classification label for each element in inputs 417 | (0 for the negative class and 1 for the positive class). 418 | alpha: (optional) Weighting factor in range (0,1) to balance 419 | positive vs negative examples. Default = -1 (no weighting). 420 | gamma: Exponent of the modulating factor (1 - p_t) to 421 | balance easy vs hard examples. 422 | Returns: 423 | Loss tensor 424 | """ 425 | prob = inputs.sigmoid() 426 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 427 | p_t = prob * targets + (1 - prob) * (1 - targets) 428 | loss = ce_loss * ((1 - p_t) ** gamma) 429 | 430 | if alpha >= 0: 431 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 432 | loss = alpha_t * loss 433 | 434 | return loss.mean(1).sum() / num_boxes 435 | -------------------------------------------------------------------------------- /ovdetr/models/model.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from util.misc import NestedTensor, inverse_sigmoid, nested_tensor_from_tensor_list 9 | 10 | 11 | def _get_clones(module, N): 12 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 13 | 14 | 15 | class MLP(nn.Module): 16 | """Very simple multi-layer perceptron (also called FFN)""" 17 | 18 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 19 | super().__init__() 20 | self.num_layers = num_layers 21 | h = [hidden_dim] * (num_layers - 1) 22 | self.layers = nn.ModuleList( 23 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 24 | ) 25 | 26 | def forward(self, x): 27 | for i, layer in enumerate(self.layers): 28 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 29 | return x 30 | 31 | 32 | class DeformableDETR(nn.Module): 33 | def __init__( 34 | self, 35 | backbone, 36 | transformer, 37 | num_classes, 38 | num_queries, 39 | num_feature_levels, 40 | aux_loss=True, 41 | with_box_refine=False, 42 | two_stage=False, 43 | cls_out_channels=91, 44 | ): 45 | super().__init__() 46 | self.num_queries = num_queries 47 | self.transformer = transformer 48 | hidden_dim = transformer.d_model 49 | self.class_embed = nn.Linear(hidden_dim, cls_out_channels) 50 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 51 | self.num_feature_levels = num_feature_levels 52 | if not two_stage: 53 | self.query_embed = nn.Embedding(num_queries, hidden_dim * 2) 54 | if num_feature_levels > 1: 55 | num_backbone_outs = len(backbone.strides) 56 | input_proj_list = [] 57 | for _ in range(num_backbone_outs): 58 | in_channels = backbone.num_channels[_] 59 | input_proj_list.append( 60 | nn.Sequential( 61 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 62 | nn.GroupNorm(32, hidden_dim), 63 | ) 64 | ) 65 | for _ in range(num_feature_levels - num_backbone_outs): 66 | input_proj_list.append( 67 | nn.Sequential( 68 | nn.Conv2d(in_channels, hidden_dim, kernel_size=3, stride=2, padding=1), 69 | nn.GroupNorm(32, hidden_dim), 70 | ) 71 | ) 72 | in_channels = hidden_dim 73 | self.input_proj = nn.ModuleList(input_proj_list) 74 | else: 75 | self.input_proj = nn.ModuleList( 76 | [ 77 | nn.Sequential( 78 | nn.Conv2d(backbone.num_channels[0], hidden_dim, kernel_size=1), 79 | nn.GroupNorm(32, hidden_dim), 80 | ) 81 | ] 82 | ) 83 | self.backbone = backbone 84 | self.aux_loss = aux_loss 85 | self.with_box_refine = with_box_refine 86 | self.two_stage = two_stage 87 | 88 | prior_prob = 0.01 89 | bias_value = -math.log((1 - prior_prob) / prior_prob) 90 | self.class_embed.bias.data = torch.ones(cls_out_channels) * bias_value 91 | nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) 92 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) 93 | for proj in self.input_proj: 94 | nn.init.xavier_uniform_(proj[0].weight, gain=1) 95 | nn.init.constant_(proj[0].bias, 0) 96 | 97 | # if two-stage, the last class_embed and bbox_embed is for region proposal generation 98 | num_pred = ( 99 | (transformer.decoder.num_layers + 1) if two_stage else transformer.decoder.num_layers 100 | ) 101 | if with_box_refine: 102 | self.class_embed = _get_clones(self.class_embed, num_pred) 103 | self.bbox_embed = _get_clones(self.bbox_embed, num_pred) 104 | nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) 105 | # hack implementation for iterative bounding box refinement 106 | self.transformer.decoder.bbox_embed = self.bbox_embed 107 | else: 108 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) 109 | self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) 110 | self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) 111 | self.transformer.decoder.bbox_embed = None 112 | if two_stage: 113 | # hack implementation for two-stage 114 | self.transformer.decoder.class_embed = self.class_embed 115 | for box_embed in self.bbox_embed: 116 | nn.init.constant_(box_embed.layers[-1].bias.data[2:], 0.0) 117 | 118 | def get_outputs_class(self, layer, data): 119 | return layer(data) 120 | 121 | def forward(self, samples: NestedTensor): 122 | if not isinstance(samples, NestedTensor): 123 | samples = nested_tensor_from_tensor_list(samples) 124 | features, pos = self.backbone(samples) 125 | 126 | srcs = [] 127 | masks = [] 128 | for l, feat in enumerate(features): 129 | src, mask = feat.decompose() 130 | srcs.append(self.input_proj[l](src)) 131 | masks.append(mask) 132 | assert mask is not None 133 | if self.num_feature_levels > len(srcs): 134 | _len_srcs = len(srcs) 135 | for l in range(_len_srcs, self.num_feature_levels): 136 | if l == _len_srcs: 137 | src = self.input_proj[l](features[-1].tensors) 138 | else: 139 | src = self.input_proj[l](srcs[-1]) 140 | m = samples.mask 141 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 142 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) 143 | srcs.append(src) 144 | masks.append(mask) 145 | pos.append(pos_l) 146 | 147 | query_embeds = None 148 | if not self.two_stage: 149 | query_embeds = self.query_embed.weight 150 | ( 151 | hs, 152 | init_reference, 153 | inter_references, 154 | enc_outputs_class, 155 | enc_outputs_coord_unact, 156 | ), _ = self.transformer(srcs, masks, pos, query_embeds) 157 | 158 | outputs_classes = [] 159 | outputs_coords = [] 160 | for lvl in range(hs.shape[0]): 161 | if lvl == 0: 162 | reference = init_reference 163 | else: 164 | reference = inter_references[lvl - 1] 165 | reference = inverse_sigmoid(reference) 166 | outputs_class = self.get_outputs_class(self.class_embed[lvl], hs[lvl]) 167 | tmp = self.bbox_embed[lvl](hs[lvl]) 168 | if reference.shape[-1] == 4: 169 | tmp += reference 170 | else: 171 | assert reference.shape[-1] == 2 172 | tmp[..., :2] += reference 173 | outputs_coord = tmp.sigmoid() 174 | outputs_classes.append(outputs_class) 175 | outputs_coords.append(outputs_coord) 176 | outputs_class = torch.stack(outputs_classes) 177 | outputs_coord = torch.stack(outputs_coords) 178 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 179 | if self.aux_loss: 180 | out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 181 | 182 | if self.two_stage: 183 | enc_outputs_coord = enc_outputs_coord_unact.sigmoid() 184 | out["enc_outputs"] = {"pred_logits": enc_outputs_class, "pred_boxes": enc_outputs_coord} 185 | return out 186 | 187 | @torch.jit.unused 188 | def _set_aux_loss(self, outputs_class, outputs_coord): 189 | # this is a workaround to make torchscript happy, as torchscript 190 | # doesn't support dictionary with non-homogeneous values, such 191 | # as a dict having both a Tensor and a list. 192 | return [ 193 | {"pred_logits": a, "pred_boxes": b} 194 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) 195 | ] 196 | 197 | 198 | class OVDETR(DeformableDETR): 199 | def __init__( 200 | self, 201 | backbone, 202 | transformer, 203 | num_classes, 204 | num_queries, 205 | num_feature_levels, 206 | aux_loss=True, 207 | with_box_refine=False, 208 | two_stage=False, 209 | cls_out_channels=2, 210 | dataset_file="coco", 211 | zeroshot_w=None, 212 | max_len=15, 213 | clip_feat_path=None, 214 | prob=0.5, 215 | ): 216 | super().__init__( 217 | backbone, 218 | transformer, 219 | num_classes, 220 | num_queries, 221 | num_feature_levels, 222 | aux_loss, 223 | with_box_refine, 224 | two_stage, 225 | cls_out_channels=1, 226 | ) 227 | self.zeroshot_w = zeroshot_w.t() 228 | 229 | self.patch2query = nn.Linear(512, 256) 230 | self.patch2query_img = nn.Linear(512, 256) 231 | for layer in [self.patch2query]: 232 | nn.init.xavier_uniform_(self.patch2query.weight) 233 | nn.init.constant_(self.patch2query.bias, 0) 234 | 235 | self.feature_align = nn.Linear(256, 512) 236 | nn.init.xavier_uniform_(self.feature_align.weight) 237 | nn.init.constant_(self.feature_align.bias, 0) 238 | 239 | num_pred = transformer.decoder.num_layers 240 | if with_box_refine: 241 | self.feature_align = _get_clones(self.feature_align, num_pred) 242 | else: 243 | self.feature_align = nn.ModuleList([self.feature_align for _ in range(num_pred)]) 244 | 245 | self.all_ids = torch.tensor(range(self.zeroshot_w.shape[-1])) 246 | self.max_len = max_len 247 | self.max_pad_len = max_len - 3 248 | 249 | self.clip_feat = torch.load(clip_feat_path) 250 | self.prob = prob 251 | 252 | def forward(self, samples: NestedTensor, targets=None): 253 | if self.training: 254 | return self.forward_train(samples, targets) 255 | else: 256 | return self.forward_test(samples) 257 | 258 | def forward_train(self, samples: NestedTensor, targets=None): 259 | if not isinstance(samples, NestedTensor): 260 | samples = nested_tensor_from_tensor_list(samples) 261 | features, pos = self.backbone(samples) 262 | 263 | srcs = [] 264 | masks = [] 265 | for l, feat in enumerate(features): 266 | src, mask = feat.decompose() 267 | srcs.append(self.input_proj[l](src)) 268 | masks.append(mask) 269 | assert mask is not None 270 | if self.num_feature_levels > len(srcs): 271 | _len_srcs = len(srcs) 272 | for l in range(_len_srcs, self.num_feature_levels): 273 | if l == _len_srcs: 274 | src = self.input_proj[l](features[-1].tensors) 275 | else: 276 | src = self.input_proj[l](srcs[-1]) 277 | m = samples.mask 278 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 279 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) 280 | srcs.append(src) 281 | masks.append(mask) 282 | pos.append(pos_l) 283 | 284 | uniq_labels = torch.cat([t["labels"] for t in targets]) 285 | uniq_labels = torch.unique(uniq_labels).to("cpu") 286 | uniq_labels = uniq_labels[torch.randperm(len(uniq_labels))][: self.max_len] 287 | select_id = uniq_labels.tolist() 288 | if len(select_id) < self.max_pad_len: 289 | pad_len = self.max_pad_len - len(uniq_labels) 290 | extra_list = torch.tensor([i for i in self.all_ids if i not in uniq_labels]) 291 | extra_labels = extra_list[torch.randperm(len(extra_list))][:pad_len] 292 | select_id += extra_labels.tolist() 293 | 294 | text_query = self.zeroshot_w[:, select_id].t() 295 | img_query = [] 296 | for cat_id in select_id: 297 | index = torch.randperm(len(self.clip_feat[cat_id]))[0:1] 298 | img_query.append(self.clip_feat[cat_id][index]) 299 | img_query = torch.cat(img_query).to(text_query.device) 300 | img_query = img_query / img_query.norm(dim=-1, keepdim=True) 301 | 302 | mask = (torch.rand(len(text_query)) < self.prob).float().unsqueeze(1).to(text_query.device) 303 | clip_query_ori = (text_query * mask + img_query * (1 - mask)).detach() 304 | 305 | dtype = self.patch2query.weight.dtype 306 | text_query = self.patch2query(text_query.type(dtype)) 307 | img_query = self.patch2query_img(img_query.type(dtype)) 308 | clip_query = text_query * mask + img_query * (1 - mask) 309 | 310 | query_embeds = None 311 | if not self.two_stage: 312 | query_embeds = self.query_embed.weight 313 | ( 314 | hs, 315 | init_reference, 316 | inter_references, 317 | enc_outputs_class, 318 | enc_outputs_coord_unact, 319 | _, 320 | ), _ = self.transformer(srcs, masks, pos, query_embeds, text_query=clip_query) 321 | 322 | outputs_classes = [] 323 | outputs_coords = [] 324 | outputs_embeds = [] 325 | for lvl in range(hs.shape[0]): 326 | if lvl == 0: 327 | reference = init_reference 328 | else: 329 | reference = inter_references[lvl - 1] 330 | reference = inverse_sigmoid(reference) 331 | outputs_class = self.get_outputs_class(self.class_embed[lvl], hs[lvl]) 332 | tmp = self.bbox_embed[lvl](hs[lvl]) 333 | if reference.shape[-1] == 4: 334 | tmp += reference 335 | else: 336 | assert reference.shape[-1] == 2 337 | tmp[..., :2] += reference 338 | outputs_coord = tmp.sigmoid() 339 | outputs_classes.append(outputs_class) 340 | outputs_coords.append(outputs_coord) 341 | outputs_embeds.append(self.feature_align[lvl](hs[lvl])) 342 | 343 | outputs_class = torch.stack(outputs_classes) 344 | outputs_coord = torch.stack(outputs_coords) 345 | outputs_embed = torch.stack(outputs_embeds) 346 | out = { 347 | "pred_logits": outputs_class[-1], 348 | "pred_boxes": outputs_coord[-1], 349 | "pred_embed": outputs_embed[-1], 350 | "select_id": select_id, 351 | "clip_query": clip_query_ori, 352 | } 353 | if self.aux_loss: 354 | out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 355 | for temp, embed in zip(out["aux_outputs"], outputs_embed[:-1]): 356 | temp["select_id"] = select_id 357 | temp["pred_embed"] = embed 358 | temp["clip_query"] = clip_query_ori 359 | 360 | if self.two_stage: 361 | enc_outputs_coord = enc_outputs_coord_unact.sigmoid() 362 | out["enc_outputs"] = { 363 | "pred_logits": enc_outputs_class, 364 | "pred_boxes": enc_outputs_coord, 365 | "select_id": select_id, 366 | } 367 | return out 368 | 369 | def forward_test(self, samples: NestedTensor): 370 | if not isinstance(samples, NestedTensor): 371 | samples = nested_tensor_from_tensor_list(samples) 372 | features, pos = self.backbone(samples) 373 | 374 | srcs = [] 375 | masks = [] 376 | for l, feat in enumerate(features): 377 | src, mask = feat.decompose() 378 | srcs.append(self.input_proj[l](src)) 379 | masks.append(mask) 380 | assert mask is not None 381 | if self.num_feature_levels > len(srcs): 382 | _len_srcs = len(srcs) 383 | for l in range(_len_srcs, self.num_feature_levels): 384 | if l == _len_srcs: 385 | src = self.input_proj[l](features[-1].tensors) 386 | else: 387 | src = self.input_proj[l](srcs[-1]) 388 | m = samples.mask 389 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]).to(torch.bool)[0] 390 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) 391 | srcs.append(src) 392 | masks.append(mask) 393 | pos.append(pos_l) 394 | 395 | select_id = list(range(self.zeroshot_w.shape[-1])) 396 | query_embeds = None 397 | if not self.two_stage: 398 | query_embeds = self.query_embed.weight 399 | 400 | outputs_class_list = [] 401 | outputs_coord_list = [] 402 | num_patch = 15 403 | cache = None 404 | dtype = self.patch2query.weight.dtype 405 | for c in range(len(select_id) // num_patch + 1): 406 | clip_query = self.zeroshot_w[:, c * num_patch : (c + 1) * num_patch].t() 407 | clip_query = self.patch2query(clip_query.type(dtype)) 408 | ( 409 | hs, 410 | init_reference, 411 | inter_references, 412 | enc_outputs_class, 413 | enc_outputs_coord_unact, 414 | cache, 415 | ), _ = self.transformer( 416 | srcs, masks, pos, query_embeds, text_query=clip_query, cache=cache 417 | ) 418 | 419 | outputs_classes = [] 420 | outputs_coords = [] 421 | for lvl in range(hs.shape[0]): 422 | if lvl == 0: 423 | reference = init_reference 424 | else: 425 | reference = inter_references[lvl - 1] 426 | reference = inverse_sigmoid(reference) 427 | outputs_class = self.get_outputs_class(self.class_embed[lvl], hs[lvl]) 428 | tmp = self.bbox_embed[lvl](hs[lvl]) 429 | if reference.shape[-1] == 4: 430 | tmp += reference 431 | else: 432 | assert reference.shape[-1] == 2 433 | tmp[..., :2] += reference 434 | outputs_coord = tmp.sigmoid() 435 | outputs_classes.append(outputs_class) 436 | outputs_coords.append(outputs_coord) 437 | outputs_class = torch.stack(outputs_classes) 438 | outputs_coord = torch.stack(outputs_coords) 439 | outputs_class_list.append(outputs_class) 440 | outputs_coord_list.append(outputs_coord) 441 | outputs_class = torch.cat(outputs_class_list, -2) 442 | outputs_coord = torch.cat(outputs_coord_list, -2) 443 | 444 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 445 | out["select_id"] = select_id 446 | if self.aux_loss: 447 | out["aux_outputs"] = self._set_aux_loss(outputs_class, outputs_coord) 448 | for temp in out["aux_outputs"]: 449 | temp["select_id"] = select_id 450 | 451 | if self.two_stage: 452 | enc_outputs_coord = enc_outputs_coord_unact.sigmoid() 453 | out["enc_outputs"] = { 454 | "pred_logits": enc_outputs_class, 455 | "pred_boxes": enc_outputs_coord, 456 | "select_id": select_id, 457 | } 458 | return out 459 | -------------------------------------------------------------------------------- /ovdetr/util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # OV DETR 3 | # Copyright (c) S-LAB, Nanyang Technological University. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | # Modified from DETR (https://github.com/facebookresearch/detr) 9 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 10 | # ------------------------------------------------------------------------ 11 | 12 | """ 13 | Misc functions, including distributed helpers. 14 | 15 | Mostly copy-paste from torchvision references. 16 | """ 17 | import datetime 18 | import os 19 | import pickle 20 | import subprocess 21 | import time 22 | from collections import defaultdict, deque 23 | from typing import List, Optional 24 | 25 | import torch 26 | import torch.distributed as dist 27 | 28 | # needed due to empty tensor bug in pytorch and torchvision 0.5 29 | import torchvision 30 | from torch import Tensor 31 | 32 | if float(torchvision.__version__[:3]) < 0.5: 33 | import math 34 | from torchvision.ops.misc import _NewEmptyTensorOp 35 | 36 | def _check_size_scale_factor(dim, size, scale_factor): 37 | # type: (int, Optional[List[int]], Optional[float]) -> None 38 | if size is None and scale_factor is None: 39 | raise ValueError("either size or scale_factor should be defined") 40 | if size is not None and scale_factor is not None: 41 | raise ValueError("only one of size or scale_factor should be defined") 42 | if not (scale_factor is not None and len(scale_factor) != dim): 43 | raise ValueError( 44 | "scale_factor shape must match input shape. " 45 | "Input is {}D, scale_factor size is {}".format(dim, len(scale_factor)) 46 | ) 47 | 48 | def _output_size(dim, input, size, scale_factor): 49 | # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] 50 | assert dim == 2 51 | _check_size_scale_factor(dim, size, scale_factor) 52 | if size is not None: 53 | return size 54 | # if dim is not 2 or scale_factor is iterable use _ntuple instead of concat 55 | assert scale_factor is not None and isinstance(scale_factor, (int, float)) 56 | scale_factors = [scale_factor, scale_factor] 57 | # math.floor might return float in py2.7 58 | return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)] 59 | 60 | 61 | elif float(torchvision.__version__[:3]) < 0.7: 62 | from torchvision.ops import _new_empty_tensor 63 | from torchvision.ops.misc import _output_size 64 | 65 | 66 | class SmoothedValue(object): 67 | """Track a series of values and provide access to smoothed values over a 68 | window or the global series average. 69 | """ 70 | 71 | def __init__(self, window_size=20, fmt=None): 72 | if fmt is None: 73 | fmt = "{median:.4f} ({global_avg:.4f})" 74 | self.deque = deque(maxlen=window_size) 75 | self.total = 0.0 76 | self.count = 0 77 | self.fmt = fmt 78 | 79 | def update(self, value, n=1): 80 | self.deque.append(value) 81 | self.count += n 82 | self.total += value * n 83 | 84 | def synchronize_between_processes(self): 85 | """ 86 | Warning: does not synchronize the deque! 87 | """ 88 | if not is_dist_avail_and_initialized(): 89 | return 90 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 91 | dist.barrier() 92 | dist.all_reduce(t) 93 | t = t.tolist() 94 | self.count = int(t[0]) 95 | self.total = t[1] 96 | 97 | @property 98 | def median(self): 99 | d = torch.tensor(list(self.deque)) 100 | return d.median().item() 101 | 102 | @property 103 | def avg(self): 104 | d = torch.tensor(list(self.deque), dtype=torch.float32) 105 | return d.mean().item() 106 | 107 | @property 108 | def global_avg(self): 109 | return self.total / self.count 110 | 111 | @property 112 | def max(self): 113 | return max(self.deque) 114 | 115 | @property 116 | def value(self): 117 | return self.deque[-1] 118 | 119 | def __str__(self): 120 | return self.fmt.format( 121 | median=self.median, 122 | avg=self.avg, 123 | global_avg=self.global_avg, 124 | max=self.max, 125 | value=self.value, 126 | ) 127 | 128 | 129 | def all_gather(data): 130 | """ 131 | Run all_gather on arbitrary picklable data (not necessarily tensors) 132 | Args: 133 | data: any picklable object 134 | Returns: 135 | list[data]: list of data gathered from each rank 136 | """ 137 | world_size = get_world_size() 138 | if world_size == 1: 139 | return [data] 140 | 141 | # serialized to a Tensor 142 | buffer = pickle.dumps(data) 143 | storage = torch.ByteStorage.from_buffer(buffer) 144 | tensor = torch.ByteTensor(storage).to("cuda") 145 | 146 | # obtain Tensor size of each rank 147 | local_size = torch.tensor([tensor.numel()], device="cuda") 148 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 149 | dist.all_gather(size_list, local_size) 150 | size_list = [int(size.item()) for size in size_list] 151 | max_size = max(size_list) 152 | 153 | # receiving Tensor from all ranks 154 | # we pad the tensor because torch all_gather does not support 155 | # gathering tensors of different shapes 156 | tensor_list = [] 157 | for _ in size_list: 158 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 159 | if local_size != max_size: 160 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 161 | tensor = torch.cat((tensor, padding), dim=0) 162 | dist.all_gather(tensor_list, tensor) 163 | 164 | data_list = [] 165 | for size, tensor in zip(size_list, tensor_list): 166 | buffer = tensor.cpu().numpy().tobytes()[:size] 167 | data_list.append(pickle.loads(buffer)) 168 | 169 | return data_list 170 | 171 | 172 | def reduce_dict(input_dict, average=True): 173 | """ 174 | Args: 175 | input_dict (dict): all the values will be reduced 176 | average (bool): whether to do average or sum 177 | Reduce the values in the dictionary from all processes so that all processes 178 | have the averaged results. Returns a dict with the same fields as 179 | input_dict, after reduction. 180 | """ 181 | world_size = get_world_size() 182 | if world_size < 2: 183 | return input_dict 184 | with torch.no_grad(): 185 | names = [] 186 | values = [] 187 | # sort the keys so that they are consistent across processes 188 | for k in sorted(input_dict.keys()): 189 | names.append(k) 190 | values.append(input_dict[k]) 191 | values = torch.stack(values, dim=0) 192 | dist.all_reduce(values) 193 | if average: 194 | values /= world_size 195 | reduced_dict = {k: v for k, v in zip(names, values)} 196 | return reduced_dict 197 | 198 | 199 | class MetricLogger(object): 200 | def __init__(self, delimiter="\t"): 201 | self.meters = defaultdict(SmoothedValue) 202 | self.delimiter = delimiter 203 | 204 | def update(self, **kwargs): 205 | for k, v in kwargs.items(): 206 | if isinstance(v, torch.Tensor): 207 | v = v.item() 208 | assert isinstance(v, (float, int)) 209 | self.meters[k].update(v) 210 | 211 | def __getattr__(self, attr): 212 | if attr in self.meters: 213 | return self.meters[attr] 214 | if attr in self.__dict__: 215 | return self.__dict__[attr] 216 | raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) 217 | 218 | def __str__(self): 219 | loss_str = [] 220 | for name, meter in self.meters.items(): 221 | loss_str.append("{}: {}".format(name, str(meter))) 222 | return self.delimiter.join(loss_str) 223 | 224 | def synchronize_between_processes(self): 225 | for meter in self.meters.values(): 226 | meter.synchronize_between_processes() 227 | 228 | def add_meter(self, name, meter): 229 | self.meters[name] = meter 230 | 231 | def log_every(self, iterable, print_freq, header=None): 232 | i = 0 233 | if not header: 234 | header = "" 235 | start_time = time.time() 236 | end = time.time() 237 | iter_time = SmoothedValue(fmt="{avg:.4f}") 238 | data_time = SmoothedValue(fmt="{avg:.4f}") 239 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 240 | if torch.cuda.is_available(): 241 | log_msg = self.delimiter.join( 242 | [ 243 | header, 244 | "[{0" + space_fmt + "}/{1}]", 245 | "eta: {eta}", 246 | "{meters}", 247 | "time: {time}", 248 | "data: {data}", 249 | "max mem: {memory:.0f}", 250 | ] 251 | ) 252 | else: 253 | log_msg = self.delimiter.join( 254 | [ 255 | header, 256 | "[{0" + space_fmt + "}/{1}]", 257 | "eta: {eta}", 258 | "{meters}", 259 | "time: {time}", 260 | "data: {data}", 261 | ] 262 | ) 263 | MB = 1024.0 * 1024.0 264 | for obj in iterable: 265 | data_time.update(time.time() - end) 266 | yield obj 267 | iter_time.update(time.time() - end) 268 | if i % print_freq == 0 or i == len(iterable) - 1: 269 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 270 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 271 | if torch.cuda.is_available(): 272 | print( 273 | log_msg.format( 274 | i, 275 | len(iterable), 276 | eta=eta_string, 277 | meters=str(self), 278 | time=str(iter_time), 279 | data=str(data_time), 280 | memory=torch.cuda.max_memory_allocated() / MB, 281 | ) 282 | ) 283 | else: 284 | print( 285 | log_msg.format( 286 | i, 287 | len(iterable), 288 | eta=eta_string, 289 | meters=str(self), 290 | time=str(iter_time), 291 | data=str(data_time), 292 | ) 293 | ) 294 | i += 1 295 | end = time.time() 296 | total_time = time.time() - start_time 297 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 298 | print( 299 | "{} Total time: {} ({:.4f} s / it)".format( 300 | header, total_time_str, total_time / len(iterable) 301 | ) 302 | ) 303 | 304 | 305 | def get_sha(): 306 | cwd = os.path.dirname(os.path.abspath(__file__)) 307 | 308 | def _run(command): 309 | return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() 310 | 311 | sha = "N/A" 312 | diff = "clean" 313 | branch = "N/A" 314 | try: 315 | sha = _run(["git", "rev-parse", "HEAD"]) 316 | subprocess.check_output(["git", "diff"], cwd=cwd) 317 | diff = _run(["git", "diff-index", "HEAD"]) 318 | diff = "has uncommited changes" if diff else "clean" 319 | branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) 320 | except Exception: 321 | pass 322 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 323 | return message 324 | 325 | 326 | def collate_fn(batch): 327 | batch = list(zip(*batch)) 328 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 329 | return tuple(batch) 330 | 331 | 332 | def _max_by_axis(the_list): 333 | # type: (List[List[int]]) -> List[int] 334 | maxes = the_list[0] 335 | for sublist in the_list[1:]: 336 | for index, item in enumerate(sublist): 337 | maxes[index] = max(maxes[index], item) 338 | return maxes 339 | 340 | 341 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 342 | # TODO make this more general 343 | if tensor_list[0].ndim == 3: 344 | # TODO make it support different-sized images 345 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 346 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 347 | batch_shape = [len(tensor_list)] + max_size 348 | b, c, h, w = batch_shape 349 | dtype = tensor_list[0].dtype 350 | device = tensor_list[0].device 351 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 352 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 353 | for img, pad_img, m in zip(tensor_list, tensor, mask): 354 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 355 | m[: img.shape[1], : img.shape[2]] = False 356 | else: 357 | raise ValueError("not supported") 358 | return NestedTensor(tensor, mask) 359 | 360 | 361 | class NestedTensor(object): 362 | def __init__(self, tensors, mask: Optional[Tensor]): 363 | self.tensors = tensors 364 | self.mask = mask 365 | 366 | def to(self, device, non_blocking=False): 367 | # type: (Device) -> NestedTensor # noqa 368 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 369 | mask = self.mask 370 | if mask is not None: 371 | assert mask is not None 372 | cast_mask = mask.to(device, non_blocking=non_blocking) 373 | else: 374 | cast_mask = None 375 | return NestedTensor(cast_tensor, cast_mask) 376 | 377 | def record_stream(self, *args, **kwargs): 378 | self.tensors.record_stream(*args, **kwargs) 379 | if self.mask is not None: 380 | self.mask.record_stream(*args, **kwargs) 381 | 382 | def decompose(self): 383 | return self.tensors, self.mask 384 | 385 | def __repr__(self): 386 | return str(self.tensors) 387 | 388 | 389 | def setup_for_distributed(is_master): 390 | """ 391 | This function disables printing when not in master process 392 | """ 393 | import builtins as __builtin__ 394 | 395 | builtin_print = __builtin__.print 396 | 397 | def print(*args, **kwargs): 398 | force = kwargs.pop("force", False) 399 | if is_master or force: 400 | builtin_print(*args, **kwargs) 401 | 402 | __builtin__.print = print 403 | 404 | 405 | def is_dist_avail_and_initialized(): 406 | if not dist.is_available(): 407 | return False 408 | if not dist.is_initialized(): 409 | return False 410 | return True 411 | 412 | 413 | def get_world_size(): 414 | if not is_dist_avail_and_initialized(): 415 | return 1 416 | return dist.get_world_size() 417 | 418 | 419 | def get_rank(): 420 | if not is_dist_avail_and_initialized(): 421 | return 0 422 | return dist.get_rank() 423 | 424 | 425 | def get_local_size(): 426 | if not is_dist_avail_and_initialized(): 427 | return 1 428 | return int(os.environ["LOCAL_SIZE"]) 429 | 430 | 431 | def get_local_rank(): 432 | if not is_dist_avail_and_initialized(): 433 | return 0 434 | return int(os.environ["LOCAL_RANK"]) 435 | 436 | 437 | def is_main_process(): 438 | return get_rank() == 0 439 | 440 | 441 | def save_on_master(*args, **kwargs): 442 | if is_main_process(): 443 | torch.save(*args, **kwargs) 444 | 445 | 446 | def init_distributed_mode(args): 447 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 448 | args.rank = int(os.environ["RANK"]) 449 | args.world_size = int(os.environ["WORLD_SIZE"]) 450 | args.gpu = int(os.environ["LOCAL_RANK"]) 451 | args.dist_url = "env://" 452 | os.environ["LOCAL_SIZE"] = str(torch.cuda.device_count()) 453 | elif "SLURM_PROCID" in os.environ: 454 | proc_id = int(os.environ["SLURM_PROCID"]) 455 | ntasks = int(os.environ["SLURM_NTASKS"]) 456 | node_list = os.environ["SLURM_NODELIST"] 457 | num_gpus = torch.cuda.device_count() 458 | addr = subprocess.getoutput("scontrol show hostname {} | head -n1".format(node_list)) 459 | os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") 460 | os.environ["MASTER_ADDR"] = addr 461 | os.environ["WORLD_SIZE"] = str(ntasks) 462 | os.environ["RANK"] = str(proc_id) 463 | os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) 464 | os.environ["LOCAL_SIZE"] = str(num_gpus) 465 | args.dist_url = "env://" 466 | args.world_size = ntasks 467 | args.rank = proc_id 468 | args.gpu = proc_id % num_gpus 469 | else: 470 | print("Not using distributed mode") 471 | args.distributed = False 472 | return 473 | 474 | args.distributed = True 475 | 476 | torch.cuda.set_device(args.gpu) 477 | args.dist_backend = "nccl" 478 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) 479 | torch.distributed.init_process_group( 480 | backend=args.dist_backend, 481 | init_method=args.dist_url, 482 | world_size=args.world_size, 483 | rank=args.rank, 484 | ) 485 | torch.distributed.barrier() 486 | setup_for_distributed(args.rank == 0) 487 | 488 | 489 | @torch.no_grad() 490 | def accuracy(output, target, topk=(1,)): 491 | """Computes the precision@k for the specified values of k""" 492 | if target.numel() == 0: 493 | return [torch.zeros([], device=output.device)] 494 | maxk = max(topk) 495 | batch_size = target.size(0) 496 | 497 | _, pred = output.topk(maxk, 1, True, True) 498 | pred = pred.t() 499 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 500 | 501 | res = [] 502 | for k in topk: 503 | correct_k = correct[:k].view(-1).float().sum(0) 504 | res.append(correct_k.mul_(100.0 / batch_size)) 505 | return res 506 | 507 | 508 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 509 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 510 | """ 511 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 512 | This will eventually be supported natively by PyTorch, and this 513 | class can go away. 514 | """ 515 | if float(torchvision.__version__[:3]) < 0.7: 516 | if input.numel() > 0: 517 | return torch.nn.functional.interpolate(input, size, scale_factor, mode, align_corners) 518 | 519 | output_shape = _output_size(2, input, size, scale_factor) 520 | output_shape = list(input.shape[:-2]) + list(output_shape) 521 | if float(torchvision.__version__[:3]) < 0.5: 522 | return _NewEmptyTensorOp.apply(input, output_shape) 523 | return _new_empty_tensor(input, output_shape) 524 | else: 525 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 526 | 527 | 528 | def get_total_grad_norm(parameters, norm_type=2): 529 | parameters = list(filter(lambda p: p.grad is not None, parameters)) 530 | norm_type = float(norm_type) 531 | device = parameters[0].grad.device 532 | total_norm = torch.norm( 533 | torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 534 | norm_type, 535 | ) 536 | return total_norm 537 | 538 | 539 | def inverse_sigmoid(x, eps=1e-5): 540 | x = x.clamp(min=0, max=1) 541 | x1 = x.clamp(min=eps) 542 | x2 = (1 - x).clamp(min=eps) 543 | return torch.log(x1 / x2) 544 | --------------------------------------------------------------------------------