├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── coco.py ├── coco_eval.py ├── coco_panoptic.py ├── panoptic_eval.py ├── transforms.py └── ytvos.py ├── engine.py ├── inference.py ├── main.py ├── models ├── __init__.py ├── backbone.py ├── dcn │ ├── __init__.py │ ├── deform_conv.py │ ├── deformable │ │ ├── deform_conv.cpp │ │ ├── deform_conv.h │ │ ├── deform_conv_cuda.cu │ │ └── deform_conv_cuda_kernel.cu │ ├── setup.py │ └── test_deform.py ├── matcher.py ├── position_encoding.py ├── segmentation.py ├── transformer.py └── vistr.py └── util ├── __init__.py ├── box_ops.py ├── detr_weights_to_vistr.py ├── misc.py └── plot_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2020 - present, Facebook, Inc 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## VisTR: End-to-End Video Instance Segmentation with Transformers 2 | 3 | This is the official implementation of the [VisTR paper](https://arxiv.org/abs/2011.14503): 4 | 5 |

6 | 7 |

8 | 9 | 10 | ### Installation 11 | We provide instructions how to install dependencies via conda. 12 | First, clone the repository locally: 13 | ``` 14 | git clone https://github.com/Epiphqny/vistr.git 15 | ``` 16 | Then, install PyTorch 1.6 and torchvision 0.7: 17 | ``` 18 | conda install pytorch==1.6.0 torchvision==0.7.0 19 | ``` 20 | Install pycocotools 21 | ``` 22 | conda install cython scipy 23 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 24 | pip install git+https://github.com/youtubevos/cocoapi.git#"egg=pycocotools&subdirectory=PythonAPI" 25 | ``` 26 | Compile DCN module(requires GCC>=5.3, cuda>=10.0) 27 | ``` 28 | cd models/dcn 29 | python setup.py build_ext --inplace 30 | ``` 31 | 32 | ### Preparation 33 | 34 | Download and extract 2019 version of YoutubeVIS train and val images with annotations from 35 | [CodeLab](https://competitions.codalab.org/competitions/20128#participate-get_data) or [YoutubeVIS](https://youtube-vos.org/dataset/vis/). 36 | We expect the directory structure to be the following: 37 | ``` 38 | VisTR 39 | ├── data 40 | │ ├── train 41 | │ ├── val 42 | │ ├── annotations 43 | │ │ ├── instances_train_sub.json 44 | │ │ ├── instances_val_sub.json 45 | ├── models 46 | ... 47 | ``` 48 | 49 | Download the pretrained DETR models [Google Drive](https://drive.google.com/drive/folders/1DlN8uWHT2WaKruarGW2_XChhpZeI9MFG?usp=sharing) [BaiduYun](https://pan.baidu.com/s/12omUNDRjhAeGZ5olqQPpHA)(passcode:alge) on COCO and save it to the pretrained path. 50 | 51 | 52 | ### Training 53 | 54 | Training of the model requires at least 32g memory GPU, we performed the experiment on 32g V100 card. (As the training resolution is limited by the GPU memory, if you have a larger memory GPU and want to perform the experiment, please contact with me, thanks very much) 55 | 56 | To train baseline VisTR on a single node with 8 gpus for 18 epochs, run: 57 | ``` 58 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --backbone resnet101/50 --ytvos_path /path/to/ytvos --masks --pretrained_weights /path/to/pretrained_path 59 | ``` 60 | 61 | ### Inference 62 | 63 | ``` 64 | python inference.py --masks --model_path /path/to/model_weights --save_path /path/to/results.json 65 | ``` 66 | 67 | ### Models 68 | 69 | We provide baseline VisTR models, and plan to include more in future. AP is computed on YouTubeVIS dataset by submitting the result json file to the [CodeLab](https://competitions.codalab.org/competitions/20128#results) system, and inference time is calculated by pure model inference time (without data-loading and post-processing). 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 112 | 113 |
namebackboneFPSmask APmodelresult json zipdetailed AP
0VisTRR5069.936.2vistr_r50.pth vistr_r50.zip

94 | 95 | 96 | 97 |

1VisTRR10157.740.1vistr_r101.pth vistr_r101.zip 108 |

109 | 110 |

