├── src ├── __init__.py ├── model │ ├── __init__.py │ ├── ops │ │ ├── make.sh │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── src │ │ │ ├── vision.cpp │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.h │ │ │ │ └── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn.h │ │ ├── setup.py │ │ └── test.py │ ├── mlp.py │ ├── model_args.py │ ├── postprocess.py │ ├── position_encoding.py │ ├── mil_loss.py │ ├── backbone.py │ ├── deformable_transformer.py │ └── ws_detr.py ├── utils │ ├── __init__.py │ ├── box_ops.py │ └── misc.py ├── coco_tools │ ├── __init__.py │ ├── infer.py │ ├── transforms.py │ └── coco.py ├── args.py └── main.py ├── ws-detr.png ├── .gitignore ├── make.sh ├── Dockerfile ├── cfgs ├── fsod_split1.yaml ├── fsod_split0.yaml ├── fsod_split2.yaml ├── quickstart.yaml ├── inaturalist_superclass.yaml ├── inaturalist.yaml └── fgvc.yaml ├── README.md └── download.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/coco_tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/model/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | python setup.py build install 3 | -------------------------------------------------------------------------------- /ws-detr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tmlabonte/weakly-supervised-DETR/HEAD/ws-detr.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | ckpts 3 | data 4 | imgs 5 | lightning_logs 6 | out 7 | src/model/ops 8 | -------------------------------------------------------------------------------- /make.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | cd src/model/ops 3 | chmod u+x make.sh 4 | sudo -E env PATH=$PATH sh make.sh 5 | -------------------------------------------------------------------------------- /src/model/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /src/model/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-devel 2 | 3 | RUN apt-get update 4 | 5 | RUN apt-get -y install libgl1-mesa-glx libsm6 libxext6 ffmpeg gcc sudo wget vim git tmux unzip ninja-build 6 | 7 | RUN pip install azureml-core==1.35.0post1 8 | RUN pip install configargparse==1.5.3 9 | RUN pip install gdown==4.4.0 10 | RUN pip install numpy==1.21.0 11 | RUN pip install opencv-contrib-python==4.5.4.60 12 | RUN pip install pycocotools==2.0.2 13 | RUN pip install pytorch-lightning==1.5.1 14 | RUN pip install scipy==1.7.1 15 | RUN pip install sparsemax==0.1.9 16 | RUN pip install tensorboard==2.7.0 17 | 18 | RUN conda uninstall torchvision 19 | RUN conda install torchvision==0.10.0 -c pytorch 20 | -------------------------------------------------------------------------------- /src/model/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /src/model/mlp.py: -------------------------------------------------------------------------------- 1 | """Defines simple multi-layer perceptron.""" 2 | 3 | # Imports PyTorch packages. 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class MLP(nn.Module): 9 | """Defines simple multi-layer perceptron.""" 10 | 11 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 12 | """Initializes MLP with linear layers.""" 13 | 14 | super().__init__() 15 | 16 | self.num_layers = num_layers 17 | h = [hidden_dim] * (num_layers - 1) 18 | 19 | # Initializes desired number and dimensionality of linear layers. 20 | self.layers = nn.ModuleList( 21 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 22 | ) 23 | 24 | def forward(self, x): 25 | """Applies MLP with ReLU activations.""" 26 | 27 | for i, layer in enumerate(self.layers): 28 | # Performs ReLU activation on all except the last layer. 29 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 30 | 31 | return x 32 | 33 | -------------------------------------------------------------------------------- /src/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | """Operations for transforming bounding boxes.""" 2 | 3 | # Import PyTorch packages. 4 | import torch 5 | 6 | 7 | def box_cxcywh_to_xyxy(x): 8 | """Converts boxes from center-x center-y width height to xyxy.""" 9 | 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | def box_cxcywh_to_xywh(x): 15 | """Converts boxes from center-x center-y width height to xywh.""" 16 | 17 | x_c, y_c, w, h = x.unbind(-1) 18 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), w, h] 19 | return torch.stack(b, dim=-1) 20 | 21 | def box_xyxy_to_cxcywh(x): 22 | """Converts boxes from xyxy to center-x center-y width height.""" 23 | 24 | x0, y0, x1, y1 = x.unbind(-1) 25 | b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] 26 | return torch.stack(b, dim=-1) 27 | 28 | def box_xyxy_to_xywh(x): 29 | """Converts boxes from xyxy to x y width height.""" 30 | 31 | x0, y0, x1, y1 = x.unbind(-1) 32 | b = ((x0, y0, x1 - x0, y1 - y0)) 33 | return torch.stack(b, dim=-1) 34 | -------------------------------------------------------------------------------- /src/model/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /src/model/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /cfgs/fsod_split1.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/fsod/images 2 | train_anns: data/fsod/annotations/fsod_200_train_seed1_trn0.8.json 3 | val_imgs_dir: data/fsod/images 4 | val_anns: data/fsod/annotations/fsod_200_test_seed1_trn0.8.json 5 | test_imgs_dir: data/fsod/images 6 | test_anns: data/fsod/annotations/fsod_200_test_seed1_trn0.8.json 7 | infer_imgs_dir: "" 8 | save_dir: "" 9 | 10 | task: train 11 | batch_size: 2 12 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 13 | classes: 200 14 | dropout: 0.1 15 | joint_probability: True 16 | infer_display_thresh: 0.2 17 | nms_thresh: 0.3 18 | offset: 1 19 | refresh_rate: 1 20 | resume_training: False 21 | resume_weights: False 22 | sampler: False 23 | sparse: True 24 | supervised: False 25 | viz_test_batches: 20 26 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 27 | workers: 4 28 | 29 | lr_backbone: 2e-5 30 | lr_detr: 2e-4 31 | lr_drop: 0.1 32 | lr_mil: 1e-3 33 | lr_patience: 0 34 | lr_step_size: 15 35 | objectness_scale: 1000 36 | weight_decay: 1e-4 37 | 38 | activation: "relu" 39 | dec_layers: 6 40 | dec_points: 4 41 | dilation: False 42 | enc_layers: 6 43 | enc_points: 4 44 | feature_levels: 4 45 | feedforward_dim: 1024 46 | hidden_dim: 256 47 | heads: 8 48 | position_embedding: "sine" 49 | position_embedding_scale: 6.283185307179586 # 2π 50 | queries: 300 51 | 52 | accumulate_grad_batches: 2 53 | deterministic: True 54 | gpus: 4 55 | max_epochs: 30 56 | -------------------------------------------------------------------------------- /cfgs/fsod_split0.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/fsod/images 2 | train_anns: data/fsod/annotations/fsod_200_train_seed0_trn0.8.json 3 | val_imgs_dir: data/fsod/images 4 | val_anns: data/fsod/annotations/fsod_200_test_seed0_trn0.8.json 5 | test_imgs_dir: data/fsod/images 6 | test_anns: data/fsod/annotations/fsod_200_test_seed0_trn0.8.json 7 | infer_imgs_dir: "" 8 | save_dir: "" 9 | 10 | task: train 11 | 12 | batch_size: 2 13 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 14 | classes: 200 15 | dropout: 0.1 16 | joint_probability: True 17 | infer_display_thresh: 0.2 18 | nms_thresh: 0.3 19 | offset: 1 20 | refresh_rate: 1 21 | resume_training: False 22 | resume_weights: False 23 | sampler: False 24 | sparse: True 25 | supervised: False 26 | viz_test_batches: 20 27 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 28 | workers: 4 29 | 30 | lr_backbone: 2e-5 31 | lr_detr: 2e-4 32 | lr_drop: 0.1 33 | lr_mil: 1e-3 34 | lr_patience: 0 35 | lr_step_size: 15 36 | objectness_scale: 1000 37 | weight_decay: 1e-4 38 | 39 | activation: "relu" 40 | dec_layers: 6 41 | dec_points: 4 42 | dilation: False 43 | enc_layers: 6 44 | enc_points: 4 45 | feature_levels: 4 46 | feedforward_dim: 1024 47 | hidden_dim: 256 48 | heads: 8 49 | position_embedding: "sine" 50 | position_embedding_scale: 6.283185307179586 # 2π 51 | queries: 300 52 | 53 | accumulate_grad_batches: 2 54 | deterministic: True 55 | gpus: 4 56 | max_epochs: 30 57 | -------------------------------------------------------------------------------- /cfgs/fsod_split2.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/fsod/images 2 | train_anns: data/fsod/annotations/fsod_200_train_seed2_trn0.8.json 3 | val_imgs_dir: data/fsod/images 4 | val_anns: data/fsod/annotations/fsod_200_test_seed2_trn0.8.json 5 | test_imgs_dir: data/fsod/images 6 | test_anns: data/fsod/annotations/fsod_200_test_seed2_trn0.8.json 7 | infer_imgs_dir: "" 8 | save_dir: "" 9 | 10 | task: train 11 | 12 | batch_size: 2 13 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 14 | classes: 200 15 | dropout: 0.1 16 | joint_probability: True 17 | infer_display_thresh: 0.2 18 | nms_thresh: 0.3 19 | offset: 1 20 | refresh_rate: 1 21 | resume_training: False 22 | resume_weights: False 23 | sampler: False 24 | sparse: True 25 | supervised: False 26 | viz_test_batches: 20 27 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 28 | workers: 4 29 | 30 | lr_backbone: 2e-5 31 | lr_detr: 2e-4 32 | lr_drop: 0.1 33 | lr_mil: 1e-3 34 | lr_patience: 0 35 | lr_step_size: 15 36 | objectness_scale: 1000 37 | weight_decay: 1e-4 38 | 39 | activation: "relu" 40 | dec_layers: 6 41 | dec_points: 4 42 | dilation: False 43 | enc_layers: 6 44 | enc_points: 4 45 | feature_levels: 4 46 | feedforward_dim: 1024 47 | hidden_dim: 256 48 | heads: 8 49 | position_embedding: "sine" 50 | position_embedding_scale: 6.283185307179586 # 2π 51 | queries: 300 52 | 53 | accumulate_grad_batches: 2 54 | deterministic: True 55 | gpus: 4 56 | max_epochs: 30 57 | -------------------------------------------------------------------------------- /cfgs/quickstart.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/fsod/images 2 | train_anns: data/fsod/annotations/fsod_200_train_seed0_trn0.8.json 3 | val_imgs_dir: data/fsod/images 4 | val_anns: data/fsod/annotations/fsod_200_test_seed0_trn0.8.json 5 | test_imgs_dir: data/fsod/images 6 | test_anns: data/fsod/annotations/fsod_200_test_seed0_trn0.8.json 7 | infer_imgs_dir: "" 8 | save_dir: out 9 | 10 | task: train 11 | 12 | batch_size: 2 13 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 14 | classes: 200 15 | dropout: 0.1 16 | joint_probability: True 17 | infer_display_thresh: 0.2 18 | nms_thresh: 0.3 19 | offset: 1 20 | refresh_rate: 1 21 | resume_training: False 22 | resume_weights: False 23 | sampler: False 24 | sparse: True 25 | supervised: False 26 | viz_test_batches: 20 27 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 28 | workers: 4 29 | 30 | lr_backbone: 2e-5 31 | lr_detr: 2e-4 32 | lr_drop: 0.1 33 | lr_mil: 1e-3 34 | lr_patience: 0 35 | lr_step_size: 15 36 | objectness_scale: 1000 37 | weight_decay: 1e-4 38 | 39 | activation: "relu" 40 | dec_layers: 6 41 | dec_points: 4 42 | dilation: False 43 | enc_layers: 6 44 | enc_points: 4 45 | feature_levels: 4 46 | feedforward_dim: 1024 47 | hidden_dim: 256 48 | heads: 8 49 | position_embedding: "sine" 50 | position_embedding_scale: 6.283185307179586 # 2π 51 | queries: 300 52 | 53 | accumulate_grad_batches: 8 54 | deterministic: True 55 | gpus: 1 56 | max_epochs: 1 57 | -------------------------------------------------------------------------------- /cfgs/inaturalist_superclass.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/inaturalist/ 2 | train_anns: data/inaturalist/annotations/train_2017_bboxes_superclass.json 3 | val_imgs_dir: data/inaturalist/ 4 | val_anns: data/inaturalist/annotations/val_2017_bboxes_superclass.json 5 | test_imgs_dir: data/inaturalist/ 6 | test_anns: data/inaturalist/annotations/val_2017_bboxes_superclass.json 7 | infer_imgs_dir: "" 8 | save_dir: "" 9 | 10 | task: train 11 | 12 | batch_size: 2 13 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 14 | classes: 13 15 | dropout: 0.1 16 | joint_probability: True 17 | infer_display_thresh: 0.2 18 | nms_thresh: 0.3 19 | offset: 0 20 | refresh_rate: 1 21 | resume_training: False 22 | resume_weights: False 23 | sampler: False 24 | sparse: True 25 | supervised: False 26 | viz_test_batches: 20 27 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 28 | workers: 4 29 | 30 | lr_backbone: 2e-5 31 | lr_detr: 2e-4 32 | lr_drop: 0.1 33 | lr_mil: 1e-3 34 | lr_patience: 0 35 | lr_step_size: 6 36 | objectness_scale: 1000 37 | weight_decay: 1e-4 38 | 39 | activation: "relu" 40 | dec_layers: 6 41 | dec_points: 4 42 | dilation: False 43 | enc_layers: 6 44 | enc_points: 4 45 | feature_levels: 4 46 | feedforward_dim: 1024 47 | hidden_dim: 256 48 | heads: 8 49 | position_embedding: "sine" 50 | position_embedding_scale: 6.283185307179586 # 2π 51 | queries: 300 52 | 53 | deterministic: True 54 | gpus: 16 55 | max_epochs: 10 56 | -------------------------------------------------------------------------------- /cfgs/inaturalist.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/inaturalist/ 2 | train_anns: data/inaturalist/annotations/train_2017_bboxes_clean.json 3 | val_imgs_dir: data/inaturalist/ 4 | val_anns: data/inaturalist/annotations/val_2017_bboxes_clean.json 5 | test_imgs_dir: data/inaturalist/ 6 | test_anns: data/inaturalist/annotations/val_2017_bboxes_clean.json 7 | infer_imgs_dir: "" 8 | save_dir: "" 9 | 10 | task: train 11 | 12 | batch_size: 2 13 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 14 | classes: 2854 15 | dropout: 0.1 16 | joint_probability: True 17 | infer_display_thresh: 0.2 18 | nms_thresh: 0.3 19 | offset: 0 20 | refresh_rate: 1 21 | resume_training: False 22 | resume_weights: False 23 | sampler: False 24 | sparse: True 25 | supervised: False 26 | viz_test_batches: 20 27 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 28 | workers: 1 29 | 30 | lr_backbone: 2e-5 31 | lr_detr: 2e-4 32 | lr_drop: 0.1 33 | lr_mil: 1e-3 34 | lr_patience: 0 35 | lr_step_size: 6 36 | objectness_scale: 1000 37 | weight_decay: 1e-4 38 | 39 | activation: "relu" 40 | dec_layers: 6 41 | dec_points: 4 42 | dilation: False 43 | enc_layers: 6 44 | enc_points: 4 45 | feature_levels: 4 46 | feedforward_dim: 1024 47 | hidden_dim: 256 48 | heads: 8 49 | position_embedding: "sine" 50 | position_embedding_scale: 6.283185307179586 # 2π 51 | queries: 300 52 | 53 | accumulate_grad_batches: 2 54 | deterministic: True 55 | gpus: 4 56 | max_epochs: 10 57 | -------------------------------------------------------------------------------- /cfgs/fgvc.yaml: -------------------------------------------------------------------------------- 1 | train_imgs_dir: data/fgvc-aircraft-2013b/images 2 | train_anns: data/fgvc-aircraft-2013b/annotations/trainval.json 3 | val_imgs_dir: data/fgvc-aircraft-2013b/images 4 | val_anns: data/fgvc-aircraft-2013b/annotations/test.json 5 | test_imgs_dir: data/fgvc-aircraft-2013b/images 6 | test_anns: data/fgvc-aircraft-2013b/annotations/test.json 7 | infer_imgs_dir: "" 8 | save_dir: "" 9 | 10 | task: train 11 | 12 | batch_size: 2 13 | class_agnostic_weights: ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth 14 | classes: 100 15 | dropout: 0.1 16 | joint_probability: True 17 | infer_display_thresh: 0.2 18 | nms_thresh: 0.3 19 | offset: 0 20 | refresh_rate: 1 21 | resume_training: False 22 | resume_weights: False 23 | sampler: False 24 | sparse: True 25 | supervised: False 26 | viz_test_batches: 20 27 | weights: ckpts/deformable-detr_fsod-800_50epochs.pth 28 | workers: 4 29 | 30 | lr_backbone: 2e-5 31 | lr_detr: 2e-4 32 | lr_drop: 0.1 33 | lr_mil: 1e-3 34 | lr_patience: 0 35 | lr_step_size: 15 36 | objectness_scale: 1000 37 | weight_decay: 1e-4 38 | 39 | activation: "relu" 40 | dec_layers: 6 41 | dec_points: 4 42 | dilation: False 43 | enc_layers: 6 44 | enc_points: 4 45 | feature_levels: 4 46 | feedforward_dim: 1024 47 | hidden_dim: 256 48 | heads: 8 49 | position_embedding: "sine" 50 | position_embedding_scale: 6.283185307179586 # 2 * pi 51 | queries: 300 52 | 53 | accumulate_grad_batches: 2 54 | deterministic: True 55 | gpus: 4 56 | max_epochs: 30 57 | -------------------------------------------------------------------------------- /src/model/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /src/coco_tools/infer.py: -------------------------------------------------------------------------------- 1 | """Simple image loading dataset for inference.""" 2 | 3 | # Imports Python builtins. 4 | import os 5 | import os.path as osp 6 | 7 | # Imports PyTorch packages. 8 | from torch.utils.data import Dataset, DataLoader 9 | 10 | # Imports other packages. 11 | from PIL import Image 12 | 13 | # Imports local packages. 14 | from .coco import get_transform 15 | from utils.misc import nested_collate 16 | 17 | 18 | class SimpleImageDataset(Dataset): 19 | """A simple image loading dataset for inference without ground truth.""" 20 | 21 | def __init__(self, imgs_dir, transform): 22 | self.img_names = sorted(os.listdir(imgs_dir)) 23 | self.img_paths = [ 24 | osp.join(imgs_dir, name) for name in self.img_names 25 | ] 26 | self.orig_sizes = [ 27 | Image.open(img).size[::-1] for img in self.img_paths 28 | ] 29 | self.transform = transform 30 | 31 | def __len__(self): 32 | return len(self.img_names) 33 | 34 | def __getitem__(self, idx): 35 | img_name = self.img_names[idx] 36 | img_path = self.img_paths[idx] 37 | orig_size = self.orig_sizes[idx] 38 | 39 | img = Image.open(img_path).convert("RGB") 40 | img, _ = self.transform(img, None) 41 | 42 | return img, img_name, orig_size 43 | 44 | def infer_loader(args): 45 | """Creates DataLoader for inference without ground truth.""" 46 | 47 | transform = get_transform() 48 | 49 | dataset = SimpleImageDataset(args.infer_dir, transform) 50 | 51 | loader = DataLoader( 52 | dataset, 53 | batch_size=args.batch_size, 54 | collate_fn=nested_collate, 55 | num_workers=args.workers, 56 | pin_memory=True, 57 | ) 58 | 59 | return loader 60 | 61 | -------------------------------------------------------------------------------- /src/model/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /src/model/model_args.py: -------------------------------------------------------------------------------- 1 | """Define model configuration arguments.""" 2 | 3 | def add_model_args(parent_parser): 4 | """Adds model configuration arguments to parser.""" 5 | 6 | parser = parent_parser.add_argument_group("DeformableDETR") 7 | 8 | parser.add( 9 | "--activation", 10 | choices=["relu", "gelu", "glu"], 11 | type=str, 12 | help="Which activation to use in the transformer.", 13 | ) 14 | parser.add( 15 | "--dec_layers", 16 | type=int, 17 | help="How many layers to use in the decoder.", 18 | ) 19 | parser.add( 20 | "--dec_points", 21 | type=int, 22 | help="How many reference points to use in the decoder.", 23 | ) 24 | parser.add( 25 | "--dilation", 26 | action="store_true", 27 | help="Whether to replace stride with dilation in the last conv block.", 28 | ) 29 | parser.add( 30 | "--enc_layers", 31 | type=int, 32 | help="How many layers to use in the encoder.", 33 | ) 34 | parser.add( 35 | "--enc_points", 36 | type=int, 37 | help="How many reference points to use in the encoder.", 38 | ) 39 | parser.add( 40 | "--feature_levels", 41 | type=int, 42 | help="How many feature levels to use in the Transformer.", 43 | ) 44 | parser.add( 45 | "--feedforward_dim", 46 | type=int, 47 | help="Intermediate size of the feedforward layers in the Transformer.", 48 | ) 49 | parser.add( 50 | "--heads", 51 | type=int, 52 | help="How many attention heads to use in the Transformer.", 53 | ) 54 | parser.add( 55 | "--hidden_dim", 56 | type=int, 57 | help="Transformer embedding dimensionality.", 58 | ) 59 | parser.add( 60 | "--position_embedding", 61 | choices=["learned", "sine"], 62 | help="Type of positional embedding to use on the image features.", 63 | ) 64 | parser.add( 65 | "--position_embedding_scale", 66 | type=float, 67 | help="Scale of the transformer position embedding.", 68 | ) 69 | parser.add( 70 | "--queries", 71 | type=int, 72 | help="How many object queries to use in the transformer.", 73 | ) 74 | 75 | return parent_parser 76 | 77 | -------------------------------------------------------------------------------- /src/model/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | import numpy 12 | 13 | import torch 14 | 15 | from torch.utils.cpp_extension import CUDA_HOME 16 | from torch.utils.cpp_extension import CppExtension 17 | from torch.utils.cpp_extension import CUDAExtension 18 | 19 | from setuptools import find_packages 20 | from setuptools import setup 21 | 22 | requirements = ["torch", "torchvision"] 23 | 24 | def get_extensions(): 25 | this_dir = os.path.dirname(os.path.abspath(__file__)) 26 | extensions_dir = os.path.join(this_dir, "src") 27 | 28 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 29 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 30 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 31 | 32 | sources = main_file + source_cpu 33 | extension = CppExtension 34 | extra_compile_args = {"cxx": []} 35 | define_macros = [] 36 | 37 | if torch.cuda.is_available() and CUDA_HOME is not None: 38 | extension = CUDAExtension 39 | sources += source_cuda 40 | define_macros += [("WITH_CUDA", None)] 41 | extra_compile_args["nvcc"] = [ 42 | "-DCUDA_HAS_FP16=1", 43 | "-D__CUDA_NO_HALF_OPERATORS__", 44 | "-D__CUDA_NO_HALF_CONVERSIONS__", 45 | "-D__CUDA_NO_HALF2_OPERATORS__", 46 | ] 47 | else: 48 | raise NotImplementedError('Cuda is not availabel') 49 | 50 | sources = [os.path.join(extensions_dir, s) for s in sources] 51 | include_dirs = [extensions_dir] 52 | ext_modules = [ 53 | extension( 54 | "MultiScaleDeformableAttention", 55 | sources, 56 | include_dirs=include_dirs, 57 | define_macros=define_macros, 58 | extra_compile_args=extra_compile_args, 59 | ) 60 | ] 61 | return ext_modules 62 | 63 | setup( 64 | name="MultiScaleDeformableAttention", 65 | version="1.0", 66 | author="Weijie Su", 67 | url="https://github.com/fundamentalvision/Deformable-DETR", 68 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 69 | packages=find_packages(exclude=("configs", "tests",)), 70 | ext_modules=get_extensions(), 71 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 72 | ) 73 | 74 | -------------------------------------------------------------------------------- /src/model/postprocess.py: -------------------------------------------------------------------------------- 1 | """Postprocesses WS-DETR output for validation and inference.""" 2 | 3 | # Imports PyTorch packages. 4 | import torch 5 | from torchvision.ops import nms 6 | 7 | # Imports local packages. 8 | from utils import box_ops 9 | 10 | # Imports local model packages. 11 | from .mil_loss import mil_score 12 | 13 | 14 | def postprocess( 15 | outputs, 16 | target_sizes, 17 | joint_probability=None, 18 | nms_thresh=None, 19 | offset=0, 20 | sparse=None, 21 | supervised=None, 22 | ): 23 | """Scales and formats model output for validation and inference.""" 24 | 25 | # Gets classification probabilities from supervised model. 26 | if supervised: 27 | prob = outputs["classes_logits"].sigmoid() 28 | # Gets MIL score from weakly supervised model. 29 | # Note: Even if sparsemax is used for training, it is not 30 | # applied during postprocessing (similarly to dropout). 31 | else: 32 | _, prob = mil_score(outputs, joint_probability=joint_probability) 33 | 34 | # Postprocessing from Deformable DETR; sorts boxes and preds by confidence. 35 | # Adds offset at the end (e.g., in case the labels are 1-indexed). 36 | confs, topk_indexes = torch.topk(prob.view(prob.shape[0], -1), 300, dim=1) 37 | topk_boxes = topk_indexes // prob.shape[2] 38 | preds = topk_indexes % prob.shape[2] + offset 39 | 40 | # Converts boxes to ((x1, y1), (x2, y2)) coordinates. 41 | boxes = box_ops.box_cxcywh_to_xyxy(outputs["boxes"]) 42 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) 43 | 44 | # Converts output from a tensor to a list. 45 | boxes = list(torch.unbind(boxes)) 46 | confs = list(torch.unbind(confs)) 47 | preds = list(torch.unbind(preds)) 48 | 49 | # Performs non-maximum suppression on the model output. 50 | # Not strictly necessary and sometimes makes model output worse, 51 | # but can be useful for visualization and inference. 52 | if nms_thresh: 53 | it = zip(boxes, confs, preds) 54 | for j, (img_boxes, img_confs, img_preds) in enumerate(it): 55 | inds = nms(img_boxes, img_confs, iou_threshold=nms_thresh) 56 | boxes[j] = img_boxes[inds] 57 | confs[j] = img_confs[inds] 58 | preds[j] = img_preds[inds] 59 | 60 | # Converts from relative [0, 1] to absolute [0, height] coordinates. 61 | img_h, img_w = target_sizes.unbind(1) 62 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 63 | for j, img_boxes in enumerate(boxes): 64 | boxes[j] = img_boxes * scale_fct[j, None, :] 65 | 66 | # Clamps box width and height to image. 67 | x_coords = boxes[j][:, 0::2] 68 | y_coords = boxes[j][:, 1::2] 69 | boxes[j][:, 0::2] = torch.clamp(x_coords, min=0., max=float(img_w[j])) 70 | boxes[j][:, 1::2] = torch.clamp(y_coords, min=0., max=float(img_h[j])) 71 | 72 | # Clamps confs to [0, 1]. 73 | for j, _ in enumerate(confs): 74 | confs[j] = torch.clamp(confs[j], min=0., max=1.) 75 | 76 | # Builds list of output dicts. 77 | results = [ 78 | {"boxes": b, "confs": c, "preds": p} 79 | for b, c, p in zip(boxes, confs, preds) 80 | ] 81 | 82 | return results 83 | 84 | -------------------------------------------------------------------------------- /src/model/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /src/model/position_encoding.py: -------------------------------------------------------------------------------- 1 | """Various positional encodings for the transformer.""" 2 | 3 | # Imports Python builtins. 4 | import math 5 | 6 | # Imports PyTorch packages. 7 | import torch 8 | from torch import nn 9 | 10 | # Imports local packages. 11 | from utils.misc import NestedTensor 12 | 13 | 14 | class PositionEmbeddingSine(nn.Module): 15 | """ 16 | This is a more standard version of the position embedding, very similar to the one 17 | used by the "Attention is All You Need" paper, generalized to work on images. 18 | """ 19 | 20 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 21 | super().__init__() 22 | self.num_pos_feats = num_pos_feats 23 | self.temperature = temperature 24 | self.normalize = normalize 25 | if scale is not None and normalize is False: 26 | raise ValueError("normalize should be True if scale is passed") 27 | if scale is None: 28 | scale = 2 * math.pi 29 | self.scale = scale 30 | 31 | def forward(self, tensor_list: NestedTensor): 32 | x = tensor_list.tensors 33 | mask = tensor_list.mask 34 | assert mask is not None 35 | not_mask = ~mask 36 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 37 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 38 | if self.normalize: 39 | eps = 1e-6 40 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 41 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 42 | 43 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 44 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 45 | 46 | pos_x = x_embed[:, :, :, None] / dim_t 47 | pos_y = y_embed[:, :, :, None] / dim_t 48 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 49 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 50 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 51 | return pos 52 | 53 | class PositionEmbeddingLearned(nn.Module): 54 | """Absolute pos embedding, learned.""" 55 | 56 | def __init__(self, num_pos_feats=256): 57 | super().__init__() 58 | self.row_embed = nn.Embedding(50, num_pos_feats) 59 | self.col_embed = nn.Embedding(50, num_pos_feats) 60 | self.reset_parameters() 61 | 62 | def reset_parameters(self): 63 | nn.init.uniform_(self.row_embed.weight) 64 | nn.init.uniform_(self.col_embed.weight) 65 | 66 | def forward(self, tensor_list: NestedTensor): 67 | x = tensor_list.tensors 68 | h, w = x.shape[-2:] 69 | i = torch.arange(w, device=x.device) 70 | j = torch.arange(h, device=x.device) 71 | x_emb = self.col_embed(i) 72 | y_emb = self.row_embed(j) 73 | pos = torch.cat([ 74 | x_emb.unsqueeze(0).repeat(h, 1, 1), 75 | y_emb.unsqueeze(1).repeat(1, w, 1), 76 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 77 | return pos 78 | 79 | def build_position_encoding(args): 80 | N_steps = args.hidden_dim // 2 81 | 82 | if args.position_embedding == "sine": 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif args.position_embedding == "learned": 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError(f"Not supported {args.position_embedding}") 88 | 89 | return position_embedding 90 | 91 | -------------------------------------------------------------------------------- /src/model/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | torch.manual_seed(3) 28 | 29 | 30 | @torch.no_grad() 31 | def check_forward_equal_with_pytorch_double(): 32 | value = torch.rand(N, S, M, D).cuda() * 0.01 33 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 34 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 35 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 36 | im2col_step = 2 37 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 38 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 39 | fwdok = torch.allclose(output_cuda, output_pytorch) 40 | max_abs_err = (output_cuda - output_pytorch).abs().max() 41 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 42 | 43 | print(f"* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}") 44 | 45 | @torch.no_grad() 46 | def check_forward_equal_with_pytorch_float(): 47 | value = torch.rand(N, S, M, D).cuda() * 0.01 48 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 49 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 50 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 51 | im2col_step = 2 52 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 53 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 54 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 55 | max_abs_err = (output_cuda - output_pytorch).abs().max() 56 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 57 | 58 | print(f"* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}") 59 | 60 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 61 | value = torch.rand(N, S, M, channels).cuda() * 0.01 62 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 63 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 64 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 65 | im2col_step = 2 66 | func = MSDeformAttnFunction.apply 67 | 68 | value.requires_grad = grad_value 69 | sampling_locations.requires_grad = grad_sampling_loc 70 | attention_weights.requires_grad = grad_attn_weight 71 | 72 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 73 | 74 | print(f"* {gradok} check_gradient_numerical(D={channels})") 75 | 76 | 77 | if __name__ == "__main__": 78 | check_forward_equal_with_pytorch_double() 79 | check_forward_equal_with_pytorch_float() 80 | 81 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 82 | check_gradient_numerical(channels, True, True, True) 83 | 84 | -------------------------------------------------------------------------------- /src/model/mil_loss.py: -------------------------------------------------------------------------------- 1 | """Defines loss for multiple instance learning (MIL).""" 2 | 3 | # Imports PyTorch packages. 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | # Imports other packages. 8 | from sparsemax import Sparsemax 9 | 10 | from utils.misc import inverse_sigmoid 11 | 12 | 13 | def mil_score(outputs, joint_probability=None, sparse=None): 14 | """Computes box-wise MIL score. 15 | 16 | The MIL score is the elementwise product of the detection and 17 | classification softmax scores. We optionally sparsemax over 18 | the detections dimension, which typically increases performance. 19 | """ 20 | 21 | if joint_probability: 22 | dets_logits = inverse_sigmoid(outputs["obj_confs"]) 23 | else: 24 | dets_logits = outputs["dets_logits"] 25 | 26 | # Softmaxes over the classes dimension. 27 | classes_logits = outputs["classes_logits"] 28 | classes = F.softmax(classes_logits, dim=2) 29 | 30 | # Computes detection sigmoid as proxy for detection confidence. 31 | dets_sigmoid = dets_logits.sigmoid() 32 | 33 | # Softmaxes or sparsemaxes over the detections dimension. 34 | if sparse: 35 | dets = Sparsemax(dim=1)(dets_logits) 36 | else: 37 | dets = F.softmax(dets_logits, dim=1) 38 | 39 | if joint_probability: 40 | num_classes = classes.shape[-1] 41 | dets = dets.unsqueeze(-1).repeat(1, 1, num_classes) 42 | 43 | # Computes element-wise product of the two scores. 44 | scores = classes * dets 45 | 46 | return dets_sigmoid, scores 47 | 48 | def mil_label(batch_size, num_classes, targets, offset=0): 49 | """Gets the weak supervision label for MIL. 50 | 51 | There is a 1 in the class slot if there is 52 | at least one instance of that class in the image. 53 | """ 54 | 55 | # Creates empty tensor for MIL labels. 56 | mil_labels = torch.zeros((batch_size, num_classes)) 57 | mil_labels = mil_labels.type_as(targets[0]["boxes"]) 58 | 59 | # Populates MIL label from targets. 60 | for j, img_target in enumerate(targets): 61 | for cls in img_target["image_labels"]: 62 | # Subtracts offset (e.g., if the labels are 1-indexed). 63 | cls -= offset 64 | 65 | # Sets class slot to 1 in the label. 66 | mil_labels[j][cls] = 1 67 | 68 | return mil_labels 69 | 70 | def mil_nll(mil_scores, mil_labels, eps=1e-5): 71 | """Computes negative log-likelihood between MIL scores and MIL labels. 72 | 73 | eps argument prevents loss from becoming NaN. 74 | """ 75 | 76 | # Computes class-wise log-likelihoods. 77 | class_likelihoods = mil_labels * torch.log(mil_scores + eps) \ 78 | + (1 - mil_labels) * torch.log(1 - mil_scores + eps) 79 | 80 | # Computes mean NLL loss across batch. 81 | nll = -torch.sum(class_likelihoods, 1) 82 | mil_loss = torch.mean(nll, 0) 83 | 84 | return mil_loss 85 | 86 | def objectness_mse_loss(dets_sigmoid, obj_confs): 87 | """Computes MSE regularization loss on detections and objectness. 88 | 89 | Aligns the maximum detection confidence of the weakly supervised 90 | model with the objectness confidence of the class-agnostic model. 91 | """ 92 | 93 | # Computes maximum detection confidences and MSE loss. 94 | max_det_confs, _ = torch.max(dets_sigmoid, dim=2) 95 | objectness_loss = F.mse_loss(max_det_confs, obj_confs) 96 | 97 | return objectness_loss 98 | 99 | def mil_loss( 100 | outputs, 101 | targets, 102 | joint_probability=None, 103 | objectness_scale=1, 104 | offset=0, 105 | sparse=None, 106 | ): 107 | """Computes MIL score and label, then returns NLL loss with objectness.""" 108 | 109 | batch_size, _, classes = outputs["classes_logits"].shape 110 | mil_labels = mil_label(batch_size, classes, targets, offset=offset) 111 | 112 | dets_sigmoid, scores = mil_score( 113 | outputs, 114 | joint_probability=joint_probability, 115 | sparse=sparse, 116 | ) 117 | mil_scores = torch.sum(scores, 1) 118 | mil_loss = mil_nll(mil_scores, mil_labels) 119 | 120 | if joint_probability: 121 | objectness_loss = 0. * torch.sum(outputs["dets_logits"]) 122 | else: 123 | objectness_loss = objectness_mse_loss( 124 | dets_sigmoid, 125 | outputs["obj_confs"], 126 | ) 127 | objectness_loss *= objectness_scale 128 | 129 | return mil_loss, objectness_loss 130 | 131 | -------------------------------------------------------------------------------- /src/model/backbone.py: -------------------------------------------------------------------------------- 1 | """Defines DETR backbone (here, a ResNet). From Deformable DETR.""" 2 | 3 | # Imports Python builtins. 4 | from collections import OrderedDict 5 | from typing import Dict, List 6 | 7 | # Imports PyTorch packages. 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | import torchvision 12 | from torchvision.models._utils import IntermediateLayerGetter 13 | 14 | # Imports local packages. 15 | from utils.misc import NestedTensor 16 | 17 | # Imports local model packages. 18 | from .position_encoding import build_position_encoding 19 | 20 | 21 | class FrozenBatchNorm2d(torch.nn.Module): 22 | """ 23 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 24 | 25 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 26 | without which any other models than torchvision.models.resnet[18,34,50,101] 27 | produce nans. 28 | """ 29 | 30 | def __init__(self, n, eps=1e-5): 31 | super(FrozenBatchNorm2d, self).__init__() 32 | self.register_buffer("weight", torch.ones(n)) 33 | self.register_buffer("bias", torch.zeros(n)) 34 | self.register_buffer("running_mean", torch.zeros(n)) 35 | self.register_buffer("running_var", torch.ones(n)) 36 | self.eps = eps 37 | 38 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 39 | missing_keys, unexpected_keys, error_msgs): 40 | num_batches_tracked_key = prefix + "num_batches_tracked" 41 | if num_batches_tracked_key in state_dict: 42 | del state_dict[num_batches_tracked_key] 43 | 44 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 45 | state_dict, prefix, local_metadata, strict, 46 | missing_keys, unexpected_keys, error_msgs) 47 | 48 | def forward(self, x): 49 | # move reshapes to the beginning 50 | # to make it fuser-friendly 51 | w = self.weight.reshape(1, -1, 1, 1) 52 | b = self.bias.reshape(1, -1, 1, 1) 53 | rv = self.running_var.reshape(1, -1, 1, 1) 54 | rm = self.running_mean.reshape(1, -1, 1, 1) 55 | eps = self.eps 56 | scale = w * (rv + eps).rsqrt() 57 | bias = b - rm * scale 58 | return x * scale + bias 59 | 60 | 61 | class BackboneBase(nn.Module): 62 | 63 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 64 | super().__init__() 65 | for name, parameter in backbone.named_parameters(): 66 | if not train_backbone or "layer2" not in name and "layer3" not in name and "layer4" not in name: 67 | parameter.requires_grad_(False) 68 | if return_interm_layers: 69 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 70 | self.strides = [8, 16, 32] 71 | self.num_channels = [512, 1024, 2048] 72 | else: 73 | return_layers = {"layer4": "0"} 74 | self.strides = [32] 75 | self.num_channels = [2048] 76 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 77 | 78 | def forward(self, tensor_list: NestedTensor): 79 | xs = self.body(tensor_list.tensors) 80 | out: Dict[str, NestedTensor] = {} 81 | for name, x in xs.items(): 82 | m = tensor_list.mask 83 | assert m is not None 84 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 85 | out[name] = NestedTensor(x, mask) 86 | return out 87 | 88 | 89 | class Backbone(BackboneBase): 90 | """ResNet backbone with frozen BatchNorm.""" 91 | def __init__(self, name: str, 92 | train_backbone: bool, 93 | return_interm_layers: bool, 94 | dilation: bool): 95 | if dilation: 96 | self.strides[-1] = self.strides[-1] // 2 97 | 98 | backbone = getattr(torchvision.models, "resnet50")( 99 | replace_stride_with_dilation=[False, False, dilation], 100 | pretrained=True, norm_layer=FrozenBatchNorm2d) 101 | 102 | super().__init__(backbone, train_backbone, return_interm_layers) 103 | 104 | 105 | class Joiner(nn.Sequential): 106 | def __init__(self, backbone, position_embedding): 107 | super().__init__(backbone, position_embedding) 108 | self.strides = backbone.strides 109 | self.num_channels = backbone.num_channels 110 | 111 | def forward(self, tensor_list: NestedTensor): 112 | xs = self[0](tensor_list) 113 | out: List[NestedTensor] = [] 114 | pos = [] 115 | for name, x in sorted(xs.items()): 116 | out.append(x) 117 | 118 | # position encoding 119 | for x in out: 120 | pos.append(self[1](x).to(x.tensors.dtype)) 121 | 122 | return out, pos 123 | 124 | 125 | def build_backbone(args): 126 | position_embedding = build_position_encoding(args) 127 | train_backbone = args.lr_backbone > 0 128 | return_interm_layers = args.feature_levels > 1 129 | backbone = Backbone("resnet50", train_backbone, return_interm_layers, args.dilation) 130 | model = Joiner(backbone, position_embedding) 131 | return model 132 | 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Weakly Supervised Detection Transformer (WS-DETR) 2 | 3 | ![WS-DETR Architecture](ws-detr.png) 4 | 5 | Official codebase for the paper: [Scaling Novel Object Detection with Weakly Supervised Detection Transformers.](https://arxiv.org/abs/2207.05205) 6 | 7 | WS-DETR leverages large-scale fully supervised pretraining to detect hundreds of novel classes using only image-level classification labels. 8 | 9 | ### Setup Instructions 10 | First, you will need to [install Docker](https://docs.docker.com/engine/install/ubuntu/) if it is not already available on your machine. To download the WS-DETR Docker image and build the deformable attention modules, use: 11 | 12 | ``` 13 | git clone https://github.com/tmlabonte/weakly-supervised-DETR 14 | cd weakly-supervised-DETR 15 | sudo docker run -it --gpus all --privileged --shm-size 32g -v $(pwd):/local tmlabonte/ws-detr:latest 16 | cd /local 17 | pip install --upgrade --no-cache-dir gdown 18 | sh make.sh 19 | ``` 20 | 21 | To download the FGVC-Aircraft and FSOD datasets, use the following command. You can also download the iNaturalist 2017 dataset, but it is quite large, so we suggest starting with FGVC-Aircraft and FSOD. 22 | 23 | `python download.py --datasets fgvc fsod` 24 | 25 | To download our FSOD-800 pretrained Deformable DETR checkpoints, use: 26 | 27 | `gdown https://drive.google.com/drive/folders/1ZJIElm5A7TaZtIvWjaqnNet2llDQVwGq -O ckpts --folder` 28 | 29 | ### Quick Start 30 | To train WS-DETR on the FSOD-200 novel classes for 1 epoch on a single 16GB GPU, use the following command. Note that WS-DETR Full (with our joint probability estimation and sparsity techniques) is enabled by default. 31 | 32 | `python src/main.py -c cfgs/quickstart.yaml` 33 | 34 | The model will automatically visualize the output and save it to `out/imgs`. 35 | 36 | ### Configs 37 | Config files are located in the `cfgs/` directory. The pre-made configs correspond to experiments from the paper. To run a new experiment, you can make a new config or just use command line arguments: 38 | 39 | `python src/main.py -c cfgs/fsod_split0.yaml --batch_size 1` 40 | 41 | All PyTorch Lightning 1.5 [Trainer options](https://pytorch-lightning.readthedocs.io/en/1.5.1/common/trainer.html#trainer-flags) are valid config variables. For an explanation of what each config variable does, use `python src/main.py -h`. 42 | 43 | ### Testing and Inference 44 | To test on a labeled dataset, set the `test` task. To predict on a directory of images, set the `infer` task. For example, 45 | 46 | `python src/main.py -c cfgs/quickstart.yaml --task test --weights out/version_0/checkpoints/last.ckpt` 47 | 48 | You can also test with a class-agnostic model (e.g., to visualize the boxes before training) as follows: 49 | 50 | `python src/main.py -c cfgs/quickstart.yaml --task test --weights ckpts/deformable-detr_fsod-800_class-agnostic_50epochs.pth --supervised --classes 2` 51 | 52 | Note that the labeled datasets given have no infer directory. 53 | 54 | ### Multi-GPU Training 55 | To perform multi-GPU training, simply set the `gpus` argument: 56 | 57 | `python src/main.py -c cfgs/fsod_split0.yaml --gpus 8` 58 | 59 | Note that `batch_size` is per-GPU. If training on less than 8 GPUs, set the `accumulate_grad_batches` option to increase the effective batch size: 60 | 61 | `python src/main.py -c cfgs/fsod_split0.yaml --gpus 2 --accumulate_grad_batches 4` 62 | 63 | The effective batch size is `gpus` x `batch_size` x `accumulate_grad_batches`. We use a default batch size of 2 per GPU (for 16GB GPUs) and an effective batch size of 16. We have found that a batch size of 32 also works well. 64 | 65 | ### Adding Your Dataset 66 | Training WS-DETR on your own dataset is simple. First, you will need a [COCO-style annotation file](https://cocodataset.org/#format-data). If you are labeling your own dataset, you can make a `classes` field for each image which contains a list of category IDs present in the image, or you can make a box annotation for each category ID and our code will convert it for you. Second, make a new config file for your experiment following the examples in the `cfgs/` directory. Remember to set the directories and annotations locations at the top, as well as the `classes` field. If your category IDs are 1-indexed, set `offset: 1`. To use `ReduceLROnPlateau` scheduler instead of `StepLR` scheduler, set `lr_patience`. Finally, to run your experiment, use the command: 67 | 68 | `python src/main.py -c cfgs/my_config.yaml` 69 | 70 | ### Suppressing Caffe2 Warning 71 | When training, especially with multiple GPUs and workers, you may see this warning: 72 | 73 | `[W pthreadpool-cpp.cc:88] Warning: Leaking Caffe2 thread-pool after fork. (function pthreadpool)` 74 | 75 | This is a harmless warning due to our version of PyTorch. To suppress it, simply append `2>/dev/null` to your command to send `stderr` to `/dev/null`: 76 | 77 | `python src/main.py -c cfgs/fsod_split0.yaml 2>/dev/null` 78 | 79 | Note that this will also send any legitimate error messages to `/dev/null`. For this reason, we recommend debugging on a single GPU with few workers. 80 | 81 | ### Note on Nondeterminism 82 | Due to the use of `atomicAdd` CUDA operations in the deformable attention module from Deformable DETR, training is inherently nondeterministic even with all seeds set. So, it is unlikely that one can reproduce exactly the results seen in the paper. However, by training with a large batch size and proper learning rate decay, most of this nondeterminism can be mitigated. 83 | 84 | ### Previous Works 85 | Original [DETR code](https://github.com/facebookresearch/detr) by Facebook. 86 | 87 | Original [Deformable DETR code](https://github.com/fundamentalvision/Deformable-DETR) by SenseTime. 88 | 89 | We use the [FSOD](https://arxiv.org/abs/1908.01998), [FGVC-Aircraft](https://arxiv.org/abs/1306.5151), and [iNaturalist 2017](https://arxiv.org/abs/1707.06642) datasets. We also use the [sparsemax](https://arxiv.org/abs/1602.02068) function. 90 | 91 | From a WSOD perspective, our work builds most heavily on [Uijlings et al. 2018](https://arxiv.org/abs/1708.06128) and [Zhong et al. 2020](https://arxiv.org/abs/2007.07986). Our MIL classifier is based on [WSDDN](https://arxiv.org/abs/1511.02853). Check them out! 92 | -------------------------------------------------------------------------------- /src/model/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, dim=256, feature_levels=4, heads=8, points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | """ 35 | super().__init__() 36 | if dim % heads != 0: 37 | raise ValueError('dim must be divisible by heads, but got {} and {}'.format(dim, heads)) 38 | _d_per_head = dim // heads 39 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 40 | if not _is_power_of_2(_d_per_head): 41 | warnings.warn("You'd better set dim in MSDeformAttn to make the dimension of each attention head a power of 2 " 42 | "which is more efficient in our CUDA implementation.") 43 | 44 | self.im2col_step = 64 45 | 46 | self.dim = dim 47 | self.feature_levels = feature_levels 48 | self.heads = heads 49 | self.points = points 50 | 51 | self.sampling_offsets = nn.Linear(dim, heads * feature_levels * points * 2) 52 | self.attention_weights = nn.Linear(dim, heads * feature_levels * points) 53 | self.value_proj = nn.Linear(dim, dim) 54 | self.output_proj = nn.Linear(dim, dim) 55 | 56 | self._reset_parameters() 57 | 58 | def _reset_parameters(self): 59 | constant_(self.sampling_offsets.weight.data, 0.) 60 | thetas = torch.arange(self.heads, dtype=torch.float32) * (2.0 * math.pi / self.heads) 61 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 62 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.heads, 1, 1, 2).repeat(1, self.feature_levels, self.points, 1) 63 | for i in range(self.points): 64 | grid_init[:, :, i, :] *= i + 1 65 | with torch.no_grad(): 66 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 67 | constant_(self.attention_weights.weight.data, 0.) 68 | constant_(self.attention_weights.bias.data, 0.) 69 | xavier_uniform_(self.value_proj.weight.data) 70 | constant_(self.value_proj.bias.data, 0.) 71 | xavier_uniform_(self.output_proj.weight.data) 72 | constant_(self.output_proj.bias.data, 0.) 73 | 74 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 75 | """ 76 | :param query (N, Length_{query}, C) 77 | :param reference_points (N, Length_{query}, feature_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 78 | or (N, Length_{query}, feature_levels, 4), add additional (w, h) to form reference boxes 79 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 80 | :param input_spatial_shapes (feature_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 81 | :param input_level_start_index (feature_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 82 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 83 | 84 | :return output (N, Length_{query}, C) 85 | """ 86 | N, Len_q, _ = query.shape 87 | N, Len_in, _ = input_flatten.shape 88 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 89 | 90 | value = self.value_proj(input_flatten) 91 | if input_padding_mask is not None: 92 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 93 | value = value.view(N, Len_in, self.heads, self.dim // self.heads) 94 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.heads, self.feature_levels, self.points, 2) 95 | attention_weights = self.attention_weights(query).view(N, Len_q, self.heads, self.feature_levels * self.points) 96 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.heads, self.feature_levels, self.points) 97 | # N, Len_q, heads, feature_levels, points, 2 98 | if reference_points.shape[-1] == 2: 99 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 100 | sampling_locations = reference_points[:, :, None, :, None, :] \ 101 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 102 | elif reference_points.shape[-1] == 4: 103 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 104 | + sampling_offsets / self.points * reference_points[:, :, None, :, None, 2:] * 0.5 105 | else: 106 | raise ValueError( 107 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 108 | output = MSDeformAttnFunction.apply( 109 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 110 | output = self.output_proj(output) 111 | return output 112 | -------------------------------------------------------------------------------- /src/args.py: -------------------------------------------------------------------------------- 1 | """Sets command line and config file arguments.""" 2 | 3 | # Imports other packages. 4 | from configargparse import Parser 5 | from pytorch_lightning import Trainer 6 | 7 | # Imports local packages. 8 | from model.ws_detr import WS_DETR 9 | 10 | 11 | def parse_args(): 12 | """Parses command line and config file arguments.""" 13 | 14 | # Instantiates config arg parser with required config file. 15 | parser = Parser( 16 | args_for_setting_config_path=["-c", "--cfg", "--config"], 17 | config_arg_is_required=True, 18 | ) 19 | 20 | # Adds command line, Trainer, and model arguments. 21 | parser = add_input_args(parser) 22 | parser = Trainer.add_argparse_args(parser) 23 | parser = WS_DETR.add_model_specific_args(parser) 24 | 25 | args = parser.parse_args() 26 | 27 | return args 28 | 29 | def add_input_args(parser): 30 | """Adds arguments not handled by Trainer or model.""" 31 | 32 | parser.add( 33 | "--train_imgs_dir", 34 | help="Training images directory.", 35 | ) 36 | parser.add( 37 | "--train_anns", 38 | help="Training labels formatted as COCO json.", 39 | ) 40 | parser.add( 41 | "--val_imgs_dir", 42 | help="Validation images directory.", 43 | ) 44 | parser.add( 45 | "--val_anns", 46 | help="Validation labels formatted as COCO json.", 47 | ) 48 | parser.add( 49 | "--test_imgs_dir", 50 | help="Testing images directory.", 51 | ) 52 | parser.add( 53 | "--test_anns", 54 | help="Testing labels formatted as COCO json.", 55 | ) 56 | parser.add( 57 | "--infer_imgs_dir", 58 | help="Images directory for performing inference.", 59 | ) 60 | parser.add( 61 | "--save_dir", 62 | default="out", 63 | help="Directory to save images and checkpoints; overrides PT dir.", 64 | ) 65 | 66 | parser.add( 67 | "--task", 68 | choices=["train", "test", "infer"], 69 | help="Mode to run the model in.", 70 | ) 71 | 72 | parser.add( 73 | "--batch_size", 74 | type=int, 75 | help="Number of images per batch.", 76 | ) 77 | parser.add( 78 | "--class_agnostic_weights", 79 | help="Filepath of class-agnostic model weights.", 80 | ) 81 | parser.add( 82 | "--classes", 83 | type=int, 84 | help="Number of classes in the dataset.", 85 | ) 86 | parser.add( 87 | "--dropout", 88 | type=float, 89 | help="Dropout probability in the Transformer.", 90 | ) 91 | parser.add( 92 | "--joint_probability", 93 | action="store_true", 94 | help=( 95 | "Whether to use our joint probability technique instead of" 96 | " learning the detection branch in the MIL classifier." 97 | ), 98 | ) 99 | parser.add( 100 | "--infer_display_thresh", 101 | type=float, 102 | help="Confidence threshold to display images during inference.", 103 | ) 104 | parser.add( 105 | "--nms_thresh", 106 | type=float, 107 | help="IoU threshold for non-maximum suppression (0 for no NMS).", 108 | ) 109 | parser.add( 110 | "--offset", 111 | type=int, 112 | help="Offset of image label indices.", 113 | ) 114 | parser.add( 115 | "--refresh_rate", 116 | type=int, 117 | help="Batch interval for updating training progress bar.", 118 | ) 119 | parser.add( 120 | "--resume_training", 121 | action="store_true", 122 | help="Whether to resume training using the PL Trainer.", 123 | ) 124 | parser.add( 125 | "--resume_weights", 126 | action="store_true", 127 | help="Whether to load all possible model weights from checkpoint.", 128 | ) 129 | parser.add( 130 | "--sampler", 131 | action="store_true", 132 | help="Whether to use a balanced random sampler in DataLoader.", 133 | ) 134 | parser.add( 135 | "--sparse", 136 | action="store_true", 137 | help=( 138 | "Whether to use sparsemax instead of softmax in the MIL head" 139 | " across the detections dimension." 140 | ), 141 | ) 142 | parser.add( 143 | "--supervised", 144 | action="store_true", 145 | help=( 146 | "Whether a fully-supervised model is being used for testing or" 147 | " inference (e.g., when visualizing class-agnostic boxes)." 148 | ), 149 | ) 150 | parser.add( 151 | "--viz_test_batches", 152 | type=int, 153 | help=("How many batches to visualize with prediction and" 154 | " ground-truth boxes during validation and test steps." 155 | ), 156 | ) 157 | parser.add( 158 | "--weights", 159 | help="Filepath of model weights.", 160 | ) 161 | parser.add( 162 | "--workers", 163 | type=int, 164 | help="Number of workers in DataLoader.", 165 | ) 166 | 167 | parser.add( 168 | "--lr_backbone", 169 | type=float, 170 | help="Learning rate for backbone and input projection.", 171 | ) 172 | parser.add( 173 | "--lr_detr", 174 | type=float, 175 | help="Learning rate for DETR modules.", 176 | ) 177 | parser.add( 178 | "--lr_drop", 179 | type=float, 180 | help=("Factor by which to drop the learning rate every" 181 | " lr_patience epochs with no loss improvement." 182 | ), 183 | ) 184 | parser.add( 185 | "--lr_mil", 186 | type=float, 187 | help="Learning rate for MIL head." 188 | ) 189 | parser.add( 190 | "--lr_patience", 191 | type=int, 192 | help=("How many epochs with no loss improvement" 193 | " after which the learning rate will drop." 194 | ), 195 | ) 196 | parser.add( 197 | "--lr_step_size", 198 | type=int, 199 | help="How many epochs to run before dropping the learning rate.", 200 | ) 201 | parser.add( 202 | "--objectness_scale", 203 | type=float, 204 | help="Scaling term for objectness regularization.", 205 | ) 206 | parser.add( 207 | "--weight_decay", 208 | type=float, 209 | help="Weight decay factor.", 210 | ) 211 | 212 | return parser 213 | 214 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | """Main script for training, validation, and inference.""" 2 | 3 | # Imports Python builtins. 4 | from copy import deepcopy 5 | import os 6 | import os.path as osp 7 | import resource 8 | 9 | # Imports PyTorch packages. 10 | import torch 11 | 12 | # Imports other packages. 13 | from configargparse import Parser 14 | from PIL import ImageFile 15 | from pytorch_lightning import Trainer 16 | from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint 17 | from pytorch_lightning.callbacks.progress.tqdm_progress import TQDMProgressBar 18 | from pytorch_lightning.utilities.seed import seed_everything 19 | 20 | # Imports local packages. 21 | from args import parse_args 22 | from coco_tools.coco import coco_loader 23 | from coco_tools.infer import infer_loader 24 | from model.ws_detr import WS_DETR 25 | from utils.misc import get_state_dict_from_checkpoint 26 | 27 | # Prevents PIL from throwing invalid error on large image files. 28 | ImageFile.LOAD_TRUNCATED_IMAGES = True 29 | 30 | # Prevents DataLoader memory error. 31 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 32 | resource.setrlimit(resource.RLIMIT_NOFILE, (8192, rlimit[1])) 33 | 34 | 35 | def load_class_agnostic_detector(args): 36 | """Loads class agnostic detector and freezes parameters.""" 37 | 38 | agnostic_args = deepcopy(args) 39 | agnostic_args.classes = 2 40 | class_agnostic_detector = WS_DETR(agnostic_args) 41 | 42 | checkpoint = torch.load(args.class_agnostic_weights, map_location="cpu") 43 | state_dict = get_state_dict_from_checkpoint(checkpoint) 44 | class_agnostic_detector.load_state_dict(state_dict, strict=False) 45 | print( 46 | f"Class-agnostic detector loaded from {args.class_agnostic_weights}." 47 | ) 48 | 49 | class_agnostic_detector.eval() 50 | for p in class_agnostic_detector.parameters(): 51 | p.requires_grad = False 52 | 53 | return class_agnostic_detector 54 | 55 | def load_model(args, coco_groundtruth, class_names=None): 56 | """Loads WS-DETR model and optionally loads weights.""" 57 | 58 | # Loads class-agnostic detector during training. 59 | # Otherwise, it is saved in the WS-DETR weights. 60 | class_agnostic_detector = None 61 | if args.task == "train" or args.supervised: 62 | class_agnostic_detector = load_class_agnostic_detector(args) 63 | 64 | # Instantiates WS-DETR model. 65 | model = WS_DETR( 66 | args, 67 | class_agnostic_detector=class_agnostic_detector, 68 | class_names=class_names, 69 | coco_groundtruth=coco_groundtruth, 70 | ) 71 | 72 | # Loads model weights. 73 | if args.weights: 74 | checkpoint = torch.load(args.weights, map_location="cpu") 75 | state_dict = get_state_dict_from_checkpoint(checkpoint) 76 | 77 | if args.resume_training and args.weights.endswith("ckpt"): 78 | args.ckpt_path = args.weights 79 | print(f"Resuming training state from {args.weights}.") 80 | elif args.resume_weights or args.task in ("test", "infer"): 81 | model.load_state_dict(state_dict, strict=False) 82 | print(f"Weights loaded from {args.weights}.") 83 | else: 84 | # Drops MIL head from checkpoint. 85 | state_dict = { 86 | k: v for k, v in state_dict.items() 87 | if "class_embed" not in k 88 | and "det_embed" not in k 89 | } 90 | 91 | model.load_state_dict(state_dict, strict=False) 92 | print(f"Weights loaded from {args.weights}.") 93 | 94 | return model 95 | 96 | def load_trainer(args): 97 | """Loads PyTorch Lightning Trainer with callbacks.""" 98 | 99 | # Instantiates checkpointer to save model 100 | # when a new best mAP is reached. 101 | checkpointer = ModelCheckpoint( 102 | filename="{epoch}-{mAP:.2f}", 103 | mode="max", 104 | monitor="mAP", 105 | save_last=True, 106 | ) 107 | 108 | # Instantiates progress bar. Changing refresh rate is useful when 109 | # stdout goes to a logfile (e.g., on cluster). 1 is normal and 0 disables. 110 | progress_bar = TQDMProgressBar(refresh_rate=args.refresh_rate) 111 | 112 | # Sets DDP strategy for multi-GPU training. 113 | args.strategy = "ddp" if args.gpus > 1 else None 114 | 115 | # Instantiates PL Trainer using args. 116 | callbacks = [checkpointer, progress_bar] 117 | trainer = Trainer.from_argparse_args(args, callbacks=callbacks) 118 | 119 | return trainer 120 | 121 | def main(args): 122 | """Trains, tests, or infers with model as specified by args.""" 123 | 124 | # Sets global seed for reproducibility. 125 | # Note: Due to CUDA operations which cannot be made deterministic, 126 | # the code will still not be perfectly reproducible. 127 | seed_everything(seed=42, workers=True) 128 | 129 | # Sets output directory. 130 | if "PT_OUTPUT_DIR" in os.environ: 131 | args.default_root_dir = os.environ["PT_OUTPUT_DIR"] 132 | elif args.save_dir: 133 | args.default_root_dir = args.save_dir 134 | else: 135 | args.default_root_dir = os.getcwd() 136 | 137 | # Sets outputs directory for inference images. 138 | args.imgs_dir = osp.join(args.default_root_dir, "imgs") 139 | os.makedirs(args.imgs_dir, exist_ok=True) 140 | 141 | # Instantiates COCO dataloaders and ground truth. 142 | if args.task == "train": 143 | train_loader = coco_loader(args, task="train") 144 | val_loader = coco_loader(args, task="val") 145 | coco_groundtruth = val_loader.dataset.coco 146 | elif args.task == "test": 147 | val_loader = coco_loader(args, task="test") 148 | coco_groundtruth = val_loader.dataset.coco 149 | elif args.task == "infer": 150 | val_loader = infer_loader(args) 151 | coco_groundtruth = None 152 | 153 | class_names = None 154 | if args.task != "infer": 155 | cats = list(val_loader.dataset.coco.cats.values()) 156 | cats = sorted(cats, key=lambda x: x["id"]) 157 | class_names = [c["name"] for c in cats] 158 | 159 | model = load_model(args, coco_groundtruth, class_names=class_names) 160 | trainer = load_trainer(args) 161 | 162 | if args.task == "train": 163 | trainer.fit(model, train_loader, val_loader) 164 | elif args.task == "test": 165 | trainer.test(model, val_loader) 166 | elif args.task == "infer": 167 | trainer.predict(model.eval(), val_loader) 168 | 169 | 170 | if __name__ == "__main__": 171 | args = parse_args() 172 | 173 | main(args) 174 | 175 | -------------------------------------------------------------------------------- /src/model/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /src/coco_tools/transforms.py: -------------------------------------------------------------------------------- 1 | """Transforms that work on both bboxes and targets; from Deformable DETR.""" 2 | 3 | # Imports Python builtins. 4 | import random 5 | 6 | # Imports PyTorch packages. 7 | import torch 8 | import torchvision.transforms as T 9 | import torchvision.transforms.functional as F 10 | 11 | # Imports other packages. 12 | import PIL 13 | 14 | # Imports local packages. 15 | from utils.box_ops import box_xyxy_to_cxcywh 16 | 17 | 18 | import numpy as np 19 | 20 | def crop(image, target, region): 21 | cropped_image = F.crop(image, *region) 22 | 23 | if target is None: 24 | return cropped_image, None 25 | 26 | target = target.copy() 27 | i, j, h, w = region 28 | 29 | target["size"] = torch.tensor([h, w]) 30 | 31 | fields = ["labels", "area", "iscrowd"] 32 | 33 | if "boxes" in target: 34 | boxes = target["boxes"] 35 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 36 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 37 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 38 | cropped_boxes = cropped_boxes.clamp(min=0) 39 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 40 | target["boxes"] = cropped_boxes.reshape(-1, 4) 41 | target["area"] = area 42 | fields.append("boxes") 43 | 44 | # remove elements for which the boxes have zero area 45 | if "boxes" in target: 46 | # favor boxes selection when defining which elements to keep 47 | # this is compatible with previous implementation 48 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 49 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 50 | 51 | for field in fields: 52 | target[field] = target[field][keep] 53 | 54 | return cropped_image, target 55 | 56 | def hflip(image, target): 57 | flipped_image = F.hflip(image) 58 | 59 | if target is None: 60 | return flipped_image, None 61 | 62 | w, h = image.size 63 | 64 | target = target.copy() 65 | if "boxes" in target: 66 | boxes = target["boxes"] 67 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 68 | target["boxes"] = boxes 69 | 70 | return flipped_image, target 71 | 72 | def resize(image, target, size, max_size=None): 73 | # size can be min_size (scalar) or (w, h) tuple 74 | 75 | rescaled_image = F.resize(image, size, max_size=max_size) 76 | 77 | if target is None: 78 | return rescaled_image, None 79 | 80 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 81 | ratio_width, ratio_height = ratios 82 | 83 | target = target.copy() 84 | if "boxes" in target: 85 | boxes = target["boxes"] 86 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 87 | target["boxes"] = scaled_boxes 88 | 89 | if "area" in target: 90 | area = target["area"] 91 | scaled_area = area * (ratio_width * ratio_height) 92 | target["area"] = scaled_area 93 | 94 | h, w = rescaled_image.size[::-1] 95 | target["size"] = torch.tensor([h, w]) 96 | 97 | return rescaled_image, target 98 | 99 | def pad(image, target, padding): 100 | # assumes that we only pad on the bottom right corners 101 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 102 | if target is None: 103 | return padded_image, None 104 | target = target.copy() 105 | # should we do something wrt the original size? 106 | target["size"] = torch.tensor(padded_image[::-1]) 107 | return padded_image, target 108 | 109 | class RandomCrop(object): 110 | def __init__(self, size): 111 | self.size = size 112 | 113 | def __call__(self, img, target): 114 | region = T.RandomCrop.get_params(img, self.size) 115 | return crop(img, target, region) 116 | 117 | class RandomSizeCrop(object): 118 | def __init__(self, min_size: int, max_size: int): 119 | self.min_size = min_size 120 | self.max_size = max_size 121 | 122 | def __call__(self, img: PIL.Image.Image, target: dict): 123 | w = random.randint(self.min_size, min(img.width, self.max_size)) 124 | h = random.randint(self.min_size, min(img.height, self.max_size)) 125 | region = T.RandomCrop.get_params(img, [h, w]) 126 | return crop(img, target, region) 127 | 128 | class CenterCrop(object): 129 | def __init__(self, size): 130 | self.size = size 131 | 132 | def __call__(self, img, target): 133 | image_width, image_height = img.size 134 | crop_height, crop_width = self.size 135 | crop_top = int(round((image_height - crop_height) / 2.)) 136 | crop_left = int(round((image_width - crop_width) / 2.)) 137 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 138 | 139 | class RandomHorizontalFlip(object): 140 | def __init__(self, p=0.5): 141 | self.p = p 142 | 143 | def __call__(self, img, target): 144 | if random.random() < self.p: 145 | return hflip(img, target) 146 | return img, target 147 | 148 | class RandomResize(object): 149 | def __init__(self, sizes, max_size=None): 150 | assert isinstance(sizes, (list, tuple)) 151 | self.sizes = sizes 152 | self.max_size = max_size 153 | 154 | def __call__(self, img, target=None): 155 | size = random.choice(self.sizes) 156 | return resize(img, target, size, self.max_size) 157 | 158 | class RandomPad(object): 159 | def __init__(self, max_pad): 160 | self.max_pad = max_pad 161 | 162 | def __call__(self, img, target): 163 | pad_x = random.randint(0, self.max_pad) 164 | pad_y = random.randint(0, self.max_pad) 165 | return pad(img, target, (pad_x, pad_y)) 166 | 167 | class RandomSelect(object): 168 | """ 169 | Randomly selects between transforms1 and transforms2, 170 | with probability p for transforms1 and (1 - p) for transforms2 171 | """ 172 | def __init__(self, transforms1, transforms2, p=0.5): 173 | self.transforms1 = transforms1 174 | self.transforms2 = transforms2 175 | self.p = p 176 | 177 | def __call__(self, img, target): 178 | if random.random() < self.p: 179 | return self.transforms1(img, target) 180 | return self.transforms2(img, target) 181 | 182 | class ToTensor(object): 183 | def __call__(self, img, target): 184 | return F.to_tensor(img), target 185 | 186 | class RandomErasing(object): 187 | def __init__(self, *args, **kwargs): 188 | self.eraser = T.RandomErasing(*args, **kwargs) 189 | 190 | def __call__(self, img, target): 191 | return self.eraser(img), target 192 | 193 | class Normalize(object): 194 | def __init__(self, mean, std): 195 | self.mean = mean 196 | self.std = std 197 | 198 | def __call__(self, image, target=None): 199 | image = F.normalize(image, mean=self.mean, std=self.std) 200 | if target is None: 201 | return image, None 202 | target = target.copy() 203 | h, w = image.shape[-2:] 204 | if "boxes" in target: 205 | boxes = target["boxes"] 206 | boxes = box_xyxy_to_cxcywh(boxes) 207 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 208 | target["boxes"] = boxes 209 | return image, target 210 | 211 | class Compose(object): 212 | def __init__(self, transforms): 213 | self.transforms = transforms 214 | 215 | def __call__(self, image, target): 216 | for t in self.transforms: 217 | image, target = t(image, target) 218 | return image, target 219 | 220 | def __repr__(self): 221 | format_string = self.__class__.__name__ + "(" 222 | for t in self.transforms: 223 | format_string += "\n" 224 | format_string += " {0}".format(t) 225 | format_string += "\n)" 226 | return format_string 227 | 228 | -------------------------------------------------------------------------------- /src/coco_tools/coco.py: -------------------------------------------------------------------------------- 1 | """Defines dataset, dataloader, and processing for COCO-style datasets.""" 2 | 3 | # Imports Python builtins. 4 | import json 5 | 6 | # Imports PyTorch packages. 7 | import torch 8 | from torch.utils.data import DataLoader, WeightedRandomSampler 9 | from torchvision.datasets.coco import CocoDetection as TorchvisionCocoDetection 10 | 11 | # Imports other packages. 12 | from pycocotools.coco import COCO 13 | from pycocotools.cocoeval import COCOeval 14 | 15 | # Imports local packages. 16 | import coco_tools.transforms as T 17 | from utils.box_ops import box_xyxy_to_xywh 18 | from utils.misc import ( 19 | get_balanced_sampler_weights_by_id, 20 | nested_collate, 21 | ) 22 | 23 | 24 | def get_transform(aug=False): 25 | """Converts to tensor and normalizes to ImageNet statistics. 26 | 27 | Optionally performs data augmentation according to Deformable DETR. 28 | """ 29 | 30 | normalize = T.Compose([ 31 | T.ToTensor(), 32 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 33 | ]) 34 | 35 | if aug: 36 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 37 | 38 | return T.Compose([ 39 | T.RandomHorizontalFlip(), 40 | T.RandomResize(scales, max_size=1333), 41 | normalize, 42 | ]) 43 | else: 44 | return T.Compose([ 45 | T.RandomResize([800], max_size=1333), 46 | normalize, 47 | ]) 48 | 49 | class CocoDetection(TorchvisionCocoDetection): 50 | """Processing object detection datasets with COCO-style annotations. 51 | 52 | Similar to Torchvision CocoDetection with additional 53 | processing and formatting for weakly supervised training. 54 | """ 55 | 56 | def __init__(self, data_dir, labels_json, task="train"): 57 | super().__init__(data_dir, labels_json) 58 | 59 | labels = json.load(open(labels_json, "r")) 60 | 61 | self.image_classes_by_id = { 62 | img["id"]: img["classes"] for img in labels["images"] 63 | } 64 | 65 | # Sets transform as specified by task. 66 | aug = True if task == "train" else False 67 | self.transform = get_transform(aug=aug) 68 | 69 | def __getitem__(self, idx): 70 | # Gets image and target from dataset. 71 | img_id = self.ids[idx] 72 | img = self._load_image(img_id) 73 | target = self._load_target(img_id) 74 | 75 | # Adds image id as a key to the target dict. 76 | target = {"image_id": img_id, "annotations": target} 77 | 78 | # Preprocesses targets into dict. 79 | target = self.prepare(target, *img.size) 80 | 81 | # Applies transformations to image and target. 82 | img, target = self.transform(img, target) 83 | 84 | return img, target 85 | 86 | def prepare(self, target, width, height): 87 | """Loads COCO annotations into dict of tensors.""" 88 | 89 | # Loads image id and annotation from target. 90 | image_id = target["image_id"] 91 | ann = target["annotations"] 92 | 93 | # Removes crowd RLEs (keeps only single-instance bounding boxes). 94 | ann = [ 95 | obj for obj in ann 96 | if "iscrowd" not in obj 97 | or not obj["iscrowd"] 98 | ] 99 | 100 | # Extracts targets of interest from annotations into lists. 101 | boxes = [obj["bbox"] for obj in ann] 102 | area = [obj["area"] for obj in ann] 103 | iscrowd = [0] * len(ann) 104 | 105 | box_classes = [] 106 | image_classes = [] 107 | if self.image_classes_by_id: 108 | image_classes = self.image_classes_by_id[image_id] 109 | if ann and "category_id" in ann[0]: 110 | box_classes = [obj["category_id"] for obj in ann] 111 | 112 | # If conf is in annotation keys, i.e., this json is from 113 | # the pseudo-label prediction, extract it. 114 | confs = torch.ones(len(ann)) 115 | if ann and "conf" in ann[0]: 116 | confs = torch.stack(torch.tensor([obj["conf"] for obj in ann])) 117 | 118 | # Converts targets to tensors. 119 | image_id = torch.tensor([image_id]) 120 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 121 | image_classes = torch.tensor(image_classes, dtype=torch.int64) 122 | box_classes = torch.tensor(box_classes, dtype=torch.int64) 123 | area = torch.tensor(area, dtype=torch.float32) 124 | iscrowd = torch.tensor(iscrowd) 125 | size = torch.as_tensor([int(height), int(width)]) 126 | 127 | if boxes.shape[0]: 128 | # Converts from (x1, y1, w, h) to (x1, y1, x2, y2). 129 | boxes[:, 2:] += boxes[:, :2] 130 | 131 | # Clamps boxes to image size. 132 | boxes[:, 0::2].clamp_(min=0., max=width) 133 | boxes[:, 1::2].clamp_(min=0., max=height) 134 | 135 | # Removes invalid boxes from targets. 136 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 137 | boxes = boxes[keep] 138 | if not self.image_classes_by_id: 139 | image_classes = image_classes[keep] 140 | box_classes = box_classes[keep] 141 | area = area[keep] 142 | iscrowd = iscrowd[keep] 143 | 144 | # Populates targets dict. 145 | target = {} 146 | target["image_id"] = image_id 147 | target["boxes"] = boxes 148 | target["confs"] = confs 149 | target["image_labels"] = image_classes 150 | target["box_labels"] = box_classes 151 | target["area"] = area 152 | target["iscrowd"] = iscrowd 153 | target["orig_size"] = size 154 | target["size"] = size 155 | 156 | return target 157 | 158 | def coco_loader(args, task="train"): 159 | """Builds a dataloader for COCO-style datasets.""" 160 | 161 | # Extracts data and labels location from args based on task. 162 | data_dir = vars(args)[task + "_imgs_dir"] 163 | labels_json = vars(args)[task + "_anns"] 164 | 165 | labels = json.load(open(labels_json, "r")) 166 | if "classes" not in labels["images"][0]: 167 | # Adds a classes field to each image with its image-level labels. 168 | print("Updating labels json") 169 | tmp_dict = {img["id"]: [] for img in labels["images"]} 170 | for ann in labels["annotations"]: 171 | tmp_dict[ann["image_id"]].append(ann["category_id"]) 172 | for img_id in tmp_dict.keys(): 173 | tmp_dict[img_id] = sorted(list(set(tmp_dict[img_id]))) 174 | for img in labels["images"]: 175 | img["classes"] = tmp_dict[img["id"]] 176 | # Removes images with no annotations. 177 | labels["images"] = [img for img in labels["images"] if img["classes"]] 178 | json.dump(labels, open(labels_json, "w")) 179 | print("Updated labels json.") 180 | 181 | dataset = CocoDetection(data_dir, labels_json, task=task) 182 | 183 | # Initializes a balanced random sampler for single-GPU training only. 184 | sampler = None 185 | shuffle = True if task == "train" else False 186 | if task == "train" and args.sampler: 187 | if args.gpus == 1: 188 | shuffle = None 189 | weights_by_id = get_balanced_sampler_weights_by_id(labels, offset) 190 | weights_by_idx = [weights_by_id[img_id] for img_id in dataset.ids] 191 | num_imgs = len(weights_by_idx) 192 | sampler = WeightedRandomSampler(weights_by_idx, num_imgs) 193 | else: 194 | raise NotImplementedError( 195 | "Balanced random sampler is not" 196 | " implemented for multi-GPU training." 197 | ) 198 | 199 | loader = DataLoader( 200 | dataset, 201 | batch_size=args.batch_size, 202 | collate_fn=nested_collate, 203 | num_workers=args.workers, 204 | pin_memory=True, 205 | sampler=sampler, 206 | shuffle=shuffle, 207 | ) 208 | 209 | return loader 210 | 211 | def prepare_coco_results(results): 212 | """Loads model results into COCO dict for evaluation.""" 213 | 214 | def coco_dict(img_id, box, conf, pred): 215 | return { 216 | "image_id": img_id, 217 | "bbox": box, 218 | "score": conf, 219 | "category_id": pred, 220 | } 221 | 222 | coco_results = [] 223 | for orig_id, result in results.items(): 224 | if not result: 225 | continue 226 | 227 | # Extracts results from model outputs. 228 | boxes = result["boxes"] 229 | confs = result["confs"].tolist() 230 | preds = result["preds"].tolist() 231 | 232 | # Converts boxes to COCO format. 233 | boxes = box_xyxy_to_xywh(boxes).tolist() 234 | 235 | result_iter = zip(boxes, confs, preds) 236 | 237 | # Builds COCO-style dict from results. 238 | coco_result = [ 239 | coco_dict(orig_id, *res) 240 | for res in result_iter 241 | ] 242 | 243 | # Adds dict to list of all results. 244 | coco_results.extend(coco_result) 245 | 246 | return coco_results 247 | 248 | def coco_evaluate(results, coco_groundtruth): 249 | """Runs COCO evaluation and returns statistics including mAP.""" 250 | 251 | # Processes results into COCO format. 252 | coco_results = prepare_coco_results(results) 253 | 254 | # Initializes COCO evaluator. 255 | coco_detections = COCO.loadRes(coco_groundtruth, coco_results) 256 | coco_evaluator = COCOeval( 257 | cocoGt=coco_groundtruth, 258 | cocoDt=coco_detections, 259 | iouType="bbox", 260 | ) 261 | 262 | # Evaluates COCO results. 263 | coco_evaluator.evaluate() 264 | coco_evaluator.accumulate() 265 | coco_evaluator.summarize() 266 | 267 | stats = coco_evaluator.stats.tolist() 268 | 269 | return stats 270 | 271 | -------------------------------------------------------------------------------- /src/utils/misc.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous utility functions.""" 2 | 3 | # Imports Python builtins. 4 | from copy import deepcopy 5 | import os.path as osp 6 | from typing import Optional, List 7 | 8 | # Imports PyTorch packages. 9 | import torch 10 | from torch import nn 11 | from torch import Tensor 12 | 13 | # Imports other packages. 14 | import numpy as np 15 | from PIL import Image, ImageDraw, ImageFont 16 | 17 | 18 | class NestedTensor(object): 19 | """Class for collection of Tensors of different sizes. From DETR.""" 20 | 21 | def __init__(self, tensors, mask: Optional[Tensor]): 22 | self.tensors = tensors 23 | self.mask = mask 24 | 25 | def to(self, device, non_blocking=False): 26 | cast_tensor = self.tensors.to(device, non_blocking=non_blocking) 27 | mask = self.mask 28 | if mask is not None: 29 | assert mask is not None 30 | cast_mask = mask.to(device, non_blocking=non_blocking) 31 | else: 32 | cast_mask = None 33 | return NestedTensor(cast_tensor, cast_mask) 34 | 35 | def record_stream(self, *args, **kwargs): 36 | self.tensors.record_stream(*args, **kwargs) 37 | if self.mask is not None: 38 | self.mask.record_stream(*args, **kwargs) 39 | 40 | def decompose(self): 41 | return self.tensors, self.mask 42 | 43 | def __repr__(self): 44 | return str(self.tensors) 45 | 46 | def _max_by_axis(the_list): 47 | """Helper function for creating a NestedTensor. From DETR.""" 48 | 49 | maxes = the_list[0] 50 | for sublist in the_list[1:]: 51 | for index, item in enumerate(sublist): 52 | maxes[index] = max(maxes[index], item) 53 | return maxes 54 | 55 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 56 | """Creates a NestedTensor from a list of Tensors. From DETR.""" 57 | 58 | if tensor_list[0].ndim == 3: 59 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 60 | batch_shape = [len(tensor_list)] + max_size 61 | b, c, h, w = batch_shape 62 | dtype = tensor_list[0].dtype 63 | device = tensor_list[0].device 64 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 65 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 66 | for img, pad_img, m in zip(tensor_list, tensor, mask): 67 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 68 | m[: img.shape[1], :img.shape[2]] = False 69 | else: 70 | raise ValueError("Tensors must be 3-dimensional.") 71 | return NestedTensor(tensor, mask) 72 | 73 | def inverse_sigmoid(x, eps=1e-5): 74 | """Transforms a sigmoid vector back into logits.""" 75 | 76 | x = x.clamp(min=0, max=1) 77 | x1 = x.clamp(min=eps) 78 | x2 = (1 - x).clamp(min=eps) 79 | return torch.log(x1/x2) 80 | 81 | def nested_collate(batch): 82 | """Collates batch of images as NestedTensor for use in DataLoader.""" 83 | 84 | batch = list(zip(*batch)) 85 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 86 | return tuple(batch) 87 | 88 | def get_clones(module, num): 89 | """Duplicates modules for multi-scale learning.""" 90 | 91 | return nn.ModuleList([deepcopy(module) for j in range(num)]) 92 | 93 | def tensor_to_pil(img): 94 | """De-normalizes a Tensor and converts to PIL Image for saving.""" 95 | 96 | mean = [0.485, 0.456, 0.406] 97 | std = [0.229, 0.224, 0.225] 98 | 99 | # De-normalizes image. 100 | for c, m, s in zip(img, mean, std): 101 | c.mul_(s).add_(m) 102 | 103 | # Switches channels as required by PIL. 104 | img = torch.transpose(img, 1, 2) 105 | img = torch.transpose(img, 0, 2) 106 | img = Image.fromarray((img.cpu().numpy() * 255).astype(np.uint8)) 107 | 108 | return img 109 | 110 | def exclude_params(params, to_exclude): 111 | """Removes a subset of parameters from a list of parameters. 112 | 113 | Useful for setting learning rates for different modules. 114 | """ 115 | 116 | new_params = [] 117 | 118 | for p in params: 119 | exclude = False 120 | 121 | for param_set in to_exclude: 122 | if not exclude: 123 | for q in param_set: 124 | if torch.equal(p, q): 125 | exclude = True 126 | 127 | if not exclude: 128 | new_params.append(p) 129 | 130 | return new_params 131 | 132 | def get_state_dict_from_checkpoint(checkpoint): 133 | """Finds state dict in checkpoint.""" 134 | 135 | if "model" in checkpoint.keys(): 136 | state_dict = checkpoint["model"] 137 | elif "state_dict" in checkpoint.keys(): 138 | state_dict = checkpoint["state_dict"] 139 | else: 140 | raise ValueError("No state dict found in checkpoint.") 141 | 142 | return state_dict 143 | 144 | def get_balanced_sampler_weights_by_id(labels, offset): 145 | """Gets weights for each image for use in a balanced sampler. 146 | 147 | Note that the sampler is not exactly balanced as images may have more 148 | than one class present. The image weight is the mean weight of all 149 | its present classes. 150 | """ 151 | 152 | images = labels["images"] 153 | classes = len(labels["categories"]) 154 | 155 | def to_array(indices): 156 | """Expands list of indices into class-length array with offset.""" 157 | a = np.zeros(classes) 158 | indices = np.asarray(indices) 159 | a[indices - offset] = 1 160 | return a 161 | 162 | vals = [to_array(img["classes"]) for img in images] 163 | totals = np.sum(np.stack(vals), axis=0) 164 | weights = 1. / totals 165 | 166 | classes_by_id = { 167 | img["id"]: np.asarray(img["classes"] - offset) for img in images 168 | } 169 | weights_by_id = { 170 | img["id"]: np.mean(weights[classes_by_id[img["id"]]]) for img in images 171 | } 172 | 173 | return weights_by_id 174 | 175 | def compute_accuracy(results, coco_groundtruth, thresh=0.1): 176 | """Computes (proxy) top1 and top5 classification accuracy. 177 | 178 | The accuracy is the proportion of images for which a class present 179 | in the image is predicted in the top1/top5 most confidence boxes 180 | respectively. Not very rigorous; mostly useful for debugging. 181 | """ 182 | 183 | groundtruth_by_img = {} 184 | for j in coco_groundtruth.imgs.keys(): 185 | groundtruth_by_img[j] = {"boxes": [], "preds": []} 186 | 187 | inds = coco_groundtruth.getAnnIds(imgIds=[j]) 188 | for i in inds: 189 | ann = coco_groundtruth.anns[i] 190 | groundtruth_by_img[j]["boxes"].append(ann["bbox"]) 191 | groundtruth_by_img[j]["preds"].append(ann["category_id"]) 192 | 193 | top1_total = 0 194 | top5_total = 0 195 | for img in results: 196 | res = results[img] 197 | gt = groundtruth_by_img[img] 198 | 199 | classes_in_img = set([p for p in gt["preds"]]) 200 | 201 | # Computes top1 accuracy. 202 | preds = res["preds"][:1] 203 | confs = res["confs"][:1] 204 | for cls in classes_in_img: 205 | for pred, conf in zip(preds, confs): 206 | if cls == pred and conf >= thresh: 207 | top1_total += 1 208 | break 209 | 210 | # Computes top5 accuracy. 211 | preds = res["preds"][:5] 212 | confs = res["confs"][:5] 213 | for cls in classes_in_img: 214 | for pred, conf in zip(preds, confs): 215 | if cls == pred and conf >= thresh: 216 | top5_total += 1 217 | break 218 | 219 | top1_acc = top1_total / len(results) 220 | top5_acc = top5_total / len(results) 221 | 222 | return top1_acc, top5_acc 223 | 224 | def gather_coco_results_across_gpus(results): 225 | """Collates COCO results across multiple GPUs for evaluation.""" 226 | 227 | coco_results = {} 228 | for result in results: 229 | img_results = {} 230 | if len(result["image_id"]) > 1: 231 | for j, img_id in enumerate(result["image_id"]): 232 | img_results[img_id.item()] = { 233 | "boxes": result["boxes"][j], 234 | "confs": result["confs"][j], 235 | "preds": result["preds"][j], 236 | } 237 | else: 238 | img_results[result["image_id"].item()] = { 239 | "boxes": result["boxes"], 240 | "confs": result["confs"], 241 | "preds": result["preds"], 242 | } 243 | coco_results.update(img_results) 244 | 245 | return coco_results 246 | 247 | def save_infer_img( 248 | img, 249 | imgs_dir, 250 | name, 251 | classes, 252 | boxes, 253 | confs, 254 | preds, 255 | offset, 256 | target_boxes=None, 257 | target_labels=None, 258 | ): 259 | """Plots predictions and ground-truth boxes on an image and saves.""" 260 | 261 | img = tensor_to_pil(img) 262 | draw = ImageDraw.Draw(img) 263 | 264 | try: 265 | font = ImageFont.truetype( 266 | "/usr/share/fonts/truetype/lato/Lato-Bold.ttf", 267 | size=16, 268 | ) 269 | text_height, _ = font.getmetrics() 270 | except: 271 | font = ImageFont.load_default() 272 | text_height = 0 273 | 274 | # Plots predictions. 275 | for box, conf, pred in zip(boxes, confs, preds): 276 | draw.rectangle( 277 | ((box[0], box[1]), (box[2], box[3])), 278 | outline="red", 279 | width=3, 280 | ) 281 | 282 | text_anchor = (box[0], box[1] - text_height) 283 | draw.text( 284 | text_anchor, 285 | f"{classes[pred - offset]} @ {conf:.2f}", 286 | fill="red", 287 | font=font, 288 | ) 289 | 290 | # Plots ground-truth boxes. 291 | if target_boxes is not None and target_labels is not None: 292 | for box, label in zip(target_boxes, target_labels): 293 | draw.rectangle( 294 | ((box[0], box[1]), (box[2], box[3])), 295 | outline="blue", 296 | width=3, 297 | ) 298 | 299 | text_anchor = (box[0], box[1] - text_height) 300 | draw.text( 301 | text_anchor, 302 | f"{classes[label - offset]}", 303 | fill="blue", 304 | font=font, 305 | ) 306 | 307 | out_path = osp.join(imgs_dir, name) 308 | img.save(out_path, "JPEG") 309 | 310 | -------------------------------------------------------------------------------- /src/model/deformable_transformer.py: -------------------------------------------------------------------------------- 1 | """Defines DeformableTransformer; from Deformable DETR.""" 2 | 3 | # Imports Python builtins. 4 | import math 5 | 6 | # Imports PyTorch packages. 7 | import torch 8 | from torch import nn 9 | from torch.nn.init import xavier_uniform_, constant_, normal_ 10 | import torch.nn.functional as F 11 | 12 | # Imports local packages. 13 | from model.ops.modules import MSDeformAttn 14 | from utils.misc import get_clones 15 | 16 | 17 | class DeformableTransformer(nn.Module): 18 | def __init__(self, args): 19 | super().__init__() 20 | 21 | self.queries = args.queries 22 | self.batch_size = args.batch_size 23 | 24 | # Initializes encoder and decoder layers. 25 | encoder_layer = DeformableTransformerEncoderLayer(args) 26 | decoder_layer = DeformableTransformerDecoderLayer(args) 27 | 28 | # Initializes encoder by stacking layers. 29 | self.encoder = DeformableTransformerEncoder( 30 | encoder_layer, 31 | args.enc_layers, 32 | ) 33 | 34 | # Initializes decoder by stacking layers. 35 | self.decoder = DeformableTransformerDecoder( 36 | decoder_layer, 37 | args.dec_layers, 38 | ) 39 | 40 | # Initializes feature levels embedding. 41 | self.level_embed = nn.Parameter( 42 | torch.Tensor(args.feature_levels, args.hidden_dim) 43 | ) 44 | 45 | self.reference_points = nn.Linear(args.hidden_dim, 2) 46 | 47 | # Initializes transformer weights and biases. 48 | self._reset_parameters() 49 | 50 | def _reset_parameters(self): 51 | for p in self.parameters(): 52 | if p.dim() > 1: 53 | nn.init.xavier_uniform_(p) 54 | 55 | for m in self.modules(): 56 | if isinstance(m, MSDeformAttn): 57 | m._reset_parameters() 58 | 59 | xavier_uniform_(self.reference_points.weight.data, gain=1.0) 60 | constant_(self.reference_points.bias.data, 0.) 61 | normal_(self.level_embed) 62 | 63 | def get_proposal_pos_embed(self, proposals): 64 | num_pos_feats = 128 65 | temperature = 10000 66 | scale = 2 * math.pi 67 | 68 | dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device) 69 | dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats) 70 | # N, L, 4 71 | proposals = proposals.sigmoid() * scale 72 | # N, L, 4, 128 73 | pos = proposals[:, :, :, None] / dim_t 74 | # N, L, 4, 64, 2 75 | pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2) 76 | 77 | return pos 78 | 79 | def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): 80 | N_, S_, C_ = memory.shape 81 | proposals = [] 82 | _cur = 0 83 | for lvl, (H_, W_) in enumerate(spatial_shapes): 84 | mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H_ * W_)].view(N_, H_, W_, 1) 85 | valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) 86 | valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) 87 | 88 | grid_y, grid_x = torch.meshgrid(torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), 89 | torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device)) 90 | grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) 91 | 92 | scale = torch.cat([valid_W.unsqueeze(-1), valid_H.unsqueeze(-1)], 1).view(N_, 1, 1, 2) 93 | grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) / scale 94 | wh = torch.ones_like(grid) * 0.05 * (2.0 ** lvl) 95 | proposal = torch.cat((grid, wh), -1).view(N_, -1, 4) 96 | proposals.append(proposal) 97 | _cur += (H_ * W_) 98 | output_proposals = torch.cat(proposals, 1) 99 | output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True) 100 | output_proposals = torch.log(output_proposals / (1 - output_proposals)) 101 | output_proposals = output_proposals.masked_fill(memory_padding_mask.unsqueeze(-1), float('inf')) 102 | output_proposals = output_proposals.masked_fill(~output_proposals_valid, float('inf')) 103 | 104 | output_memory = memory 105 | output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0)) 106 | output_memory = output_memory.masked_fill(~output_proposals_valid, float(0)) 107 | output_memory = self.enc_output_norm(self.enc_output(output_memory)) 108 | return output_memory, output_proposals 109 | 110 | def get_valid_ratio(self, mask): 111 | _, H, W = mask.shape 112 | valid_H = torch.sum(~mask[:, :, 0], 1) 113 | valid_W = torch.sum(~mask[:, 0, :], 1) 114 | valid_ratio_h = valid_H.float() / H 115 | valid_ratio_w = valid_W.float() / W 116 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 117 | return valid_ratio 118 | 119 | def forward(self, srcs, masks, pos_embeds, query_embed=None): 120 | # prepare input for encoder 121 | src_flatten = [] 122 | mask_flatten = [] 123 | lvl_pos_embed_flatten = [] 124 | spatial_shapes = [] 125 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 126 | bs, c, h, w = src.shape 127 | spatial_shape = (h, w) 128 | spatial_shapes.append(spatial_shape) 129 | src = src.flatten(2).transpose(1, 2) 130 | mask = mask.flatten(1) 131 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 132 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 133 | lvl_pos_embed_flatten.append(lvl_pos_embed) 134 | src_flatten.append(src) 135 | mask_flatten.append(mask) 136 | src_flatten = torch.cat(src_flatten, 1) 137 | mask_flatten = torch.cat(mask_flatten, 1) 138 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 139 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) 140 | level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) 141 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 142 | 143 | # encoder 144 | memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) 145 | # prepare input for decoder 146 | bs, _, c = memory.shape 147 | query_embed, tgt = torch.split(query_embed, c, dim=1) 148 | query_embed = query_embed.unsqueeze(0).expand(bs, -1, -1) 149 | tgt = tgt.unsqueeze(0).expand(bs, -1, -1) 150 | reference_points = self.reference_points(query_embed).sigmoid() 151 | 152 | # decoder 153 | hs, inter_references = self.decoder(tgt, reference_points, memory, 154 | spatial_shapes, level_start_index, valid_ratios, query_embed, mask_flatten) 155 | 156 | return hs, reference_points, inter_references 157 | 158 | 159 | class DeformableTransformerEncoderLayer(nn.Module): 160 | def __init__(self, args): 161 | super().__init__() 162 | 163 | # self attention 164 | self.self_attn = MSDeformAttn( 165 | dim=args.hidden_dim, 166 | feature_levels=args.feature_levels, 167 | heads=args.heads, 168 | points=args.enc_points, 169 | ) 170 | self.dropout1 = nn.Dropout(args.dropout) 171 | self.norm1 = nn.LayerNorm(args.hidden_dim) 172 | 173 | # ffn 174 | self.linear1 = nn.Linear(args.hidden_dim, args.feedforward_dim) 175 | self.activation = vars(F)[args.activation] 176 | self.dropout2 = nn.Dropout(args.dropout) 177 | self.linear2 = nn.Linear(args.feedforward_dim, args.hidden_dim) 178 | self.dropout3 = nn.Dropout(args.dropout) 179 | self.norm2 = nn.LayerNorm(args.hidden_dim) 180 | 181 | @staticmethod 182 | def with_pos_embed(tensor, pos): 183 | return tensor if pos is None else tensor + pos 184 | 185 | def forward_ffn(self, src): 186 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 187 | src = src + self.dropout3(src2) 188 | src = self.norm2(src) 189 | return src 190 | 191 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): 192 | # self attention 193 | src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) 194 | src = src + self.dropout1(src2) 195 | src = self.norm1(src) 196 | 197 | # ffn 198 | src = self.forward_ffn(src) 199 | 200 | return src 201 | 202 | 203 | class DeformableTransformerEncoder(nn.Module): 204 | def __init__(self, encoder_layer, num_layers): 205 | super().__init__() 206 | self.layers = get_clones(encoder_layer, num_layers) 207 | self.num_layers = num_layers 208 | 209 | @staticmethod 210 | def get_reference_points(spatial_shapes, valid_ratios, device): 211 | reference_points_list = [] 212 | for lvl, (H_, W_) in enumerate(spatial_shapes): 213 | 214 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 215 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 216 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) 217 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) 218 | ref = torch.stack((ref_x, ref_y), -1) 219 | reference_points_list.append(ref) 220 | reference_points = torch.cat(reference_points_list, 1) 221 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 222 | return reference_points 223 | 224 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): 225 | output = src 226 | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) 227 | for _, layer in enumerate(self.layers): 228 | output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) 229 | 230 | return output 231 | 232 | 233 | class DeformableTransformerDecoderLayer(nn.Module): 234 | def __init__(self, args): 235 | super().__init__() 236 | 237 | # cross attention 238 | self.cross_attn = MSDeformAttn( 239 | dim=args.hidden_dim, 240 | feature_levels=args.feature_levels, 241 | heads=args.heads, 242 | points=args.dec_points, 243 | ) 244 | self.dropout1 = nn.Dropout(args.dropout) 245 | self.norm1 = nn.LayerNorm(args.hidden_dim) 246 | 247 | # self attention 248 | self.self_attn = nn.MultiheadAttention( 249 | args.hidden_dim, 250 | args.heads, 251 | dropout=args.dropout 252 | ) 253 | self.dropout2 = nn.Dropout(args.dropout) 254 | self.norm2 = nn.LayerNorm(args.hidden_dim) 255 | 256 | # ffn 257 | self.linear1 = nn.Linear(args.hidden_dim, args.feedforward_dim) 258 | self.activation = vars(F)[args.activation] 259 | self.dropout3 = nn.Dropout(args.dropout) 260 | self.linear2 = nn.Linear(args.feedforward_dim, args.hidden_dim) 261 | self.dropout4 = nn.Dropout(args.dropout) 262 | self.norm3 = nn.LayerNorm(args.hidden_dim) 263 | 264 | @staticmethod 265 | def with_pos_embed(tensor, pos): 266 | return tensor if pos is None else tensor + pos 267 | 268 | def forward_ffn(self, tgt): 269 | tgt2 = self.linear2(self.dropout3(self.activation(self.linear1(tgt)))) 270 | tgt = tgt + self.dropout4(tgt2) 271 | tgt = self.norm3(tgt) 272 | return tgt 273 | 274 | def forward(self, tgt, query_pos, reference_points, src, src_spatial_shapes, level_start_index, src_padding_mask=None): 275 | # self attention 276 | q = k = self.with_pos_embed(tgt, query_pos) 277 | tgt2 = self.self_attn(q.transpose(0, 1), k.transpose(0, 1), tgt.transpose(0, 1))[0].transpose(0, 1) 278 | tgt = tgt + self.dropout2(tgt2) 279 | tgt = self.norm2(tgt) 280 | 281 | # cross attention 282 | tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos), 283 | reference_points, 284 | src, src_spatial_shapes, level_start_index, src_padding_mask) 285 | tgt = tgt + self.dropout1(tgt2) 286 | tgt = self.norm1(tgt) 287 | 288 | # ffn 289 | tgt = self.forward_ffn(tgt) 290 | 291 | return tgt 292 | 293 | 294 | class DeformableTransformerDecoder(nn.Module): 295 | def __init__(self, decoder_layer, num_layers): 296 | super().__init__() 297 | self.layers = get_clones(decoder_layer, num_layers) 298 | self.num_layers = num_layers 299 | 300 | def forward(self, tgt, reference_points, src, src_spatial_shapes, src_level_start_index, src_valid_ratios, 301 | query_pos=None, src_padding_mask=None): 302 | output = tgt 303 | 304 | intermediate = [] 305 | intermediate_reference_points = [] 306 | for lid, layer in enumerate(self.layers): 307 | if reference_points.shape[-1] == 4: 308 | reference_points_input = reference_points[:, :, None] \ 309 | * torch.cat([src_valid_ratios, src_valid_ratios], -1)[:, None] 310 | else: 311 | assert reference_points.shape[-1] == 2 312 | reference_points_input = reference_points[:, :, None] * src_valid_ratios[:, None] 313 | output = layer(output, query_pos, reference_points_input, src, src_spatial_shapes, src_level_start_index, src_padding_mask) 314 | 315 | intermediate.append(output) 316 | intermediate_reference_points.append(reference_points) 317 | 318 | return torch.stack(intermediate), torch.stack(intermediate_reference_points) 319 | 320 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | """Downloads datasets and creates train/test splits.""" 2 | 3 | # Imports Python builtins. 4 | from copy import deepcopy 5 | import json 6 | import math 7 | import os 8 | import os.path as osp 9 | import random 10 | import requests 11 | import shutil 12 | 13 | # Imports other packages. 14 | from configargparse import Parser 15 | import gdown 16 | from PIL import Image 17 | 18 | 19 | def load_anns(path): 20 | """Loads annotation dict from JSON file.""" 21 | 22 | with open(path, "r") as f: 23 | anns = json.load(f) 24 | return anns 25 | 26 | def save_anns(anns, path): 27 | """Saves annotation dict to disk as JSON file.""" 28 | 29 | with open(path, "w") as f: 30 | json.dump(anns, f) 31 | 32 | def make_class_agnostic(anns): 33 | """Sets all categories to 1 in an annotation dict.""" 34 | 35 | class_agnostic_anns = deepcopy(anns) 36 | 37 | categories = [ 38 | {"id": 0, "name": "no object"}, 39 | {"id": 1, "name": "object"}, 40 | ] 41 | 42 | class_agnostic_anns["categories"] = categories 43 | 44 | for ann in class_agnostic_anns["annotations"]: 45 | ann["category_id"] = 1 46 | 47 | return class_agnostic_anns 48 | 49 | def make_subset(anns, img_ids): 50 | """Gets images corresponding to ids from an annotation dict.""" 51 | 52 | subset_anns = deepcopy(anns) 53 | 54 | subset_anns["images"] = [ 55 | img for img in anns["images"] if img["id"] in img_ids 56 | ] 57 | subset_anns["annotations"] = [ 58 | ann for ann in anns["annotations"] if ann["image_id"] in img_ids 59 | ] 60 | 61 | return subset_anns 62 | 63 | def get_classes_by_img(anns): 64 | """Gets dict of {img: classes}.""" 65 | 66 | classes_by_img = {img["id"]: [] for img in anns["images"]} 67 | for ann in anns["annotations"]: 68 | classes_by_img[ann["image_id"]].append(ann["category_id"]) 69 | classes_by_img = { 70 | img_id: set(classes) for img_id, classes in classes_by_img.items() 71 | } 72 | 73 | return classes_by_img 74 | 75 | def get_num_samples_by_cls(anns, img_ids=None): 76 | """Gets dict of {cls: num samples}.""" 77 | 78 | if img_ids: 79 | img_ids = set(img_ids) 80 | else: 81 | img_ids = set([img["id"] for img in anns["images"]]) 82 | 83 | num_samples_by_cls = {} 84 | for ann in anns["annotations"]: 85 | if ann["image_id"] in img_ids: 86 | if ann["category_id"] in num_samples_by_cls: 87 | num_samples_by_cls[ann["category_id"]] += 1 88 | else: 89 | num_samples_by_cls[ann["category_id"]] = 1 90 | 91 | return num_samples_by_cls 92 | 93 | def make_splits(path, train_pcts, splits=1): 94 | """Makes random train/test split(s) from an annotation file.""" 95 | 96 | anns = load_anns(path) 97 | img_ids = [img["id"] for img in anns["images"]] 98 | 99 | for train_pct in train_pcts: 100 | for seed in range(splits): 101 | random.seed(seed) 102 | train_size = math.ceil(train_pct * len(img_ids)) 103 | 104 | train_img_ids = random.sample(img_ids, train_size) 105 | test_img_ids = [ 106 | img_id for img_id in img_ids if img_id not in train_img_ids 107 | ] 108 | 109 | x = zip(("train", "test"), (train_img_ids, test_img_ids)) 110 | for name, ids in x: 111 | subset_anns = make_subset(anns, ids) 112 | class_agnostic_anns = make_class_agnostic(subset_anns) 113 | 114 | suffix = f"_{name}_seed{seed}_trn{train_pct}.json" 115 | subset_anns_path = osp.splitext(path)[0] + suffix 116 | class_agnostic_anns_path = osp.splitext(subset_anns_path)[0] \ 117 | + "_class_agnostic.json" 118 | 119 | save_anns(subset_anns, subset_anns_path) 120 | save_anns(class_agnostic_anns, class_agnostic_anns_path) 121 | 122 | def make_cls_splits(path, cls_pcts, splits=1): 123 | """Makes random class-wise train/test splits from an annotation file.""" 124 | 125 | anns = load_anns(path) 126 | img_ids = [img["id"] for img in anns["images"]] 127 | classes = [cls["id"] for cls in anns["categories"]] 128 | classes_by_img = get_classes_by_img(anns) 129 | 130 | for cls_pct in cls_pcts: 131 | cls_size = math.ceil(cls_pct * len(classes)) 132 | 133 | for seed in range(splits): 134 | random.seed(seed) 135 | cls_ids = random.sample(classes, cls_size) 136 | 137 | train_img_ids = [] 138 | for img_id in img_ids: 139 | if all(cls in cls_ids for cls in classes_by_img[img_id]): 140 | train_img_ids.append(img_id) 141 | 142 | subset_anns = make_subset(anns, train_img_ids) 143 | class_agnostic_anns = make_class_agnostic(subset_anns) 144 | 145 | suffix = f"_train_seed{seed}_cls{cls_pct}.json" 146 | subset_anns_path = osp.splitext(path)[0] + suffix 147 | class_agnostic_anns_path = osp.splitext(subset_anns_path)[0] \ 148 | + "_class_agnostic.json" 149 | 150 | save_anns(subset_anns, subset_anns_path) 151 | save_anns(class_agnostic_anns, class_agnostic_anns_path) 152 | 153 | def fgvc_download(base_dir): 154 | """Downloads and extracts FGVC dataset.""" 155 | 156 | print("Downloading FGVC-Aircraft dataset...") 157 | 158 | url = ( 159 | "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/" 160 | "archives/fgvc-aircraft-2013b.tar.gz" 161 | ) 162 | 163 | fgvc_dir = osp.join(base_dir, "fgvc-aircraft-2013b") 164 | os.makedirs(fgvc_dir, exist_ok=True) 165 | 166 | # Downloads FGVC dataset from VGG. 167 | data = requests.get(url) 168 | tar_path = osp.join(base_dir, "fgvc-aircraft-2013b.tar.gz") 169 | with open(tar_path, "wb") as f: 170 | f.write(data.content) 171 | 172 | print("Done.") 173 | print("Extracting FGVC-Aircraft dataset...") 174 | 175 | # Extracts FGVC dataset. May take a while. 176 | tmp_dir = osp.join(base_dir, "tmp") 177 | shutil.unpack_archive(tar_path, tmp_dir) 178 | 179 | print("Done.") 180 | print("Formatting images...") 181 | 182 | tmp_fgvc_dir = osp.join(tmp_dir, "fgvc-aircraft-2013b", "data") 183 | old_imgs_dir = osp.join(tmp_fgvc_dir, "images") 184 | imgs_dir = osp.join(fgvc_dir, "images") 185 | shutil.move(old_imgs_dir, imgs_dir) 186 | 187 | # Crops images (the bottom 20px are an info banner). 188 | name_to_wh = {} 189 | for img_name in os.listdir(imgs_dir): 190 | img_path = osp.join(imgs_dir, img_name) 191 | img = Image.open(img_path) 192 | w, h = img.size 193 | 194 | img = img.crop((0, 0, w, h - 20)) 195 | img.save(img_path) 196 | 197 | name_to_wh[osp.splitext(img_name)[0]] = img.size 198 | 199 | print("Done.") 200 | print("Converting annotations...") 201 | 202 | # Gets FGVC classes. 203 | name_to_id = {} 204 | categories = [] 205 | fgvc_cat_path = osp.join(tmp_fgvc_dir, "variants.txt") 206 | with open(fgvc_cat_path, "r") as f: 207 | for j, line in enumerate(f): 208 | line = line.strip("\n") 209 | name_to_id[line] = j 210 | categories.append({"id": j, "name": line}) 211 | 212 | # Gets FGVC boxes. 213 | id_to_box = {} 214 | fgvc_box_path = osp.join(tmp_fgvc_dir, "images_box.txt") 215 | with open(fgvc_box_path, "r") as f: 216 | for line in f: 217 | box = line.strip("\n").split() 218 | id_to_box[box[0]] = [int(b) for b in box[1:]] 219 | 220 | # Makes COCO json annotations. 221 | ann_dir = osp.join(fgvc_dir, "annotations") 222 | os.makedirs(ann_dir, exist_ok=True) 223 | for split in ("trainval", "test"): 224 | ann = {"images": [], "categories": categories, "annotations": []} 225 | 226 | fgvc_ann_path = osp.join(tmp_fgvc_dir, f"images_variant_{split}.txt") 227 | with open(fgvc_ann_path, "r") as f: 228 | for j, line in enumerate(f): 229 | line = line.strip("\n").split() 230 | name = f"{line[0]}.jpg" 231 | cat = name_to_id[" ".join(line[1:])] 232 | box = id_to_box[line[0]] 233 | wh = name_to_wh[line[0]] 234 | 235 | img = { 236 | "file_name": name, 237 | "height": wh[1], 238 | "id": j, 239 | "width": wh[0], 240 | } 241 | 242 | box_ann = { 243 | "area": wh[0] * wh[1], 244 | "bbox": box, 245 | "category_id": cat, 246 | "id": j, 247 | "image_id": j, 248 | "iscrowd": 0, 249 | } 250 | 251 | ann["images"].append(img) 252 | ann["annotations"].append(box_ann) 253 | 254 | json_path = osp.join(ann_dir, f"{split}.json") 255 | with open(json_path, "w") as f: 256 | json.dump(ann, f) 257 | 258 | shutil.rmtree(tmp_dir) 259 | 260 | print("Done.") 261 | 262 | def fsod_download(base_dir): 263 | """Downloads and extracts FSOD dataset.""" 264 | 265 | url = ( 266 | "https://drive.google.com/drive/folders/" 267 | "1XXADD7GvW8M_xzgFpHfudYDYtKtDgZGM" 268 | ) 269 | 270 | # Downloads FSOD dataset from Google Drive. 271 | fsod_dir = osp.join(base_dir, "fsod") 272 | gdown.download_folder(url, output=fsod_dir) 273 | 274 | print("Extracting FSOD dataset...") 275 | 276 | # Extracts FSOD dataset. May take a while. 277 | img_dir = osp.join(fsod_dir, "images") 278 | for img_tar_name in os.listdir(img_dir): 279 | img_tar_path = osp.join(img_dir, img_tar_name) 280 | out_path = osp.splitext(img_tar_path)[0] 281 | shutil.unpack_archive(img_tar_path, out_path) 282 | 283 | print("Done.") 284 | 285 | def fsod_split(base_dir): 286 | """Creates train/test splits from FSOD dataset.""" 287 | 288 | print("Creating FSOD splits...") 289 | 290 | ann_dir = osp.join(base_dir, "fsod", "annotations") 291 | 292 | fsod_train_anns_path = osp.join(ann_dir, "fsod_train.json") 293 | fsod_800_anns_path = osp.join(ann_dir, "fsod_800.json") 294 | fsod_test_anns_path = osp.join(ann_dir, "fsod_test.json") 295 | fsod_200_anns_path = osp.join(ann_dir, "fsod_200.json") 296 | 297 | # Renames base annotations. 298 | shutil.copyfile(fsod_train_anns_path, fsod_800_anns_path) 299 | shutil.copyfile(fsod_test_anns_path, fsod_200_anns_path) 300 | 301 | # Saves class-agnostic annotations. 302 | for anns_path in (fsod_800_anns_path, fsod_200_anns_path): 303 | anns = load_anns(anns_path) 304 | class_agnostic_anns_path = anns_path[:-5] + "_class_agnostic.json" 305 | class_agnostic_anns = make_class_agnostic(anns) 306 | save_anns(class_agnostic_anns, class_agnostic_anns_path) 307 | 308 | # Creates 3x 80/20 train/test splits from FSOD 200. 309 | make_splits(fsod_200_anns_path, [0.8], splits=3) 310 | 311 | # Makes 20/40/60/80 class and data splits from FSOD 800. 312 | make_cls_splits(fsod_800_anns_path, [0.2, 0.4, 0.6, 0.8], splits=3) 313 | make_splits(fsod_800_anns_path, [0.2, 0.4, 0.6, 0.8], splits=3) 314 | 315 | print("Done.") 316 | 317 | def inaturalist_download(base_dir): 318 | """Downloads and extracts iNaturalist 2017 dataset.""" 319 | 320 | print("Downloading iNaturalist dataset...") 321 | 322 | base_url = "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/" 323 | img_url = osp.join(base_url, "train_val_images.tar.gz") 324 | train_anns_url = osp.join(base_url, "train_2017_bboxes.zip") 325 | val_anns_url = osp.join(base_url, "val_2017_bboxes.zip") 326 | 327 | # Downloads iNaturalist 2017 images from AWS. 328 | img_data = requests.get(img_url) 329 | img_tar_path = osp.join(base_dir, "inaturalist/train_val_images.tar.gz") 330 | with open(img_tar_path, "wb") as f: 331 | f.write(img_data.content) 332 | 333 | train_anns_data = requests.get(train_anns_url) 334 | train_anns_zip_path = osp.join(base_dir, "inaturalist/train_2017_bboxes.zip") 335 | with open(train_anns_zip_path, "wb") as f: 336 | f.write(train_anns_data.content) 337 | 338 | val_anns_data = requests.get(val_anns_url) 339 | val_anns_zip_path = osp.join(base_dir, "inaturalist/val_2017_bboxes.zip") 340 | with open(val_anns_zip_path, "wb") as f: 341 | f.write(val_anns_data.content) 342 | 343 | print("Done.") 344 | print("Extracting iNaturalist dataset...") 345 | 346 | inat_path = osp.join(base_dir, "inaturalist") 347 | anns_path = osp.join(inat_path, "annotations") 348 | 349 | # Extracts iNaturalist 2017 images and annotations. May take a while. 350 | shutil.unpack_archive(img_tar_path, inat_path) 351 | shutil.unpack_archive(train_anns_zip_path, anns_path) 352 | shutil.unpack_archive(val_anns_zip_path, anns_path) 353 | 354 | print("Done.") 355 | print("Formatting iNaturalist annotations...") 356 | 357 | train_anns = json.load(open(osp.join(anns_path, "train_2017_bboxes.json"), "r")) 358 | val_anns = json.load(open(osp.join(anns_path, "val_2017_bboxes.json"), "r")) 359 | 360 | # Makes clean fine-grained annotation files. 361 | id_to_name = {c["id"]: c["name"] for c in train_anns["categories"]} 362 | ann_cat_ids = [ann["category_id"] for ann in train_anns["annotations"]] 363 | class_names = [id_to_name[ann["category_id"]] for ann in train_anns["annotations"]] 364 | class_names = sorted(list(set(class_names))) 365 | classes = [{"id": j, "name": n} for j, n in enumerate(class_names)] 366 | name_to_id = {c["name"]: c["id"] for c in classes} 367 | old_to_new = {c["id"]: name_to_id[c["name"]] for c in train_anns["categories"]} 368 | 369 | for split, anns in zip(("train", "val"), (train_anns, val_anns)): 370 | clean_anns = deepcopy(anns) 371 | clean_anns["categories"] = classes 372 | for ann in clean_anns["annotations"]: 373 | ann["category_id"] = old_to_new[ann["category_id"]] 374 | 375 | clean_anns_path = osp.join(anns_path, f"{split}_2017_bboxes_clean.json") 376 | with open(clean_anns_path, "w") as f: 377 | json.dump(clean_anns, f) 378 | 379 | # Makes superclass annotation files. 380 | superclass_names = sorted(list(set([c["supercategory"] for c in train_anns["categories"]]))) 381 | superclasses = [{"id": j, "name": n} for j, n in enumerate(superclass_names)] 382 | name_to_id = {c["name"]: c["id"] for c in superclasses} 383 | old_to_new = {c["id"]: name_to_id[c["supercategory"]] for c in train_anns["categories"]} 384 | 385 | for split, anns in zip(("train", "val"), (train_anns, val_anns)): 386 | superclass_anns = deepcopy(anns) 387 | superclass_anns["categories"] = superclasses 388 | for ann in superclass_anns["annotations"]: 389 | ann["category_id"] = old_to_new[ann["category_id"]] 390 | 391 | superclass_anns_path = osp.join(anns_path, f"{split}_2017_bboxes_superclass.json") 392 | with open(superclass_anns_path, "w") as f: 393 | json.dump(superclass_anns, f) 394 | 395 | print("Done.") 396 | 397 | if __name__ == "__main__": 398 | parser = Parser() 399 | 400 | parser.add( 401 | "--base_dir", 402 | default="data", 403 | help="Where to extract dataset files.", 404 | ) 405 | parser.add( 406 | "--datasets", 407 | choices=["all", "fgvc", "fsod", "inaturalist"], 408 | default="all", 409 | nargs="+", 410 | help="Which dataset(s) to download.", 411 | ) 412 | 413 | args = parser.parse_args() 414 | if "all" in args.datasets: 415 | args.datasets = ["fgvc", "fsod", "inaturalist"] 416 | 417 | os.makedirs(args.base_dir, exist_ok=True) 418 | 419 | if "fgvc" in args.datasets: 420 | fgvc_download(args.base_dir) 421 | if "fsod" in args.datasets: 422 | fsod_download(args.base_dir) 423 | fsod_split(args.base_dir) 424 | if "inaturalist" in args.datasets: 425 | inaturalist_download(args.base_dir) 426 | 427 | -------------------------------------------------------------------------------- /src/model/ws_detr.py: -------------------------------------------------------------------------------- 1 | """Defines WS-DETR LightningModule.""" 2 | 3 | # Imports Python builtins. 4 | from io import StringIO 5 | import math 6 | import sys 7 | 8 | # Imports PyTorch packages. 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | from torch.optim import AdamW 13 | from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR 14 | import torchvision.transforms as T 15 | 16 | # Imports other packages. 17 | from azureml.core.run import Run 18 | import pytorch_lightning as pl 19 | 20 | # Imports local packages. 21 | from coco_tools.coco import coco_evaluate 22 | from utils.box_ops import box_cxcywh_to_xyxy 23 | from utils.misc import ( 24 | compute_accuracy, 25 | exclude_params, 26 | gather_coco_results_across_gpus, 27 | inverse_sigmoid, 28 | NestedTensor, 29 | nested_tensor_from_tensor_list, 30 | save_infer_img, 31 | ) 32 | 33 | # Imports local model packages. 34 | from .backbone import build_backbone 35 | from .deformable_transformer import DeformableTransformer 36 | from .mil_loss import mil_loss 37 | from .mlp import MLP 38 | from .model_args import add_model_args 39 | from .postprocess import postprocess 40 | 41 | 42 | class WS_DETR(pl.LightningModule): 43 | """Defines Weakly Supervised Detection Transformer (WS-DETR).""" 44 | 45 | def __init__( 46 | self, 47 | args, 48 | class_agnostic_detector=None, 49 | class_names=None, 50 | coco_groundtruth=None, 51 | ): 52 | """Initializes WS-DETR with backbone, Transformer, and embeddings.""" 53 | 54 | super().__init__() 55 | 56 | # Saves class names for visualization. 57 | if args.classes == 2: 58 | args.offset = 0 59 | self.class_names = ["no object", "object"] 60 | else: 61 | self.class_names = class_names 62 | 63 | # Saves hyperparameters to self.hparams. 64 | self.save_hyperparameters(args) 65 | 66 | # Saves instance variables. 67 | self.class_agnostic_detector = class_agnostic_detector 68 | self.coco_groundtruth = coco_groundtruth 69 | self.queries = args.queries 70 | self.feature_levels = args.feature_levels 71 | 72 | # Initializes backbone and Transformer. 73 | self.backbone = build_backbone(args) 74 | self.transformer = DeformableTransformer(args) 75 | self.query_embed = nn.Embedding(args.queries, args.hidden_dim * 2) 76 | 77 | # Initializes box and MIL head embeddings. 78 | self.bbox_embed = MLP(args.hidden_dim, args.hidden_dim, 4, 3) 79 | self.det_embed = nn.Linear(args.hidden_dim, args.classes) 80 | self.class_embed = nn.Linear(args.hidden_dim, args.classes) 81 | 82 | # Initializes input projection. 83 | self.init_input_proj(args.feature_levels, args.hidden_dim) 84 | 85 | # Initializes weights and biases for embeddings. 86 | self.init_weights_and_biases(args.classes) 87 | 88 | # Duplicates embeddings for each decoder layer. 89 | layers = range(self.transformer.decoder.num_layers) 90 | self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in layers]) 91 | self.det_embed = nn.ModuleList([self.det_embed for _ in layers]) 92 | self.class_embed = nn.ModuleList([self.class_embed for _ in layers]) 93 | 94 | # Shares class-agnostic query embedding with multi-class model. 95 | if class_agnostic_detector: 96 | self.query_embed = class_agnostic_detector.query_embed 97 | 98 | # Freezes query embedding. 99 | for p in self.query_embed.parameters(): 100 | p.requires_grad = False 101 | 102 | def init_input_proj(self, feature_levels, hidden_dim): 103 | """Initializes input projection.""" 104 | 105 | if feature_levels > 1: 106 | # Initializes multi-scale input projection. 107 | num_backbone_outs = len(self.backbone.strides) 108 | input_proj_list = [] 109 | 110 | for _ in range(num_backbone_outs): 111 | in_channels = self.backbone.num_channels[_] 112 | input_proj_list.append(nn.Sequential( 113 | nn.Conv2d(in_channels, hidden_dim, kernel_size=1), 114 | nn.GroupNorm(32, hidden_dim), 115 | )) 116 | 117 | for _ in range(feature_levels - num_backbone_outs): 118 | input_proj_list.append(nn.Sequential( 119 | nn.Conv2d( 120 | in_channels, 121 | hidden_dim, 122 | kernel_size=3, 123 | stride=2, 124 | padding=1, 125 | ), 126 | nn.GroupNorm(32, hidden_dim), 127 | )) 128 | in_channels = hidden_dim 129 | self.input_proj = nn.ModuleList(input_proj_list) 130 | else: 131 | # Initializes single-scale input projection. 132 | self.input_proj = nn.ModuleList([ 133 | nn.Sequential( 134 | nn.Conv2d( 135 | self.backbone.num_channels[0], 136 | hidden_dim, 137 | kernel_size=1, 138 | ), 139 | nn.GroupNorm(32, hidden_dim), 140 | )]) 141 | 142 | def init_weights_and_biases(self, classes): 143 | """Initializes embedding weights and biases.""" 144 | 145 | prior_prob = 0.01 146 | bias_value = -math.log((1 - prior_prob) / prior_prob) 147 | 148 | # Initializes MIL head embeddings. 149 | self.det_embed.bias.data = torch.ones(classes) * bias_value 150 | self.class_embed.bias.data = torch.ones(classes) * bias_value 151 | 152 | # Initializes bbox embeddings. 153 | nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) 154 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) 155 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) 156 | 157 | # Initializes input projection. 158 | for proj in self.input_proj: 159 | nn.init.xavier_uniform_(proj[0].weight, gain=1) 160 | nn.init.constant_(proj[0].bias, 0) 161 | 162 | def forward(self, imgs): 163 | """Applies WS-DETR to a batch of images.""" 164 | 165 | # Casts images to NestedTensor. 166 | if not isinstance(imgs, NestedTensor): 167 | imgs = nested_tensor_from_tensor_list(imgs) 168 | 169 | # Computes agnostic boxes and object confidences. 170 | if self.class_agnostic_detector: 171 | agnostic_outputs = self.class_agnostic_detector(imgs) 172 | agnostic_boxes = agnostic_outputs["boxes"] 173 | agnostic_scores = agnostic_outputs["classes_logits"].sigmoid() 174 | obj_confs = agnostic_scores[:, :, 1] 175 | 176 | # Extracts features and position embedding using backbone. 177 | features, pos = self.backbone(imgs) 178 | 179 | # Decomposes NestedTensor and feeds image through input projection. 180 | srcs = [] 181 | masks = [] 182 | for l, feat in enumerate(features): 183 | src, mask = feat.decompose() 184 | srcs.append(self.input_proj[l](src)) 185 | masks.append(mask) 186 | assert mask is not None 187 | 188 | # Computes multi-scale feature maps. 189 | if self.feature_levels > len(srcs): 190 | _len_srcs = len(srcs) 191 | for l in range(_len_srcs, self.feature_levels): 192 | if l == _len_srcs: 193 | src = self.input_proj[l](features[-1].tensors) 194 | else: 195 | src = self.input_proj[l](srcs[-1]) 196 | m = imgs.mask 197 | mask = F.interpolate(m[None].float(), size=src.shape[-2:]) 198 | mask = mask.to(torch.bool)[0] 199 | pos_l = self.backbone[1](NestedTensor(src, mask)).to(src.dtype) 200 | srcs.append(src) 201 | masks.append(mask) 202 | pos.append(pos_l) 203 | 204 | # Passes images, masks, and position embedding through Transformer 205 | # with query embedding applied in the decoder. 206 | hs, init_reference, inter_references = self.transformer( 207 | srcs, 208 | masks, 209 | pos, 210 | self.query_embed.weight, 211 | ) 212 | 213 | # Computes multi-level embedding outputs from Transformer embedding. 214 | boxes = [] 215 | classes = [] 216 | dets = [] 217 | for lvl in range(hs.shape[0]): 218 | lvl_boxes = self.bbox_embed[lvl](hs[lvl]) 219 | lvl_classes = self.class_embed[lvl](hs[lvl]) 220 | lvl_dets = self.det_embed[lvl](hs[lvl]) 221 | 222 | # Postprocesses box embedding with reference points. 223 | if lvl == 0: 224 | reference = init_reference 225 | else: 226 | reference = inter_references[lvl - 1] 227 | reference = inverse_sigmoid(reference) 228 | if reference.shape[-1] == 4: 229 | lvl_boxes += reference 230 | else: 231 | assert reference.shape[-1] == 2 232 | lvl_boxes[..., :2] += reference 233 | 234 | lvl_boxes = lvl_boxes.sigmoid() 235 | boxes.append(lvl_boxes) 236 | classes.append(lvl_classes) 237 | dets.append(lvl_dets) 238 | 239 | # Converts lists to tensors. 240 | boxes = torch.stack(boxes) 241 | classes = torch.stack(classes) 242 | dets = torch.stack(dets) 243 | 244 | # Combines results into output dict. 245 | out = { 246 | "boxes": boxes[-1], 247 | "classes_logits": classes[-1], 248 | "dets_logits": dets[-1], 249 | } 250 | 251 | # Replaces box prediction with class agnostic boxes and 252 | # adds objectness confidences to results. 253 | if self.class_agnostic_detector: 254 | out["boxes"] = agnostic_boxes 255 | out["obj_confs"] = obj_confs 256 | 257 | return out 258 | 259 | def configure_optimizers(self): 260 | """Configures AdamW optimizer and StepLR scheduler.""" 261 | 262 | # Separates backbone, DETR, and MIL parameters. 263 | backbone_params = list(self.backbone.parameters()) 264 | backbone_params.extend(list(self.input_proj.parameters())) 265 | 266 | mil_params = list(self.class_embed.parameters()) 267 | mil_params.extend(list(self.det_embed.parameters())) 268 | 269 | to_exclude = [backbone_params, mil_params] 270 | detr_params = exclude_params( 271 | self.parameters(), 272 | to_exclude, 273 | ) 274 | 275 | # Assigns different learning rates to backbone, DETR, and MIL head. 276 | param_dicts = [ 277 | { 278 | "params": mil_params, 279 | "lr": self.hparams.lr_mil, 280 | }, 281 | { 282 | "params": detr_params, 283 | "lr": self.hparams.lr_detr, 284 | }, 285 | { 286 | "params": backbone_params, 287 | "lr": self.hparams.lr_backbone, 288 | }, 289 | ] 290 | 291 | # Initializes AdamW optimizer with specified LR and weight decay. 292 | optimizer = AdamW( 293 | param_dicts, 294 | weight_decay=self.hparams.weight_decay, 295 | ) 296 | 297 | if self.hparams.lr_patience and self.hparams.lr_step_size: 298 | raise ValueError( 299 | "Please only enable one of lr_patience (for ReduceLROnPlateau" 300 | " scheduler) and lr_step_size (for StepLR scheduler)." 301 | ) 302 | 303 | # Initializes scheduler which drops LR by a factor of 10 304 | # if it does not decrease within lr_patience epochs. 305 | if self.hparams.lr_patience: 306 | scheduler = ReduceLROnPlateau( 307 | optimizer, 308 | patience=self.hparams.lr_patience, 309 | ) 310 | # Initializes scheduler which drops LR by a factor of 10 311 | # after every lr_step_size epochs. 312 | elif self.hparams.lr_step_size: 313 | scheduler = StepLR( 314 | optimizer, 315 | self.hparams.lr_step_size, 316 | gamma=self.hparams.lr_drop, 317 | ) 318 | 319 | # Builds optimizer config as expected by PL. 320 | cfg = { 321 | "optimizer": optimizer, 322 | "lr_scheduler": { 323 | "scheduler": scheduler, 324 | "monitor": "val_loss", 325 | }, 326 | } 327 | 328 | return cfg 329 | 330 | def forward_with_loss(self, batch, idx): 331 | """Computes prediction and loss.""" 332 | 333 | imgs, targets = batch 334 | 335 | outputs = self(imgs) 336 | 337 | # Computes MIL and objectness loss. 338 | mil, obj = mil_loss( 339 | outputs, 340 | targets, 341 | joint_probability=self.hparams.joint_probability, 342 | objectness_scale=self.hparams.objectness_scale, 343 | offset=self.hparams.offset, 344 | sparse=self.hparams.sparse, 345 | ) 346 | 347 | loss = torch.stack((mil, obj)).sum() 348 | 349 | return outputs, loss 350 | 351 | def training_step(self, batch, idx): 352 | """Computes loss.""" 353 | 354 | _, loss = self.forward_with_loss(batch, idx) 355 | 356 | return loss 357 | 358 | def training_epoch_end(self, training_step_outputs): 359 | """Computes and logs epoch training loss.""" 360 | 361 | # Gathers loss across GPUs. 362 | loss = torch.stack(training_step_outputs).mean() 363 | loss = self.all_gather(loss).mean.item() 364 | 365 | if self.trainer.is_global_zero: 366 | try: 367 | # Logs to AzureML. 368 | writer = Run.get_context(allow_offline=False) 369 | writer.log("Train Loss", loss) 370 | except: 371 | pass 372 | 373 | def validation_step(self, batch, idx): 374 | """Computes loss and postprocesses prediction.""" 375 | 376 | imgs, targets = batch 377 | orig_sizes = torch.stack([t["orig_size"] for t in targets]) 378 | 379 | # Plots predictions with ground-truth boxes. 380 | if idx < self.hparams.viz_test_batches: 381 | names = [str(t["image_id"].item()) + ".jpg" for t in targets] 382 | self.predict_helper(imgs, names, orig_sizes, targets=targets) 383 | 384 | outputs, loss = self.forward_with_loss(batch, idx) 385 | 386 | # Postprocesses model outputs for COCO metrics computation. 387 | results = postprocess( 388 | outputs, 389 | orig_sizes, 390 | joint_probability=self.hparams.joint_probability, 391 | nms_thresh=self.hparams.nms_thresh, 392 | offset=self.hparams.offset, 393 | sparse=self.hparams.sparse, 394 | supervised=self.hparams.supervised, 395 | ) 396 | 397 | for target, result in zip(targets, results): 398 | result["image_id"] = target["image_id"] 399 | 400 | return results, loss 401 | 402 | def validation_epoch_end(self, validation_step_outputs): 403 | """Computes COCO metrics over all validation batches.""" 404 | 405 | results = [] 406 | losses = [] 407 | for result, loss in validation_step_outputs: 408 | results.extend(result) 409 | losses.append(loss) 410 | 411 | # Gathers loss across GPUs. 412 | loss = torch.stack(losses).mean() 413 | loss = self.all_gather(loss).mean().item() 414 | 415 | # Gathers COCO results across GPUs. 416 | results = self.all_gather(results) 417 | coco_results = gather_coco_results_across_gpus(results) 418 | 419 | # Performs COCO evaluation while suppressing prints so 420 | # it doesn't print on every GPU. I tried doing evaluation 421 | # on rank zero only, but it doesn't work (possible PL bug). 422 | coco_prints = StringIO() 423 | sys.stdout = coco_prints 424 | stats = coco_evaluate(coco_results, self.coco_groundtruth) 425 | sys.stdout = sys.__stdout__ 426 | top1_acc, top5_acc = compute_accuracy( 427 | coco_results, 428 | self.coco_groundtruth, 429 | ) 430 | 431 | # Logs to PL logger and syncs across GPUs. 432 | self.log("val_loss", loss, sync_dist=True) 433 | self.log("mAP", stats[0], sync_dist=True) 434 | self.log("AP50", stats[1], sync_dist=True) 435 | self.log("AP75", stats[2], sync_dist=True) 436 | self.log("Top1 Acc", top1_acc, sync_dist=True) 437 | self.log("Top5 Acc", top5_acc, sync_dist=True) 438 | 439 | if self.trainer.is_global_zero: 440 | try: 441 | # Prints COCO evaluation results. 442 | print(coco_prints.getvalue().strip("\n")) 443 | 444 | # Logs to AzureML. 445 | writer = Run.get_context(allow_offline=False) 446 | writer.log("Val Loss", loss) 447 | writer.log("mAP", stats[0]) 448 | writer.log("AP50", stats[1]) 449 | writer.log("AP75", stats[2]) 450 | writer.log("Top1 Acc", top1_acc) 451 | writer.log("Top5 Acc", top5_acc) 452 | except: 453 | pass 454 | 455 | def test_step(self, batch, idx): 456 | """Computes loss and postprocesses prediction.""" 457 | 458 | return self.validation_step(batch, idx) 459 | 460 | def test_epoch_end(self, test_step_outputs): 461 | """Computes COCO metrics over all test batches.""" 462 | 463 | return self.validation_epoch_end(test_step_outputs) 464 | 465 | def predict_helper(self, imgs, names, orig_sizes, targets=None): 466 | """Helper function for prediction and image saving.""" 467 | 468 | outputs = self(imgs) 469 | 470 | results = postprocess( 471 | outputs, 472 | orig_sizes, 473 | joint_probability=self.hparams.joint_probability, 474 | nms_thresh=self.hparams.nms_thresh, 475 | offset=self.hparams.offset, 476 | sparse=self.hparams.sparse, 477 | supervised=self.hparams.supervised, 478 | ) 479 | 480 | if not targets: 481 | targets = [None for _ in range(len(results))] 482 | 483 | z = zip(imgs.tensors, imgs.mask, names, orig_sizes, results, targets) 484 | for img, mask, name, orig_size, result, target in z: 485 | # Applies mask to get original tensor. 486 | for j, x in enumerate(mask[0]): 487 | if x: 488 | break 489 | for i, x in enumerate(torch.transpose(mask, 0, 1)[0]): 490 | if x: 491 | break 492 | img = img[:, :i, :j] 493 | img = T.Resize(orig_size.int().tolist())(img) 494 | 495 | # Thresholds output for display. 496 | keep = result["confs"] > self.hparams.infer_display_thresh 497 | boxes = result["boxes"][keep] 498 | confs = result["confs"][keep] 499 | preds = result["preds"][keep] 500 | 501 | target_boxes = None if not target else target["boxes"] 502 | target_labels = None if not target else target["box_labels"] 503 | 504 | if target_boxes is not None: 505 | target_boxes = box_cxcywh_to_xyxy(target_boxes) 506 | 507 | # Converts from [0, 1] to [0, height] coordinates. 508 | img_h, img_w = orig_size 509 | dims = [img_w, img_h, img_w, img_h] 510 | scale_fct = torch.tensor(dims).type_as(target_boxes) 511 | scale_fct = scale_fct.unsqueeze(0).repeat(len(target_boxes), 1) 512 | target_boxes = target_boxes * scale_fct 513 | 514 | # Plots prediction and ground-truth boxes and saves image. 515 | save_infer_img( 516 | img, 517 | self.hparams.imgs_dir, 518 | name, 519 | self.class_names, 520 | boxes, 521 | confs, 522 | preds, 523 | self.hparams.offset, 524 | target_boxes=target_boxes, 525 | target_labels=target_labels, 526 | ) 527 | 528 | def predict_step(self, batch, batch_idx, dataloader_idx=None): 529 | """Performs inference without targets and saves images.""" 530 | 531 | imgs = batch[0] 532 | names = batch[1] 533 | orig_sizes = batch[2] 534 | orig_sizes = [torch.tensor(x) for x in orig_sizes] 535 | orig_sizes = torch.stack(orig_sizes).type_as(imgs.tensors) 536 | 537 | return self.predict_helper(imgs, names, orig_sizes) 538 | 539 | @staticmethod 540 | def add_model_specific_args(parent_parser): 541 | """Adds model configuration arguments to parser.""" 542 | 543 | return add_model_args(parent_parser) 544 | 545 | -------------------------------------------------------------------------------- /src/model/ops/src/cuda/ms_deform_im2col_cuda.cuh: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************** 7 | * Modified from DCN (https://github.com/msracver/Deformable-ConvNets) 8 | * Copyright (c) 2018 Microsoft 9 | ************************************************************************** 10 | */ 11 | 12 | #include 13 | #include 14 | #include 15 | 16 | #include 17 | #include 18 | 19 | #include 20 | 21 | #define CUDA_KERNEL_LOOP(i, n) \ 22 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 23 | i < (n); \ 24 | i += blockDim.x * gridDim.x) 25 | 26 | const int CUDA_NUM_THREADS = 1024; 27 | inline int GET_BLOCKS(const int N, const int num_threads) 28 | { 29 | return (N + num_threads - 1) / num_threads; 30 | } 31 | 32 | 33 | template 34 | __device__ scalar_t ms_deform_attn_im2col_bilinear(const scalar_t* &bottom_data, 35 | const int &height, const int &width, const int &nheads, const int &channels, 36 | const scalar_t &h, const scalar_t &w, const int &m, const int &c) 37 | { 38 | const int h_low = floor(h); 39 | const int w_low = floor(w); 40 | const int h_high = h_low + 1; 41 | const int w_high = w_low + 1; 42 | 43 | const scalar_t lh = h - h_low; 44 | const scalar_t lw = w - w_low; 45 | const scalar_t hh = 1 - lh, hw = 1 - lw; 46 | 47 | const int w_stride = nheads * channels; 48 | const int h_stride = width * w_stride; 49 | const int h_low_ptr_offset = h_low * h_stride; 50 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 51 | const int w_low_ptr_offset = w_low * w_stride; 52 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 53 | const int base_ptr = m * channels + c; 54 | 55 | scalar_t v1 = 0; 56 | if (h_low >= 0 && w_low >= 0) 57 | { 58 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 59 | v1 = bottom_data[ptr1]; 60 | } 61 | scalar_t v2 = 0; 62 | if (h_low >= 0 && w_high <= width - 1) 63 | { 64 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 65 | v2 = bottom_data[ptr2]; 66 | } 67 | scalar_t v3 = 0; 68 | if (h_high <= height - 1 && w_low >= 0) 69 | { 70 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 71 | v3 = bottom_data[ptr3]; 72 | } 73 | scalar_t v4 = 0; 74 | if (h_high <= height - 1 && w_high <= width - 1) 75 | { 76 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 77 | v4 = bottom_data[ptr4]; 78 | } 79 | 80 | const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 81 | 82 | const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 83 | return val; 84 | } 85 | 86 | 87 | template 88 | __device__ void ms_deform_attn_col2im_bilinear(const scalar_t* &bottom_data, 89 | const int &height, const int &width, const int &nheads, const int &channels, 90 | const scalar_t &h, const scalar_t &w, const int &m, const int &c, 91 | const scalar_t &top_grad, 92 | const scalar_t &attn_weight, 93 | scalar_t* &grad_value, 94 | scalar_t* grad_sampling_loc, 95 | scalar_t* grad_attn_weight) 96 | { 97 | const int h_low = floor(h); 98 | const int w_low = floor(w); 99 | const int h_high = h_low + 1; 100 | const int w_high = w_low + 1; 101 | 102 | const scalar_t lh = h - h_low; 103 | const scalar_t lw = w - w_low; 104 | const scalar_t hh = 1 - lh, hw = 1 - lw; 105 | 106 | const int w_stride = nheads * channels; 107 | const int h_stride = width * w_stride; 108 | const int h_low_ptr_offset = h_low * h_stride; 109 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 110 | const int w_low_ptr_offset = w_low * w_stride; 111 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 112 | const int base_ptr = m * channels + c; 113 | 114 | const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 115 | const scalar_t top_grad_value = top_grad * attn_weight; 116 | scalar_t grad_h_weight = 0, grad_w_weight = 0; 117 | 118 | scalar_t v1 = 0; 119 | if (h_low >= 0 && w_low >= 0) 120 | { 121 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 122 | v1 = bottom_data[ptr1]; 123 | grad_h_weight -= hw * v1; 124 | grad_w_weight -= hh * v1; 125 | atomicAdd(grad_value+ptr1, w1*top_grad_value); 126 | } 127 | scalar_t v2 = 0; 128 | if (h_low >= 0 && w_high <= width - 1) 129 | { 130 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 131 | v2 = bottom_data[ptr2]; 132 | grad_h_weight -= lw * v2; 133 | grad_w_weight += hh * v2; 134 | atomicAdd(grad_value+ptr2, w2*top_grad_value); 135 | } 136 | scalar_t v3 = 0; 137 | if (h_high <= height - 1 && w_low >= 0) 138 | { 139 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 140 | v3 = bottom_data[ptr3]; 141 | grad_h_weight += hw * v3; 142 | grad_w_weight -= lh * v3; 143 | atomicAdd(grad_value+ptr3, w3*top_grad_value); 144 | } 145 | scalar_t v4 = 0; 146 | if (h_high <= height - 1 && w_high <= width - 1) 147 | { 148 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 149 | v4 = bottom_data[ptr4]; 150 | grad_h_weight += lw * v4; 151 | grad_w_weight += lh * v4; 152 | atomicAdd(grad_value+ptr4, w4*top_grad_value); 153 | } 154 | 155 | const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 156 | *grad_attn_weight = top_grad * val; 157 | *grad_sampling_loc = width * grad_w_weight * top_grad_value; 158 | *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value; 159 | } 160 | 161 | 162 | template 163 | __device__ void ms_deform_attn_col2im_bilinear_gm(const scalar_t* &bottom_data, 164 | const int &height, const int &width, const int &nheads, const int &channels, 165 | const scalar_t &h, const scalar_t &w, const int &m, const int &c, 166 | const scalar_t &top_grad, 167 | const scalar_t &attn_weight, 168 | scalar_t* &grad_value, 169 | scalar_t* grad_sampling_loc, 170 | scalar_t* grad_attn_weight) 171 | { 172 | const int h_low = floor(h); 173 | const int w_low = floor(w); 174 | const int h_high = h_low + 1; 175 | const int w_high = w_low + 1; 176 | 177 | const scalar_t lh = h - h_low; 178 | const scalar_t lw = w - w_low; 179 | const scalar_t hh = 1 - lh, hw = 1 - lw; 180 | 181 | const int w_stride = nheads * channels; 182 | const int h_stride = width * w_stride; 183 | const int h_low_ptr_offset = h_low * h_stride; 184 | const int h_high_ptr_offset = h_low_ptr_offset + h_stride; 185 | const int w_low_ptr_offset = w_low * w_stride; 186 | const int w_high_ptr_offset = w_low_ptr_offset + w_stride; 187 | const int base_ptr = m * channels + c; 188 | 189 | const scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; 190 | const scalar_t top_grad_value = top_grad * attn_weight; 191 | scalar_t grad_h_weight = 0, grad_w_weight = 0; 192 | 193 | scalar_t v1 = 0; 194 | if (h_low >= 0 && w_low >= 0) 195 | { 196 | const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; 197 | v1 = bottom_data[ptr1]; 198 | grad_h_weight -= hw * v1; 199 | grad_w_weight -= hh * v1; 200 | atomicAdd(grad_value+ptr1, w1*top_grad_value); 201 | } 202 | scalar_t v2 = 0; 203 | if (h_low >= 0 && w_high <= width - 1) 204 | { 205 | const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; 206 | v2 = bottom_data[ptr2]; 207 | grad_h_weight -= lw * v2; 208 | grad_w_weight += hh * v2; 209 | atomicAdd(grad_value+ptr2, w2*top_grad_value); 210 | } 211 | scalar_t v3 = 0; 212 | if (h_high <= height - 1 && w_low >= 0) 213 | { 214 | const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; 215 | v3 = bottom_data[ptr3]; 216 | grad_h_weight += hw * v3; 217 | grad_w_weight -= lh * v3; 218 | atomicAdd(grad_value+ptr3, w3*top_grad_value); 219 | } 220 | scalar_t v4 = 0; 221 | if (h_high <= height - 1 && w_high <= width - 1) 222 | { 223 | const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; 224 | v4 = bottom_data[ptr4]; 225 | grad_h_weight += lw * v4; 226 | grad_w_weight += lh * v4; 227 | atomicAdd(grad_value+ptr4, w4*top_grad_value); 228 | } 229 | 230 | const scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); 231 | atomicAdd(grad_attn_weight, top_grad * val); 232 | atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value); 233 | atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value); 234 | } 235 | 236 | 237 | template 238 | __global__ void ms_deformable_im2col_gpu_kernel(const int n, 239 | const scalar_t *data_value, 240 | const int64_t *data_spatial_shapes, 241 | const int64_t *data_level_start_index, 242 | const scalar_t *data_sampling_loc, 243 | const scalar_t *data_attn_weight, 244 | const int batch_size, 245 | const int spatial_size, 246 | const int num_heads, 247 | const int channels, 248 | const int num_levels, 249 | const int num_query, 250 | const int num_point, 251 | scalar_t *data_col) 252 | { 253 | CUDA_KERNEL_LOOP(index, n) 254 | { 255 | int _temp = index; 256 | const int c_col = _temp % channels; 257 | _temp /= channels; 258 | const int sampling_index = _temp; 259 | const int m_col = _temp % num_heads; 260 | _temp /= num_heads; 261 | const int q_col = _temp % num_query; 262 | _temp /= num_query; 263 | const int b_col = _temp; 264 | 265 | scalar_t *data_col_ptr = data_col + index; 266 | int data_weight_ptr = sampling_index * num_levels * num_point; 267 | int data_loc_w_ptr = data_weight_ptr << 1; 268 | const int qid_stride = num_heads * channels; 269 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 270 | scalar_t col = 0; 271 | 272 | for (int l_col=0; l_col < num_levels; ++l_col) 273 | { 274 | const int level_start_id = data_level_start_index[l_col]; 275 | const int spatial_h_ptr = l_col << 1; 276 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 277 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 278 | const scalar_t *data_value_ptr = data_value + (data_value_ptr_init_offset + level_start_id * qid_stride); 279 | for (int p_col=0; p_col < num_point; ++p_col) 280 | { 281 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 282 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 283 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 284 | 285 | const scalar_t h_im = loc_h * spatial_h - 0.5; 286 | const scalar_t w_im = loc_w * spatial_w - 0.5; 287 | 288 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 289 | { 290 | col += ms_deform_attn_im2col_bilinear(data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col) * weight; 291 | } 292 | 293 | data_weight_ptr += 1; 294 | data_loc_w_ptr += 2; 295 | } 296 | } 297 | *data_col_ptr = col; 298 | } 299 | } 300 | 301 | template 302 | __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1(const int n, 303 | const scalar_t *grad_col, 304 | const scalar_t *data_value, 305 | const int64_t *data_spatial_shapes, 306 | const int64_t *data_level_start_index, 307 | const scalar_t *data_sampling_loc, 308 | const scalar_t *data_attn_weight, 309 | const int batch_size, 310 | const int spatial_size, 311 | const int num_heads, 312 | const int channels, 313 | const int num_levels, 314 | const int num_query, 315 | const int num_point, 316 | scalar_t *grad_value, 317 | scalar_t *grad_sampling_loc, 318 | scalar_t *grad_attn_weight) 319 | { 320 | CUDA_KERNEL_LOOP(index, n) 321 | { 322 | __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; 323 | __shared__ scalar_t cache_grad_attn_weight[blockSize]; 324 | unsigned int tid = threadIdx.x; 325 | int _temp = index; 326 | const int c_col = _temp % channels; 327 | _temp /= channels; 328 | const int sampling_index = _temp; 329 | const int m_col = _temp % num_heads; 330 | _temp /= num_heads; 331 | const int q_col = _temp % num_query; 332 | _temp /= num_query; 333 | const int b_col = _temp; 334 | 335 | const scalar_t top_grad = grad_col[index]; 336 | 337 | int data_weight_ptr = sampling_index * num_levels * num_point; 338 | int data_loc_w_ptr = data_weight_ptr << 1; 339 | const int grad_sampling_ptr = data_weight_ptr; 340 | grad_sampling_loc += grad_sampling_ptr << 1; 341 | grad_attn_weight += grad_sampling_ptr; 342 | const int grad_weight_stride = 1; 343 | const int grad_loc_stride = 2; 344 | const int qid_stride = num_heads * channels; 345 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 346 | 347 | for (int l_col=0; l_col < num_levels; ++l_col) 348 | { 349 | const int level_start_id = data_level_start_index[l_col]; 350 | const int spatial_h_ptr = l_col << 1; 351 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 352 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 353 | const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; 354 | const scalar_t *data_value_ptr = data_value + value_ptr_offset; 355 | scalar_t *grad_value_ptr = grad_value + value_ptr_offset; 356 | 357 | for (int p_col=0; p_col < num_point; ++p_col) 358 | { 359 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 360 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 361 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 362 | 363 | const scalar_t h_im = loc_h * spatial_h - 0.5; 364 | const scalar_t w_im = loc_w * spatial_w - 0.5; 365 | *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; 366 | *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; 367 | *(cache_grad_attn_weight+threadIdx.x)=0; 368 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 369 | { 370 | ms_deform_attn_col2im_bilinear( 371 | data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, 372 | top_grad, weight, grad_value_ptr, 373 | cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); 374 | } 375 | 376 | __syncthreads(); 377 | if (tid == 0) 378 | { 379 | scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; 380 | int sid=2; 381 | for (unsigned int tid = 1; tid < blockSize; ++tid) 382 | { 383 | _grad_w += cache_grad_sampling_loc[sid]; 384 | _grad_h += cache_grad_sampling_loc[sid + 1]; 385 | _grad_a += cache_grad_attn_weight[tid]; 386 | sid += 2; 387 | } 388 | 389 | 390 | *grad_sampling_loc = _grad_w; 391 | *(grad_sampling_loc + 1) = _grad_h; 392 | *grad_attn_weight = _grad_a; 393 | } 394 | __syncthreads(); 395 | 396 | data_weight_ptr += 1; 397 | data_loc_w_ptr += 2; 398 | grad_attn_weight += grad_weight_stride; 399 | grad_sampling_loc += grad_loc_stride; 400 | } 401 | } 402 | } 403 | } 404 | 405 | 406 | template 407 | __global__ void ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2(const int n, 408 | const scalar_t *grad_col, 409 | const scalar_t *data_value, 410 | const int64_t *data_spatial_shapes, 411 | const int64_t *data_level_start_index, 412 | const scalar_t *data_sampling_loc, 413 | const scalar_t *data_attn_weight, 414 | const int batch_size, 415 | const int spatial_size, 416 | const int num_heads, 417 | const int channels, 418 | const int num_levels, 419 | const int num_query, 420 | const int num_point, 421 | scalar_t *grad_value, 422 | scalar_t *grad_sampling_loc, 423 | scalar_t *grad_attn_weight) 424 | { 425 | CUDA_KERNEL_LOOP(index, n) 426 | { 427 | __shared__ scalar_t cache_grad_sampling_loc[blockSize * 2]; 428 | __shared__ scalar_t cache_grad_attn_weight[blockSize]; 429 | unsigned int tid = threadIdx.x; 430 | int _temp = index; 431 | const int c_col = _temp % channels; 432 | _temp /= channels; 433 | const int sampling_index = _temp; 434 | const int m_col = _temp % num_heads; 435 | _temp /= num_heads; 436 | const int q_col = _temp % num_query; 437 | _temp /= num_query; 438 | const int b_col = _temp; 439 | 440 | const scalar_t top_grad = grad_col[index]; 441 | 442 | int data_weight_ptr = sampling_index * num_levels * num_point; 443 | int data_loc_w_ptr = data_weight_ptr << 1; 444 | const int grad_sampling_ptr = data_weight_ptr; 445 | grad_sampling_loc += grad_sampling_ptr << 1; 446 | grad_attn_weight += grad_sampling_ptr; 447 | const int grad_weight_stride = 1; 448 | const int grad_loc_stride = 2; 449 | const int qid_stride = num_heads * channels; 450 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 451 | 452 | for (int l_col=0; l_col < num_levels; ++l_col) 453 | { 454 | const int level_start_id = data_level_start_index[l_col]; 455 | const int spatial_h_ptr = l_col << 1; 456 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 457 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 458 | const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; 459 | const scalar_t *data_value_ptr = data_value + value_ptr_offset; 460 | scalar_t *grad_value_ptr = grad_value + value_ptr_offset; 461 | 462 | for (int p_col=0; p_col < num_point; ++p_col) 463 | { 464 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 465 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 466 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 467 | 468 | const scalar_t h_im = loc_h * spatial_h - 0.5; 469 | const scalar_t w_im = loc_w * spatial_w - 0.5; 470 | *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; 471 | *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; 472 | *(cache_grad_attn_weight+threadIdx.x)=0; 473 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 474 | { 475 | ms_deform_attn_col2im_bilinear( 476 | data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, 477 | top_grad, weight, grad_value_ptr, 478 | cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); 479 | } 480 | 481 | __syncthreads(); 482 | 483 | for (unsigned int s=blockSize/2; s>0; s>>=1) 484 | { 485 | if (tid < s) { 486 | const unsigned int xid1 = tid << 1; 487 | const unsigned int xid2 = (tid + s) << 1; 488 | cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; 489 | cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; 490 | cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; 491 | } 492 | __syncthreads(); 493 | } 494 | 495 | if (tid == 0) 496 | { 497 | *grad_sampling_loc = cache_grad_sampling_loc[0]; 498 | *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; 499 | *grad_attn_weight = cache_grad_attn_weight[0]; 500 | } 501 | __syncthreads(); 502 | 503 | data_weight_ptr += 1; 504 | data_loc_w_ptr += 2; 505 | grad_attn_weight += grad_weight_stride; 506 | grad_sampling_loc += grad_loc_stride; 507 | } 508 | } 509 | } 510 | } 511 | 512 | 513 | template 514 | __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v1(const int n, 515 | const scalar_t *grad_col, 516 | const scalar_t *data_value, 517 | const int64_t *data_spatial_shapes, 518 | const int64_t *data_level_start_index, 519 | const scalar_t *data_sampling_loc, 520 | const scalar_t *data_attn_weight, 521 | const int batch_size, 522 | const int spatial_size, 523 | const int num_heads, 524 | const int channels, 525 | const int num_levels, 526 | const int num_query, 527 | const int num_point, 528 | scalar_t *grad_value, 529 | scalar_t *grad_sampling_loc, 530 | scalar_t *grad_attn_weight) 531 | { 532 | CUDA_KERNEL_LOOP(index, n) 533 | { 534 | extern __shared__ int _s[]; 535 | scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; 536 | scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; 537 | unsigned int tid = threadIdx.x; 538 | int _temp = index; 539 | const int c_col = _temp % channels; 540 | _temp /= channels; 541 | const int sampling_index = _temp; 542 | const int m_col = _temp % num_heads; 543 | _temp /= num_heads; 544 | const int q_col = _temp % num_query; 545 | _temp /= num_query; 546 | const int b_col = _temp; 547 | 548 | const scalar_t top_grad = grad_col[index]; 549 | 550 | int data_weight_ptr = sampling_index * num_levels * num_point; 551 | int data_loc_w_ptr = data_weight_ptr << 1; 552 | const int grad_sampling_ptr = data_weight_ptr; 553 | grad_sampling_loc += grad_sampling_ptr << 1; 554 | grad_attn_weight += grad_sampling_ptr; 555 | const int grad_weight_stride = 1; 556 | const int grad_loc_stride = 2; 557 | const int qid_stride = num_heads * channels; 558 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 559 | 560 | for (int l_col=0; l_col < num_levels; ++l_col) 561 | { 562 | const int level_start_id = data_level_start_index[l_col]; 563 | const int spatial_h_ptr = l_col << 1; 564 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 565 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 566 | const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; 567 | const scalar_t *data_value_ptr = data_value + value_ptr_offset; 568 | scalar_t *grad_value_ptr = grad_value + value_ptr_offset; 569 | 570 | for (int p_col=0; p_col < num_point; ++p_col) 571 | { 572 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 573 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 574 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 575 | 576 | const scalar_t h_im = loc_h * spatial_h - 0.5; 577 | const scalar_t w_im = loc_w * spatial_w - 0.5; 578 | *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; 579 | *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; 580 | *(cache_grad_attn_weight+threadIdx.x)=0; 581 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 582 | { 583 | ms_deform_attn_col2im_bilinear( 584 | data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, 585 | top_grad, weight, grad_value_ptr, 586 | cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); 587 | } 588 | 589 | __syncthreads(); 590 | if (tid == 0) 591 | { 592 | scalar_t _grad_w=cache_grad_sampling_loc[0], _grad_h=cache_grad_sampling_loc[1], _grad_a=cache_grad_attn_weight[0]; 593 | int sid=2; 594 | for (unsigned int tid = 1; tid < blockDim.x; ++tid) 595 | { 596 | _grad_w += cache_grad_sampling_loc[sid]; 597 | _grad_h += cache_grad_sampling_loc[sid + 1]; 598 | _grad_a += cache_grad_attn_weight[tid]; 599 | sid += 2; 600 | } 601 | 602 | 603 | *grad_sampling_loc = _grad_w; 604 | *(grad_sampling_loc + 1) = _grad_h; 605 | *grad_attn_weight = _grad_a; 606 | } 607 | __syncthreads(); 608 | 609 | data_weight_ptr += 1; 610 | data_loc_w_ptr += 2; 611 | grad_attn_weight += grad_weight_stride; 612 | grad_sampling_loc += grad_loc_stride; 613 | } 614 | } 615 | } 616 | } 617 | 618 | template 619 | __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2(const int n, 620 | const scalar_t *grad_col, 621 | const scalar_t *data_value, 622 | const int64_t *data_spatial_shapes, 623 | const int64_t *data_level_start_index, 624 | const scalar_t *data_sampling_loc, 625 | const scalar_t *data_attn_weight, 626 | const int batch_size, 627 | const int spatial_size, 628 | const int num_heads, 629 | const int channels, 630 | const int num_levels, 631 | const int num_query, 632 | const int num_point, 633 | scalar_t *grad_value, 634 | scalar_t *grad_sampling_loc, 635 | scalar_t *grad_attn_weight) 636 | { 637 | CUDA_KERNEL_LOOP(index, n) 638 | { 639 | extern __shared__ int _s[]; 640 | scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; 641 | scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; 642 | unsigned int tid = threadIdx.x; 643 | int _temp = index; 644 | const int c_col = _temp % channels; 645 | _temp /= channels; 646 | const int sampling_index = _temp; 647 | const int m_col = _temp % num_heads; 648 | _temp /= num_heads; 649 | const int q_col = _temp % num_query; 650 | _temp /= num_query; 651 | const int b_col = _temp; 652 | 653 | const scalar_t top_grad = grad_col[index]; 654 | 655 | int data_weight_ptr = sampling_index * num_levels * num_point; 656 | int data_loc_w_ptr = data_weight_ptr << 1; 657 | const int grad_sampling_ptr = data_weight_ptr; 658 | grad_sampling_loc += grad_sampling_ptr << 1; 659 | grad_attn_weight += grad_sampling_ptr; 660 | const int grad_weight_stride = 1; 661 | const int grad_loc_stride = 2; 662 | const int qid_stride = num_heads * channels; 663 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 664 | 665 | for (int l_col=0; l_col < num_levels; ++l_col) 666 | { 667 | const int level_start_id = data_level_start_index[l_col]; 668 | const int spatial_h_ptr = l_col << 1; 669 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 670 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 671 | const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; 672 | const scalar_t *data_value_ptr = data_value + value_ptr_offset; 673 | scalar_t *grad_value_ptr = grad_value + value_ptr_offset; 674 | 675 | for (int p_col=0; p_col < num_point; ++p_col) 676 | { 677 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 678 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 679 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 680 | 681 | const scalar_t h_im = loc_h * spatial_h - 0.5; 682 | const scalar_t w_im = loc_w * spatial_w - 0.5; 683 | *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; 684 | *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; 685 | *(cache_grad_attn_weight+threadIdx.x)=0; 686 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 687 | { 688 | ms_deform_attn_col2im_bilinear( 689 | data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, 690 | top_grad, weight, grad_value_ptr, 691 | cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); 692 | } 693 | 694 | __syncthreads(); 695 | 696 | for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) 697 | { 698 | if (tid < s) { 699 | const unsigned int xid1 = tid << 1; 700 | const unsigned int xid2 = (tid + s) << 1; 701 | cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; 702 | cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; 703 | cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; 704 | if (tid + (s << 1) < spre) 705 | { 706 | cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; 707 | cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; 708 | cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; 709 | } 710 | } 711 | __syncthreads(); 712 | } 713 | 714 | if (tid == 0) 715 | { 716 | *grad_sampling_loc = cache_grad_sampling_loc[0]; 717 | *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1]; 718 | *grad_attn_weight = cache_grad_attn_weight[0]; 719 | } 720 | __syncthreads(); 721 | 722 | data_weight_ptr += 1; 723 | data_loc_w_ptr += 2; 724 | grad_attn_weight += grad_weight_stride; 725 | grad_sampling_loc += grad_loc_stride; 726 | } 727 | } 728 | } 729 | } 730 | 731 | template 732 | __global__ void ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks(const int n, 733 | const scalar_t *grad_col, 734 | const scalar_t *data_value, 735 | const int64_t *data_spatial_shapes, 736 | const int64_t *data_level_start_index, 737 | const scalar_t *data_sampling_loc, 738 | const scalar_t *data_attn_weight, 739 | const int batch_size, 740 | const int spatial_size, 741 | const int num_heads, 742 | const int channels, 743 | const int num_levels, 744 | const int num_query, 745 | const int num_point, 746 | scalar_t *grad_value, 747 | scalar_t *grad_sampling_loc, 748 | scalar_t *grad_attn_weight) 749 | { 750 | CUDA_KERNEL_LOOP(index, n) 751 | { 752 | extern __shared__ int _s[]; 753 | scalar_t* cache_grad_sampling_loc = (scalar_t*)_s; 754 | scalar_t* cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x; 755 | unsigned int tid = threadIdx.x; 756 | int _temp = index; 757 | const int c_col = _temp % channels; 758 | _temp /= channels; 759 | const int sampling_index = _temp; 760 | const int m_col = _temp % num_heads; 761 | _temp /= num_heads; 762 | const int q_col = _temp % num_query; 763 | _temp /= num_query; 764 | const int b_col = _temp; 765 | 766 | const scalar_t top_grad = grad_col[index]; 767 | 768 | int data_weight_ptr = sampling_index * num_levels * num_point; 769 | int data_loc_w_ptr = data_weight_ptr << 1; 770 | const int grad_sampling_ptr = data_weight_ptr; 771 | grad_sampling_loc += grad_sampling_ptr << 1; 772 | grad_attn_weight += grad_sampling_ptr; 773 | const int grad_weight_stride = 1; 774 | const int grad_loc_stride = 2; 775 | const int qid_stride = num_heads * channels; 776 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 777 | 778 | for (int l_col=0; l_col < num_levels; ++l_col) 779 | { 780 | const int level_start_id = data_level_start_index[l_col]; 781 | const int spatial_h_ptr = l_col << 1; 782 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 783 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 784 | const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; 785 | const scalar_t *data_value_ptr = data_value + value_ptr_offset; 786 | scalar_t *grad_value_ptr = grad_value + value_ptr_offset; 787 | 788 | for (int p_col=0; p_col < num_point; ++p_col) 789 | { 790 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 791 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 792 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 793 | 794 | const scalar_t h_im = loc_h * spatial_h - 0.5; 795 | const scalar_t w_im = loc_w * spatial_w - 0.5; 796 | *(cache_grad_sampling_loc+(threadIdx.x << 1)) = 0; 797 | *(cache_grad_sampling_loc+((threadIdx.x << 1) + 1)) = 0; 798 | *(cache_grad_attn_weight+threadIdx.x)=0; 799 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 800 | { 801 | ms_deform_attn_col2im_bilinear( 802 | data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, 803 | top_grad, weight, grad_value_ptr, 804 | cache_grad_sampling_loc+(threadIdx.x << 1), cache_grad_attn_weight+threadIdx.x); 805 | } 806 | 807 | __syncthreads(); 808 | 809 | for (unsigned int s=blockDim.x/2, spre=blockDim.x; s>0; s>>=1, spre>>=1) 810 | { 811 | if (tid < s) { 812 | const unsigned int xid1 = tid << 1; 813 | const unsigned int xid2 = (tid + s) << 1; 814 | cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s]; 815 | cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2]; 816 | cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1]; 817 | if (tid + (s << 1) < spre) 818 | { 819 | cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + (s << 1)]; 820 | cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2 + (s << 1)]; 821 | cache_grad_sampling_loc[xid1 + 1] += cache_grad_sampling_loc[xid2 + 1 + (s << 1)]; 822 | } 823 | } 824 | __syncthreads(); 825 | } 826 | 827 | if (tid == 0) 828 | { 829 | atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]); 830 | atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]); 831 | atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]); 832 | } 833 | __syncthreads(); 834 | 835 | data_weight_ptr += 1; 836 | data_loc_w_ptr += 2; 837 | grad_attn_weight += grad_weight_stride; 838 | grad_sampling_loc += grad_loc_stride; 839 | } 840 | } 841 | } 842 | } 843 | 844 | 845 | template 846 | __global__ void ms_deformable_col2im_gpu_kernel_gm(const int n, 847 | const scalar_t *grad_col, 848 | const scalar_t *data_value, 849 | const int64_t *data_spatial_shapes, 850 | const int64_t *data_level_start_index, 851 | const scalar_t *data_sampling_loc, 852 | const scalar_t *data_attn_weight, 853 | const int batch_size, 854 | const int spatial_size, 855 | const int num_heads, 856 | const int channels, 857 | const int num_levels, 858 | const int num_query, 859 | const int num_point, 860 | scalar_t *grad_value, 861 | scalar_t *grad_sampling_loc, 862 | scalar_t *grad_attn_weight) 863 | { 864 | CUDA_KERNEL_LOOP(index, n) 865 | { 866 | int _temp = index; 867 | const int c_col = _temp % channels; 868 | _temp /= channels; 869 | const int sampling_index = _temp; 870 | const int m_col = _temp % num_heads; 871 | _temp /= num_heads; 872 | const int q_col = _temp % num_query; 873 | _temp /= num_query; 874 | const int b_col = _temp; 875 | 876 | const scalar_t top_grad = grad_col[index]; 877 | 878 | int data_weight_ptr = sampling_index * num_levels * num_point; 879 | int data_loc_w_ptr = data_weight_ptr << 1; 880 | const int grad_sampling_ptr = data_weight_ptr; 881 | grad_sampling_loc += grad_sampling_ptr << 1; 882 | grad_attn_weight += grad_sampling_ptr; 883 | const int grad_weight_stride = 1; 884 | const int grad_loc_stride = 2; 885 | const int qid_stride = num_heads * channels; 886 | const int data_value_ptr_init_offset = b_col * spatial_size * qid_stride; 887 | 888 | for (int l_col=0; l_col < num_levels; ++l_col) 889 | { 890 | const int level_start_id = data_level_start_index[l_col]; 891 | const int spatial_h_ptr = l_col << 1; 892 | const int spatial_h = data_spatial_shapes[spatial_h_ptr]; 893 | const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1]; 894 | const int value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; 895 | const scalar_t *data_value_ptr = data_value + value_ptr_offset; 896 | scalar_t *grad_value_ptr = grad_value + value_ptr_offset; 897 | 898 | for (int p_col=0; p_col < num_point; ++p_col) 899 | { 900 | const scalar_t loc_w = data_sampling_loc[data_loc_w_ptr]; 901 | const scalar_t loc_h = data_sampling_loc[data_loc_w_ptr + 1]; 902 | const scalar_t weight = data_attn_weight[data_weight_ptr]; 903 | 904 | const scalar_t h_im = loc_h * spatial_h - 0.5; 905 | const scalar_t w_im = loc_w * spatial_w - 0.5; 906 | if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) 907 | { 908 | ms_deform_attn_col2im_bilinear_gm( 909 | data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im, w_im, m_col, c_col, 910 | top_grad, weight, grad_value_ptr, 911 | grad_sampling_loc, grad_attn_weight); 912 | } 913 | data_weight_ptr += 1; 914 | data_loc_w_ptr += 2; 915 | grad_attn_weight += grad_weight_stride; 916 | grad_sampling_loc += grad_loc_stride; 917 | } 918 | } 919 | } 920 | } 921 | 922 | 923 | template 924 | void ms_deformable_im2col_cuda(cudaStream_t stream, 925 | const scalar_t* data_value, 926 | const int64_t* data_spatial_shapes, 927 | const int64_t* data_level_start_index, 928 | const scalar_t* data_sampling_loc, 929 | const scalar_t* data_attn_weight, 930 | const int batch_size, 931 | const int spatial_size, 932 | const int num_heads, 933 | const int channels, 934 | const int num_levels, 935 | const int num_query, 936 | const int num_point, 937 | scalar_t* data_col) 938 | { 939 | const int num_kernels = batch_size * num_query * num_heads * channels; 940 | const int num_actual_kernels = batch_size * num_query * num_heads * channels; 941 | const int num_threads = CUDA_NUM_THREADS; 942 | ms_deformable_im2col_gpu_kernel 943 | <<>>( 945 | num_kernels, data_value, data_spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, 946 | batch_size, spatial_size, num_heads, channels, num_levels, num_query, num_point, data_col); 947 | 948 | cudaError_t err = cudaGetLastError(); 949 | if (err != cudaSuccess) 950 | { 951 | printf("error in ms_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); 952 | } 953 | 954 | } 955 | 956 | template 957 | void ms_deformable_col2im_cuda(cudaStream_t stream, 958 | const scalar_t* grad_col, 959 | const scalar_t* data_value, 960 | const int64_t * data_spatial_shapes, 961 | const int64_t * data_level_start_index, 962 | const scalar_t * data_sampling_loc, 963 | const scalar_t * data_attn_weight, 964 | const int batch_size, 965 | const int spatial_size, 966 | const int num_heads, 967 | const int channels, 968 | const int num_levels, 969 | const int num_query, 970 | const int num_point, 971 | scalar_t* grad_value, 972 | scalar_t* grad_sampling_loc, 973 | scalar_t* grad_attn_weight) 974 | { 975 | const int num_threads = (channels > CUDA_NUM_THREADS)?CUDA_NUM_THREADS:channels; 976 | const int num_kernels = batch_size * num_query * num_heads * channels; 977 | const int num_actual_kernels = batch_size * num_query * num_heads * channels; 978 | if (channels > 1024) 979 | { 980 | if ((channels & 1023) == 0) 981 | { 982 | ms_deformable_col2im_gpu_kernel_shm_reduce_v2_multi_blocks 983 | <<>>( 985 | num_kernels, 986 | grad_col, 987 | data_value, 988 | data_spatial_shapes, 989 | data_level_start_index, 990 | data_sampling_loc, 991 | data_attn_weight, 992 | batch_size, 993 | spatial_size, 994 | num_heads, 995 | channels, 996 | num_levels, 997 | num_query, 998 | num_point, 999 | grad_value, 1000 | grad_sampling_loc, 1001 | grad_attn_weight); 1002 | } 1003 | else 1004 | { 1005 | ms_deformable_col2im_gpu_kernel_gm 1006 | <<>>( 1008 | num_kernels, 1009 | grad_col, 1010 | data_value, 1011 | data_spatial_shapes, 1012 | data_level_start_index, 1013 | data_sampling_loc, 1014 | data_attn_weight, 1015 | batch_size, 1016 | spatial_size, 1017 | num_heads, 1018 | channels, 1019 | num_levels, 1020 | num_query, 1021 | num_point, 1022 | grad_value, 1023 | grad_sampling_loc, 1024 | grad_attn_weight); 1025 | } 1026 | } 1027 | else{ 1028 | switch(channels) 1029 | { 1030 | case 1: 1031 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 1032 | <<>>( 1034 | num_kernels, 1035 | grad_col, 1036 | data_value, 1037 | data_spatial_shapes, 1038 | data_level_start_index, 1039 | data_sampling_loc, 1040 | data_attn_weight, 1041 | batch_size, 1042 | spatial_size, 1043 | num_heads, 1044 | channels, 1045 | num_levels, 1046 | num_query, 1047 | num_point, 1048 | grad_value, 1049 | grad_sampling_loc, 1050 | grad_attn_weight); 1051 | break; 1052 | case 2: 1053 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 1054 | <<>>( 1056 | num_kernels, 1057 | grad_col, 1058 | data_value, 1059 | data_spatial_shapes, 1060 | data_level_start_index, 1061 | data_sampling_loc, 1062 | data_attn_weight, 1063 | batch_size, 1064 | spatial_size, 1065 | num_heads, 1066 | channels, 1067 | num_levels, 1068 | num_query, 1069 | num_point, 1070 | grad_value, 1071 | grad_sampling_loc, 1072 | grad_attn_weight); 1073 | break; 1074 | case 4: 1075 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 1076 | <<>>( 1078 | num_kernels, 1079 | grad_col, 1080 | data_value, 1081 | data_spatial_shapes, 1082 | data_level_start_index, 1083 | data_sampling_loc, 1084 | data_attn_weight, 1085 | batch_size, 1086 | spatial_size, 1087 | num_heads, 1088 | channels, 1089 | num_levels, 1090 | num_query, 1091 | num_point, 1092 | grad_value, 1093 | grad_sampling_loc, 1094 | grad_attn_weight); 1095 | break; 1096 | case 8: 1097 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 1098 | <<>>( 1100 | num_kernels, 1101 | grad_col, 1102 | data_value, 1103 | data_spatial_shapes, 1104 | data_level_start_index, 1105 | data_sampling_loc, 1106 | data_attn_weight, 1107 | batch_size, 1108 | spatial_size, 1109 | num_heads, 1110 | channels, 1111 | num_levels, 1112 | num_query, 1113 | num_point, 1114 | grad_value, 1115 | grad_sampling_loc, 1116 | grad_attn_weight); 1117 | break; 1118 | case 16: 1119 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 1120 | <<>>( 1122 | num_kernels, 1123 | grad_col, 1124 | data_value, 1125 | data_spatial_shapes, 1126 | data_level_start_index, 1127 | data_sampling_loc, 1128 | data_attn_weight, 1129 | batch_size, 1130 | spatial_size, 1131 | num_heads, 1132 | channels, 1133 | num_levels, 1134 | num_query, 1135 | num_point, 1136 | grad_value, 1137 | grad_sampling_loc, 1138 | grad_attn_weight); 1139 | break; 1140 | case 32: 1141 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v1 1142 | <<>>( 1144 | num_kernels, 1145 | grad_col, 1146 | data_value, 1147 | data_spatial_shapes, 1148 | data_level_start_index, 1149 | data_sampling_loc, 1150 | data_attn_weight, 1151 | batch_size, 1152 | spatial_size, 1153 | num_heads, 1154 | channels, 1155 | num_levels, 1156 | num_query, 1157 | num_point, 1158 | grad_value, 1159 | grad_sampling_loc, 1160 | grad_attn_weight); 1161 | break; 1162 | case 64: 1163 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 1164 | <<>>( 1166 | num_kernels, 1167 | grad_col, 1168 | data_value, 1169 | data_spatial_shapes, 1170 | data_level_start_index, 1171 | data_sampling_loc, 1172 | data_attn_weight, 1173 | batch_size, 1174 | spatial_size, 1175 | num_heads, 1176 | channels, 1177 | num_levels, 1178 | num_query, 1179 | num_point, 1180 | grad_value, 1181 | grad_sampling_loc, 1182 | grad_attn_weight); 1183 | break; 1184 | case 128: 1185 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 1186 | <<>>( 1188 | num_kernels, 1189 | grad_col, 1190 | data_value, 1191 | data_spatial_shapes, 1192 | data_level_start_index, 1193 | data_sampling_loc, 1194 | data_attn_weight, 1195 | batch_size, 1196 | spatial_size, 1197 | num_heads, 1198 | channels, 1199 | num_levels, 1200 | num_query, 1201 | num_point, 1202 | grad_value, 1203 | grad_sampling_loc, 1204 | grad_attn_weight); 1205 | break; 1206 | case 256: 1207 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 1208 | <<>>( 1210 | num_kernels, 1211 | grad_col, 1212 | data_value, 1213 | data_spatial_shapes, 1214 | data_level_start_index, 1215 | data_sampling_loc, 1216 | data_attn_weight, 1217 | batch_size, 1218 | spatial_size, 1219 | num_heads, 1220 | channels, 1221 | num_levels, 1222 | num_query, 1223 | num_point, 1224 | grad_value, 1225 | grad_sampling_loc, 1226 | grad_attn_weight); 1227 | break; 1228 | case 512: 1229 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 1230 | <<>>( 1232 | num_kernels, 1233 | grad_col, 1234 | data_value, 1235 | data_spatial_shapes, 1236 | data_level_start_index, 1237 | data_sampling_loc, 1238 | data_attn_weight, 1239 | batch_size, 1240 | spatial_size, 1241 | num_heads, 1242 | channels, 1243 | num_levels, 1244 | num_query, 1245 | num_point, 1246 | grad_value, 1247 | grad_sampling_loc, 1248 | grad_attn_weight); 1249 | break; 1250 | case 1024: 1251 | ms_deformable_col2im_gpu_kernel_shm_blocksize_aware_reduce_v2 1252 | <<>>( 1254 | num_kernels, 1255 | grad_col, 1256 | data_value, 1257 | data_spatial_shapes, 1258 | data_level_start_index, 1259 | data_sampling_loc, 1260 | data_attn_weight, 1261 | batch_size, 1262 | spatial_size, 1263 | num_heads, 1264 | channels, 1265 | num_levels, 1266 | num_query, 1267 | num_point, 1268 | grad_value, 1269 | grad_sampling_loc, 1270 | grad_attn_weight); 1271 | break; 1272 | default: 1273 | if (channels < 64) 1274 | { 1275 | ms_deformable_col2im_gpu_kernel_shm_reduce_v1 1276 | <<>>( 1278 | num_kernels, 1279 | grad_col, 1280 | data_value, 1281 | data_spatial_shapes, 1282 | data_level_start_index, 1283 | data_sampling_loc, 1284 | data_attn_weight, 1285 | batch_size, 1286 | spatial_size, 1287 | num_heads, 1288 | channels, 1289 | num_levels, 1290 | num_query, 1291 | num_point, 1292 | grad_value, 1293 | grad_sampling_loc, 1294 | grad_attn_weight); 1295 | } 1296 | else 1297 | { 1298 | ms_deformable_col2im_gpu_kernel_shm_reduce_v2 1299 | <<>>( 1301 | num_kernels, 1302 | grad_col, 1303 | data_value, 1304 | data_spatial_shapes, 1305 | data_level_start_index, 1306 | data_sampling_loc, 1307 | data_attn_weight, 1308 | batch_size, 1309 | spatial_size, 1310 | num_heads, 1311 | channels, 1312 | num_levels, 1313 | num_query, 1314 | num_point, 1315 | grad_value, 1316 | grad_sampling_loc, 1317 | grad_attn_weight); 1318 | } 1319 | } 1320 | } 1321 | cudaError_t err = cudaGetLastError(); 1322 | if (err != cudaSuccess) 1323 | { 1324 | printf("error in ms_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); 1325 | } 1326 | 1327 | } --------------------------------------------------------------------------------