├── README.md ├── demo.py ├── demo └── demo_examples │ ├── 00000.jpg │ ├── 00005.jpg │ ├── 00010.jpg │ ├── 00015.jpg │ ├── 00020.jpg │ ├── 00025.jpg │ ├── 00030.jpg │ ├── 00035.jpg │ ├── 00040.jpg │ ├── 00045.jpg │ ├── 00050.jpg │ ├── 00055.jpg │ ├── 00060.jpg │ ├── 00065.jpg │ ├── 00070.jpg │ ├── 00075.jpg │ ├── 00080.jpg │ ├── 00085.jpg │ └── 00090.jpg ├── engine.py ├── illustration.jpg ├── inference_ytvos.py ├── inference_ytvos_segm.py ├── main.py ├── models ├── __init__.py ├── amm_resnet.py ├── backbone.py ├── criterion.py ├── cycle.py ├── deformable_transformer.py ├── matcher.py ├── ops │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn_cpu.h │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp │ └── test.py ├── position_encoding.py ├── postprocessors.py ├── referformer.py ├── segmentation.py ├── swin_transformer.py ├── vector_quantitizer.py └── video_swin_transformer.py ├── opts.py ├── requirements.txt ├── tools ├── __init__.py ├── colormap.py ├── data │ ├── convert_davis_to_ytvos.py │ └── convert_refexp_to_coco.py ├── load_pretrained_weights.py └── warmup_poly_lr_scheduler.py └── util ├── __init__.py ├── box_ops.py └── misc.py /README.md: -------------------------------------------------------------------------------- 1 | > [**Towards Robust Referring Video Object Segmentation with Cyclic Relational Consistency**](https://arxiv.org/abs/2207.01203) 2 | > 3 | > Xiang Li, Jinglu Wang, Xiaohao Xu, Xiao Li, Bhiksha Raj, Yan Lu 4 | 5 |

6 | 7 | # Updates 8 | - **(2023-05-30)** Code released. 9 | - **(2023-07-13)** R2VOS is accepted to ICCV 2023! 10 | 11 | # Install 12 | 13 | ``` 14 | conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch 15 | pip install -r requirements.txt 16 | pip install 'git+https://github.com/facebookresearch/fvcore' 17 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 18 | cd models/ops 19 | python setup.py build install 20 | cd ../.. 21 | ``` 22 | 23 | # Docker 24 | You may try [docker](https://hub.docker.com/r/ang9867/refer) to quick start. 25 | 26 | # Weights 27 | Please download and put the [checkpoint.pth](https://drive.google.com/file/d/1gknDDMxWKqZ7yPuTh1fome1-Ba4f_G9K/view?usp=share_link) in the main folder. 28 | 29 | # Run demo: 30 | Inference on images in the demo/demo_examples. 31 | ``` 32 | python demo.py --with_box_refine --binary --freeze_text_encoder --output_dir=output/demo --resume=checkpoint.pth --backbone resnet50 --ngpu 1 --use_cycle --mix_query --neg_cls --is_eval --use_cls --demo_exp 'a big track on the road' --demo_path 'demo/demo_examples' 33 | ``` 34 | 35 | # Inference: 36 | If you want to evaluate on Ref-YTVOS, you may try inference_ytvos.py or inference_ytvos_segm.py if you encounter OOM for the entire video inference. 37 | ``` 38 | python inference_ytvos.py --with_box_refine --binary --freeze_text_encoder --output_dir=output/eval --resume=checkpoint.pth --backbone resnet50 --ngpu 1 --use_cycle --mix_query --neg_cls --is_eval --use_cls --ytvos_path=/data/ref-ytvos 39 | ``` 40 | # Related works for robust multimodal video segmentation: 41 | > [R2-Bench: Benchmarking the Robustness of Referring Perception Models under Perturbations 42 | ](https://arxiv.org/abs/2403.04924), Arxiv 2024 43 | 44 | > [Towards Robust Audiovisual Segmentation in Complex Environments with Quantization-based Semantic Decomposition](https://arxiv.org/abs/2310.00132), CVPR 2024 45 | ## Citation 46 | ``` 47 | @inproceedings{li2023robust, 48 | title={Robust referring video object segmentation with cyclic structural consensus}, 49 | author={Li, Xiang and Wang, Jinglu and Xu, Xiaohao and Li, Xiao and Raj, Bhiksha and Lu, Yan}, 50 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 51 | pages={22236--22245}, 52 | year={2023} 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inference code for ReferFormer, on Ref-Youtube-VOS 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | ''' 5 | import argparse 6 | import json 7 | import random 8 | import sys 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import util.misc as utils 16 | from models import build_model 17 | import torchvision.transforms as T 18 | import matplotlib.pyplot as plt 19 | import os 20 | import cv2 21 | from PIL import Image, ImageDraw 22 | import math 23 | import torch.nn.functional as F 24 | import json 25 | 26 | import opts 27 | from tqdm import tqdm 28 | 29 | import multiprocessing as mp 30 | import threading 31 | import glob 32 | 33 | 34 | from tools.colormap import colormap 35 | 36 | 37 | # colormap 38 | color_list = colormap() 39 | color_list = color_list.astype('uint8').tolist() 40 | 41 | def main(args): 42 | args.masks = True 43 | args.batch_size == 1 44 | print("Inference only supports for batch size = 1") 45 | 46 | global transform 47 | transform = T.Compose([ 48 | T.Resize(args.inf_res), 49 | T.ToTensor(), 50 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 51 | ]) 52 | 53 | # fix the seed for reproducibility 54 | seed = args.seed + utils.get_rank() 55 | torch.manual_seed(seed) 56 | np.random.seed(seed) 57 | random.seed(seed) 58 | 59 | # save path 60 | output_dir = args.output_dir 61 | save_path_prefix = os.path.join(output_dir) 62 | if not os.path.exists(save_path_prefix): 63 | os.makedirs(save_path_prefix) 64 | 65 | global result_dict 66 | result_dict = mp.Manager().dict() 67 | frames = sorted(glob.glob(args.demo_path+'/*')) 68 | sub_processor(0, args, args.demo_exp, frames, save_path_prefix) 69 | 70 | result_dict = dict(result_dict) 71 | num_all_frames_gpus = 0 72 | for pid, num_all_frames in result_dict.items(): 73 | num_all_frames_gpus += num_all_frames 74 | 75 | def sub_processor(pid, args, exp, frames, save_path_prefix): 76 | torch.cuda.set_device(pid) 77 | 78 | # model 79 | model, criterion, _ = build_model(args) 80 | device = args.device 81 | model.to(device) 82 | 83 | model_without_ddp = model 84 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 85 | 86 | if pid == 0: 87 | print('number of params:', n_parameters) 88 | 89 | if args.resume: 90 | checkpoint = torch.load(args.resume, map_location='cpu') 91 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 92 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 93 | if len(missing_keys) > 0: 94 | print('Missing Keys: {}'.format(missing_keys)) 95 | if len(unexpected_keys) > 0: 96 | print('Unexpected Keys: {}'.format(unexpected_keys)) 97 | else: 98 | raise ValueError('Please specify the checkpoint for inference.') 99 | 100 | 101 | # start inference 102 | num_all_frames = 0 103 | model.eval() 104 | 105 | sentence_features = [] 106 | pseudo_sentence_features = [] 107 | video_name = 'demo' 108 | # exp = meta[i]["exp"] 109 | # # exp = 'a dog is with its puppies on the cloth' 110 | # # TODO: temp 111 | # frames = meta[i]["frames"] 112 | # frames = [f'/home/mcg/ReferFormer/demo/frames_{fid}.jpg' for fid in range(1,2)] 113 | 114 | video_len = len(frames) 115 | # store images 116 | imgs = [] 117 | for t in range(video_len): 118 | frame = frames[t] 119 | img_path = os.path.join(frame) 120 | img = Image.open(img_path).convert('RGB') 121 | origin_w, origin_h = img.size 122 | imgs.append(transform(img)) # list[img] 123 | 124 | imgs = torch.stack(imgs, dim=0).to(args.device) # [video_len, 3, h, w] 125 | img_h, img_w = imgs.shape[-2:] 126 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device) 127 | target = {"size": size} 128 | 129 | with torch.no_grad(): 130 | outputs = model([imgs], [exp], [target]) 131 | 132 | pred_logits = outputs["pred_logits"][0] 133 | pred_boxes = outputs["pred_boxes"][0] 134 | pred_masks = outputs["pred_masks"][0] 135 | pred_ref_points = outputs["reference_points"][0] 136 | text_sentence_features = outputs['sentence_feature'] 137 | if args.use_cycle: 138 | pseudo_text_sentence_features = outputs['pseudo_sentence_feature'] 139 | # anchor = outputs['negative_anchor'] 140 | sentence_features.append(text_sentence_features) 141 | pseudo_sentence_features.append(pseudo_text_sentence_features) 142 | # print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2)) 143 | # print(anchor) 144 | # according to pred_logits, select the query index 145 | pred_scores = pred_logits.sigmoid() # [t, q, k] 146 | pred_score = pred_scores 147 | pred_scores = pred_scores.mean(0) # [q, k] 148 | max_scores, _ = pred_scores.max(-1) # [q,] 149 | # print(max_scores) 150 | _, max_ind = max_scores.max(-1) # [1,] 151 | max_inds = max_ind.repeat(video_len) 152 | pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w] 153 | pred_masks = pred_masks.unsqueeze(0) 154 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False) 155 | if args.save_prob: 156 | pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy() 157 | else: 158 | pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy() 159 | if args.use_score: 160 | pred_score = pred_score[range(video_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1) 161 | pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks 162 | 163 | # store the video results 164 | all_pred_logits = pred_logits[range(video_len), max_inds].sigmoid().cpu().numpy() 165 | all_pred_boxes = pred_boxes[range(video_len), max_inds] 166 | all_pred_ref_points = pred_ref_points[range(video_len), max_inds] 167 | all_pred_masks = pred_masks 168 | 169 | save_path = os.path.join(save_path_prefix) 170 | if not os.path.exists(save_path): 171 | os.makedirs(save_path) 172 | for j in range(video_len): 173 | frame_name = frames[j] 174 | confidence = all_pred_logits[j] 175 | mask = all_pred_masks[j].astype(np.float32) 176 | save_file = os.path.join(save_path, f"{j}" + ".png") 177 | # print(save_file) 178 | if 'pair_logits' in outputs.keys() and args.use_cls: 179 | if outputs['pair_logits'].cpu().numpy() >= 0.5: 180 | print('This is a negative pair, disalignment degree:', outputs['pair_logits'].cpu().numpy().item()) 181 | else: 182 | print('This is a positive pair, disalignment degree:', outputs['pair_logits'].cpu().numpy().item()) 183 | mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1 184 | mask = Image.fromarray(mask * 255).convert('L') 185 | mask.save(save_file) 186 | print(f'Results saved to {save_path}') 187 | result_dict[str(pid)] = num_all_frames 188 | 189 | 190 | # visuaize functions 191 | def box_cxcywh_to_xyxy(x): 192 | x_c, y_c, w, h = x.unbind(1) 193 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 194 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 195 | return torch.stack(b, dim=1) 196 | 197 | def rescale_bboxes(out_bbox, size): 198 | img_w, img_h = size 199 | b = box_cxcywh_to_xyxy(out_bbox) 200 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 201 | return b 202 | 203 | 204 | # Visualization functions 205 | def draw_reference_points(draw, reference_points, img_size, color): 206 | W, H = img_size 207 | for i, ref_point in enumerate(reference_points): 208 | init_x, init_y = ref_point 209 | x, y = W * init_x, H * init_y 210 | cur_color = color 211 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4) 212 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4) 213 | 214 | def draw_sample_points(draw, sample_points, img_size, color_list): 215 | alpha = 255 216 | for i, samples in enumerate(sample_points): 217 | for sample in samples: 218 | x, y = sample 219 | cur_color = color_list[i % len(color_list)][::-1] 220 | cur_color += [alpha] 221 | draw.ellipse((x-2, y-2, x+2, y+2), 222 | fill=tuple(cur_color), outline=tuple(cur_color), width=1) 223 | 224 | def vis_add_mask(img, mask, color): 225 | origin_img = np.asarray(img.convert('RGB')).copy() 226 | color = np.array(color) 227 | 228 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np 229 | mask = mask > 0.5 230 | 231 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5 232 | origin_img = Image.fromarray(origin_img) 233 | return origin_img 234 | 235 | 236 | 237 | if __name__ == '__main__': 238 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()]) 239 | args = parser.parse_args() 240 | main(args) 241 | -------------------------------------------------------------------------------- /demo/demo_examples/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00000.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00005.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00010.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00015.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00020.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00025.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00030.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00035.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00040.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00045.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00045.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00050.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00055.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00060.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00065.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00065.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00070.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00075.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00075.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00080.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00080.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00085.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00085.jpg -------------------------------------------------------------------------------- /demo/demo_examples/00090.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/demo/demo_examples/00090.jpg -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train and eval functions used in main.py 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import math 6 | from models import postprocessors 7 | import os 8 | import sys 9 | from typing import Iterable 10 | 11 | import torch 12 | import torch.distributed as dist 13 | 14 | import util.misc as utils 15 | from datasets.coco_eval import CocoEvaluator 16 | from datasets.refexp_eval import RefExpEvaluator 17 | 18 | from pycocotools.coco import COCO 19 | from pycocotools.cocoeval import COCOeval 20 | from datasets.a2d_eval import calculate_precision_at_k_and_iou_metrics 21 | 22 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 23 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 24 | device: torch.device, lr_scheduler, epoch: int, max_norm: float = 0): 25 | model.train() 26 | criterion.train() 27 | metric_logger = utils.MetricLogger(delimiter=" ") 28 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 10 31 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 32 | samples = samples.to(device) 33 | captions = [t["caption"] for t in targets] 34 | targets = utils.targets_to(targets, device) 35 | 36 | outputs = model(samples, captions, targets) 37 | loss_dict = criterion(outputs, targets) 38 | weight_dict = criterion.weight_dict 39 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 40 | 41 | # reduce losses over all GPUs for logging purposes 42 | loss_dict_reduced = utils.reduce_dict(loss_dict) 43 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 44 | for k, v in loss_dict_reduced.items()} 45 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 46 | for k, v in loss_dict_reduced.items() if k in weight_dict} 47 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 48 | loss_value = losses_reduced_scaled.item() 49 | 50 | if not math.isfinite(loss_value): 51 | print("Loss is {}, stopping training".format(loss_value)) 52 | print(loss_dict_reduced) 53 | sys.exit(1) 54 | optimizer.zero_grad() 55 | losses.backward() 56 | if max_norm > 0: 57 | grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 58 | else: 59 | grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm) 60 | optimizer.step() 61 | # lr_scheduler.step() 62 | 63 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 64 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 65 | metric_logger.update(grad_norm=grad_total_norm) 66 | 67 | # gather the stats from all processes 68 | metric_logger.synchronize_between_processes() 69 | print("Averaged stats:", metric_logger) 70 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 71 | 72 | 73 | @torch.no_grad() 74 | def evaluate(model, criterion, postprocessors, data_loader, evaluator_list, device, args): 75 | model.eval() 76 | criterion.eval() 77 | 78 | metric_logger = utils.MetricLogger(delimiter=" ") 79 | header = 'Test:' 80 | 81 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 82 | samples = samples.to(device) 83 | captions = [t["caption"] for t in targets] 84 | targets = utils.targets_to(targets, device) 85 | outputs = model(samples, captions, targets) 86 | loss_dict = criterion(outputs, targets) 87 | weight_dict = criterion.weight_dict 88 | 89 | # reduce losses over all GPUs for logging purposes 90 | loss_dict_reduced = utils.reduce_dict(loss_dict) 91 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 92 | for k, v in loss_dict_reduced.items() if k in weight_dict} 93 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 94 | for k, v in loss_dict_reduced.items()} 95 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 96 | **loss_dict_reduced_scaled, 97 | **loss_dict_reduced_unscaled) 98 | 99 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 100 | results = postprocessors['bbox'](outputs, orig_target_sizes) 101 | if 'segm' in postprocessors.keys(): 102 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 103 | results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes) 104 | res = {target['image_id']: output for target, output in zip(targets, results)} 105 | 106 | for evaluator in evaluator_list: 107 | evaluator.update(res) 108 | 109 | # gather the stats from all processes 110 | metric_logger.synchronize_between_processes() 111 | print("Averaged stats:", metric_logger) 112 | for evaluator in evaluator_list: 113 | evaluator.synchronize_between_processes() 114 | 115 | # accumulate predictions from all images 116 | refexp_res = None 117 | for evaluator in evaluator_list: 118 | if isinstance(evaluator, CocoEvaluator): 119 | evaluator.accumulate() 120 | evaluator.summarize() 121 | elif isinstance(evaluator, RefExpEvaluator): 122 | refexp_res = evaluator.summarize() 123 | 124 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 125 | 126 | # update stats 127 | for evaluator in evaluator_list: 128 | if isinstance(evaluator, CocoEvaluator): 129 | if "bbox" in postprocessors.keys(): 130 | stats["coco_eval_bbox"] = evaluator.coco_eval["bbox"].stats.tolist() 131 | if "segm" in postprocessors.keys(): 132 | stats["coco_eval_masks"] = evaluator.coco_eval["segm"].stats.tolist() 133 | if refexp_res is not None: 134 | stats.update(refexp_res) 135 | 136 | return stats 137 | 138 | 139 | @torch.no_grad() 140 | def evaluate_a2d(model, data_loader, postprocessor, device, args): 141 | model.eval() 142 | predictions = [] 143 | metric_logger = utils.MetricLogger(delimiter=" ") 144 | header = 'Test:' 145 | 146 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 147 | image_ids = [t['image_id'] for t in targets] 148 | 149 | samples = samples.to(device) 150 | captions = [t["caption"] for t in targets] 151 | targets = utils.targets_to(targets, device) 152 | 153 | outputs = model(samples, captions, targets) 154 | 155 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 156 | target_sizes = torch.stack([t["size"] for t in targets], dim=0) 157 | processed_outputs = postprocessor(outputs, orig_target_sizes, target_sizes) 158 | 159 | for p, image_id in zip(processed_outputs, image_ids): 160 | for s, m in zip(p['scores'], p['rle_masks']): 161 | predictions.append({'image_id': image_id, 162 | 'category_id': 1, # dummy label, as categories are not predicted in ref-vos 163 | 'segmentation': m, 164 | 'score': s.item()}) 165 | 166 | # gather and merge predictions from all gpus 167 | gathered_pred_lists = utils.all_gather(predictions) 168 | predictions = [p for p_list in gathered_pred_lists for p in p_list] 169 | # evaluation 170 | eval_metrics = {} 171 | if utils.is_main_process(): 172 | if args.dataset_file == 'a2d': 173 | coco_gt = COCO(os.path.join(args.a2d_path, 'a2d_sentences_test_annotations_in_coco_format.json')) 174 | elif args.dataset_file == 'jhmdb': 175 | coco_gt = COCO(os.path.join(args.jhmdb_path, 'jhmdb_sentences_gt_annotations_in_coco_format.json')) 176 | elif args.dataset_file == 'refcocoVideo': 177 | coco_gt = COCO(os.path.join(args.coco_path, 'refcocoVideo/finetune_refcoco_val.json')) 178 | else: 179 | raise NotImplementedError 180 | coco_pred = coco_gt.loadRes(predictions) 181 | coco_eval = COCOeval(coco_gt, coco_pred, iouType='segm') 182 | coco_eval.params.useCats = 0 # ignore categories as they are not predicted in ref-vos task 183 | coco_eval.evaluate() 184 | coco_eval.accumulate() 185 | coco_eval.summarize() 186 | ap_labels = ['mAP 0.5:0.95', 'AP 0.5', 'AP 0.75', 'AP 0.5:0.95 S', 'AP 0.5:0.95 M', 'AP 0.5:0.95 L'] 187 | ap_metrics = coco_eval.stats[:6] 188 | eval_metrics = {l: m for l, m in zip(ap_labels, ap_metrics)} 189 | # Precision and IOU 190 | precision_at_k, overall_iou, mean_iou = calculate_precision_at_k_and_iou_metrics(coco_gt, coco_pred) 191 | eval_metrics.update({f'P@{k}': m for k, m in zip([0.5, 0.6, 0.7, 0.8, 0.9], precision_at_k)}) 192 | eval_metrics.update({'overall_iou': overall_iou, 'mean_iou': mean_iou}) 193 | print(eval_metrics) 194 | 195 | # sync all processes before starting a new epoch or exiting 196 | dist.barrier() 197 | return eval_metrics 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /illustration.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/illustration.jpg -------------------------------------------------------------------------------- /inference_ytvos.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inference code for ReferFormer, on Ref-Youtube-VOS 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | ''' 5 | import argparse 6 | import json 7 | import random 8 | import sys 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import util.misc as utils 16 | from models import build_model 17 | import torchvision.transforms as T 18 | import matplotlib.pyplot as plt 19 | import os 20 | import cv2 21 | from PIL import Image, ImageDraw 22 | import math 23 | import torch.nn.functional as F 24 | import json 25 | 26 | import opts 27 | from tqdm import tqdm 28 | 29 | import multiprocessing as mp 30 | import threading 31 | 32 | 33 | from tools.colormap import colormap 34 | 35 | 36 | # colormap 37 | color_list = colormap() 38 | color_list = color_list.astype('uint8').tolist() 39 | 40 | def main(args): 41 | args.masks = True 42 | args.batch_size == 1 43 | print("Inference only supports for batch size = 1") 44 | 45 | global transform 46 | transform = T.Compose([ 47 | T.Resize(args.inf_res), 48 | T.ToTensor(), 49 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 50 | ]) 51 | 52 | # fix the seed for reproducibility 53 | seed = args.seed + utils.get_rank() 54 | torch.manual_seed(seed) 55 | np.random.seed(seed) 56 | random.seed(seed) 57 | 58 | split = args.split 59 | # save path 60 | output_dir = args.output_dir 61 | save_path_prefix = os.path.join(output_dir, split) 62 | if not os.path.exists(save_path_prefix): 63 | os.makedirs(save_path_prefix) 64 | 65 | save_visualize_path_prefix = os.path.join(output_dir, split + '_images') 66 | if args.visualize: 67 | if not os.path.exists(save_visualize_path_prefix): 68 | os.makedirs(save_visualize_path_prefix) 69 | if args.val_type: 70 | print(f'Warning: use robust rvos validation set, type: {args.val_type}') 71 | assert args.val_type in ['_random', '_color', '_old_color', '_object', '_action', '_dyn', '_1', '_3', '_5'], \ 72 | f'Unknown val type for robust rvos, {args.val_type}' 73 | # load data 74 | root = Path(args.ytvos_path) # data/ref-youtube-vos 75 | img_folder = os.path.join(root, split, f"JPEGImages{args.val_type}") 76 | meta_file = os.path.join(root, "meta_expressions", split, f"meta_expressions{args.val_type}.json") 77 | with open(meta_file, "r") as f: 78 | data = json.load(f)["videos"] 79 | valid_test_videos = set(data.keys()) 80 | # for some reasons the competition's validation expressions dict contains both the validation (202) & 81 | # test videos (305). so we simply load the test expressions dict and use it to filter out the test videos from 82 | # the validation expressions dict: 83 | test_meta_file = os.path.join(root, "meta_expressions", "test", "meta_expressions.json") 84 | with open(test_meta_file, 'r') as f: 85 | test_data = json.load(f)['videos'] 86 | test_videos = set(test_data.keys()) 87 | valid_videos = valid_test_videos - test_videos 88 | video_list = sorted([video for video in valid_videos]) 89 | if 'color' not in args.val_type: 90 | assert len(video_list) == 202 or len(video_list) == 202+305, 'error: incorrect number of validation videos' 91 | 92 | # create subprocess 93 | thread_num = args.ngpu 94 | global result_dict 95 | result_dict = mp.Manager().dict() 96 | 97 | processes = [] 98 | lock = threading.Lock() 99 | 100 | video_num = len(video_list) 101 | per_thread_video_num = video_num // thread_num 102 | 103 | start_time = time.time() 104 | print('Start inference') 105 | for i in range(thread_num): 106 | if i == thread_num - 1: 107 | sub_video_list = video_list[i * per_thread_video_num:] 108 | else: 109 | sub_video_list = video_list[i * per_thread_video_num: (i + 1) * per_thread_video_num] 110 | p = mp.Process(target=sub_processor, args=(lock, i, args, data, 111 | save_path_prefix, save_visualize_path_prefix, 112 | img_folder, sub_video_list)) 113 | p.start() 114 | processes.append(p) 115 | 116 | for p in processes: 117 | p.join() 118 | 119 | end_time = time.time() 120 | total_time = end_time - start_time 121 | 122 | result_dict = dict(result_dict) 123 | num_all_frames_gpus = 0 124 | for pid, num_all_frames in result_dict.items(): 125 | num_all_frames_gpus += num_all_frames 126 | 127 | print("Total inference time: %.4f s" %(total_time)) 128 | 129 | def sub_processor(lock, pid, args, data, save_path_prefix, save_visualize_path_prefix, img_folder, video_list): 130 | text = 'processor %d' % pid 131 | with lock: 132 | progress = tqdm( 133 | total=len(video_list), 134 | position=pid, 135 | desc=text, 136 | ncols=0 137 | ) 138 | torch.cuda.set_device(pid) 139 | 140 | # model 141 | model, criterion, _ = build_model(args) 142 | device = args.device 143 | model.to(device) 144 | 145 | model_without_ddp = model 146 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 147 | 148 | if pid == 0: 149 | print('number of params:', n_parameters) 150 | 151 | if args.resume: 152 | checkpoint = torch.load(args.resume, map_location='cpu') 153 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 154 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 155 | if len(missing_keys) > 0: 156 | print('Missing Keys: {}'.format(missing_keys)) 157 | if len(unexpected_keys) > 0: 158 | print('Unexpected Keys: {}'.format(unexpected_keys)) 159 | else: 160 | raise ValueError('Please specify the checkpoint for inference.') 161 | 162 | 163 | # start inference 164 | num_all_frames = 0 165 | model.eval() 166 | 167 | sentence_features = [] 168 | pseudo_sentence_features = [] 169 | # 1. For each video 170 | for vid, video in enumerate(video_list): 171 | metas = [] # list[dict], length is number of expressions 172 | 173 | expressions = data[video]["expressions"] 174 | expression_list = list(expressions.keys()) 175 | num_expressions = len(expression_list) 176 | video_len = len(data[video]["frames"]) 177 | 178 | # read all the anno meta 179 | for i in range(num_expressions): 180 | meta = {} 181 | meta["video"] = video 182 | meta["exp"] = expressions[expression_list[i]]["exp"] 183 | meta["exp_id"] = expression_list[i] 184 | meta["frames"] = data[video]["frames"] 185 | metas.append(meta) 186 | meta = metas 187 | 188 | # if vid < 8: # 46, 81 189 | # print(video, metas[0]["exp"]) 190 | # continue 191 | 192 | vis = False 193 | # 2. For each expression 194 | for i in range(num_expressions): 195 | video_name = meta[i]["video"] 196 | exp = meta[i]["exp"] 197 | # exp = 'a dog is with its puppies on the cloth' 198 | # TODO: temp 199 | exp_id = meta[i]["exp_id"] 200 | frames = meta[i]["frames"] 201 | # frames = [f'/home/mcg/ReferFormer/demo/frames_{fid}.jpg' for fid in range(1,2)] 202 | if vis: 203 | os.makedirs('vis/tmp', exist_ok=True) 204 | 205 | video_len = len(frames) 206 | # store images 207 | imgs = [] 208 | for t in range(video_len): 209 | frame = frames[t] 210 | img_path = os.path.join(img_folder, video_name, frame + ".jpg") 211 | # img_path = frame 212 | img = Image.open(img_path).convert('RGB') 213 | if vis: 214 | if t < 1: 215 | img.save(f'vis/tmp/{t}.png') 216 | origin_w, origin_h = img.size 217 | imgs.append(transform(img)) # list[img] 218 | 219 | imgs = torch.stack(imgs, dim=0).to(args.device) # [video_len, 3, h, w] 220 | img_h, img_w = imgs.shape[-2:] 221 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device) 222 | target = {"size": size} 223 | 224 | # start = time.time() 225 | with torch.no_grad(): 226 | outputs = model([imgs], [exp], [target]) 227 | # end = time.time() 228 | # print((end-start)/video_len) 229 | 230 | if vis: 231 | os.system(f'mv vis/tmp vis/{vid}-{exp.replace(" ", "_")}') 232 | 233 | pred_logits = outputs["pred_logits"][0] 234 | pred_boxes = outputs["pred_boxes"][0] 235 | pred_masks = outputs["pred_masks"][0] 236 | pred_ref_points = outputs["reference_points"][0] 237 | text_sentence_features = outputs['sentence_feature'] 238 | if args.use_cycle: 239 | pseudo_text_sentence_features = outputs['pseudo_sentence_feature'] 240 | # anchor = outputs['negative_anchor'] 241 | sentence_features.append(text_sentence_features) 242 | pseudo_sentence_features.append(pseudo_text_sentence_features) 243 | # print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2)) 244 | # print(anchor) 245 | # according to pred_logits, select the query index 246 | pred_scores = pred_logits.sigmoid() # [t, q, k] 247 | pred_score = pred_scores 248 | pred_scores = pred_scores.mean(0) # [q, k] 249 | max_scores, _ = pred_scores.max(-1) # [q,] 250 | # print(max_scores) 251 | _, max_ind = max_scores.max(-1) # [1,] 252 | max_inds = max_ind.repeat(video_len) 253 | pred_masks = pred_masks[range(video_len), max_inds, ...] # [t, h, w] 254 | pred_masks = pred_masks.unsqueeze(0) 255 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False) 256 | if args.save_prob: 257 | pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy() 258 | else: 259 | pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy() 260 | if args.use_score: 261 | pred_score = pred_score[range(video_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1) 262 | pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks 263 | 264 | # store the video results 265 | all_pred_logits = pred_logits[range(video_len), max_inds].sigmoid().cpu().numpy() 266 | all_pred_boxes = pred_boxes[range(video_len), max_inds] 267 | all_pred_ref_points = pred_ref_points[range(video_len), max_inds] 268 | all_pred_masks = pred_masks 269 | 270 | if args.visualize: 271 | for t, frame in enumerate(frames): 272 | # original 273 | img_path = os.path.join(img_folder, video_name, frame + '.jpg') 274 | source_img = Image.open(img_path).convert('RGBA') # PIL image 275 | 276 | draw = ImageDraw.Draw(source_img) 277 | draw_boxes = all_pred_boxes[t].unsqueeze(0) 278 | draw_boxes = rescale_bboxes(draw_boxes.detach(), (origin_w, origin_h)).tolist() 279 | 280 | # draw boxes 281 | xmin, ymin, xmax, ymax = draw_boxes[0] 282 | draw.rectangle(((xmin, ymin), (xmax, ymax)), outline=tuple(color_list[i%len(color_list)]), width=2) 283 | 284 | # draw reference point 285 | ref_points = all_pred_ref_points[t].unsqueeze(0).detach().cpu().tolist() 286 | draw_reference_points(draw, ref_points, source_img.size, color=color_list[i%len(color_list)]) 287 | 288 | # draw mask 289 | source_img = vis_add_mask(source_img, all_pred_masks[t], color_list[i%len(color_list)]) 290 | 291 | # save 292 | save_visualize_path_dir = os.path.join(save_visualize_path_prefix, video, str(i)) 293 | if not os.path.exists(save_visualize_path_dir): 294 | os.makedirs(save_visualize_path_dir) 295 | save_visualize_path = os.path.join(save_visualize_path_dir, frame + '.png') 296 | source_img.save(save_visualize_path) 297 | 298 | 299 | # save binary image 300 | save_path = os.path.join(save_path_prefix, video_name, exp_id) 301 | if not os.path.exists(save_path): 302 | os.makedirs(save_path) 303 | for j in range(video_len): 304 | frame_name = frames[j] 305 | confidence = all_pred_logits[j] 306 | mask = all_pred_masks[j].astype(np.float32) 307 | save_file = os.path.join(save_path, frame_name + ".png") 308 | # print(save_file) 309 | if 'pair_logits' in outputs.keys() and args.use_cls: 310 | mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1 311 | mask = Image.fromarray(mask * 255).convert('L') 312 | mask.save(save_file) 313 | 314 | # print(torch.nn.functional.mse_loss(text_sentence_features, pseudo_text_sentence_features)) 315 | # if vid == 10: 316 | # bert_embed = torch.cat(sentence_features, dim=0) 317 | # np.save(f'output/bert_embed{args.val_type}.npy', bert_embed.cpu().numpy()) 318 | # pseudo_bert_embed = torch.cat(pseudo_sentence_features, dim=0) 319 | # np.save(f'output/pseudo_bert_embed{args.val_type}.npy', pseudo_bert_embed.cpu().numpy()) 320 | 321 | with lock: 322 | progress.update(1) 323 | # bert_embed = torch.cat(sentence_features, dim=0) 324 | # np.save('output/bert_embed.npy', bert_embed.cpu().numpy()) 325 | 326 | result_dict[str(pid)] = num_all_frames 327 | with lock: 328 | progress.close() 329 | 330 | 331 | # visuaize functions 332 | def box_cxcywh_to_xyxy(x): 333 | x_c, y_c, w, h = x.unbind(1) 334 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 335 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 336 | return torch.stack(b, dim=1) 337 | 338 | def rescale_bboxes(out_bbox, size): 339 | img_w, img_h = size 340 | b = box_cxcywh_to_xyxy(out_bbox) 341 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 342 | return b 343 | 344 | 345 | # Visualization functions 346 | def draw_reference_points(draw, reference_points, img_size, color): 347 | W, H = img_size 348 | for i, ref_point in enumerate(reference_points): 349 | init_x, init_y = ref_point 350 | x, y = W * init_x, H * init_y 351 | cur_color = color 352 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4) 353 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4) 354 | 355 | def draw_sample_points(draw, sample_points, img_size, color_list): 356 | alpha = 255 357 | for i, samples in enumerate(sample_points): 358 | for sample in samples: 359 | x, y = sample 360 | cur_color = color_list[i % len(color_list)][::-1] 361 | cur_color += [alpha] 362 | draw.ellipse((x-2, y-2, x+2, y+2), 363 | fill=tuple(cur_color), outline=tuple(cur_color), width=1) 364 | 365 | def vis_add_mask(img, mask, color): 366 | origin_img = np.asarray(img.convert('RGB')).copy() 367 | color = np.array(color) 368 | 369 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np 370 | mask = mask > 0.5 371 | 372 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5 373 | origin_img = Image.fromarray(origin_img) 374 | return origin_img 375 | 376 | 377 | 378 | if __name__ == '__main__': 379 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()]) 380 | args = parser.parse_args() 381 | main(args) 382 | -------------------------------------------------------------------------------- /inference_ytvos_segm.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Inference code for ReferFormer, on Ref-Youtube-VOS 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | ''' 5 | import argparse 6 | import json 7 | import random 8 | import sys 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import util.misc as utils 16 | from models import build_model 17 | import torchvision.transforms as T 18 | import matplotlib.pyplot as plt 19 | import os 20 | import cv2 21 | from PIL import Image, ImageDraw 22 | import math 23 | import torch.nn.functional as F 24 | import json 25 | 26 | import opts 27 | from tqdm import tqdm 28 | 29 | import multiprocessing as mp 30 | import threading 31 | 32 | 33 | from tools.colormap import colormap 34 | 35 | 36 | # colormap 37 | color_list = colormap() 38 | color_list = color_list.astype('uint8').tolist() 39 | 40 | def main(args): 41 | args.masks = True 42 | args.batch_size == 1 43 | print("Inference only supports for batch size = 1") 44 | 45 | global transform 46 | transform = T.Compose([ 47 | T.Resize(args.inf_res), 48 | T.ToTensor(), 49 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 50 | ]) 51 | 52 | # fix the seed for reproducibility 53 | seed = args.seed + utils.get_rank() 54 | torch.manual_seed(seed) 55 | np.random.seed(seed) 56 | random.seed(seed) 57 | 58 | split = args.split 59 | # save path 60 | output_dir = args.output_dir 61 | save_path_prefix = os.path.join(output_dir, split) 62 | if not os.path.exists(save_path_prefix): 63 | os.makedirs(save_path_prefix) 64 | 65 | save_visualize_path_prefix = os.path.join(output_dir, split + '_images') 66 | if args.visualize: 67 | if not os.path.exists(save_visualize_path_prefix): 68 | os.makedirs(save_visualize_path_prefix) 69 | 70 | if args.val_type: 71 | print(f'Warning: use robust rvos validation set, type: {args.val_type}') 72 | assert args.val_type in ['_random', '_color', '_old_color', '_object', '_action', '_dyn', '_1', '_3', '_5'], \ 73 | f'Unknown val type for robust rvos, {args.val_type}' 74 | # load data 75 | root = Path(args.ytvos_path) # data/ref-youtube-vos 76 | img_folder = os.path.join(root, split, f"JPEGImages{args.val_type}") 77 | meta_file = os.path.join(root, "meta_expressions", split, f"meta_expressions{args.val_type}.json") 78 | with open(meta_file, "r") as f: 79 | data = json.load(f)["videos"] 80 | valid_test_videos = set(data.keys()) 81 | # for some reasons the competition's validation expressions dict contains both the validation (202) & 82 | # test videos (305). so we simply load the test expressions dict and use it to filter out the test videos from 83 | # the validation expressions dict: 84 | test_meta_file = os.path.join(root, "meta_expressions", "test", "meta_expressions.json") 85 | with open(test_meta_file, 'r') as f: 86 | test_data = json.load(f)['videos'] 87 | test_videos = set(test_data.keys()) 88 | valid_videos = valid_test_videos - test_videos 89 | video_list = sorted([video for video in valid_videos]) 90 | if 'color' not in args.val_type: 91 | assert len(video_list) == 202 or len(video_list) == 202 + 305, 'error: incorrect number of validation videos' 92 | 93 | # create subprocess 94 | thread_num = args.ngpu 95 | global result_dict 96 | result_dict = mp.Manager().dict() 97 | 98 | processes = [] 99 | lock = threading.Lock() 100 | 101 | video_num = len(video_list) 102 | per_thread_video_num = video_num // thread_num 103 | 104 | start_time = time.time() 105 | print('Start inference') 106 | for i in range(thread_num): 107 | if i == thread_num - 1: 108 | sub_video_list = video_list[i * per_thread_video_num:] 109 | else: 110 | sub_video_list = video_list[i * per_thread_video_num: (i + 1) * per_thread_video_num] 111 | p = mp.Process(target=sub_processor, args=(lock, i, args, data, 112 | save_path_prefix, save_visualize_path_prefix, 113 | img_folder, sub_video_list)) 114 | p.start() 115 | processes.append(p) 116 | 117 | for p in processes: 118 | p.join() 119 | 120 | end_time = time.time() 121 | total_time = end_time - start_time 122 | 123 | result_dict = dict(result_dict) 124 | num_all_frames_gpus = 0 125 | for pid, num_all_frames in result_dict.items(): 126 | num_all_frames_gpus += num_all_frames 127 | 128 | print("Total inference time: %.4f s" %(total_time)) 129 | 130 | def sub_processor(lock, pid, args, data, save_path_prefix, save_visualize_path_prefix, img_folder, video_list): 131 | text = 'processor %d' % pid 132 | with lock: 133 | progress = tqdm( 134 | total=len(video_list), 135 | position=pid, 136 | desc=text, 137 | ncols=0 138 | ) 139 | torch.cuda.set_device(pid) 140 | 141 | # model 142 | model, criterion, _ = build_model(args) 143 | device = args.device 144 | model.to(device) 145 | 146 | model_without_ddp = model 147 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 148 | 149 | if pid == 0: 150 | print('number of params:', n_parameters) 151 | 152 | if args.resume: 153 | checkpoint = torch.load(args.resume, map_location='cpu') 154 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 155 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 156 | if len(missing_keys) > 0: 157 | print('Missing Keys: {}'.format(missing_keys)) 158 | if len(unexpected_keys) > 0: 159 | print('Unexpected Keys: {}'.format(unexpected_keys)) 160 | else: 161 | raise ValueError('Please specify the checkpoint for inference.') 162 | 163 | 164 | # start inference 165 | num_all_frames = 0 166 | model.eval() 167 | 168 | sentence_features = [] 169 | pseudo_sentence_features = [] 170 | # 1. For each video 171 | for vid, video in enumerate(video_list): 172 | metas = [] # list[dict], length is number of expressions 173 | 174 | expressions = data[video]["expressions"] 175 | expression_list = list(expressions.keys()) 176 | num_expressions = len(expression_list) 177 | video_len = len(data[video]["frames"]) 178 | 179 | # read all the anno meta 180 | for i in range(num_expressions): 181 | meta = {} 182 | meta["video"] = video 183 | meta["exp"] = expressions[expression_list[i]]["exp"] 184 | meta["exp_id"] = expression_list[i] 185 | meta["frames"] = data[video]["frames"] 186 | metas.append(meta) 187 | meta = metas 188 | 189 | vis = False 190 | # 2. For each expression 191 | for i in range(num_expressions): 192 | video_name = meta[i]["video"] 193 | exp = meta[i]["exp"] 194 | # TODO: temp 195 | exp_id = meta[i]["exp_id"] 196 | frames = meta[i]["frames"] 197 | if vis: 198 | os.makedirs('vis/tmp', exist_ok=True) 199 | 200 | video_len = len(frames) 201 | # store images 202 | imgs = [] 203 | for t in range(video_len): 204 | frame = frames[t] 205 | img_path = os.path.join(img_folder, video_name, frame + ".jpg") 206 | # img_path = frame 207 | img = Image.open(img_path).convert('RGB') 208 | if vis: 209 | if t < 1: 210 | img.save(f'vis/tmp/{t}.png') 211 | origin_w, origin_h = img.size 212 | imgs.append(transform(img)) # list[img] 213 | 214 | segm_logits = [] 215 | segm_masks = [] 216 | 217 | for segm_idx in range(video_len // args.segm_frame): 218 | if segm_idx != (video_len // args.segm_frame) - 1: 219 | segm = imgs[segm_idx * args.segm_frame:(segm_idx+1) * args.segm_frame] 220 | else: 221 | segm = imgs[segm_idx * args.segm_frame:] 222 | 223 | segm_len = len(segm) 224 | segm = torch.stack(segm, dim=0).to(args.device) # [video_len, 3, h, w] 225 | img_h, img_w = segm.shape[-2:] 226 | size = torch.as_tensor([int(img_h), int(img_w)]).to(args.device) 227 | target = {"size": size} 228 | 229 | # start = time.time() 230 | with torch.no_grad(): 231 | outputs = model([segm], [exp], [target]) 232 | # end = time.time() 233 | # print((end-start)/video_len) 234 | 235 | if vis: 236 | os.system(f'mv vis/tmp vis/{vid}-{exp.replace(" ", "_")}') 237 | 238 | pred_logits = outputs["pred_logits"][0] 239 | pred_boxes = outputs["pred_boxes"][0] 240 | pred_masks = outputs["pred_masks"][0] 241 | pred_ref_points = outputs["reference_points"][0] 242 | text_sentence_features = outputs['sentence_feature'] 243 | if args.use_cycle: 244 | pseudo_text_sentence_features = outputs['pseudo_sentence_feature'] 245 | # anchor = outputs['negative_anchor'] 246 | sentence_features.append(text_sentence_features) 247 | pseudo_sentence_features.append(pseudo_text_sentence_features) 248 | # print(F.pairwise_distance(text_sentence_features, pseudo_text_sentence_features.squeeze(0), p=2)) 249 | # print(anchor) 250 | # according to pred_logits, select the query index 251 | pred_scores = pred_logits.sigmoid() # [t, q, k] 252 | pred_score = pred_scores 253 | pred_scores = pred_scores.mean(0) # [q, k] 254 | max_scores, _ = pred_scores.max(-1) # [q,] 255 | # print(max_scores) 256 | _, max_ind = max_scores.max(-1) # [1,] 257 | max_inds = max_ind.repeat(segm_len) 258 | pred_masks = pred_masks[range(segm_len), max_inds, ...] # [t, h, w] 259 | pred_masks = pred_masks.unsqueeze(0) 260 | pred_masks = F.interpolate(pred_masks, size=(origin_h, origin_w), mode='bilinear', align_corners=False) 261 | if args.save_prob: 262 | pred_masks = pred_masks.sigmoid().squeeze(0).detach().cpu().numpy() 263 | else: 264 | pred_masks = (pred_masks.sigmoid() > args.threshold).squeeze(0).detach().cpu().numpy() 265 | if args.use_score: 266 | pred_score = pred_score[range(segm_len), max_inds, 0].unsqueeze(-1).unsqueeze(-1) 267 | pred_masks *= (pred_score > 0.3).cpu().numpy() * pred_masks 268 | 269 | # store the video results 270 | all_pred_logits = pred_logits[range(segm_len), max_inds] 271 | all_pred_boxes = pred_boxes[range(segm_len), max_inds] 272 | all_pred_ref_points = pred_ref_points[range(segm_len), max_inds] 273 | all_pred_masks = pred_masks 274 | segm_logits.append(all_pred_logits.cpu().numpy()) 275 | segm_masks.append(all_pred_masks) 276 | 277 | all_pred_logits = np.concatenate(segm_logits, axis=0) 278 | all_pred_masks = np.concatenate(segm_masks, axis=0) 279 | assert all_pred_masks.shape[0] == video_len 280 | # save binary image 281 | save_path = os.path.join(save_path_prefix, video_name, exp_id) 282 | if not os.path.exists(save_path): 283 | os.makedirs(save_path) 284 | for j in range(video_len): 285 | frame_name = frames[j] 286 | mask = all_pred_masks[j].astype(np.float32) 287 | save_file = os.path.join(save_path, frame_name + ".png") 288 | if 'pair_logits' in outputs.keys() and args.use_cls: 289 | mask *= 0 if outputs['pair_logits'].cpu().numpy() >= 0.5 else 1 290 | mask = Image.fromarray(mask * 255).convert('L') 291 | mask.save(save_file) 292 | 293 | with lock: 294 | progress.update(1) 295 | # bert_embed = torch.cat(sentence_features, dim=0) 296 | # np.save('output/bert_embed.npy', bert_embed.cpu().numpy()) 297 | 298 | result_dict[str(pid)] = num_all_frames 299 | with lock: 300 | progress.close() 301 | 302 | 303 | # visuaize functions 304 | def box_cxcywh_to_xyxy(x): 305 | x_c, y_c, w, h = x.unbind(1) 306 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 307 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 308 | return torch.stack(b, dim=1) 309 | 310 | def rescale_bboxes(out_bbox, size): 311 | img_w, img_h = size 312 | b = box_cxcywh_to_xyxy(out_bbox) 313 | b = b.cpu() * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) 314 | return b 315 | 316 | 317 | # Visualization functions 318 | def draw_reference_points(draw, reference_points, img_size, color): 319 | W, H = img_size 320 | for i, ref_point in enumerate(reference_points): 321 | init_x, init_y = ref_point 322 | x, y = W * init_x, H * init_y 323 | cur_color = color 324 | draw.line((x-10, y, x+10, y), tuple(cur_color), width=4) 325 | draw.line((x, y-10, x, y+10), tuple(cur_color), width=4) 326 | 327 | def draw_sample_points(draw, sample_points, img_size, color_list): 328 | alpha = 255 329 | for i, samples in enumerate(sample_points): 330 | for sample in samples: 331 | x, y = sample 332 | cur_color = color_list[i % len(color_list)][::-1] 333 | cur_color += [alpha] 334 | draw.ellipse((x-2, y-2, x+2, y+2), 335 | fill=tuple(cur_color), outline=tuple(cur_color), width=1) 336 | 337 | def vis_add_mask(img, mask, color): 338 | origin_img = np.asarray(img.convert('RGB')).copy() 339 | color = np.array(color) 340 | 341 | mask = mask.reshape(mask.shape[0], mask.shape[1]).astype('uint8') # np 342 | mask = mask > 0.5 343 | 344 | origin_img[mask] = origin_img[mask] * 0.5 + color * 0.5 345 | origin_img = Image.fromarray(origin_img) 346 | return origin_img 347 | 348 | 349 | 350 | if __name__ == '__main__': 351 | parser = argparse.ArgumentParser('ReferFormer inference script', parents=[opts.get_args_parser()]) 352 | args = parser.parse_args() 353 | main(args) 354 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Training script of ReferFormer 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import argparse 6 | import datetime 7 | import json 8 | import random 9 | import time 10 | from pathlib import Path 11 | 12 | import numpy as np 13 | import torch 14 | from torch.utils.data import DataLoader, DistributedSampler 15 | import torch.distributed as dist 16 | 17 | import util.misc as utils 18 | import datasets.samplers as samplers 19 | from datasets import build_dataset, get_coco_api_from_dataset 20 | from engine import train_one_epoch, evaluate, evaluate_a2d 21 | from models import build_model 22 | 23 | from tools.load_pretrained_weights import pre_trained_model_to_finetune 24 | from tools.warmup_poly_lr_scheduler import WarmupPolyLR 25 | import opts 26 | 27 | 28 | 29 | def main(args): 30 | args.masks = True 31 | 32 | utils.init_distributed_mode(args) 33 | print("git:\n {}\n".format(utils.get_sha())) 34 | print(args) 35 | 36 | print(f'\n Run on {args.dataset_file} dataset.') 37 | print('\n') 38 | 39 | device = torch.device(args.device) 40 | 41 | # fix the seed for reproducibility 42 | seed = args.seed + utils.get_rank() 43 | torch.manual_seed(seed) 44 | np.random.seed(seed) 45 | random.seed(seed) 46 | 47 | model, criterion, postprocessor = build_model(args) 48 | model.to(device) 49 | 50 | model_without_ddp = model 51 | if args.distributed: 52 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 53 | model_without_ddp = model.module 54 | 55 | # for n, p in model_without_ddp.named_parameters(): 56 | # if p.requires_grad: 57 | # print(n) 58 | 59 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 60 | print('number of params:', n_parameters) 61 | 62 | def match_name_keywords(n, name_keywords): 63 | out = False 64 | for b in name_keywords: 65 | if b in n: 66 | out = True 67 | break 68 | return out 69 | 70 | if args.only_cycle: 71 | for n, p in model_without_ddp.named_parameters(): 72 | if not (match_name_keywords(n, 'text_query') or match_name_keywords(n, 'text_decoder')): 73 | p.requires_grad = False 74 | 75 | param_dicts = [ 76 | { 77 | "params": 78 | [p for n, p in model_without_ddp.named_parameters() 79 | if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_text_encoder_names) 80 | and not match_name_keywords(n, args.lr_linear_proj_names) and not match_name_keywords(n, args.lr_anchor_names) and p.requires_grad], 81 | "lr": args.lr * args.lr_multi, 82 | }, 83 | { 84 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad], 85 | "lr": args.lr_backbone * args.lr_multi, 86 | }, 87 | { 88 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_text_encoder_names) and p.requires_grad], 89 | "lr": args.lr_text_encoder * args.lr_multi, 90 | }, 91 | { 92 | "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad], 93 | "lr": args.lr * args.lr_linear_proj_mult * args.lr_multi, 94 | }, 95 | { 96 | "params": [p for n, p in model_without_ddp.named_parameters() if 97 | match_name_keywords(n, args.lr_anchor_names) and p.requires_grad], 98 | "lr": args.lr * args.lr_anchor_mult * args.lr_multi, 99 | } 100 | ] 101 | 102 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 103 | weight_decay=args.weight_decay) 104 | 105 | # no validation ground truth for ytvos dataset 106 | dataset_train = build_dataset(args.dataset_file, image_set='train', args=args) 107 | 108 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.lr_drop) 109 | 110 | # max_iter = len(dataset_train) / dist.get_world_size() / args.batch_size * args.epochs 111 | # lr_scheduler = WarmupPolyLR(optimizer, max_iter, warmup_iters=100, warmup_factor=0.01) 112 | # print(f"dataset size: {len(dataset_train)}, max_iter: {max_iter}") 113 | 114 | if args.distributed: 115 | if args.cache_mode: 116 | sampler_train = samplers.NodeDistributedSampler(dataset_train) 117 | else: 118 | sampler_train = samplers.DistributedSampler(dataset_train) 119 | else: 120 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 121 | 122 | batch_sampler_train = torch.utils.data.BatchSampler( 123 | sampler_train, args.batch_size, drop_last=True) 124 | 125 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 126 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 127 | 128 | # A2D-Sentences 129 | if args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb' or args.dataset_file == 'refcocoa2d' or args.dataset_file=='refcocoVideo': 130 | dataset_val = build_dataset(args.dataset_file, image_set='val', args=args) 131 | if args.distributed: 132 | if args.cache_mode: 133 | sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False) 134 | else: 135 | sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False) 136 | else: 137 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 138 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, 139 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers, 140 | pin_memory=True) 141 | 142 | 143 | if args.dataset_file == "davis": 144 | assert args.pretrained_weights is not None, "Please provide the pretrained weight to finetune for Ref-DAVIS17" 145 | print("============================================>") 146 | print("Ref-DAVIS17 are finetuned using the checkpoint trained on Ref-Youtube-VOS") 147 | print("Load checkpoint weights from {} ...".format(args.pretrained_weights)) 148 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu") 149 | checkpoint_dict = pre_trained_model_to_finetune(checkpoint, args) 150 | model_without_ddp.load_state_dict(checkpoint_dict, strict=False) 151 | print("============================================>") 152 | 153 | if args.dataset_file == "jhmdb": 154 | assert args.resume is not None, "Please provide the checkpoint to resume for JHMDB-Sentences" 155 | print("============================================>") 156 | print("JHMDB-Sentences are directly evaluated using the checkpoint trained on A2D-Sentences") 157 | print("Load checkpoint weights from {} ...".format(args.pretrained_weights)) 158 | # load checkpoint in the args.resume 159 | print("============================================>") 160 | 161 | # for Ref-Youtube-VOS and A2D-Sentences 162 | # finetune using the pretrained weights on Ref-COCO 163 | if args.dataset_file != "davis" and args.dataset_file != "jhmdb" and args.pretrained_weights is not None: 164 | print("============================================>") 165 | print("Load pretrained weights from {} ...".format(args.pretrained_weights)) 166 | checkpoint = torch.load(args.pretrained_weights, map_location="cpu") 167 | checkpoint_dict = pre_trained_model_to_finetune(checkpoint, args) 168 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint_dict, strict=False) 169 | print(checkpoint_dict.keys()) 170 | print("============================================>") 171 | print(missing_keys) 172 | print("============================================>") 173 | print(unexpected_keys) 174 | print("============================================>") 175 | 176 | 177 | output_dir = Path(args.output_dir) 178 | if args.resume: 179 | if args.resume.startswith('https'): 180 | checkpoint = torch.hub.load_state_dict_from_url( 181 | args.resume, map_location='cpu', check_hash=True) 182 | else: 183 | checkpoint = torch.load(args.resume, map_location='cpu') 184 | missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 185 | unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))] 186 | if len(missing_keys) > 0: 187 | print('Missing Keys: {}'.format(missing_keys)) 188 | if len(unexpected_keys) > 0: 189 | print('Unexpected Keys: {}'.format(unexpected_keys)) 190 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 191 | import copy 192 | p_groups = copy.deepcopy(optimizer.param_groups) 193 | optimizer.load_state_dict(checkpoint['optimizer']) 194 | for pg, pg_old in zip(optimizer.param_groups, p_groups): 195 | pg['lr'] = pg_old['lr'] 196 | pg['initial_lr'] = pg_old['initial_lr'] 197 | # print(optimizer.param_groups) 198 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 199 | # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance). 200 | args.override_resumed_lr_drop = True 201 | if args.override_resumed_lr_drop: 202 | print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.') 203 | lr_scheduler.step_size = args.lr_drop 204 | lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) 205 | lr_scheduler.step(lr_scheduler.last_epoch) 206 | args.start_epoch = checkpoint['epoch'] + 1 207 | 208 | if args.eval: 209 | assert args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb' or args.dataset_file == 'refcocoVideo', \ 210 | 'Only A2D-Sentences and JHMDB-Sentences datasets support evaluation' 211 | test_stats = evaluate_a2d(model, data_loader_val, postprocessor, device, args) 212 | return 213 | 214 | 215 | print("Start training") 216 | start_time = time.time() 217 | for epoch in range(args.start_epoch, args.epochs): 218 | if args.distributed: 219 | sampler_train.set_epoch(epoch) 220 | train_stats = train_one_epoch( 221 | model, criterion, data_loader_train, optimizer, device, lr_scheduler, epoch, 222 | args.clip_max_norm) 223 | lr_scheduler.step() 224 | if args.output_dir: 225 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 226 | # extra checkpoint before LR drop and every epochs 227 | # if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 1 == 0: 228 | if (epoch + 1) % 1 == 0: 229 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') 230 | for checkpoint_path in checkpoint_paths: 231 | utils.save_on_master({ 232 | 'model': model_without_ddp.state_dict(), 233 | 'optimizer': optimizer.state_dict(), 234 | 'lr_scheduler': lr_scheduler.state_dict(), 235 | 'epoch': epoch, 236 | 'args': args, 237 | }, checkpoint_path) 238 | 239 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 240 | 'epoch': epoch, 241 | 'n_parameters': n_parameters} 242 | 243 | if args.dataset_file == 'a2d' or args.dataset_file == 'refcocoa2d': 244 | test_stats = evaluate_a2d(model, data_loader_val, postprocessor, device, args) 245 | log_stats.update({**{f'{k}': v for k, v in test_stats.items()}}) 246 | 247 | if args.output_dir and utils.is_main_process(): 248 | with (output_dir / "log.txt").open("a") as f: 249 | f.write(json.dumps(log_stats) + "\n") 250 | 251 | 252 | total_time = time.time() - start_time 253 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 254 | print('Training time {}'.format(total_time_str)) 255 | 256 | 257 | if __name__ == '__main__': 258 | parser = argparse.ArgumentParser('RVOSNet training and evaluation script', parents=[opts.get_args_parser()]) 259 | args = parser.parse_args() 260 | if args.output_dir: 261 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 262 | main(args) 263 | 264 | 265 | 266 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .referformer import build 2 | 3 | def build_model(args): 4 | return build(args) 5 | -------------------------------------------------------------------------------- /models/amm_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | import torch.nn as nn 4 | from torchvision.models.utils import load_state_dict_from_url 5 | from typing import Type, Any, Callable, Union, List, Optional 6 | 7 | 8 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'amm_resnet50', 'amm_resnet101', 9 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 10 | 'wide_resnet50_2', 'wide_resnet101_2'] 11 | 12 | 13 | model_urls = { 14 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 15 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 16 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 17 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 18 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 19 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 20 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 21 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 22 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 23 | } 24 | 25 | 26 | def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: 27 | """3x3 convolution with padding""" 28 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 29 | padding=dilation, groups=groups, bias=False, dilation=dilation) 30 | 31 | 32 | def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: 33 | """1x1 convolution""" 34 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 35 | 36 | 37 | class BasicBlock(nn.Module): 38 | expansion: int = 1 39 | 40 | def __init__( 41 | self, 42 | inplanes: int, 43 | planes: int, 44 | stride: int = 1, 45 | downsample: Optional[nn.Module] = None, 46 | groups: int = 1, 47 | base_width: int = 64, 48 | dilation: int = 1, 49 | norm_layer: Optional[Callable[..., nn.Module]] = None 50 | ) -> None: 51 | super(BasicBlock, self).__init__() 52 | if norm_layer is None: 53 | norm_layer = nn.BatchNorm2d 54 | if groups != 1 or base_width != 64: 55 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 56 | if dilation > 1: 57 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 58 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 59 | self.conv1 = conv3x3(inplanes, planes, stride) 60 | self.bn1 = norm_layer(planes) 61 | self.relu = nn.ReLU(inplace=True) 62 | self.conv2 = conv3x3(planes, planes) 63 | self.bn2 = norm_layer(planes) 64 | self.downsample = downsample 65 | self.stride = stride 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | identity = x 69 | 70 | out = self.conv1(x) 71 | out = self.bn1(out) 72 | out = self.relu(out) 73 | 74 | out = self.conv2(out) 75 | out = self.bn2(out) 76 | 77 | if self.downsample is not None: 78 | identity = self.downsample(x) 79 | 80 | out += identity 81 | out = self.relu(out) 82 | 83 | return out 84 | 85 | 86 | class Bottleneck(nn.Module): 87 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 88 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 89 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 90 | # This variant is also known as ResNet V1.5 and improves accuracy according to 91 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 92 | 93 | expansion: int = 4 94 | 95 | def __init__( 96 | self, 97 | inplanes: int, 98 | planes: int, 99 | stride: int = 1, 100 | downsample: Optional[nn.Module] = None, 101 | groups: int = 1, 102 | base_width: int = 64, 103 | dilation: int = 1, 104 | norm_layer: Optional[Callable[..., nn.Module]] = None 105 | ) -> None: 106 | super(Bottleneck, self).__init__() 107 | if norm_layer is None: 108 | norm_layer = nn.BatchNorm2d 109 | width = int(planes * (base_width / 64.)) * groups 110 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 111 | self.conv1 = conv1x1(inplanes, width) 112 | self.bn1 = norm_layer(width) 113 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 114 | self.bn2 = norm_layer(width) 115 | self.conv3 = conv1x1(width, planes * self.expansion) 116 | self.bn3 = norm_layer(planes * self.expansion) 117 | self.relu = nn.ReLU(inplace=True) 118 | self.downsample = downsample 119 | self.stride = stride 120 | 121 | def forward(self, x: Tensor) -> Tensor: 122 | identity = x 123 | 124 | out = self.conv1(x) 125 | out = self.bn1(out) 126 | out = self.relu(out) 127 | 128 | out = self.conv2(out) 129 | out = self.bn2(out) 130 | out = self.relu(out) 131 | 132 | out = self.conv3(out) 133 | out = self.bn3(out) 134 | 135 | if self.downsample is not None: 136 | identity = self.downsample(x) 137 | 138 | out += identity 139 | out = self.relu(out) 140 | 141 | return out 142 | 143 | 144 | class ResNet(nn.Module): 145 | 146 | def __init__( 147 | self, 148 | block: Type[Union[BasicBlock, Bottleneck]], 149 | layers: List[int], 150 | num_classes: int = 1000, 151 | zero_init_residual: bool = False, 152 | groups: int = 1, 153 | width_per_group: int = 64, 154 | replace_stride_with_dilation: Optional[List[bool]] = None, 155 | norm_layer: Optional[Callable[..., nn.Module]] = None 156 | ) -> None: 157 | super(ResNet, self).__init__() 158 | if norm_layer is None: 159 | norm_layer = nn.BatchNorm2d 160 | self._norm_layer = norm_layer 161 | 162 | self.inplanes = 64 163 | self.dilation = 1 164 | if replace_stride_with_dilation is None: 165 | # each element in the tuple indicates if we should replace 166 | # the 2x2 stride with a dilated convolution instead 167 | replace_stride_with_dilation = [False, False, False] 168 | if len(replace_stride_with_dilation) != 3: 169 | raise ValueError("replace_stride_with_dilation should be None " 170 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 171 | self.groups = groups 172 | self.base_width = width_per_group 173 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 174 | bias=False) 175 | self.bn1 = norm_layer(self.inplanes) 176 | self.relu = nn.ReLU(inplace=True) 177 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 178 | self.layer1 = self._make_layer(block, 64, layers[0]) 179 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 180 | dilate=replace_stride_with_dilation[0]) 181 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 182 | dilate=replace_stride_with_dilation[1]) 183 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 184 | dilate=replace_stride_with_dilation[2]) 185 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 186 | self.fc = nn.Linear(512 * block.expansion, num_classes) 187 | 188 | for m in self.modules(): 189 | if isinstance(m, nn.Conv2d): 190 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 191 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 192 | nn.init.constant_(m.weight, 1) 193 | nn.init.constant_(m.bias, 0) 194 | 195 | # Zero-initialize the last BN in each residual branch, 196 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 197 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 198 | if zero_init_residual: 199 | for m in self.modules(): 200 | if isinstance(m, Bottleneck): 201 | nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] 202 | elif isinstance(m, BasicBlock): 203 | nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] 204 | 205 | def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int, 206 | stride: int = 1, dilate: bool = False) -> nn.Sequential: 207 | norm_layer = self._norm_layer 208 | downsample = None 209 | previous_dilation = self.dilation 210 | if dilate: 211 | self.dilation *= stride 212 | stride = 1 213 | if stride != 1 or self.inplanes != planes * block.expansion: 214 | downsample = nn.Sequential( 215 | conv1x1(self.inplanes, planes * block.expansion, stride), 216 | norm_layer(planes * block.expansion), 217 | ) 218 | 219 | layers = [] 220 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 221 | self.base_width, previous_dilation, norm_layer)) 222 | self.inplanes = planes * block.expansion 223 | for _ in range(1, blocks): 224 | layers.append(block(self.inplanes, planes, groups=self.groups, 225 | base_width=self.base_width, dilation=self.dilation, 226 | norm_layer=norm_layer)) 227 | 228 | return nn.Sequential(*layers) 229 | 230 | def _forward_impl(self, x: Tensor) -> Tensor: 231 | # See note [TorchScript super()] 232 | x = self.conv1(x) 233 | x = self.bn1(x) 234 | x = self.relu(x) 235 | x = self.maxpool(x) 236 | 237 | x = self.layer1(x) 238 | x = self.layer2(x) 239 | x = self.layer3(x) 240 | x = self.layer4(x) 241 | 242 | x = self.avgpool(x) 243 | x = torch.flatten(x, 1) 244 | x = self.fc(x) 245 | 246 | return x 247 | 248 | def forward(self, x: Tensor) -> Tensor: 249 | return self._forward_impl(x) 250 | 251 | 252 | def _resnet( 253 | arch: str, 254 | block: Type[Union[BasicBlock, Bottleneck]], 255 | layers: List[int], 256 | pretrained: bool, 257 | progress: bool, 258 | **kwargs: Any 259 | ) -> ResNet: 260 | model = ResNet(block, layers, **kwargs) 261 | if pretrained: 262 | state_dict = load_state_dict_from_url(model_urls[arch], 263 | progress=progress) 264 | model.load_state_dict(state_dict) 265 | return model 266 | 267 | 268 | def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 269 | r"""ResNet-18 model from 270 | `"Deep Residual Learning for Image Recognition" `_. 271 | 272 | Args: 273 | pretrained (bool): If True, returns a model pre-trained on ImageNet 274 | progress (bool): If True, displays a progress bar of the download to stderr 275 | """ 276 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 277 | **kwargs) 278 | 279 | 280 | def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 281 | r"""ResNet-34 model from 282 | `"Deep Residual Learning for Image Recognition" `_. 283 | 284 | Args: 285 | pretrained (bool): If True, returns a model pre-trained on ImageNet 286 | progress (bool): If True, displays a progress bar of the download to stderr 287 | """ 288 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 289 | **kwargs) 290 | 291 | 292 | def amm_resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 293 | r"""ResNet-50 model from 294 | `"Deep Residual Learning for Image Recognition" `_. 295 | 296 | Args: 297 | pretrained (bool): If True, returns a model pre-trained on ImageNet 298 | progress (bool): If True, displays a progress bar of the download to stderr 299 | """ 300 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 301 | **kwargs) 302 | 303 | 304 | def amm_resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 305 | r"""ResNet-101 model from 306 | `"Deep Residual Learning for Image Recognition" `_. 307 | 308 | Args: 309 | pretrained (bool): If True, returns a model pre-trained on ImageNet 310 | progress (bool): If True, displays a progress bar of the download to stderr 311 | """ 312 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 313 | **kwargs) 314 | 315 | 316 | def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 317 | r"""ResNet-152 model from 318 | `"Deep Residual Learning for Image Recognition" `_. 319 | 320 | Args: 321 | pretrained (bool): If True, returns a model pre-trained on ImageNet 322 | progress (bool): If True, displays a progress bar of the download to stderr 323 | """ 324 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 325 | **kwargs) 326 | 327 | 328 | def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 329 | r"""ResNeXt-50 32x4d model from 330 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 331 | 332 | Args: 333 | pretrained (bool): If True, returns a model pre-trained on ImageNet 334 | progress (bool): If True, displays a progress bar of the download to stderr 335 | """ 336 | kwargs['groups'] = 32 337 | kwargs['width_per_group'] = 4 338 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 339 | pretrained, progress, **kwargs) 340 | 341 | 342 | def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 343 | r"""ResNeXt-101 32x8d model from 344 | `"Aggregated Residual Transformation for Deep Neural Networks" `_. 345 | 346 | Args: 347 | pretrained (bool): If True, returns a model pre-trained on ImageNet 348 | progress (bool): If True, displays a progress bar of the download to stderr 349 | """ 350 | kwargs['groups'] = 32 351 | kwargs['width_per_group'] = 8 352 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 353 | pretrained, progress, **kwargs) 354 | 355 | 356 | def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 357 | r"""Wide ResNet-50-2 model from 358 | `"Wide Residual Networks" `_. 359 | 360 | The model is the same as ResNet except for the bottleneck number of channels 361 | which is twice larger in every block. The number of channels in outer 1x1 362 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 363 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 364 | 365 | Args: 366 | pretrained (bool): If True, returns a model pre-trained on ImageNet 367 | progress (bool): If True, displays a progress bar of the download to stderr 368 | """ 369 | kwargs['width_per_group'] = 64 * 2 370 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 371 | pretrained, progress, **kwargs) 372 | 373 | 374 | def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet: 375 | r"""Wide ResNet-101-2 model from 376 | `"Wide Residual Networks" `_. 377 | 378 | The model is the same as ResNet except for the bottleneck number of channels 379 | which is twice larger in every block. The number of channels in outer 1x1 380 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 381 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 382 | 383 | Args: 384 | pretrained (bool): If True, returns a model pre-trained on ImageNet 385 | progress (bool): If True, displays a progress bar of the download to stderr 386 | """ 387 | kwargs['width_per_group'] = 64 * 2 388 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 389 | pretrained, progress, **kwargs) 390 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbone modules. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import os 6 | import sys 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import torchvision 12 | from torch import nn 13 | # from torchvision.models._utils import IntermediateLayerGetter 14 | from typing import Dict, List 15 | from einops import rearrange 16 | 17 | from .amm_resnet import amm_resnet50, amm_resnet101 18 | from util.misc import NestedTensor, is_main_process 19 | 20 | from .position_encoding import build_position_encoding 21 | 22 | 23 | class FrozenBatchNorm2d(torch.nn.Module): 24 | """ 25 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 26 | 27 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 28 | without which any other models than torchvision.models.resnet[18,34,50,101] 29 | produce nans. 30 | """ 31 | 32 | def __init__(self, n): 33 | super(FrozenBatchNorm2d, self).__init__() 34 | self.register_buffer("weight", torch.ones(n)) 35 | self.register_buffer("bias", torch.zeros(n)) 36 | self.register_buffer("running_mean", torch.zeros(n)) 37 | self.register_buffer("running_var", torch.ones(n)) 38 | 39 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 40 | missing_keys, unexpected_keys, error_msgs): 41 | num_batches_tracked_key = prefix + 'num_batches_tracked' 42 | if num_batches_tracked_key in state_dict: 43 | del state_dict[num_batches_tracked_key] 44 | 45 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 46 | state_dict, prefix, local_metadata, strict, 47 | missing_keys, unexpected_keys, error_msgs) 48 | 49 | def forward(self, x): 50 | # move reshapes to the beginning 51 | # to make it fuser-friendly 52 | w = self.weight.reshape(1, -1, 1, 1) 53 | b = self.bias.reshape(1, -1, 1, 1) 54 | rv = self.running_var.reshape(1, -1, 1, 1) 55 | rm = self.running_mean.reshape(1, -1, 1, 1) 56 | eps = 1e-5 57 | scale = w * (rv + eps).rsqrt() 58 | bias = b - rm * scale 59 | return x * scale + bias 60 | 61 | 62 | class BackboneBase(nn.Module): 63 | 64 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 65 | super().__init__() 66 | for name, parameter in backbone.named_parameters(): 67 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 68 | parameter.requires_grad_(False) 69 | if return_interm_layers: 70 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 71 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} deformable detr 72 | self.strides = [4, 8, 16, 32] 73 | self.num_channels = [256, 512, 1024, 2048] 74 | else: 75 | return_layers = {'layer4': "0"} 76 | self.strides = [32] 77 | self.num_channels = [2048] 78 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 79 | 80 | def forward(self, tensor_list: NestedTensor): 81 | xs = self.body(tensor_list.tensors) 82 | out: Dict[str, NestedTensor] = {} 83 | for name, x in xs.items(): 84 | m = tensor_list.mask 85 | assert m is not None 86 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 87 | out[name] = NestedTensor(x, mask) 88 | return out 89 | 90 | 91 | class Backbone(BackboneBase): 92 | """ResNet backbone with frozen BatchNorm.""" 93 | def __init__(self, name: str, 94 | train_backbone: bool, 95 | return_interm_layers: bool, 96 | dilation: bool): 97 | backbone = getattr(torchvision.models, name)( 98 | replace_stride_with_dilation=[False, False, dilation], 99 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 100 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" 101 | super().__init__(backbone, train_backbone, return_interm_layers) 102 | if dilation: 103 | self.strides[-1] = self.strides[-1] // 2 104 | 105 | class AMM_Backbone(BackboneBase): 106 | """ResNet backbone with frozen BatchNorm.""" 107 | def __init__(self, name: str, 108 | train_backbone: bool, 109 | return_interm_layers: bool, 110 | dilation: bool): 111 | assert name in ['amm_resnet50', 'amm_resnet101'], 'unsupported backbone type' 112 | net = amm_resnet50 if name == 'amm_resnet50' else amm_resnet101() 113 | backbone = net(replace_stride_with_dilation=[False, False, dilation], 114 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 115 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" 116 | super().__init__(backbone, train_backbone, return_interm_layers) 117 | if dilation: 118 | self.strides[-1] = self.strides[-1] // 2 119 | 120 | class Joiner(nn.Sequential): 121 | def __init__(self, backbone, position_embedding): 122 | super().__init__(backbone, position_embedding) 123 | self.strides = backbone.strides 124 | self.num_channels = backbone.num_channels 125 | 126 | 127 | def forward(self, tensor_list: NestedTensor): 128 | tensor_list.tensors = rearrange(tensor_list.tensors, 'b t c h w -> (b t) c h w') 129 | tensor_list.mask = rearrange(tensor_list.mask, 'b t h w -> (b t) h w') 130 | 131 | xs = self[0](tensor_list) 132 | out: List[NestedTensor] = [] 133 | pos = [] 134 | for name, x in xs.items(): 135 | out.append(x) 136 | # position encoding 137 | pos.append(self[1](x).to(x.tensors.dtype)) 138 | return out, pos 139 | 140 | class IntermediateLayerGetter(nn.ModuleDict): 141 | """ 142 | Module wrapper that returns intermediate layers from a model 143 | 144 | It has a strong assumption that the modules have been registered 145 | into the model in the same order as they are used. 146 | This means that one should **not** reuse the same nn.Module 147 | twice in the forward if you want this to work. 148 | 149 | Additionally, it is only able to query submodules that are directly 150 | assigned to the model. So if `model` is passed, `model.feature1` can 151 | be returned, but not `model.feature1.layer2`. 152 | 153 | Args: 154 | model (nn.Module): model on which we will extract the features 155 | return_layers (Dict[name, new_name]): a dict containing the names 156 | of the modules for which the activations will be returned as 157 | the key of the dict, and the value of the dict is the name 158 | of the returned activation (which the user can specify). 159 | """ 160 | _version = 2 161 | __annotations__ = { 162 | "return_layers": Dict[str, str], 163 | } 164 | 165 | def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: 166 | if not set(return_layers).issubset([name for name, _ in model.named_children()]): 167 | raise ValueError("return_layers are not present in model") 168 | orig_return_layers = return_layers 169 | return_layers = {str(k): str(v) for k, v in return_layers.items()} 170 | layers = OrderedDict() 171 | for name, module in model.named_children(): 172 | layers[name] = module 173 | if name in return_layers: 174 | del return_layers[name] 175 | if not return_layers: 176 | break 177 | 178 | super(IntermediateLayerGetter, self).__init__(layers) 179 | self.return_layers = orig_return_layers 180 | 181 | def forward(self, x): 182 | out = OrderedDict() 183 | for name, module in self.items(): 184 | x = module(x) 185 | if name in self.return_layers: 186 | out_name = self.return_layers[name] 187 | out[out_name] = x 188 | return out 189 | 190 | 191 | def build_backbone(args): 192 | position_embedding = build_position_encoding(args) 193 | train_backbone = args.lr_backbone > 0 194 | return_interm_layers = args.masks or (args.num) 195 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 196 | model = Joiner(backbone, position_embedding) 197 | model.num_channels = backbone.num_channels 198 | return model 199 | 200 | def build_amm_backbone(args): 201 | position_embedding = build_position_encoding(args) 202 | train_backbone = args.lr_backbone > 0 203 | return_interm_layers = args.masks or (args.num) 204 | backbone = AMM_Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 205 | model = Joiner(backbone, position_embedding) 206 | model.num_channels = backbone.num_channels 207 | return model -------------------------------------------------------------------------------- /models/cycle.py: -------------------------------------------------------------------------------- 1 | from segmentation import * 2 | from collections import defaultdict 3 | from typing import List, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | from PIL import Image 10 | import cv2 11 | import numpy as np 12 | from einops import rearrange, repeat 13 | 14 | try: 15 | from panopticapi.utils import id2rgb, rgb2id 16 | except ImportError: 17 | pass 18 | 19 | import fvcore.nn.weight_init as weight_init 20 | 21 | from .position_encoding import PositionEmbeddingSine1D 22 | 23 | BN_MOMENTUM = 0.1 24 | 25 | class S2EPrime(nn.Module): 26 | def __init__(self, feature_channels: List, conv_dim: int, mask_dim: int, dim_feedforward: int = 2048, norm=None, 27 | return_query=False, stage_num=1): 28 | """ 29 | Args: 30 | feature_channels: list of fpn feature channel numbers. 31 | conv_dim: number of output channels for the intermediate conv layers. 32 | mask_dim: number of output channels for the final conv layer. 33 | dim_feedforward: number of vision-language fusion module ffn channel numbers. 34 | norm (str or callable): normalization for all conv layers 35 | """ 36 | super().__init__() 37 | 38 | self.feature_channels = feature_channels 39 | 40 | lateral_convs = [] 41 | output_convs = [] 42 | 43 | use_bias = norm == "" 44 | for idx, in_channels in enumerate(feature_channels): 45 | # in_channels: 4x -> 32x 46 | lateral_norm = get_norm(norm, conv_dim) 47 | output_norm = get_norm(norm, conv_dim) 48 | 49 | lateral_conv = Conv2d( 50 | in_channels, conv_dim, kernel_size=1, bias=use_bias, norm=lateral_norm 51 | ) 52 | output_conv = Conv2d( 53 | conv_dim, 54 | conv_dim, 55 | kernel_size=3, 56 | stride=1, 57 | padding=1, 58 | bias=use_bias, 59 | norm=output_norm, 60 | activation=F.relu, 61 | ) 62 | weight_init.c2_xavier_fill(lateral_conv) 63 | weight_init.c2_xavier_fill(output_conv) 64 | 65 | stage = idx + 1 66 | 67 | self.add_module("adapter_{}".format(stage), lateral_conv) 68 | self.add_module("layer_{}".format(stage), output_conv) 69 | lateral_convs.append(lateral_conv) 70 | output_convs.append(output_conv) 71 | 72 | # Place convs into top-down order (from low to high resolution) 73 | # to make the top-down computation in forward clearer. 74 | self.lateral_convs = lateral_convs[::-1] 75 | self.output_convs = output_convs[::-1] 76 | 77 | self.mask_dim = mask_dim 78 | self.mask_features = Conv2d( 79 | conv_dim, 80 | mask_dim, 81 | kernel_size=3, 82 | stride=1, 83 | padding=1, 84 | ) 85 | weight_init.c2_xavier_fill(self.mask_features) 86 | 87 | # vision-language cross-modal fusion 88 | self.text_pos = PositionEmbeddingSine1D(conv_dim, normalize=True) 89 | sr_ratios = [8, 4, 2, 1] 90 | 91 | self.stage_num = stage_num 92 | self.dyn_conv_list = [] 93 | self.dyn_filter_list = [] 94 | self.cross_attn_list = [] 95 | self.pool_conv_list = [] 96 | self.out_conv_list = [] 97 | for i in range(self.stage_num): 98 | # init dyn conv 99 | output_norm = get_norm(norm, conv_dim) 100 | self.dyn_conv = Conv2d( 101 | conv_dim, 102 | conv_dim, 103 | kernel_size=1, 104 | stride=1, 105 | padding=0, 106 | bias=use_bias, 107 | norm=output_norm, 108 | activation=F.relu, 109 | ) 110 | self.dyn_filter = nn.Linear(conv_dim, conv_dim, bias=True) 111 | weight_init.c2_xavier_fill(self.dyn_conv) 112 | weight_init.c2_xavier_fill(self.dyn_filter) 113 | 114 | # init text2video attn 115 | self.cross_attn = DualVisionLanguageBlock(conv_dim, dim_feedforward=dim_feedforward, 116 | nhead=8, sr_ratio=sr_ratios[-1]) 117 | for p in self.cross_attn.parameters(): 118 | if p.dim() > 1: 119 | nn.init.xavier_uniform_(p) 120 | text_pool_size = 3 121 | self.pool = nn.AdaptiveAvgPool1d(text_pool_size) 122 | self.pool_conv = Conv2d( 123 | text_pool_size*conv_dim, 124 | conv_dim, 125 | kernel_size=1, 126 | stride=1, 127 | padding=0, 128 | bias=use_bias, 129 | norm=output_norm, 130 | activation=F.relu, 131 | ) 132 | weight_init.c2_xavier_fill(self.pool_conv) 133 | # init out conv for stage != 0 134 | if i != 0: 135 | output_norm = get_norm(norm, conv_dim) 136 | output_conv = Conv2d( 137 | conv_dim, 138 | conv_dim, 139 | kernel_size=3, 140 | stride=1, 141 | padding=1, 142 | bias=use_bias, 143 | norm=output_norm, 144 | activation=F.relu, 145 | ) 146 | weight_init.c2_xavier_fill(output_conv) 147 | self.out_conv_list.append(output_conv) 148 | # add to list 149 | self.dyn_conv_list.append(self.dyn_conv) 150 | self.dyn_filter_list.append(self.dyn_filter) 151 | self.cross_attn_list.append(self.cross_attn) 152 | self.pool_conv_list.append(self.pool_conv) 153 | self.dyn_conv_list = nn.ModuleList(self.dyn_conv_list) 154 | self.dyn_filter_list = nn.ModuleList(self.dyn_filter_list) 155 | self.cross_attn_list = nn.ModuleList(self.cross_attn_list) 156 | self.pool_conv_list = nn.ModuleList(self.pool_conv_list) 157 | self.out_conv_list = nn.ModuleList(self.out_conv_list) 158 | print('use recurrent dual path cross attention') 159 | 160 | def forward_features(self, features, text_features, text_sentence_features, poses, memory, nf): 161 | # nf: num_frames 162 | text_pos = self.text_pos(text_features).permute(2, 0, 1) # [length, batch_size, c] 163 | text_features, text_masks = text_features.decompose() 164 | text_features = text_features.permute(1, 0, 2) 165 | 166 | for idx, (mem, f, pos) in enumerate(zip(memory[::-1], features[1:][::-1], poses[1:][::-1])): # 32x -> 8x 167 | lateral_conv = self.lateral_convs[idx] 168 | output_conv = self.output_convs[idx] 169 | 170 | _, x_mask = f.decompose() 171 | n, c, h, w = pos.shape 172 | b = n // nf 173 | t = nf 174 | 175 | # NOTE: here the (h, w) is the size for current fpn layer 176 | vision_features = lateral_conv(mem) # [b*t, c, h, w] 177 | 178 | # upsample 179 | # TODO: only fuse in high-level and repeat 180 | if idx == 0: # top layer 181 | for stage in range(self.stage_num): 182 | if stage == 0: 183 | vision_features = rearrange(vision_features, '(b t) c h w -> (t h w) b c', b=b, t=t) 184 | else: 185 | vision_features = rearrange(y, '(b t) c h w -> (t h w) b c', b=b, t=t) 186 | vision_pos = rearrange(pos, '(b t) c h w -> (t h w) b c', b=b, t=t) 187 | vision_masks = rearrange(x_mask, '(b t) h w -> b (t h w)', b=b, t=t) 188 | cur_fpn, text_features_ = self.cross_attn_list[stage](tgt=vision_features, 189 | memory=text_features, 190 | t=t, h=h, w=w, 191 | tgt_key_padding_mask=vision_masks, 192 | memory_key_padding_mask=text_masks, 193 | pos=text_pos, 194 | query_pos=vision_pos 195 | ) # [t*h*w, b, c] 196 | # text_features [l, b, c] 197 | cur_fpn = rearrange(cur_fpn, '(t h w) b c -> (b t) c h w', t=t, h=h, w=w) 198 | 199 | # # repeat cls_token to compute filter, only apply once 200 | # if text_sentence_features.shape[0] != b * t: 201 | # text_sentence_features = repeat(text_sentence_features, 'b c -> b t c', t=t) 202 | # text_sentence_features = rearrange(text_sentence_features, 'b t c -> (b t) c') 203 | # filter = self.dyn_filter(text_sentence_features).unsqueeze(1) 204 | # TODO: test only no fusion 205 | text_features_ = repeat(text_features_, 'l b c -> l b t c', t=t) 206 | text_features_ = rearrange(text_features_, 'l b t c -> (b t) c l') 207 | text_features_ = rearrange(self.pool(text_features_), '(b t) c l -> (b t) l c', b=b, t=t) 208 | filter = self.dyn_filter_list[stage](text_features_) 209 | y = cur_fpn 210 | y_ = self.dyn_conv_list[stage](cur_fpn) 211 | y_ = torch.einsum('ichw,ijc -> ijchw', y_, filter) 212 | y_ = rearrange(y_, 'i j c h w -> i (j c) h w') 213 | y_ = self.pool_conv_list[stage](y_) 214 | if stage == 0: 215 | y = output_conv(y+y_) 216 | else: 217 | y = self.out_conv_list[stage-1](y+y_) 218 | else: 219 | # Following FPN implementation, we use nearest upsampling here 220 | cur_fpn = vision_features 221 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") 222 | y = output_conv(y) 223 | 224 | # 4x level 225 | lateral_conv = self.lateral_convs[-1] 226 | output_conv = self.output_convs[-1] 227 | 228 | x, x_mask = features[0].decompose() 229 | pos = poses[0] 230 | n, c, h, w = pos.shape 231 | b = n // nf 232 | t = nf 233 | 234 | vision_features = lateral_conv(x) # [b*t, c, h, w] 235 | cur_fpn = vision_features 236 | # Following FPN implementation, we use nearest upsampling here 237 | y = cur_fpn + F.interpolate(y, size=cur_fpn.shape[-2:], mode="nearest") 238 | y = output_conv(y) 239 | return y 240 | 241 | def forward(self, features, text_features, text_sentence_features, pos, memory, nf): 242 | """The forward function receives the vision and language features, 243 | and outputs the mask features with the spatial stride of 4x. 244 | 245 | Args: 246 | features (list[NestedTensor]): backbone features (vision), length is number of FPN layers 247 | tensors: [b*t, ci, hi, wi], mask: [b*t, hi, wi] 248 | text_features (NestedTensor): text features (language) 249 | tensors: [b, length, c], mask: [b, length] 250 | pos (list[Tensor]): position encoding of vision features, length is number of FPN layers 251 | tensors: [b*t, c, hi, wi] 252 | memory (list[Tensor]): features from encoder output. from 8x -> 32x 253 | NOTE: the layer orders of both features and pos are res2 -> res5 254 | 255 | Returns: 256 | mask_features (Tensor): [b*t, mask_dim, h, w], with the spatial stride of 4x. 257 | """ 258 | y = self.forward_features(features, text_features, text_sentence_features, pos, memory, nf) 259 | return self.mask_features(y) 260 | 261 | 262 | # class EPrime2SPrime(nn.Module): 263 | # 264 | # 265 | # class Cycle(nn.Module): 266 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instance Sequence Matching 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, multi_iou 11 | from util.misc import nested_tensor_from_tensor_list 12 | 13 | INF = 100000000 14 | 15 | def dice_coef(inputs, targets): 16 | inputs = inputs.sigmoid() 17 | inputs = inputs.flatten(1).unsqueeze(1) # [N, 1, THW] 18 | targets = targets.flatten(1).unsqueeze(0) # [1, M, THW] 19 | numerator = 2 * (inputs * targets).sum(2) 20 | denominator = inputs.sum(-1) + targets.sum(-1) 21 | 22 | # NOTE coef doesn't be subtracted to 1 as it is not necessary for computing costs 23 | coef = (numerator + 1) / (denominator + 1) 24 | return coef 25 | 26 | def sigmoid_focal_coef(inputs, targets, alpha: float = 0.25, gamma: float = 2): 27 | N, M = len(inputs), len(targets) 28 | inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1) # [N, M, THW] 29 | targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1) # [N, M, THW] 30 | 31 | prob = inputs.sigmoid() 32 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 33 | p_t = prob * targets + (1 - prob) * (1 - targets) 34 | coef = ce_loss * ((1 - p_t) ** gamma) 35 | 36 | if alpha >= 0: 37 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 38 | coef = alpha_t * coef 39 | 40 | return coef.mean(2) # [N, M] 41 | 42 | 43 | class HungarianMatcher(nn.Module): 44 | """This class computes an assignment between the targets and the predictions of the network 45 | 46 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 47 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 48 | while the others are un-matched (and thus treated as non-objects). 49 | """ 50 | 51 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, 52 | cost_mask: float = 1, cost_dice: float = 1, num_classes: int = 1): 53 | """Creates the matcher 54 | 55 | Params: 56 | cost_class: This is the relative weight of the classification error in the matching cost 57 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 58 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 59 | cost_mask: This is the relative weight of the sigmoid focal loss of the mask in the matching cost 60 | cost_dice: This is the relative weight of the dice loss of the mask in the matching cost 61 | """ 62 | super().__init__() 63 | self.cost_class = cost_class 64 | self.cost_bbox = cost_bbox 65 | self.cost_giou = cost_giou 66 | self.cost_mask = cost_mask 67 | self.cost_dice = cost_dice 68 | self.num_classes = num_classes 69 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0 \ 70 | or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" 71 | self.mask_out_stride = 4 72 | 73 | @torch.no_grad() 74 | def forward(self, outputs, targets): 75 | """ Performs the matching 76 | Params: 77 | outputs: This is a dict that contains at least these entries: 78 | "pred_logits": Tensor of dim [batch_size, num_queries_per_frame, num_frames, num_classes] with the classification logits 79 | "pred_boxes": Tensor of dim [batch_size, num_queries_per_frame, num_frames, 4] with the predicted box coordinates 80 | "pred_masks": Tensor of dim [batch_size, num_queries_per_frame, num_frames, h, w], h,w in 4x size 81 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 82 | NOTE: Since every frame has one object at most 83 | "labels": Tensor of dim [num_frames] (where num_target_boxes is the number of ground-truth 84 | objects in the target) containing the class labels 85 | "boxes": Tensor of dim [num_frames, 4] containing the target box coordinates 86 | "masks": Tensor of dim [num_frames, h, w], h,w in origin size 87 | Returns: 88 | A list of size batch_size, containing tuples of (index_i, index_j) where: 89 | - index_i is the indices of the selected predictions (in order) 90 | - index_j is the indices of the corresponding selected targets (in order) 91 | For each batch element, it holds: 92 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 93 | """ 94 | src_logits = outputs["pred_logits"] 95 | src_boxes = outputs["pred_boxes"] 96 | src_masks = outputs["pred_masks"] 97 | 98 | bs, nf, nq, h, w = src_masks.shape 99 | 100 | # handle mask padding issue 101 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], 102 | size_divisibility=32, 103 | split=False).decompose() 104 | target_masks = target_masks.to(src_masks) # [B, T, H, W] 105 | 106 | # downsample ground truth masks with ratio mask_out_stride 107 | start = int(self.mask_out_stride // 2) 108 | im_h, im_w = target_masks.shape[-2:] 109 | 110 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 111 | assert target_masks.size(2) * self.mask_out_stride == im_h 112 | assert target_masks.size(3) * self.mask_out_stride == im_w 113 | 114 | indices = [] 115 | for i in range(bs): 116 | out_prob = src_logits[i].sigmoid() 117 | out_bbox = src_boxes[i] 118 | out_mask = src_masks[i] 119 | 120 | tgt_ids = targets[i]["labels"] 121 | tgt_bbox = targets[i]["boxes"] 122 | tgt_mask = target_masks[i] 123 | tgt_valid = targets[i]["valid"] 124 | 125 | # class cost 126 | # we average the cost on valid frames 127 | cost_class = [] 128 | for t in range(nf): 129 | if tgt_valid[t] == 0: 130 | continue 131 | 132 | out_prob_split = out_prob[t] 133 | tgt_ids_split = tgt_ids[t].unsqueeze(0) 134 | 135 | # Compute the classification cost. 136 | alpha = 0.25 137 | gamma = 2.0 138 | neg_cost_class = (1 - alpha) * (out_prob_split ** gamma) * (-(1 - out_prob_split + 1e-8).log()) 139 | pos_cost_class = alpha * ((1 - out_prob_split) ** gamma) * (-(out_prob_split + 1e-8).log()) 140 | if self.num_classes == 1: # binary referred 141 | cost_class_split = pos_cost_class[:, [0]] - neg_cost_class[:, [0]] 142 | else: 143 | cost_class_split = pos_cost_class[:, tgt_ids_split] - neg_cost_class[:, tgt_ids_split] 144 | 145 | cost_class.append(cost_class_split) 146 | cost_class = torch.stack(cost_class, dim=0).mean(0) # [q, 1] 147 | 148 | # box cost 149 | # we average the cost on every frame 150 | cost_bbox, cost_giou = [], [] 151 | for t in range(nf): 152 | out_bbox_split = out_bbox[t] 153 | tgt_bbox_split = tgt_bbox[t].unsqueeze(0) 154 | 155 | # Compute the L1 cost between boxes 156 | cost_bbox_split = torch.cdist(out_bbox_split, tgt_bbox_split, p=1) 157 | 158 | # Compute the giou cost betwen boxes 159 | cost_giou_split = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox_split), 160 | box_cxcywh_to_xyxy(tgt_bbox_split)) 161 | 162 | cost_bbox.append(cost_bbox_split) 163 | cost_giou.append(cost_giou_split) 164 | cost_bbox = torch.stack(cost_bbox, dim=0).mean(0) 165 | cost_giou = torch.stack(cost_giou, dim=0).mean(0) 166 | 167 | # mask cost 168 | # Compute the focal loss between masks 169 | cost_mask = sigmoid_focal_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0)) 170 | 171 | # Compute the dice loss betwen masks 172 | cost_dice = -dice_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0)) 173 | 174 | # Final cost matrix 175 | C = self.cost_class * cost_class + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou + \ 176 | self.cost_mask * cost_mask + self.cost_dice * cost_dice # [q, 1] 177 | 178 | # Only has one tgt, MinCost Matcher 179 | _, src_ind = torch.min(C, dim=0) 180 | tgt_ind = torch.arange(1).to(src_ind) 181 | indices.append((src_ind.long(), tgt_ind.long())) 182 | 183 | # list[tuple], length is batch_size 184 | return indices 185 | 186 | 187 | def build_matcher(args): 188 | if args.binary: 189 | num_classes = 1 190 | else: 191 | if args.dataset_file == 'ytvos': 192 | num_classes = 65 193 | elif args.dataset_file == 'davis': 194 | num_classes = 78 195 | elif args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb': 196 | num_classes = 1 197 | else: 198 | num_classes = 91 # for coco 199 | return HungarianMatcher(cost_class=args.set_cost_class, 200 | cost_bbox=args.set_cost_bbox, 201 | cost_giou=args.set_cost_giou, 202 | cost_mask=args.set_cost_mask, 203 | cost_dice=args.set_cost_dice, 204 | num_classes=num_classes) 205 | 206 | 207 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # Modify for sample points visualization 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import print_function 12 | from __future__ import division 13 | 14 | import warnings 15 | import math 16 | 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | from torch.nn.init import xavier_uniform_, constant_ 21 | 22 | from ..functions import MSDeformAttnFunction 23 | 24 | 25 | def _is_power_of_2(n): 26 | if (not isinstance(n, int)) or (n < 0): 27 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 28 | return (n & (n-1) == 0) and n != 0 29 | 30 | 31 | class MSDeformAttn(nn.Module): 32 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 33 | """ 34 | Multi-Scale Deformable Attention Module 35 | :param d_model hidden dimension 36 | :param n_levels number of feature levels 37 | :param n_heads number of attention heads 38 | :param n_points number of sampling points per attention head per feature level 39 | """ 40 | super().__init__() 41 | if d_model % n_heads != 0: 42 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 43 | _d_per_head = d_model // n_heads 44 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 45 | if not _is_power_of_2(_d_per_head): 46 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 47 | "which is more efficient in our CUDA implementation.") 48 | 49 | self.im2col_step = 64 50 | 51 | self.d_model = d_model 52 | self.n_levels = n_levels 53 | self.n_heads = n_heads 54 | self.n_points = n_points 55 | 56 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 57 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 58 | self.value_proj = nn.Linear(d_model, d_model) 59 | self.output_proj = nn.Linear(d_model, d_model) 60 | 61 | self._reset_parameters() 62 | 63 | def _reset_parameters(self): 64 | constant_(self.sampling_offsets.weight.data, 0.) 65 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 66 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 67 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 68 | for i in range(self.n_points): 69 | grid_init[:, :, i, :] *= i + 1 70 | with torch.no_grad(): 71 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 72 | constant_(self.attention_weights.weight.data, 0.) 73 | constant_(self.attention_weights.bias.data, 0.) 74 | xavier_uniform_(self.value_proj.weight.data) 75 | constant_(self.value_proj.bias.data, 0.) 76 | xavier_uniform_(self.output_proj.weight.data) 77 | constant_(self.output_proj.bias.data, 0.) 78 | 79 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 80 | """ 81 | :param query (N, Length_{query}, C) 82 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 83 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 84 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 85 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 86 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 87 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 88 | 89 | :return output (N, Length_{query}, C) 90 | """ 91 | N, Len_q, _ = query.shape 92 | N, Len_in, _ = input_flatten.shape 93 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 94 | 95 | value = self.value_proj(input_flatten) 96 | if input_padding_mask is not None: 97 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 98 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 99 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 100 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 101 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 102 | # N, Len_q, n_heads, n_levels, n_points, 2 103 | if reference_points.shape[-1] == 2: 104 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 105 | sampling_locations = reference_points[:, :, None, :, None, :] \ 106 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 107 | elif reference_points.shape[-1] == 4: 108 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 109 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 110 | else: 111 | raise ValueError( 112 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 113 | output = MSDeformAttnFunction.apply( 114 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 115 | output = self.output_proj(output) 116 | 117 | return output, sampling_locations, attention_weights 118 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from util.misc import NestedTensor 10 | 11 | # dimension == 1 12 | class PositionEmbeddingSine1D(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=256, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors # [B, C, T] 30 | mask = tensor_list.mask # [B, T] 31 | assert mask is not None 32 | not_mask = ~mask 33 | x_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T] 34 | if self.normalize: 35 | eps = 1e-6 36 | x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, None] / dim_t # [B, T, C] 42 | # n,c,t 43 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 44 | pos = pos_x.permute(0, 2, 1) # [B, C, T] 45 | return pos 46 | 47 | # dimension == 2 48 | class PositionEmbeddingSine2D(nn.Module): 49 | """ 50 | This is a more standard version of the position embedding, very similar to the one 51 | used by the Attention is all you need paper, generalized to work on images. 52 | """ 53 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 54 | super().__init__() 55 | self.num_pos_feats = num_pos_feats 56 | self.temperature = temperature 57 | self.normalize = normalize 58 | if scale is not None and normalize is False: 59 | raise ValueError("normalize should be True if scale is passed") 60 | if scale is None: 61 | scale = 2 * math.pi 62 | self.scale = scale 63 | 64 | def forward(self, tensor_list: NestedTensor): 65 | x = tensor_list.tensors # [B, C, H, W] 66 | mask = tensor_list.mask # [B, H, W] 67 | assert mask is not None 68 | not_mask = ~mask 69 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 70 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 71 | if self.normalize: 72 | eps = 1e-6 73 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 74 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 75 | 76 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 77 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 78 | 79 | pos_x = x_embed[:, :, :, None] / dim_t 80 | pos_y = y_embed[:, :, :, None] / dim_t 81 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 82 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 83 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 84 | return pos # [B, C, H, W] 85 | 86 | 87 | # dimension == 3 88 | class PositionEmbeddingSine3D(nn.Module): 89 | """ 90 | This is a more standard version of the position embedding, very similar to the one 91 | used by the Attention is all you need paper, generalized to work on images. 92 | """ 93 | def __init__(self, num_pos_feats=64, num_frames=36, temperature=10000, normalize=False, scale=None): 94 | super().__init__() 95 | self.num_pos_feats = num_pos_feats 96 | self.temperature = temperature 97 | self.normalize = normalize 98 | self.frames = num_frames 99 | if scale is not None and normalize is False: 100 | raise ValueError("normalize should be True if scale is passed") 101 | if scale is None: 102 | scale = 2 * math.pi 103 | self.scale = scale 104 | 105 | def forward(self, tensor_list: NestedTensor): 106 | x = tensor_list.tensors # [B*T, C, H, W] 107 | mask = tensor_list.mask # [B*T, H, W] 108 | n,h,w = mask.shape 109 | mask = mask.reshape(n//self.frames, self.frames,h,w) # [B, T, H, W] 110 | assert mask is not None 111 | not_mask = ~mask 112 | z_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T, H, W] 113 | y_embed = not_mask.cumsum(2, dtype=torch.float32) # [B, T, H, W] 114 | x_embed = not_mask.cumsum(3, dtype=torch.float32) # [B, T, H, W] 115 | if self.normalize: 116 | eps = 1e-6 117 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale 118 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 119 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 120 | 121 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) # 122 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 123 | 124 | pos_x = x_embed[:, :, :, :, None] / dim_t # [B, T, H, W, c] 125 | pos_y = y_embed[:, :, :, :, None] / dim_t 126 | pos_z = z_embed[:, :, :, :, None] / dim_t 127 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) # [B, T, H, W, c] 128 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 129 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 130 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) # [B, T, C, H, W] 131 | return pos 132 | 133 | 134 | 135 | def build_position_encoding(args): 136 | # build 2D position encoding 137 | N_steps = args.hidden_dim // 2 # 256 / 2 = 128 138 | if args.position_embedding in ('v2', 'sine'): 139 | # TODO find a better way of exposing other arguments 140 | position_embedding = PositionEmbeddingSine2D(N_steps, normalize=True) 141 | else: 142 | raise ValueError(f"not supported {args.position_embedding}") 143 | 144 | return position_embedding 145 | 146 | -------------------------------------------------------------------------------- /models/postprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | """Postprocessors class to transform MDETR output according to the downstream task""" 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | import pycocotools.mask as mask_util 10 | 11 | from util import box_ops 12 | 13 | 14 | class A2DSentencesPostProcess(nn.Module): 15 | """ 16 | This module converts the model's output into the format expected by the coco api for the given task 17 | """ 18 | def __init__(self, threshold=0.5): 19 | super().__init__() 20 | self.threshold = threshold 21 | 22 | @torch.no_grad() 23 | def forward(self, outputs, orig_target_sizes, max_target_sizes): 24 | """ Perform the computation 25 | Parameters: 26 | outputs: raw outputs of the model 27 | orig_target_sizes: original size of the samples (no augmentations or padding) 28 | max_target_sizes: size of samples (input to model) after size augmentation. 29 | NOTE: the max_padding_size is 4x out_masks.shape[-2:] 30 | """ 31 | assert len(orig_target_sizes) == len(max_target_sizes) 32 | 33 | # there is only one valid frames, thus T=1 34 | out_logits = outputs['pred_logits'][:, 0, :, 0] # [B, T, N, 1] -> [B, N] 35 | out_masks = outputs['pred_masks'][:, 0, :, :, :] # [B, T, N, out_h, out_w] -> [B, N, out_h, out_w] 36 | out_h, out_w = out_masks.shape[-2:] 37 | 38 | scores = out_logits.sigmoid() 39 | pred_masks = F.interpolate(out_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False) # [B, N, H, W] 40 | pred_masks = (pred_masks.sigmoid() > 0.5) # [B, N, H, W] 41 | processed_pred_masks, rle_masks = [], [] 42 | # for each batch 43 | for f_pred_masks, resized_size, orig_size in zip(pred_masks, max_target_sizes, orig_target_sizes): 44 | f_mask_h, f_mask_w = resized_size # resized shape without padding 45 | f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, :f_mask_w].unsqueeze(1) # remove the samples' padding 46 | # resize the samples back to their original dataset (target) size for evaluation 47 | f_pred_masks_processed = F.interpolate(f_pred_masks_no_pad.float(), size=tuple(orig_size.tolist()), mode="nearest") 48 | f_pred_rle_masks = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 49 | for mask in f_pred_masks_processed.cpu()] 50 | processed_pred_masks.append(f_pred_masks_processed) 51 | rle_masks.append(f_pred_rle_masks) 52 | predictions = [{'scores': s, 'masks': m, 'rle_masks': rle} 53 | for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] 54 | return predictions 55 | 56 | 57 | # PostProcess for pretraining 58 | class PostProcess(nn.Module): 59 | """ This module converts the model's output into the format expected by the coco api""" 60 | 61 | @torch.no_grad() 62 | def forward(self, outputs, target_sizes): 63 | """Perform the computation 64 | Parameters: 65 | outputs: raw outputs of the model 66 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 67 | For evaluation, this must be the original image size (before any data augmentation) 68 | For visualization, this should be the image size after data augment, but before padding 69 | Returns: 70 | 71 | """ 72 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] 73 | 74 | assert len(out_logits) == len(target_sizes) 75 | assert target_sizes.shape[1] == 2 76 | 77 | # coco, num_frames=1 78 | out_logits = outputs["pred_logits"].flatten(1, 2) 79 | out_boxes = outputs["pred_boxes"].flatten(1, 2) 80 | bs, num_queries = out_logits.shape[:2] 81 | 82 | prob = out_logits.sigmoid() # [bs, num_queries, num_classes] 83 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True) 84 | scores = topk_values # [bs, num_queries] 85 | topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries] 86 | labels = topk_indexes % out_logits.shape[2] # [bs, num_queries] 87 | 88 | boxes = box_ops.box_cxcywh_to_xyxy(out_boxes) # [bs, num_queries, 4] 89 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) 90 | 91 | # and from relative [0, 1] to absolute [0, height] coordinates 92 | img_h, img_w = target_sizes.unbind(1) 93 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 94 | boxes = boxes * scale_fct[:, None, :] # [bs, num_queries, 4] 95 | 96 | assert len(scores) == len(labels) == len(boxes) 97 | results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] 98 | 99 | return results 100 | 101 | 102 | class PostProcessSegm(nn.Module): 103 | """Similar to PostProcess but for segmentation masks. 104 | This processor is to be called sequentially after PostProcess. 105 | Args: 106 | threshold: threshold that will be applied to binarize the segmentation masks. 107 | """ 108 | 109 | def __init__(self, threshold=0.5): 110 | super().__init__() 111 | self.threshold = threshold 112 | 113 | @torch.no_grad() 114 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 115 | """Perform the computation 116 | Parameters: 117 | results: already pre-processed boxes (output of PostProcess) NOTE here 118 | outputs: raw outputs of the model 119 | orig_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 120 | For evaluation, this must be the original image size (before any data augmentation) 121 | For visualization, this should be the image size after data augment, but before padding 122 | max_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 123 | after data augmentation. 124 | """ 125 | assert len(orig_target_sizes) == len(max_target_sizes) 126 | 127 | out_logits = outputs["pred_logits"].flatten(1, 2) 128 | out_masks = outputs["pred_masks"].flatten(1, 2) 129 | bs, num_queries = out_logits.shape[:2] 130 | 131 | prob = out_logits.sigmoid() 132 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True) 133 | scores = topk_values # [bs, num_queries] 134 | topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries] 135 | labels = topk_indexes % out_logits.shape[2] # [bs, num_queries] 136 | 137 | outputs_masks = [out_m[topk_boxes[i]].unsqueeze(0) for i, out_m, in enumerate(out_masks)] # list[Tensor] 138 | outputs_masks = torch.cat(outputs_masks, dim=0) # [bs, num_queries, H, W] 139 | out_h, out_w = outputs_masks.shape[-2:] 140 | 141 | # max_h, max_w = max_target_sizes.max(0)[0].tolist() 142 | # outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) 143 | outputs_masks = F.interpolate(outputs_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False) 144 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() 145 | 146 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): 147 | img_h, img_w = t[0], t[1] 148 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) 149 | results[i]["masks"] = F.interpolate( 150 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 151 | ).byte() 152 | 153 | return results 154 | 155 | 156 | 157 | def build_postprocessors(args, dataset_name): 158 | if dataset_name == 'a2d' or dataset_name == 'jhmdb': 159 | postprocessors = A2DSentencesPostProcess(threshold=args.threshold) 160 | else: 161 | # for coco pretrain postprocessor 162 | postprocessors: Dict[str, nn.Module] = {"bbox": PostProcess()} 163 | if args.masks: 164 | postprocessors["segm"] = PostProcessSegm(threshold=args.threshold) 165 | # postprocessors = PostProcessSegm(threshold=args.threshold) 166 | return postprocessors 167 | -------------------------------------------------------------------------------- /models/vector_quantitizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class VectorQuantizerEMA(nn.Module): 6 | def __init__(self, num_embeddings=512, embedding_dim=256, commitment_cost=0.25, decay=0.99, epsilon=1e-5): 7 | super(VectorQuantizerEMA, self).__init__() 8 | 9 | self._embedding_dim = embedding_dim 10 | self._num_embeddings = num_embeddings 11 | 12 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 13 | self._embedding.weight.data.normal_() 14 | self._commitment_cost = commitment_cost 15 | 16 | self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings)) 17 | self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim)) 18 | self._ema_w.data.normal_() 19 | 20 | self._decay = decay 21 | self._epsilon = epsilon 22 | 23 | def forward(self, inputs): 24 | # inputs btqc 25 | input_shape = inputs.shape 26 | 27 | # Flatten input 28 | flat_input = inputs.view(-1, self._embedding_dim) 29 | 30 | # Calculate distances 31 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) 32 | + torch.sum(self._embedding.weight ** 2, dim=1) 33 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) 34 | 35 | # Encoding 36 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 37 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 38 | encodings.scatter_(1, encoding_indices, 1) 39 | 40 | # Quantize and unflatten 41 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 42 | 43 | # Use EMA to update the embedding vectors 44 | if self.training: 45 | self._ema_cluster_size = self._ema_cluster_size * self._decay + \ 46 | (1 - self._decay) * torch.sum(encodings, 0) 47 | 48 | # Laplace smoothing of the cluster size 49 | n = torch.sum(self._ema_cluster_size.data) 50 | self._ema_cluster_size = ( 51 | (self._ema_cluster_size + self._epsilon) 52 | / (n + self._num_embeddings * self._epsilon) * n) 53 | 54 | dw = torch.matmul(encodings.t(), flat_input) 55 | self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw) 56 | 57 | self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1)) 58 | 59 | # Loss 60 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 61 | loss = self._commitment_cost * e_latent_loss 62 | 63 | # Straight Through Estimator 64 | quantized = inputs + (quantized - inputs).detach() 65 | avg_probs = torch.mean(encodings, dim=0) 66 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 67 | 68 | return loss, quantized, perplexity, encodings 69 | 70 | 71 | class VectorQuantizer(nn.Module): 72 | def __init__(self, num_embeddings=512, embedding_dim=256, commitment_cost=0.25): 73 | super(VectorQuantizer, self).__init__() 74 | 75 | self._embedding_dim = embedding_dim 76 | self._num_embeddings = num_embeddings 77 | 78 | self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim) 79 | self._embedding.weight.data.normal_() 80 | self._commitment_cost = commitment_cost 81 | 82 | def forward(self, inputs): 83 | input_shape = inputs.shape 84 | 85 | # Flatten input 86 | flat_input = inputs.view(-1, self._embedding_dim) 87 | 88 | # Calculate distances 89 | distances = (torch.sum(flat_input ** 2, dim=1, keepdim=True) 90 | + torch.sum(self._embedding.weight ** 2, dim=1) 91 | - 2 * torch.matmul(flat_input, self._embedding.weight.t())) 92 | 93 | # Encoding 94 | encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1) 95 | encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device) 96 | encodings.scatter_(1, encoding_indices, 1) 97 | 98 | # Quantize and unflatten 99 | quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape) 100 | 101 | # Loss 102 | e_latent_loss = F.mse_loss(quantized.detach(), inputs) 103 | q_latent_loss = F.mse_loss(quantized, inputs.detach()) 104 | loss = q_latent_loss + self._commitment_cost * e_latent_loss 105 | 106 | quantized = inputs + (quantized - inputs).detach() 107 | avg_probs = torch.mean(encodings, dim=0) 108 | perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) 109 | 110 | return loss, quantized, perplexity, encodings -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser('ReferFormer training and inference scripts.', add_help=False) 5 | parser.add_argument('--lr', default=1e-4, type=float) 6 | parser.add_argument('--lr_backbone', default=5e-5, type=float) 7 | parser.add_argument('--lr_backbone_names', default=['backbone.0'], type=str, nargs='+') 8 | parser.add_argument('--lr_text_encoder', default=1e-5, type=float) 9 | parser.add_argument('--lr_text_encoder_names', default=['text_encoder'], type=str, nargs='+') 10 | parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+') 11 | parser.add_argument('--lr_linear_proj_mult', default=1.0, type=float) 12 | parser.add_argument('--lr_multi', default=1.0, type=float) 13 | parser.add_argument('--batch_size', default=1, type=int) 14 | parser.add_argument('--weight_decay', default=5e-4, type=float) 15 | parser.add_argument('--epochs', default=12, type=int) 16 | parser.add_argument('--lr_drop', default=[8, 10], type=int, nargs='+') 17 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 18 | help='gradient clipping max norm') 19 | 20 | # Model parameters 21 | # load the pretrained weights 22 | parser.add_argument('--pretrained_weights', type=str, default=None, 23 | help="Path to the pretrained model.") 24 | 25 | # Variants of Deformable DETR 26 | parser.add_argument('--with_box_refine', default=False, action='store_true') 27 | parser.add_argument('--two_stage', default=False, action='store_true') # NOTE: must be false 28 | 29 | # * Backbone 30 | # ["resnet50", "resnet101", "swin_t_p4w7", "swin_s_p4w7", "swin_b_p4w7", "swin_l_p4w7"] 31 | # ["video_swin_t_p4w7", "video_swin_s_p4w7", "video_swin_b_p4w7"] 32 | parser.add_argument('--backbone', default='resnet50', type=str, 33 | help="Name of the convolutional backbone to use") 34 | parser.add_argument('--backbone_pretrained', default=None, type=str, 35 | help="if use swin backbone and train from scratch, the path to the pretrained weights") 36 | parser.add_argument('--use_checkpoint', action='store_true', help='whether use checkpoint for swin/video swin backbone') 37 | parser.add_argument('--dilation', action='store_true', # DC5 38 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 39 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 40 | help="Type of positional embedding to use on top of the image features") 41 | parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels') 42 | 43 | # * Transformer 44 | parser.add_argument('--enc_layers', default=4, type=int, 45 | help="Number of encoding layers in the transformer") 46 | parser.add_argument('--dec_layers', default=4, type=int, 47 | help="Number of decoding layers in the transformer") 48 | parser.add_argument('--dim_feedforward', default=2048, type=int, 49 | help="Intermediate size of the feedforward layers in the transformer blocks") 50 | parser.add_argument('--hidden_dim', default=256, type=int, 51 | help="Size of the embeddings (dimension of the transformer)") 52 | parser.add_argument('--dropout', default=0.1, type=float, 53 | help="Dropout applied in the transformer") 54 | parser.add_argument('--nheads', default=8, type=int, 55 | help="Number of attention heads inside the transformer's attentions") 56 | parser.add_argument('--num_frames', default=5, type=int, 57 | help="Number of clip frames for training") 58 | parser.add_argument('--num_queries', default=5, type=int, 59 | help="Number of query slots, all frames share the same queries") 60 | parser.add_argument('--dec_n_points', default=4, type=int) 61 | parser.add_argument('--enc_n_points', default=4, type=int) 62 | parser.add_argument('--pre_norm', action='store_true') 63 | # for text 64 | parser.add_argument('--freeze_text_encoder', action='store_true') # default: False 65 | 66 | # * Segmentation 67 | parser.add_argument('--masks', action='store_true', 68 | help="Train segmentation head if the flag is provided") 69 | parser.add_argument('--mask_dim', default=256, type=int, 70 | help="Size of the mask embeddings (dimension of the dynamic mask conv)") 71 | parser.add_argument('--controller_layers', default=3, type=int, 72 | help="Dynamic conv layer number") 73 | parser.add_argument('--dynamic_mask_channels', default=8, type=int, 74 | help="Dynamic conv final channel number") 75 | parser.add_argument('--no_rel_coord', dest='rel_coord', action='store_false', 76 | help="Disables relative coordinates") 77 | 78 | # Loss 79 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 80 | help="Disables auxiliary decoding losses (loss at each layer)") 81 | # * Matcher 82 | parser.add_argument('--set_cost_class', default=2, type=float, 83 | help="Class coefficient in the matching cost") 84 | parser.add_argument('--set_cost_bbox', default=5, type=float, 85 | help="L1 box coefficient in the matching cost") 86 | parser.add_argument('--set_cost_giou', default=2, type=float, 87 | help="giou box coefficient in the matching cost") 88 | parser.add_argument('--set_cost_mask', default=2, type=float, 89 | help="mask coefficient in the matching cost") 90 | parser.add_argument('--set_cost_dice', default=5, type=float, 91 | help="mask coefficient in the matching cost") 92 | # * Loss coefficients 93 | parser.add_argument('--mask_loss_coef', default=2, type=float) 94 | parser.add_argument('--dice_loss_coef', default=5, type=float) 95 | parser.add_argument('--cls_loss_coef', default=2, type=float) 96 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 97 | parser.add_argument('--giou_loss_coef', default=2, type=float) 98 | parser.add_argument('--eos_coef', default=0.1, type=float, 99 | help="Relative classification weight of the no-object class") 100 | parser.add_argument('--focal_alpha', default=0.25, type=float) 101 | 102 | # dataset parameters 103 | # ['ytvos', 'davis', 'a2d', 'jhmdb', 'refcoco', 'refcoco+', 'refcocog', 'all'] 104 | # 'all': using the three ref datasets for pretraining 105 | parser.add_argument('--dataset_file', default='ytvos', help='Dataset name') 106 | parser.add_argument('--coco_path', type=str, default='data/coco') 107 | parser.add_argument('--ytvos_path', type=str, default='data/ref-youtube-vos') 108 | parser.add_argument('--davis_path', type=str, default='data/ref-davis') 109 | parser.add_argument('--a2d_path', type=str, default='data/a2d_sentences') 110 | parser.add_argument('--jhmdb_path', type=str, default='/mnt/data/jhmdb') 111 | parser.add_argument('--max_skip', default=3, type=int, help="max skip frame number") 112 | parser.add_argument('--max_size', default=640, type=int, help="max size for the frame") 113 | parser.add_argument('--binary', action='store_true') 114 | parser.add_argument('--remove_difficult', action='store_true') 115 | 116 | parser.add_argument('--output_dir', default='output', 117 | help='path where to save, empty for no saving') 118 | parser.add_argument('--device', default='cuda', 119 | help='device to use for training / testing') 120 | parser.add_argument('--seed', default=42, type=int) 121 | parser.add_argument('--resume', default='', help='resume from checkpoint') 122 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 123 | help='start epoch') 124 | parser.add_argument('--eval', action='store_true') 125 | parser.add_argument('--num_workers', default=4, type=int) 126 | 127 | # test setting 128 | parser.add_argument('--threshold', default=0.5, type=float) # binary threshold for mask 129 | parser.add_argument('--ngpu', default=8, type=int, help='gpu number when inference for ref-ytvos and ref-davis') 130 | parser.add_argument('--split', default='valid', type=str, choices=['valid', 'test']) 131 | parser.add_argument('--visualize', action='store_true', help='whether visualize the masks during inference') 132 | 133 | # distributed training parameters 134 | parser.add_argument('--world_size', default=1, type=int, 135 | help='number of distributed processes') 136 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 137 | parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory') 138 | 139 | # additional parameters 140 | parser.add_argument('--fpn_type', default='dual', help='fpn type can be dual, dyn, and default') 141 | parser.add_argument('--val_type', default='', help='validation dataset type for youtube rvos') 142 | parser.add_argument('--as_vos', default=False, help='as semi-supervised vos') 143 | parser.add_argument('--query_feat_dim', default=2048, help='feat_dim of 1/32 visual feature map') 144 | parser.add_argument('--inf_res', default=360, type=int, help='inference size') 145 | parser.add_argument('--text_enc_type', default='distilroberta-base', help='fpn type can be dual, dyn, and default') 146 | parser.add_argument('--use_cycle', action='store_true', help='use cycle consistency') 147 | parser.add_argument('--add_negative', action='store_true', help='add negative sample on gpu 0 for triplet loss') 148 | parser.add_argument('--only_cycle', action='store_true', help='only train cycle consistency part model') 149 | parser.add_argument('--cycle_loss_dist_coef', default=1, type=float) 150 | parser.add_argument('--cycle_loss_angle_coef', default=1, type=float) 151 | parser.add_argument('--cycle_loss_mse_coef', default=0.0, type=float) 152 | parser.add_argument('--cycle_loss_cls_coef', default=1, type=float) 153 | parser.add_argument('--fg_contra_loss_coef', default=1, type=float) 154 | parser.add_argument('--VQ_loss_coef', default=0.5, type=float) 155 | parser.add_argument('--cycle_loss_contrastive_coef', default=0.1, type=float) 156 | parser.add_argument('--loc_loss_coef', default=3, type=float) 157 | parser.add_argument('--lr_anchor_names', default=['negative_anchor'], type=str, nargs='+') 158 | parser.add_argument('--lr_anchor_mult', default=0.1, type=float) 159 | parser.add_argument('--contra_margin', default=0.5, type=float) 160 | parser.add_argument('--is_eval', action='store_true', help='use in eval') 161 | parser.add_argument('--neg_cls', action='store_true', help='add classifier to classify neg samples') 162 | parser.add_argument('--bert_cycle', action='store_true', help='use 768 dim output from bert as pos gt') 163 | parser.add_argument('--mix_query', action='store_true', help='mix pseudo-text and text query to deformable trans') 164 | parser.add_argument('--quantitize_query', action='store_true', help='quantitize text query') 165 | parser.add_argument('--use_fg_contra', action='store_true', help='use fg contra loss') 166 | parser.add_argument('--freeze_quantitizer', action='store_true', help='freeze quantitizer') 167 | parser.add_argument('--pseudo_label_path', default='', help='pseudo label path') 168 | parser.add_argument('--use_cls', action='store_true', help='use neg cls to filter out negative videos') 169 | parser.add_argument('--use_score', action='store_true', help='use score to filter out negative videos') 170 | parser.add_argument('--save_prob', action='store_true', help='save prob map') 171 | parser.add_argument('--segm_frame', default=5, type=int) 172 | parser.add_argument('--demo_exp', default='a big track on the road', help='demo exp') 173 | parser.add_argument('--demo_path', default='demo/demo_examples', help='demo frames folder path') 174 | return parser 175 | 176 | 177 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.12.5 2 | cython 3 | scipy 4 | opencv-python 5 | pillow 6 | scikit-image 7 | timm 8 | einops 9 | pandas 10 | imgaug 11 | h5py 12 | av -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/tools/__init__.py -------------------------------------------------------------------------------- /tools/colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def colormap(rgb=False): 5 | color_list = np.array( 6 | [ 7 | 0.000, 0.447, 0.741, 8 | 0.850, 0.325, 0.098, 9 | 0.929, 0.694, 0.125, 10 | 0.494, 0.184, 0.556, 11 | 0.466, 0.674, 0.188, 12 | 0.301, 0.745, 0.933, 13 | 0.635, 0.078, 0.184, 14 | 0.300, 0.300, 0.300, 15 | 0.600, 0.600, 0.600, 16 | 1.000, 0.000, 0.000, 17 | 1.000, 0.500, 0.000, 18 | 0.749, 0.749, 0.000, 19 | 0.000, 1.000, 0.000, 20 | 0.000, 0.000, 1.000, 21 | 0.667, 0.000, 1.000, 22 | 0.333, 0.333, 0.000, 23 | 0.333, 0.667, 0.000, 24 | 0.333, 1.000, 0.000, 25 | 0.667, 0.333, 0.000, 26 | 0.667, 0.667, 0.000, 27 | 0.667, 1.000, 0.000, 28 | 1.000, 0.333, 0.000, 29 | 1.000, 0.667, 0.000, 30 | 1.000, 1.000, 0.000, 31 | 0.000, 0.333, 0.500, 32 | 0.000, 0.667, 0.500, 33 | 0.000, 1.000, 0.500, 34 | 0.333, 0.000, 0.500, 35 | 0.333, 0.333, 0.500, 36 | 0.333, 0.667, 0.500, 37 | 0.333, 1.000, 0.500, 38 | 0.667, 0.000, 0.500, 39 | 0.667, 0.333, 0.500, 40 | 0.667, 0.667, 0.500, 41 | 0.667, 1.000, 0.500, 42 | 1.000, 0.000, 0.500, 43 | 1.000, 0.333, 0.500, 44 | 1.000, 0.667, 0.500, 45 | 1.000, 1.000, 0.500, 46 | 0.000, 0.333, 1.000, 47 | 0.000, 0.667, 1.000, 48 | 0.000, 1.000, 1.000, 49 | 0.333, 0.000, 1.000, 50 | 0.333, 0.333, 1.000, 51 | 0.333, 0.667, 1.000, 52 | 0.333, 1.000, 1.000, 53 | 0.667, 0.000, 1.000, 54 | 0.667, 0.333, 1.000, 55 | 0.667, 0.667, 1.000, 56 | 0.667, 1.000, 1.000, 57 | 1.000, 0.000, 1.000, 58 | 1.000, 0.333, 1.000, 59 | 1.000, 0.667, 1.000, 60 | 0.167, 0.000, 0.000, 61 | 0.333, 0.000, 0.000, 62 | 0.500, 0.000, 0.000, 63 | 0.667, 0.000, 0.000, 64 | 0.833, 0.000, 0.000, 65 | 1.000, 0.000, 0.000, 66 | 0.000, 0.167, 0.000, 67 | 0.000, 0.333, 0.000, 68 | 0.000, 0.500, 0.000, 69 | 0.000, 0.667, 0.000, 70 | 0.000, 0.833, 0.000, 71 | 0.000, 1.000, 0.000, 72 | 0.000, 0.000, 0.167, 73 | 0.000, 0.000, 0.333, 74 | 0.000, 0.000, 0.500, 75 | 0.000, 0.000, 0.667, 76 | 0.000, 0.000, 0.833, 77 | 0.000, 0.000, 1.000, 78 | 0.000, 0.000, 0.000, 79 | 0.143, 0.143, 0.143, 80 | 0.286, 0.286, 0.286, 81 | 0.429, 0.429, 0.429, 82 | 0.571, 0.571, 0.571, 83 | 0.714, 0.714, 0.714, 84 | 0.857, 0.857, 0.857, 85 | 1.000, 1.000, 1.000 86 | ] 87 | ).astype(np.float32) 88 | color_list = color_list.reshape((-1, 3)) * 255 89 | if not rgb: 90 | color_list = color_list[:, ::-1] 91 | return color_list -------------------------------------------------------------------------------- /tools/data/convert_davis_to_ytvos.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | 4 | """ 5 | ytvos structure 6 | - train 7 | - Annotations 8 | - video1 9 | - video2 10 | - JPEGImages 11 | - video1 12 | -video2 13 | meta.json 14 | - valid 15 | - Annotations 16 | - JPEGImages 17 | meta.json 18 | - meta_expressions 19 | - train 20 | meta_expressions.json 21 | - valid 22 | meta_expressions.json 23 | """ 24 | 25 | def read_split_set(data_root='data/ref-davis'): 26 | set_split_path = os.path.join(data_root, "DAVIS/ImageSets/2017") 27 | # train set 28 | with open(os.path.join(set_split_path, "train.txt"), "r") as f: 29 | train_set = f.readlines() 30 | train_set = [x.strip() for x in train_set] # 60 videos 31 | # val set 32 | with open(os.path.join(set_split_path, "val.txt"), "r") as f: 33 | val_set = f.readlines() 34 | val_set = [x.strip() for x in val_set] # 30 videos 35 | return train_set, val_set # List 36 | 37 | 38 | def mv_images_to_folder(data_root='data/ref-davis', output_root='data/ref-davis'): 39 | train_img_path = os.path.join(output_root, "train/JPEGImages") 40 | train_anno_path = os.path.join(output_root, "train/Annotations") 41 | val_img_path = os.path.join(output_root, "valid/JPEGImages") 42 | val_anno_path = os.path.join(output_root, "valid/Annotations") 43 | meta_train_path = os.path.join(output_root, "meta_expressions/train") 44 | meta_val_path = os.path.join(output_root, "meta_expressions/valid") 45 | paths = [train_img_path, train_anno_path, val_img_path, val_anno_path, 46 | meta_train_path, meta_val_path] 47 | for path in paths: 48 | if not os.path.exists(path): 49 | os.makedirs(path) 50 | 51 | # 1. read the train/val split 52 | train_set, val_set = read_split_set(data_root) 53 | 54 | # 2. move images and annotations 55 | # train set 56 | for video in train_set: 57 | # move images 58 | base_img_path = os.path.join(data_root, "DAVIS/JPEGImages/480p", video) 59 | mv_cmd = f"mv {base_img_path} {train_img_path}" 60 | os.system(mv_cmd) 61 | # move annotations 62 | base_anno_path = os.path.join(data_root, "DAVIS/Annotations_unsupervised/480p", video) 63 | mv_cmd = f"mv {base_anno_path} {train_anno_path}" 64 | os.system(mv_cmd) 65 | # val set 66 | for video in val_set: 67 | # move images 68 | base_img_path = os.path.join(data_root, "DAVIS/JPEGImages/480p", video) 69 | mv_cmd = f"mv {base_img_path} {val_img_path}" 70 | os.system(mv_cmd) 71 | # move annotations 72 | base_anno_path = os.path.join(data_root, "DAVIS/Annotations_unsupervised/480p", video) 73 | mv_cmd = f"mv {base_anno_path} {val_anno_path}" 74 | os.system(mv_cmd) 75 | 76 | def create_meta_expressions(data_root='data/ref-davis', output_root='data/ref-davis'): 77 | """ 78 | NOTE: expressions odd: first anno, even: full anno 79 | meta_expression.json format 80 | { 81 | "videos": { 82 | "video1: { 83 | "expressions": { 84 | "0": { 85 | "exp": "xxxxx", 86 | "obj_id": "1" (start from 1) 87 | } 88 | "1": { 89 | "exp": "xxxxx", 90 | "obj_id": "1" 91 | } 92 | } 93 | "frames": [ 94 | "00000", 95 | "00001", 96 | ... 97 | ] 98 | } 99 | } 100 | } 101 | """ 102 | train_img_path = os.path.join(output_root, "train/JPEGImages") 103 | val_img_path = os.path.join(output_root, "valid/JPEGImages") 104 | meta_train_path = os.path.join(output_root, "meta_expressions/train") 105 | meta_val_path = os.path.join(output_root, "meta_expressions/valid") 106 | 107 | # 1. read the train/val split 108 | train_set, val_set = read_split_set(data_root) 109 | 110 | # 2. create meta_expression.json 111 | # NOTE: there are two annotators, and each annotator have first anno and full anno, respectively 112 | def read_expressions_from_txt(file_path, encoding='utf-8'): 113 | """ 114 | videos["video1"] = [ 115 | {"obj_id": 1, "exp": "xxxxx"}, 116 | {"obj_id": 2, "exp": "xxxxx"}, 117 | {"obj_id": 3, "exp": "xxxxx"}, 118 | ] 119 | """ 120 | videos = {} 121 | with open(file_path, "r", encoding=encoding) as f: 122 | for idx, line in enumerate(f.readlines()): 123 | line = line.strip() 124 | video_name, obj_id = line.split()[:2] 125 | exp = ' '.join(line.split()[2:])[1:-1] 126 | # handle bad case 127 | if video_name == "clasic-car": 128 | video_name = "classic-car" 129 | elif video_name == "dog-scale": 130 | video_name = "dogs-scale" 131 | elif video_name == "motor-bike": 132 | video_name = "motorbike" 133 | 134 | 135 | if not video_name in videos.keys(): 136 | videos[video_name] = [] 137 | exp_dict = { 138 | "exp": exp, 139 | "obj_id": obj_id 140 | } 141 | videos[video_name].append(exp_dict) 142 | 143 | # sort the order of expressions in each video 144 | for key, value in videos.items(): 145 | value = sorted(value, key = lambda e:e.__getitem__('obj_id')) 146 | videos[key] = value 147 | return videos 148 | 149 | anno1_first_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot1.txt") 150 | anno1_full_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot1_full_video.txt") 151 | anno2_first_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot2.txt") 152 | anno2_full_path = os.path.join(data_root, "davis_text_annotations/Davis17_annot2_full_video.txt") 153 | # all videos information 154 | anno1_first = read_expressions_from_txt(anno1_first_path, encoding='utf-8') 155 | anno1_full = read_expressions_from_txt(anno1_full_path, encoding='utf-8') 156 | anno2_first = read_expressions_from_txt(anno2_first_path, encoding='latin-1') 157 | anno2_full = read_expressions_from_txt(anno2_full_path, encoding='latin-1') 158 | 159 | # 2(1). train 160 | train_videos = {} # {"video1": {}, "video2": {}, ...}, the final results to dump 161 | for video in train_set: # 60 videos 162 | video_dict = {} # for each video 163 | 164 | # store the information of video 165 | expressions = {} 166 | exp_id = 0 # start from 0 167 | for anno1_first_video, anno1_full_video, anno2_first_video, anno2_full_video in zip( 168 | anno1_first[video], anno1_full[video], anno2_first[video], anno2_full[video]): 169 | expressions[str(exp_id)] = anno1_first_video 170 | exp_id += 1 171 | expressions[str(exp_id)] = anno1_full_video 172 | exp_id += 1 173 | expressions[str(exp_id)] = anno2_first_video 174 | exp_id += 1 175 | expressions[str(exp_id)] = anno2_full_video 176 | exp_id += 1 177 | video_dict["expressions"] = expressions 178 | # read frame names for each video 179 | video_frames = os.listdir(os.path.join(train_img_path, video)) 180 | video_frames = [x.split(".")[0] for x in video_frames] # remove ".jpg" 181 | video_frames.sort() 182 | video_dict["frames"] = video_frames 183 | 184 | train_videos[video] = video_dict 185 | 186 | # 2(2). val 187 | val_videos = {} 188 | for video in val_set: 189 | video_dict = {} # for each video 190 | 191 | # store the information of video 192 | expressions = {} 193 | exp_id = 0 # start from 0 194 | for anno1_first_video, anno1_full_video, anno2_first_video, anno2_full_video in zip( 195 | anno1_first[video], anno1_full[video], anno2_first[video], anno2_full[video]): 196 | expressions[str(exp_id)] = anno1_first_video 197 | exp_id += 1 198 | expressions[str(exp_id)] = anno1_full_video 199 | exp_id += 1 200 | expressions[str(exp_id)] = anno2_first_video 201 | exp_id += 1 202 | expressions[str(exp_id)] = anno2_full_video 203 | exp_id += 1 204 | video_dict["expressions"] = expressions 205 | # read frame names for each video 206 | video_frames = os.listdir(os.path.join(val_img_path, video)) 207 | video_frames = [x.split(".")[0] for x in video_frames] # remove ".jpg" 208 | video_frames.sort() 209 | video_dict["frames"] = video_frames 210 | 211 | val_videos[video] = video_dict 212 | 213 | # 3. store the meta_expressions.json 214 | # train 215 | train_meta = {"videos": train_videos} 216 | with open(os.path.join(meta_train_path, "meta_expressions.json"), "w") as out: 217 | json.dump(train_meta, out) 218 | # val 219 | val_meta = {"videos": val_videos} 220 | with open(os.path.join(meta_val_path, "meta_expressions.json"), "w") as out: 221 | json.dump(val_meta, out) 222 | 223 | def create_meta_annotaions(data_root='data/ref-davis', output_root='data/ref-davis'): 224 | """ 225 | NOTE: frame names are not stored compared with ytvos 226 | meta.json format 227 | { 228 | "videos": { 229 | "video1: { 230 | "objects": { 231 | "1": {"category": "bike"}, 232 | "2": {"category": "person"} 233 | } 234 | } 235 | } 236 | } 237 | """ 238 | out_train_path = os.path.join(output_root, "train") 239 | out_val_path = os.path.join(output_root, "valid") 240 | 241 | # read the semantic information 242 | with open(os.path.join(data_root, "DAVIS/davis_semantics.json")) as f: 243 | davis_semantics = json.load(f) 244 | 245 | # 1. read the train/val split 246 | train_set, val_set = read_split_set(data_root) 247 | 248 | # 2. create meta.json 249 | # train 250 | train_videos = {} 251 | for video in train_set: 252 | video_dict = {} # for each video 253 | video_dict["objects"] = {} 254 | num_obj = len(davis_semantics[video].keys()) 255 | for obj_id in range(1, num_obj+1): # start from 1 256 | video_dict["objects"][str(obj_id)] = {"category": davis_semantics[video][str(obj_id)]} 257 | train_videos[video] = video_dict 258 | 259 | # val 260 | val_videos = {} 261 | for video in val_set: 262 | video_dict = {} 263 | video_dict["objects"] = {} 264 | num_obj = len(davis_semantics[video].keys()) 265 | for obj_id in range(1, num_obj+1): # start from 1 266 | video_dict["objects"][str(obj_id)] = {"category": davis_semantics[video][str(obj_id)]} 267 | val_videos[video] = video_dict 268 | 269 | # store the meta.json file 270 | train_meta = {"videos": train_videos} 271 | with open(os.path.join(out_train_path, "meta.json"), "w") as out: 272 | json.dump(train_meta, out) 273 | val_meta = {"videos": val_videos} 274 | with open(os.path.join(out_val_path, "meta.json"), "w") as out: 275 | json.dump(val_meta, out) 276 | 277 | if __name__ == '__main__': 278 | data_root = "/mnt/data/ref-davis" 279 | output_root = "/mnt/data/ref-davis" 280 | print("Converting ref-davis to ref-youtube-vos format....") 281 | mv_images_to_folder(data_root, output_root) 282 | create_meta_expressions(data_root, output_root) 283 | create_meta_annotaions(data_root, output_root) 284 | 285 | -------------------------------------------------------------------------------- /tools/data/convert_refexp_to_coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from datasets.refer import REFER 4 | import cv2 5 | from tqdm import tqdm 6 | import json 7 | import pickle 8 | import json 9 | 10 | 11 | def convert_to_coco(data_root='/mnt/data/refcoco', output_root='/mnt/data/refcoco_referformer', dataset='refcoco', dataset_split='unc'): 12 | dataset_dir = os.path.join(data_root, dataset) 13 | output_dir = os.path.join(output_root, dataset) # .json save path 14 | if not os.path.exists(output_dir): 15 | os.makedirs(output_dir) 16 | 17 | # read REFER 18 | refer = REFER(data_root, dataset, dataset_split) 19 | refs = refer.Refs 20 | anns = refer.Anns 21 | imgs = refer.Imgs 22 | cats = refer.Cats 23 | sents = refer.Sents 24 | """ 25 | # create sets of mapping 26 | # 1) Refs: {ref_id: ref} 27 | # 2) Anns: {ann_id: ann} 28 | # 3) Imgs: {image_id: image} 29 | # 4) Cats: {category_id: category_name} 30 | # 5) Sents: {sent_id: sent} 31 | # 6) imgToRefs: {image_id: refs} 32 | # 7) imgToAnns: {image_id: anns} 33 | # 8) refToAnn: {ref_id: ann} 34 | # 9) annToRef: {ann_id: ref} 35 | # 10) catToRefs: {category_id: refs} 36 | # 11) sentToRef: {sent_id: ref} 37 | # 12) sentToTokens: {sent_id: tokens} 38 | 39 | Refs: List[Dict], "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences" 40 | "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent" 41 | Anns: List[Dict], "segmentation", "area", "iscrowd", "image_id", "bbox", "category_id", "id" 42 | Imgs: List[Dict], "license", "file_name", "coco_url", "height", "width", "date_captured", "flickr_url", "id" 43 | Cats: List[Dict], "supercategory", "name", "id" 44 | Sents: List[Dict], "tokens"(List), "raw", "sent_id", "sent", here the "sent_id" is consistent 45 | """ 46 | print('Dataset [%s_%s] contains: ' % (dataset, dataset_split)) 47 | ref_ids = refer.getRefIds() 48 | image_ids = refer.getImgIds() 49 | print('There are %s expressions for %s refereed objects in %s images.' % ( 50 | len(refer.Sents), len(ref_ids), len(image_ids))) 51 | 52 | print('\nAmong them:') 53 | if dataset == 'refcoco': 54 | splits = ['train', 'val', 'testA', 'testB'] 55 | elif dataset == 'refcoco+': 56 | splits = ['train', 'val', 'testA', 'testB'] 57 | elif dataset == 'refcocog': 58 | splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now. 59 | 60 | for split in splits: 61 | ref_ids = refer.getRefIds(split=split) 62 | print(' %s referred objects are in split [%s].' % (len(ref_ids), split)) 63 | 64 | with open(os.path.join(dataset_dir, "instances.json"), "r") as f: 65 | ann_json = json.load(f) 66 | 67 | # 1. for each split: train, val... 68 | for split in splits: 69 | max_length = 0 # max length of a sentence 70 | 71 | coco_ann = { 72 | "info": "", 73 | "licenses": "", 74 | "images": [], # each caption is a image sample 75 | "annotations": [], 76 | "categories": [] 77 | } 78 | coco_ann['info'], coco_ann['licenses'], coco_ann['categories'] = \ 79 | ann_json['info'], ann_json['licenses'], ann_json['categories'] 80 | 81 | num_images = 0 # each caption is a sample, create a "images" and a "annotations", since each image has one box 82 | ref_ids = refer.getRefIds(split=split) 83 | # 2. for each referred object 84 | for i in tqdm(ref_ids): 85 | ref = refs[i] 86 | # "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences" 87 | # "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent" 88 | img = imgs[ref["image_id"]] 89 | ann = anns[ref["ann_id"]] 90 | 91 | # 3. for each sentence, which is a sample 92 | for sentence in ref["sentences"]: 93 | num_images += 1 94 | # append image info 95 | image_info = { 96 | "file_name": img["file_name"], 97 | "height": img["height"], 98 | "width": img["width"], 99 | "original_id": img["id"], 100 | "id": num_images, 101 | "caption": sentence["sent"], 102 | "dataset_name": dataset 103 | } 104 | coco_ann["images"].append(image_info) 105 | 106 | # append annotation info 107 | ann_info = { 108 | "segmentation": ann["segmentation"], 109 | "area": ann["area"], 110 | "iscrowd": ann["iscrowd"], 111 | "bbox": ann["bbox"], 112 | "image_id": num_images, 113 | "category_id": ann["category_id"], 114 | "id": num_images, 115 | "original_id": ann["id"] 116 | } 117 | coco_ann["annotations"].append(ann_info) 118 | 119 | max_length = max(max_length, len(sentence["tokens"])) 120 | 121 | print("Total expression: {} in split {}".format(num_images, split)) 122 | print("Max sentence length of the split: ", max_length) 123 | # save the json file 124 | save_file = "instances_{}_{}.json".format(dataset, split) 125 | with open(os.path.join(output_dir, save_file), 'w') as f: 126 | json.dump(coco_ann, f) 127 | 128 | 129 | if __name__ == '__main__': 130 | datasets = ["refcoco", "refcoco+", "refcocog"] 131 | datasets_split = ["unc", "unc", "umd"] 132 | for (dataset, dataset_split) in zip(datasets, datasets_split): 133 | convert_to_coco(dataset=dataset, dataset_split=dataset_split) 134 | print("") 135 | 136 | """ 137 | # original mapping 138 | {'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 'airplane': 5, 'bus': 6, 'train': 7, 'truck': 8, 'boat': 9, 139 | 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13, 'parking meter': 14, 'bench': 15, 'bird': 16, 'cat': 17, 140 | 'dog': 18, 'horse': 19, 'sheep': 20, 'cow': 21, 'elephant': 22, 'bear': 23, 'zebra': 24, 'giraffe': 25, 'backpack': 27, 141 | 'umbrella': 28, 'handbag': 31, 'tie': 32, 'suitcase': 33, 'frisbee': 34, 'skis': 35, 'snowboard': 36, 'sports ball': 37, 142 | 'kite': 38, 'baseball bat': 39, 'baseball glove': 40, 'skateboard': 41, 'surfboard': 42, 'tennis racket': 43, 'bottle': 44, 143 | 'wine glass': 46, 'cup': 47, 'fork': 48, 'knife': 49, 'spoon': 50, 'bowl': 51, 'banana': 52, 'apple': 53, 'sandwich': 54, 144 | 'orange': 55, 'broccoli': 56, 'carrot': 57, 'hot dog': 58, 'pizza': 59, 'donut': 60, 'cake': 61, 'chair': 62, 'couch': 63, 145 | 'potted plant': 64, 'bed': 65, 'dining table': 67, 'toilet': 70, 'tv': 72, 'laptop': 73, 'mouse': 74, 'remote': 75, 146 | 'keyboard': 76, 'cell phone': 77, 'microwave': 78, 'oven': 79, 'toaster': 80, 'sink': 81, 'refrigerator': 82, 'book': 84, 147 | 'clock': 85, 'vase': 86, 'scissors': 87, 'teddy bear': 88, 'hair drier': 89, 'toothbrush': 90} 148 | 149 | """ -------------------------------------------------------------------------------- /tools/load_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | def pre_trained_model_to_finetune(checkpoint, args): 5 | if 'model' in checkpoint.keys(): 6 | checkpoint = checkpoint['model'] 7 | # only delete the class_embed since the finetuned dataset has different num_classes 8 | num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers 9 | for l in range(num_layers): 10 | if "class_embed.{}.weight" in checkpoint.keys(): 11 | del checkpoint["class_embed.{}.weight".format(l)] 12 | del checkpoint["class_embed.{}.bias".format(l)] 13 | # # determine backbone.0 14 | # flag = 0 15 | # for key in checkpoint.keys(): 16 | # if 'backbone' in key: 17 | # flag = 1 18 | # if flag == 0: 19 | # new_ckpt = OrderedDict() 20 | # for k, v in checkpoint.items(): 21 | # if 'patch_embed' in k or 'attn.relative_position_' in k: 22 | # continue 23 | # new_ckpt['backbone.0.body.' + k] = v 24 | # checkpoint = new_ckpt 25 | 26 | else: 27 | checkpoint = checkpoint['state_dict'] 28 | new_ckpt = OrderedDict() 29 | for k, v in checkpoint.items(): 30 | if 'patch_embed' in k: 31 | continue 32 | new_ckpt[k.replace('backbone', 'backbone.0.body')] = v 33 | checkpoint = new_ckpt 34 | return checkpoint 35 | -------------------------------------------------------------------------------- /tools/warmup_poly_lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import math 3 | from typing import List 4 | 5 | def _get_warmup_factor_at_iter( 6 | method: str, iter: int, warmup_iters: int, warmup_factor: float 7 | ) -> float: 8 | """ 9 | Return the learning rate warmup factor at a specific iteration. 10 | See :paper:`ImageNet in 1h` for more details. 11 | Args: 12 | method (str): warmup method; either "constant" or "linear". 13 | iter (int): iteration at which to calculate the warmup factor. 14 | warmup_iters (int): the number of warmup iterations. 15 | warmup_factor (float): the base warmup factor (the meaning changes according 16 | to the method used). 17 | Returns: 18 | float: the effective warmup factor at the given iteration. 19 | """ 20 | if iter >= warmup_iters: 21 | return 1.0 22 | 23 | if method == "constant": 24 | return warmup_factor 25 | elif method == "linear": 26 | alpha = iter / warmup_iters 27 | return warmup_factor * (1 - alpha) + alpha 28 | else: 29 | raise ValueError("Unknown warmup method: {}".format(method)) 30 | 31 | class WarmupPolyLR(torch.optim.lr_scheduler._LRScheduler): 32 | """ 33 | Poly learning rate schedule used to train DeepLab. 34 | Paper: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, 35 | Atrous Convolution, and Fully Connected CRFs. 36 | Reference: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/utils/train_utils.py#L337 # noqa 37 | """ 38 | 39 | def __init__( 40 | self, 41 | optimizer: torch.optim.Optimizer, 42 | max_iters: int, 43 | warmup_factor: float = 0.001, 44 | warmup_iters: int = 1000, 45 | warmup_method: str = "linear", 46 | last_epoch: int = -1, 47 | power: float = 0.9, 48 | constant_ending: float = 0.0, 49 | ): 50 | self.max_iters = max_iters 51 | self.warmup_factor = warmup_factor 52 | self.warmup_iters = warmup_iters 53 | self.warmup_method = warmup_method 54 | self.power = power 55 | self.constant_ending = constant_ending 56 | super().__init__(optimizer, last_epoch) 57 | 58 | def get_lr(self) -> List[float]: 59 | warmup_factor = _get_warmup_factor_at_iter( 60 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 61 | ) 62 | if self.constant_ending > 0 and warmup_factor == 1.0: 63 | # Constant ending lr. 64 | if ( 65 | math.pow((1.0 - self.last_epoch / self.max_iters), self.power) 66 | < self.constant_ending 67 | ): 68 | return [base_lr * self.constant_ending for base_lr in self.base_lrs] 69 | return [ 70 | base_lr * warmup_factor * math.pow((1.0 - self.last_epoch / self.max_iters), self.power) 71 | for base_lr in self.base_lrs 72 | ] 73 | 74 | def _compute_values(self) -> List[float]: 75 | # The new interface 76 | return self.get_lr() 77 | 78 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lxa9867/R2VOS/88801cfba43cd0ca14f5da024678dd504f6737dc/util/__init__.py -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | def clip_iou(boxes1,boxes2): 8 | area1 = box_area(boxes1) 9 | area2 = box_area(boxes2) 10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) 11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) 12 | wh = (rb - lt).clamp(min=0) 13 | inter = wh[:,0] * wh[:,1] 14 | union = area1 + area2 - inter 15 | iou = (inter + 1e-6) / (union+1e-6) 16 | return iou 17 | 18 | def multi_iou(boxes1, boxes2): 19 | lt = torch.max(boxes1[...,:2], boxes2[...,:2]) 20 | rb = torch.min(boxes1[...,2:], boxes2[...,2:]) 21 | wh = (rb - lt).clamp(min=0) 22 | wh_1 = boxes1[...,2:] - boxes1[...,:2] 23 | wh_2 = boxes2[...,2:] - boxes2[...,:2] 24 | inter = wh[...,0] * wh[...,1] 25 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter 26 | iou = (inter + 1e-6) / (union + 1e-6) 27 | return iou 28 | 29 | def box_cxcywh_to_xyxy(x): 30 | x_c, y_c, w, h = x.unbind(-1) 31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 32 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 33 | return torch.stack(b, dim=-1) 34 | 35 | 36 | def box_xyxy_to_cxcywh(x): 37 | x0, y0, x1, y1 = x.unbind(-1) 38 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 39 | (x1 - x0), (y1 - y0)] 40 | return torch.stack(b, dim=-1) 41 | 42 | 43 | # modified from torchvision to also return the union 44 | def box_iou(boxes1, boxes2): 45 | area1 = box_area(boxes1) 46 | area2 = box_area(boxes2) 47 | 48 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 49 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 50 | 51 | wh = (rb - lt).clamp(min=0) # [N,M,2] 52 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 53 | 54 | union = area1[:, None] + area2 - inter 55 | 56 | iou = (inter+1e-6) / (union+1e-6) 57 | return iou, union 58 | 59 | 60 | def generalized_box_iou(boxes1, boxes2): 61 | """ 62 | Generalized IoU from https://giou.stanford.edu/ 63 | 64 | The boxes should be in [x0, y0, x1, y1] format 65 | 66 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 67 | and M = len(boxes2) 68 | """ 69 | # degenerate boxes gives inf / nan results 70 | # so do an early check 71 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 72 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 73 | iou, union = box_iou(boxes1, boxes2) 74 | 75 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 76 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 77 | 78 | wh = (rb - lt).clamp(min=0) # [N,M,2] 79 | area = wh[:, :, 0] * wh[:, :, 1] 80 | 81 | return iou - ((area - union) + 1e-6) / (area + 1e-6) 82 | 83 | 84 | def masks_to_boxes(masks): 85 | """Compute the bounding boxes around the provided masks 86 | 87 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 88 | 89 | Returns a [N, 4] tensors, with the boxes in xyxy format 90 | """ 91 | if masks.numel() == 0: 92 | return torch.zeros((0, 4), device=masks.device) 93 | 94 | h, w = masks.shape[-2:] 95 | 96 | y = torch.arange(0, h, dtype=torch.float) 97 | x = torch.arange(0, w, dtype=torch.float) 98 | y, x = torch.meshgrid(y, x) 99 | 100 | x_mask = (masks * x.unsqueeze(0)) 101 | x_max = x_mask.flatten(1).max(-1)[0] 102 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 103 | 104 | y_mask = (masks * y.unsqueeze(0)) 105 | y_max = y_mask.flatten(1).max(-1)[0] 106 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 107 | 108 | return torch.stack([x_min, y_min, x_max, y_max], 1) 109 | --------------------------------------------------------------------------------