├── data └── coco ├── images ├── 1.jpg ├── 2.jpg ├── det_res_1.jpg └── det_res_2.jpg ├── .assets ├── decoder_layer.jpg ├── semantics_aligner.jpg └── matching_complication.jpg ├── models ├── __init__.py ├── ops │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── src │ │ ├── vision.cpp │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_attn_cuda.cu │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.h │ │ │ └── ms_deform_attn_cpu.cpp │ │ └── ms_deform_attn.h │ ├── setup.py │ └── test.py ├── misc.py ├── position_encoding.py ├── matcher.py ├── backbone.py ├── transformer_encoder.py ├── transformer.py ├── transformer_decoder.py └── segmentation.py ├── scripts ├── r50_e12_4gpu.sh ├── r50_e50_4gpu.sh ├── r50_dc5_e12_8gpu.sh ├── r50_dc5_e50_8gpu.sh ├── r50_smca_e50_4gpu.sh ├── r50_smca_e12_4gpu.sh ├── r50_dc5_smca_e12_8gpu.sh ├── r50_dc5_smca_e50_8gpu.sh ├── r50_ms_smca_e12_8gpu.sh └── r50_ms_smca_e50_8gpu.sh ├── util ├── __init__.py ├── box_ops.py └── misc.py ├── scripts_slurm ├── r50_e12_4gpu.sh ├── r50_e50_4gpu.sh ├── r50_dc5_e12_8gpu.sh ├── r50_dc5_e50_8gpu.sh ├── r50_smca_e12_4gpu.sh ├── r50_smca_e50_4gpu.sh ├── r50_dc5_smca_e12_8gpu.sh ├── r50_dc5_smca_e50_8gpu.sh ├── r50_ms_smca_e12_8gpu.sh └── r50_ms_smca_e50_8gpu.sh ├── LICENSE ├── datasets ├── __init__.py ├── panoptic_eval.py ├── coco_panoptic.py ├── coco.py ├── transforms.py └── coco_eval.py ├── .gitignore ├── engine.py ├── demo.py ├── main.py └── README.md /data/coco: -------------------------------------------------------------------------------- 1 | /mnt/lustre/gjzhang/Datasets/coco/coco -------------------------------------------------------------------------------- /images/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/images/1.jpg -------------------------------------------------------------------------------- /images/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/images/2.jpg -------------------------------------------------------------------------------- /images/det_res_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/images/det_res_1.jpg -------------------------------------------------------------------------------- /images/det_res_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/images/det_res_2.jpg -------------------------------------------------------------------------------- /.assets/decoder_layer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/.assets/decoder_layer.jpg -------------------------------------------------------------------------------- /.assets/semantics_aligner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/.assets/semantics_aligner.jpg -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .fast_detr import build 2 | 3 | 4 | def build_model(args): 5 | return build(args) 6 | -------------------------------------------------------------------------------- /.assets/matching_complication.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZhangGongjie/SAM-DETR/HEAD/.assets/matching_complication.jpg -------------------------------------------------------------------------------- /scripts/r50_e12_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=4 \ 13 | --use_env main.py \ 14 | --batch_size 4 \ 15 | --epochs 12 \ 16 | --lr_drop 10 \ 17 | --output_dir ${EXP_DIR} \ 18 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 19 | 20 | -------------------------------------------------------------------------------- /scripts/r50_e50_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=4 \ 13 | --use_env main.py \ 14 | --batch_size 4 \ 15 | --epochs 50 \ 16 | --lr_drop 40 \ 17 | --output_dir ${EXP_DIR} \ 18 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 19 | 20 | 21 | -------------------------------------------------------------------------------- /scripts/r50_dc5_e12_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=8 \ 13 | --use_env main.py \ 14 | --batch_size 1 \ 15 | --dilation \ 16 | --epochs 12 \ 17 | --lr_drop 10 \ 18 | --output_dir ${EXP_DIR} \ 19 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 20 | 21 | -------------------------------------------------------------------------------- /scripts/r50_dc5_e50_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=8 \ 13 | --use_env main.py \ 14 | --batch_size 1 \ 15 | --dilation \ 16 | --epochs 50 \ 17 | --lr_drop 40 \ 18 | --output_dir ${EXP_DIR} \ 19 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 20 | 21 | -------------------------------------------------------------------------------- /scripts/r50_smca_e50_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_smca_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=4 \ 13 | --use_env main.py \ 14 | --batch_size 4 \ 15 | --smca \ 16 | --epochs 50 \ 17 | --lr_drop 40 \ 18 | --output_dir ${EXP_DIR} \ 19 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 20 | 21 | -------------------------------------------------------------------------------- /scripts/r50_smca_e12_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_smca_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=4 \ 13 | --use_env main.py \ 14 | --batch_size 4 \ 15 | --smca \ 16 | --epochs 12 \ 17 | --lr_drop 10 \ 18 | --output_dir ${EXP_DIR} \ 19 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 20 | 21 | 22 | -------------------------------------------------------------------------------- /scripts/r50_dc5_smca_e12_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_smca_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=8 \ 13 | --use_env main.py \ 14 | --batch_size 1 \ 15 | --smca \ 16 | --dilation \ 17 | --epochs 12 \ 18 | --lr_drop 10 \ 19 | --output_dir ${EXP_DIR} \ 20 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 21 | 22 | -------------------------------------------------------------------------------- /scripts/r50_dc5_smca_e50_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_smca_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=8 \ 13 | --use_env main.py \ 14 | --batch_size 1 \ 15 | --smca \ 16 | --dilation \ 17 | --epochs 50 \ 18 | --lr_drop 40 \ 19 | --output_dir ${EXP_DIR} \ 20 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 21 | 22 | -------------------------------------------------------------------------------- /scripts/r50_ms_smca_e12_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_ms_smca_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=8 \ 13 | --use_env main.py \ 14 | --batch_size 2 \ 15 | --smca \ 16 | --multiscale \ 17 | --epochs 12 \ 18 | --lr_drop 10 \ 19 | --output_dir ${EXP_DIR} \ 20 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 21 | 22 | 23 | -------------------------------------------------------------------------------- /scripts/r50_ms_smca_e50_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_ms_smca_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | python -m torch.distributed.launch \ 12 | --nproc_per_node=8 \ 13 | --use_env main.py \ 14 | --batch_size 2 \ 15 | --smca \ 16 | --multiscale \ 17 | --epochs 50 \ 18 | --lr_drop 40 \ 19 | --output_dir ${EXP_DIR} \ 20 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 21 | 22 | 23 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | -------------------------------------------------------------------------------- /scripts_slurm/r50_e12_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:4 \ 14 | --ntasks=4 \ 15 | --ntasks-per-node=4 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 4 \ 20 | --epochs 12 \ 21 | --lr_drop 10 \ 22 | --output_dir ${EXP_DIR} \ 23 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 24 | -------------------------------------------------------------------------------- /scripts_slurm/r50_e50_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:4 \ 14 | --ntasks=4 \ 15 | --ntasks-per-node=4 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 4 \ 20 | --epochs 50 \ 21 | --lr_drop 40 \ 22 | --output_dir ${EXP_DIR} \ 23 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 24 | -------------------------------------------------------------------------------- /scripts_slurm/r50_dc5_e12_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:8 \ 14 | --ntasks=8 \ 15 | --ntasks-per-node=8 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 1 \ 20 | --dilation \ 21 | --epochs 12 \ 22 | --lr_drop 10 \ 23 | --output_dir ${EXP_DIR} \ 24 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 25 | -------------------------------------------------------------------------------- /scripts_slurm/r50_dc5_e50_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:8 \ 14 | --ntasks=8 \ 15 | --ntasks-per-node=8 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 1 \ 20 | --dilation \ 21 | --epochs 50 \ 22 | --lr_drop 40 \ 23 | --output_dir ${EXP_DIR} \ 24 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 25 | -------------------------------------------------------------------------------- /scripts_slurm/r50_smca_e12_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_smca_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:4 \ 14 | --ntasks=4 \ 15 | --ntasks-per-node=4 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 4 \ 20 | --smca \ 21 | --epochs 12 \ 22 | --lr_drop 10 \ 23 | --output_dir ${EXP_DIR} \ 24 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 25 | -------------------------------------------------------------------------------- /scripts_slurm/r50_smca_e50_4gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_smca_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:4 \ 14 | --ntasks=4 \ 15 | --ntasks-per-node=4 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 4 \ 20 | --smca \ 21 | --epochs 50 \ 22 | --lr_drop 40 \ 23 | --output_dir ${EXP_DIR} \ 24 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 25 | -------------------------------------------------------------------------------- /scripts_slurm/r50_dc5_smca_e12_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_smca_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:8 \ 14 | --ntasks=8 \ 15 | --ntasks-per-node=8 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 1 \ 20 | --smca \ 21 | --dilation \ 22 | --epochs 12 \ 23 | --lr_drop 10 \ 24 | --output_dir ${EXP_DIR} \ 25 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 26 | -------------------------------------------------------------------------------- /scripts_slurm/r50_dc5_smca_e50_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_dc5_smca_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:8 \ 14 | --ntasks=8 \ 15 | --ntasks-per-node=8 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 1 \ 20 | --smca \ 21 | --dilation \ 22 | --epochs 50 \ 23 | --lr_drop 40 \ 24 | --output_dir ${EXP_DIR} \ 25 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 26 | -------------------------------------------------------------------------------- /scripts_slurm/r50_ms_smca_e12_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_ms_smca_e12 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:8 \ 14 | --ntasks=8 \ 15 | --ntasks-per-node=8 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 2 \ 20 | --smca \ 21 | --multiscale \ 22 | --epochs 12 \ 23 | --lr_drop 10 \ 24 | --output_dir ${EXP_DIR} \ 25 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 26 | -------------------------------------------------------------------------------- /scripts_slurm/r50_ms_smca_e50_8gpu.sh: -------------------------------------------------------------------------------- 1 | EXP_DIR=output/r50_ms_smca_e50 2 | 3 | if [ ! -d "output" ]; then 4 | mkdir output 5 | fi 6 | 7 | if [ ! -d "${EXP_DIR}" ]; then 8 | mkdir ${EXP_DIR} 9 | fi 10 | 11 | srun -p cluster_name \ 12 | --job-name=SAM-DETR \ 13 | --gres=gpu:8 \ 14 | --ntasks=8 \ 15 | --ntasks-per-node=8 \ 16 | --cpus-per-task=2 \ 17 | --kill-on-bad-exit=1 \ 18 | python main.py \ 19 | --batch_size 2 \ 20 | --smca \ 21 | --multiscale \ 22 | --epochs 50 \ 23 | --lr_drop 40 \ 24 | --output_dir ${EXP_DIR} \ 25 | 2>&1 | tee ${EXP_DIR}/detailed_log.txt 26 | -------------------------------------------------------------------------------- /models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 GJ Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import torch.utils.data 7 | import torchvision 8 | 9 | from .coco import build as build_coco 10 | 11 | 12 | def get_coco_api_from_dataset(dataset): 13 | for _ in range(10): 14 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 15 | # break 16 | if isinstance(dataset, torch.utils.data.Subset): 17 | dataset = dataset.dataset 18 | if isinstance(dataset, torchvision.datasets.CocoDetection): 19 | return dataset.coco 20 | 21 | 22 | def build_dataset(image_set, args): 23 | if args.dataset_file == 'coco': 24 | return build_coco(image_set, args) 25 | if args.dataset_file == 'coco_panoptic': 26 | # to avoid making panopticapi required for coco 27 | from .coco_panoptic import build as build_coco_panoptic 28 | return build_coco_panoptic(image_set, args) 29 | raise ValueError(f'dataset {args.dataset_file} not supported') 30 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /models/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | import copy 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def _get_clones(module, num_layers): 11 | return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)]) 12 | 13 | 14 | class MLP(nn.Module): 15 | """ Very simple multi-layer perceptron (also called Feed-Forward-Networks -- FFN) """ 16 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 17 | super().__init__() 18 | self.num_layers = num_layers 19 | h = [hidden_dim] * (num_layers - 1) 20 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 21 | 22 | def forward(self, x): 23 | for i, layer in enumerate(self.layers): 24 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 25 | return x 26 | 27 | 28 | def _get_activation_fn(activation): 29 | """Return an activation function given a string""" 30 | if activation == "relu": 31 | return F.relu 32 | if activation == "gelu": 33 | return F.gelu 34 | if activation == "glu": 35 | return F.glu 36 | raise RuntimeError(F"activation should be relu/gelu/glu, not {activation}.") 37 | -------------------------------------------------------------------------------- /datasets/panoptic_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import json 7 | import os 8 | 9 | import util.misc as utils 10 | 11 | try: 12 | from panopticapi.evaluation import pq_compute 13 | except ImportError: 14 | pass 15 | 16 | 17 | class PanopticEvaluator(object): 18 | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): 19 | self.gt_json = ann_file 20 | self.gt_folder = ann_folder 21 | if utils.is_main_process(): 22 | if not os.path.exists(output_dir): 23 | os.mkdir(output_dir) 24 | self.output_dir = output_dir 25 | self.predictions = [] 26 | 27 | def update(self, predictions): 28 | for p in predictions: 29 | with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: 30 | f.write(p.pop("png_string")) 31 | 32 | self.predictions += predictions 33 | 34 | def synchronize_between_processes(self): 35 | all_predictions = utils.all_gather(self.predictions) 36 | merged_predictions = [] 37 | for p in all_predictions: 38 | merged_predictions += p 39 | self.predictions = merged_predictions 40 | 41 | def summarize(self): 42 | if utils.is_main_process(): 43 | json_data = {"annotations": self.predictions} 44 | predictions_json = os.path.join(self.output_dir, "predictions.json") 45 | with open(predictions_json, "w") as f: 46 | f.write(json.dumps(json_data)) 47 | return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) 48 | return None 49 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | Utilities for bounding box manipulation and GIoU. 8 | """ 9 | import torch 10 | from torchvision.ops.boxes import box_area 11 | 12 | 13 | def box_cxcywh_to_xyxy(x): 14 | x_c, y_c, w, h = x.unbind(-1) 15 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 16 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 17 | return torch.stack(b, dim=-1) 18 | 19 | 20 | def box_xyxy_to_cxcywh(x): 21 | x0, y0, x1, y1 = x.unbind(-1) 22 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 23 | (x1 - x0), (y1 - y0)] 24 | return torch.stack(b, dim=-1) 25 | 26 | 27 | # modified from torchvision to also return the union 28 | def box_iou(boxes1, boxes2): 29 | area1 = box_area(boxes1) 30 | area2 = box_area(boxes2) 31 | 32 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 33 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 34 | 35 | wh = (rb - lt).clamp(min=0) # [N,M,2] 36 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 37 | 38 | union = area1[:, None] + area2 - inter 39 | 40 | iou = inter / union 41 | return iou, union 42 | 43 | 44 | def generalized_box_iou(boxes1, boxes2): 45 | """ 46 | Generalized IoU from https://giou.stanford.edu/ 47 | 48 | The boxes should be in [x0, y0, x1, y1] format 49 | 50 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 51 | and M = len(boxes2) 52 | """ 53 | # degenerate boxes gives inf / nan results 54 | # so do an early check 55 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 56 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 57 | iou, union = box_iou(boxes1, boxes2) 58 | 59 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 60 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 61 | 62 | wh = (rb - lt).clamp(min=0) # [N,M,2] 63 | area = wh[:, :, 0] * wh[:, :, 1] 64 | 65 | return iou - (area - union) / area 66 | 67 | 68 | def masks_to_boxes(masks): 69 | """Compute the bounding boxes around the provided masks 70 | 71 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 72 | 73 | Returns a [N, 4] tensors, with the boxes in xyxy format 74 | """ 75 | if masks.numel() == 0: 76 | return torch.zeros((0, 4), device=masks.device) 77 | 78 | h, w = masks.shape[-2:] 79 | 80 | y = torch.arange(0, h, dtype=torch.float) 81 | x = torch.arange(0, w, dtype=torch.float) 82 | y, x = torch.meshgrid(y, x) 83 | 84 | x_mask = (masks * x.unsqueeze(0)) 85 | x_max = x_mask.flatten(1).max(-1)[0] 86 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | y_mask = (masks * y.unsqueeze(0)) 89 | y_max = y_mask.flatten(1).max(-1)[0] 90 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 91 | 92 | return torch.stack([x_min, y_min, x_max, y_max], 1) 93 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | positional encodings for the transformer. 8 | """ 9 | import math 10 | import torch 11 | import torch.nn as nn 12 | 13 | from util.misc import NestedTensor 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention is all you need paper, generalized to work on images. 20 | """ 21 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 22 | super().__init__() 23 | self.num_pos_feats = num_pos_feats 24 | self.temperature = temperature 25 | self.normalize = normalize 26 | if scale is not None and normalize is False: 27 | raise ValueError("normalize should be True if scale is passed") 28 | if scale is None: 29 | scale = 2 * math.pi 30 | self.scale = scale 31 | 32 | def forward(self, tensor_list: NestedTensor): 33 | x = tensor_list.tensors 34 | mask = tensor_list.mask 35 | assert mask is not None 36 | not_mask = ~mask 37 | y_embed = not_mask.cumsum(1, dtype=torch.float32) - 0.5 38 | x_embed = not_mask.cumsum(2, dtype=torch.float32) - 0.5 39 | if self.normalize: 40 | eps = 1e-6 41 | y_embed = y_embed / (y_embed[:, -1:, :] + eps + 0.5) * self.scale 42 | x_embed = x_embed / (x_embed[:, :, -1:] + eps + 0.5) * self.scale 43 | 44 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 45 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 46 | 47 | pos_x = x_embed[:, :, :, None] / dim_t 48 | pos_y = y_embed[:, :, :, None] / dim_t 49 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 50 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 51 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 52 | return pos 53 | 54 | 55 | @torch.no_grad() 56 | def gen_sineembed_for_position(pos_tensor): 57 | scale = 2 * math.pi 58 | dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) 59 | dim_t = 10000 ** (2 * (dim_t // 2) / 128) 60 | x_embed = pos_tensor[:, :, 0] * scale 61 | y_embed = pos_tensor[:, :, 1] * scale 62 | pos_x = x_embed[:, :, None] / dim_t 63 | pos_y = y_embed[:, :, None] / dim_t 64 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 65 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 66 | pos = torch.cat((pos_y, pos_x), dim=2) 67 | return pos 68 | 69 | 70 | def build_position_encoding(args): 71 | if args.position_embedding in ('sine'): 72 | position_embedding = PositionEmbeddingSine(args.hidden_dim // 2, normalize=True) 73 | else: 74 | raise ValueError(f"Unknown args.position_embedding: {args.position_embedding}.") 75 | return position_embedding 76 | -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import json 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image 12 | 13 | from panopticapi.utils import rgb2id 14 | from util.box_ops import masks_to_boxes 15 | 16 | from .coco import make_coco_transforms 17 | 18 | 19 | class CocoPanoptic: 20 | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): 21 | with open(ann_file, 'r') as f: 22 | self.coco = json.load(f) 23 | 24 | # sort 'images' field so that they are aligned with 'annotations' 25 | # i.e., in alphabetical order 26 | self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) 27 | # sanity check 28 | if "annotations" in self.coco: 29 | for img, ann in zip(self.coco['images'], self.coco['annotations']): 30 | assert img['file_name'][:-4] == ann['file_name'][:-4] 31 | 32 | self.img_folder = img_folder 33 | self.ann_folder = ann_folder 34 | self.ann_file = ann_file 35 | self.transforms = transforms 36 | self.return_masks = return_masks 37 | 38 | def __getitem__(self, idx): 39 | ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] 40 | img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') 41 | ann_path = Path(self.ann_folder) / ann_info['file_name'] 42 | 43 | img = Image.open(img_path).convert('RGB') 44 | w, h = img.size 45 | if "segments_info" in ann_info: 46 | masks = np.asarray(Image.open(ann_path), dtype=np.uint32) 47 | masks = rgb2id(masks) 48 | 49 | ids = np.array([ann['id'] for ann in ann_info['segments_info']]) 50 | masks = masks == ids[:, None, None] 51 | 52 | masks = torch.as_tensor(masks, dtype=torch.uint8) 53 | labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) 54 | 55 | target = {} 56 | target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) 57 | if self.return_masks: 58 | target['masks'] = masks 59 | target['labels'] = labels 60 | 61 | target["boxes"] = masks_to_boxes(masks) 62 | 63 | target['size'] = torch.as_tensor([int(h), int(w)]) 64 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 65 | if "segments_info" in ann_info: 66 | for name in ['iscrowd', 'area']: 67 | target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) 68 | 69 | if self.transforms is not None: 70 | img, target = self.transforms(img, target) 71 | 72 | return img, target 73 | 74 | def __len__(self): 75 | return len(self.coco['images']) 76 | 77 | def get_height_and_width(self, idx): 78 | img_info = self.coco['images'][idx] 79 | height = img_info['height'] 80 | width = img_info['width'] 81 | return height, width 82 | 83 | 84 | def build(image_set, args): 85 | img_folder_root = Path(args.coco_path) 86 | ann_folder_root = Path(args.coco_panoptic_path) 87 | assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' 88 | assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' 89 | mode = 'panoptic' 90 | PATHS = { 91 | "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), 92 | "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), 93 | } 94 | 95 | img_folder, ann_file = PATHS[image_set] 96 | img_folder_path = img_folder_root / img_folder 97 | ann_folder = ann_folder_root / f'{mode}_{img_folder}' 98 | ann_file = ann_folder_root / ann_file 99 | 100 | dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, 101 | transforms=make_coco_transforms(image_set), return_masks=args.masks) 102 | 103 | return dataset 104 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | 9 | import torch 10 | from scipy.optimize import linear_sum_assignment 11 | from torch import nn 12 | 13 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 14 | 15 | 16 | class HungarianMatcher(nn.Module): 17 | """This class computes an assignment between the targets and the predictions of the network 18 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 19 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 20 | while the others are un-matched (and thus treated as non-objects). 21 | """ 22 | 23 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 24 | """Creates the matcher 25 | Params: 26 | cost_class: This is the relative weight of the classification error in the matching cost 27 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 28 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 29 | """ 30 | super().__init__() 31 | self.cost_class = cost_class 32 | self.cost_bbox = cost_bbox 33 | self.cost_giou = cost_giou 34 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 35 | 36 | @torch.no_grad() 37 | def forward(self, outputs, targets): 38 | """ Performs the matching 39 | Params: 40 | outputs: This is a dict that contains at least these entries: 41 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 42 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 43 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 44 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 45 | objects in the target) containing the class labels 46 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 47 | Returns: 48 | A list of size batch_size, containing tuples of (index_i, index_j) where: 49 | - index_i is the indices of the selected predictions (in order) 50 | - index_j is the indices of the corresponding selected targets (in order) 51 | For each batch element, it holds: 52 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 53 | """ 54 | bs, num_queries = outputs["pred_logits"].shape[:2] 55 | 56 | # We flatten to compute the cost matrices in a batch 57 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] 58 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 59 | 60 | # Also concat the target labels and boxes 61 | tgt_ids = torch.cat([v["labels"] for v in targets]) 62 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 63 | 64 | # Compute the classification cost. 65 | alpha = 0.25 66 | gamma = 2.0 67 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 68 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 69 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 70 | 71 | # Compute the L1 cost between boxes 72 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 73 | 74 | # Compute the giou cost betwen boxes 75 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 76 | 77 | # Final cost matrix 78 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 79 | C = C.view(bs, num_queries, -1).cpu() 80 | 81 | sizes = [len(v["boxes"]) for v in targets] 82 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 83 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 84 | 85 | 86 | def build_matcher(args): 87 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) 88 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Mofified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | Backbone modules. 8 | """ 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision 12 | from torch import nn 13 | from torchvision.models._utils import IntermediateLayerGetter 14 | from typing import Dict, List 15 | 16 | from util.misc import NestedTensor, is_main_process 17 | from .position_encoding import build_position_encoding 18 | 19 | 20 | class FrozenBatchNorm2d(torch.nn.Module): 21 | """ 22 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 23 | 24 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 25 | without which any other models than torchvision.models.resnet[18,34,50,101] 26 | produce nans. 27 | """ 28 | 29 | def __init__(self, n): 30 | super(FrozenBatchNorm2d, self).__init__() 31 | self.register_buffer("weight", torch.ones(n)) 32 | self.register_buffer("bias", torch.zeros(n)) 33 | self.register_buffer("running_mean", torch.zeros(n)) 34 | self.register_buffer("running_var", torch.ones(n)) 35 | 36 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | # move reshapes to the beginning 47 | # to make it fuser-friendly 48 | w = self.weight.reshape(1, -1, 1, 1) 49 | b = self.bias.reshape(1, -1, 1, 1) 50 | rv = self.running_var.reshape(1, -1, 1, 1) 51 | rm = self.running_mean.reshape(1, -1, 1, 1) 52 | eps = 1e-5 53 | scale = w * (rv + eps).rsqrt() 54 | bias = b - rm * scale 55 | return x * scale + bias 56 | 57 | 58 | class BackboneBase(nn.Module): 59 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 60 | super().__init__() 61 | for name, parameter in backbone.named_parameters(): 62 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 63 | parameter.requires_grad_(False) 64 | if return_interm_layers: 65 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 66 | # Hard-coded backbone parameters 67 | self.strides = [8, 16, 32] 68 | self.num_channels = [512, 1024, 2048] 69 | else: 70 | return_layers = {'layer4': "0"} 71 | # Hard-coded backbone parameters 72 | self.strides = [32] 73 | self.num_channels = [2048] 74 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 75 | 76 | def forward(self, tensor_list: NestedTensor): 77 | xs = self.body(tensor_list.tensors) 78 | out: Dict[str, NestedTensor] = {} 79 | for name, x in xs.items(): 80 | m = tensor_list.mask 81 | assert m is not None 82 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 83 | out[name] = NestedTensor(x, mask) 84 | return out 85 | 86 | 87 | class Backbone(BackboneBase): 88 | """ResNet backbone with frozen BatchNorm.""" 89 | def __init__(self, name: str, 90 | train_backbone: bool, 91 | return_interm_layers: bool, 92 | dilation: bool): 93 | backbone = getattr(torchvision.models, name)( 94 | replace_stride_with_dilation=[False, False, dilation], 95 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 96 | assert name not in ('resnet18', 'resnet34'), "Number of channels are hard coded, thus do not support res18/34." 97 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 98 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 99 | 100 | 101 | class Joiner(nn.Sequential): 102 | def __init__(self, backbone, position_embedding): 103 | super().__init__(backbone, position_embedding) 104 | 105 | def forward(self, tensor_list: NestedTensor): 106 | xs = self[0](tensor_list) 107 | out: List[NestedTensor] = [] 108 | pos = [] 109 | for name, x in xs.items(): 110 | out.append(x) 111 | # position encoding 112 | pos.append(self[1](x).to(x.tensors.dtype)) 113 | return out, pos 114 | 115 | 116 | def build_backbone(args): 117 | position_embedding = build_position_encoding(args) 118 | train_backbone = args.lr_backbone > 0 119 | return_interm_layers = args.masks or args.multiscale 120 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 121 | model = Joiner(backbone, position_embedding) 122 | model.num_channels = backbone.num_channels 123 | return model 124 | -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | COCO dataset which returns image_id for evaluation. 8 | """ 9 | from pathlib import Path 10 | 11 | import torch 12 | import torch.utils.data 13 | import torchvision 14 | from pycocotools import mask as coco_mask 15 | 16 | import datasets.transforms as T 17 | 18 | 19 | class CocoDetection(torchvision.datasets.CocoDetection): 20 | def __init__(self, img_folder, ann_file, transforms, return_masks): 21 | super(CocoDetection, self).__init__(img_folder, ann_file) 22 | self._transforms = transforms 23 | self.prepare = ConvertCocoPolysToMask(return_masks) 24 | 25 | def __getitem__(self, idx): 26 | img, target = super(CocoDetection, self).__getitem__(idx) 27 | image_id = self.ids[idx] 28 | target = {'image_id': image_id, 'annotations': target} 29 | img, target = self.prepare(img, target) 30 | if self._transforms is not None: 31 | img, target = self._transforms(img, target) 32 | return img, target 33 | 34 | 35 | def convert_coco_poly_to_mask(segmentations, height, width): 36 | masks = [] 37 | for polygons in segmentations: 38 | rles = coco_mask.frPyObjects(polygons, height, width) 39 | mask = coco_mask.decode(rles) 40 | if len(mask.shape) < 3: 41 | mask = mask[..., None] 42 | mask = torch.as_tensor(mask, dtype=torch.uint8) 43 | mask = mask.any(dim=2) 44 | masks.append(mask) 45 | if masks: 46 | masks = torch.stack(masks, dim=0) 47 | else: 48 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 49 | return masks 50 | 51 | 52 | class ConvertCocoPolysToMask(object): 53 | def __init__(self, return_masks=False): 54 | self.return_masks = return_masks 55 | 56 | def __call__(self, image, target): 57 | w, h = image.size 58 | 59 | image_id = target["image_id"] 60 | image_id = torch.tensor([image_id]) 61 | 62 | anno = target["annotations"] 63 | 64 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 65 | 66 | boxes = [obj["bbox"] for obj in anno] 67 | # guard against no boxes via resizing 68 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 69 | boxes[:, 2:] += boxes[:, :2] 70 | boxes[:, 0::2].clamp_(min=0, max=w) 71 | boxes[:, 1::2].clamp_(min=0, max=h) 72 | 73 | classes = [obj["category_id"] for obj in anno] 74 | classes = torch.tensor(classes, dtype=torch.int64) 75 | 76 | if self.return_masks: 77 | segmentations = [obj["segmentation"] for obj in anno] 78 | masks = convert_coco_poly_to_mask(segmentations, h, w) 79 | 80 | keypoints = None 81 | if anno and "keypoints" in anno[0]: 82 | keypoints = [obj["keypoints"] for obj in anno] 83 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 84 | num_keypoints = keypoints.shape[0] 85 | if num_keypoints: 86 | keypoints = keypoints.view(num_keypoints, -1, 3) 87 | 88 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 89 | boxes = boxes[keep] 90 | classes = classes[keep] 91 | if self.return_masks: 92 | masks = masks[keep] 93 | if keypoints is not None: 94 | keypoints = keypoints[keep] 95 | 96 | target = {} 97 | target["boxes"] = boxes 98 | target["labels"] = classes 99 | if self.return_masks: 100 | target["masks"] = masks 101 | target["image_id"] = image_id 102 | if keypoints is not None: 103 | target["keypoints"] = keypoints 104 | 105 | # for conversion to coco api 106 | area = torch.tensor([obj["area"] for obj in anno]) 107 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 108 | target["area"] = area[keep] 109 | target["iscrowd"] = iscrowd[keep] 110 | 111 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 112 | target["size"] = torch.as_tensor([int(h), int(w)]) 113 | 114 | return image, target 115 | 116 | 117 | def make_coco_transforms(image_set): 118 | 119 | normalize = T.Compose([ 120 | T.ToTensor(), 121 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 122 | ]) 123 | 124 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 125 | 126 | if image_set == 'train': 127 | return T.Compose([ 128 | T.RandomHorizontalFlip(), 129 | T.RandomSelect( 130 | T.RandomResize(scales, max_size=1333), 131 | T.Compose([ 132 | T.RandomResize([400, 500, 600]), 133 | T.RandomSizeCrop(384, 600), 134 | T.RandomResize(scales, max_size=1333), 135 | ]) 136 | ), 137 | normalize, 138 | ]) 139 | 140 | if image_set == 'val': 141 | return T.Compose([ 142 | T.RandomResize([800], max_size=1333), 143 | normalize, 144 | ]) 145 | 146 | raise ValueError(f'unknown {image_set}') 147 | 148 | 149 | def build(image_set, args): 150 | root = Path(args.coco_path) 151 | assert root.exists(), f'provided COCO path {root} does not exist' 152 | mode = 'instances' 153 | PATHS = { 154 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), 155 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), 156 | "test": (root / "test2017", root / "annotations" / f'image_info_test-dev2017.json'), 157 | } 158 | 159 | img_folder, ann_file = PATHS[image_set] 160 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) 161 | return dataset 162 | -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | self._reset_parameters() 61 | 62 | def _reset_parameters(self): 63 | constant_(self.sampling_offsets.weight.data, 0.) 64 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 65 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 66 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 67 | for i in range(self.n_points): 68 | grid_init[:, :, i, :] *= i + 1 69 | with torch.no_grad(): 70 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 71 | constant_(self.attention_weights.weight.data, 0.) 72 | constant_(self.attention_weights.bias.data, 0.) 73 | xavier_uniform_(self.value_proj.weight.data) 74 | constant_(self.value_proj.bias.data, 0.) 75 | xavier_uniform_(self.output_proj.weight.data) 76 | constant_(self.output_proj.bias.data, 0.) 77 | 78 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 79 | """ 80 | :param query (N, Length_{query}, C) 81 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 82 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 83 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 84 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 85 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 86 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 87 | 88 | :return output (N, Length_{query}, C) 89 | """ 90 | N, Len_q, _ = query.shape 91 | N, Len_in, _ = input_flatten.shape 92 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 93 | 94 | value = self.value_proj(input_flatten) 95 | if input_padding_mask is not None: 96 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 97 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 98 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 101 | # N, Len_q, n_heads, n_levels, n_points, 2 102 | if reference_points.shape[-1] == 2: 103 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 104 | sampling_locations = reference_points[:, :, None, :, None, :] \ 105 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 106 | elif reference_points.shape[-1] == 4: 107 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 108 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 109 | else: 110 | raise ValueError( 111 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 112 | output = MSDeformAttnFunction.apply( 113 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 114 | output = self.output_proj(output) 115 | return output 116 | -------------------------------------------------------------------------------- /models/transformer_encoder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 6 | # Copyright (c) SenseTime. All Rights Reserved. 7 | # ------------------------------------------------------------------------ 8 | from typing import Optional 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch import Tensor 13 | 14 | from models.misc import _get_clones, _get_activation_fn 15 | 16 | 17 | class TransformerEncoder(nn.Module): 18 | def __init__(self, args, encoder_layer, num_layers): 19 | super().__init__() 20 | self.args = args 21 | self.num_layers = num_layers 22 | self.layers = _get_clones(encoder_layer, num_layers) 23 | assert num_layers == self.args.enc_layers 24 | 25 | def forward(self, src, 26 | mask: Optional[Tensor] = None, 27 | src_key_padding_mask: Optional[Tensor] = None, 28 | pos: Optional[Tensor] = None): 29 | output = src 30 | for layer in self.layers: 31 | output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos) 32 | return output 33 | 34 | 35 | class TransformerEncoderLayer(nn.Module): 36 | def __init__(self, args, activation="relu"): 37 | super().__init__() 38 | self.args = args 39 | self.d_model = args.hidden_dim 40 | self.nheads = args.nheads 41 | self.num_queries = args.num_queries 42 | self.dim_feedforward = args.dim_feedforward 43 | self.dropout = args.dropout 44 | self.activation = _get_activation_fn(activation) 45 | 46 | # Encoder Self-Attention 47 | self.self_attn = nn.MultiheadAttention(self.d_model, self.nheads, dropout=self.dropout) 48 | self.norm1 = nn.LayerNorm(self.d_model) 49 | self.dropout1 = nn.Dropout(self.dropout) 50 | 51 | # FFN 52 | self.linear1 = nn.Linear(self.d_model, self.dim_feedforward) 53 | self.dropout2 = nn.Dropout(self.dropout) 54 | self.linear2 = nn.Linear(self.dim_feedforward, self.d_model) 55 | self.dropout3 = nn.Dropout(self.dropout) 56 | self.norm2 = nn.LayerNorm(self.d_model) 57 | 58 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 59 | return tensor if pos is None else tensor + pos 60 | 61 | def forward(self, src, 62 | src_mask: Optional[Tensor] = None, 63 | src_key_padding_mask: Optional[Tensor] = None, 64 | pos: Optional[Tensor] = None): 65 | # Self-Attention 66 | q = k = self.with_pos_embed(src, pos) 67 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0] 68 | src = src + self.dropout1(src2) 69 | src = self.norm1(src) 70 | 71 | # FFN 72 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 73 | src = src + self.dropout3(src2) 74 | src = self.norm2(src) 75 | 76 | return src 77 | 78 | 79 | class DeformableTransformerEncoder(nn.Module): 80 | def __init__(self, args, encoder_layer, num_layers): 81 | super().__init__() 82 | self.args = args 83 | self.layers = _get_clones(encoder_layer, num_layers) 84 | self.num_layers = num_layers 85 | assert num_layers == self.args.enc_layers 86 | 87 | @staticmethod 88 | def get_reference_points(spatial_shapes, valid_ratios, device): 89 | reference_points_list = [] 90 | for lvl, (H_, W_) in enumerate(spatial_shapes): 91 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 92 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 93 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) 94 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) 95 | ref = torch.stack((ref_x, ref_y), -1) 96 | reference_points_list.append(ref) 97 | reference_points = torch.cat(reference_points_list, 1) 98 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 99 | return reference_points 100 | 101 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): 102 | output = src 103 | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) 104 | for _, layer in enumerate(self.layers): 105 | output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) 106 | return output 107 | 108 | 109 | class DeformableTransformerEncoderLayer(nn.Module): 110 | def __init__(self, args, activation='relu'): 111 | super().__init__() 112 | 113 | self.args = args 114 | self.d_model = args.hidden_dim 115 | self.nheads = args.nheads 116 | self.num_queries = args.num_queries 117 | # Note: Multiscale encoder's dim_feedforward halved for memory efficiency 118 | self.dim_feedforward = args.dim_feedforward // 2 119 | self.dropout = args.dropout 120 | 121 | # Hard-coded Hyper-parameters 122 | self.n_feature_levels = 3 123 | self.n_points = 4 124 | 125 | # self attention 126 | from models.ops.modules import MSDeformAttn 127 | self.self_attn = MSDeformAttn(self.d_model, self.n_feature_levels, self.nheads, self.n_points) 128 | self.dropout1 = nn.Dropout(self.dropout) 129 | self.norm1 = nn.LayerNorm(self.d_model) 130 | 131 | # ffn 132 | self.linear1 = nn.Linear(self.d_model, self.dim_feedforward) 133 | self.activation = _get_activation_fn(activation) 134 | self.dropout2 = nn.Dropout(self.dropout) 135 | self.linear2 = nn.Linear(self.dim_feedforward, self.d_model) 136 | self.dropout3 = nn.Dropout(self.dropout) 137 | self.norm2 = nn.LayerNorm(self.d_model) 138 | 139 | @staticmethod 140 | def with_pos_embed(tensor, pos): 141 | return tensor if pos is None else tensor + pos 142 | 143 | def forward_ffn(self, src): 144 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 145 | src = src + self.dropout3(src2) 146 | src = self.norm2(src) 147 | return src 148 | 149 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): 150 | # self attention 151 | src2 = self.self_attn(self.with_pos_embed(src, pos), 152 | reference_points, 153 | src, 154 | spatial_shapes, 155 | level_start_index, 156 | padding_mask) 157 | src = src + self.dropout1(src2) 158 | src = self.norm1(src) 159 | 160 | # ffn 161 | src = self.forward_ffn(src) 162 | 163 | return src 164 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | Train and eval functions used in main.py 8 | """ 9 | import math 10 | import os 11 | import sys 12 | from typing import Iterable 13 | 14 | import torch 15 | 16 | import util.misc as utils 17 | from datasets.coco_eval import CocoEvaluator 18 | from datasets.panoptic_eval import PanopticEvaluator 19 | 20 | 21 | def train_one_epoch(model: torch.nn.Module, 22 | criterion: torch.nn.Module, 23 | data_loader: Iterable, 24 | optimizer: torch.optim.Optimizer, 25 | device: torch.device, 26 | epoch: int, 27 | max_norm: float = 0): 28 | model.train() 29 | criterion.train() 30 | metric_logger = utils.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 32 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 100 35 | 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | samples = samples.to(device) 38 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 39 | 40 | outputs = model(samples) 41 | loss_dict = criterion(outputs, targets) 42 | weight_dict = criterion.weight_dict 43 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 44 | 45 | # reduce losses over all GPUs for logging purposes 46 | loss_dict_reduced = utils.reduce_dict(loss_dict) 47 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} 48 | loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} 49 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 50 | 51 | loss_value = losses_reduced_scaled.item() 52 | 53 | if not math.isfinite(loss_value): 54 | print("Loss is {}.\n Training terminated.".format(loss_value)) 55 | print(loss_dict_reduced) 56 | sys.exit(1) 57 | 58 | optimizer.zero_grad() 59 | losses.backward() 60 | if max_norm > 0: 61 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 62 | optimizer.step() 63 | 64 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 65 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 66 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 67 | 68 | del samples 69 | del targets 70 | del outputs 71 | del loss_dict 72 | del loss_dict_reduced 73 | del loss_dict_reduced_unscaled 74 | del weight_dict 75 | del losses 76 | del losses_reduced_scaled 77 | 78 | # gather the stats from all processes 79 | metric_logger.synchronize_between_processes() 80 | print("Averaged stats:", metric_logger) 81 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 82 | 83 | 84 | @torch.no_grad() 85 | def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): 86 | model.eval() 87 | criterion.eval() 88 | 89 | metric_logger = utils.MetricLogger(delimiter=" ") 90 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 91 | header = 'Test:' 92 | 93 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) 94 | coco_evaluator = CocoEvaluator(base_ds, iou_types) 95 | 96 | panoptic_evaluator = None 97 | if 'panoptic' in postprocessors.keys(): 98 | panoptic_evaluator = PanopticEvaluator( 99 | data_loader.dataset.ann_file, 100 | data_loader.dataset.ann_folder, 101 | output_dir=os.path.join(output_dir, "panoptic_eval"), 102 | ) 103 | 104 | print_freq = 100 105 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 106 | samples = samples.to(device) 107 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 108 | 109 | outputs = model(samples) 110 | loss_dict = criterion(outputs, targets) 111 | weight_dict = criterion.weight_dict 112 | 113 | # reduce losses over all GPUs for logging purposes 114 | loss_dict_reduced = utils.reduce_dict(loss_dict) 115 | loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} 116 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} 117 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 118 | **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 119 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 120 | 121 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 122 | results = postprocessors['bbox'](outputs, orig_target_sizes) 123 | if 'segm' in postprocessors.keys(): 124 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 125 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) 126 | res = {target['image_id'].item(): output for target, output in zip(targets, results)} 127 | if coco_evaluator is not None: 128 | coco_evaluator.update(res) 129 | 130 | if panoptic_evaluator is not None: 131 | res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) 132 | for i, target in enumerate(targets): 133 | image_id = target["image_id"].item() 134 | file_name = f"{image_id:012d}.png" 135 | res_pano[i]["image_id"] = image_id 136 | res_pano[i]["file_name"] = file_name 137 | panoptic_evaluator.update(res_pano) 138 | 139 | # gather the stats from all processes 140 | metric_logger.synchronize_between_processes() 141 | print("Averaged stats:", metric_logger) 142 | if coco_evaluator is not None: 143 | coco_evaluator.synchronize_between_processes() 144 | if panoptic_evaluator is not None: 145 | panoptic_evaluator.synchronize_between_processes() 146 | 147 | # accumulate predictions from all images 148 | if coco_evaluator is not None: 149 | coco_evaluator.accumulate() 150 | coco_evaluator.summarize() 151 | panoptic_res = None 152 | if panoptic_evaluator is not None: 153 | panoptic_res = panoptic_evaluator.summarize() 154 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 155 | if coco_evaluator is not None: 156 | if 'bbox' in postprocessors.keys(): 157 | stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() 158 | if 'segm' in postprocessors.keys(): 159 | stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() 160 | if panoptic_res is not None: 161 | stats['PQ_all'] = panoptic_res["All"] 162 | stats['PQ_th'] = panoptic_res["Things"] 163 | stats['PQ_st'] = panoptic_res["Stuff"] 164 | 165 | del samples 166 | del targets 167 | del outputs 168 | del loss_dict 169 | del loss_dict_reduced 170 | del loss_dict_reduced_unscaled 171 | del weight_dict 172 | 173 | torch.cuda.empty_cache() 174 | 175 | return stats, coco_evaluator 176 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | Transforms and data augmentation for both image + bbox. 8 | """ 9 | import random 10 | 11 | import PIL 12 | import torch 13 | import torchvision.transforms as T 14 | import torchvision.transforms.functional as F 15 | 16 | from util.box_ops import box_xyxy_to_cxcywh 17 | from util.misc import interpolate 18 | 19 | 20 | def crop(image, target, region): 21 | cropped_image = F.crop(image, *region) 22 | 23 | target = target.copy() 24 | i, j, h, w = region 25 | 26 | # should we do something wrt the original size? 27 | target["size"] = torch.tensor([h, w]) 28 | 29 | fields = ["labels", "area", "iscrowd"] 30 | 31 | if "boxes" in target: 32 | boxes = target["boxes"] 33 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 34 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 35 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 36 | cropped_boxes = cropped_boxes.clamp(min=0) 37 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 38 | target["boxes"] = cropped_boxes.reshape(-1, 4) 39 | target["area"] = area 40 | fields.append("boxes") 41 | 42 | if "masks" in target: 43 | # FIXME should we update the area here if there are no boxes? 44 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 45 | fields.append("masks") 46 | 47 | # remove elements for which the boxes or masks that have zero area 48 | if "boxes" in target or "masks" in target: 49 | # favor boxes selection when defining which elements to keep 50 | # this is compatible with previous implementation 51 | if "boxes" in target: 52 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 53 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 54 | else: 55 | keep = target['masks'].flatten(1).any(1) 56 | 57 | for field in fields: 58 | target[field] = target[field][keep] 59 | 60 | return cropped_image, target 61 | 62 | 63 | def hflip(image, target): 64 | flipped_image = F.hflip(image) 65 | 66 | w, h = image.size 67 | 68 | target = target.copy() 69 | if "boxes" in target: 70 | boxes = target["boxes"] 71 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 72 | target["boxes"] = boxes 73 | 74 | if "masks" in target: 75 | target['masks'] = target['masks'].flip(-1) 76 | 77 | return flipped_image, target 78 | 79 | 80 | def resize(image, target, size, max_size=None): 81 | # size can be min_size (scalar) or (w, h) tuple 82 | 83 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 84 | w, h = image_size 85 | if max_size is not None: 86 | min_original_size = float(min((w, h))) 87 | max_original_size = float(max((w, h))) 88 | if max_original_size / min_original_size * size > max_size: 89 | size = int(round(max_size * min_original_size / max_original_size)) 90 | 91 | if (w <= h and w == size) or (h <= w and h == size): 92 | return (h, w) 93 | 94 | if w < h: 95 | ow = size 96 | oh = int(size * h / w) 97 | else: 98 | oh = size 99 | ow = int(size * w / h) 100 | 101 | return (oh, ow) 102 | 103 | def get_size(image_size, size, max_size=None): 104 | if isinstance(size, (list, tuple)): 105 | return size[::-1] 106 | else: 107 | return get_size_with_aspect_ratio(image_size, size, max_size) 108 | 109 | size = get_size(image.size, size, max_size) 110 | rescaled_image = F.resize(image, size) 111 | 112 | if target is None: 113 | return rescaled_image, None 114 | 115 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 116 | ratio_width, ratio_height = ratios 117 | 118 | target = target.copy() 119 | if "boxes" in target: 120 | boxes = target["boxes"] 121 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 122 | target["boxes"] = scaled_boxes 123 | 124 | if "area" in target: 125 | area = target["area"] 126 | scaled_area = area * (ratio_width * ratio_height) 127 | target["area"] = scaled_area 128 | 129 | h, w = size 130 | target["size"] = torch.tensor([h, w]) 131 | 132 | if "masks" in target: 133 | target['masks'] = interpolate( 134 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 135 | 136 | return rescaled_image, target 137 | 138 | 139 | def pad(image, target, padding): 140 | # assumes that we only pad on the bottom right corners 141 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 142 | if target is None: 143 | return padded_image, None 144 | target = target.copy() 145 | # should we do something wrt the original size? 146 | target["size"] = torch.tensor(padded_image.size[::-1]) 147 | if "masks" in target: 148 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 149 | return padded_image, target 150 | 151 | 152 | class RandomCrop(object): 153 | def __init__(self, size): 154 | self.size = size 155 | 156 | def __call__(self, img, target): 157 | region = T.RandomCrop.get_params(img, self.size) 158 | return crop(img, target, region) 159 | 160 | 161 | class RandomSizeCrop(object): 162 | def __init__(self, min_size: int, max_size: int): 163 | self.min_size = min_size 164 | self.max_size = max_size 165 | 166 | def __call__(self, img: PIL.Image.Image, target: dict): 167 | w = random.randint(self.min_size, min(img.width, self.max_size)) 168 | h = random.randint(self.min_size, min(img.height, self.max_size)) 169 | region = T.RandomCrop.get_params(img, [h, w]) 170 | return crop(img, target, region) 171 | 172 | 173 | class CenterCrop(object): 174 | def __init__(self, size): 175 | self.size = size 176 | 177 | def __call__(self, img, target): 178 | image_width, image_height = img.size 179 | crop_height, crop_width = self.size 180 | crop_top = int(round((image_height - crop_height) / 2.)) 181 | crop_left = int(round((image_width - crop_width) / 2.)) 182 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 183 | 184 | 185 | class RandomHorizontalFlip(object): 186 | def __init__(self, p=0.5): 187 | self.p = p 188 | 189 | def __call__(self, img, target): 190 | if random.random() < self.p: 191 | return hflip(img, target) 192 | return img, target 193 | 194 | 195 | class RandomResize(object): 196 | def __init__(self, sizes, max_size=None): 197 | assert isinstance(sizes, (list, tuple)) 198 | self.sizes = sizes 199 | self.max_size = max_size 200 | 201 | def __call__(self, img, target=None): 202 | size = random.choice(self.sizes) 203 | return resize(img, target, size, self.max_size) 204 | 205 | 206 | class RandomPad(object): 207 | def __init__(self, max_pad): 208 | self.max_pad = max_pad 209 | 210 | def __call__(self, img, target): 211 | pad_x = random.randint(0, self.max_pad) 212 | pad_y = random.randint(0, self.max_pad) 213 | return pad(img, target, (pad_x, pad_y)) 214 | 215 | 216 | class RandomSelect(object): 217 | """ 218 | Randomly selects between transforms1 and transforms2, 219 | with probability p for transforms1 and (1 - p) for transforms2 220 | """ 221 | def __init__(self, transforms1, transforms2, p=0.5): 222 | self.transforms1 = transforms1 223 | self.transforms2 = transforms2 224 | self.p = p 225 | 226 | def __call__(self, img, target): 227 | if random.random() < self.p: 228 | return self.transforms1(img, target) 229 | return self.transforms2(img, target) 230 | 231 | 232 | class ToTensor(object): 233 | def __call__(self, img, target): 234 | return F.to_tensor(img), target 235 | 236 | 237 | class RandomErasing(object): 238 | 239 | def __init__(self, *args, **kwargs): 240 | self.eraser = T.RandomErasing(*args, **kwargs) 241 | 242 | def __call__(self, img, target): 243 | return self.eraser(img), target 244 | 245 | 246 | class Normalize(object): 247 | def __init__(self, mean, std): 248 | self.mean = mean 249 | self.std = std 250 | 251 | def __call__(self, image, target=None): 252 | image = F.normalize(image, mean=self.mean, std=self.std) 253 | if target is None: 254 | return image, None 255 | target = target.copy() 256 | h, w = image.shape[-2:] 257 | if "boxes" in target: 258 | boxes = target["boxes"] 259 | boxes = box_xyxy_to_cxcywh(boxes) 260 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 261 | target["boxes"] = boxes 262 | return image, target 263 | 264 | 265 | class Compose(object): 266 | def __init__(self, transforms): 267 | self.transforms = transforms 268 | 269 | def __call__(self, image, target): 270 | for t in self.transforms: 271 | image, target = t(image, target) 272 | return image, target 273 | 274 | def __repr__(self): 275 | format_string = self.__class__.__name__ + "(" 276 | for t in self.transforms: 277 | format_string += "\n" 278 | format_string += " {0}".format(t) 279 | format_string += "\n)" 280 | return format_string 281 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from models.transformer_encoder import TransformerEncoder, TransformerEncoderLayer 10 | from models.transformer_decoder import TransformerDecoder, TransformerDecoderLayer 11 | 12 | 13 | class Transformer(nn.Module): 14 | def __init__(self, args, activation="relu"): 15 | super().__init__() 16 | self.args = args 17 | self.multiscale = args.multiscale 18 | self.d_model = args.hidden_dim 19 | self.nheads = args.nheads 20 | self.num_queries = args.num_queries 21 | self.enc_layers = args.enc_layers 22 | self.dec_layers = args.dec_layers 23 | self.dim_feedforward = args.dim_feedforward 24 | self.dropout = args.dropout 25 | 26 | if self.multiscale: 27 | # Reminder: To use multiscale SAM-DETR, you need to compile CUDA operators for Deformable Attention. 28 | from models.transformer_encoder import DeformableTransformerEncoder, DeformableTransformerEncoderLayer 29 | self.num_feature_levels = 3 # Hard-coded multiscale parameters 30 | # Use Deformable Attention in Transformer Encoder for efficient computation of multiscale features 31 | encoder_layer = DeformableTransformerEncoderLayer(args, activation) 32 | self.encoder = DeformableTransformerEncoder(args, encoder_layer, self.enc_layers) 33 | self.level_embed = nn.Parameter(torch.Tensor(self.num_feature_levels, self.d_model)) 34 | else: 35 | encoder_layer = TransformerEncoderLayer(args, activation) 36 | self.encoder = TransformerEncoder(args, encoder_layer, self.enc_layers) 37 | 38 | decoder_layer = TransformerDecoderLayer(args, activation) 39 | self.decoder = TransformerDecoder(args, decoder_layer, self.dec_layers) 40 | 41 | self._reset_parameters() 42 | 43 | def _reset_parameters(self): 44 | for p in self.parameters(): 45 | if p.dim() > 1: 46 | nn.init.xavier_uniform_(p) 47 | if self.multiscale: 48 | from models.ops.modules import MSDeformAttn 49 | for m in self.modules(): 50 | if isinstance(m, MSDeformAttn): 51 | m._reset_parameters() 52 | nn.init.normal_(self.level_embed) 53 | 54 | def get_valid_ratio(self, mask): 55 | _, H, W = mask.shape 56 | valid_H = torch.sum(~mask[:, :, 0], 1) 57 | valid_W = torch.sum(~mask[:, 0, :], 1) 58 | valid_ratio_h = valid_H.float() / H 59 | valid_ratio_w = valid_W.float() / W 60 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 61 | return valid_ratio 62 | 63 | def forward(self, srcs, masks, query_embed, pos_embeds): 64 | if self.multiscale: 65 | return self.forward_multi_scale(srcs, masks, query_embed, pos_embeds) 66 | else: 67 | return self.forward_single_scale(srcs[0], masks[0], query_embed, pos_embeds[0]) 68 | 69 | def forward_single_scale(self, src, mask, query_embed, pos): 70 | bs, c, memory_h, memory_w = src.shape 71 | 72 | if self.args.smca: 73 | grid_y, grid_x = torch.meshgrid(torch.arange(0, memory_h), torch.arange(0, memory_w)) 74 | grid = torch.stack((grid_x, grid_y), 2).float().to(src.device) 75 | grid = grid.reshape(-1, 2).unsqueeze(1).repeat(1, bs * self.nheads, 1) 76 | else: 77 | grid = None 78 | 79 | src = src.flatten(2).permute(2, 0, 1) # flatten NxCxHxW to HWxNxC 80 | pos = pos.flatten(2).permute(2, 0, 1) # flatten NxCxHxW to HWxNxC 81 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 82 | mask = mask.flatten(1) 83 | 84 | tgt = torch.zeros(self.num_queries, bs, c, device=query_embed.device) 85 | 86 | # encoder 87 | memory = self.encoder(src, 88 | src_key_padding_mask=mask, 89 | pos=pos) 90 | 91 | # decoder 92 | hs, references = self.decoder(tgt, 93 | memory, 94 | memory_key_padding_mask=mask, 95 | pos=pos, 96 | query_pos=query_embed, 97 | memory_h=memory_h, 98 | memory_w=memory_w, 99 | grid=grid) 100 | return hs, references 101 | 102 | def forward_multi_scale(self, srcs, masks, query_embed, pos_embeds): 103 | 104 | bs, c, h_16, w_16 = srcs[0].shape 105 | bs, c, h_32, w_32 = srcs[1].shape 106 | bs, c, h_64, w_64 = srcs[2].shape 107 | 108 | src_16 = srcs[0].flatten(2).permute(2, 0, 1) 109 | orig_pos_embed_16 = pos_embeds[0] + self.level_embed[0].unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 110 | pos_embed_16 = orig_pos_embed_16.flatten(2).permute(2, 0, 1) 111 | mask_16 = masks[0].flatten(1) 112 | 113 | src_32 = srcs[1].flatten(2).permute(2, 0, 1) 114 | orig_pos_embed_32 = pos_embeds[1] + self.level_embed[1].unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 115 | pos_embed_32 = orig_pos_embed_32.flatten(2).permute(2, 0, 1) 116 | mask_32 = masks[1].flatten(1) 117 | 118 | src_64 = srcs[2].flatten(2).permute(2, 0, 1) 119 | orig_pos_embed_64 = pos_embeds[2] + self.level_embed[2].unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 120 | pos_embed_64 = orig_pos_embed_64.flatten(2).permute(2, 0, 1) 121 | mask_64 = masks[2].flatten(1) 122 | 123 | # prepare input for encoder 124 | src_flatten = [] 125 | mask_flatten = [] 126 | lvl_pos_embed_flatten = [] 127 | spatial_shapes = [] 128 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 129 | bs, _, h, w = src.shape 130 | spatial_shape = (h, w) 131 | spatial_shapes.append(spatial_shape) 132 | src = src.flatten(2).transpose(1, 2) 133 | mask = mask.flatten(1) 134 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 135 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 136 | lvl_pos_embed_flatten.append(lvl_pos_embed) 137 | src_flatten.append(src) 138 | mask_flatten.append(mask) 139 | src_flatten = torch.cat(src_flatten, 1) 140 | mask_flatten = torch.cat(mask_flatten, 1) 141 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 142 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) 143 | level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) 144 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 145 | 146 | # encoder 147 | memory = self.encoder(src_flatten, 148 | spatial_shapes, 149 | level_start_index, 150 | valid_ratios, 151 | lvl_pos_embed_flatten, 152 | mask_flatten) 153 | 154 | # prepare input for decoder 155 | tgt = torch.zeros(self.num_queries, bs, self.d_model, device=query_embed.device) 156 | mem1 = memory[:, level_start_index[0]: level_start_index[1], :] 157 | mem2 = memory[:, level_start_index[1]: level_start_index[2], :] 158 | mem3 = memory[:, level_start_index[2]:, :] 159 | memory_flatten = [] 160 | for m in [mem1, mem2, mem3]: 161 | memory_flatten.append(m.permute(1, 0, 2)) 162 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 163 | 164 | grid_y_16, grid_x_16 = torch.meshgrid(torch.arange(0, h_16), torch.arange(0, w_16)) 165 | grid_16 = torch.stack((grid_x_16, grid_y_16), 2).float() 166 | grid_16.requires_grad = False 167 | grid_16 = grid_16.type_as(srcs[0]) 168 | grid_16 = grid_16.unsqueeze(0).permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1) 169 | grid_16 = grid_16.repeat(1, bs * 8, 1) 170 | 171 | grid_y_32, grid_x_32 = torch.meshgrid(torch.arange(0, h_32), torch.arange(0, w_32)) 172 | grid_32 = torch.stack((grid_x_32, grid_y_32), 2).float() 173 | grid_32.requires_grad = False 174 | grid_32 = grid_32.type_as(srcs[0]) 175 | grid_32 = grid_32.unsqueeze(0).permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1) 176 | grid_32 = grid_32.repeat(1, bs * 8, 1) 177 | 178 | grid_y_64, grid_x_64 = torch.meshgrid(torch.arange(0, h_64), torch.arange(0, w_64)) 179 | grid_64 = torch.stack((grid_x_64, grid_y_64), 2).float() 180 | grid_64.requires_grad = False 181 | grid_64 = grid_64.type_as(srcs[0]) 182 | grid_64 = grid_64.unsqueeze(0).permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1) 183 | grid_64 = grid_64.repeat(1, bs * 8, 1) 184 | 185 | # decoder 186 | hs, references = self.decoder(tgt, 187 | [memory_flatten[0], memory_flatten[1], memory_flatten[2]], 188 | memory_key_padding_mask=[mask_16, mask_32, mask_64], 189 | pos=[pos_embed_16, pos_embed_32, pos_embed_64], 190 | query_pos=query_embed, 191 | memory_h=[h_16, h_32, h_64], 192 | memory_w=[w_16, w_32, w_64], 193 | grid=[grid_16, grid_32, grid_64]) 194 | 195 | return hs, references 196 | 197 | 198 | def build_transformer(args): 199 | return Transformer(args, activation="relu") 200 | -------------------------------------------------------------------------------- /datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | COCO evaluator that works in distributed mode. 8 | 9 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 10 | The difference is that there is less copy-pasting from pycocotools 11 | in the end of the file, as python3 can suppress prints with contextlib 12 | """ 13 | import os 14 | import contextlib 15 | import copy 16 | import numpy as np 17 | import torch 18 | 19 | from pycocotools.cocoeval import COCOeval 20 | from pycocotools.coco import COCO 21 | import pycocotools.mask as mask_util 22 | 23 | from util.misc import all_gather 24 | 25 | 26 | class CocoEvaluator(object): 27 | def __init__(self, coco_gt, iou_types): 28 | assert isinstance(iou_types, (list, tuple)) 29 | coco_gt = copy.deepcopy(coco_gt) 30 | self.coco_gt = coco_gt 31 | 32 | self.iou_types = iou_types 33 | self.coco_eval = {} 34 | for iou_type in iou_types: 35 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 36 | 37 | self.img_ids = [] 38 | self.eval_imgs = {k: [] for k in iou_types} 39 | 40 | def update(self, predictions): 41 | img_ids = list(np.unique(list(predictions.keys()))) 42 | self.img_ids.extend(img_ids) 43 | 44 | for iou_type in self.iou_types: 45 | results = self.prepare(predictions, iou_type) 46 | 47 | # suppress pycocotools prints 48 | with open(os.devnull, 'w') as devnull: 49 | with contextlib.redirect_stdout(devnull): 50 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 51 | coco_eval = self.coco_eval[iou_type] 52 | 53 | coco_eval.cocoDt = coco_dt 54 | coco_eval.params.imgIds = list(img_ids) 55 | img_ids, eval_imgs = evaluate(coco_eval) 56 | 57 | self.eval_imgs[iou_type].append(eval_imgs) 58 | 59 | def synchronize_between_processes(self): 60 | for iou_type in self.iou_types: 61 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 62 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 63 | 64 | def accumulate(self): 65 | for coco_eval in self.coco_eval.values(): 66 | coco_eval.accumulate() 67 | 68 | def summarize(self): 69 | for iou_type, coco_eval in self.coco_eval.items(): 70 | print("IoU metric: {}".format(iou_type)) 71 | coco_eval.summarize() 72 | 73 | def prepare(self, predictions, iou_type): 74 | if iou_type == "bbox": 75 | return self.prepare_for_coco_detection(predictions) 76 | elif iou_type == "segm": 77 | return self.prepare_for_coco_segmentation(predictions) 78 | elif iou_type == "keypoints": 79 | return self.prepare_for_coco_keypoint(predictions) 80 | else: 81 | raise ValueError("Unknown iou type {}".format(iou_type)) 82 | 83 | def prepare_for_coco_detection(self, predictions): 84 | coco_results = [] 85 | for original_id, prediction in predictions.items(): 86 | if len(prediction) == 0: 87 | continue 88 | 89 | boxes = prediction["boxes"] 90 | boxes = convert_to_xywh(boxes).tolist() 91 | scores = prediction["scores"].tolist() 92 | labels = prediction["labels"].tolist() 93 | 94 | coco_results.extend( 95 | [ 96 | { 97 | "image_id": original_id, 98 | "category_id": labels[k], 99 | "bbox": box, 100 | "score": scores[k], 101 | } 102 | for k, box in enumerate(boxes) 103 | ] 104 | ) 105 | return coco_results 106 | 107 | def prepare_for_coco_segmentation(self, predictions): 108 | coco_results = [] 109 | for original_id, prediction in predictions.items(): 110 | if len(prediction) == 0: 111 | continue 112 | 113 | scores = prediction["scores"] 114 | labels = prediction["labels"] 115 | masks = prediction["masks"] 116 | 117 | masks = masks > 0.5 118 | 119 | scores = prediction["scores"].tolist() 120 | labels = prediction["labels"].tolist() 121 | 122 | rles = [ 123 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 124 | for mask in masks 125 | ] 126 | for rle in rles: 127 | rle["counts"] = rle["counts"].decode("utf-8") 128 | 129 | coco_results.extend( 130 | [ 131 | { 132 | "image_id": original_id, 133 | "category_id": labels[k], 134 | "segmentation": rle, 135 | "score": scores[k], 136 | } 137 | for k, rle in enumerate(rles) 138 | ] 139 | ) 140 | return coco_results 141 | 142 | def prepare_for_coco_keypoint(self, predictions): 143 | coco_results = [] 144 | for original_id, prediction in predictions.items(): 145 | if len(prediction) == 0: 146 | continue 147 | 148 | boxes = prediction["boxes"] 149 | boxes = convert_to_xywh(boxes).tolist() 150 | scores = prediction["scores"].tolist() 151 | labels = prediction["labels"].tolist() 152 | keypoints = prediction["keypoints"] 153 | keypoints = keypoints.flatten(start_dim=1).tolist() 154 | 155 | coco_results.extend( 156 | [ 157 | { 158 | "image_id": original_id, 159 | "category_id": labels[k], 160 | 'keypoints': keypoint, 161 | "score": scores[k], 162 | } 163 | for k, keypoint in enumerate(keypoints) 164 | ] 165 | ) 166 | return coco_results 167 | 168 | 169 | def convert_to_xywh(boxes): 170 | xmin, ymin, xmax, ymax = boxes.unbind(1) 171 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 172 | 173 | 174 | def merge(img_ids, eval_imgs): 175 | all_img_ids = all_gather(img_ids) 176 | all_eval_imgs = all_gather(eval_imgs) 177 | 178 | merged_img_ids = [] 179 | for p in all_img_ids: 180 | merged_img_ids.extend(p) 181 | 182 | merged_eval_imgs = [] 183 | for p in all_eval_imgs: 184 | merged_eval_imgs.append(p) 185 | 186 | merged_img_ids = np.array(merged_img_ids) 187 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 188 | 189 | # keep only unique (and in sorted order) images 190 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 191 | merged_eval_imgs = merged_eval_imgs[..., idx] 192 | 193 | return merged_img_ids, merged_eval_imgs 194 | 195 | 196 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 197 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 198 | img_ids = list(img_ids) 199 | eval_imgs = list(eval_imgs.flatten()) 200 | 201 | coco_eval.evalImgs = eval_imgs 202 | coco_eval.params.imgIds = img_ids 203 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 204 | 205 | 206 | ################################################################# 207 | # From pycocotools, just removed the prints and fixed 208 | # a Python3 bug about unicode not defined 209 | ################################################################# 210 | 211 | 212 | def evaluate(self): 213 | ''' 214 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 215 | :return: None 216 | ''' 217 | # tic = time.time() 218 | # print('Running per image evaluation...') 219 | p = self.params 220 | # add backward compatibility if useSegm is specified in params 221 | if p.useSegm is not None: 222 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 223 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 224 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 225 | p.imgIds = list(np.unique(p.imgIds)) 226 | if p.useCats: 227 | p.catIds = list(np.unique(p.catIds)) 228 | p.maxDets = sorted(p.maxDets) 229 | self.params = p 230 | 231 | self._prepare() 232 | # loop through images, area range, max detection number 233 | catIds = p.catIds if p.useCats else [-1] 234 | 235 | if p.iouType == 'segm' or p.iouType == 'bbox': 236 | computeIoU = self.computeIoU 237 | elif p.iouType == 'keypoints': 238 | computeIoU = self.computeOks 239 | self.ious = { 240 | (imgId, catId): computeIoU(imgId, catId) 241 | for imgId in p.imgIds 242 | for catId in catIds} 243 | 244 | evaluateImg = self.evaluateImg 245 | maxDet = p.maxDets[-1] 246 | evalImgs = [ 247 | evaluateImg(imgId, catId, areaRng, maxDet) 248 | for catId in catIds 249 | for areaRng in p.areaRng 250 | for imgId in p.imgIds 251 | ] 252 | # this is NOT in the pycocotools code, but could be done outside 253 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 254 | self._paramsEval = copy.deepcopy(self.params) 255 | # toc = time.time() 256 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 257 | return p.imgIds, evalImgs 258 | 259 | ################################################################# 260 | # end of straight copy from pycocotools, just removing the prints 261 | ################################################################# 262 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import os 7 | import argparse 8 | import random 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | from PIL import Image 13 | import torch 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | 16 | import datasets 17 | import util.misc as utils 18 | from datasets import build_dataset, get_coco_api_from_dataset 19 | from datasets.coco import make_coco_transforms 20 | from models import build_model 21 | 22 | 23 | def get_args_parser(): 24 | parser = argparse.ArgumentParser('SAM-DETR: Accelerating DETR Convergence via Semantic-Aligned Matching', add_help=False) 25 | parser.add_argument('--lr', default=1e-4, type=float) 26 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 27 | parser.add_argument('--lr_linear_proj_names', default=[], type=str, nargs='+') 28 | parser.add_argument('--lr_linear_proj_mult', default=0.1, type=float) 29 | parser.add_argument('--batch_size', default=1, type=int) 30 | parser.add_argument('--weight_decay', default=1e-4, type=float) 31 | parser.add_argument('--epochs', default=50, type=int) 32 | parser.add_argument('--lr_drop', default=40, type=int) 33 | parser.add_argument('--clip_max_norm', default=0.1, type=float, help='gradient clipping max norm') 34 | 35 | # Model parameters 36 | parser.add_argument('--frozen_weights', type=str, default=None, 37 | help="Path to the pretrained model. If set, only the mask head will be trained") 38 | parser.add_argument('--multiscale', default=False, action='store_true') 39 | # * Backbone 40 | parser.add_argument('--backbone', default='resnet50', type=str, 41 | help="Name of the convolutional backbone to use") 42 | parser.add_argument('--dilation', action='store_true', 43 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 44 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine',), 45 | help="Type of positional embedding to use on top of the image features") 46 | 47 | # * Transformer 48 | parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer") 49 | parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer") 50 | parser.add_argument('--dim_feedforward', default=2048, type=int, help="dimension of the FFN in the transformer") 51 | parser.add_argument('--hidden_dim', default=256, type=int, help="dimension of the transformer") 52 | parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") 53 | parser.add_argument('--nheads', default=8, type=int, help="Number of attention heads in the transformer attention") 54 | parser.add_argument('--num_queries', default=300, type=int, help="Number of query slots") 55 | 56 | parser.add_argument('--smca', default=False, action='store_true') 57 | 58 | # * Segmentation 59 | parser.add_argument('--masks', action='store_true', help="Train segmentation head if the flag is provided") 60 | 61 | # Loss 62 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 63 | help="Disables auxiliary decoding losses (loss at each layer)") 64 | 65 | # * Matcher 66 | parser.add_argument('--set_cost_class', default=2.0, type=float, help="Class coefficient in the matching cost") 67 | parser.add_argument('--set_cost_bbox', default=5.0, type=float, help="L1 box coefficient in the matching cost") 68 | parser.add_argument('--set_cost_giou', default=2.0, type=float, help="giou box coefficient in the matching cost") 69 | 70 | # * Loss coefficients 71 | parser.add_argument('--mask_loss_coef', default=1.0, type=float) 72 | parser.add_argument('--dice_loss_coef', default=1.0, type=float) 73 | parser.add_argument('--cls_loss_coef', default=2.0, type=float) 74 | parser.add_argument('--bbox_loss_coef', default=5.0, type=float) 75 | parser.add_argument('--giou_loss_coef', default=2.0, type=float) 76 | parser.add_argument('--focal_alpha', default=0.25, type=float) 77 | 78 | # dataset parameters 79 | parser.add_argument('--dataset_file', default='coco') 80 | parser.add_argument('--coco_path', type=str, default='data/coco') 81 | parser.add_argument('--coco_panoptic_path', type=str) 82 | parser.add_argument('--remove_difficult', action='store_true') 83 | 84 | parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') 85 | parser.add_argument('--device', default='cuda', help='device to use for training / testing. We must use cuda.') 86 | parser.add_argument('--seed', default=42, type=int) 87 | parser.add_argument('--resume', default='', help='resume from checkpoint, empty for training from scratch') 88 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') 89 | parser.add_argument('--eval', action='store_true') 90 | parser.add_argument('--eval_every_epoch', default=1, type=int, help='eval every ? epoch') 91 | parser.add_argument('--save_every_epoch', default=1, type=int, help='save model weights every ? epoch') 92 | parser.add_argument('--num_workers', default=2, type=int) 93 | 94 | # distributed training parameters 95 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 96 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 97 | 98 | return parser 99 | 100 | 101 | def main(args): 102 | 103 | utils.init_distributed_mode(args) 104 | 105 | if args.frozen_weights is not None: 106 | assert args.masks, "Frozen training is meant for segmentation only." 107 | print(args) 108 | 109 | device = torch.device(args.device) 110 | 111 | # fix the seed for reproducibility 112 | seed = args.seed + utils.get_rank() 113 | torch.manual_seed(seed) 114 | np.random.seed(seed) 115 | random.seed(seed) 116 | 117 | model, criterion, post_processors = build_model(args) 118 | model.to(device) 119 | 120 | model_without_ddp = model 121 | if args.distributed: 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 123 | model_without_ddp = model.module 124 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 125 | print('Total number of params in model: ', n_parameters) 126 | 127 | def match_keywords(n, name_keywords): 128 | out = False 129 | for b in name_keywords: 130 | if b in n: 131 | out = True 132 | break 133 | return out 134 | 135 | param_dicts = [ 136 | { 137 | "params": 138 | [p for n, p in model_without_ddp.named_parameters() 139 | if "backbone.0" not in n and not match_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 140 | "lr": args.lr, 141 | }, 142 | { 143 | "params": [p for n, p in model_without_ddp.named_parameters() 144 | if "backbone.0" in n and p.requires_grad], 145 | "lr": args.lr_backbone, 146 | }, 147 | { 148 | "params": [p for n, p in model_without_ddp.named_parameters() 149 | if match_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 150 | "lr": args.lr * args.lr_linear_proj_mult, 151 | } 152 | ] 153 | 154 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) 155 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 156 | 157 | # dataset_train = build_dataset(image_set='train', args=args) 158 | # dataset_val = build_dataset(image_set='val', args=args) 159 | # 160 | # if args.distributed: 161 | # sampler_train = DistributedSampler(dataset_train) 162 | # sampler_val = DistributedSampler(dataset_val, shuffle=False) 163 | # else: 164 | # sampler_train = torch.utils.data.RandomSampler(dataset_train) 165 | # sampler_val = torch.utils.data.SequentialSampler(dataset_val) 166 | # 167 | # batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) 168 | 169 | # data_loader_train = DataLoader(dataset_train, 170 | # batch_sampler=batch_sampler_train, 171 | # collate_fn=utils.collate_fn, 172 | # num_workers=args.num_workers) 173 | # 174 | # data_loader_val = DataLoader(dataset_val, 175 | # args.batch_size, 176 | # sampler=sampler_val, 177 | # drop_last=False, 178 | # collate_fn=utils.collate_fn, 179 | # num_workers=args.num_workers) 180 | # 181 | # if args.dataset_file == "coco_panoptic": 182 | # # We also evaluate AP during panoptic training, on original coco DS 183 | # coco_val = datasets.coco.build("val", args) 184 | # base_ds = get_coco_api_from_dataset(coco_val) 185 | # else: 186 | # base_ds = get_coco_api_from_dataset(dataset_val) 187 | 188 | if args.frozen_weights is not None: 189 | checkpoint = torch.load(args.frozen_weights, map_location='cpu') 190 | model_without_ddp.detr.load_state_dict(checkpoint['model']) 191 | 192 | output_dir = Path(args.output_dir) 193 | if args.resume: 194 | if args.resume.startswith('https'): 195 | checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) 196 | else: 197 | checkpoint = torch.load(args.resume, map_location='cpu') 198 | model_without_ddp.load_state_dict(checkpoint['model']) 199 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 200 | optimizer.load_state_dict(checkpoint['optimizer']) 201 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 202 | args.start_epoch = checkpoint['epoch'] + 1 203 | 204 | transforms = make_coco_transforms("val") 205 | DETECTION_THRESHOLD = 0.5 206 | inference_dir = "./images/" 207 | image_dirs = os.listdir(inference_dir) 208 | image_dirs = [filename for filename in image_dirs if filename.endswith(".jpg") and 'det_res' not in filename] 209 | model.eval() 210 | with torch.no_grad(): 211 | for image_dir in image_dirs: 212 | img = Image.open(os.path.join(inference_dir, image_dir)).convert("RGB") 213 | w, h = img.size 214 | orig_target_sizes = torch.tensor([[h, w]], device=device) 215 | img, _ = transforms(img, target=None) 216 | img = img.to(device) 217 | img = img.unsqueeze(0) # adding batch dimension 218 | outputs = model(img) 219 | results = post_processors['bbox'](outputs, orig_target_sizes)[0] 220 | indexes = results['scores'] >= DETECTION_THRESHOLD 221 | scores = results['scores'][indexes] 222 | labels = results['labels'][indexes] 223 | boxes = results['boxes'][indexes] 224 | 225 | # Visualize the detection results 226 | import cv2 227 | img_det_result = cv2.imread(os.path.join(inference_dir, image_dir)) 228 | for i in range(scores.shape[0]): 229 | x1, y1, x2, y2 = round(float(boxes[i, 0])), round(float(boxes[i, 1])), round(float(boxes[i, 2])), round(float(boxes[i, 3])) 230 | img_det_result = cv2.rectangle(img_det_result, (x1, y1), (x2, y2), (0, 0, 255), 2) 231 | cv2.imwrite(os.path.join(inference_dir, "det_res_" + image_dir), img_det_result) 232 | 233 | 234 | if __name__ == '__main__': 235 | parser = argparse.ArgumentParser("SAM-DETR", parents=[get_args_parser()]) 236 | args = parser.parse_args() 237 | if args.output_dir: 238 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 239 | main(args) 240 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | import argparse 7 | import datetime 8 | import json 9 | import random 10 | import time 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import torch 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | 17 | import datasets 18 | import util.misc as utils 19 | from datasets import build_dataset, get_coco_api_from_dataset 20 | from engine import evaluate, train_one_epoch 21 | from models import build_model 22 | 23 | 24 | def get_args_parser(): 25 | parser = argparse.ArgumentParser('SAM-DETR: Accelerating DETR Convergence via Semantic-Aligned Matching', add_help=False) 26 | parser.add_argument('--lr', default=1e-4, type=float) 27 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 28 | parser.add_argument('--lr_linear_proj_names', default=[], type=str, nargs='+') 29 | parser.add_argument('--lr_linear_proj_mult', default=0.1, type=float) 30 | parser.add_argument('--batch_size', default=2, type=int) 31 | parser.add_argument('--weight_decay', default=1e-4, type=float) 32 | parser.add_argument('--epochs', default=50, type=int) 33 | parser.add_argument('--lr_drop', default=40, type=int) 34 | parser.add_argument('--clip_max_norm', default=0.1, type=float, help='gradient clipping max norm') 35 | 36 | # Model parameters 37 | parser.add_argument('--frozen_weights', type=str, default=None, 38 | help="Path to the pretrained model. If set, only the mask head will be trained") 39 | parser.add_argument('--multiscale', default=False, action='store_true') 40 | # * Backbone 41 | parser.add_argument('--backbone', default='resnet50', type=str, 42 | help="Name of the convolutional backbone to use") 43 | parser.add_argument('--dilation', action='store_true', 44 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 45 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine',), 46 | help="Type of positional embedding to use on top of the image features") 47 | 48 | # * Transformer 49 | parser.add_argument('--enc_layers', default=6, type=int, help="Number of encoding layers in the transformer") 50 | parser.add_argument('--dec_layers', default=6, type=int, help="Number of decoding layers in the transformer") 51 | parser.add_argument('--dim_feedforward', default=2048, type=int, help="dimension of the FFN in the transformer") 52 | parser.add_argument('--hidden_dim', default=256, type=int, help="dimension of the transformer") 53 | parser.add_argument('--dropout', default=0.1, type=float, help="Dropout applied in the transformer") 54 | parser.add_argument('--nheads', default=8, type=int, help="Number of attention heads in the transformer attention") 55 | parser.add_argument('--num_queries', default=300, type=int, help="Number of query slots") 56 | 57 | parser.add_argument('--smca', default=False, action='store_true') 58 | 59 | # * Segmentation 60 | parser.add_argument('--masks', action='store_true', help="Train segmentation head if the flag is provided") 61 | 62 | # Loss 63 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 64 | help="Disables auxiliary decoding losses (loss at each layer)") 65 | 66 | # * Matcher 67 | parser.add_argument('--set_cost_class', default=2.0, type=float, help="Class coefficient in the matching cost") 68 | parser.add_argument('--set_cost_bbox', default=5.0, type=float, help="L1 box coefficient in the matching cost") 69 | parser.add_argument('--set_cost_giou', default=2.0, type=float, help="giou box coefficient in the matching cost") 70 | 71 | # * Loss coefficients 72 | parser.add_argument('--mask_loss_coef', default=1.0, type=float) 73 | parser.add_argument('--dice_loss_coef', default=1.0, type=float) 74 | parser.add_argument('--cls_loss_coef', default=2.0, type=float) 75 | parser.add_argument('--bbox_loss_coef', default=5.0, type=float) 76 | parser.add_argument('--giou_loss_coef', default=2.0, type=float) 77 | parser.add_argument('--focal_alpha', default=0.25, type=float) 78 | 79 | # dataset parameters 80 | parser.add_argument('--dataset_file', default='coco') 81 | parser.add_argument('--coco_path', type=str, default='data/coco') 82 | parser.add_argument('--coco_panoptic_path', type=str) 83 | parser.add_argument('--remove_difficult', action='store_true') 84 | 85 | parser.add_argument('--output_dir', default='', help='path where to save, empty for no saving') 86 | parser.add_argument('--device', default='cuda', help='device to use for training / testing. We must use cuda.') 87 | parser.add_argument('--seed', default=42, type=int) 88 | parser.add_argument('--resume', default='', help='resume from checkpoint, empty for training from scratch') 89 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') 90 | parser.add_argument('--eval', action='store_true') 91 | parser.add_argument('--eval_every_epoch', default=1, type=int, help='eval every ? epoch') 92 | parser.add_argument('--save_every_epoch', default=1, type=int, help='save model weights every ? epoch') 93 | parser.add_argument('--num_workers', default=2, type=int) 94 | 95 | # distributed training parameters 96 | parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes') 97 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 98 | 99 | return parser 100 | 101 | 102 | def main(args): 103 | utils.init_distributed_mode(args) 104 | 105 | if args.frozen_weights is not None: 106 | assert args.masks, "Frozen training is meant for segmentation only." 107 | print(args) 108 | 109 | device = torch.device(args.device) 110 | 111 | # fix the seed for reproducibility 112 | seed = args.seed + utils.get_rank() 113 | torch.manual_seed(seed) 114 | np.random.seed(seed) 115 | random.seed(seed) 116 | 117 | model, criterion, post_processors = build_model(args) 118 | model.to(device) 119 | 120 | model_without_ddp = model 121 | if args.distributed: 122 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=False) 123 | model_without_ddp = model.module 124 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 125 | print('Total number of params in model: ', n_parameters) 126 | 127 | def match_keywords(n, name_keywords): 128 | out = False 129 | for b in name_keywords: 130 | if b in n: 131 | out = True 132 | break 133 | return out 134 | 135 | param_dicts = [ 136 | { 137 | "params": 138 | [p for n, p in model_without_ddp.named_parameters() 139 | if "backbone.0" not in n and not match_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 140 | "lr": args.lr, 141 | }, 142 | { 143 | "params": [p for n, p in model_without_ddp.named_parameters() 144 | if "backbone.0" in n and p.requires_grad], 145 | "lr": args.lr_backbone, 146 | }, 147 | { 148 | "params": [p for n, p in model_without_ddp.named_parameters() 149 | if match_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 150 | "lr": args.lr * args.lr_linear_proj_mult, 151 | } 152 | ] 153 | 154 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) 155 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 156 | 157 | dataset_train = build_dataset(image_set='train', args=args) 158 | dataset_val = build_dataset(image_set='val', args=args) 159 | 160 | if args.distributed: 161 | sampler_train = DistributedSampler(dataset_train) 162 | sampler_val = DistributedSampler(dataset_val, shuffle=False) 163 | else: 164 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 165 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 166 | 167 | batch_sampler_train = torch.utils.data.BatchSampler(sampler_train, args.batch_size, drop_last=True) 168 | 169 | data_loader_train = DataLoader(dataset_train, 170 | batch_sampler=batch_sampler_train, 171 | collate_fn=utils.collate_fn, 172 | num_workers=args.num_workers) 173 | 174 | data_loader_val = DataLoader(dataset_val, 175 | args.batch_size, 176 | sampler=sampler_val, 177 | drop_last=False, 178 | collate_fn=utils.collate_fn, 179 | num_workers=args.num_workers) 180 | 181 | if args.dataset_file == "coco_panoptic": 182 | # We also evaluate AP during panoptic training, on original coco DS 183 | coco_val = datasets.coco.build("val", args) 184 | base_ds = get_coco_api_from_dataset(coco_val) 185 | else: 186 | base_ds = get_coco_api_from_dataset(dataset_val) 187 | 188 | if args.frozen_weights is not None: 189 | checkpoint = torch.load(args.frozen_weights, map_location='cpu') 190 | model_without_ddp.detr.load_state_dict(checkpoint['model']) 191 | 192 | output_dir = Path(args.output_dir) 193 | if args.resume: 194 | if args.resume.startswith('https'): 195 | checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True) 196 | else: 197 | checkpoint = torch.load(args.resume, map_location='cpu') 198 | model_without_ddp.load_state_dict(checkpoint['model']) 199 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 200 | optimizer.load_state_dict(checkpoint['optimizer']) 201 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 202 | args.start_epoch = checkpoint['epoch'] + 1 203 | 204 | if args.eval: 205 | test_stats, coco_evaluator = evaluate( 206 | model, criterion, post_processors, data_loader_val, base_ds, device, args.output_dir 207 | ) 208 | if args.output_dir: 209 | utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") 210 | return 211 | 212 | print("Start training...") 213 | start_time = time.time() 214 | 215 | for epoch in range(args.start_epoch, args.epochs): 216 | if args.distributed: 217 | sampler_train.set_epoch(epoch) 218 | 219 | train_stats = train_one_epoch( 220 | model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm 221 | ) 222 | lr_scheduler.step() 223 | 224 | if args.output_dir: 225 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 226 | if (epoch + 1) % args.save_every_epoch == 0: 227 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') 228 | for checkpoint_path in checkpoint_paths: 229 | utils.save_on_master({ 230 | 'model': model_without_ddp.state_dict(), 231 | 'optimizer': optimizer.state_dict(), 232 | 'lr_scheduler': lr_scheduler.state_dict(), 233 | 'epoch': epoch, 234 | 'args': args, 235 | }, checkpoint_path) 236 | 237 | if (epoch + 1) % args.eval_every_epoch == 0: 238 | test_stats, coco_evaluator = evaluate( 239 | model, criterion, post_processors, data_loader_val, base_ds, device, args.output_dir 240 | ) 241 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 242 | **{f'test_{k}': v for k, v in test_stats.items()}, 243 | 'epoch': epoch, 244 | 'n_parameters': n_parameters} 245 | if args.output_dir and utils.is_main_process(): 246 | with (output_dir / "log.txt").open("a") as f: 247 | f.write(json.dumps(log_stats) + "\n") 248 | # for evaluation logs 249 | if coco_evaluator is not None: 250 | (output_dir / 'eval').mkdir(exist_ok=True) 251 | if "bbox" in coco_evaluator.coco_eval: 252 | filenames = ['latest.pth'] 253 | filenames.append(f'{epoch:04}.pth') 254 | for name in filenames: 255 | torch.save(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval" / name) 256 | 257 | total_time = time.time() - start_time 258 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 259 | print('Training completed.\nTotal training time: {}'.format(total_time_str)) 260 | 261 | 262 | if __name__ == '__main__': 263 | parser = argparse.ArgumentParser("SAM-DETR", parents=[get_args_parser()]) 264 | args = parser.parse_args() 265 | if args.output_dir: 266 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 267 | main(args) 268 | -------------------------------------------------------------------------------- /models/transformer_decoder.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | from typing import Optional 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torch import Tensor 11 | import torchvision 12 | 13 | from models.misc import _get_clones, _get_activation_fn, MLP 14 | from models.position_encoding import gen_sineembed_for_position 15 | from models.attention import MultiheadAttention 16 | from util.box_ops import box_cxcywh_to_xyxy 17 | 18 | 19 | class TransformerDecoder(nn.Module): 20 | def __init__(self, args, decoder_layer, num_layers): 21 | super().__init__() 22 | self.args = args 23 | self.multiscale = args.multiscale 24 | self.num_layers = num_layers 25 | self.layers = _get_clones(decoder_layer, num_layers) 26 | assert num_layers == self.args.dec_layers 27 | self.box_embed = None 28 | 29 | def forward(self, tgt, memory, 30 | tgt_mask: Optional[Tensor] = None, 31 | memory_mask: Optional[Tensor] = None, 32 | tgt_key_padding_mask: Optional[Tensor] = None, 33 | memory_key_padding_mask: Optional[Tensor] = None, 34 | pos: Optional[Tensor] = None, 35 | query_pos: Optional[Tensor] = None, 36 | memory_h=None, 37 | memory_w=None, 38 | grid=None): 39 | output = tgt 40 | 41 | intermediate = [] 42 | intermediate_reference_boxes = [] 43 | 44 | for layer_id, layer in enumerate(self.layers): 45 | 46 | if layer_id == 0 or layer_id == 1: 47 | scale_level = 2 48 | elif layer_id == 2 or layer_id == 3: 49 | scale_level = 1 50 | elif layer_id == 4 or layer_id == 5: 51 | scale_level = 0 52 | else: 53 | assert False 54 | 55 | if layer_id == 0: 56 | reference_boxes_before_sigmoid = query_pos # [num_queries, batch_size, 4] 57 | reference_boxes = reference_boxes_before_sigmoid.sigmoid().transpose(0, 1) 58 | else: 59 | tmp = self.bbox_embed[layer_id - 1](output) 60 | reference_boxes_before_sigmoid = tmp + reference_boxes_before_sigmoid 61 | reference_boxes = reference_boxes_before_sigmoid.sigmoid().transpose(0, 1) 62 | reference_boxes_before_sigmoid = reference_boxes_before_sigmoid.detach() 63 | reference_boxes = reference_boxes.detach() 64 | 65 | obj_center = reference_boxes[..., :2].transpose(0, 1) # [num_queries, batch_size, 2] 66 | 67 | # get sine embedding for the query vector 68 | query_ref_boxes_sine_embed = gen_sineembed_for_position(obj_center) 69 | 70 | if self.multiscale: 71 | memory_ = memory[scale_level] 72 | memory_h_ = memory_h[scale_level] 73 | memory_w_ = memory_w[scale_level] 74 | memory_key_padding_mask_ = memory_key_padding_mask[scale_level] 75 | pos_ = pos[scale_level] 76 | grid_ = grid[scale_level] 77 | else: 78 | memory_ = memory 79 | memory_h_ = memory_h 80 | memory_w_ = memory_w 81 | memory_key_padding_mask_ = memory_key_padding_mask 82 | pos_ = pos 83 | grid_ = grid 84 | 85 | output = layer(output, 86 | memory_, 87 | tgt_mask=tgt_mask, 88 | memory_mask=memory_mask, 89 | tgt_key_padding_mask=tgt_key_padding_mask, 90 | memory_key_padding_mask=memory_key_padding_mask_, 91 | pos=pos_, 92 | query_ref_boxes_sine_embed=query_ref_boxes_sine_embed, 93 | reference_boxes=reference_boxes, 94 | memory_h=memory_h_, 95 | memory_w=memory_w_, 96 | grid=grid_,) 97 | 98 | intermediate.append(output) 99 | intermediate_reference_boxes.append(reference_boxes) 100 | 101 | return torch.stack(intermediate).transpose(1, 2), \ 102 | torch.stack(intermediate_reference_boxes) 103 | 104 | 105 | class TransformerDecoderLayer(nn.Module): 106 | def __init__(self, args, activation="relu"): 107 | super().__init__() 108 | self.args = args 109 | self.d_model = args.hidden_dim 110 | self.nheads = args.nheads 111 | self.num_queries = args.num_queries 112 | self.dim_feedforward = args.dim_feedforward 113 | self.dropout = args.dropout 114 | self.activation = _get_activation_fn(activation) 115 | 116 | # Decoder Self-Attention 117 | self.sa_qcontent_proj = nn.Linear(self.d_model, self.d_model) 118 | self.sa_qpos_proj = nn.Linear(self.d_model, self.d_model) 119 | self.sa_kcontent_proj = nn.Linear(self.d_model, self.d_model) 120 | self.sa_kpos_proj = nn.Linear(self.d_model, self.d_model) 121 | self.sa_v_proj = nn.Linear(self.d_model, self.d_model) 122 | self.self_attn = MultiheadAttention(self.d_model, self.nheads, dropout=self.dropout, vdim=self.d_model) 123 | self.dropout1 = nn.Dropout(self.dropout) 124 | self.norm1 = nn.LayerNorm(self.d_model) 125 | 126 | # Decoder Cross-Attention 127 | self.ca_qcontent_proj = nn.Linear(self.d_model, self.d_model) 128 | self.ca_kcontent_proj = nn.Linear(self.d_model, self.d_model) 129 | self.ca_v_proj = nn.Linear(self.d_model, self.d_model) 130 | self.ca_qpos_sine_proj = MLP(self.d_model, self.d_model, self.d_model, 2) 131 | self.ca_kpos_sine_proj = MLP(self.d_model, self.d_model, self.d_model, 2) 132 | self.cross_attn = MultiheadAttention(self.nheads * self.d_model, self.nheads, dropout=self.dropout, vdim=self.d_model) 133 | self.dropout2 = nn.Dropout(self.dropout) 134 | self.norm2 = nn.LayerNorm(self.d_model) 135 | 136 | self.point1 = nn.Sequential( 137 | nn.Conv2d(self.d_model, self.d_model // 4, kernel_size=1, stride=1, padding=0), 138 | nn.ReLU(), 139 | ) 140 | if self.args.smca: 141 | self.point2 = nn.Sequential( 142 | nn.Linear(self.d_model // 4 * 7 * 7, 256), 143 | nn.ReLU(), 144 | nn.Linear(256, 512), 145 | nn.ReLU(), 146 | nn.Linear(512, 512), 147 | nn.ReLU(), 148 | nn.Linear(512, self.nheads * 4), 149 | ) 150 | nn.init.constant_(self.point2[-1].weight.data, 0) 151 | nn.init.constant_(self.point2[-1].bias.data, 0) 152 | else: 153 | self.point2 = nn.Sequential( 154 | nn.Linear(self.d_model // 4 * 7 * 7, 256), 155 | nn.ReLU(), 156 | nn.Linear(256, 512), 157 | nn.ReLU(), 158 | nn.Linear(512, 512), 159 | nn.ReLU(), 160 | nn.Linear(512, self.nheads * 2), 161 | ) 162 | nn.init.constant_(self.point2[-1].weight.data, 0) 163 | nn.init.constant_(self.point2[-1].bias.data, 0) 164 | 165 | 166 | self.attn1 = nn.Linear(self.d_model, self.d_model * self.nheads) 167 | self.attn2 = nn.Linear(self.d_model, self.d_model * self.nheads) 168 | 169 | # FFN 170 | self.linear1 = nn.Linear(self.d_model, self.dim_feedforward) 171 | self.dropout88 = nn.Dropout(self.dropout) 172 | self.linear2 = nn.Linear(self.dim_feedforward, self.d_model) 173 | self.dropout3 = nn.Dropout(self.dropout) 174 | self.norm3 = nn.LayerNorm(self.d_model) 175 | 176 | def get_valid_ratio(self, mask): 177 | _, H, W = mask.shape 178 | valid_H = torch.sum(~mask[:, :, 0], 1) 179 | valid_W = torch.sum(~mask[:, 0, :], 1) 180 | valid_ratio_h = valid_H.float() / H 181 | valid_ratio_w = valid_W.float() / W 182 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h, valid_ratio_w, valid_ratio_h], -1) 183 | return valid_ratio 184 | 185 | def forward(self, tgt, memory, 186 | tgt_mask: Optional[Tensor] = None, 187 | memory_mask: Optional[Tensor] = None, 188 | tgt_key_padding_mask: Optional[Tensor] = None, 189 | memory_key_padding_mask: Optional[Tensor] = None, 190 | pos: Optional[Tensor] = None, 191 | query_ref_boxes_sine_embed = None, 192 | reference_boxes: Optional[Tensor] = None, 193 | memory_h=None, 194 | memory_w=None, 195 | grid=None): 196 | 197 | num_queries = tgt.shape[0] 198 | bs = tgt.shape[1] 199 | c = tgt.shape[2] 200 | n_model = c 201 | valid_ratio = self.get_valid_ratio(memory_key_padding_mask.view(bs, memory_h, memory_w)) 202 | 203 | memory_2d = memory.view(memory_h, memory_w, bs, c) 204 | memory_2d = memory_2d.permute(2, 3, 0, 1) 205 | 206 | # ========== Begin of Self-Attention ============= 207 | q_content = self.sa_qcontent_proj(tgt) 208 | q_pos = self.sa_qpos_proj(query_ref_boxes_sine_embed) 209 | k_content = self.sa_kcontent_proj(tgt) 210 | k_pos = self.sa_kpos_proj(query_ref_boxes_sine_embed) 211 | v = self.sa_v_proj(tgt) 212 | q = q_content + q_pos 213 | k = k_content + k_pos 214 | tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask)[0] 215 | # ========== End of Self-Attention ============= 216 | tgt = tgt + self.dropout1(tgt2) 217 | tgt = self.norm1(tgt) 218 | 219 | reference_boxes_xyxy = box_cxcywh_to_xyxy(reference_boxes) 220 | reference_boxes_xyxy[:, :, 0] *= memory_w 221 | reference_boxes_xyxy[:, :, 1] *= memory_h 222 | reference_boxes_xyxy[:, :, 2] *= memory_w 223 | reference_boxes_xyxy[:, :, 3] *= memory_h 224 | reference_boxes_xyxy = reference_boxes_xyxy * valid_ratio.view(bs, 1, 4) 225 | 226 | q_content = torchvision.ops.roi_align( 227 | memory_2d, 228 | list(torch.unbind(reference_boxes_xyxy, dim=0)), 229 | output_size=(7, 7), 230 | spatial_scale=1.0, 231 | aligned=True) # (bs * num_queries, c, 7, 7) 232 | 233 | q_content_points = torchvision.ops.roi_align( 234 | memory_2d, 235 | list(torch.unbind(reference_boxes_xyxy, dim=0)), 236 | output_size=(7, 7), 237 | spatial_scale=1.0, 238 | aligned=True) # (bs * num_queries, c, 7, 7) 239 | 240 | q_content_index = q_content_points.view(bs * num_queries, -1, 7, 7) 241 | 242 | points = self.point1(q_content_index) 243 | points = points.reshape(bs * num_queries, -1) 244 | points = self.point2(points) 245 | if not self.args.smca: 246 | points = points.view(bs * num_queries, 1, self.nheads, 2).tanh() 247 | else: 248 | points_scale = points[:, 2 * self.nheads:].reshape(bs, num_queries, self.nheads, 2).permute(1, 0, 2, 3) 249 | points = points[:, :2 * self.nheads].view(bs * num_queries, 1, self.nheads, 2).tanh() 250 | 251 | q_content = F.grid_sample(q_content, points, padding_mode="zeros", align_corners=False).view(bs * num_queries, -1) 252 | q_content = q_content.view(bs, num_queries, -1, 8).permute(1, 0, 3, 2) # (num_query, bs, n_head, 256) 253 | q_content = q_content * self.attn1(tgt).view(num_queries, bs, self.nheads, n_model).sigmoid() 254 | 255 | q_pos_center = reference_boxes[:, :, :2].reshape(bs, num_queries, 1, 2).expand(-1, -1, self.nheads, -1) 256 | q_pos_scale = reference_boxes[:, :, 2:].reshape(bs, num_queries, 1, 2).expand(-1, -1, self.nheads, -1) * 0.5 257 | q_pos_delta = points.reshape(bs, num_queries, self.nheads, 2) 258 | q_pos = q_pos_center + q_pos_scale * q_pos_delta 259 | 260 | q_pos = q_pos.permute(1, 0, 2, 3) # (num_query, bs, n_head, 2) 261 | q_pos = q_pos.reshape(num_queries, bs * self.nheads, 2) 262 | 263 | if self.args.smca: 264 | # SMCA: start 265 | gau_point = torch.clone(q_pos) 266 | gau_point[:, :, 0] *= memory_w 267 | gau_point[:, :, 1] *= memory_h 268 | gau_point = gau_point.reshape(num_queries, bs, self.nheads, 2) 269 | gau_point = gau_point * valid_ratio[:, :2].reshape(1, bs, 1, 2) 270 | gau_point = gau_point.reshape(num_queries, bs * self.nheads, 2) 271 | gau_distance = (gau_point.unsqueeze(1) - (grid + 0.5).unsqueeze(0)).pow(2) 272 | gau_scale = points_scale 273 | gau_scale = gau_scale * gau_scale 274 | gau_scale = gau_scale.reshape(num_queries, -1, 2).unsqueeze(1) 275 | gau_distance = (gau_distance * gau_scale).sum(-1) 276 | gaussian = -(gau_distance - 0).abs() / 8.0 # 8.0 is the number used in SMCA-DETR 277 | # SMCA: end 278 | else: 279 | gaussian = None 280 | 281 | q_pos = gen_sineembed_for_position(q_pos).reshape(num_queries, bs, self.nheads, c) 282 | q_pos = q_pos * self.attn2(tgt).view(num_queries, bs, self.nheads, n_model).sigmoid() 283 | 284 | # ========== Begin of Cross-Attention ============= 285 | # Apply projections here 286 | # shape: num_queries x batch_size x 256 287 | q_content = self.ca_qcontent_proj(q_content) 288 | k_content = self.ca_kcontent_proj(memory).view(-1, bs, 1, 256).expand(-1, -1, self.nheads, -1) 289 | v = self.ca_v_proj(memory).view(-1, bs, n_model) 290 | 291 | num_queries, bs, n_head, n_model = q_content.shape 292 | hw, _, _, _ = k_content.shape 293 | 294 | q = q_content 295 | k = k_content 296 | 297 | query_sine_embed = self.ca_qpos_sine_proj(q_pos) 298 | q = (q + query_sine_embed).view(num_queries, bs, self.nheads * n_model) 299 | 300 | k = k.view(hw, bs, self.nheads, n_model) 301 | k_pos = self.ca_kpos_sine_proj(pos) 302 | k_pos = k_pos.view(hw, bs, 1, n_model).expand(-1, -1, self.nheads, -1) 303 | k = (k + k_pos).view(hw, bs, self.nheads * n_model) 304 | 305 | if self.args.smca: 306 | tgt2 = self.cross_attn(query=q, 307 | key=k, 308 | value=v, attn_mask=memory_mask, 309 | key_padding_mask=memory_key_padding_mask, 310 | gaussian=[gaussian])[0] 311 | else: 312 | tgt2 = self.cross_attn(query=q, 313 | key=k, 314 | value=v, attn_mask=memory_mask, 315 | key_padding_mask=memory_key_padding_mask, 316 | gaussian=None)[0] 317 | # ========== End of Cross-Attention ============= 318 | tgt = tgt + self.dropout2(tgt2) 319 | tgt = self.norm2(tgt) 320 | 321 | # FFN 322 | tgt2 = self.linear2(self.dropout88(self.activation(self.linear1(tgt)))) 323 | tgt = tgt + self.dropout3(tgt2) 324 | tgt = self.norm3(tgt) 325 | 326 | return tgt 327 | -------------------------------------------------------------------------------- /models/segmentation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copied from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | This file provides the definition of the convolutional heads used to predict masks, as well as the losses 8 | """ 9 | import io 10 | from collections import defaultdict 11 | from typing import List, Optional 12 | 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | from torch import Tensor 17 | from PIL import Image 18 | 19 | import util.box_ops as box_ops 20 | from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list 21 | 22 | try: 23 | from panopticapi.utils import id2rgb, rgb2id 24 | except ImportError: 25 | pass 26 | 27 | 28 | class DETRsegm(nn.Module): 29 | def __init__(self, detr, freeze_detr=False): 30 | super().__init__() 31 | self.detr = detr 32 | 33 | if freeze_detr: 34 | for p in self.parameters(): 35 | p.requires_grad_(False) 36 | 37 | hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead 38 | self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) 39 | self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) 40 | 41 | def forward(self, samples: NestedTensor): 42 | if isinstance(samples, (list, torch.Tensor)): 43 | samples = nested_tensor_from_tensor_list(samples) 44 | features, pos = self.detr.backbone(samples) 45 | 46 | bs = features[-1].tensors.shape[0] 47 | 48 | src, mask = features[-1].decompose() 49 | assert mask is not None 50 | src_proj = self.detr.input_proj(src) 51 | hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) 52 | 53 | outputs_class = self.detr.class_embed(hs) 54 | outputs_coord = self.detr.bbox_embed(hs).sigmoid() 55 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 56 | if self.detr.aux_loss: 57 | out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord) 58 | 59 | # FIXME h_boxes takes the last one computed, keep this in mind 60 | bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) 61 | 62 | seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) 63 | outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) 64 | 65 | out["pred_masks"] = outputs_seg_masks 66 | return out 67 | 68 | 69 | def _expand(tensor, length: int): 70 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) 71 | 72 | 73 | class MaskHeadSmallConv(nn.Module): 74 | """ 75 | Simple convolutional head, using group norm. 76 | Upsampling is done using a FPN approach 77 | """ 78 | 79 | def __init__(self, dim, fpn_dims, context_dim): 80 | super().__init__() 81 | 82 | inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] 83 | self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) 84 | self.gn1 = torch.nn.GroupNorm(8, dim) 85 | self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) 86 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) 87 | self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) 88 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) 89 | self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) 90 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) 91 | self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) 92 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) 93 | self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) 94 | 95 | self.dim = dim 96 | 97 | self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) 98 | self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) 99 | self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) 100 | 101 | for m in self.modules(): 102 | if isinstance(m, nn.Conv2d): 103 | nn.init.kaiming_uniform_(m.weight, a=1) 104 | nn.init.constant_(m.bias, 0) 105 | 106 | def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): 107 | x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) 108 | 109 | x = self.lay1(x) 110 | x = self.gn1(x) 111 | x = F.relu(x) 112 | x = self.lay2(x) 113 | x = self.gn2(x) 114 | x = F.relu(x) 115 | 116 | cur_fpn = self.adapter1(fpns[0]) 117 | if cur_fpn.size(0) != x.size(0): 118 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 119 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 120 | x = self.lay3(x) 121 | x = self.gn3(x) 122 | x = F.relu(x) 123 | 124 | cur_fpn = self.adapter2(fpns[1]) 125 | if cur_fpn.size(0) != x.size(0): 126 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 127 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 128 | x = self.lay4(x) 129 | x = self.gn4(x) 130 | x = F.relu(x) 131 | 132 | cur_fpn = self.adapter3(fpns[2]) 133 | if cur_fpn.size(0) != x.size(0): 134 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 135 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 136 | x = self.lay5(x) 137 | x = self.gn5(x) 138 | x = F.relu(x) 139 | 140 | x = self.out_lay(x) 141 | return x 142 | 143 | 144 | class MHAttentionMap(nn.Module): 145 | """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" 146 | 147 | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): 148 | super().__init__() 149 | self.num_heads = num_heads 150 | self.hidden_dim = hidden_dim 151 | self.dropout = nn.Dropout(dropout) 152 | 153 | self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 154 | self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 155 | 156 | nn.init.zeros_(self.k_linear.bias) 157 | nn.init.zeros_(self.q_linear.bias) 158 | nn.init.xavier_uniform_(self.k_linear.weight) 159 | nn.init.xavier_uniform_(self.q_linear.weight) 160 | self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 161 | 162 | def forward(self, q, k, mask: Optional[Tensor] = None): 163 | q = self.q_linear(q) 164 | k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) 165 | qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) 166 | kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) 167 | weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) 168 | 169 | if mask is not None: 170 | weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) 171 | weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) 172 | weights = self.dropout(weights) 173 | return weights 174 | 175 | 176 | def dice_loss(inputs, targets, num_boxes): 177 | """ 178 | Compute the DICE loss, similar to generalized IOU for masks 179 | Args: 180 | inputs: A float tensor of arbitrary shape. 181 | The predictions for each example. 182 | targets: A float tensor with the same shape as inputs. Stores the binary 183 | classification label for each element in inputs 184 | (0 for the negative class and 1 for the positive class). 185 | """ 186 | inputs = inputs.sigmoid() 187 | inputs = inputs.flatten(1) 188 | numerator = 2 * (inputs * targets).sum(1) 189 | denominator = inputs.sum(-1) + targets.sum(-1) 190 | loss = 1 - (numerator + 1) / (denominator + 1) 191 | return loss.sum() / num_boxes 192 | 193 | 194 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 195 | """ 196 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 197 | Args: 198 | inputs: A float tensor of arbitrary shape. 199 | The predictions for each example. 200 | targets: A float tensor with the same shape as inputs. Stores the binary 201 | classification label for each element in inputs 202 | (0 for the negative class and 1 for the positive class). 203 | alpha: (optional) Weighting factor in range (0,1) to balance 204 | positive vs negative examples. Default = -1 (no weighting). 205 | gamma: Exponent of the modulating factor (1 - p_t) to 206 | balance easy vs hard examples. 207 | Returns: 208 | Loss tensor 209 | """ 210 | prob = inputs.sigmoid() 211 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 212 | p_t = prob * targets + (1 - prob) * (1 - targets) 213 | loss = ce_loss * ((1 - p_t) ** gamma) 214 | 215 | if alpha >= 0: 216 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 217 | loss = alpha_t * loss 218 | 219 | return loss.mean(1).sum() / num_boxes 220 | 221 | 222 | class PostProcessSegm(nn.Module): 223 | def __init__(self, threshold=0.5): 224 | super().__init__() 225 | self.threshold = threshold 226 | 227 | @torch.no_grad() 228 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 229 | assert len(orig_target_sizes) == len(max_target_sizes) 230 | max_h, max_w = max_target_sizes.max(0)[0].tolist() 231 | outputs_masks = outputs["pred_masks"].squeeze(2) 232 | outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) 233 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() 234 | 235 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): 236 | img_h, img_w = t[0], t[1] 237 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) 238 | results[i]["masks"] = F.interpolate( 239 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 240 | ).byte() 241 | 242 | return results 243 | 244 | 245 | class PostProcessPanoptic(nn.Module): 246 | """This class converts the output of the model to the final panoptic result, in the format expected by the 247 | coco panoptic API """ 248 | 249 | def __init__(self, is_thing_map, threshold=0.85): 250 | """ 251 | Parameters: 252 | is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether 253 | the class is a thing (True) or a stuff (False) class 254 | threshold: confidence threshold: segments with confidence lower than this will be deleted 255 | """ 256 | super().__init__() 257 | self.threshold = threshold 258 | self.is_thing_map = is_thing_map 259 | 260 | def forward(self, outputs, processed_sizes, target_sizes=None): 261 | """ This function computes the panoptic prediction from the model's predictions. 262 | Parameters: 263 | outputs: This is a dict coming directly from the model. See the model doc for the content. 264 | processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the 265 | model, ie the size after data augmentation but before batching. 266 | target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size 267 | of each prediction. If left to None, it will default to the processed_sizes 268 | """ 269 | if target_sizes is None: 270 | target_sizes = processed_sizes 271 | assert len(processed_sizes) == len(target_sizes) 272 | out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] 273 | assert len(out_logits) == len(raw_masks) == len(target_sizes) 274 | preds = [] 275 | 276 | def to_tuple(tup): 277 | if isinstance(tup, tuple): 278 | return tup 279 | return tuple(tup.cpu().tolist()) 280 | 281 | for cur_logits, cur_masks, cur_boxes, size, target_size in zip( 282 | out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes 283 | ): 284 | # we filter empty queries and detection below threshold 285 | scores, labels = cur_logits.softmax(-1).max(-1) 286 | keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) 287 | cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) 288 | cur_scores = cur_scores[keep] 289 | cur_classes = cur_classes[keep] 290 | cur_masks = cur_masks[keep] 291 | cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) 292 | cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) 293 | 294 | h, w = cur_masks.shape[-2:] 295 | assert len(cur_boxes) == len(cur_classes) 296 | 297 | # It may be that we have several predicted masks for the same stuff class. 298 | # In the following, we track the list of masks ids for each stuff class (they are merged later on) 299 | cur_masks = cur_masks.flatten(1) 300 | stuff_equiv_classes = defaultdict(lambda: []) 301 | for k, label in enumerate(cur_classes): 302 | if not self.is_thing_map[label.item()]: 303 | stuff_equiv_classes[label.item()].append(k) 304 | 305 | def get_ids_area(masks, scores, dedup=False): 306 | # This helper function creates the final panoptic segmentation image 307 | # It also returns the area of the masks that appears on the image 308 | 309 | m_id = masks.transpose(0, 1).softmax(-1) 310 | 311 | if m_id.shape[-1] == 0: 312 | # We didn't detect any mask :( 313 | m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) 314 | else: 315 | m_id = m_id.argmax(-1).view(h, w) 316 | 317 | if dedup: 318 | # Merge the masks corresponding to the same stuff class 319 | for equiv in stuff_equiv_classes.values(): 320 | if len(equiv) > 1: 321 | for eq_id in equiv: 322 | m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) 323 | 324 | final_h, final_w = to_tuple(target_size) 325 | 326 | seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) 327 | seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) 328 | 329 | np_seg_img = ( 330 | torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() 331 | ) 332 | m_id = torch.from_numpy(rgb2id(np_seg_img)) 333 | 334 | area = [] 335 | for i in range(len(scores)): 336 | area.append(m_id.eq(i).sum().item()) 337 | return area, seg_img 338 | 339 | area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) 340 | if cur_classes.numel() > 0: 341 | # We know filter empty masks as long as we find some 342 | while True: 343 | filtered_small = torch.as_tensor( 344 | [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device 345 | ) 346 | if filtered_small.any().item(): 347 | cur_scores = cur_scores[~filtered_small] 348 | cur_classes = cur_classes[~filtered_small] 349 | cur_masks = cur_masks[~filtered_small] 350 | area, seg_img = get_ids_area(cur_masks, cur_scores) 351 | else: 352 | break 353 | 354 | else: 355 | cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) 356 | 357 | segments_info = [] 358 | for i, a in enumerate(area): 359 | cat = cur_classes[i].item() 360 | segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) 361 | del cur_classes 362 | 363 | with io.BytesIO() as out: 364 | seg_img.save(out, format="PNG") 365 | predictions = {"png_string": out.getvalue(), "segments_info": segments_info} 366 | preds.append(predictions) 367 | return preds 368 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SAM-DETR (Semantic-Aligned-Matching DETR) 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2203.06883-b31b1b.svg)](https://arxiv.org/abs/2203.06883) 4 | [![Survey](https://github.com/sindresorhus/awesome/blob/main/media/mentioned-badge.svg)](https://github.com/dk-liang/Awesome-Visual-Transformer) 5 | [![Maintenance](https://img.shields.io/badge/Maintained%3F-yes-green.svg)](https://GitHub.com/Naereen/StrapDown.js/graphs/commit-activity) 6 | [![PR's Welcome](https://img.shields.io/badge/PRs-welcome-brightgreen.svg?style=flat)](http://makeapullrequest.com) 7 | [![GitHub license](https://badgen.net/github/license/ZhangGongjie/SAM-DETR)](https://github.com/ZhangGongjie/SAM-DETR/blob/master/LICENSE) 8 | 9 | This repository is an official PyTorch implementation of the 10 | CVPR 2022 paper "[Accelerating DETR Convergence via Semantic-Aligned Matching](https://arxiv.org/abs/2203.06883)". 11 | 12 | *[UPDATE on 21 Apr 2022]*   We found that with a very simple modification (with no extra computational cost), SAM-DETR can achieve better performance. On MS-COCO, **SAM-DETR w/ SMCA** can achieve **37.0 AP** within 12 epochs, and **42.7 AP** within 50 epochs. We will release the updated training scripts, model weights, and logs in the future. Please stay tuned! 13 | 14 | ## Introduction 15 | 16 | TL;DR   SAM-DETR is an efficeint DETR-like object detector that can 17 | converge wihtin 12 epochs and outperform the strong Faster R-CNN (w/ FPN) baseline. 18 | 19 | The recently developed DEtection TRansformer (DETR) has established a new 20 | object detection paradigm by eliminating a series of hand-crafted components. 21 | However, DETR suffers from extremely slow convergence, which increases the 22 | training cost significantly. We observe that the slow convergence can be largely 23 | attributed to the complication in matching object queries to encoded image features 24 | in DETR's decoder cross-attention modules. 25 | 26 |
27 | 28 |
29 | 30 | Motivated by this observation, in our paper, we propose SAM-DETR, a 31 | Semantic-Aligned-Matching DETR that can greatly accelerates DETR's convergence 32 | without sacrificing its accuracy. SAM-DETR addresses the slow convergence issue 33 | from two perspectives. First, it projects object queries into the same 34 | embedding space as encoded image features, where the matching can be accomplished 35 | efficiently with aligned semantics. Second, it explicitly searches salient 36 | points with the most discriminative features for semantic-aligned matching, 37 | which further speeds up the convergence and boosts detection accuracy as well. 38 | Being like a plug and play, SAM-DETR complements existing convergence solutions 39 | well yet only introduces slight computational overhead. Experiments 40 | show that the proposed SAM-DETR achieves superior convergence as well as 41 | competitive detection accuracy. 42 | 43 | At the core of SAM-DETR is a plug-and-play module named "Semantics Aligner" appended 44 | ahead of the cross-attention module in DETR's each decoder layer. It also models a learnable 45 | reference box for each object query, whose center location is used to generate 46 | corresponding position embeddings. 47 | 48 |
49 | 50 |
51 | 52 | The figure below illustrates the architecture of the appended "Semantics Aligner", which 53 | aligns the semantics of "encoded image features" and "object queries" by re-sampling features 54 | from multiple salient points as new object queries. 55 | 56 |
57 | 58 |
59 | 60 | Being like a plug-and-play, our approach can be 61 | easily integrated with existing convergence solutions (*e.g.*, SMCA) in a complementary manner, 62 | boosting detection accuracy and convergence speed further. 63 | 64 | Please check [our CVPR 2022 paper](https://arxiv.org/abs/2203.06883) for more details. 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | ## Installation 73 | 74 | ### Pre-Requisites 75 | You must have NVIDIA GPUs to run the codes. 76 | 77 | The implementation codes are developed and tested with the following environment setups: 78 | - Linux 79 | - 8x NVIDIA V100 GPUs (32GB) 80 | - CUDA 10.1 81 | - Python == 3.8 82 | - PyTorch == 1.8.1+cu101, TorchVision == 0.9.1+cu101 83 | - GCC == 7.5.0 84 | - cython, pycocotools, tqdm, scipy 85 | 86 | We recommend using the exact setups above. However, other environments (Linux, Python>=3.7, CUDA>=9.2, GCC>=5.4, PyTorch>=1.5.1, TorchVision>=0.6.1) should also work. 87 | 88 | ### Code Installation 89 | 90 | First, clone the repository locally: 91 | ```shell 92 | git clone https://github.com/ZhangGongjie/SAM-DETR.git 93 | ``` 94 | 95 | We recommend you to use [Anaconda](https://www.anaconda.com/) to create a conda environment: 96 | ```bash 97 | conda create -n sam_detr python=3.8 pip 98 | ``` 99 | 100 | Then, activate the environment: 101 | ```bash 102 | conda activate sam_detr 103 | ``` 104 | 105 | Then, install PyTorch and TorchVision: 106 | 107 | (preferably using our recommended setups; CUDA version should match your own local environment) 108 | ```bash 109 | conda install pytorch=1.8.1 torchvision=0.9.1 cudatoolkit=10.1 -c pytorch 110 | ``` 111 | 112 | After that, install other requirements: 113 | ```bash 114 | conda install cython scipy tqdm 115 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 116 | ``` 117 | 118 | *[Optional]*   If you wish to run multi-scale version of SAM-DETR (results not reported in the CVPR paper), you need to compile [*Deformable Attention*](https://github.com/fundamentalvision/Deformable-DETR), 119 | which is used in DETR encoder to generate feature pyramid efficiently. If you don't need multi-scale 120 | version of SAM-DETR, you may skip this step. 121 | ```bash 122 | # Optionally compile CUDA operators of Deformable Attention for multi-scale SAM-DETR 123 | cd SAM-DETR 124 | cd ./models/ops 125 | sh ./make.sh 126 | python test.py # unit test (should see all checking is True) 127 | ``` 128 | 129 | ### Data Preparation 130 | 131 | Please download [COCO 2017 dataset](https://cocodataset.org/) and organize them as following: 132 | 133 | ``` 134 | code_root/ 135 | └── data/ 136 | └── coco/ 137 | ├── train2017/ 138 | ├── val2017/ 139 | └── annotations/ 140 | ├── instances_train2017.json 141 | └── instances_val2017.json 142 | ``` 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | ## Usage 151 | 152 | ### Reproducing Paper Results 153 | 154 | All scripts to reproduce results reported in [our CVPR 2022 paper](https://arxiv.org/abs/2203.06883) 155 | are stored in ```./scripts```. We also provide scripts for slurm cluster, 156 | which are stored in ```./scripts_slurm```. 157 | 158 | Taking SAM-DETR-R50 w/ SMCA (12 epochs) for example, to reproduce its results, simply 159 | run: 160 | ```shell 161 | bash scripts/r50_smca_e12_4gpu.sh 162 | ``` 163 | 164 | Taking SAM-DETR-R50 multiscale w/ SMCA (50 epochs) for example, to reproduce its results on a slurm cluster, simply 165 | run: 166 | ```shell 167 | bash scripts_slurm/r50_ms_smca_e50_8gpu.sh 168 | ``` 169 | 170 | Reminder: To reproduce results, please make sure the total batch size matches the implementation details described in our paper. For ```R50 (single-scale)``` 171 | experiments, we use 4 GPUs with a batch size of 4 on each GPU. For ```R50 (multi-scale)``` 172 | experiments, we use 8 GPUs with a batch size of 2 on each GPU. For ```R50-DC5 (single-scale)``` 173 | experiments, we use 8 GPUs with a batch size of 1 on each GPU. 174 | 175 | 176 | 177 | ### Training 178 | To perform training on COCO *train2017*, modify the arguments based on the scripts below: 179 | ```shell 180 | python -m torch.distributed.launch \ 181 | --nproc_per_node=4 \ # number of GPUs to perform training 182 | --use_env main.py \ 183 | --batch_size 4 \ # batch_size on individual GPU (this is *NOT* total batch_size) 184 | --smca \ # to integrate with SMCA, remove this line to disable SMCA 185 | --dilation \ # to enable DC5, remove this line to disable DC5 186 | --multiscale \ # to enable multi-scale, remove this line to disable multiscale 187 | --epochs 50 \ # total number of epochs to train 188 | --lr_drop 40 \ # when to drop learning rate 189 | --output_dir output/xxxx # where to store outputs, remove this line for not storing outputs 190 | ``` 191 | More arguments and their explanations are available at ```main.py```. 192 | 193 | ### Evaluation 194 | To evaluate a model on COCO *val2017*, simply add ```--resume``` and ```--eval``` arguments to your training scripts: 195 | ```shell 196 | python -m torch.distributed.launch \ 197 | --nproc_per_node=4 \ 198 | --use_env main.py \ 199 | --batch_size 4 \ 200 | --smca \ 201 | --dilation \ 202 | --multiscale \ 203 | --epochs 50 \ 204 | --lr_drop 40 \ 205 | --resume \ # trained model weights 206 | --eval \ # this means that only evaluation will be performed 207 | --output_dir output/xxxx 208 | ``` 209 | 210 | 211 | ### Visualize Detection Results 212 | We provide `demo.py`, which is a minimal implementation that allows users to visualize model's detection predictions. It performs detection on images inside the `./images` folder, and stores detection visualizations in that folder. Taking SAM-DETR-R50 w/ SMCA (50 epochs) for example, simply run: 213 | ```shell 214 | python demo.py \ # do NOT use distributed mode 215 | --smca \ 216 | --epochs 50 \ # you need to set this correct. See models/fast_detr.py L50-79 for details. 217 | --resume # trained model weights 218 | ``` 219 | 220 | 221 | 222 | ## Model Zoo 223 | 224 | *Trained model weights are stored in Google Drive.* 225 | 226 | The original DETR models trained for 500 epochs: 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 |
MethodEpochsParams (M)GFLOPsAPURL
DETR-R50500418642.0log
DETR-R50-DC55004118743.3log
258 | 259 | 260 | Our proposed SAM-DETR models (results reported in [our CVPR paper](https://arxiv.org/abs/2203.06883)): 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 |
MethodEpochsParams (M)GFLOPsAPURL
SAM-DETR-R50125810033.1model
log
SAM-DETR-R50 w/ SMCA125810036.0model
log
SAM-DETR-R50-DC5125821038.3model
log
SAM-DETR-R50-DC5 w/ SMCA125821040.6model
log
SAM-DETR-R50505810039.8model
log
SAM-DETR-R50 w/ SMCA505810041.8model
log
SAM-DETR-R50-DC5505821043.3model
log
SAM-DETR-R50-DC5 w/ SMCA505821045.0model
log
339 | 340 | 341 | 342 | 343 | 344 | Our proposed multi-scale SAM-DETR models (results to appear in a journal extension): 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 |
MethodEpochsParams (M)GFLOPsAPURL
SAM-DETR-R50-MS125520341.1model
log
SAM-DETR-R50-MS w/ SMCA125520342.8model
log
SAM-DETR-R50-MS505520346.1model
log
SAM-DETR-R50-MS w/ SMCA505520347.1model
log
391 | 392 | Note: 393 | 1. AP is computed on *COCO val2017*. 394 | 2. "DC5" means removing the stride in C5 stage of ResNet and add a dilation of 2 instead. 395 | 3. The GFLOPs of our models are estimated using [fvcore](https://github.com/facebookresearch/fvcore) on the first 100 images in *COCO val2017*. GFLOPs varies as input image sizes change. There may exist slight difference from actual values. 396 | 397 | 398 | 399 | 400 | 401 | ## License 402 | 403 | The implementation codes of SAM-DETR are released under the MIT license. 404 | 405 | Please see the [LICENSE](LICENSE) file for more information. 406 | 407 | However, prior works' licenses also apply. It is your responsibility to ensure you comply with all license requirements. 408 | 409 | ## Citation 410 | 411 | If you find SAM-DETR useful or inspiring, please consider citing: 412 | 413 | ```bibtex 414 | @inproceedings{zhang2022-SAMDETR, 415 | title = {Accelerating {DETR} Convergence via Semantic-Aligned Matching}, 416 | author = {Zhang, Gongjie and Luo, Zhipeng and Yu, Yingchen and Cui, Kaiwen and Lu, Shijian}, 417 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 418 | pages = {949-958}, 419 | year = {2022}, 420 | } 421 | ``` 422 | 423 | 424 | 425 | ## Acknowledgement 426 | 427 | Our SAM-DETR is heavily inspired by many outstanding prior works, including [DETR](https://github.com/facebookresearch/detr), [Conditional-DETR](https://github.com/Atten4Vis/ConditionalDETR), 428 | [SMCA-DETR](https://github.com/gaopengcuhk/SMCA-DETR), and [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR). 429 | Thank the authors of above projects for open-sourcing their implementation codes! 430 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | """ 7 | Misc functions, including distributed helpers. 8 | 9 | Mostly copy-paste from torchvision references. 10 | """ 11 | import os 12 | import subprocess 13 | import time 14 | from collections import defaultdict, deque 15 | import datetime 16 | import pickle 17 | from typing import Optional, List 18 | 19 | import torch 20 | import torch.distributed as dist 21 | from torch import Tensor 22 | 23 | # needed due to empty tensor bug in pytorch and torchvision 0.5 24 | import torchvision 25 | if float(torchvision.__version__.split(".")[1]) < 7.0: 26 | from torchvision.ops import _new_empty_tensor 27 | from torchvision.ops.misc import _output_size 28 | 29 | 30 | class SmoothedValue(object): 31 | """Track a series of values and provide access to smoothed values over a 32 | window or the global series average. 33 | """ 34 | 35 | def __init__(self, window_size=20, fmt=None): 36 | if fmt is None: 37 | fmt = "{median:.4f} ({global_avg:.4f})" 38 | self.deque = deque(maxlen=window_size) 39 | self.total = 0.0 40 | self.count = 0 41 | self.fmt = fmt 42 | 43 | def update(self, value, n=1): 44 | self.deque.append(value) 45 | self.count += n 46 | self.total += value * n 47 | 48 | def synchronize_between_processes(self): 49 | """ 50 | Warning: does not synchronize the deque! 51 | """ 52 | if not is_dist_avail_and_initialized(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 55 | dist.barrier() 56 | dist.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self): 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self): 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self): 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self): 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self): 81 | return self.deque[-1] 82 | 83 | def __str__(self): 84 | return self.fmt.format( 85 | median=self.median, 86 | avg=self.avg, 87 | global_avg=self.global_avg, 88 | max=self.max, 89 | value=self.value) 90 | 91 | 92 | def all_gather(data): 93 | """ 94 | Run all_gather on arbitrary picklable data (not necessarily tensors) 95 | Args: 96 | data: any picklable object 97 | Returns: 98 | list[data]: list of data gathered from each rank 99 | """ 100 | world_size = get_world_size() 101 | if world_size == 1: 102 | return [data] 103 | 104 | # serialized to a Tensor 105 | buffer = pickle.dumps(data) 106 | storage = torch.ByteStorage.from_buffer(buffer) 107 | tensor = torch.ByteTensor(storage).to("cuda") 108 | 109 | # obtain Tensor size of each rank 110 | local_size = torch.tensor([tensor.numel()], device="cuda") 111 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 112 | dist.all_gather(size_list, local_size) 113 | size_list = [int(size.item()) for size in size_list] 114 | max_size = max(size_list) 115 | 116 | # receiving Tensor from all ranks 117 | # we pad the tensor because torch all_gather does not support 118 | # gathering tensors of different shapes 119 | tensor_list = [] 120 | for _ in size_list: 121 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 122 | if local_size != max_size: 123 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 124 | tensor = torch.cat((tensor, padding), dim=0) 125 | dist.all_gather(tensor_list, tensor) 126 | 127 | data_list = [] 128 | for size, tensor in zip(size_list, tensor_list): 129 | buffer = tensor.cpu().numpy().tobytes()[:size] 130 | data_list.append(pickle.loads(buffer)) 131 | 132 | return data_list 133 | 134 | 135 | def reduce_dict(input_dict, average=True): 136 | """ 137 | Args: 138 | input_dict (dict): all the values will be reduced 139 | average (bool): whether to do average or sum 140 | Reduce the values in the dictionary from all processes so that all processes 141 | have the averaged results. Returns a dict with the same fields as 142 | input_dict, after reduction. 143 | """ 144 | world_size = get_world_size() 145 | if world_size < 2: 146 | return input_dict 147 | with torch.no_grad(): 148 | names = [] 149 | values = [] 150 | # sort the keys so that they are consistent across processes 151 | for k in sorted(input_dict.keys()): 152 | names.append(k) 153 | values.append(input_dict[k]) 154 | values = torch.stack(values, dim=0) 155 | dist.all_reduce(values) 156 | if average: 157 | values /= world_size 158 | reduced_dict = {k: v for k, v in zip(names, values)} 159 | return reduced_dict 160 | 161 | 162 | class MetricLogger(object): 163 | def __init__(self, delimiter="\t"): 164 | self.meters = defaultdict(SmoothedValue) 165 | self.delimiter = delimiter 166 | 167 | def update(self, **kwargs): 168 | for k, v in kwargs.items(): 169 | if isinstance(v, torch.Tensor): 170 | v = v.item() 171 | assert isinstance(v, (float, int)) 172 | self.meters[k].update(v) 173 | 174 | def __getattr__(self, attr): 175 | if attr in self.meters: 176 | return self.meters[attr] 177 | if attr in self.__dict__: 178 | return self.__dict__[attr] 179 | raise AttributeError("'{}' object has no attribute '{}'".format( 180 | type(self).__name__, attr)) 181 | 182 | def __str__(self): 183 | loss_str = [] 184 | for name, meter in self.meters.items(): 185 | loss_str.append( 186 | "{}: {}".format(name, str(meter)) 187 | ) 188 | return self.delimiter.join(loss_str) 189 | 190 | def synchronize_between_processes(self): 191 | for meter in self.meters.values(): 192 | meter.synchronize_between_processes() 193 | 194 | def add_meter(self, name, meter): 195 | self.meters[name] = meter 196 | 197 | def log_every(self, iterable, print_freq, header=None): 198 | i = 0 199 | if not header: 200 | header = '' 201 | start_time = time.time() 202 | end = time.time() 203 | iter_time = SmoothedValue(fmt='{avg:.4f}') 204 | data_time = SmoothedValue(fmt='{avg:.4f}') 205 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 206 | if torch.cuda.is_available(): 207 | log_msg = self.delimiter.join([ 208 | header, 209 | '[{0' + space_fmt + '}/{1}]', 210 | 'eta: {eta}', 211 | '{meters}', 212 | 'time: {time}', 213 | 'data: {data}', 214 | 'max mem: {memory:.0f}' 215 | ]) 216 | else: 217 | log_msg = self.delimiter.join([ 218 | header, 219 | '[{0' + space_fmt + '}/{1}]', 220 | 'eta: {eta}', 221 | '{meters}', 222 | 'time: {time}', 223 | 'data: {data}' 224 | ]) 225 | MB = 1024.0 * 1024.0 226 | for obj in iterable: 227 | data_time.update(time.time() - end) 228 | yield obj 229 | iter_time.update(time.time() - end) 230 | if i % print_freq == 0 or i == len(iterable) - 1: 231 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 232 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 233 | if torch.cuda.is_available(): 234 | print(log_msg.format( 235 | i, len(iterable), eta=eta_string, 236 | meters=str(self), 237 | time=str(iter_time), data=str(data_time), 238 | memory=torch.cuda.max_memory_allocated() / MB)) 239 | else: 240 | print(log_msg.format( 241 | i, len(iterable), eta=eta_string, 242 | meters=str(self), 243 | time=str(iter_time), data=str(data_time))) 244 | i += 1 245 | end = time.time() 246 | total_time = time.time() - start_time 247 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 248 | print('{} Total time: {} ({:.4f} s / it)'.format( 249 | header, total_time_str, total_time / len(iterable))) 250 | 251 | 252 | def get_sha(): 253 | cwd = os.path.dirname(os.path.abspath(__file__)) 254 | 255 | def _run(command): 256 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 257 | sha = 'N/A' 258 | diff = "clean" 259 | branch = 'N/A' 260 | try: 261 | sha = _run(['git', 'rev-parse', 'HEAD']) 262 | subprocess.check_output(['git', 'diff'], cwd=cwd) 263 | diff = _run(['git', 'diff-index', 'HEAD']) 264 | diff = "has uncommited changes" if diff else "clean" 265 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 266 | except Exception: 267 | pass 268 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 269 | return message 270 | 271 | 272 | def collate_fn(batch): 273 | batch = list(zip(*batch)) 274 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 275 | return tuple(batch) 276 | 277 | 278 | def _max_by_axis(the_list): 279 | # type: (List[List[int]]) -> List[int] 280 | maxes = the_list[0] 281 | for sublist in the_list[1:]: 282 | for index, item in enumerate(sublist): 283 | maxes[index] = max(maxes[index], item) 284 | return maxes 285 | 286 | 287 | class NestedTensor(object): 288 | def __init__(self, tensors, mask: Optional[Tensor]): 289 | self.tensors = tensors 290 | self.mask = mask 291 | 292 | def to(self, device): 293 | # type: (Device) -> NestedTensor # noqa 294 | cast_tensor = self.tensors.to(device) 295 | mask = self.mask 296 | if mask is not None: 297 | assert mask is not None 298 | cast_mask = mask.to(device) 299 | else: 300 | cast_mask = None 301 | return NestedTensor(cast_tensor, cast_mask) 302 | 303 | def decompose(self): 304 | return self.tensors, self.mask 305 | 306 | def __repr__(self): 307 | return str(self.tensors) 308 | 309 | 310 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 311 | # TODO make this more general 312 | if tensor_list[0].ndim == 3: 313 | if torchvision._is_tracing(): 314 | # nested_tensor_from_tensor_list() does not export well to ONNX 315 | # call _onnx_nested_tensor_from_tensor_list() instead 316 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 317 | 318 | # TODO make it support different-sized images 319 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 320 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 321 | batch_shape = [len(tensor_list)] + max_size 322 | b, c, h, w = batch_shape 323 | dtype = tensor_list[0].dtype 324 | device = tensor_list[0].device 325 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 326 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 327 | for img, pad_img, m in zip(tensor_list, tensor, mask): 328 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 329 | m[: img.shape[1], :img.shape[2]] = False 330 | else: 331 | raise ValueError('not supported') 332 | return NestedTensor(tensor, mask) 333 | 334 | 335 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 336 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 337 | @torch.jit.unused 338 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 339 | max_size = [] 340 | for i in range(tensor_list[0].dim()): 341 | max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) 342 | max_size.append(max_size_i) 343 | max_size = tuple(max_size) 344 | 345 | # work around for 346 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 347 | # m[: img.shape[1], :img.shape[2]] = False 348 | # which is not yet supported in onnx 349 | padded_imgs = [] 350 | padded_masks = [] 351 | for img in tensor_list: 352 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 353 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 354 | padded_imgs.append(padded_img) 355 | 356 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 357 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 358 | padded_masks.append(padded_mask.to(torch.bool)) 359 | 360 | tensor = torch.stack(padded_imgs) 361 | mask = torch.stack(padded_masks) 362 | 363 | return NestedTensor(tensor, mask=mask) 364 | 365 | 366 | def setup_for_distributed(is_master): 367 | """ 368 | This function disables printing when not in master process 369 | """ 370 | import builtins as __builtin__ 371 | builtin_print = __builtin__.print 372 | 373 | def print(*args, **kwargs): 374 | force = kwargs.pop('force', False) 375 | if is_master or force: 376 | builtin_print(*args, **kwargs) 377 | 378 | __builtin__.print = print 379 | 380 | 381 | def is_dist_avail_and_initialized(): 382 | if not dist.is_available(): 383 | return False 384 | if not dist.is_initialized(): 385 | return False 386 | return True 387 | 388 | 389 | def get_world_size(): 390 | if not is_dist_avail_and_initialized(): 391 | return 1 392 | return dist.get_world_size() 393 | 394 | 395 | def get_rank(): 396 | if not is_dist_avail_and_initialized(): 397 | return 0 398 | return dist.get_rank() 399 | 400 | 401 | def is_main_process(): 402 | return get_rank() == 0 403 | 404 | 405 | def save_on_master(*args, **kwargs): 406 | if is_main_process(): 407 | torch.save(*args, **kwargs) 408 | 409 | 410 | def init_distributed_mode(args): 411 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 412 | args.rank = int(os.environ["RANK"]) 413 | args.world_size = int(os.environ['WORLD_SIZE']) 414 | args.gpu = int(os.environ['LOCAL_RANK']) 415 | elif 'SLURM_PROCID' in os.environ: 416 | proc_id = int(os.environ['SLURM_PROCID']) 417 | ntasks = int(os.environ['SLURM_NTASKS']) 418 | node_list = os.environ['SLURM_NODELIST'] 419 | num_gpus = torch.cuda.device_count() 420 | addr = subprocess.getoutput('scontrol show hostname {} | head -n1'.format(node_list)) 421 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29686') 422 | os.environ['MASTER_ADDR'] = addr 423 | os.environ['WORLD_SIZE'] = str(ntasks) 424 | os.environ['RANK'] = str(proc_id) 425 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 426 | os.environ['LOCAL_SIZE'] = str(num_gpus) 427 | args.dist_url = 'env://' 428 | args.world_size = ntasks 429 | args.rank = proc_id 430 | args.gpu = proc_id % num_gpus 431 | else: 432 | print('Not using distributed mode') 433 | args.distributed = False 434 | return 435 | 436 | args.distributed = True 437 | 438 | torch.cuda.set_device(args.gpu) 439 | args.dist_backend = 'nccl' 440 | print('| distributed init (rank {}): {}'.format( 441 | args.rank, args.dist_url), flush=True) 442 | torch.distributed.init_process_group(backend=args.dist_backend, 443 | init_method=args.dist_url, 444 | world_size=args.world_size, 445 | rank=args.rank) 446 | torch.distributed.barrier() 447 | setup_for_distributed(args.rank == 0) 448 | 449 | 450 | @torch.no_grad() 451 | def accuracy(output, target, topk=(1,)): 452 | """Computes the precision@k for the specified values of k""" 453 | if target.numel() == 0: 454 | return [torch.zeros([], device=output.device)] 455 | maxk = max(topk) 456 | batch_size = target.size(0) 457 | 458 | _, pred = output.topk(maxk, 1, True, True) 459 | pred = pred.t() 460 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 461 | 462 | res = [] 463 | for k in topk: 464 | correct_k = correct[:k].view(-1).float().sum(0) 465 | res.append(correct_k.mul_(100.0 / batch_size)) 466 | return res 467 | 468 | 469 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 470 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 471 | """ 472 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 473 | This will eventually be supported natively by PyTorch, and this 474 | class can go away. 475 | """ 476 | if float(torchvision.__version__.split(".")[1]) < 7.0: 477 | if input.numel() > 0: 478 | return torch.nn.functional.interpolate( 479 | input, size, scale_factor, mode, align_corners 480 | ) 481 | 482 | output_shape = _output_size(2, input, size, scale_factor) 483 | output_shape = list(input.shape[:-2]) + list(output_shape) 484 | return _new_empty_tensor(input, output_shape) 485 | else: 486 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 487 | 488 | 489 | def inverse_sigmoid(x, eps=1e-5): 490 | x = x.clamp(min=0, max=1) 491 | x1 = x.clamp(min=eps) 492 | x2 = (1 - x).clamp(min=eps) 493 | return torch.log(x1 / x2) 494 | --------------------------------------------------------------------------------