├── LICENSE
├── README.md
├── datasets
├── __init__.py
├── coco.py
├── coco_eval.py
├── coco_panoptic.py
├── panoptic_eval.py
├── transforms.py
└── ytvos.py
├── engine.py
├── inference.py
├── main.py
├── models
├── __init__.py
├── backbone.py
├── dcn
│ ├── __init__.py
│ ├── deform_conv.py
│ ├── deformable
│ │ ├── deform_conv.cpp
│ │ ├── deform_conv.h
│ │ ├── deform_conv_cuda.cu
│ │ └── deform_conv_cuda_kernel.cu
│ ├── setup.py
│ └── test_deform.py
├── matcher.py
├── position_encoding.py
├── segmentation.py
├── transformer.py
└── vistr.py
└── util
├── __init__.py
├── box_ops.py
├── detr_weights_to_vistr.py
├── misc.py
└── plot_utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2020 - present, Facebook, Inc
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## VisTR: End-to-End Video Instance Segmentation with Transformers
2 |
3 | This is the official implementation of the [VisTR paper](https://arxiv.org/abs/2011.14503):
4 |
5 |
6 |
7 |
8 |
9 |
10 | ### Installation
11 | We provide instructions how to install dependencies via conda.
12 | First, clone the repository locally:
13 | ```
14 | git clone https://github.com/Epiphqny/vistr.git
15 | ```
16 | Then, install PyTorch 1.6 and torchvision 0.7:
17 | ```
18 | conda install pytorch==1.6.0 torchvision==0.7.0
19 | ```
20 | Install pycocotools
21 | ```
22 | conda install cython scipy
23 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
24 | pip install git+https://github.com/youtubevos/cocoapi.git#"egg=pycocotools&subdirectory=PythonAPI"
25 | ```
26 | Compile DCN module(requires GCC>=5.3, cuda>=10.0)
27 | ```
28 | cd models/dcn
29 | python setup.py build_ext --inplace
30 | ```
31 |
32 | ### Preparation
33 |
34 | Download and extract 2019 version of YoutubeVIS train and val images with annotations from
35 | [CodeLab](https://competitions.codalab.org/competitions/20128#participate-get_data) or [YoutubeVIS](https://youtube-vos.org/dataset/vis/).
36 | We expect the directory structure to be the following:
37 | ```
38 | VisTR
39 | ├── data
40 | │ ├── train
41 | │ ├── val
42 | │ ├── annotations
43 | │ │ ├── instances_train_sub.json
44 | │ │ ├── instances_val_sub.json
45 | ├── models
46 | ...
47 | ```
48 |
49 | Download the pretrained DETR models [Google Drive](https://drive.google.com/drive/folders/1DlN8uWHT2WaKruarGW2_XChhpZeI9MFG?usp=sharing) [BaiduYun](https://pan.baidu.com/s/12omUNDRjhAeGZ5olqQPpHA)(passcode:alge) on COCO and save it to the pretrained path.
50 |
51 |
52 | ### Training
53 |
54 | Training of the model requires at least 32g memory GPU, we performed the experiment on 32g V100 card. (As the training resolution is limited by the GPU memory, if you have a larger memory GPU and want to perform the experiment, please contact with me, thanks very much)
55 |
56 | To train baseline VisTR on a single node with 8 gpus for 18 epochs, run:
57 | ```
58 | python -m torch.distributed.launch --nproc_per_node=8 --use_env main.py --backbone resnet101/50 --ytvos_path /path/to/ytvos --masks --pretrained_weights /path/to/pretrained_path
59 | ```
60 |
61 | ### Inference
62 |
63 | ```
64 | python inference.py --masks --model_path /path/to/model_weights --save_path /path/to/results.json
65 | ```
66 |
67 | ### Models
68 |
69 | We provide baseline VisTR models, and plan to include more in future. AP is computed on YouTubeVIS dataset by submitting the result json file to the [CodeLab](https://competitions.codalab.org/competitions/20128#results) system, and inference time is calculated by pure model inference time (without data-loading and post-processing).
70 |
71 |
72 |
73 |
74 | |
75 | name |
76 | backbone |
77 | FPS |
78 | mask AP |
79 | model |
80 | result json zip |
81 | detailed AP |
82 |
83 |
84 |
85 |
86 | 0 |
87 | VisTR |
88 | R50 |
89 | 69.9 |
90 | 36.2 |
91 | vistr_r50.pth |
92 | vistr_r50.zip |
93 |
94 |
95 |
96 |
97 | |
98 |
99 |
100 | 1 |
101 | VisTR |
102 | R101 |
103 | 57.7 |
104 | 40.1 |
105 | vistr_r101.pth |
106 | vistr_r101.zip |
107 |
108 |
109 |
110 |
111 | |
112 |
113 |
114 |
115 |
116 | ### License
117 |
118 | VisTR is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information.
119 |
120 | ### Acknowledgement
121 | We would like to thank the [DETR](https://github.com/facebookresearch/detr) open-source project for its awesome work, part of the code are modified from its project.
122 |
123 | ### Citation
124 |
125 | Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follow.
126 |
127 | ```
128 | @inproceedings{wang2020end,
129 | title={End-to-End Video Instance Segmentation with Transformers},
130 | author={Wang, Yuqing and Xu, Zhaoliang and Wang, Xinlong and Shen, Chunhua and Cheng, Baoshan and Shen, Hao and Xia, Huaxia},
131 | booktitle = {Proc. IEEE Conf. Computer Vision and Pattern Recognition (CVPR)},
132 | year={2021}
133 | }
134 | ```
135 |
136 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | import torchvision
3 |
4 | from .coco import build as build_coco
5 | from .ytvos import build as build_ytvos
6 |
7 | def get_coco_api_from_dataset(dataset):
8 | for _ in range(10):
9 | # if isinstance(dataset, torchvision.datasets.CocoDetection):
10 | # break
11 | if isinstance(dataset, torch.utils.data.Subset):
12 | dataset = dataset.dataset
13 | if isinstance(dataset, torchvision.datasets.CocoDetection):
14 | return dataset.coco
15 |
16 |
17 | def build_dataset(image_set, args):
18 | if args.dataset_file == 'coco':
19 | return build_coco(image_set, args)
20 | if args.dataset_file == 'coco_panoptic':
21 | # to avoid making panopticapi required for coco
22 | from .coco_panoptic import build as build_coco_panoptic
23 | return build_coco_panoptic(image_set, args)
24 | if args.dataset_file == 'ytvos':
25 | return build_ytvos(image_set, args)
26 | raise ValueError(f'dataset {args.dataset_file} not supported')
27 |
--------------------------------------------------------------------------------
/datasets/coco.py:
--------------------------------------------------------------------------------
1 | """
2 | COCO dataset which returns image_id for evaluation.
3 |
4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
5 | """
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.utils.data
10 | import torchvision
11 | from pycocotools import mask as coco_mask
12 |
13 | import datasets.transforms as T
14 |
15 |
16 | class CocoDetection(torchvision.datasets.CocoDetection):
17 | def __init__(self, img_folder, ann_file, transforms, return_masks):
18 | super(CocoDetection, self).__init__(img_folder, ann_file)
19 | self._transforms = transforms
20 | self.prepare = ConvertCocoPolysToMask(return_masks)
21 |
22 | def __getitem__(self, idx):
23 | img, target = super(CocoDetection, self).__getitem__(idx)
24 | image_id = self.ids[idx]
25 | target = {'image_id': image_id, 'annotations': target}
26 |
27 | img, target = self.prepare(img, target)
28 | if self._transforms is not None:
29 | img, target = self._transforms(img, target)
30 | return img, target
31 |
32 |
33 | def convert_coco_poly_to_mask(segmentations, height, width):
34 | masks = []
35 | for polygons in segmentations:
36 | rles = coco_mask.frPyObjects(polygons, height, width)
37 | mask = coco_mask.decode(rles)
38 | if len(mask.shape) < 3:
39 | mask = mask[..., None]
40 | mask = torch.as_tensor(mask, dtype=torch.uint8)
41 | mask = mask.any(dim=2)
42 | masks.append(mask)
43 | if masks:
44 | masks = torch.stack(masks, dim=0)
45 | else:
46 | masks = torch.zeros((0, height, width), dtype=torch.uint8)
47 | return masks
48 |
49 |
50 | class ConvertCocoPolysToMask(object):
51 | def __init__(self, return_masks=False):
52 | self.return_masks = return_masks
53 |
54 | def __call__(self, image, target):
55 | w, h = image.size
56 |
57 | image_id = target["image_id"]
58 | image_id = torch.tensor([image_id])
59 |
60 | anno = target["annotations"]
61 |
62 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
63 |
64 | boxes = [obj["bbox"] for obj in anno]
65 | # guard against no boxes via resizing
66 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
67 | boxes[:, 2:] += boxes[:, :2]
68 | boxes[:, 0::2].clamp_(min=0, max=w)
69 | boxes[:, 1::2].clamp_(min=0, max=h)
70 |
71 | classes = [obj["category_id"] for obj in anno]
72 | classes = torch.tensor(classes, dtype=torch.int64)
73 |
74 | if self.return_masks:
75 | segmentations = [obj["segmentation"] for obj in anno]
76 | masks = convert_coco_poly_to_mask(segmentations, h, w)
77 |
78 | keypoints = None
79 | if anno and "keypoints" in anno[0]:
80 | keypoints = [obj["keypoints"] for obj in anno]
81 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
82 | num_keypoints = keypoints.shape[0]
83 | if num_keypoints:
84 | keypoints = keypoints.view(num_keypoints, -1, 3)
85 |
86 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
87 | boxes = boxes[keep]
88 | classes = classes[keep]
89 | if self.return_masks:
90 | masks = masks[keep]
91 | if keypoints is not None:
92 | keypoints = keypoints[keep]
93 |
94 | target = {}
95 | target["boxes"] = boxes
96 | target["labels"] = classes
97 | if self.return_masks:
98 | target["masks"] = masks
99 | target["image_id"] = image_id
100 | if keypoints is not None:
101 | target["keypoints"] = keypoints
102 |
103 | # for conversion to coco api
104 | area = torch.tensor([obj["area"] for obj in anno])
105 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
106 | target["area"] = area[keep]
107 | target["iscrowd"] = iscrowd[keep]
108 |
109 | target["orig_size"] = torch.as_tensor([int(h), int(w)])
110 | target["size"] = torch.as_tensor([int(h), int(w)])
111 |
112 | return image, target
113 |
114 |
115 | def make_coco_transforms(image_set):
116 |
117 | normalize = T.Compose([
118 | T.ToTensor(),
119 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
120 | ])
121 |
122 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
123 |
124 | if image_set == 'train':
125 | return T.Compose([
126 | T.RandomHorizontalFlip(),
127 | T.RandomSelect(
128 | T.RandomResize(scales, max_size=1333),
129 | T.Compose([
130 | T.RandomResize([400, 500, 600]),
131 | T.RandomSizeCrop(384, 600),
132 | T.RandomResize(scales, max_size=1333),
133 | ])
134 | ),
135 | normalize,
136 | ])
137 |
138 | if image_set == 'val':
139 | return T.Compose([
140 | T.RandomResize([800], max_size=1333),
141 | normalize,
142 | ])
143 |
144 | raise ValueError(f'unknown {image_set}')
145 |
146 |
147 | def build(image_set, args):
148 | root = Path(args.coco_path)
149 | assert root.exists(), f'provided COCO path {root} does not exist'
150 | mode = 'instances'
151 | PATHS = {
152 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
153 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
154 | }
155 | img_folder, ann_file = PATHS[image_set]
156 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
157 | return dataset
158 |
--------------------------------------------------------------------------------
/datasets/coco_eval.py:
--------------------------------------------------------------------------------
1 | """
2 | COCO evaluator that works in distributed mode.
3 |
4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
5 | The difference is that there is less copy-pasting from pycocotools
6 | in the end of the file, as python3 can suppress prints with contextlib
7 | """
8 | import os
9 | import contextlib
10 | import copy
11 | import numpy as np
12 | import torch
13 |
14 | from pycocotools.cocoeval import COCOeval
15 | from pycocotools.coco import COCO
16 | import pycocotools.mask as mask_util
17 |
18 | from util.misc import all_gather
19 |
20 |
21 | class CocoEvaluator(object):
22 | def __init__(self, coco_gt, iou_types):
23 | assert isinstance(iou_types, (list, tuple))
24 | coco_gt = copy.deepcopy(coco_gt)
25 | self.coco_gt = coco_gt
26 |
27 | self.iou_types = iou_types
28 | self.coco_eval = {}
29 | for iou_type in iou_types:
30 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
31 |
32 | self.img_ids = []
33 | self.eval_imgs = {k: [] for k in iou_types}
34 |
35 | def update(self, predictions):
36 | img_ids = list(np.unique(list(predictions.keys())))
37 | self.img_ids.extend(img_ids)
38 |
39 | for iou_type in self.iou_types:
40 | results = self.prepare(predictions, iou_type)
41 |
42 | # suppress pycocotools prints
43 | with open(os.devnull, 'w') as devnull:
44 | with contextlib.redirect_stdout(devnull):
45 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
46 | coco_eval = self.coco_eval[iou_type]
47 |
48 | coco_eval.cocoDt = coco_dt
49 | coco_eval.params.imgIds = list(img_ids)
50 | img_ids, eval_imgs = evaluate(coco_eval)
51 |
52 | self.eval_imgs[iou_type].append(eval_imgs)
53 |
54 | def synchronize_between_processes(self):
55 | for iou_type in self.iou_types:
56 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
57 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
58 |
59 | def accumulate(self):
60 | for coco_eval in self.coco_eval.values():
61 | coco_eval.accumulate()
62 |
63 | def summarize(self):
64 | for iou_type, coco_eval in self.coco_eval.items():
65 | print("IoU metric: {}".format(iou_type))
66 | coco_eval.summarize()
67 |
68 | def prepare(self, predictions, iou_type):
69 | if iou_type == "bbox":
70 | return self.prepare_for_coco_detection(predictions)
71 | elif iou_type == "segm":
72 | return self.prepare_for_coco_segmentation(predictions)
73 | elif iou_type == "keypoints":
74 | return self.prepare_for_coco_keypoint(predictions)
75 | else:
76 | raise ValueError("Unknown iou type {}".format(iou_type))
77 |
78 | def prepare_for_coco_detection(self, predictions):
79 | coco_results = []
80 | for original_id, prediction in predictions.items():
81 | if len(prediction) == 0:
82 | continue
83 |
84 | boxes = prediction["boxes"]
85 | boxes = convert_to_xywh(boxes).tolist()
86 | scores = prediction["scores"].tolist()
87 | labels = prediction["labels"].tolist()
88 |
89 | coco_results.extend(
90 | [
91 | {
92 | "image_id": original_id,
93 | "category_id": labels[k],
94 | "bbox": box,
95 | "score": scores[k],
96 | }
97 | for k, box in enumerate(boxes)
98 | ]
99 | )
100 | return coco_results
101 |
102 | def prepare_for_coco_segmentation(self, predictions):
103 | coco_results = []
104 | for original_id, prediction in predictions.items():
105 | if len(prediction) == 0:
106 | continue
107 |
108 | scores = prediction["scores"]
109 | labels = prediction["labels"]
110 | masks = prediction["masks"]
111 |
112 | masks = masks > 0.5
113 |
114 | scores = prediction["scores"].tolist()
115 | labels = prediction["labels"].tolist()
116 |
117 | rles = [
118 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
119 | for mask in masks
120 | ]
121 | for rle in rles:
122 | rle["counts"] = rle["counts"].decode("utf-8")
123 |
124 | coco_results.extend(
125 | [
126 | {
127 | "image_id": original_id,
128 | "category_id": labels[k],
129 | "segmentation": rle,
130 | "score": scores[k],
131 | }
132 | for k, rle in enumerate(rles)
133 | ]
134 | )
135 | return coco_results
136 |
137 | def prepare_for_coco_keypoint(self, predictions):
138 | coco_results = []
139 | for original_id, prediction in predictions.items():
140 | if len(prediction) == 0:
141 | continue
142 |
143 | boxes = prediction["boxes"]
144 | boxes = convert_to_xywh(boxes).tolist()
145 | scores = prediction["scores"].tolist()
146 | labels = prediction["labels"].tolist()
147 | keypoints = prediction["keypoints"]
148 | keypoints = keypoints.flatten(start_dim=1).tolist()
149 |
150 | coco_results.extend(
151 | [
152 | {
153 | "image_id": original_id,
154 | "category_id": labels[k],
155 | 'keypoints': keypoint,
156 | "score": scores[k],
157 | }
158 | for k, keypoint in enumerate(keypoints)
159 | ]
160 | )
161 | return coco_results
162 |
163 |
164 | def convert_to_xywh(boxes):
165 | xmin, ymin, xmax, ymax = boxes.unbind(1)
166 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
167 |
168 |
169 | def merge(img_ids, eval_imgs):
170 | all_img_ids = all_gather(img_ids)
171 | all_eval_imgs = all_gather(eval_imgs)
172 |
173 | merged_img_ids = []
174 | for p in all_img_ids:
175 | merged_img_ids.extend(p)
176 |
177 | merged_eval_imgs = []
178 | for p in all_eval_imgs:
179 | merged_eval_imgs.append(p)
180 |
181 | merged_img_ids = np.array(merged_img_ids)
182 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
183 |
184 | # keep only unique (and in sorted order) images
185 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
186 | merged_eval_imgs = merged_eval_imgs[..., idx]
187 |
188 | return merged_img_ids, merged_eval_imgs
189 |
190 |
191 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
192 | img_ids, eval_imgs = merge(img_ids, eval_imgs)
193 | img_ids = list(img_ids)
194 | eval_imgs = list(eval_imgs.flatten())
195 |
196 | coco_eval.evalImgs = eval_imgs
197 | coco_eval.params.imgIds = img_ids
198 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
199 |
200 |
201 | #################################################################
202 | # From pycocotools, just removed the prints and fixed
203 | # a Python3 bug about unicode not defined
204 | #################################################################
205 |
206 |
207 | def evaluate(self):
208 | '''
209 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
210 | :return: None
211 | '''
212 | # tic = time.time()
213 | # print('Running per image evaluation...')
214 | p = self.params
215 | # add backward compatibility if useSegm is specified in params
216 | if p.useSegm is not None:
217 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
218 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
219 | # print('Evaluate annotation type *{}*'.format(p.iouType))
220 | p.imgIds = list(np.unique(p.imgIds))
221 | if p.useCats:
222 | p.catIds = list(np.unique(p.catIds))
223 | p.maxDets = sorted(p.maxDets)
224 | self.params = p
225 |
226 | self._prepare()
227 | # loop through images, area range, max detection number
228 | catIds = p.catIds if p.useCats else [-1]
229 |
230 | if p.iouType == 'segm' or p.iouType == 'bbox':
231 | computeIoU = self.computeIoU
232 | elif p.iouType == 'keypoints':
233 | computeIoU = self.computeOks
234 | self.ious = {
235 | (imgId, catId): computeIoU(imgId, catId)
236 | for imgId in p.imgIds
237 | for catId in catIds}
238 |
239 | evaluateImg = self.evaluateImg
240 | maxDet = p.maxDets[-1]
241 | evalImgs = [
242 | evaluateImg(imgId, catId, areaRng, maxDet)
243 | for catId in catIds
244 | for areaRng in p.areaRng
245 | for imgId in p.imgIds
246 | ]
247 | # this is NOT in the pycocotools code, but could be done outside
248 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
249 | self._paramsEval = copy.deepcopy(self.params)
250 | # toc = time.time()
251 | # print('DONE (t={:0.2f}s).'.format(toc-tic))
252 | return p.imgIds, evalImgs
253 |
254 | #################################################################
255 | # end of straight copy from pycocotools, just removing the prints
256 | #################################################################
257 |
--------------------------------------------------------------------------------
/datasets/coco_panoptic.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 |
8 | from panopticapi.utils import rgb2id
9 | from util.box_ops import masks_to_boxes
10 |
11 | from .coco import make_coco_transforms
12 |
13 |
14 | class CocoPanoptic:
15 | def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True):
16 | with open(ann_file, 'r') as f:
17 | self.coco = json.load(f)
18 |
19 | # sort 'images' field so that they are aligned with 'annotations'
20 | # i.e., in alphabetical order
21 | self.coco['images'] = sorted(self.coco['images'], key=lambda x: x['id'])
22 | # sanity check
23 | if "annotations" in self.coco:
24 | for img, ann in zip(self.coco['images'], self.coco['annotations']):
25 | assert img['file_name'][:-4] == ann['file_name'][:-4]
26 |
27 | self.img_folder = img_folder
28 | self.ann_folder = ann_folder
29 | self.ann_file = ann_file
30 | self.transforms = transforms
31 | self.return_masks = return_masks
32 |
33 | def __getitem__(self, idx):
34 | ann_info = self.coco['annotations'][idx] if "annotations" in self.coco else self.coco['images'][idx]
35 | img_path = Path(self.img_folder) / ann_info['file_name'].replace('.png', '.jpg')
36 | ann_path = Path(self.ann_folder) / ann_info['file_name']
37 |
38 | img = Image.open(img_path).convert('RGB')
39 | w, h = img.size
40 | if "segments_info" in ann_info:
41 | masks = np.asarray(Image.open(ann_path), dtype=np.uint32)
42 | masks = rgb2id(masks)
43 |
44 | ids = np.array([ann['id'] for ann in ann_info['segments_info']])
45 | masks = masks == ids[:, None, None]
46 |
47 | masks = torch.as_tensor(masks, dtype=torch.uint8)
48 | labels = torch.tensor([ann['category_id'] for ann in ann_info['segments_info']], dtype=torch.int64)
49 |
50 | target = {}
51 | target['image_id'] = torch.tensor([ann_info['image_id'] if "image_id" in ann_info else ann_info["id"]])
52 | if self.return_masks:
53 | target['masks'] = masks
54 | target['labels'] = labels
55 |
56 | target["boxes"] = masks_to_boxes(masks)
57 |
58 | target['size'] = torch.as_tensor([int(h), int(w)])
59 | target['orig_size'] = torch.as_tensor([int(h), int(w)])
60 | if "segments_info" in ann_info:
61 | for name in ['iscrowd', 'area']:
62 | target[name] = torch.tensor([ann[name] for ann in ann_info['segments_info']])
63 |
64 | if self.transforms is not None:
65 | img, target = self.transforms(img, target)
66 |
67 | return img, target
68 |
69 | def __len__(self):
70 | return len(self.coco['images'])
71 |
72 | def get_height_and_width(self, idx):
73 | img_info = self.coco['images'][idx]
74 | height = img_info['height']
75 | width = img_info['width']
76 | return height, width
77 |
78 |
79 | def build(image_set, args):
80 | img_folder_root = Path(args.coco_path)
81 | ann_folder_root = Path(args.coco_panoptic_path)
82 | assert img_folder_root.exists(), f'provided COCO path {img_folder_root} does not exist'
83 | assert ann_folder_root.exists(), f'provided COCO path {ann_folder_root} does not exist'
84 | mode = 'panoptic'
85 | PATHS = {
86 | "train": ("train2017", Path("annotations") / f'{mode}_train2017.json'),
87 | "val": ("val2017", Path("annotations") / f'{mode}_val2017.json'),
88 | }
89 |
90 | img_folder, ann_file = PATHS[image_set]
91 | img_folder_path = img_folder_root / img_folder
92 | ann_folder = ann_folder_root / f'{mode}_{img_folder}'
93 | ann_file = ann_folder_root / ann_file
94 |
95 | dataset = CocoPanoptic(img_folder_path, ann_folder, ann_file,
96 | transforms=make_coco_transforms(image_set), return_masks=args.masks)
97 |
98 | return dataset
99 |
--------------------------------------------------------------------------------
/datasets/panoptic_eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 | import util.misc as utils
5 |
6 | try:
7 | from panopticapi.evaluation import pq_compute
8 | except ImportError:
9 | pass
10 |
11 |
12 | class PanopticEvaluator(object):
13 | def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"):
14 | self.gt_json = ann_file
15 | self.gt_folder = ann_folder
16 | if utils.is_main_process():
17 | if not os.path.exists(output_dir):
18 | os.mkdir(output_dir)
19 | self.output_dir = output_dir
20 | self.predictions = []
21 |
22 | def update(self, predictions):
23 | for p in predictions:
24 | with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f:
25 | f.write(p.pop("png_string"))
26 |
27 | self.predictions += predictions
28 |
29 | def synchronize_between_processes(self):
30 | all_predictions = utils.all_gather(self.predictions)
31 | merged_predictions = []
32 | for p in all_predictions:
33 | merged_predictions += p
34 | self.predictions = merged_predictions
35 |
36 | def summarize(self):
37 | if utils.is_main_process():
38 | json_data = {"annotations": self.predictions}
39 | predictions_json = os.path.join(self.output_dir, "predictions.json")
40 | with open(predictions_json, "w") as f:
41 | f.write(json.dumps(json_data))
42 | return pq_compute(self.gt_json, predictions_json, gt_folder=self.gt_folder, pred_folder=self.output_dir)
43 | return None
44 |
--------------------------------------------------------------------------------
/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | """
2 | Transforms and data augmentation for sequence level images, bboxes and masks.
3 | """
4 | import random
5 |
6 | import PIL
7 | import torch
8 | import torchvision.transforms as T
9 | import torchvision.transforms.functional as F
10 |
11 | from util.box_ops import box_xyxy_to_cxcywh, box_iou
12 | from util.misc import interpolate
13 | import numpy as np
14 | from numpy import random as rand
15 | from PIL import Image
16 | import cv2
17 |
18 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6):
19 | assert mode in ['iou', 'iof']
20 | bboxes1 = bboxes1.astype(np.float32)
21 | bboxes2 = bboxes2.astype(np.float32)
22 | rows = bboxes1.shape[0]
23 | cols = bboxes2.shape[0]
24 | ious = np.zeros((rows, cols), dtype=np.float32)
25 | if rows * cols == 0:
26 | return ious
27 | exchange = False
28 | if bboxes1.shape[0] > bboxes2.shape[0]:
29 | bboxes1, bboxes2 = bboxes2, bboxes1
30 | ious = np.zeros((cols, rows), dtype=np.float32)
31 | exchange = True
32 | area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
33 | area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
34 | for i in range(bboxes1.shape[0]):
35 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
36 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
37 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
38 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
39 | overlap = np.maximum(x_end - x_start, 0) * np.maximum(y_end - y_start, 0)
40 | if mode == 'iou':
41 | union = area1[i] + area2 - overlap
42 | else:
43 | union = area1[i] if not exchange else area2
44 | union = np.maximum(union, eps)
45 | ious[i, :] = overlap / union
46 | if exchange:
47 | ious = ious.T
48 | return ious
49 |
50 |
51 | def crop(clip, target, region):
52 | cropped_image = []
53 | for image in clip:
54 | cropped_image.append(F.crop(image, *region))
55 |
56 | target = target.copy()
57 | i, j, h, w = region
58 |
59 | # should we do something wrt the original size?
60 | target["size"] = torch.tensor([h, w])
61 |
62 | fields = ["labels", "area", "iscrowd"]
63 |
64 | if "boxes" in target:
65 | boxes = target["boxes"]
66 | max_size = torch.as_tensor([w, h], dtype=torch.float32)
67 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
68 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
69 | cropped_boxes = cropped_boxes.clamp(min=0)
70 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
71 | target["boxes"] = cropped_boxes.reshape(-1, 4)
72 | target["area"] = area
73 | fields.append("boxes")
74 |
75 | if "masks" in target:
76 | # FIXME should we update the area here if there are no boxes?
77 | target['masks'] = target['masks'][:, i:i + h, j:j + w]
78 | fields.append("masks")
79 |
80 | return cropped_image, target
81 |
82 |
83 | def hflip(clip, target):
84 | flipped_image = []
85 | for image in clip:
86 | flipped_image.append(F.hflip(image))
87 |
88 | w, h = clip[0].size
89 |
90 | target = target.copy()
91 | if "boxes" in target:
92 | boxes = target["boxes"]
93 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
94 | target["boxes"] = boxes
95 |
96 | if "masks" in target:
97 | target['masks'] = target['masks'].flip(-1)
98 |
99 | return flipped_image, target
100 |
101 | def vflip(image,target):
102 | flipped_image = []
103 | for image in clip:
104 | flipped_image.append(F.vflip(image))
105 | w, h = clip[0].size
106 | target = target.copy()
107 | if "boxes" in target:
108 | boxes = target["boxes"]
109 | boxes = boxes[:, [0, 3, 2, 1]] * torch.as_tensor([1, -1, 1, -1]) + torch.as_tensor([0, h, 0, h])
110 | target["boxes"] = boxes
111 |
112 | if "masks" in target:
113 | target['masks'] = target['masks'].flip(1)
114 |
115 | return flipped_image, target
116 |
117 | def resize(clip, target, size, max_size=None):
118 | # size can be min_size (scalar) or (w, h) tuple
119 |
120 | def get_size_with_aspect_ratio(image_size, size, max_size=None):
121 | w, h = image_size
122 | if max_size is not None:
123 | min_original_size = float(min((w, h)))
124 | max_original_size = float(max((w, h)))
125 | if max_original_size / min_original_size * size > max_size:
126 | size = int(round(max_size * min_original_size / max_original_size))
127 |
128 | if (w <= h and w == size) or (h <= w and h == size):
129 | return (h, w)
130 |
131 | if w < h:
132 | ow = size
133 | oh = int(size * h / w)
134 | else:
135 | oh = size
136 | ow = int(size * w / h)
137 |
138 | return (oh, ow)
139 |
140 | def get_size(image_size, size, max_size=None):
141 | if isinstance(size, (list, tuple)):
142 | return size[::-1]
143 | else:
144 | return get_size_with_aspect_ratio(image_size, size, max_size)
145 |
146 | size = get_size(clip[0].size, size, max_size)
147 | rescaled_image = []
148 | for image in clip:
149 | rescaled_image.append(F.resize(image, size))
150 |
151 | if target is None:
152 | return rescaled_image, None
153 |
154 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image[0].size, clip[0].size))
155 | ratio_width, ratio_height = ratios
156 |
157 | target = target.copy()
158 | if "boxes" in target:
159 | boxes = target["boxes"]
160 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
161 | target["boxes"] = scaled_boxes
162 |
163 | if "area" in target:
164 | area = target["area"]
165 | scaled_area = area * (ratio_width * ratio_height)
166 | target["area"] = scaled_area
167 |
168 | h, w = size
169 | target["size"] = torch.tensor([h, w])
170 |
171 | if "masks" in target:
172 | if target['masks'].shape[0]>0:
173 | target['masks'] = interpolate(
174 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
175 | else:
176 | target['masks'] = torch.zeros((target['masks'].shape[0],h,w))
177 | return rescaled_image, target
178 |
179 |
180 | def pad(clip, target, padding):
181 | # assumes that we only pad on the bottom right corners
182 | padded_image = []
183 | for image in clip:
184 | padded_image.append(F.pad(image, (0, 0, padding[0], padding[1])))
185 | if target is None:
186 | return padded_image, None
187 | target = target.copy()
188 | # should we do something wrt the original size?
189 | target["size"] = torch.tensor(padded_image[0].size[::-1])
190 | if "masks" in target:
191 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
192 | return padded_image, target
193 |
194 |
195 | class RandomCrop(object):
196 | def __init__(self, size):
197 | self.size = size
198 |
199 | def __call__(self, img, target):
200 | region = T.RandomCrop.get_params(img, self.size)
201 | return crop(img, target, region)
202 |
203 |
204 | class RandomSizeCrop(object):
205 | def __init__(self, min_size: int, max_size: int):
206 | self.min_size = min_size
207 | self.max_size = max_size
208 |
209 | def __call__(self, img: PIL.Image.Image, target: dict):
210 | w = random.randint(self.min_size, min(img[0].width, self.max_size))
211 | h = random.randint(self.min_size, min(img[0].height, self.max_size))
212 | region = T.RandomCrop.get_params(img[0], [h, w])
213 | return crop(img, target, region)
214 |
215 |
216 | class CenterCrop(object):
217 | def __init__(self, size):
218 | self.size = size
219 |
220 | def __call__(self, img, target):
221 | image_width, image_height = img.size
222 | crop_height, crop_width = self.size
223 | crop_top = int(round((image_height - crop_height) / 2.))
224 | crop_left = int(round((image_width - crop_width) / 2.))
225 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
226 |
227 |
228 | class MinIoURandomCrop(object):
229 | def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3):
230 | self.min_ious = min_ious
231 | self.sample_mode = (1, *min_ious, 0)
232 | self.min_crop_size = min_crop_size
233 |
234 | def __call__(self, img, target):
235 | w,h = img.size
236 | while True:
237 | mode = random.choice(self.sample_mode)
238 | self.mode = mode
239 | if mode == 1:
240 | return img,target
241 | min_iou = mode
242 | boxes = target['boxes'].numpy()
243 | labels = target['labels']
244 |
245 | for i in range(50):
246 | new_w = rand.uniform(self.min_crop_size * w, w)
247 | new_h = rand.uniform(self.min_crop_size * h, h)
248 | if new_h / new_w < 0.5 or new_h / new_w > 2:
249 | continue
250 | left = rand.uniform(w - new_w)
251 | top = rand.uniform(h - new_h)
252 | patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h)))
253 | if patch[2] == patch[0] or patch[3] == patch[1]:
254 | continue
255 | overlaps = bbox_overlaps(patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
256 | if len(overlaps) > 0 and overlaps.min() < min_iou:
257 | continue
258 |
259 | if len(overlaps) > 0:
260 | def is_center_of_bboxes_in_patch(boxes, patch):
261 | center = (boxes[:, :2] + boxes[:, 2:]) / 2
262 | mask = ((center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * (center[:, 1] < patch[3]))
263 | return mask
264 | mask = is_center_of_bboxes_in_patch(boxes, patch)
265 | if False in mask:
266 | continue
267 | #TODO: use no center boxes
268 | #if not mask.any():
269 | # continue
270 |
271 | boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
272 | boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
273 | boxes -= np.tile(patch[:2], 2)
274 | target['boxes'] = torch.tensor(boxes)
275 |
276 | img = np.asarray(img)[patch[1]:patch[3], patch[0]:patch[2]]
277 | img = Image.fromarray(img)
278 | width, height = img.size
279 | target['orig_size'] = torch.tensor([height,width])
280 | target['size'] = torch.tensor([height,width])
281 | return img,target
282 |
283 |
284 | class RandomContrast(object):
285 | def __init__(self, lower=0.5, upper=1.5):
286 | self.lower = lower
287 | self.upper = upper
288 | assert self.upper >= self.lower, "contrast upper must be >= lower."
289 | assert self.lower >= 0, "contrast lower must be non-negative."
290 | def __call__(self, image, target):
291 |
292 | if rand.randint(2):
293 | alpha = rand.uniform(self.lower, self.upper)
294 | image *= alpha
295 | return image, target
296 |
297 | class RandomBrightness(object):
298 | def __init__(self, delta=32):
299 | assert delta >= 0.0
300 | assert delta <= 255.0
301 | self.delta = delta
302 | def __call__(self, image, target):
303 | if rand.randint(2):
304 | delta = rand.uniform(-self.delta, self.delta)
305 | image += delta
306 | return image, target
307 |
308 | class RandomSaturation(object):
309 | def __init__(self, lower=0.5, upper=1.5):
310 | self.lower = lower
311 | self.upper = upper
312 | assert self.upper >= self.lower, "contrast upper must be >= lower."
313 | assert self.lower >= 0, "contrast lower must be non-negative."
314 |
315 | def __call__(self, image, target):
316 | if rand.randint(2):
317 | image[:, :, 1] *= rand.uniform(self.lower, self.upper)
318 | return image, target
319 |
320 | class RandomHue(object): #
321 | def __init__(self, delta=18.0):
322 | assert delta >= 0.0 and delta <= 360.0
323 | self.delta = delta
324 |
325 | def __call__(self, image, target):
326 | if rand.randint(2):
327 | image[:, :, 0] += rand.uniform(-self.delta, self.delta)
328 | image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0
329 | image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
330 | return image, target
331 |
332 | class RandomLightingNoise(object):
333 | def __init__(self):
334 | self.perms = ((0, 1, 2), (0, 2, 1),
335 | (1, 0, 2), (1, 2, 0),
336 | (2, 0, 1), (2, 1, 0))
337 | def __call__(self, image, target):
338 | if rand.randint(2):
339 | swap = self.perms[rand.randint(len(self.perms))]
340 | shuffle = SwapChannels(swap) # shuffle channels
341 | image = shuffle(image)
342 | return image, target
343 |
344 | class ConvertColor(object):
345 | def __init__(self, current='BGR', transform='HSV'):
346 | self.transform = transform
347 | self.current = current
348 |
349 | def __call__(self, image, target):
350 | if self.current == 'BGR' and self.transform == 'HSV':
351 | image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
352 | elif self.current == 'HSV' and self.transform == 'BGR':
353 | image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
354 | else:
355 | raise NotImplementedError
356 | return image, target
357 |
358 | class SwapChannels(object):
359 | def __init__(self, swaps):
360 | self.swaps = swaps
361 | def __call__(self, image):
362 | image = image[:, :, self.swaps]
363 | return image
364 |
365 | class PhotometricDistort(object):
366 | def __init__(self):
367 | self.pd = [
368 | RandomContrast(),
369 | ConvertColor(transform='HSV'),
370 | RandomSaturation(),
371 | RandomHue(),
372 | ConvertColor(current='HSV', transform='BGR'),
373 | RandomContrast()
374 | ]
375 | self.rand_brightness = RandomBrightness()
376 | self.rand_light_noise = RandomLightingNoise()
377 |
378 | def __call__(self,clip,target):
379 | imgs = []
380 | for img in clip:
381 | img = np.asarray(img).astype('float32')
382 | img, target = self.rand_brightness(img, target)
383 | if rand.randint(2):
384 | distort = Compose(self.pd[:-1])
385 | else:
386 | distort = Compose(self.pd[1:])
387 | img, target = distort(img, target)
388 | img, target = self.rand_light_noise(img, target)
389 | imgs.append(Image.fromarray(img.astype('uint8')))
390 | return imgs, target
391 |
392 | #NOTICE: if used for mask, need to change
393 | class Expand(object):
394 | def __init__(self, mean):
395 | self.mean = mean
396 | def __call__(self, clip, target):
397 | if rand.randint(2):
398 | return clip,target
399 | imgs = []
400 | masks = []
401 | image = np.asarray(clip[0]).astype('float32')
402 | height, width, depth = image.shape
403 | ratio = rand.uniform(1, 4)
404 | left = rand.uniform(0, width*ratio - width)
405 | top = rand.uniform(0, height*ratio - height)
406 | for i in range(len(clip)):
407 | image = np.asarray(clip[i]).astype('float32')
408 | expand_image = np.zeros((int(height*ratio), int(width*ratio), depth),dtype=image.dtype)
409 | expand_image[:, :, :] = self.mean
410 | expand_image[int(top):int(top + height),int(left):int(left + width)] = image
411 | imgs.append(Image.fromarray(expand_image.astype('uint8')))
412 | expand_mask = torch.zeros((int(height*ratio), int(width*ratio)),dtype=torch.uint8)
413 | expand_mask[int(top):int(top + height),int(left):int(left + width)] = target['masks'][i]
414 | masks.append(expand_mask)
415 | boxes = target['boxes'].numpy()
416 | boxes[:, :2] += (int(left), int(top))
417 | boxes[:, 2:] += (int(left), int(top))
418 | target['boxes'] = torch.tensor(boxes)
419 | target['masks']=torch.stack(masks)
420 | return imgs, target
421 |
422 | class RandomHorizontalFlip(object):
423 | def __init__(self, p=0.5):
424 | self.p = p
425 |
426 | def __call__(self, img, target):
427 | if random.random() < self.p:
428 | return hflip(img, target)
429 | return img, target
430 |
431 | class RandomVerticalFlip(object):
432 | def __init__(self, p=0.5):
433 | self.p = p
434 |
435 | def __call__(self, img, target):
436 | if random.random() < self.p:
437 | return vflip(img, target)
438 | return img, target
439 |
440 |
441 | class RandomResize(object):
442 | def __init__(self, sizes, max_size=None):
443 | assert isinstance(sizes, (list, tuple))
444 | self.sizes = sizes
445 | self.max_size = max_size
446 |
447 | def __call__(self, img, target=None):
448 | size = random.choice(self.sizes)
449 | return resize(img, target, size, self.max_size)
450 |
451 |
452 | class RandomPad(object):
453 | def __init__(self, max_pad):
454 | self.max_pad = max_pad
455 |
456 | def __call__(self, img, target):
457 | pad_x = random.randint(0, self.max_pad)
458 | pad_y = random.randint(0, self.max_pad)
459 | return pad(img, target, (pad_x, pad_y))
460 |
461 |
462 | class RandomSelect(object):
463 | """
464 | Randomly selects between transforms1 and transforms2,
465 | with probability p for transforms1 and (1 - p) for transforms2
466 | """
467 | def __init__(self, transforms1, transforms2, p=0.5):
468 | self.transforms1 = transforms1
469 | self.transforms2 = transforms2
470 | self.p = p
471 |
472 | def __call__(self, img, target):
473 | if random.random() < self.p:
474 | return self.transforms1(img, target)
475 | return self.transforms2(img, target)
476 |
477 |
478 | class ToTensor(object):
479 | def __call__(self, clip, target):
480 | img = []
481 | for im in clip:
482 | img.append(F.to_tensor(im))
483 | return img, target
484 |
485 |
486 | class RandomErasing(object):
487 |
488 | def __init__(self, *args, **kwargs):
489 | self.eraser = T.RandomErasing(*args, **kwargs)
490 |
491 | def __call__(self, img, target):
492 | return self.eraser(img), target
493 |
494 |
495 | class Normalize(object):
496 | def __init__(self, mean, std):
497 | self.mean = mean
498 | self.std = std
499 |
500 | def __call__(self, clip, target=None):
501 | image = []
502 | for im in clip:
503 | image.append(F.normalize(im, mean=self.mean, std=self.std))
504 | if target is None:
505 | return image, None
506 | target = target.copy()
507 | h, w = image[0].shape[-2:]
508 | if "boxes" in target:
509 | boxes = target["boxes"]
510 | boxes = box_xyxy_to_cxcywh(boxes)
511 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
512 | target["boxes"] = boxes
513 | return image, target
514 |
515 |
516 | class Compose(object):
517 | def __init__(self, transforms):
518 | self.transforms = transforms
519 |
520 | def __call__(self, image, target):
521 | for t in self.transforms:
522 | image, target = t(image, target)
523 | return image, target
524 |
525 | def __repr__(self):
526 | format_string = self.__class__.__name__ + "("
527 | for t in self.transforms:
528 | format_string += "\n"
529 | format_string += " {0}".format(t)
530 | format_string += "\n)"
531 | return format_string
532 |
--------------------------------------------------------------------------------
/datasets/ytvos.py:
--------------------------------------------------------------------------------
1 | """
2 | YoutubeVIS data loader
3 | """
4 | from pathlib import Path
5 |
6 | import torch
7 | import torch.utils.data
8 | import torchvision
9 | from pycocotools.ytvos import YTVOS
10 | from pycocotools.ytvoseval import YTVOSeval
11 | import datasets.transforms as T
12 | from pycocotools import mask as coco_mask
13 | import os
14 | from PIL import Image
15 | from random import randint
16 | import cv2
17 | import random
18 |
19 | class YTVOSDataset:
20 | def __init__(self, img_folder, ann_file, transforms, return_masks, num_frames):
21 | self.img_folder = img_folder
22 | self.ann_file = ann_file
23 | self._transforms = transforms
24 | self.return_masks = return_masks
25 | self.num_frames = num_frames
26 | self.prepare = ConvertCocoPolysToMask(return_masks)
27 | self.ytvos = YTVOS(ann_file)
28 | self.cat_ids = self.ytvos.getCatIds()
29 | self.vid_ids = self.ytvos.getVidIds()
30 | self.vid_infos = []
31 | for i in self.vid_ids:
32 | info = self.ytvos.loadVids([i])[0]
33 | info['filenames'] = info['file_names']
34 | self.vid_infos.append(info)
35 | self.img_ids = []
36 | for idx, vid_info in enumerate(self.vid_infos):
37 | for frame_id in range(len(vid_info['filenames'])):
38 | self.img_ids.append((idx, frame_id))
39 | def __len__(self):
40 | return len(self.img_ids)
41 |
42 | def __getitem__(self, idx):
43 | vid, frame_id = self.img_ids[idx]
44 | vid_id = self.vid_infos[vid]['id']
45 | img = []
46 | vid_len = len(self.vid_infos[vid]['file_names'])
47 | inds = list(range(self.num_frames))
48 | inds = [i%vid_len for i in inds][::-1]
49 | # if random
50 | # random.shuffle(inds)
51 | for j in range(self.num_frames):
52 | img_path = os.path.join(str(self.img_folder), self.vid_infos[vid]['file_names'][frame_id-inds[j]])
53 | img.append(Image.open(img_path).convert('RGB'))
54 | ann_ids = self.ytvos.getAnnIds(vidIds=[vid_id])
55 | target = self.ytvos.loadAnns(ann_ids)
56 | target = {'image_id': idx, 'video_id': vid, 'frame_id': frame_id, 'annotations': target}
57 | target = self.prepare(img[0], target, inds, self.num_frames)
58 | if self._transforms is not None:
59 | img, target = self._transforms(img, target)
60 | return torch.cat(img,dim=0), target
61 |
62 |
63 | def convert_coco_poly_to_mask(segmentations, height, width):
64 | masks = []
65 | for polygons in segmentations:
66 | if not polygons:
67 | mask = torch.zeros((height,width), dtype=torch.uint8)
68 | else:
69 | rles = coco_mask.frPyObjects(polygons, height, width)
70 | mask = coco_mask.decode(rles)
71 | if len(mask.shape) < 3:
72 | mask = mask[..., None]
73 | mask = torch.as_tensor(mask, dtype=torch.uint8)
74 | mask = mask.any(dim=2)
75 | masks.append(mask)
76 | if masks:
77 | masks = torch.stack(masks, dim=0)
78 | else:
79 | masks = torch.zeros((0, height, width), dtype=torch.uint8)
80 | return masks
81 |
82 |
83 | class ConvertCocoPolysToMask(object):
84 | def __init__(self, return_masks=False):
85 | self.return_masks = return_masks
86 |
87 | def __call__(self, image, target, inds, num_frames):
88 | w, h = image.size
89 | image_id = target["image_id"]
90 | frame_id = target['frame_id']
91 | image_id = torch.tensor([image_id])
92 |
93 | anno = target["annotations"]
94 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
95 | boxes = []
96 | classes = []
97 | segmentations = []
98 | area = []
99 | iscrowd = []
100 | valid = []
101 | # add valid flag for bboxes
102 | for i, ann in enumerate(anno):
103 | for j in range(num_frames):
104 | bbox = ann['bboxes'][frame_id-inds[j]]
105 | areas = ann['areas'][frame_id-inds[j]]
106 | segm = ann['segmentations'][frame_id-inds[j]]
107 | clas = ann["category_id"]
108 | # for empty boxes
109 | if bbox is None:
110 | bbox = [0,0,0,0]
111 | areas = 0
112 | valid.append(0)
113 | clas = 0
114 | else:
115 | valid.append(1)
116 | crowd = ann["iscrowd"] if "iscrowd" in ann else 0
117 | boxes.append(bbox)
118 | area.append(areas)
119 | segmentations.append(segm)
120 | classes.append(clas)
121 | iscrowd.append(crowd)
122 | # guard against no boxes via resizing
123 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
124 | boxes[:, 2:] += boxes[:, :2]
125 | boxes[:, 0::2].clamp_(min=0, max=w)
126 | boxes[:, 1::2].clamp_(min=0, max=h)
127 | classes = torch.tensor(classes, dtype=torch.int64)
128 | if self.return_masks:
129 | masks = convert_coco_poly_to_mask(segmentations, h, w)
130 | target = {}
131 | target["boxes"] = boxes
132 | target["labels"] = classes
133 | if self.return_masks:
134 | target["masks"] = masks
135 | target["image_id"] = image_id
136 |
137 | # for conversion to coco api
138 | area = torch.tensor(area)
139 | iscrowd = torch.tensor(iscrowd)
140 | target["valid"] = torch.tensor(valid)
141 | target["area"] = area
142 | target["iscrowd"] = iscrowd
143 | target["orig_size"] = torch.as_tensor([int(h), int(w)])
144 | target["size"] = torch.as_tensor([int(h), int(w)])
145 | return target
146 |
147 |
148 | def make_coco_transforms(image_set):
149 |
150 | normalize = T.Compose([
151 | T.ToTensor(),
152 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
153 | ])
154 |
155 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768]
156 |
157 | if image_set == 'train':
158 | return T.Compose([
159 | T.RandomHorizontalFlip(),
160 | T.RandomResize(scales, max_size=800),
161 | T.PhotometricDistort(),
162 | T.Compose([
163 | T.RandomResize([400, 500, 600]),
164 | T.RandomSizeCrop(384, 600),
165 | # To suit the GPU memory the scale might be different
166 | T.RandomResize([300], max_size=540),#for r50
167 | #T.RandomResize([280], max_size=504),#for r101
168 | ]),
169 | normalize,
170 | ])
171 |
172 | if image_set == 'val':
173 | return T.Compose([
174 | T.RandomResize([360], max_size=640),
175 | normalize,
176 | ])
177 |
178 | raise ValueError(f'unknown {image_set}')
179 |
180 |
181 | def build(image_set, args):
182 | root = Path(args.ytvos_path)
183 | assert root.exists(), f'provided YTVOS path {root} does not exist'
184 | mode = 'instances'
185 | PATHS = {
186 | "train": (root / "train/JPEGImages", root / "annotations" / f'{mode}_train_sub.json'),
187 | "val": (root / "valid/JPEGImages", root / "annotations" / f'{mode}_val_sub.json'),
188 | }
189 | img_folder, ann_file = PATHS[image_set]
190 | dataset = YTVOSDataset(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks, num_frames = args.num_frames)
191 | return dataset
192 |
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | """
2 | Train and eval functions used in main.py
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import math
6 | import os
7 | import sys
8 | from typing import Iterable
9 |
10 | import torch
11 |
12 | import util.misc as utils
13 | from datasets.coco_eval import CocoEvaluator
14 | from datasets.panoptic_eval import PanopticEvaluator
15 |
16 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
17 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
18 | device: torch.device, epoch: int, max_norm: float = 0):
19 | model.train()
20 | criterion.train()
21 | metric_logger = utils.MetricLogger(delimiter=" ")
22 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
23 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
24 | header = 'Epoch: [{}]'.format(epoch)
25 | print_freq = 10
26 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
27 | samples = samples.to(device)
28 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
29 | outputs = model(samples)
30 | loss_dict = criterion(outputs, targets)
31 | weight_dict = criterion.weight_dict
32 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
33 |
34 | # reduce losses over all GPUs for logging purposes
35 | loss_dict_reduced = utils.reduce_dict(loss_dict)
36 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
37 | for k, v in loss_dict_reduced.items()}
38 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
39 | for k, v in loss_dict_reduced.items() if k in weight_dict}
40 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
41 |
42 | loss_value = losses_reduced_scaled.item()
43 |
44 | if not math.isfinite(loss_value):
45 | print("Loss is {}, stopping training".format(loss_value))
46 | print(loss_dict_reduced)
47 | sys.exit(1)
48 | optimizer.zero_grad()
49 | losses.backward()
50 | if max_norm > 0:
51 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
52 | optimizer.step()
53 |
54 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
55 | metric_logger.update(class_error=loss_dict_reduced['class_error'])
56 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
57 |
58 | # gather the stats from all processes
59 | metric_logger.synchronize_between_processes()
60 | print("Averaged stats:", metric_logger)
61 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
62 |
63 |
64 | @torch.no_grad()
65 | def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
66 | model.eval()
67 | criterion.eval()
68 |
69 | metric_logger = utils.MetricLogger(delimiter=" ")
70 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
71 | header = 'Test:'
72 |
73 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
74 | coco_evaluator = CocoEvaluator(base_ds, iou_types)
75 | # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
76 |
77 | panoptic_evaluator = None
78 | if 'panoptic' in postprocessors.keys():
79 | panoptic_evaluator = PanopticEvaluator(
80 | data_loader.dataset.ann_file,
81 | data_loader.dataset.ann_folder,
82 | output_dir=os.path.join(output_dir, "panoptic_eval"),
83 | )
84 |
85 | for samples, targets in metric_logger.log_every(data_loader, 10, header):
86 | samples = samples.to(device)
87 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
88 |
89 | outputs = model(samples)
90 | loss_dict = criterion(outputs, targets)
91 | weight_dict = criterion.weight_dict
92 |
93 | # reduce losses over all GPUs for logging purposes
94 | loss_dict_reduced = utils.reduce_dict(loss_dict)
95 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
96 | for k, v in loss_dict_reduced.items() if k in weight_dict}
97 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
98 | for k, v in loss_dict_reduced.items()}
99 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
100 | **loss_dict_reduced_scaled,
101 | **loss_dict_reduced_unscaled)
102 | metric_logger.update(class_error=loss_dict_reduced['class_error'])
103 |
104 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
105 | results = postprocessors['bbox'](outputs, orig_target_sizes)
106 | if 'segm' in postprocessors.keys():
107 | target_sizes = torch.stack([t["size"] for t in targets], dim=0)
108 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
109 | res = {target['image_id'].item(): output for target, output in zip(targets, results)}
110 | if coco_evaluator is not None:
111 | coco_evaluator.update(res)
112 |
113 | if panoptic_evaluator is not None:
114 | res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
115 | for i, target in enumerate(targets):
116 | image_id = target["image_id"].item()
117 | file_name = f"{image_id:012d}.png"
118 | res_pano[i]["image_id"] = image_id
119 | res_pano[i]["file_name"] = file_name
120 |
121 | panoptic_evaluator.update(res_pano)
122 |
123 | # gather the stats from all processes
124 | metric_logger.synchronize_between_processes()
125 | print("Averaged stats:", metric_logger)
126 | if coco_evaluator is not None:
127 | coco_evaluator.synchronize_between_processes()
128 | if panoptic_evaluator is not None:
129 | panoptic_evaluator.synchronize_between_processes()
130 |
131 | # accumulate predictions from all images
132 | if coco_evaluator is not None:
133 | coco_evaluator.accumulate()
134 | coco_evaluator.summarize()
135 | panoptic_res = None
136 | if panoptic_evaluator is not None:
137 | panoptic_res = panoptic_evaluator.summarize()
138 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
139 | if coco_evaluator is not None:
140 | if 'bbox' in postprocessors.keys():
141 | stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
142 | if 'segm' in postprocessors.keys():
143 | stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
144 | if panoptic_res is not None:
145 | stats['PQ_all'] = panoptic_res["All"]
146 | stats['PQ_th'] = panoptic_res["Things"]
147 | stats['PQ_st'] = panoptic_res["Stuff"]
148 | return stats, coco_evaluator
149 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | '''
2 | Inference code for VisTR
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | '''
5 | import argparse
6 | import datetime
7 | import json
8 | import random
9 | import time
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | from torch.utils.data import DataLoader, DistributedSampler
15 |
16 | import datasets
17 | import util.misc as utils
18 | from datasets import build_dataset, get_coco_api_from_dataset
19 | from engine import evaluate, train_one_epoch
20 | from models import build_model
21 | import torchvision.transforms as T
22 | import matplotlib.pyplot as plt
23 | import os
24 | from PIL import Image
25 | import math
26 | import torch.nn.functional as F
27 | import json
28 | from scipy.optimize import linear_sum_assignment
29 | import pycocotools.mask as mask_util
30 |
31 |
32 |
33 | def get_args_parser():
34 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
35 | parser.add_argument('--lr', default=1e-4, type=float)
36 | parser.add_argument('--lr_backbone', default=1e-5, type=float)
37 | parser.add_argument('--batch_size', default=2, type=int)
38 | parser.add_argument('--weight_decay', default=1e-4, type=float)
39 | parser.add_argument('--epochs', default=150, type=int)
40 | parser.add_argument('--lr_drop', default=100, type=int)
41 | parser.add_argument('--clip_max_norm', default=0.1, type=float,
42 | help='gradient clipping max norm')
43 |
44 | # Model parameters
45 | parser.add_argument('--model_path', type=str, default=None,
46 | help="Path to the model weights.")
47 | # * Backbone
48 | parser.add_argument('--backbone', default='resnet101', type=str,
49 | help="Name of the convolutional backbone to use")
50 | parser.add_argument('--dilation', action='store_true',
51 | help="If true, we replace stride with dilation in the last convolutional block (DC5)")
52 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
53 | help="Type of positional embedding to use on top of the image features")
54 |
55 | # * Transformer
56 | parser.add_argument('--enc_layers', default=6, type=int,
57 | help="Number of encoding layers in the transformer")
58 | parser.add_argument('--dec_layers', default=6, type=int,
59 | help="Number of decoding layers in the transformer")
60 | parser.add_argument('--dim_feedforward', default=2048, type=int,
61 | help="Intermediate size of the feedforward layers in the transformer blocks")
62 | parser.add_argument('--hidden_dim', default=384, type=int,
63 | help="Size of the embeddings (dimension of the transformer)")
64 | parser.add_argument('--dropout', default=0.1, type=float,
65 | help="Dropout applied in the transformer")
66 | parser.add_argument('--nheads', default=8, type=int,
67 | help="Number of attention heads inside the transformer's attentions")
68 | parser.add_argument('--num_frames', default=36, type=int,
69 | help="Number of frames")
70 | parser.add_argument('--num_ins', default=10, type=int,
71 | help="Number of instances")
72 | parser.add_argument('--num_queries', default=360, type=int,
73 | help="Number of query slots")
74 | parser.add_argument('--pre_norm', action='store_true')
75 |
76 | # * Segmentation
77 | parser.add_argument('--masks', action='store_true',
78 | help="Train segmentation head if the flag is provided")
79 |
80 | # Loss
81 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
82 | help="Disables auxiliary decoding losses (loss at each layer)")
83 | # * Matcher
84 | parser.add_argument('--set_cost_class', default=1, type=float,
85 | help="Class coefficient in the matching cost")
86 | parser.add_argument('--set_cost_bbox', default=5, type=float,
87 | help="L1 box coefficient in the matching cost")
88 | parser.add_argument('--set_cost_giou', default=2, type=float,
89 | help="giou box coefficient in the matching cost")
90 | # * Loss coefficients
91 | parser.add_argument('--mask_loss_coef', default=1, type=float)
92 | parser.add_argument('--dice_loss_coef', default=1, type=float)
93 | parser.add_argument('--bbox_loss_coef', default=5, type=float)
94 | parser.add_argument('--giou_loss_coef', default=2, type=float)
95 | parser.add_argument('--eos_coef', default=0.1, type=float,
96 | help="Relative classification weight of the no-object class")
97 |
98 | # dataset parameters
99 | parser.add_argument('--img_path', default='data/ytvos/valid/JPEGImages/')
100 | parser.add_argument('--ann_path', default='data/ytvos/annotations/instances_val_sub.json')
101 | parser.add_argument('--save_path', default='results.json')
102 | parser.add_argument('--dataset_file', default='ytvos')
103 | parser.add_argument('--coco_path', type=str)
104 | parser.add_argument('--coco_panoptic_path', type=str)
105 | parser.add_argument('--remove_difficult', action='store_true')
106 |
107 | parser.add_argument('--output_dir', default='output_ytvos',
108 | help='path where to save, empty for no saving')
109 | parser.add_argument('--device', default='cuda',
110 | help='device to use for training / testing')
111 | parser.add_argument('--seed', default=42, type=int)
112 | parser.add_argument('--resume', default='', help='resume from checkpoint')
113 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
114 | help='start epoch')
115 | #parser.add_argument('--eval', action='store_true')
116 | parser.add_argument('--eval', action='store_false')
117 | parser.add_argument('--num_workers', default=0, type=int)
118 |
119 | # distributed training parameters
120 | parser.add_argument('--world_size', default=1, type=int,
121 | help='number of distributed processes')
122 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
123 | return parser
124 |
125 | CLASSES=['person','giant_panda','lizard','parrot','skateboard','sedan','ape',
126 | 'dog','snake','monkey','hand','rabbit','duck','cat','cow','fish',
127 | 'train','horse','turtle','bear','motorbike','giraffe','leopard',
128 | 'fox','deer','owl','surfboard','airplane','truck','zebra','tiger',
129 | 'elephant','snowboard','boat','shark','mouse','frog','eagle','earless_seal',
130 | 'tennis_racket']
131 | COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
132 | [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933],
133 | [0.494, 0.000, 0.556], [0.494, 0.000, 0.000], [0.000, 0.745, 0.000],
134 | [0.700, 0.300, 0.600]]
135 | transform = T.Compose([
136 | T.Resize(300),
137 | T.ToTensor(),
138 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
139 | ])
140 |
141 |
142 | # for output bounding box post-processing
143 | def box_cxcywh_to_xyxy(x):
144 | x_c, y_c, w, h = x.unbind(1)
145 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
146 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
147 | return torch.stack(b, dim=1)
148 |
149 | def rescale_bboxes(out_bbox, size):
150 | img_w, img_h = size
151 | b = box_cxcywh_to_xyxy(out_bbox)
152 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
153 | return b
154 |
155 |
156 |
157 | def main(args):
158 |
159 | device = torch.device(args.device)
160 |
161 | # fix the seed for reproducibility
162 | seed = args.seed + utils.get_rank()
163 | torch.manual_seed(seed)
164 | np.random.seed(seed)
165 | random.seed(seed)
166 | num_frames = args.num_frames
167 | num_ins = args.num_ins
168 | with torch.no_grad():
169 | model, criterion, postprocessors = build_model(args)
170 | model.to(device)
171 | state_dict = torch.load(args.model_path)['model']
172 | model.load_state_dict(state_dict)
173 | folder = args.img_path
174 | videos = json.load(open(args.ann_path,'rb'))['videos']
175 | vis_num = len(videos)
176 | result = []
177 | for i in range(vis_num):
178 | print("Process video: ",i)
179 | id_ = videos[i]['id']
180 | length = videos[i]['length']
181 | file_names = videos[i]['file_names']
182 | clip_num = math.ceil(length/num_frames)
183 |
184 | img_set=[]
185 | if length0.5
203 | pred_logits = logits.reshape(num_frames,num_ins,logits.shape[-1]).cpu().detach().numpy()
204 | pred_masks = pred_masks[:length]
205 | pred_logits = pred_logits[:length]
206 | pred_scores = np.max(pred_logits,axis=-1)
207 | pred_logits = np.argmax(pred_logits,axis=-1)
208 | for m in range(num_ins):
209 | if pred_masks[:,m].max()==0:
210 | continue
211 | score = pred_scores[:,m].mean()
212 | #category_id = pred_logits[:,m][pred_scores[:,m].argmax()]
213 | category_id = np.argmax(np.bincount(pred_logits[:,m]))
214 | instance = {'video_id':id_, 'score':float(score), 'category_id':int(category_id)}
215 | segmentation = []
216 | for n in range(length):
217 | if pred_scores[n,m]<0.001:
218 | segmentation.append(None)
219 | else:
220 | mask = (pred_masks[n,m]).astype(np.uint8)
221 | rle = mask_util.encode(np.array(mask[:,:,np.newaxis], order='F'))[0]
222 | rle["counts"] = rle["counts"].decode("utf-8")
223 | segmentation.append(rle)
224 | instance['segmentations'] = segmentation
225 | result.append(instance)
226 | with open(args.save_path, 'w', encoding='utf-8') as f:
227 | json.dump(result,f)
228 |
229 |
230 |
231 |
232 |
233 | if __name__ == '__main__':
234 | parser = argparse.ArgumentParser('VisTR inference script', parents=[get_args_parser()])
235 | args = parser.parse_args()
236 | main(args)
237 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Training script of VisTR
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import argparse
6 | import datetime
7 | import json
8 | import random
9 | import time
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | from torch.utils.data import DataLoader, DistributedSampler
15 |
16 | import datasets
17 | import util.misc as utils
18 | from datasets import build_dataset, get_coco_api_from_dataset
19 | from engine import evaluate, train_one_epoch
20 | from models import build_model
21 |
22 |
23 | def get_args_parser():
24 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
25 | parser.add_argument('--lr', default=1e-4, type=float)
26 | parser.add_argument('--lr_backbone', default=1e-5, type=float)
27 | parser.add_argument('--batch_size', default=1, type=int)
28 | parser.add_argument('--weight_decay', default=1e-4, type=float)
29 | parser.add_argument('--epochs', default=18, type=int)
30 | parser.add_argument('--lr_drop', default=12, type=int)
31 | parser.add_argument('--clip_max_norm', default=0.1, type=float,
32 | help='gradient clipping max norm')
33 |
34 | # Model parameters
35 | parser.add_argument('--pretrained_weights', type=str, default="r101_pretrained.pth",
36 | help="Path to the pretrained model.")
37 | # * Backbone
38 | parser.add_argument('--backbone', default='resnet101', type=str,
39 | help="Name of the convolutional backbone to use")
40 | parser.add_argument('--dilation', action='store_true',
41 | help="If true, we replace stride with dilation in the last convolutional block (DC5)")
42 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
43 | help="Type of positional embedding to use on top of the image features")
44 |
45 | # * Transformer
46 | parser.add_argument('--enc_layers', default=6, type=int,
47 | help="Number of encoding layers in the transformer")
48 | parser.add_argument('--dec_layers', default=6, type=int,
49 | help="Number of decoding layers in the transformer")
50 | parser.add_argument('--dim_feedforward', default=2048, type=int,
51 | help="Intermediate size of the feedforward layers in the transformer blocks")
52 | parser.add_argument('--hidden_dim', default=384, type=int,
53 | help="Size of the embeddings (dimension of the transformer)")
54 | parser.add_argument('--dropout', default=0.1, type=float,
55 | help="Dropout applied in the transformer")
56 | parser.add_argument('--nheads', default=8, type=int,
57 | help="Number of attention heads inside the transformer's attentions")
58 | parser.add_argument('--num_frames', default=36, type=int,
59 | help="Number of frames")
60 | parser.add_argument('--num_queries', default=360, type=int,
61 | help="Number of query slots")
62 | parser.add_argument('--pre_norm', action='store_true')
63 |
64 | # * Segmentation
65 | parser.add_argument('--masks', action='store_true',
66 | help="Train segmentation head if the flag is provided")
67 |
68 | # Loss
69 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
70 | help="Disables auxiliary decoding losses (loss at each layer)")
71 | # * Matcher
72 | parser.add_argument('--set_cost_class', default=1, type=float,
73 | help="Class coefficient in the matching cost")
74 | parser.add_argument('--set_cost_bbox', default=5, type=float,
75 | help="L1 box coefficient in the matching cost")
76 | parser.add_argument('--set_cost_giou', default=2, type=float,
77 | help="giou box coefficient in the matching cost")
78 | # * Loss coefficients
79 | parser.add_argument('--mask_loss_coef', default=1, type=float)
80 | parser.add_argument('--dice_loss_coef', default=1, type=float)
81 | parser.add_argument('--bbox_loss_coef', default=5, type=float)
82 | parser.add_argument('--giou_loss_coef', default=2, type=float)
83 | parser.add_argument('--eos_coef', default=0.1, type=float,
84 | help="Relative classification weight of the no-object class")
85 |
86 | # dataset parameters
87 | parser.add_argument('--dataset_file', default='ytvos')
88 | parser.add_argument('--ytvos_path', type=str)
89 | parser.add_argument('--remove_difficult', action='store_true')
90 |
91 | parser.add_argument('--output_dir', default='r101_vistr',
92 | help='path where to save, empty for no saving')
93 | parser.add_argument('--device', default='cuda',
94 | help='device to use for training / testing')
95 | parser.add_argument('--seed', default=42, type=int)
96 | parser.add_argument('--resume', default='', help='resume from checkpoint')
97 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
98 | help='start epoch')
99 | parser.add_argument('--eval', action='store_true')
100 | parser.add_argument('--num_workers', default=4, type=int)
101 |
102 | # distributed training parameters
103 | parser.add_argument('--world_size', default=1, type=int,
104 | help='number of distributed processes')
105 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
106 | return parser
107 |
108 |
109 | def main(args):
110 | utils.init_distributed_mode(args)
111 | print("git:\n {}\n".format(utils.get_sha()))
112 |
113 |
114 | device = torch.device(args.device)
115 |
116 | # fix the seed for reproducibility
117 | seed = args.seed + utils.get_rank()
118 | torch.manual_seed(seed)
119 | np.random.seed(seed)
120 | random.seed(seed)
121 |
122 | model, criterion, postprocessors = build_model(args)
123 | model.to(device)
124 |
125 | model_without_ddp = model
126 | if args.distributed:
127 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
128 | model_without_ddp = model.module
129 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
130 | print('number of params:', n_parameters)
131 |
132 | param_dicts = [
133 | {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]},
134 | {
135 | "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad],
136 | "lr": args.lr_backbone,
137 | },
138 | ]
139 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
140 | weight_decay=args.weight_decay)
141 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
142 |
143 | # no validation ground truth for ytvos dataset
144 | dataset_train = build_dataset(image_set='train', args=args)
145 | if args.distributed:
146 | sampler_train = DistributedSampler(dataset_train)
147 | else:
148 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
149 |
150 | batch_sampler_train = torch.utils.data.BatchSampler(
151 | sampler_train, args.batch_size, drop_last=True)
152 |
153 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
154 | collate_fn=utils.collate_fn, num_workers=args.num_workers)
155 |
156 | output_dir = Path(args.output_dir)
157 |
158 | # load coco pretrained weight
159 | checkpoint = torch.load(args.pretrained_weights, map_location='cpu')['model']
160 | del checkpoint["vistr.class_embed.weight"]
161 | del checkpoint["vistr.class_embed.bias"]
162 | del checkpoint["vistr.query_embed.weight"]
163 | model.module.load_state_dict(checkpoint,strict=False)
164 |
165 | if args.resume:
166 | if args.resume.startswith('https'):
167 | checkpoint = torch.hub.load_state_dict_from_url(
168 | args.resume, map_location='cpu', check_hash=True)
169 | else:
170 | checkpoint = torch.load(args.resume, map_location='cpu')
171 | model_without_ddp.load_state_dict(checkpoint['model'])
172 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
173 | optimizer.load_state_dict(checkpoint['optimizer'])
174 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
175 | args.start_epoch = checkpoint['epoch'] + 1
176 |
177 | print("Start training")
178 | start_time = time.time()
179 | for epoch in range(args.start_epoch, args.epochs):
180 | if args.distributed:
181 | sampler_train.set_epoch(epoch)
182 | train_stats = train_one_epoch(
183 | model, criterion, data_loader_train, optimizer, device, epoch,
184 | args.clip_max_norm)
185 | lr_scheduler.step()
186 | if args.output_dir:
187 | checkpoint_paths = [output_dir / 'checkpoint.pth']
188 | # extra checkpoint before LR drop and every epochs
189 | if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0:
190 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
191 | for checkpoint_path in checkpoint_paths:
192 | utils.save_on_master({
193 | 'model': model_without_ddp.state_dict(),
194 | 'optimizer': optimizer.state_dict(),
195 | 'lr_scheduler': lr_scheduler.state_dict(),
196 | 'epoch': epoch,
197 | 'args': args,
198 | }, checkpoint_path)
199 |
200 |
201 | total_time = time.time() - start_time
202 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
203 | print('Training time {}'.format(total_time_str))
204 |
205 |
206 | if __name__ == '__main__':
207 | parser = argparse.ArgumentParser('VisTR training and evaluation script', parents=[get_args_parser()])
208 | args = parser.parse_args()
209 | if args.output_dir:
210 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
211 | main(args)
212 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .vistr import build
2 |
3 |
4 | def build_model(args):
5 | return build(args)
6 |
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | """
2 | Backbone modules.
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | from torch import nn
11 | from torchvision.models._utils import IntermediateLayerGetter
12 | from typing import Dict, List
13 |
14 | from util.misc import NestedTensor, is_main_process
15 |
16 | from .position_encoding import build_position_encoding
17 |
18 |
19 | class FrozenBatchNorm2d(torch.nn.Module):
20 | """
21 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
22 |
23 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
24 | without which any other models than torchvision.models.resnet[18,34,50,101]
25 | produce nans.
26 | """
27 |
28 | def __init__(self, n):
29 | super(FrozenBatchNorm2d, self).__init__()
30 | self.register_buffer("weight", torch.ones(n))
31 | self.register_buffer("bias", torch.zeros(n))
32 | self.register_buffer("running_mean", torch.zeros(n))
33 | self.register_buffer("running_var", torch.ones(n))
34 |
35 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
36 | missing_keys, unexpected_keys, error_msgs):
37 | num_batches_tracked_key = prefix + 'num_batches_tracked'
38 | if num_batches_tracked_key in state_dict:
39 | del state_dict[num_batches_tracked_key]
40 |
41 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
42 | state_dict, prefix, local_metadata, strict,
43 | missing_keys, unexpected_keys, error_msgs)
44 |
45 | def forward(self, x):
46 | # move reshapes to the beginning
47 | # to make it fuser-friendly
48 | w = self.weight.reshape(1, -1, 1, 1)
49 | b = self.bias.reshape(1, -1, 1, 1)
50 | rv = self.running_var.reshape(1, -1, 1, 1)
51 | rm = self.running_mean.reshape(1, -1, 1, 1)
52 | eps = 1e-5
53 | scale = w * (rv + eps).rsqrt()
54 | bias = b - rm * scale
55 | return x * scale + bias
56 |
57 |
58 | class BackboneBase(nn.Module):
59 |
60 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
61 | super().__init__()
62 | for name, parameter in backbone.named_parameters():
63 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
64 | parameter.requires_grad_(False)
65 | if return_interm_layers:
66 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
67 | else:
68 | return_layers = {'layer4': "0"}
69 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
70 | self.num_channels = num_channels
71 |
72 | def forward(self, tensor_list: NestedTensor):
73 | xs = self.body(tensor_list.tensors)
74 | out: Dict[str, NestedTensor] = {}
75 | for name, x in xs.items():
76 | m = tensor_list.mask
77 | assert m is not None
78 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
79 | out[name] = NestedTensor(x, mask)
80 | return out
81 |
82 |
83 | class Backbone(BackboneBase):
84 | """ResNet backbone with frozen BatchNorm."""
85 | def __init__(self, name: str,
86 | train_backbone: bool,
87 | return_interm_layers: bool,
88 | dilation: bool):
89 | backbone = getattr(torchvision.models, name)(
90 | replace_stride_with_dilation=[False, False, dilation],
91 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
92 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
93 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers)
94 |
95 |
96 | class Joiner(nn.Sequential):
97 | def __init__(self, backbone, position_embedding):
98 | super().__init__(backbone, position_embedding)
99 |
100 | def forward(self, tensor_list: NestedTensor):
101 | xs = self[0](tensor_list)
102 | out: List[NestedTensor] = []
103 | pos = []
104 | for name, x in xs.items():
105 | out.append(x)
106 | # position encoding
107 | pos.append(self[1](x).to(x.tensors.dtype))
108 |
109 | return out, pos
110 |
111 |
112 | def build_backbone(args):
113 | position_embedding = build_position_encoding(args)
114 | train_backbone = args.lr_backbone > 0
115 | return_interm_layers = args.masks
116 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
117 | model = Joiner(backbone, position_embedding)
118 | model.num_channels = backbone.num_channels
119 | return model
120 |
--------------------------------------------------------------------------------
/models/dcn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuqingWang1029/VisTR/445c9e4e787a1fb3c959d7e7bb6ecf809bdac155/models/dcn/__init__.py
--------------------------------------------------------------------------------
/models/dcn/deform_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import math
3 | from functools import lru_cache
4 | import torch
5 | from torch import nn
6 | from torch.autograd import Function
7 | from torch.autograd.function import once_differentiable
8 | from torch.nn.modules.utils import _pair
9 |
10 | from . import _C
11 |
12 | class _NewEmptyTensorOp(torch.autograd.Function):
13 | @staticmethod
14 | def forward(ctx, x, new_shape):
15 | ctx.shape = x.shape
16 | return x.new_empty(new_shape)
17 |
18 | @staticmethod
19 | def backward(ctx, grad):
20 | shape = ctx.shape
21 | return _NewEmptyTensorOp.apply(grad, shape), None
22 |
23 | class _DeformConv(Function):
24 | @staticmethod
25 | def forward(
26 | ctx,
27 | input,
28 | offset,
29 | weight,
30 | stride=1,
31 | padding=0,
32 | dilation=1,
33 | groups=1,
34 | deformable_groups=1,
35 | im2col_step=64,
36 | ):
37 | if input is not None and input.dim() != 4:
38 | raise ValueError(
39 | "Expected 4D tensor as input, got {}D tensor instead.".format(input.dim())
40 | )
41 | ctx.stride = _pair(stride)
42 | ctx.padding = _pair(padding)
43 | ctx.dilation = _pair(dilation)
44 | ctx.groups = groups
45 | ctx.deformable_groups = deformable_groups
46 | ctx.im2col_step = im2col_step
47 |
48 | ctx.save_for_backward(input, offset, weight)
49 |
50 | output = input.new_empty(
51 | _DeformConv._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride)
52 | )
53 |
54 | ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones
55 |
56 | if not input.is_cuda:
57 | raise NotImplementedError("Deformable Conv is not supported on CPUs!")
58 | else:
59 | cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step)
60 | assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize"
61 |
62 | _C.deform_conv_forward(
63 | input,
64 | weight,
65 | offset,
66 | output,
67 | ctx.bufs_[0],
68 | ctx.bufs_[1],
69 | weight.size(3),
70 | weight.size(2),
71 | ctx.stride[1],
72 | ctx.stride[0],
73 | ctx.padding[1],
74 | ctx.padding[0],
75 | ctx.dilation[1],
76 | ctx.dilation[0],
77 | ctx.groups,
78 | ctx.deformable_groups,
79 | cur_im2col_step,
80 | )
81 | return output
82 |
83 | @staticmethod
84 | @once_differentiable
85 | def backward(ctx, grad_output):
86 | input, offset, weight = ctx.saved_tensors
87 |
88 | grad_input = grad_offset = grad_weight = None
89 |
90 | if not grad_output.is_cuda:
91 | raise NotImplementedError("Deformable Conv is not supported on CPUs!")
92 | else:
93 | cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step)
94 | assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize"
95 |
96 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
97 | grad_input = torch.zeros_like(input)
98 | grad_offset = torch.zeros_like(offset)
99 | _C.deform_conv_backward_input(
100 | input,
101 | offset,
102 | grad_output,
103 | grad_input,
104 | grad_offset,
105 | weight,
106 | ctx.bufs_[0],
107 | weight.size(3),
108 | weight.size(2),
109 | ctx.stride[1],
110 | ctx.stride[0],
111 | ctx.padding[1],
112 | ctx.padding[0],
113 | ctx.dilation[1],
114 | ctx.dilation[0],
115 | ctx.groups,
116 | ctx.deformable_groups,
117 | cur_im2col_step,
118 | )
119 |
120 | if ctx.needs_input_grad[2]:
121 | grad_weight = torch.zeros_like(weight)
122 | _C.deform_conv_backward_filter(
123 | input,
124 | offset,
125 | grad_output,
126 | grad_weight,
127 | ctx.bufs_[0],
128 | ctx.bufs_[1],
129 | weight.size(3),
130 | weight.size(2),
131 | ctx.stride[1],
132 | ctx.stride[0],
133 | ctx.padding[1],
134 | ctx.padding[0],
135 | ctx.dilation[1],
136 | ctx.dilation[0],
137 | ctx.groups,
138 | ctx.deformable_groups,
139 | 1,
140 | cur_im2col_step,
141 | )
142 |
143 | return grad_input, grad_offset, grad_weight, None, None, None, None, None, None
144 |
145 | @staticmethod
146 | def _output_size(input, weight, padding, dilation, stride):
147 | channels = weight.size(0)
148 | output_size = (input.size(0), channels)
149 | for d in range(input.dim() - 2):
150 | in_size = input.size(d + 2)
151 | pad = padding[d]
152 | kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
153 | stride_ = stride[d]
154 | output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,)
155 | if not all(map(lambda s: s > 0, output_size)):
156 | raise ValueError(
157 | "convolution input is too small (output would be {})".format(
158 | "x".join(map(str, output_size))
159 | )
160 | )
161 | return output_size
162 |
163 | @staticmethod
164 | @lru_cache(maxsize=128)
165 | def _cal_im2col_step(input_size, default_size):
166 | """
167 | Calculate proper im2col step size, which should be divisible by input_size and not larger
168 | than prefer_size. Meanwhile the step size should be as large as possible to be more
169 | efficient. So we choose the largest one among all divisors of input_size which are smaller
170 | than prefer_size.
171 | :param input_size: input batch size .
172 | :param default_size: default preferred im2col step size.
173 | :return: the largest proper step size.
174 | """
175 | if input_size <= default_size:
176 | return input_size
177 | best_step = 1
178 | for step in range(2, min(int(math.sqrt(input_size)) + 1, default_size)):
179 | if input_size % step == 0:
180 | if input_size // step <= default_size:
181 | return input_size // step
182 | best_step = step
183 |
184 | return best_step
185 |
186 |
187 | class _ModulatedDeformConv(Function):
188 | @staticmethod
189 | def forward(
190 | ctx,
191 | input,
192 | offset,
193 | mask,
194 | weight,
195 | bias=None,
196 | stride=1,
197 | padding=0,
198 | dilation=1,
199 | groups=1,
200 | deformable_groups=1,
201 | ):
202 | ctx.stride = stride
203 | ctx.padding = padding
204 | ctx.dilation = dilation
205 | ctx.groups = groups
206 | ctx.deformable_groups = deformable_groups
207 | ctx.with_bias = bias is not None
208 | if not ctx.with_bias:
209 | bias = input.new_empty(1) # fake tensor
210 | if not input.is_cuda:
211 | raise NotImplementedError("Deformable Conv is not supported on CPUs!")
212 | if (
213 | weight.requires_grad
214 | or mask.requires_grad
215 | or offset.requires_grad
216 | or input.requires_grad
217 | ):
218 | ctx.save_for_backward(input, offset, mask, weight, bias)
219 | output = input.new_empty(_ModulatedDeformConv._infer_shape(ctx, input, weight))
220 | ctx._bufs = [input.new_empty(0), input.new_empty(0)]
221 | _C.modulated_deform_conv_forward(
222 | input,
223 | weight,
224 | bias,
225 | ctx._bufs[0],
226 | offset,
227 | mask,
228 | output,
229 | ctx._bufs[1],
230 | weight.shape[2],
231 | weight.shape[3],
232 | ctx.stride,
233 | ctx.stride,
234 | ctx.padding,
235 | ctx.padding,
236 | ctx.dilation,
237 | ctx.dilation,
238 | ctx.groups,
239 | ctx.deformable_groups,
240 | ctx.with_bias,
241 | )
242 | return output
243 |
244 | @staticmethod
245 | @once_differentiable
246 | def backward(ctx, grad_output):
247 | if not grad_output.is_cuda:
248 | raise NotImplementedError("Deformable Conv is not supported on CPUs!")
249 | input, offset, mask, weight, bias = ctx.saved_tensors
250 | grad_input = torch.zeros_like(input)
251 | grad_offset = torch.zeros_like(offset)
252 | grad_mask = torch.zeros_like(mask)
253 | grad_weight = torch.zeros_like(weight)
254 | grad_bias = torch.zeros_like(bias)
255 | _C.modulated_deform_conv_backward(
256 | input,
257 | weight,
258 | bias,
259 | ctx._bufs[0],
260 | offset,
261 | mask,
262 | ctx._bufs[1],
263 | grad_input,
264 | grad_weight,
265 | grad_bias,
266 | grad_offset,
267 | grad_mask,
268 | grad_output,
269 | weight.shape[2],
270 | weight.shape[3],
271 | ctx.stride,
272 | ctx.stride,
273 | ctx.padding,
274 | ctx.padding,
275 | ctx.dilation,
276 | ctx.dilation,
277 | ctx.groups,
278 | ctx.deformable_groups,
279 | ctx.with_bias,
280 | )
281 | if not ctx.with_bias:
282 | grad_bias = None
283 |
284 | return (
285 | grad_input,
286 | grad_offset,
287 | grad_mask,
288 | grad_weight,
289 | grad_bias,
290 | None,
291 | None,
292 | None,
293 | None,
294 | None,
295 | )
296 |
297 | @staticmethod
298 | def _infer_shape(ctx, input, weight):
299 | n = input.size(0)
300 | channels_out = weight.size(0)
301 | height, width = input.shape[2:4]
302 | kernel_h, kernel_w = weight.shape[2:4]
303 | height_out = (
304 | height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1)
305 | ) // ctx.stride + 1
306 | width_out = (
307 | width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1)
308 | ) // ctx.stride + 1
309 | return n, channels_out, height_out, width_out
310 |
311 |
312 | deform_conv = _DeformConv.apply
313 | modulated_deform_conv = _ModulatedDeformConv.apply
314 |
315 |
316 | class DeformConv(nn.Module):
317 | def __init__(
318 | self,
319 | in_channels,
320 | out_channels,
321 | kernel_size,
322 | stride=1,
323 | padding=0,
324 | dilation=1,
325 | groups=1,
326 | deformable_groups=1,
327 | bias=False,
328 | norm=None,
329 | activation=None,
330 | ):
331 | """
332 | Deformable convolution from :paper:`deformconv`.
333 |
334 | Arguments are similar to :class:`Conv2D`. Extra arguments:
335 |
336 | Args:
337 | deformable_groups (int): number of groups used in deformable convolution.
338 | norm (nn.Module, optional): a normalization layer
339 | activation (callable(Tensor) -> Tensor): a callable activation function
340 | """
341 | super(DeformConv, self).__init__()
342 |
343 | assert not bias
344 | assert in_channels % groups == 0, "in_channels {} cannot be divisible by groups {}".format(
345 | in_channels, groups
346 | )
347 | assert (
348 | out_channels % groups == 0
349 | ), "out_channels {} cannot be divisible by groups {}".format(out_channels, groups)
350 |
351 | self.in_channels = in_channels
352 | self.out_channels = out_channels
353 | self.kernel_size = _pair(kernel_size)
354 | self.stride = _pair(stride)
355 | self.padding = _pair(padding)
356 | self.dilation = _pair(dilation)
357 | self.groups = groups
358 | self.deformable_groups = deformable_groups
359 | self.norm = norm
360 | self.activation = activation
361 |
362 | self.weight = nn.Parameter(
363 | torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size)
364 | )
365 | self.bias = None
366 |
367 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
368 |
369 | def forward(self, x, offset):
370 | if x.numel() == 0:
371 | # When input is empty, we want to return a empty tensor with "correct" shape,
372 | # So that the following operations will not panic
373 | # if they check for the shape of the tensor.
374 | # This computes the height and width of the output tensor
375 | output_shape = [
376 | (i + 2 * p - (di * (k - 1) + 1)) // s + 1
377 | for i, p, di, k, s in zip(
378 | x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
379 | )
380 | ]
381 | output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
382 | return _NewEmptyTensorOp.apply(x, output_shape)
383 |
384 | x = deform_conv(
385 | x,
386 | offset,
387 | self.weight,
388 | self.stride,
389 | self.padding,
390 | self.dilation,
391 | self.groups,
392 | self.deformable_groups,
393 | )
394 | if self.norm is not None:
395 | x = self.norm(x)
396 | if self.activation is not None:
397 | x = self.activation(x)
398 | return x
399 |
400 | def extra_repr(self):
401 | tmpstr = "in_channels=" + str(self.in_channels)
402 | tmpstr += ", out_channels=" + str(self.out_channels)
403 | tmpstr += ", kernel_size=" + str(self.kernel_size)
404 | tmpstr += ", stride=" + str(self.stride)
405 | tmpstr += ", padding=" + str(self.padding)
406 | tmpstr += ", dilation=" + str(self.dilation)
407 | tmpstr += ", groups=" + str(self.groups)
408 | tmpstr += ", deformable_groups=" + str(self.deformable_groups)
409 | tmpstr += ", bias=False"
410 | return tmpstr
411 |
412 |
413 | class ModulatedDeformConv(nn.Module):
414 | def __init__(
415 | self,
416 | in_channels,
417 | out_channels,
418 | kernel_size,
419 | stride=1,
420 | padding=0,
421 | dilation=1,
422 | groups=1,
423 | deformable_groups=1,
424 | bias=True,
425 | norm=None,
426 | activation=None,
427 | ):
428 | """
429 | Modulated deformable convolution from :paper:`deformconv2`.
430 |
431 | Arguments are similar to :class:`Conv2D`. Extra arguments:
432 |
433 | Args:
434 | deformable_groups (int): number of groups used in deformable convolution.
435 | norm (nn.Module, optional): a normalization layer
436 | activation (callable(Tensor) -> Tensor): a callable activation function
437 | """
438 | super(ModulatedDeformConv, self).__init__()
439 | self.in_channels = in_channels
440 | self.out_channels = out_channels
441 | self.kernel_size = _pair(kernel_size)
442 | self.stride = stride
443 | self.padding = padding
444 | self.dilation = dilation
445 | self.groups = groups
446 | self.deformable_groups = deformable_groups
447 | self.with_bias = bias
448 | self.norm = norm
449 | self.activation = activation
450 |
451 | self.weight = nn.Parameter(
452 | torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)
453 | )
454 | if bias:
455 | self.bias = nn.Parameter(torch.Tensor(out_channels))
456 | else:
457 | self.bias = None
458 |
459 | nn.init.kaiming_uniform_(self.weight, nonlinearity="relu")
460 | if self.bias is not None:
461 | nn.init.constant_(self.bias, 0)
462 |
463 | def forward(self, x, offset, mask):
464 | if x.numel() == 0:
465 | output_shape = [
466 | (i + 2 * p - (di * (k - 1) + 1)) // s + 1
467 | for i, p, di, k, s in zip(
468 | x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride
469 | )
470 | ]
471 | output_shape = [x.shape[0], self.weight.shape[0]] + output_shape
472 | return _NewEmptyTensorOp.apply(x, output_shape)
473 |
474 | x = modulated_deform_conv(
475 | x,
476 | offset,
477 | mask,
478 | self.weight,
479 | self.bias,
480 | self.stride,
481 | self.padding,
482 | self.dilation,
483 | self.groups,
484 | self.deformable_groups,
485 | )
486 | if self.norm is not None:
487 | x = self.norm(x)
488 | if self.activation is not None:
489 | x = self.activation(x)
490 | return x
491 |
492 | def extra_repr(self):
493 | tmpstr = "in_channels=" + str(self.in_channels)
494 | tmpstr += ", out_channels=" + str(self.out_channels)
495 | tmpstr += ", kernel_size=" + str(self.kernel_size)
496 | tmpstr += ", stride=" + str(self.stride)
497 | tmpstr += ", padding=" + str(self.padding)
498 | tmpstr += ", dilation=" + str(self.dilation)
499 | tmpstr += ", groups=" + str(self.groups)
500 | tmpstr += ", deformable_groups=" + str(self.deformable_groups)
501 | tmpstr += ", bias=" + str(self.with_bias)
502 | return tmpstr
503 |
--------------------------------------------------------------------------------
/models/dcn/deformable/deform_conv.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include "deform_conv.h"
3 |
4 | #if defined(WITH_CUDA) || defined(WITH_HIP)
5 | extern int get_cudart_version();
6 | #endif
7 |
8 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
9 |
10 | m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward");
11 | m.def(
12 | "deform_conv_backward_input",
13 | &deform_conv_backward_input,
14 | "deform_conv_backward_input");
15 | m.def(
16 | "deform_conv_backward_filter",
17 | &deform_conv_backward_filter,
18 | "deform_conv_backward_filter");
19 | m.def(
20 | "modulated_deform_conv_forward",
21 | &modulated_deform_conv_forward,
22 | "modulated_deform_conv_forward");
23 | m.def(
24 | "modulated_deform_conv_backward",
25 | &modulated_deform_conv_backward,
26 | "modulated_deform_conv_backward");
27 | }
--------------------------------------------------------------------------------
/models/dcn/deformable/deform_conv.h:
--------------------------------------------------------------------------------
1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | #pragma once
3 | #include
4 |
5 |
6 |
7 | #if defined(WITH_CUDA) || defined(WITH_HIP)
8 | int deform_conv_forward_cuda(
9 | at::Tensor input,
10 | at::Tensor weight,
11 | at::Tensor offset,
12 | at::Tensor output,
13 | at::Tensor columns,
14 | at::Tensor ones,
15 | int kW,
16 | int kH,
17 | int dW,
18 | int dH,
19 | int padW,
20 | int padH,
21 | int dilationW,
22 | int dilationH,
23 | int group,
24 | int deformable_group,
25 | int im2col_step);
26 |
27 | int deform_conv_backward_input_cuda(
28 | at::Tensor input,
29 | at::Tensor offset,
30 | at::Tensor gradOutput,
31 | at::Tensor gradInput,
32 | at::Tensor gradOffset,
33 | at::Tensor weight,
34 | at::Tensor columns,
35 | int kW,
36 | int kH,
37 | int dW,
38 | int dH,
39 | int padW,
40 | int padH,
41 | int dilationW,
42 | int dilationH,
43 | int group,
44 | int deformable_group,
45 | int im2col_step);
46 |
47 | int deform_conv_backward_parameters_cuda(
48 | at::Tensor input,
49 | at::Tensor offset,
50 | at::Tensor gradOutput,
51 | at::Tensor gradWeight, // at::Tensor gradBias,
52 | at::Tensor columns,
53 | at::Tensor ones,
54 | int kW,
55 | int kH,
56 | int dW,
57 | int dH,
58 | int padW,
59 | int padH,
60 | int dilationW,
61 | int dilationH,
62 | int group,
63 | int deformable_group,
64 | float scale,
65 | int im2col_step);
66 |
67 | void modulated_deform_conv_cuda_forward(
68 | at::Tensor input,
69 | at::Tensor weight,
70 | at::Tensor bias,
71 | at::Tensor ones,
72 | at::Tensor offset,
73 | at::Tensor mask,
74 | at::Tensor output,
75 | at::Tensor columns,
76 | int kernel_h,
77 | int kernel_w,
78 | const int stride_h,
79 | const int stride_w,
80 | const int pad_h,
81 | const int pad_w,
82 | const int dilation_h,
83 | const int dilation_w,
84 | const int group,
85 | const int deformable_group,
86 | const bool with_bias);
87 |
88 | void modulated_deform_conv_cuda_backward(
89 | at::Tensor input,
90 | at::Tensor weight,
91 | at::Tensor bias,
92 | at::Tensor ones,
93 | at::Tensor offset,
94 | at::Tensor mask,
95 | at::Tensor columns,
96 | at::Tensor grad_input,
97 | at::Tensor grad_weight,
98 | at::Tensor grad_bias,
99 | at::Tensor grad_offset,
100 | at::Tensor grad_mask,
101 | at::Tensor grad_output,
102 | int kernel_h,
103 | int kernel_w,
104 | int stride_h,
105 | int stride_w,
106 | int pad_h,
107 | int pad_w,
108 | int dilation_h,
109 | int dilation_w,
110 | int group,
111 | int deformable_group,
112 | const bool with_bias);
113 |
114 | #endif
115 |
116 | inline int deform_conv_forward(
117 | at::Tensor input,
118 | at::Tensor weight,
119 | at::Tensor offset,
120 | at::Tensor output,
121 | at::Tensor columns,
122 | at::Tensor ones,
123 | int kW,
124 | int kH,
125 | int dW,
126 | int dH,
127 | int padW,
128 | int padH,
129 | int dilationW,
130 | int dilationH,
131 | int group,
132 | int deformable_group,
133 | int im2col_step) {
134 | if (input.is_cuda()) {
135 | #if defined(WITH_CUDA) || defined(WITH_HIP)
136 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
137 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
138 | return deform_conv_forward_cuda(
139 | input,
140 | weight,
141 | offset,
142 | output,
143 | columns,
144 | ones,
145 | kW,
146 | kH,
147 | dW,
148 | dH,
149 | padW,
150 | padH,
151 | dilationW,
152 | dilationH,
153 | group,
154 | deformable_group,
155 | im2col_step);
156 | #else
157 | AT_ERROR("Not compiled with GPU support");
158 | #endif
159 | }
160 | AT_ERROR("Not implemented on the CPU");
161 | }
162 |
163 | inline int deform_conv_backward_input(
164 | at::Tensor input,
165 | at::Tensor offset,
166 | at::Tensor gradOutput,
167 | at::Tensor gradInput,
168 | at::Tensor gradOffset,
169 | at::Tensor weight,
170 | at::Tensor columns,
171 | int kW,
172 | int kH,
173 | int dW,
174 | int dH,
175 | int padW,
176 | int padH,
177 | int dilationW,
178 | int dilationH,
179 | int group,
180 | int deformable_group,
181 | int im2col_step) {
182 | if (gradOutput.is_cuda()) {
183 | #if defined(WITH_CUDA) || defined(WITH_HIP)
184 | TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
185 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
186 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
187 | return deform_conv_backward_input_cuda(
188 | input,
189 | offset,
190 | gradOutput,
191 | gradInput,
192 | gradOffset,
193 | weight,
194 | columns,
195 | kW,
196 | kH,
197 | dW,
198 | dH,
199 | padW,
200 | padH,
201 | dilationW,
202 | dilationH,
203 | group,
204 | deformable_group,
205 | im2col_step);
206 | #else
207 | AT_ERROR("Not compiled with GPU support");
208 | #endif
209 | }
210 | AT_ERROR("Not implemented on the CPU");
211 | }
212 |
213 | inline int deform_conv_backward_filter(
214 | at::Tensor input,
215 | at::Tensor offset,
216 | at::Tensor gradOutput,
217 | at::Tensor gradWeight, // at::Tensor gradBias,
218 | at::Tensor columns,
219 | at::Tensor ones,
220 | int kW,
221 | int kH,
222 | int dW,
223 | int dH,
224 | int padW,
225 | int padH,
226 | int dilationW,
227 | int dilationH,
228 | int group,
229 | int deformable_group,
230 | float scale,
231 | int im2col_step) {
232 | if (gradOutput.is_cuda()) {
233 | #if defined(WITH_CUDA) || defined(WITH_HIP)
234 | TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
235 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
236 | return deform_conv_backward_parameters_cuda(
237 | input,
238 | offset,
239 | gradOutput,
240 | gradWeight,
241 | columns,
242 | ones,
243 | kW,
244 | kH,
245 | dW,
246 | dH,
247 | padW,
248 | padH,
249 | dilationW,
250 | dilationH,
251 | group,
252 | deformable_group,
253 | scale,
254 | im2col_step);
255 | #else
256 | AT_ERROR("Not compiled with GPU support");
257 | #endif
258 | }
259 | AT_ERROR("Not implemented on the CPU");
260 | }
261 |
262 | inline void modulated_deform_conv_forward(
263 | at::Tensor input,
264 | at::Tensor weight,
265 | at::Tensor bias,
266 | at::Tensor ones,
267 | at::Tensor offset,
268 | at::Tensor mask,
269 | at::Tensor output,
270 | at::Tensor columns,
271 | int kernel_h,
272 | int kernel_w,
273 | const int stride_h,
274 | const int stride_w,
275 | const int pad_h,
276 | const int pad_w,
277 | const int dilation_h,
278 | const int dilation_w,
279 | const int group,
280 | const int deformable_group,
281 | const bool with_bias) {
282 | if (input.is_cuda()) {
283 | #if defined(WITH_CUDA) || defined(WITH_HIP)
284 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
285 | TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!");
286 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
287 | return modulated_deform_conv_cuda_forward(
288 | input,
289 | weight,
290 | bias,
291 | ones,
292 | offset,
293 | mask,
294 | output,
295 | columns,
296 | kernel_h,
297 | kernel_w,
298 | stride_h,
299 | stride_w,
300 | pad_h,
301 | pad_w,
302 | dilation_h,
303 | dilation_w,
304 | group,
305 | deformable_group,
306 | with_bias);
307 | #else
308 | AT_ERROR("Not compiled with GPU support");
309 | #endif
310 | }
311 | AT_ERROR("Not implemented on the CPU");
312 | }
313 |
314 | inline void modulated_deform_conv_backward(
315 | at::Tensor input,
316 | at::Tensor weight,
317 | at::Tensor bias,
318 | at::Tensor ones,
319 | at::Tensor offset,
320 | at::Tensor mask,
321 | at::Tensor columns,
322 | at::Tensor grad_input,
323 | at::Tensor grad_weight,
324 | at::Tensor grad_bias,
325 | at::Tensor grad_offset,
326 | at::Tensor grad_mask,
327 | at::Tensor grad_output,
328 | int kernel_h,
329 | int kernel_w,
330 | int stride_h,
331 | int stride_w,
332 | int pad_h,
333 | int pad_w,
334 | int dilation_h,
335 | int dilation_w,
336 | int group,
337 | int deformable_group,
338 | const bool with_bias) {
339 | if (grad_output.is_cuda()) {
340 | #if defined(WITH_CUDA) || defined(WITH_HIP)
341 | TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!");
342 | TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!");
343 | TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!");
344 | TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!");
345 | return modulated_deform_conv_cuda_backward(
346 | input,
347 | weight,
348 | bias,
349 | ones,
350 | offset,
351 | mask,
352 | columns,
353 | grad_input,
354 | grad_weight,
355 | grad_bias,
356 | grad_offset,
357 | grad_mask,
358 | grad_output,
359 | kernel_h,
360 | kernel_w,
361 | stride_h,
362 | stride_w,
363 | pad_h,
364 | pad_w,
365 | dilation_h,
366 | dilation_w,
367 | group,
368 | deformable_group,
369 | with_bias);
370 | #else
371 | AT_ERROR("Not compiled with GPU support");
372 | #endif
373 | }
374 | AT_ERROR("Not implemented on the CPU");
375 | }
376 |
--------------------------------------------------------------------------------
/models/dcn/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 | import os
4 | setup(
5 | name='deform_conv',
6 | ext_modules=[
7 | CUDAExtension('_C', [
8 | 'deformable/deform_conv.cpp',
9 | 'deformable/deform_conv_cuda.cu',
10 | 'deformable/deform_conv_cuda_kernel.cu'],
11 | include_dirs=["deformable"],
12 | define_macros = [("WITH_CUDA", None)],
13 | extra_compile_args={"nvcc":[
14 | "-O3",
15 | "-DCUDA_HAS_FP16=1",
16 | "-D__CUDA_NO_HALF_OPERATORS__",
17 | "-D__CUDA_NO_HALF_CONVERSIONS__",
18 | "-D__CUDA_NO_HALF2_OPERATORS__"],
19 | "cxx":[],}
20 | )
21 | ],
22 | cmdclass={'build_ext': BuildExtension})
--------------------------------------------------------------------------------
/models/dcn/test_deform.py:
--------------------------------------------------------------------------------
1 | from deform_conv import DeformConv
2 | import torch
3 | conv = DeformConv(10,20,3,padding=1)
4 | conv.cuda()
5 | x = torch.rand([1,10,416,416]).cuda()
6 | offset = torch.rand([1,2*9,416,416]).cuda()
7 | y = conv(x,offset)
8 | print(y.shape)
9 |
--------------------------------------------------------------------------------
/models/matcher.py:
--------------------------------------------------------------------------------
1 | """
2 | Instance Sequence Matching
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 | from torch import nn
8 |
9 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, multi_iou
10 | INF = 100000000
11 |
12 | class HungarianMatcher(nn.Module):
13 | """This class computes an assignment between the targets and the predictions of the network
14 |
15 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
16 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
17 | while the others are un-matched (and thus treated as non-objects).
18 | """
19 |
20 | def __init__(self, num_frames : int = 36, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
21 | """Creates the matcher
22 |
23 | Params:
24 | cost_class: This is the relative weight of the classification error in the matching cost
25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
27 | """
28 | super().__init__()
29 | self.cost_class = cost_class
30 | self.cost_bbox = cost_bbox
31 | self.cost_giou = cost_giou
32 | self.num_frames = num_frames
33 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
34 |
35 | @torch.no_grad()
36 | def forward(self, outputs, targets):
37 | """ Performs the sequence level matching
38 | """
39 | bs, num_queries = outputs["pred_logits"].shape[:2]
40 | indices = []
41 | for i in range(bs):
42 | out_prob = outputs["pred_logits"][i].softmax(-1)
43 | out_bbox = outputs["pred_boxes"][i]
44 | tgt_ids = targets[i]["labels"]
45 | tgt_bbox = targets[i]["boxes"]
46 | tgt_valid = targets[i]["valid"]
47 | num_out = 10
48 | num_tgt = len(tgt_ids)//self.num_frames
49 | out_prob_split = out_prob.reshape(self.num_frames,num_out,out_prob.shape[-1]).permute(1,0,2)
50 | out_bbox_split = out_bbox.reshape(self.num_frames,num_out,out_bbox.shape[-1]).permute(1,0,2).unsqueeze(1)
51 | tgt_bbox_split = tgt_bbox.reshape(num_tgt,self.num_frames,4).unsqueeze(0)
52 | tgt_valid_split = tgt_valid.reshape(num_tgt,self.num_frames)
53 | frame_index = torch.arange(start=0,end=self.num_frames).repeat(num_tgt).long()
54 | class_cost = -1 * out_prob_split[:,frame_index,tgt_ids].view(num_out,num_tgt,self.num_frames).mean(dim=-1)
55 | bbox_cost = (out_bbox_split-tgt_bbox_split).abs().mean((-1,-2))
56 | iou_cost = -1 * multi_iou(box_cxcywh_to_xyxy(out_bbox_split),box_cxcywh_to_xyxy(tgt_bbox_split)).mean(-1)
57 | #TODO: only deal with box and mask with empty target
58 | cost = self.cost_class*class_cost + self.cost_bbox*bbox_cost + self.cost_giou*iou_cost
59 | out_i, tgt_i = linear_sum_assignment(cost.cpu())
60 | index_i,index_j = [],[]
61 | for j in range(len(out_i)):
62 | tgt_valid_ind_j = tgt_valid_split[j].nonzero().flatten()
63 | index_i.append(tgt_valid_ind_j*num_out + out_i[j])
64 | index_j.append(tgt_valid_ind_j + tgt_i[j]* self.num_frames)
65 | if index_i==[] or index_j==[]:
66 | indices.append((torch.tensor([]).long().to(out_prob.device),torch.tensor([]).long().to(out_prob.device)))
67 | else:
68 | index_i = torch.cat(index_i).long()
69 | index_j = torch.cat(index_j).long()
70 | indices.append((index_i,index_j))
71 | return indices
72 |
73 | def build_matcher(args):
74 | return HungarianMatcher(num_frames = args.num_frames, cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
75 |
--------------------------------------------------------------------------------
/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | """
2 | Various positional encodings for the transformer.
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 |
9 | from util.misc import NestedTensor
10 |
11 | # position encoding for 3 dims
12 | class PositionEmbeddingSine(nn.Module):
13 | """
14 | This is a more standard version of the position embedding, very similar to the one
15 | used by the Attention is all you need paper, generalized to work on images.
16 | """
17 | def __init__(self, num_pos_feats=64, num_frames = 36, temperature=10000, normalize=False, scale=None):
18 | super().__init__()
19 | self.num_pos_feats = num_pos_feats
20 | self.temperature = temperature
21 | self.normalize = normalize
22 | self.frames = num_frames
23 | if scale is not None and normalize is False:
24 | raise ValueError("normalize should be True if scale is passed")
25 | if scale is None:
26 | scale = 2 * math.pi
27 | self.scale = scale
28 |
29 | def forward(self, tensor_list: NestedTensor):
30 | x = tensor_list.tensors
31 | mask = tensor_list.mask
32 | n,h,w = mask.shape
33 | mask = mask.reshape(n//self.frames, self.frames,h,w)
34 | assert mask is not None
35 | not_mask = ~mask
36 | z_embed = not_mask.cumsum(1, dtype=torch.float32)
37 | y_embed = not_mask.cumsum(2, dtype=torch.float32)
38 | x_embed = not_mask.cumsum(3, dtype=torch.float32)
39 | if self.normalize:
40 | eps = 1e-6
41 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
42 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
43 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
44 |
45 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
46 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
47 |
48 | pos_x = x_embed[:, :, :, :, None] / dim_t
49 | pos_y = y_embed[:, :, :, :, None] / dim_t
50 | pos_z = z_embed[:, :, :, :, None] / dim_t
51 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
52 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
53 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
54 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3)
55 | return pos
56 |
57 |
58 |
59 |
60 | def build_position_encoding(args):
61 | N_steps = args.hidden_dim // 3
62 | if args.position_embedding in ('v2', 'sine'):
63 | # TODO find a better way of exposing other arguments
64 | position_embedding = PositionEmbeddingSine(N_steps, num_frames = args.num_frames, normalize=True)
65 | else:
66 | raise ValueError(f"not supported {args.position_embedding}")
67 |
68 | return position_embedding
69 |
--------------------------------------------------------------------------------
/models/segmentation.py:
--------------------------------------------------------------------------------
1 | """
2 | Instance Sequence Segmentation
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import io
6 | from collections import defaultdict
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torch import Tensor
13 | from PIL import Image
14 | from .dcn.deform_conv import DeformConv
15 |
16 | import util.box_ops as box_ops
17 | from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list
18 |
19 | try:
20 | from panopticapi.utils import id2rgb, rgb2id
21 | except ImportError:
22 | pass
23 | import time
24 | BN_MOMENTUM = 0.1
25 |
26 | def conv3x3(in_planes, out_planes, stride=1):
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=1, bias=False)
30 |
31 | class BasicBlock(nn.Module):
32 | expansion = 1
33 | def __init__(self, inplanes, planes, stride=1, downsample=None):
34 | super(BasicBlock, self).__init__()
35 | self.conv1 = conv3x3(inplanes, planes, stride)
36 | self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
37 | self.relu = nn.ReLU(inplace=True)
38 | self.conv2 = conv3x3(planes, planes)
39 | self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
40 | self.downsample = downsample
41 | self.stride = stride
42 |
43 | def forward(self, x):
44 | residual = x
45 | out = self.conv1(x)
46 | out = self.bn1(out)
47 | out = self.relu(out)
48 | out = self.conv2(out)
49 | out = self.bn2(out)
50 |
51 | if self.downsample is not None:
52 | residual = self.downsample(x)
53 |
54 | out += residual
55 | out = self.relu(out)
56 | return out
57 |
58 |
59 |
60 | class VisTRsegm(nn.Module):
61 | def __init__(self, vistr, freeze_vistr=False):
62 | super().__init__()
63 | self.vistr = vistr
64 |
65 | if freeze_vistr:
66 | for p in self.parameters():
67 | p.requires_grad_(False)
68 |
69 | hidden_dim, nheads = vistr.transformer.d_model, vistr.transformer.nhead
70 | self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0)
71 | self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim)
72 | self.insmask_head = nn.Sequential(
73 | nn.Conv3d(24,12,3,padding=2,dilation=2),
74 | nn.GroupNorm(4,12),
75 | nn.ReLU(),
76 | nn.Conv3d(12,12,3,padding=2,dilation=2),
77 | nn.GroupNorm(4,12),
78 | nn.ReLU(),
79 | nn.Conv3d(12,12,3,padding=2,dilation=2),
80 | nn.GroupNorm(4,12),
81 | nn.ReLU(),
82 | nn.Conv3d(12,1,1))
83 | def forward(self, samples: NestedTensor):
84 | if not isinstance(samples, NestedTensor):
85 | samples = nested_tensor_from_tensor_list(samples)
86 | features, pos = self.vistr.backbone(samples)
87 | bs = features[-1].tensors.shape[0]
88 | src, mask = features[-1].decompose()
89 | assert mask is not None
90 | src_proj = self.vistr.input_proj(src)
91 | n,c,s_h,s_w = src_proj.shape
92 | bs_f = bs//self.vistr.num_frames
93 | src_proj = src_proj.reshape(bs_f, self.vistr.num_frames,c, s_h, s_w).permute(0,2,1,3,4).flatten(-2)
94 | mask = mask.reshape(bs_f, self.vistr.num_frames, s_h*s_w)
95 | pos = pos[-1].permute(0,2,1,3,4).flatten(-2)
96 | hs, memory = self.vistr.transformer(src_proj, mask, self.vistr.query_embed.weight, pos)
97 | outputs_class = self.vistr.class_embed(hs)
98 | outputs_coord = self.vistr.bbox_embed(hs).sigmoid()
99 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]}
100 | if self.vistr.aux_loss:
101 | out['aux_outputs'] = self.vistr._set_aux_loss(outputs_class, outputs_coord)
102 | for i in range(3):
103 | _,c_f,h,w = features[i].tensors.shape
104 | features[i].tensors = features[i].tensors.reshape(bs_f, self.vistr.num_frames, c_f, h,w)
105 | n_f = self.vistr.num_queries//self.vistr.num_frames
106 | outputs_seg_masks = []
107 |
108 | # image level processing using box attention
109 | for i in range(self.vistr.num_frames):
110 | hs_f = hs[-1][:,i*n_f:(i+1)*n_f,:]
111 | memory_f = memory[:,:,i,:].reshape(bs_f, c, s_h,s_w)
112 | mask_f = mask[:,i,:].reshape(bs_f, s_h,s_w)
113 | bbox_mask_f = self.bbox_attention(hs_f, memory_f, mask=mask_f)
114 | seg_masks_f = self.mask_head(memory_f, bbox_mask_f, [features[2].tensors[:,i], features[1].tensors[:,i], features[0].tensors[:,i]])
115 | outputs_seg_masks_f = seg_masks_f.view(bs_f, n_f, 24, seg_masks_f.shape[-2], seg_masks_f.shape[-1])
116 | outputs_seg_masks.append(outputs_seg_masks_f)
117 | frame_masks = torch.cat(outputs_seg_masks,dim=0)
118 | outputs_seg_masks = []
119 |
120 | # instance level processing using 3D convolution
121 | for i in range(frame_masks.size(1)):
122 | mask_ins = frame_masks[:,i].unsqueeze(0)
123 | mask_ins = mask_ins.permute(0,2,1,3,4)
124 | outputs_seg_masks.append(self.insmask_head(mask_ins))
125 | outputs_seg_masks = torch.cat(outputs_seg_masks,1).squeeze(0).permute(1,0,2,3)
126 | outputs_seg_masks = outputs_seg_masks.reshape(1,self.vistr.num_queries,outputs_seg_masks.size(-2),outputs_seg_masks.size(-1))
127 | out["pred_masks"] = outputs_seg_masks
128 | return out
129 |
130 |
131 | def _expand(tensor, length: int):
132 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
133 |
134 |
135 | class MaskHeadSmallConv(nn.Module):
136 | """
137 | Simple convolutional head, using group norm.
138 | Upsampling is done using a FPN approach
139 | """
140 |
141 | def __init__(self, dim, fpn_dims, context_dim):
142 | super().__init__()
143 |
144 | inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64]
145 | self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
146 | self.gn1 = torch.nn.GroupNorm(8, dim)
147 | self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
148 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1])
149 | self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
150 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2])
151 | self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
152 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3])
153 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4])
154 | self.conv_offset = torch.nn.Conv2d(inter_dims[3], 18, 1)#, bias=False)
155 | self.dcn = DeformConv(inter_dims[3],inter_dims[4], 3, padding=1)
156 |
157 | self.dim = dim
158 |
159 | self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
160 | self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
161 | self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
162 |
163 | for name, m in self.named_modules():
164 | if name == "conv_offset":
165 | nn.init.constant_(m.weight, 0)
166 | nn.init.constant_(m.bias, 0)
167 | else:
168 | if isinstance(m, nn.Conv2d):
169 | nn.init.kaiming_uniform_(m.weight, a=1)
170 | nn.init.constant_(m.bias, 0)
171 |
172 | def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]):
173 | x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1)
174 |
175 | x = self.lay1(x)
176 | x = self.gn1(x)
177 | x = F.relu(x)
178 | x = self.lay2(x)
179 | x = self.gn2(x)
180 | x = F.relu(x)
181 |
182 | cur_fpn = self.adapter1(fpns[0])
183 | if cur_fpn.size(0) != x.size(0):
184 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
185 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
186 | x = self.lay3(x)
187 | x = self.gn3(x)
188 | x = F.relu(x)
189 |
190 | cur_fpn = self.adapter2(fpns[1])
191 | if cur_fpn.size(0) != x.size(0):
192 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
193 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
194 | x = self.lay4(x)
195 | x = self.gn4(x)
196 | x = F.relu(x)
197 |
198 | cur_fpn = self.adapter3(fpns[2])
199 | if cur_fpn.size(0) != x.size(0):
200 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0))
201 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
202 | # dcn for the last layer
203 | offset = self.conv_offset(x)
204 | x = self.dcn(x,offset)
205 | x = self.gn5(x)
206 | x = F.relu(x)
207 | return x
208 |
209 |
210 | class MHAttentionMap(nn.Module):
211 | """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""
212 |
213 | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True):
214 | super().__init__()
215 | self.num_heads = num_heads
216 | self.hidden_dim = hidden_dim
217 | self.dropout = nn.Dropout(dropout)
218 |
219 | self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
220 | self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
221 |
222 | nn.init.zeros_(self.k_linear.bias)
223 | nn.init.zeros_(self.q_linear.bias)
224 | nn.init.xavier_uniform_(self.k_linear.weight)
225 | nn.init.xavier_uniform_(self.q_linear.weight)
226 | self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5
227 |
228 | def forward(self, q, k, mask: Optional[Tensor] = None):
229 | q = self.q_linear(q)
230 | k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
231 | qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads)
232 | kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1])
233 | weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
234 |
235 | if mask is not None:
236 | weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
237 | weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
238 | weights = self.dropout(weights)
239 | return weights
240 |
241 |
242 | def dice_loss(inputs, targets, num_boxes):
243 | """
244 | Compute the DICE loss, similar to generalized IOU for masks
245 | Args:
246 | inputs: A float tensor of arbitrary shape.
247 | The predictions for each example.
248 | targets: A float tensor with the same shape as inputs. Stores the binary
249 | classification label for each element in inputs
250 | (0 for the negative class and 1 for the positive class).
251 | """
252 | inputs = inputs.sigmoid()
253 | inputs = inputs.flatten(1)
254 | numerator = 2 * (inputs * targets).sum(1)
255 | denominator = inputs.sum(-1) + targets.sum(-1)
256 | loss = 1 - (numerator + 1) / (denominator + 1)
257 | return loss.sum() / num_boxes
258 |
259 |
260 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
261 | """
262 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
263 | Args:
264 | inputs: A float tensor of arbitrary shape.
265 | The predictions for each example.
266 | targets: A float tensor with the same shape as inputs. Stores the binary
267 | classification label for each element in inputs
268 | (0 for the negative class and 1 for the positive class).
269 | alpha: (optional) Weighting factor in range (0,1) to balance
270 | positive vs negative examples. Default = -1 (no weighting).
271 | gamma: Exponent of the modulating factor (1 - p_t) to
272 | balance easy vs hard examples.
273 | Returns:
274 | Loss tensor
275 | """
276 | prob = inputs.sigmoid()
277 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
278 | p_t = prob * targets + (1 - prob) * (1 - targets)
279 | loss = ce_loss * ((1 - p_t) ** gamma)
280 |
281 | if alpha >= 0:
282 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
283 | loss = alpha_t * loss
284 |
285 | return loss.mean(1).sum() / num_boxes
286 |
287 |
288 | class PostProcessSegm(nn.Module):
289 | def __init__(self, threshold=0.5):
290 | super().__init__()
291 | self.threshold = threshold
292 |
293 | @torch.no_grad()
294 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
295 | assert len(orig_target_sizes) == len(max_target_sizes)
296 | max_h, max_w = max_target_sizes.max(0)[0].tolist()
297 | outputs_masks = outputs["pred_masks"].squeeze(2)
298 | outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
299 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
300 |
301 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
302 | img_h, img_w = t[0], t[1]
303 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
304 | results[i]["masks"] = F.interpolate(
305 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
306 | ).byte()
307 |
308 | return results
309 |
310 |
311 |
--------------------------------------------------------------------------------
/models/transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | VisTR Transformer class.
3 |
4 | Copy-paste from torch.nn.Transformer with modifications:
5 | * positional encodings are passed in MHattention
6 | * extra LN at the end of encoder is removed
7 | * decoder returns a stack of activations from all decoding layers
8 | Modified from DETR (https://github.com/facebookresearch/detr)
9 | """
10 | import copy
11 | from typing import Optional, List
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch import nn, Tensor
16 |
17 |
18 | class Transformer(nn.Module):
19 |
20 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
21 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
22 | activation="relu", normalize_before=False,
23 | return_intermediate_dec=False):
24 | super().__init__()
25 |
26 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
27 | dropout, activation, normalize_before)
28 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
30 |
31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
32 | dropout, activation, normalize_before)
33 | decoder_norm = nn.LayerNorm(d_model)
34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
35 | return_intermediate=return_intermediate_dec)
36 |
37 | self._reset_parameters()
38 |
39 | self.d_model = d_model
40 | self.nhead = nhead
41 |
42 | def _reset_parameters(self):
43 | for p in self.parameters():
44 | if p.dim() > 1:
45 | nn.init.xavier_uniform_(p)
46 |
47 | def forward(self, src, mask, query_embed, pos_embed):
48 | # flatten NxCxHxW to HWxNxC
49 | bs, c, h, w = src.shape
50 | src = src.flatten(2).permute(2, 0, 1)
51 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
52 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
53 | mask = mask.flatten(1)
54 |
55 | tgt = torch.zeros_like(query_embed)
56 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
57 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
58 | pos=pos_embed, query_pos=query_embed)
59 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
60 |
61 |
62 | class TransformerEncoder(nn.Module):
63 |
64 | def __init__(self, encoder_layer, num_layers, norm=None):
65 | super().__init__()
66 | self.layers = _get_clones(encoder_layer, num_layers)
67 | self.num_layers = num_layers
68 | self.norm = norm
69 |
70 | def forward(self, src,
71 | mask: Optional[Tensor] = None,
72 | src_key_padding_mask: Optional[Tensor] = None,
73 | pos: Optional[Tensor] = None):
74 | output = src
75 |
76 | for layer in self.layers:
77 | output = layer(output, src_mask=mask,
78 | src_key_padding_mask=src_key_padding_mask, pos=pos)
79 |
80 | if self.norm is not None:
81 | output = self.norm(output)
82 |
83 | return output
84 |
85 |
86 | class TransformerDecoder(nn.Module):
87 |
88 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
89 | super().__init__()
90 | self.layers = _get_clones(decoder_layer, num_layers)
91 | self.num_layers = num_layers
92 | self.norm = norm
93 | self.return_intermediate = return_intermediate
94 |
95 | def forward(self, tgt, memory,
96 | tgt_mask: Optional[Tensor] = None,
97 | memory_mask: Optional[Tensor] = None,
98 | tgt_key_padding_mask: Optional[Tensor] = None,
99 | memory_key_padding_mask: Optional[Tensor] = None,
100 | pos: Optional[Tensor] = None,
101 | query_pos: Optional[Tensor] = None):
102 | output = tgt
103 |
104 | intermediate = []
105 |
106 | for layer in self.layers:
107 | output = layer(output, memory, tgt_mask=tgt_mask,
108 | memory_mask=memory_mask,
109 | tgt_key_padding_mask=tgt_key_padding_mask,
110 | memory_key_padding_mask=memory_key_padding_mask,
111 | pos=pos, query_pos=query_pos)
112 | if self.return_intermediate:
113 | intermediate.append(self.norm(output))
114 |
115 | if self.norm is not None:
116 | output = self.norm(output)
117 | if self.return_intermediate:
118 | intermediate.pop()
119 | intermediate.append(output)
120 |
121 | if self.return_intermediate:
122 | return torch.stack(intermediate)
123 |
124 | return output
125 |
126 |
127 | class TransformerEncoderLayer(nn.Module):
128 |
129 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
130 | activation="relu", normalize_before=False):
131 | super().__init__()
132 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
133 | # Implementation of Feedforward model
134 | self.linear1 = nn.Linear(d_model, dim_feedforward)
135 | self.dropout = nn.Dropout(dropout)
136 | self.linear2 = nn.Linear(dim_feedforward, d_model)
137 |
138 | self.norm1 = nn.LayerNorm(d_model)
139 | self.norm2 = nn.LayerNorm(d_model)
140 | self.dropout1 = nn.Dropout(dropout)
141 | self.dropout2 = nn.Dropout(dropout)
142 |
143 | self.activation = _get_activation_fn(activation)
144 | self.normalize_before = normalize_before
145 |
146 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
147 | return tensor if pos is None else tensor + pos
148 |
149 | def forward_post(self,
150 | src,
151 | src_mask: Optional[Tensor] = None,
152 | src_key_padding_mask: Optional[Tensor] = None,
153 | pos: Optional[Tensor] = None):
154 | q = k = self.with_pos_embed(src, pos)
155 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
156 | key_padding_mask=src_key_padding_mask)[0]
157 | src = src + self.dropout1(src2)
158 | src = self.norm1(src)
159 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
160 | src = src + self.dropout2(src2)
161 | src = self.norm2(src)
162 | return src
163 |
164 | def forward_pre(self, src,
165 | src_mask: Optional[Tensor] = None,
166 | src_key_padding_mask: Optional[Tensor] = None,
167 | pos: Optional[Tensor] = None):
168 | src2 = self.norm1(src)
169 | q = k = self.with_pos_embed(src2, pos)
170 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
171 | key_padding_mask=src_key_padding_mask)[0]
172 | src = src + self.dropout1(src2)
173 | src2 = self.norm2(src)
174 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
175 | src = src + self.dropout2(src2)
176 | return src
177 |
178 | def forward(self, src,
179 | src_mask: Optional[Tensor] = None,
180 | src_key_padding_mask: Optional[Tensor] = None,
181 | pos: Optional[Tensor] = None):
182 | if self.normalize_before:
183 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
184 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
185 |
186 |
187 | class TransformerDecoderLayer(nn.Module):
188 |
189 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
190 | activation="relu", normalize_before=False):
191 | super().__init__()
192 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
193 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
194 | # Implementation of Feedforward model
195 | self.linear1 = nn.Linear(d_model, dim_feedforward)
196 | self.dropout = nn.Dropout(dropout)
197 | self.linear2 = nn.Linear(dim_feedforward, d_model)
198 |
199 | self.norm1 = nn.LayerNorm(d_model)
200 | self.norm2 = nn.LayerNorm(d_model)
201 | self.norm3 = nn.LayerNorm(d_model)
202 | self.dropout1 = nn.Dropout(dropout)
203 | self.dropout2 = nn.Dropout(dropout)
204 | self.dropout3 = nn.Dropout(dropout)
205 |
206 | self.activation = _get_activation_fn(activation)
207 | self.normalize_before = normalize_before
208 |
209 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
210 | return tensor if pos is None else tensor + pos
211 |
212 | def forward_post(self, tgt, memory,
213 | tgt_mask: Optional[Tensor] = None,
214 | memory_mask: Optional[Tensor] = None,
215 | tgt_key_padding_mask: Optional[Tensor] = None,
216 | memory_key_padding_mask: Optional[Tensor] = None,
217 | pos: Optional[Tensor] = None,
218 | query_pos: Optional[Tensor] = None):
219 | q = k = self.with_pos_embed(tgt, query_pos)
220 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
221 | key_padding_mask=tgt_key_padding_mask)[0]
222 | tgt = tgt + self.dropout1(tgt2)
223 | tgt = self.norm1(tgt)
224 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
225 | key=self.with_pos_embed(memory, pos),
226 | value=memory, attn_mask=memory_mask,
227 | key_padding_mask=memory_key_padding_mask)[0]
228 | tgt = tgt + self.dropout2(tgt2)
229 | tgt = self.norm2(tgt)
230 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
231 | tgt = tgt + self.dropout3(tgt2)
232 | tgt = self.norm3(tgt)
233 | return tgt
234 |
235 | def forward_pre(self, tgt, memory,
236 | tgt_mask: Optional[Tensor] = None,
237 | memory_mask: Optional[Tensor] = None,
238 | tgt_key_padding_mask: Optional[Tensor] = None,
239 | memory_key_padding_mask: Optional[Tensor] = None,
240 | pos: Optional[Tensor] = None,
241 | query_pos: Optional[Tensor] = None):
242 | tgt2 = self.norm1(tgt)
243 | q = k = self.with_pos_embed(tgt2, query_pos)
244 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
245 | key_padding_mask=tgt_key_padding_mask)[0]
246 | tgt = tgt + self.dropout1(tgt2)
247 | tgt2 = self.norm2(tgt)
248 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
249 | key=self.with_pos_embed(memory, pos),
250 | value=memory, attn_mask=memory_mask,
251 | key_padding_mask=memory_key_padding_mask)[0]
252 | tgt = tgt + self.dropout2(tgt2)
253 | tgt2 = self.norm3(tgt)
254 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
255 | tgt = tgt + self.dropout3(tgt2)
256 | return tgt
257 |
258 | def forward(self, tgt, memory,
259 | tgt_mask: Optional[Tensor] = None,
260 | memory_mask: Optional[Tensor] = None,
261 | tgt_key_padding_mask: Optional[Tensor] = None,
262 | memory_key_padding_mask: Optional[Tensor] = None,
263 | pos: Optional[Tensor] = None,
264 | query_pos: Optional[Tensor] = None):
265 | if self.normalize_before:
266 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
267 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
268 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
269 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
270 |
271 |
272 | def _get_clones(module, N):
273 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
274 |
275 |
276 | def build_transformer(args):
277 | return Transformer(
278 | d_model=args.hidden_dim,
279 | dropout=args.dropout,
280 | nhead=args.nheads,
281 | dim_feedforward=args.dim_feedforward,
282 | num_encoder_layers=args.enc_layers,
283 | num_decoder_layers=args.dec_layers,
284 | normalize_before=args.pre_norm,
285 | return_intermediate_dec=True,
286 | )
287 |
288 |
289 | def _get_activation_fn(activation):
290 | """Return an activation function given a string"""
291 | if activation == "relu":
292 | return F.relu
293 | if activation == "gelu":
294 | return F.gelu
295 | if activation == "glu":
296 | return F.glu
297 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
298 |
--------------------------------------------------------------------------------
/models/vistr.py:
--------------------------------------------------------------------------------
1 | """
2 | VisTR model and criterion classes.
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 | from util import box_ops
10 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list,
11 | accuracy, get_world_size, interpolate,
12 | is_dist_avail_and_initialized)
13 |
14 | from .backbone import build_backbone
15 | from .matcher import build_matcher
16 | from .segmentation import (VisTRsegm, PostProcessSegm,
17 | dice_loss, sigmoid_focal_loss)
18 | from .transformer import build_transformer
19 |
20 |
21 | class VisTR(nn.Module):
22 | """ This is the VisTR module that performs video object detection """
23 | def __init__(self, backbone, transformer, num_classes, num_frames, num_queries, aux_loss=False):
24 | """ Initializes the model.
25 | Parameters:
26 | backbone: torch module of the backbone to be used. See backbone.py
27 | transformer: torch module of the transformer architecture. See transformer.py
28 | num_classes: number of object classes
29 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects
30 | VisTR can detect in a video. For ytvos, we recommend 10 queries for each frame,
31 | thus 360 queries for 36 frames.
32 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
33 | """
34 | super().__init__()
35 | self.num_queries = num_queries
36 | self.transformer = transformer
37 | hidden_dim = transformer.d_model
38 | self.hidden_dim = hidden_dim
39 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
40 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
41 | self.query_embed = nn.Embedding(num_queries, hidden_dim)
42 | self.num_frames = num_frames
43 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
44 | self.backbone = backbone
45 | self.aux_loss = aux_loss
46 |
47 | def forward(self, samples: NestedTensor):
48 | """ The forward expects a NestedTensor, which consists of:
49 | - samples.tensors: image sequences, of shape [num_frames x 3 x H x W]
50 | - samples.mask: a binary mask of shape [num_frames x H x W], containing 1 on padded pixels
51 |
52 | It returns a dict with the following elements:
53 | - "pred_logits": the classification logits (including no-object) for all queries.
54 | Shape= [batch_size x num_queries x (num_classes + 1)]
55 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as
56 | (center_x, center_y, height, width). These values are normalized in [0, 1],
57 | relative to the size of each individual image (disregarding possible padding).
58 | See PostProcess for information on how to retrieve the unnormalized bounding box.
59 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
60 | dictionnaries containing the two above keys for each decoder layer.
61 | """
62 | if not isinstance(samples, NestedTensor):
63 | samples = nested_tensor_from_tensor_list(samples)
64 | # moved the frame to batch dimension for computation efficiency
65 | features, pos = self.backbone(samples)
66 | pos = pos[-1]
67 | src, mask = features[-1].decompose()
68 | src_proj = self.input_proj(src)
69 | n,c,h,w = src_proj.shape
70 | assert mask is not None
71 | src_proj = src_proj.reshape(n//self.num_frames, self.num_frames, c, h, w).permute(0,2,1,3,4).flatten(-2)
72 | mask = mask.reshape(n//self.num_frames, self.num_frames, h*w)
73 | pos = pos.permute(0,2,1,3,4).flatten(-2)
74 | hs = self.transformer(src_proj, mask, self.query_embed.weight, pos)[0]
75 |
76 | outputs_class = self.class_embed(hs)
77 | outputs_coord = self.bbox_embed(hs).sigmoid()
78 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
79 | if self.aux_loss:
80 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
81 | return out
82 |
83 | @torch.jit.unused
84 | def _set_aux_loss(self, outputs_class, outputs_coord):
85 | # this is a workaround to make torchscript happy, as torchscript
86 | # doesn't support dictionary with non-homogeneous values, such
87 | # as a dict having both a Tensor and a list.
88 | return [{'pred_logits': a, 'pred_boxes': b}
89 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
90 |
91 |
92 | class SetCriterion(nn.Module):
93 | """ This class computes the loss for VisTR.
94 | The process happens in two steps:
95 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
96 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
97 | """
98 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses):
99 | """ Create the criterion.
100 | Parameters:
101 | num_classes: number of object categories, omitting the special no-object category
102 | matcher: module able to compute a matching between targets and proposals
103 | weight_dict: dict containing as key the names of the losses and as values their relative weight.
104 | eos_coef: relative classification weight applied to the no-object category
105 | losses: list of all the losses to be applied. See get_loss for list of available losses.
106 | """
107 | super().__init__()
108 | self.num_classes = num_classes
109 | self.matcher = matcher
110 | self.weight_dict = weight_dict
111 | self.eos_coef = eos_coef
112 | self.losses = losses
113 | empty_weight = torch.ones(self.num_classes + 1)
114 | empty_weight[-1] = self.eos_coef
115 | self.register_buffer('empty_weight', empty_weight)
116 |
117 | def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
118 | """Classification loss (NLL)
119 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
120 | """
121 | assert 'pred_logits' in outputs
122 | src_logits = outputs['pred_logits']
123 |
124 | idx = self._get_src_permutation_idx(indices)
125 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
126 | target_classes = torch.full(src_logits.shape[:2], self.num_classes,
127 | dtype=torch.int64, device=src_logits.device)
128 | target_classes[idx] = target_classes_o
129 |
130 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
131 | losses = {'loss_ce': loss_ce}
132 |
133 | if log:
134 | # TODO this should probably be a separate loss, not hacked in this one here
135 | losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
136 | return losses
137 |
138 | @torch.no_grad()
139 | def loss_cardinality(self, outputs, targets, indices, num_boxes):
140 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
141 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
142 | """
143 | pred_logits = outputs['pred_logits']
144 | device = pred_logits.device
145 | tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device)
146 | # Count the number of predictions that are NOT "no-object" (which is the last class)
147 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
148 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
149 | losses = {'cardinality_error': card_err}
150 | return losses
151 |
152 | def loss_boxes(self, outputs, targets, indices, num_boxes):
153 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
154 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
155 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
156 | """
157 | assert 'pred_boxes' in outputs
158 | idx = self._get_src_permutation_idx(indices)
159 | src_boxes = outputs['pred_boxes'][idx]
160 | target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
161 |
162 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
163 |
164 | losses = {}
165 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes
166 |
167 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou(
168 | box_ops.box_cxcywh_to_xyxy(src_boxes),
169 | box_ops.box_cxcywh_to_xyxy(target_boxes)))
170 | losses['loss_giou'] = loss_giou.sum() / num_boxes
171 | return losses
172 |
173 | def loss_masks(self, outputs, targets, indices, num_boxes):
174 | """Compute the losses related to the masks: the focal loss and the dice loss.
175 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w]
176 | """
177 | assert "pred_masks" in outputs
178 |
179 | src_idx = self._get_src_permutation_idx(indices)
180 | tgt_idx = self._get_tgt_permutation_idx(indices)
181 |
182 | src_masks = outputs["pred_masks"]
183 | # TODO use valid to mask invalid areas due to padding in loss
184 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], split=False).decompose()
185 | target_masks = target_masks.to(src_masks)
186 | src_masks = src_masks[src_idx]
187 | # upsample predictions to the target size
188 | try:
189 | src_masks = interpolate(src_masks[:, None], size=target_masks.shape[-2:],
190 | mode="bilinear", align_corners=False)
191 | src_masks = src_masks[:, 0].flatten(1)
192 | target_masks = target_masks[tgt_idx].flatten(1)
193 | except:
194 | src_masks = src_masks.flatten(1)
195 | target_masks = src_masks.clone()
196 | losses = {
197 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes),
198 | "loss_dice": dice_loss(src_masks, target_masks, num_boxes),
199 | }
200 | return losses
201 |
202 | def _get_src_permutation_idx(self, indices):
203 | # permute predictions following indices
204 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
205 | src_idx = torch.cat([src for (src, _) in indices])
206 | return batch_idx, src_idx
207 |
208 | def _get_tgt_permutation_idx(self, indices):
209 | # permute targets following indices
210 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
211 | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
212 | return batch_idx, tgt_idx
213 |
214 | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
215 | loss_map = {
216 | 'labels': self.loss_labels,
217 | 'cardinality': self.loss_cardinality,
218 | 'boxes': self.loss_boxes,
219 | 'masks': self.loss_masks
220 | }
221 | assert loss in loss_map, f'do you really want to compute {loss} loss?'
222 | return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
223 |
224 | def forward(self, outputs, targets):
225 | """ This performs the loss computation.
226 | Parameters:
227 | outputs: dict of tensors, see the output specification of the model for the format
228 | targets: list of dicts, such that len(targets) == batch_size.
229 | The expected keys in each dict depends on the losses applied, see each loss' doc
230 | """
231 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
232 | # Retrieve the matching between the outputs of the last layer and the targets
233 | indices = self.matcher(outputs_without_aux, targets)
234 |
235 | # Compute the average number of target boxes accross all nodes, for normalization purposes
236 | num_boxes = sum(len(t["labels"]) for t in targets)
237 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
238 | if is_dist_avail_and_initialized():
239 | torch.distributed.all_reduce(num_boxes)
240 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
241 |
242 | # Compute all the requested losses
243 | losses = {}
244 | for loss in self.losses:
245 | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
246 |
247 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
248 | if 'aux_outputs' in outputs:
249 | for i, aux_outputs in enumerate(outputs['aux_outputs']):
250 | indices = self.matcher(aux_outputs, targets)
251 | for loss in self.losses:
252 | if loss == 'masks':
253 | # Intermediate masks losses are too costly to compute, we ignore them.
254 | continue
255 | kwargs = {}
256 | if loss == 'labels':
257 | # Logging is enabled only for the last layer
258 | kwargs = {'log': False}
259 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
260 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
261 | losses.update(l_dict)
262 |
263 | return losses
264 |
265 |
266 | class PostProcess(nn.Module):
267 | """ This module converts the model's output into the format expected by the coco api"""
268 | @torch.no_grad()
269 | def forward(self, outputs, target_sizes):
270 | """ Perform the computation
271 | Parameters:
272 | outputs: raw outputs of the model
273 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
274 | For evaluation, this must be the original image size (before any data augmentation)
275 | For visualization, this should be the image size after data augment, but before padding
276 | """
277 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
278 |
279 | assert len(out_logits) == len(target_sizes)
280 | assert target_sizes.shape[1] == 2
281 |
282 | prob = F.softmax(out_logits, -1)
283 | scores, labels = prob[..., :-1].max(-1)
284 |
285 | # convert to [x0, y0, x1, y1] format
286 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
287 | # and from relative [0, 1] to absolute [0, height] coordinates
288 | img_h, img_w = target_sizes.unbind(1)
289 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
290 | boxes = boxes * scale_fct[:, None, :]
291 |
292 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
293 |
294 | return results
295 |
296 |
297 | class MLP(nn.Module):
298 | """ Very simple multi-layer perceptron (also called FFN)"""
299 |
300 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
301 | super().__init__()
302 | self.num_layers = num_layers
303 | h = [hidden_dim] * (num_layers - 1)
304 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
305 |
306 | def forward(self, x):
307 | for i, layer in enumerate(self.layers):
308 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
309 | return x
310 |
311 |
312 | def build(args):
313 | if args.dataset_file == "ytvos":
314 | num_classes = 41
315 | device = torch.device(args.device)
316 |
317 | backbone = build_backbone(args)
318 |
319 | transformer = build_transformer(args)
320 |
321 | model = VisTR(
322 | backbone,
323 | transformer,
324 | num_classes=num_classes,
325 | num_frames=args.num_frames,
326 | num_queries=args.num_queries,
327 | aux_loss=args.aux_loss,
328 | )
329 | if args.masks:
330 | model = VisTRsegm(model)
331 | matcher = build_matcher(args)
332 | weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
333 | weight_dict['loss_giou'] = args.giou_loss_coef
334 | if args.masks:
335 | weight_dict["loss_mask"] = args.mask_loss_coef
336 | weight_dict["loss_dice"] = args.dice_loss_coef
337 | # TODO this is a hack
338 | if args.aux_loss:
339 | aux_weight_dict = {}
340 | for i in range(args.dec_layers - 1):
341 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
342 | weight_dict.update(aux_weight_dict)
343 |
344 | losses = ['labels', 'boxes', 'cardinality']
345 | if args.masks:
346 | losses += ["masks"]
347 | criterion = SetCriterion(num_classes, matcher=matcher, weight_dict=weight_dict,
348 | eos_coef=args.eos_coef, losses=losses)
349 | criterion.to(device)
350 | postprocessors = {'bbox': PostProcess()}
351 | if args.masks:
352 | postprocessors['segm'] = PostProcessSegm()
353 | return model, criterion, postprocessors
354 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/YuqingWang1029/VisTR/445c9e4e787a1fb3c959d7e7bb6ecf809bdac155/util/__init__.py
--------------------------------------------------------------------------------
/util/box_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for bounding box manipulation and GIoU.
3 | """
4 | import torch
5 | from torchvision.ops.boxes import box_area
6 |
7 | def clip_iou(boxes1,boxes2):
8 | area1 = box_area(boxes1)
9 | area2 = box_area(boxes2)
10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2])
11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])
12 | wh = (rb-lt).clamp(min=0)
13 | inter = wh[:,0]*wh[:,1]
14 | union = area1 + area2 - inter
15 | iou = (inter+1e-6) / (union+1e-6)
16 | # generalized version
17 | # iou=iou-(inter-union)/inter
18 | return iou
19 |
20 | def multi_iou(boxes1, boxes2):
21 | lt = torch.max(boxes1[...,:2], boxes2[...,:2])
22 | rb = torch.min(boxes1[...,2:], boxes2[...,2:])
23 | wh = (rb-lt).clamp(min=0)
24 | wh_1 = boxes1[...,2:] - boxes1[...,:2]
25 | wh_2 = boxes2[...,2:] - boxes2[...,:2]
26 | inter = wh[...,0] * wh[...,1]
27 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter
28 | iou = (inter+1e-6) / (union+1e-6)
29 | return iou
30 |
31 | def box_cxcywh_to_xyxy(x):
32 | x_c, y_c, w, h = x.unbind(-1)
33 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
34 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
35 | return torch.stack(b, dim=-1)
36 |
37 |
38 | def box_xyxy_to_cxcywh(x):
39 | x0, y0, x1, y1 = x.unbind(-1)
40 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
41 | (x1 - x0), (y1 - y0)]
42 | return torch.stack(b, dim=-1)
43 |
44 |
45 | # modified from torchvision to also return the union
46 | def box_iou(boxes1, boxes2):
47 | area1 = box_area(boxes1)
48 | area2 = box_area(boxes2)
49 |
50 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
51 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
52 |
53 | wh = (rb - lt).clamp(min=0) # [N,M,2]
54 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
55 |
56 | union = area1[:, None] + area2 - inter
57 |
58 | iou = (inter+1e-6) / (union+1e-6)
59 | return iou, union
60 |
61 |
62 | def generalized_box_iou(boxes1, boxes2):
63 | """
64 | Generalized IoU from https://giou.stanford.edu/
65 |
66 | The boxes should be in [x0, y0, x1, y1] format
67 |
68 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
69 | and M = len(boxes2)
70 | """
71 | # degenerate boxes gives inf / nan results
72 | # so do an early check
73 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
74 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
75 | iou, union = box_iou(boxes1, boxes2)
76 |
77 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
78 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
79 |
80 | wh = (rb - lt).clamp(min=0) # [N,M,2]
81 | area = wh[:, :, 0] * wh[:, :, 1]
82 |
83 | return iou - ((area - union)+1e-6) / (area+1e-6)
84 |
85 |
86 | def masks_to_boxes(masks):
87 | """Compute the bounding boxes around the provided masks
88 |
89 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
90 |
91 | Returns a [N, 4] tensors, with the boxes in xyxy format
92 | """
93 | if masks.numel() == 0:
94 | return torch.zeros((0, 4), device=masks.device)
95 |
96 | h, w = masks.shape[-2:]
97 |
98 | y = torch.arange(0, h, dtype=torch.float)
99 | x = torch.arange(0, w, dtype=torch.float)
100 | y, x = torch.meshgrid(y, x)
101 |
102 | x_mask = (masks * x.unsqueeze(0))
103 | x_max = x_mask.flatten(1).max(-1)[0]
104 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
105 |
106 | y_mask = (masks * y.unsqueeze(0))
107 | y_max = y_mask.flatten(1).max(-1)[0]
108 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
109 |
110 | return torch.stack([x_min, y_min, x_max, y_max], 1)
111 |
--------------------------------------------------------------------------------
/util/detr_weights_to_vistr.py:
--------------------------------------------------------------------------------
1 | '''
2 | convert detr pretrained weights to vistr format
3 | '''
4 | import sys
5 | import torch
6 | import collections
7 |
8 | if __name__ == "__main__":
9 | input_path = sys.argv[1]
10 | detr_weights = torch.load(input_path)['model']
11 | vistr_weights = collections.OrderedDict()
12 |
13 | for k,v in detr_weights.items():
14 | if k.startswith("detr"):
15 | k = k.replace("detr","vistr")
16 | vistr_weights[k]=v
17 | res = {"model":vistr_weights}
18 |
19 | torch.save(res,sys.argv[2])
20 |
21 |
22 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | """
2 | Misc functions, including distributed helpers.
3 |
4 | Mostly copy-paste from torchvision references.
5 | """
6 | import os
7 | import subprocess
8 | import time
9 | from collections import defaultdict, deque
10 | import datetime
11 | import pickle
12 | from typing import Optional, List
13 |
14 | import torch
15 | import torch.distributed as dist
16 | from torch import Tensor
17 |
18 | # needed due to empty tensor bug in pytorch and torchvision 0.5
19 | import torchvision
20 | if float(torchvision.__version__[:3]) < 0.7:
21 | from torchvision.ops import _new_empty_tensor
22 | from torchvision.ops.misc import _output_size
23 |
24 |
25 | class SmoothedValue(object):
26 | """Track a series of values and provide access to smoothed values over a
27 | window or the global series average.
28 | """
29 |
30 | def __init__(self, window_size=20, fmt=None):
31 | if fmt is None:
32 | fmt = "{median:.4f} ({global_avg:.4f})"
33 | self.deque = deque(maxlen=window_size)
34 | self.total = 0.0
35 | self.count = 0
36 | self.fmt = fmt
37 |
38 | def update(self, value, n=1):
39 | self.deque.append(value)
40 | self.count += n
41 | self.total += value * n
42 |
43 | def synchronize_between_processes(self):
44 | """
45 | Warning: does not synchronize the deque!
46 | """
47 | if not is_dist_avail_and_initialized():
48 | return
49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
50 | dist.barrier()
51 | dist.all_reduce(t)
52 | t = t.tolist()
53 | self.count = int(t[0])
54 | self.total = t[1]
55 |
56 | @property
57 | def median(self):
58 | d = torch.tensor(list(self.deque))
59 | return d.median().item()
60 |
61 | @property
62 | def avg(self):
63 | d = torch.tensor(list(self.deque), dtype=torch.float32)
64 | return d.mean().item()
65 |
66 | @property
67 | def global_avg(self):
68 | return self.total / self.count
69 |
70 | @property
71 | def max(self):
72 | return max(self.deque)
73 |
74 | @property
75 | def value(self):
76 | return self.deque[-1]
77 |
78 | def __str__(self):
79 | return self.fmt.format(
80 | median=self.median,
81 | avg=self.avg,
82 | global_avg=self.global_avg,
83 | max=self.max,
84 | value=self.value)
85 |
86 |
87 | def all_gather(data):
88 | """
89 | Run all_gather on arbitrary picklable data (not necessarily tensors)
90 | Args:
91 | data: any picklable object
92 | Returns:
93 | list[data]: list of data gathered from each rank
94 | """
95 | world_size = get_world_size()
96 | if world_size == 1:
97 | return [data]
98 |
99 | # serialized to a Tensor
100 | buffer = pickle.dumps(data)
101 | storage = torch.ByteStorage.from_buffer(buffer)
102 | tensor = torch.ByteTensor(storage).to("cuda")
103 |
104 | # obtain Tensor size of each rank
105 | local_size = torch.tensor([tensor.numel()], device="cuda")
106 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
107 | dist.all_gather(size_list, local_size)
108 | size_list = [int(size.item()) for size in size_list]
109 | max_size = max(size_list)
110 |
111 | # receiving Tensor from all ranks
112 | # we pad the tensor because torch all_gather does not support
113 | # gathering tensors of different shapes
114 | tensor_list = []
115 | for _ in size_list:
116 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
117 | if local_size != max_size:
118 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
119 | tensor = torch.cat((tensor, padding), dim=0)
120 | dist.all_gather(tensor_list, tensor)
121 |
122 | data_list = []
123 | for size, tensor in zip(size_list, tensor_list):
124 | buffer = tensor.cpu().numpy().tobytes()[:size]
125 | data_list.append(pickle.loads(buffer))
126 |
127 | return data_list
128 |
129 |
130 | def reduce_dict(input_dict, average=True):
131 | """
132 | Args:
133 | input_dict (dict): all the values will be reduced
134 | average (bool): whether to do average or sum
135 | Reduce the values in the dictionary from all processes so that all processes
136 | have the averaged results. Returns a dict with the same fields as
137 | input_dict, after reduction.
138 | """
139 | world_size = get_world_size()
140 | if world_size < 2:
141 | return input_dict
142 | with torch.no_grad():
143 | names = []
144 | values = []
145 | # sort the keys so that they are consistent across processes
146 | for k in sorted(input_dict.keys()):
147 | names.append(k)
148 | values.append(input_dict[k])
149 | values = torch.stack(values, dim=0)
150 | dist.all_reduce(values)
151 | if average:
152 | values /= world_size
153 | reduced_dict = {k: v for k, v in zip(names, values)}
154 | return reduced_dict
155 |
156 |
157 | class MetricLogger(object):
158 | def __init__(self, delimiter="\t"):
159 | self.meters = defaultdict(SmoothedValue)
160 | self.delimiter = delimiter
161 |
162 | def update(self, **kwargs):
163 | for k, v in kwargs.items():
164 | if isinstance(v, torch.Tensor):
165 | v = v.item()
166 | assert isinstance(v, (float, int))
167 | self.meters[k].update(v)
168 |
169 | def __getattr__(self, attr):
170 | if attr in self.meters:
171 | return self.meters[attr]
172 | if attr in self.__dict__:
173 | return self.__dict__[attr]
174 | raise AttributeError("'{}' object has no attribute '{}'".format(
175 | type(self).__name__, attr))
176 |
177 | def __str__(self):
178 | loss_str = []
179 | for name, meter in self.meters.items():
180 | loss_str.append(
181 | "{}: {}".format(name, str(meter))
182 | )
183 | return self.delimiter.join(loss_str)
184 |
185 | def synchronize_between_processes(self):
186 | for meter in self.meters.values():
187 | meter.synchronize_between_processes()
188 |
189 | def add_meter(self, name, meter):
190 | self.meters[name] = meter
191 |
192 | def log_every(self, iterable, print_freq, header=None):
193 | i = 0
194 | if not header:
195 | header = ''
196 | start_time = time.time()
197 | end = time.time()
198 | iter_time = SmoothedValue(fmt='{avg:.4f}')
199 | data_time = SmoothedValue(fmt='{avg:.4f}')
200 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
201 | if torch.cuda.is_available():
202 | log_msg = self.delimiter.join([
203 | header,
204 | '[{0' + space_fmt + '}/{1}]',
205 | 'eta: {eta}',
206 | '{meters}',
207 | 'time: {time}',
208 | 'data: {data}',
209 | 'max mem: {memory:.0f}'
210 | ])
211 | else:
212 | log_msg = self.delimiter.join([
213 | header,
214 | '[{0' + space_fmt + '}/{1}]',
215 | 'eta: {eta}',
216 | '{meters}',
217 | 'time: {time}',
218 | 'data: {data}'
219 | ])
220 | MB = 1024.0 * 1024.0
221 | for obj in iterable:
222 | data_time.update(time.time() - end)
223 | yield obj
224 | iter_time.update(time.time() - end)
225 | if i % print_freq == 0 or i == len(iterable) - 1:
226 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
227 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
228 | if torch.cuda.is_available():
229 | print(log_msg.format(
230 | i, len(iterable), eta=eta_string,
231 | meters=str(self),
232 | time=str(iter_time), data=str(data_time),
233 | memory=torch.cuda.max_memory_allocated() / MB))
234 | else:
235 | print(log_msg.format(
236 | i, len(iterable), eta=eta_string,
237 | meters=str(self),
238 | time=str(iter_time), data=str(data_time)))
239 | i += 1
240 | end = time.time()
241 | total_time = time.time() - start_time
242 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
243 | print('{} Total time: {} ({:.4f} s / it)'.format(
244 | header, total_time_str, total_time / len(iterable)))
245 |
246 |
247 | def get_sha():
248 | cwd = os.path.dirname(os.path.abspath(__file__))
249 |
250 | def _run(command):
251 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
252 | sha = 'N/A'
253 | diff = "clean"
254 | branch = 'N/A'
255 | try:
256 | sha = _run(['git', 'rev-parse', 'HEAD'])
257 | subprocess.check_output(['git', 'diff'], cwd=cwd)
258 | diff = _run(['git', 'diff-index', 'HEAD'])
259 | diff = "has uncommited changes" if diff else "clean"
260 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
261 | except Exception:
262 | pass
263 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
264 | return message
265 |
266 |
267 | def collate_fn(batch):
268 | batch = list(zip(*batch))
269 | batch[0] = nested_tensor_from_tensor_list(batch[0])
270 | return tuple(batch)
271 |
272 |
273 | def _max_by_axis(the_list):
274 | # type: (List[List[int]]) -> List[int]
275 | maxes = the_list[0]
276 | for sublist in the_list[1:]:
277 | for index, item in enumerate(sublist):
278 | maxes[index] = max(maxes[index], item)
279 | return maxes
280 |
281 |
282 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor], split=True):
283 | # TODO make this more general
284 | if split:
285 | tensor_list = [tensor.split(3,dim=0) for tensor in tensor_list]
286 | tensor_list = [item for sublist in tensor_list for item in sublist]
287 | if tensor_list[0].ndim == 3:
288 | # TODO make it support different-sized images
289 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
290 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
291 | batch_shape = [len(tensor_list)] + max_size
292 | b, c, h, w = batch_shape
293 | dtype = tensor_list[0].dtype
294 | device = tensor_list[0].device
295 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
296 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
297 | for img, pad_img, m in zip(tensor_list, tensor, mask):
298 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
299 | m[: img.shape[1], :img.shape[2]] = False
300 | else:
301 | raise ValueError('not supported')
302 | return NestedTensor(tensor, mask)
303 |
304 |
305 | class NestedTensor(object):
306 | def __init__(self, tensors, mask: Optional[Tensor]):
307 | self.tensors = tensors
308 | self.mask = mask
309 |
310 | def to(self, device):
311 | # type: (Device) -> NestedTensor # noqa
312 | cast_tensor = self.tensors.to(device)
313 | mask = self.mask
314 | if mask is not None:
315 | assert mask is not None
316 | cast_mask = mask.to(device)
317 | else:
318 | cast_mask = None
319 | return NestedTensor(cast_tensor, cast_mask)
320 |
321 | def decompose(self):
322 | return self.tensors, self.mask
323 |
324 | def __repr__(self):
325 | return str(self.tensors)
326 |
327 |
328 | def setup_for_distributed(is_master):
329 | """
330 | This function disables printing when not in master process
331 | """
332 | import builtins as __builtin__
333 | builtin_print = __builtin__.print
334 |
335 | def print(*args, **kwargs):
336 | force = kwargs.pop('force', False)
337 | if is_master or force:
338 | builtin_print(*args, **kwargs)
339 |
340 | __builtin__.print = print
341 |
342 |
343 | def is_dist_avail_and_initialized():
344 | if not dist.is_available():
345 | return False
346 | if not dist.is_initialized():
347 | return False
348 | return True
349 |
350 |
351 | def get_world_size():
352 | if not is_dist_avail_and_initialized():
353 | return 1
354 | return dist.get_world_size()
355 |
356 |
357 | def get_rank():
358 | if not is_dist_avail_and_initialized():
359 | return 0
360 | return dist.get_rank()
361 |
362 |
363 | def is_main_process():
364 | return get_rank() == 0
365 |
366 |
367 | def save_on_master(*args, **kwargs):
368 | if is_main_process():
369 | torch.save(*args, **kwargs, _use_new_zipfile_serialization=False)
370 |
371 |
372 | def init_distributed_mode(args):
373 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
374 | args.rank = int(os.environ["RANK"])
375 | args.world_size = int(os.environ['WORLD_SIZE'])
376 | args.gpu = int(os.environ['LOCAL_RANK'])
377 | elif 'SLURM_PROCID' in os.environ:
378 | args.rank = int(os.environ['SLURM_PROCID'])
379 | args.gpu = args.rank % torch.cuda.device_count()
380 | else:
381 | print('Not using distributed mode')
382 | args.distributed = False
383 | return
384 |
385 | args.distributed = True
386 |
387 | torch.cuda.set_device(args.gpu)
388 | args.dist_backend = 'nccl'
389 | print('| distributed init (rank {}): {}'.format(
390 | args.rank, args.dist_url), flush=True)
391 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
392 | world_size=args.world_size, rank=args.rank)
393 | torch.distributed.barrier()
394 | setup_for_distributed(args.rank == 0)
395 |
396 |
397 | @torch.no_grad()
398 | def accuracy(output, target, topk=(1,)):
399 | """Computes the precision@k for the specified values of k"""
400 | if target.numel() == 0:
401 | return [torch.zeros([], device=output.device)]
402 | maxk = max(topk)
403 | batch_size = target.size(0)
404 |
405 | _, pred = output.topk(maxk, 1, True, True)
406 | pred = pred.t()
407 | correct = pred.eq(target.view(1, -1).expand_as(pred))
408 |
409 | res = []
410 | for k in topk:
411 | correct_k = correct[:k].view(-1).float().sum(0)
412 | res.append(correct_k.mul_(100.0 / batch_size))
413 | return res
414 |
415 |
416 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
417 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
418 | """
419 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
420 | This will eventually be supported natively by PyTorch, and this
421 | class can go away.
422 | """
423 | if float(torchvision.__version__[:3]) < 0.7:
424 | if input.numel() > 0:
425 | return torch.nn.functional.interpolate(
426 | input, size, scale_factor, mode, align_corners
427 | )
428 |
429 | output_shape = _output_size(2, input, size, scale_factor)
430 | output_shape = list(input.shape[:-2]) + list(output_shape)
431 | return _new_empty_tensor(input, output_shape)
432 | else:
433 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
434 |
--------------------------------------------------------------------------------
/util/plot_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Plotting utilities to visualize training logs.
3 | """
4 | import torch
5 | import pandas as pd
6 | import seaborn as sns
7 | import matplotlib.pyplot as plt
8 |
9 | from pathlib import Path, PurePath
10 |
11 |
12 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'):
13 | '''
14 | Function to plot specific fields from training log(s). Plots both training and test results.
15 |
16 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file
17 | - fields = which results to plot from each log file - plots both training and test for each field.
18 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots
19 | - log_name = optional, name of log file if different than default 'log.txt'.
20 |
21 | :: Outputs - matplotlib plots of results in fields, color coded for each log file.
22 | - solid lines are training results, dashed lines are test results.
23 |
24 | '''
25 | func_name = "plot_utils.py::plot_logs"
26 |
27 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path,
28 | # convert single Path to list to avoid 'not iterable' error
29 |
30 | if not isinstance(logs, list):
31 | if isinstance(logs, PurePath):
32 | logs = [logs]
33 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].")
34 | else:
35 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \
36 | Expect list[Path] or single Path obj, received {type(logs)}")
37 |
38 | # verify valid dir(s) and that every item in list is Path object
39 | for i, dir in enumerate(logs):
40 | if not isinstance(dir, PurePath):
41 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}")
42 | if dir.exists():
43 | continue
44 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}")
45 |
46 | # load log file(s) and plot
47 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs]
48 |
49 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5))
50 |
51 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))):
52 | for j, field in enumerate(fields):
53 | if field == 'mAP':
54 | coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean()
55 | axs[j].plot(coco_eval, c=color)
56 | else:
57 | df.interpolate().ewm(com=ewm_col).mean().plot(
58 | y=[f'train_{field}', f'test_{field}'],
59 | ax=axs[j],
60 | color=[color] * 2,
61 | style=['-', '--']
62 | )
63 | for ax, field in zip(axs, fields):
64 | ax.legend([Path(p).name for p in logs])
65 | ax.set_title(field)
66 |
67 |
68 | def plot_precision_recall(files, naming_scheme='iter'):
69 | if naming_scheme == 'exp_id':
70 | # name becomes exp_id
71 | names = [f.parts[-3] for f in files]
72 | elif naming_scheme == 'iter':
73 | names = [f.stem for f in files]
74 | else:
75 | raise ValueError(f'not supported {naming_scheme}')
76 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5))
77 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names):
78 | data = torch.load(f)
79 | # precision is n_iou, n_points, n_cat, n_area, max_det
80 | precision = data['precision']
81 | recall = data['params'].recThrs
82 | scores = data['scores']
83 | # take precision for all classes, all areas and 100 detections
84 | precision = precision[0, :, :, 0, -1].mean(1)
85 | scores = scores[0, :, :, 0, -1].mean(1)
86 | prec = precision.mean()
87 | rec = data['recall'][0, :, 0, -1].mean()
88 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' +
89 | f'score={scores.mean():0.3f}, ' +
90 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}'
91 | )
92 | axs[0].plot(recall, precision, c=color)
93 | axs[1].plot(recall, scores, c=color)
94 |
95 | axs[0].set_title('Precision / Recall')
96 | axs[0].legend(names)
97 | axs[1].set_title('Scores / Recall')
98 | axs[1].legend(names)
99 | return fig, axs
100 |
--------------------------------------------------------------------------------