111 |
114 | 115 | 116 | ### License 117 | 118 | VisTR is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information. 119 | 120 | ### Acknowledgement 121 | We would like to thank the [DETR](https://github.com/facebookresearch/detr) open-source project for its awesome work, part of the code are modified from its project. 122 | 123 | ### Citation 124 | 125 | Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follow. 126 | 127 | ``` 128 | @inproceedings{wang2020end, 129 | title={End-to-End Video Instance Segmentation with Transformers}, 130 | author={Wang, Yuqing and Xu, Zhaoliang and Wang, Xinlong and Shen, Chunhua and Cheng, Baoshan and Shen, Hao and Xia, Huaxia}, 131 | booktitle = {Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR)}, 132 | year={2021} 133 | } 134 | ``` 135 | 136 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .coco import build as build_coco 5 | from .ytvos import build as build_ytvos 6 | 7 | def get_coco_api_from_dataset(dataset): 8 | for _ in range(10): 9 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 10 | # break 11 | if isinstance(dataset, torch.utils.data.Subset): 12 | dataset = dataset.dataset 13 | if isinstance(dataset, torchvision.datasets.CocoDetection): 14 | return dataset.coco 15 | 16 | 17 | def build_dataset(image_set, args): 18 | if args.dataset_file == 'coco': 19 | return build_coco(image_set, args) 20 | if args.dataset_file == 'coco_panoptic': 21 | # to avoid making panopticapi required for coco 22 | from .coco_panoptic import build as build_coco_panoptic 23 | return build_coco_panoptic(image_set, args) 24 | if args.dataset_file == 'ytvos': 25 | return build_ytvos(image_set, args) 26 | raise ValueError(f'dataset {args.dataset_file} not supported') 27 | -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | COCO dataset which returns image_id for evaluation. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 5 | """ 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | from pycocotools import mask as coco_mask 12 | 13 | import datasets.transforms as T 14 | 15 | 16 | class CocoDetection(torchvision.datasets.CocoDetection): 17 | def __init__(self, img_folder, ann_file, transforms, return_masks): 18 | super(CocoDetection, self).__init__(img_folder, ann_file) 19 | self._transforms = transforms 20 | self.prepare = ConvertCocoPolysToMask(return_masks) 21 | 22 | def __getitem__(self, idx): 23 | img, target = super(CocoDetection, self).__getitem__(idx) 24 | image_id = self.ids[idx] 25 | target = {'image_id': image_id, 'annotations': target} 26 | 27 | img, target = self.prepare(img, target) 28 | if self._transforms is not None: 29 | img, target = self._transforms(img, target) 30 | return img, target 31 | 32 | 33 | def convert_coco_poly_to_mask(segmentations, height, width): 34 | masks = [] 35 | for polygons in segmentations: 36 | rles = coco_mask.frPyObjects(polygons, height, width) 37 | mask = coco_mask.decode(rles) 38 | if len(mask.shape) < 3: 39 | mask = mask[..., None] 40 | mask = torch.as_tensor(mask, dtype=torch.uint8) 41 | mask = mask.any(dim=2) 42 | masks.append(mask) 43 | if masks: 44 | masks = torch.stack(masks, dim=0) 45 | else: 46 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 47 | return masks 48 | 49 | 50 | class ConvertCocoPolysToMask(object): 51 | def __init__(self, return_masks=False): 52 | self.return_masks = return_masks 53 | 54 | def __call__(self, image, target): 55 | w, h = image.size 56 | 57 | image_id = target["image_id"] 58 | image_id = torch.tensor([image_id]) 59 | 60 | anno = target["annotations"] 61 | 62 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 63 | 64 | boxes = [obj["bbox"] for obj in anno] 65 | # guard against no boxes via resizing 66 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 67 | boxes[:, 2:] += boxes[:, :2] 68 | boxes[:, 0::2].clamp_(min=0, max=w) 69 | boxes[:, 1::2].clamp_(min=0, max=h) 70 | 71 | classes = [obj["category_id"] for obj in anno] 72 | classes = torch.tensor(classes, dtype=torch.int64) 73 | 74 | if self.return_masks: 75 | segmentations = [obj["segmentation"] for obj in anno] 76 | masks = convert_coco_poly_to_mask(segmentations, h, w) 77 | 78 | keypoints = None 79 | if anno and "keypoints" in anno[0]: 80 | keypoints = [obj["keypoints"] for obj in anno] 81 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 82 | num_keypoints = keypoints.shape[0] 83 | if num_keypoints: 84 | keypoints = keypoints.view(num_keypoints, -1, 3) 85 | 86 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 87 | boxes = boxes[keep] 88 | classes = classes[keep] 89 | if self.return_masks: 90 | masks = masks[keep] 91 | if keypoints is not None: 92 | keypoints = keypoints[keep] 93 | 94 | target = {} 95 | target["boxes"] = boxes 96 | target["labels"] = classes 97 | if self.return_masks: 98 | target["masks"] = masks 99 | target["image_id"] = image_id 100 | if keypoints is not None: 101 | target["keypoints"] = keypoints 102 | 103 | # for conversion to coco api 104 | area = torch.tensor([obj["area"] for obj in anno]) 105 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 106 | target["area"] = area[keep] 107 | target["iscrowd"] = iscrowd[keep] 108 | 109 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 110 | target["size"] = torch.as_tensor([int(h), int(w)]) 111 | 112 | return image, target 113 | 114 | 115 | def make_coco_transforms(image_set): 116 | 117 | normalize = T.Compose([ 118 | T.ToTensor(), 119 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 120 | ]) 121 | 122 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 123 | 124 | if image_set == 'train': 125 | return T.Compose([ 126 | T.RandomHorizontalFlip(), 127 | T.RandomSelect( 128 | T.RandomResize(scales, max_size=1333), 129 | T.Compose([ 130 | T.RandomResize([400, 500, 600]), 131 | T.RandomSizeCrop(384, 600), 132 | T.RandomResize(scales, max_size=1333), 133 | ]) 134 | ), 135 | normalize, 136 | ]) 137 | 138 | if image_set == 'val': 139 | return T.Compose([ 140 | T.RandomResize([800], max_size=1333), 141 | normalize, 142 | ]) 143 | 144 | raise ValueError(f'unknown {image_set}') 145 | 146 | 147 | def build(image_set, args): 148 | root = Path(args.coco_path) 149 | assert root.exists(), f'provided COCO path {root} does not exist' 150 | mode = 'instances' 151 | PATHS = { 152 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), 153 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), 154 | } 155 | img_folder, ann_file = PATHS[image_set] 156 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) 157 | return dataset 158 | -------------------------------------------------------------------------------- /datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | COCO evaluator that works in distributed mode. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 5 | The difference is that there is less copy-pasting from pycocotools 6 | in the end of the file, as python3 can suppress prints with contextlib 7 | """ 8 | import os 9 | import contextlib 10 | import copy 11 | import numpy as np 12 | import torch 13 | 14 | from pycocotools.cocoeval import COCOeval 15 | from pycocotools.coco import COCO 16 | import pycocotools.mask as mask_util 17 | 18 | from util.misc import all_gather 19 | 20 | 21 | class CocoEvaluator(object): 22 | def __init__(self, coco_gt, iou_types): 23 | assert isinstance(iou_types, (list, tuple)) 24 | coco_gt = copy.deepcopy(coco_gt) 25 | self.coco_gt = coco_gt 26 | 27 | self.iou_types = iou_types 28 | self.coco_eval = {} 29 | for iou_type in iou_types: 30 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 31 | 32 | self.img_ids = [] 33 | self.eval_imgs = {k: [] for k in iou_types} 34 | 35 | def update(self, predictions): 36 | img_ids = list(np.unique(list(predictions.keys()))) 37 | self.img_ids.extend(img_ids) 38 | 39 | for iou_type in self.iou_types: 40 | results = self.prepare(predictions, iou_type) 41 | 42 | # suppress pycocotools prints 43 | with open(os.devnull, 'w') as devnull: 44 | with contextlib.redirect_stdout(devnull): 45 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 46 | coco_eval = self.coco_eval[iou_type] 47 | 48 | coco_eval.cocoDt = coco_dt 49 | coco_eval.params.imgIds = list(img_ids) 50 | img_ids, eval_imgs = evaluate(coco_eval) 51 | 52 | self.eval_imgs[iou_type].append(eval_imgs) 53 | 54 | def synchronize_between_processes(self): 55 | for iou_type in self.iou_types: 56 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 57 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 58 | 59 | def accumulate(self): 60 | for coco_eval in self.coco_eval.values(): 61 | coco_eval.accumulate() 62 | 63 | def summarize(self): 64 | for iou_type, coco_eval in self.coco_eval.items(): 65 | print("IoU metric: {}".format(iou_type)) 66 | coco_eval.summarize() 67 | 68 | def prepare(self, predictions, iou_type): 69 | if iou_type == "bbox": 70 | return self.prepare_for_coco_detection(predictions) 71 | elif iou_type == "segm": 72 | return self.prepare_for_coco_segmentation(predictions) 73 | elif iou_type == "keypoints": 74 | return self.prepare_for_coco_keypoint(predictions) 75 | else: 76 | raise ValueError("Unknown iou type {}".format(iou_type)) 77 | 78 | def prepare_for_coco_detection(self, predictions): 79 | coco_results = [] 80 | for original_id, prediction in predictions.items(): 81 | if len(prediction) == 0: 82 | continue 83 | 84 | boxes = prediction["boxes"] 85 | boxes = convert_to_xywh(boxes).tolist() 86 | scores = prediction["scores"].tolist() 87 | labels = prediction["labels"].tolist() 88 | 89 | coco_results.extend( 90 | [ 91 | { 92 | "image_id": original_id, 93 | "category_id": labels[k], 94 | "bbox": box, 95 | "score": scores[k], 96 | } 97 | for k, box in enumerate(boxes) 98 | ] 99 | ) 100 | return coco_results 101 | 102 | def prepare_for_coco_segmentation(self, predictions): 103 | coco_results = [] 104 | for original_id, prediction in predictions.items(): 105 | if len(prediction) == 0: 106 | continue 107 | 108 | scores = prediction["scores"] 109 | labels = prediction["labels"] 110 | masks = prediction["masks"] 111 | 112 | masks = masks > 0.5 113 | 114 | scores = prediction["scores"].tolist() 115 | labels = prediction["labels"].tolist() 116 | 117 | rles = [ 118 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 119 | for mask in masks 120 | ] 121 | for rle in rles: 122 | rle["counts"] = rle["counts"].decode("utf-8") 123 | 124 | coco_results.extend( 125 | [ 126 | { 127 | "image_id": original_id, 128 | "category_id": labels[k], 129 | "segmentation": rle, 130 | "score": scores[k], 131 | } 132 | for k, rle in enumerate(rles) 133 | ] 134 | ) 135 | return coco_results 136 | 137 | def prepare_for_coco_keypoint(self, predictions): 138 | coco_results = [] 139 | for original_id, prediction in predictions.items(): 140 | if len(prediction) == 0: 141 | continue 142 | 143 | boxes = prediction["boxes"] 144 | boxes = convert_to_xywh(boxes).tolist() 145 | scores = prediction["scores"].tolist() 146 | labels = prediction["labels"].tolist() 147 | keypoints = prediction["keypoints"] 148 | keypoints = keypoints.flatten(start_dim=1).tolist() 149 | 150 | coco_results.extend( 151 | [ 152 | { 153 | "image_id": original_id, 154 | "category_id": labels[k], 155 | 'keypoints': keypoint, 156 | "score": scores[k], 157 | } 158 | for k, keypoint in enumerate(keypoints) 159 | ] 160 | ) 161 | return coco_results 162 | 163 | 164 | def convert_to_xywh(boxes): 165 | xmin, ymin, xmax, ymax = boxes.unbind(1) 166 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 167 | 168 | 169 | def merge(img_ids, eval_imgs): 170 | all_img_ids = all_gather(img_ids) 171 | all_eval_imgs = all_gather(eval_imgs) 172 | 173 | merged_img_ids = [] 174 | for p in all_img_ids: 175 | merged_img_ids.extend(p) 176 | 177 | merged_eval_imgs = [] 178 | for p in all_eval_imgs: 179 | merged_eval_imgs.append(p) 180 | 181 | merged_img_ids = np.array(merged_img_ids) 182 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 183 | 184 | # keep only unique (and in sorted order) images 185 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 186 | merged_eval_imgs = merged_eval_imgs[..., idx] 187 | 188 | return merged_img_ids, merged_eval_imgs 189 | 190 | 191 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 192 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 193 | img_ids = list(img_ids) 194 | eval_imgs = list(eval_imgs.flatten()) 195 | 196 | coco_eval.evalImgs = eval_imgs 197 | coco_eval.params.imgIds = img_ids 198 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 199 | 200 | 201 | ################################################################# 202 | # From pycocotools, just removed the prints and fixed 203 | # a Python3 bug about unicode not defined 204 | ################################################################# 205 | 206 | 207 | def evaluate(self): 208 | ''' 209 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 210 | :return: None 211 | ''' 212 | # tic = time.time() 213 | # print('Running per image evaluation...') 214 | p = self.params 215 | # add backward compatibility if useSegm is specified in params 216 | if p.useSegm is not None: 217 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 218 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 219 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 220 | p.imgIds = list(np.unique(p.imgIds)) 221 | if p.useCats: 222 | p.catIds = list(np.unique(p.catIds)) 223 | p.maxDets = sorted(p.maxDets) 224 | self.params = p 225 | 226 | self._prepare() 227 | # loop through images, area range, max detection number 228 | catIds = p.catIds if p.useCats else [-1] 229 | 230 | if p.iouType == 'segm' or p.iouType == 'bbox': 231 | computeIoU = self.computeIoU 232 | elif p.iouType == 'keypoints': 233 | computeIoU = self.computeOks 234 | self.ious = { 235 | (imgId, catId): computeIoU(imgId, catId) 236 | for imgId in p.imgIds 237 | for catId in catIds} 238 | 239 | evaluateImg = self.evaluateImg 240 | maxDet = p.maxDets[-1] 241 | evalImgs = [ 242 | evaluateImg(imgId, catId, areaRng, maxDet) 243 | for catId in catIds 244 | for areaRng in p.areaRng 245 | for imgId in p.imgIds 246 | ] 247 | # this is NOT in the pycocotools code, but could be done outside 248 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 249 | self._paramsEval = copy.deepcopy(self.params) 250 | # toc = time.time() 251 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 252 | return p.imgIds, evalImgs 253 | 254 | ################################################################# 255 | # end of straight copy from pycocotools, just removing the prints 256 | ################################################################# 257 | -------------------------------------------------------------------------------- /datasets/coco_panoptic.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | from panopticapi.utils import rgb2id 9 | from util.box_ops import masks_to_boxes 10 | 11 | from .coco import make_coco_transforms 12 | 13 | 14 | class CocoPanoptic: 15 | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): 16 | with open(ann_file, 'r') as f: 17 | self.coco = json.load(f) 18 | 19 | # sort 'images' field so that they are aligned with 'annotations' 20 | # i.e., in alphabetical order 21 | self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id']) 22 | # sanity check 23 | if "annotations" in self.coco: 24 | for img, ann in zip(self.coco['images'], self.coco['annotations']): 25 | assert img['file_name'][:-4] == ann['file_name'][:-4] 26 | 27 | self.img_folder = img_folder 28 | self.ann_folder = ann_folder 29 | self.ann_file = ann_file 30 | self.transforms = transforms 31 | self.return_masks = return_masks 32 | 33 | def __getitem__(self, idx): 34 | ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx] 35 | img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg') 36 | ann_path = Path(self.ann_folder) / ann_info['file_name'] 37 | 38 | img = Image.open(img_path).convert('RGB') 39 | w, h = img.size 40 | if "segments_info" in ann_info: 41 | masks = np.asarray(Image.open(ann_path), dtype=np.uint32) 42 | masks = rgb2id(masks) 43 | 44 | ids = np.array([ann['id'] for ann in ann_info['segments_info']]) 45 | masks = masks == ids[:, None, None] 46 | 47 | masks = torch.as_tensor(masks, dtype=torch.uint8) 48 | labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64) 49 | 50 | target = {} 51 | target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]]) 52 | if self.return_masks: 53 | target['masks'] = masks 54 | target['labels'] = labels 55 | 56 | target["boxes"] = masks_to_boxes(masks) 57 | 58 | target['size'] = torch.as_tensor([int(h), int(w)]) 59 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 60 | if "segments_info" in ann_info: 61 | for name in ['iscrowd', 'area']: 62 | target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']]) 63 | 64 | if self.transforms is not None: 65 | img, target = self.transforms(img, target) 66 | 67 | return img, target 68 | 69 | def __len__(self): 70 | return len(self.coco['images']) 71 | 72 | def get_height_and_width(self, idx): 73 | img_info = self.coco['images'][idx] 74 | height = img_info['height'] 75 | width = img_info['width'] 76 | return height, width 77 | 78 | 79 | def build(image_set, args): 80 | img_folder_root = Path(args.coco_path) 81 | ann_folder_root = Path(args.coco_panoptic_path) 82 | assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist' 83 | assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist' 84 | mode = 'panoptic' 85 | PATHS = { 86 | "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'), 87 | "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'), 88 | } 89 | 90 | img_folder, ann_file = PATHS[image_set] 91 | img_folder_path = img_folder_root / img_folder 92 | ann_folder = ann_folder_root / f'{mode}_{img_folder}' 93 | ann_file = ann_folder_root / ann_file 94 | 95 | dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file, 96 | transforms=make_coco_transforms(image_set), return_masks=args.masks) 97 | 98 | return dataset 99 | -------------------------------------------------------------------------------- /datasets/panoptic_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | import util.misc as utils 5 | 6 | try: 7 | from panopticapi.evaluation import pq_compute 8 | except ImportError: 9 | pass 10 | 11 | 12 | class PanopticEvaluator(object): 13 | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): 14 | self.gt_json = ann_file 15 | self.gt_folder = ann_folder 16 | if utils.is_main_process(): 17 | if not os.path.exists(output_dir): 18 | os.mkdir(output_dir) 19 | self.output_dir = output_dir 20 | self.predictions = [] 21 | 22 | def update(self, predictions): 23 | for p in predictions: 24 | with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: 25 | f.write(p.pop("png_string")) 26 | 27 | self.predictions += predictions 28 | 29 | def synchronize_between_processes(self): 30 | all_predictions = utils.all_gather(self.predictions) 31 | merged_predictions = [] 32 | for p in all_predictions: 33 | merged_predictions += p 34 | self.predictions = merged_predictions 35 | 36 | def summarize(self): 37 | if utils.is_main_process(): 38 | json_data = {"annotations": self.predictions} 39 | predictions_json = os.path.join(self.output_dir, "predictions.json") 40 | with open(predictions_json, "w") as f: 41 | f.write(json.dumps(json_data)) 42 | return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir) 43 | return None 44 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | """ 2 | Transforms and data augmentation for sequence level images, bboxes and masks. 3 | """ 4 | import random 5 | 6 | import PIL 7 | import torch 8 | import torchvision.transforms as T 9 | import torchvision.transforms.functional as F 10 | 11 | from util.box_ops import box_xyxy_to_cxcywh, box_iou 12 | from util.misc import interpolate 13 | import numpy as np 14 | from numpy import random as rand 15 | from PIL import Image 16 | import cv2 17 | 18 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6): 19 | assert mode in ['iou', 'iof'] 20 | bboxes1 = bboxes1.astype(np.float32) 21 | bboxes2 = bboxes2.astype(np.float32) 22 | rows = bboxes1.shape[0] 23 | cols = bboxes2.shape[0] 24 | ious = np.zeros((rows, cols), dtype=np.float32) 25 | if rows * cols == 0: 26 | return ious 27 | exchange = False 28 | if bboxes1.shape[0] > bboxes2.shape[0]: 29 | bboxes1, bboxes2 = bboxes2, bboxes1 30 | ious = np.zeros((cols, rows), dtype=np.float32) 31 | exchange = True 32 | area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1]) 33 | area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1]) 34 | for i in range(bboxes1.shape[0]): 35 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 36 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 37 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 38 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 39 | overlap = np.maximum(x_end - x_start, 0) * np.maximum(y_end - y_start, 0) 40 | if mode == 'iou': 41 | union = area1[i] + area2 - overlap 42 | else: 43 | union = area1[i] if not exchange else area2 44 | union = np.maximum(union, eps) 45 | ious[i, :] = overlap / union 46 | if exchange: 47 | ious = ious.T 48 | return ious 49 | 50 | 51 | def crop(clip, target, region): 52 | cropped_image = [] 53 | for image in clip: 54 | cropped_image.append(F.crop(image, *region)) 55 | 56 | target = target.copy() 57 | i, j, h, w = region 58 | 59 | # should we do something wrt the original size? 60 | target["size"] = torch.tensor([h, w]) 61 | 62 | fields = ["labels", "area", "iscrowd"] 63 | 64 | if "boxes" in target: 65 | boxes = target["boxes"] 66 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 67 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 68 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 69 | cropped_boxes = cropped_boxes.clamp(min=0) 70 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 71 | target["boxes"] = cropped_boxes.reshape(-1, 4) 72 | target["area"] = area 73 | fields.append("boxes") 74 | 75 | if "masks" in target: 76 | # FIXME should we update the area here if there are no boxes? 77 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 78 | fields.append("masks") 79 | 80 | return cropped_image, target 81 | 82 | 83 | def hflip(clip, target): 84 | flipped_image = [] 85 | for image in clip: 86 | flipped_image.append(F.hflip(image)) 87 | 88 | w, h = clip[0].size 89 | 90 | target = target.copy() 91 | if "boxes" in target: 92 | boxes = target["boxes"] 93 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 94 | target["boxes"] = boxes 95 | 96 | if "masks" in target: 97 | target['masks'] = target['masks'].flip(-1) 98 | 99 | return flipped_image, target 100 | 101 | def vflip(image,target): 102 | flipped_image = [] 103 | for image in clip: 104 | flipped_image.append(F.vflip(image)) 105 | w, h = clip[0].size 106 | target = target.copy() 107 | if "boxes" in target: 108 | boxes = target["boxes"] 109 | boxes = boxes[:, [0, 3, 2, 1]] * torch.as_tensor([1, -1, 1, -1]) + torch.as_tensor([0, h, 0, h]) 110 | target["boxes"] = boxes 111 | 112 | if "masks" in target: 113 | target['masks'] = target['masks'].flip(1) 114 | 115 | return flipped_image, target 116 | 117 | def resize(clip, target, size, max_size=None): 118 | # size can be min_size (scalar) or (w, h) tuple 119 | 120 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 121 | w, h = image_size 122 | if max_size is not None: 123 | min_original_size = float(min((w, h))) 124 | max_original_size = float(max((w, h))) 125 | if max_original_size / min_original_size * size > max_size: 126 | size = int(round(max_size * min_original_size / max_original_size)) 127 | 128 | if (w <= h and w == size) or (h <= w and h == size): 129 | return (h, w) 130 | 131 | if w < h: 132 | ow = size 133 | oh = int(size * h / w) 134 | else: 135 | oh = size 136 | ow = int(size * w / h) 137 | 138 | return (oh, ow) 139 | 140 | def get_size(image_size, size, max_size=None): 141 | if isinstance(size, (list, tuple)): 142 | return size[::-1] 143 | else: 144 | return get_size_with_aspect_ratio(image_size, size, max_size) 145 | 146 | size = get_size(clip[0].size, size, max_size) 147 | rescaled_image = [] 148 | for image in clip: 149 | rescaled_image.append(F.resize(image, size)) 150 | 151 | if target is None: 152 | return rescaled_image, None 153 | 154 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image[0].size, clip[0].size)) 155 | ratio_width, ratio_height = ratios 156 | 157 | target = target.copy() 158 | if "boxes" in target: 159 | boxes = target["boxes"] 160 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 161 | target["boxes"] = scaled_boxes 162 | 163 | if "area" in target: 164 | area = target["area"] 165 | scaled_area = area * (ratio_width * ratio_height) 166 | target["area"] = scaled_area 167 | 168 | h, w = size 169 | target["size"] = torch.tensor([h, w]) 170 | 171 | if "masks" in target: 172 | if target['masks'].shape[0]>0: 173 | target['masks'] = interpolate( 174 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 175 | else: 176 | target['masks'] = torch.zeros((target['masks'].shape[0],h,w)) 177 | return rescaled_image, target 178 | 179 | 180 | def pad(clip, target, padding): 181 | # assumes that we only pad on the bottom right corners 182 | padded_image = [] 183 | for image in clip: 184 | padded_image.append(F.pad(image, (0, 0, padding[0], padding[1]))) 185 | if target is None: 186 | return padded_image, None 187 | target = target.copy() 188 | # should we do something wrt the original size? 189 | target["size"] = torch.tensor(padded_image[0].size[::-1]) 190 | if "masks" in target: 191 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 192 | return padded_image, target 193 | 194 | 195 | class RandomCrop(object): 196 | def __init__(self, size): 197 | self.size = size 198 | 199 | def __call__(self, img, target): 200 | region = T.RandomCrop.get_params(img, self.size) 201 | return crop(img, target, region) 202 | 203 | 204 | class RandomSizeCrop(object): 205 | def __init__(self, min_size: int, max_size: int): 206 | self.min_size = min_size 207 | self.max_size = max_size 208 | 209 | def __call__(self, img: PIL.Image.Image, target: dict): 210 | w = random.randint(self.min_size, min(img[0].width, self.max_size)) 211 | h = random.randint(self.min_size, min(img[0].height, self.max_size)) 212 | region = T.RandomCrop.get_params(img[0], [h, w]) 213 | return crop(img, target, region) 214 | 215 | 216 | class CenterCrop(object): 217 | def __init__(self, size): 218 | self.size = size 219 | 220 | def __call__(self, img, target): 221 | image_width, image_height = img.size 222 | crop_height, crop_width = self.size 223 | crop_top = int(round((image_height - crop_height) / 2.)) 224 | crop_left = int(round((image_width - crop_width) / 2.)) 225 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 226 | 227 | 228 | class MinIoURandomCrop(object): 229 | def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3): 230 | self.min_ious = min_ious 231 | self.sample_mode = (1, *min_ious, 0) 232 | self.min_crop_size = min_crop_size 233 | 234 | def __call__(self, img, target): 235 | w,h = img.size 236 | while True: 237 | mode = random.choice(self.sample_mode) 238 | self.mode = mode 239 | if mode == 1: 240 | return img,target 241 | min_iou = mode 242 | boxes = target['boxes'].numpy() 243 | labels = target['labels'] 244 | 245 | for i in range(50): 246 | new_w = rand.uniform(self.min_crop_size * w, w) 247 | new_h = rand.uniform(self.min_crop_size * h, h) 248 | if new_h / new_w < 0.5 or new_h / new_w > 2: 249 | continue 250 | left = rand.uniform(w - new_w) 251 | top = rand.uniform(h - new_h) 252 | patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h))) 253 | if patch[2] == patch[0] or patch[3] == patch[1]: 254 | continue 255 | overlaps = bbox_overlaps(patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) 256 | if len(overlaps) > 0 and overlaps.min() < min_iou: 257 | continue 258 | 259 | if len(overlaps) > 0: 260 | def is_center_of_bboxes_in_patch(boxes, patch): 261 | center = (boxes[:, :2] + boxes[:, 2:]) / 2 262 | mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * (center[:, 1] < patch[3])) 263 | return mask 264 | mask = is_center_of_bboxes_in_patch(boxes, patch) 265 | if False in mask: 266 | continue 267 | #TODO: use no center boxes 268 | #if not mask.any(): 269 | # continue 270 | 271 | boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) 272 | boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) 273 | boxes -= np.tile(patch[:2], 2) 274 | target['boxes'] = torch.tensor(boxes) 275 | 276 | img = np.asarray(img)[patch[1]:patch[3], patch[0]:patch[2]] 277 | img = Image.fromarray(img) 278 | width, height = img.size 279 | target['orig_size'] = torch.tensor([height,width]) 280 | target['size'] = torch.tensor([height,width]) 281 | return img,target 282 | 283 | 284 | class RandomContrast(object): 285 | def __init__(self, lower=0.5, upper=1.5): 286 | self.lower = lower 287 | self.upper = upper 288 | assert self.upper >= self.lower, "contrast upper must be >= lower." 289 | assert self.lower >= 0, "contrast lower must be non-negative." 290 | def __call__(self, image, target): 291 | 292 | if rand.randint(2): 293 | alpha = rand.uniform(self.lower, self.upper) 294 | image *= alpha 295 | return image, target 296 | 297 | class RandomBrightness(object): 298 | def __init__(self, delta=32): 299 | assert delta >= 0.0 300 | assert delta <= 255.0 301 | self.delta = delta 302 | def __call__(self, image, target): 303 | if rand.randint(2): 304 | delta = rand.uniform(-self.delta, self.delta) 305 | image += delta 306 | return image, target 307 | 308 | class RandomSaturation(object): 309 | def __init__(self, lower=0.5, upper=1.5): 310 | self.lower = lower 311 | self.upper = upper 312 | assert self.upper >= self.lower, "contrast upper must be >= lower." 313 | assert self.lower >= 0, "contrast lower must be non-negative." 314 | 315 | def __call__(self, image, target): 316 | if rand.randint(2): 317 | image[:, :, 1] *= rand.uniform(self.lower, self.upper) 318 | return image, target 319 | 320 | class RandomHue(object): # 321 | def __init__(self, delta=18.0): 322 | assert delta >= 0.0 and delta <= 360.0 323 | self.delta = delta 324 | 325 | def __call__(self, image, target): 326 | if rand.randint(2): 327 | image[:, :, 0] += rand.uniform(-self.delta, self.delta) 328 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 329 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 330 | return image, target 331 | 332 | class RandomLightingNoise(object): 333 | def __init__(self): 334 | self.perms = ((0, 1, 2), (0, 2, 1), 335 | (1, 0, 2), (1, 2, 0), 336 | (2, 0, 1), (2, 1, 0)) 337 | def __call__(self, image, target): 338 | if rand.randint(2): 339 | swap = self.perms[rand.randint(len(self.perms))] 340 | shuffle = SwapChannels(swap) # shuffle channels 341 | image = shuffle(image) 342 | return image, target 343 | 344 | class ConvertColor(object): 345 | def __init__(self, current='BGR', transform='HSV'): 346 | self.transform = transform 347 | self.current = current 348 | 349 | def __call__(self, image, target): 350 | if self.current == 'BGR' and self.transform == 'HSV': 351 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 352 | elif self.current == 'HSV' and self.transform == 'BGR': 353 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) 354 | else: 355 | raise NotImplementedError 356 | return image, target 357 | 358 | class SwapChannels(object): 359 | def __init__(self, swaps): 360 | self.swaps = swaps 361 | def __call__(self, image): 362 | image = image[:, :, self.swaps] 363 | return image 364 | 365 | class PhotometricDistort(object): 366 | def __init__(self): 367 | self.pd = [ 368 | RandomContrast(), 369 | ConvertColor(transform='HSV'), 370 | RandomSaturation(), 371 | RandomHue(), 372 | ConvertColor(current='HSV', transform='BGR'), 373 | RandomContrast() 374 | ] 375 | self.rand_brightness = RandomBrightness() 376 | self.rand_light_noise = RandomLightingNoise() 377 | 378 | def __call__(self,clip,target): 379 | imgs = [] 380 | for img in clip: 381 | img = np.asarray(img).astype('float32') 382 | img, target = self.rand_brightness(img, target) 383 | if rand.randint(2): 384 | distort = Compose(self.pd[:-1]) 385 | else: 386 | distort = Compose(self.pd[1:]) 387 | img, target = distort(img, target) 388 | img, target = self.rand_light_noise(img, target) 389 | imgs.append(Image.fromarray(img.astype('uint8'))) 390 | return imgs, target 391 | 392 | #NOTICE: if used for mask, need to change 393 | class Expand(object): 394 | def __init__(self, mean): 395 | self.mean = mean 396 | def __call__(self, clip, target): 397 | if rand.randint(2): 398 | return clip,target 399 | imgs = [] 400 | masks = [] 401 | image = np.asarray(clip[0]).astype('float32') 402 | height, width, depth = image.shape 403 | ratio = rand.uniform(1, 4) 404 | left = rand.uniform(0, width*ratio - width) 405 | top = rand.uniform(0, height*ratio - height) 406 | for i in range(len(clip)): 407 | image = np.asarray(clip[i]).astype('float32') 408 | expand_image = np.zeros((int(height*ratio), int(width*ratio), depth),dtype=image.dtype) 409 | expand_image[:, :, :] = self.mean 410 | expand_image[int(top):int(top + height),int(left):int(left + width)] = image 411 | imgs.append(Image.fromarray(expand_image.astype('uint8'))) 412 | expand_mask = torch.zeros((int(height*ratio), int(width*ratio)),dtype=torch.uint8) 413 | expand_mask[int(top):int(top + height),int(left):int(left + width)] = target['masks'][i] 414 | masks.append(expand_mask) 415 | boxes = target['boxes'].numpy() 416 | boxes[:, :2] += (int(left), int(top)) 417 | boxes[:, 2:] += (int(left), int(top)) 418 | target['boxes'] = torch.tensor(boxes) 419 | target['masks']=torch.stack(masks) 420 | return imgs, target 421 | 422 | class RandomHorizontalFlip(object): 423 | def __init__(self, p=0.5): 424 | self.p = p 425 | 426 | def __call__(self, img, target): 427 | if random.random() < self.p: 428 | return hflip(img, target) 429 | return img, target 430 | 431 | class RandomVerticalFlip(object): 432 | def __init__(self, p=0.5): 433 | self.p = p 434 | 435 | def __call__(self, img, target): 436 | if random.random() < self.p: 437 | return vflip(img, target) 438 | return img, target 439 | 440 | 441 | class RandomResize(object): 442 | def __init__(self, sizes, max_size=None): 443 | assert isinstance(sizes, (list, tuple)) 444 | self.sizes = sizes 445 | self.max_size = max_size 446 | 447 | def __call__(self, img, target=None): 448 | size = random.choice(self.sizes) 449 | return resize(img, target, size, self.max_size) 450 | 451 | 452 | class RandomPad(object): 453 | def __init__(self, max_pad): 454 | self.max_pad = max_pad 455 | 456 | def __call__(self, img, target): 457 | pad_x = random.randint(0, self.max_pad) 458 | pad_y = random.randint(0, self.max_pad) 459 | return pad(img, target, (pad_x, pad_y)) 460 | 461 | 462 | class RandomSelect(object): 463 | """ 464 | Randomly selects between transforms1 and transforms2, 465 | with probability p for transforms1 and (1 - p) for transforms2 466 | """ 467 | def __init__(self, transforms1, transforms2, p=0.5): 468 | self.transforms1 = transforms1 469 | self.transforms2 = transforms2 470 | self.p = p 471 | 472 | def __call__(self, img, target): 473 | if random.random() < self.p: 474 | return self.transforms1(img, target) 475 | return self.transforms2(img, target) 476 | 477 | 478 | class ToTensor(object): 479 | def __call__(self, clip, target): 480 | img = [] 481 | for im in clip: 482 | img.append(F.to_tensor(im)) 483 | return img, target 484 | 485 | 486 | class RandomErasing(object): 487 | 488 | def __init__(self, *args, **kwargs): 489 | self.eraser = T.RandomErasing(*args, **kwargs) 490 | 491 | def __call__(self, img, target): 492 | return self.eraser(img), target 493 | 494 | 495 | class Normalize(object): 496 | def __init__(self, mean, std): 497 | self.mean = mean 498 | self.std = std 499 | 500 | def __call__(self, clip, target=None): 501 | image = [] 502 | for im in clip: 503 | image.append(F.normalize(im, mean=self.mean, std=self.std)) 504 | if target is None: 505 | return image, None 506 | target = target.copy() 507 | h, w = image[0].shape[-2:] 508 | if "boxes" in target: 509 | boxes = target["boxes"] 510 | boxes = box_xyxy_to_cxcywh(boxes) 511 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 512 | target["boxes"] = boxes 513 | return image, target 514 | 515 | 516 | class Compose(object): 517 | def __init__(self, transforms): 518 | self.transforms = transforms 519 | 520 | def __call__(self, image, target): 521 | for t in self.transforms: 522 | image, target = t(image, target) 523 | return image, target 524 | 525 | def __repr__(self): 526 | format_string = self.__class__.__name__ + "(" 527 | for t in self.transforms: 528 | format_string += "\n" 529 | format_string += " {0}".format(t) 530 | format_string += "\n)" 531 | return format_string 532 | -------------------------------------------------------------------------------- /datasets/ytvos.py: -------------------------------------------------------------------------------- 1 | """ 2 | YoutubeVIS data loader 3 | """ 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.utils.data 8 | import torchvision 9 | from pycocotools.ytvos import YTVOS 10 | from pycocotools.ytvoseval import YTVOSeval 11 | import datasets.transforms as T 12 | from pycocotools import mask as coco_mask 13 | import os 14 | from PIL import Image 15 | from random import randint 16 | import cv2 17 | import random 18 | 19 | class YTVOSDataset: 20 | def __init__(self, img_folder, ann_file, transforms, return_masks, num_frames): 21 | self.img_folder = img_folder 22 | self.ann_file = ann_file 23 | self._transforms = transforms 24 | self.return_masks = return_masks 25 | self.num_frames = num_frames 26 | self.prepare = ConvertCocoPolysToMask(return_masks) 27 | self.ytvos = YTVOS(ann_file) 28 | self.cat_ids = self.ytvos.getCatIds() 29 | self.vid_ids = self.ytvos.getVidIds() 30 | self.vid_infos = [] 31 | for i in self.vid_ids: 32 | info = self.ytvos.loadVids([i])[0] 33 | info['filenames'] = info['file_names'] 34 | self.vid_infos.append(info) 35 | self.img_ids = [] 36 | for idx, vid_info in enumerate(self.vid_infos): 37 | for frame_id in range(len(vid_info['filenames'])): 38 | self.img_ids.append((idx, frame_id)) 39 | def __len__(self): 40 | return len(self.img_ids) 41 | 42 | def __getitem__(self, idx): 43 | vid, frame_id = self.img_ids[idx] 44 | vid_id = self.vid_infos[vid]['id'] 45 | img = [] 46 | vid_len = len(self.vid_infos[vid]['file_names']) 47 | inds = list(range(self.num_frames)) 48 | inds = [i%vid_len for i in inds][::-1] 49 | # if random 50 | # random.shuffle(inds) 51 | for j in range(self.num_frames): 52 | img_path = os.path.join(str(self.img_folder), self.vid_infos[vid]['file_names'][frame_id-inds[j]]) 53 | img.append(Image.open(img_path).convert('RGB')) 54 | ann_ids = self.ytvos.getAnnIds(vidIds=[vid_id]) 55 | target = self.ytvos.loadAnns(ann_ids) 56 | target = {'image_id': idx, 'video_id': vid, 'frame_id': frame_id, 'annotations': target} 57 | target = self.prepare(img[0], target, inds, self.num_frames) 58 | if self._transforms is not None: 59 | img, target = self._transforms(img, target) 60 | return torch.cat(img,dim=0), target 61 | 62 | 63 | def convert_coco_poly_to_mask(segmentations, height, width): 64 | masks = [] 65 | for polygons in segmentations: 66 | if not polygons: 67 | mask = torch.zeros((height,width), dtype=torch.uint8) 68 | else: 69 | rles = coco_mask.frPyObjects(polygons, height, width) 70 | mask = coco_mask.decode(rles) 71 | if len(mask.shape) < 3: 72 | mask = mask[..., None] 73 | mask = torch.as_tensor(mask, dtype=torch.uint8) 74 | mask = mask.any(dim=2) 75 | masks.append(mask) 76 | if masks: 77 | masks = torch.stack(masks, dim=0) 78 | else: 79 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 80 | return masks 81 | 82 | 83 | class ConvertCocoPolysToMask(object): 84 | def __init__(self, return_masks=False): 85 | self.return_masks = return_masks 86 | 87 | def __call__(self, image, target, inds, num_frames): 88 | w, h = image.size 89 | image_id = target["image_id"] 90 | frame_id = target['frame_id'] 91 | image_id = torch.tensor([image_id]) 92 | 93 | anno = target["annotations"] 94 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 95 | boxes = [] 96 | classes = [] 97 | segmentations = [] 98 | area = [] 99 | iscrowd = [] 100 | valid = [] 101 | # add valid flag for bboxes 102 | for i, ann in enumerate(anno): 103 | for j in range(num_frames): 104 | bbox = ann['bboxes'][frame_id-inds[j]] 105 | areas = ann['areas'][frame_id-inds[j]] 106 | segm = ann['segmentations'][frame_id-inds[j]] 107 | clas = ann["category_id"] 108 | # for empty boxes 109 | if bbox is None: 110 | bbox = [0,0,0,0] 111 | areas = 0 112 | valid.append(0) 113 | clas = 0 114 | else: 115 | valid.append(1) 116 | crowd = ann["iscrowd"] if "iscrowd" in ann else 0 117 | boxes.append(bbox) 118 | area.append(areas) 119 | segmentations.append(segm) 120 | classes.append(clas) 121 | iscrowd.append(crowd) 122 | # guard against no boxes via resizing 123 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 124 | boxes[:, 2:] += boxes[:, :2] 125 | boxes[:, 0::2].clamp_(min=0, max=w) 126 | boxes[:, 1::2].clamp_(min=0, max=h) 127 | classes = torch.tensor(classes, dtype=torch.int64) 128 | if self.return_masks: 129 | masks = convert_coco_poly_to_mask(segmentations, h, w) 130 | target = {} 131 | target["boxes"] = boxes 132 | target["labels"] = classes 133 | if self.return_masks: 134 | target["masks"] = masks 135 | target["image_id"] = image_id 136 | 137 | # for conversion to coco api 138 | area = torch.tensor(area) 139 | iscrowd = torch.tensor(iscrowd) 140 | target["valid"] = torch.tensor(valid) 141 | target["area"] = area 142 | target["iscrowd"] = iscrowd 143 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 144 | target["size"] = torch.as_tensor([int(h), int(w)]) 145 | return target 146 | 147 | 148 | def make_coco_transforms(image_set): 149 | 150 | normalize = T.Compose([ 151 | T.ToTensor(), 152 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 153 | ]) 154 | 155 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768] 156 | 157 | if image_set == 'train': 158 | return T.Compose([ 159 | T.RandomHorizontalFlip(), 160 | T.RandomResize(scales, max_size=800), 161 | T.PhotometricDistort(), 162 | T.Compose([ 163 | T.RandomResize([400, 500, 600]), 164 | T.RandomSizeCrop(384, 600), 165 | # To suit the GPU memory the scale might be different 166 | T.RandomResize([300], max_size=540),#for r50 167 | #T.RandomResize([280], max_size=504),#for r101 168 | ]), 169 | normalize, 170 | ]) 171 | 172 | if image_set == 'val': 173 | return T.Compose([ 174 | T.RandomResize([360], max_size=640), 175 | normalize, 176 | ]) 177 | 178 | raise ValueError(f'unknown {image_set}') 179 | 180 | 181 | def build(image_set, args): 182 | root = Path(args.ytvos_path) 183 | assert root.exists(), f'provided YTVOS path {root} does not exist' 184 | mode = 'instances' 185 | PATHS = { 186 | "train": (root / "train/JPEGImages", root / "annotations" / f'{mode}_train_sub.json'), 187 | "val": (root / "valid/JPEGImages", root / "annotations" / f'{mode}_val_sub.json'), 188 | } 189 | img_folder, ann_file = PATHS[image_set] 190 | dataset = YTVOSDataset(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks, num_frames = args.num_frames) 191 | return dataset 192 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import math 6 | import os 7 | import sys 8 | from typing import Iterable 9 | 10 | import torch 11 | 12 | import util.misc as utils 13 | from datasets.coco_eval import CocoEvaluator 14 | from datasets.panoptic_eval import PanopticEvaluator 15 | 16 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 17 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 18 | device: torch.device, epoch: int, max_norm: float = 0): 19 | model.train() 20 | criterion.train() 21 | metric_logger = utils.MetricLogger(delimiter=" ") 22 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 23 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 24 | header = 'Epoch: [{}]'.format(epoch) 25 | print_freq = 10 26 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 27 | samples = samples.to(device) 28 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 29 | outputs = model(samples) 30 | loss_dict = criterion(outputs, targets) 31 | weight_dict = criterion.weight_dict 32 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 33 | 34 | # reduce losses over all GPUs for logging purposes 35 | loss_dict_reduced = utils.reduce_dict(loss_dict) 36 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 37 | for k, v in loss_dict_reduced.items()} 38 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 39 | for k, v in loss_dict_reduced.items() if k in weight_dict} 40 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 41 | 42 | loss_value = losses_reduced_scaled.item() 43 | 44 | if not math.isfinite(loss_value): 45 | print("Loss is {}, stopping training".format(loss_value)) 46 | print(loss_dict_reduced) 47 | sys.exit(1) 48 | optimizer.zero_grad() 49 | losses.backward() 50 | if max_norm > 0: 51 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 52 | optimizer.step() 53 | 54 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 55 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 56 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 57 | 58 | # gather the stats from all processes 59 | metric_logger.synchronize_between_processes() 60 | print("Averaged stats:", metric_logger) 61 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 62 | 63 | 64 | @torch.no_grad() 65 | def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): 66 | model.eval() 67 | criterion.eval() 68 | 69 | metric_logger = utils.MetricLogger(delimiter=" ") 70 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 71 | header = 'Test:' 72 | 73 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) 74 | coco_evaluator = CocoEvaluator(base_ds, iou_types) 75 | # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] 76 | 77 | panoptic_evaluator = None 78 | if 'panoptic' in postprocessors.keys(): 79 | panoptic_evaluator = PanopticEvaluator( 80 | data_loader.dataset.ann_file, 81 | data_loader.dataset.ann_folder, 82 | output_dir=os.path.join(output_dir, "panoptic_eval"), 83 | ) 84 | 85 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 86 | samples = samples.to(device) 87 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 88 | 89 | outputs = model(samples) 90 | loss_dict = criterion(outputs, targets) 91 | weight_dict = criterion.weight_dict 92 | 93 | # reduce losses over all GPUs for logging purposes 94 | loss_dict_reduced = utils.reduce_dict(loss_dict) 95 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 96 | for k, v in loss_dict_reduced.items() if k in weight_dict} 97 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 98 | for k, v in loss_dict_reduced.items()} 99 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 100 | **loss_dict_reduced_scaled, 101 | **loss_dict_reduced_unscaled) 102 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 103 | 104 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 105 | results = postprocessors['bbox'](outputs, orig_target_sizes) 106 | if 'segm' in postprocessors.keys(): 107 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 108 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) 109 | res = {target['image_id'].item(): output for target, output in zip(targets, results)} 110 | if coco_evaluator is not None: 111 | coco_evaluator.update(res) 112 | 113 | if panoptic_evaluator is not None: 114 | res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes) 115 | for i, target in enumerate(targets): 116 | image_id = target["image_id"].item() 117 | file_name = f"{image_id:012d}.png" 118 | res_pano[i]["image_id"] = image_id 119 | res_pano[i]["file_name"] = file_name 120 | 121 | panoptic_evaluator.update(res_pano) 122 | 123 | # gather the stats from all processes 124 | metric_logger.synchronize_between_processes() 125 | print("Averaged stats:", metric_logger) 126 | if coco_evaluator is not None: 127 | coco_evaluator.synchronize_between_processes() 128 | if panoptic_evaluator is not None: 129 | panoptic_evaluator.synchronize_between_processes() 130 | 131 | # accumulate predictions from all images 132 | if coco_evaluator is not None: 133 | coco_evaluator.accumulate() 134 | coco_evaluator.summarize() 135 | panoptic_res = None 136 | if panoptic_evaluator is not None: 137 | panoptic_res = panoptic_evaluator.summarize() 138 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 139 | if coco_evaluator is not None: 140 | if 'bbox' in postprocessors.keys(): 141 | stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() 142 | if 'segm' in postprocessors.keys(): 143 | stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() 144 | if panoptic_res is not None: 145 | stats['PQ_all'] = panoptic_res["All"] 146 | stats['PQ_th'] = panoptic_res["Things"] 147 | stats['PQ_st'] = panoptic_res["Stuff"] 148 | return stats, coco_evaluator 149 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inference code for VisTR 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | ''' 5 | import argparse 6 | import datetime 7 | import json 8 | import random 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | 16 | import datasets 17 | import util.misc as utils 18 | from datasets import build_dataset, get_coco_api_from_dataset 19 | from engine import evaluate, train_one_epoch 20 | from models import build_model 21 | import torchvision.transforms as T 22 | import matplotlib.pyplot as plt 23 | import os 24 | from PIL import Image 25 | import math 26 | import torch.nn.functional as F 27 | import json 28 | from scipy.optimize import linear_sum_assignment 29 | import pycocotools.mask as mask_util 30 | 31 | 32 | 33 | def get_args_parser(): 34 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 35 | parser.add_argument('--lr', default=1e-4, type=float) 36 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 37 | parser.add_argument('--batch_size', default=2, type=int) 38 | parser.add_argument('--weight_decay', default=1e-4, type=float) 39 | parser.add_argument('--epochs', default=150, type=int) 40 | parser.add_argument('--lr_drop', default=100, type=int) 41 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 42 | help='gradient clipping max norm') 43 | 44 | # Model parameters 45 | parser.add_argument('--model_path', type=str, default=None, 46 | help="Path to the model weights.") 47 | # * Backbone 48 | parser.add_argument('--backbone', default='resnet101', type=str, 49 | help="Name of the convolutional backbone to use") 50 | parser.add_argument('--dilation', action='store_true', 51 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 52 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 53 | help="Type of positional embedding to use on top of the image features") 54 | 55 | # * Transformer 56 | parser.add_argument('--enc_layers', default=6, type=int, 57 | help="Number of encoding layers in the transformer") 58 | parser.add_argument('--dec_layers', default=6, type=int, 59 | help="Number of decoding layers in the transformer") 60 | parser.add_argument('--dim_feedforward', default=2048, type=int, 61 | help="Intermediate size of the feedforward layers in the transformer blocks") 62 | parser.add_argument('--hidden_dim', default=384, type=int, 63 | help="Size of the embeddings (dimension of the transformer)") 64 | parser.add_argument('--dropout', default=0.1, type=float, 65 | help="Dropout applied in the transformer") 66 | parser.add_argument('--nheads', default=8, type=int, 67 | help="Number of attention heads inside the transformer's attentions") 68 | parser.add_argument('--num_frames', default=36, type=int, 69 | help="Number of frames") 70 | parser.add_argument('--num_ins', default=10, type=int, 71 | help="Number of instances") 72 | parser.add_argument('--num_queries', default=360, type=int, 73 | help="Number of query slots") 74 | parser.add_argument('--pre_norm', action='store_true') 75 | 76 | # * Segmentation 77 | parser.add_argument('--masks', action='store_true', 78 | help="Train segmentation head if the flag is provided") 79 | 80 | # Loss 81 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 82 | help="Disables auxiliary decoding losses (loss at each layer)") 83 | # * Matcher 84 | parser.add_argument('--set_cost_class', default=1, type=float, 85 | help="Class coefficient in the matching cost") 86 | parser.add_argument('--set_cost_bbox', default=5, type=float, 87 | help="L1 box coefficient in the matching cost") 88 | parser.add_argument('--set_cost_giou', default=2, type=float, 89 | help="giou box coefficient in the matching cost") 90 | # * Loss coefficients 91 | parser.add_argument('--mask_loss_coef', default=1, type=float) 92 | parser.add_argument('--dice_loss_coef', default=1, type=float) 93 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 94 | parser.add_argument('--giou_loss_coef', default=2, type=float) 95 | parser.add_argument('--eos_coef', default=0.1, type=float, 96 | help="Relative classification weight of the no-object class") 97 | 98 | # dataset parameters 99 | parser.add_argument('--img_path', default='data/ytvos/valid/JPEGImages/') 100 | parser.add_argument('--ann_path', default='data/ytvos/annotations/instances_val_sub.json') 101 | parser.add_argument('--save_path', default='results.json') 102 | parser.add_argument('--dataset_file', default='ytvos') 103 | parser.add_argument('--coco_path', type=str) 104 | parser.add_argument('--coco_panoptic_path', type=str) 105 | parser.add_argument('--remove_difficult', action='store_true') 106 | 107 | parser.add_argument('--output_dir', default='output_ytvos', 108 | help='path where to save, empty for no saving') 109 | parser.add_argument('--device', default='cuda', 110 | help='device to use for training / testing') 111 | parser.add_argument('--seed', default=42, type=int) 112 | parser.add_argument('--resume', default='', help='resume from checkpoint') 113 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 114 | help='start epoch') 115 | #parser.add_argument('--eval', action='store_true') 116 | parser.add_argument('--eval', action='store_false') 117 | parser.add_argument('--num_workers', default=0, type=int) 118 | 119 | # distributed training parameters 120 | parser.add_argument('--world_size', default=1, type=int, 121 | help='number of distributed processes') 122 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 123 | return parser 124 | 125 | CLASSES=['person','giant_panda','lizard','parrot','skateboard','sedan','ape', 126 | 'dog','snake','monkey','hand','rabbit','duck','cat','cow','fish', 127 | 'train','horse','turtle','bear','motorbike','giraffe','leopard', 128 | 'fox','deer','owl','surfboard','airplane','truck','zebra','tiger', 129 | 'elephant','snowboard','boat','shark','mouse','frog','eagle','earless_seal', 130 | 'tennis_racket'] 131 | COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125], 132 | [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933], 133 | [0.494, 0.000, 0.556], [0.494, 0.000, 0.000], [0.000, 0.745, 0.000], 134 | [0.700, 0.300, 0.600]] 135 | transform = T.Compose([ 136 | T.Resize(300), 137 | T.ToTensor(), 138 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 139 | ]) 140 | 141 | 142 | # for output bounding box post-processing 143 | def box_cxcywh_to_xyxy(x): 144 | x_c, y_c, w, h = x.unbind(1) 145 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 146 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 147 | return torch.stack(b, dim=1) 148 | 149 | def rescale_bboxes(out_bbox, size): 150 | img_w, img_h = size 151 | b = box_cxcywh_to_xyxy(out_bbox) 152 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 153 | return b 154 | 155 | 156 | 157 | def main(args): 158 | 159 | device = torch.device(args.device) 160 | 161 | # fix the seed for reproducibility 162 | seed = args.seed + utils.get_rank() 163 | torch.manual_seed(seed) 164 | np.random.seed(seed) 165 | random.seed(seed) 166 | num_frames = args.num_frames 167 | num_ins = args.num_ins 168 | with torch.no_grad(): 169 | model, criterion, postprocessors = build_model(args) 170 | model.to(device) 171 | state_dict = torch.load(args.model_path)['model'] 172 | model.load_state_dict(state_dict) 173 | folder = args.img_path 174 | videos = json.load(open(args.ann_path,'rb'))['videos'] 175 | vis_num = len(videos) 176 | result = [] 177 | for i in range(vis_num): 178 | print("Process video: ",i) 179 | id_ = videos[i]['id'] 180 | length = videos[i]['length'] 181 | file_names = videos[i]['file_names'] 182 | clip_num = math.ceil(length/num_frames) 183 | 184 | img_set=[] 185 | if length0.5 203 | pred_logits = logits.reshape(num_frames,num_ins,logits.shape[-1]).cpu().detach().numpy() 204 | pred_masks = pred_masks[:length] 205 | pred_logits = pred_logits[:length] 206 | pred_scores = np.max(pred_logits,axis=-1) 207 | pred_logits = np.argmax(pred_logits,axis=-1) 208 | for m in range(num_ins): 209 | if pred_masks[:,m].max()==0: 210 | continue 211 | score = pred_scores[:,m].mean() 212 | #category_id = pred_logits[:,m][pred_scores[:,m].argmax()] 213 | category_id = np.argmax(np.bincount(pred_logits[:,m])) 214 | instance = {'video_id':id_, 'score':float(score), 'category_id':int(category_id)} 215 | segmentation = [] 216 | for n in range(length): 217 | if pred_scores[n,m]<0.001: 218 | segmentation.append(None) 219 | else: 220 | mask = (pred_masks[n,m]).astype(np.uint8) 221 | rle = mask_util.encode(np.array(mask[:,:,np.newaxis], order='F'))[0] 222 | rle["counts"] = rle["counts"].decode("utf-8") 223 | segmentation.append(rle) 224 | instance['segmentations'] = segmentation 225 | result.append(instance) 226 | with open(args.save_path, 'w', encoding='utf-8') as f: 227 | json.dump(result,f) 228 | 229 | 230 | 231 | 232 | 233 | if __name__ == '__main__': 234 | parser = argparse.ArgumentParser('VisTR inference script', parents=[get_args_parser()]) 235 | args = parser.parse_args() 236 | main(args) 237 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script of VisTR 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import argparse 6 | import datetime 7 | import json 8 | import random 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | 16 | import datasets 17 | import util.misc as utils 18 | from datasets import build_dataset, get_coco_api_from_dataset 19 | from engine import evaluate, train_one_epoch 20 | from models import build_model 21 | 22 | 23 | def get_args_parser(): 24 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 25 | parser.add_argument('--lr', default=1e-4, type=float) 26 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 27 | parser.add_argument('--batch_size', default=1, type=int) 28 | parser.add_argument('--weight_decay', default=1e-4, type=float) 29 | parser.add_argument('--epochs', default=18, type=int) 30 | parser.add_argument('--lr_drop', default=12, type=int) 31 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 32 | help='gradient clipping max norm') 33 | 34 | # Model parameters 35 | parser.add_argument('--pretrained_weights', type=str, default="r101_pretrained.pth", 36 | help="Path to the pretrained model.") 37 | # * Backbone 38 | parser.add_argument('--backbone', default='resnet101', type=str, 39 | help="Name of the convolutional backbone to use") 40 | parser.add_argument('--dilation', action='store_true', 41 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 42 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 43 | help="Type of positional embedding to use on top of the image features") 44 | 45 | # * Transformer 46 | parser.add_argument('--enc_layers', default=6, type=int, 47 | help="Number of encoding layers in the transformer") 48 | parser.add_argument('--dec_layers', default=6, type=int, 49 | help="Number of decoding layers in the transformer") 50 | parser.add_argument('--dim_feedforward', default=2048, type=int, 51 | help="Intermediate size of the feedforward layers in the transformer blocks") 52 | parser.add_argument('--hidden_dim', default=384, type=int, 53 | help="Size of the embeddings (dimension of the transformer)") 54 | parser.add_argument('--dropout', default=0.1, type=float, 55 | help="Dropout applied in the transformer") 56 | parser.add_argument('--nheads', default=8, type=int, 57 | help="Number of attention heads inside the transformer's attentions") 58 | parser.add_argument('--num_frames', default=36, type=int, 59 | help="Number of frames") 60 | parser.add_argument('--num_queries', default=360, type=int, 61 | help="Number of query slots") 62 | parser.add_argument('--pre_norm', action='store_true') 63 | 64 | # * Segmentation 65 | parser.add_argument('--masks', action='store_true', 66 | help="Train segmentation head if the flag is provided") 67 | 68 | # Loss 69 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 70 | help="Disables auxiliary decoding losses (loss at each layer)") 71 | # * Matcher 72 | parser.add_argument('--set_cost_class', default=1, type=float, 73 | help="Class coefficient in the matching cost") 74 | parser.add_argument('--set_cost_bbox', default=5, type=float, 75 | help="L1 box coefficient in the matching cost") 76 | parser.add_argument('--set_cost_giou', default=2, type=float, 77 | help="giou box coefficient in the matching cost") 78 | # * Loss coefficients 79 | parser.add_argument('--mask_loss_coef', default=1, type=float) 80 | parser.add_argument('--dice_loss_coef', default=1, type=float) 81 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 82 | parser.add_argument('--giou_loss_coef', default=2, type=float) 83 | parser.add_argument('--eos_coef', default=0.1, type=float, 84 | help="Relative classification weight of the no-object class") 85 | 86 | # dataset parameters 87 | parser.add_argument('--dataset_file', default='ytvos') 88 | parser.add_argument('--ytvos_path', type=str) 89 | parser.add_argument('--remove_difficult', action='store_true') 90 | 91 | parser.add_argument('--output_dir', default='r101_vistr', 92 | help='path where to save, empty for no saving') 93 | parser.add_argument('--device', default='cuda', 94 | help='device to use for training / testing') 95 | parser.add_argument('--seed', default=42, type=int) 96 | parser.add_argument('--resume', default='', help='resume from checkpoint') 97 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 98 | help='start epoch') 99 | parser.add_argument('--eval', action='store_true') 100 | parser.add_argument('--num_workers', default=4, type=int) 101 | 102 | # distributed training parameters 103 | parser.add_argument('--world_size', default=1, type=int, 104 | help='number of distributed processes') 105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 106 | return parser 107 | 108 | 109 | def main(args): 110 | utils.init_distributed_mode(args) 111 | print("git:\n {}\n".format(utils.get_sha())) 112 | 113 | 114 | device = torch.device(args.device) 115 | 116 | # fix the seed for reproducibility 117 | seed = args.seed + utils.get_rank() 118 | torch.manual_seed(seed) 119 | np.random.seed(seed) 120 | random.seed(seed) 121 | 122 | model, criterion, postprocessors = build_model(args) 123 | model.to(device) 124 | 125 | model_without_ddp = model 126 | if args.distributed: 127 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 128 | model_without_ddp = model.module 129 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 130 | print('number of params:', n_parameters) 131 | 132 | param_dicts = [ 133 | {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, 134 | { 135 | "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], 136 | "lr": args.lr_backbone, 137 | }, 138 | ] 139 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 140 | weight_decay=args.weight_decay) 141 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 142 | 143 | # no validation ground truth for ytvos dataset 144 | dataset_train = build_dataset(image_set='train', args=args) 145 | if args.distributed: 146 | sampler_train = DistributedSampler(dataset_train) 147 | else: 148 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 149 | 150 | batch_sampler_train = torch.utils.data.BatchSampler( 151 | sampler_train, args.batch_size, drop_last=True) 152 | 153 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 154 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 155 | 156 | output_dir = Path(args.output_dir) 157 | 158 | # load coco pretrained weight 159 | checkpoint = torch.load(args.pretrained_weights, map_location='cpu')['model'] 160 | del checkpoint["vistr.class_embed.weight"] 161 | del checkpoint["vistr.class_embed.bias"] 162 | del checkpoint["vistr.query_embed.weight"] 163 | model.module.load_state_dict(checkpoint,strict=False) 164 | 165 | if args.resume: 166 | if args.resume.startswith('https'): 167 | checkpoint = torch.hub.load_state_dict_from_url( 168 | args.resume, map_location='cpu', check_hash=True) 169 | else: 170 | checkpoint = torch.load(args.resume, map_location='cpu') 171 | model_without_ddp.load_state_dict(checkpoint['model']) 172 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 173 | optimizer.load_state_dict(checkpoint['optimizer']) 174 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 175 | args.start_epoch = checkpoint['epoch'] + 1 176 | 177 | print("Start training") 178 | start_time = time.time() 179 | for epoch in range(args.start_epoch, args.epochs): 180 | if args.distributed: 181 | sampler_train.set_epoch(epoch) 182 | train_stats = train_one_epoch( 183 | model, criterion, data_loader_train, optimizer, device, epoch, 184 | args.clip_max_norm) 185 | lr_scheduler.step() 186 | if args.output_dir: 187 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 188 | # extra checkpoint before LR drop and every epochs 189 | if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0: 190 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') 191 | for checkpoint_path in checkpoint_paths: 192 | utils.save_on_master({ 193 | 'model': model_without_ddp.state_dict(), 194 | 'optimizer': optimizer.state_dict(), 195 | 'lr_scheduler': lr_scheduler.state_dict(), 196 | 'epoch': epoch, 197 | 'args': args, 198 | }, checkpoint_path) 199 | 200 | 201 | total_time = time.time() - start_time 202 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 203 | print('Training time {}'.format(total_time_str)) 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser('VisTR training and evaluation script', parents=[get_args_parser()]) 208 | args = parser.parse_args() 209 | if args.output_dir: 210 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 211 | main(args) 212 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .vistr import build 2 | 3 | 4 | def build_model(args): 5 | return build(args) 6 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbone modules. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from util.misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | 23 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 24 | without which any other models than torchvision.models.resnet[18,34,50,101] 25 | produce nans. 26 | """ 27 | 28 | def __init__(self, n): 29 | super(FrozenBatchNorm2d, self).__init__() 30 | self.register_buffer("weight", torch.ones(n)) 31 | self.register_buffer("bias", torch.zeros(n)) 32 | self.register_buffer("running_mean", torch.zeros(n)) 33 | self.register_buffer("running_var", torch.ones(n)) 34 | 35 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 36 | missing_keys, unexpected_keys, error_msgs): 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | # move reshapes to the beginning 47 | # to make it fuser-friendly 48 | w = self.weight.reshape(1, -1, 1, 1) 49 | b = self.bias.reshape(1, -1, 1, 1) 50 | rv = self.running_var.reshape(1, -1, 1, 1) 51 | rm = self.running_mean.reshape(1, -1, 1, 1) 52 | eps = 1e-5 53 | scale = w * (rv + eps).rsqrt() 54 | bias = b - rm * scale 55 | return x * scale + bias 56 | 57 | 58 | class BackboneBase(nn.Module): 59 | 60 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 61 | super().__init__() 62 | for name, parameter in backbone.named_parameters(): 63 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 64 | parameter.requires_grad_(False) 65 | if return_interm_layers: 66 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 67 | else: 68 | return_layers = {'layer4': "0"} 69 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 70 | self.num_channels = num_channels 71 | 72 | def forward(self, tensor_list: NestedTensor): 73 | xs = self.body(tensor_list.tensors) 74 | out: Dict[str, NestedTensor] = {} 75 | for name, x in xs.items(): 76 | m = tensor_list.mask 77 | assert m is not None 78 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 79 | out[name] = NestedTensor(x, mask) 80 | return out 81 | 82 | 83 | class Backbone(BackboneBase): 84 | """ResNet backbone with frozen BatchNorm.""" 85 | def __init__(self, name: str, 86 | train_backbone: bool, 87 | return_interm_layers: bool, 88 | dilation: bool): 89 | backbone = getattr(torchvision.models, name)( 90 | replace_stride_with_dilation=[False, False, dilation], 91 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 92 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 93 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 94 | 95 | 96 | class Joiner(nn.Sequential): 97 | def __init__(self, backbone, position_embedding): 98 | super().__init__(backbone, position_embedding) 99 | 100 | def forward(self, tensor_list: NestedTensor): 101 | xs = self[0](tensor_list) 102 | out: List[NestedTensor] = [] 103 | pos = [] 104 | for name, x in xs.items(): 105 | out.append(x) 106 | # position encoding 107 | pos.append(self[1](x).to(x.tensors.dtype)) 108 | 109 | return out, pos 110 | 111 | 112 | def build_backbone(args): 113 | position_embedding = build_position_encoding(args) 114 | train_backbone = args.lr_backbone > 0 115 | return_interm_layers = args.masks 116 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 117 | model = Joiner(backbone, position_embedding) 118 | model.num_channels = backbone.num_channels 119 | return model 120 | -------------------------------------------------------------------------------- /models/dcn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuqingWang1029/VisTR/445c9e4e787a1fb3c959d7e7bb6ecf809bdac155/models/dcn/__init__.py -------------------------------------------------------------------------------- /models/dcn/deform_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import math 3 | from functools import lru_cache 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | from torch.autograd.function import once_differentiable 8 | from torch.nn.modules.utils import _pair 9 | 10 | from . import _C 11 | 12 | class _NewEmptyTensorOp(torch.autograd.Function): 13 | @staticmethod 14 | def forward(ctx, x, new_shape): 15 | ctx.shape = x.shape 16 | return x.new_empty(new_shape) 17 | 18 | @staticmethod 19 | def backward(ctx, grad): 20 | shape = ctx.shape 21 | return _NewEmptyTensorOp.apply(grad, shape), None 22 | 23 | class _DeformConv(Function): 24 | @staticmethod 25 | def forward( 26 | ctx, 27 | input, 28 | offset, 29 | weight, 30 | stride=1, 31 | padding=0, 32 | dilation=1, 33 | groups=1, 34 | deformable_groups=1, 35 | im2col_step=64, 36 | ): 37 | if input is not None and input.dim() != 4: 38 | raise ValueError( 39 | "Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()) 40 | ) 41 | ctx.stride = _pair(stride) 42 | ctx.padding = _pair(padding) 43 | ctx.dilation = _pair(dilation) 44 | ctx.groups = groups 45 | ctx.deformable_groups = deformable_groups 46 | ctx.im2col_step = im2col_step 47 | 48 | ctx.save_for_backward(input, offset, weight) 49 | 50 | output = input.new_empty( 51 | _DeformConv._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride) 52 | ) 53 | 54 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones 55 | 56 | if not input.is_cuda: 57 | raise NotImplementedError("Deformable Conv is not supported on CPUs!") 58 | else: 59 | cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step) 60 | assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" 61 | 62 | _C.deform_conv_forward( 63 | input, 64 | weight, 65 | offset, 66 | output, 67 | ctx.bufs_[0], 68 | ctx.bufs_[1], 69 | weight.size(3), 70 | weight.size(2), 71 | ctx.stride[1], 72 | ctx.stride[0], 73 | ctx.padding[1], 74 | ctx.padding[0], 75 | ctx.dilation[1], 76 | ctx.dilation[0], 77 | ctx.groups, 78 | ctx.deformable_groups, 79 | cur_im2col_step, 80 | ) 81 | return output 82 | 83 | @staticmethod 84 | @once_differentiable 85 | def backward(ctx, grad_output): 86 | input, offset, weight = ctx.saved_tensors 87 | 88 | grad_input = grad_offset = grad_weight = None 89 | 90 | if not grad_output.is_cuda: 91 | raise NotImplementedError("Deformable Conv is not supported on CPUs!") 92 | else: 93 | cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step) 94 | assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" 95 | 96 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 97 | grad_input = torch.zeros_like(input) 98 | grad_offset = torch.zeros_like(offset) 99 | _C.deform_conv_backward_input( 100 | input, 101 | offset, 102 | grad_output, 103 | grad_input, 104 | grad_offset, 105 | weight, 106 | ctx.bufs_[0], 107 | weight.size(3), 108 | weight.size(2), 109 | ctx.stride[1], 110 | ctx.stride[0], 111 | ctx.padding[1], 112 | ctx.padding[0], 113 | ctx.dilation[1], 114 | ctx.dilation[0], 115 | ctx.groups, 116 | ctx.deformable_groups, 117 | cur_im2col_step, 118 | ) 119 | 120 | if ctx.needs_input_grad[2]: 121 | grad_weight = torch.zeros_like(weight) 122 | _C.deform_conv_backward_filter( 123 | input, 124 | offset, 125 | grad_output, 126 | grad_weight, 127 | ctx.bufs_[0], 128 | ctx.bufs_[1], 129 | weight.size(3), 130 | weight.size(2), 131 | ctx.stride[1], 132 | ctx.stride[0], 133 | ctx.padding[1], 134 | ctx.padding[0], 135 | ctx.dilation[1], 136 | ctx.dilation[0], 137 | ctx.groups, 138 | ctx.deformable_groups, 139 | 1, 140 | cur_im2col_step, 141 | ) 142 | 143 | return grad_input, grad_offset, grad_weight, None, None, None, None, None, None 144 | 145 | @staticmethod 146 | def _output_size(input, weight, padding, dilation, stride): 147 | channels = weight.size(0) 148 | output_size = (input.size(0), channels) 149 | for d in range(input.dim() - 2): 150 | in_size = input.size(d + 2) 151 | pad = padding[d] 152 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 153 | stride_ = stride[d] 154 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,) 155 | if not all(map(lambda s: s > 0, output_size)): 156 | raise ValueError( 157 | "convolution input is too small (output would be {})".format( 158 | "x".join(map(str, output_size)) 159 | ) 160 | ) 161 | return output_size 162 | 163 | @staticmethod 164 | @lru_cache(maxsize=128) 165 | def _cal_im2col_step(input_size, default_size): 166 | """ 167 | Calculate proper im2col step size, which should be divisible by input_size and not larger 168 | than prefer_size. Meanwhile the step size should be as large as possible to be more 169 | efficient. So we choose the largest one among all divisors of input_size which are smaller 170 | than prefer_size. 171 | :param input_size: input batch size . 172 | :param default_size: default preferred im2col step size. 173 | :return: the largest proper step size. 174 | """ 175 | if input_size <= default_size: 176 | return input_size 177 | best_step = 1 178 | for step in range(2, min(int(math.sqrt(input_size)) + 1, default_size)): 179 | if input_size % step == 0: 180 | if input_size // step <= default_size: 181 | return input_size // step 182 | best_step = step 183 | 184 | return best_step 185 | 186 | 187 | class _ModulatedDeformConv(Function): 188 | @staticmethod 189 | def forward( 190 | ctx, 191 | input, 192 | offset, 193 | mask, 194 | weight, 195 | bias=None, 196 | stride=1, 197 | padding=0, 198 | dilation=1, 199 | groups=1, 200 | deformable_groups=1, 201 | ): 202 | ctx.stride = stride 203 | ctx.padding = padding 204 | ctx.dilation = dilation 205 | ctx.groups = groups 206 | ctx.deformable_groups = deformable_groups 207 | ctx.with_bias = bias is not None 208 | if not ctx.with_bias: 209 | bias = input.new_empty(1) # fake tensor 210 | if not input.is_cuda: 211 | raise NotImplementedError("Deformable Conv is not supported on CPUs!") 212 | if ( 213 | weight.requires_grad 214 | or mask.requires_grad 215 | or offset.requires_grad 216 | or input.requires_grad 217 | ): 218 | ctx.save_for_backward(input, offset, mask, weight, bias) 219 | output = input.new_empty(_ModulatedDeformConv._infer_shape(ctx, input, weight)) 220 | ctx._bufs = [input.new_empty(0), input.new_empty(0)] 221 | _C.modulated_deform_conv_forward( 222 | input, 223 | weight, 224 | bias, 225 | ctx._bufs[0], 226 | offset, 227 | mask, 228 | output, 229 | ctx._bufs[1], 230 | weight.shape[2], 231 | weight.shape[3], 232 | ctx.stride, 233 | ctx.stride, 234 | ctx.padding, 235 | ctx.padding, 236 | ctx.dilation, 237 | ctx.dilation, 238 | ctx.groups, 239 | ctx.deformable_groups, 240 | ctx.with_bias, 241 | ) 242 | return output 243 | 244 | @staticmethod 245 | @once_differentiable 246 | def backward(ctx, grad_output): 247 | if not grad_output.is_cuda: 248 | raise NotImplementedError("Deformable Conv is not supported on CPUs!") 249 | input, offset, mask, weight, bias = ctx.saved_tensors 250 | grad_input = torch.zeros_like(input) 251 | grad_offset = torch.zeros_like(offset) 252 | grad_mask = torch.zeros_like(mask) 253 | grad_weight = torch.zeros_like(weight) 254 | grad_bias = torch.zeros_like(bias) 255 | _C.modulated_deform_conv_backward( 256 | input, 257 | weight, 258 | bias, 259 | ctx._bufs[0], 260 | offset, 261 | mask, 262 | ctx._bufs[1], 263 | grad_input, 264 | grad_weight, 265 | grad_bias, 266 | grad_offset, 267 | grad_mask, 268 | grad_output, 269 | weight.shape[2], 270 | weight.shape[3], 271 | ctx.stride, 272 | ctx.stride, 273 | ctx.padding, 274 | ctx.padding, 275 | ctx.dilation, 276 | ctx.dilation, 277 | ctx.groups, 278 | ctx.deformable_groups, 279 | ctx.with_bias, 280 | ) 281 | if not ctx.with_bias: 282 | grad_bias = None 283 | 284 | return ( 285 | grad_input, 286 | grad_offset, 287 | grad_mask, 288 | grad_weight, 289 | grad_bias, 290 | None, 291 | None, 292 | None, 293 | None, 294 | None, 295 | ) 296 | 297 | @staticmethod 298 | def _infer_shape(ctx, input, weight): 299 | n = input.size(0) 300 | channels_out = weight.size(0) 301 | height, width = input.shape[2:4] 302 | kernel_h, kernel_w = weight.shape[2:4] 303 | height_out = ( 304 | height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1) 305 | ) // ctx.stride + 1 306 | width_out = ( 307 | width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1) 308 | ) // ctx.stride + 1 309 | return n, channels_out, height_out, width_out 310 | 311 | 312 | deform_conv = _DeformConv.apply 313 | modulated_deform_conv = _ModulatedDeformConv.apply 314 | 315 | 316 | class DeformConv(nn.Module): 317 | def __init__( 318 | self, 319 | in_channels, 320 | out_channels, 321 | kernel_size, 322 | stride=1, 323 | padding=0, 324 | dilation=1, 325 | groups=1, 326 | deformable_groups=1, 327 | bias=False, 328 | norm=None, 329 | activation=None, 330 | ): 331 | """ 332 | Deformable convolution from :paper:`deformconv`. 333 | 334 | Arguments are similar to :class:`Conv2D`. Extra arguments: 335 | 336 | Args: 337 | deformable_groups (int): number of groups used in deformable convolution. 338 | norm (nn.Module, optional): a normalization layer 339 | activation (callable(Tensor) -> Tensor): a callable activation function 340 | """ 341 | super(DeformConv, self).__init__() 342 | 343 | assert not bias 344 | assert in_channels % groups == 0, "in_channels {} cannot be divisible by groups {}".format( 345 | in_channels, groups 346 | ) 347 | assert ( 348 | out_channels % groups == 0 349 | ), "out_channels {} cannot be divisible by groups {}".format(out_channels, groups) 350 | 351 | self.in_channels = in_channels 352 | self.out_channels = out_channels 353 | self.kernel_size = _pair(kernel_size) 354 | self.stride = _pair(stride) 355 | self.padding = _pair(padding) 356 | self.dilation = _pair(dilation) 357 | self.groups = groups 358 | self.deformable_groups = deformable_groups 359 | self.norm = norm 360 | self.activation = activation 361 | 362 | self.weight = nn.Parameter( 363 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size) 364 | ) 365 | self.bias = None 366 | 367 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 368 | 369 | def forward(self, x, offset): 370 | if x.numel() == 0: 371 | # When input is empty, we want to return a empty tensor with "correct" shape, 372 | # So that the following operations will not panic 373 | # if they check for the shape of the tensor. 374 | # This computes the height and width of the output tensor 375 | output_shape = [ 376 | (i + 2 * p - (di * (k - 1) + 1)) // s + 1 377 | for i, p, di, k, s in zip( 378 | x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride 379 | ) 380 | ] 381 | output_shape = [x.shape[0], self.weight.shape[0]] + output_shape 382 | return _NewEmptyTensorOp.apply(x, output_shape) 383 | 384 | x = deform_conv( 385 | x, 386 | offset, 387 | self.weight, 388 | self.stride, 389 | self.padding, 390 | self.dilation, 391 | self.groups, 392 | self.deformable_groups, 393 | ) 394 | if self.norm is not None: 395 | x = self.norm(x) 396 | if self.activation is not None: 397 | x = self.activation(x) 398 | return x 399 | 400 | def extra_repr(self): 401 | tmpstr = "in_channels=" + str(self.in_channels) 402 | tmpstr += ", out_channels=" + str(self.out_channels) 403 | tmpstr += ", kernel_size=" + str(self.kernel_size) 404 | tmpstr += ", stride=" + str(self.stride) 405 | tmpstr += ", padding=" + str(self.padding) 406 | tmpstr += ", dilation=" + str(self.dilation) 407 | tmpstr += ", groups=" + str(self.groups) 408 | tmpstr += ", deformable_groups=" + str(self.deformable_groups) 409 | tmpstr += ", bias=False" 410 | return tmpstr 411 | 412 | 413 | class ModulatedDeformConv(nn.Module): 414 | def __init__( 415 | self, 416 | in_channels, 417 | out_channels, 418 | kernel_size, 419 | stride=1, 420 | padding=0, 421 | dilation=1, 422 | groups=1, 423 | deformable_groups=1, 424 | bias=True, 425 | norm=None, 426 | activation=None, 427 | ): 428 | """ 429 | Modulated deformable convolution from :paper:`deformconv2`. 430 | 431 | Arguments are similar to :class:`Conv2D`. Extra arguments: 432 | 433 | Args: 434 | deformable_groups (int): number of groups used in deformable convolution. 435 | norm (nn.Module, optional): a normalization layer 436 | activation (callable(Tensor) -> Tensor): a callable activation function 437 | """ 438 | super(ModulatedDeformConv, self).__init__() 439 | self.in_channels = in_channels 440 | self.out_channels = out_channels 441 | self.kernel_size = _pair(kernel_size) 442 | self.stride = stride 443 | self.padding = padding 444 | self.dilation = dilation 445 | self.groups = groups 446 | self.deformable_groups = deformable_groups 447 | self.with_bias = bias 448 | self.norm = norm 449 | self.activation = activation 450 | 451 | self.weight = nn.Parameter( 452 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) 453 | ) 454 | if bias: 455 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 456 | else: 457 | self.bias = None 458 | 459 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") 460 | if self.bias is not None: 461 | nn.init.constant_(self.bias, 0) 462 | 463 | def forward(self, x, offset, mask): 464 | if x.numel() == 0: 465 | output_shape = [ 466 | (i + 2 * p - (di * (k - 1) + 1)) // s + 1 467 | for i, p, di, k, s in zip( 468 | x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride 469 | ) 470 | ] 471 | output_shape = [x.shape[0], self.weight.shape[0]] + output_shape 472 | return _NewEmptyTensorOp.apply(x, output_shape) 473 | 474 | x = modulated_deform_conv( 475 | x, 476 | offset, 477 | mask, 478 | self.weight, 479 | self.bias, 480 | self.stride, 481 | self.padding, 482 | self.dilation, 483 | self.groups, 484 | self.deformable_groups, 485 | ) 486 | if self.norm is not None: 487 | x = self.norm(x) 488 | if self.activation is not None: 489 | x = self.activation(x) 490 | return x 491 | 492 | def extra_repr(self): 493 | tmpstr = "in_channels=" + str(self.in_channels) 494 | tmpstr += ", out_channels=" + str(self.out_channels) 495 | tmpstr += ", kernel_size=" + str(self.kernel_size) 496 | tmpstr += ", stride=" + str(self.stride) 497 | tmpstr += ", padding=" + str(self.padding) 498 | tmpstr += ", dilation=" + str(self.dilation) 499 | tmpstr += ", groups=" + str(self.groups) 500 | tmpstr += ", deformable_groups=" + str(self.deformable_groups) 501 | tmpstr += ", bias=" + str(self.with_bias) 502 | return tmpstr 503 | -------------------------------------------------------------------------------- /models/dcn/deformable/deform_conv.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include "deform_conv.h" 3 | 4 | #if defined(WITH_CUDA) || defined(WITH_HIP) 5 | extern int get_cudart_version(); 6 | #endif 7 | 8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 9 | 10 | m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); 11 | m.def( 12 | "deform_conv_backward_input", 13 | &deform_conv_backward_input, 14 | "deform_conv_backward_input"); 15 | m.def( 16 | "deform_conv_backward_filter", 17 | &deform_conv_backward_filter, 18 | "deform_conv_backward_filter"); 19 | m.def( 20 | "modulated_deform_conv_forward", 21 | &modulated_deform_conv_forward, 22 | "modulated_deform_conv_forward"); 23 | m.def( 24 | "modulated_deform_conv_backward", 25 | &modulated_deform_conv_backward, 26 | "modulated_deform_conv_backward"); 27 | } -------------------------------------------------------------------------------- /models/dcn/deformable/deform_conv.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | #pragma once 3 | #include 4 | 5 | 6 | 7 | #if defined(WITH_CUDA) || defined(WITH_HIP) 8 | int deform_conv_forward_cuda( 9 | at::Tensor input, 10 | at::Tensor weight, 11 | at::Tensor offset, 12 | at::Tensor output, 13 | at::Tensor columns, 14 | at::Tensor ones, 15 | int kW, 16 | int kH, 17 | int dW, 18 | int dH, 19 | int padW, 20 | int padH, 21 | int dilationW, 22 | int dilationH, 23 | int group, 24 | int deformable_group, 25 | int im2col_step); 26 | 27 | int deform_conv_backward_input_cuda( 28 | at::Tensor input, 29 | at::Tensor offset, 30 | at::Tensor gradOutput, 31 | at::Tensor gradInput, 32 | at::Tensor gradOffset, 33 | at::Tensor weight, 34 | at::Tensor columns, 35 | int kW, 36 | int kH, 37 | int dW, 38 | int dH, 39 | int padW, 40 | int padH, 41 | int dilationW, 42 | int dilationH, 43 | int group, 44 | int deformable_group, 45 | int im2col_step); 46 | 47 | int deform_conv_backward_parameters_cuda( 48 | at::Tensor input, 49 | at::Tensor offset, 50 | at::Tensor gradOutput, 51 | at::Tensor gradWeight, // at::Tensor gradBias, 52 | at::Tensor columns, 53 | at::Tensor ones, 54 | int kW, 55 | int kH, 56 | int dW, 57 | int dH, 58 | int padW, 59 | int padH, 60 | int dilationW, 61 | int dilationH, 62 | int group, 63 | int deformable_group, 64 | float scale, 65 | int im2col_step); 66 | 67 | void modulated_deform_conv_cuda_forward( 68 | at::Tensor input, 69 | at::Tensor weight, 70 | at::Tensor bias, 71 | at::Tensor ones, 72 | at::Tensor offset, 73 | at::Tensor mask, 74 | at::Tensor output, 75 | at::Tensor columns, 76 | int kernel_h, 77 | int kernel_w, 78 | const int stride_h, 79 | const int stride_w, 80 | const int pad_h, 81 | const int pad_w, 82 | const int dilation_h, 83 | const int dilation_w, 84 | const int group, 85 | const int deformable_group, 86 | const bool with_bias); 87 | 88 | void modulated_deform_conv_cuda_backward( 89 | at::Tensor input, 90 | at::Tensor weight, 91 | at::Tensor bias, 92 | at::Tensor ones, 93 | at::Tensor offset, 94 | at::Tensor mask, 95 | at::Tensor columns, 96 | at::Tensor grad_input, 97 | at::Tensor grad_weight, 98 | at::Tensor grad_bias, 99 | at::Tensor grad_offset, 100 | at::Tensor grad_mask, 101 | at::Tensor grad_output, 102 | int kernel_h, 103 | int kernel_w, 104 | int stride_h, 105 | int stride_w, 106 | int pad_h, 107 | int pad_w, 108 | int dilation_h, 109 | int dilation_w, 110 | int group, 111 | int deformable_group, 112 | const bool with_bias); 113 | 114 | #endif 115 | 116 | inline int deform_conv_forward( 117 | at::Tensor input, 118 | at::Tensor weight, 119 | at::Tensor offset, 120 | at::Tensor output, 121 | at::Tensor columns, 122 | at::Tensor ones, 123 | int kW, 124 | int kH, 125 | int dW, 126 | int dH, 127 | int padW, 128 | int padH, 129 | int dilationW, 130 | int dilationH, 131 | int group, 132 | int deformable_group, 133 | int im2col_step) { 134 | if (input.is_cuda()) { 135 | #if defined(WITH_CUDA) || defined(WITH_HIP) 136 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); 137 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); 138 | return deform_conv_forward_cuda( 139 | input, 140 | weight, 141 | offset, 142 | output, 143 | columns, 144 | ones, 145 | kW, 146 | kH, 147 | dW, 148 | dH, 149 | padW, 150 | padH, 151 | dilationW, 152 | dilationH, 153 | group, 154 | deformable_group, 155 | im2col_step); 156 | #else 157 | AT_ERROR("Not compiled with GPU support"); 158 | #endif 159 | } 160 | AT_ERROR("Not implemented on the CPU"); 161 | } 162 | 163 | inline int deform_conv_backward_input( 164 | at::Tensor input, 165 | at::Tensor offset, 166 | at::Tensor gradOutput, 167 | at::Tensor gradInput, 168 | at::Tensor gradOffset, 169 | at::Tensor weight, 170 | at::Tensor columns, 171 | int kW, 172 | int kH, 173 | int dW, 174 | int dH, 175 | int padW, 176 | int padH, 177 | int dilationW, 178 | int dilationH, 179 | int group, 180 | int deformable_group, 181 | int im2col_step) { 182 | if (gradOutput.is_cuda()) { 183 | #if defined(WITH_CUDA) || defined(WITH_HIP) 184 | TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); 185 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); 186 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); 187 | return deform_conv_backward_input_cuda( 188 | input, 189 | offset, 190 | gradOutput, 191 | gradInput, 192 | gradOffset, 193 | weight, 194 | columns, 195 | kW, 196 | kH, 197 | dW, 198 | dH, 199 | padW, 200 | padH, 201 | dilationW, 202 | dilationH, 203 | group, 204 | deformable_group, 205 | im2col_step); 206 | #else 207 | AT_ERROR("Not compiled with GPU support"); 208 | #endif 209 | } 210 | AT_ERROR("Not implemented on the CPU"); 211 | } 212 | 213 | inline int deform_conv_backward_filter( 214 | at::Tensor input, 215 | at::Tensor offset, 216 | at::Tensor gradOutput, 217 | at::Tensor gradWeight, // at::Tensor gradBias, 218 | at::Tensor columns, 219 | at::Tensor ones, 220 | int kW, 221 | int kH, 222 | int dW, 223 | int dH, 224 | int padW, 225 | int padH, 226 | int dilationW, 227 | int dilationH, 228 | int group, 229 | int deformable_group, 230 | float scale, 231 | int im2col_step) { 232 | if (gradOutput.is_cuda()) { 233 | #if defined(WITH_CUDA) || defined(WITH_HIP) 234 | TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); 235 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); 236 | return deform_conv_backward_parameters_cuda( 237 | input, 238 | offset, 239 | gradOutput, 240 | gradWeight, 241 | columns, 242 | ones, 243 | kW, 244 | kH, 245 | dW, 246 | dH, 247 | padW, 248 | padH, 249 | dilationW, 250 | dilationH, 251 | group, 252 | deformable_group, 253 | scale, 254 | im2col_step); 255 | #else 256 | AT_ERROR("Not compiled with GPU support"); 257 | #endif 258 | } 259 | AT_ERROR("Not implemented on the CPU"); 260 | } 261 | 262 | inline void modulated_deform_conv_forward( 263 | at::Tensor input, 264 | at::Tensor weight, 265 | at::Tensor bias, 266 | at::Tensor ones, 267 | at::Tensor offset, 268 | at::Tensor mask, 269 | at::Tensor output, 270 | at::Tensor columns, 271 | int kernel_h, 272 | int kernel_w, 273 | const int stride_h, 274 | const int stride_w, 275 | const int pad_h, 276 | const int pad_w, 277 | const int dilation_h, 278 | const int dilation_w, 279 | const int group, 280 | const int deformable_group, 281 | const bool with_bias) { 282 | if (input.is_cuda()) { 283 | #if defined(WITH_CUDA) || defined(WITH_HIP) 284 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); 285 | TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!"); 286 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); 287 | return modulated_deform_conv_cuda_forward( 288 | input, 289 | weight, 290 | bias, 291 | ones, 292 | offset, 293 | mask, 294 | output, 295 | columns, 296 | kernel_h, 297 | kernel_w, 298 | stride_h, 299 | stride_w, 300 | pad_h, 301 | pad_w, 302 | dilation_h, 303 | dilation_w, 304 | group, 305 | deformable_group, 306 | with_bias); 307 | #else 308 | AT_ERROR("Not compiled with GPU support"); 309 | #endif 310 | } 311 | AT_ERROR("Not implemented on the CPU"); 312 | } 313 | 314 | inline void modulated_deform_conv_backward( 315 | at::Tensor input, 316 | at::Tensor weight, 317 | at::Tensor bias, 318 | at::Tensor ones, 319 | at::Tensor offset, 320 | at::Tensor mask, 321 | at::Tensor columns, 322 | at::Tensor grad_input, 323 | at::Tensor grad_weight, 324 | at::Tensor grad_bias, 325 | at::Tensor grad_offset, 326 | at::Tensor grad_mask, 327 | at::Tensor grad_output, 328 | int kernel_h, 329 | int kernel_w, 330 | int stride_h, 331 | int stride_w, 332 | int pad_h, 333 | int pad_w, 334 | int dilation_h, 335 | int dilation_w, 336 | int group, 337 | int deformable_group, 338 | const bool with_bias) { 339 | if (grad_output.is_cuda()) { 340 | #if defined(WITH_CUDA) || defined(WITH_HIP) 341 | TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); 342 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); 343 | TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!"); 344 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); 345 | return modulated_deform_conv_cuda_backward( 346 | input, 347 | weight, 348 | bias, 349 | ones, 350 | offset, 351 | mask, 352 | columns, 353 | grad_input, 354 | grad_weight, 355 | grad_bias, 356 | grad_offset, 357 | grad_mask, 358 | grad_output, 359 | kernel_h, 360 | kernel_w, 361 | stride_h, 362 | stride_w, 363 | pad_h, 364 | pad_w, 365 | dilation_h, 366 | dilation_w, 367 | group, 368 | deformable_group, 369 | with_bias); 370 | #else 371 | AT_ERROR("Not compiled with GPU support"); 372 | #endif 373 | } 374 | AT_ERROR("Not implemented on the CPU"); 375 | } 376 | -------------------------------------------------------------------------------- /models/dcn/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | import os 4 | setup( 5 | name='deform_conv', 6 | ext_modules=[ 7 | CUDAExtension('_C', [ 8 | 'deformable/deform_conv.cpp', 9 | 'deformable/deform_conv_cuda.cu', 10 | 'deformable/deform_conv_cuda_kernel.cu'], 11 | include_dirs=["deformable"], 12 | define_macros = [("WITH_CUDA", None)], 13 | extra_compile_args={"nvcc":[ 14 | "-O3", 15 | "-DCUDA_HAS_FP16=1", 16 | "-D__CUDA_NO_HALF_OPERATORS__", 17 | "-D__CUDA_NO_HALF_CONVERSIONS__", 18 | "-D__CUDA_NO_HALF2_OPERATORS__"], 19 | "cxx":[],} 20 | ) 21 | ], 22 | cmdclass={'build_ext': BuildExtension}) -------------------------------------------------------------------------------- /models/dcn/test_deform.py: -------------------------------------------------------------------------------- 1 | from deform_conv import DeformConv 2 | import torch 3 | conv = DeformConv(10,20,3,padding=1) 4 | conv.cuda() 5 | x = torch.rand([1,10,416,416]).cuda() 6 | offset = torch.rand([1,2*9,416,416]).cuda() 7 | y = conv(x,offset) 8 | print(y.shape) 9 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instance Sequence Matching 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, multi_iou 10 | INF = 100000000 11 | 12 | class HungarianMatcher(nn.Module): 13 | """This class computes an assignment between the targets and the predictions of the network 14 | 15 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 16 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 17 | while the others are un-matched (and thus treated as non-objects). 18 | """ 19 | 20 | def __init__(self, num_frames : int = 36, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 21 | """Creates the matcher 22 | 23 | Params: 24 | cost_class: This is the relative weight of the classification error in the matching cost 25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 27 | """ 28 | super().__init__() 29 | self.cost_class = cost_class 30 | self.cost_bbox = cost_bbox 31 | self.cost_giou = cost_giou 32 | self.num_frames = num_frames 33 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 34 | 35 | @torch.no_grad() 36 | def forward(self, outputs, targets): 37 | """ Performs the sequence level matching 38 | """ 39 | bs, num_queries = outputs["pred_logits"].shape[:2] 40 | indices = [] 41 | for i in range(bs): 42 | out_prob = outputs["pred_logits"][i].softmax(-1) 43 | out_bbox = outputs["pred_boxes"][i] 44 | tgt_ids = targets[i]["labels"] 45 | tgt_bbox = targets[i]["boxes"] 46 | tgt_valid = targets[i]["valid"] 47 | num_out = 10 48 | num_tgt = len(tgt_ids)//self.num_frames 49 | out_prob_split = out_prob.reshape(self.num_frames,num_out,out_prob.shape[-1]).permute(1,0,2) 50 | out_bbox_split = out_bbox.reshape(self.num_frames,num_out,out_bbox.shape[-1]).permute(1,0,2).unsqueeze(1) 51 | tgt_bbox_split = tgt_bbox.reshape(num_tgt,self.num_frames,4).unsqueeze(0) 52 | tgt_valid_split = tgt_valid.reshape(num_tgt,self.num_frames) 53 | frame_index = torch.arange(start=0,end=self.num_frames).repeat(num_tgt).long() 54 | class_cost = -1 * out_prob_split[:,frame_index,tgt_ids].view(num_out,num_tgt,self.num_frames).mean(dim=-1) 55 | bbox_cost = (out_bbox_split-tgt_bbox_split).abs().mean((-1,-2)) 56 | iou_cost = -1 * multi_iou(box_cxcywh_to_xyxy(out_bbox_split),box_cxcywh_to_xyxy(tgt_bbox_split)).mean(-1) 57 | #TODO: only deal with box and mask with empty target 58 | cost = self.cost_class*class_cost + self.cost_bbox*bbox_cost + self.cost_giou*iou_cost 59 | out_i, tgt_i = linear_sum_assignment(cost.cpu()) 60 | index_i,index_j = [],[] 61 | for j in range(len(out_i)): 62 | tgt_valid_ind_j = tgt_valid_split[j].nonzero().flatten() 63 | index_i.append(tgt_valid_ind_j*num_out + out_i[j]) 64 | index_j.append(tgt_valid_ind_j + tgt_i[j]* self.num_frames) 65 | if index_i==[] or index_j==[]: 66 | indices.append((torch.tensor([]).long().to(out_prob.device),torch.tensor([]).long().to(out_prob.device))) 67 | else: 68 | index_i = torch.cat(index_i).long() 69 | index_j = torch.cat(index_j).long() 70 | indices.append((index_i,index_j)) 71 | return indices 72 | 73 | def build_matcher(args): 74 | return HungarianMatcher(num_frames = args.num_frames, cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) 75 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from util.misc import NestedTensor 10 | 11 | # position encoding for 3 dims 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, num_frames = 36, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | self.frames = num_frames 23 | if scale is not None and normalize is False: 24 | raise ValueError("normalize should be True if scale is passed") 25 | if scale is None: 26 | scale = 2 * math.pi 27 | self.scale = scale 28 | 29 | def forward(self, tensor_list: NestedTensor): 30 | x = tensor_list.tensors 31 | mask = tensor_list.mask 32 | n,h,w = mask.shape 33 | mask = mask.reshape(n//self.frames, self.frames,h,w) 34 | assert mask is not None 35 | not_mask = ~mask 36 | z_embed = not_mask.cumsum(1, dtype=torch.float32) 37 | y_embed = not_mask.cumsum(2, dtype=torch.float32) 38 | x_embed = not_mask.cumsum(3, dtype=torch.float32) 39 | if self.normalize: 40 | eps = 1e-6 41 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale 42 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 43 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 44 | 45 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 46 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 47 | 48 | pos_x = x_embed[:, :, :, :, None] / dim_t 49 | pos_y = y_embed[:, :, :, :, None] / dim_t 50 | pos_z = z_embed[:, :, :, :, None] / dim_t 51 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 52 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 53 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 54 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) 55 | return pos 56 | 57 | 58 | 59 | 60 | def build_position_encoding(args): 61 | N_steps = args.hidden_dim // 3 62 | if args.position_embedding in ('v2', 'sine'): 63 | # TODO find a better way of exposing other arguments 64 | position_embedding = PositionEmbeddingSine(N_steps, num_frames = args.num_frames, normalize=True) 65 | else: 66 | raise ValueError(f"not supported {args.position_embedding}") 67 | 68 | return position_embedding 69 | -------------------------------------------------------------------------------- /models/segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instance Sequence Segmentation 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import io 6 | from collections import defaultdict 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torch import Tensor 13 | from PIL import Image 14 | from .dcn.deform_conv import DeformConv 15 | 16 | import util.box_ops as box_ops 17 | from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list 18 | 19 | try: 20 | from panopticapi.utils import id2rgb, rgb2id 21 | except ImportError: 22 | pass 23 | import time 24 | BN_MOMENTUM = 0.1 25 | 26 | def conv3x3(in_planes, out_planes, stride=1): 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=1, bias=False) 30 | 31 | class BasicBlock(nn.Module): 32 | expansion = 1 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | out = self.conv2(out) 49 | out = self.bn2(out) 50 | 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | 54 | out += residual 55 | out = self.relu(out) 56 | return out 57 | 58 | 59 | 60 | class VisTRsegm(nn.Module): 61 | def __init__(self, vistr, freeze_vistr=False): 62 | super().__init__() 63 | self.vistr = vistr 64 | 65 | if freeze_vistr: 66 | for p in self.parameters(): 67 | p.requires_grad_(False) 68 | 69 | hidden_dim, nheads = vistr.transformer.d_model, vistr.transformer.nhead 70 | self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) 71 | self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) 72 | self.insmask_head = nn.Sequential( 73 | nn.Conv3d(24,12,3,padding=2,dilation=2), 74 | nn.GroupNorm(4,12), 75 | nn.ReLU(), 76 | nn.Conv3d(12,12,3,padding=2,dilation=2), 77 | nn.GroupNorm(4,12), 78 | nn.ReLU(), 79 | nn.Conv3d(12,12,3,padding=2,dilation=2), 80 | nn.GroupNorm(4,12), 81 | nn.ReLU(), 82 | nn.Conv3d(12,1,1)) 83 | def forward(self, samples: NestedTensor): 84 | if not isinstance(samples, NestedTensor): 85 | samples = nested_tensor_from_tensor_list(samples) 86 | features, pos = self.vistr.backbone(samples) 87 | bs = features[-1].tensors.shape[0] 88 | src, mask = features[-1].decompose() 89 | assert mask is not None 90 | src_proj = self.vistr.input_proj(src) 91 | n,c,s_h,s_w = src_proj.shape 92 | bs_f = bs//self.vistr.num_frames 93 | src_proj = src_proj.reshape(bs_f, self.vistr.num_frames,c, s_h, s_w).permute(0,2,1,3,4).flatten(-2) 94 | mask = mask.reshape(bs_f, self.vistr.num_frames, s_h*s_w) 95 | pos = pos[-1].permute(0,2,1,3,4).flatten(-2) 96 | hs, memory = self.vistr.transformer(src_proj, mask, self.vistr.query_embed.weight, pos) 97 | outputs_class = self.vistr.class_embed(hs) 98 | outputs_coord = self.vistr.bbox_embed(hs).sigmoid() 99 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 100 | if self.vistr.aux_loss: 101 | out['aux_outputs'] = self.vistr._set_aux_loss(outputs_class, outputs_coord) 102 | for i in range(3): 103 | _,c_f,h,w = features[i].tensors.shape 104 | features[i].tensors = features[i].tensors.reshape(bs_f, self.vistr.num_frames, c_f, h,w) 105 | n_f = self.vistr.num_queries//self.vistr.num_frames 106 | outputs_seg_masks = [] 107 | 108 | # image level processing using box attention 109 | for i in range(self.vistr.num_frames): 110 | hs_f = hs[-1][:,i*n_f:(i+1)*n_f,:] 111 | memory_f = memory[:,:,i,:].reshape(bs_f, c, s_h,s_w) 112 | mask_f = mask[:,i,:].reshape(bs_f, s_h,s_w) 113 | bbox_mask_f = self.bbox_attention(hs_f, memory_f, mask=mask_f) 114 | seg_masks_f = self.mask_head(memory_f, bbox_mask_f, [features[2].tensors[:,i], features[1].tensors[:,i], features[0].tensors[:,i]]) 115 | outputs_seg_masks_f = seg_masks_f.view(bs_f, n_f, 24, seg_masks_f.shape[-2], seg_masks_f.shape[-1]) 116 | outputs_seg_masks.append(outputs_seg_masks_f) 117 | frame_masks = torch.cat(outputs_seg_masks,dim=0) 118 | outputs_seg_masks = [] 119 | 120 | # instance level processing using 3D convolution 121 | for i in range(frame_masks.size(1)): 122 | mask_ins = frame_masks[:,i].unsqueeze(0) 123 | mask_ins = mask_ins.permute(0,2,1,3,4) 124 | outputs_seg_masks.append(self.insmask_head(mask_ins)) 125 | outputs_seg_masks = torch.cat(outputs_seg_masks,1).squeeze(0).permute(1,0,2,3) 126 | outputs_seg_masks = outputs_seg_masks.reshape(1,self.vistr.num_queries,outputs_seg_masks.size(-2),outputs_seg_masks.size(-1)) 127 | out["pred_masks"] = outputs_seg_masks 128 | return out 129 | 130 | 131 | def _expand(tensor, length: int): 132 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) 133 | 134 | 135 | class MaskHeadSmallConv(nn.Module): 136 | """ 137 | Simple convolutional head, using group norm. 138 | Upsampling is done using a FPN approach 139 | """ 140 | 141 | def __init__(self, dim, fpn_dims, context_dim): 142 | super().__init__() 143 | 144 | inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] 145 | self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) 146 | self.gn1 = torch.nn.GroupNorm(8, dim) 147 | self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) 148 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) 149 | self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) 150 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) 151 | self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) 152 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) 153 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) 154 | self.conv_offset = torch.nn.Conv2d(inter_dims[3], 18, 1)#, bias=False) 155 | self.dcn = DeformConv(inter_dims[3],inter_dims[4], 3, padding=1) 156 | 157 | self.dim = dim 158 | 159 | self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) 160 | self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) 161 | self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) 162 | 163 | for name, m in self.named_modules(): 164 | if name == "conv_offset": 165 | nn.init.constant_(m.weight, 0) 166 | nn.init.constant_(m.bias, 0) 167 | else: 168 | if isinstance(m, nn.Conv2d): 169 | nn.init.kaiming_uniform_(m.weight, a=1) 170 | nn.init.constant_(m.bias, 0) 171 | 172 | def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): 173 | x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) 174 | 175 | x = self.lay1(x) 176 | x = self.gn1(x) 177 | x = F.relu(x) 178 | x = self.lay2(x) 179 | x = self.gn2(x) 180 | x = F.relu(x) 181 | 182 | cur_fpn = self.adapter1(fpns[0]) 183 | if cur_fpn.size(0) != x.size(0): 184 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 185 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 186 | x = self.lay3(x) 187 | x = self.gn3(x) 188 | x = F.relu(x) 189 | 190 | cur_fpn = self.adapter2(fpns[1]) 191 | if cur_fpn.size(0) != x.size(0): 192 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 193 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 194 | x = self.lay4(x) 195 | x = self.gn4(x) 196 | x = F.relu(x) 197 | 198 | cur_fpn = self.adapter3(fpns[2]) 199 | if cur_fpn.size(0) != x.size(0): 200 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 201 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 202 | # dcn for the last layer 203 | offset = self.conv_offset(x) 204 | x = self.dcn(x,offset) 205 | x = self.gn5(x) 206 | x = F.relu(x) 207 | return x 208 | 209 | 210 | class MHAttentionMap(nn.Module): 211 | """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" 212 | 213 | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): 214 | super().__init__() 215 | self.num_heads = num_heads 216 | self.hidden_dim = hidden_dim 217 | self.dropout = nn.Dropout(dropout) 218 | 219 | self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 220 | self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 221 | 222 | nn.init.zeros_(self.k_linear.bias) 223 | nn.init.zeros_(self.q_linear.bias) 224 | nn.init.xavier_uniform_(self.k_linear.weight) 225 | nn.init.xavier_uniform_(self.q_linear.weight) 226 | self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 227 | 228 | def forward(self, q, k, mask: Optional[Tensor] = None): 229 | q = self.q_linear(q) 230 | k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) 231 | qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) 232 | kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) 233 | weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) 234 | 235 | if mask is not None: 236 | weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) 237 | weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights) 238 | weights = self.dropout(weights) 239 | return weights 240 | 241 | 242 | def dice_loss(inputs, targets, num_boxes): 243 | """ 244 | Compute the DICE loss, similar to generalized IOU for masks 245 | Args: 246 | inputs: A float tensor of arbitrary shape. 247 | The predictions for each example. 248 | targets: A float tensor with the same shape as inputs. Stores the binary 249 | classification label for each element in inputs 250 | (0 for the negative class and 1 for the positive class). 251 | """ 252 | inputs = inputs.sigmoid() 253 | inputs = inputs.flatten(1) 254 | numerator = 2 * (inputs * targets).sum(1) 255 | denominator = inputs.sum(-1) + targets.sum(-1) 256 | loss = 1 - (numerator + 1) / (denominator + 1) 257 | return loss.sum() / num_boxes 258 | 259 | 260 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 261 | """ 262 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 263 | Args: 264 | inputs: A float tensor of arbitrary shape. 265 | The predictions for each example. 266 | targets: A float tensor with the same shape as inputs. Stores the binary 267 | classification label for each element in inputs 268 | (0 for the negative class and 1 for the positive class). 269 | alpha: (optional) Weighting factor in range (0,1) to balance 270 | positive vs negative examples. Default = -1 (no weighting). 271 | gamma: Exponent of the modulating factor (1 - p_t) to 272 | balance easy vs hard examples. 273 | Returns: 274 | Loss tensor 275 | """ 276 | prob = inputs.sigmoid() 277 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 278 | p_t = prob * targets + (1 - prob) * (1 - targets) 279 | loss = ce_loss * ((1 - p_t) ** gamma) 280 | 281 | if alpha >= 0: 282 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 283 | loss = alpha_t * loss 284 | 285 | return loss.mean(1).sum() / num_boxes 286 | 287 | 288 | class PostProcessSegm(nn.Module): 289 | def __init__(self, threshold=0.5): 290 | super().__init__() 291 | self.threshold = threshold 292 | 293 | @torch.no_grad() 294 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 295 | assert len(orig_target_sizes) == len(max_target_sizes) 296 | max_h, max_w = max_target_sizes.max(0)[0].tolist() 297 | outputs_masks = outputs["pred_masks"].squeeze(2) 298 | outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) 299 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() 300 | 301 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): 302 | img_h, img_w = t[0], t[1] 303 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) 304 | results[i]["masks"] = F.interpolate( 305 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 306 | ).byte() 307 | 308 | return results 309 | 310 | 311 | -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | VisTR Transformer class. 3 | 4 | Copy-paste from torch.nn.Transformer with modifications: 5 | * positional encodings are passed in MHattention 6 | * extra LN at the end of encoder is removed 7 | * decoder returns a stack of activations from all decoding layers 8 | Modified from DETR (https://github.com/facebookresearch/detr) 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class Transformer(nn.Module): 19 | 20 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 21 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super().__init__() 25 | 26 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 27 | dropout, activation, normalize_before) 28 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 30 | 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 35 | return_intermediate=return_intermediate_dec) 36 | 37 | self._reset_parameters() 38 | 39 | self.d_model = d_model 40 | self.nhead = nhead 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, src, mask, query_embed, pos_embed): 48 | # flatten NxCxHxW to HWxNxC 49 | bs, c, h, w = src.shape 50 | src = src.flatten(2).permute(2, 0, 1) 51 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 52 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 53 | mask = mask.flatten(1) 54 | 55 | tgt = torch.zeros_like(query_embed) 56 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 57 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 58 | pos=pos_embed, query_pos=query_embed) 59 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 60 | 61 | 62 | class TransformerEncoder(nn.Module): 63 | 64 | def __init__(self, encoder_layer, num_layers, norm=None): 65 | super().__init__() 66 | self.layers = _get_clones(encoder_layer, num_layers) 67 | self.num_layers = num_layers 68 | self.norm = norm 69 | 70 | def forward(self, src, 71 | mask: Optional[Tensor] = None, 72 | src_key_padding_mask: Optional[Tensor] = None, 73 | pos: Optional[Tensor] = None): 74 | output = src 75 | 76 | for layer in self.layers: 77 | output = layer(output, src_mask=mask, 78 | src_key_padding_mask=src_key_padding_mask, pos=pos) 79 | 80 | if self.norm is not None: 81 | output = self.norm(output) 82 | 83 | return output 84 | 85 | 86 | class TransformerDecoder(nn.Module): 87 | 88 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 89 | super().__init__() 90 | self.layers = _get_clones(decoder_layer, num_layers) 91 | self.num_layers = num_layers 92 | self.norm = norm 93 | self.return_intermediate = return_intermediate 94 | 95 | def forward(self, tgt, memory, 96 | tgt_mask: Optional[Tensor] = None, 97 | memory_mask: Optional[Tensor] = None, 98 | tgt_key_padding_mask: Optional[Tensor] = None, 99 | memory_key_padding_mask: Optional[Tensor] = None, 100 | pos: Optional[Tensor] = None, 101 | query_pos: Optional[Tensor] = None): 102 | output = tgt 103 | 104 | intermediate = [] 105 | 106 | for layer in self.layers: 107 | output = layer(output, memory, tgt_mask=tgt_mask, 108 | memory_mask=memory_mask, 109 | tgt_key_padding_mask=tgt_key_padding_mask, 110 | memory_key_padding_mask=memory_key_padding_mask, 111 | pos=pos, query_pos=query_pos) 112 | if self.return_intermediate: 113 | intermediate.append(self.norm(output)) 114 | 115 | if self.norm is not None: 116 | output = self.norm(output) 117 | if self.return_intermediate: 118 | intermediate.pop() 119 | intermediate.append(output) 120 | 121 | if self.return_intermediate: 122 | return torch.stack(intermediate) 123 | 124 | return output 125 | 126 | 127 | class TransformerEncoderLayer(nn.Module): 128 | 129 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 130 | activation="relu", normalize_before=False): 131 | super().__init__() 132 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 133 | # Implementation of Feedforward model 134 | self.linear1 = nn.Linear(d_model, dim_feedforward) 135 | self.dropout = nn.Dropout(dropout) 136 | self.linear2 = nn.Linear(dim_feedforward, d_model) 137 | 138 | self.norm1 = nn.LayerNorm(d_model) 139 | self.norm2 = nn.LayerNorm(d_model) 140 | self.dropout1 = nn.Dropout(dropout) 141 | self.dropout2 = nn.Dropout(dropout) 142 | 143 | self.activation = _get_activation_fn(activation) 144 | self.normalize_before = normalize_before 145 | 146 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 147 | return tensor if pos is None else tensor + pos 148 | 149 | def forward_post(self, 150 | src, 151 | src_mask: Optional[Tensor] = None, 152 | src_key_padding_mask: Optional[Tensor] = None, 153 | pos: Optional[Tensor] = None): 154 | q = k = self.with_pos_embed(src, pos) 155 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 156 | key_padding_mask=src_key_padding_mask)[0] 157 | src = src + self.dropout1(src2) 158 | src = self.norm1(src) 159 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 160 | src = src + self.dropout2(src2) 161 | src = self.norm2(src) 162 | return src 163 | 164 | def forward_pre(self, src, 165 | src_mask: Optional[Tensor] = None, 166 | src_key_padding_mask: Optional[Tensor] = None, 167 | pos: Optional[Tensor] = None): 168 | src2 = self.norm1(src) 169 | q = k = self.with_pos_embed(src2, pos) 170 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 171 | key_padding_mask=src_key_padding_mask)[0] 172 | src = src + self.dropout1(src2) 173 | src2 = self.norm2(src) 174 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 175 | src = src + self.dropout2(src2) 176 | return src 177 | 178 | def forward(self, src, 179 | src_mask: Optional[Tensor] = None, 180 | src_key_padding_mask: Optional[Tensor] = None, 181 | pos: Optional[Tensor] = None): 182 | if self.normalize_before: 183 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 184 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 185 | 186 | 187 | class TransformerDecoderLayer(nn.Module): 188 | 189 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 190 | activation="relu", normalize_before=False): 191 | super().__init__() 192 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 193 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 194 | # Implementation of Feedforward model 195 | self.linear1 = nn.Linear(d_model, dim_feedforward) 196 | self.dropout = nn.Dropout(dropout) 197 | self.linear2 = nn.Linear(dim_feedforward, d_model) 198 | 199 | self.norm1 = nn.LayerNorm(d_model) 200 | self.norm2 = nn.LayerNorm(d_model) 201 | self.norm3 = nn.LayerNorm(d_model) 202 | self.dropout1 = nn.Dropout(dropout) 203 | self.dropout2 = nn.Dropout(dropout) 204 | self.dropout3 = nn.Dropout(dropout) 205 | 206 | self.activation = _get_activation_fn(activation) 207 | self.normalize_before = normalize_before 208 | 209 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 210 | return tensor if pos is None else tensor + pos 211 | 212 | def forward_post(self, tgt, memory, 213 | tgt_mask: Optional[Tensor] = None, 214 | memory_mask: Optional[Tensor] = None, 215 | tgt_key_padding_mask: Optional[Tensor] = None, 216 | memory_key_padding_mask: Optional[Tensor] = None, 217 | pos: Optional[Tensor] = None, 218 | query_pos: Optional[Tensor] = None): 219 | q = k = self.with_pos_embed(tgt, query_pos) 220 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 221 | key_padding_mask=tgt_key_padding_mask)[0] 222 | tgt = tgt + self.dropout1(tgt2) 223 | tgt = self.norm1(tgt) 224 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 225 | key=self.with_pos_embed(memory, pos), 226 | value=memory, attn_mask=memory_mask, 227 | key_padding_mask=memory_key_padding_mask)[0] 228 | tgt = tgt + self.dropout2(tgt2) 229 | tgt = self.norm2(tgt) 230 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 231 | tgt = tgt + self.dropout3(tgt2) 232 | tgt = self.norm3(tgt) 233 | return tgt 234 | 235 | def forward_pre(self, tgt, memory, 236 | tgt_mask: Optional[Tensor] = None, 237 | memory_mask: Optional[Tensor] = None, 238 | tgt_key_padding_mask: Optional[Tensor] = None, 239 | memory_key_padding_mask: Optional[Tensor] = None, 240 | pos: Optional[Tensor] = None, 241 | query_pos: Optional[Tensor] = None): 242 | tgt2 = self.norm1(tgt) 243 | q = k = self.with_pos_embed(tgt2, query_pos) 244 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 245 | key_padding_mask=tgt_key_padding_mask)[0] 246 | tgt = tgt + self.dropout1(tgt2) 247 | tgt2 = self.norm2(tgt) 248 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 249 | key=self.with_pos_embed(memory, pos), 250 | value=memory, attn_mask=memory_mask, 251 | key_padding_mask=memory_key_padding_mask)[0] 252 | tgt = tgt + self.dropout2(tgt2) 253 | tgt2 = self.norm3(tgt) 254 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 255 | tgt = tgt + self.dropout3(tgt2) 256 | return tgt 257 | 258 | def forward(self, tgt, memory, 259 | tgt_mask: Optional[Tensor] = None, 260 | memory_mask: Optional[Tensor] = None, 261 | tgt_key_padding_mask: Optional[Tensor] = None, 262 | memory_key_padding_mask: Optional[Tensor] = None, 263 | pos: Optional[Tensor] = None, 264 | query_pos: Optional[Tensor] = None): 265 | if self.normalize_before: 266 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 267 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 268 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 269 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 270 | 271 | 272 | def _get_clones(module, N): 273 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 274 | 275 | 276 | def build_transformer(args): 277 | return Transformer( 278 | d_model=args.hidden_dim, 279 | dropout=args.dropout, 280 | nhead=args.nheads, 281 | dim_feedforward=args.dim_feedforward, 282 | num_encoder_layers=args.enc_layers, 283 | num_decoder_layers=args.dec_layers, 284 | normalize_before=args.pre_norm, 285 | return_intermediate_dec=True, 286 | ) 287 | 288 | 289 | def _get_activation_fn(activation): 290 | """Return an activation function given a string""" 291 | if activation == "relu": 292 | return F.relu 293 | if activation == "gelu": 294 | return F.gelu 295 | if activation == "glu": 296 | return F.glu 297 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 298 | -------------------------------------------------------------------------------- /models/vistr.py: -------------------------------------------------------------------------------- 1 | """ 2 | VisTR model and criterion classes. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from util import box_ops 10 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 11 | accuracy, get_world_size, interpolate, 12 | is_dist_avail_and_initialized) 13 | 14 | from .backbone import build_backbone 15 | from .matcher import build_matcher 16 | from .segmentation import (VisTRsegm, PostProcessSegm, 17 | dice_loss, sigmoid_focal_loss) 18 | from .transformer import build_transformer 19 | 20 | 21 | class VisTR(nn.Module): 22 | """ This is the VisTR module that performs video object detection """ 23 | def __init__(self, backbone, transformer, num_classes, num_frames, num_queries, aux_loss=False): 24 | """ Initializes the model. 25 | Parameters: 26 | backbone: torch module of the backbone to be used. See backbone.py 27 | transformer: torch module of the transformer architecture. See transformer.py 28 | num_classes: number of object classes 29 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 30 | VisTR can detect in a video. For ytvos, we recommend 10 queries for each frame, 31 | thus 360 queries for 36 frames. 32 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 33 | """ 34 | super().__init__() 35 | self.num_queries = num_queries 36 | self.transformer = transformer 37 | hidden_dim = transformer.d_model 38 | self.hidden_dim = hidden_dim 39 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 40 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 41 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 42 | self.num_frames = num_frames 43 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 44 | self.backbone = backbone 45 | self.aux_loss = aux_loss 46 | 47 | def forward(self, samples: NestedTensor): 48 | """ The forward expects a NestedTensor, which consists of: 49 | - samples.tensors: image sequences, of shape [num_frames x 3 x H x W] 50 | - samples.mask: a binary mask of shape [num_frames x H x W], containing 1 on padded pixels 51 | 52 | It returns a dict with the following elements: 53 | - "pred_logits": the classification logits (including no-object) for all queries. 54 | Shape= [batch_size x num_queries x (num_classes + 1)] 55 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 56 | (center_x, center_y, height, width). These values are normalized in [0, 1], 57 | relative to the size of each individual image (disregarding possible padding). 58 | See PostProcess for information on how to retrieve the unnormalized bounding box. 59 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 60 | dictionnaries containing the two above keys for each decoder layer. 61 | """ 62 | if not isinstance(samples, NestedTensor): 63 | samples = nested_tensor_from_tensor_list(samples) 64 | # moved the frame to batch dimension for computation efficiency 65 | features, pos = self.backbone(samples) 66 | pos = pos[-1] 67 | src, mask = features[-1].decompose() 68 | src_proj = self.input_proj(src) 69 | n,c,h,w = src_proj.shape 70 | assert mask is not None 71 | src_proj = src_proj.reshape(n//self.num_frames, self.num_frames, c, h, w).permute(0,2,1,3,4).flatten(-2) 72 | mask = mask.reshape(n//self.num_frames, self.num_frames, h*w) 73 | pos = pos.permute(0,2,1,3,4).flatten(-2) 74 | hs = self.transformer(src_proj, mask, self.query_embed.weight, pos)[0] 75 | 76 | outputs_class = self.class_embed(hs) 77 | outputs_coord = self.bbox_embed(hs).sigmoid() 78 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} 79 | if self.aux_loss: 80 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) 81 | return out 82 | 83 | @torch.jit.unused 84 | def _set_aux_loss(self, outputs_class, outputs_coord): 85 | # this is a workaround to make torchscript happy, as torchscript 86 | # doesn't support dictionary with non-homogeneous values, such 87 | # as a dict having both a Tensor and a list. 88 | return [{'pred_logits': a, 'pred_boxes': b} 89 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] 90 | 91 | 92 | class SetCriterion(nn.Module): 93 | """ This class computes the loss for VisTR. 94 | The process happens in two steps: 95 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 96 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 97 | """ 98 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses): 99 | """ Create the criterion. 100 | Parameters: 101 | num_classes: number of object categories, omitting the special no-object category 102 | matcher: module able to compute a matching between targets and proposals 103 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 104 | eos_coef: relative classification weight applied to the no-object category 105 | losses: list of all the losses to be applied. See get_loss for list of available losses. 106 | """ 107 | super().__init__() 108 | self.num_classes = num_classes 109 | self.matcher = matcher 110 | self.weight_dict = weight_dict 111 | self.eos_coef = eos_coef 112 | self.losses = losses 113 | empty_weight = torch.ones(self.num_classes + 1) 114 | empty_weight[-1] = self.eos_coef 115 | self.register_buffer('empty_weight', empty_weight) 116 | 117 | def loss_labels(self, outputs, targets, indices, num_boxes, log=True): 118 | """Classification loss (NLL) 119 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 120 | """ 121 | assert 'pred_logits' in outputs 122 | src_logits = outputs['pred_logits'] 123 | 124 | idx = self._get_src_permutation_idx(indices) 125 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 126 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, 127 | dtype=torch.int64, device=src_logits.device) 128 | target_classes[idx] = target_classes_o 129 | 130 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) 131 | losses = {'loss_ce': loss_ce} 132 | 133 | if log: 134 | # TODO this should probably be a separate loss, not hacked in this one here 135 | losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] 136 | return losses 137 | 138 | @torch.no_grad() 139 | def loss_cardinality(self, outputs, targets, indices, num_boxes): 140 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 141 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 142 | """ 143 | pred_logits = outputs['pred_logits'] 144 | device = pred_logits.device 145 | tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) 146 | # Count the number of predictions that are NOT "no-object" (which is the last class) 147 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 148 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 149 | losses = {'cardinality_error': card_err} 150 | return losses 151 | 152 | def loss_boxes(self, outputs, targets, indices, num_boxes): 153 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 154 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 155 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 156 | """ 157 | assert 'pred_boxes' in outputs 158 | idx = self._get_src_permutation_idx(indices) 159 | src_boxes = outputs['pred_boxes'][idx] 160 | target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) 161 | 162 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 163 | 164 | losses = {} 165 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes 166 | 167 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( 168 | box_ops.box_cxcywh_to_xyxy(src_boxes), 169 | box_ops.box_cxcywh_to_xyxy(target_boxes))) 170 | losses['loss_giou'] = loss_giou.sum() / num_boxes 171 | return losses 172 | 173 | def loss_masks(self, outputs, targets, indices, num_boxes): 174 | """Compute the losses related to the masks: the focal loss and the dice loss. 175 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 176 | """ 177 | assert "pred_masks" in outputs 178 | 179 | src_idx = self._get_src_permutation_idx(indices) 180 | tgt_idx = self._get_tgt_permutation_idx(indices) 181 | 182 | src_masks = outputs["pred_masks"] 183 | # TODO use valid to mask invalid areas due to padding in loss 184 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], split=False).decompose() 185 | target_masks = target_masks.to(src_masks) 186 | src_masks = src_masks[src_idx] 187 | # upsample predictions to the target size 188 | try: 189 | src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:], 190 | mode="bilinear", align_corners=False) 191 | src_masks = src_masks[:, 0].flatten(1) 192 | target_masks = target_masks[tgt_idx].flatten(1) 193 | except: 194 | src_masks = src_masks.flatten(1) 195 | target_masks = src_masks.clone() 196 | losses = { 197 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), 198 | "loss_dice": dice_loss(src_masks, target_masks, num_boxes), 199 | } 200 | return losses 201 | 202 | def _get_src_permutation_idx(self, indices): 203 | # permute predictions following indices 204 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 205 | src_idx = torch.cat([src for (src, _) in indices]) 206 | return batch_idx, src_idx 207 | 208 | def _get_tgt_permutation_idx(self, indices): 209 | # permute targets following indices 210 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 211 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 212 | return batch_idx, tgt_idx 213 | 214 | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): 215 | loss_map = { 216 | 'labels': self.loss_labels, 217 | 'cardinality': self.loss_cardinality, 218 | 'boxes': self.loss_boxes, 219 | 'masks': self.loss_masks 220 | } 221 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 222 | return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) 223 | 224 | def forward(self, outputs, targets): 225 | """ This performs the loss computation. 226 | Parameters: 227 | outputs: dict of tensors, see the output specification of the model for the format 228 | targets: list of dicts, such that len(targets) == batch_size. 229 | The expected keys in each dict depends on the losses applied, see each loss' doc 230 | """ 231 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} 232 | # Retrieve the matching between the outputs of the last layer and the targets 233 | indices = self.matcher(outputs_without_aux, targets) 234 | 235 | # Compute the average number of target boxes accross all nodes, for normalization purposes 236 | num_boxes = sum(len(t["labels"]) for t in targets) 237 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) 238 | if is_dist_avail_and_initialized(): 239 | torch.distributed.all_reduce(num_boxes) 240 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 241 | 242 | # Compute all the requested losses 243 | losses = {} 244 | for loss in self.losses: 245 | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) 246 | 247 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 248 | if 'aux_outputs' in outputs: 249 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 250 | indices = self.matcher(aux_outputs, targets) 251 | for loss in self.losses: 252 | if loss == 'masks': 253 | # Intermediate masks losses are too costly to compute, we ignore them. 254 | continue 255 | kwargs = {} 256 | if loss == 'labels': 257 | # Logging is enabled only for the last layer 258 | kwargs = {'log': False} 259 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) 260 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 261 | losses.update(l_dict) 262 | 263 | return losses 264 | 265 | 266 | class PostProcess(nn.Module): 267 | """ This module converts the model's output into the format expected by the coco api""" 268 | @torch.no_grad() 269 | def forward(self, outputs, target_sizes): 270 | """ Perform the computation 271 | Parameters: 272 | outputs: raw outputs of the model 273 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 274 | For evaluation, this must be the original image size (before any data augmentation) 275 | For visualization, this should be the image size after data augment, but before padding 276 | """ 277 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 278 | 279 | assert len(out_logits) == len(target_sizes) 280 | assert target_sizes.shape[1] == 2 281 | 282 | prob = F.softmax(out_logits, -1) 283 | scores, labels = prob[..., :-1].max(-1) 284 | 285 | # convert to [x0, y0, x1, y1] format 286 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) 287 | # and from relative [0, 1] to absolute [0, height] coordinates 288 | img_h, img_w = target_sizes.unbind(1) 289 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 290 | boxes = boxes * scale_fct[:, None, :] 291 | 292 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] 293 | 294 | return results 295 | 296 | 297 | class MLP(nn.Module): 298 | """ Very simple multi-layer perceptron (also called FFN)""" 299 | 300 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 301 | super().__init__() 302 | self.num_layers = num_layers 303 | h = [hidden_dim] * (num_layers - 1) 304 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 305 | 306 | def forward(self, x): 307 | for i, layer in enumerate(self.layers): 308 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 309 | return x 310 | 311 | 312 | def build(args): 313 | if args.dataset_file == "ytvos": 314 | num_classes = 41 315 | device = torch.device(args.device) 316 | 317 | backbone = build_backbone(args) 318 | 319 | transformer = build_transformer(args) 320 | 321 | model = VisTR( 322 | backbone, 323 | transformer, 324 | num_classes=num_classes, 325 | num_frames=args.num_frames, 326 | num_queries=args.num_queries, 327 | aux_loss=args.aux_loss, 328 | ) 329 | if args.masks: 330 | model = VisTRsegm(model) 331 | matcher = build_matcher(args) 332 | weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} 333 | weight_dict['loss_giou'] = args.giou_loss_coef 334 | if args.masks: 335 | weight_dict["loss_mask"] = args.mask_loss_coef 336 | weight_dict["loss_dice"] = args.dice_loss_coef 337 | # TODO this is a hack 338 | if args.aux_loss: 339 | aux_weight_dict = {} 340 | for i in range(args.dec_layers - 1): 341 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 342 | weight_dict.update(aux_weight_dict) 343 | 344 | losses = ['labels', 'boxes', 'cardinality'] 345 | if args.masks: 346 | losses += ["masks"] 347 | criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict, 348 | eos_coef=args.eos_coef, losses=losses) 349 | criterion.to(device) 350 | postprocessors = {'bbox': PostProcess()} 351 | if args.masks: 352 | postprocessors['segm'] = PostProcessSegm() 353 | return model, criterion, postprocessors 354 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuqingWang1029/VisTR/445c9e4e787a1fb3c959d7e7bb6ecf809bdac155/util/__init__.py -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | def clip_iou(boxes1,boxes2): 8 | area1 = box_area(boxes1) 9 | area2 = box_area(boxes2) 10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) 11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) 12 | wh = (rb-lt).clamp(min=0) 13 | inter = wh[:,0]*wh[:,1] 14 | union = area1 + area2 - inter 15 | iou = (inter+1e-6) / (union+1e-6) 16 | # generalized version 17 | # iou=iou-(inter-union)/inter 18 | return iou 19 | 20 | def multi_iou(boxes1, boxes2): 21 | lt = torch.max(boxes1[...,:2], boxes2[...,:2]) 22 | rb = torch.min(boxes1[...,2:], boxes2[...,2:]) 23 | wh = (rb-lt).clamp(min=0) 24 | wh_1 = boxes1[...,2:] - boxes1[...,:2] 25 | wh_2 = boxes2[...,2:] - boxes2[...,:2] 26 | inter = wh[...,0] * wh[...,1] 27 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter 28 | iou = (inter+1e-6) / (union+1e-6) 29 | return iou 30 | 31 | def box_cxcywh_to_xyxy(x): 32 | x_c, y_c, w, h = x.unbind(-1) 33 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 34 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 35 | return torch.stack(b, dim=-1) 36 | 37 | 38 | def box_xyxy_to_cxcywh(x): 39 | x0, y0, x1, y1 = x.unbind(-1) 40 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 41 | (x1 - x0), (y1 - y0)] 42 | return torch.stack(b, dim=-1) 43 | 44 | 45 | # modified from torchvision to also return the union 46 | def box_iou(boxes1, boxes2): 47 | area1 = box_area(boxes1) 48 | area2 = box_area(boxes2) 49 | 50 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 51 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 52 | 53 | wh = (rb - lt).clamp(min=0) # [N,M,2] 54 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 55 | 56 | union = area1[:, None] + area2 - inter 57 | 58 | iou = (inter+1e-6) / (union+1e-6) 59 | return iou, union 60 | 61 | 62 | def generalized_box_iou(boxes1, boxes2): 63 | """ 64 | Generalized IoU from https://giou.stanford.edu/ 65 | 66 | The boxes should be in [x0, y0, x1, y1] format 67 | 68 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 69 | and M = len(boxes2) 70 | """ 71 | # degenerate boxes gives inf / nan results 72 | # so do an early check 73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 75 | iou, union = box_iou(boxes1, boxes2) 76 | 77 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 78 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 79 | 80 | wh = (rb - lt).clamp(min=0) # [N,M,2] 81 | area = wh[:, :, 0] * wh[:, :, 1] 82 | 83 | return iou - ((area - union)+1e-6) / (area+1e-6) 84 | 85 | 86 | def masks_to_boxes(masks): 87 | """Compute the bounding boxes around the provided masks 88 | 89 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 90 | 91 | Returns a [N, 4] tensors, with the boxes in xyxy format 92 | """ 93 | if masks.numel() == 0: 94 | return torch.zeros((0, 4), device=masks.device) 95 | 96 | h, w = masks.shape[-2:] 97 | 98 | y = torch.arange(0, h, dtype=torch.float) 99 | x = torch.arange(0, w, dtype=torch.float) 100 | y, x = torch.meshgrid(y, x) 101 | 102 | x_mask = (masks * x.unsqueeze(0)) 103 | x_max = x_mask.flatten(1).max(-1)[0] 104 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 105 | 106 | y_mask = (masks * y.unsqueeze(0)) 107 | y_max = y_mask.flatten(1).max(-1)[0] 108 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 109 | 110 | return torch.stack([x_min, y_min, x_max, y_max], 1) 111 | -------------------------------------------------------------------------------- /util/detr_weights_to_vistr.py: -------------------------------------------------------------------------------- 1 | ''' 2 | convert detr pretrained weights to vistr format 3 | ''' 4 | import sys 5 | import torch 6 | import collections 7 | 8 | if __name__ == "__main__": 9 | input_path = sys.argv[1] 10 | detr_weights = torch.load(input_path)['model'] 11 | vistr_weights = collections.OrderedDict() 12 | 13 | for k,v in detr_weights.items(): 14 | if k.startswith("detr"): 15 | k = k.replace("detr","vistr") 16 | vistr_weights[k]=v 17 | res = {"model":vistr_weights} 18 | 19 | torch.save(res,sys.argv[2]) 20 | 21 | 22 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | 4 | Mostly copy-paste from torchvision references. 5 | """ 6 | import os 7 | import subprocess 8 | import time 9 | from collections import defaultdict, deque 10 | import datetime 11 | import pickle 12 | from typing import Optional, List 13 | 14 | import torch 15 | import torch.distributed as dist 16 | from torch import Tensor 17 | 18 | # needed due to empty tensor bug in pytorch and torchvision 0.5 19 | import torchvision 20 | if float(torchvision.__version__[:3]) < 0.7: 21 | from torchvision.ops import _new_empty_tensor 22 | from torchvision.ops.misc import _output_size 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | def all_gather(data): 88 | """ 89 | Run all_gather on arbitrary picklable data (not necessarily tensors) 90 | Args: 91 | data: any picklable object 92 | Returns: 93 | list[data]: list of data gathered from each rank 94 | """ 95 | world_size = get_world_size() 96 | if world_size == 1: 97 | return [data] 98 | 99 | # serialized to a Tensor 100 | buffer = pickle.dumps(data) 101 | storage = torch.ByteStorage.from_buffer(buffer) 102 | tensor = torch.ByteTensor(storage).to("cuda") 103 | 104 | # obtain Tensor size of each rank 105 | local_size = torch.tensor([tensor.numel()], device="cuda") 106 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 107 | dist.all_gather(size_list, local_size) 108 | size_list = [int(size.item()) for size in size_list] 109 | max_size = max(size_list) 110 | 111 | # receiving Tensor from all ranks 112 | # we pad the tensor because torch all_gather does not support 113 | # gathering tensors of different shapes 114 | tensor_list = [] 115 | for _ in size_list: 116 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 117 | if local_size != max_size: 118 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 119 | tensor = torch.cat((tensor, padding), dim=0) 120 | dist.all_gather(tensor_list, tensor) 121 | 122 | data_list = [] 123 | for size, tensor in zip(size_list, tensor_list): 124 | buffer = tensor.cpu().numpy().tobytes()[:size] 125 | data_list.append(pickle.loads(buffer)) 126 | 127 | return data_list 128 | 129 | 130 | def reduce_dict(input_dict, average=True): 131 | """ 132 | Args: 133 | input_dict (dict): all the values will be reduced 134 | average (bool): whether to do average or sum 135 | Reduce the values in the dictionary from all processes so that all processes 136 | have the averaged results. Returns a dict with the same fields as 137 | input_dict, after reduction. 138 | """ 139 | world_size = get_world_size() 140 | if world_size < 2: 141 | return input_dict 142 | with torch.no_grad(): 143 | names = [] 144 | values = [] 145 | # sort the keys so that they are consistent across processes 146 | for k in sorted(input_dict.keys()): 147 | names.append(k) 148 | values.append(input_dict[k]) 149 | values = torch.stack(values, dim=0) 150 | dist.all_reduce(values) 151 | if average: 152 | values /= world_size 153 | reduced_dict = {k: v for k, v in zip(names, values)} 154 | return reduced_dict 155 | 156 | 157 | class MetricLogger(object): 158 | def __init__(self, delimiter="\t"): 159 | self.meters = defaultdict(SmoothedValue) 160 | self.delimiter = delimiter 161 | 162 | def update(self, **kwargs): 163 | for k, v in kwargs.items(): 164 | if isinstance(v, torch.Tensor): 165 | v = v.item() 166 | assert isinstance(v, (float, int)) 167 | self.meters[k].update(v) 168 | 169 | def __getattr__(self, attr): 170 | if attr in self.meters: 171 | return self.meters[attr] 172 | if attr in self.__dict__: 173 | return self.__dict__[attr] 174 | raise AttributeError("'{}' object has no attribute '{}'".format( 175 | type(self).__name__, attr)) 176 | 177 | def __str__(self): 178 | loss_str = [] 179 | for name, meter in self.meters.items(): 180 | loss_str.append( 181 | "{}: {}".format(name, str(meter)) 182 | ) 183 | return self.delimiter.join(loss_str) 184 | 185 | def synchronize_between_processes(self): 186 | for meter in self.meters.values(): 187 | meter.synchronize_between_processes() 188 | 189 | def add_meter(self, name, meter): 190 | self.meters[name] = meter 191 | 192 | def log_every(self, iterable, print_freq, header=None): 193 | i = 0 194 | if not header: 195 | header = '' 196 | start_time = time.time() 197 | end = time.time() 198 | iter_time = SmoothedValue(fmt='{avg:.4f}') 199 | data_time = SmoothedValue(fmt='{avg:.4f}') 200 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 201 | if torch.cuda.is_available(): 202 | log_msg = self.delimiter.join([ 203 | header, 204 | '[{0' + space_fmt + '}/{1}]', 205 | 'eta: {eta}', 206 | '{meters}', 207 | 'time: {time}', 208 | 'data: {data}', 209 | 'max mem: {memory:.0f}' 210 | ]) 211 | else: 212 | log_msg = self.delimiter.join([ 213 | header, 214 | '[{0' + space_fmt + '}/{1}]', 215 | 'eta: {eta}', 216 | '{meters}', 217 | 'time: {time}', 218 | 'data: {data}' 219 | ]) 220 | MB = 1024.0 * 1024.0 221 | for obj in iterable: 222 | data_time.update(time.time() - end) 223 | yield obj 224 | iter_time.update(time.time() - end) 225 | if i % print_freq == 0 or i == len(iterable) - 1: 226 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 227 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 228 | if torch.cuda.is_available(): 229 | print(log_msg.format( 230 | i, len(iterable), eta=eta_string, 231 | meters=str(self), 232 | time=str(iter_time), data=str(data_time), 233 | memory=torch.cuda.max_memory_allocated() / MB)) 234 | else: 235 | print(log_msg.format( 236 | i, len(iterable), eta=eta_string, 237 | meters=str(self), 238 | time=str(iter_time), data=str(data_time))) 239 | i += 1 240 | end = time.time() 241 | total_time = time.time() - start_time 242 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 243 | print('{} Total time: {} ({:.4f} s / it)'.format( 244 | header, total_time_str, total_time / len(iterable))) 245 | 246 | 247 | def get_sha(): 248 | cwd = os.path.dirname(os.path.abspath(__file__)) 249 | 250 | def _run(command): 251 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 252 | sha = 'N/A' 253 | diff = "clean" 254 | branch = 'N/A' 255 | try: 256 | sha = _run(['git', 'rev-parse', 'HEAD']) 257 | subprocess.check_output(['git', 'diff'], cwd=cwd) 258 | diff = _run(['git', 'diff-index', 'HEAD']) 259 | diff = "has uncommited changes" if diff else "clean" 260 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 261 | except Exception: 262 | pass 263 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 264 | return message 265 | 266 | 267 | def collate_fn(batch): 268 | batch = list(zip(*batch)) 269 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 270 | return tuple(batch) 271 | 272 | 273 | def _max_by_axis(the_list): 274 | # type: (List[List[int]]) -> List[int] 275 | maxes = the_list[0] 276 | for sublist in the_list[1:]: 277 | for index, item in enumerate(sublist): 278 | maxes[index] = max(maxes[index], item) 279 | return maxes 280 | 281 | 282 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor], split=True): 283 | # TODO make this more general 284 | if split: 285 | tensor_list = [tensor.split(3,dim=0) for tensor in tensor_list] 286 | tensor_list = [item for sublist in tensor_list for item in sublist] 287 | if tensor_list[0].ndim == 3: 288 | # TODO make it support different-sized images 289 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 290 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 291 | batch_shape = [len(tensor_list)] + max_size 292 | b, c, h, w = batch_shape 293 | dtype = tensor_list[0].dtype 294 | device = tensor_list[0].device 295 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 296 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 297 | for img, pad_img, m in zip(tensor_list, tensor, mask): 298 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 299 | m[: img.shape[1], :img.shape[2]] = False 300 | else: 301 | raise ValueError('not supported') 302 | return NestedTensor(tensor, mask) 303 | 304 | 305 | class NestedTensor(object): 306 | def __init__(self, tensors, mask: Optional[Tensor]): 307 | self.tensors = tensors 308 | self.mask = mask 309 | 310 | def to(self, device): 311 | # type: (Device) -> NestedTensor # noqa 312 | cast_tensor = self.tensors.to(device) 313 | mask = self.mask 314 | if mask is not None: 315 | assert mask is not None 316 | cast_mask = mask.to(device) 317 | else: 318 | cast_mask = None 319 | return NestedTensor(cast_tensor, cast_mask) 320 | 321 | def decompose(self): 322 | return self.tensors, self.mask 323 | 324 | def __repr__(self): 325 | return str(self.tensors) 326 | 327 | 328 | def setup_for_distributed(is_master): 329 | """ 330 | This function disables printing when not in master process 331 | """ 332 | import builtins as __builtin__ 333 | builtin_print = __builtin__.print 334 | 335 | def print(*args, **kwargs): 336 | force = kwargs.pop('force', False) 337 | if is_master or force: 338 | builtin_print(*args, **kwargs) 339 | 340 | __builtin__.print = print 341 | 342 | 343 | def is_dist_avail_and_initialized(): 344 | if not dist.is_available(): 345 | return False 346 | if not dist.is_initialized(): 347 | return False 348 | return True 349 | 350 | 351 | def get_world_size(): 352 | if not is_dist_avail_and_initialized(): 353 | return 1 354 | return dist.get_world_size() 355 | 356 | 357 | def get_rank(): 358 | if not is_dist_avail_and_initialized(): 359 | return 0 360 | return dist.get_rank() 361 | 362 | 363 | def is_main_process(): 364 | return get_rank() == 0 365 | 366 | 367 | def save_on_master(*args, **kwargs): 368 | if is_main_process(): 369 | torch.save(*args, **kwargs, _use_new_zipfile_serialization=False) 370 | 371 | 372 | def init_distributed_mode(args): 373 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 374 | args.rank = int(os.environ["RANK"]) 375 | args.world_size = int(os.environ['WORLD_SIZE']) 376 | args.gpu = int(os.environ['LOCAL_RANK']) 377 | elif 'SLURM_PROCID' in os.environ: 378 | args.rank = int(os.environ['SLURM_PROCID']) 379 | args.gpu = args.rank % torch.cuda.device_count() 380 | else: 381 | print('Not using distributed mode') 382 | args.distributed = False 383 | return 384 | 385 | args.distributed = True 386 | 387 | torch.cuda.set_device(args.gpu) 388 | args.dist_backend = 'nccl' 389 | print('| distributed init (rank {}): {}'.format( 390 | args.rank, args.dist_url), flush=True) 391 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 392 | world_size=args.world_size, rank=args.rank) 393 | torch.distributed.barrier() 394 | setup_for_distributed(args.rank == 0) 395 | 396 | 397 | @torch.no_grad() 398 | def accuracy(output, target, topk=(1,)): 399 | """Computes the precision@k for the specified values of k""" 400 | if target.numel() == 0: 401 | return [torch.zeros([], device=output.device)] 402 | maxk = max(topk) 403 | batch_size = target.size(0) 404 | 405 | _, pred = output.topk(maxk, 1, True, True) 406 | pred = pred.t() 407 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 408 | 409 | res = [] 410 | for k in topk: 411 | correct_k = correct[:k].view(-1).float().sum(0) 412 | res.append(correct_k.mul_(100.0 / batch_size)) 413 | return res 414 | 415 | 416 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 417 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 418 | """ 419 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 420 | This will eventually be supported natively by PyTorch, and this 421 | class can go away. 422 | """ 423 | if float(torchvision.__version__[:3]) < 0.7: 424 | if input.numel() > 0: 425 | return torch.nn.functional.interpolate( 426 | input, size, scale_factor, mode, align_corners 427 | ) 428 | 429 | output_shape = _output_size(2, input, size, scale_factor) 430 | output_shape = list(input.shape[:-2]) + list(output_shape) 431 | return _new_empty_tensor(input, output_shape) 432 | else: 433 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 434 | -------------------------------------------------------------------------------- /util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import seaborn as sns 7 | import matplotlib.pyplot as plt 8 | 9 | from pathlib import Path, PurePath 10 | 11 | 12 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 13 | ''' 14 | Function to plot specific fields from training log(s). Plots both training and test results. 15 | 16 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 17 | - fields = which results to plot from each log file - plots both training and test for each field. 18 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 19 | - log_name = optional, name of log file if different than default 'log.txt'. 20 | 21 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 22 | - solid lines are training results, dashed lines are test results. 23 | 24 | ''' 25 | func_name = "plot_utils.py::plot_logs" 26 | 27 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 28 | # convert single Path to list to avoid 'not iterable' error 29 | 30 | if not isinstance(logs, list): 31 | if isinstance(logs, PurePath): 32 | logs = [logs] 33 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 34 | else: 35 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 36 | Expect list[Path] or single Path obj, received {type(logs)}") 37 | 38 | # verify valid dir(s) and that every item in list is Path object 39 | for i, dir in enumerate(logs): 40 | if not isinstance(dir, PurePath): 41 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 42 | if dir.exists(): 43 | continue 44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 45 | 46 | # load log file(s) and plot 47 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 48 | 49 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 50 | 51 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 52 | for j, field in enumerate(fields): 53 | if field == 'mAP': 54 | coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() 55 | axs[j].plot(coco_eval, c=color) 56 | else: 57 | df.interpolate().ewm(com=ewm_col).mean().plot( 58 | y=[f'train_{field}', f'test_{field}'], 59 | ax=axs[j], 60 | color=[color] * 2, 61 | style=['-', '--'] 62 | ) 63 | for ax, field in zip(axs, fields): 64 | ax.legend([Path(p).name for p in logs]) 65 | ax.set_title(field) 66 | 67 | 68 | def plot_precision_recall(files, naming_scheme='iter'): 69 | if naming_scheme == 'exp_id': 70 | # name becomes exp_id 71 | names = [f.parts[-3] for f in files] 72 | elif naming_scheme == 'iter': 73 | names = [f.stem for f in files] 74 | else: 75 | raise ValueError(f'not supported {naming_scheme}') 76 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 77 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 78 | data = torch.load(f) 79 | # precision is n_iou, n_points, n_cat, n_area, max_det 80 | precision = data['precision'] 81 | recall = data['params'].recThrs 82 | scores = data['scores'] 83 | # take precision for all classes, all areas and 100 detections 84 | precision = precision[0, :, :, 0, -1].mean(1) 85 | scores = scores[0, :, :, 0, -1].mean(1) 86 | prec = precision.mean() 87 | rec = data['recall'][0, :, 0, -1].mean() 88 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 89 | f'score={scores.mean():0.3f}, ' + 90 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 91 | ) 92 | axs[0].plot(recall, precision, c=color) 93 | axs[1].plot(recall, scores, c=color) 94 | 95 | axs[0].set_title('Precision / Recall') 96 | axs[0].legend(names) 97 | axs[1].set_title('Scores / Recall') 98 | axs[1].legend(names) 99 | return fig, axs 100 | --------------------------------------------------------------------------------