├── LICENSE ├── README.md ├── configs └── CIS-R50.yaml ├── dcnet ├── __init__.py ├── config.py ├── data │ ├── __init__.py │ ├── dataset_mappers │ │ ├── __init__.py │ │ └── mapper_with_Fourier_amplitude.py │ └── datasets │ │ ├── __init__.py │ │ └── register_cis.py ├── dcnet.py ├── modeling │ ├── ICS │ │ ├── __init__.py │ │ ├── ics.py │ │ ├── position_encoding.py │ │ └── reference_attention.py │ ├── PCD │ │ ├── __init__.py │ │ ├── difference_attention.py │ │ ├── ops │ │ │ ├── functions │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn_func.py │ │ │ ├── make.sh │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ └── ms_deform_attn.py │ │ │ ├── setup.py │ │ │ ├── src │ │ │ │ ├── cpu │ │ │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ │ │ └── ms_deform_attn_cpu.h │ │ │ │ ├── cuda │ │ │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ │ └── ms_deform_im2col_cuda.cuh │ │ │ │ ├── ms_deform_attn.h │ │ │ │ └── vision.cpp │ │ │ └── test.py │ │ └── pcd.py │ ├── __init__.py │ ├── criterion.py │ └── matcher.py └── utils │ ├── __init__.py │ └── misc.py ├── framework.png ├── requirements.txt └── train_net.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 USTCL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Camouflaged Instance Segmentation via Explicit De-camouflaging 2 | 3 | Official Implementation of CVPR2023 Highlight [paper](http://openaccess.thecvf.com/content/CVPR2023/html/Luo_Camouflaged_Instance_Segmentation_via_Explicit_De-Camouflaging_CVPR_2023_paper.html) "Camouflaged Instance Segmentation via Explicit De-camouflaging" 4 | 5 | ## DCNet 6 | 7 | ![Alt text](framework.png) 8 | 9 | We propose a novel De-camouflaging Network (DCNet) by jointly modeling pixel-level camouflage decoupling and instance-level camouflage suppression for Camouflaged Instance Segmentation (CIS) task. 10 | 11 | 12 | ## Environment preparation 13 | 14 | The code is tested on CUDA 11.3 and pytorch 1.10.1, change the versions below to your desired ones. 15 | 16 | ```shell 17 | conda create -n dcnet python=3.9 -y 18 | conda activate dcnet 19 | conda install pytorch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1 cudatoolkit=11.3 -c pytorch -c conda-forge 20 | python -m pip install detectron2 -f \ 21 | https://dl.fbaipublicfiles.com/detectron2/wheels/cu113/torch1.10/index.html 22 | 23 | cd DCNet 24 | pip install -r requirements.txt 25 | cd dcnet/modeling/PCD/ops 26 | sh make.sh 27 | ``` 28 | 29 | ## Dataset preparation 30 | 31 | ### Download the datasets 32 | 33 | - **COD10K**: [Google](https://drive.google.com/file/d/1YGa3v-MiXy-3MMJDkidLXPt0KQwygt-Z/view?usp=sharing) **Json files:** [Google](https://drive.google.com/drive/folders/1Yvz63C8c7LOHFRgm06viUM9XupARRPif?usp=sharing) 34 | - **NC4K**: [Google](https://drive.google.com/file/d/1eK_oi-N4Rmo6IIxUNbYHBiNWuDDLGr_k/view?usp=sharing); **Json files:** [Google](https://drive.google.com/drive/folders/1LyK7tl2QVZBFiNaWI_n0ZVa0QiwF2B8e?usp=sharing) 35 | 36 | ### Register datasets 37 | 38 | Change the path of the datasets as well as annotations in `dcnet/data/datasets/register_cis.py`. 39 | 40 | ```python 41 | # dcnet/data/datasets/register_cis.py 42 | # change the paths 43 | COD10K_ROOT = 'COD10K' # path to your COD10K dataset 44 | ANN_ROOT = os.path.join(COD10K_ROOT, 'annotations') 45 | TRAIN_PATH = os.path.join(COD10K_ROOT, 'Train_Image_CAM') 46 | TEST_PATH = os.path.join(COD10K_ROOT, 'Test_Image_CAM') 47 | TRAIN_JSON = os.path.join(ANN_ROOT, 'train_instance.json') 48 | TEST_JSON = os.path.join(ANN_ROOT, 'test2026.json') 49 | 50 | NC4K_ROOT = 'NC4K' # path to your NC4K dataset 51 | NC4K_PATH = os.path.join(NC4K_ROOT, 'test/image') 52 | NC4K_JSON = os.path.join(NC4K_ROOT, 'nc4k_test.json') 53 | ``` 54 | 55 | ## Train 56 | 57 | ```shell 58 | python train_net.py \ 59 | --config-file configs/CIS-R50.yaml \ 60 | MODEL.WEIGHTS {PATH_TO_PRE_TRAINED_WEIGHTS} 61 | ``` 62 | 63 | ## Pre-trained models 64 | 65 | DCNet model (ResNet-50) weights: [Google](https://drive.google.com/file/d/1xeB_F713KiGHhMSvwBcGmlgKap1IbOvQ/view?usp=sharing) 66 | 67 | 68 | ## Evalation 69 | 70 | ```shell 71 | python train_net.py \ 72 | --eval-only \ 73 | --config-file configs/CIS-R50.yaml \ 74 | MODEL.WEIGHTS {PATH_TO_PRE_TRAINED_WEIGHTS} 75 | ``` 76 | 77 | Please replace `{PATH_TO_PRE_TRAINED_WEIGHTS}` to the pre-trained weights. 78 | 79 | ## Citation 80 | 81 | If you find this code useful for your research, please cite our paper: 82 | ``` 83 | @inproceedings{luo2023camouflaged, 84 | title={Camouflaged Instance Segmentation via Explicit De-Camouflaging}, 85 | author={Luo, Naisong and Pan, Yuwen and Sun, Rui and Zhang, Tianzhu and Xiong, Zhiwei and Wu, Feng}, 86 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 87 | pages={17918--17927}, 88 | year={2023} 89 | } 90 | ``` 91 | 92 | ## Acknowledgements 93 | 94 | Some codes are adapted from [OSFormer](https://github.com/PJLallen/OSFormer) and [Mask2Former](https://github.com/facebookresearch/Mask2Former). We thank them for their excellent projects. 95 | -------------------------------------------------------------------------------- /configs/CIS-R50.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | BACKBONE: 3 | FREEZE_AT: 0 4 | NAME: "build_resnet_backbone" 5 | WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" 6 | RESNETS: 7 | DEPTH: 50 8 | STEM_TYPE: "basic" # not used 9 | STEM_OUT_CHANNELS: 64 10 | STRIDE_IN_1X1: False 11 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 12 | META_ARCHITECTURE: "DCNet" 13 | SEM_SEG_HEAD: 14 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 15 | DCNET: 16 | DEEP_SUPERVISION: True 17 | CLASS_WEIGHT: 1.0 18 | DICE_WEIGHT: 1.0 19 | MASK_WEIGHT: 20.0 20 | NUM_OBJECT_QUERIES: 10 21 | DEC_LAYERS: 6 22 | DATASETS: 23 | TRAIN: ("cod10k_train",) 24 | TEST: ("cod10k_test", "nc4k_test") 25 | SOLVER: 26 | IMS_PER_BATCH: 2 27 | BASE_LR: 0.0001 28 | STEPS: (70000, 90000) 29 | MAX_ITER: 100000 30 | WARMUP_FACTOR: 1.0 31 | WARMUP_ITERS: 10 32 | WEIGHT_DECAY: 0.05 33 | OPTIMIZER: "ADAMW" 34 | BACKBONE_MULTIPLIER: 0.1 35 | AMP: 36 | ENABLED: True 37 | INPUT: 38 | IMAGE_SIZE: 1024 39 | MIN_SCALE: 0.1 40 | MAX_SCALE: 2.0 41 | FORMAT: "RGB" 42 | TEST: 43 | EVAL_PERIOD: 5000 44 | DETECTIONS_PER_IMAGE: 10 45 | DATALOADER: 46 | FILTER_EMPTY_ANNOTATIONS: True 47 | NUM_WORKERS: 4 48 | OUTPUT_DIR: output -------------------------------------------------------------------------------- /dcnet/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import modeling 3 | 4 | # config 5 | from .config import add_dcnet_config 6 | 7 | # dataset loading 8 | from .data.datasets.register_cis import register_dataset 9 | from .data.dataset_mappers.mapper_with_Fourier_amplitude import DatasetMapper_Fourier_amplitude 10 | 11 | # model 12 | from .dcnet import DCNet -------------------------------------------------------------------------------- /dcnet/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | 4 | def add_dcnet_config(cfg): 5 | """ 6 | Add config for DCNET. 7 | """ 8 | 9 | # Color augmentation 10 | cfg.INPUT.COLOR_AUG_SSD = False 11 | 12 | # Pad image and segmentation GT in dataset mapper. 13 | cfg.INPUT.SIZE_DIVISIBILITY = -1 14 | 15 | 16 | # optimizer 17 | cfg.SOLVER.OPTIMIZER = "ADAMW" 18 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 19 | 20 | # DCNet model config 21 | cfg.MODEL.DCNET = CN() 22 | 23 | # loss 24 | cfg.MODEL.DCNET.DEEP_SUPERVISION = True 25 | # cfg.MODEL.DCNET.NO_OBJECT_WEIGHT = 0.1 26 | cfg.MODEL.DCNET.CLASS_WEIGHT = 1.0 27 | cfg.MODEL.DCNET.DICE_WEIGHT = 1.0 28 | cfg.MODEL.DCNET.MASK_WEIGHT = 20.0 29 | cfg.MODEL.DCNET.DEC_LAYERS = 6 30 | cfg.MODEL.DCNET.NUM_OBJECT_QUERIES = 10 31 | cfg.MODEL.DCNET.SIZE_DIVISIBILITY = 32 32 | 33 | 34 | # LSJ aug 35 | cfg.INPUT.IMAGE_SIZE = 1024 36 | cfg.INPUT.MIN_SCALE = 0.1 37 | cfg.INPUT.MAX_SCALE = 2.0 38 | 39 | # MSDeformAttn configs 40 | cfg.MODEL.SEM_SEG_HEAD.DEFORMABLE_TRANSFORMER_ENCODER_N_POINTS = 4 -------------------------------------------------------------------------------- /dcnet/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/dcnet/data/__init__.py -------------------------------------------------------------------------------- /dcnet/data/dataset_mappers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/dcnet/data/dataset_mappers/__init__.py -------------------------------------------------------------------------------- /dcnet/data/dataset_mappers/mapper_with_Fourier_amplitude.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from detectron2.config import configurable 8 | from detectron2.data import detection_utils as utils 9 | from detectron2.data import transforms as T 10 | from detectron2.data.transforms import TransformGen 11 | from detectron2.structures import BitMasks, Instances 12 | 13 | from pycocotools import mask as coco_mask 14 | 15 | __all__ = ["DatasetMapper_Fourier_amplitude"] 16 | 17 | 18 | def convert_coco_poly_to_mask(segmentations, height, width): 19 | masks = [] 20 | for polygons in segmentations: 21 | rles = coco_mask.frPyObjects(polygons, height, width) 22 | mask = coco_mask.decode(rles) 23 | if len(mask.shape) < 3: 24 | mask = mask[..., None] 25 | mask = torch.as_tensor(mask, dtype=torch.uint8) 26 | mask = mask.any(dim=2) 27 | masks.append(mask) 28 | if masks: 29 | masks = torch.stack(masks, dim=0) 30 | else: 31 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 32 | return masks 33 | 34 | def img_cut(img): 35 | img[img < 0] = 0 36 | img[img > 255] = 255 37 | return img.astype(np.uint8) 38 | 39 | def build_transform_gen(cfg, is_train): 40 | """ 41 | Create a list of default :class:`Augmentation` from config. 42 | Now it includes resizing and flipping. 43 | Returns: 44 | list[Augmentation] 45 | """ 46 | assert is_train, "Only support training augmentation" 47 | image_size = cfg.INPUT.IMAGE_SIZE 48 | min_scale = cfg.INPUT.MIN_SCALE 49 | max_scale = cfg.INPUT.MAX_SCALE 50 | 51 | augmentation = [] 52 | 53 | if cfg.INPUT.RANDOM_FLIP != "none": 54 | augmentation.append( 55 | T.RandomFlip( 56 | horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", 57 | vertical=cfg.INPUT.RANDOM_FLIP == "vertical", 58 | ) 59 | ) 60 | 61 | augmentation.extend([ 62 | T.ResizeScale( 63 | min_scale=min_scale, max_scale=max_scale, target_height=image_size, target_width=image_size 64 | ), 65 | T.FixedSizeCrop(crop_size=(image_size, image_size)), 66 | ]) 67 | 68 | return augmentation 69 | 70 | def build_transform_test(cfg, is_train): 71 | """ 72 | Create a list of default :class:`Augmentation` from config. 73 | Now it includes resizing and flipping. 74 | Returns: 75 | list[Augmentation] 76 | """ 77 | assert not is_train, "Only support testing augmentation" 78 | image_size = cfg.INPUT.IMAGE_SIZE 79 | 80 | augmentation = [] 81 | 82 | augmentation.extend([ 83 | T.ResizeScale( 84 | min_scale=1, max_scale=1, target_height=image_size, target_width=image_size 85 | )] 86 | ) 87 | 88 | return augmentation 89 | 90 | class DatasetMapper_Fourier_amplitude: 91 | """ 92 | A callable which takes a dataset dict in Detectron2 Dataset format, 93 | and map it into a format used by DCNet. 94 | 95 | This dataset mapper applies the same transformation as DETR for COCO panoptic segmentation. 96 | 97 | The callable currently does the following: 98 | 99 | 1. Read the image from "file_name" 100 | 2. Applies geometric transforms to the image and annotation 101 | 3. Find and applies suitable cropping to the image and annotation 102 | 4. Prepare image and annotation to Tensors 103 | """ 104 | 105 | @configurable 106 | def __init__( 107 | self, 108 | is_train=True, 109 | *, 110 | tfm_gens, 111 | image_format, 112 | ): 113 | """ 114 | NOTE: this interface is experimental. 115 | Args: 116 | is_train: for training or inference 117 | augmentations: a list of augmentations or deterministic transforms to apply 118 | tfm_gens: data augmentation 119 | image_format: an image format supported by :func:`detection_utils.read_image`. 120 | """ 121 | self.tfm_gens = tfm_gens 122 | logging.getLogger(__name__).info( 123 | "[DatasetMapper_Fourier_amplitude] Full TransformGens used in training: {}".format(str(self.tfm_gens)) 124 | ) 125 | 126 | self.img_format = image_format 127 | self.is_train = is_train 128 | 129 | @classmethod 130 | def from_config(cls, cfg, is_train=True): 131 | # Build augmentation 132 | if is_train: 133 | tfm_gens = build_transform_gen(cfg, is_train) 134 | else: 135 | tfm_gens = build_transform_test(cfg, is_train) 136 | 137 | ret = { 138 | "is_train": is_train, 139 | "tfm_gens": tfm_gens, 140 | "image_format": cfg.INPUT.FORMAT, 141 | } 142 | return ret 143 | 144 | def __call__(self, dataset_dict): 145 | """ 146 | Args: 147 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 148 | 149 | Returns: 150 | dict: a format that builtin models in detectron2 accept 151 | """ 152 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 153 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 154 | utils.check_image_size(dataset_dict, image) 155 | 156 | 157 | # TODO: get padding mask 158 | # by feeding a "segmentation mask" to the same transforms 159 | padding_mask = np.ones(image.shape[:2]) 160 | 161 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 162 | # the crop transformation has default padding value 0 for segmentation 163 | padding_mask = transforms.apply_segmentation(padding_mask) 164 | padding_mask = ~ padding_mask.astype(bool) 165 | 166 | image_shape = image.shape[:2] # h, w 167 | 168 | # get Fourier amplitude and phase 169 | fre = np.fft.fft2(image,axes=(0,1)) 170 | fre_a = np.abs(fre) 171 | fre_p = np.angle(fre) 172 | 173 | # set the amplitude a constant 174 | constant = fre_p.mean() 175 | fre_ = fre_a * np.e**(1j*constant) 176 | img_a = np.abs(np.fft.ifft2(fre_,axes=(0,1))) 177 | img_a = img_cut(img_a) 178 | 179 | dataset_dict["image_a"] = torch.as_tensor(np.ascontiguousarray(img_a.transpose(2, 0, 1))) 180 | 181 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 182 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 183 | # Therefore it's important to use torch.Tensor. 184 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 185 | dataset_dict["padding_mask"] = torch.as_tensor(np.ascontiguousarray(padding_mask)) 186 | 187 | if not self.is_train: 188 | # USER: Modify this if you want to keep them for some reason. 189 | dataset_dict.pop("annotations", None) 190 | return dataset_dict 191 | 192 | if "annotations" in dataset_dict: 193 | # USER: Modify this if you want to keep them for some reason. 194 | for anno in dataset_dict["annotations"]: 195 | # Let's always keep mask 196 | # if not self.mask_on: 197 | # anno.pop("segmentation", None) 198 | anno.pop("keypoints", None) 199 | 200 | # USER: Implement additional transformations if you have other types of data 201 | annos = [ 202 | utils.transform_instance_annotations(obj, transforms, image_shape) 203 | for obj in dataset_dict.pop("annotations") 204 | if obj.get("iscrowd", 0) == 0 205 | ] 206 | # NOTE: does not support BitMask due to augmentation 207 | # Current BitMask cannot handle empty objects 208 | instances = utils.annotations_to_instances(annos, image_shape) 209 | # After transforms such as cropping are applied, the bounding box may no longer 210 | # tightly bound the object. As an example, imagine a triangle object 211 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight 212 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to 213 | # the intersection of original bounding box and the cropping box. 214 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 215 | # Need to filter empty instances first (due to augmentation) 216 | instances = utils.filter_empty_instances(instances) 217 | # Generate masks from polygon 218 | h, w = instances.image_size 219 | # image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float) 220 | if hasattr(instances, 'gt_masks'): 221 | gt_masks = instances.gt_masks 222 | gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) 223 | instances.gt_masks = gt_masks 224 | dataset_dict["instances"] = instances 225 | 226 | return dataset_dict 227 | -------------------------------------------------------------------------------- /dcnet/data/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/dcnet/data/datasets/__init__.py -------------------------------------------------------------------------------- /dcnet/data/datasets/register_cis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detectron2.data.datasets.coco import load_coco_json 3 | from detectron2.data import MetadataCatalog, DatasetCatalog 4 | 5 | COD10K_ROOT = '/data1/dataset/COD10K' 6 | ANN_ROOT = os.path.join(COD10K_ROOT, 'annotations') 7 | TRAIN_PATH = os.path.join(COD10K_ROOT, 'Train_Image_CAM') 8 | TEST_PATH = os.path.join(COD10K_ROOT, 'Test_Image_CAM') 9 | TRAIN_JSON = os.path.join(ANN_ROOT, 'train_instance.json') 10 | TEST_JSON = os.path.join(ANN_ROOT, 'test2026.json') 11 | 12 | NC4K_ROOT = '/data1/dataset/NC4K/NC4K' 13 | NC4K_PATH = os.path.join(NC4K_ROOT, 'test/image') 14 | NC4K_JSON = os.path.join(NC4K_ROOT, 'nc4k_test.json') 15 | 16 | CLASS_NAMES = ["foreground"] 17 | 18 | PREDEFINED_SPLITS_DATASET = { 19 | "cod10k_train": (TRAIN_PATH, TRAIN_JSON), 20 | "cod10k_test": (TEST_PATH, TEST_JSON), 21 | "nc4k_test": (NC4K_PATH, NC4K_JSON), 22 | } 23 | 24 | 25 | def register_dataset(): 26 | """ 27 | purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET 28 | """ 29 | for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items(): 30 | register_dataset_instances(name=key, 31 | json_file=json_file, 32 | image_root=image_root) 33 | 34 | 35 | def register_dataset_instances(name, json_file, image_root): 36 | """ 37 | purpose: register dataset to DatasetCatalog, 38 | register metadata to MetadataCatalog and set attribute 39 | """ 40 | DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name)) 41 | MetadataCatalog.get(name).set(json_file=json_file, 42 | image_root=image_root, 43 | evaluator_type="coco") -------------------------------------------------------------------------------- /dcnet/dcnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | # from detectron2.data import MetadataCatalog 6 | from detectron2.modeling import build_backbone, META_ARCH_REGISTRY 7 | from detectron2.modeling.postprocessing import sem_seg_postprocess 8 | from detectron2.structures import Boxes, ImageList, Instances 9 | from detectron2.utils.memory import retry_if_cuda_oom 10 | 11 | from .modeling.criterion import SetCriterion 12 | from .modeling.matcher import HungarianMatcher 13 | from .modeling import pcd, ics 14 | 15 | @META_ARCH_REGISTRY.register() 16 | class DCNet(nn.Module): 17 | 18 | def __init__(self, cfg): 19 | 20 | super().__init__() 21 | self.backbone = build_backbone(cfg) 22 | self.PCD = pcd(cfg, self.backbone.output_shape()) 23 | self.ICP = ics(cfg) 24 | 25 | matcher = HungarianMatcher( 26 | cost_class=cfg.MODEL.DCNET.CLASS_WEIGHT, 27 | cost_mask=cfg.MODEL.DCNET.MASK_WEIGHT, 28 | cost_dice=cfg.MODEL.DCNET.DICE_WEIGHT, 29 | num_points=112 ** 2, 30 | ) 31 | weight_dict = { 32 | "loss_ce": cfg.MODEL.DCNET.CLASS_WEIGHT, 33 | "loss_mask": cfg.MODEL.DCNET.MASK_WEIGHT, 34 | "loss_dice": cfg.MODEL.DCNET.DICE_WEIGHT 35 | } 36 | 37 | if cfg.MODEL.DCNET.DEEP_SUPERVISION: 38 | dec_layers = cfg.MODEL.DCNET.DEC_LAYERS 39 | aux_weight_dict = {} 40 | for i in range(dec_layers - 1): 41 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 42 | weight_dict.update(aux_weight_dict) 43 | 44 | self.criterion = SetCriterion( 45 | num_classes=1, 46 | matcher=matcher, 47 | weight_dict=weight_dict, 48 | eos_coef=0.1, 49 | losses=["labels", "masks"], 50 | num_points=112 ** 2, 51 | oversample_ratio=3, 52 | importance_sample_ratio=0.75, 53 | ) 54 | self.size_divisibility = cfg.MODEL.DCNET.SIZE_DIVISIBILITY 55 | 56 | pixel_mean = [123.675, 116.280, 103.530] 57 | pixel_std = [58.395, 57.120, 57.375] 58 | self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) 59 | self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) 60 | 61 | # additional args 62 | self.num_queries = cfg.MODEL.DCNET.NUM_OBJECT_QUERIES 63 | self.test_topk_per_image = cfg.TEST.DETECTIONS_PER_IMAGE 64 | 65 | @property 66 | def device(self): 67 | return self.pixel_mean.device 68 | 69 | def forward(self, batched_inputs): 70 | 71 | images = [x["image"].to(self.device) for x in batched_inputs] 72 | images = [(x - self.pixel_mean) / self.pixel_std for x in images] 73 | images_a = [x["image_a"].to(self.device) for x in batched_inputs] 74 | images_a = [(x - self.pixel_mean) / self.pixel_std for x in images_a] 75 | 76 | images = ImageList.from_tensors(images, self.size_divisibility) 77 | images_a = ImageList.from_tensors(images_a, self.size_divisibility) 78 | 79 | features = self.backbone(images.tensor) 80 | dc_pixel_features, pixel_embedding = self.PCD(features, images_a.tensor) 81 | outputs = self.ICP(dc_pixel_features, pixel_embedding) 82 | 83 | if self.training: 84 | # mask classification target 85 | if "instances" in batched_inputs[0]: 86 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 87 | targets = self.prepare_targets(gt_instances, images) 88 | else: 89 | targets = None 90 | 91 | # bipartite matching-based loss 92 | losses = self.criterion(outputs, targets) 93 | 94 | for k in list(losses.keys()): 95 | if k in self.criterion.weight_dict: 96 | losses[k] *= self.criterion.weight_dict[k] 97 | else: 98 | # remove this loss if not specified in `weight_dict` 99 | losses.pop(k) 100 | return losses 101 | else: 102 | mask_cls_results = outputs["pred_logits"] 103 | mask_pred_results = outputs["pred_masks"] 104 | # upsample masks 105 | mask_pred_results = F.interpolate( 106 | mask_pred_results, 107 | size=(images.tensor.shape[-2], images.tensor.shape[-1]), 108 | mode="bilinear", 109 | align_corners=False, 110 | ) 111 | 112 | del outputs 113 | 114 | processed_results = [] 115 | for mask_cls_result, mask_pred_result, input_per_image, image_size in zip( 116 | mask_cls_results, mask_pred_results, batched_inputs, images.image_sizes 117 | ): 118 | height = input_per_image.get("height", image_size[0]) 119 | width = input_per_image.get("width", image_size[1]) 120 | processed_results.append({}) 121 | 122 | mask_pred_result = retry_if_cuda_oom(sem_seg_postprocess)( 123 | mask_pred_result, image_size, height, width 124 | ) 125 | mask_cls_result = mask_cls_result.to(mask_pred_result) 126 | 127 | # instance segmentation inference 128 | instance_r = retry_if_cuda_oom(self.instance_inference)(mask_cls_result, mask_pred_result) 129 | processed_results[-1]["instances"] = instance_r 130 | 131 | return processed_results 132 | 133 | def prepare_targets(self, targets, images): 134 | h_pad, w_pad = images.tensor.shape[-2:] 135 | new_targets = [] 136 | for targets_per_image in targets: 137 | # pad gt 138 | gt_masks = targets_per_image.gt_masks 139 | padded_masks = torch.zeros((gt_masks.shape[0], h_pad, w_pad), dtype=gt_masks.dtype, device=gt_masks.device) 140 | padded_masks[:, : gt_masks.shape[1], : gt_masks.shape[2]] = gt_masks 141 | new_targets.append( 142 | { 143 | "labels": targets_per_image.gt_classes, 144 | "masks": padded_masks, 145 | } 146 | ) 147 | return new_targets 148 | 149 | def instance_inference(self, mask_cls, mask_pred): 150 | # mask_pred is already processed to have the same shape as original input 151 | image_size = mask_pred.shape[-2:] 152 | 153 | # [Q, K] 154 | scores = F.softmax(mask_cls, dim=-1)[:, :-1] 155 | num_classes = 1 156 | labels = torch.arange(num_classes, device=self.device).unsqueeze(0).repeat(self.num_queries, 1).flatten(0, 1) 157 | scores_per_image, topk_indices = scores.flatten(0, 1).topk(self.test_topk_per_image, sorted=False) 158 | labels_per_image = labels[topk_indices] 159 | 160 | topk_indices = topk_indices // num_classes 161 | mask_pred = mask_pred[topk_indices] 162 | 163 | result = Instances(image_size) 164 | # mask (before sigmoid) 165 | result.pred_masks = (mask_pred > 0).float() 166 | result.pred_boxes = Boxes(torch.zeros(mask_pred.size(0), 4)) 167 | # Uncomment the following to get boxes from masks (this is slow) 168 | # result.pred_boxes = BitMasks(mask_pred > 0).get_bounding_boxes() 169 | 170 | # calculate average mask prob 171 | mask_scores_per_image = (mask_pred.sigmoid().flatten(1) * result.pred_masks.flatten(1)).sum(1) / (result.pred_masks.flatten(1).sum(1) + 1e-6) 172 | result.scores = scores_per_image * mask_scores_per_image 173 | result.pred_classes = labels_per_image 174 | return result 175 | -------------------------------------------------------------------------------- /dcnet/modeling/ICS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/dcnet/modeling/ICS/__init__.py -------------------------------------------------------------------------------- /dcnet/modeling/ICS/ics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instance-level Camouflage Suppression Module 3 | """ 4 | 5 | import fvcore.nn.weight_init as weight_init 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | from detectron2.layers import Conv2d 11 | 12 | from .position_encoding import PositionEmbeddingSine 13 | from .reference_attention import ReferenceTransformer 14 | 15 | class ics(nn.Module): 16 | def __init__( 17 | self, 18 | cfg, 19 | in_channels=256, 20 | mask_classification=True, 21 | num_classes: int = 1, 22 | hidden_dim: int = 256, 23 | nheads: int = 8, 24 | dropout: float = 0.1, 25 | dim_feedforward: int = 2048, 26 | pre_norm: bool = False, 27 | mask_dim: int = 256, 28 | enforce_input_project: bool = False, 29 | ): 30 | 31 | super().__init__() 32 | 33 | self.mask_classification = mask_classification 34 | deep_supervision = cfg.MODEL.DCNET.DEEP_SUPERVISION 35 | 36 | # positional encoding 37 | N_steps = hidden_dim // 2 38 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 39 | 40 | transformer = ReferenceTransformer( 41 | d_model=hidden_dim, 42 | dropout=dropout, 43 | nhead=nheads, 44 | dim_feedforward=dim_feedforward, 45 | num_encoder_layers=0, 46 | num_decoder_layers=cfg.MODEL.DCNET.DEC_LAYERS, 47 | normalize_before=pre_norm, 48 | return_intermediate_dec=deep_supervision, 49 | ) 50 | 51 | self.num_queries = cfg.MODEL.DCNET.NUM_OBJECT_QUERIES 52 | self.transformer = transformer 53 | hidden_dim = transformer.d_model 54 | 55 | self.query_embed = nn.Embedding(self.num_queries, hidden_dim) 56 | 57 | if in_channels != hidden_dim or enforce_input_project: 58 | self.input_proj = Conv2d(in_channels, hidden_dim, kernel_size=1) 59 | weight_init.c2_xavier_fill(self.input_proj) 60 | else: 61 | self.input_proj = nn.Sequential() 62 | self.aux_loss = deep_supervision 63 | 64 | # output FFNs 65 | if self.mask_classification: 66 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 67 | self.mask_embed = MLP(hidden_dim, hidden_dim, mask_dim, 3) 68 | 69 | def forward(self, x, mask_features, mask=None): 70 | if mask is not None: 71 | mask = F.interpolate(mask[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 72 | pos = self.pe_layer(x, mask) 73 | 74 | src = x 75 | hs, memory = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos) 76 | 77 | if self.mask_classification: 78 | outputs_class = self.class_embed(hs) 79 | out = {"pred_logits": outputs_class[-1]} 80 | else: 81 | out = {} 82 | 83 | if self.aux_loss: 84 | # [l, bs, queries, embed] 85 | mask_embed = self.mask_embed(hs) 86 | outputs_seg_masks = torch.einsum("lbqc,bchw->lbqhw", mask_embed, mask_features) 87 | out["pred_masks"] = outputs_seg_masks[-1] 88 | out["aux_outputs"] = self._set_aux_loss( 89 | outputs_class if self.mask_classification else None, outputs_seg_masks 90 | ) 91 | else: 92 | # FIXME h_boxes takes the last one computed, keep this in mind 93 | # [bs, queries, embed] 94 | mask_embed = self.mask_embed(hs[-1]) 95 | outputs_seg_masks = torch.einsum("bqc,bchw->bqhw", mask_embed, mask_features) 96 | out["pred_masks"] = outputs_seg_masks 97 | return out 98 | 99 | @torch.jit.unused 100 | def _set_aux_loss(self, outputs_class, outputs_seg_masks): 101 | # this is a workaround to make torchscript happy, as torchscript 102 | # doesn't support dictionary with non-homogeneous values, such 103 | # as a dict having both a Tensor and a list. 104 | if self.mask_classification: 105 | return [ 106 | {"pred_logits": a, "pred_masks": b} 107 | for a, b in zip(outputs_class[:-1], outputs_seg_masks[:-1]) 108 | ] 109 | else: 110 | return [{"pred_masks": b} for b in outputs_seg_masks[:-1]] 111 | 112 | class MLP(nn.Module): 113 | """Very simple multi-layer perceptron (also called FFN)""" 114 | 115 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 116 | super().__init__() 117 | self.num_layers = num_layers 118 | h = [hidden_dim] * (num_layers - 1) 119 | self.layers = nn.ModuleList( 120 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 121 | ) 122 | 123 | def forward(self, x): 124 | for i, layer in enumerate(self.layers): 125 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 126 | return x -------------------------------------------------------------------------------- /dcnet/modeling/ICS/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/detr/blob/master/models/position_encoding.py 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class PositionEmbeddingSine(nn.Module): 12 | """ 13 | This is a more standard version of the position embedding, very similar to the one 14 | used by the Attention is all you need paper, generalized to work on images. 15 | """ 16 | 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, x, mask=None): 29 | if mask is None: 30 | mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) 31 | not_mask = ~mask 32 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 33 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 34 | if self.normalize: 35 | eps = 1e-6 36 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 37 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 38 | 39 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 40 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 41 | 42 | pos_x = x_embed[:, :, :, None] / dim_t 43 | pos_y = y_embed[:, :, :, None] / dim_t 44 | pos_x = torch.stack( 45 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 46 | ).flatten(3) 47 | pos_y = torch.stack( 48 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 49 | ).flatten(3) 50 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 51 | return pos 52 | 53 | def __repr__(self, _repr_indent=4): 54 | head = "Positional encoding " + self.__class__.__name__ 55 | body = [ 56 | "num_pos_feats: {}".format(self.num_pos_feats), 57 | "temperature: {}".format(self.temperature), 58 | "normalize: {}".format(self.normalize), 59 | "scale: {}".format(self.scale), 60 | ] 61 | # _repr_indent = 4 62 | lines = [head] + [" " * _repr_indent + line for line in body] 63 | return "\n".join(lines) 64 | -------------------------------------------------------------------------------- /dcnet/modeling/ICS/reference_attention.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional, Tuple, List 3 | import copy 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch import nn 8 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 9 | from torch.nn.parameter import Parameter 10 | from torch.nn.modules import Module 11 | from torch.nn import functional as F 12 | import math 13 | 14 | from torch.overrides import ( 15 | has_torch_function, handle_torch_function) 16 | 17 | 18 | class Reference(nn.Module): 19 | def __init__(self, topk=64, match_dim=32, dropout=0.1): 20 | super(Reference, self).__init__() 21 | 22 | self.dropout = nn.Dropout(dropout) 23 | self.match_dim = match_dim # HW->D :D: (should be << HW) 24 | 25 | self.corr_proj_q = nn.Linear(topk, match_dim) 26 | self.corr_proj_k = nn.Linear(topk, match_dim) 27 | self.LN_k = nn.LayerNorm(match_dim) 28 | self.LN_q = nn.LayerNorm(match_dim) 29 | 30 | self.proj_high_q = nn.Linear(match_dim, match_dim) 31 | self.proj_high_k = nn.Linear(match_dim, match_dim) 32 | 33 | self._reset_parameters() 34 | 35 | def _reset_parameters(self): 36 | xavier_uniform_(self.corr_proj_q.weight) 37 | xavier_uniform_(self.corr_proj_k.weight) 38 | xavier_uniform_(self.proj_high_q.weight) 39 | xavier_uniform_(self.proj_high_k.weight) 40 | 41 | # if self.out_proj.bias is not None: 42 | constant_(self.corr_proj_q.bias, 0.) 43 | constant_(self.corr_proj_k.bias, 0.) 44 | constant_(self.proj_high_q.bias, 0.) 45 | constant_(self.proj_high_k.bias, 0.) 46 | 47 | def forward(self, att_q2k: Tensor, key: Tensor, key2: Tensor, topk=64): 48 | # att_q2k : [bs, Nq, Nk] 49 | # key, key2 : [B, Nk, E] 50 | 51 | bs, Nq, Nk = att_q2k.shape 52 | _, _, E = key.shape 53 | 54 | att_confidence = att_q2k.detach().sum(dim=1) 55 | _, topk_index = torch.topk(att_confidence, k=topk, dim=1) # [B, k] 56 | 57 | index_att = topk_index.unsqueeze(1).repeat(1, Nq, 1) 58 | index_key = topk_index.unsqueeze(2).repeat(1, 1, E) 59 | 60 | att_q2k_topk = torch.gather(att_q2k, dim=2, index=index_att) # [B, Nq, k] 61 | key_topk = torch.gather(key, dim=1, index=index_key) # [B, k, E] 62 | 63 | att_k2k_topk = torch.bmm(key2 / math.sqrt(E), key_topk.transpose(-2, -1)) # [B, Nk, k] 64 | 65 | q = self.corr_proj_q(att_q2k_topk) # [bs, Nq, match_dim] 66 | q = F.relu(q) 67 | 68 | k = self.corr_proj_k(att_k2k_topk) # [bs, Nk, match_dim] 69 | k = F.relu(k) 70 | 71 | q = self.LN_q(q) 72 | k = self.LN_k(k) 73 | q = self.proj_high_q(q) 74 | k = self.proj_high_k(k) 75 | 76 | att_q2k_refer = torch.bmm(q / math.sqrt(self.match_dim), k.transpose(-2,-1)) # [bs*nH, Nq, Nk] 77 | return att_q2k_refer 78 | 79 | 80 | class ReferenceAttention(Module): 81 | 82 | __constants__ = ['batch_first'] 83 | bias_k: Optional[torch.Tensor] 84 | bias_v: Optional[torch.Tensor] 85 | 86 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, 87 | kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: 88 | factory_kwargs = {'device': device, 'dtype': dtype} 89 | super(ReferenceAttention, self).__init__() 90 | self.embed_dim = embed_dim 91 | self.kdim = kdim if kdim is not None else embed_dim 92 | self.vdim = vdim if vdim is not None else embed_dim 93 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 94 | 95 | self.num_heads = num_heads 96 | self.dropout = dropout 97 | self.batch_first = batch_first 98 | self.head_dim = embed_dim // num_heads 99 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 100 | 101 | if self._qkv_same_embed_dim is False: 102 | self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) 103 | self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) 104 | self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) 105 | self.register_parameter('in_proj_weight', None) 106 | else: 107 | self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) 108 | self.register_parameter('q_proj_weight', None) 109 | self.register_parameter('k_proj_weight', None) 110 | self.register_parameter('v_proj_weight', None) 111 | 112 | if bias: 113 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) 114 | else: 115 | self.register_parameter('in_proj_bias', None) 116 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) 117 | 118 | if add_bias_kv: 119 | self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 120 | self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 121 | else: 122 | self.bias_k = self.bias_v = None 123 | 124 | self.add_zero_attn = add_zero_attn 125 | 126 | 127 | self.k_proj_weight2 = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) 128 | self.b_k2 = Parameter(torch.empty((embed_dim), **factory_kwargs)) 129 | 130 | self._reset_parameters() 131 | 132 | def _reset_parameters(self): 133 | if self._qkv_same_embed_dim: 134 | xavier_uniform_(self.in_proj_weight) 135 | else: 136 | xavier_uniform_(self.q_proj_weight) 137 | xavier_uniform_(self.k_proj_weight) 138 | xavier_uniform_(self.v_proj_weight) 139 | 140 | if self.in_proj_bias is not None: 141 | constant_(self.in_proj_bias, 0.) 142 | constant_(self.out_proj.bias, 0.) 143 | if self.bias_k is not None: 144 | xavier_normal_(self.bias_k) 145 | if self.bias_v is not None: 146 | xavier_normal_(self.bias_v) 147 | 148 | def __setstate__(self, state): 149 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 150 | if '_qkv_same_embed_dim' not in state: 151 | state['_qkv_same_embed_dim'] = True 152 | 153 | super(ReferenceAttention, self).__setstate__(state) 154 | 155 | def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, 156 | need_weights: bool = True, attn_mask: Optional[Tensor] = None, reference=None) -> Tuple[Tensor, Optional[Tensor]]: 157 | 158 | if self.batch_first: 159 | query, key, value = [x.transpose(1, 0) for x in (query, key, value)] 160 | 161 | if not self._qkv_same_embed_dim: 162 | attn_output, attn_output_weights = multi_head_reference_attention_forward( 163 | query, key, value, self.embed_dim, self.num_heads, 164 | self.in_proj_weight, self.in_proj_bias, 165 | self.bias_k, self.bias_v, self.add_zero_attn, 166 | self.dropout, self.out_proj.weight, self.out_proj.bias, 167 | training=self.training, 168 | key_padding_mask=key_padding_mask, need_weights=need_weights, 169 | attn_mask=attn_mask, use_separate_proj_weight=True, 170 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 171 | v_proj_weight=self.v_proj_weight, 172 | reference=reference, k_proj_weight2=self.k_proj_weight2, b_k2=self.b_k2) 173 | else: 174 | attn_output, attn_output_weights = multi_head_reference_attention_forward( 175 | query, key, value, self.embed_dim, self.num_heads, 176 | self.in_proj_weight, self.in_proj_bias, 177 | self.bias_k, self.bias_v, self.add_zero_attn, 178 | self.dropout, self.out_proj.weight, self.out_proj.bias, 179 | training=self.training, 180 | key_padding_mask=key_padding_mask, need_weights=need_weights, 181 | attn_mask=attn_mask, 182 | reference=reference, k_proj_weight2=self.k_proj_weight2, b_k2=self.b_k2) 183 | if self.batch_first: 184 | return attn_output.transpose(1, 0), attn_output_weights 185 | else: 186 | return attn_output, attn_output_weights 187 | 188 | class ReferenceTransformer(nn.Module): 189 | def __init__( 190 | self, 191 | d_model=512, 192 | nhead=8, 193 | num_encoder_layers=6, 194 | num_decoder_layers=6, 195 | dim_feedforward=2048, 196 | dropout=0.1, 197 | activation="relu", 198 | normalize_before=False, 199 | return_intermediate_dec=False, 200 | ): 201 | super().__init__() 202 | 203 | reference = Reference() 204 | 205 | encoder_layer = TransformerEncoderLayer( 206 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before 207 | ) 208 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 209 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 210 | 211 | decoder_layer = TransformerDecoderLayer( 212 | d_model, nhead, dim_feedforward, dropout, activation, normalize_before, reference 213 | ) 214 | decoder_norm = nn.LayerNorm(d_model) 215 | self.decoder = TransformerDecoder( 216 | decoder_layer, 217 | num_decoder_layers, 218 | decoder_norm, 219 | return_intermediate=return_intermediate_dec, 220 | ) 221 | 222 | self._reset_parameters() 223 | 224 | self.d_model = d_model 225 | self.nhead = nhead 226 | 227 | def _reset_parameters(self): 228 | for p in self.parameters(): 229 | if p.dim() > 1: 230 | nn.init.xavier_uniform_(p) 231 | 232 | def forward(self, src, mask, query_embed, pos_embed): 233 | # flatten NxCxHxW to HWxNxC 234 | bs, c, h, w = src.shape 235 | src = src.flatten(2).permute(2, 0, 1) 236 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 237 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 238 | if mask is not None: 239 | mask = mask.flatten(1) 240 | 241 | tgt = torch.zeros_like(query_embed) 242 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 243 | hs = self.decoder( 244 | tgt, memory, memory_key_padding_mask=mask, pos=pos_embed, query_pos=query_embed 245 | ) 246 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 247 | 248 | 249 | def multi_head_reference_attention_forward( 250 | query: Tensor, 251 | key: Tensor, 252 | value: Tensor, 253 | embed_dim_to_check: int, 254 | num_heads: int, 255 | in_proj_weight: Tensor, 256 | in_proj_bias: Optional[Tensor], 257 | bias_k: Optional[Tensor], 258 | bias_v: Optional[Tensor], 259 | add_zero_attn: bool, 260 | dropout_p: float, 261 | out_proj_weight: Tensor, 262 | out_proj_bias: Optional[Tensor], 263 | training: bool = True, 264 | key_padding_mask: Optional[Tensor] = None, 265 | need_weights: bool = True, 266 | attn_mask: Optional[Tensor] = None, 267 | use_separate_proj_weight: bool = False, 268 | q_proj_weight: Optional[Tensor] = None, 269 | k_proj_weight: Optional[Tensor] = None, 270 | v_proj_weight: Optional[Tensor] = None, 271 | static_k: Optional[Tensor] = None, 272 | static_v: Optional[Tensor] = None, 273 | reference = None, 274 | k_proj_weight2: Optional[Tensor] = None, 275 | b_k2: Optional[Tensor] = None, 276 | ) -> Tuple[Tensor, Optional[Tensor]]: 277 | 278 | tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) 279 | if has_torch_function(tens_ops): 280 | return handle_torch_function( 281 | multi_head_reference_attention_forward, 282 | tens_ops, 283 | query, 284 | key, 285 | value, 286 | embed_dim_to_check, 287 | num_heads, 288 | in_proj_weight, 289 | in_proj_bias, 290 | bias_k, 291 | bias_v, 292 | add_zero_attn, 293 | dropout_p, 294 | out_proj_weight, 295 | out_proj_bias, 296 | training=training, 297 | key_padding_mask=key_padding_mask, 298 | need_weights=need_weights, 299 | attn_mask=attn_mask, 300 | use_separate_proj_weight=use_separate_proj_weight, 301 | q_proj_weight=q_proj_weight, 302 | k_proj_weight=k_proj_weight, 303 | v_proj_weight=v_proj_weight, 304 | static_k=static_k, 305 | static_v=static_v, 306 | reference=reference, 307 | k_proj_weight2=k_proj_weight2, 308 | b_k2=b_k2, 309 | ) 310 | 311 | # set up shape vars 312 | tgt_len, bsz, embed_dim = query.shape 313 | src_len, _, _ = key.shape 314 | assert embed_dim == embed_dim_to_check, \ 315 | f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 316 | if isinstance(embed_dim, torch.Tensor): 317 | # embed_dim can be a tensor when JIT tracing 318 | head_dim = embed_dim.div(num_heads, rounding_mode='trunc') 319 | else: 320 | head_dim = embed_dim // num_heads 321 | assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 322 | if use_separate_proj_weight: 323 | # allow MHA to have different embedding dimensions when separate projection weights are used 324 | assert key.shape[:2] == value.shape[:2], \ 325 | f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" 326 | else: 327 | assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" 328 | 329 | # 330 | # compute in-projection 331 | # 332 | if not use_separate_proj_weight: 333 | q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) 334 | else: 335 | assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" 336 | assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" 337 | assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" 338 | if in_proj_bias is None: 339 | b_q = b_k = b_v = None 340 | else: 341 | b_q, b_k, b_v = in_proj_bias.chunk(3) 342 | q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) 343 | # 对key加入新的映射,成为k2 344 | k2 = F.linear(key, k_proj_weight2, b_k2) 345 | 346 | # prep attention mask 347 | if attn_mask is not None: 348 | if attn_mask.dtype == torch.uint8: 349 | warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 350 | attn_mask = attn_mask.to(torch.bool) 351 | else: 352 | assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ 353 | f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" 354 | # ensure attn_mask's dim is 3 355 | if attn_mask.dim() == 2: 356 | correct_2d_size = (tgt_len, src_len) 357 | if attn_mask.shape != correct_2d_size: 358 | raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") 359 | attn_mask = attn_mask.unsqueeze(0) 360 | elif attn_mask.dim() == 3: 361 | correct_3d_size = (bsz * num_heads, tgt_len, src_len) 362 | if attn_mask.shape != correct_3d_size: 363 | raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") 364 | else: 365 | raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") 366 | 367 | # prep key padding mask 368 | if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 369 | warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 370 | key_padding_mask = key_padding_mask.to(torch.bool) 371 | 372 | # add bias along batch dimension (currently second) 373 | if bias_k is not None and bias_v is not None: 374 | assert static_k is None, "bias cannot be added to static key." 375 | assert static_v is None, "bias cannot be added to static value." 376 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 377 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 378 | if attn_mask is not None: 379 | attn_mask = F.pad(attn_mask, (0, 1)) 380 | if key_padding_mask is not None: 381 | key_padding_mask = F.pad(key_padding_mask, (0, 1)) 382 | else: 383 | assert bias_k is None 384 | assert bias_v is None 385 | 386 | # 387 | # reshape q, k, v for multihead attention and make em batch first 388 | # 389 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 390 | if static_k is None: 391 | k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 392 | k2 = k2.contiguous().view(k2.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 393 | else: 394 | # TODO finish disentangling control flow so we don't do in-projections when statics are passed 395 | assert static_k.size(0) == bsz * num_heads, \ 396 | f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" 397 | assert static_k.size(2) == head_dim, \ 398 | f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" 399 | k = static_k 400 | if static_v is None: 401 | v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 402 | else: 403 | # TODO finish disentangling control flow so we don't do in-projections when statics are passed 404 | assert static_v.size(0) == bsz * num_heads, \ 405 | f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" 406 | assert static_v.size(2) == head_dim, \ 407 | f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" 408 | v = static_v 409 | 410 | # add zero attention along batch dimension (now first) 411 | if add_zero_attn: 412 | zero_attn_shape = (bsz * num_heads, 1, head_dim) 413 | k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) 414 | v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) 415 | if attn_mask is not None: 416 | attn_mask = F.pad(attn_mask, (0, 1)) 417 | if key_padding_mask is not None: 418 | key_padding_mask = F.pad(key_padding_mask, (0, 1)) 419 | 420 | # update source sequence length after adjustments 421 | src_len = k.size(1) 422 | 423 | # merge key padding and attention masks 424 | if key_padding_mask is not None: 425 | assert key_padding_mask.shape == (bsz, src_len), \ 426 | f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" 427 | key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ 428 | expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) 429 | if attn_mask is None: 430 | attn_mask = key_padding_mask 431 | elif attn_mask.dtype == torch.bool: 432 | attn_mask = attn_mask.logical_or(key_padding_mask) 433 | else: 434 | attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) 435 | 436 | # convert mask to float 437 | if attn_mask is not None and attn_mask.dtype == torch.bool: 438 | new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) 439 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 440 | attn_mask = new_attn_mask 441 | 442 | # adjust dropout probability 443 | if not training: 444 | dropout_p = 0.0 445 | 446 | # 447 | # (deep breath) calculate attention and out projection 448 | # 449 | attn_output, attn_output_weights = _scaled_dot_product_attention_with_reference(q, k, v, k2, attn_mask, dropout_p, reference=reference) 450 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 451 | attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 452 | 453 | if need_weights: 454 | # average attention weights over heads 455 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 456 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 457 | else: 458 | return attn_output, None 459 | 460 | def _in_projection_packed( 461 | q: Tensor, 462 | k: Tensor, 463 | v: Tensor, 464 | w: Tensor, 465 | b: Optional[Tensor] = None, 466 | ) -> List[Tensor]: 467 | 468 | E = q.size(-1) 469 | if k is v: 470 | if q is k: 471 | # self-attention 472 | return F.linear(q, w, b).chunk(3, dim=-1) 473 | else: 474 | # encoder-decoder attention 475 | w_q, w_kv = w.split([E, E * 2]) 476 | if b is None: 477 | b_q = b_kv = None 478 | else: 479 | b_q, b_kv = b.split([E, E * 2]) 480 | return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) 481 | else: 482 | w_q, w_k, w_v = w.chunk(3) 483 | if b is None: 484 | b_q = b_k = b_v = None 485 | else: 486 | b_q, b_k, b_v = b.chunk(3) 487 | return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) 488 | 489 | 490 | def _in_projection( 491 | q: Tensor, 492 | k: Tensor, 493 | v: Tensor, 494 | w_q: Tensor, 495 | w_k: Tensor, 496 | w_v: Tensor, 497 | b_q: Optional[Tensor] = None, 498 | b_k: Optional[Tensor] = None, 499 | b_v: Optional[Tensor] = None, 500 | ) -> Tuple[Tensor, Tensor, Tensor]: 501 | 502 | Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) 503 | assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" 504 | assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" 505 | assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" 506 | assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" 507 | assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" 508 | assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" 509 | return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) 510 | 511 | 512 | def _scaled_dot_product_attention_with_reference( 513 | q: Tensor, 514 | k: Tensor, 515 | v: Tensor, 516 | k2: Tensor, 517 | attn_mask: Optional[Tensor] = None, 518 | dropout_p: float = 0.0, 519 | reference = None, 520 | ) -> Tuple[Tensor, Tensor]: 521 | 522 | B, Nt, E = q.shape 523 | q = q / math.sqrt(E) 524 | # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) 525 | attn = torch.bmm(q, k.transpose(-2, -1)) 526 | 527 | # process attention mask by reference attention 528 | if reference is not None: 529 | attn = reference(attn, k, k2) 530 | 531 | if attn_mask is not None: 532 | attn += attn_mask 533 | 534 | attn = F.softmax(attn, dim=-1) 535 | if dropout_p > 0.0: 536 | attn = F.dropout(attn, p=dropout_p) 537 | # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) 538 | output = torch.bmm(attn, v) 539 | return output, attn 540 | 541 | class TransformerEncoder(nn.Module): 542 | def __init__(self, encoder_layer, num_layers, norm=None): 543 | super().__init__() 544 | self.layers = _get_clones(encoder_layer, num_layers) 545 | self.num_layers = num_layers 546 | self.norm = norm 547 | 548 | def forward( 549 | self, 550 | src, 551 | mask: Optional[Tensor] = None, 552 | src_key_padding_mask: Optional[Tensor] = None, 553 | pos: Optional[Tensor] = None, 554 | ): 555 | output = src 556 | 557 | for layer in self.layers: 558 | output = layer( 559 | output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos 560 | ) 561 | 562 | if self.norm is not None: 563 | output = self.norm(output) 564 | 565 | return output 566 | 567 | 568 | class TransformerDecoder(nn.Module): 569 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 570 | super().__init__() 571 | self.layers = _get_clones(decoder_layer, num_layers) 572 | self.num_layers = num_layers 573 | self.norm = norm 574 | self.return_intermediate = return_intermediate 575 | 576 | def forward( 577 | self, 578 | tgt, 579 | memory, 580 | tgt_mask: Optional[Tensor] = None, 581 | memory_mask: Optional[Tensor] = None, 582 | tgt_key_padding_mask: Optional[Tensor] = None, 583 | memory_key_padding_mask: Optional[Tensor] = None, 584 | pos: Optional[Tensor] = None, 585 | query_pos: Optional[Tensor] = None, 586 | ): 587 | output = tgt 588 | 589 | intermediate = [] 590 | 591 | for layer in self.layers: 592 | output = layer( 593 | output, 594 | memory, 595 | tgt_mask=tgt_mask, 596 | memory_mask=memory_mask, 597 | tgt_key_padding_mask=tgt_key_padding_mask, 598 | memory_key_padding_mask=memory_key_padding_mask, 599 | pos=pos, 600 | query_pos=query_pos, 601 | ) 602 | if self.return_intermediate: 603 | intermediate.append(self.norm(output)) 604 | 605 | if self.norm is not None: 606 | output = self.norm(output) 607 | if self.return_intermediate: 608 | intermediate.pop() 609 | intermediate.append(output) 610 | 611 | if self.return_intermediate: 612 | return torch.stack(intermediate) 613 | 614 | return output.unsqueeze(0) 615 | 616 | 617 | class TransformerEncoderLayer(nn.Module): 618 | def __init__( 619 | self, 620 | d_model, 621 | nhead, 622 | dim_feedforward=2048, 623 | dropout=0.1, 624 | activation="relu", 625 | normalize_before=False, 626 | ): 627 | super().__init__() 628 | self.self_attn = ReferenceAttention(d_model, nhead, dropout=dropout) 629 | # Implementation of Feedforward model 630 | self.linear1 = nn.Linear(d_model, dim_feedforward) 631 | self.dropout = nn.Dropout(dropout) 632 | self.linear2 = nn.Linear(dim_feedforward, d_model) 633 | 634 | self.norm1 = nn.LayerNorm(d_model) 635 | self.norm2 = nn.LayerNorm(d_model) 636 | self.dropout1 = nn.Dropout(dropout) 637 | self.dropout2 = nn.Dropout(dropout) 638 | 639 | self.activation = _get_activation_fn(activation) 640 | self.normalize_before = normalize_before 641 | 642 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 643 | return tensor if pos is None else tensor + pos 644 | 645 | def forward_post( 646 | self, 647 | src, 648 | src_mask: Optional[Tensor] = None, 649 | src_key_padding_mask: Optional[Tensor] = None, 650 | pos: Optional[Tensor] = None, 651 | ): 652 | q = k = self.with_pos_embed(src, pos) 653 | src2 = self.self_attn( 654 | q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 655 | )[0] 656 | src = src + self.dropout1(src2) 657 | src = self.norm1(src) 658 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 659 | src = src + self.dropout2(src2) 660 | src = self.norm2(src) 661 | return src 662 | 663 | def forward_pre( 664 | self, 665 | src, 666 | src_mask: Optional[Tensor] = None, 667 | src_key_padding_mask: Optional[Tensor] = None, 668 | pos: Optional[Tensor] = None, 669 | ): 670 | src2 = self.norm1(src) 671 | q = k = self.with_pos_embed(src2, pos) 672 | src2 = self.self_attn( 673 | q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask 674 | )[0] 675 | src = src + self.dropout1(src2) 676 | src2 = self.norm2(src) 677 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 678 | src = src + self.dropout2(src2) 679 | return src 680 | 681 | def forward( 682 | self, 683 | src, 684 | src_mask: Optional[Tensor] = None, 685 | src_key_padding_mask: Optional[Tensor] = None, 686 | pos: Optional[Tensor] = None, 687 | ): 688 | if self.normalize_before: 689 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 690 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 691 | 692 | 693 | class TransformerDecoderLayer(nn.Module): 694 | def __init__( 695 | self, 696 | d_model, 697 | nhead, 698 | dim_feedforward=2048, 699 | dropout=0.1, 700 | activation="relu", 701 | normalize_before=False, 702 | reference=None, 703 | ): 704 | super().__init__() 705 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 706 | self.multihead_attn = ReferenceAttention(d_model, nhead, dropout=dropout) 707 | # Implementation of Feedforward model 708 | self.linear1 = nn.Linear(d_model, dim_feedforward) 709 | self.dropout = nn.Dropout(dropout) 710 | self.linear2 = nn.Linear(dim_feedforward, d_model) 711 | 712 | self.norm1 = nn.LayerNorm(d_model) 713 | self.norm2 = nn.LayerNorm(d_model) 714 | self.norm3 = nn.LayerNorm(d_model) 715 | self.dropout1 = nn.Dropout(dropout) 716 | self.dropout2 = nn.Dropout(dropout) 717 | self.dropout3 = nn.Dropout(dropout) 718 | 719 | self.activation = _get_activation_fn(activation) 720 | self.normalize_before = normalize_before 721 | 722 | self.reference = reference 723 | 724 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 725 | return tensor if pos is None else tensor + pos 726 | 727 | def forward_post( 728 | self, 729 | tgt, 730 | memory, 731 | tgt_mask: Optional[Tensor] = None, 732 | memory_mask: Optional[Tensor] = None, 733 | tgt_key_padding_mask: Optional[Tensor] = None, 734 | memory_key_padding_mask: Optional[Tensor] = None, 735 | pos: Optional[Tensor] = None, 736 | query_pos: Optional[Tensor] = None, 737 | ): 738 | q = k = self.with_pos_embed(tgt, query_pos) 739 | tgt2 = self.self_attn( 740 | q, k, value=tgt, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 741 | )[0] 742 | tgt = tgt + self.dropout1(tgt2) 743 | tgt = self.norm1(tgt) 744 | tgt2 = self.multihead_attn( 745 | query=self.with_pos_embed(tgt, query_pos), 746 | key=self.with_pos_embed(memory, pos), 747 | value=memory, 748 | attn_mask=memory_mask, 749 | key_padding_mask=memory_key_padding_mask, 750 | reference=self.reference, 751 | )[0] 752 | tgt = tgt + self.dropout2(tgt2) 753 | tgt = self.norm2(tgt) 754 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 755 | tgt = tgt + self.dropout3(tgt2) 756 | tgt = self.norm3(tgt) 757 | return tgt 758 | 759 | def forward_pre( 760 | self, 761 | tgt, 762 | memory, 763 | tgt_mask: Optional[Tensor] = None, 764 | memory_mask: Optional[Tensor] = None, 765 | tgt_key_padding_mask: Optional[Tensor] = None, 766 | memory_key_padding_mask: Optional[Tensor] = None, 767 | pos: Optional[Tensor] = None, 768 | query_pos: Optional[Tensor] = None, 769 | ): 770 | tgt2 = self.norm1(tgt) 771 | q = k = self.with_pos_embed(tgt2, query_pos) 772 | tgt2 = self.self_attn( 773 | q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask 774 | )[0] 775 | tgt = tgt + self.dropout1(tgt2) 776 | tgt2 = self.norm2(tgt) 777 | tgt2 = self.multihead_attn( 778 | query=self.with_pos_embed(tgt2, query_pos), 779 | key=self.with_pos_embed(memory, pos), 780 | value=memory, 781 | attn_mask=memory_mask, 782 | key_padding_mask=memory_key_padding_mask, 783 | reference=self.reference 784 | )[0] 785 | tgt = tgt + self.dropout2(tgt2) 786 | tgt2 = self.norm3(tgt) 787 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 788 | tgt = tgt + self.dropout3(tgt2) 789 | return tgt 790 | 791 | def forward( 792 | self, 793 | tgt, 794 | memory, 795 | tgt_mask: Optional[Tensor] = None, 796 | memory_mask: Optional[Tensor] = None, 797 | tgt_key_padding_mask: Optional[Tensor] = None, 798 | memory_key_padding_mask: Optional[Tensor] = None, 799 | pos: Optional[Tensor] = None, 800 | query_pos: Optional[Tensor] = None, 801 | ): 802 | if self.normalize_before: 803 | return self.forward_pre( 804 | tgt, 805 | memory, 806 | tgt_mask, 807 | memory_mask, 808 | tgt_key_padding_mask, 809 | memory_key_padding_mask, 810 | pos, 811 | query_pos, 812 | ) 813 | return self.forward_post( 814 | tgt, 815 | memory, 816 | tgt_mask, 817 | memory_mask, 818 | tgt_key_padding_mask, 819 | memory_key_padding_mask, 820 | pos, 821 | query_pos, 822 | ) 823 | 824 | 825 | def _get_clones(module, N): 826 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 827 | 828 | 829 | def _get_activation_fn(activation): 830 | """Return an activation function given a string""" 831 | if activation == "relu": 832 | return F.relu 833 | if activation == "gelu": 834 | return F.gelu 835 | if activation == "glu": 836 | return F.glu 837 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 838 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/dcnet/modeling/PCD/__init__.py -------------------------------------------------------------------------------- /dcnet/modeling/PCD/difference_attention.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Optional, Tuple, List 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch import nn 7 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ 8 | from torch.nn.parameter import Parameter 9 | from torch.nn.modules import Module 10 | from torch.nn import functional as F 11 | import math 12 | 13 | from torch.overrides import ( 14 | has_torch_function, handle_torch_function) 15 | 16 | 17 | class DiffAttention(Module): 18 | 19 | __constants__ = ['batch_first'] 20 | bias_k: Optional[torch.Tensor] 21 | bias_v: Optional[torch.Tensor] 22 | 23 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, 24 | kdim=None, vdim=None, batch_first=False, device=None, dtype=None) -> None: 25 | factory_kwargs = {'device': device, 'dtype': dtype} 26 | super(DiffAttention, self).__init__() 27 | self.embed_dim = embed_dim 28 | self.kdim = kdim if kdim is not None else embed_dim 29 | self.vdim = vdim if vdim is not None else embed_dim 30 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 31 | 32 | self.num_heads = num_heads 33 | self.dropout = dropout 34 | self.batch_first = batch_first 35 | self.head_dim = embed_dim // num_heads 36 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 37 | 38 | if self._qkv_same_embed_dim is False: 39 | self.q_proj_weight = Parameter(torch.empty((embed_dim, embed_dim), **factory_kwargs)) 40 | self.k_proj_weight = Parameter(torch.empty((embed_dim, self.kdim), **factory_kwargs)) 41 | self.v_proj_weight = Parameter(torch.empty((embed_dim, self.vdim), **factory_kwargs)) 42 | self.register_parameter('in_proj_weight', None) 43 | else: 44 | self.in_proj_weight = Parameter(torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)) 45 | self.register_parameter('q_proj_weight', None) 46 | self.register_parameter('k_proj_weight', None) 47 | self.register_parameter('v_proj_weight', None) 48 | 49 | if bias: 50 | self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs)) 51 | else: 52 | self.register_parameter('in_proj_bias', None) 53 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) 54 | 55 | if add_bias_kv: 56 | self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 57 | self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) 58 | else: 59 | self.bias_k = self.bias_v = None 60 | 61 | self.add_zero_attn = add_zero_attn 62 | 63 | self._reset_parameters() 64 | 65 | def _reset_parameters(self): 66 | if self._qkv_same_embed_dim: 67 | xavier_uniform_(self.in_proj_weight) 68 | else: 69 | xavier_uniform_(self.q_proj_weight) 70 | xavier_uniform_(self.k_proj_weight) 71 | xavier_uniform_(self.v_proj_weight) 72 | 73 | if self.in_proj_bias is not None: 74 | constant_(self.in_proj_bias, 0.) 75 | constant_(self.out_proj.bias, 0.) 76 | if self.bias_k is not None: 77 | xavier_normal_(self.bias_k) 78 | if self.bias_v is not None: 79 | xavier_normal_(self.bias_v) 80 | 81 | def __setstate__(self, state): 82 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 83 | if '_qkv_same_embed_dim' not in state: 84 | state['_qkv_same_embed_dim'] = True 85 | 86 | super(DiffAttention, self).__setstate__(state) 87 | 88 | def forward(self, query: Tensor, key: Tensor, value: Tensor, key_padding_mask: Optional[Tensor] = None, 89 | need_weights: bool = True, attn_mask: Optional[Tensor] = None) -> Tuple[Tensor, Optional[Tensor]]: 90 | 91 | if self.batch_first: 92 | query, key, value = [x.transpose(1, 0) for x in (query, key, value)] 93 | 94 | if not self._qkv_same_embed_dim: 95 | attn_output, attn_output_weights = multi_head_difference_attention_forward( 96 | query, key, value, self.embed_dim, self.num_heads, 97 | self.in_proj_weight, self.in_proj_bias, 98 | self.bias_k, self.bias_v, self.add_zero_attn, 99 | self.dropout, self.out_proj.weight, self.out_proj.bias, 100 | training=self.training, 101 | key_padding_mask=key_padding_mask, need_weights=need_weights, 102 | attn_mask=attn_mask, use_separate_proj_weight=True, 103 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 104 | v_proj_weight=self.v_proj_weight) 105 | else: 106 | attn_output, attn_output_weights = multi_head_difference_attention_forward( 107 | query, key, value, self.embed_dim, self.num_heads, 108 | self.in_proj_weight, self.in_proj_bias, 109 | self.bias_k, self.bias_v, self.add_zero_attn, 110 | self.dropout, self.out_proj.weight, self.out_proj.bias, 111 | training=self.training, 112 | key_padding_mask=key_padding_mask, need_weights=need_weights, 113 | attn_mask=attn_mask) 114 | if self.batch_first: 115 | return attn_output.transpose(1, 0), attn_output_weights 116 | else: 117 | return attn_output, attn_output_weights 118 | 119 | def multi_head_difference_attention_forward( 120 | query: Tensor, 121 | key: Tensor, 122 | value: Tensor, 123 | embed_dim_to_check: int, 124 | num_heads: int, 125 | in_proj_weight: Tensor, 126 | in_proj_bias: Optional[Tensor], 127 | bias_k: Optional[Tensor], 128 | bias_v: Optional[Tensor], 129 | add_zero_attn: bool, 130 | dropout_p: float, 131 | out_proj_weight: Tensor, 132 | out_proj_bias: Optional[Tensor], 133 | training: bool = True, 134 | key_padding_mask: Optional[Tensor] = None, 135 | need_weights: bool = True, 136 | attn_mask: Optional[Tensor] = None, 137 | use_separate_proj_weight: bool = False, 138 | q_proj_weight: Optional[Tensor] = None, 139 | k_proj_weight: Optional[Tensor] = None, 140 | v_proj_weight: Optional[Tensor] = None, 141 | static_k: Optional[Tensor] = None, 142 | static_v: Optional[Tensor] = None, 143 | ) -> Tuple[Tensor, Optional[Tensor]]: 144 | 145 | tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, out_proj_weight, out_proj_bias) 146 | if has_torch_function(tens_ops): 147 | return handle_torch_function( 148 | multi_head_difference_attention_forward, 149 | tens_ops, 150 | query, 151 | key, 152 | value, 153 | embed_dim_to_check, 154 | num_heads, 155 | in_proj_weight, 156 | in_proj_bias, 157 | bias_k, 158 | bias_v, 159 | add_zero_attn, 160 | dropout_p, 161 | out_proj_weight, 162 | out_proj_bias, 163 | training=training, 164 | key_padding_mask=key_padding_mask, 165 | need_weights=need_weights, 166 | attn_mask=attn_mask, 167 | use_separate_proj_weight=use_separate_proj_weight, 168 | q_proj_weight=q_proj_weight, 169 | k_proj_weight=k_proj_weight, 170 | v_proj_weight=v_proj_weight, 171 | static_k=static_k, 172 | static_v=static_v, 173 | ) 174 | 175 | # set up shape vars 176 | tgt_len, bsz, embed_dim = query.shape 177 | src_len, _, _ = key.shape 178 | assert embed_dim == embed_dim_to_check, \ 179 | f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}" 180 | if isinstance(embed_dim, torch.Tensor): 181 | # embed_dim can be a tensor when JIT tracing 182 | head_dim = embed_dim.div(num_heads, rounding_mode='trunc') 183 | else: 184 | head_dim = embed_dim // num_heads 185 | assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}" 186 | if use_separate_proj_weight: 187 | # allow MHA to have different embedding dimensions when separate projection weights are used 188 | assert key.shape[:2] == value.shape[:2], \ 189 | f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}" 190 | else: 191 | assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}" 192 | 193 | # 194 | # compute in-projection 195 | # 196 | if not use_separate_proj_weight: 197 | q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias) 198 | else: 199 | assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None" 200 | assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None" 201 | assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None" 202 | if in_proj_bias is None: 203 | b_q = b_k = b_v = None 204 | else: 205 | b_q, b_k, b_v = in_proj_bias.chunk(3) 206 | q, k, v = _in_projection(query, key, value, q_proj_weight, k_proj_weight, v_proj_weight, b_q, b_k, b_v) 207 | 208 | # prep attention mask 209 | if attn_mask is not None: 210 | if attn_mask.dtype == torch.uint8: 211 | warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 212 | attn_mask = attn_mask.to(torch.bool) 213 | else: 214 | assert attn_mask.is_floating_point() or attn_mask.dtype == torch.bool, \ 215 | f"Only float, byte, and bool types are supported for attn_mask, not {attn_mask.dtype}" 216 | # ensure attn_mask's dim is 3 217 | if attn_mask.dim() == 2: 218 | correct_2d_size = (tgt_len, src_len) 219 | if attn_mask.shape != correct_2d_size: 220 | raise RuntimeError(f"The shape of the 2D attn_mask is {attn_mask.shape}, but should be {correct_2d_size}.") 221 | attn_mask = attn_mask.unsqueeze(0) 222 | elif attn_mask.dim() == 3: 223 | correct_3d_size = (bsz * num_heads, tgt_len, src_len) 224 | if attn_mask.shape != correct_3d_size: 225 | raise RuntimeError(f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}.") 226 | else: 227 | raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported") 228 | 229 | # prep key padding mask 230 | if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 231 | warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 232 | key_padding_mask = key_padding_mask.to(torch.bool) 233 | 234 | # add bias along batch dimension (currently second) 235 | if bias_k is not None and bias_v is not None: 236 | assert static_k is None, "bias cannot be added to static key." 237 | assert static_v is None, "bias cannot be added to static value." 238 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 239 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 240 | if attn_mask is not None: 241 | attn_mask = F.pad(attn_mask, (0, 1)) 242 | if key_padding_mask is not None: 243 | key_padding_mask = F.pad(key_padding_mask, (0, 1)) 244 | else: 245 | assert bias_k is None 246 | assert bias_v is None 247 | 248 | # 249 | # reshape q, k, v for multihead attention and make em batch first 250 | # 251 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 252 | if static_k is None: 253 | k = k.contiguous().view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 254 | else: 255 | # TODO finish disentangling control flow so we don't do in-projections when statics are passed 256 | assert static_k.size(0) == bsz * num_heads, \ 257 | f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}" 258 | assert static_k.size(2) == head_dim, \ 259 | f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}" 260 | k = static_k 261 | if static_v is None: 262 | v = v.contiguous().view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1) 263 | else: 264 | # TODO finish disentangling control flow so we don't do in-projections when statics are passed 265 | assert static_v.size(0) == bsz * num_heads, \ 266 | f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}" 267 | assert static_v.size(2) == head_dim, \ 268 | f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}" 269 | v = static_v 270 | 271 | # add zero attention along batch dimension (now first) 272 | if add_zero_attn: 273 | zero_attn_shape = (bsz * num_heads, 1, head_dim) 274 | k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1) 275 | v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1) 276 | if attn_mask is not None: 277 | attn_mask = F.pad(attn_mask, (0, 1)) 278 | if key_padding_mask is not None: 279 | key_padding_mask = F.pad(key_padding_mask, (0, 1)) 280 | 281 | # update source sequence length after adjustments 282 | src_len = k.size(1) 283 | 284 | # merge key padding and attention masks 285 | if key_padding_mask is not None: 286 | assert key_padding_mask.shape == (bsz, src_len), \ 287 | f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}" 288 | key_padding_mask = key_padding_mask.view(bsz, 1, 1, src_len). \ 289 | expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len) 290 | if attn_mask is None: 291 | attn_mask = key_padding_mask 292 | elif attn_mask.dtype == torch.bool: 293 | attn_mask = attn_mask.logical_or(key_padding_mask) 294 | else: 295 | attn_mask = attn_mask.masked_fill(key_padding_mask, float("-inf")) 296 | 297 | # convert mask to float 298 | if attn_mask is not None and attn_mask.dtype == torch.bool: 299 | new_attn_mask = torch.zeros_like(attn_mask, dtype=torch.float) 300 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 301 | attn_mask = new_attn_mask 302 | 303 | # adjust dropout probability 304 | if not training: 305 | dropout_p = 0.0 306 | 307 | # 308 | # (deep breath) calculate attention and out projection 309 | # 310 | attn_output, attn_output_weights = _difference_attention(q, k, v, attn_mask, dropout_p) 311 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 312 | attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 313 | 314 | if need_weights: 315 | # average attention weights over heads 316 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 317 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 318 | else: 319 | return attn_output, None 320 | 321 | def _in_projection_packed( 322 | q: Tensor, 323 | k: Tensor, 324 | v: Tensor, 325 | w: Tensor, 326 | b: Optional[Tensor] = None, 327 | ) -> List[Tensor]: 328 | 329 | E = q.size(-1) 330 | if k is v: 331 | if q is k: 332 | # self-attention 333 | return F.linear(q, w, b).chunk(3, dim=-1) 334 | else: 335 | # encoder-decoder attention 336 | w_q, w_kv = w.split([E, E * 2]) 337 | if b is None: 338 | b_q = b_kv = None 339 | else: 340 | b_q, b_kv = b.split([E, E * 2]) 341 | return (F.linear(q, w_q, b_q),) + F.linear(k, w_kv, b_kv).chunk(2, dim=-1) 342 | else: 343 | w_q, w_k, w_v = w.chunk(3) 344 | if b is None: 345 | b_q = b_k = b_v = None 346 | else: 347 | b_q, b_k, b_v = b.chunk(3) 348 | return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) 349 | 350 | 351 | def _in_projection( 352 | q: Tensor, 353 | k: Tensor, 354 | v: Tensor, 355 | w_q: Tensor, 356 | w_k: Tensor, 357 | w_v: Tensor, 358 | b_q: Optional[Tensor] = None, 359 | b_k: Optional[Tensor] = None, 360 | b_v: Optional[Tensor] = None, 361 | ) -> Tuple[Tensor, Tensor, Tensor]: 362 | 363 | Eq, Ek, Ev = q.size(-1), k.size(-1), v.size(-1) 364 | assert w_q.shape == (Eq, Eq), f"expecting query weights shape of {(Eq, Eq)}, but got {w_q.shape}" 365 | assert w_k.shape == (Eq, Ek), f"expecting key weights shape of {(Eq, Ek)}, but got {w_k.shape}" 366 | assert w_v.shape == (Eq, Ev), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}" 367 | assert b_q is None or b_q.shape == (Eq,), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}" 368 | assert b_k is None or b_k.shape == (Eq,), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}" 369 | assert b_v is None or b_v.shape == (Eq,), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}" 370 | return F.linear(q, w_q, b_q), F.linear(k, w_k, b_k), F.linear(v, w_v, b_v) 371 | 372 | 373 | def _difference_attention( 374 | q: Tensor, 375 | k: Tensor, 376 | v: Tensor, 377 | attn_mask: Optional[Tensor] = None, 378 | dropout_p: float = 0.0, 379 | ) -> Tuple[Tensor, Tensor]: 380 | 381 | B, Nt, E = q.shape 382 | B, Ns, _ = k.shape 383 | # q = q / math.sqrt(E) 384 | # (B, Nt, E) x (B, E, Ns) -> (B, Nt, Ns) 385 | # attn = torch.bmm(q, k.transpose(-2, -1)) 386 | 387 | q = q.unsqueeze(2).repeat(1, 1, Ns, 1) 388 | k = k.unsqueeze(1).repeat(1, Nt, 1, 1) 389 | 390 | diff = q - k # (B, Nt, Ns, E) 391 | attn = (diff ** 2).sum(dim=-1) / math.sqrt(E) # (B, Nt, Ns) 392 | 393 | if attn_mask is not None: 394 | attn += attn_mask 395 | # attn = F.softmax(attn, dim=-1) 396 | if dropout_p > 0.0: 397 | attn = F.dropout(attn, p=dropout_p) 398 | # (B, Nt, Ns) x (B, Ns, E) -> (B, Nt, E) 399 | output = torch.bmm(attn, v) 400 | return output, attn 401 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn_func import MSDeformAttnFunction 13 | 14 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch.autograd import Function 19 | from torch.autograd.function import once_differentiable 20 | 21 | try: 22 | import MultiScaleDeformableAttention as MSDA 23 | except ModuleNotFoundError as e: 24 | info_string = ( 25 | "\n\nPlease compile MultiScaleDeformableAttention CUDA op with the following commands:\n" 26 | "\t`cd dcnet\modeling\PCD\ops`\n" 27 | "\t`sh make.sh`\n" 28 | ) 29 | raise ModuleNotFoundError(info_string) 30 | 31 | 32 | class MSDeformAttnFunction(Function): 33 | @staticmethod 34 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 35 | ctx.im2col_step = im2col_step 36 | output = MSDA.ms_deform_attn_forward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 38 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 39 | return output 40 | 41 | @staticmethod 42 | @once_differentiable 43 | def backward(ctx, grad_output): 44 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 45 | grad_value, grad_sampling_loc, grad_attn_weight = \ 46 | MSDA.ms_deform_attn_backward( 47 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 48 | 49 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 50 | 51 | 52 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 53 | # for debug and test only, 54 | # need to use cuda version instead 55 | N_, S_, M_, D_ = value.shape 56 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 57 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 58 | sampling_grids = 2 * sampling_locations - 1 59 | sampling_value_list = [] 60 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 61 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 62 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 63 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 64 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 65 | # N_*M_, D_, Lq_, P_ 66 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 67 | mode='bilinear', padding_mode='zeros', align_corners=False) 68 | sampling_value_list.append(sampling_value_l_) 69 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 70 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 71 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 72 | return output.transpose(1, 2).contiguous() 73 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | # Copyright (c) Facebook, Inc. and its affiliates. 11 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 12 | 13 | python setup.py build install 14 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from .ms_deform_attn import MSDeformAttn 13 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import warnings 17 | import math 18 | 19 | import torch 20 | from torch import nn 21 | import torch.nn.functional as F 22 | from torch.nn.init import xavier_uniform_, constant_ 23 | 24 | from ..functions import MSDeformAttnFunction 25 | from ..functions.ms_deform_attn_func import ms_deform_attn_core_pytorch 26 | 27 | 28 | def _is_power_of_2(n): 29 | if (not isinstance(n, int)) or (n < 0): 30 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 31 | return (n & (n-1) == 0) and n != 0 32 | 33 | 34 | class MSDeformAttn(nn.Module): 35 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 36 | """ 37 | Multi-Scale Deformable Attention Module 38 | :param d_model hidden dimension 39 | :param n_levels number of feature levels 40 | :param n_heads number of attention heads 41 | :param n_points number of sampling points per attention head per feature level 42 | """ 43 | super().__init__() 44 | if d_model % n_heads != 0: 45 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 46 | _d_per_head = d_model // n_heads 47 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 48 | if not _is_power_of_2(_d_per_head): 49 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 50 | "which is more efficient in our CUDA implementation.") 51 | 52 | self.im2col_step = 128 53 | 54 | self.d_model = d_model 55 | self.n_levels = n_levels 56 | self.n_heads = n_heads 57 | self.n_points = n_points 58 | 59 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 60 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 61 | self.value_proj = nn.Linear(d_model, d_model) 62 | self.output_proj = nn.Linear(d_model, d_model) 63 | 64 | self._reset_parameters() 65 | 66 | def _reset_parameters(self): 67 | constant_(self.sampling_offsets.weight.data, 0.) 68 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 69 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 70 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 71 | for i in range(self.n_points): 72 | grid_init[:, :, i, :] *= i + 1 73 | with torch.no_grad(): 74 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 75 | constant_(self.attention_weights.weight.data, 0.) 76 | constant_(self.attention_weights.bias.data, 0.) 77 | xavier_uniform_(self.value_proj.weight.data) 78 | constant_(self.value_proj.bias.data, 0.) 79 | xavier_uniform_(self.output_proj.weight.data) 80 | constant_(self.output_proj.bias.data, 0.) 81 | 82 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 83 | """ 84 | :param query (N, Length_{query}, C) 85 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 86 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 87 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 88 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 89 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 90 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 91 | 92 | :return output (N, Length_{query}, C) 93 | """ 94 | N, Len_q, _ = query.shape 95 | N, Len_in, _ = input_flatten.shape 96 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 97 | 98 | value = self.value_proj(input_flatten) 99 | if input_padding_mask is not None: 100 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 101 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 102 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 103 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 104 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 105 | # N, Len_q, n_heads, n_levels, n_points, 2 106 | if reference_points.shape[-1] == 2: 107 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 108 | sampling_locations = reference_points[:, :, None, :, None, :] \ 109 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 110 | elif reference_points.shape[-1] == 4: 111 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 112 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 113 | else: 114 | raise ValueError( 115 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 116 | output = MSDeformAttnFunction.apply( 117 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 118 | # # For FLOPs calculation only 119 | # output = ms_deform_attn_core_pytorch(value, input_spatial_shapes, sampling_locations, attention_weights) 120 | output = self.output_proj(output) 121 | return output 122 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | import os 13 | import glob 14 | 15 | import torch 16 | 17 | from torch.utils.cpp_extension import CUDA_HOME 18 | from torch.utils.cpp_extension import CppExtension 19 | from torch.utils.cpp_extension import CUDAExtension 20 | 21 | from setuptools import find_packages 22 | from setuptools import setup 23 | 24 | requirements = ["torch", "torchvision"] 25 | 26 | def get_extensions(): 27 | this_dir = os.path.dirname(os.path.abspath(__file__)) 28 | extensions_dir = os.path.join(this_dir, "src") 29 | 30 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 31 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 32 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 33 | 34 | sources = main_file + source_cpu 35 | extension = CppExtension 36 | extra_compile_args = {"cxx": []} 37 | define_macros = [] 38 | 39 | # Force cuda since torch ask for a device, not if cuda is in fact available. 40 | if (os.environ.get('FORCE_CUDA') or torch.cuda.is_available()) and CUDA_HOME is not None: 41 | extension = CUDAExtension 42 | sources += source_cuda 43 | define_macros += [("WITH_CUDA", None)] 44 | extra_compile_args["nvcc"] = [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ] 50 | else: 51 | if CUDA_HOME is None: 52 | raise NotImplementedError('CUDA_HOME is None. Please set environment variable CUDA_HOME.') 53 | else: 54 | raise NotImplementedError('No CUDA runtime is found. Please set FORCE_CUDA=1 or test it by running torch.cuda.is_available().') 55 | 56 | sources = [os.path.join(extensions_dir, s) for s in sources] 57 | include_dirs = [extensions_dir] 58 | ext_modules = [ 59 | extension( 60 | "MultiScaleDeformableAttention", 61 | sources, 62 | include_dirs=include_dirs, 63 | define_macros=define_macros, 64 | extra_compile_args=extra_compile_args, 65 | ) 66 | ] 67 | return ext_modules 68 | 69 | setup( 70 | name="MultiScaleDeformableAttention", 71 | version="1.0", 72 | author="Weijie Su", 73 | url="https://github.com/fundamentalvision/Deformable-DETR", 74 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 75 | packages=find_packages(exclude=("configs", "tests",)), 76 | ext_modules=get_extensions(), 77 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 78 | ) 79 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | 18 | #include 19 | #include 20 | 21 | 22 | at::Tensor 23 | ms_deform_attn_cpu_forward( 24 | const at::Tensor &value, 25 | const at::Tensor &spatial_shapes, 26 | const at::Tensor &level_start_index, 27 | const at::Tensor &sampling_loc, 28 | const at::Tensor &attn_weight, 29 | const int im2col_step) 30 | { 31 | AT_ERROR("Not implement on cpu"); 32 | } 33 | 34 | std::vector 35 | ms_deform_attn_cpu_backward( 36 | const at::Tensor &value, 37 | const at::Tensor &spatial_shapes, 38 | const at::Tensor &level_start_index, 39 | const at::Tensor &sampling_loc, 40 | const at::Tensor &attn_weight, 41 | const at::Tensor &grad_output, 42 | const int im2col_step) 43 | { 44 | AT_ERROR("Not implement on cpu"); 45 | } 46 | 47 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor 20 | ms_deform_attn_cpu_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step); 27 | 28 | std::vector 29 | ms_deform_attn_cpu_backward( 30 | const at::Tensor &value, 31 | const at::Tensor &spatial_shapes, 32 | const at::Tensor &level_start_index, 33 | const at::Tensor &sampling_loc, 34 | const at::Tensor &attn_weight, 35 | const at::Tensor &grad_output, 36 | const int im2col_step); 37 | 38 | 39 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include 17 | #include "cuda/ms_deform_im2col_cuda.cuh" 18 | 19 | #include 20 | #include 21 | #include 22 | #include 23 | 24 | 25 | at::Tensor ms_deform_attn_cuda_forward( 26 | const at::Tensor &value, 27 | const at::Tensor &spatial_shapes, 28 | const at::Tensor &level_start_index, 29 | const at::Tensor &sampling_loc, 30 | const at::Tensor &attn_weight, 31 | const int im2col_step) 32 | { 33 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 34 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 35 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 36 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 37 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 38 | 39 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 40 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 41 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 42 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 43 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 44 | 45 | const int batch = value.size(0); 46 | const int spatial_size = value.size(1); 47 | const int num_heads = value.size(2); 48 | const int channels = value.size(3); 49 | 50 | const int num_levels = spatial_shapes.size(0); 51 | 52 | const int num_query = sampling_loc.size(1); 53 | const int num_point = sampling_loc.size(4); 54 | 55 | const int im2col_step_ = std::min(batch, im2col_step); 56 | 57 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 58 | 59 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 60 | 61 | const int batch_n = im2col_step_; 62 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 63 | auto per_value_size = spatial_size * num_heads * channels; 64 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 65 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 66 | for (int n = 0; n < batch/im2col_step_; ++n) 67 | { 68 | auto columns = output_n.select(0, n); 69 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 70 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 71 | value.data() + n * im2col_step_ * per_value_size, 72 | spatial_shapes.data(), 73 | level_start_index.data(), 74 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 75 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 76 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 77 | columns.data()); 78 | 79 | })); 80 | } 81 | 82 | output = output.view({batch, num_query, num_heads*channels}); 83 | 84 | return output; 85 | } 86 | 87 | 88 | std::vector ms_deform_attn_cuda_backward( 89 | const at::Tensor &value, 90 | const at::Tensor &spatial_shapes, 91 | const at::Tensor &level_start_index, 92 | const at::Tensor &sampling_loc, 93 | const at::Tensor &attn_weight, 94 | const at::Tensor &grad_output, 95 | const int im2col_step) 96 | { 97 | 98 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 99 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 100 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 101 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 102 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 103 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 104 | 105 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 106 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 107 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 108 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 109 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 110 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 111 | 112 | const int batch = value.size(0); 113 | const int spatial_size = value.size(1); 114 | const int num_heads = value.size(2); 115 | const int channels = value.size(3); 116 | 117 | const int num_levels = spatial_shapes.size(0); 118 | 119 | const int num_query = sampling_loc.size(1); 120 | const int num_point = sampling_loc.size(4); 121 | 122 | const int im2col_step_ = std::min(batch, im2col_step); 123 | 124 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 125 | 126 | auto grad_value = at::zeros_like(value); 127 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 128 | auto grad_attn_weight = at::zeros_like(attn_weight); 129 | 130 | const int batch_n = im2col_step_; 131 | auto per_value_size = spatial_size * num_heads * channels; 132 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 133 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 134 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 135 | 136 | for (int n = 0; n < batch/im2col_step_; ++n) 137 | { 138 | auto grad_output_g = grad_output_n.select(0, n); 139 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 140 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 141 | grad_output_g.data(), 142 | value.data() + n * im2col_step_ * per_value_size, 143 | spatial_shapes.data(), 144 | level_start_index.data(), 145 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 146 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 147 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 148 | grad_value.data() + n * im2col_step_ * per_value_size, 149 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 150 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 151 | 152 | })); 153 | } 154 | 155 | return { 156 | grad_value, grad_sampling_loc, grad_attn_weight 157 | }; 158 | } -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | #include 18 | 19 | at::Tensor ms_deform_attn_cuda_forward( 20 | const at::Tensor &value, 21 | const at::Tensor &spatial_shapes, 22 | const at::Tensor &level_start_index, 23 | const at::Tensor &sampling_loc, 24 | const at::Tensor &attn_weight, 25 | const int im2col_step); 26 | 27 | std::vector ms_deform_attn_cuda_backward( 28 | const at::Tensor &value, 29 | const at::Tensor &spatial_shapes, 30 | const at::Tensor &level_start_index, 31 | const at::Tensor &sampling_loc, 32 | const at::Tensor &attn_weight, 33 | const at::Tensor &grad_output, 34 | const int im2col_step); 35 | 36 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #pragma once 17 | 18 | #include "cpu/ms_deform_attn_cpu.h" 19 | 20 | #ifdef WITH_CUDA 21 | #include "cuda/ms_deform_attn_cuda.h" 22 | #endif 23 | 24 | 25 | at::Tensor 26 | ms_deform_attn_forward( 27 | const at::Tensor &value, 28 | const at::Tensor &spatial_shapes, 29 | const at::Tensor &level_start_index, 30 | const at::Tensor &sampling_loc, 31 | const at::Tensor &attn_weight, 32 | const int im2col_step) 33 | { 34 | if (value.type().is_cuda()) 35 | { 36 | #ifdef WITH_CUDA 37 | return ms_deform_attn_cuda_forward( 38 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 39 | #else 40 | AT_ERROR("Not compiled with GPU support"); 41 | #endif 42 | } 43 | AT_ERROR("Not implemented on the CPU"); 44 | } 45 | 46 | std::vector 47 | ms_deform_attn_backward( 48 | const at::Tensor &value, 49 | const at::Tensor &spatial_shapes, 50 | const at::Tensor &level_start_index, 51 | const at::Tensor &sampling_loc, 52 | const at::Tensor &attn_weight, 53 | const at::Tensor &grad_output, 54 | const int im2col_step) 55 | { 56 | if (value.type().is_cuda()) 57 | { 58 | #ifdef WITH_CUDA 59 | return ms_deform_attn_cuda_backward( 60 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 61 | #else 62 | AT_ERROR("Not compiled with GPU support"); 63 | #endif 64 | } 65 | AT_ERROR("Not implemented on the CPU"); 66 | } 67 | 68 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | /*! 12 | * Copyright (c) Facebook, Inc. and its affiliates. 13 | * Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 14 | */ 15 | 16 | #include "ms_deform_attn.h" 17 | 18 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 19 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 20 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 21 | } 22 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | # Copyright (c) Facebook, Inc. and its affiliates. 10 | # Modified by Bowen Cheng from https://github.com/fundamentalvision/Deformable-DETR 11 | 12 | from __future__ import absolute_import 13 | from __future__ import print_function 14 | from __future__ import division 15 | 16 | import time 17 | import torch 18 | import torch.nn as nn 19 | from torch.autograd import gradcheck 20 | 21 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 22 | 23 | 24 | N, M, D = 1, 2, 2 25 | Lq, L, P = 2, 2, 2 26 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 27 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 28 | S = sum([(H*W).item() for H, W in shapes]) 29 | 30 | 31 | torch.manual_seed(3) 32 | 33 | 34 | @torch.no_grad() 35 | def check_forward_equal_with_pytorch_double(): 36 | value = torch.rand(N, S, M, D).cuda() * 0.01 37 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 38 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 39 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 40 | im2col_step = 2 41 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 42 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 43 | fwdok = torch.allclose(output_cuda, output_pytorch) 44 | max_abs_err = (output_cuda - output_pytorch).abs().max() 45 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 46 | 47 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 48 | 49 | 50 | @torch.no_grad() 51 | def check_forward_equal_with_pytorch_float(): 52 | value = torch.rand(N, S, M, D).cuda() * 0.01 53 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 54 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 55 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 56 | im2col_step = 2 57 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 58 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 59 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 60 | max_abs_err = (output_cuda - output_pytorch).abs().max() 61 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 62 | 63 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 64 | 65 | 66 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 67 | 68 | value = torch.rand(N, S, M, channels).cuda() * 0.01 69 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 70 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 71 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 72 | im2col_step = 2 73 | func = MSDeformAttnFunction.apply 74 | 75 | value.requires_grad = grad_value 76 | sampling_locations.requires_grad = grad_sampling_loc 77 | attention_weights.requires_grad = grad_attn_weight 78 | 79 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 80 | 81 | print(f'* {gradok} check_gradient_numerical(D={channels})') 82 | 83 | 84 | if __name__ == '__main__': 85 | check_forward_equal_with_pytorch_double() 86 | check_forward_equal_with_pytorch_float() 87 | 88 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 89 | check_gradient_numerical(channels, True, True, True) 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /dcnet/modeling/PCD/pcd.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pixel-level Camouflage Decoupling Module 3 | """ 4 | 5 | import torch 6 | import numpy as np 7 | from torch import nn, Tensor 8 | from torch.nn import functional as F 9 | from typing import Callable, Dict, List, Optional, Tuple, Union 10 | import torchvision 11 | from torch.nn.init import constant_, xavier_normal_, xavier_uniform_, normal_ 12 | import fvcore.nn.weight_init as weight_init 13 | from torch.cuda.amp import autocast 14 | 15 | 16 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 17 | 18 | from ..ICS.position_encoding import PositionEmbeddingSine 19 | from ..ICS.reference_attention import _get_clones, _get_activation_fn 20 | from .ops.modules import MSDeformAttn 21 | from .difference_attention import DiffAttention 22 | 23 | class pcd(nn.Module): 24 | def __init__( 25 | self, 26 | cfg, 27 | input_shape: Dict[str, ShapeSpec], 28 | ): 29 | super().__init__() 30 | self.cam_extractor = CamExtractor(input_shape) 31 | self.fusion_layer = FusionLayer(cfg, input_shape) 32 | 33 | def forward(self, features, images_a): 34 | dc_features = self.cam_extractor(images_a, features) 35 | dc_pixel_features, pixel_embedding = self.fusion_layer.forward_features(dc_features) 36 | 37 | return dc_pixel_features, pixel_embedding 38 | 39 | class CamExtractor(nn.Module): 40 | def __init__( 41 | self, 42 | input_shape, 43 | conv_dim=256 44 | ): 45 | 46 | super().__init__() 47 | 48 | self.backbone = torchvision.models.resnet18(pretrained=True) 49 | 50 | 51 | self.in_features = [k for k, v in input_shape.items()] 52 | 53 | self.feature_names = self.in_features[1:] # [res3, res4, res5] 54 | input_channel2 = 512 55 | 56 | self.input_proj1 = nn.ModuleList() 57 | self.input_proj2 = nn.ModuleList() 58 | 59 | self.attention = nn.ModuleList() 60 | 61 | for feature_name in self.feature_names: 62 | dim1 = input_shape[feature_name].channels 63 | dim2 = input_channel2 64 | 65 | self.input_proj1.append(Conv2d(dim1, conv_dim, kernel_size=1)) 66 | self.input_proj2.append(Conv2d(dim2, conv_dim, kernel_size=1)) 67 | 68 | self.attention.append(DiffAttention(embed_dim=conv_dim, num_heads=1, dropout=0.1)) 69 | 70 | self._reset_parameters() 71 | 72 | def _reset_parameters(self): 73 | for proj in self.input_proj1: 74 | xavier_uniform_(proj.weight) 75 | constant_(proj.bias, 0) 76 | 77 | for proj in self.input_proj2: 78 | xavier_uniform_(proj.weight) 79 | constant_(proj.bias, 0) 80 | 81 | 82 | def forward(self, images_a : Tensor ,features: Dict[str, Tensor]): 83 | x = self.backbone.conv1(images_a) 84 | x = self.backbone.bn1(x) 85 | x = self.backbone.relu(x) 86 | x = self.backbone.maxpool(x) 87 | 88 | x = self.backbone.layer1(x) 89 | x = self.backbone.layer2(x) 90 | x = self.backbone.layer3(x) 91 | x = self.backbone.layer4(x) 92 | 93 | x = self.backbone.avgpool(x) # [B, 512, 1, 1] 94 | 95 | features_enrich = {} 96 | 97 | for i, name in enumerate(self.feature_names): 98 | feature1 = features[name] 99 | 100 | B, C, H, W = feature1.shape 101 | 102 | feature_conv = self.input_proj1[i](feature1) 103 | style_vector = self.input_proj2[i](x) 104 | 105 | feature_conv = feature_conv.flatten(2).permute(2, 0, 1) 106 | style_vector = style_vector.flatten(2).permute(2, 0, 1) 107 | style_vector_update, attn_map = self.attention[i](style_vector, feature_conv, feature_conv) 108 | # attn_map: [B, 1, HW] 109 | 110 | attn_map = attn_map.view(B, 1, H, W) 111 | feature_style = feature1 * attn_map 112 | 113 | features_enrich[name] = feature_style 114 | 115 | for k, v in features.items(): 116 | if k in self.feature_names: 117 | features[k] = features_enrich[k] 118 | 119 | return features 120 | 121 | class FusionLayer(nn.Module): 122 | def __init__( 123 | self, 124 | cfg, 125 | input_shape: Dict[str, ShapeSpec], 126 | transformer_dropout: float = 0.1, 127 | transformer_nheads: int = 8, 128 | transformer_dim_feedforward: int = 1024, 129 | transformer_enc_layers: int = 6, 130 | conv_dim: int = 256, 131 | mask_dim: int = 256, 132 | norm: Optional[Union[str, Callable]] = "GN", 133 | transformer_in_features: List[str] = ["res3", "res4", "res5"], 134 | common_stride: int = 4, 135 | ): 136 | 137 | super().__init__() 138 | 139 | input_shape = {k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES} 140 | transformer_input_shape = { 141 | k: v for k, v in input_shape.items() if k in transformer_in_features 142 | } 143 | # this is the input shape of pixel decoder 144 | input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) 145 | self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" 146 | self.feature_strides = [v.stride for k, v in input_shape] 147 | self.feature_channels = [v.channels for k, v in input_shape] 148 | 149 | # this is the input shape of transformer encoder (could use less features than pixel decoder 150 | transformer_input_shape = sorted(transformer_input_shape.items(), key=lambda x: x[1].stride) 151 | self.transformer_in_features = [k for k, v in transformer_input_shape] # starting from "res2" to "res5" 152 | transformer_in_channels = [v.channels for k, v in transformer_input_shape] 153 | self.transformer_feature_strides = [v.stride for k, v in transformer_input_shape] # to decide extra FPN layers 154 | 155 | self.transformer_num_feature_levels = len(self.transformer_in_features) 156 | if self.transformer_num_feature_levels > 1: 157 | input_proj_list = [] 158 | # from low resolution to high resolution (res5 -> res2) 159 | for in_channels in transformer_in_channels[::-1]: 160 | input_proj_list.append(nn.Sequential( 161 | nn.Conv2d(in_channels, conv_dim, kernel_size=1), 162 | nn.GroupNorm(32, conv_dim), 163 | )) 164 | self.input_proj = nn.ModuleList(input_proj_list) 165 | else: 166 | self.input_proj = nn.ModuleList([ 167 | nn.Sequential( 168 | nn.Conv2d(transformer_in_channels[-1], conv_dim, kernel_size=1), 169 | nn.GroupNorm(32, conv_dim), 170 | )]) 171 | 172 | for proj in self.input_proj: 173 | nn.init.xavier_uniform_(proj[0].weight, gain=1) 174 | nn.init.constant_(proj[0].bias, 0) 175 | 176 | self.transformer = MSDeformAttnTransformerEncoderOnly( 177 | d_model=conv_dim, 178 | dropout=transformer_dropout, 179 | nhead=transformer_nheads, 180 | dim_feedforward=transformer_dim_feedforward, 181 | num_encoder_layers=transformer_enc_layers, 182 | num_feature_levels=self.transformer_num_feature_levels, 183 | ) 184 | N_steps = conv_dim // 2 185 | self.pe_layer = PositionEmbeddingSine(N_steps, normalize=True) 186 | 187 | self.mask_dim = mask_dim 188 | # use 1x1 conv instead 189 | self.mask_features = Conv2d( 190 | conv_dim, 191 | mask_dim, 192 | kernel_size=1, 193 | stride=1, 194 | padding=0, 195 | ) 196 | weight_init.c2_xavier_fill(self.mask_features) 197 | 198 | self.fusion_num_feature_levels = 3 # always use 3 scales 199 | self.common_stride = common_stride 200 | 201 | # extra fpn levels 202 | stride = min(self.transformer_feature_strides) 203 | self.num_fpn_levels = int(np.log2(stride) - np.log2(self.common_stride)) 204 | 205 | lateral_convs = [] 206 | output_convs = [] 207 | 208 | use_bias = norm == "" 209 | for idx, in_channels in enumerate(self.feature_channels[:self.num_fpn_levels]): 210 | lateral_norm = get_norm(norm, conv_dim) 211 | output_norm = get_norm(norm, conv_dim) 212 | 213 | lateral_conv = Conv2d( 214 | in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm 215 | ) 216 | output_conv = Conv2d( 217 | conv_dim, 218 | conv_dim, 219 | kernel_size=3, 220 | stride=1, 221 | padding=1, 222 | bias=use_bias, 223 | norm=output_norm, 224 | activation=F.relu, 225 | ) 226 | weight_init.c2_xavier_fill(lateral_conv) 227 | weight_init.c2_xavier_fill(output_conv) 228 | self.add_module("adapter_{}".format(idx + 1), lateral_conv) 229 | self.add_module("layer_{}".format(idx + 1), output_conv) 230 | 231 | lateral_convs.append(lateral_conv) 232 | output_convs.append(output_conv) 233 | # Place convs into top-down order (from low to high resolution) 234 | # to make the top-down computation in forward clearer. 235 | self.lateral_convs = lateral_convs[::-1] 236 | self.output_convs = output_convs[::-1] 237 | 238 | @autocast(enabled=False) 239 | def forward_features(self, features): 240 | srcs = [] 241 | pos = [] 242 | # Reverse feature maps into top-down order (from low to high resolution) 243 | for idx, f in enumerate(self.transformer_in_features[::-1]): 244 | x = features[f].float() # deformable detr does not support half precision 245 | srcs.append(self.input_proj[idx](x)) 246 | pos.append(self.pe_layer(x)) 247 | 248 | y, spatial_shapes, level_start_index = self.transformer(srcs, pos) 249 | bs = y.shape[0] 250 | 251 | split_size_or_sections = [None] * self.transformer_num_feature_levels 252 | for i in range(self.transformer_num_feature_levels): 253 | if i < self.transformer_num_feature_levels - 1: 254 | split_size_or_sections[i] = level_start_index[i + 1] - level_start_index[i] 255 | else: 256 | split_size_or_sections[i] = y.shape[1] - level_start_index[i] 257 | y = torch.split(y, split_size_or_sections, dim=1) 258 | 259 | out = [] 260 | num_cur_levels = 0 261 | for i, z in enumerate(y): 262 | out.append(z.transpose(1, 2).view(bs, -1, spatial_shapes[i][0], spatial_shapes[i][1])) 263 | 264 | # append `out` with extra FPN levels 265 | # Reverse feature maps into top-down order (from low to high resolution) 266 | for idx, f in enumerate(self.in_features[:self.num_fpn_levels][::-1]): 267 | x = features[f].float() 268 | lateral_conv = self.lateral_convs[idx] 269 | output_conv = self.output_convs[idx] 270 | cur_fpn = lateral_conv(x) 271 | # Following FPN implementation, we use nearest upsampling here 272 | y = cur_fpn + F.interpolate(out[-1], size=cur_fpn.shape[-2:], mode="bilinear", align_corners=False) 273 | y = output_conv(y) 274 | out.append(y) 275 | 276 | return out[0], self.mask_features(out[-1]) 277 | 278 | 279 | # MSDeformAttn Transformer encoder in deformable detr 280 | class MSDeformAttnTransformerEncoderOnly(nn.Module): 281 | def __init__(self, d_model=256, nhead=8, 282 | num_encoder_layers=6, dim_feedforward=1024, dropout=0.1, 283 | activation="relu", 284 | num_feature_levels=4, enc_n_points=4, 285 | ): 286 | super().__init__() 287 | 288 | self.d_model = d_model 289 | self.nhead = nhead 290 | 291 | encoder_layer = MSDeformAttnTransformerEncoderLayer(d_model, dim_feedforward, 292 | dropout, activation, 293 | num_feature_levels, nhead, enc_n_points) 294 | self.encoder = MSDeformAttnTransformerEncoder(encoder_layer, num_encoder_layers) 295 | 296 | self.level_embed = nn.Parameter(torch.Tensor(num_feature_levels, d_model)) 297 | 298 | self._reset_parameters() 299 | 300 | def _reset_parameters(self): 301 | for p in self.parameters(): 302 | if p.dim() > 1: 303 | nn.init.xavier_uniform_(p) 304 | for m in self.modules(): 305 | if isinstance(m, MSDeformAttn): 306 | m._reset_parameters() 307 | normal_(self.level_embed) 308 | 309 | def get_valid_ratio(self, mask): 310 | _, H, W = mask.shape 311 | valid_H = torch.sum(~mask[:, :, 0], 1) 312 | valid_W = torch.sum(~mask[:, 0, :], 1) 313 | valid_ratio_h = valid_H.float() / H 314 | valid_ratio_w = valid_W.float() / W 315 | valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) 316 | return valid_ratio 317 | 318 | def forward(self, srcs, pos_embeds): 319 | masks = [torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) for x in srcs] 320 | # prepare input for encoder 321 | src_flatten = [] 322 | mask_flatten = [] 323 | lvl_pos_embed_flatten = [] 324 | spatial_shapes = [] 325 | for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): 326 | bs, c, h, w = src.shape 327 | spatial_shape = (h, w) 328 | spatial_shapes.append(spatial_shape) 329 | src = src.flatten(2).transpose(1, 2) 330 | mask = mask.flatten(1) 331 | pos_embed = pos_embed.flatten(2).transpose(1, 2) 332 | lvl_pos_embed = pos_embed + self.level_embed[lvl].view(1, 1, -1) 333 | lvl_pos_embed_flatten.append(lvl_pos_embed) 334 | src_flatten.append(src) 335 | mask_flatten.append(mask) 336 | src_flatten = torch.cat(src_flatten, 1) 337 | mask_flatten = torch.cat(mask_flatten, 1) 338 | lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) 339 | spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=src_flatten.device) 340 | level_start_index = torch.cat((spatial_shapes.new_zeros((1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) 341 | valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) 342 | 343 | # encoder 344 | memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, lvl_pos_embed_flatten, mask_flatten) 345 | 346 | return memory, spatial_shapes, level_start_index 347 | 348 | 349 | class MSDeformAttnTransformerEncoderLayer(nn.Module): 350 | def __init__(self, 351 | d_model=256, d_ffn=1024, 352 | dropout=0.1, activation="relu", 353 | n_levels=4, n_heads=8, n_points=4): 354 | super().__init__() 355 | 356 | # self attention 357 | self.self_attn = MSDeformAttn(d_model, n_levels, n_heads, n_points) 358 | self.dropout1 = nn.Dropout(dropout) 359 | self.norm1 = nn.LayerNorm(d_model) 360 | 361 | # ffn 362 | self.linear1 = nn.Linear(d_model, d_ffn) 363 | self.activation = _get_activation_fn(activation) 364 | self.dropout2 = nn.Dropout(dropout) 365 | self.linear2 = nn.Linear(d_ffn, d_model) 366 | self.dropout3 = nn.Dropout(dropout) 367 | self.norm2 = nn.LayerNorm(d_model) 368 | 369 | @staticmethod 370 | def with_pos_embed(tensor, pos): 371 | return tensor if pos is None else tensor + pos 372 | 373 | def forward_ffn(self, src): 374 | src2 = self.linear2(self.dropout2(self.activation(self.linear1(src)))) 375 | src = src + self.dropout3(src2) 376 | src = self.norm2(src) 377 | return src 378 | 379 | def forward(self, src, pos, reference_points, spatial_shapes, level_start_index, padding_mask=None): 380 | # self attention 381 | src2 = self.self_attn(self.with_pos_embed(src, pos), reference_points, src, spatial_shapes, level_start_index, padding_mask) 382 | src = src + self.dropout1(src2) 383 | src = self.norm1(src) 384 | 385 | # ffn 386 | src = self.forward_ffn(src) 387 | 388 | return src 389 | 390 | 391 | class MSDeformAttnTransformerEncoder(nn.Module): 392 | def __init__(self, encoder_layer, num_layers): 393 | super().__init__() 394 | self.layers = _get_clones(encoder_layer, num_layers) 395 | self.num_layers = num_layers 396 | 397 | @staticmethod 398 | def get_reference_points(spatial_shapes, valid_ratios, device): 399 | reference_points_list = [] 400 | for lvl, (H_, W_) in enumerate(spatial_shapes): 401 | 402 | ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), 403 | torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) 404 | ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * H_) 405 | ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * W_) 406 | ref = torch.stack((ref_x, ref_y), -1) 407 | reference_points_list.append(ref) 408 | reference_points = torch.cat(reference_points_list, 1) 409 | reference_points = reference_points[:, :, None] * valid_ratios[:, None] 410 | return reference_points 411 | 412 | def forward(self, src, spatial_shapes, level_start_index, valid_ratios, pos=None, padding_mask=None): 413 | output = src 414 | reference_points = self.get_reference_points(spatial_shapes, valid_ratios, device=src.device) 415 | for _, layer in enumerate(self.layers): 416 | output = layer(output, pos, reference_points, spatial_shapes, level_start_index, padding_mask) 417 | 418 | return output 419 | 420 | -------------------------------------------------------------------------------- /dcnet/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .PCD.pcd import pcd 2 | from .ICS.ics import ics -------------------------------------------------------------------------------- /dcnet/modeling/criterion.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/facebookresearch/detr/blob/master/models/detr.py 2 | import logging 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | from detectron2.utils.comm import get_world_size 9 | from detectron2.projects.point_rend.point_features import ( 10 | get_uncertain_point_coords_with_randomness, 11 | point_sample, 12 | ) 13 | 14 | from ..utils.misc import is_dist_avail_and_initialized, nested_tensor_from_tensor_list 15 | 16 | 17 | def dice_loss( 18 | inputs: torch.Tensor, 19 | targets: torch.Tensor, 20 | num_masks: float, 21 | ): 22 | """ 23 | Compute the DICE loss, similar to generalized IOU for masks 24 | Args: 25 | inputs: A float tensor of arbitrary shape. 26 | The predictions for each example. 27 | targets: A float tensor with the same shape as inputs. Stores the binary 28 | classification label for each element in inputs 29 | (0 for the negative class and 1 for the positive class). 30 | """ 31 | inputs = inputs.sigmoid() 32 | inputs = inputs.flatten(1) 33 | numerator = 2 * (inputs * targets).sum(-1) 34 | denominator = inputs.sum(-1) + targets.sum(-1) 35 | loss = 1 - (numerator + 1) / (denominator + 1) 36 | return loss.sum() / num_masks 37 | 38 | 39 | dice_loss_jit = torch.jit.script( 40 | dice_loss 41 | ) # type: torch.jit.ScriptModule 42 | 43 | 44 | def sigmoid_ce_loss( 45 | inputs: torch.Tensor, 46 | targets: torch.Tensor, 47 | num_masks: float, 48 | ): 49 | """ 50 | Args: 51 | inputs: A float tensor of arbitrary shape. 52 | The predictions for each example. 53 | targets: A float tensor with the same shape as inputs. Stores the binary 54 | classification label for each element in inputs 55 | (0 for the negative class and 1 for the positive class). 56 | Returns: 57 | Loss tensor 58 | """ 59 | loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 60 | 61 | return loss.mean(1).sum() / num_masks 62 | 63 | 64 | sigmoid_ce_loss_jit = torch.jit.script( 65 | sigmoid_ce_loss 66 | ) # type: torch.jit.ScriptModule 67 | 68 | 69 | def calculate_uncertainty(logits): 70 | """ 71 | We estimate uncerainty as L1 distance between 0.0 and the logit prediction in 'logits' for the 72 | foreground class in `classes`. 73 | Args: 74 | logits (Tensor): A tensor of shape (R, 1, ...) for class-specific or 75 | class-agnostic, where R is the total number of predicted masks in all images and C is 76 | the number of foreground classes. The values are logits. 77 | Returns: 78 | scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with 79 | the most uncertain locations having the highest uncertainty score. 80 | """ 81 | assert logits.shape[1] == 1 82 | gt_class_logits = logits.clone() 83 | return -(torch.abs(gt_class_logits)) 84 | 85 | 86 | class SetCriterion(nn.Module): 87 | """This class computes the loss for DETR. 88 | The process happens in two steps: 89 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 90 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 91 | """ 92 | 93 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, 94 | num_points, oversample_ratio, importance_sample_ratio): 95 | """Create the criterion. 96 | Parameters: 97 | num_classes: number of object categories, omitting the special no-object category 98 | matcher: module able to compute a matching between targets and proposals 99 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 100 | eos_coef: relative classification weight applied to the no-object category 101 | losses: list of all the losses to be applied. See get_loss for list of available losses. 102 | """ 103 | super().__init__() 104 | self.num_classes = num_classes 105 | self.matcher = matcher 106 | self.weight_dict = weight_dict 107 | self.eos_coef = eos_coef 108 | self.losses = losses 109 | empty_weight = torch.ones(self.num_classes + 1) 110 | empty_weight[-1] = self.eos_coef 111 | self.register_buffer("empty_weight", empty_weight) 112 | 113 | # pointwise mask loss parameters 114 | self.num_points = num_points 115 | self.oversample_ratio = oversample_ratio 116 | self.importance_sample_ratio = importance_sample_ratio 117 | 118 | def loss_labels(self, outputs, targets, indices, num_masks): 119 | """Classification loss (NLL) 120 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 121 | """ 122 | assert "pred_logits" in outputs 123 | src_logits = outputs["pred_logits"].float() 124 | 125 | idx = self._get_src_permutation_idx(indices) 126 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 127 | target_classes = torch.full( 128 | src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device 129 | ) 130 | target_classes[idx] = target_classes_o 131 | 132 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) 133 | losses = {"loss_ce": loss_ce} 134 | return losses 135 | 136 | def loss_masks(self, outputs, targets, indices, num_masks): 137 | """Compute the losses related to the masks: the focal loss and the dice loss. 138 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 139 | """ 140 | assert "pred_masks" in outputs 141 | 142 | src_idx = self._get_src_permutation_idx(indices) 143 | tgt_idx = self._get_tgt_permutation_idx(indices) 144 | src_masks = outputs["pred_masks"] 145 | src_masks = src_masks[src_idx] 146 | masks = [t["masks"] for t in targets] 147 | # TODO use valid to mask invalid areas due to padding in loss 148 | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() 149 | target_masks = target_masks.to(src_masks) 150 | target_masks = target_masks[tgt_idx] 151 | 152 | # No need to upsample predictions as we are using normalized coordinates :) 153 | # N x 1 x H x W 154 | src_masks = src_masks[:, None] 155 | target_masks = target_masks[:, None] 156 | 157 | with torch.no_grad(): 158 | # sample point_coords 159 | point_coords = get_uncertain_point_coords_with_randomness( 160 | src_masks, 161 | lambda logits: calculate_uncertainty(logits), 162 | self.num_points, 163 | self.oversample_ratio, 164 | self.importance_sample_ratio, 165 | ) 166 | # get gt labels 167 | point_labels = point_sample( 168 | target_masks, 169 | point_coords, 170 | align_corners=False, 171 | ).squeeze(1) 172 | 173 | point_logits = point_sample( 174 | src_masks, 175 | point_coords, 176 | align_corners=False, 177 | ).squeeze(1) 178 | 179 | losses = { 180 | "loss_mask": sigmoid_ce_loss_jit(point_logits, point_labels, num_masks), 181 | "loss_dice": dice_loss_jit(point_logits, point_labels, num_masks), 182 | } 183 | 184 | del src_masks 185 | del target_masks 186 | return losses 187 | 188 | def _get_src_permutation_idx(self, indices): 189 | # permute predictions following indices 190 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 191 | src_idx = torch.cat([src for (src, _) in indices]) 192 | return batch_idx, src_idx 193 | 194 | def _get_tgt_permutation_idx(self, indices): 195 | # permute targets following indices 196 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 197 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 198 | return batch_idx, tgt_idx 199 | 200 | def get_loss(self, loss, outputs, targets, indices, num_masks): 201 | loss_map = { 202 | 'labels': self.loss_labels, 203 | 'masks': self.loss_masks, 204 | } 205 | assert loss in loss_map, f"do you really want to compute {loss} loss?" 206 | return loss_map[loss](outputs, targets, indices, num_masks) 207 | 208 | def forward(self, outputs, targets): 209 | """This performs the loss computation. 210 | Parameters: 211 | outputs: dict of tensors, see the output specification of the model for the format 212 | targets: list of dicts, such that len(targets) == batch_size. 213 | The expected keys in each dict depends on the losses applied, see each loss' doc 214 | """ 215 | outputs_without_aux = {k: v for k, v in outputs.items() if k != "aux_outputs"} 216 | 217 | # Retrieve the matching between the outputs of the last layer and the targets 218 | indices = self.matcher(outputs_without_aux, targets) 219 | 220 | # Compute the average number of target boxes accross all nodes, for normalization purposes 221 | num_masks = sum(len(t["labels"]) for t in targets) 222 | num_masks = torch.as_tensor( 223 | [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device 224 | ) 225 | if is_dist_avail_and_initialized(): 226 | torch.distributed.all_reduce(num_masks) 227 | num_masks = torch.clamp(num_masks / get_world_size(), min=1).item() 228 | 229 | # Compute all the requested losses 230 | losses = {} 231 | for loss in self.losses: 232 | losses.update(self.get_loss(loss, outputs, targets, indices, num_masks)) 233 | 234 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 235 | if "aux_outputs" in outputs: 236 | for i, aux_outputs in enumerate(outputs["aux_outputs"]): 237 | indices = self.matcher(aux_outputs, targets) 238 | for loss in self.losses: 239 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_masks) 240 | l_dict = {k + f"_{i}": v for k, v in l_dict.items()} 241 | losses.update(l_dict) 242 | 243 | return losses 244 | 245 | def __repr__(self): 246 | head = "Criterion " + self.__class__.__name__ 247 | body = [ 248 | "matcher: {}".format(self.matcher.__repr__(_repr_indent=8)), 249 | "losses: {}".format(self.losses), 250 | "weight_dict: {}".format(self.weight_dict), 251 | "num_classes: {}".format(self.num_classes), 252 | "eos_coef: {}".format(self.eos_coef), 253 | "num_points: {}".format(self.num_points), 254 | "oversample_ratio: {}".format(self.oversample_ratio), 255 | "importance_sample_ratio: {}".format(self.importance_sample_ratio), 256 | ] 257 | _repr_indent = 4 258 | lines = [head] + [" " * _repr_indent + line for line in body] 259 | return "\n".join(lines) 260 | -------------------------------------------------------------------------------- /dcnet/modeling/matcher.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/facebookresearch/detr/blob/master/models/matcher.py 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | from scipy.optimize import linear_sum_assignment 8 | from torch import nn 9 | from torch.cuda.amp import autocast 10 | 11 | from detectron2.projects.point_rend.point_features import point_sample 12 | 13 | 14 | def batch_dice_loss(inputs: torch.Tensor, targets: torch.Tensor): 15 | """ 16 | Compute the DICE loss, similar to generalized IOU for masks 17 | Args: 18 | inputs: A float tensor of arbitrary shape. 19 | The predictions for each example. 20 | targets: A float tensor with the same shape as inputs. Stores the binary 21 | classification label for each element in inputs 22 | (0 for the negative class and 1 for the positive class). 23 | """ 24 | inputs = inputs.sigmoid() 25 | inputs = inputs.flatten(1) 26 | numerator = 2 * torch.einsum("nc,mc->nm", inputs, targets) 27 | denominator = inputs.sum(-1)[:, None] + targets.sum(-1)[None, :] 28 | loss = 1 - (numerator + 1) / (denominator + 1) 29 | return loss 30 | 31 | 32 | batch_dice_loss_jit = torch.jit.script( 33 | batch_dice_loss 34 | ) # type: torch.jit.ScriptModule 35 | 36 | 37 | def batch_sigmoid_ce_loss(inputs: torch.Tensor, targets: torch.Tensor): 38 | """ 39 | Args: 40 | inputs: A float tensor of arbitrary shape. 41 | The predictions for each example. 42 | targets: A float tensor with the same shape as inputs. Stores the binary 43 | classification label for each element in inputs 44 | (0 for the negative class and 1 for the positive class). 45 | Returns: 46 | Loss tensor 47 | """ 48 | hw = inputs.shape[1] 49 | 50 | pos = F.binary_cross_entropy_with_logits( 51 | inputs, torch.ones_like(inputs), reduction="none" 52 | ) 53 | neg = F.binary_cross_entropy_with_logits( 54 | inputs, torch.zeros_like(inputs), reduction="none" 55 | ) 56 | 57 | loss = torch.einsum("nc,mc->nm", pos, targets) + torch.einsum( 58 | "nc,mc->nm", neg, (1 - targets) 59 | ) 60 | 61 | return loss / hw 62 | 63 | 64 | batch_sigmoid_ce_loss_jit = torch.jit.script( 65 | batch_sigmoid_ce_loss 66 | ) # type: torch.jit.ScriptModule 67 | 68 | 69 | class HungarianMatcher(nn.Module): 70 | """This class computes an assignment between the targets and the predictions of the network 71 | 72 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 73 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 74 | while the others are un-matched (and thus treated as non-objects). 75 | """ 76 | 77 | def __init__(self, cost_class: float = 1, cost_mask: float = 1, cost_dice: float = 1, num_points: int = 0): 78 | """Creates the matcher 79 | 80 | Params: 81 | cost_class: This is the relative weight of the classification error in the matching cost 82 | cost_mask: This is the relative weight of the focal loss of the binary mask in the matching cost 83 | cost_dice: This is the relative weight of the dice loss of the binary mask in the matching cost 84 | """ 85 | super().__init__() 86 | self.cost_class = cost_class 87 | self.cost_mask = cost_mask 88 | self.cost_dice = cost_dice 89 | 90 | assert cost_class != 0 or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" 91 | 92 | self.num_points = num_points 93 | 94 | @torch.no_grad() 95 | def memory_efficient_forward(self, outputs, targets): 96 | """More memory-friendly matching""" 97 | bs, num_queries = outputs["pred_logits"].shape[:2] 98 | 99 | indices = [] 100 | 101 | # Iterate through batch size 102 | for b in range(bs): 103 | 104 | out_prob = outputs["pred_logits"][b].softmax(-1) # [num_queries, num_classes] 105 | tgt_ids = targets[b]["labels"] 106 | 107 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 108 | # but approximate it in 1 - proba[target class]. 109 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 110 | cost_class = -out_prob[:, tgt_ids] 111 | 112 | out_mask = outputs["pred_masks"][b] # [num_queries, H_pred, W_pred] 113 | # gt masks are already padded when preparing target 114 | tgt_mask = targets[b]["masks"].to(out_mask) 115 | 116 | out_mask = out_mask[:, None] 117 | tgt_mask = tgt_mask[:, None] 118 | # all masks share the same set of points for efficient matching! 119 | point_coords = torch.rand(1, self.num_points, 2, device=out_mask.device) 120 | # get gt labels 121 | tgt_mask = point_sample( 122 | tgt_mask, 123 | point_coords.repeat(tgt_mask.shape[0], 1, 1), 124 | align_corners=False, 125 | ).squeeze(1) 126 | 127 | out_mask = point_sample( 128 | out_mask, 129 | point_coords.repeat(out_mask.shape[0], 1, 1), 130 | align_corners=False, 131 | ).squeeze(1) 132 | 133 | with autocast(enabled=False): 134 | out_mask = out_mask.float() 135 | tgt_mask = tgt_mask.float() 136 | # Compute the focal loss between masks 137 | cost_mask = batch_sigmoid_ce_loss_jit(out_mask, tgt_mask) 138 | 139 | # Compute the dice loss betwen masks 140 | # FIX: batch_dice_loss_jit 141 | cost_dice = batch_dice_loss(out_mask, tgt_mask) 142 | 143 | # Final cost matrix 144 | C = ( 145 | self.cost_mask * cost_mask 146 | + self.cost_class * cost_class 147 | + self.cost_dice * cost_dice 148 | ) 149 | C = C.reshape(num_queries, -1).cpu() 150 | 151 | indices.append(linear_sum_assignment(C)) 152 | # ValueError: matrix contains invalid numeric entries 153 | # reason: the learning rete is too large to diverging 154 | 155 | return [ 156 | (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) 157 | for i, j in indices 158 | ] 159 | 160 | @torch.no_grad() 161 | def forward(self, outputs, targets): 162 | """Performs the matching 163 | 164 | Params: 165 | outputs: This is a dict that contains at least these entries: 166 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 167 | "pred_masks": Tensor of dim [batch_size, num_queries, H_pred, W_pred] with the predicted masks 168 | 169 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 170 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 171 | objects in the target) containing the class labels 172 | "masks": Tensor of dim [num_target_boxes, H_gt, W_gt] containing the target masks 173 | 174 | Returns: 175 | A list of size batch_size, containing tuples of (index_i, index_j) where: 176 | - index_i is the indices of the selected predictions (in order) 177 | - index_j is the indices of the corresponding selected targets (in order) 178 | For each batch element, it holds: 179 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 180 | """ 181 | return self.memory_efficient_forward(outputs, targets) 182 | 183 | def __repr__(self, _repr_indent=4): 184 | head = "Matcher " + self.__class__.__name__ 185 | body = [ 186 | "cost_class: {}".format(self.cost_class), 187 | "cost_mask: {}".format(self.cost_mask), 188 | "cost_dice: {}".format(self.cost_dice), 189 | ] 190 | lines = [head] + [" " * _repr_indent + line for line in body] 191 | return "\n".join(lines) 192 | -------------------------------------------------------------------------------- /dcnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/dcnet/utils/__init__.py -------------------------------------------------------------------------------- /dcnet/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/facebookresearch/detr/blob/master/util/misc.py 2 | """ 3 | Misc functions, including distributed helpers. 4 | 5 | Mostly copy-paste from torchvision references. 6 | """ 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import torchvision 12 | from torch import Tensor 13 | 14 | 15 | def _max_by_axis(the_list): 16 | # type: (List[List[int]]) -> List[int] 17 | maxes = the_list[0] 18 | for sublist in the_list[1:]: 19 | for index, item in enumerate(sublist): 20 | maxes[index] = max(maxes[index], item) 21 | return maxes 22 | 23 | 24 | class NestedTensor(object): 25 | def __init__(self, tensors, mask: Optional[Tensor]): 26 | self.tensors = tensors 27 | self.mask = mask 28 | 29 | def to(self, device): 30 | # type: (Device) -> NestedTensor # noqa 31 | cast_tensor = self.tensors.to(device) 32 | mask = self.mask 33 | if mask is not None: 34 | assert mask is not None 35 | cast_mask = mask.to(device) 36 | else: 37 | cast_mask = None 38 | return NestedTensor(cast_tensor, cast_mask) 39 | 40 | def decompose(self): 41 | return self.tensors, self.mask 42 | 43 | def __repr__(self): 44 | return str(self.tensors) 45 | 46 | 47 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 48 | # TODO make this more general 49 | if tensor_list[0].ndim == 3: 50 | if torchvision._is_tracing(): 51 | # nested_tensor_from_tensor_list() does not export well to ONNX 52 | # call _onnx_nested_tensor_from_tensor_list() instead 53 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 54 | 55 | # TODO make it support different-sized images 56 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 57 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 58 | batch_shape = [len(tensor_list)] + max_size 59 | b, c, h, w = batch_shape 60 | dtype = tensor_list[0].dtype 61 | device = tensor_list[0].device 62 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 63 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 64 | for img, pad_img, m in zip(tensor_list, tensor, mask): 65 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 66 | m[: img.shape[1], : img.shape[2]] = False 67 | else: 68 | raise ValueError("not supported") 69 | return NestedTensor(tensor, mask) 70 | 71 | 72 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 73 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 74 | @torch.jit.unused 75 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 76 | max_size = [] 77 | for i in range(tensor_list[0].dim()): 78 | max_size_i = torch.max( 79 | torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32) 80 | ).to(torch.int64) 81 | max_size.append(max_size_i) 82 | max_size = tuple(max_size) 83 | 84 | # work around for 85 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 86 | # m[: img.shape[1], :img.shape[2]] = False 87 | # which is not yet supported in onnx 88 | padded_imgs = [] 89 | padded_masks = [] 90 | for img in tensor_list: 91 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 92 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 93 | padded_imgs.append(padded_img) 94 | 95 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 96 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 97 | padded_masks.append(padded_mask.to(torch.bool)) 98 | 99 | tensor = torch.stack(padded_imgs) 100 | mask = torch.stack(padded_masks) 101 | 102 | return NestedTensor(tensor, mask=mask) 103 | 104 | 105 | def is_dist_avail_and_initialized(): 106 | if not dist.is_available(): 107 | return False 108 | if not dist.is_initialized(): 109 | return False 110 | return True 111 | -------------------------------------------------------------------------------- /framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/USTCL/DCNet/f3c9098d1e0696cae8a5cfe59f952487d089ac4c/framework.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | scipy 3 | shapely 4 | timm 5 | h5py 6 | submitit 7 | scikit-image 8 | torchvision 9 | opencv-python -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | import os 2 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 3 | 4 | import copy 5 | import itertools 6 | import logging 7 | 8 | from collections import OrderedDict 9 | from typing import Any, Dict, List, Set 10 | 11 | import torch 12 | 13 | import detectron2.utils.comm as comm 14 | from detectron2.checkpoint import DetectionCheckpointer 15 | from detectron2.config import get_cfg 16 | from detectron2.data import MetadataCatalog, build_detection_train_loader, build_detection_test_loader 17 | from detectron2.engine import ( 18 | DefaultTrainer, 19 | default_argument_parser, 20 | default_setup, 21 | launch, 22 | ) 23 | from detectron2.evaluation import ( 24 | COCOEvaluator, 25 | DatasetEvaluators, 26 | verify_results, 27 | ) 28 | from detectron2.projects.deeplab import add_deeplab_config, build_lr_scheduler 29 | from detectron2.solver.build import maybe_add_gradient_clipping 30 | from detectron2.utils.logger import setup_logger 31 | 32 | from dcnet import ( 33 | 34 | add_dcnet_config, 35 | register_dataset, 36 | DatasetMapper_Fourier_amplitude, 37 | ) 38 | 39 | 40 | class Trainer(DefaultTrainer): 41 | """ 42 | Extension of the Trainer class adapted to DCNet. 43 | """ 44 | 45 | @classmethod 46 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 47 | """ 48 | Create evaluator(s) for a given dataset. 49 | This uses the special metadata "evaluator_type" associated with each 50 | builtin dataset. For your own dataset, you can simply create an 51 | evaluator manually in your script and do not have to worry about the 52 | hacky if-else logic here. 53 | """ 54 | if output_folder is None: 55 | # inference output dir 56 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference", dataset_name) 57 | evaluator_list = [] 58 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 59 | 60 | # instance segmentation 61 | if evaluator_type == "coco": 62 | evaluator_list.append(COCOEvaluator(dataset_name, output_dir=output_folder)) 63 | 64 | if len(evaluator_list) == 0: 65 | raise NotImplementedError( 66 | "no Evaluator for the dataset {} with the type {}".format( 67 | dataset_name, evaluator_type 68 | ) 69 | ) 70 | elif len(evaluator_list) == 1: 71 | return evaluator_list[0] 72 | return DatasetEvaluators(evaluator_list) 73 | 74 | @classmethod 75 | def build_train_loader(cls, cfg): 76 | mapper = DatasetMapper_Fourier_amplitude(cfg, True) 77 | return build_detection_train_loader(cfg, mapper=mapper) 78 | 79 | 80 | @classmethod 81 | def build_test_loader(cls, cfg, dataset_name): 82 | mapper = DatasetMapper_Fourier_amplitude(cfg, False) 83 | return build_detection_test_loader(cfg, dataset_name, mapper=mapper) 84 | 85 | @classmethod 86 | def build_lr_scheduler(cls, cfg, optimizer): 87 | """ 88 | It now calls :func:`detectron2.solver.build_lr_scheduler`. 89 | Overwrite it if you'd like a different scheduler. 90 | """ 91 | return build_lr_scheduler(cfg, optimizer) 92 | 93 | @classmethod 94 | def build_optimizer(cls, cfg, model): 95 | weight_decay_norm = 0 96 | weight_decay_embed = 0 97 | 98 | defaults = {} 99 | defaults["lr"] = cfg.SOLVER.BASE_LR 100 | defaults["weight_decay"] = 0.05 101 | 102 | norm_module_types = ( 103 | torch.nn.BatchNorm1d, 104 | torch.nn.BatchNorm2d, 105 | torch.nn.BatchNorm3d, 106 | torch.nn.SyncBatchNorm, 107 | # NaiveSyncBatchNorm inherits from BatchNorm2d 108 | torch.nn.GroupNorm, 109 | torch.nn.InstanceNorm1d, 110 | torch.nn.InstanceNorm2d, 111 | torch.nn.InstanceNorm3d, 112 | torch.nn.LayerNorm, 113 | torch.nn.LocalResponseNorm, 114 | ) 115 | 116 | params: List[Dict[str, Any]] = [] 117 | memo: Set[torch.nn.parameter.Parameter] = set() 118 | for module_name, module in model.named_modules(): 119 | for module_param_name, value in module.named_parameters(recurse=False): 120 | if not value.requires_grad: 121 | # print('requires_grad is False:', module_param_name) 122 | continue 123 | # Avoid duplicating parameters 124 | if value in memo: 125 | continue 126 | memo.add(value) 127 | 128 | hyperparams = copy.copy(defaults) 129 | if "backbone" in module_name: 130 | hyperparams["lr"] = hyperparams["lr"] * cfg.SOLVER.BACKBONE_MULTIPLIER 131 | # print(module_name, '+', module_param_name) 132 | if ( 133 | "relative_position_bias_table" in module_param_name 134 | or "absolute_pos_embed" in module_param_name 135 | ): 136 | print(module_param_name) 137 | hyperparams["weight_decay"] = 0.0 138 | if isinstance(module, norm_module_types): 139 | hyperparams["weight_decay"] = weight_decay_norm 140 | if isinstance(module, torch.nn.Embedding): 141 | hyperparams["weight_decay"] = weight_decay_embed 142 | params.append({"params": [value], **hyperparams}) 143 | # print('add to optimizer:', module_name, module_param_name) 144 | 145 | def maybe_add_full_model_gradient_clipping(optim): 146 | # detectron2 doesn't have full model gradient clipping now 147 | clip_norm_val = 0.01 148 | 149 | class FullModelGradientClippingOptimizer(optim): 150 | def step(self, closure=None): 151 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 152 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 153 | super().step(closure=closure) 154 | 155 | return FullModelGradientClippingOptimizer 156 | 157 | optimizer_type = cfg.SOLVER.OPTIMIZER 158 | if optimizer_type == "SGD": 159 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 160 | params, cfg.SOLVER.BASE_LR, momentum=0.9 161 | ) 162 | elif optimizer_type == "ADAMW": 163 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 164 | params, cfg.SOLVER.BASE_LR 165 | ) 166 | else: 167 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 168 | 169 | return optimizer 170 | 171 | 172 | def setup(args): 173 | """ 174 | Create configs and perform basic setups. 175 | """ 176 | cfg = get_cfg() 177 | # for poly lr schedule 178 | add_deeplab_config(cfg) 179 | add_dcnet_config(cfg) 180 | cfg.merge_from_file(args.config_file) 181 | cfg.merge_from_list(args.opts) 182 | cfg.freeze() 183 | default_setup(cfg, args) 184 | # Setup logger for "dcnet" module 185 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="dcnet") 186 | return cfg 187 | 188 | 189 | def main(args): 190 | cfg = setup(args) 191 | 192 | register_dataset() 193 | 194 | if args.eval_only: 195 | model = Trainer.build_model(cfg) 196 | # print('Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) 197 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 198 | cfg.MODEL.WEIGHTS, resume=args.resume 199 | ) 200 | res = Trainer.test(cfg, model) 201 | if comm.is_main_process(): 202 | verify_results(cfg, res) 203 | return res 204 | 205 | trainer = Trainer(cfg) 206 | trainer.resume_or_load(resume=args.resume) 207 | return trainer.train() 208 | 209 | 210 | if __name__ == "__main__": 211 | args = default_argument_parser().parse_args() 212 | print("Command Line Args:", args) 213 | launch( 214 | main, 215 | args.num_gpus, 216 | num_machines=args.num_machines, 217 | machine_rank=args.machine_rank, 218 | dist_url=args.dist_url, 219 | args=(args,), 220 | ) 221 | --------------------------------------------------------------------------------