├── .github ├── attention-maps.png ├── conditional-detr.png ├── convergence-curve.png └── ISSUE_TEMPLATE │ ├── bugs.md │ ├── questions-help-support.md │ └── unexpected-problems-bugs.md ├── requirements.txt ├── scripts ├── conddetr_r50_epoch50.sh ├── conddetr_r101_epoch50.sh ├── conddetr_r50dc5_epoch50.sh └── conddetr_r101dc5_epoch50.sh ├── .gitignore ├── util ├── __init__.py ├── box_ops.py └── misc.py ├── models ├── __init__.py ├── position_encoding.py ├── matcher.py ├── backbone.py ├── segmentation.py ├── transformer.py ├── conditional_detr.py └── attention.py ├── datasets ├── __init__.py ├── panoptic_eval.py ├── coco_panoptic.py ├── coco.py ├── transforms.py └── coco_eval.py ├── hubconf.py ├── engine.py ├── README.md ├── LICENSE └── main.py /.github/attention-maps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atten4Vis/ConditionalDETR/HEAD/.github/attention-maps.png -------------------------------------------------------------------------------- /.github/conditional-detr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atten4Vis/ConditionalDETR/HEAD/.github/conditional-detr.png -------------------------------------------------------------------------------- /.github/convergence-curve.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Atten4Vis/ConditionalDETR/HEAD/.github/convergence-curve.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI&egg=pycocotools 3 | torch>=1.7.0 4 | torchvision>=0.6.0 5 | git+https://github.com/cocodataset/panopticapi.git#egg=panopticapi 6 | scipy 7 | termcolor 8 | -------------------------------------------------------------------------------- /scripts/conddetr_r50_epoch50.sh: -------------------------------------------------------------------------------- 1 | script_name1=`basename $0` 2 | script_name=${script_name1:0:${#script_name1}-3} 3 | 4 | python -m torch.distributed.launch \ 5 | --nproc_per_node=8 \ 6 | --use_env \ 7 | main.py \ 8 | --coco_path ../data/coco \ 9 | --output_dir output/$script_name -------------------------------------------------------------------------------- /scripts/conddetr_r101_epoch50.sh: -------------------------------------------------------------------------------- 1 | script_name1=`basename $0` 2 | script_name=${script_name1:0:${#script_name1}-3} 3 | 4 | python -m torch.distributed.launch \ 5 | --nproc_per_node=8 \ 6 | --use_env \ 7 | main.py \ 8 | --backbone resnet101 \ 9 | --coco_path ../data/coco \ 10 | --output_dir output/$script_name -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .nfs* 2 | *.ipynb 3 | *.pyc 4 | .dumbo.json 5 | .DS_Store 6 | .*.swp 7 | *.pth 8 | **/__pycache__/** 9 | .ipynb_checkpoints/ 10 | datasets/data/ 11 | experiment-* 12 | *.tmp 13 | *.pkl 14 | **/.mypy_cache/* 15 | .mypy_cache/* 16 | not_tracked_dir/ 17 | .vscode 18 | output/ 19 | imgs/ 20 | figs/ 21 | logs/ 22 | plot*.py 23 | -------------------------------------------------------------------------------- /scripts/conddetr_r50dc5_epoch50.sh: -------------------------------------------------------------------------------- 1 | script_name1=`basename $0` 2 | script_name=${script_name1:0:${#script_name1}-3} 3 | 4 | python -m torch.distributed.launch \ 5 | --nproc_per_node=8 \ 6 | --use_env \ 7 | main.py \ 8 | --batch_size 1 \ 9 | --dilation \ 10 | --coco_path ../data/coco \ 11 | --output_dir output/$script_name -------------------------------------------------------------------------------- /scripts/conddetr_r101dc5_epoch50.sh: -------------------------------------------------------------------------------- 1 | script_name1=`basename $0` 2 | script_name=${script_name1:0:${#script_name1}-3} 3 | 4 | python -m torch.distributed.launch \ 5 | --nproc_per_node=8 \ 6 | --use_env \ 7 | main.py \ 8 | --backbone resnet101 \ 9 | --batch_size 1 \ 10 | --dilation \ 11 | --coco_path ../data/coco \ 12 | --output_dir output/$script_name -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | from .conditional_detr import build 11 | 12 | 13 | def build_model(args): 14 | return build(args) 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bugs.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "🐛 Bugs" 3 | about: Report bugs in DETR 4 | title: Please read & provide the following 5 | 6 | --- 7 | 8 | ## Instructions To Reproduce the 🐛 Bug: 9 | 10 | 1. what changes you made (`git diff`) or what code you wrote 11 | ``` 12 | 13 | ``` 14 | 2. what exact command you run: 15 | 3. what you observed (including __full logs__): 16 | ``` 17 | 18 | ``` 19 | 4. please simplify the steps as much as possible so they do not require additional resources to 20 | run, such as a private dataset. 21 | 22 | ## Expected behavior: 23 | 24 | If there are no obvious error in "what you observed" provided above, 25 | please tell us the expected behavior. 26 | 27 | ## Environment: 28 | 29 | Provide your environment information using the following command: 30 | ``` 31 | python -m torch.utils.collect_env 32 | ``` 33 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions-help-support.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "How to do something❓" 3 | about: How to do something using DETR? 4 | 5 | --- 6 | 7 | ## ❓ How to do something using DETR 8 | 9 | Describe what you want to do, including: 10 | 1. what inputs you will provide, if any: 11 | 2. what outputs you are expecting: 12 | 13 | 14 | NOTE: 15 | 16 | 1. Only general answers are provided. 17 | If you want to ask about "why X did not work", please use the 18 | [Unexpected behaviors](https://github.com/facebookresearch/detr/issues/new/choose) issue template. 19 | 20 | 2. About how to implement new models / new dataloader / new training logic, etc., check documentation first. 21 | 22 | 3. We do not answer general machine learning / computer vision questions that are not specific to DETR, such as how a model works, how to improve your training/make it converge, or what algorithm/methods can be used to achieve X. 23 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/unexpected-problems-bugs.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: "Unexpected behaviors" 3 | about: Run into unexpected behaviors when using DETR 4 | title: Please read & provide the following 5 | 6 | --- 7 | 8 | If you do not know the root cause of the problem, and wish someone to help you, please 9 | post according to this template: 10 | 11 | ## Instructions To Reproduce the Issue: 12 | 13 | 1. what changes you made (`git diff`) or what code you wrote 14 | ``` 15 | 16 | ``` 17 | 2. what exact command you run: 18 | 3. what you observed (including __full logs__): 19 | ``` 20 | 21 | ``` 22 | 4. please simplify the steps as much as possible so they do not require additional resources to 23 | run, such as a private dataset. 24 | 25 | ## Expected behavior: 26 | 27 | If there are no obvious error in "what you observed" provided above, 28 | please tell us the expected behavior. 29 | 30 | If you expect the model to converge / work better, note that we do not give suggestions 31 | on how to train a new model. 32 | Only in one of the two conditions we will help with it: 33 | (1) You're unable to reproduce the results in DETR model zoo. 34 | (2) It indicates a DETR bug. 35 | 36 | ## Environment: 37 | 38 | Provide your environment information using the following command: 39 | ``` 40 | python -m torch.utils.collect_env 41 | ``` 42 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | import torch.utils.data 11 | import torchvision 12 | 13 | from .coco import build as build_coco 14 | 15 | 16 | def get_coco_api_from_dataset(dataset): 17 | for _ in range(10): 18 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 19 | # break 20 | if isinstance(dataset, torch.utils.data.Subset): 21 | dataset = dataset.dataset 22 | if isinstance(dataset, torchvision.datasets.CocoDetection): 23 | return dataset.coco 24 | 25 | 26 | def build_dataset(image_set, args): 27 | if args.dataset_file == 'coco': 28 | return build_coco(image_set, args) 29 | if args.dataset_file == 'coco_panoptic': 30 | # to avoid making panopticapi required for coco 31 | from .coco_panoptic import build as build_coco_panoptic 32 | return build_coco_panoptic(image_set, args) 33 | raise ValueError(f'dataset {args.dataset_file} not supported') 34 | -------------------------------------------------------------------------------- /datasets/panoptic_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | import json 11 | import os 12 | 13 | import util.misc as utils 14 | 15 | try: 16 | from panopticapi.evaluation import pq_compute 17 | except ImportError: 18 | pass 19 | 20 | 21 | class PanopticEvaluator(object): 22 | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): 23 | self.gt_json = ann_file 24 | self.gt_folder = ann_folder 25 | if utils.is_main_process(): 26 | if not os.path.exists(output_dir): 27 | os.mkdir(output_dir) 28 | self.output_dir = output_dir 29 | self.predictions = [] 30 | 31 | def update(self, predictions): 32 | for p in predictions: 33 | with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: 34 | f.write(p.pop("png_string")) 35 | 36 | self.predictions += predictions 37 | 38 | def synchronize_between_processes(self): 39 | all_predictions = utils.all_gather(self.predictions) 40 | merged_predictions = [] 41 | for p in all_predictions: 42 | merged_predictions += p 43 | self.predictions = merged_predictions 44 | 45 | def summarize(self): 46 | if utils.is_main_process(): 47 | json_data = {"annotations": self.predictions} 48 | predictions_json = os.path.join(self.output_dir, "predictions.json") 49 | with open(predictions_json, "w") as f: 50 | f.write(json.dumps(json_data)) 51 | return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) 52 | return None 53 | -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Utilities for bounding box manipulation and GIoU. 12 | """ 13 | import torch 14 | from torchvision.ops.boxes import box_area 15 | 16 | 17 | def box_cxcywh_to_xyxy(x): 18 | x_c, y_c, w, h = x.unbind(-1) 19 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 20 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 21 | return torch.stack(b, dim=-1) 22 | 23 | 24 | def box_xyxy_to_cxcywh(x): 25 | x0, y0, x1, y1 = x.unbind(-1) 26 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 27 | (x1 - x0), (y1 - y0)] 28 | return torch.stack(b, dim=-1) 29 | 30 | 31 | # modified from torchvision to also return the union 32 | def box_iou(boxes1, boxes2): 33 | area1 = box_area(boxes1) 34 | area2 = box_area(boxes2) 35 | 36 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 37 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 38 | 39 | wh = (rb - lt).clamp(min=0) # [N,M,2] 40 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 41 | 42 | union = area1[:, None] + area2 - inter 43 | 44 | iou = inter / union 45 | return iou, union 46 | 47 | 48 | def generalized_box_iou(boxes1, boxes2): 49 | """ 50 | Generalized IoU from https://giou.stanford.edu/ 51 | 52 | The boxes should be in [x0, y0, x1, y1] format 53 | 54 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 55 | and M = len(boxes2) 56 | """ 57 | # degenerate boxes gives inf / nan results 58 | # so do an early check 59 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 60 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 61 | iou, union = box_iou(boxes1, boxes2) 62 | 63 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 64 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 65 | 66 | wh = (rb - lt).clamp(min=0) # [N,M,2] 67 | area = wh[:, :, 0] * wh[:, :, 1] 68 | 69 | return iou - (area - union) / area 70 | 71 | 72 | def masks_to_boxes(masks): 73 | """Compute the bounding boxes around the provided masks 74 | 75 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 76 | 77 | Returns a [N, 4] tensors, with the boxes in xyxy format 78 | """ 79 | if masks.numel() == 0: 80 | return torch.zeros((0, 4), device=masks.device) 81 | 82 | h, w = masks.shape[-2:] 83 | 84 | y = torch.arange(0, h, dtype=torch.float) 85 | x = torch.arange(0, w, dtype=torch.float) 86 | y, x = torch.meshgrid(y, x) 87 | 88 | x_mask = (masks * x.unsqueeze(0)) 89 | x_max = x_mask.flatten(1).max(-1)[0] 90 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 91 | 92 | y_mask = (masks * y.unsqueeze(0)) 93 | y_max = y_mask.flatten(1).max(-1)[0] 94 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 95 | 96 | return torch.stack([x_min, y_min, x_max, y_max], 1) 97 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Various positional encodings for the transformer. 12 | """ 13 | import math 14 | import torch 15 | from torch import nn 16 | 17 | from util.misc import NestedTensor 18 | 19 | 20 | class PositionEmbeddingSine(nn.Module): 21 | """ 22 | This is a more standard version of the position embedding, very similar to the one 23 | used by the Attention is all you need paper, generalized to work on images. 24 | """ 25 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 26 | super().__init__() 27 | self.num_pos_feats = num_pos_feats 28 | self.temperature = temperature 29 | self.normalize = normalize 30 | if scale is not None and normalize is False: 31 | raise ValueError("normalize should be True if scale is passed") 32 | if scale is None: 33 | scale = 2 * math.pi 34 | self.scale = scale 35 | 36 | def forward(self, tensor_list: NestedTensor): 37 | x = tensor_list.tensors 38 | mask = tensor_list.mask 39 | assert mask is not None 40 | not_mask = ~mask 41 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 42 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 43 | if self.normalize: 44 | eps = 1e-6 45 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 46 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, :, :, None] / dim_t 52 | pos_y = y_embed[:, :, :, None] / dim_t 53 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 54 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 55 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 56 | return pos 57 | 58 | 59 | class PositionEmbeddingLearned(nn.Module): 60 | """ 61 | Absolute pos embedding, learned. 62 | """ 63 | def __init__(self, num_pos_feats=256): 64 | super().__init__() 65 | self.row_embed = nn.Embedding(50, num_pos_feats) 66 | self.col_embed = nn.Embedding(50, num_pos_feats) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | nn.init.uniform_(self.row_embed.weight) 71 | nn.init.uniform_(self.col_embed.weight) 72 | 73 | def forward(self, tensor_list: NestedTensor): 74 | x = tensor_list.tensors 75 | h, w = x.shape[-2:] 76 | i = torch.arange(w, device=x.device) 77 | j = torch.arange(h, device=x.device) 78 | x_emb = self.col_embed(i) 79 | y_emb = self.row_embed(j) 80 | pos = torch.cat([ 81 | x_emb.unsqueeze(0).repeat(h, 1, 1), 82 | y_emb.unsqueeze(1).repeat(1, w, 1), 83 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 84 | return pos 85 | 86 | 87 | def build_position_encoding(args): 88 | N_steps = args.hidden_dim // 2 89 | if args.position_embedding in ('v2', 'sine'): 90 | # TODO find a better way of exposing other arguments 91 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 92 | elif args.position_embedding in ('v3', 'learned'): 93 | position_embedding = PositionEmbeddingLearned(N_steps) 94 | else: 95 | raise ValueError(f"not supported {args.position_embedding}") 96 | 97 | return position_embedding 98 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | 4 | from models.backbone import Backbone, Joiner 5 | from models.conditional_detr import ConditionalDETR, PostProcess 6 | from models.position_encoding import PositionEmbeddingSine 7 | from models.segmentation import DETRsegm, PostProcessPanoptic 8 | from models.transformer import Transformer 9 | 10 | dependencies = ["torch", "torchvision"] 11 | 12 | 13 | def _make_conditional_detr(backbone_name: str, dilation=False, num_classes=91, mask=False): 14 | hidden_dim = 256 15 | backbone = Backbone(backbone_name, train_backbone=True, return_interm_layers=mask, dilation=dilation) 16 | pos_enc = PositionEmbeddingSine(hidden_dim // 2, normalize=True) 17 | backbone_with_pos_enc = Joiner(backbone, pos_enc) 18 | backbone_with_pos_enc.num_channels = backbone.num_channels 19 | transformer = Transformer(d_model=hidden_dim, return_intermediate_dec=True) 20 | detr = ConditionalDETR(backbone_with_pos_enc, transformer, num_classes=num_classes, num_queries=300) 21 | if mask: 22 | return DETRsegm(detr) 23 | return detr 24 | 25 | 26 | def conditional_detr_resnet50(pretrained=False, num_classes=91, return_postprocessor=False): 27 | """ 28 | ConditionalDETR R50 with 6 encoder and 6 decoder layers. 29 | 30 | Achieves 40.9 AP on COCO val5k. 31 | """ 32 | model = _make_conditional_detr("resnet50", dilation=False, num_classes=num_classes) 33 | if pretrained: 34 | checkpoint = torch.hub.load_state_dict_from_url( 35 | url="https://github.com/DeppMeng/ConditionalDETR/releases/download/v1.0/ConditionalDETR_r50_epoch50.pth", map_location="cpu", check_hash=True 36 | ) 37 | model.load_state_dict(checkpoint["model"]) 38 | if return_postprocessor: 39 | return model, PostProcess() 40 | return model 41 | 42 | 43 | def conditional_detr_resnet50_dc5(pretrained=False, num_classes=91, return_postprocessor=False): 44 | """ 45 | ConditionalDETR-DC5 R50 with 6 encoder and 6 decoder layers. 46 | 47 | The last block of RessNet-50 has dilation to increase 48 | output resolution. 49 | Achieves 43. AP on COCO val5k. 50 | """ 51 | model = _make_conditional_detr("resnet50", dilation=True, num_classes=num_classes) 52 | if pretrained: 53 | checkpoint = torch.hub.load_state_dict_from_url( 54 | url="https://github.com/DeppMeng/ConditionalDETR/releases/download/v1.0/ConditionalDETR_r50dc5_epoch50.pth", map_location="cpu", check_hash=True 55 | ) 56 | model.load_state_dict(checkpoint["model"]) 57 | if return_postprocessor: 58 | return model, PostProcess() 59 | return model 60 | 61 | 62 | def conditional_detr_resnet101(pretrained=False, num_classes=91, return_postprocessor=False): 63 | """ 64 | ConditionalDETR-DC5 R101 with 6 encoder and 6 decoder layers. 65 | 66 | Achieves 42.8 AP on COCO val5k. 67 | """ 68 | model = _make_conditional_detr("resnet101", dilation=False, num_classes=num_classes) 69 | if pretrained: 70 | checkpoint = torch.hub.load_state_dict_from_url( 71 | url="https://github.com/DeppMeng/ConditionalDETR/releases/download/v1.0/ConditionalDETR_r101_epoch50.pth", map_location="cpu", check_hash=True 72 | ) 73 | model.load_state_dict(checkpoint["model"]) 74 | if return_postprocessor: 75 | return model, PostProcess() 76 | return model 77 | 78 | 79 | def conditional_detr_resnet101_dc5(pretrained=False, num_classes=91, return_postprocessor=False): 80 | """ 81 | ConditionalDETR-DC5 R101 with 6 encoder and 6 decoder layers. 82 | 83 | The last block of ResNet-101 has dilation to increase 84 | output resolution. 85 | Achieves 45.0 AP on COCO val5k. 86 | """ 87 | model = _make_conditional_detr("resnet101", dilation=True, num_classes=num_classes) 88 | if pretrained: 89 | checkpoint = torch.hub.load_state_dict_from_url( 90 | url="https://github.com/DeppMeng/ConditionalDETR/releases/download/v1.0/ConditionalDETR_r101dc5_epoch50.pth", map_location="cpu", check_hash=True 91 | ) 92 | model.load_state_dict(checkpoint["model"]) 93 | if return_postprocessor: 94 | return model, PostProcess() 95 | return model 96 | 97 | -------------------------------------------------------------------------------- /datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | import json 11 | from pathlib import Path 12 | 13 | import numpy as np 14 | import torch 15 | from PIL import Image 16 | 17 | from panopticapi.utils import rgb2id 18 | from util.box_ops import masks_to_boxes 19 | 20 | from .coco import make_coco_transforms 21 | 22 | 23 | class CocoPanoptic: 24 | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): 25 | with open(ann_file, 'r') as f: 26 | self.coco = json.load(f) 27 | 28 | # sort 'images' field so that they are aligned with 'annotations' 29 | # i.e., in alphabetical order 30 | self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) 31 | # sanity check 32 | if "annotations" in self.coco: 33 | for img, ann in zip(self.coco['images'], self.coco['annotations']): 34 | assert img['file_name'][:-4] == ann['file_name'][:-4] 35 | 36 | self.img_folder = img_folder 37 | self.ann_folder = ann_folder 38 | self.ann_file = ann_file 39 | self.transforms = transforms 40 | self.return_masks = return_masks 41 | 42 | def __getitem__(self, idx): 43 | ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] 44 | img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') 45 | ann_path = Path(self.ann_folder) / ann_info['file_name'] 46 | 47 | img = Image.open(img_path).convert('RGB') 48 | w, h = img.size 49 | if "segments_info" in ann_info: 50 | masks = np.asarray(Image.open(ann_path), dtype=np.uint32) 51 | masks = rgb2id(masks) 52 | 53 | ids = np.array([ann['id'] for ann in ann_info['segments_info']]) 54 | masks = masks == ids[:, None, None] 55 | 56 | masks = torch.as_tensor(masks, dtype=torch.uint8) 57 | labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) 58 | 59 | target = {} 60 | target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) 61 | if self.return_masks: 62 | target['masks'] = masks 63 | target['labels'] = labels 64 | 65 | target["boxes"] = masks_to_boxes(masks) 66 | 67 | target['size'] = torch.as_tensor([int(h), int(w)]) 68 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 69 | if "segments_info" in ann_info: 70 | for name in ['iscrowd', 'area']: 71 | target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) 72 | 73 | if self.transforms is not None: 74 | img, target = self.transforms(img, target) 75 | 76 | return img, target 77 | 78 | def __len__(self): 79 | return len(self.coco['images']) 80 | 81 | def get_height_and_width(self, idx): 82 | img_info = self.coco['images'][idx] 83 | height = img_info['height'] 84 | width = img_info['width'] 85 | return height, width 86 | 87 | 88 | def build(image_set, args): 89 | img_folder_root = Path(args.coco_path) 90 | ann_folder_root = Path(args.coco_panoptic_path) 91 | assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' 92 | assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' 93 | mode = 'panoptic' 94 | PATHS = { 95 | "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), 96 | "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), 97 | } 98 | 99 | img_folder, ann_file = PATHS[image_set] 100 | img_folder_path = img_folder_root / img_folder 101 | ann_folder = ann_folder_root / f'{mode}_{img_folder}' 102 | ann_file = ann_folder_root / ann_file 103 | 104 | dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, 105 | transforms=make_coco_transforms(image_set), return_masks=args.masks) 106 | 107 | return dataset 108 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modules to compute the matching cost and solve the corresponding LSAP. 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 10 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 11 | # ------------------------------------------------------------------------ 12 | 13 | import torch 14 | from scipy.optimize import linear_sum_assignment 15 | from torch import nn 16 | 17 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 18 | 19 | 20 | class HungarianMatcher(nn.Module): 21 | """This class computes an assignment between the targets and the predictions of the network 22 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 23 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 24 | while the others are un-matched (and thus treated as non-objects). 25 | """ 26 | 27 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 28 | """Creates the matcher 29 | Params: 30 | cost_class: This is the relative weight of the classification error in the matching cost 31 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 32 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 33 | """ 34 | super().__init__() 35 | self.cost_class = cost_class 36 | self.cost_bbox = cost_bbox 37 | self.cost_giou = cost_giou 38 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 39 | 40 | @torch.no_grad() 41 | def forward(self, outputs, targets): 42 | """ Performs the matching 43 | Params: 44 | outputs: This is a dict that contains at least these entries: 45 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 46 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 47 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 48 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 49 | objects in the target) containing the class labels 50 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 51 | Returns: 52 | A list of size batch_size, containing tuples of (index_i, index_j) where: 53 | - index_i is the indices of the selected predictions (in order) 54 | - index_j is the indices of the corresponding selected targets (in order) 55 | For each batch element, it holds: 56 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 57 | """ 58 | bs, num_queries = outputs["pred_logits"].shape[:2] 59 | 60 | # We flatten to compute the cost matrices in a batch 61 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] 62 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 63 | 64 | # Also concat the target labels and boxes 65 | tgt_ids = torch.cat([v["labels"] for v in targets]) 66 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 67 | 68 | # Compute the classification cost. 69 | alpha = 0.25 70 | gamma = 2.0 71 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 72 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 73 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 74 | 75 | # Compute the L1 cost between boxes 76 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 77 | 78 | # Compute the giou cost betwen boxes 79 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 80 | 81 | # Final cost matrix 82 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 83 | C = C.view(bs, num_queries, -1).cpu() 84 | 85 | sizes = [len(v["boxes"]) for v in targets] 86 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 87 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 88 | 89 | 90 | def build_matcher(args): 91 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Backbone modules. 12 | """ 13 | from collections import OrderedDict 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | import torchvision 18 | from torch import nn 19 | from torchvision.models._utils import IntermediateLayerGetter 20 | from typing import Dict, List 21 | 22 | from util.misc import NestedTensor, is_main_process 23 | 24 | from .position_encoding import build_position_encoding 25 | 26 | 27 | class FrozenBatchNorm2d(torch.nn.Module): 28 | """ 29 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 30 | 31 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 32 | without which any other models than torchvision.models.resnet[18,34,50,101] 33 | produce nans. 34 | """ 35 | 36 | def __init__(self, n): 37 | super(FrozenBatchNorm2d, self).__init__() 38 | self.register_buffer("weight", torch.ones(n)) 39 | self.register_buffer("bias", torch.zeros(n)) 40 | self.register_buffer("running_mean", torch.zeros(n)) 41 | self.register_buffer("running_var", torch.ones(n)) 42 | 43 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 44 | missing_keys, unexpected_keys, error_msgs): 45 | num_batches_tracked_key = prefix + 'num_batches_tracked' 46 | if num_batches_tracked_key in state_dict: 47 | del state_dict[num_batches_tracked_key] 48 | 49 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 50 | state_dict, prefix, local_metadata, strict, 51 | missing_keys, unexpected_keys, error_msgs) 52 | 53 | def forward(self, x): 54 | # move reshapes to the beginning 55 | # to make it fuser-friendly 56 | w = self.weight.reshape(1, -1, 1, 1) 57 | b = self.bias.reshape(1, -1, 1, 1) 58 | rv = self.running_var.reshape(1, -1, 1, 1) 59 | rm = self.running_mean.reshape(1, -1, 1, 1) 60 | eps = 1e-5 61 | scale = w * (rv + eps).rsqrt() 62 | bias = b - rm * scale 63 | return x * scale + bias 64 | 65 | 66 | class BackboneBase(nn.Module): 67 | 68 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 69 | super().__init__() 70 | for name, parameter in backbone.named_parameters(): 71 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 72 | parameter.requires_grad_(False) 73 | if return_interm_layers: 74 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 75 | else: 76 | return_layers = {'layer4': "0"} 77 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 78 | self.num_channels = num_channels 79 | 80 | def forward(self, tensor_list: NestedTensor): 81 | xs = self.body(tensor_list.tensors) 82 | out: Dict[str, NestedTensor] = {} 83 | for name, x in xs.items(): 84 | m = tensor_list.mask 85 | assert m is not None 86 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 87 | out[name] = NestedTensor(x, mask) 88 | return out 89 | 90 | 91 | class Backbone(BackboneBase): 92 | """ResNet backbone with frozen BatchNorm.""" 93 | def __init__(self, name: str, 94 | train_backbone: bool, 95 | return_interm_layers: bool, 96 | dilation: bool): 97 | backbone = getattr(torchvision.models, name)( 98 | replace_stride_with_dilation=[False, False, dilation], 99 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 100 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 101 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 102 | 103 | 104 | class Joiner(nn.Sequential): 105 | def __init__(self, backbone, position_embedding): 106 | super().__init__(backbone, position_embedding) 107 | 108 | def forward(self, tensor_list: NestedTensor): 109 | xs = self[0](tensor_list) 110 | out: List[NestedTensor] = [] 111 | pos = [] 112 | for name, x in xs.items(): 113 | out.append(x) 114 | # position encoding 115 | pos.append(self[1](x).to(x.tensors.dtype)) 116 | 117 | return out, pos 118 | 119 | 120 | def build_backbone(args): 121 | position_embedding = build_position_encoding(args) 122 | train_backbone = args.lr_backbone > 0 123 | return_interm_layers = args.masks 124 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 125 | model = Joiner(backbone, position_embedding) 126 | model.num_channels = backbone.num_channels 127 | return model 128 | -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | COCO dataset which returns image_id for evaluation. 12 | 13 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 14 | """ 15 | from pathlib import Path 16 | 17 | import torch 18 | import torch.utils.data 19 | import torchvision 20 | from pycocotools import mask as coco_mask 21 | 22 | import datasets.transforms as T 23 | 24 | 25 | class CocoDetection(torchvision.datasets.CocoDetection): 26 | def __init__(self, img_folder, ann_file, transforms, return_masks): 27 | super(CocoDetection, self).__init__(img_folder, ann_file) 28 | self._transforms = transforms 29 | self.prepare = ConvertCocoPolysToMask(return_masks) 30 | 31 | def __getitem__(self, idx): 32 | img, target = super(CocoDetection, self).__getitem__(idx) 33 | image_id = self.ids[idx] 34 | target = {'image_id': image_id, 'annotations': target} 35 | img, target = self.prepare(img, target) 36 | if self._transforms is not None: 37 | img, target = self._transforms(img, target) 38 | return img, target 39 | 40 | 41 | def convert_coco_poly_to_mask(segmentations, height, width): 42 | masks = [] 43 | for polygons in segmentations: 44 | rles = coco_mask.frPyObjects(polygons, height, width) 45 | mask = coco_mask.decode(rles) 46 | if len(mask.shape) < 3: 47 | mask = mask[..., None] 48 | mask = torch.as_tensor(mask, dtype=torch.uint8) 49 | mask = mask.any(dim=2) 50 | masks.append(mask) 51 | if masks: 52 | masks = torch.stack(masks, dim=0) 53 | else: 54 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 55 | return masks 56 | 57 | 58 | class ConvertCocoPolysToMask(object): 59 | def __init__(self, return_masks=False): 60 | self.return_masks = return_masks 61 | 62 | def __call__(self, image, target): 63 | w, h = image.size 64 | 65 | image_id = target["image_id"] 66 | image_id = torch.tensor([image_id]) 67 | 68 | anno = target["annotations"] 69 | 70 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 71 | 72 | boxes = [obj["bbox"] for obj in anno] 73 | # guard against no boxes via resizing 74 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 75 | boxes[:, 2:] += boxes[:, :2] 76 | boxes[:, 0::2].clamp_(min=0, max=w) 77 | boxes[:, 1::2].clamp_(min=0, max=h) 78 | 79 | classes = [obj["category_id"] for obj in anno] 80 | classes = torch.tensor(classes, dtype=torch.int64) 81 | 82 | if self.return_masks: 83 | segmentations = [obj["segmentation"] for obj in anno] 84 | masks = convert_coco_poly_to_mask(segmentations, h, w) 85 | 86 | keypoints = None 87 | if anno and "keypoints" in anno[0]: 88 | keypoints = [obj["keypoints"] for obj in anno] 89 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 90 | num_keypoints = keypoints.shape[0] 91 | if num_keypoints: 92 | keypoints = keypoints.view(num_keypoints, -1, 3) 93 | 94 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 95 | boxes = boxes[keep] 96 | classes = classes[keep] 97 | if self.return_masks: 98 | masks = masks[keep] 99 | if keypoints is not None: 100 | keypoints = keypoints[keep] 101 | 102 | target = {} 103 | target["boxes"] = boxes 104 | target["labels"] = classes 105 | if self.return_masks: 106 | target["masks"] = masks 107 | target["image_id"] = image_id 108 | if keypoints is not None: 109 | target["keypoints"] = keypoints 110 | 111 | # for conversion to coco api 112 | area = torch.tensor([obj["area"] for obj in anno]) 113 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 114 | target["area"] = area[keep] 115 | target["iscrowd"] = iscrowd[keep] 116 | 117 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 118 | target["size"] = torch.as_tensor([int(h), int(w)]) 119 | 120 | return image, target 121 | 122 | 123 | def make_coco_transforms(image_set): 124 | 125 | normalize = T.Compose([ 126 | T.ToTensor(), 127 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 128 | ]) 129 | 130 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 131 | 132 | if image_set == 'train': 133 | return T.Compose([ 134 | T.RandomHorizontalFlip(), 135 | T.RandomSelect( 136 | T.RandomResize(scales, max_size=1333), 137 | T.Compose([ 138 | T.RandomResize([400, 500, 600]), 139 | T.RandomSizeCrop(384, 600), 140 | T.RandomResize(scales, max_size=1333), 141 | ]) 142 | ), 143 | normalize, 144 | ]) 145 | 146 | if image_set == 'val': 147 | return T.Compose([ 148 | T.RandomResize([800], max_size=1333), 149 | normalize, 150 | ]) 151 | 152 | raise ValueError(f'unknown {image_set}') 153 | 154 | 155 | def build(image_set, args): 156 | root = Path(args.coco_path) 157 | assert root.exists(), f'provided COCO path {root} does not exist' 158 | mode = 'instances' 159 | PATHS = { 160 | "train": (root / "images" / "train2017", root / "annotations" / f'{mode}_train2017.json'), 161 | "val": (root / "images" / "val2017", root / "annotations" / f'{mode}_val2017.json'), 162 | "test": (root / "images" / "test2017", root / "annotations" / f'image_info_test-dev2017.json'), 163 | } 164 | 165 | img_folder, ann_file = PATHS[image_set] 166 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) 167 | return dataset 168 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Train and eval functions used in main.py 12 | """ 13 | import math 14 | import os 15 | import sys 16 | from typing import Iterable 17 | 18 | import torch 19 | 20 | import util.misc as utils 21 | from datasets.coco_eval import CocoEvaluator 22 | from datasets.panoptic_eval import PanopticEvaluator 23 | 24 | 25 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 26 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 27 | device: torch.device, epoch: int, max_norm: float = 0): 28 | model.train() 29 | criterion.train() 30 | metric_logger = utils.MetricLogger(delimiter=" ") 31 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 32 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 33 | header = 'Epoch: [{}]'.format(epoch) 34 | print_freq = 10 35 | 36 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 37 | samples = samples.to(device) 38 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 39 | 40 | outputs = model(samples) 41 | loss_dict = criterion(outputs, targets) 42 | weight_dict = criterion.weight_dict 43 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 44 | 45 | # reduce losses over all GPUs for logging purposes 46 | loss_dict_reduced = utils.reduce_dict(loss_dict) 47 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 48 | for k, v in loss_dict_reduced.items()} 49 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 50 | for k, v in loss_dict_reduced.items() if k in weight_dict} 51 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 52 | 53 | loss_value = losses_reduced_scaled.item() 54 | 55 | if not math.isfinite(loss_value): 56 | print("Loss is {}, stopping training".format(loss_value)) 57 | print(loss_dict_reduced) 58 | sys.exit(1) 59 | 60 | optimizer.zero_grad() 61 | losses.backward() 62 | if max_norm > 0: 63 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 64 | optimizer.step() 65 | 66 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 67 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 68 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 69 | # gather the stats from all processes 70 | metric_logger.synchronize_between_processes() 71 | print("Averaged stats:", metric_logger) 72 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 73 | 74 | 75 | @torch.no_grad() 76 | def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): 77 | model.eval() 78 | criterion.eval() 79 | 80 | metric_logger = utils.MetricLogger(delimiter=" ") 81 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 82 | header = 'Test:' 83 | 84 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) 85 | coco_evaluator = CocoEvaluator(base_ds, iou_types) 86 | # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] 87 | 88 | panoptic_evaluator = None 89 | if 'panoptic' in postprocessors.keys(): 90 | panoptic_evaluator = PanopticEvaluator( 91 | data_loader.dataset.ann_file, 92 | data_loader.dataset.ann_folder, 93 | output_dir=os.path.join(output_dir, "panoptic_eval"), 94 | ) 95 | 96 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 97 | samples = samples.to(device) 98 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 99 | 100 | outputs = model(samples) 101 | loss_dict = criterion(outputs, targets) 102 | weight_dict = criterion.weight_dict 103 | 104 | # reduce losses over all GPUs for logging purposes 105 | loss_dict_reduced = utils.reduce_dict(loss_dict) 106 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 107 | for k, v in loss_dict_reduced.items() if k in weight_dict} 108 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 109 | for k, v in loss_dict_reduced.items()} 110 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 111 | **loss_dict_reduced_scaled, 112 | **loss_dict_reduced_unscaled) 113 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 114 | 115 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 116 | results = postprocessors['bbox'](outputs, orig_target_sizes) 117 | if 'segm' in postprocessors.keys(): 118 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 119 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) 120 | res = {target['image_id'].item(): output for target, output in zip(targets, results)} 121 | if coco_evaluator is not None: 122 | coco_evaluator.update(res) 123 | 124 | if panoptic_evaluator is not None: 125 | res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) 126 | for i, target in enumerate(targets): 127 | image_id = target["image_id"].item() 128 | file_name = f"{image_id:012d}.png" 129 | res_pano[i]["image_id"] = image_id 130 | res_pano[i]["file_name"] = file_name 131 | 132 | panoptic_evaluator.update(res_pano) 133 | 134 | # gather the stats from all processes 135 | metric_logger.synchronize_between_processes() 136 | print("Averaged stats:", metric_logger) 137 | if coco_evaluator is not None: 138 | coco_evaluator.synchronize_between_processes() 139 | if panoptic_evaluator is not None: 140 | panoptic_evaluator.synchronize_between_processes() 141 | 142 | # accumulate predictions from all images 143 | if coco_evaluator is not None: 144 | coco_evaluator.accumulate() 145 | coco_evaluator.summarize() 146 | panoptic_res = None 147 | if panoptic_evaluator is not None: 148 | panoptic_res = panoptic_evaluator.summarize() 149 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 150 | if coco_evaluator is not None: 151 | if 'bbox' in postprocessors.keys(): 152 | stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() 153 | if 'segm' in postprocessors.keys(): 154 | stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() 155 | if panoptic_res is not None: 156 | stats['PQ_all'] = panoptic_res["All"] 157 | stats['PQ_th'] = panoptic_res["Things"] 158 | stats['PQ_st'] = panoptic_res["Stuff"] 159 | return stats, coco_evaluator 160 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Transforms and data augmentation for both image + bbox. 12 | """ 13 | import random 14 | 15 | import PIL 16 | import torch 17 | import torchvision.transforms as T 18 | import torchvision.transforms.functional as F 19 | 20 | from util.box_ops import box_xyxy_to_cxcywh 21 | from util.misc import interpolate 22 | 23 | 24 | def crop(image, target, region): 25 | cropped_image = F.crop(image, *region) 26 | 27 | target = target.copy() 28 | i, j, h, w = region 29 | 30 | # should we do something wrt the original size? 31 | target["size"] = torch.tensor([h, w]) 32 | 33 | fields = ["labels", "area", "iscrowd"] 34 | 35 | if "boxes" in target: 36 | boxes = target["boxes"] 37 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 38 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 39 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 40 | cropped_boxes = cropped_boxes.clamp(min=0) 41 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 42 | target["boxes"] = cropped_boxes.reshape(-1, 4) 43 | target["area"] = area 44 | fields.append("boxes") 45 | 46 | if "masks" in target: 47 | # FIXME should we update the area here if there are no boxes? 48 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 49 | fields.append("masks") 50 | 51 | # remove elements for which the boxes or masks that have zero area 52 | if "boxes" in target or "masks" in target: 53 | # favor boxes selection when defining which elements to keep 54 | # this is compatible with previous implementation 55 | if "boxes" in target: 56 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 57 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 58 | else: 59 | keep = target['masks'].flatten(1).any(1) 60 | 61 | for field in fields: 62 | target[field] = target[field][keep] 63 | 64 | return cropped_image, target 65 | 66 | 67 | def hflip(image, target): 68 | flipped_image = F.hflip(image) 69 | 70 | w, h = image.size 71 | 72 | target = target.copy() 73 | if "boxes" in target: 74 | boxes = target["boxes"] 75 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 76 | target["boxes"] = boxes 77 | 78 | if "masks" in target: 79 | target['masks'] = target['masks'].flip(-1) 80 | 81 | return flipped_image, target 82 | 83 | 84 | def resize(image, target, size, max_size=None): 85 | # size can be min_size (scalar) or (w, h) tuple 86 | 87 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 88 | w, h = image_size 89 | if max_size is not None: 90 | min_original_size = float(min((w, h))) 91 | max_original_size = float(max((w, h))) 92 | if max_original_size / min_original_size * size > max_size: 93 | size = int(round(max_size * min_original_size / max_original_size)) 94 | 95 | if (w <= h and w == size) or (h <= w and h == size): 96 | return (h, w) 97 | 98 | if w < h: 99 | ow = size 100 | oh = int(size * h / w) 101 | else: 102 | oh = size 103 | ow = int(size * w / h) 104 | 105 | return (oh, ow) 106 | 107 | def get_size(image_size, size, max_size=None): 108 | if isinstance(size, (list, tuple)): 109 | return size[::-1] 110 | else: 111 | return get_size_with_aspect_ratio(image_size, size, max_size) 112 | 113 | size = get_size(image.size, size, max_size) 114 | rescaled_image = F.resize(image, size) 115 | 116 | if target is None: 117 | return rescaled_image, None 118 | 119 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 120 | ratio_width, ratio_height = ratios 121 | 122 | target = target.copy() 123 | if "boxes" in target: 124 | boxes = target["boxes"] 125 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 126 | target["boxes"] = scaled_boxes 127 | 128 | if "area" in target: 129 | area = target["area"] 130 | scaled_area = area * (ratio_width * ratio_height) 131 | target["area"] = scaled_area 132 | 133 | h, w = size 134 | target["size"] = torch.tensor([h, w]) 135 | 136 | if "masks" in target: 137 | target['masks'] = interpolate( 138 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 139 | 140 | return rescaled_image, target 141 | 142 | 143 | def pad(image, target, padding): 144 | # assumes that we only pad on the bottom right corners 145 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 146 | if target is None: 147 | return padded_image, None 148 | target = target.copy() 149 | # should we do something wrt the original size? 150 | target["size"] = torch.tensor(padded_image.size[::-1]) 151 | if "masks" in target: 152 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 153 | return padded_image, target 154 | 155 | 156 | class RandomCrop(object): 157 | def __init__(self, size): 158 | self.size = size 159 | 160 | def __call__(self, img, target): 161 | region = T.RandomCrop.get_params(img, self.size) 162 | return crop(img, target, region) 163 | 164 | 165 | class RandomSizeCrop(object): 166 | def __init__(self, min_size: int, max_size: int): 167 | self.min_size = min_size 168 | self.max_size = max_size 169 | 170 | def __call__(self, img: PIL.Image.Image, target: dict): 171 | w = random.randint(self.min_size, min(img.width, self.max_size)) 172 | h = random.randint(self.min_size, min(img.height, self.max_size)) 173 | region = T.RandomCrop.get_params(img, [h, w]) 174 | return crop(img, target, region) 175 | 176 | 177 | class CenterCrop(object): 178 | def __init__(self, size): 179 | self.size = size 180 | 181 | def __call__(self, img, target): 182 | image_width, image_height = img.size 183 | crop_height, crop_width = self.size 184 | crop_top = int(round((image_height - crop_height) / 2.)) 185 | crop_left = int(round((image_width - crop_width) / 2.)) 186 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 187 | 188 | 189 | class RandomHorizontalFlip(object): 190 | def __init__(self, p=0.5): 191 | self.p = p 192 | 193 | def __call__(self, img, target): 194 | if random.random() < self.p: 195 | return hflip(img, target) 196 | return img, target 197 | 198 | 199 | class RandomResize(object): 200 | def __init__(self, sizes, max_size=None): 201 | assert isinstance(sizes, (list, tuple)) 202 | self.sizes = sizes 203 | self.max_size = max_size 204 | 205 | def __call__(self, img, target=None): 206 | size = random.choice(self.sizes) 207 | return resize(img, target, size, self.max_size) 208 | 209 | 210 | class RandomPad(object): 211 | def __init__(self, max_pad): 212 | self.max_pad = max_pad 213 | 214 | def __call__(self, img, target): 215 | pad_x = random.randint(0, self.max_pad) 216 | pad_y = random.randint(0, self.max_pad) 217 | return pad(img, target, (pad_x, pad_y)) 218 | 219 | 220 | class RandomSelect(object): 221 | """ 222 | Randomly selects between transforms1 and transforms2, 223 | with probability p for transforms1 and (1 - p) for transforms2 224 | """ 225 | def __init__(self, transforms1, transforms2, p=0.5): 226 | self.transforms1 = transforms1 227 | self.transforms2 = transforms2 228 | self.p = p 229 | 230 | def __call__(self, img, target): 231 | if random.random() < self.p: 232 | return self.transforms1(img, target) 233 | return self.transforms2(img, target) 234 | 235 | 236 | class ToTensor(object): 237 | def __call__(self, img, target): 238 | return F.to_tensor(img), target 239 | 240 | 241 | class RandomErasing(object): 242 | 243 | def __init__(self, *args, **kwargs): 244 | self.eraser = T.RandomErasing(*args, **kwargs) 245 | 246 | def __call__(self, img, target): 247 | return self.eraser(img), target 248 | 249 | 250 | class Normalize(object): 251 | def __init__(self, mean, std): 252 | self.mean = mean 253 | self.std = std 254 | 255 | def __call__(self, image, target=None): 256 | image = F.normalize(image, mean=self.mean, std=self.std) 257 | if target is None: 258 | return image, None 259 | target = target.copy() 260 | h, w = image.shape[-2:] 261 | if "boxes" in target: 262 | boxes = target["boxes"] 263 | boxes = box_xyxy_to_cxcywh(boxes) 264 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 265 | target["boxes"] = boxes 266 | return image, target 267 | 268 | 269 | class Compose(object): 270 | def __init__(self, transforms): 271 | self.transforms = transforms 272 | 273 | def __call__(self, image, target): 274 | for t in self.transforms: 275 | image, target = t(image, target) 276 | return image, target 277 | 278 | def __repr__(self): 279 | format_string = self.__class__.__name__ + "(" 280 | for t in self.transforms: 281 | format_string += "\n" 282 | format_string += " {0}".format(t) 283 | format_string += "\n)" 284 | return format_string 285 | -------------------------------------------------------------------------------- /datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | COCO evaluator that works in distributed mode. 12 | 13 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 14 | The difference is that there is less copy-pasting from pycocotools 15 | in the end of the file, as python3 can suppress prints with contextlib 16 | """ 17 | import os 18 | import contextlib 19 | import copy 20 | import numpy as np 21 | import torch 22 | 23 | from pycocotools.cocoeval import COCOeval 24 | from pycocotools.coco import COCO 25 | import pycocotools.mask as mask_util 26 | 27 | from util.misc import all_gather 28 | 29 | 30 | class CocoEvaluator(object): 31 | def __init__(self, coco_gt, iou_types): 32 | assert isinstance(iou_types, (list, tuple)) 33 | coco_gt = copy.deepcopy(coco_gt) 34 | self.coco_gt = coco_gt 35 | 36 | self.iou_types = iou_types 37 | self.coco_eval = {} 38 | for iou_type in iou_types: 39 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 40 | 41 | self.img_ids = [] 42 | self.eval_imgs = {k: [] for k in iou_types} 43 | 44 | def update(self, predictions): 45 | img_ids = list(np.unique(list(predictions.keys()))) 46 | self.img_ids.extend(img_ids) 47 | 48 | for iou_type in self.iou_types: 49 | results = self.prepare(predictions, iou_type) 50 | 51 | # suppress pycocotools prints 52 | with open(os.devnull, 'w') as devnull: 53 | with contextlib.redirect_stdout(devnull): 54 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 55 | coco_eval = self.coco_eval[iou_type] 56 | 57 | coco_eval.cocoDt = coco_dt 58 | coco_eval.params.imgIds = list(img_ids) 59 | img_ids, eval_imgs = evaluate(coco_eval) 60 | 61 | self.eval_imgs[iou_type].append(eval_imgs) 62 | 63 | def synchronize_between_processes(self): 64 | for iou_type in self.iou_types: 65 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 66 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 67 | 68 | def accumulate(self): 69 | for coco_eval in self.coco_eval.values(): 70 | coco_eval.accumulate() 71 | 72 | def summarize(self): 73 | for iou_type, coco_eval in self.coco_eval.items(): 74 | print("IoU metric: {}".format(iou_type)) 75 | coco_eval.summarize() 76 | 77 | def prepare(self, predictions, iou_type): 78 | if iou_type == "bbox": 79 | return self.prepare_for_coco_detection(predictions) 80 | elif iou_type == "segm": 81 | return self.prepare_for_coco_segmentation(predictions) 82 | elif iou_type == "keypoints": 83 | return self.prepare_for_coco_keypoint(predictions) 84 | else: 85 | raise ValueError("Unknown iou type {}".format(iou_type)) 86 | 87 | def prepare_for_coco_detection(self, predictions): 88 | coco_results = [] 89 | for original_id, prediction in predictions.items(): 90 | if len(prediction) == 0: 91 | continue 92 | 93 | boxes = prediction["boxes"] 94 | boxes = convert_to_xywh(boxes).tolist() 95 | scores = prediction["scores"].tolist() 96 | labels = prediction["labels"].tolist() 97 | 98 | coco_results.extend( 99 | [ 100 | { 101 | "image_id": original_id, 102 | "category_id": labels[k], 103 | "bbox": box, 104 | "score": scores[k], 105 | } 106 | for k, box in enumerate(boxes) 107 | ] 108 | ) 109 | return coco_results 110 | 111 | def prepare_for_coco_segmentation(self, predictions): 112 | coco_results = [] 113 | for original_id, prediction in predictions.items(): 114 | if len(prediction) == 0: 115 | continue 116 | 117 | scores = prediction["scores"] 118 | labels = prediction["labels"] 119 | masks = prediction["masks"] 120 | 121 | masks = masks > 0.5 122 | 123 | scores = prediction["scores"].tolist() 124 | labels = prediction["labels"].tolist() 125 | 126 | rles = [ 127 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 128 | for mask in masks 129 | ] 130 | for rle in rles: 131 | rle["counts"] = rle["counts"].decode("utf-8") 132 | 133 | coco_results.extend( 134 | [ 135 | { 136 | "image_id": original_id, 137 | "category_id": labels[k], 138 | "segmentation": rle, 139 | "score": scores[k], 140 | } 141 | for k, rle in enumerate(rles) 142 | ] 143 | ) 144 | return coco_results 145 | 146 | def prepare_for_coco_keypoint(self, predictions): 147 | coco_results = [] 148 | for original_id, prediction in predictions.items(): 149 | if len(prediction) == 0: 150 | continue 151 | 152 | boxes = prediction["boxes"] 153 | boxes = convert_to_xywh(boxes).tolist() 154 | scores = prediction["scores"].tolist() 155 | labels = prediction["labels"].tolist() 156 | keypoints = prediction["keypoints"] 157 | keypoints = keypoints.flatten(start_dim=1).tolist() 158 | 159 | coco_results.extend( 160 | [ 161 | { 162 | "image_id": original_id, 163 | "category_id": labels[k], 164 | 'keypoints': keypoint, 165 | "score": scores[k], 166 | } 167 | for k, keypoint in enumerate(keypoints) 168 | ] 169 | ) 170 | return coco_results 171 | 172 | 173 | def convert_to_xywh(boxes): 174 | xmin, ymin, xmax, ymax = boxes.unbind(1) 175 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 176 | 177 | 178 | def merge(img_ids, eval_imgs): 179 | all_img_ids = all_gather(img_ids) 180 | all_eval_imgs = all_gather(eval_imgs) 181 | 182 | merged_img_ids = [] 183 | for p in all_img_ids: 184 | merged_img_ids.extend(p) 185 | 186 | merged_eval_imgs = [] 187 | for p in all_eval_imgs: 188 | merged_eval_imgs.append(p) 189 | 190 | merged_img_ids = np.array(merged_img_ids) 191 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 192 | 193 | # keep only unique (and in sorted order) images 194 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 195 | merged_eval_imgs = merged_eval_imgs[..., idx] 196 | 197 | return merged_img_ids, merged_eval_imgs 198 | 199 | 200 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 201 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 202 | img_ids = list(img_ids) 203 | eval_imgs = list(eval_imgs.flatten()) 204 | 205 | coco_eval.evalImgs = eval_imgs 206 | coco_eval.params.imgIds = img_ids 207 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 208 | 209 | 210 | ################################################################# 211 | # From pycocotools, just removed the prints and fixed 212 | # a Python3 bug about unicode not defined 213 | ################################################################# 214 | 215 | 216 | def evaluate(self): 217 | ''' 218 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 219 | :return: None 220 | ''' 221 | # tic = time.time() 222 | # print('Running per image evaluation...') 223 | p = self.params 224 | # add backward compatibility if useSegm is specified in params 225 | if p.useSegm is not None: 226 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 227 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 228 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 229 | p.imgIds = list(np.unique(p.imgIds)) 230 | if p.useCats: 231 | p.catIds = list(np.unique(p.catIds)) 232 | p.maxDets = sorted(p.maxDets) 233 | self.params = p 234 | 235 | self._prepare() 236 | # loop through images, area range, max detection number 237 | catIds = p.catIds if p.useCats else [-1] 238 | 239 | if p.iouType == 'segm' or p.iouType == 'bbox': 240 | computeIoU = self.computeIoU 241 | elif p.iouType == 'keypoints': 242 | computeIoU = self.computeOks 243 | self.ious = { 244 | (imgId, catId): computeIoU(imgId, catId) 245 | for imgId in p.imgIds 246 | for catId in catIds} 247 | 248 | evaluateImg = self.evaluateImg 249 | maxDet = p.maxDets[-1] 250 | evalImgs = [ 251 | evaluateImg(imgId, catId, areaRng, maxDet) 252 | for catId in catIds 253 | for areaRng in p.areaRng 254 | for imgId in p.imgIds 255 | ] 256 | # this is NOT in the pycocotools code, but could be done outside 257 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 258 | self._paramsEval = copy.deepcopy(self.params) 259 | # toc = time.time() 260 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 261 | return p.imgIds, evalImgs 262 | 263 | ################################################################# 264 | # end of straight copy from pycocotools, just removing the prints 265 | ################################################################# 266 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional DETR 2 | 3 | This repository is an official implementation of the ICCV 2021 paper "[Conditional DETR for Fast Training Convergence](https://arxiv.org/abs/2108.06152)". 4 | 5 | * Conditional DETR is integrated in [Huggingface](https://huggingface.co/docs/transformers/main/en/model_doc/conditional_detr), try out our model [here](https://huggingface.co/microsoft/conditional-detr-resnet-50). 6 | 7 | ## Introduction 8 | 9 | The DETR approach applies the 10 | transformer encoder and decoder architecture to object detection 11 | and achieves promising performance. In this paper, 12 | we handle the critical issue, slow training convergence, 13 | and present a conditional cross-attention mechanism for 14 | fast DETR training. Our approach is motivated by that the 15 | cross-attention in DETR relies highly on the content embeddings 16 | and that the spatial embeddings make minor contributions, 17 | increasing the need for high-quality content embeddings 18 | and thus increasing the training difficulty. 19 | 20 |
21 | 22 |
23 | 24 | Our conditional DETR learns a conditional 25 | spatial query from the decoder embedding 26 | for decoder multi-head cross-attention. 27 | The benefit is that through the conditional spatial query, 28 | each cross-attention head is able to 29 | attend 30 | to a band containing a distinct region, 31 | e.g., one object extremity or a region inside the object box (Figure 1). 32 | This narrows down the spatial range for localizing the distinct regions 33 | for object classification and box regression, 34 | thus relaxing the dependence on the content embeddings and 35 | easing the training. Empirical results show that conditional 36 | DETR converges 6.7x faster for the backbones R50 and 37 | R101 and 10x faster for stronger backbones DC5-R50 and 38 | DC5-R101. 39 | 40 |
41 | 42 | 43 |
44 | 45 | 46 | 47 | ## Model Zoo 48 | 49 | We provide conditional DETR and conditional DETR-DC5 models. 50 | AP is computed on COCO 2017 *val*. 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 |
MethodEpochsParams (M)FLOPs (G)APAPSAPMAPLURL
DETR-R50500418642.020.545.861.1model
log
DETR-R5050418634.813.937.354.4model
log
DETR-DC5-R505004118743.322.547.361.1model
log
DETR-R1015006015243.521.048.061.8model
log
DETR-R101506015236.915.540.655.6model
log
DETR-DC5-R1015006025344.923.749.562.3model
log
Conditional DETR-R5050449041.020.644.359.3model
log
Conditional DETR-DC5-R50504419543.723.947.660.1model
log
Conditional DETR-R101506315642.821.746.660.9model
log
Conditional DETR-DC5-R101506326245.026.148.962.8model
log
179 | 180 | The models are also available via torch hub, to load conditional DETR R50 with pretrained weights simply do: 181 | ```python 182 | model = torch.hub.load('Atten4Vis/ConditionalDETR:main', 'conditional_detr_resnet50', pretrained=True) 183 | ``` 184 | 185 | Note: 186 | 1. The numbers in the table are slightly differently 187 | from the numbers in the paper. We re-ran some experiments when releasing the codes. 188 | 2. "DC5" means removing the stride in C5 stage of ResNet and add a dilation of 2 instead. 189 | 190 | 191 | 192 | 193 | ## Installation 194 | 195 | ### Requirements 196 | - Python >= 3.7, CUDA >= 10.1 197 | - PyTorch >= 1.7.0, torchvision >= 0.6.1 198 | - Cython, COCOAPI, scipy, termcolor 199 | 200 | The code is developed using Python 3.8 with PyTorch 1.7.0. 201 | First, clone the repository locally: 202 | ```shell 203 | git clone https://github.com/Atten4Vis/ConditionalDETR.git 204 | ``` 205 | Then, install PyTorch and torchvision: 206 | ```shell 207 | conda install pytorch=1.7.0 torchvision=0.6.1 cudatoolkit=10.1 -c pytorch 208 | ``` 209 | Install other requirements: 210 | ```shell 211 | cd ConditionalDETR 212 | pip install -r requirements.txt 213 | ``` 214 | 215 | 216 | 217 | ## Usage 218 | 219 | ### Data preparation 220 | 221 | Download and extract COCO 2017 train and val images with annotations from 222 | [http://cocodataset.org](http://cocodataset.org/#download). 223 | We expect the directory structure to be the following: 224 | ``` 225 | path/to/coco/ 226 | ├── annotations/ # annotation json files 227 | └── images/ 228 | ├── train2017/ # train images 229 | ├── val2017/ # val images 230 | └── test2017/ # test images 231 | ``` 232 | 233 | ### Training 234 | 235 | To train conditional DETR-R50 on a single node with 8 gpus for 50 epochs run: 236 | ```shell 237 | bash scripts/conddetr_r50_epoch50.sh 238 | ``` 239 | or 240 | ```shell 241 | python -m torch.distributed.launch \ 242 | --nproc_per_node=8 \ 243 | --use_env \ 244 | main.py \ 245 | --resume auto \ 246 | --coco_path /path/to/coco \ 247 | --output_dir output/conddetr_r50_epoch50 248 | ``` 249 | The training process takes around 30 hours on a single machine with 8 V100 cards. 250 | 251 | Same as DETR training setting, we train conditional DETR with AdamW setting learning rate in the transformer to 1e-4 and 1e-5 in the backbone. 252 | Horizontal flips, scales and crops are used for augmentation. 253 | Images are rescaled to have min size 800 and max size 1333. 254 | The transformer is trained with dropout of 0.1, and the whole model is trained with grad clip of 0.1. 255 | 256 | ### Evaluation 257 | To evaluate conditional DETR-R50 on COCO *val* with 8 GPUs run: 258 | ```shell 259 | python -m torch.distributed.launch \ 260 | --nproc_per_node=8 \ 261 | --use_env \ 262 | main.py \ 263 | --batch_size 2 \ 264 | --eval \ 265 | --resume \ 266 | --coco_path /path/to/coco \ 267 | --output_dir output/ 268 | ``` 269 | 270 | Note that numbers vary depending on batch size (number of images) per GPU. 271 | Non-DC5 models were trained with batch size 2, and DC5 with 1, 272 | so DC5 models show a significant drop in AP if evaluated with more 273 | than 1 image per GPU. 274 | 275 | ## License 276 | 277 | Conditional DETR is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information. 278 | 279 | 280 | 281 | ## Citation 282 | 283 | ```bibtex 284 | @inproceedings{meng2021-CondDETR, 285 | title = {Conditional DETR for Fast Training Convergence}, 286 | author = {Meng, Depu and Chen, Xiaokang and Fan, Zejia and Zeng, Gang and Li, Houqiang and Yuan, Yuhui and Sun, Lei and Wang, Jingdong}, 287 | booktitle = {Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, 288 | year = {2021} 289 | } 290 | ``` 291 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | import argparse 11 | import datetime 12 | import json 13 | import random 14 | import time 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | import torch 19 | from torch.utils.data import DataLoader, DistributedSampler 20 | 21 | import datasets 22 | import util.misc as utils 23 | from datasets import build_dataset, get_coco_api_from_dataset 24 | from engine import evaluate, train_one_epoch 25 | from models import build_model 26 | 27 | 28 | def get_args_parser(): 29 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 30 | parser.add_argument('--lr', default=1e-4, type=float) 31 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 32 | parser.add_argument('--batch_size', default=2, type=int) 33 | parser.add_argument('--weight_decay', default=1e-4, type=float) 34 | parser.add_argument('--epochs', default=50, type=int) 35 | parser.add_argument('--lr_drop', default=40, type=int) 36 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 37 | help='gradient clipping max norm') 38 | 39 | # Model parameters 40 | parser.add_argument('--frozen_weights', type=str, default=None, 41 | help="Path to the pretrained model. If set, only the mask head will be trained") 42 | # * Backbone 43 | parser.add_argument('--backbone', default='resnet50', type=str, 44 | help="Name of the convolutional backbone to use") 45 | parser.add_argument('--dilation', action='store_true', 46 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 47 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 48 | help="Type of positional embedding to use on top of the image features") 49 | 50 | # * Transformer 51 | parser.add_argument('--enc_layers', default=6, type=int, 52 | help="Number of encoding layers in the transformer") 53 | parser.add_argument('--dec_layers', default=6, type=int, 54 | help="Number of decoding layers in the transformer") 55 | parser.add_argument('--dim_feedforward', default=2048, type=int, 56 | help="Intermediate size of the feedforward layers in the transformer blocks") 57 | parser.add_argument('--hidden_dim', default=256, type=int, 58 | help="Size of the embeddings (dimension of the transformer)") 59 | parser.add_argument('--dropout', default=0.1, type=float, 60 | help="Dropout applied in the transformer") 61 | parser.add_argument('--nheads', default=8, type=int, 62 | help="Number of attention heads inside the transformer's attentions") 63 | parser.add_argument('--num_queries', default=300, type=int, 64 | help="Number of query slots") 65 | parser.add_argument('--pre_norm', action='store_true') 66 | 67 | # * Segmentation 68 | parser.add_argument('--masks', action='store_true', 69 | help="Train segmentation head if the flag is provided") 70 | 71 | # Loss 72 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 73 | help="Disables auxiliary decoding losses (loss at each layer)") 74 | 75 | # * Matcher 76 | parser.add_argument('--set_cost_class', default=2, type=float, 77 | help="Class coefficient in the matching cost") 78 | parser.add_argument('--set_cost_bbox', default=5, type=float, 79 | help="L1 box coefficient in the matching cost") 80 | parser.add_argument('--set_cost_giou', default=2, type=float, 81 | help="giou box coefficient in the matching cost") 82 | 83 | # * Loss coefficients 84 | parser.add_argument('--mask_loss_coef', default=1, type=float) 85 | parser.add_argument('--dice_loss_coef', default=1, type=float) 86 | parser.add_argument('--cls_loss_coef', default=2, type=float) 87 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 88 | parser.add_argument('--giou_loss_coef', default=2, type=float) 89 | parser.add_argument('--focal_alpha', default=0.25, type=float) 90 | 91 | # dataset parameters 92 | parser.add_argument('--dataset_file', default='coco') 93 | parser.add_argument('--coco_path', type=str) 94 | parser.add_argument('--coco_panoptic_path', type=str) 95 | parser.add_argument('--remove_difficult', action='store_true') 96 | 97 | parser.add_argument('--output_dir', default='', 98 | help='path where to save, empty for no saving') 99 | parser.add_argument('--device', default='cuda', 100 | help='device to use for training / testing') 101 | parser.add_argument('--seed', default=42, type=int) 102 | parser.add_argument('--resume', default='', help='resume from checkpoint') 103 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 104 | help='start epoch') 105 | parser.add_argument('--eval', action='store_true') 106 | parser.add_argument('--num_workers', default=2, type=int) 107 | 108 | # distributed training parameters 109 | parser.add_argument('--world_size', default=1, type=int, 110 | help='number of distributed processes') 111 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 112 | return parser 113 | 114 | 115 | def main(args): 116 | utils.init_distributed_mode(args) 117 | print("git:\n {}\n".format(utils.get_sha())) 118 | 119 | if args.frozen_weights is not None: 120 | assert args.masks, "Frozen training is meant for segmentation only" 121 | print(args) 122 | 123 | device = torch.device(args.device) 124 | 125 | # fix the seed for reproducibility 126 | seed = args.seed + utils.get_rank() 127 | torch.manual_seed(seed) 128 | np.random.seed(seed) 129 | random.seed(seed) 130 | 131 | model, criterion, postprocessors = build_model(args) 132 | model.to(device) 133 | 134 | model_without_ddp = model 135 | if args.distributed: 136 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 137 | model_without_ddp = model.module 138 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 139 | print('number of params:', n_parameters) 140 | 141 | param_dicts = [ 142 | {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, 143 | { 144 | "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], 145 | "lr": args.lr_backbone, 146 | }, 147 | ] 148 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 149 | weight_decay=args.weight_decay) 150 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 151 | 152 | dataset_train = build_dataset(image_set='train', args=args) 153 | dataset_val = build_dataset(image_set='val', args=args) 154 | 155 | if args.distributed: 156 | sampler_train = DistributedSampler(dataset_train) 157 | sampler_val = DistributedSampler(dataset_val, shuffle=False) 158 | else: 159 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 160 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 161 | 162 | batch_sampler_train = torch.utils.data.BatchSampler( 163 | sampler_train, args.batch_size, drop_last=True) 164 | 165 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 166 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 167 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, 168 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) 169 | 170 | if args.dataset_file == "coco_panoptic": 171 | # We also evaluate AP during panoptic training, on original coco DS 172 | coco_val = datasets.coco.build("val", args) 173 | base_ds = get_coco_api_from_dataset(coco_val) 174 | else: 175 | base_ds = get_coco_api_from_dataset(dataset_val) 176 | 177 | if args.frozen_weights is not None: 178 | checkpoint = torch.load(args.frozen_weights, map_location='cpu') 179 | model_without_ddp.detr.load_state_dict(checkpoint['model']) 180 | 181 | output_dir = Path(args.output_dir) 182 | if args.resume: 183 | if args.resume.startswith('https'): 184 | checkpoint = torch.hub.load_state_dict_from_url( 185 | args.resume, map_location='cpu', check_hash=True) 186 | else: 187 | checkpoint = torch.load(args.resume, map_location='cpu') 188 | model_without_ddp.load_state_dict(checkpoint['model']) 189 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 190 | optimizer.load_state_dict(checkpoint['optimizer']) 191 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 192 | args.start_epoch = checkpoint['epoch'] + 1 193 | 194 | if args.eval: 195 | test_stats, coco_evaluator = evaluate(model, criterion, postprocessors, 196 | data_loader_val, base_ds, device, args.output_dir) 197 | if args.output_dir: 198 | utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth") 199 | return 200 | 201 | print("Start training") 202 | start_time = time.time() 203 | for epoch in range(args.start_epoch, args.epochs): 204 | if args.distributed: 205 | sampler_train.set_epoch(epoch) 206 | train_stats = train_one_epoch( 207 | model, criterion, data_loader_train, optimizer, device, epoch, 208 | args.clip_max_norm) 209 | lr_scheduler.step() 210 | if args.output_dir: 211 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 212 | # extra checkpoint before LR drop and every 100 epochs 213 | if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0: 214 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') 215 | for checkpoint_path in checkpoint_paths: 216 | utils.save_on_master({ 217 | 'model': model_without_ddp.state_dict(), 218 | 'optimizer': optimizer.state_dict(), 219 | 'lr_scheduler': lr_scheduler.state_dict(), 220 | 'epoch': epoch, 221 | 'args': args, 222 | }, checkpoint_path) 223 | 224 | test_stats, coco_evaluator = evaluate( 225 | model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir 226 | ) 227 | 228 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 229 | **{f'test_{k}': v for k, v in test_stats.items()}, 230 | 'epoch': epoch, 231 | 'n_parameters': n_parameters} 232 | 233 | if args.output_dir and utils.is_main_process(): 234 | with (output_dir / "log.txt").open("a") as f: 235 | f.write(json.dumps(log_stats) + "\n") 236 | 237 | # for evaluation logs 238 | if coco_evaluator is not None: 239 | (output_dir / 'eval').mkdir(exist_ok=True) 240 | if "bbox" in coco_evaluator.coco_eval: 241 | filenames = ['latest.pth'] 242 | if epoch % 50 == 0: 243 | filenames.append(f'{epoch:03}.pth') 244 | for name in filenames: 245 | torch.save(coco_evaluator.coco_eval["bbox"].eval, 246 | output_dir / "eval" / name) 247 | 248 | total_time = time.time() - start_time 249 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 250 | print('Training time {}'.format(total_time_str)) 251 | 252 | 253 | if __name__ == '__main__': 254 | parser = argparse.ArgumentParser('Conditional DETR training and evaluation script', parents=[get_args_parser()]) 255 | args = parser.parse_args() 256 | if args.output_dir: 257 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 258 | main(args) 259 | -------------------------------------------------------------------------------- /models/segmentation.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | This file provides the definition of the convolutional heads used to predict masks, as well as the losses 12 | """ 13 | import io 14 | from collections import defaultdict 15 | from typing import List, Optional 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from torch import Tensor 21 | from PIL import Image 22 | 23 | import util.box_ops as box_ops 24 | from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list 25 | 26 | try: 27 | from panopticapi.utils import id2rgb, rgb2id 28 | except ImportError: 29 | pass 30 | 31 | 32 | class DETRsegm(nn.Module): 33 | def __init__(self, detr, freeze_detr=False): 34 | super().__init__() 35 | self.detr = detr 36 | 37 | if freeze_detr: 38 | for p in self.parameters(): 39 | p.requires_grad_(False) 40 | 41 | hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead 42 | self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) 43 | self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) 44 | 45 | def forward(self, samples: NestedTensor): 46 | if isinstance(samples, (list, torch.Tensor)): 47 | samples = nested_tensor_from_tensor_list(samples) 48 | features, pos = self.detr.backbone(samples) 49 | 50 | bs = features[-1].tensors.shape[0] 51 | 52 | src, mask = features[-1].decompose() 53 | assert mask is not None 54 | src_proj = self.detr.input_proj(src) 55 | hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) 56 | 57 | outputs_class = self.detr.class_embed(hs) 58 | outputs_coord = self.detr.bbox_embed(hs).sigmoid() 59 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 60 | if self.detr.aux_loss: 61 | out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord) 62 | 63 | # FIXME h_boxes takes the last one computed, keep this in mind 64 | bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) 65 | 66 | seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) 67 | outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) 68 | 69 | out["pred_masks"] = outputs_seg_masks 70 | return out 71 | 72 | 73 | def _expand(tensor, length: int): 74 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) 75 | 76 | 77 | class MaskHeadSmallConv(nn.Module): 78 | """ 79 | Simple convolutional head, using group norm. 80 | Upsampling is done using a FPN approach 81 | """ 82 | 83 | def __init__(self, dim, fpn_dims, context_dim): 84 | super().__init__() 85 | 86 | inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] 87 | self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) 88 | self.gn1 = torch.nn.GroupNorm(8, dim) 89 | self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) 90 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) 91 | self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) 92 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) 93 | self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) 94 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) 95 | self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) 96 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) 97 | self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) 98 | 99 | self.dim = dim 100 | 101 | self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) 102 | self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) 103 | self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) 104 | 105 | for m in self.modules(): 106 | if isinstance(m, nn.Conv2d): 107 | nn.init.kaiming_uniform_(m.weight, a=1) 108 | nn.init.constant_(m.bias, 0) 109 | 110 | def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): 111 | x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) 112 | 113 | x = self.lay1(x) 114 | x = self.gn1(x) 115 | x = F.relu(x) 116 | x = self.lay2(x) 117 | x = self.gn2(x) 118 | x = F.relu(x) 119 | 120 | cur_fpn = self.adapter1(fpns[0]) 121 | if cur_fpn.size(0) != x.size(0): 122 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 123 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 124 | x = self.lay3(x) 125 | x = self.gn3(x) 126 | x = F.relu(x) 127 | 128 | cur_fpn = self.adapter2(fpns[1]) 129 | if cur_fpn.size(0) != x.size(0): 130 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 131 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 132 | x = self.lay4(x) 133 | x = self.gn4(x) 134 | x = F.relu(x) 135 | 136 | cur_fpn = self.adapter3(fpns[2]) 137 | if cur_fpn.size(0) != x.size(0): 138 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 139 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 140 | x = self.lay5(x) 141 | x = self.gn5(x) 142 | x = F.relu(x) 143 | 144 | x = self.out_lay(x) 145 | return x 146 | 147 | 148 | class MHAttentionMap(nn.Module): 149 | """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" 150 | 151 | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): 152 | super().__init__() 153 | self.num_heads = num_heads 154 | self.hidden_dim = hidden_dim 155 | self.dropout = nn.Dropout(dropout) 156 | 157 | self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 158 | self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 159 | 160 | nn.init.zeros_(self.k_linear.bias) 161 | nn.init.zeros_(self.q_linear.bias) 162 | nn.init.xavier_uniform_(self.k_linear.weight) 163 | nn.init.xavier_uniform_(self.q_linear.weight) 164 | self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 165 | 166 | def forward(self, q, k, mask: Optional[Tensor] = None): 167 | q = self.q_linear(q) 168 | k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) 169 | qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) 170 | kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) 171 | weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) 172 | 173 | if mask is not None: 174 | weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) 175 | weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) 176 | weights = self.dropout(weights) 177 | return weights 178 | 179 | 180 | def dice_loss(inputs, targets, num_boxes): 181 | """ 182 | Compute the DICE loss, similar to generalized IOU for masks 183 | Args: 184 | inputs: A float tensor of arbitrary shape. 185 | The predictions for each example. 186 | targets: A float tensor with the same shape as inputs. Stores the binary 187 | classification label for each element in inputs 188 | (0 for the negative class and 1 for the positive class). 189 | """ 190 | inputs = inputs.sigmoid() 191 | inputs = inputs.flatten(1) 192 | numerator = 2 * (inputs * targets).sum(1) 193 | denominator = inputs.sum(-1) + targets.sum(-1) 194 | loss = 1 - (numerator + 1) / (denominator + 1) 195 | return loss.sum() / num_boxes 196 | 197 | 198 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 199 | """ 200 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 201 | Args: 202 | inputs: A float tensor of arbitrary shape. 203 | The predictions for each example. 204 | targets: A float tensor with the same shape as inputs. Stores the binary 205 | classification label for each element in inputs 206 | (0 for the negative class and 1 for the positive class). 207 | alpha: (optional) Weighting factor in range (0,1) to balance 208 | positive vs negative examples. Default = -1 (no weighting). 209 | gamma: Exponent of the modulating factor (1 - p_t) to 210 | balance easy vs hard examples. 211 | Returns: 212 | Loss tensor 213 | """ 214 | prob = inputs.sigmoid() 215 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 216 | p_t = prob * targets + (1 - prob) * (1 - targets) 217 | loss = ce_loss * ((1 - p_t) ** gamma) 218 | 219 | if alpha >= 0: 220 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 221 | loss = alpha_t * loss 222 | 223 | return loss.mean(1).sum() / num_boxes 224 | 225 | 226 | class PostProcessSegm(nn.Module): 227 | def __init__(self, threshold=0.5): 228 | super().__init__() 229 | self.threshold = threshold 230 | 231 | @torch.no_grad() 232 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 233 | assert len(orig_target_sizes) == len(max_target_sizes) 234 | max_h, max_w = max_target_sizes.max(0)[0].tolist() 235 | outputs_masks = outputs["pred_masks"].squeeze(2) 236 | outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) 237 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() 238 | 239 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): 240 | img_h, img_w = t[0], t[1] 241 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) 242 | results[i]["masks"] = F.interpolate( 243 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 244 | ).byte() 245 | 246 | return results 247 | 248 | 249 | class PostProcessPanoptic(nn.Module): 250 | """This class converts the output of the model to the final panoptic result, in the format expected by the 251 | coco panoptic API """ 252 | 253 | def __init__(self, is_thing_map, threshold=0.85): 254 | """ 255 | Parameters: 256 | is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether 257 | the class is a thing (True) or a stuff (False) class 258 | threshold: confidence threshold: segments with confidence lower than this will be deleted 259 | """ 260 | super().__init__() 261 | self.threshold = threshold 262 | self.is_thing_map = is_thing_map 263 | 264 | def forward(self, outputs, processed_sizes, target_sizes=None): 265 | """ This function computes the panoptic prediction from the model's predictions. 266 | Parameters: 267 | outputs: This is a dict coming directly from the model. See the model doc for the content. 268 | processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the 269 | model, ie the size after data augmentation but before batching. 270 | target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size 271 | of each prediction. If left to None, it will default to the processed_sizes 272 | """ 273 | if target_sizes is None: 274 | target_sizes = processed_sizes 275 | assert len(processed_sizes) == len(target_sizes) 276 | out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] 277 | assert len(out_logits) == len(raw_masks) == len(target_sizes) 278 | preds = [] 279 | 280 | def to_tuple(tup): 281 | if isinstance(tup, tuple): 282 | return tup 283 | return tuple(tup.cpu().tolist()) 284 | 285 | for cur_logits, cur_masks, cur_boxes, size, target_size in zip( 286 | out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes 287 | ): 288 | # we filter empty queries and detection below threshold 289 | scores, labels = cur_logits.softmax(-1).max(-1) 290 | keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) 291 | cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) 292 | cur_scores = cur_scores[keep] 293 | cur_classes = cur_classes[keep] 294 | cur_masks = cur_masks[keep] 295 | cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) 296 | cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) 297 | 298 | h, w = cur_masks.shape[-2:] 299 | assert len(cur_boxes) == len(cur_classes) 300 | 301 | # It may be that we have several predicted masks for the same stuff class. 302 | # In the following, we track the list of masks ids for each stuff class (they are merged later on) 303 | cur_masks = cur_masks.flatten(1) 304 | stuff_equiv_classes = defaultdict(lambda: []) 305 | for k, label in enumerate(cur_classes): 306 | if not self.is_thing_map[label.item()]: 307 | stuff_equiv_classes[label.item()].append(k) 308 | 309 | def get_ids_area(masks, scores, dedup=False): 310 | # This helper function creates the final panoptic segmentation image 311 | # It also returns the area of the masks that appears on the image 312 | 313 | m_id = masks.transpose(0, 1).softmax(-1) 314 | 315 | if m_id.shape[-1] == 0: 316 | # We didn't detect any mask :( 317 | m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) 318 | else: 319 | m_id = m_id.argmax(-1).view(h, w) 320 | 321 | if dedup: 322 | # Merge the masks corresponding to the same stuff class 323 | for equiv in stuff_equiv_classes.values(): 324 | if len(equiv) > 1: 325 | for eq_id in equiv: 326 | m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) 327 | 328 | final_h, final_w = to_tuple(target_size) 329 | 330 | seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) 331 | seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) 332 | 333 | np_seg_img = ( 334 | torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() 335 | ) 336 | m_id = torch.from_numpy(rgb2id(np_seg_img)) 337 | 338 | area = [] 339 | for i in range(len(scores)): 340 | area.append(m_id.eq(i).sum().item()) 341 | return area, seg_img 342 | 343 | area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) 344 | if cur_classes.numel() > 0: 345 | # We know filter empty masks as long as we find some 346 | while True: 347 | filtered_small = torch.as_tensor( 348 | [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device 349 | ) 350 | if filtered_small.any().item(): 351 | cur_scores = cur_scores[~filtered_small] 352 | cur_classes = cur_classes[~filtered_small] 353 | cur_masks = cur_masks[~filtered_small] 354 | area, seg_img = get_ids_area(cur_masks, cur_scores) 355 | else: 356 | break 357 | 358 | else: 359 | cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) 360 | 361 | segments_info = [] 362 | for i, a in enumerate(area): 363 | cat = cur_classes[i].item() 364 | segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) 365 | del cur_classes 366 | 367 | with io.BytesIO() as out: 368 | seg_img.save(out, format="PNG") 369 | predictions = {"png_string": out.getvalue(), "segments_info": segments_info} 370 | preds.append(predictions) 371 | return preds 372 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Copied from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Misc functions, including distributed helpers. 12 | 13 | Mostly copy-paste from torchvision references. 14 | """ 15 | import os 16 | import subprocess 17 | import time 18 | from collections import defaultdict, deque 19 | import datetime 20 | import pickle 21 | from typing import Optional, List 22 | 23 | import torch 24 | import torch.distributed as dist 25 | from torch import Tensor 26 | 27 | # needed due to empty tensor bug in pytorch and torchvision 0.5 28 | import torchvision 29 | if float(torchvision.__version__.split(".")[1]) < 7.0: 30 | from torchvision.ops import _new_empty_tensor 31 | from torchvision.ops.misc import _output_size 32 | 33 | 34 | class SmoothedValue(object): 35 | """Track a series of values and provide access to smoothed values over a 36 | window or the global series average. 37 | """ 38 | 39 | def __init__(self, window_size=20, fmt=None): 40 | if fmt is None: 41 | fmt = "{median:.4f} ({global_avg:.4f})" 42 | self.deque = deque(maxlen=window_size) 43 | self.total = 0.0 44 | self.count = 0 45 | self.fmt = fmt 46 | 47 | def update(self, value, n=1): 48 | self.deque.append(value) 49 | self.count += n 50 | self.total += value * n 51 | 52 | def synchronize_between_processes(self): 53 | """ 54 | Warning: does not synchronize the deque! 55 | """ 56 | if not is_dist_avail_and_initialized(): 57 | return 58 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 59 | dist.barrier() 60 | dist.all_reduce(t) 61 | t = t.tolist() 62 | self.count = int(t[0]) 63 | self.total = t[1] 64 | 65 | @property 66 | def median(self): 67 | d = torch.tensor(list(self.deque)) 68 | return d.median().item() 69 | 70 | @property 71 | def avg(self): 72 | d = torch.tensor(list(self.deque), dtype=torch.float32) 73 | return d.mean().item() 74 | 75 | @property 76 | def global_avg(self): 77 | return self.total / self.count 78 | 79 | @property 80 | def max(self): 81 | return max(self.deque) 82 | 83 | @property 84 | def value(self): 85 | return self.deque[-1] 86 | 87 | def __str__(self): 88 | return self.fmt.format( 89 | median=self.median, 90 | avg=self.avg, 91 | global_avg=self.global_avg, 92 | max=self.max, 93 | value=self.value) 94 | 95 | 96 | def all_gather(data): 97 | """ 98 | Run all_gather on arbitrary picklable data (not necessarily tensors) 99 | Args: 100 | data: any picklable object 101 | Returns: 102 | list[data]: list of data gathered from each rank 103 | """ 104 | world_size = get_world_size() 105 | if world_size == 1: 106 | return [data] 107 | 108 | # serialized to a Tensor 109 | buffer = pickle.dumps(data) 110 | storage = torch.ByteStorage.from_buffer(buffer) 111 | tensor = torch.ByteTensor(storage).to("cuda") 112 | 113 | # obtain Tensor size of each rank 114 | local_size = torch.tensor([tensor.numel()], device="cuda") 115 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 116 | dist.all_gather(size_list, local_size) 117 | size_list = [int(size.item()) for size in size_list] 118 | max_size = max(size_list) 119 | 120 | # receiving Tensor from all ranks 121 | # we pad the tensor because torch all_gather does not support 122 | # gathering tensors of different shapes 123 | tensor_list = [] 124 | for _ in size_list: 125 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 126 | if local_size != max_size: 127 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 128 | tensor = torch.cat((tensor, padding), dim=0) 129 | dist.all_gather(tensor_list, tensor) 130 | 131 | data_list = [] 132 | for size, tensor in zip(size_list, tensor_list): 133 | buffer = tensor.cpu().numpy().tobytes()[:size] 134 | data_list.append(pickle.loads(buffer)) 135 | 136 | return data_list 137 | 138 | 139 | def reduce_dict(input_dict, average=True): 140 | """ 141 | Args: 142 | input_dict (dict): all the values will be reduced 143 | average (bool): whether to do average or sum 144 | Reduce the values in the dictionary from all processes so that all processes 145 | have the averaged results. Returns a dict with the same fields as 146 | input_dict, after reduction. 147 | """ 148 | world_size = get_world_size() 149 | if world_size < 2: 150 | return input_dict 151 | with torch.no_grad(): 152 | names = [] 153 | values = [] 154 | # sort the keys so that they are consistent across processes 155 | for k in sorted(input_dict.keys()): 156 | names.append(k) 157 | values.append(input_dict[k]) 158 | values = torch.stack(values, dim=0) 159 | dist.all_reduce(values) 160 | if average: 161 | values /= world_size 162 | reduced_dict = {k: v for k, v in zip(names, values)} 163 | return reduced_dict 164 | 165 | 166 | class MetricLogger(object): 167 | def __init__(self, delimiter="\t"): 168 | self.meters = defaultdict(SmoothedValue) 169 | self.delimiter = delimiter 170 | 171 | def update(self, **kwargs): 172 | for k, v in kwargs.items(): 173 | if isinstance(v, torch.Tensor): 174 | v = v.item() 175 | assert isinstance(v, (float, int)) 176 | self.meters[k].update(v) 177 | 178 | def __getattr__(self, attr): 179 | if attr in self.meters: 180 | return self.meters[attr] 181 | if attr in self.__dict__: 182 | return self.__dict__[attr] 183 | raise AttributeError("'{}' object has no attribute '{}'".format( 184 | type(self).__name__, attr)) 185 | 186 | def __str__(self): 187 | loss_str = [] 188 | for name, meter in self.meters.items(): 189 | loss_str.append( 190 | "{}: {}".format(name, str(meter)) 191 | ) 192 | return self.delimiter.join(loss_str) 193 | 194 | def synchronize_between_processes(self): 195 | for meter in self.meters.values(): 196 | meter.synchronize_between_processes() 197 | 198 | def add_meter(self, name, meter): 199 | self.meters[name] = meter 200 | 201 | def log_every(self, iterable, print_freq, header=None): 202 | i = 0 203 | if not header: 204 | header = '' 205 | start_time = time.time() 206 | end = time.time() 207 | iter_time = SmoothedValue(fmt='{avg:.4f}') 208 | data_time = SmoothedValue(fmt='{avg:.4f}') 209 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 210 | if torch.cuda.is_available(): 211 | log_msg = self.delimiter.join([ 212 | header, 213 | '[{0' + space_fmt + '}/{1}]', 214 | 'eta: {eta}', 215 | '{meters}', 216 | 'time: {time}', 217 | 'data: {data}', 218 | 'max mem: {memory:.0f}' 219 | ]) 220 | else: 221 | log_msg = self.delimiter.join([ 222 | header, 223 | '[{0' + space_fmt + '}/{1}]', 224 | 'eta: {eta}', 225 | '{meters}', 226 | 'time: {time}', 227 | 'data: {data}' 228 | ]) 229 | MB = 1024.0 * 1024.0 230 | for obj in iterable: 231 | data_time.update(time.time() - end) 232 | yield obj 233 | iter_time.update(time.time() - end) 234 | if i % print_freq == 0 or i == len(iterable) - 1: 235 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 236 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 237 | if torch.cuda.is_available(): 238 | print(log_msg.format( 239 | i, len(iterable), eta=eta_string, 240 | meters=str(self), 241 | time=str(iter_time), data=str(data_time), 242 | memory=torch.cuda.max_memory_allocated() / MB)) 243 | else: 244 | print(log_msg.format( 245 | i, len(iterable), eta=eta_string, 246 | meters=str(self), 247 | time=str(iter_time), data=str(data_time))) 248 | i += 1 249 | end = time.time() 250 | total_time = time.time() - start_time 251 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 252 | print('{} Total time: {} ({:.4f} s / it)'.format( 253 | header, total_time_str, total_time / len(iterable))) 254 | 255 | 256 | def get_sha(): 257 | cwd = os.path.dirname(os.path.abspath(__file__)) 258 | 259 | def _run(command): 260 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 261 | sha = 'N/A' 262 | diff = "clean" 263 | branch = 'N/A' 264 | try: 265 | sha = _run(['git', 'rev-parse', 'HEAD']) 266 | subprocess.check_output(['git', 'diff'], cwd=cwd) 267 | diff = _run(['git', 'diff-index', 'HEAD']) 268 | diff = "has uncommited changes" if diff else "clean" 269 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 270 | except Exception: 271 | pass 272 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 273 | return message 274 | 275 | 276 | def collate_fn(batch): 277 | batch = list(zip(*batch)) 278 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 279 | return tuple(batch) 280 | 281 | 282 | def _max_by_axis(the_list): 283 | # type: (List[List[int]]) -> List[int] 284 | maxes = the_list[0] 285 | for sublist in the_list[1:]: 286 | for index, item in enumerate(sublist): 287 | maxes[index] = max(maxes[index], item) 288 | return maxes 289 | 290 | 291 | class NestedTensor(object): 292 | def __init__(self, tensors, mask: Optional[Tensor]): 293 | self.tensors = tensors 294 | self.mask = mask 295 | 296 | def to(self, device): 297 | # type: (Device) -> NestedTensor # noqa 298 | cast_tensor = self.tensors.to(device) 299 | mask = self.mask 300 | if mask is not None: 301 | assert mask is not None 302 | cast_mask = mask.to(device) 303 | else: 304 | cast_mask = None 305 | return NestedTensor(cast_tensor, cast_mask) 306 | 307 | def decompose(self): 308 | return self.tensors, self.mask 309 | 310 | def __repr__(self): 311 | return str(self.tensors) 312 | 313 | 314 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 315 | # TODO make this more general 316 | if tensor_list[0].ndim == 3: 317 | if torchvision._is_tracing(): 318 | # nested_tensor_from_tensor_list() does not export well to ONNX 319 | # call _onnx_nested_tensor_from_tensor_list() instead 320 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 321 | 322 | # TODO make it support different-sized images 323 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 324 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 325 | batch_shape = [len(tensor_list)] + max_size 326 | b, c, h, w = batch_shape 327 | dtype = tensor_list[0].dtype 328 | device = tensor_list[0].device 329 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 330 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 331 | for img, pad_img, m in zip(tensor_list, tensor, mask): 332 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 333 | m[: img.shape[1], :img.shape[2]] = False 334 | else: 335 | raise ValueError('not supported') 336 | return NestedTensor(tensor, mask) 337 | 338 | 339 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 340 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 341 | @torch.jit.unused 342 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 343 | max_size = [] 344 | for i in range(tensor_list[0].dim()): 345 | max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) 346 | max_size.append(max_size_i) 347 | max_size = tuple(max_size) 348 | 349 | # work around for 350 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 351 | # m[: img.shape[1], :img.shape[2]] = False 352 | # which is not yet supported in onnx 353 | padded_imgs = [] 354 | padded_masks = [] 355 | for img in tensor_list: 356 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 357 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 358 | padded_imgs.append(padded_img) 359 | 360 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 361 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 362 | padded_masks.append(padded_mask.to(torch.bool)) 363 | 364 | tensor = torch.stack(padded_imgs) 365 | mask = torch.stack(padded_masks) 366 | 367 | return NestedTensor(tensor, mask=mask) 368 | 369 | 370 | def setup_for_distributed(is_master): 371 | """ 372 | This function disables printing when not in master process 373 | """ 374 | import builtins as __builtin__ 375 | builtin_print = __builtin__.print 376 | 377 | def print(*args, **kwargs): 378 | force = kwargs.pop('force', False) 379 | if is_master or force: 380 | builtin_print(*args, **kwargs) 381 | 382 | __builtin__.print = print 383 | 384 | 385 | def is_dist_avail_and_initialized(): 386 | if not dist.is_available(): 387 | return False 388 | if not dist.is_initialized(): 389 | return False 390 | return True 391 | 392 | 393 | def get_world_size(): 394 | if not is_dist_avail_and_initialized(): 395 | return 1 396 | return dist.get_world_size() 397 | 398 | 399 | def get_rank(): 400 | if not is_dist_avail_and_initialized(): 401 | return 0 402 | return dist.get_rank() 403 | 404 | 405 | def is_main_process(): 406 | return get_rank() == 0 407 | 408 | 409 | def save_on_master(*args, **kwargs): 410 | if is_main_process(): 411 | torch.save(*args, **kwargs) 412 | 413 | 414 | def init_distributed_mode(args): 415 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 416 | args.rank = int(os.environ["RANK"]) 417 | args.world_size = int(os.environ['WORLD_SIZE']) 418 | args.gpu = int(os.environ['LOCAL_RANK']) 419 | elif 'SLURM_PROCID' in os.environ: 420 | args.rank = int(os.environ['SLURM_PROCID']) 421 | args.gpu = args.rank % torch.cuda.device_count() 422 | else: 423 | print('Not using distributed mode') 424 | args.distributed = False 425 | return 426 | 427 | args.distributed = True 428 | 429 | torch.cuda.set_device(args.gpu) 430 | args.dist_backend = 'nccl' 431 | print('| distributed init (rank {}): {}'.format( 432 | args.rank, args.dist_url), flush=True) 433 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 434 | world_size=args.world_size, rank=args.rank) 435 | torch.distributed.barrier() 436 | setup_for_distributed(args.rank == 0) 437 | 438 | 439 | @torch.no_grad() 440 | def accuracy(output, target, topk=(1,)): 441 | """Computes the precision@k for the specified values of k""" 442 | if target.numel() == 0: 443 | return [torch.zeros([], device=output.device)] 444 | maxk = max(topk) 445 | batch_size = target.size(0) 446 | 447 | _, pred = output.topk(maxk, 1, True, True) 448 | pred = pred.t() 449 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 450 | 451 | res = [] 452 | for k in topk: 453 | correct_k = correct[:k].view(-1).float().sum(0) 454 | res.append(correct_k.mul_(100.0 / batch_size)) 455 | return res 456 | 457 | 458 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 459 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 460 | """ 461 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 462 | This will eventually be supported natively by PyTorch, and this 463 | class can go away. 464 | """ 465 | if float(torchvision.__version__.split(".")[1]) < 7.0: 466 | if input.numel() > 0: 467 | return torch.nn.functional.interpolate( 468 | input, size, scale_factor, mode, align_corners 469 | ) 470 | 471 | output_shape = _output_size(2, input, size, scale_factor) 472 | output_shape = list(input.shape[:-2]) + list(output_shape) 473 | return _new_empty_tensor(input, output_shape) 474 | else: 475 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 476 | 477 | def inverse_sigmoid(x, eps=1e-5): 478 | x = x.clamp(min=0, max=1) 479 | x1 = x.clamp(min=eps) 480 | x2 = (1 - x).clamp(min=eps) 481 | return torch.log(x1/x2) -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR Transformer class. 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | 10 | import math 11 | import copy 12 | from typing import Optional, List 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn, Tensor 17 | from .attention import MultiheadAttention 18 | 19 | class MLP(nn.Module): 20 | """ Very simple multi-layer perceptron (also called FFN)""" 21 | 22 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 23 | super().__init__() 24 | self.num_layers = num_layers 25 | h = [hidden_dim] * (num_layers - 1) 26 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 27 | 28 | def forward(self, x): 29 | for i, layer in enumerate(self.layers): 30 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 31 | return x 32 | 33 | def gen_sineembed_for_position(pos_tensor): 34 | # n_query, bs, _ = pos_tensor.size() 35 | # sineembed_tensor = torch.zeros(n_query, bs, 256) 36 | scale = 2 * math.pi 37 | dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) 38 | dim_t = 10000 ** (2 * (dim_t // 2) / 128) 39 | x_embed = pos_tensor[:, :, 0] * scale 40 | y_embed = pos_tensor[:, :, 1] * scale 41 | pos_x = x_embed[:, :, None] / dim_t 42 | pos_y = y_embed[:, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 44 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 45 | pos = torch.cat((pos_y, pos_x), dim=2) 46 | return pos 47 | 48 | class Transformer(nn.Module): 49 | 50 | def __init__(self, d_model=512, nhead=8, num_queries=300, num_encoder_layers=6, 51 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 52 | activation="relu", normalize_before=False, 53 | return_intermediate_dec=False): 54 | super().__init__() 55 | 56 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 57 | dropout, activation, normalize_before) 58 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 59 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 60 | 61 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 62 | dropout, activation, normalize_before) 63 | decoder_norm = nn.LayerNorm(d_model) 64 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 65 | return_intermediate=return_intermediate_dec, 66 | d_model=d_model) 67 | 68 | self._reset_parameters() 69 | 70 | self.d_model = d_model 71 | self.nhead = nhead 72 | self.dec_layers = num_decoder_layers 73 | 74 | def _reset_parameters(self): 75 | for p in self.parameters(): 76 | if p.dim() > 1: 77 | nn.init.xavier_uniform_(p) 78 | 79 | def forward(self, src, mask, query_embed, pos_embed): 80 | # flatten NxCxHxW to HWxNxC 81 | bs, c, h, w = src.shape 82 | src = src.flatten(2).permute(2, 0, 1) 83 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 84 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 85 | mask = mask.flatten(1) 86 | 87 | tgt = torch.zeros_like(query_embed) 88 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 89 | hs, references = self.decoder(tgt, memory, memory_key_padding_mask=mask, 90 | pos=pos_embed, query_pos=query_embed) 91 | return hs, references 92 | 93 | 94 | class TransformerEncoder(nn.Module): 95 | 96 | def __init__(self, encoder_layer, num_layers, norm=None): 97 | super().__init__() 98 | self.layers = _get_clones(encoder_layer, num_layers) 99 | self.num_layers = num_layers 100 | self.norm = norm 101 | 102 | def forward(self, src, 103 | mask: Optional[Tensor] = None, 104 | src_key_padding_mask: Optional[Tensor] = None, 105 | pos: Optional[Tensor] = None): 106 | output = src 107 | 108 | for layer in self.layers: 109 | output = layer(output, src_mask=mask, 110 | src_key_padding_mask=src_key_padding_mask, pos=pos) 111 | 112 | if self.norm is not None: 113 | output = self.norm(output) 114 | 115 | return output 116 | 117 | 118 | class TransformerDecoder(nn.Module): 119 | 120 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False, d_model=256): 121 | super().__init__() 122 | self.layers = _get_clones(decoder_layer, num_layers) 123 | self.num_layers = num_layers 124 | self.norm = norm 125 | self.return_intermediate = return_intermediate 126 | self.query_scale = MLP(d_model, d_model, d_model, 2) 127 | self.ref_point_head = MLP(d_model, d_model, 2, 2) 128 | for layer_id in range(num_layers - 1): 129 | self.layers[layer_id + 1].ca_qpos_proj = None 130 | 131 | def forward(self, tgt, memory, 132 | tgt_mask: Optional[Tensor] = None, 133 | memory_mask: Optional[Tensor] = None, 134 | tgt_key_padding_mask: Optional[Tensor] = None, 135 | memory_key_padding_mask: Optional[Tensor] = None, 136 | pos: Optional[Tensor] = None, 137 | query_pos: Optional[Tensor] = None): 138 | output = tgt 139 | 140 | intermediate = [] 141 | reference_points_before_sigmoid = self.ref_point_head(query_pos) # [num_queries, batch_size, 2] 142 | reference_points = reference_points_before_sigmoid.sigmoid().transpose(0, 1) 143 | 144 | for layer_id, layer in enumerate(self.layers): 145 | obj_center = reference_points[..., :2].transpose(0, 1) # [num_queries, batch_size, 2] 146 | 147 | # For the first decoder layer, we do not apply transformation over p_s 148 | if layer_id == 0: 149 | pos_transformation = 1 150 | else: 151 | pos_transformation = self.query_scale(output) 152 | 153 | # get sine embedding for the query vector 154 | query_sine_embed = gen_sineembed_for_position(obj_center) 155 | # apply transformation 156 | query_sine_embed = query_sine_embed * pos_transformation 157 | output = layer(output, memory, tgt_mask=tgt_mask, 158 | memory_mask=memory_mask, 159 | tgt_key_padding_mask=tgt_key_padding_mask, 160 | memory_key_padding_mask=memory_key_padding_mask, 161 | pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed, 162 | is_first=(layer_id == 0)) 163 | if self.return_intermediate: 164 | intermediate.append(self.norm(output)) 165 | 166 | if self.norm is not None: 167 | output = self.norm(output) 168 | if self.return_intermediate: 169 | intermediate.pop() 170 | intermediate.append(output) 171 | 172 | if self.return_intermediate: 173 | return [torch.stack(intermediate).transpose(1, 2), reference_points] 174 | 175 | return output.unsqueeze(0) 176 | 177 | 178 | class TransformerEncoderLayer(nn.Module): 179 | 180 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 181 | activation="relu", normalize_before=False): 182 | super().__init__() 183 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 184 | # Implementation of Feedforward model 185 | self.linear1 = nn.Linear(d_model, dim_feedforward) 186 | self.dropout = nn.Dropout(dropout) 187 | self.linear2 = nn.Linear(dim_feedforward, d_model) 188 | 189 | self.norm1 = nn.LayerNorm(d_model) 190 | self.norm2 = nn.LayerNorm(d_model) 191 | self.dropout1 = nn.Dropout(dropout) 192 | self.dropout2 = nn.Dropout(dropout) 193 | 194 | self.activation = _get_activation_fn(activation) 195 | self.normalize_before = normalize_before 196 | 197 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 198 | return tensor if pos is None else tensor + pos 199 | 200 | def forward_post(self, 201 | src, 202 | src_mask: Optional[Tensor] = None, 203 | src_key_padding_mask: Optional[Tensor] = None, 204 | pos: Optional[Tensor] = None): 205 | q = k = self.with_pos_embed(src, pos) 206 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 207 | key_padding_mask=src_key_padding_mask)[0] 208 | src = src + self.dropout1(src2) 209 | src = self.norm1(src) 210 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 211 | src = src + self.dropout2(src2) 212 | src = self.norm2(src) 213 | return src 214 | 215 | def forward_pre(self, src, 216 | src_mask: Optional[Tensor] = None, 217 | src_key_padding_mask: Optional[Tensor] = None, 218 | pos: Optional[Tensor] = None): 219 | src2 = self.norm1(src) 220 | q = k = self.with_pos_embed(src2, pos) 221 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 222 | key_padding_mask=src_key_padding_mask)[0] 223 | src = src + self.dropout1(src2) 224 | src2 = self.norm2(src) 225 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 226 | src = src + self.dropout2(src2) 227 | return src 228 | 229 | def forward(self, src, 230 | src_mask: Optional[Tensor] = None, 231 | src_key_padding_mask: Optional[Tensor] = None, 232 | pos: Optional[Tensor] = None): 233 | if self.normalize_before: 234 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 235 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 236 | 237 | 238 | class TransformerDecoderLayer(nn.Module): 239 | 240 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 241 | activation="relu", normalize_before=False): 242 | super().__init__() 243 | # Decoder Self-Attention 244 | self.sa_qcontent_proj = nn.Linear(d_model, d_model) 245 | self.sa_qpos_proj = nn.Linear(d_model, d_model) 246 | self.sa_kcontent_proj = nn.Linear(d_model, d_model) 247 | self.sa_kpos_proj = nn.Linear(d_model, d_model) 248 | self.sa_v_proj = nn.Linear(d_model, d_model) 249 | self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, vdim=d_model) 250 | 251 | # Decoder Cross-Attention 252 | self.ca_qcontent_proj = nn.Linear(d_model, d_model) 253 | self.ca_qpos_proj = nn.Linear(d_model, d_model) 254 | self.ca_kcontent_proj = nn.Linear(d_model, d_model) 255 | self.ca_kpos_proj = nn.Linear(d_model, d_model) 256 | self.ca_v_proj = nn.Linear(d_model, d_model) 257 | self.ca_qpos_sine_proj = nn.Linear(d_model, d_model) 258 | self.cross_attn = MultiheadAttention(d_model*2, nhead, dropout=dropout, vdim=d_model) 259 | 260 | self.nhead = nhead 261 | 262 | # Implementation of Feedforward model 263 | self.linear1 = nn.Linear(d_model, dim_feedforward) 264 | self.dropout = nn.Dropout(dropout) 265 | self.linear2 = nn.Linear(dim_feedforward, d_model) 266 | 267 | self.norm1 = nn.LayerNorm(d_model) 268 | self.norm2 = nn.LayerNorm(d_model) 269 | self.norm3 = nn.LayerNorm(d_model) 270 | self.dropout1 = nn.Dropout(dropout) 271 | self.dropout2 = nn.Dropout(dropout) 272 | self.dropout3 = nn.Dropout(dropout) 273 | 274 | self.activation = _get_activation_fn(activation) 275 | self.normalize_before = normalize_before 276 | 277 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 278 | return tensor if pos is None else tensor + pos 279 | 280 | def forward_post(self, tgt, memory, 281 | tgt_mask: Optional[Tensor] = None, 282 | memory_mask: Optional[Tensor] = None, 283 | tgt_key_padding_mask: Optional[Tensor] = None, 284 | memory_key_padding_mask: Optional[Tensor] = None, 285 | pos: Optional[Tensor] = None, 286 | query_pos: Optional[Tensor] = None, 287 | query_sine_embed = None, 288 | is_first = False): 289 | 290 | # ========== Begin of Self-Attention ============= 291 | # Apply projections here 292 | # shape: num_queries x batch_size x 256 293 | q_content = self.sa_qcontent_proj(tgt) # target is the input of the first decoder layer. zero by default. 294 | q_pos = self.sa_qpos_proj(query_pos) 295 | k_content = self.sa_kcontent_proj(tgt) 296 | k_pos = self.sa_kpos_proj(query_pos) 297 | v = self.sa_v_proj(tgt) 298 | 299 | num_queries, bs, n_model = q_content.shape 300 | hw, _, _ = k_content.shape 301 | 302 | q = q_content + q_pos 303 | k = k_content + k_pos 304 | 305 | tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask, 306 | key_padding_mask=tgt_key_padding_mask)[0] 307 | # ========== End of Self-Attention ============= 308 | 309 | tgt = tgt + self.dropout1(tgt2) 310 | tgt = self.norm1(tgt) 311 | 312 | # ========== Begin of Cross-Attention ============= 313 | # Apply projections here 314 | # shape: num_queries x batch_size x 256 315 | q_content = self.ca_qcontent_proj(tgt) 316 | k_content = self.ca_kcontent_proj(memory) 317 | v = self.ca_v_proj(memory) 318 | 319 | num_queries, bs, n_model = q_content.shape 320 | hw, _, _ = k_content.shape 321 | 322 | k_pos = self.ca_kpos_proj(pos) 323 | 324 | # For the first decoder layer, we concatenate the positional embedding predicted from 325 | # the object query (the positional embedding) into the original query (key) in DETR. 326 | if is_first: 327 | q_pos = self.ca_qpos_proj(query_pos) 328 | q = q_content + q_pos 329 | k = k_content + k_pos 330 | else: 331 | q = q_content 332 | k = k_content 333 | 334 | q = q.view(num_queries, bs, self.nhead, n_model//self.nhead) 335 | query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed) 336 | query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead) 337 | q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2) 338 | k = k.view(hw, bs, self.nhead, n_model//self.nhead) 339 | k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead) 340 | k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2) 341 | 342 | tgt2 = self.cross_attn(query=q, 343 | key=k, 344 | value=v, attn_mask=memory_mask, 345 | key_padding_mask=memory_key_padding_mask)[0] 346 | # ========== End of Cross-Attention ============= 347 | 348 | tgt = tgt + self.dropout2(tgt2) 349 | tgt = self.norm2(tgt) 350 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 351 | tgt = tgt + self.dropout3(tgt2) 352 | tgt = self.norm3(tgt) 353 | return tgt 354 | 355 | def forward_pre(self, tgt, memory, 356 | tgt_mask: Optional[Tensor] = None, 357 | memory_mask: Optional[Tensor] = None, 358 | tgt_key_padding_mask: Optional[Tensor] = None, 359 | memory_key_padding_mask: Optional[Tensor] = None, 360 | pos: Optional[Tensor] = None, 361 | query_pos: Optional[Tensor] = None): 362 | tgt2 = self.norm1(tgt) 363 | q = k = self.with_pos_embed(tgt2, query_pos) 364 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 365 | key_padding_mask=tgt_key_padding_mask)[0] 366 | tgt = tgt + self.dropout1(tgt2) 367 | tgt2 = self.norm2(tgt) 368 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 369 | key=self.with_pos_embed(memory, pos), 370 | value=memory, attn_mask=memory_mask, 371 | key_padding_mask=memory_key_padding_mask)[0] 372 | tgt = tgt + self.dropout2(tgt2) 373 | tgt2 = self.norm3(tgt) 374 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 375 | tgt = tgt + self.dropout3(tgt2) 376 | return tgt 377 | 378 | def forward(self, tgt, memory, 379 | tgt_mask: Optional[Tensor] = None, 380 | memory_mask: Optional[Tensor] = None, 381 | tgt_key_padding_mask: Optional[Tensor] = None, 382 | memory_key_padding_mask: Optional[Tensor] = None, 383 | pos: Optional[Tensor] = None, 384 | query_pos: Optional[Tensor] = None, 385 | query_sine_embed = None, 386 | is_first = False): 387 | if self.normalize_before: 388 | raise NotImplementedError 389 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 390 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 391 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 392 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos, query_sine_embed, is_first) 393 | 394 | 395 | def _get_clones(module, N): 396 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 397 | 398 | 399 | def build_transformer(args): 400 | return Transformer( 401 | d_model=args.hidden_dim, 402 | dropout=args.dropout, 403 | nhead=args.nheads, 404 | num_queries=args.num_queries, 405 | dim_feedforward=args.dim_feedforward, 406 | num_encoder_layers=args.enc_layers, 407 | num_decoder_layers=args.dec_layers, 408 | normalize_before=args.pre_norm, 409 | return_intermediate_dec=True, 410 | ) 411 | 412 | 413 | def _get_activation_fn(activation): 414 | """Return an activation function given a string""" 415 | if activation == "relu": 416 | return F.relu 417 | if activation == "gelu": 418 | return F.gelu 419 | if activation == "glu": 420 | return F.glu 421 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 422 | -------------------------------------------------------------------------------- /models/conditional_detr.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR model and criterion classes. 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 8 | # ------------------------------------------------------------------------ 9 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 10 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 11 | # ------------------------------------------------------------------------ 12 | 13 | import math 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn 17 | 18 | from util import box_ops 19 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 20 | accuracy, get_world_size, interpolate, 21 | is_dist_avail_and_initialized, inverse_sigmoid) 22 | 23 | from .backbone import build_backbone 24 | from .matcher import build_matcher 25 | from .segmentation import (DETRsegm, PostProcessPanoptic, PostProcessSegm, 26 | dice_loss, sigmoid_focal_loss) 27 | from .transformer import build_transformer 28 | 29 | 30 | class ConditionalDETR(nn.Module): 31 | """ This is the Conditional DETR module that performs object detection """ 32 | def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False): 33 | """ Initializes the model. 34 | Parameters: 35 | backbone: torch module of the backbone to be used. See backbone.py 36 | transformer: torch module of the transformer architecture. See transformer.py 37 | num_classes: number of object classes 38 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 39 | Conditional DETR can detect in a single image. For COCO, we recommend 100 queries. 40 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 41 | """ 42 | super().__init__() 43 | self.num_queries = num_queries 44 | self.transformer = transformer 45 | hidden_dim = transformer.d_model 46 | self.class_embed = nn.Linear(hidden_dim, num_classes) 47 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 48 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 49 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 50 | self.backbone = backbone 51 | self.aux_loss = aux_loss 52 | 53 | # init prior_prob setting for focal loss 54 | prior_prob = 0.01 55 | bias_value = -math.log((1 - prior_prob) / prior_prob) 56 | self.class_embed.bias.data = torch.ones(num_classes) * bias_value 57 | 58 | # init bbox_mebed 59 | nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) 60 | nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) 61 | 62 | def forward(self, samples: NestedTensor): 63 | """ The forward expects a NestedTensor, which consists of: 64 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 65 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 66 | 67 | It returns a dict with the following elements: 68 | - "pred_logits": the classification logits (including no-object) for all queries. 69 | Shape= [batch_size x num_queries x num_classes] 70 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 71 | (center_x, center_y, width, height). These values are normalized in [0, 1], 72 | relative to the size of each individual image (disregarding possible padding). 73 | See PostProcess for information on how to retrieve the unnormalized bounding box. 74 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 75 | dictionnaries containing the two above keys for each decoder layer. 76 | """ 77 | if isinstance(samples, (list, torch.Tensor)): 78 | samples = nested_tensor_from_tensor_list(samples) 79 | features, pos = self.backbone(samples) 80 | 81 | src, mask = features[-1].decompose() 82 | assert mask is not None 83 | hs, reference = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1]) 84 | 85 | reference_before_sigmoid = inverse_sigmoid(reference) 86 | outputs_coords = [] 87 | for lvl in range(hs.shape[0]): 88 | tmp = self.bbox_embed(hs[lvl]) 89 | tmp[..., :2] += reference_before_sigmoid 90 | outputs_coord = tmp.sigmoid() 91 | outputs_coords.append(outputs_coord) 92 | outputs_coord = torch.stack(outputs_coords) 93 | 94 | outputs_class = self.class_embed(hs) 95 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} 96 | if self.aux_loss: 97 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) 98 | return out 99 | 100 | @torch.jit.unused 101 | def _set_aux_loss(self, outputs_class, outputs_coord): 102 | # this is a workaround to make torchscript happy, as torchscript 103 | # doesn't support dictionary with non-homogeneous values, such 104 | # as a dict having both a Tensor and a list. 105 | return [{'pred_logits': a, 'pred_boxes': b} 106 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] 107 | 108 | 109 | class SetCriterion(nn.Module): 110 | """ This class computes the loss for Conditional DETR. 111 | The process happens in two steps: 112 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 113 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 114 | """ 115 | def __init__(self, num_classes, matcher, weight_dict, focal_alpha, losses): 116 | """ Create the criterion. 117 | Parameters: 118 | num_classes: number of object categories, omitting the special no-object category 119 | matcher: module able to compute a matching between targets and proposals 120 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 121 | losses: list of all the losses to be applied. See get_loss for list of available losses. 122 | focal_alpha: alpha in Focal Loss 123 | """ 124 | super().__init__() 125 | self.num_classes = num_classes 126 | self.matcher = matcher 127 | self.weight_dict = weight_dict 128 | self.losses = losses 129 | self.focal_alpha = focal_alpha 130 | 131 | 132 | def loss_labels(self, outputs, targets, indices, num_boxes, log=True): 133 | """Classification loss (Binary focal loss) 134 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 135 | """ 136 | assert 'pred_logits' in outputs 137 | src_logits = outputs['pred_logits'] 138 | 139 | idx = self._get_src_permutation_idx(indices) 140 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 141 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, 142 | dtype=torch.int64, device=src_logits.device) 143 | target_classes[idx] = target_classes_o 144 | 145 | target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2]+1], 146 | dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) 147 | target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) 148 | 149 | target_classes_onehot = target_classes_onehot[:,:,:-1] 150 | loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] 151 | losses = {'loss_ce': loss_ce} 152 | 153 | if log: 154 | # TODO this should probably be a separate loss, not hacked in this one here 155 | losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] 156 | return losses 157 | 158 | @torch.no_grad() 159 | def loss_cardinality(self, outputs, targets, indices, num_boxes): 160 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 161 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 162 | """ 163 | pred_logits = outputs['pred_logits'] 164 | device = pred_logits.device 165 | tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) 166 | # Count the number of predictions that are NOT "no-object" (which is the last class) 167 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 168 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 169 | losses = {'cardinality_error': card_err} 170 | return losses 171 | 172 | def loss_boxes(self, outputs, targets, indices, num_boxes): 173 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 174 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 175 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 176 | """ 177 | assert 'pred_boxes' in outputs 178 | idx = self._get_src_permutation_idx(indices) 179 | src_boxes = outputs['pred_boxes'][idx] 180 | target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) 181 | 182 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 183 | 184 | losses = {} 185 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes 186 | 187 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( 188 | box_ops.box_cxcywh_to_xyxy(src_boxes), 189 | box_ops.box_cxcywh_to_xyxy(target_boxes))) 190 | losses['loss_giou'] = loss_giou.sum() / num_boxes 191 | return losses 192 | 193 | def loss_masks(self, outputs, targets, indices, num_boxes): 194 | """Compute the losses related to the masks: the focal loss and the dice loss. 195 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 196 | """ 197 | assert "pred_masks" in outputs 198 | 199 | src_idx = self._get_src_permutation_idx(indices) 200 | tgt_idx = self._get_tgt_permutation_idx(indices) 201 | src_masks = outputs["pred_masks"] 202 | src_masks = src_masks[src_idx] 203 | masks = [t["masks"] for t in targets] 204 | # TODO use valid to mask invalid areas due to padding in loss 205 | target_masks, valid = nested_tensor_from_tensor_list(masks).decompose() 206 | target_masks = target_masks.to(src_masks) 207 | target_masks = target_masks[tgt_idx] 208 | 209 | # upsample predictions to the target size 210 | src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], 211 | mode="bilinear", align_corners=False) 212 | src_masks = src_masks[:, 0].flatten(1) 213 | 214 | target_masks = target_masks.flatten(1) 215 | target_masks = target_masks.view(src_masks.shape) 216 | losses = { 217 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), 218 | "loss_dice": dice_loss(src_masks, target_masks, num_boxes), 219 | } 220 | return losses 221 | 222 | def _get_src_permutation_idx(self, indices): 223 | # permute predictions following indices 224 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 225 | src_idx = torch.cat([src for (src, _) in indices]) 226 | return batch_idx, src_idx 227 | 228 | def _get_tgt_permutation_idx(self, indices): 229 | # permute targets following indices 230 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 231 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 232 | return batch_idx, tgt_idx 233 | 234 | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): 235 | loss_map = { 236 | 'labels': self.loss_labels, 237 | 'cardinality': self.loss_cardinality, 238 | 'boxes': self.loss_boxes, 239 | 'masks': self.loss_masks 240 | } 241 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 242 | return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) 243 | 244 | def forward(self, outputs, targets): 245 | """ This performs the loss computation. 246 | Parameters: 247 | outputs: dict of tensors, see the output specification of the model for the format 248 | targets: list of dicts, such that len(targets) == batch_size. 249 | The expected keys in each dict depends on the losses applied, see each loss' doc 250 | """ 251 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} 252 | 253 | # Retrieve the matching between the outputs of the last layer and the targets 254 | indices = self.matcher(outputs_without_aux, targets) 255 | 256 | # Compute the average number of target boxes accross all nodes, for normalization purposes 257 | num_boxes = sum(len(t["labels"]) for t in targets) 258 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) 259 | if is_dist_avail_and_initialized(): 260 | torch.distributed.all_reduce(num_boxes) 261 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 262 | 263 | # Compute all the requested losses 264 | losses = {} 265 | for loss in self.losses: 266 | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) 267 | 268 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 269 | if 'aux_outputs' in outputs: 270 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 271 | indices = self.matcher(aux_outputs, targets) 272 | for loss in self.losses: 273 | if loss == 'masks': 274 | # Intermediate masks losses are too costly to compute, we ignore them. 275 | continue 276 | kwargs = {} 277 | if loss == 'labels': 278 | # Logging is enabled only for the last layer 279 | kwargs = {'log': False} 280 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) 281 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 282 | losses.update(l_dict) 283 | 284 | return losses 285 | 286 | 287 | class PostProcess(nn.Module): 288 | """ This module converts the model's output into the format expected by the coco api""" 289 | @torch.no_grad() 290 | def forward(self, outputs, target_sizes): 291 | """ Perform the computation 292 | Parameters: 293 | outputs: raw outputs of the model 294 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 295 | For evaluation, this must be the original image size (before any data augmentation) 296 | For visualization, this should be the image size after data augment, but before padding 297 | """ 298 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 299 | 300 | assert len(out_logits) == len(target_sizes) 301 | assert target_sizes.shape[1] == 2 302 | 303 | prob = out_logits.sigmoid() 304 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), 100, dim=1) 305 | scores = topk_values 306 | topk_boxes = topk_indexes // out_logits.shape[2] 307 | labels = topk_indexes % out_logits.shape[2] 308 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) 309 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) 310 | 311 | # and from relative [0, 1] to absolute [0, height] coordinates 312 | img_h, img_w = target_sizes.unbind(1) 313 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 314 | boxes = boxes * scale_fct[:, None, :] 315 | 316 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] 317 | 318 | return results 319 | 320 | 321 | class MLP(nn.Module): 322 | """ Very simple multi-layer perceptron (also called FFN)""" 323 | 324 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 325 | super().__init__() 326 | self.num_layers = num_layers 327 | h = [hidden_dim] * (num_layers - 1) 328 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 329 | 330 | def forward(self, x): 331 | for i, layer in enumerate(self.layers): 332 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 333 | return x 334 | 335 | 336 | def build(args): 337 | # the `num_classes` naming here is somewhat misleading. 338 | # it indeed corresponds to `max_obj_id + 1`, where max_obj_id 339 | # is the maximum id for a class in your dataset. For example, 340 | # COCO has a max_obj_id of 90, so we pass `num_classes` to be 91. 341 | # As another example, for a dataset that has a single class with id 1, 342 | # you should pass `num_classes` to be 2 (max_obj_id + 1). 343 | # For more details on this, check the following discussion 344 | # https://github.com/facebookresearch/detr/issues/108#issuecomment-650269223 345 | num_classes = 20 if args.dataset_file != 'coco' else 91 346 | if args.dataset_file == "coco_panoptic": 347 | # for panoptic, we just add a num_classes that is large enough to hold 348 | # max_obj_id + 1, but the exact value doesn't really matter 349 | num_classes = 250 350 | device = torch.device(args.device) 351 | 352 | backbone = build_backbone(args) 353 | 354 | transformer = build_transformer(args) 355 | 356 | model = ConditionalDETR( 357 | backbone, 358 | transformer, 359 | num_classes=num_classes, 360 | num_queries=args.num_queries, 361 | aux_loss=args.aux_loss, 362 | ) 363 | if args.masks: 364 | model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) 365 | matcher = build_matcher(args) 366 | weight_dict = {'loss_ce': args.cls_loss_coef, 'loss_bbox': args.bbox_loss_coef} 367 | weight_dict['loss_giou'] = args.giou_loss_coef 368 | if args.masks: 369 | weight_dict["loss_mask"] = args.mask_loss_coef 370 | weight_dict["loss_dice"] = args.dice_loss_coef 371 | # TODO this is a hack 372 | if args.aux_loss: 373 | aux_weight_dict = {} 374 | for i in range(args.dec_layers - 1): 375 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 376 | weight_dict.update(aux_weight_dict) 377 | 378 | losses = ['labels', 'boxes', 'cardinality'] 379 | if args.masks: 380 | losses += ["masks"] 381 | criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, 382 | focal_alpha=args.focal_alpha, losses=losses) 383 | criterion.to(device) 384 | postprocessors = {'bbox': PostProcess()} 385 | if args.masks: 386 | postprocessors['segm'] = PostProcessSegm() 387 | if args.dataset_file == "coco_panoptic": 388 | is_thing_map = {i: i <= 90 for i in range(201)} 389 | postprocessors["panoptic"] = PostProcessPanoptic(is_thing_map, threshold=0.85) 390 | 391 | return model, criterion, postprocessors 392 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Conditional DETR 3 | # Copyright (c) 2021 Microsoft. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from codes in torch.nn 7 | # ------------------------------------------------------------------------ 8 | 9 | """ 10 | MultiheadAttention that support query, key, and value to have different dimensions. 11 | Query, key, and value projections are removed. 12 | 13 | Mostly copy-paste from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/activation.py#L873 14 | and https://github.com/pytorch/pytorch/blob/master/torch/nn/functional.py#L4837 15 | """ 16 | 17 | import copy 18 | from typing import Optional, List 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | from torch import nn, Tensor 23 | 24 | import warnings 25 | from typing import Tuple, Optional 26 | 27 | import torch 28 | from torch import Tensor 29 | if float(torch.__version__.split('.')[0]) == 0 or (float(torch.__version__.split('.')[0]) == 1 and float(torch.__version__.split('.')[1])) < 9: 30 | from torch.nn.modules.linear import _LinearWithBias 31 | else: 32 | from torch.nn.modules.linear import NonDynamicallyQuantizableLinear as _LinearWithBias 33 | from torch.nn.init import xavier_uniform_ 34 | from torch.nn.init import constant_ 35 | from torch.nn.init import xavier_normal_ 36 | from torch.nn.parameter import Parameter 37 | from torch.nn.modules.module import Module 38 | from torch.nn import functional as F 39 | 40 | import warnings 41 | import math 42 | 43 | from torch._C import _infer_size, _add_docstr 44 | from torch.nn import _reduction as _Reduction 45 | from torch.nn.modules import utils 46 | from torch.nn.modules.utils import _single, _pair, _triple, _list_with_default 47 | from torch.nn import grad 48 | from torch import _VF 49 | from torch._jit_internal import boolean_dispatch, List, Optional, _overload, Tuple 50 | if float(torch.__version__.split('.')[0]) == 0 or (float(torch.__version__.split('.')[0]) == 1 and float(torch.__version__.split('.')[1])) < 7: 51 | from torch._overrides import has_torch_function, handle_torch_function 52 | else: 53 | from torch.overrides import has_torch_function, handle_torch_function 54 | Tensor = torch.Tensor 55 | 56 | from torch.nn.functional import linear, pad, softmax, dropout 57 | 58 | class MultiheadAttention(Module): 59 | r"""Allows the model to jointly attend to information 60 | from different representation subspaces. 61 | See reference: Attention Is All You Need 62 | .. math:: 63 | \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O 64 | \text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) 65 | Args: 66 | embed_dim: total dimension of the model. 67 | num_heads: parallel attention heads. 68 | dropout: a Dropout layer on attn_output_weights. Default: 0.0. 69 | bias: add bias as module parameter. Default: True. 70 | add_bias_kv: add bias to the key and value sequences at dim=0. 71 | add_zero_attn: add a new batch of zeros to the key and 72 | value sequences at dim=1. 73 | kdim: total number of features in key. Default: None. 74 | vdim: total number of features in value. Default: None. 75 | Note: if kdim and vdim are None, they will be set to embed_dim such that 76 | query, key, and value have the same number of features. 77 | Examples:: 78 | >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) 79 | >>> attn_output, attn_output_weights = multihead_attn(query, key, value) 80 | """ 81 | bias_k: Optional[torch.Tensor] 82 | bias_v: Optional[torch.Tensor] 83 | 84 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None): 85 | super(MultiheadAttention, self).__init__() 86 | self.embed_dim = embed_dim 87 | self.kdim = kdim if kdim is not None else embed_dim 88 | self.vdim = vdim if vdim is not None else embed_dim 89 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 90 | 91 | self.num_heads = num_heads 92 | self.dropout = dropout 93 | self.head_dim = embed_dim // num_heads 94 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 95 | 96 | self.out_proj = _LinearWithBias(vdim, vdim) 97 | 98 | self.in_proj_bias = None 99 | self.in_proj_weight = None 100 | self.bias_k = self.bias_v = None 101 | self.q_proj_weight = None 102 | self.k_proj_weight = None 103 | self.v_proj_weight = None 104 | 105 | self.add_zero_attn = add_zero_attn 106 | 107 | self._reset_parameters() 108 | 109 | def _reset_parameters(self): 110 | constant_(self.out_proj.bias, 0.) 111 | 112 | def __setstate__(self, state): 113 | # Support loading old MultiheadAttention checkpoints generated by v1.1.0 114 | if '_qkv_same_embed_dim' not in state: 115 | state['_qkv_same_embed_dim'] = True 116 | 117 | super(MultiheadAttention, self).__setstate__(state) 118 | 119 | def forward(self, query, key, value, key_padding_mask=None, 120 | need_weights=True, attn_mask=None): 121 | # type: (Tensor, Tensor, Tensor, Optional[Tensor], bool, Optional[Tensor]) -> Tuple[Tensor, Optional[Tensor]] 122 | r""" 123 | Args: 124 | query, key, value: map a query and a set of key-value pairs to an output. 125 | See "Attention Is All You Need" for more details. 126 | key_padding_mask: if provided, specified padding elements in the key will 127 | be ignored by the attention. When given a binary mask and a value is True, 128 | the corresponding value on the attention layer will be ignored. When given 129 | a byte mask and a value is non-zero, the corresponding value on the attention 130 | layer will be ignored 131 | need_weights: output attn_output_weights. 132 | attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 133 | the batches while a 3D mask allows to specify a different mask for the entries of each batch. 134 | Shape: 135 | - Inputs: 136 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 137 | the embedding dimension. 138 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 139 | the embedding dimension. 140 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 141 | the embedding dimension. 142 | - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. 143 | If a ByteTensor is provided, the non-zero positions will be ignored while the position 144 | with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the 145 | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 146 | - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 147 | 3D mask :math:`(N*\text{num_heads}, L, S)` where N is the batch size, L is the target sequence length, 148 | S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked 149 | positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend 150 | while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` 151 | is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 152 | is provided, it will be added to the attention weight. 153 | - Outputs: 154 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 155 | E is the embedding dimension. 156 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 157 | L is the target sequence length, S is the source sequence length. 158 | """ 159 | if not self._qkv_same_embed_dim: 160 | return multi_head_attention_forward( 161 | query, key, value, self.embed_dim, self.num_heads, 162 | self.in_proj_weight, self.in_proj_bias, 163 | self.bias_k, self.bias_v, self.add_zero_attn, 164 | self.dropout, self.out_proj.weight, self.out_proj.bias, 165 | training=self.training, 166 | key_padding_mask=key_padding_mask, need_weights=need_weights, 167 | attn_mask=attn_mask, use_separate_proj_weight=True, 168 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 169 | v_proj_weight=self.v_proj_weight, out_dim=self.vdim) 170 | else: 171 | return multi_head_attention_forward( 172 | query, key, value, self.embed_dim, self.num_heads, 173 | self.in_proj_weight, self.in_proj_bias, 174 | self.bias_k, self.bias_v, self.add_zero_attn, 175 | self.dropout, self.out_proj.weight, self.out_proj.bias, 176 | training=self.training, 177 | key_padding_mask=key_padding_mask, need_weights=need_weights, 178 | attn_mask=attn_mask, out_dim=self.vdim) 179 | 180 | 181 | def multi_head_attention_forward(query: Tensor, 182 | key: Tensor, 183 | value: Tensor, 184 | embed_dim_to_check: int, 185 | num_heads: int, 186 | in_proj_weight: Tensor, 187 | in_proj_bias: Tensor, 188 | bias_k: Optional[Tensor], 189 | bias_v: Optional[Tensor], 190 | add_zero_attn: bool, 191 | dropout_p: float, 192 | out_proj_weight: Tensor, 193 | out_proj_bias: Tensor, 194 | training: bool = True, 195 | key_padding_mask: Optional[Tensor] = None, 196 | need_weights: bool = True, 197 | attn_mask: Optional[Tensor] = None, 198 | use_separate_proj_weight: bool = False, 199 | q_proj_weight: Optional[Tensor] = None, 200 | k_proj_weight: Optional[Tensor] = None, 201 | v_proj_weight: Optional[Tensor] = None, 202 | static_k: Optional[Tensor] = None, 203 | static_v: Optional[Tensor] = None, 204 | out_dim: Optional[Tensor] = None 205 | ) -> Tuple[Tensor, Optional[Tensor]]: 206 | r""" 207 | Args: 208 | query, key, value: map a query and a set of key-value pairs to an output. 209 | See "Attention Is All You Need" for more details. 210 | embed_dim_to_check: total dimension of the model. 211 | num_heads: parallel attention heads. 212 | in_proj_weight, in_proj_bias: input projection weight and bias. 213 | bias_k, bias_v: bias of the key and value sequences to be added at dim=0. 214 | add_zero_attn: add a new batch of zeros to the key and 215 | value sequences at dim=1. 216 | dropout_p: probability of an element to be zeroed. 217 | out_proj_weight, out_proj_bias: the output projection weight and bias. 218 | training: apply dropout if is ``True``. 219 | key_padding_mask: if provided, specified padding elements in the key will 220 | be ignored by the attention. This is an binary mask. When the value is True, 221 | the corresponding value on the attention layer will be filled with -inf. 222 | need_weights: output attn_output_weights. 223 | attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all 224 | the batches while a 3D mask allows to specify a different mask for the entries of each batch. 225 | use_separate_proj_weight: the function accept the proj. weights for query, key, 226 | and value in different forms. If false, in_proj_weight will be used, which is 227 | a combination of q_proj_weight, k_proj_weight, v_proj_weight. 228 | q_proj_weight, k_proj_weight, v_proj_weight, in_proj_bias: input projection weight and bias. 229 | static_k, static_v: static key and value used for attention operators. 230 | Shape: 231 | Inputs: 232 | - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is 233 | the embedding dimension. 234 | - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is 235 | the embedding dimension. 236 | - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is 237 | the embedding dimension. 238 | - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. 239 | If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions 240 | will be unchanged. If a BoolTensor is provided, the positions with the 241 | value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. 242 | - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. 243 | 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, 244 | S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked 245 | positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend 246 | while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` 247 | are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor 248 | is provided, it will be added to the attention weight. 249 | - static_k: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 250 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 251 | - static_v: :math:`(N*num_heads, S, E/num_heads)`, where S is the source sequence length, 252 | N is the batch size, E is the embedding dimension. E/num_heads is the head dimension. 253 | Outputs: 254 | - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, 255 | E is the embedding dimension. 256 | - attn_output_weights: :math:`(N, L, S)` where N is the batch size, 257 | L is the target sequence length, S is the source sequence length. 258 | """ 259 | if not torch.jit.is_scripting(): 260 | tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, 261 | out_proj_weight, out_proj_bias) 262 | if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): 263 | return handle_torch_function( 264 | multi_head_attention_forward, tens_ops, query, key, value, 265 | embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, 266 | bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, 267 | out_proj_bias, training=training, key_padding_mask=key_padding_mask, 268 | need_weights=need_weights, attn_mask=attn_mask, 269 | use_separate_proj_weight=use_separate_proj_weight, 270 | q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, 271 | v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) 272 | tgt_len, bsz, embed_dim = query.size() 273 | assert embed_dim == embed_dim_to_check 274 | # allow MHA to have different sizes for the feature dimension 275 | assert key.size(0) == value.size(0) and key.size(1) == value.size(1) 276 | 277 | head_dim = embed_dim // num_heads 278 | v_head_dim = out_dim // num_heads 279 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 280 | scaling = float(head_dim) ** -0.5 281 | 282 | q = query * scaling 283 | k = key 284 | v = value 285 | 286 | if attn_mask is not None: 287 | assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \ 288 | attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \ 289 | 'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype) 290 | if attn_mask.dtype == torch.uint8: 291 | warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 292 | attn_mask = attn_mask.to(torch.bool) 293 | 294 | if attn_mask.dim() == 2: 295 | attn_mask = attn_mask.unsqueeze(0) 296 | if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: 297 | raise RuntimeError('The size of the 2D attn_mask is not correct.') 298 | elif attn_mask.dim() == 3: 299 | if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: 300 | raise RuntimeError('The size of the 3D attn_mask is not correct.') 301 | else: 302 | raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) 303 | # attn_mask's dim is 3 now. 304 | 305 | # convert ByteTensor key_padding_mask to bool 306 | if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: 307 | warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.") 308 | key_padding_mask = key_padding_mask.to(torch.bool) 309 | 310 | if bias_k is not None and bias_v is not None: 311 | if static_k is None and static_v is None: 312 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 313 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 314 | if attn_mask is not None: 315 | attn_mask = pad(attn_mask, (0, 1)) 316 | if key_padding_mask is not None: 317 | key_padding_mask = pad(key_padding_mask, (0, 1)) 318 | else: 319 | assert static_k is None, "bias cannot be added to static key." 320 | assert static_v is None, "bias cannot be added to static value." 321 | else: 322 | assert bias_k is None 323 | assert bias_v is None 324 | 325 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 326 | if k is not None: 327 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 328 | if v is not None: 329 | v = v.contiguous().view(-1, bsz * num_heads, v_head_dim).transpose(0, 1) 330 | 331 | if static_k is not None: 332 | assert static_k.size(0) == bsz * num_heads 333 | assert static_k.size(2) == head_dim 334 | k = static_k 335 | 336 | if static_v is not None: 337 | assert static_v.size(0) == bsz * num_heads 338 | assert static_v.size(2) == v_head_dim 339 | v = static_v 340 | 341 | src_len = k.size(1) 342 | 343 | if key_padding_mask is not None: 344 | assert key_padding_mask.size(0) == bsz 345 | assert key_padding_mask.size(1) == src_len 346 | 347 | if add_zero_attn: 348 | src_len += 1 349 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 350 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 351 | if attn_mask is not None: 352 | attn_mask = pad(attn_mask, (0, 1)) 353 | if key_padding_mask is not None: 354 | key_padding_mask = pad(key_padding_mask, (0, 1)) 355 | 356 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 357 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 358 | 359 | if attn_mask is not None: 360 | if attn_mask.dtype == torch.bool: 361 | attn_output_weights.masked_fill_(attn_mask, float('-inf')) 362 | else: 363 | attn_output_weights += attn_mask 364 | 365 | 366 | if key_padding_mask is not None: 367 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 368 | attn_output_weights = attn_output_weights.masked_fill( 369 | key_padding_mask.unsqueeze(1).unsqueeze(2), 370 | float('-inf'), 371 | ) 372 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 373 | 374 | attn_output_weights = softmax( 375 | attn_output_weights, dim=-1) 376 | attn_output_weights = dropout(attn_output_weights, p=dropout_p, training=training) 377 | 378 | attn_output = torch.bmm(attn_output_weights, v) 379 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, v_head_dim] 380 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, out_dim) 381 | attn_output = linear(attn_output, out_proj_weight, out_proj_bias) 382 | 383 | if need_weights: 384 | # average attention weights over heads 385 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 386 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 387 | else: 388 | return attn_output, None --------------------------------------------------------------------------------