├── README.md
├── demo.py
├── demo
└── demo_examples
│ ├── 00000.jpg
│ ├── 00005.jpg
│ ├── 00010.jpg
│ ├── 00015.jpg
│ ├── 00020.jpg
│ ├── 00025.jpg
│ ├── 00030.jpg
│ ├── 00035.jpg
│ ├── 00040.jpg
│ ├── 00045.jpg
│ ├── 00050.jpg
│ ├── 00055.jpg
│ ├── 00060.jpg
│ ├── 00065.jpg
│ ├── 00070.jpg
│ ├── 00075.jpg
│ ├── 00080.jpg
│ ├── 00085.jpg
│ └── 00090.jpg
├── engine.py
├── illustration.jpg
├── inference_ytvos.py
├── inference_ytvos_segm.py
├── main.py
├── models
├── __init__.py
├── amm_resnet.py
├── backbone.py
├── criterion.py
├── cycle.py
├── deformable_transformer.py
├── matcher.py
├── ops
│ ├── functions
│ │ ├── __init__.py
│ │ └── ms_deform_attn_func.py
│ ├── make.sh
│ ├── modules
│ │ ├── __init__.py
│ │ └── ms_deform_attn.py
│ ├── setup.py
│ ├── src
│ │ ├── cpu
│ │ │ ├── ms_deform_attn_cpu.cpp
│ │ │ └── ms_deform_attn_cpu.h
│ │ ├── cuda
│ │ │ ├── ms_deform_attn_cuda.cu
│ │ │ ├── ms_deform_attn_cuda.h
│ │ │ └── ms_deform_im2col_cuda.cuh
│ │ ├── ms_deform_attn.h
│ │ └── vision.cpp
│ └── test.py
├── position_encoding.py
├── postprocessors.py
├── referformer.py
├── segmentation.py
├── swin_transformer.py
├── vector_quantitizer.py
└── video_swin_transformer.py
├── opts.py
├── requirements.txt
├── tools
├── __init__.py
├── colormap.py
├── data
│ ├── convert_davis_to_ytvos.py
│ └── convert_refexp_to_coco.py
├── load_pretrained_weights.py
└── warmup_poly_lr_scheduler.py
└── util
├── __init__.py
├── box_ops.py
└── misc.py
/README.md:
--------------------------------------------------------------------------------
1 | > [**Towards Robust Referring Video Object Segmentation with Cyclic Relational Consistency**](https://arxiv.org/abs/2207.01203)
2 | >
3 | > Xiang Li, Jinglu Wang, Xiaohao Xu, Xiao Li, Bhiksha Raj, Yan Lu
4 |
5 |

6 |
7 | # Updates
8 | - **(2023-05-30)** Code released.
9 | - **(2023-07-13)** R2VOS is accepted to ICCV 2023!
10 |
11 | # Install
12 |
13 | ```
14 | conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch
15 | pip install -r requirements.txt
16 | pip install 'git+https://github.com/facebookresearch/fvcore'
17 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
18 | cd models/ops
19 | python setup.py build install
20 | cd ../..
21 | ```
22 |
23 | # Docker
24 | You may try [docker](https://hub.docker.com/r/ang9867/refer) to quick start.
25 |
26 | # Weights
27 | Please download and put the [checkpoint.pth](https://drive.google.com/file/d/1gknDDMxWKqZ7yPuTh1fome1-Ba4f_G9K/view?usp=share_link) in the main folder.
28 |
29 | # Run demo:
30 | Inference on images in the demo/demo_examples.
31 | ```
32 | python demo.py --with_box_refine --binary --freeze_text_encoder --output_dir=output/demo --resume=checkpoint.pth --backbone resnet50 --ngpu 1 --use_cycle --mix_query --neg_cls --is_eval --use_cls --demo_exp 'a big track on the road' --demo_path 'demo/demo_examples'
33 | ```
34 |
35 | # Inference:
36 | If you want to evaluate on Ref-YTVOS, you may try inference_ytvos.py or inference_ytvos_segm.py if you encounter OOM for the entire video inference.
37 | ```
38 | python inference_ytvos.py --with_box_refine --binary --freeze_text_encoder --output_dir=output/eval --resume=checkpoint.pth --backbone resnet50 --ngpu 1 --use_cycle --mix_query --neg_cls --is_eval --use_cls --ytvos_path=/data/ref-ytvos
39 | ```
40 | # Related works for robust multimodal video segmentation:
41 | > [R2-Bench: Benchmarking the Robustness of Referring Perception Models under Perturbations
42 | ](https://arxiv.org/abs/2403.04924), Arxiv 2024
43 |
44 | > [Towards Robust Audiovisual Segmentation in Complex Environments with Quantization-based Semantic Decomposition](https://arxiv.org/abs/2310.00132), CVPR 2024
45 | ## Citation
46 | ```
47 | @inproceedings{li2023robust,
48 | title={Robust referring video object segmentation with cyclic structural consensus},
49 | author={Li, Xiang and Wang, Jinglu and Xu, Xiaohao and Li, Xiao and Raj, Bhiksha and Lu, Yan},
50 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
51 | pages={22236--22245},
52 | year={2023}
53 | }
54 | ```
55 |
56 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | '''
2 | Inference code for ReferFormer, on Ref-Youtube-VOS
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | '''
5 | import argparse
6 | import json
7 | import random
8 | import sys
9 | import time
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 |
15 | import util.misc as utils
16 | from models import build_model
17 | import torchvision.transforms as T
18 | import matplotlib.pyplot as plt
19 | import os
20 | import cv2
21 | from PIL import Image, ImageDraw
22 | import math
23 | import torch.nn.functional as F
24 | import json
25 |
26 | import opts
27 | from tqdm import tqdm
28 |
29 | import multiprocessing as mp
30 | import threading
31 | import glob
32 |
33 |
34 | from tools.colormap import colormap
35 |
36 |
37 | # colormap
38 | color_list = colormap()
39 | color_list = color_list.astype('uint8').tolist()
40 |
41 | def main(args):
42 | args.masks = True
43 | args.batch_size == 1
44 | print("Inference only supports for batch size = 1")
45 |
46 | global transform
47 | transform = T.Compose([
48 | T.Resize(args.inf_res),
49 | T.ToTensor(),
50 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
51 | ])
52 |
53 | # fix the seed for reproducibility
54 | seed = args.seed + utils.get_rank()
55 | torch.manual_seed(seed)
56 | np.random.seed(seed)
57 | random.seed(seed)
58 |
59 | # save path
60 | output_dir = args.output_dir
61 | save_path_prefix = os.path.join(output_dir)
62 | if not os.path.exists(save_path_prefix):
63 | os.makedirs(save_path_prefix)
64 |
65 | global result_dict
66 | result_dict = mp.Manager().dict()
67 | frames = sorted(glob.glob(args.demo_path+'/*'))
68 | sub_processor(0, args, args.demo_exp, frames, save_path_prefix)
69 |
70 | result_dict = dict(result_dict)
71 | num_all_frames_gpus = 0
72 | for pid, num_all_frames in result_dict.items():
73 | num_all_frames_gpus += num_all_frames
74 |
75 | def sub_processor(pid, args, exp, frames, save_path_prefix):
76 | torch.cuda.set_device(pid)
77 |
78 | # model
79 | model, criterion, _ = build_model(args)
80 | device = args.device
81 | model.to(device)
82 |
83 | model_without_ddp = model
84 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
85 |
86 | if pid == 0:
87 | print('number of params:', n_parameters)
88 |
89 | if args.resume:
90 | checkpoint = torch.load(args.resume, map_location='cpu')
91 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
92 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
93 | if len(missing_keys) > 0:
94 | print('Missing Keys: {}'.format(missing_keys))
95 | if len(unexpected_keys) > 0:
96 | print('Unexpected Keys: {}'.format(unexpected_keys))
97 | else:
98 | raise ValueError('Please specify the checkpoint for inference.')
99 |
100 |
101 | # start inference
102 | num_all_frames = 0
103 | model.eval()
104 |
105 | sentence_features = []
106 | pseudo_sentence_features = []
107 | video_name = 'demo'
108 | # exp = meta[i]["exp"]
109 | # # exp = 'a dog is with its puppies on the cloth'
110 | # # TODO: temp
111 | # frames = meta[i]["frames"]
112 | # frames = [f'/home/mcg/ReferFormer/demo/frames_{fid}.jpg' for fid in range(1,2)]
113 |
114 | video_len = len(frames)
115 | # store images
116 | imgs = []
117 | for t in range(video_len):
118 | frame = frames[t]
119 | img_path = os.path.join(frame)
120 | img = Image.open(img_path).convert('RGB')
121 | origin_w, origin_h = img.size
122 | imgs.append(transform(img)) # list[img]
123 |
124 | imgs = torch.stack(imgs, dim=0).to(args.device) # [video_len, 3, h, w]
125 | img_h, img_w = imgs.shape[-2:]
126 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device)
127 | target = {"size": size}
128 |
129 | with torch.no_grad():
130 | outputs = model([imgs], [exp], [target])
131 |
132 | pred_logits = outputs["pred_logits"][0]
133 | pred_boxes = outputs["pred_boxes"][0]
134 | pred_masks = outputs["pred_masks"][0]
135 | pred_ref_points = outputs["reference_points"][0]
136 | text_sentence_features = outputs['sentence_feature']
137 | if args.use_cycle:
138 | pseudo_text_sentence_features = outputs['pseudo_sentence_feature']
139 | # anchor = outputs['negative_anchor']
140 | sentence_features.append(text_sentence_features)
141 | pseudo_sentence_features.append(pseudo_text_sentence_features)
142 | # print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2))
143 | # print(anchor)
144 | # according to pred_logits, select the query index
145 | pred_scores = pred_logits.sigmoid() # [t, q, k]
146 | pred_score = pred_scores
147 | pred_scores = pred_scores.mean(0) # [q, k]
148 | max_scores, _ = pred_scores.max(-1) # [q,]
149 | # print(max_scores)
150 | _, max_ind = max_scores.max(-1) # [1,]
151 | max_inds = max_ind.repeat(video_len)
152 | pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w]
153 | pred_masks = pred_masks.unsqueeze(0)
154 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False)
155 | if args.save_prob:
156 | pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy()
157 | else:
158 | pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy()
159 | if args.use_score:
160 | pred_score = pred_score[range(video_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1)
161 | pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks
162 |
163 | # store the video results
164 | all_pred_logits = pred_logits[range(video_len), max_inds].sigmoid().cpu().numpy()
165 | all_pred_boxes = pred_boxes[range(video_len), max_inds]
166 | all_pred_ref_points = pred_ref_points[range(video_len), max_inds]
167 | all_pred_masks = pred_masks
168 |
169 | save_path = os.path.join(save_path_prefix)
170 | if not os.path.exists(save_path):
171 | os.makedirs(save_path)
172 | for j in range(video_len):
173 | frame_name = frames[j]
174 | confidence = all_pred_logits[j]
175 | mask = all_pred_masks[j].astype(np.float32)
176 | save_file = os.path.join(save_path, f"{j}" + ".png")
177 | # print(save_file)
178 | if 'pair_logits' in outputs.keys() and args.use_cls:
179 | if outputs['pair_logits'].cpu().numpy() >= 0.5:
180 | print('This is a negative pair, disalignment degree:', outputs['pair_logits'].cpu().numpy().item())
181 | else:
182 | print('This is a positive pair, disalignment degree:', outputs['pair_logits'].cpu().numpy().item())
183 | mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1
184 | mask = Image.fromarray(mask * 255).convert('L')
185 | mask.save(save_file)
186 | print(f'Results saved to {save_path}')
187 | result_dict[str(pid)] = num_all_frames
188 |
189 |
190 | # visuaize functions
191 | def box_cxcywh_to_xyxy(x):
192 | x_c, y_c, w, h = x.unbind(1)
193 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
194 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
195 | return torch.stack(b, dim=1)
196 |
197 | def rescale_bboxes(out_bbox, size):
198 | img_w, img_h = size
199 | b = box_cxcywh_to_xyxy(out_bbox)
200 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
201 | return b
202 |
203 |
204 | # Visualization functions
205 | def draw_reference_points(draw, reference_points, img_size, color):
206 | W, H = img_size
207 | for i, ref_point in enumerate(reference_points):
208 | init_x, init_y = ref_point
209 | x, y = W * init_x, H * init_y
210 | cur_color = color
211 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4)
212 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4)
213 |
214 | def draw_sample_points(draw, sample_points, img_size, color_list):
215 | alpha = 255
216 | for i, samples in enumerate(sample_points):
217 | for sample in samples:
218 | x, y = sample
219 | cur_color = color_list[i % len(color_list)][::-1]
220 | cur_color += [alpha]
221 | draw.ellipse((x-2, y-2, x+2, y+2),
222 | fill=tuple(cur_color), outline=tuple(cur_color), width=1)
223 |
224 | def vis_add_mask(img, mask, color):
225 | origin_img = np.asarray(img.convert('RGB')).copy()
226 | color = np.array(color)
227 |
228 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np
229 | mask = mask > 0.5
230 |
231 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5
232 | origin_img = Image.fromarray(origin_img)
233 | return origin_img
234 |
235 |
236 |
237 | if __name__ == '__main__':
238 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()])
239 | args = parser.parse_args()
240 | main(args)
241 |
--------------------------------------------------------------------------------
/demo/demo_examples/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00000.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00005.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00005.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00010.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00010.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00015.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00015.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00020.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00020.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00025.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00025.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00030.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00030.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00035.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00035.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00040.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00040.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00045.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00045.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00050.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00050.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00055.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00055.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00060.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00060.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00065.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00065.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00070.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00070.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00075.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00075.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00080.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00080.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00085.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00085.jpg
--------------------------------------------------------------------------------
/demo/demo_examples/00090.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00090.jpg
--------------------------------------------------------------------------------
/engine.py:
--------------------------------------------------------------------------------
1 | """
2 | Train and eval functions used in main.py
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import math
6 | from models import postprocessors
7 | import os
8 | import sys
9 | from typing import Iterable
10 |
11 | import torch
12 | import torch.distributed as dist
13 |
14 | import util.misc as utils
15 | from datasets.coco_eval import CocoEvaluator
16 | from datasets.refexp_eval import RefExpEvaluator
17 |
18 | from pycocotools.coco import COCO
19 | from pycocotools.cocoeval import COCOeval
20 | from datasets.a2d_eval import calculate_precision_at_k_and_iou_metrics
21 |
22 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
23 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
24 | device: torch.device, lr_scheduler, epoch: int, max_norm: float = 0):
25 | model.train()
26 | criterion.train()
27 | metric_logger = utils.MetricLogger(delimiter=" ")
28 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
29 | header = 'Epoch: [{}]'.format(epoch)
30 | print_freq = 10
31 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
32 | samples = samples.to(device)
33 | captions = [t["caption"] for t in targets]
34 | targets = utils.targets_to(targets, device)
35 |
36 | outputs = model(samples, captions, targets)
37 | loss_dict = criterion(outputs, targets)
38 | weight_dict = criterion.weight_dict
39 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
40 |
41 | # reduce losses over all GPUs for logging purposes
42 | loss_dict_reduced = utils.reduce_dict(loss_dict)
43 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
44 | for k, v in loss_dict_reduced.items()}
45 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
46 | for k, v in loss_dict_reduced.items() if k in weight_dict}
47 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
48 | loss_value = losses_reduced_scaled.item()
49 |
50 | if not math.isfinite(loss_value):
51 | print("Loss is {}, stopping training".format(loss_value))
52 | print(loss_dict_reduced)
53 | sys.exit(1)
54 | optimizer.zero_grad()
55 | losses.backward()
56 | if max_norm > 0:
57 | grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
58 | else:
59 | grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
60 | optimizer.step()
61 | # lr_scheduler.step()
62 |
63 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
64 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
65 | metric_logger.update(grad_norm=grad_total_norm)
66 |
67 | # gather the stats from all processes
68 | metric_logger.synchronize_between_processes()
69 | print("Averaged stats:", metric_logger)
70 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
71 |
72 |
73 | @torch.no_grad()
74 | def evaluate(model, criterion, postprocessors, data_loader, evaluator_list, device, args):
75 | model.eval()
76 | criterion.eval()
77 |
78 | metric_logger = utils.MetricLogger(delimiter=" ")
79 | header = 'Test:'
80 |
81 | for samples, targets in metric_logger.log_every(data_loader, 10, header):
82 | samples = samples.to(device)
83 | captions = [t["caption"] for t in targets]
84 | targets = utils.targets_to(targets, device)
85 | outputs = model(samples, captions, targets)
86 | loss_dict = criterion(outputs, targets)
87 | weight_dict = criterion.weight_dict
88 |
89 | # reduce losses over all GPUs for logging purposes
90 | loss_dict_reduced = utils.reduce_dict(loss_dict)
91 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
92 | for k, v in loss_dict_reduced.items() if k in weight_dict}
93 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
94 | for k, v in loss_dict_reduced.items()}
95 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
96 | **loss_dict_reduced_scaled,
97 | **loss_dict_reduced_unscaled)
98 |
99 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
100 | results = postprocessors['bbox'](outputs, orig_target_sizes)
101 | if 'segm' in postprocessors.keys():
102 | target_sizes = torch.stack([t["size"] for t in targets], dim=0)
103 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
104 | res = {target['image_id']: output for target, output in zip(targets, results)}
105 |
106 | for evaluator in evaluator_list:
107 | evaluator.update(res)
108 |
109 | # gather the stats from all processes
110 | metric_logger.synchronize_between_processes()
111 | print("Averaged stats:", metric_logger)
112 | for evaluator in evaluator_list:
113 | evaluator.synchronize_between_processes()
114 |
115 | # accumulate predictions from all images
116 | refexp_res = None
117 | for evaluator in evaluator_list:
118 | if isinstance(evaluator, CocoEvaluator):
119 | evaluator.accumulate()
120 | evaluator.summarize()
121 | elif isinstance(evaluator, RefExpEvaluator):
122 | refexp_res = evaluator.summarize()
123 |
124 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
125 |
126 | # update stats
127 | for evaluator in evaluator_list:
128 | if isinstance(evaluator, CocoEvaluator):
129 | if "bbox" in postprocessors.keys():
130 | stats["coco_eval_bbox"] = evaluator.coco_eval["bbox"].stats.tolist()
131 | if "segm" in postprocessors.keys():
132 | stats["coco_eval_masks"] = evaluator.coco_eval["segm"].stats.tolist()
133 | if refexp_res is not None:
134 | stats.update(refexp_res)
135 |
136 | return stats
137 |
138 |
139 | @torch.no_grad()
140 | def evaluate_a2d(model, data_loader, postprocessor, device, args):
141 | model.eval()
142 | predictions = []
143 | metric_logger = utils.MetricLogger(delimiter=" ")
144 | header = 'Test:'
145 |
146 | for samples, targets in metric_logger.log_every(data_loader, 10, header):
147 | image_ids = [t['image_id'] for t in targets]
148 |
149 | samples = samples.to(device)
150 | captions = [t["caption"] for t in targets]
151 | targets = utils.targets_to(targets, device)
152 |
153 | outputs = model(samples, captions, targets)
154 |
155 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
156 | target_sizes = torch.stack([t["size"] for t in targets], dim=0)
157 | processed_outputs = postprocessor(outputs, orig_target_sizes, target_sizes)
158 |
159 | for p, image_id in zip(processed_outputs, image_ids):
160 | for s, m in zip(p['scores'], p['rle_masks']):
161 | predictions.append({'image_id': image_id,
162 | 'category_id': 1, # dummy label, as categories are not predicted in ref-vos
163 | 'segmentation': m,
164 | 'score': s.item()})
165 |
166 | # gather and merge predictions from all gpus
167 | gathered_pred_lists = utils.all_gather(predictions)
168 | predictions = [p for p_list in gathered_pred_lists for p in p_list]
169 | # evaluation
170 | eval_metrics = {}
171 | if utils.is_main_process():
172 | if args.dataset_file == 'a2d':
173 | coco_gt = COCO(os.path.join(args.a2d_path, 'a2d_sentences_test_annotations_in_coco_format.json'))
174 | elif args.dataset_file == 'jhmdb':
175 | coco_gt = COCO(os.path.join(args.jhmdb_path, 'jhmdb_sentences_gt_annotations_in_coco_format.json'))
176 | elif args.dataset_file == 'refcocoVideo':
177 | coco_gt = COCO(os.path.join(args.coco_path, 'refcocoVideo/finetune_refcoco_val.json'))
178 | else:
179 | raise NotImplementedError
180 | coco_pred = coco_gt.loadRes(predictions)
181 | coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm')
182 | coco_eval.params.useCats = 0 # ignore categories as they are not predicted in ref-vos task
183 | coco_eval.evaluate()
184 | coco_eval.accumulate()
185 | coco_eval.summarize()
186 | ap_labels = ['mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', 'AP 0.5:0.95 M', 'AP 0.5:0.95 L']
187 | ap_metrics = coco_eval.stats[:6]
188 | eval_metrics = {l: m for l, m in zip(ap_labels, ap_metrics)}
189 | # Precision and IOU
190 | precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(coco_gt, coco_pred)
191 | eval_metrics.update({f'P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)})
192 | eval_metrics.update({'overall_iou': overall_iou, 'mean_iou': mean_iou})
193 | print(eval_metrics)
194 |
195 | # sync all processes before starting a new epoch or exiting
196 | dist.barrier()
197 | return eval_metrics
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
--------------------------------------------------------------------------------
/illustration.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/illustration.jpg
--------------------------------------------------------------------------------
/inference_ytvos.py:
--------------------------------------------------------------------------------
1 | '''
2 | Inference code for ReferFormer, on Ref-Youtube-VOS
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | '''
5 | import argparse
6 | import json
7 | import random
8 | import sys
9 | import time
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 |
15 | import util.misc as utils
16 | from models import build_model
17 | import torchvision.transforms as T
18 | import matplotlib.pyplot as plt
19 | import os
20 | import cv2
21 | from PIL import Image, ImageDraw
22 | import math
23 | import torch.nn.functional as F
24 | import json
25 |
26 | import opts
27 | from tqdm import tqdm
28 |
29 | import multiprocessing as mp
30 | import threading
31 |
32 |
33 | from tools.colormap import colormap
34 |
35 |
36 | # colormap
37 | color_list = colormap()
38 | color_list = color_list.astype('uint8').tolist()
39 |
40 | def main(args):
41 | args.masks = True
42 | args.batch_size == 1
43 | print("Inference only supports for batch size = 1")
44 |
45 | global transform
46 | transform = T.Compose([
47 | T.Resize(args.inf_res),
48 | T.ToTensor(),
49 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
50 | ])
51 |
52 | # fix the seed for reproducibility
53 | seed = args.seed + utils.get_rank()
54 | torch.manual_seed(seed)
55 | np.random.seed(seed)
56 | random.seed(seed)
57 |
58 | split = args.split
59 | # save path
60 | output_dir = args.output_dir
61 | save_path_prefix = os.path.join(output_dir, split)
62 | if not os.path.exists(save_path_prefix):
63 | os.makedirs(save_path_prefix)
64 |
65 | save_visualize_path_prefix = os.path.join(output_dir, split + '_images')
66 | if args.visualize:
67 | if not os.path.exists(save_visualize_path_prefix):
68 | os.makedirs(save_visualize_path_prefix)
69 | if args.val_type:
70 | print(f'Warning: use robust rvos validation set, type: {args.val_type}')
71 | assert args.val_type in ['_random', '_color', '_old_color', '_object', '_action', '_dyn', '_1', '_3', '_5'], \
72 | f'Unknown val type for robust rvos, {args.val_type}'
73 | # load data
74 | root = Path(args.ytvos_path) # data/ref-youtube-vos
75 | img_folder = os.path.join(root, split, f"JPEGImages{args.val_type}")
76 | meta_file = os.path.join(root, "meta_expressions", split, f"meta_expressions{args.val_type}.json")
77 | with open(meta_file, "r") as f:
78 | data = json.load(f)["videos"]
79 | valid_test_videos = set(data.keys())
80 | # for some reasons the competition's validation expressions dict contains both the validation (202) &
81 | # test videos (305). so we simply load the test expressions dict and use it to filter out the test videos from
82 | # the validation expressions dict:
83 | test_meta_file = os.path.join(root, "meta_expressions", "test", "meta_expressions.json")
84 | with open(test_meta_file, 'r') as f:
85 | test_data = json.load(f)['videos']
86 | test_videos = set(test_data.keys())
87 | valid_videos = valid_test_videos - test_videos
88 | video_list = sorted([video for video in valid_videos])
89 | if 'color' not in args.val_type:
90 | assert len(video_list) == 202 or len(video_list) == 202+305, 'error: incorrect number of validation videos'
91 |
92 | # create subprocess
93 | thread_num = args.ngpu
94 | global result_dict
95 | result_dict = mp.Manager().dict()
96 |
97 | processes = []
98 | lock = threading.Lock()
99 |
100 | video_num = len(video_list)
101 | per_thread_video_num = video_num // thread_num
102 |
103 | start_time = time.time()
104 | print('Start inference')
105 | for i in range(thread_num):
106 | if i == thread_num - 1:
107 | sub_video_list = video_list[i * per_thread_video_num:]
108 | else:
109 | sub_video_list = video_list[i * per_thread_video_num: (i + 1) * per_thread_video_num]
110 | p = mp.Process(target=sub_processor, args=(lock, i, args, data,
111 | save_path_prefix, save_visualize_path_prefix,
112 | img_folder, sub_video_list))
113 | p.start()
114 | processes.append(p)
115 |
116 | for p in processes:
117 | p.join()
118 |
119 | end_time = time.time()
120 | total_time = end_time - start_time
121 |
122 | result_dict = dict(result_dict)
123 | num_all_frames_gpus = 0
124 | for pid, num_all_frames in result_dict.items():
125 | num_all_frames_gpus += num_all_frames
126 |
127 | print("Total inference time: %.4f s" %(total_time))
128 |
129 | def sub_processor(lock, pid, args, data, save_path_prefix, save_visualize_path_prefix, img_folder, video_list):
130 | text = 'processor %d' % pid
131 | with lock:
132 | progress = tqdm(
133 | total=len(video_list),
134 | position=pid,
135 | desc=text,
136 | ncols=0
137 | )
138 | torch.cuda.set_device(pid)
139 |
140 | # model
141 | model, criterion, _ = build_model(args)
142 | device = args.device
143 | model.to(device)
144 |
145 | model_without_ddp = model
146 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
147 |
148 | if pid == 0:
149 | print('number of params:', n_parameters)
150 |
151 | if args.resume:
152 | checkpoint = torch.load(args.resume, map_location='cpu')
153 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
154 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
155 | if len(missing_keys) > 0:
156 | print('Missing Keys: {}'.format(missing_keys))
157 | if len(unexpected_keys) > 0:
158 | print('Unexpected Keys: {}'.format(unexpected_keys))
159 | else:
160 | raise ValueError('Please specify the checkpoint for inference.')
161 |
162 |
163 | # start inference
164 | num_all_frames = 0
165 | model.eval()
166 |
167 | sentence_features = []
168 | pseudo_sentence_features = []
169 | # 1. For each video
170 | for vid, video in enumerate(video_list):
171 | metas = [] # list[dict], length is number of expressions
172 |
173 | expressions = data[video]["expressions"]
174 | expression_list = list(expressions.keys())
175 | num_expressions = len(expression_list)
176 | video_len = len(data[video]["frames"])
177 |
178 | # read all the anno meta
179 | for i in range(num_expressions):
180 | meta = {}
181 | meta["video"] = video
182 | meta["exp"] = expressions[expression_list[i]]["exp"]
183 | meta["exp_id"] = expression_list[i]
184 | meta["frames"] = data[video]["frames"]
185 | metas.append(meta)
186 | meta = metas
187 |
188 | # if vid < 8: # 46, 81
189 | # print(video, metas[0]["exp"])
190 | # continue
191 |
192 | vis = False
193 | # 2. For each expression
194 | for i in range(num_expressions):
195 | video_name = meta[i]["video"]
196 | exp = meta[i]["exp"]
197 | # exp = 'a dog is with its puppies on the cloth'
198 | # TODO: temp
199 | exp_id = meta[i]["exp_id"]
200 | frames = meta[i]["frames"]
201 | # frames = [f'/home/mcg/ReferFormer/demo/frames_{fid}.jpg' for fid in range(1,2)]
202 | if vis:
203 | os.makedirs('vis/tmp', exist_ok=True)
204 |
205 | video_len = len(frames)
206 | # store images
207 | imgs = []
208 | for t in range(video_len):
209 | frame = frames[t]
210 | img_path = os.path.join(img_folder, video_name, frame + ".jpg")
211 | # img_path = frame
212 | img = Image.open(img_path).convert('RGB')
213 | if vis:
214 | if t < 1:
215 | img.save(f'vis/tmp/{t}.png')
216 | origin_w, origin_h = img.size
217 | imgs.append(transform(img)) # list[img]
218 |
219 | imgs = torch.stack(imgs, dim=0).to(args.device) # [video_len, 3, h, w]
220 | img_h, img_w = imgs.shape[-2:]
221 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device)
222 | target = {"size": size}
223 |
224 | # start = time.time()
225 | with torch.no_grad():
226 | outputs = model([imgs], [exp], [target])
227 | # end = time.time()
228 | # print((end-start)/video_len)
229 |
230 | if vis:
231 | os.system(f'mv vis/tmp vis/{vid}-{exp.replace(" ", "_")}')
232 |
233 | pred_logits = outputs["pred_logits"][0]
234 | pred_boxes = outputs["pred_boxes"][0]
235 | pred_masks = outputs["pred_masks"][0]
236 | pred_ref_points = outputs["reference_points"][0]
237 | text_sentence_features = outputs['sentence_feature']
238 | if args.use_cycle:
239 | pseudo_text_sentence_features = outputs['pseudo_sentence_feature']
240 | # anchor = outputs['negative_anchor']
241 | sentence_features.append(text_sentence_features)
242 | pseudo_sentence_features.append(pseudo_text_sentence_features)
243 | # print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2))
244 | # print(anchor)
245 | # according to pred_logits, select the query index
246 | pred_scores = pred_logits.sigmoid() # [t, q, k]
247 | pred_score = pred_scores
248 | pred_scores = pred_scores.mean(0) # [q, k]
249 | max_scores, _ = pred_scores.max(-1) # [q,]
250 | # print(max_scores)
251 | _, max_ind = max_scores.max(-1) # [1,]
252 | max_inds = max_ind.repeat(video_len)
253 | pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w]
254 | pred_masks = pred_masks.unsqueeze(0)
255 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False)
256 | if args.save_prob:
257 | pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy()
258 | else:
259 | pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy()
260 | if args.use_score:
261 | pred_score = pred_score[range(video_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1)
262 | pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks
263 |
264 | # store the video results
265 | all_pred_logits = pred_logits[range(video_len), max_inds].sigmoid().cpu().numpy()
266 | all_pred_boxes = pred_boxes[range(video_len), max_inds]
267 | all_pred_ref_points = pred_ref_points[range(video_len), max_inds]
268 | all_pred_masks = pred_masks
269 |
270 | if args.visualize:
271 | for t, frame in enumerate(frames):
272 | # original
273 | img_path = os.path.join(img_folder, video_name, frame + '.jpg')
274 | source_img = Image.open(img_path).convert('RGBA') # PIL image
275 |
276 | draw = ImageDraw.Draw(source_img)
277 | draw_boxes = all_pred_boxes[t].unsqueeze(0)
278 | draw_boxes = rescale_bboxes(draw_boxes.detach(), (origin_w, origin_h)).tolist()
279 |
280 | # draw boxes
281 | xmin, ymin, xmax, ymax = draw_boxes[0]
282 | draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=tuple(color_list[i%len(color_list)]), width=2)
283 |
284 | # draw reference point
285 | ref_points = all_pred_ref_points[t].unsqueeze(0).detach().cpu().tolist()
286 | draw_reference_points(draw, ref_points, source_img.size, color=color_list[i%len(color_list)])
287 |
288 | # draw mask
289 | source_img = vis_add_mask(source_img, all_pred_masks[t], color_list[i%len(color_list)])
290 |
291 | # save
292 | save_visualize_path_dir = os.path.join(save_visualize_path_prefix, video, str(i))
293 | if not os.path.exists(save_visualize_path_dir):
294 | os.makedirs(save_visualize_path_dir)
295 | save_visualize_path = os.path.join(save_visualize_path_dir, frame + '.png')
296 | source_img.save(save_visualize_path)
297 |
298 |
299 | # save binary image
300 | save_path = os.path.join(save_path_prefix, video_name, exp_id)
301 | if not os.path.exists(save_path):
302 | os.makedirs(save_path)
303 | for j in range(video_len):
304 | frame_name = frames[j]
305 | confidence = all_pred_logits[j]
306 | mask = all_pred_masks[j].astype(np.float32)
307 | save_file = os.path.join(save_path, frame_name + ".png")
308 | # print(save_file)
309 | if 'pair_logits' in outputs.keys() and args.use_cls:
310 | mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1
311 | mask = Image.fromarray(mask * 255).convert('L')
312 | mask.save(save_file)
313 |
314 | # print(torch.nn.functional.mse_loss(text_sentence_features, pseudo_text_sentence_features))
315 | # if vid == 10:
316 | # bert_embed = torch.cat(sentence_features, dim=0)
317 | # np.save(f'output/bert_embed{args.val_type}.npy', bert_embed.cpu().numpy())
318 | # pseudo_bert_embed = torch.cat(pseudo_sentence_features, dim=0)
319 | # np.save(f'output/pseudo_bert_embed{args.val_type}.npy', pseudo_bert_embed.cpu().numpy())
320 |
321 | with lock:
322 | progress.update(1)
323 | # bert_embed = torch.cat(sentence_features, dim=0)
324 | # np.save('output/bert_embed.npy', bert_embed.cpu().numpy())
325 |
326 | result_dict[str(pid)] = num_all_frames
327 | with lock:
328 | progress.close()
329 |
330 |
331 | # visuaize functions
332 | def box_cxcywh_to_xyxy(x):
333 | x_c, y_c, w, h = x.unbind(1)
334 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
335 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
336 | return torch.stack(b, dim=1)
337 |
338 | def rescale_bboxes(out_bbox, size):
339 | img_w, img_h = size
340 | b = box_cxcywh_to_xyxy(out_bbox)
341 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
342 | return b
343 |
344 |
345 | # Visualization functions
346 | def draw_reference_points(draw, reference_points, img_size, color):
347 | W, H = img_size
348 | for i, ref_point in enumerate(reference_points):
349 | init_x, init_y = ref_point
350 | x, y = W * init_x, H * init_y
351 | cur_color = color
352 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4)
353 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4)
354 |
355 | def draw_sample_points(draw, sample_points, img_size, color_list):
356 | alpha = 255
357 | for i, samples in enumerate(sample_points):
358 | for sample in samples:
359 | x, y = sample
360 | cur_color = color_list[i % len(color_list)][::-1]
361 | cur_color += [alpha]
362 | draw.ellipse((x-2, y-2, x+2, y+2),
363 | fill=tuple(cur_color), outline=tuple(cur_color), width=1)
364 |
365 | def vis_add_mask(img, mask, color):
366 | origin_img = np.asarray(img.convert('RGB')).copy()
367 | color = np.array(color)
368 |
369 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np
370 | mask = mask > 0.5
371 |
372 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5
373 | origin_img = Image.fromarray(origin_img)
374 | return origin_img
375 |
376 |
377 |
378 | if __name__ == '__main__':
379 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()])
380 | args = parser.parse_args()
381 | main(args)
382 |
--------------------------------------------------------------------------------
/inference_ytvos_segm.py:
--------------------------------------------------------------------------------
1 | '''
2 | Inference code for ReferFormer, on Ref-Youtube-VOS
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | '''
5 | import argparse
6 | import json
7 | import random
8 | import sys
9 | import time
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 |
15 | import util.misc as utils
16 | from models import build_model
17 | import torchvision.transforms as T
18 | import matplotlib.pyplot as plt
19 | import os
20 | import cv2
21 | from PIL import Image, ImageDraw
22 | import math
23 | import torch.nn.functional as F
24 | import json
25 |
26 | import opts
27 | from tqdm import tqdm
28 |
29 | import multiprocessing as mp
30 | import threading
31 |
32 |
33 | from tools.colormap import colormap
34 |
35 |
36 | # colormap
37 | color_list = colormap()
38 | color_list = color_list.astype('uint8').tolist()
39 |
40 | def main(args):
41 | args.masks = True
42 | args.batch_size == 1
43 | print("Inference only supports for batch size = 1")
44 |
45 | global transform
46 | transform = T.Compose([
47 | T.Resize(args.inf_res),
48 | T.ToTensor(),
49 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
50 | ])
51 |
52 | # fix the seed for reproducibility
53 | seed = args.seed + utils.get_rank()
54 | torch.manual_seed(seed)
55 | np.random.seed(seed)
56 | random.seed(seed)
57 |
58 | split = args.split
59 | # save path
60 | output_dir = args.output_dir
61 | save_path_prefix = os.path.join(output_dir, split)
62 | if not os.path.exists(save_path_prefix):
63 | os.makedirs(save_path_prefix)
64 |
65 | save_visualize_path_prefix = os.path.join(output_dir, split + '_images')
66 | if args.visualize:
67 | if not os.path.exists(save_visualize_path_prefix):
68 | os.makedirs(save_visualize_path_prefix)
69 |
70 | if args.val_type:
71 | print(f'Warning: use robust rvos validation set, type: {args.val_type}')
72 | assert args.val_type in ['_random', '_color', '_old_color', '_object', '_action', '_dyn', '_1', '_3', '_5'], \
73 | f'Unknown val type for robust rvos, {args.val_type}'
74 | # load data
75 | root = Path(args.ytvos_path) # data/ref-youtube-vos
76 | img_folder = os.path.join(root, split, f"JPEGImages{args.val_type}")
77 | meta_file = os.path.join(root, "meta_expressions", split, f"meta_expressions{args.val_type}.json")
78 | with open(meta_file, "r") as f:
79 | data = json.load(f)["videos"]
80 | valid_test_videos = set(data.keys())
81 | # for some reasons the competition's validation expressions dict contains both the validation (202) &
82 | # test videos (305). so we simply load the test expressions dict and use it to filter out the test videos from
83 | # the validation expressions dict:
84 | test_meta_file = os.path.join(root, "meta_expressions", "test", "meta_expressions.json")
85 | with open(test_meta_file, 'r') as f:
86 | test_data = json.load(f)['videos']
87 | test_videos = set(test_data.keys())
88 | valid_videos = valid_test_videos - test_videos
89 | video_list = sorted([video for video in valid_videos])
90 | if 'color' not in args.val_type:
91 | assert len(video_list) == 202 or len(video_list) == 202 + 305, 'error: incorrect number of validation videos'
92 |
93 | # create subprocess
94 | thread_num = args.ngpu
95 | global result_dict
96 | result_dict = mp.Manager().dict()
97 |
98 | processes = []
99 | lock = threading.Lock()
100 |
101 | video_num = len(video_list)
102 | per_thread_video_num = video_num // thread_num
103 |
104 | start_time = time.time()
105 | print('Start inference')
106 | for i in range(thread_num):
107 | if i == thread_num - 1:
108 | sub_video_list = video_list[i * per_thread_video_num:]
109 | else:
110 | sub_video_list = video_list[i * per_thread_video_num: (i + 1) * per_thread_video_num]
111 | p = mp.Process(target=sub_processor, args=(lock, i, args, data,
112 | save_path_prefix, save_visualize_path_prefix,
113 | img_folder, sub_video_list))
114 | p.start()
115 | processes.append(p)
116 |
117 | for p in processes:
118 | p.join()
119 |
120 | end_time = time.time()
121 | total_time = end_time - start_time
122 |
123 | result_dict = dict(result_dict)
124 | num_all_frames_gpus = 0
125 | for pid, num_all_frames in result_dict.items():
126 | num_all_frames_gpus += num_all_frames
127 |
128 | print("Total inference time: %.4f s" %(total_time))
129 |
130 | def sub_processor(lock, pid, args, data, save_path_prefix, save_visualize_path_prefix, img_folder, video_list):
131 | text = 'processor %d' % pid
132 | with lock:
133 | progress = tqdm(
134 | total=len(video_list),
135 | position=pid,
136 | desc=text,
137 | ncols=0
138 | )
139 | torch.cuda.set_device(pid)
140 |
141 | # model
142 | model, criterion, _ = build_model(args)
143 | device = args.device
144 | model.to(device)
145 |
146 | model_without_ddp = model
147 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
148 |
149 | if pid == 0:
150 | print('number of params:', n_parameters)
151 |
152 | if args.resume:
153 | checkpoint = torch.load(args.resume, map_location='cpu')
154 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
155 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
156 | if len(missing_keys) > 0:
157 | print('Missing Keys: {}'.format(missing_keys))
158 | if len(unexpected_keys) > 0:
159 | print('Unexpected Keys: {}'.format(unexpected_keys))
160 | else:
161 | raise ValueError('Please specify the checkpoint for inference.')
162 |
163 |
164 | # start inference
165 | num_all_frames = 0
166 | model.eval()
167 |
168 | sentence_features = []
169 | pseudo_sentence_features = []
170 | # 1. For each video
171 | for vid, video in enumerate(video_list):
172 | metas = [] # list[dict], length is number of expressions
173 |
174 | expressions = data[video]["expressions"]
175 | expression_list = list(expressions.keys())
176 | num_expressions = len(expression_list)
177 | video_len = len(data[video]["frames"])
178 |
179 | # read all the anno meta
180 | for i in range(num_expressions):
181 | meta = {}
182 | meta["video"] = video
183 | meta["exp"] = expressions[expression_list[i]]["exp"]
184 | meta["exp_id"] = expression_list[i]
185 | meta["frames"] = data[video]["frames"]
186 | metas.append(meta)
187 | meta = metas
188 |
189 | vis = False
190 | # 2. For each expression
191 | for i in range(num_expressions):
192 | video_name = meta[i]["video"]
193 | exp = meta[i]["exp"]
194 | # TODO: temp
195 | exp_id = meta[i]["exp_id"]
196 | frames = meta[i]["frames"]
197 | if vis:
198 | os.makedirs('vis/tmp', exist_ok=True)
199 |
200 | video_len = len(frames)
201 | # store images
202 | imgs = []
203 | for t in range(video_len):
204 | frame = frames[t]
205 | img_path = os.path.join(img_folder, video_name, frame + ".jpg")
206 | # img_path = frame
207 | img = Image.open(img_path).convert('RGB')
208 | if vis:
209 | if t < 1:
210 | img.save(f'vis/tmp/{t}.png')
211 | origin_w, origin_h = img.size
212 | imgs.append(transform(img)) # list[img]
213 |
214 | segm_logits = []
215 | segm_masks = []
216 |
217 | for segm_idx in range(video_len // args.segm_frame):
218 | if segm_idx != (video_len // args.segm_frame) - 1:
219 | segm = imgs[segm_idx * args.segm_frame:(segm_idx+1) * args.segm_frame]
220 | else:
221 | segm = imgs[segm_idx * args.segm_frame:]
222 |
223 | segm_len = len(segm)
224 | segm = torch.stack(segm, dim=0).to(args.device) # [video_len, 3, h, w]
225 | img_h, img_w = segm.shape[-2:]
226 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device)
227 | target = {"size": size}
228 |
229 | # start = time.time()
230 | with torch.no_grad():
231 | outputs = model([segm], [exp], [target])
232 | # end = time.time()
233 | # print((end-start)/video_len)
234 |
235 | if vis:
236 | os.system(f'mv vis/tmp vis/{vid}-{exp.replace(" ", "_")}')
237 |
238 | pred_logits = outputs["pred_logits"][0]
239 | pred_boxes = outputs["pred_boxes"][0]
240 | pred_masks = outputs["pred_masks"][0]
241 | pred_ref_points = outputs["reference_points"][0]
242 | text_sentence_features = outputs['sentence_feature']
243 | if args.use_cycle:
244 | pseudo_text_sentence_features = outputs['pseudo_sentence_feature']
245 | # anchor = outputs['negative_anchor']
246 | sentence_features.append(text_sentence_features)
247 | pseudo_sentence_features.append(pseudo_text_sentence_features)
248 | # print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2))
249 | # print(anchor)
250 | # according to pred_logits, select the query index
251 | pred_scores = pred_logits.sigmoid() # [t, q, k]
252 | pred_score = pred_scores
253 | pred_scores = pred_scores.mean(0) # [q, k]
254 | max_scores, _ = pred_scores.max(-1) # [q,]
255 | # print(max_scores)
256 | _, max_ind = max_scores.max(-1) # [1,]
257 | max_inds = max_ind.repeat(segm_len)
258 | pred_masks = pred_masks[range(segm_len), max_inds, ...] # [t, h, w]
259 | pred_masks = pred_masks.unsqueeze(0)
260 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False)
261 | if args.save_prob:
262 | pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy()
263 | else:
264 | pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy()
265 | if args.use_score:
266 | pred_score = pred_score[range(segm_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1)
267 | pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks
268 |
269 | # store the video results
270 | all_pred_logits = pred_logits[range(segm_len), max_inds]
271 | all_pred_boxes = pred_boxes[range(segm_len), max_inds]
272 | all_pred_ref_points = pred_ref_points[range(segm_len), max_inds]
273 | all_pred_masks = pred_masks
274 | segm_logits.append(all_pred_logits.cpu().numpy())
275 | segm_masks.append(all_pred_masks)
276 |
277 | all_pred_logits = np.concatenate(segm_logits, axis=0)
278 | all_pred_masks = np.concatenate(segm_masks, axis=0)
279 | assert all_pred_masks.shape[0] == video_len
280 | # save binary image
281 | save_path = os.path.join(save_path_prefix, video_name, exp_id)
282 | if not os.path.exists(save_path):
283 | os.makedirs(save_path)
284 | for j in range(video_len):
285 | frame_name = frames[j]
286 | mask = all_pred_masks[j].astype(np.float32)
287 | save_file = os.path.join(save_path, frame_name + ".png")
288 | if 'pair_logits' in outputs.keys() and args.use_cls:
289 | mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1
290 | mask = Image.fromarray(mask * 255).convert('L')
291 | mask.save(save_file)
292 |
293 | with lock:
294 | progress.update(1)
295 | # bert_embed = torch.cat(sentence_features, dim=0)
296 | # np.save('output/bert_embed.npy', bert_embed.cpu().numpy())
297 |
298 | result_dict[str(pid)] = num_all_frames
299 | with lock:
300 | progress.close()
301 |
302 |
303 | # visuaize functions
304 | def box_cxcywh_to_xyxy(x):
305 | x_c, y_c, w, h = x.unbind(1)
306 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
307 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
308 | return torch.stack(b, dim=1)
309 |
310 | def rescale_bboxes(out_bbox, size):
311 | img_w, img_h = size
312 | b = box_cxcywh_to_xyxy(out_bbox)
313 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
314 | return b
315 |
316 |
317 | # Visualization functions
318 | def draw_reference_points(draw, reference_points, img_size, color):
319 | W, H = img_size
320 | for i, ref_point in enumerate(reference_points):
321 | init_x, init_y = ref_point
322 | x, y = W * init_x, H * init_y
323 | cur_color = color
324 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4)
325 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4)
326 |
327 | def draw_sample_points(draw, sample_points, img_size, color_list):
328 | alpha = 255
329 | for i, samples in enumerate(sample_points):
330 | for sample in samples:
331 | x, y = sample
332 | cur_color = color_list[i % len(color_list)][::-1]
333 | cur_color += [alpha]
334 | draw.ellipse((x-2, y-2, x+2, y+2),
335 | fill=tuple(cur_color), outline=tuple(cur_color), width=1)
336 |
337 | def vis_add_mask(img, mask, color):
338 | origin_img = np.asarray(img.convert('RGB')).copy()
339 | color = np.array(color)
340 |
341 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np
342 | mask = mask > 0.5
343 |
344 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5
345 | origin_img = Image.fromarray(origin_img)
346 | return origin_img
347 |
348 |
349 |
350 | if __name__ == '__main__':
351 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()])
352 | args = parser.parse_args()
353 | main(args)
354 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | """
2 | Training script of ReferFormer
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import argparse
6 | import datetime
7 | import json
8 | import random
9 | import time
10 | from pathlib import Path
11 |
12 | import numpy as np
13 | import torch
14 | from torch.utils.data import DataLoader, DistributedSampler
15 | import torch.distributed as dist
16 |
17 | import util.misc as utils
18 | import datasets.samplers as samplers
19 | from datasets import build_dataset, get_coco_api_from_dataset
20 | from engine import train_one_epoch, evaluate, evaluate_a2d
21 | from models import build_model
22 |
23 | from tools.load_pretrained_weights import pre_trained_model_to_finetune
24 | from tools.warmup_poly_lr_scheduler import WarmupPolyLR
25 | import opts
26 |
27 |
28 |
29 | def main(args):
30 | args.masks = True
31 |
32 | utils.init_distributed_mode(args)
33 | print("git:\n {}\n".format(utils.get_sha()))
34 | print(args)
35 |
36 | print(f'\n Run on {args.dataset_file} dataset.')
37 | print('\n')
38 |
39 | device = torch.device(args.device)
40 |
41 | # fix the seed for reproducibility
42 | seed = args.seed + utils.get_rank()
43 | torch.manual_seed(seed)
44 | np.random.seed(seed)
45 | random.seed(seed)
46 |
47 | model, criterion, postprocessor = build_model(args)
48 | model.to(device)
49 |
50 | model_without_ddp = model
51 | if args.distributed:
52 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
53 | model_without_ddp = model.module
54 |
55 | # for n, p in model_without_ddp.named_parameters():
56 | # if p.requires_grad:
57 | # print(n)
58 |
59 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
60 | print('number of params:', n_parameters)
61 |
62 | def match_name_keywords(n, name_keywords):
63 | out = False
64 | for b in name_keywords:
65 | if b in n:
66 | out = True
67 | break
68 | return out
69 |
70 | if args.only_cycle:
71 | for n, p in model_without_ddp.named_parameters():
72 | if not (match_name_keywords(n, 'text_query') or match_name_keywords(n, 'text_decoder')):
73 | p.requires_grad = False
74 |
75 | param_dicts = [
76 | {
77 | "params":
78 | [p for n, p in model_without_ddp.named_parameters()
79 | if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_text_encoder_names)
80 | and not match_name_keywords(n, args.lr_linear_proj_names) and not match_name_keywords(n, args.lr_anchor_names) and p.requires_grad],
81 | "lr": args.lr * args.lr_multi,
82 | },
83 | {
84 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
85 | "lr": args.lr_backbone * args.lr_multi,
86 | },
87 | {
88 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_text_encoder_names) and p.requires_grad],
89 | "lr": args.lr_text_encoder * args.lr_multi,
90 | },
91 | {
92 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
93 | "lr": args.lr * args.lr_linear_proj_mult * args.lr_multi,
94 | },
95 | {
96 | "params": [p for n, p in model_without_ddp.named_parameters() if
97 | match_name_keywords(n, args.lr_anchor_names) and p.requires_grad],
98 | "lr": args.lr * args.lr_anchor_mult * args.lr_multi,
99 | }
100 | ]
101 |
102 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
103 | weight_decay=args.weight_decay)
104 |
105 | # no validation ground truth for ytvos dataset
106 | dataset_train = build_dataset(args.dataset_file, image_set='train', args=args)
107 |
108 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop)
109 |
110 | # max_iter = len(dataset_train) / dist.get_world_size() / args.batch_size * args.epochs
111 | # lr_scheduler = WarmupPolyLR(optimizer, max_iter, warmup_iters=100, warmup_factor=0.01)
112 | # print(f"dataset size: {len(dataset_train)}, max_iter: {max_iter}")
113 |
114 | if args.distributed:
115 | if args.cache_mode:
116 | sampler_train = samplers.NodeDistributedSampler(dataset_train)
117 | else:
118 | sampler_train = samplers.DistributedSampler(dataset_train)
119 | else:
120 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
121 |
122 | batch_sampler_train = torch.utils.data.BatchSampler(
123 | sampler_train, args.batch_size, drop_last=True)
124 |
125 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
126 | collate_fn=utils.collate_fn, num_workers=args.num_workers)
127 |
128 | # A2D-Sentences
129 | if args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb' or args.dataset_file == 'refcocoa2d' or args.dataset_file=='refcocoVideo':
130 | dataset_val = build_dataset(args.dataset_file, image_set='val', args=args)
131 | if args.distributed:
132 | if args.cache_mode:
133 | sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
134 | else:
135 | sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
136 | else:
137 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
138 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
139 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
140 | pin_memory=True)
141 |
142 |
143 | if args.dataset_file == "davis":
144 | assert args.pretrained_weights is not None, "Please provide the pretrained weight to finetune for Ref-DAVIS17"
145 | print("============================================>")
146 | print("Ref-DAVIS17 are finetuned using the checkpoint trained on Ref-Youtube-VOS")
147 | print("Load checkpoint weights from {} ...".format(args.pretrained_weights))
148 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu")
149 | checkpoint_dict = pre_trained_model_to_finetune(checkpoint, args)
150 | model_without_ddp.load_state_dict(checkpoint_dict, strict=False)
151 | print("============================================>")
152 |
153 | if args.dataset_file == "jhmdb":
154 | assert args.resume is not None, "Please provide the checkpoint to resume for JHMDB-Sentences"
155 | print("============================================>")
156 | print("JHMDB-Sentences are directly evaluated using the checkpoint trained on A2D-Sentences")
157 | print("Load checkpoint weights from {} ...".format(args.pretrained_weights))
158 | # load checkpoint in the args.resume
159 | print("============================================>")
160 |
161 | # for Ref-Youtube-VOS and A2D-Sentences
162 | # finetune using the pretrained weights on Ref-COCO
163 | if args.dataset_file != "davis" and args.dataset_file != "jhmdb" and args.pretrained_weights is not None:
164 | print("============================================>")
165 | print("Load pretrained weights from {} ...".format(args.pretrained_weights))
166 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu")
167 | checkpoint_dict = pre_trained_model_to_finetune(checkpoint, args)
168 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint_dict, strict=False)
169 | print(checkpoint_dict.keys())
170 | print("============================================>")
171 | print(missing_keys)
172 | print("============================================>")
173 | print(unexpected_keys)
174 | print("============================================>")
175 |
176 |
177 | output_dir = Path(args.output_dir)
178 | if args.resume:
179 | if args.resume.startswith('https'):
180 | checkpoint = torch.hub.load_state_dict_from_url(
181 | args.resume, map_location='cpu', check_hash=True)
182 | else:
183 | checkpoint = torch.load(args.resume, map_location='cpu')
184 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
185 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
186 | if len(missing_keys) > 0:
187 | print('Missing Keys: {}'.format(missing_keys))
188 | if len(unexpected_keys) > 0:
189 | print('Unexpected Keys: {}'.format(unexpected_keys))
190 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
191 | import copy
192 | p_groups = copy.deepcopy(optimizer.param_groups)
193 | optimizer.load_state_dict(checkpoint['optimizer'])
194 | for pg, pg_old in zip(optimizer.param_groups, p_groups):
195 | pg['lr'] = pg_old['lr']
196 | pg['initial_lr'] = pg_old['initial_lr']
197 | # print(optimizer.param_groups)
198 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
199 | # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
200 | args.override_resumed_lr_drop = True
201 | if args.override_resumed_lr_drop:
202 | print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
203 | lr_scheduler.step_size = args.lr_drop
204 | lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
205 | lr_scheduler.step(lr_scheduler.last_epoch)
206 | args.start_epoch = checkpoint['epoch'] + 1
207 |
208 | if args.eval:
209 | assert args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb' or args.dataset_file == 'refcocoVideo', \
210 | 'Only A2D-Sentences and JHMDB-Sentences datasets support evaluation'
211 | test_stats = evaluate_a2d(model, data_loader_val, postprocessor, device, args)
212 | return
213 |
214 |
215 | print("Start training")
216 | start_time = time.time()
217 | for epoch in range(args.start_epoch, args.epochs):
218 | if args.distributed:
219 | sampler_train.set_epoch(epoch)
220 | train_stats = train_one_epoch(
221 | model, criterion, data_loader_train, optimizer, device, lr_scheduler, epoch,
222 | args.clip_max_norm)
223 | lr_scheduler.step()
224 | if args.output_dir:
225 | checkpoint_paths = [output_dir / 'checkpoint.pth']
226 | # extra checkpoint before LR drop and every epochs
227 | # if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0:
228 | if (epoch + 1) % 1 == 0:
229 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
230 | for checkpoint_path in checkpoint_paths:
231 | utils.save_on_master({
232 | 'model': model_without_ddp.state_dict(),
233 | 'optimizer': optimizer.state_dict(),
234 | 'lr_scheduler': lr_scheduler.state_dict(),
235 | 'epoch': epoch,
236 | 'args': args,
237 | }, checkpoint_path)
238 |
239 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
240 | 'epoch': epoch,
241 | 'n_parameters': n_parameters}
242 |
243 | if args.dataset_file == 'a2d' or args.dataset_file == 'refcocoa2d':
244 | test_stats = evaluate_a2d(model, data_loader_val, postprocessor, device, args)
245 | log_stats.update({**{f'{k}': v for k, v in test_stats.items()}})
246 |
247 | if args.output_dir and utils.is_main_process():
248 | with (output_dir / "log.txt").open("a") as f:
249 | f.write(json.dumps(log_stats) + "\n")
250 |
251 |
252 | total_time = time.time() - start_time
253 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
254 | print('Training time {}'.format(total_time_str))
255 |
256 |
257 | if __name__ == '__main__':
258 | parser = argparse.ArgumentParser('RVOSNet training and evaluation script', parents=[opts.get_args_parser()])
259 | args = parser.parse_args()
260 | if args.output_dir:
261 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
262 | main(args)
263 |
264 |
265 |
266 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .referformer import build
2 |
3 | def build_model(args):
4 | return build(args)
5 |
--------------------------------------------------------------------------------
/models/amm_resnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import Tensor
3 | import torch.nn as nn
4 | from torchvision.models.utils import load_state_dict_from_url
5 | from typing import Type, Any, Callable, Union, List, Optional
6 |
7 |
8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'amm_resnet50', 'amm_resnet101',
9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
10 | 'wide_resnet50_2', 'wide_resnet101_2']
11 |
12 |
13 | model_urls = {
14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
23 | }
24 |
25 |
26 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
27 | """3x3 convolution with padding"""
28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
29 | padding=dilation, groups=groups, bias=False, dilation=dilation)
30 |
31 |
32 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
33 | """1x1 convolution"""
34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
35 |
36 |
37 | class BasicBlock(nn.Module):
38 | expansion: int = 1
39 |
40 | def __init__(
41 | self,
42 | inplanes: int,
43 | planes: int,
44 | stride: int = 1,
45 | downsample: Optional[nn.Module] = None,
46 | groups: int = 1,
47 | base_width: int = 64,
48 | dilation: int = 1,
49 | norm_layer: Optional[Callable[..., nn.Module]] = None
50 | ) -> None:
51 | super(BasicBlock, self).__init__()
52 | if norm_layer is None:
53 | norm_layer = nn.BatchNorm2d
54 | if groups != 1 or base_width != 64:
55 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
56 | if dilation > 1:
57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
59 | self.conv1 = conv3x3(inplanes, planes, stride)
60 | self.bn1 = norm_layer(planes)
61 | self.relu = nn.ReLU(inplace=True)
62 | self.conv2 = conv3x3(planes, planes)
63 | self.bn2 = norm_layer(planes)
64 | self.downsample = downsample
65 | self.stride = stride
66 |
67 | def forward(self, x: Tensor) -> Tensor:
68 | identity = x
69 |
70 | out = self.conv1(x)
71 | out = self.bn1(out)
72 | out = self.relu(out)
73 |
74 | out = self.conv2(out)
75 | out = self.bn2(out)
76 |
77 | if self.downsample is not None:
78 | identity = self.downsample(x)
79 |
80 | out += identity
81 | out = self.relu(out)
82 |
83 | return out
84 |
85 |
86 | class Bottleneck(nn.Module):
87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1)
89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
90 | # This variant is also known as ResNet V1.5 and improves accuracy according to
91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
92 |
93 | expansion: int = 4
94 |
95 | def __init__(
96 | self,
97 | inplanes: int,
98 | planes: int,
99 | stride: int = 1,
100 | downsample: Optional[nn.Module] = None,
101 | groups: int = 1,
102 | base_width: int = 64,
103 | dilation: int = 1,
104 | norm_layer: Optional[Callable[..., nn.Module]] = None
105 | ) -> None:
106 | super(Bottleneck, self).__init__()
107 | if norm_layer is None:
108 | norm_layer = nn.BatchNorm2d
109 | width = int(planes * (base_width / 64.)) * groups
110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1
111 | self.conv1 = conv1x1(inplanes, width)
112 | self.bn1 = norm_layer(width)
113 | self.conv2 = conv3x3(width, width, stride, groups, dilation)
114 | self.bn2 = norm_layer(width)
115 | self.conv3 = conv1x1(width, planes * self.expansion)
116 | self.bn3 = norm_layer(planes * self.expansion)
117 | self.relu = nn.ReLU(inplace=True)
118 | self.downsample = downsample
119 | self.stride = stride
120 |
121 | def forward(self, x: Tensor) -> Tensor:
122 | identity = x
123 |
124 | out = self.conv1(x)
125 | out = self.bn1(out)
126 | out = self.relu(out)
127 |
128 | out = self.conv2(out)
129 | out = self.bn2(out)
130 | out = self.relu(out)
131 |
132 | out = self.conv3(out)
133 | out = self.bn3(out)
134 |
135 | if self.downsample is not None:
136 | identity = self.downsample(x)
137 |
138 | out += identity
139 | out = self.relu(out)
140 |
141 | return out
142 |
143 |
144 | class ResNet(nn.Module):
145 |
146 | def __init__(
147 | self,
148 | block: Type[Union[BasicBlock, Bottleneck]],
149 | layers: List[int],
150 | num_classes: int = 1000,
151 | zero_init_residual: bool = False,
152 | groups: int = 1,
153 | width_per_group: int = 64,
154 | replace_stride_with_dilation: Optional[List[bool]] = None,
155 | norm_layer: Optional[Callable[..., nn.Module]] = None
156 | ) -> None:
157 | super(ResNet, self).__init__()
158 | if norm_layer is None:
159 | norm_layer = nn.BatchNorm2d
160 | self._norm_layer = norm_layer
161 |
162 | self.inplanes = 64
163 | self.dilation = 1
164 | if replace_stride_with_dilation is None:
165 | # each element in the tuple indicates if we should replace
166 | # the 2x2 stride with a dilated convolution instead
167 | replace_stride_with_dilation = [False, False, False]
168 | if len(replace_stride_with_dilation) != 3:
169 | raise ValueError("replace_stride_with_dilation should be None "
170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
171 | self.groups = groups
172 | self.base_width = width_per_group
173 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
174 | bias=False)
175 | self.bn1 = norm_layer(self.inplanes)
176 | self.relu = nn.ReLU(inplace=True)
177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
178 | self.layer1 = self._make_layer(block, 64, layers[0])
179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
180 | dilate=replace_stride_with_dilation[0])
181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
182 | dilate=replace_stride_with_dilation[1])
183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
184 | dilate=replace_stride_with_dilation[2])
185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
186 | self.fc = nn.Linear(512 * block.expansion, num_classes)
187 |
188 | for m in self.modules():
189 | if isinstance(m, nn.Conv2d):
190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
192 | nn.init.constant_(m.weight, 1)
193 | nn.init.constant_(m.bias, 0)
194 |
195 | # Zero-initialize the last BN in each residual branch,
196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
198 | if zero_init_residual:
199 | for m in self.modules():
200 | if isinstance(m, Bottleneck):
201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
202 | elif isinstance(m, BasicBlock):
203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
204 |
205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
206 | stride: int = 1, dilate: bool = False) -> nn.Sequential:
207 | norm_layer = self._norm_layer
208 | downsample = None
209 | previous_dilation = self.dilation
210 | if dilate:
211 | self.dilation *= stride
212 | stride = 1
213 | if stride != 1 or self.inplanes != planes * block.expansion:
214 | downsample = nn.Sequential(
215 | conv1x1(self.inplanes, planes * block.expansion, stride),
216 | norm_layer(planes * block.expansion),
217 | )
218 |
219 | layers = []
220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
221 | self.base_width, previous_dilation, norm_layer))
222 | self.inplanes = planes * block.expansion
223 | for _ in range(1, blocks):
224 | layers.append(block(self.inplanes, planes, groups=self.groups,
225 | base_width=self.base_width, dilation=self.dilation,
226 | norm_layer=norm_layer))
227 |
228 | return nn.Sequential(*layers)
229 |
230 | def _forward_impl(self, x: Tensor) -> Tensor:
231 | # See note [TorchScript super()]
232 | x = self.conv1(x)
233 | x = self.bn1(x)
234 | x = self.relu(x)
235 | x = self.maxpool(x)
236 |
237 | x = self.layer1(x)
238 | x = self.layer2(x)
239 | x = self.layer3(x)
240 | x = self.layer4(x)
241 |
242 | x = self.avgpool(x)
243 | x = torch.flatten(x, 1)
244 | x = self.fc(x)
245 |
246 | return x
247 |
248 | def forward(self, x: Tensor) -> Tensor:
249 | return self._forward_impl(x)
250 |
251 |
252 | def _resnet(
253 | arch: str,
254 | block: Type[Union[BasicBlock, Bottleneck]],
255 | layers: List[int],
256 | pretrained: bool,
257 | progress: bool,
258 | **kwargs: Any
259 | ) -> ResNet:
260 | model = ResNet(block, layers, **kwargs)
261 | if pretrained:
262 | state_dict = load_state_dict_from_url(model_urls[arch],
263 | progress=progress)
264 | model.load_state_dict(state_dict)
265 | return model
266 |
267 |
268 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
269 | r"""ResNet-18 model from
270 | `"Deep Residual Learning for Image Recognition" `_.
271 |
272 | Args:
273 | pretrained (bool): If True, returns a model pre-trained on ImageNet
274 | progress (bool): If True, displays a progress bar of the download to stderr
275 | """
276 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
277 | **kwargs)
278 |
279 |
280 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
281 | r"""ResNet-34 model from
282 | `"Deep Residual Learning for Image Recognition" `_.
283 |
284 | Args:
285 | pretrained (bool): If True, returns a model pre-trained on ImageNet
286 | progress (bool): If True, displays a progress bar of the download to stderr
287 | """
288 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
289 | **kwargs)
290 |
291 |
292 | def amm_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
293 | r"""ResNet-50 model from
294 | `"Deep Residual Learning for Image Recognition" `_.
295 |
296 | Args:
297 | pretrained (bool): If True, returns a model pre-trained on ImageNet
298 | progress (bool): If True, displays a progress bar of the download to stderr
299 | """
300 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
301 | **kwargs)
302 |
303 |
304 | def amm_resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
305 | r"""ResNet-101 model from
306 | `"Deep Residual Learning for Image Recognition" `_.
307 |
308 | Args:
309 | pretrained (bool): If True, returns a model pre-trained on ImageNet
310 | progress (bool): If True, displays a progress bar of the download to stderr
311 | """
312 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
313 | **kwargs)
314 |
315 |
316 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
317 | r"""ResNet-152 model from
318 | `"Deep Residual Learning for Image Recognition" `_.
319 |
320 | Args:
321 | pretrained (bool): If True, returns a model pre-trained on ImageNet
322 | progress (bool): If True, displays a progress bar of the download to stderr
323 | """
324 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
325 | **kwargs)
326 |
327 |
328 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
329 | r"""ResNeXt-50 32x4d model from
330 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
331 |
332 | Args:
333 | pretrained (bool): If True, returns a model pre-trained on ImageNet
334 | progress (bool): If True, displays a progress bar of the download to stderr
335 | """
336 | kwargs['groups'] = 32
337 | kwargs['width_per_group'] = 4
338 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
339 | pretrained, progress, **kwargs)
340 |
341 |
342 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
343 | r"""ResNeXt-101 32x8d model from
344 | `"Aggregated Residual Transformation for Deep Neural Networks" `_.
345 |
346 | Args:
347 | pretrained (bool): If True, returns a model pre-trained on ImageNet
348 | progress (bool): If True, displays a progress bar of the download to stderr
349 | """
350 | kwargs['groups'] = 32
351 | kwargs['width_per_group'] = 8
352 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
353 | pretrained, progress, **kwargs)
354 |
355 |
356 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
357 | r"""Wide ResNet-50-2 model from
358 | `"Wide Residual Networks" `_.
359 |
360 | The model is the same as ResNet except for the bottleneck number of channels
361 | which is twice larger in every block. The number of channels in outer 1x1
362 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
363 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
364 |
365 | Args:
366 | pretrained (bool): If True, returns a model pre-trained on ImageNet
367 | progress (bool): If True, displays a progress bar of the download to stderr
368 | """
369 | kwargs['width_per_group'] = 64 * 2
370 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
371 | pretrained, progress, **kwargs)
372 |
373 |
374 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
375 | r"""Wide ResNet-101-2 model from
376 | `"Wide Residual Networks" `_.
377 |
378 | The model is the same as ResNet except for the bottleneck number of channels
379 | which is twice larger in every block. The number of channels in outer 1x1
380 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
381 | channels, and in Wide ResNet-50-2 has 2048-1024-2048.
382 |
383 | Args:
384 | pretrained (bool): If True, returns a model pre-trained on ImageNet
385 | progress (bool): If True, displays a progress bar of the download to stderr
386 | """
387 | kwargs['width_per_group'] = 64 * 2
388 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
389 | pretrained, progress, **kwargs)
390 |
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | """
2 | Backbone modules.
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import os
6 | import sys
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | import torchvision
12 | from torch import nn
13 | # from torchvision.models._utils import IntermediateLayerGetter
14 | from typing import Dict, List
15 | from einops import rearrange
16 |
17 | from .amm_resnet import amm_resnet50, amm_resnet101
18 | from util.misc import NestedTensor, is_main_process
19 |
20 | from .position_encoding import build_position_encoding
21 |
22 |
23 | class FrozenBatchNorm2d(torch.nn.Module):
24 | """
25 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
26 |
27 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
28 | without which any other models than torchvision.models.resnet[18,34,50,101]
29 | produce nans.
30 | """
31 |
32 | def __init__(self, n):
33 | super(FrozenBatchNorm2d, self).__init__()
34 | self.register_buffer("weight", torch.ones(n))
35 | self.register_buffer("bias", torch.zeros(n))
36 | self.register_buffer("running_mean", torch.zeros(n))
37 | self.register_buffer("running_var", torch.ones(n))
38 |
39 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
40 | missing_keys, unexpected_keys, error_msgs):
41 | num_batches_tracked_key = prefix + 'num_batches_tracked'
42 | if num_batches_tracked_key in state_dict:
43 | del state_dict[num_batches_tracked_key]
44 |
45 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
46 | state_dict, prefix, local_metadata, strict,
47 | missing_keys, unexpected_keys, error_msgs)
48 |
49 | def forward(self, x):
50 | # move reshapes to the beginning
51 | # to make it fuser-friendly
52 | w = self.weight.reshape(1, -1, 1, 1)
53 | b = self.bias.reshape(1, -1, 1, 1)
54 | rv = self.running_var.reshape(1, -1, 1, 1)
55 | rm = self.running_mean.reshape(1, -1, 1, 1)
56 | eps = 1e-5
57 | scale = w * (rv + eps).rsqrt()
58 | bias = b - rm * scale
59 | return x * scale + bias
60 |
61 |
62 | class BackboneBase(nn.Module):
63 |
64 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool):
65 | super().__init__()
66 | for name, parameter in backbone.named_parameters():
67 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
68 | parameter.requires_grad_(False)
69 | if return_interm_layers:
70 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
71 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} deformable detr
72 | self.strides = [4, 8, 16, 32]
73 | self.num_channels = [256, 512, 1024, 2048]
74 | else:
75 | return_layers = {'layer4': "0"}
76 | self.strides = [32]
77 | self.num_channels = [2048]
78 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
79 |
80 | def forward(self, tensor_list: NestedTensor):
81 | xs = self.body(tensor_list.tensors)
82 | out: Dict[str, NestedTensor] = {}
83 | for name, x in xs.items():
84 | m = tensor_list.mask
85 | assert m is not None
86 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
87 | out[name] = NestedTensor(x, mask)
88 | return out
89 |
90 |
91 | class Backbone(BackboneBase):
92 | """ResNet backbone with frozen BatchNorm."""
93 | def __init__(self, name: str,
94 | train_backbone: bool,
95 | return_interm_layers: bool,
96 | dilation: bool):
97 | backbone = getattr(torchvision.models, name)(
98 | replace_stride_with_dilation=[False, False, dilation],
99 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
100 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
101 | super().__init__(backbone, train_backbone, return_interm_layers)
102 | if dilation:
103 | self.strides[-1] = self.strides[-1] // 2
104 |
105 | class AMM_Backbone(BackboneBase):
106 | """ResNet backbone with frozen BatchNorm."""
107 | def __init__(self, name: str,
108 | train_backbone: bool,
109 | return_interm_layers: bool,
110 | dilation: bool):
111 | assert name in ['amm_resnet50', 'amm_resnet101'], 'unsupported backbone type'
112 | net = amm_resnet50 if name == 'amm_resnet50' else amm_resnet101()
113 | backbone = net(replace_stride_with_dilation=[False, False, dilation],
114 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
115 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded"
116 | super().__init__(backbone, train_backbone, return_interm_layers)
117 | if dilation:
118 | self.strides[-1] = self.strides[-1] // 2
119 |
120 | class Joiner(nn.Sequential):
121 | def __init__(self, backbone, position_embedding):
122 | super().__init__(backbone, position_embedding)
123 | self.strides = backbone.strides
124 | self.num_channels = backbone.num_channels
125 |
126 |
127 | def forward(self, tensor_list: NestedTensor):
128 | tensor_list.tensors = rearrange(tensor_list.tensors, 'b t c h w -> (b t) c h w')
129 | tensor_list.mask = rearrange(tensor_list.mask, 'b t h w -> (b t) h w')
130 |
131 | xs = self[0](tensor_list)
132 | out: List[NestedTensor] = []
133 | pos = []
134 | for name, x in xs.items():
135 | out.append(x)
136 | # position encoding
137 | pos.append(self[1](x).to(x.tensors.dtype))
138 | return out, pos
139 |
140 | class IntermediateLayerGetter(nn.ModuleDict):
141 | """
142 | Module wrapper that returns intermediate layers from a model
143 |
144 | It has a strong assumption that the modules have been registered
145 | into the model in the same order as they are used.
146 | This means that one should **not** reuse the same nn.Module
147 | twice in the forward if you want this to work.
148 |
149 | Additionally, it is only able to query submodules that are directly
150 | assigned to the model. So if `model` is passed, `model.feature1` can
151 | be returned, but not `model.feature1.layer2`.
152 |
153 | Args:
154 | model (nn.Module): model on which we will extract the features
155 | return_layers (Dict[name, new_name]): a dict containing the names
156 | of the modules for which the activations will be returned as
157 | the key of the dict, and the value of the dict is the name
158 | of the returned activation (which the user can specify).
159 | """
160 | _version = 2
161 | __annotations__ = {
162 | "return_layers": Dict[str, str],
163 | }
164 |
165 | def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None:
166 | if not set(return_layers).issubset([name for name, _ in model.named_children()]):
167 | raise ValueError("return_layers are not present in model")
168 | orig_return_layers = return_layers
169 | return_layers = {str(k): str(v) for k, v in return_layers.items()}
170 | layers = OrderedDict()
171 | for name, module in model.named_children():
172 | layers[name] = module
173 | if name in return_layers:
174 | del return_layers[name]
175 | if not return_layers:
176 | break
177 |
178 | super(IntermediateLayerGetter, self).__init__(layers)
179 | self.return_layers = orig_return_layers
180 |
181 | def forward(self, x):
182 | out = OrderedDict()
183 | for name, module in self.items():
184 | x = module(x)
185 | if name in self.return_layers:
186 | out_name = self.return_layers[name]
187 | out[out_name] = x
188 | return out
189 |
190 |
191 | def build_backbone(args):
192 | position_embedding = build_position_encoding(args)
193 | train_backbone = args.lr_backbone > 0
194 | return_interm_layers = args.masks or (args.num)
195 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
196 | model = Joiner(backbone, position_embedding)
197 | model.num_channels = backbone.num_channels
198 | return model
199 |
200 | def build_amm_backbone(args):
201 | position_embedding = build_position_encoding(args)
202 | train_backbone = args.lr_backbone > 0
203 | return_interm_layers = args.masks or (args.num)
204 | backbone = AMM_Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation)
205 | model = Joiner(backbone, position_embedding)
206 | model.num_channels = backbone.num_channels
207 | return model
--------------------------------------------------------------------------------
/models/cycle.py:
--------------------------------------------------------------------------------
1 | from segmentation import *
2 | from collections import defaultdict
3 | from typing import List, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 | from PIL import Image
10 | import cv2
11 | import numpy as np
12 | from einops import rearrange, repeat
13 |
14 | try:
15 | from panopticapi.utils import id2rgb, rgb2id
16 | except ImportError:
17 | pass
18 |
19 | import fvcore.nn.weight_init as weight_init
20 |
21 | from .position_encoding import PositionEmbeddingSine1D
22 |
23 | BN_MOMENTUM = 0.1
24 |
25 | class S2EPrime(nn.Module):
26 | def __init__(self, feature_channels: List, conv_dim: int, mask_dim: int, dim_feedforward: int = 2048, norm=None,
27 | return_query=False, stage_num=1):
28 | """
29 | Args:
30 | feature_channels: list of fpn feature channel numbers.
31 | conv_dim: number of output channels for the intermediate conv layers.
32 | mask_dim: number of output channels for the final conv layer.
33 | dim_feedforward: number of vision-language fusion module ffn channel numbers.
34 | norm (str or callable): normalization for all conv layers
35 | """
36 | super().__init__()
37 |
38 | self.feature_channels = feature_channels
39 |
40 | lateral_convs = []
41 | output_convs = []
42 |
43 | use_bias = norm == ""
44 | for idx, in_channels in enumerate(feature_channels):
45 | # in_channels: 4x -> 32x
46 | lateral_norm = get_norm(norm, conv_dim)
47 | output_norm = get_norm(norm, conv_dim)
48 |
49 | lateral_conv = Conv2d(
50 | in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm
51 | )
52 | output_conv = Conv2d(
53 | conv_dim,
54 | conv_dim,
55 | kernel_size=3,
56 | stride=1,
57 | padding=1,
58 | bias=use_bias,
59 | norm=output_norm,
60 | activation=F.relu,
61 | )
62 | weight_init.c2_xavier_fill(lateral_conv)
63 | weight_init.c2_xavier_fill(output_conv)
64 |
65 | stage = idx + 1
66 |
67 | self.add_module("adapter_{}".format(stage), lateral_conv)
68 | self.add_module("layer_{}".format(stage), output_conv)
69 | lateral_convs.append(lateral_conv)
70 | output_convs.append(output_conv)
71 |
72 | # Place convs into top-down order (from low to high resolution)
73 | # to make the top-down computation in forward clearer.
74 | self.lateral_convs = lateral_convs[::-1]
75 | self.output_convs = output_convs[::-1]
76 |
77 | self.mask_dim = mask_dim
78 | self.mask_features = Conv2d(
79 | conv_dim,
80 | mask_dim,
81 | kernel_size=3,
82 | stride=1,
83 | padding=1,
84 | )
85 | weight_init.c2_xavier_fill(self.mask_features)
86 |
87 | # vision-language cross-modal fusion
88 | self.text_pos = PositionEmbeddingSine1D(conv_dim, normalize=True)
89 | sr_ratios = [8, 4, 2, 1]
90 |
91 | self.stage_num = stage_num
92 | self.dyn_conv_list = []
93 | self.dyn_filter_list = []
94 | self.cross_attn_list = []
95 | self.pool_conv_list = []
96 | self.out_conv_list = []
97 | for i in range(self.stage_num):
98 | # init dyn conv
99 | output_norm = get_norm(norm, conv_dim)
100 | self.dyn_conv = Conv2d(
101 | conv_dim,
102 | conv_dim,
103 | kernel_size=1,
104 | stride=1,
105 | padding=0,
106 | bias=use_bias,
107 | norm=output_norm,
108 | activation=F.relu,
109 | )
110 | self.dyn_filter = nn.Linear(conv_dim, conv_dim, bias=True)
111 | weight_init.c2_xavier_fill(self.dyn_conv)
112 | weight_init.c2_xavier_fill(self.dyn_filter)
113 |
114 | # init text2video attn
115 | self.cross_attn = DualVisionLanguageBlock(conv_dim, dim_feedforward=dim_feedforward,
116 | nhead=8, sr_ratio=sr_ratios[-1])
117 | for p in self.cross_attn.parameters():
118 | if p.dim() > 1:
119 | nn.init.xavier_uniform_(p)
120 | text_pool_size = 3
121 | self.pool = nn.AdaptiveAvgPool1d(text_pool_size)
122 | self.pool_conv = Conv2d(
123 | text_pool_size*conv_dim,
124 | conv_dim,
125 | kernel_size=1,
126 | stride=1,
127 | padding=0,
128 | bias=use_bias,
129 | norm=output_norm,
130 | activation=F.relu,
131 | )
132 | weight_init.c2_xavier_fill(self.pool_conv)
133 | # init out conv for stage != 0
134 | if i != 0:
135 | output_norm = get_norm(norm, conv_dim)
136 | output_conv = Conv2d(
137 | conv_dim,
138 | conv_dim,
139 | kernel_size=3,
140 | stride=1,
141 | padding=1,
142 | bias=use_bias,
143 | norm=output_norm,
144 | activation=F.relu,
145 | )
146 | weight_init.c2_xavier_fill(output_conv)
147 | self.out_conv_list.append(output_conv)
148 | # add to list
149 | self.dyn_conv_list.append(self.dyn_conv)
150 | self.dyn_filter_list.append(self.dyn_filter)
151 | self.cross_attn_list.append(self.cross_attn)
152 | self.pool_conv_list.append(self.pool_conv)
153 | self.dyn_conv_list = nn.ModuleList(self.dyn_conv_list)
154 | self.dyn_filter_list = nn.ModuleList(self.dyn_filter_list)
155 | self.cross_attn_list = nn.ModuleList(self.cross_attn_list)
156 | self.pool_conv_list = nn.ModuleList(self.pool_conv_list)
157 | self.out_conv_list = nn.ModuleList(self.out_conv_list)
158 | print('use recurrent dual path cross attention')
159 |
160 | def forward_features(self, features, text_features, text_sentence_features, poses, memory, nf):
161 | # nf: num_frames
162 | text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c]
163 | text_features, text_masks = text_features.decompose()
164 | text_features = text_features.permute(1, 0, 2)
165 |
166 | for idx, (mem, f, pos) in enumerate(zip(memory[::-1], features[1:][::-1], poses[1:][::-1])): # 32x -> 8x
167 | lateral_conv = self.lateral_convs[idx]
168 | output_conv = self.output_convs[idx]
169 |
170 | _, x_mask = f.decompose()
171 | n, c, h, w = pos.shape
172 | b = n // nf
173 | t = nf
174 |
175 | # NOTE: here the (h, w) is the size for current fpn layer
176 | vision_features = lateral_conv(mem) # [b*t, c, h, w]
177 |
178 | # upsample
179 | # TODO: only fuse in high-level and repeat
180 | if idx == 0: # top layer
181 | for stage in range(self.stage_num):
182 | if stage == 0:
183 | vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t)
184 | else:
185 | vision_features = rearrange(y, '(b t) c h w -> (t h w) b c', b=b, t=t)
186 | vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t)
187 | vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t)
188 | cur_fpn, text_features_ = self.cross_attn_list[stage](tgt=vision_features,
189 | memory=text_features,
190 | t=t, h=h, w=w,
191 | tgt_key_padding_mask=vision_masks,
192 | memory_key_padding_mask=text_masks,
193 | pos=text_pos,
194 | query_pos=vision_pos
195 | ) # [t*h*w, b, c]
196 | # text_features [l, b, c]
197 | cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w)
198 |
199 | # # repeat cls_token to compute filter, only apply once
200 | # if text_sentence_features.shape[0] != b * t:
201 | # text_sentence_features = repeat(text_sentence_features, 'b c -> b t c', t=t)
202 | # text_sentence_features = rearrange(text_sentence_features, 'b t c -> (b t) c')
203 | # filter = self.dyn_filter(text_sentence_features).unsqueeze(1)
204 | # TODO: test only no fusion
205 | text_features_ = repeat(text_features_, 'l b c -> l b t c', t=t)
206 | text_features_ = rearrange(text_features_, 'l b t c -> (b t) c l')
207 | text_features_ = rearrange(self.pool(text_features_), '(b t) c l -> (b t) l c', b=b, t=t)
208 | filter = self.dyn_filter_list[stage](text_features_)
209 | y = cur_fpn
210 | y_ = self.dyn_conv_list[stage](cur_fpn)
211 | y_ = torch.einsum('ichw,ijc -> ijchw', y_, filter)
212 | y_ = rearrange(y_, 'i j c h w -> i (j c) h w')
213 | y_ = self.pool_conv_list[stage](y_)
214 | if stage == 0:
215 | y = output_conv(y+y_)
216 | else:
217 | y = self.out_conv_list[stage-1](y+y_)
218 | else:
219 | # Following FPN implementation, we use nearest upsampling here
220 | cur_fpn = vision_features
221 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
222 | y = output_conv(y)
223 |
224 | # 4x level
225 | lateral_conv = self.lateral_convs[-1]
226 | output_conv = self.output_convs[-1]
227 |
228 | x, x_mask = features[0].decompose()
229 | pos = poses[0]
230 | n, c, h, w = pos.shape
231 | b = n // nf
232 | t = nf
233 |
234 | vision_features = lateral_conv(x) # [b*t, c, h, w]
235 | cur_fpn = vision_features
236 | # Following FPN implementation, we use nearest upsampling here
237 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest")
238 | y = output_conv(y)
239 | return y
240 |
241 | def forward(self, features, text_features, text_sentence_features, pos, memory, nf):
242 | """The forward function receives the vision and language features,
243 | and outputs the mask features with the spatial stride of 4x.
244 |
245 | Args:
246 | features (list[NestedTensor]): backbone features (vision), length is number of FPN layers
247 | tensors: [b*t, ci, hi, wi], mask: [b*t, hi, wi]
248 | text_features (NestedTensor): text features (language)
249 | tensors: [b, length, c], mask: [b, length]
250 | pos (list[Tensor]): position encoding of vision features, length is number of FPN layers
251 | tensors: [b*t, c, hi, wi]
252 | memory (list[Tensor]): features from encoder output. from 8x -> 32x
253 | NOTE: the layer orders of both features and pos are res2 -> res5
254 |
255 | Returns:
256 | mask_features (Tensor): [b*t, mask_dim, h, w], with the spatial stride of 4x.
257 | """
258 | y = self.forward_features(features, text_features, text_sentence_features, pos, memory, nf)
259 | return self.mask_features(y)
260 |
261 |
262 | # class EPrime2SPrime(nn.Module):
263 | #
264 | #
265 | # class Cycle(nn.Module):
266 |
--------------------------------------------------------------------------------
/models/matcher.py:
--------------------------------------------------------------------------------
1 | """
2 | Instance Sequence Matching
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 | from torch import nn
8 | import torch.nn.functional as F
9 |
10 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, multi_iou
11 | from util.misc import nested_tensor_from_tensor_list
12 |
13 | INF = 100000000
14 |
15 | def dice_coef(inputs, targets):
16 | inputs = inputs.sigmoid()
17 | inputs = inputs.flatten(1).unsqueeze(1) # [N, 1, THW]
18 | targets = targets.flatten(1).unsqueeze(0) # [1, M, THW]
19 | numerator = 2 * (inputs * targets).sum(2)
20 | denominator = inputs.sum(-1) + targets.sum(-1)
21 |
22 | # NOTE coef doesn't be subtracted to 1 as it is not necessary for computing costs
23 | coef = (numerator + 1) / (denominator + 1)
24 | return coef
25 |
26 | def sigmoid_focal_coef(inputs, targets, alpha: float = 0.25, gamma: float = 2):
27 | N, M = len(inputs), len(targets)
28 | inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1) # [N, M, THW]
29 | targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1) # [N, M, THW]
30 |
31 | prob = inputs.sigmoid()
32 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
33 | p_t = prob * targets + (1 - prob) * (1 - targets)
34 | coef = ce_loss * ((1 - p_t) ** gamma)
35 |
36 | if alpha >= 0:
37 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
38 | coef = alpha_t * coef
39 |
40 | return coef.mean(2) # [N, M]
41 |
42 |
43 | class HungarianMatcher(nn.Module):
44 | """This class computes an assignment between the targets and the predictions of the network
45 |
46 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
47 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
48 | while the others are un-matched (and thus treated as non-objects).
49 | """
50 |
51 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1,
52 | cost_mask: float = 1, cost_dice: float = 1, num_classes: int = 1):
53 | """Creates the matcher
54 |
55 | Params:
56 | cost_class: This is the relative weight of the classification error in the matching cost
57 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
58 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
59 | cost_mask: This is the relative weight of the sigmoid focal loss of the mask in the matching cost
60 | cost_dice: This is the relative weight of the dice loss of the mask in the matching cost
61 | """
62 | super().__init__()
63 | self.cost_class = cost_class
64 | self.cost_bbox = cost_bbox
65 | self.cost_giou = cost_giou
66 | self.cost_mask = cost_mask
67 | self.cost_dice = cost_dice
68 | self.num_classes = num_classes
69 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0 \
70 | or cost_mask != 0 or cost_dice != 0, "all costs cant be 0"
71 | self.mask_out_stride = 4
72 |
73 | @torch.no_grad()
74 | def forward(self, outputs, targets):
75 | """ Performs the matching
76 | Params:
77 | outputs: This is a dict that contains at least these entries:
78 | "pred_logits": Tensor of dim [batch_size, num_queries_per_frame, num_frames, num_classes] with the classification logits
79 | "pred_boxes": Tensor of dim [batch_size, num_queries_per_frame, num_frames, 4] with the predicted box coordinates
80 | "pred_masks": Tensor of dim [batch_size, num_queries_per_frame, num_frames, h, w], h,w in 4x size
81 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
82 | NOTE: Since every frame has one object at most
83 | "labels": Tensor of dim [num_frames] (where num_target_boxes is the number of ground-truth
84 | objects in the target) containing the class labels
85 | "boxes": Tensor of dim [num_frames, 4] containing the target box coordinates
86 | "masks": Tensor of dim [num_frames, h, w], h,w in origin size
87 | Returns:
88 | A list of size batch_size, containing tuples of (index_i, index_j) where:
89 | - index_i is the indices of the selected predictions (in order)
90 | - index_j is the indices of the corresponding selected targets (in order)
91 | For each batch element, it holds:
92 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
93 | """
94 | src_logits = outputs["pred_logits"]
95 | src_boxes = outputs["pred_boxes"]
96 | src_masks = outputs["pred_masks"]
97 |
98 | bs, nf, nq, h, w = src_masks.shape
99 |
100 | # handle mask padding issue
101 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets],
102 | size_divisibility=32,
103 | split=False).decompose()
104 | target_masks = target_masks.to(src_masks) # [B, T, H, W]
105 |
106 | # downsample ground truth masks with ratio mask_out_stride
107 | start = int(self.mask_out_stride // 2)
108 | im_h, im_w = target_masks.shape[-2:]
109 |
110 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride]
111 | assert target_masks.size(2) * self.mask_out_stride == im_h
112 | assert target_masks.size(3) * self.mask_out_stride == im_w
113 |
114 | indices = []
115 | for i in range(bs):
116 | out_prob = src_logits[i].sigmoid()
117 | out_bbox = src_boxes[i]
118 | out_mask = src_masks[i]
119 |
120 | tgt_ids = targets[i]["labels"]
121 | tgt_bbox = targets[i]["boxes"]
122 | tgt_mask = target_masks[i]
123 | tgt_valid = targets[i]["valid"]
124 |
125 | # class cost
126 | # we average the cost on valid frames
127 | cost_class = []
128 | for t in range(nf):
129 | if tgt_valid[t] == 0:
130 | continue
131 |
132 | out_prob_split = out_prob[t]
133 | tgt_ids_split = tgt_ids[t].unsqueeze(0)
134 |
135 | # Compute the classification cost.
136 | alpha = 0.25
137 | gamma = 2.0
138 | neg_cost_class = (1 - alpha) * (out_prob_split ** gamma) * (-(1 - out_prob_split + 1e-8).log())
139 | pos_cost_class = alpha * ((1 - out_prob_split) ** gamma) * (-(out_prob_split + 1e-8).log())
140 | if self.num_classes == 1: # binary referred
141 | cost_class_split = pos_cost_class[:, [0]] - neg_cost_class[:, [0]]
142 | else:
143 | cost_class_split = pos_cost_class[:, tgt_ids_split] - neg_cost_class[:, tgt_ids_split]
144 |
145 | cost_class.append(cost_class_split)
146 | cost_class = torch.stack(cost_class, dim=0).mean(0) # [q, 1]
147 |
148 | # box cost
149 | # we average the cost on every frame
150 | cost_bbox, cost_giou = [], []
151 | for t in range(nf):
152 | out_bbox_split = out_bbox[t]
153 | tgt_bbox_split = tgt_bbox[t].unsqueeze(0)
154 |
155 | # Compute the L1 cost between boxes
156 | cost_bbox_split = torch.cdist(out_bbox_split, tgt_bbox_split, p=1)
157 |
158 | # Compute the giou cost betwen boxes
159 | cost_giou_split = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox_split),
160 | box_cxcywh_to_xyxy(tgt_bbox_split))
161 |
162 | cost_bbox.append(cost_bbox_split)
163 | cost_giou.append(cost_giou_split)
164 | cost_bbox = torch.stack(cost_bbox, dim=0).mean(0)
165 | cost_giou = torch.stack(cost_giou, dim=0).mean(0)
166 |
167 | # mask cost
168 | # Compute the focal loss between masks
169 | cost_mask = sigmoid_focal_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0))
170 |
171 | # Compute the dice loss betwen masks
172 | cost_dice = -dice_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0))
173 |
174 | # Final cost matrix
175 | C = self.cost_class * cost_class + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou + \
176 | self.cost_mask * cost_mask + self.cost_dice * cost_dice # [q, 1]
177 |
178 | # Only has one tgt, MinCost Matcher
179 | _, src_ind = torch.min(C, dim=0)
180 | tgt_ind = torch.arange(1).to(src_ind)
181 | indices.append((src_ind.long(), tgt_ind.long()))
182 |
183 | # list[tuple], length is batch_size
184 | return indices
185 |
186 |
187 | def build_matcher(args):
188 | if args.binary:
189 | num_classes = 1
190 | else:
191 | if args.dataset_file == 'ytvos':
192 | num_classes = 65
193 | elif args.dataset_file == 'davis':
194 | num_classes = 78
195 | elif args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb':
196 | num_classes = 1
197 | else:
198 | num_classes = 91 # for coco
199 | return HungarianMatcher(cost_class=args.set_cost_class,
200 | cost_bbox=args.set_cost_bbox,
201 | cost_giou=args.set_cost_giou,
202 | cost_mask=args.set_cost_mask,
203 | cost_dice=args.set_cost_dice,
204 | num_classes=num_classes)
205 |
206 |
207 |
--------------------------------------------------------------------------------
/models/ops/functions/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn_func import MSDeformAttnFunction
10 |
11 |
--------------------------------------------------------------------------------
/models/ops/functions/ms_deform_attn_func.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import torch
14 | import torch.nn.functional as F
15 | from torch.autograd import Function
16 | from torch.autograd.function import once_differentiable
17 |
18 | import MultiScaleDeformableAttention as MSDA
19 |
20 |
21 | class MSDeformAttnFunction(Function):
22 | @staticmethod
23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step):
24 | ctx.im2col_step = im2col_step
25 | output = MSDA.ms_deform_attn_forward(
26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step)
27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights)
28 | return output
29 |
30 | @staticmethod
31 | @once_differentiable
32 | def backward(ctx, grad_output):
33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors
34 | grad_value, grad_sampling_loc, grad_attn_weight = \
35 | MSDA.ms_deform_attn_backward(
36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step)
37 |
38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None
39 |
40 |
41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights):
42 | # for debug and test only,
43 | # need to use cuda version instead
44 | N_, S_, M_, D_ = value.shape
45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape
46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
47 | sampling_grids = 2 * sampling_locations - 1
48 | sampling_value_list = []
49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes):
50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_)
52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1)
54 | # N_*M_, D_, Lq_, P_
55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_,
56 | mode='bilinear', padding_mode='zeros', align_corners=False)
57 | sampling_value_list.append(sampling_value_l_)
58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_)
59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_)
60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_)
61 | return output.transpose(1, 2).contiguous()
62 |
--------------------------------------------------------------------------------
/models/ops/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | python setup.py build install
11 |
--------------------------------------------------------------------------------
/models/ops/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from .ms_deform_attn import MSDeformAttn
10 |
--------------------------------------------------------------------------------
/models/ops/modules/ms_deform_attn.py:
--------------------------------------------------------------------------------
1 | # Modify for sample points visualization
2 | # ------------------------------------------------------------------------------------------------
3 | # Deformable DETR
4 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | # ------------------------------------------------------------------------------------------------
7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | # ------------------------------------------------------------------------------------------------
9 |
10 | from __future__ import absolute_import
11 | from __future__ import print_function
12 | from __future__ import division
13 |
14 | import warnings
15 | import math
16 |
17 | import torch
18 | from torch import nn
19 | import torch.nn.functional as F
20 | from torch.nn.init import xavier_uniform_, constant_
21 |
22 | from ..functions import MSDeformAttnFunction
23 |
24 |
25 | def _is_power_of_2(n):
26 | if (not isinstance(n, int)) or (n < 0):
27 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n)))
28 | return (n & (n-1) == 0) and n != 0
29 |
30 |
31 | class MSDeformAttn(nn.Module):
32 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4):
33 | """
34 | Multi-Scale Deformable Attention Module
35 | :param d_model hidden dimension
36 | :param n_levels number of feature levels
37 | :param n_heads number of attention heads
38 | :param n_points number of sampling points per attention head per feature level
39 | """
40 | super().__init__()
41 | if d_model % n_heads != 0:
42 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads))
43 | _d_per_head = d_model // n_heads
44 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation
45 | if not _is_power_of_2(_d_per_head):
46 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 "
47 | "which is more efficient in our CUDA implementation.")
48 |
49 | self.im2col_step = 64
50 |
51 | self.d_model = d_model
52 | self.n_levels = n_levels
53 | self.n_heads = n_heads
54 | self.n_points = n_points
55 |
56 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2)
57 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points)
58 | self.value_proj = nn.Linear(d_model, d_model)
59 | self.output_proj = nn.Linear(d_model, d_model)
60 |
61 | self._reset_parameters()
62 |
63 | def _reset_parameters(self):
64 | constant_(self.sampling_offsets.weight.data, 0.)
65 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads)
66 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
67 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1)
68 | for i in range(self.n_points):
69 | grid_init[:, :, i, :] *= i + 1
70 | with torch.no_grad():
71 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
72 | constant_(self.attention_weights.weight.data, 0.)
73 | constant_(self.attention_weights.bias.data, 0.)
74 | xavier_uniform_(self.value_proj.weight.data)
75 | constant_(self.value_proj.bias.data, 0.)
76 | xavier_uniform_(self.output_proj.weight.data)
77 | constant_(self.output_proj.bias.data, 0.)
78 |
79 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None):
80 | """
81 | :param query (N, Length_{query}, C)
82 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area
83 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes
84 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C)
85 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
86 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}]
87 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements
88 |
89 | :return output (N, Length_{query}, C)
90 | """
91 | N, Len_q, _ = query.shape
92 | N, Len_in, _ = input_flatten.shape
93 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in
94 |
95 | value = self.value_proj(input_flatten)
96 | if input_padding_mask is not None:
97 | value = value.masked_fill(input_padding_mask[..., None], float(0))
98 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads)
99 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2)
100 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points)
101 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points)
102 | # N, Len_q, n_heads, n_levels, n_points, 2
103 | if reference_points.shape[-1] == 2:
104 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1)
105 | sampling_locations = reference_points[:, :, None, :, None, :] \
106 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
107 | elif reference_points.shape[-1] == 4:
108 | sampling_locations = reference_points[:, :, None, :, None, :2] \
109 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
110 | else:
111 | raise ValueError(
112 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1]))
113 | output = MSDeformAttnFunction.apply(
114 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step)
115 | output = self.output_proj(output)
116 |
117 | return output, sampling_locations, attention_weights
118 |
--------------------------------------------------------------------------------
/models/ops/setup.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | import os
10 | import glob
11 |
12 | import torch
13 |
14 | from torch.utils.cpp_extension import CUDA_HOME
15 | from torch.utils.cpp_extension import CppExtension
16 | from torch.utils.cpp_extension import CUDAExtension
17 |
18 | from setuptools import find_packages
19 | from setuptools import setup
20 |
21 | requirements = ["torch", "torchvision"]
22 |
23 | def get_extensions():
24 | this_dir = os.path.dirname(os.path.abspath(__file__))
25 | extensions_dir = os.path.join(this_dir, "src")
26 |
27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu"))
30 |
31 | sources = main_file + source_cpu
32 | extension = CppExtension
33 | extra_compile_args = {"cxx": []}
34 | define_macros = []
35 |
36 | if torch.cuda.is_available() and CUDA_HOME is not None:
37 | extension = CUDAExtension
38 | sources += source_cuda
39 | define_macros += [("WITH_CUDA", None)]
40 | extra_compile_args["nvcc"] = [
41 | "-DCUDA_HAS_FP16=1",
42 | "-D__CUDA_NO_HALF_OPERATORS__",
43 | "-D__CUDA_NO_HALF_CONVERSIONS__",
44 | "-D__CUDA_NO_HALF2_OPERATORS__",
45 | ]
46 | else:
47 | raise NotImplementedError('Cuda is not availabel')
48 |
49 | sources = [os.path.join(extensions_dir, s) for s in sources]
50 | include_dirs = [extensions_dir]
51 | ext_modules = [
52 | extension(
53 | "MultiScaleDeformableAttention",
54 | sources,
55 | include_dirs=include_dirs,
56 | define_macros=define_macros,
57 | extra_compile_args=extra_compile_args,
58 | )
59 | ]
60 | return ext_modules
61 |
62 | setup(
63 | name="MultiScaleDeformableAttention",
64 | version="1.0",
65 | author="Weijie Su",
66 | url="https://github.com/fundamentalvision/Deformable-DETR",
67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention",
68 | packages=find_packages(exclude=("configs", "tests",)),
69 | ext_modules=get_extensions(),
70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
71 | )
72 |
--------------------------------------------------------------------------------
/models/ops/src/cpu/ms_deform_attn_cpu.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 |
13 | #include
14 | #include
15 |
16 |
17 | at::Tensor
18 | ms_deform_attn_cpu_forward(
19 | const at::Tensor &value,
20 | const at::Tensor &spatial_shapes,
21 | const at::Tensor &level_start_index,
22 | const at::Tensor &sampling_loc,
23 | const at::Tensor &attn_weight,
24 | const int im2col_step)
25 | {
26 | AT_ERROR("Not implement on cpu");
27 | }
28 |
29 | std::vector
30 | ms_deform_attn_cpu_backward(
31 | const at::Tensor &value,
32 | const at::Tensor &spatial_shapes,
33 | const at::Tensor &level_start_index,
34 | const at::Tensor &sampling_loc,
35 | const at::Tensor &attn_weight,
36 | const at::Tensor &grad_output,
37 | const int im2col_step)
38 | {
39 | AT_ERROR("Not implement on cpu");
40 | }
41 |
42 |
--------------------------------------------------------------------------------
/models/ops/src/cpu/ms_deform_attn_cpu.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor
15 | ms_deform_attn_cpu_forward(
16 | const at::Tensor &value,
17 | const at::Tensor &spatial_shapes,
18 | const at::Tensor &level_start_index,
19 | const at::Tensor &sampling_loc,
20 | const at::Tensor &attn_weight,
21 | const int im2col_step);
22 |
23 | std::vector
24 | ms_deform_attn_cpu_backward(
25 | const at::Tensor &value,
26 | const at::Tensor &spatial_shapes,
27 | const at::Tensor &level_start_index,
28 | const at::Tensor &sampling_loc,
29 | const at::Tensor &attn_weight,
30 | const at::Tensor &grad_output,
31 | const int im2col_step);
32 |
33 |
34 |
--------------------------------------------------------------------------------
/models/ops/src/cuda/ms_deform_attn_cuda.cu:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include
12 | #include "cuda/ms_deform_im2col_cuda.cuh"
13 |
14 | #include
15 | #include
16 | #include
17 | #include
18 |
19 |
20 | at::Tensor ms_deform_attn_cuda_forward(
21 | const at::Tensor &value,
22 | const at::Tensor &spatial_shapes,
23 | const at::Tensor &level_start_index,
24 | const at::Tensor &sampling_loc,
25 | const at::Tensor &attn_weight,
26 | const int im2col_step)
27 | {
28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
33 |
34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
39 |
40 | const int batch = value.size(0);
41 | const int spatial_size = value.size(1);
42 | const int num_heads = value.size(2);
43 | const int channels = value.size(3);
44 |
45 | const int num_levels = spatial_shapes.size(0);
46 |
47 | const int num_query = sampling_loc.size(1);
48 | const int num_point = sampling_loc.size(4);
49 |
50 | const int im2col_step_ = std::min(batch, im2col_step);
51 |
52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
53 |
54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options());
55 |
56 | const int batch_n = im2col_step_;
57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
58 | auto per_value_size = spatial_size * num_heads * channels;
59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
61 | for (int n = 0; n < batch/im2col_step_; ++n)
62 | {
63 | auto columns = output_n.select(0, n);
64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] {
65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(),
66 | value.data() + n * im2col_step_ * per_value_size,
67 | spatial_shapes.data(),
68 | level_start_index.data(),
69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
72 | columns.data());
73 |
74 | }));
75 | }
76 |
77 | output = output.view({batch, num_query, num_heads*channels});
78 |
79 | return output;
80 | }
81 |
82 |
83 | std::vector ms_deform_attn_cuda_backward(
84 | const at::Tensor &value,
85 | const at::Tensor &spatial_shapes,
86 | const at::Tensor &level_start_index,
87 | const at::Tensor &sampling_loc,
88 | const at::Tensor &attn_weight,
89 | const at::Tensor &grad_output,
90 | const int im2col_step)
91 | {
92 |
93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous");
94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous");
95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous");
96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous");
97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous");
98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous");
99 |
100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor");
101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor");
102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor");
103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor");
104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor");
105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor");
106 |
107 | const int batch = value.size(0);
108 | const int spatial_size = value.size(1);
109 | const int num_heads = value.size(2);
110 | const int channels = value.size(3);
111 |
112 | const int num_levels = spatial_shapes.size(0);
113 |
114 | const int num_query = sampling_loc.size(1);
115 | const int num_point = sampling_loc.size(4);
116 |
117 | const int im2col_step_ = std::min(batch, im2col_step);
118 |
119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_);
120 |
121 | auto grad_value = at::zeros_like(value);
122 | auto grad_sampling_loc = at::zeros_like(sampling_loc);
123 | auto grad_attn_weight = at::zeros_like(attn_weight);
124 |
125 | const int batch_n = im2col_step_;
126 | auto per_value_size = spatial_size * num_heads * channels;
127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2;
128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point;
129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels});
130 |
131 | for (int n = 0; n < batch/im2col_step_; ++n)
132 | {
133 | auto grad_output_g = grad_output_n.select(0, n);
134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] {
135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(),
136 | grad_output_g.data(),
137 | value.data() + n * im2col_step_ * per_value_size,
138 | spatial_shapes.data(),
139 | level_start_index.data(),
140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size,
142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point,
143 | grad_value.data() + n * im2col_step_ * per_value_size,
144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size,
145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size);
146 |
147 | }));
148 | }
149 |
150 | return {
151 | grad_value, grad_sampling_loc, grad_attn_weight
152 | };
153 | }
--------------------------------------------------------------------------------
/models/ops/src/cuda/ms_deform_attn_cuda.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 | #include
13 |
14 | at::Tensor ms_deform_attn_cuda_forward(
15 | const at::Tensor &value,
16 | const at::Tensor &spatial_shapes,
17 | const at::Tensor &level_start_index,
18 | const at::Tensor &sampling_loc,
19 | const at::Tensor &attn_weight,
20 | const int im2col_step);
21 |
22 | std::vector ms_deform_attn_cuda_backward(
23 | const at::Tensor &value,
24 | const at::Tensor &spatial_shapes,
25 | const at::Tensor &level_start_index,
26 | const at::Tensor &sampling_loc,
27 | const at::Tensor &attn_weight,
28 | const at::Tensor &grad_output,
29 | const int im2col_step);
30 |
31 |
--------------------------------------------------------------------------------
/models/ops/src/ms_deform_attn.h:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #pragma once
12 |
13 | #include "cpu/ms_deform_attn_cpu.h"
14 |
15 | #ifdef WITH_CUDA
16 | #include "cuda/ms_deform_attn_cuda.h"
17 | #endif
18 |
19 |
20 | at::Tensor
21 | ms_deform_attn_forward(
22 | const at::Tensor &value,
23 | const at::Tensor &spatial_shapes,
24 | const at::Tensor &level_start_index,
25 | const at::Tensor &sampling_loc,
26 | const at::Tensor &attn_weight,
27 | const int im2col_step)
28 | {
29 | if (value.type().is_cuda())
30 | {
31 | #ifdef WITH_CUDA
32 | return ms_deform_attn_cuda_forward(
33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step);
34 | #else
35 | AT_ERROR("Not compiled with GPU support");
36 | #endif
37 | }
38 | AT_ERROR("Not implemented on the CPU");
39 | }
40 |
41 | std::vector
42 | ms_deform_attn_backward(
43 | const at::Tensor &value,
44 | const at::Tensor &spatial_shapes,
45 | const at::Tensor &level_start_index,
46 | const at::Tensor &sampling_loc,
47 | const at::Tensor &attn_weight,
48 | const at::Tensor &grad_output,
49 | const int im2col_step)
50 | {
51 | if (value.type().is_cuda())
52 | {
53 | #ifdef WITH_CUDA
54 | return ms_deform_attn_cuda_backward(
55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step);
56 | #else
57 | AT_ERROR("Not compiled with GPU support");
58 | #endif
59 | }
60 | AT_ERROR("Not implemented on the CPU");
61 | }
62 |
63 |
--------------------------------------------------------------------------------
/models/ops/src/vision.cpp:
--------------------------------------------------------------------------------
1 | /*!
2 | **************************************************************************************************
3 | * Deformable DETR
4 | * Copyright (c) 2020 SenseTime. All Rights Reserved.
5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details]
6 | **************************************************************************************************
7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
8 | **************************************************************************************************
9 | */
10 |
11 | #include "ms_deform_attn.h"
12 |
13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward");
15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward");
16 | }
17 |
--------------------------------------------------------------------------------
/models/ops/test.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------------------------------
2 | # Deformable DETR
3 | # Copyright (c) 2020 SenseTime. All Rights Reserved.
4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5 | # ------------------------------------------------------------------------------------------------
6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0
7 | # ------------------------------------------------------------------------------------------------
8 |
9 | from __future__ import absolute_import
10 | from __future__ import print_function
11 | from __future__ import division
12 |
13 | import time
14 | import torch
15 | import torch.nn as nn
16 | from torch.autograd import gradcheck
17 |
18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch
19 |
20 |
21 | N, M, D = 1, 2, 2
22 | Lq, L, P = 2, 2, 2
23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda()
24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1]))
25 | S = sum([(H*W).item() for H, W in shapes])
26 |
27 |
28 | torch.manual_seed(3)
29 |
30 |
31 | @torch.no_grad()
32 | def check_forward_equal_with_pytorch_double():
33 | value = torch.rand(N, S, M, D).cuda() * 0.01
34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
37 | im2col_step = 2
38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu()
39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu()
40 | fwdok = torch.allclose(output_cuda, output_pytorch)
41 | max_abs_err = (output_cuda - output_pytorch).abs().max()
42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
43 |
44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
45 |
46 |
47 | @torch.no_grad()
48 | def check_forward_equal_with_pytorch_float():
49 | value = torch.rand(N, S, M, D).cuda() * 0.01
50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
53 | im2col_step = 2
54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu()
55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu()
56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3)
57 | max_abs_err = (output_cuda - output_pytorch).abs().max()
58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max()
59 |
60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}')
61 |
62 |
63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True):
64 |
65 | value = torch.rand(N, S, M, channels).cuda() * 0.01
66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda()
67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5
68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True)
69 | im2col_step = 2
70 | func = MSDeformAttnFunction.apply
71 |
72 | value.requires_grad = grad_value
73 | sampling_locations.requires_grad = grad_sampling_loc
74 | attention_weights.requires_grad = grad_attn_weight
75 |
76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step))
77 |
78 | print(f'* {gradok} check_gradient_numerical(D={channels})')
79 |
80 |
81 | if __name__ == '__main__':
82 | check_forward_equal_with_pytorch_double()
83 | check_forward_equal_with_pytorch_float()
84 |
85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]:
86 | check_gradient_numerical(channels, True, True, True)
87 |
88 |
89 |
90 |
--------------------------------------------------------------------------------
/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | """
2 | Various positional encodings for the transformer.
3 | Modified from DETR (https://github.com/facebookresearch/detr)
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 |
9 | from util.misc import NestedTensor
10 |
11 | # dimension == 1
12 | class PositionEmbeddingSine1D(nn.Module):
13 | """
14 | This is a more standard version of the position embedding, very similar to the one
15 | used by the Attention is all you need paper, generalized to work on images.
16 | """
17 | def __init__(self, num_pos_feats=256, temperature=10000, normalize=False, scale=None):
18 | super().__init__()
19 | self.num_pos_feats = num_pos_feats
20 | self.temperature = temperature
21 | self.normalize = normalize
22 | if scale is not None and normalize is False:
23 | raise ValueError("normalize should be True if scale is passed")
24 | if scale is None:
25 | scale = 2 * math.pi
26 | self.scale = scale
27 |
28 | def forward(self, tensor_list: NestedTensor):
29 | x = tensor_list.tensors # [B, C, T]
30 | mask = tensor_list.mask # [B, T]
31 | assert mask is not None
32 | not_mask = ~mask
33 | x_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T]
34 | if self.normalize:
35 | eps = 1e-6
36 | x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
37 |
38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40 |
41 | pos_x = x_embed[:, :, None] / dim_t # [B, T, C]
42 | # n,c,t
43 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
44 | pos = pos_x.permute(0, 2, 1) # [B, C, T]
45 | return pos
46 |
47 | # dimension == 2
48 | class PositionEmbeddingSine2D(nn.Module):
49 | """
50 | This is a more standard version of the position embedding, very similar to the one
51 | used by the Attention is all you need paper, generalized to work on images.
52 | """
53 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
54 | super().__init__()
55 | self.num_pos_feats = num_pos_feats
56 | self.temperature = temperature
57 | self.normalize = normalize
58 | if scale is not None and normalize is False:
59 | raise ValueError("normalize should be True if scale is passed")
60 | if scale is None:
61 | scale = 2 * math.pi
62 | self.scale = scale
63 |
64 | def forward(self, tensor_list: NestedTensor):
65 | x = tensor_list.tensors # [B, C, H, W]
66 | mask = tensor_list.mask # [B, H, W]
67 | assert mask is not None
68 | not_mask = ~mask
69 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
70 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
71 | if self.normalize:
72 | eps = 1e-6
73 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
74 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale
75 |
76 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
77 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
78 |
79 | pos_x = x_embed[:, :, :, None] / dim_t
80 | pos_y = y_embed[:, :, :, None] / dim_t
81 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
82 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
83 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
84 | return pos # [B, C, H, W]
85 |
86 |
87 | # dimension == 3
88 | class PositionEmbeddingSine3D(nn.Module):
89 | """
90 | This is a more standard version of the position embedding, very similar to the one
91 | used by the Attention is all you need paper, generalized to work on images.
92 | """
93 | def __init__(self, num_pos_feats=64, num_frames=36, temperature=10000, normalize=False, scale=None):
94 | super().__init__()
95 | self.num_pos_feats = num_pos_feats
96 | self.temperature = temperature
97 | self.normalize = normalize
98 | self.frames = num_frames
99 | if scale is not None and normalize is False:
100 | raise ValueError("normalize should be True if scale is passed")
101 | if scale is None:
102 | scale = 2 * math.pi
103 | self.scale = scale
104 |
105 | def forward(self, tensor_list: NestedTensor):
106 | x = tensor_list.tensors # [B*T, C, H, W]
107 | mask = tensor_list.mask # [B*T, H, W]
108 | n,h,w = mask.shape
109 | mask = mask.reshape(n//self.frames, self.frames,h,w) # [B, T, H, W]
110 | assert mask is not None
111 | not_mask = ~mask
112 | z_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T, H, W]
113 | y_embed = not_mask.cumsum(2, dtype=torch.float32) # [B, T, H, W]
114 | x_embed = not_mask.cumsum(3, dtype=torch.float32) # [B, T, H, W]
115 | if self.normalize:
116 | eps = 1e-6
117 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale
118 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale
119 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale
120 |
121 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) #
122 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
123 |
124 | pos_x = x_embed[:, :, :, :, None] / dim_t # [B, T, H, W, c]
125 | pos_y = y_embed[:, :, :, :, None] / dim_t
126 | pos_z = z_embed[:, :, :, :, None] / dim_t
127 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) # [B, T, H, W, c]
128 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
129 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4)
130 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) # [B, T, C, H, W]
131 | return pos
132 |
133 |
134 |
135 | def build_position_encoding(args):
136 | # build 2D position encoding
137 | N_steps = args.hidden_dim // 2 # 256 / 2 = 128
138 | if args.position_embedding in ('v2', 'sine'):
139 | # TODO find a better way of exposing other arguments
140 | position_embedding = PositionEmbeddingSine2D(N_steps, normalize=True)
141 | else:
142 | raise ValueError(f"not supported {args.position_embedding}")
143 |
144 | return position_embedding
145 |
146 |
--------------------------------------------------------------------------------
/models/postprocessors.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved
2 | """Postprocessors class to transform MDETR output according to the downstream task"""
3 | from typing import Dict
4 |
5 | import numpy as np
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import nn
9 | import pycocotools.mask as mask_util
10 |
11 | from util import box_ops
12 |
13 |
14 | class A2DSentencesPostProcess(nn.Module):
15 | """
16 | This module converts the model's output into the format expected by the coco api for the given task
17 | """
18 | def __init__(self, threshold=0.5):
19 | super().__init__()
20 | self.threshold = threshold
21 |
22 | @torch.no_grad()
23 | def forward(self, outputs, orig_target_sizes, max_target_sizes):
24 | """ Perform the computation
25 | Parameters:
26 | outputs: raw outputs of the model
27 | orig_target_sizes: original size of the samples (no augmentations or padding)
28 | max_target_sizes: size of samples (input to model) after size augmentation.
29 | NOTE: the max_padding_size is 4x out_masks.shape[-2:]
30 | """
31 | assert len(orig_target_sizes) == len(max_target_sizes)
32 |
33 | # there is only one valid frames, thus T=1
34 | out_logits = outputs['pred_logits'][:, 0, :, 0] # [B, T, N, 1] -> [B, N]
35 | out_masks = outputs['pred_masks'][:, 0, :, :, :] # [B, T, N, out_h, out_w] -> [B, N, out_h, out_w]
36 | out_h, out_w = out_masks.shape[-2:]
37 |
38 | scores = out_logits.sigmoid()
39 | pred_masks = F.interpolate(out_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False) # [B, N, H, W]
40 | pred_masks = (pred_masks.sigmoid() > 0.5) # [B, N, H, W]
41 | processed_pred_masks, rle_masks = [], []
42 | # for each batch
43 | for f_pred_masks, resized_size, orig_size in zip(pred_masks, max_target_sizes, orig_target_sizes):
44 | f_mask_h, f_mask_w = resized_size # resized shape without padding
45 | f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, :f_mask_w].unsqueeze(1) # remove the samples' padding
46 | # resize the samples back to their original dataset (target) size for evaluation
47 | f_pred_masks_processed = F.interpolate(f_pred_masks_no_pad.float(), size=tuple(orig_size.tolist()), mode="nearest")
48 | f_pred_rle_masks = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
49 | for mask in f_pred_masks_processed.cpu()]
50 | processed_pred_masks.append(f_pred_masks_processed)
51 | rle_masks.append(f_pred_rle_masks)
52 | predictions = [{'scores': s, 'masks': m, 'rle_masks': rle}
53 | for s, m, rle in zip(scores, processed_pred_masks, rle_masks)]
54 | return predictions
55 |
56 |
57 | # PostProcess for pretraining
58 | class PostProcess(nn.Module):
59 | """ This module converts the model's output into the format expected by the coco api"""
60 |
61 | @torch.no_grad()
62 | def forward(self, outputs, target_sizes):
63 | """Perform the computation
64 | Parameters:
65 | outputs: raw outputs of the model
66 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
67 | For evaluation, this must be the original image size (before any data augmentation)
68 | For visualization, this should be the image size after data augment, but before padding
69 | Returns:
70 |
71 | """
72 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]
73 |
74 | assert len(out_logits) == len(target_sizes)
75 | assert target_sizes.shape[1] == 2
76 |
77 | # coco, num_frames=1
78 | out_logits = outputs["pred_logits"].flatten(1, 2)
79 | out_boxes = outputs["pred_boxes"].flatten(1, 2)
80 | bs, num_queries = out_logits.shape[:2]
81 |
82 | prob = out_logits.sigmoid() # [bs, num_queries, num_classes]
83 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True)
84 | scores = topk_values # [bs, num_queries]
85 | topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries]
86 | labels = topk_indexes % out_logits.shape[2] # [bs, num_queries]
87 |
88 | boxes = box_ops.box_cxcywh_to_xyxy(out_boxes) # [bs, num_queries, 4]
89 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4))
90 |
91 | # and from relative [0, 1] to absolute [0, height] coordinates
92 | img_h, img_w = target_sizes.unbind(1)
93 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
94 | boxes = boxes * scale_fct[:, None, :] # [bs, num_queries, 4]
95 |
96 | assert len(scores) == len(labels) == len(boxes)
97 | results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)]
98 |
99 | return results
100 |
101 |
102 | class PostProcessSegm(nn.Module):
103 | """Similar to PostProcess but for segmentation masks.
104 | This processor is to be called sequentially after PostProcess.
105 | Args:
106 | threshold: threshold that will be applied to binarize the segmentation masks.
107 | """
108 |
109 | def __init__(self, threshold=0.5):
110 | super().__init__()
111 | self.threshold = threshold
112 |
113 | @torch.no_grad()
114 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
115 | """Perform the computation
116 | Parameters:
117 | results: already pre-processed boxes (output of PostProcess) NOTE here
118 | outputs: raw outputs of the model
119 | orig_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
120 | For evaluation, this must be the original image size (before any data augmentation)
121 | For visualization, this should be the image size after data augment, but before padding
122 | max_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
123 | after data augmentation.
124 | """
125 | assert len(orig_target_sizes) == len(max_target_sizes)
126 |
127 | out_logits = outputs["pred_logits"].flatten(1, 2)
128 | out_masks = outputs["pred_masks"].flatten(1, 2)
129 | bs, num_queries = out_logits.shape[:2]
130 |
131 | prob = out_logits.sigmoid()
132 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True)
133 | scores = topk_values # [bs, num_queries]
134 | topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries]
135 | labels = topk_indexes % out_logits.shape[2] # [bs, num_queries]
136 |
137 | outputs_masks = [out_m[topk_boxes[i]].unsqueeze(0) for i, out_m, in enumerate(out_masks)] # list[Tensor]
138 | outputs_masks = torch.cat(outputs_masks, dim=0) # [bs, num_queries, H, W]
139 | out_h, out_w = outputs_masks.shape[-2:]
140 |
141 | # max_h, max_w = max_target_sizes.max(0)[0].tolist()
142 | # outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False)
143 | outputs_masks = F.interpolate(outputs_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False)
144 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()
145 |
146 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
147 | img_h, img_w = t[0], t[1]
148 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
149 | results[i]["masks"] = F.interpolate(
150 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
151 | ).byte()
152 |
153 | return results
154 |
155 |
156 |
157 | def build_postprocessors(args, dataset_name):
158 | if dataset_name == 'a2d' or dataset_name == 'jhmdb':
159 | postprocessors = A2DSentencesPostProcess(threshold=args.threshold)
160 | else:
161 | # for coco pretrain postprocessor
162 | postprocessors: Dict[str, nn.Module] = {"bbox": PostProcess()}
163 | if args.masks:
164 | postprocessors["segm"] = PostProcessSegm(threshold=args.threshold)
165 | # postprocessors = PostProcessSegm(threshold=args.threshold)
166 | return postprocessors
167 |
--------------------------------------------------------------------------------
/models/vector_quantitizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class VectorQuantizerEMA(nn.Module):
6 | def __init__(self, num_embeddings=512, embedding_dim=256, commitment_cost=0.25, decay=0.99, epsilon=1e-5):
7 | super(VectorQuantizerEMA, self).__init__()
8 |
9 | self._embedding_dim = embedding_dim
10 | self._num_embeddings = num_embeddings
11 |
12 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
13 | self._embedding.weight.data.normal_()
14 | self._commitment_cost = commitment_cost
15 |
16 | self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
17 | self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
18 | self._ema_w.data.normal_()
19 |
20 | self._decay = decay
21 | self._epsilon = epsilon
22 |
23 | def forward(self, inputs):
24 | # inputs btqc
25 | input_shape = inputs.shape
26 |
27 | # Flatten input
28 | flat_input = inputs.view(-1, self._embedding_dim)
29 |
30 | # Calculate distances
31 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
32 | + torch.sum(self._embedding.weight ** 2, dim=1)
33 | - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
34 |
35 | # Encoding
36 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
37 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
38 | encodings.scatter_(1, encoding_indices, 1)
39 |
40 | # Quantize and unflatten
41 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
42 |
43 | # Use EMA to update the embedding vectors
44 | if self.training:
45 | self._ema_cluster_size = self._ema_cluster_size * self._decay + \
46 | (1 - self._decay) * torch.sum(encodings, 0)
47 |
48 | # Laplace smoothing of the cluster size
49 | n = torch.sum(self._ema_cluster_size.data)
50 | self._ema_cluster_size = (
51 | (self._ema_cluster_size + self._epsilon)
52 | / (n + self._num_embeddings * self._epsilon) * n)
53 |
54 | dw = torch.matmul(encodings.t(), flat_input)
55 | self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
56 |
57 | self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
58 |
59 | # Loss
60 | e_latent_loss = F.mse_loss(quantized.detach(), inputs)
61 | loss = self._commitment_cost * e_latent_loss
62 |
63 | # Straight Through Estimator
64 | quantized = inputs + (quantized - inputs).detach()
65 | avg_probs = torch.mean(encodings, dim=0)
66 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
67 |
68 | return loss, quantized, perplexity, encodings
69 |
70 |
71 | class VectorQuantizer(nn.Module):
72 | def __init__(self, num_embeddings=512, embedding_dim=256, commitment_cost=0.25):
73 | super(VectorQuantizer, self).__init__()
74 |
75 | self._embedding_dim = embedding_dim
76 | self._num_embeddings = num_embeddings
77 |
78 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
79 | self._embedding.weight.data.normal_()
80 | self._commitment_cost = commitment_cost
81 |
82 | def forward(self, inputs):
83 | input_shape = inputs.shape
84 |
85 | # Flatten input
86 | flat_input = inputs.view(-1, self._embedding_dim)
87 |
88 | # Calculate distances
89 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True)
90 | + torch.sum(self._embedding.weight ** 2, dim=1)
91 | - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
92 |
93 | # Encoding
94 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
95 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
96 | encodings.scatter_(1, encoding_indices, 1)
97 |
98 | # Quantize and unflatten
99 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
100 |
101 | # Loss
102 | e_latent_loss = F.mse_loss(quantized.detach(), inputs)
103 | q_latent_loss = F.mse_loss(quantized, inputs.detach())
104 | loss = q_latent_loss + self._commitment_cost * e_latent_loss
105 |
106 | quantized = inputs + (quantized - inputs).detach()
107 | avg_probs = torch.mean(encodings, dim=0)
108 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
109 |
110 | return loss, quantized, perplexity, encodings
--------------------------------------------------------------------------------
/opts.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | def get_args_parser():
4 | parser = argparse.ArgumentParser('ReferFormer training and inference scripts.', add_help=False)
5 | parser.add_argument('--lr', default=1e-4, type=float)
6 | parser.add_argument('--lr_backbone', default=5e-5, type=float)
7 | parser.add_argument('--lr_backbone_names', default=['backbone.0'], type=str, nargs='+')
8 | parser.add_argument('--lr_text_encoder', default=1e-5, type=float)
9 | parser.add_argument('--lr_text_encoder_names', default=['text_encoder'], type=str, nargs='+')
10 | parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+')
11 | parser.add_argument('--lr_linear_proj_mult', default=1.0, type=float)
12 | parser.add_argument('--lr_multi', default=1.0, type=float)
13 | parser.add_argument('--batch_size', default=1, type=int)
14 | parser.add_argument('--weight_decay', default=5e-4, type=float)
15 | parser.add_argument('--epochs', default=12, type=int)
16 | parser.add_argument('--lr_drop', default=[8, 10], type=int, nargs='+')
17 | parser.add_argument('--clip_max_norm', default=0.1, type=float,
18 | help='gradient clipping max norm')
19 |
20 | # Model parameters
21 | # load the pretrained weights
22 | parser.add_argument('--pretrained_weights', type=str, default=None,
23 | help="Path to the pretrained model.")
24 |
25 | # Variants of Deformable DETR
26 | parser.add_argument('--with_box_refine', default=False, action='store_true')
27 | parser.add_argument('--two_stage', default=False, action='store_true') # NOTE: must be false
28 |
29 | # * Backbone
30 | # ["resnet50", "resnet101", "swin_t_p4w7", "swin_s_p4w7", "swin_b_p4w7", "swin_l_p4w7"]
31 | # ["video_swin_t_p4w7", "video_swin_s_p4w7", "video_swin_b_p4w7"]
32 | parser.add_argument('--backbone', default='resnet50', type=str,
33 | help="Name of the convolutional backbone to use")
34 | parser.add_argument('--backbone_pretrained', default=None, type=str,
35 | help="if use swin backbone and train from scratch, the path to the pretrained weights")
36 | parser.add_argument('--use_checkpoint', action='store_true', help='whether use checkpoint for swin/video swin backbone')
37 | parser.add_argument('--dilation', action='store_true', # DC5
38 | help="If true, we replace stride with dilation in the last convolutional block (DC5)")
39 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
40 | help="Type of positional embedding to use on top of the image features")
41 | parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels')
42 |
43 | # * Transformer
44 | parser.add_argument('--enc_layers', default=4, type=int,
45 | help="Number of encoding layers in the transformer")
46 | parser.add_argument('--dec_layers', default=4, type=int,
47 | help="Number of decoding layers in the transformer")
48 | parser.add_argument('--dim_feedforward', default=2048, type=int,
49 | help="Intermediate size of the feedforward layers in the transformer blocks")
50 | parser.add_argument('--hidden_dim', default=256, type=int,
51 | help="Size of the embeddings (dimension of the transformer)")
52 | parser.add_argument('--dropout', default=0.1, type=float,
53 | help="Dropout applied in the transformer")
54 | parser.add_argument('--nheads', default=8, type=int,
55 | help="Number of attention heads inside the transformer's attentions")
56 | parser.add_argument('--num_frames', default=5, type=int,
57 | help="Number of clip frames for training")
58 | parser.add_argument('--num_queries', default=5, type=int,
59 | help="Number of query slots, all frames share the same queries")
60 | parser.add_argument('--dec_n_points', default=4, type=int)
61 | parser.add_argument('--enc_n_points', default=4, type=int)
62 | parser.add_argument('--pre_norm', action='store_true')
63 | # for text
64 | parser.add_argument('--freeze_text_encoder', action='store_true') # default: False
65 |
66 | # * Segmentation
67 | parser.add_argument('--masks', action='store_true',
68 | help="Train segmentation head if the flag is provided")
69 | parser.add_argument('--mask_dim', default=256, type=int,
70 | help="Size of the mask embeddings (dimension of the dynamic mask conv)")
71 | parser.add_argument('--controller_layers', default=3, type=int,
72 | help="Dynamic conv layer number")
73 | parser.add_argument('--dynamic_mask_channels', default=8, type=int,
74 | help="Dynamic conv final channel number")
75 | parser.add_argument('--no_rel_coord', dest='rel_coord', action='store_false',
76 | help="Disables relative coordinates")
77 |
78 | # Loss
79 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
80 | help="Disables auxiliary decoding losses (loss at each layer)")
81 | # * Matcher
82 | parser.add_argument('--set_cost_class', default=2, type=float,
83 | help="Class coefficient in the matching cost")
84 | parser.add_argument('--set_cost_bbox', default=5, type=float,
85 | help="L1 box coefficient in the matching cost")
86 | parser.add_argument('--set_cost_giou', default=2, type=float,
87 | help="giou box coefficient in the matching cost")
88 | parser.add_argument('--set_cost_mask', default=2, type=float,
89 | help="mask coefficient in the matching cost")
90 | parser.add_argument('--set_cost_dice', default=5, type=float,
91 | help="mask coefficient in the matching cost")
92 | # * Loss coefficients
93 | parser.add_argument('--mask_loss_coef', default=2, type=float)
94 | parser.add_argument('--dice_loss_coef', default=5, type=float)
95 | parser.add_argument('--cls_loss_coef', default=2, type=float)
96 | parser.add_argument('--bbox_loss_coef', default=5, type=float)
97 | parser.add_argument('--giou_loss_coef', default=2, type=float)
98 | parser.add_argument('--eos_coef', default=0.1, type=float,
99 | help="Relative classification weight of the no-object class")
100 | parser.add_argument('--focal_alpha', default=0.25, type=float)
101 |
102 | # dataset parameters
103 | # ['ytvos', 'davis', 'a2d', 'jhmdb', 'refcoco', 'refcoco+', 'refcocog', 'all']
104 | # 'all': using the three ref datasets for pretraining
105 | parser.add_argument('--dataset_file', default='ytvos', help='Dataset name')
106 | parser.add_argument('--coco_path', type=str, default='data/coco')
107 | parser.add_argument('--ytvos_path', type=str, default='data/ref-youtube-vos')
108 | parser.add_argument('--davis_path', type=str, default='data/ref-davis')
109 | parser.add_argument('--a2d_path', type=str, default='data/a2d_sentences')
110 | parser.add_argument('--jhmdb_path', type=str, default='/mnt/data/jhmdb')
111 | parser.add_argument('--max_skip', default=3, type=int, help="max skip frame number")
112 | parser.add_argument('--max_size', default=640, type=int, help="max size for the frame")
113 | parser.add_argument('--binary', action='store_true')
114 | parser.add_argument('--remove_difficult', action='store_true')
115 |
116 | parser.add_argument('--output_dir', default='output',
117 | help='path where to save, empty for no saving')
118 | parser.add_argument('--device', default='cuda',
119 | help='device to use for training / testing')
120 | parser.add_argument('--seed', default=42, type=int)
121 | parser.add_argument('--resume', default='', help='resume from checkpoint')
122 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
123 | help='start epoch')
124 | parser.add_argument('--eval', action='store_true')
125 | parser.add_argument('--num_workers', default=4, type=int)
126 |
127 | # test setting
128 | parser.add_argument('--threshold', default=0.5, type=float) # binary threshold for mask
129 | parser.add_argument('--ngpu', default=8, type=int, help='gpu number when inference for ref-ytvos and ref-davis')
130 | parser.add_argument('--split', default='valid', type=str, choices=['valid', 'test'])
131 | parser.add_argument('--visualize', action='store_true', help='whether visualize the masks during inference')
132 |
133 | # distributed training parameters
134 | parser.add_argument('--world_size', default=1, type=int,
135 | help='number of distributed processes')
136 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
137 | parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory')
138 |
139 | # additional parameters
140 | parser.add_argument('--fpn_type', default='dual', help='fpn type can be dual, dyn, and default')
141 | parser.add_argument('--val_type', default='', help='validation dataset type for youtube rvos')
142 | parser.add_argument('--as_vos', default=False, help='as semi-supervised vos')
143 | parser.add_argument('--query_feat_dim', default=2048, help='feat_dim of 1/32 visual feature map')
144 | parser.add_argument('--inf_res', default=360, type=int, help='inference size')
145 | parser.add_argument('--text_enc_type', default='distilroberta-base', help='fpn type can be dual, dyn, and default')
146 | parser.add_argument('--use_cycle', action='store_true', help='use cycle consistency')
147 | parser.add_argument('--add_negative', action='store_true', help='add negative sample on gpu 0 for triplet loss')
148 | parser.add_argument('--only_cycle', action='store_true', help='only train cycle consistency part model')
149 | parser.add_argument('--cycle_loss_dist_coef', default=1, type=float)
150 | parser.add_argument('--cycle_loss_angle_coef', default=1, type=float)
151 | parser.add_argument('--cycle_loss_mse_coef', default=0.0, type=float)
152 | parser.add_argument('--cycle_loss_cls_coef', default=1, type=float)
153 | parser.add_argument('--fg_contra_loss_coef', default=1, type=float)
154 | parser.add_argument('--VQ_loss_coef', default=0.5, type=float)
155 | parser.add_argument('--cycle_loss_contrastive_coef', default=0.1, type=float)
156 | parser.add_argument('--loc_loss_coef', default=3, type=float)
157 | parser.add_argument('--lr_anchor_names', default=['negative_anchor'], type=str, nargs='+')
158 | parser.add_argument('--lr_anchor_mult', default=0.1, type=float)
159 | parser.add_argument('--contra_margin', default=0.5, type=float)
160 | parser.add_argument('--is_eval', action='store_true', help='use in eval')
161 | parser.add_argument('--neg_cls', action='store_true', help='add classifier to classify neg samples')
162 | parser.add_argument('--bert_cycle', action='store_true', help='use 768 dim output from bert as pos gt')
163 | parser.add_argument('--mix_query', action='store_true', help='mix pseudo-text and text query to deformable trans')
164 | parser.add_argument('--quantitize_query', action='store_true', help='quantitize text query')
165 | parser.add_argument('--use_fg_contra', action='store_true', help='use fg contra loss')
166 | parser.add_argument('--freeze_quantitizer', action='store_true', help='freeze quantitizer')
167 | parser.add_argument('--pseudo_label_path', default='', help='pseudo label path')
168 | parser.add_argument('--use_cls', action='store_true', help='use neg cls to filter out negative videos')
169 | parser.add_argument('--use_score', action='store_true', help='use score to filter out negative videos')
170 | parser.add_argument('--save_prob', action='store_true', help='save prob map')
171 | parser.add_argument('--segm_frame', default=5, type=int)
172 | parser.add_argument('--demo_exp', default='a big track on the road', help='demo exp')
173 | parser.add_argument('--demo_path', default='demo/demo_examples', help='demo frames folder path')
174 | return parser
175 |
176 |
177 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.12.5
2 | cython
3 | scipy
4 | opencv-python
5 | pillow
6 | scikit-image
7 | timm
8 | einops
9 | pandas
10 | imgaug
11 | h5py
12 | av
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/tools/__init__.py
--------------------------------------------------------------------------------
/tools/colormap.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def colormap(rgb=False):
5 | color_list = np.array(
6 | [
7 | 0.000, 0.447, 0.741,
8 | 0.850, 0.325, 0.098,
9 | 0.929, 0.694, 0.125,
10 | 0.494, 0.184, 0.556,
11 | 0.466, 0.674, 0.188,
12 | 0.301, 0.745, 0.933,
13 | 0.635, 0.078, 0.184,
14 | 0.300, 0.300, 0.300,
15 | 0.600, 0.600, 0.600,
16 | 1.000, 0.000, 0.000,
17 | 1.000, 0.500, 0.000,
18 | 0.749, 0.749, 0.000,
19 | 0.000, 1.000, 0.000,
20 | 0.000, 0.000, 1.000,
21 | 0.667, 0.000, 1.000,
22 | 0.333, 0.333, 0.000,
23 | 0.333, 0.667, 0.000,
24 | 0.333, 1.000, 0.000,
25 | 0.667, 0.333, 0.000,
26 | 0.667, 0.667, 0.000,
27 | 0.667, 1.000, 0.000,
28 | 1.000, 0.333, 0.000,
29 | 1.000, 0.667, 0.000,
30 | 1.000, 1.000, 0.000,
31 | 0.000, 0.333, 0.500,
32 | 0.000, 0.667, 0.500,
33 | 0.000, 1.000, 0.500,
34 | 0.333, 0.000, 0.500,
35 | 0.333, 0.333, 0.500,
36 | 0.333, 0.667, 0.500,
37 | 0.333, 1.000, 0.500,
38 | 0.667, 0.000, 0.500,
39 | 0.667, 0.333, 0.500,
40 | 0.667, 0.667, 0.500,
41 | 0.667, 1.000, 0.500,
42 | 1.000, 0.000, 0.500,
43 | 1.000, 0.333, 0.500,
44 | 1.000, 0.667, 0.500,
45 | 1.000, 1.000, 0.500,
46 | 0.000, 0.333, 1.000,
47 | 0.000, 0.667, 1.000,
48 | 0.000, 1.000, 1.000,
49 | 0.333, 0.000, 1.000,
50 | 0.333, 0.333, 1.000,
51 | 0.333, 0.667, 1.000,
52 | 0.333, 1.000, 1.000,
53 | 0.667, 0.000, 1.000,
54 | 0.667, 0.333, 1.000,
55 | 0.667, 0.667, 1.000,
56 | 0.667, 1.000, 1.000,
57 | 1.000, 0.000, 1.000,
58 | 1.000, 0.333, 1.000,
59 | 1.000, 0.667, 1.000,
60 | 0.167, 0.000, 0.000,
61 | 0.333, 0.000, 0.000,
62 | 0.500, 0.000, 0.000,
63 | 0.667, 0.000, 0.000,
64 | 0.833, 0.000, 0.000,
65 | 1.000, 0.000, 0.000,
66 | 0.000, 0.167, 0.000,
67 | 0.000, 0.333, 0.000,
68 | 0.000, 0.500, 0.000,
69 | 0.000, 0.667, 0.000,
70 | 0.000, 0.833, 0.000,
71 | 0.000, 1.000, 0.000,
72 | 0.000, 0.000, 0.167,
73 | 0.000, 0.000, 0.333,
74 | 0.000, 0.000, 0.500,
75 | 0.000, 0.000, 0.667,
76 | 0.000, 0.000, 0.833,
77 | 0.000, 0.000, 1.000,
78 | 0.000, 0.000, 0.000,
79 | 0.143, 0.143, 0.143,
80 | 0.286, 0.286, 0.286,
81 | 0.429, 0.429, 0.429,
82 | 0.571, 0.571, 0.571,
83 | 0.714, 0.714, 0.714,
84 | 0.857, 0.857, 0.857,
85 | 1.000, 1.000, 1.000
86 | ]
87 | ).astype(np.float32)
88 | color_list = color_list.reshape((-1, 3)) * 255
89 | if not rgb:
90 | color_list = color_list[:, ::-1]
91 | return color_list
--------------------------------------------------------------------------------
/tools/data/convert_davis_to_ytvos.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 |
4 | """
5 | ytvos structure
6 | - train
7 | - Annotations
8 | - video1
9 | - video2
10 | - JPEGImages
11 | - video1
12 | -video2
13 | meta.json
14 | - valid
15 | - Annotations
16 | - JPEGImages
17 | meta.json
18 | - meta_expressions
19 | - train
20 | meta_expressions.json
21 | - valid
22 | meta_expressions.json
23 | """
24 |
25 | def read_split_set(data_root='data/ref-davis'):
26 | set_split_path = os.path.join(data_root, "DAVIS/ImageSets/2017")
27 | # train set
28 | with open(os.path.join(set_split_path, "train.txt"), "r") as f:
29 | train_set = f.readlines()
30 | train_set = [x.strip() for x in train_set] # 60 videos
31 | # val set
32 | with open(os.path.join(set_split_path, "val.txt"), "r") as f:
33 | val_set = f.readlines()
34 | val_set = [x.strip() for x in val_set] # 30 videos
35 | return train_set, val_set # List
36 |
37 |
38 | def mv_images_to_folder(data_root='data/ref-davis', output_root='data/ref-davis'):
39 | train_img_path = os.path.join(output_root, "train/JPEGImages")
40 | train_anno_path = os.path.join(output_root, "train/Annotations")
41 | val_img_path = os.path.join(output_root, "valid/JPEGImages")
42 | val_anno_path = os.path.join(output_root, "valid/Annotations")
43 | meta_train_path = os.path.join(output_root, "meta_expressions/train")
44 | meta_val_path = os.path.join(output_root, "meta_expressions/valid")
45 | paths = [train_img_path, train_anno_path, val_img_path, val_anno_path,
46 | meta_train_path, meta_val_path]
47 | for path in paths:
48 | if not os.path.exists(path):
49 | os.makedirs(path)
50 |
51 | # 1. read the train/val split
52 | train_set, val_set = read_split_set(data_root)
53 |
54 | # 2. move images and annotations
55 | # train set
56 | for video in train_set:
57 | # move images
58 | base_img_path = os.path.join(data_root, "DAVIS/JPEGImages/480p", video)
59 | mv_cmd = f"mv {base_img_path} {train_img_path}"
60 | os.system(mv_cmd)
61 | # move annotations
62 | base_anno_path = os.path.join(data_root, "DAVIS/Annotations_unsupervised/480p", video)
63 | mv_cmd = f"mv {base_anno_path} {train_anno_path}"
64 | os.system(mv_cmd)
65 | # val set
66 | for video in val_set:
67 | # move images
68 | base_img_path = os.path.join(data_root, "DAVIS/JPEGImages/480p", video)
69 | mv_cmd = f"mv {base_img_path} {val_img_path}"
70 | os.system(mv_cmd)
71 | # move annotations
72 | base_anno_path = os.path.join(data_root, "DAVIS/Annotations_unsupervised/480p", video)
73 | mv_cmd = f"mv {base_anno_path} {val_anno_path}"
74 | os.system(mv_cmd)
75 |
76 | def create_meta_expressions(data_root='data/ref-davis', output_root='data/ref-davis'):
77 | """
78 | NOTE: expressions odd: first anno, even: full anno
79 | meta_expression.json format
80 | {
81 | "videos": {
82 | "video1: {
83 | "expressions": {
84 | "0": {
85 | "exp": "xxxxx",
86 | "obj_id": "1" (start from 1)
87 | }
88 | "1": {
89 | "exp": "xxxxx",
90 | "obj_id": "1"
91 | }
92 | }
93 | "frames": [
94 | "00000",
95 | "00001",
96 | ...
97 | ]
98 | }
99 | }
100 | }
101 | """
102 | train_img_path = os.path.join(output_root, "train/JPEGImages")
103 | val_img_path = os.path.join(output_root, "valid/JPEGImages")
104 | meta_train_path = os.path.join(output_root, "meta_expressions/train")
105 | meta_val_path = os.path.join(output_root, "meta_expressions/valid")
106 |
107 | # 1. read the train/val split
108 | train_set, val_set = read_split_set(data_root)
109 |
110 | # 2. create meta_expression.json
111 | # NOTE: there are two annotators, and each annotator have first anno and full anno, respectively
112 | def read_expressions_from_txt(file_path, encoding='utf-8'):
113 | """
114 | videos["video1"] = [
115 | {"obj_id": 1, "exp": "xxxxx"},
116 | {"obj_id": 2, "exp": "xxxxx"},
117 | {"obj_id": 3, "exp": "xxxxx"},
118 | ]
119 | """
120 | videos = {}
121 | with open(file_path, "r", encoding=encoding) as f:
122 | for idx, line in enumerate(f.readlines()):
123 | line = line.strip()
124 | video_name, obj_id = line.split()[:2]
125 | exp = ' '.join(line.split()[2:])[1:-1]
126 | # handle bad case
127 | if video_name == "clasic-car":
128 | video_name = "classic-car"
129 | elif video_name == "dog-scale":
130 | video_name = "dogs-scale"
131 | elif video_name == "motor-bike":
132 | video_name = "motorbike"
133 |
134 |
135 | if not video_name in videos.keys():
136 | videos[video_name] = []
137 | exp_dict = {
138 | "exp": exp,
139 | "obj_id": obj_id
140 | }
141 | videos[video_name].append(exp_dict)
142 |
143 | # sort the order of expressions in each video
144 | for key, value in videos.items():
145 | value = sorted(value, key = lambda e:e.__getitem__('obj_id'))
146 | videos[key] = value
147 | return videos
148 |
149 | anno1_first_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot1.txt")
150 | anno1_full_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot1_full_video.txt")
151 | anno2_first_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot2.txt")
152 | anno2_full_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot2_full_video.txt")
153 | # all videos information
154 | anno1_first = read_expressions_from_txt(anno1_first_path, encoding='utf-8')
155 | anno1_full = read_expressions_from_txt(anno1_full_path, encoding='utf-8')
156 | anno2_first = read_expressions_from_txt(anno2_first_path, encoding='latin-1')
157 | anno2_full = read_expressions_from_txt(anno2_full_path, encoding='latin-1')
158 |
159 | # 2(1). train
160 | train_videos = {} # {"video1": {}, "video2": {}, ...}, the final results to dump
161 | for video in train_set: # 60 videos
162 | video_dict = {} # for each video
163 |
164 | # store the information of video
165 | expressions = {}
166 | exp_id = 0 # start from 0
167 | for anno1_first_video, anno1_full_video, anno2_first_video, anno2_full_video in zip(
168 | anno1_first[video], anno1_full[video], anno2_first[video], anno2_full[video]):
169 | expressions[str(exp_id)] = anno1_first_video
170 | exp_id += 1
171 | expressions[str(exp_id)] = anno1_full_video
172 | exp_id += 1
173 | expressions[str(exp_id)] = anno2_first_video
174 | exp_id += 1
175 | expressions[str(exp_id)] = anno2_full_video
176 | exp_id += 1
177 | video_dict["expressions"] = expressions
178 | # read frame names for each video
179 | video_frames = os.listdir(os.path.join(train_img_path, video))
180 | video_frames = [x.split(".")[0] for x in video_frames] # remove ".jpg"
181 | video_frames.sort()
182 | video_dict["frames"] = video_frames
183 |
184 | train_videos[video] = video_dict
185 |
186 | # 2(2). val
187 | val_videos = {}
188 | for video in val_set:
189 | video_dict = {} # for each video
190 |
191 | # store the information of video
192 | expressions = {}
193 | exp_id = 0 # start from 0
194 | for anno1_first_video, anno1_full_video, anno2_first_video, anno2_full_video in zip(
195 | anno1_first[video], anno1_full[video], anno2_first[video], anno2_full[video]):
196 | expressions[str(exp_id)] = anno1_first_video
197 | exp_id += 1
198 | expressions[str(exp_id)] = anno1_full_video
199 | exp_id += 1
200 | expressions[str(exp_id)] = anno2_first_video
201 | exp_id += 1
202 | expressions[str(exp_id)] = anno2_full_video
203 | exp_id += 1
204 | video_dict["expressions"] = expressions
205 | # read frame names for each video
206 | video_frames = os.listdir(os.path.join(val_img_path, video))
207 | video_frames = [x.split(".")[0] for x in video_frames] # remove ".jpg"
208 | video_frames.sort()
209 | video_dict["frames"] = video_frames
210 |
211 | val_videos[video] = video_dict
212 |
213 | # 3. store the meta_expressions.json
214 | # train
215 | train_meta = {"videos": train_videos}
216 | with open(os.path.join(meta_train_path, "meta_expressions.json"), "w") as out:
217 | json.dump(train_meta, out)
218 | # val
219 | val_meta = {"videos": val_videos}
220 | with open(os.path.join(meta_val_path, "meta_expressions.json"), "w") as out:
221 | json.dump(val_meta, out)
222 |
223 | def create_meta_annotaions(data_root='data/ref-davis', output_root='data/ref-davis'):
224 | """
225 | NOTE: frame names are not stored compared with ytvos
226 | meta.json format
227 | {
228 | "videos": {
229 | "video1: {
230 | "objects": {
231 | "1": {"category": "bike"},
232 | "2": {"category": "person"}
233 | }
234 | }
235 | }
236 | }
237 | """
238 | out_train_path = os.path.join(output_root, "train")
239 | out_val_path = os.path.join(output_root, "valid")
240 |
241 | # read the semantic information
242 | with open(os.path.join(data_root, "DAVIS/davis_semantics.json")) as f:
243 | davis_semantics = json.load(f)
244 |
245 | # 1. read the train/val split
246 | train_set, val_set = read_split_set(data_root)
247 |
248 | # 2. create meta.json
249 | # train
250 | train_videos = {}
251 | for video in train_set:
252 | video_dict = {} # for each video
253 | video_dict["objects"] = {}
254 | num_obj = len(davis_semantics[video].keys())
255 | for obj_id in range(1, num_obj+1): # start from 1
256 | video_dict["objects"][str(obj_id)] = {"category": davis_semantics[video][str(obj_id)]}
257 | train_videos[video] = video_dict
258 |
259 | # val
260 | val_videos = {}
261 | for video in val_set:
262 | video_dict = {}
263 | video_dict["objects"] = {}
264 | num_obj = len(davis_semantics[video].keys())
265 | for obj_id in range(1, num_obj+1): # start from 1
266 | video_dict["objects"][str(obj_id)] = {"category": davis_semantics[video][str(obj_id)]}
267 | val_videos[video] = video_dict
268 |
269 | # store the meta.json file
270 | train_meta = {"videos": train_videos}
271 | with open(os.path.join(out_train_path, "meta.json"), "w") as out:
272 | json.dump(train_meta, out)
273 | val_meta = {"videos": val_videos}
274 | with open(os.path.join(out_val_path, "meta.json"), "w") as out:
275 | json.dump(val_meta, out)
276 |
277 | if __name__ == '__main__':
278 | data_root = "/mnt/data/ref-davis"
279 | output_root = "/mnt/data/ref-davis"
280 | print("Converting ref-davis to ref-youtube-vos format....")
281 | mv_images_to_folder(data_root, output_root)
282 | create_meta_expressions(data_root, output_root)
283 | create_meta_annotaions(data_root, output_root)
284 |
285 |
--------------------------------------------------------------------------------
/tools/data/convert_refexp_to_coco.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from datasets.refer import REFER
4 | import cv2
5 | from tqdm import tqdm
6 | import json
7 | import pickle
8 | import json
9 |
10 |
11 | def convert_to_coco(data_root='/mnt/data/refcoco', output_root='/mnt/data/refcoco_referformer', dataset='refcoco', dataset_split='unc'):
12 | dataset_dir = os.path.join(data_root, dataset)
13 | output_dir = os.path.join(output_root, dataset) # .json save path
14 | if not os.path.exists(output_dir):
15 | os.makedirs(output_dir)
16 |
17 | # read REFER
18 | refer = REFER(data_root, dataset, dataset_split)
19 | refs = refer.Refs
20 | anns = refer.Anns
21 | imgs = refer.Imgs
22 | cats = refer.Cats
23 | sents = refer.Sents
24 | """
25 | # create sets of mapping
26 | # 1) Refs: {ref_id: ref}
27 | # 2) Anns: {ann_id: ann}
28 | # 3) Imgs: {image_id: image}
29 | # 4) Cats: {category_id: category_name}
30 | # 5) Sents: {sent_id: sent}
31 | # 6) imgToRefs: {image_id: refs}
32 | # 7) imgToAnns: {image_id: anns}
33 | # 8) refToAnn: {ref_id: ann}
34 | # 9) annToRef: {ann_id: ref}
35 | # 10) catToRefs: {category_id: refs}
36 | # 11) sentToRef: {sent_id: ref}
37 | # 12) sentToTokens: {sent_id: tokens}
38 |
39 | Refs: List[Dict], "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences"
40 | "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent"
41 | Anns: List[Dict], "segmentation", "area", "iscrowd", "image_id", "bbox", "category_id", "id"
42 | Imgs: List[Dict], "license", "file_name", "coco_url", "height", "width", "date_captured", "flickr_url", "id"
43 | Cats: List[Dict], "supercategory", "name", "id"
44 | Sents: List[Dict], "tokens"(List), "raw", "sent_id", "sent", here the "sent_id" is consistent
45 | """
46 | print('Dataset [%s_%s] contains: ' % (dataset, dataset_split))
47 | ref_ids = refer.getRefIds()
48 | image_ids = refer.getImgIds()
49 | print('There are %s expressions for %s refereed objects in %s images.' % (
50 | len(refer.Sents), len(ref_ids), len(image_ids)))
51 |
52 | print('\nAmong them:')
53 | if dataset == 'refcoco':
54 | splits = ['train', 'val', 'testA', 'testB']
55 | elif dataset == 'refcoco+':
56 | splits = ['train', 'val', 'testA', 'testB']
57 | elif dataset == 'refcocog':
58 | splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now.
59 |
60 | for split in splits:
61 | ref_ids = refer.getRefIds(split=split)
62 | print(' %s referred objects are in split [%s].' % (len(ref_ids), split))
63 |
64 | with open(os.path.join(dataset_dir, "instances.json"), "r") as f:
65 | ann_json = json.load(f)
66 |
67 | # 1. for each split: train, val...
68 | for split in splits:
69 | max_length = 0 # max length of a sentence
70 |
71 | coco_ann = {
72 | "info": "",
73 | "licenses": "",
74 | "images": [], # each caption is a image sample
75 | "annotations": [],
76 | "categories": []
77 | }
78 | coco_ann['info'], coco_ann['licenses'], coco_ann['categories'] = \
79 | ann_json['info'], ann_json['licenses'], ann_json['categories']
80 |
81 | num_images = 0 # each caption is a sample, create a "images" and a "annotations", since each image has one box
82 | ref_ids = refer.getRefIds(split=split)
83 | # 2. for each referred object
84 | for i in tqdm(ref_ids):
85 | ref = refs[i]
86 | # "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences"
87 | # "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent"
88 | img = imgs[ref["image_id"]]
89 | ann = anns[ref["ann_id"]]
90 |
91 | # 3. for each sentence, which is a sample
92 | for sentence in ref["sentences"]:
93 | num_images += 1
94 | # append image info
95 | image_info = {
96 | "file_name": img["file_name"],
97 | "height": img["height"],
98 | "width": img["width"],
99 | "original_id": img["id"],
100 | "id": num_images,
101 | "caption": sentence["sent"],
102 | "dataset_name": dataset
103 | }
104 | coco_ann["images"].append(image_info)
105 |
106 | # append annotation info
107 | ann_info = {
108 | "segmentation": ann["segmentation"],
109 | "area": ann["area"],
110 | "iscrowd": ann["iscrowd"],
111 | "bbox": ann["bbox"],
112 | "image_id": num_images,
113 | "category_id": ann["category_id"],
114 | "id": num_images,
115 | "original_id": ann["id"]
116 | }
117 | coco_ann["annotations"].append(ann_info)
118 |
119 | max_length = max(max_length, len(sentence["tokens"]))
120 |
121 | print("Total expression: {} in split {}".format(num_images, split))
122 | print("Max sentence length of the split: ", max_length)
123 | # save the json file
124 | save_file = "instances_{}_{}.json".format(dataset, split)
125 | with open(os.path.join(output_dir, save_file), 'w') as f:
126 | json.dump(coco_ann, f)
127 |
128 |
129 | if __name__ == '__main__':
130 | datasets = ["refcoco", "refcoco+", "refcocog"]
131 | datasets_split = ["unc", "unc", "umd"]
132 | for (dataset, dataset_split) in zip(datasets, datasets_split):
133 | convert_to_coco(dataset=dataset, dataset_split=dataset_split)
134 | print("")
135 |
136 | """
137 | # original mapping
138 | {'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 'airplane': 5, 'bus': 6, 'train': 7, 'truck': 8, 'boat': 9,
139 | 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13, 'parking meter': 14, 'bench': 15, 'bird': 16, 'cat': 17,
140 | 'dog': 18, 'horse': 19, 'sheep': 20, 'cow': 21, 'elephant': 22, 'bear': 23, 'zebra': 24, 'giraffe': 25, 'backpack': 27,
141 | 'umbrella': 28, 'handbag': 31, 'tie': 32, 'suitcase': 33, 'frisbee': 34, 'skis': 35, 'snowboard': 36, 'sports ball': 37,
142 | 'kite': 38, 'baseball bat': 39, 'baseball glove': 40, 'skateboard': 41, 'surfboard': 42, 'tennis racket': 43, 'bottle': 44,
143 | 'wine glass': 46, 'cup': 47, 'fork': 48, 'knife': 49, 'spoon': 50, 'bowl': 51, 'banana': 52, 'apple': 53, 'sandwich': 54,
144 | 'orange': 55, 'broccoli': 56, 'carrot': 57, 'hot dog': 58, 'pizza': 59, 'donut': 60, 'cake': 61, 'chair': 62, 'couch': 63,
145 | 'potted plant': 64, 'bed': 65, 'dining table': 67, 'toilet': 70, 'tv': 72, 'laptop': 73, 'mouse': 74, 'remote': 75,
146 | 'keyboard': 76, 'cell phone': 77, 'microwave': 78, 'oven': 79, 'toaster': 80, 'sink': 81, 'refrigerator': 82, 'book': 84,
147 | 'clock': 85, 'vase': 86, 'scissors': 87, 'teddy bear': 88, 'hair drier': 89, 'toothbrush': 90}
148 |
149 | """
--------------------------------------------------------------------------------
/tools/load_pretrained_weights.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | def pre_trained_model_to_finetune(checkpoint, args):
5 | if 'model' in checkpoint.keys():
6 | checkpoint = checkpoint['model']
7 | # only delete the class_embed since the finetuned dataset has different num_classes
8 | num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers
9 | for l in range(num_layers):
10 | if "class_embed.{}.weight" in checkpoint.keys():
11 | del checkpoint["class_embed.{}.weight".format(l)]
12 | del checkpoint["class_embed.{}.bias".format(l)]
13 | # # determine backbone.0
14 | # flag = 0
15 | # for key in checkpoint.keys():
16 | # if 'backbone' in key:
17 | # flag = 1
18 | # if flag == 0:
19 | # new_ckpt = OrderedDict()
20 | # for k, v in checkpoint.items():
21 | # if 'patch_embed' in k or 'attn.relative_position_' in k:
22 | # continue
23 | # new_ckpt['backbone.0.body.' + k] = v
24 | # checkpoint = new_ckpt
25 |
26 | else:
27 | checkpoint = checkpoint['state_dict']
28 | new_ckpt = OrderedDict()
29 | for k, v in checkpoint.items():
30 | if 'patch_embed' in k:
31 | continue
32 | new_ckpt[k.replace('backbone', 'backbone.0.body')] = v
33 | checkpoint = new_ckpt
34 | return checkpoint
35 |
--------------------------------------------------------------------------------
/tools/warmup_poly_lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | from typing import List
4 |
5 | def _get_warmup_factor_at_iter(
6 | method: str, iter: int, warmup_iters: int, warmup_factor: float
7 | ) -> float:
8 | """
9 | Return the learning rate warmup factor at a specific iteration.
10 | See :paper:`ImageNet in 1h` for more details.
11 | Args:
12 | method (str): warmup method; either "constant" or "linear".
13 | iter (int): iteration at which to calculate the warmup factor.
14 | warmup_iters (int): the number of warmup iterations.
15 | warmup_factor (float): the base warmup factor (the meaning changes according
16 | to the method used).
17 | Returns:
18 | float: the effective warmup factor at the given iteration.
19 | """
20 | if iter >= warmup_iters:
21 | return 1.0
22 |
23 | if method == "constant":
24 | return warmup_factor
25 | elif method == "linear":
26 | alpha = iter / warmup_iters
27 | return warmup_factor * (1 - alpha) + alpha
28 | else:
29 | raise ValueError("Unknown warmup method: {}".format(method))
30 |
31 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler):
32 | """
33 | Poly learning rate schedule used to train DeepLab.
34 | Paper: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets,
35 | Atrous Convolution, and Fully Connected CRFs.
36 | Reference: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/utils/train_utils.py#L337 # noqa
37 | """
38 |
39 | def __init__(
40 | self,
41 | optimizer: torch.optim.Optimizer,
42 | max_iters: int,
43 | warmup_factor: float = 0.001,
44 | warmup_iters: int = 1000,
45 | warmup_method: str = "linear",
46 | last_epoch: int = -1,
47 | power: float = 0.9,
48 | constant_ending: float = 0.0,
49 | ):
50 | self.max_iters = max_iters
51 | self.warmup_factor = warmup_factor
52 | self.warmup_iters = warmup_iters
53 | self.warmup_method = warmup_method
54 | self.power = power
55 | self.constant_ending = constant_ending
56 | super().__init__(optimizer, last_epoch)
57 |
58 | def get_lr(self) -> List[float]:
59 | warmup_factor = _get_warmup_factor_at_iter(
60 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
61 | )
62 | if self.constant_ending > 0 and warmup_factor == 1.0:
63 | # Constant ending lr.
64 | if (
65 | math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
66 | < self.constant_ending
67 | ):
68 | return [base_lr * self.constant_ending for base_lr in self.base_lrs]
69 | return [
70 | base_lr * warmup_factor * math.pow((1.0 - self.last_epoch / self.max_iters), self.power)
71 | for base_lr in self.base_lrs
72 | ]
73 |
74 | def _compute_values(self) -> List[float]:
75 | # The new interface
76 | return self.get_lr()
77 |
78 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/util/__init__.py
--------------------------------------------------------------------------------
/util/box_ops.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities for bounding box manipulation and GIoU.
3 | """
4 | import torch
5 | from torchvision.ops.boxes import box_area
6 |
7 | def clip_iou(boxes1,boxes2):
8 | area1 = box_area(boxes1)
9 | area2 = box_area(boxes2)
10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2])
11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:])
12 | wh = (rb - lt).clamp(min=0)
13 | inter = wh[:,0] * wh[:,1]
14 | union = area1 + area2 - inter
15 | iou = (inter + 1e-6) / (union+1e-6)
16 | return iou
17 |
18 | def multi_iou(boxes1, boxes2):
19 | lt = torch.max(boxes1[...,:2], boxes2[...,:2])
20 | rb = torch.min(boxes1[...,2:], boxes2[...,2:])
21 | wh = (rb - lt).clamp(min=0)
22 | wh_1 = boxes1[...,2:] - boxes1[...,:2]
23 | wh_2 = boxes2[...,2:] - boxes2[...,:2]
24 | inter = wh[...,0] * wh[...,1]
25 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter
26 | iou = (inter + 1e-6) / (union + 1e-6)
27 | return iou
28 |
29 | def box_cxcywh_to_xyxy(x):
30 | x_c, y_c, w, h = x.unbind(-1)
31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
32 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
33 | return torch.stack(b, dim=-1)
34 |
35 |
36 | def box_xyxy_to_cxcywh(x):
37 | x0, y0, x1, y1 = x.unbind(-1)
38 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
39 | (x1 - x0), (y1 - y0)]
40 | return torch.stack(b, dim=-1)
41 |
42 |
43 | # modified from torchvision to also return the union
44 | def box_iou(boxes1, boxes2):
45 | area1 = box_area(boxes1)
46 | area2 = box_area(boxes2)
47 |
48 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
49 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
50 |
51 | wh = (rb - lt).clamp(min=0) # [N,M,2]
52 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
53 |
54 | union = area1[:, None] + area2 - inter
55 |
56 | iou = (inter+1e-6) / (union+1e-6)
57 | return iou, union
58 |
59 |
60 | def generalized_box_iou(boxes1, boxes2):
61 | """
62 | Generalized IoU from https://giou.stanford.edu/
63 |
64 | The boxes should be in [x0, y0, x1, y1] format
65 |
66 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
67 | and M = len(boxes2)
68 | """
69 | # degenerate boxes gives inf / nan results
70 | # so do an early check
71 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
72 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
73 | iou, union = box_iou(boxes1, boxes2)
74 |
75 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
76 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
77 |
78 | wh = (rb - lt).clamp(min=0) # [N,M,2]
79 | area = wh[:, :, 0] * wh[:, :, 1]
80 |
81 | return iou - ((area - union) + 1e-6) / (area + 1e-6)
82 |
83 |
84 | def masks_to_boxes(masks):
85 | """Compute the bounding boxes around the provided masks
86 |
87 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
88 |
89 | Returns a [N, 4] tensors, with the boxes in xyxy format
90 | """
91 | if masks.numel() == 0:
92 | return torch.zeros((0, 4), device=masks.device)
93 |
94 | h, w = masks.shape[-2:]
95 |
96 | y = torch.arange(0, h, dtype=torch.float)
97 | x = torch.arange(0, w, dtype=torch.float)
98 | y, x = torch.meshgrid(y, x)
99 |
100 | x_mask = (masks * x.unsqueeze(0))
101 | x_max = x_mask.flatten(1).max(-1)[0]
102 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
103 |
104 | y_mask = (masks * y.unsqueeze(0))
105 | y_max = y_mask.flatten(1).max(-1)[0]
106 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
107 |
108 | return torch.stack([x_min, y_min, x_max, y_max], 1)
109 |
--------------------------------------------------------------------------------