├── .gitignore ├── LICENSE ├── README.md ├── datasets ├── __init__.py ├── a2d.py ├── a2d_eval.py ├── categories.py ├── coco.py ├── coco_eval.py ├── concat_dataset.py ├── davis.py ├── image_to_seq_augmenter.py ├── jhmdb.py ├── refer.py ├── refexp.py ├── refexp2seq.py ├── refexp_eval.py ├── samplers.py ├── transforms_image.py ├── transforms_video.py └── ytvos.py ├── davis2017 ├── __init__.py ├── davis.py ├── evaluation.py ├── metrics.py ├── results.py └── utils.py ├── docs ├── data.md ├── framework.png └── install.md ├── engine.py ├── eval_davis.py ├── inference_davis.py ├── inference_ytvos.py ├── main.py ├── main_pretrain.py ├── models ├── __init__.py ├── backbone.py ├── criterion.py ├── decoder.py ├── deformable_transformer.py ├── matcher.py ├── modules.py ├── ops │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── setup.py │ ├── src │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn_cpu.h │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.cu │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_im2col_cuda.cuh │ │ ├── ms_deform_attn.h │ │ └── vision.cpp │ └── test.py ├── position_encoding.py ├── postprocessors.py ├── segmentation.py ├── sgmg.py ├── swin_transformer.py ├── text_encoder │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── text_encoder.py │ └── tokenizer.py └── video_swin_transformer.py ├── opts.py ├── requirements.txt ├── scripts ├── dist_test_davis_videoswinb.sh ├── dist_test_ytvos_videoswinb.sh ├── dist_train_a2d_videoswinb.sh ├── dist_train_scratch_ytvos_videoswin.sh ├── dist_train_ytvos_videoswin.sh └── dist_train_ytvos_videoswinb.sh ├── tools ├── colormap.py ├── data │ ├── convert_davis_to_ytvos.py │ └── convert_refexp_to_coco.py └── load_pretrained_weights.py ├── util ├── __init__.py ├── box_ops.py ├── logger.py └── misc.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | log/ 2 | saves/ 3 | .vscode/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | *idea 135 | *runs 136 | *nohup* 137 | checkpoints -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License](https://img.shields.io/badge/license-CC--BY--NC%204.0-green)](https://creativecommons.org/licenses/by-nc/4.0/) 2 | [![arXiv](https://img.shields.io/badge/cs.CV-%09arXiv%3A2205.00823-red)](https://arxiv.org/abs/2307.13537) 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/spectrum-guided-multi-granularity-referring/referring-expression-segmentation-on-a2d)](https://paperswithcode.com/sota/referring-expression-segmentation-on-a2d?p=spectrum-guided-multi-granularity-referring) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/spectrum-guided-multi-granularity-referring/referring-expression-segmentation-on-j-hmdb)](https://paperswithcode.com/sota/referring-expression-segmentation-on-j-hmdb?p=spectrum-guided-multi-granularity-referring) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/spectrum-guided-multi-granularity-referring/referring-expression-segmentation-on-refer-1)](https://paperswithcode.com/sota/referring-expression-segmentation-on-refer-1?p=spectrum-guided-multi-granularity-referring) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/spectrum-guided-multi-granularity-referring/referring-expression-segmentation-on-davis)](https://paperswithcode.com/sota/referring-expression-segmentation-on-davis?p=spectrum-guided-multi-granularity-referring) 8 | 9 | ## New: see the new work [HTR](https://github.com/bo-miao/HTR) (TCSVT 2024), the first end-to-end decoupled framework that improves the baseline by 4.7%. It identifies aligned frames for text-conditioned segmentation and builds memory, then propagates mask features to segment the remaining frames for temporally consistent R-VOS. A new metric for evaluating temporal consistency is also introduced. 10 | 11 | ## New: see our latest work [RefHuman](https://github.com/bo-miao/RefHuman) (NeurIPS 2024), which introduces a unified model for referring to any person in the wild using text, clicks, or scribbles! 12 | 13 | The official implementation of the **ICCV 2023** paper: 14 | 15 |
16 |

17 | 18 | Spectrum-guided Multi-granularity Referring Video Object Segmentation 19 | 20 |

21 |
22 | 23 | 24 |

25 | 26 | > [**Spectrum-guided Multi-granularity Referring Video Object Segmentation**](https://arxiv.org/abs/2307.13537) 27 | > 28 | > Bo Miao, Mohammed Bennamoun, Yongsheng Gao, Ajmal Mian 29 | > 30 | > ICCV 2023 31 | 32 | ## Introduction 33 | 34 | We propose a Spectrum-guided Multi-granularity (SgMg) approach that follows a segment-and-optimize pipeline to tackle the feature drift issue found in previous decode-and-segment approaches. Extensive experiments show that SgMg achieves state-of-the-art overall performance on multiple benchmark datasets, outperforming the closest competitor by 2.8% points on Ref-YouTube-VOS with faster inference time. 35 | 36 | ## Setup 37 | 38 | The main setup of our code follows [Referformer](https://github.com/wjn922/ReferFormer). 39 | 40 | Please refer to [install.md](docs/install.md) for installation. 41 | 42 | Please refer to [data.md](docs/data.md) for data preparation. 43 | 44 | ## Training and Evaluation 45 | 46 | All the models are trained using 2 RTX 3090 GPU. If you encounter the OOM error, please add the command `--use_checkpoint`. 47 | 48 | The training and evaluation scripts are included in the `scripts` folder. If you want to train/evaluate SgMg, please run the following command: 49 | 50 | ``` 51 | sh dist_train_ytvos_videoswinb.sh 52 | ``` 53 | 54 | ``` 55 | sh dist_test_ytvos_videoswinb.sh 56 | ``` 57 | 58 | Note: You can modify the `--backbone` and `--backbone_pretrained` to specify a backbone. 59 | 60 | ## Model Zoo 61 | 62 | We provide the pretrained model for different [visual backbones](https://drive.google.com/drive/folders/13XFkNtYFIcTgEc3d7-8wQA-Ovi0T_z2v?usp=sharing) and the checkpoints for SgMg (refer below). 63 | 64 | You can put the models in the `checkpoints` folder to start training/inference. 65 | 66 | ### Results (Ref-YouTube-VOS & Ref-DAVIS) 67 | 68 | To evaluate the results, please upload the zip file to the [competition server](https://codalab.lisn.upsaclay.fr/competitions/3282#participate). 69 | 70 | | Backbone| Ref-YouTube-VOS J&F | Ref-DAVIS J&F | Model | Submission | 71 | | :----: | :----: | :----: | :----: | :----: | 72 | | Video-Swin-T | 62.0 | 61.9 |[model](https://drive.google.com/file/d/1SiHl7oYqBabaN28nsrNOJeiZrJyhRixl/view?usp=sharing) | [link](https://drive.google.com/file/d/1jEVlgPzAuNJxOrcy83r0jbsGcRwPao3-/view?usp=sharing) | 73 | | Video-Swin-B | 65.7 | 63.3 | [model](https://drive.google.com/file/d/1sZngZ_7JlgZWX2bEQ7Xw36_VbBOlLJU8/view?usp=sharing) | [link](https://drive.google.com/file/d/1t5XqyqEsIvh0D92oSn-Pct4bfLsxcz73/view?usp=sharing) | 74 | 75 | ### Results (A2D-Sentences & JHMDB-Sentences) 76 | 77 | | Backbone | (A2D) mAP | Mean IoU | Overall IoU | (JHMDB) mAP | Mean IoU | Overall IoU | Model | 78 | | :----: | :----: | :----: | :----: | :----: | :----: | :----: | :----: | 79 | | Video-Swin-T | 56.1 | 78.0 | 70.4 | 44.4 | 72.8 | 71.7 | [model](https://drive.google.com/file/d/1LKjaMOBrpGT7tWLS3CmhDl_QQfUQPglJ/view?usp=sharing) | 80 | | Video-Swin-B | 58.5 | 79.9 | 72.0 | 45.0 | 73.7 | 72.5 | [model](https://drive.google.com/file/d/1PQh0QSWqWUUnWf9WtvHgZ7plrORQvjzN/view?usp=sharing) | 81 | 82 | ### Results (RefCOCO/+/g) 83 | 84 | The overall IoU is used as the metric, and the model is obtained from the pre-training stage mentioned in the paper. 85 | 86 | | Backbone | RefCOCO | RefCOCO+ | RefCOCOg | Model | 87 | | :----: | :----: | :----: | :----: | :----: | 88 | | Video-Swin-B | 76.3 | 66.4 | 70.0 | [model](https://drive.google.com/file/d/1URnBMpZh0J7mBg6H2b1pdqywMM8vOopG/view?usp=sharing) | 89 | 90 | ## Acknowledgements 91 | 92 | - [Deformable DETR](https://github.com/fundamentalvision/Deformable-DETR) 93 | - [ReferFormer](https://github.com/wjn922/ReferFormer) 94 | - [MTTR](https://github.com/mttr2021/MTTR) 95 | 96 | ## Citation 97 | 98 | ``` 99 | @InProceedings{Miao_2023_ICCV, 100 | author = {Miao, Bo and Bennamoun, Mohammed and Gao, Yongsheng and Mian, Ajmal}, 101 | title = {Spectrum-guided Multi-granularity Referring Video Object Segmentation}, 102 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 103 | month = {October}, 104 | year = {2023}, 105 | pages = {920-930} 106 | } 107 | ``` 108 | 109 | ## Contact 110 | If you have any questions about this project, please feel free to contact bomiaobbb@gmail.com. 111 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .ytvos import build as build_ytvos 5 | from .davis import build as build_davis 6 | from .a2d import build as build_a2d 7 | from .jhmdb import build as build_jhmdb 8 | from .refexp import build as build_refexp 9 | from .concat_dataset import build as build_joint 10 | from .concat_dataset import build_coco as build_joint_coco 11 | from .concat_dataset import build_joint_ytb_dvs 12 | 13 | def get_coco_api_from_dataset(dataset): 14 | for _ in range(10): 15 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 16 | # break 17 | if isinstance(dataset, torch.utils.data.Subset): 18 | dataset = dataset.dataset 19 | if isinstance(dataset, torchvision.datasets.CocoDetection): 20 | return dataset.coco 21 | 22 | 23 | def build_dataset(dataset_file: str, image_set: str, args): 24 | if dataset_file == 'ytvos': 25 | print("\n **** Start to build dataset {}. **** \n".format("build_ytvos")) 26 | return build_ytvos(image_set, args) 27 | if dataset_file == 'davis': 28 | print("\n **** Start to build dataset {}. **** \n".format("build_davis")) 29 | return build_davis(image_set, args) 30 | if dataset_file == 'a2d': 31 | print("\n **** Start to build dataset {}. **** \n".format("build_a2d")) 32 | return build_a2d(image_set, args) 33 | if dataset_file == 'jhmdb': 34 | print("\n **** Start to build dataset {}. **** \n".format("build_jhmdb")) 35 | return build_jhmdb(image_set, args) 36 | # for pretraining 37 | if dataset_file == "refcoco" or dataset_file == "refcoco+" or dataset_file == "refcocog": 38 | print("\n **** Start to build dataset {}. **** \n".format("build_refexp")) 39 | return build_refexp(dataset_file, image_set, args) 40 | 41 | # for joint training of refcoco and ytvos, not used. 42 | if dataset_file == 'joint': 43 | print("\n **** Start to build dataset {}. **** \n".format("build_joint")) 44 | return build_joint(image_set, args) 45 | if dataset_file == 'joint_coco': 46 | print("\n **** Start to build dataset {}. **** \n".format("build_joint_coco")) 47 | return build_joint_coco(image_set, args) 48 | if dataset_file == 'ytvos_joint_davis': 49 | print("\n **** Start to build dataset {}. **** \n".format("build_joint_ytb_dvs")) 50 | return build_joint_ytb_dvs(image_set, args) 51 | raise ValueError(f'dataset {dataset_file} not supported') 52 | -------------------------------------------------------------------------------- /datasets/a2d_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains implementations for the precision@k and IoU (mean, overall) evaluation metrics. 3 | copy-paste from https://github.com/mttr2021/MTTR/blob/main/metrics.py 4 | """ 5 | import torch 6 | from tqdm import tqdm 7 | from pycocotools.coco import COCO 8 | from pycocotools.mask import decode 9 | import numpy as np 10 | 11 | from torchvision.ops.boxes import box_area 12 | 13 | def compute_bbox_iou(boxes1: torch.Tensor, boxes2: torch.Tensor): 14 | # both boxes: xyxy 15 | area1 = box_area(boxes1) 16 | area2 = box_area(boxes2) 17 | 18 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 19 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 20 | 21 | wh = (rb - lt).clamp(min=0) # [N,M,2] 22 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 23 | 24 | union = area1[:, None] + area2 - inter 25 | 26 | iou = (inter+1e-6) / (union+1e-6) 27 | return iou, inter, union 28 | 29 | def compute_mask_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): 30 | outputs = outputs.int() 31 | intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 32 | union = (outputs | labels).float().sum((1, 2)) # Will be zero if both are 0 33 | iou = (intersection + EPS) / (union + EPS) # EPS is used to avoid division by zero 34 | return iou, intersection, union 35 | 36 | # mask 37 | def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): 38 | print('evaluating mask precision@k & iou metrics...') 39 | counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} 40 | total_intersection_area = 0 41 | total_union_area = 0 42 | ious_list = [] 43 | for instance in tqdm(coco_gt.imgs.keys()): # each image_id contains exactly one instance 44 | gt_annot = coco_gt.imgToAnns[instance][0] 45 | gt_mask = decode(gt_annot['segmentation']) 46 | pred_annots = coco_pred.imgToAnns[instance] 47 | pred_annot = sorted(pred_annots, key=lambda a: a['score'])[-1] # choose pred with highest score 48 | pred_mask = decode(pred_annot['segmentation']) 49 | iou, intersection, union = compute_mask_iou(torch.tensor(pred_mask).unsqueeze(0), 50 | torch.tensor(gt_mask).unsqueeze(0)) 51 | iou, intersection, union = iou.item(), intersection.item(), union.item() 52 | for iou_threshold in counters_by_iou.keys(): 53 | if iou > iou_threshold: 54 | counters_by_iou[iou_threshold] += 1 55 | total_intersection_area += intersection 56 | total_union_area += union 57 | ious_list.append(iou) 58 | num_samples = len(ious_list) 59 | precision_at_k = np.array(list(counters_by_iou.values())) / num_samples 60 | overall_iou = total_intersection_area / total_union_area 61 | mean_iou = np.mean(ious_list) 62 | return precision_at_k, overall_iou, mean_iou 63 | 64 | # bbox 65 | def calculate_bbox_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): 66 | print('evaluating bbox precision@k & iou metrics...') 67 | counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} 68 | total_intersection_area = 0 69 | total_union_area = 0 70 | ious_list = [] 71 | for instance in tqdm(coco_gt.imgs.keys()): # each image_id contains exactly one instance 72 | gt_annot = coco_gt.imgToAnns[instance][0] 73 | gt_bbox = gt_annot['bbox'] # xywh 74 | gt_bbox = [ 75 | gt_bbox[0], 76 | gt_bbox[1], 77 | gt_bbox[2] + gt_bbox[0], 78 | gt_bbox[3] + gt_bbox[1], 79 | ] 80 | pred_annots = coco_pred.imgToAnns[instance] 81 | pred_annot = sorted(pred_annots, key=lambda a: a['score'])[-1] # choose pred with highest score 82 | pred_bbox = pred_annot['bbox'] # xyxy 83 | iou, intersection, union = compute_bbox_iou(torch.tensor(pred_bbox).unsqueeze(0), 84 | torch.tensor(gt_bbox).unsqueeze(0)) 85 | iou, intersection, union = iou.item(), intersection.item(), union.item() 86 | for iou_threshold in counters_by_iou.keys(): 87 | if iou > iou_threshold: 88 | counters_by_iou[iou_threshold] += 1 89 | total_intersection_area += intersection 90 | total_union_area += union 91 | ious_list.append(iou) 92 | num_samples = len(ious_list) 93 | precision_at_k = np.array(list(counters_by_iou.values())) / num_samples 94 | overall_iou = total_intersection_area / total_union_area 95 | mean_iou = np.mean(ious_list) 96 | return precision_at_k, overall_iou, mean_iou 97 | -------------------------------------------------------------------------------- /datasets/categories.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------- 2 | # 1. refer_youtube_vos 3 | ytvos_category_dict = { 4 | 'airplane': 0, 'ape': 1, 'bear': 2, 'bike': 3, 'bird': 4, 'boat': 5, 'bucket': 6, 'bus': 7, 'camel': 8, 'cat': 9, 5 | 'cow': 10, 'crocodile': 11, 'deer': 12, 'dog': 13, 'dolphin': 14, 'duck': 15, 'eagle': 16, 'earless_seal': 17, 6 | 'elephant': 18, 'fish': 19, 'fox': 20, 'frisbee': 21, 'frog': 22, 'giant_panda': 23, 'giraffe': 24, 'hand': 25, 7 | 'hat': 26, 'hedgehog': 27, 'horse': 28, 'knife': 29, 'leopard': 30, 'lion': 31, 'lizard': 32, 'monkey': 33, 8 | 'motorbike': 34, 'mouse': 35, 'others': 36, 'owl': 37, 'paddle': 38, 'parachute': 39, 'parrot': 40, 'penguin': 41, 9 | 'person': 42, 'plant': 43, 'rabbit': 44, 'raccoon': 45, 'sedan': 46, 'shark': 47, 'sheep': 48, 'sign': 49, 10 | 'skateboard': 50, 'snail': 51, 'snake': 52, 'snowboard': 53, 'squirrel': 54, 'surfboard': 55, 'tennis_racket': 56, 11 | 'tiger': 57, 'toilet': 58, 'train': 59, 'truck': 60, 'turtle': 61, 'umbrella': 62, 'whale': 63, 'zebra': 64 12 | } 13 | 14 | ytvos_category_list = [ 15 | 'airplane', 'ape', 'bear', 'bike', 'bird', 'boat', 'bucket', 'bus', 'camel', 'cat', 'cow', 'crocodile', 16 | 'deer', 'dog', 'dolphin', 'duck', 'eagle', 'earless_seal', 'elephant', 'fish', 'fox', 'frisbee', 'frog', 17 | 'giant_panda', 'giraffe', 'hand', 'hat', 'hedgehog', 'horse', 'knife', 'leopard', 'lion', 'lizard', 18 | 'monkey', 'motorbike', 'mouse', 'others', 'owl', 'paddle', 'parachute', 'parrot', 'penguin', 'person', 19 | 'plant', 'rabbit', 'raccoon', 'sedan', 'shark', 'sheep', 'sign', 'skateboard', 'snail', 'snake', 'snowboard', 20 | 'squirrel', 'surfboard', 'tennis_racket', 'tiger', 'toilet', 'train', 'truck', 'turtle', 'umbrella', 'whale', 'zebra' 21 | ] 22 | 23 | # ------------------------------------------------------------------------------------------------------------------- 24 | # 2. refer_davis17 25 | davis_category_dict = { 26 | 'airplane': 0, 'backpack': 1, 'ball': 2, 'bear': 3, 'bicycle': 4, 'bird': 5, 'boat': 6, 'bottle': 7, 'box': 8, 'bus': 9, 27 | 'camel': 10, 'car': 11, 'carriage': 12, 'cat': 13, 'cellphone': 14, 'chamaleon': 15, 'cow': 16, 'deer': 17, 'dog': 18, 28 | 'dolphin': 19, 'drone': 20, 'elephant': 21, 'excavator': 22, 'fish': 23, 'goat': 24, 'golf cart': 25, 'golf club': 26, 29 | 'grass': 27, 'guitar': 28, 'gun': 29, 'helicopter': 30, 'horse': 31, 'hoverboard': 32, 'kart': 33, 'key': 34, 'kite': 35, 30 | 'koala': 36, 'leash': 37, 'lion': 38, 'lock': 39, 'mask': 40, 'microphone': 41, 'monkey': 42, 'motorcycle': 43, 'oar': 44, 31 | 'paper': 45, 'paraglide': 46, 'person': 47, 'pig': 48, 'pole': 49, 'potted plant': 50, 'puck': 51, 'rack': 52, 'rhino': 53, 32 | 'rope': 54, 'sail': 55, 'scale': 56, 'scooter': 57, 'selfie stick': 58, 'sheep': 59, 'skateboard': 60, 'ski': 61, 'ski poles': 62, 33 | 'snake': 63, 'snowboard': 64, 'stick': 65, 'stroller': 66, 'surfboard': 67, 'swing': 68, 'tennis racket': 69, 'tractor': 70, 34 | 'trailer': 71, 'train': 72, 'truck': 73, 'turtle': 74, 'varanus': 75, 'violin': 76, 'wheelchair': 77 35 | } 36 | 37 | davis_category_list = [ 38 | 'airplane', 'backpack', 'ball', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'box', 'bus', 'camel', 'car', 'carriage', 39 | 'cat', 'cellphone', 'chamaleon', 'cow', 'deer', 'dog', 'dolphin', 'drone', 'elephant', 'excavator', 'fish', 'goat', 40 | 'golf cart', 'golf club', 'grass', 'guitar', 'gun', 'helicopter', 'horse', 'hoverboard', 'kart', 'key', 'kite', 'koala', 41 | 'leash', 'lion', 'lock', 'mask', 'microphone', 'monkey', 'motorcycle', 'oar', 'paper', 'paraglide', 'person', 'pig', 42 | 'pole', 'potted plant', 'puck', 'rack', 'rhino', 'rope', 'sail', 'scale', 'scooter', 'selfie stick', 'sheep', 'skateboard', 43 | 'ski', 'ski poles', 'snake', 'snowboard', 'stick', 'stroller', 'surfboard', 'swing', 'tennis racket', 'tractor', 'trailer', 44 | 'train', 'truck', 'turtle', 'varanus', 'violin', 'wheelchair' 45 | ] -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | COCO dataset which returns image_id for evaluation. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 5 | """ 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | from pycocotools import mask as coco_mask 12 | 13 | import datasets.transforms as T 14 | 15 | 16 | class CocoDetection(torchvision.datasets.CocoDetection): 17 | def __init__(self, img_folder, ann_file, transforms, return_masks): 18 | super(CocoDetection, self).__init__(img_folder, ann_file) 19 | self._transforms = transforms 20 | self.prepare = ConvertCocoPolysToMask(return_masks) 21 | 22 | def __getitem__(self, idx): 23 | img, target = super(CocoDetection, self).__getitem__(idx) 24 | image_id = self.ids[idx] 25 | target = {'image_id': image_id, 'annotations': target} 26 | 27 | img, target = self.prepare(img, target) 28 | if self._transforms is not None: 29 | img, target = self._transforms(img, target) 30 | return img, target 31 | 32 | 33 | def convert_coco_poly_to_mask(segmentations, height, width): 34 | masks = [] 35 | for polygons in segmentations: 36 | rles = coco_mask.frPyObjects(polygons, height, width) 37 | mask = coco_mask.decode(rles) 38 | if len(mask.shape) < 3: 39 | mask = mask[..., None] 40 | mask = torch.as_tensor(mask, dtype=torch.uint8) 41 | mask = mask.any(dim=2) 42 | masks.append(mask) 43 | if masks: 44 | masks = torch.stack(masks, dim=0) 45 | else: 46 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 47 | return masks 48 | 49 | 50 | class ConvertCocoPolysToMask(object): 51 | def __init__(self, return_masks=False): 52 | self.return_masks = return_masks 53 | 54 | def __call__(self, image, target): 55 | w, h = image.size 56 | 57 | image_id = target["image_id"] 58 | image_id = torch.tensor([image_id]) 59 | 60 | anno = target["annotations"] 61 | 62 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 63 | 64 | boxes = [obj["bbox"] for obj in anno] 65 | # guard against no boxes via resizing 66 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 67 | boxes[:, 2:] += boxes[:, :2] 68 | boxes[:, 0::2].clamp_(min=0, max=w) 69 | boxes[:, 1::2].clamp_(min=0, max=h) 70 | 71 | classes = [obj["category_id"] for obj in anno] 72 | classes = torch.tensor(classes, dtype=torch.int64) 73 | 74 | if self.return_masks: 75 | segmentations = [obj["segmentation"] for obj in anno] 76 | masks = convert_coco_poly_to_mask(segmentations, h, w) 77 | 78 | keypoints = None 79 | if anno and "keypoints" in anno[0]: 80 | keypoints = [obj["keypoints"] for obj in anno] 81 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 82 | num_keypoints = keypoints.shape[0] 83 | if num_keypoints: 84 | keypoints = keypoints.view(num_keypoints, -1, 3) 85 | 86 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 87 | boxes = boxes[keep] 88 | classes = classes[keep] 89 | if self.return_masks: 90 | masks = masks[keep] 91 | if keypoints is not None: 92 | keypoints = keypoints[keep] 93 | 94 | target = {} 95 | target["boxes"] = boxes 96 | target["labels"] = classes 97 | if self.return_masks: 98 | target["masks"] = masks 99 | target["image_id"] = image_id 100 | if keypoints is not None: 101 | target["keypoints"] = keypoints 102 | 103 | # for conversion to coco api 104 | area = torch.tensor([obj["area"] for obj in anno]) 105 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 106 | target["area"] = area[keep] 107 | target["iscrowd"] = iscrowd[keep] 108 | 109 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 110 | target["size"] = torch.as_tensor([int(h), int(w)]) 111 | 112 | return image, target 113 | 114 | 115 | def make_coco_transforms(image_set): 116 | 117 | normalize = T.Compose([ 118 | T.ToTensor(), 119 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 120 | ]) 121 | 122 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 123 | 124 | if image_set == 'train': 125 | return T.Compose([ 126 | T.RandomHorizontalFlip(), 127 | T.RandomSelect( 128 | T.RandomResize(scales, max_size=1333), 129 | T.Compose([ 130 | T.RandomResize([400, 500, 600]), 131 | T.RandomSizeCrop(384, 600), 132 | T.RandomResize(scales, max_size=1333), 133 | ]) 134 | ), 135 | normalize, 136 | ]) 137 | 138 | if image_set == 'val': 139 | return T.Compose([ 140 | T.RandomResize([800], max_size=1333), 141 | normalize, 142 | ]) 143 | 144 | raise ValueError(f'unknown {image_set}') 145 | 146 | 147 | def build(image_set, args): 148 | root = Path(args.coco_path) 149 | assert root.exists(), f'provided COCO path {root} does not exist' 150 | mode = 'instances' 151 | PATHS = { 152 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), 153 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), 154 | } 155 | img_folder, ann_file = PATHS[image_set] 156 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) 157 | return dataset 158 | -------------------------------------------------------------------------------- /datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from .refexp2seq import build as build_seq_refexp 13 | from .ytvos import build as build_ytvs 14 | from .davis import build as build_davis 15 | from datasets import ytvos 16 | 17 | 18 | # join ref coco and ytvos 19 | def build(image_set, args): 20 | concat_data = [] 21 | 22 | print('preparing coco2seq dataset ....') 23 | coco_names = ["refcoco", "refcoco+", "refcocog"] 24 | for name in coco_names: 25 | coco_seq = build_seq_refexp(name, image_set, args) 26 | concat_data.append(coco_seq) 27 | 28 | print('preparing ytvos dataset .... ') 29 | ytvos_dataset = build_ytvs(image_set, args) 30 | concat_data.append(ytvos_dataset) 31 | 32 | concat_data = ConcatDataset(concat_data) 33 | 34 | return concat_data 35 | 36 | def build_coco(image_set, args): 37 | concat_data = [] 38 | 39 | print('preparing coco2seq dataset ....') 40 | coco_names = ["refcoco", "refcoco+", "refcocog"] 41 | for name in coco_names: 42 | coco_seq = build_seq_refexp(name, image_set, args) 43 | concat_data.append(coco_seq) 44 | 45 | concat_data = ConcatDataset(concat_data) 46 | return concat_data 47 | 48 | def build_joint_ytb_dvs(image_set, args): 49 | concat_data = [] 50 | 51 | print('preparing davis dataset ....') 52 | dvs_dataset = build_davis(image_set, args) 53 | for i in range(5): 54 | concat_data.append(dvs_dataset) 55 | 56 | print('preparing ytvos dataset .... ') 57 | ytvos_dataset = build_ytvs(image_set, args) 58 | concat_data.append(ytvos_dataset) 59 | 60 | concat_data = ConcatDataset(concat_data) 61 | 62 | return concat_data 63 | -------------------------------------------------------------------------------- /datasets/image_to_seq_augmenter.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from SeqFormer (https://github.com/wjf5203/SeqFormer) 3 | # ------------------------------------------------------------------------ 4 | # Modified from STEm-Seg (https://github.com/sabarim/STEm-Seg) 5 | # ------------------------------------------------------------------------ 6 | 7 | 8 | import imgaug 9 | import imgaug.augmenters as iaa 10 | import numpy as np 11 | 12 | from datetime import datetime 13 | 14 | from imgaug.augmentables.segmaps import SegmentationMapsOnImage 15 | from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage 16 | 17 | 18 | class ImageToSeqAugmenter(object): 19 | def __init__(self, perspective=True, affine=True, motion_blur=True, 20 | brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12, 21 | scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20), 22 | motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5): 23 | 24 | self.basic_augmenter = iaa.SomeOf((1, None), [ 25 | iaa.Add(brightness_range), 26 | iaa.AddToHueAndSaturation(hue_saturation_range) 27 | ] 28 | ) 29 | 30 | transforms = [] 31 | if perspective: 32 | transforms.append(iaa.PerspectiveTransform(perspective_magnitude)) 33 | if affine: 34 | transforms.append(iaa.Affine(scale=scale_range, 35 | translate_percent=translate_range, 36 | rotate=rotation_range, 37 | order=1, # cv2.INTER_LINEAR 38 | backend='auto')) 39 | transforms = iaa.Sequential(transforms) 40 | transforms = [transforms] 41 | 42 | if motion_blur: 43 | blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf( 44 | [ 45 | iaa.MotionBlur(ksize) 46 | for ksize in motion_blur_kernel_sizes 47 | ] 48 | )) 49 | transforms.append(blur) 50 | 51 | self.frame_shift_augmenter = iaa.Sequential(transforms) 52 | 53 | @staticmethod 54 | def condense_masks(instance_masks): 55 | condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8) 56 | for instance_id, mask in enumerate(instance_masks, 1): 57 | condensed_mask = np.where(mask, instance_id, condensed_mask) 58 | 59 | return condensed_mask 60 | 61 | @staticmethod 62 | def expand_masks(condensed_mask, num_instances): 63 | return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)] 64 | 65 | def __call__(self, image, masks=None, boxes=None): 66 | det_augmenter = self.frame_shift_augmenter.to_deterministic() 67 | 68 | 69 | if masks is not None: 70 | masks_np, is_binary_mask = [], [] 71 | boxs_np = [] 72 | 73 | for mask in masks: 74 | 75 | if isinstance(mask, np.ndarray): 76 | masks_np.append(mask.astype(np.bool)) 77 | is_binary_mask.append(False) 78 | else: 79 | raise ValueError("Invalid mask type: {}".format(type(mask))) 80 | 81 | num_instances = len(masks_np) 82 | masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2]) 83 | # boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2]) 84 | 85 | seed = int(datetime.now().strftime('%M%S%f')[-8:]) 86 | imgaug.seed(seed) 87 | aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np) 88 | imgaug.seed(seed) 89 | # invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2) 90 | aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances) 91 | # aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image() 92 | aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)] 93 | # (427, 640, 3) (427, 640) 94 | return aug_image, aug_masks #, aug_boxes.to_xyxy_array() 95 | 96 | else: 97 | # if no mask is provided, random generate and delete the mask. 98 | masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])] 99 | aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks) 100 | return aug_image, invalid_pts_mask.get_arr() == 0 101 | -------------------------------------------------------------------------------- /datasets/jhmdb.py: -------------------------------------------------------------------------------- 1 | """ 2 | JHMDB-Sentences data loader 3 | modified from https://github.com/mttr2021/MTTR/blob/main/datasets/jhmdb_sentences/jhmdb_sentences_dataset.py 4 | """ 5 | from pathlib import Path 6 | 7 | import torch 8 | from torchvision.io import read_video 9 | import torchvision.transforms.functional as F 10 | 11 | from torch.utils.data import Dataset 12 | import datasets.transforms_video as T 13 | 14 | import os 15 | from PIL import Image 16 | import json 17 | import numpy as np 18 | import random 19 | 20 | import scipy.io 21 | 22 | def get_image_id(video_id, frame_idx): 23 | image_id = f'v_{video_id}_f_{frame_idx}' 24 | return image_id 25 | 26 | class JHMDBSentencesDataset(Dataset): 27 | """ 28 | A Torch dataset for JHMDB-Sentences. 29 | For more information check out: https://kgavrilyuk.github.io/publication/actor_action/ or the original paper at: 30 | https://arxiv.org/abs/1803.07485 31 | """ 32 | def __init__(self, image_folder: Path, ann_file: Path, transforms, return_masks: bool, 33 | num_frames: int, max_skip: int, subset): 34 | super(JHMDBSentencesDataset, self).__init__() 35 | self.dataset_path = '../datasets' 36 | self.ann_file = ann_file 37 | self.samples_metadata = self.get_samples_metadata() 38 | 39 | self._transforms = transforms 40 | self.return_masks = return_masks # not used 41 | self.num_frames = num_frames 42 | self.max_skip = max_skip 43 | self.subset = subset 44 | 45 | print(f'\n {subset} sample num: ', len(self.samples_metadata)) 46 | print('\n') 47 | 48 | def get_samples_metadata(self): 49 | with open(str(self.ann_file), 'r') as f: 50 | samples_metadata = [tuple(a) for a in json.load(f)] 51 | return samples_metadata 52 | 53 | @staticmethod 54 | def bounding_box(img): 55 | rows = np.any(img, axis=1) 56 | cols = np.any(img, axis=0) 57 | rmin, rmax = np.where(rows)[0][[0, -1]] 58 | cmin, cmax = np.where(cols)[0][[0, -1]] 59 | return rmin, rmax, cmin, cmax # y1, y2, x1, x2 60 | 61 | def __len__(self): 62 | return len(self.samples_metadata) 63 | 64 | def __getitem__(self, idx): 65 | # only support for evaluation 66 | video_id, chosen_frame_path, video_masks_path, video_total_frames, text_query = self.samples_metadata[idx] 67 | text_query = " ".join(text_query.lower().split()) # clean up the text query 68 | 69 | # read the source window frames: 70 | chosen_frame_idx = int(chosen_frame_path.split('/')[-1].split('.')[0]) 71 | # get a window of window_size frames with frame chosen_frame_idx in the middle. 72 | start_idx, end_idx = chosen_frame_idx - self.num_frames // 2, chosen_frame_idx + (self.num_frames + 1) // 2 73 | frame_indices = list(range(start_idx, end_idx)) # note that jhmdb-sentences frames are 1-indexed 74 | # extract the window source frames: 75 | sample_indx = [] 76 | for i in frame_indices: 77 | i = min(max(i, 1), video_total_frames) # pad out of range indices with edge frames 78 | sample_indx.append(i) 79 | sample_indx.sort() 80 | # find the valid frame index in sampled frame list, there is only one valid frame 81 | valid_indices = sample_indx.index(chosen_frame_idx) 82 | 83 | # read frames 84 | imgs, boxes, masks, valid = [], [], [], [] 85 | for i in sample_indx: 86 | p = '/'.join(chosen_frame_path.split('/')[:-1]) + f'/{i:05d}.png' 87 | frame_path = os.path.join(self.dataset_path, p) 88 | imgs.append(Image.open(frame_path).convert('RGB')) 89 | 90 | # read the instance masks: 91 | video_masks_path = os.path.join(self.dataset_path, video_masks_path) 92 | all_video_masks = scipy.io.loadmat(video_masks_path)['part_mask'].transpose(2, 0, 1) # [T, H, W] 93 | # note that to take the center-frame corresponding mask we switch to 0-indexing: 94 | instance_mask = torch.tensor(all_video_masks[chosen_frame_idx - 1]) # [H, W] 95 | mask = instance_mask.numpy() 96 | if (mask > 0).any(): 97 | y1, y2, x1, x2 = self.bounding_box(mask) 98 | box = torch.tensor([x1, y1, x2, y2]).to(torch.float) 99 | valid.append(1) 100 | else: # some frame didn't contain the instance 101 | box = torch.tensor([0, 0, 0, 0]).to(torch.float) 102 | valid.append(0) 103 | mask = torch.from_numpy(mask) 104 | boxes.append(box) 105 | masks.append(mask) 106 | 107 | # transform 108 | h, w = instance_mask.shape[-2:] 109 | boxes = torch.stack(boxes, dim=0) 110 | boxes[:, 0::2].clamp_(min=0, max=w) 111 | boxes[:, 1::2].clamp_(min=0, max=h) 112 | masks = torch.stack(masks, dim=0) 113 | # there is only one valid frame 114 | target = { 115 | 'frames_idx': torch.tensor(sample_indx), # [T,] 116 | 'valid_indices': torch.tensor([valid_indices]), 117 | 'boxes': boxes, # [1, 4], xyxy 118 | 'masks': masks, # [1, H, W] 119 | 'valid': torch.tensor(valid), # [1,] 120 | 'caption': text_query, 121 | 'orig_size': torch.as_tensor([int(h), int(w)]), 122 | 'size': torch.as_tensor([int(h), int(w)]), 123 | 'image_id': get_image_id(video_id, chosen_frame_idx) 124 | } 125 | 126 | # "boxes" normalize to [0, 1] and transform from xyxy to cxcywh in self._transform 127 | imgs, target = self._transforms(imgs, target) 128 | imgs = torch.stack(imgs, dim=0) # [T, 3, H, W] 129 | 130 | # in 'val', valid always satisfies 131 | return imgs, target 132 | 133 | 134 | def make_coco_transforms(image_set, max_size=640): 135 | normalize = T.Compose([ 136 | T.ToTensor(), 137 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 138 | ]) 139 | 140 | scales = [288, 320, 352, 392, 416, 448, 480, 512] 141 | 142 | if image_set == 'train': 143 | return T.Compose([ 144 | T.RandomHorizontalFlip(), 145 | T.PhotometricDistort(), 146 | T.RandomSelect( 147 | T.Compose([ 148 | T.RandomResize(scales, max_size=max_size), 149 | T.Check(), 150 | ]), 151 | T.Compose([ 152 | T.RandomResize([400, 500, 600]), 153 | T.RandomSizeCrop(384, 600), 154 | T.RandomResize(scales, max_size=max_size), 155 | T.Check(), 156 | ]) 157 | ), 158 | normalize, 159 | ]) 160 | 161 | # we do not use the 'val' set since the annotations are inaccessible 162 | if image_set == 'val': 163 | return T.Compose([ 164 | T.RandomResize([360], max_size=640), 165 | normalize, 166 | ]) 167 | 168 | raise ValueError(f'unknown {image_set}') 169 | 170 | 171 | def build(image_set, args): 172 | root = Path(args.jhmdb_path) 173 | assert root.exists(), f'provided JHMDB-Sentences path {root} does not exist' 174 | PATHS = { 175 | "train": (root, root / "jhmdb_sentences_samples_metadata.json"), # not used 176 | "val": (root, root / "jhmdb_sentences_samples_metadata.json"), 177 | } 178 | img_folder, ann_file = PATHS[image_set] 179 | dataset = JHMDBSentencesDataset(img_folder, ann_file, transforms=make_coco_transforms(image_set, max_size=args.max_size), 180 | return_masks=args.masks, num_frames=args.num_frames, max_skip=args.max_skip, subset=image_set) 181 | return dataset -------------------------------------------------------------------------------- /datasets/refexp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | """ 4 | COCO dataset which returns image_id for evaluation. 5 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 6 | """ 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.utils.data 11 | import torchvision 12 | from pycocotools import mask as coco_mask 13 | 14 | import datasets.transforms_image as T 15 | 16 | 17 | class ModulatedDetection(torchvision.datasets.CocoDetection): 18 | def __init__(self, img_folder, ann_file, transforms, return_masks): 19 | super(ModulatedDetection, self).__init__(img_folder, ann_file) 20 | self._transforms = transforms 21 | self.prepare = ConvertCocoPolysToMask(return_masks) 22 | 23 | def __getitem__(self, idx): 24 | instance_check = False 25 | while not instance_check: 26 | img, target = super(ModulatedDetection, self).__getitem__(idx) 27 | image_id = self.ids[idx] 28 | coco_img = self.coco.loadImgs(image_id)[0] 29 | caption = coco_img["caption"] 30 | dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None 31 | target = {"image_id": image_id, "annotations": target, "caption": caption} 32 | img, target = self.prepare(img, target) 33 | if self._transforms is not None: 34 | img, target = self._transforms(img, target) 35 | target["dataset_name"] = dataset_name 36 | for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: 37 | if extra_key in coco_img: 38 | target[extra_key] = coco_img[extra_key] # box xyxy -> cxcywh 39 | # FIXME: handle "valid", since some box may be removed due to random crop 40 | target["valid"] = torch.tensor([1]) if len(target["area"]) != 0 else torch.tensor([0]) 41 | 42 | if torch.any(target['valid'] == 1): # at leatst one instance 43 | instance_check = True 44 | else: 45 | import random 46 | idx = random.randint(0, self.__len__() - 1) 47 | return img.unsqueeze(0), target 48 | # return img: [1, 3, H, W], the first dimension means T = 1. 49 | 50 | 51 | def convert_coco_poly_to_mask(segmentations, height, width): 52 | masks = [] 53 | for polygons in segmentations: 54 | rles = coco_mask.frPyObjects(polygons, height, width) 55 | mask = coco_mask.decode(rles) 56 | if len(mask.shape) < 3: 57 | mask = mask[..., None] 58 | mask = torch.as_tensor(mask, dtype=torch.uint8) 59 | mask = mask.any(dim=2) 60 | masks.append(mask) 61 | if masks: 62 | masks = torch.stack(masks, dim=0) 63 | else: 64 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 65 | return masks 66 | 67 | 68 | class ConvertCocoPolysToMask(object): 69 | def __init__(self, return_masks=False): 70 | self.return_masks = return_masks 71 | 72 | def __call__(self, image, target): 73 | w, h = image.size 74 | 75 | image_id = target["image_id"] 76 | image_id = torch.tensor([image_id]) 77 | 78 | anno = target["annotations"] 79 | caption = target["caption"] if "caption" in target else None 80 | 81 | anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] 82 | 83 | boxes = [obj["bbox"] for obj in anno] 84 | # guard against no boxes via resizing 85 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 86 | boxes[:, 2:] += boxes[:, :2] # xminyminwh -> xyxy 87 | boxes[:, 0::2].clamp_(min=0, max=w) 88 | boxes[:, 1::2].clamp_(min=0, max=h) 89 | 90 | classes = [obj["category_id"] for obj in anno] 91 | classes = torch.tensor(classes, dtype=torch.int64) 92 | 93 | if self.return_masks: 94 | segmentations = [obj["segmentation"] for obj in anno] 95 | masks = convert_coco_poly_to_mask(segmentations, h, w) 96 | 97 | # keep the valid boxes 98 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 99 | boxes = boxes[keep] 100 | classes = classes[keep] 101 | if self.return_masks: 102 | masks = masks[keep] 103 | 104 | target = {} 105 | target["boxes"] = boxes 106 | target["labels"] = classes 107 | if caption is not None: 108 | target["caption"] = caption 109 | if self.return_masks: 110 | target["masks"] = masks 111 | target["image_id"] = image_id 112 | 113 | # for conversion to coco api 114 | area = torch.tensor([obj["area"] for obj in anno]) 115 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 116 | target["area"] = area[keep] 117 | target["iscrowd"] = iscrowd[keep] 118 | target["valid"] = torch.tensor([1]) 119 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 120 | target["size"] = torch.as_tensor([int(h), int(w)]) 121 | return image, target 122 | 123 | 124 | def make_coco_transforms(image_set, cautious): 125 | 126 | normalize = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 127 | 128 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768] 129 | final_scales = [296, 328, 360, 392, 416, 448, 480, 512] 130 | 131 | max_size = 800 132 | if image_set == "train": 133 | horizontal = [] if cautious else [T.RandomHorizontalFlip()] 134 | return T.Compose( 135 | horizontal 136 | + [ 137 | T.RandomSelect( 138 | T.RandomResize(scales, max_size=max_size), 139 | T.Compose( 140 | [ 141 | T.RandomResize([400, 500, 600]), 142 | T.RandomSizeCrop(384, 600, respect_boxes=cautious), 143 | T.RandomResize(final_scales, max_size=640), 144 | ] 145 | ), 146 | ), 147 | normalize, 148 | ] 149 | ) 150 | 151 | if image_set == "val": 152 | return T.Compose( 153 | [ 154 | T.RandomResize([360], max_size=640), 155 | normalize, 156 | ] 157 | ) 158 | 159 | raise ValueError(f"unknown {image_set}") 160 | 161 | 162 | def build(dataset_file, image_set, args): 163 | root = Path(args.coco_path) 164 | assert root.exists(), f"provided COCO path {root} does not exist" 165 | mode = "instances" 166 | dataset = dataset_file 167 | PATHS = { 168 | "train": (root / "train2014", root / dataset / f"{mode}_{dataset}_train.json"), 169 | "val": (root / "train2014", root / dataset / f"{mode}_{dataset}_val.json"), 170 | } 171 | 172 | img_folder, ann_file = PATHS[image_set] 173 | dataset = ModulatedDetection( 174 | img_folder, 175 | ann_file, 176 | transforms=make_coco_transforms(image_set, False), 177 | return_masks=args.masks, 178 | ) 179 | return dataset -------------------------------------------------------------------------------- /datasets/refexp_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | import copy 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.utils.data 8 | 9 | import util.misc as utils 10 | from util.box_ops import generalized_box_iou 11 | 12 | 13 | class RefExpEvaluator(object): 14 | def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5): 15 | assert isinstance(k, (list, tuple)) 16 | refexp_gt = copy.deepcopy(refexp_gt) 17 | self.refexp_gt = refexp_gt 18 | self.iou_types = iou_types 19 | self.img_ids = self.refexp_gt.imgs.keys() 20 | self.predictions = {} 21 | self.k = k 22 | self.thresh_iou = thresh_iou 23 | 24 | def accumulate(self): 25 | pass 26 | 27 | def update(self, predictions): 28 | self.predictions.update(predictions) 29 | 30 | def synchronize_between_processes(self): 31 | all_predictions = utils.all_gather(self.predictions) 32 | merged_predictions = {} 33 | for p in all_predictions: 34 | merged_predictions.update(p) 35 | self.predictions = merged_predictions 36 | 37 | def summarize(self): 38 | if utils.is_main_process(): 39 | dataset2score = { 40 | "refcoco": {k: 0.0 for k in self.k}, 41 | "refcoco+": {k: 0.0 for k in self.k}, 42 | "refcocog": {k: 0.0 for k in self.k}, 43 | } 44 | dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0} 45 | for image_id in self.img_ids: 46 | ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id) 47 | assert len(ann_ids) == 1 48 | img_info = self.refexp_gt.loadImgs(image_id)[0] 49 | 50 | target = self.refexp_gt.loadAnns(ann_ids[0]) 51 | prediction = self.predictions[image_id] 52 | assert prediction is not None 53 | sorted_scores_boxes = sorted( 54 | zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True 55 | ) 56 | sorted_scores, sorted_boxes = zip(*sorted_scores_boxes) 57 | sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes]) 58 | target_bbox = target[0]["bbox"] 59 | converted_bbox = [ 60 | target_bbox[0], 61 | target_bbox[1], 62 | target_bbox[2] + target_bbox[0], 63 | target_bbox[3] + target_bbox[1], 64 | ] 65 | giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4)) 66 | for k in self.k: 67 | if max(giou[:k]) >= self.thresh_iou: 68 | dataset2score[img_info["dataset_name"]][k] += 1.0 69 | dataset2count[img_info["dataset_name"]] += 1.0 70 | 71 | for key, value in dataset2score.items(): 72 | for k in self.k: 73 | try: 74 | value[k] /= dataset2count[key] 75 | except: 76 | pass 77 | results = {} 78 | for key, value in dataset2score.items(): 79 | results[key] = sorted([v for k, v in value.items()]) 80 | print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n") 81 | 82 | return results 83 | return None 84 | 85 | 86 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from codes in torch.utils.data.distributed 7 | # ------------------------------------------------------------------------ 8 | 9 | import os 10 | import math 11 | import torch 12 | import torch.distributed as dist 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class DistributedSampler(Sampler): 17 | """Sampler that restricts data loading to a subset of the dataset. 18 | It is especially useful in conjunction with 19 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 20 | process can pass a DistributedSampler instance as a DataLoader sampler, 21 | and load a subset of the original dataset that is exclusive to it. 22 | .. note:: 23 | Dataset is assumed to be of constant size. 24 | Arguments: 25 | dataset: Dataset used for sampling. 26 | num_replicas (optional): Number of processes participating in 27 | distributed training. 28 | rank (optional): Rank of the current process within num_replicas. 29 | """ 30 | 31 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | num_replicas = dist.get_world_size() 36 | if rank is None: 37 | if not dist.is_available(): 38 | raise RuntimeError("Requires distributed package to be available") 39 | rank = dist.get_rank() 40 | self.dataset = dataset 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 45 | self.total_size = self.num_samples * self.num_replicas 46 | self.shuffle = shuffle 47 | 48 | def __iter__(self): 49 | if self.shuffle: 50 | # deterministically shuffle based on epoch 51 | g = torch.Generator() 52 | g.manual_seed(self.epoch) 53 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 54 | else: 55 | indices = torch.arange(len(self.dataset)).tolist() 56 | 57 | # add extra samples to make it evenly divisible 58 | indices += indices[: (self.total_size - len(indices))] 59 | assert len(indices) == self.total_size 60 | 61 | # subsample 62 | offset = self.num_samples * self.rank 63 | indices = indices[offset : offset + self.num_samples] 64 | assert len(indices) == self.num_samples 65 | 66 | return iter(indices) 67 | 68 | def __len__(self): 69 | return self.num_samples 70 | 71 | def set_epoch(self, epoch): 72 | self.epoch = epoch 73 | 74 | 75 | class NodeDistributedSampler(Sampler): 76 | """Sampler that restricts data loading to a subset of the dataset. 77 | It is especially useful in conjunction with 78 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 79 | process can pass a DistributedSampler instance as a DataLoader sampler, 80 | and load a subset of the original dataset that is exclusive to it. 81 | .. note:: 82 | Dataset is assumed to be of constant size. 83 | Arguments: 84 | dataset: Dataset used for sampling. 85 | num_replicas (optional): Number of processes participating in 86 | distributed training. 87 | rank (optional): Rank of the current process within num_replicas. 88 | """ 89 | 90 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 91 | if num_replicas is None: 92 | if not dist.is_available(): 93 | raise RuntimeError("Requires distributed package to be available") 94 | num_replicas = dist.get_world_size() 95 | if rank is None: 96 | if not dist.is_available(): 97 | raise RuntimeError("Requires distributed package to be available") 98 | rank = dist.get_rank() 99 | if local_rank is None: 100 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 101 | if local_size is None: 102 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 103 | self.dataset = dataset 104 | self.shuffle = shuffle 105 | self.num_replicas = num_replicas 106 | self.num_parts = local_size 107 | self.rank = rank 108 | self.local_rank = local_rank 109 | self.epoch = 0 110 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 111 | self.total_size = self.num_samples * self.num_replicas 112 | 113 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts 114 | 115 | def __iter__(self): 116 | if self.shuffle: 117 | # deterministically shuffle based on epoch 118 | g = torch.Generator() 119 | g.manual_seed(self.epoch) 120 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 121 | else: 122 | indices = torch.arange(len(self.dataset)).tolist() 123 | indices = [i for i in indices if i % self.num_parts == self.local_rank] 124 | 125 | # add extra samples to make it evenly divisible 126 | indices += indices[:(self.total_size_parts - len(indices))] 127 | assert len(indices) == self.total_size_parts 128 | 129 | # subsample 130 | indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] 131 | assert len(indices) == self.num_samples 132 | 133 | return iter(indices) 134 | 135 | def __len__(self): 136 | return self.num_samples 137 | 138 | def set_epoch(self, epoch): 139 | self.epoch = epoch 140 | -------------------------------------------------------------------------------- /davis2017/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /davis2017/davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from collections import defaultdict 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | class DAVIS(object): 9 | SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] 10 | TASKS = ['semi-supervised', 'unsupervised'] 11 | DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' 12 | VOID_LABEL = 255 13 | 14 | def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): 15 | """ 16 | Class to read the DAVIS dataset 17 | :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 18 | :param task: Task to load the annotations, choose between semi-supervised or unsupervised. 19 | :param subset: Set to load the annotations 20 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set. 21 | :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' 22 | """ 23 | if subset not in self.SUBSET_OPTIONS: 24 | raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') 25 | if task not in self.TASKS: 26 | raise ValueError(f'The only tasks that are supported are {self.TASKS}') 27 | 28 | self.task = task 29 | self.subset = subset 30 | self.root = root 31 | self.img_path = os.path.join(self.root, 'JPEGImages', resolution) 32 | annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' 33 | self.mask_path = os.path.join(self.root, annotations_folder, resolution) 34 | year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' 35 | self.imagesets_path = os.path.join(self.root, 'ImageSets', year) 36 | 37 | self._check_directories() 38 | 39 | if sequences == 'all': 40 | with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: 41 | tmp = f.readlines() 42 | sequences_names = [x.strip() for x in tmp] 43 | else: 44 | sequences_names = sequences if isinstance(sequences, list) else [sequences] 45 | self.sequences = defaultdict(dict) 46 | 47 | for seq in sequences_names: 48 | images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() 49 | if len(images) == 0 and not codalab: 50 | raise FileNotFoundError(f'Images for sequence {seq} not found.') 51 | self.sequences[seq]['images'] = images 52 | masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() 53 | masks.extend([-1] * (len(images) - len(masks))) 54 | self.sequences[seq]['masks'] = masks 55 | 56 | def _check_directories(self): 57 | if not os.path.exists(self.root): 58 | raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') 59 | if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): 60 | raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' 61 | f'for the {self.task} task from {self.DATASET_WEB}') 62 | if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): 63 | raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') 64 | 65 | def get_frames(self, sequence): 66 | for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): 67 | image = np.array(Image.open(img)) 68 | mask = None if msk is None else np.array(Image.open(msk)) 69 | yield image, mask 70 | 71 | def _get_all_elements(self, sequence, obj_type): 72 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) 73 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) 74 | obj_id = [] 75 | for i, obj in enumerate(self.sequences[sequence][obj_type]): 76 | all_objs[i, ...] = np.array(Image.open(obj)) 77 | obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) 78 | return all_objs, obj_id 79 | 80 | def get_all_images(self, sequence): 81 | return self._get_all_elements(sequence, 'images') 82 | 83 | def get_all_masks(self, sequence, separate_objects_masks=False): 84 | masks, masks_id = self._get_all_elements(sequence, 'masks') 85 | masks_void = np.zeros_like(masks) 86 | 87 | # Separate void and object masks 88 | for i in range(masks.shape[0]): 89 | masks_void[i, ...] = masks[i, ...] == 255 90 | masks[i, masks[i, ...] == 255] = 0 91 | 92 | if separate_objects_masks: 93 | num_objects = int(np.max(masks[0, ...])) 94 | tmp = np.ones((num_objects, *masks.shape)) 95 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 96 | masks = (tmp == masks[None, ...]) 97 | masks = masks > 0 98 | return masks, masks_void, masks_id 99 | 100 | def get_sequences(self): 101 | for seq in self.sequences: 102 | yield seq 103 | 104 | 105 | if __name__ == '__main__': 106 | from matplotlib import pyplot as plt 107 | 108 | only_first_frame = True 109 | subsets = ['train', 'val'] 110 | 111 | for s in subsets: 112 | dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s) 113 | for seq in dataset.get_sequences(): 114 | g = dataset.get_frames(seq) 115 | img, mask = next(g) 116 | plt.subplot(2, 1, 1) 117 | plt.title(seq) 118 | plt.imshow(img) 119 | plt.subplot(2, 1, 2) 120 | plt.imshow(mask) 121 | plt.show(block=True) 122 | 123 | -------------------------------------------------------------------------------- /davis2017/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | import warnings 4 | warnings.filterwarnings("ignore", category=RuntimeWarning) 5 | 6 | import numpy as np 7 | from davis2017.davis import DAVIS 8 | from davis2017.metrics import db_eval_boundary, db_eval_iou 9 | from davis2017 import utils 10 | from davis2017.results import Results 11 | from scipy.optimize import linear_sum_assignment 12 | 13 | 14 | class DAVISEvaluation(object): 15 | def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False): 16 | """ 17 | Class to evaluate DAVIS sequences from a certain set and for a certain task 18 | :param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 19 | :param task: Task to compute the evaluation, chose between semi-supervised or unsupervised. 20 | :param gt_set: Set to compute the evaluation 21 | :param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set. 22 | """ 23 | self.davis_root = davis_root 24 | self.task = task 25 | self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab) 26 | 27 | @staticmethod 28 | def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): 29 | if all_res_masks.shape[0] > all_gt_masks.shape[0]: 30 | sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!") 31 | sys.exit() 32 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 33 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 34 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 35 | j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2]) 36 | for ii in range(all_gt_masks.shape[0]): 37 | if 'J' in metric: 38 | j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 39 | if 'F' in metric: 40 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 41 | return j_metrics_res, f_metrics_res 42 | 43 | @staticmethod 44 | def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20): 45 | if all_res_masks.shape[0] > max_n_proposals: 46 | sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!") 47 | sys.exit() 48 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 49 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 50 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 51 | j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) 52 | f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) 53 | for ii in range(all_gt_masks.shape[0]): 54 | for jj in range(all_res_masks.shape[0]): 55 | if 'J' in metric: 56 | j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) 57 | if 'F' in metric: 58 | f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) 59 | if 'J' in metric and 'F' in metric: 60 | all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2 61 | else: 62 | all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2) 63 | row_ind, col_ind = linear_sum_assignment(-all_metrics) 64 | return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :] 65 | 66 | def evaluate(self, res_path, metric=('J', 'F'), debug=False): 67 | metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] 68 | if 'T' in metric: 69 | raise ValueError('Temporal metric not supported!') 70 | if 'J' not in metric and 'F' not in metric: 71 | raise ValueError('Metric possible values are J for IoU or F for Boundary') 72 | 73 | # Containers 74 | metrics_res = {} 75 | if 'J' in metric: 76 | metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 77 | if 'F' in metric: 78 | metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 79 | 80 | # Sweep all sequences 81 | results = Results(root_dir=res_path) 82 | for seq in tqdm(list(self.dataset.get_sequences())): 83 | all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True) 84 | if self.task == 'semi-supervised': 85 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 86 | all_res_masks = results.read_masks(seq, all_masks_id) 87 | if self.task == 'unsupervised': 88 | j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric) 89 | elif self.task == 'semi-supervised': 90 | j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric) 91 | for ii in range(all_gt_masks.shape[0]): 92 | seq_name = f'{seq}_{ii+1}' 93 | if 'J' in metric: 94 | [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) 95 | metrics_res['J']["M"].append(JM) 96 | metrics_res['J']["R"].append(JR) 97 | metrics_res['J']["D"].append(JD) 98 | metrics_res['J']["M_per_object"][seq_name] = JM 99 | if 'F' in metric: 100 | [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii]) 101 | metrics_res['F']["M"].append(FM) 102 | metrics_res['F']["R"].append(FR) 103 | metrics_res['F']["D"].append(FD) 104 | metrics_res['F']["M_per_object"][seq_name] = FM 105 | 106 | # Show progress 107 | if debug: 108 | sys.stdout.write(seq + '\n') 109 | sys.stdout.flush() 110 | return metrics_res 111 | -------------------------------------------------------------------------------- /davis2017/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def db_eval_iou(annotation, segmentation, void_pixels=None): 7 | """ Compute region similarity as the Jaccard Index. 8 | Arguments: 9 | annotation (ndarray): binary annotation map. 10 | segmentation (ndarray): binary segmentation map. 11 | void_pixels (ndarray): optional mask with void pixels 12 | 13 | Return: 14 | jaccard (float): region similarity 15 | """ 16 | assert annotation.shape == segmentation.shape, \ 17 | f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' 18 | annotation = annotation.astype(np.bool) 19 | segmentation = segmentation.astype(np.bool) 20 | 21 | if void_pixels is not None: 22 | assert annotation.shape == void_pixels.shape, \ 23 | f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' 24 | void_pixels = void_pixels.astype(np.bool) 25 | else: 26 | void_pixels = np.zeros_like(segmentation) 27 | 28 | # Intersection between all sets 29 | inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 30 | union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 31 | 32 | j = inters / union 33 | if j.ndim == 0: 34 | j = 1 if np.isclose(union, 0) else j 35 | else: 36 | j[np.isclose(union, 0)] = 1 37 | return j 38 | 39 | 40 | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): 41 | assert annotation.shape == segmentation.shape 42 | if void_pixels is not None: 43 | assert annotation.shape == void_pixels.shape 44 | if annotation.ndim == 3: 45 | n_frames = annotation.shape[0] 46 | f_res = np.zeros(n_frames) 47 | for frame_id in range(n_frames): 48 | void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] 49 | f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) 50 | elif annotation.ndim == 2: 51 | f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) 52 | else: 53 | raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') 54 | return f_res 55 | 56 | 57 | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): 58 | """ 59 | Compute mean,recall and decay from per-frame evaluation. 60 | Calculates precision/recall for boundaries between foreground_mask and 61 | gt_mask using morphological operators to speed it up. 62 | 63 | Arguments: 64 | foreground_mask (ndarray): binary segmentation image. 65 | gt_mask (ndarray): binary annotated image. 66 | void_pixels (ndarray): optional mask with void pixels 67 | 68 | Returns: 69 | F (float): boundaries F-measure 70 | """ 71 | assert np.atleast_3d(foreground_mask).shape[2] == 1 72 | if void_pixels is not None: 73 | void_pixels = void_pixels.astype(np.bool) 74 | else: 75 | void_pixels = np.zeros_like(foreground_mask).astype(np.bool) 76 | 77 | bound_pix = bound_th if bound_th >= 1 else \ 78 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 79 | 80 | # Get the pixel boundaries of both masks 81 | fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) 82 | gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) 83 | 84 | from skimage.morphology import disk 85 | 86 | # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 87 | fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 88 | # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 89 | gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 90 | 91 | # Get the intersection 92 | gt_match = gt_boundary * fg_dil 93 | fg_match = fg_boundary * gt_dil 94 | 95 | # Area of the intersection 96 | n_fg = np.sum(fg_boundary) 97 | n_gt = np.sum(gt_boundary) 98 | 99 | # % Compute precision and recall 100 | if n_fg == 0 and n_gt > 0: 101 | precision = 1 102 | recall = 0 103 | elif n_fg > 0 and n_gt == 0: 104 | precision = 0 105 | recall = 1 106 | elif n_fg == 0 and n_gt == 0: 107 | precision = 1 108 | recall = 1 109 | else: 110 | precision = np.sum(fg_match) / float(n_fg) 111 | recall = np.sum(gt_match) / float(n_gt) 112 | 113 | # Compute F measure 114 | if precision + recall == 0: 115 | F = 0 116 | else: 117 | F = 2 * precision * recall / (precision + recall) 118 | 119 | return F 120 | 121 | 122 | def _seg2bmap(seg, width=None, height=None): 123 | """ 124 | From a segmentation, compute a binary boundary map with 1 pixel wide 125 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 126 | origin from the actual segment boundary. 127 | Arguments: 128 | seg : Segments labeled from 1..k. 129 | width : Width of desired bmap <= seg.shape[1] 130 | height : Height of desired bmap <= seg.shape[0] 131 | Returns: 132 | bmap (ndarray): Binary boundary map. 133 | David Martin 134 | January 2003 135 | """ 136 | 137 | seg = seg.astype(np.bool) 138 | seg[seg > 0] = 1 139 | 140 | assert np.atleast_3d(seg).shape[2] == 1 141 | 142 | width = seg.shape[1] if width is None else width 143 | height = seg.shape[0] if height is None else height 144 | 145 | h, w = seg.shape[:2] 146 | 147 | ar1 = float(width) / float(height) 148 | ar2 = float(w) / float(h) 149 | 150 | assert not ( 151 | width > w | height > h | abs(ar1 - ar2) > 0.01 152 | ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) 153 | 154 | e = np.zeros_like(seg) 155 | s = np.zeros_like(seg) 156 | se = np.zeros_like(seg) 157 | 158 | e[:, :-1] = seg[:, 1:] 159 | s[:-1, :] = seg[1:, :] 160 | se[:-1, :-1] = seg[1:, 1:] 161 | 162 | b = seg ^ e | seg ^ s | seg ^ se 163 | b[-1, :] = seg[-1, :] ^ e[-1, :] 164 | b[:, -1] = seg[:, -1] ^ s[:, -1] 165 | b[-1, -1] = 0 166 | 167 | if w == width and h == height: 168 | bmap = b 169 | else: 170 | bmap = np.zeros((height, width)) 171 | for x in range(w): 172 | for y in range(h): 173 | if b[y, x]: 174 | j = 1 + math.floor((y - 1) + height / h) 175 | i = 1 + math.floor((x - 1) + width / h) 176 | bmap[j, i] = 1 177 | 178 | return bmap 179 | 180 | 181 | if __name__ == '__main__': 182 | from davis2017.davis import DAVIS 183 | from davis2017.results import Results 184 | 185 | dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics') 186 | results = Results(root_dir='examples/osvos') 187 | # Test timing F measure 188 | for seq in dataset.get_sequences(): 189 | all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True) 190 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 191 | all_res_masks = results.read_masks(seq, all_masks_id) 192 | f_metrics_res = np.zeros(all_gt_masks.shape[:2]) 193 | for ii in range(all_gt_masks.shape[0]): 194 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...]) 195 | 196 | # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py 197 | # snakeviz f_measure.prof 198 | -------------------------------------------------------------------------------- /davis2017/results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import sys 5 | 6 | 7 | class Results(object): 8 | def __init__(self, root_dir): 9 | self.root_dir = root_dir 10 | 11 | def _read_mask(self, sequence, frame_id): 12 | try: 13 | mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png') 14 | return np.array(Image.open(mask_path)) 15 | except IOError as err: 16 | sys.stdout.write(sequence + " frame %s not found!\n" % frame_id) 17 | sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence " 18 | "folder.\nThe indexes have to match with the initial frame.\n") 19 | sys.stderr.write("IOError: " + err.strerror + "\n") 20 | sys.exit() 21 | 22 | def read_masks(self, sequence, masks_id): 23 | mask_0 = self._read_mask(sequence, masks_id[0]) 24 | masks = np.zeros((len(masks_id), *mask_0.shape)) 25 | for ii, m in enumerate(masks_id): 26 | masks[ii, ...] = self._read_mask(sequence, m) 27 | num_objects = int(np.max(masks)) 28 | tmp = np.ones((num_objects, *masks.shape)) 29 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 30 | masks = (tmp == masks[None, ...]) > 0 31 | return masks 32 | -------------------------------------------------------------------------------- /davis2017/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from PIL import Image 5 | import warnings 6 | from davis2017.davis import DAVIS 7 | 8 | 9 | def _pascal_color_map(N=256, normalized=False): 10 | """ 11 | Python implementation of the color map function for the PASCAL VOC data set. 12 | Official Matlab version can be found in the PASCAL VOC devkit 13 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 14 | """ 15 | 16 | def bitget(byteval, idx): 17 | return (byteval & (1 << idx)) != 0 18 | 19 | dtype = 'float32' if normalized else 'uint8' 20 | cmap = np.zeros((N, 3), dtype=dtype) 21 | for i in range(N): 22 | r = g = b = 0 23 | c = i 24 | for j in range(8): 25 | r = r | (bitget(c, 0) << 7 - j) 26 | g = g | (bitget(c, 1) << 7 - j) 27 | b = b | (bitget(c, 2) << 7 - j) 28 | c = c >> 3 29 | 30 | cmap[i] = np.array([r, g, b]) 31 | 32 | cmap = cmap / 255 if normalized else cmap 33 | return cmap 34 | 35 | 36 | def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 37 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 38 | if im.shape[:-1] != ann.shape: 39 | raise ValueError('First two dimensions of `im` and `ann` must match') 40 | if im.shape[-1] != 3: 41 | raise ValueError('im must have three channels at the 3 dimension') 42 | 43 | colors = colors or _pascal_color_map() 44 | colors = np.asarray(colors, dtype=np.uint8) 45 | 46 | mask = colors[ann] 47 | fg = im * alpha + (1 - alpha) * mask 48 | 49 | img = im.copy() 50 | img[ann > 0] = fg[ann > 0] 51 | 52 | if contour_thickness: # pragma: no cover 53 | import cv2 54 | for obj_id in np.unique(ann[ann > 0]): 55 | contours = cv2.findContours((ann == obj_id).astype( 56 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 57 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 58 | contour_thickness) 59 | return img 60 | 61 | 62 | def generate_obj_proposals(davis_root, subset, num_proposals, save_path): 63 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 64 | for seq in dataset.get_sequences(): 65 | save_dir = os.path.join(save_path, seq) 66 | if os.path.exists(save_dir): 67 | continue 68 | all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 69 | img_size = all_gt_masks.shape[2:] 70 | num_rows = int(np.ceil(np.sqrt(num_proposals))) 71 | proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) 72 | height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() 73 | width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() 74 | ii = 0 75 | prev_h, prev_w = 0, 0 76 | for h in height_slices[1:]: 77 | for w in width_slices[1:]: 78 | proposals[ii, :, prev_h:h, prev_w:w] = 1 79 | prev_w = w 80 | ii += 1 81 | if ii == num_proposals: 82 | break 83 | prev_h, prev_w = h, 0 84 | if ii == num_proposals: 85 | break 86 | 87 | os.makedirs(save_dir, exist_ok=True) 88 | for i, mask_id in enumerate(all_masks_id): 89 | mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) 90 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 91 | 92 | 93 | def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path): 94 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 95 | for seq in dataset.get_sequences(): 96 | gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 97 | obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) 98 | gt_masks = gt_masks[obj_swap, ...] 99 | save_dir = os.path.join(save_path, seq) 100 | os.makedirs(save_dir, exist_ok=True) 101 | for i, mask_id in enumerate(all_masks_id): 102 | mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) 103 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 104 | 105 | 106 | def color_map(N=256, normalized=False): 107 | def bitget(byteval, idx): 108 | return ((byteval & (1 << idx)) != 0) 109 | 110 | dtype = 'float32' if normalized else 'uint8' 111 | cmap = np.zeros((N, 3), dtype=dtype) 112 | for i in range(N): 113 | r = g = b = 0 114 | c = i 115 | for j in range(8): 116 | r = r | (bitget(c, 0) << 7-j) 117 | g = g | (bitget(c, 1) << 7-j) 118 | b = b | (bitget(c, 2) << 7-j) 119 | c = c >> 3 120 | 121 | cmap[i] = np.array([r, g, b]) 122 | 123 | cmap = cmap/255 if normalized else cmap 124 | return cmap 125 | 126 | 127 | def save_mask(mask, img_path): 128 | if np.max(mask) > 255: 129 | raise ValueError('Maximum id pixel value is 255') 130 | mask_img = Image.fromarray(mask.astype(np.uint8)) 131 | mask_img.putpalette(color_map().flatten().tolist()) 132 | mask_img.save(img_path) 133 | 134 | 135 | def db_statistics(per_frame_values): 136 | """ Compute mean,recall and decay from per-frame evaluation. 137 | Arguments: 138 | per_frame_values (ndarray): per-frame evaluation 139 | 140 | Returns: 141 | M,O,D (float,float,float): 142 | return evaluation statistics: mean,recall,decay. 143 | """ 144 | 145 | # strip off nan values 146 | with warnings.catch_warnings(): 147 | warnings.simplefilter("ignore", category=RuntimeWarning) 148 | M = np.nanmean(per_frame_values) 149 | O = np.nanmean(per_frame_values > 0.5) 150 | 151 | N_bins = 4 152 | ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 153 | ids = ids.astype(np.uint8) 154 | 155 | D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] 156 | 157 | with warnings.catch_warnings(): 158 | warnings.simplefilter("ignore", category=RuntimeWarning) 159 | D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) 160 | 161 | return M, O, D 162 | 163 | 164 | def list_files(dir, extension=".png"): 165 | return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] 166 | 167 | 168 | def force_symlink(file1, file2): 169 | try: 170 | os.symlink(file1, file2) 171 | except OSError as e: 172 | if e.errno == errno.EEXIST: 173 | os.remove(file2) 174 | os.symlink(file1, file2) 175 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | After the organization, we expect the directory structure to be the following: 4 | 5 | Notice: `SgMg` and `datasets` are in the same folder. 6 | 7 | ``` 8 | ├── SgMg 9 | ├── datasets 10 | │ ├── coco 11 | │ ├── train2014 12 | │ ├── refcoco 13 | │ ├── instances_refcoco_train.json 14 | │ ├── instances_refcoco_val.json 15 | │ ├── refcoco+ 16 | │ ├── instances_refcoco+_train.json 17 | │ ├── instances_refcoco+_val.json 18 | │ ├── refcocog 19 | │ ├── instances_refcocog_train.json 20 | │ ├── instances_refcocog_val.json 21 | │ ├── refer_youtube_vos 22 | │ ├── meta_expressions 23 | │ ├── train 24 | │ ├── JPEGImages 25 | │ ├── Annotations 26 | │ ├── meta.json 27 | │ ├── valid 28 | │ ├── JPEGImages 29 | │ ├── refer_davis 30 | │ ├── meta_expressions 31 | │ ├── valid 32 | │ ├── JPEGImages 33 | │ ├── 480p 34 | │ ├── Annotations 35 | │ ├── ImageSets 36 | │ ├── meta.json 37 | │ ├── a2d_sentences 38 | │ ├── Release 39 | │ ├── text_annotations 40 | │ ├── a2d_annotation_with_instances 41 | │ ├── a2d_annotation.txt 42 | │ ├── a2d_missed_videos.txt 43 | │ ├── a2d_sentences_single_frame_test_annotations.json 44 | │ ├── a2d_sentences_single_frame_train_annotations.json 45 | │ ├── a2d_sentences_test_annotations_in_coco_format.json 46 | │ ├── jhmdb_sentences 47 | │ ├── Rename_Images 48 | │ ├── puppet_mask 49 | │ ├── jhmdb_annotation.txt 50 | │ ├── jhmdb_sentences_samples_metadata.json 51 | │ ├── jhmdb_sentences_gt_annotations_in_coco_format.json 52 | ... 53 | ``` 54 | 55 | ## Ref-COCO 56 | 57 | Download the dataset from the official website [COCO](https://cocodataset.org/#download). 58 | RefCOCO/+/g use the COCO2014 train split. 59 | Download the annotation files from [github](https://github.com/lichengunc/refer). 60 | 61 | Convert the annotation files: 62 | 63 | ``` 64 | python3 tools/data/convert_refexp_to_coco.py 65 | ``` 66 | 67 | Finally, we expect the directory structure to be the following: 68 | 69 | ``` 70 | ├── datasets 71 | │ ├── coco 72 | │ ├── train2014 73 | │ ├── refcoco 74 | │ ├── instances_refcoco_train.json 75 | │ ├── instances_refcoco_val.json 76 | │ ├── refcoco+ 77 | │ ├── instances_refcoco+_train.json 78 | │ ├── instances_refcoco+_val.json 79 | │ ├── refcocog 80 | │ ├── instances_refcocog_train.json 81 | │ ├── instances_refcocog_val.json 82 | ``` 83 | 84 | ## refer_youtube_vos 85 | 86 | Download the dataset from the competition's website [here](https://competitions.codalab.org/competitions/29139#participate-get_data). 87 | Then, extract and organize the file. We expect the directory structure to be the following: 88 | 89 | ``` 90 | ├── datasets 91 | │ ├── refer_youtube_vos 92 | │ ├── meta_expressions 93 | │ ├── train 94 | │ ├── JPEGImages 95 | │ ├── Annotations 96 | │ ├── meta.json 97 | │ ├── valid 98 | │ ├── JPEGImages 99 | ``` 100 | 101 | ## refer_davis17 102 | 103 | **Notice: We recommend to directly download the parsed Ref-DAVIS dataset from the** [Google Drive](https://drive.google.com/file/d/1W0RsdxMK3VkNL80H1OWNmia-2asdCyYF/view?usp=sharing) **to avoid the following steps.** 104 | 105 | Download the DAVIS2017 dataset from the [website](https://davischallenge.org/davis2017/code.html). Note that you only need to download the two zip files `DAVIS-2017-Unsupervised-trainval-480p.zip` and `DAVIS-2017_semantics-480p.zip`. 106 | Download the text annotations from the [website](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/video-segmentation/video-object-segmentation-with-language-referring-expressions). 107 | Then, put the zip files in the directory as follows. 108 | 109 | ``` 110 | ├── datasets 111 | │ ├── refer_davis 112 | │ │ ├── DAVIS-2017_semantics-480p.zip 113 | │ │ ├── DAVIS-2017-Unsupervised-trainval-480p.zip 114 | │ │ ├── davis_text_annotations.zip 115 | ``` 116 | 117 | Unzip these zip files. 118 | ``` 119 | unzip -o davis_text_annotations.zip 120 | unzip -o DAVIS-2017_semantics-480p.zip 121 | unzip -o DAVIS-2017-Unsupervised-trainval-480p.zip 122 | ``` 123 | 124 | Preprocess the dataset to refer_youtube_vos format. (Make sure you are in the main directory) 125 | 126 | ``` 127 | python tools/data/convert_davis_to_ytvos.py 128 | ``` 129 | 130 | Finally, unzip the file `DAVIS-2017-Unsupervised-trainval-480p.zip` again (since we use `mv` in preprocess for efficiency). 131 | 132 | ``` 133 | unzip -o DAVIS-2017-Unsupervised-trainval-480p.zip 134 | ``` 135 | 136 | ## A2D-Sentences 137 | 138 | Follow the instructions and download the dataset from the website [here](https://kgavrilyuk.github.io/publication/actor_action/). 139 | Then, extract the files. Additionally, we use the same json annotation files generated by [MTTR](https://github.com/mttr2021/MTTR). Please download these files from [google drive](https://drive.google.com/drive/u/0/folders/1daTuACcZUKuzgl0iqzwCfKm_tSISarFl). 140 | We expect the directory structure to be the following: 141 | 142 | ``` 143 | ├── datasets 144 | │ ├── a2d_sentences 145 | │ │ ├── Release 146 | │ │ ├── text_annotations 147 | │ │ │ ├── a2d_annotation_with_instances 148 | │ │ │ ├── a2d_annotation.txt 149 | │ │ │ ├── a2d_missed_videos.txt 150 | │ │ ├── a2d_sentences_single_frame_test_annotations.json 151 | │ │ ├── a2d_sentences_single_frame_train_annotations.json 152 | │ │ ├── a2d_sentences_test_annotations_in_coco_format.json 153 | ``` 154 | 155 | ## JHMDB-Sentences 156 | 157 | Follow the instructions and download the dataset from the website [here](https://kgavrilyuk.github.io/publication/actor_action/). 158 | Then, extract the files. Additionally, we use the same json annotation files generated by [MTTR](https://github.com/mttr2021/MTTR). Please download these files from [google drive](https://drive.google.com/drive/u/0/folders/1sXmjpWmc0GxYIz-EFLw5S9dJvmGJAPqx). 159 | We expect the directory structure to be the following: 160 | 161 | ``` 162 | ├── datasets 163 | │ ├── jhmdb_sentences 164 | │ │ ├── Rename_Images 165 | │ │ ├── puppet_mask 166 | │ │ ├── jhmdb_annotation.txt 167 | │ │ ├── jhmdb_sentences_samples_metadata.json 168 | │ │ ├── jhmdb_sentences_gt_annotations_in_coco_format.json 169 | ``` 170 | -------------------------------------------------------------------------------- /docs/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-miao/SgMg/90fd3c476858218b1ed0c21ec28e64e762ca2c84/docs/framework.png -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | We provide the instructions to install the dependency packages. 4 | 5 | 6 | ## Setup 7 | 8 | First, clone the repository locally. 9 | 10 | ``` 11 | git clone https://github.com/bo-miao/SgMg 12 | ``` 13 | 14 | Then, install Pytorch==1.11.0 (CUDA 11.3) torchvision==0.12.0 and the necessary packages as well as pycocotools. 15 | ``` 16 | pip install -r requirements.txt 17 | pip install 'git+https://github.com/facebookresearch/fvcore' 18 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 19 | ``` 20 | 21 | Finally, compile CUDA operators. 22 | ``` 23 | cd models/ops 24 | python setup.py build install 25 | cd ../.. 26 | ``` 27 | -------------------------------------------------------------------------------- /eval_davis.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | from time import time 5 | import argparse 6 | 7 | import numpy as np 8 | import pandas as pd 9 | from davis2017.evaluation import DAVISEvaluation 10 | 11 | default_davis_path = '../datasets/refer_davis/valid' 12 | 13 | time_start = time() 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--davis_path', type=str, help='Path to the DAVIS folder containing the JPEGImages, Annotations, ' 16 | 'ImageSets, Annotations_unsupervised folders', 17 | required=False, default=default_davis_path) 18 | parser.add_argument('--set', type=str, help='Subset to evaluate the results', default='val') # val subset 19 | parser.add_argument('--task', type=str, help='Task to evaluate the results', default='unsupervised', 20 | choices=['semi-supervised', 'unsupervised']) 21 | parser.add_argument('--results_path', type=str, help='Path to the folder containing the sequences folders', 22 | required=True) 23 | args, _ = parser.parse_known_args() 24 | csv_name_global = f'global_results-{args.set}.csv' 25 | csv_name_per_sequence = f'per-sequence_results-{args.set}.csv' 26 | 27 | # Check if the method has been evaluated before, if so read the results, otherwise compute the results 28 | csv_name_global_path = os.path.join(args.results_path, csv_name_global) 29 | csv_name_per_sequence_path = os.path.join(args.results_path, csv_name_per_sequence) 30 | if os.path.exists(csv_name_global_path) and os.path.exists(csv_name_per_sequence_path): 31 | print('Using precomputed results...') 32 | table_g = pd.read_csv(csv_name_global_path) 33 | table_seq = pd.read_csv(csv_name_per_sequence_path) 34 | else: 35 | print(f'Evaluating sequences for the {args.task} task...') 36 | # Create dataset and evaluate 37 | dataset_eval = DAVISEvaluation(davis_root=args.davis_path, task=args.task, gt_set=args.set) 38 | metrics_res = dataset_eval.evaluate(args.results_path) 39 | J, F = metrics_res['J'], metrics_res['F'] 40 | 41 | # Generate dataframe for the general results 42 | g_measures = ['J&F-Mean', 'J-Mean', 'J-Recall', 'J-Decay', 'F-Mean', 'F-Recall', 'F-Decay'] 43 | final_mean = (np.mean(J["M"]) + np.mean(F["M"])) / 2. 44 | g_res = np.array([final_mean, np.mean(J["M"]), np.mean(J["R"]), np.mean(J["D"]), np.mean(F["M"]), np.mean(F["R"]), 45 | np.mean(F["D"])]) 46 | g_res = np.reshape(g_res, [1, len(g_res)]) 47 | table_g = pd.DataFrame(data=g_res, columns=g_measures) 48 | with open(csv_name_global_path, 'w') as f: 49 | table_g.to_csv(f, index=False, float_format="%.5f") 50 | print(f'Global results saved in {csv_name_global_path}') 51 | 52 | # Generate a dataframe for the per sequence results 53 | seq_names = list(J['M_per_object'].keys()) 54 | seq_measures = ['Sequence', 'J-Mean', 'F-Mean'] 55 | J_per_object = [J['M_per_object'][x] for x in seq_names] 56 | F_per_object = [F['M_per_object'][x] for x in seq_names] 57 | table_seq = pd.DataFrame(data=list(zip(seq_names, J_per_object, F_per_object)), columns=seq_measures) 58 | with open(csv_name_per_sequence_path, 'w') as f: 59 | table_seq.to_csv(f, index=False, float_format="%.5f") 60 | print(f'Per-sequence results saved in {csv_name_per_sequence_path}') 61 | 62 | # Print the results 63 | sys.stdout.write(f"--------------------------- Global results for {args.set} ---------------------------\n") 64 | print(table_g.to_string(index=False)) 65 | sys.stdout.write(f"\n---------- Per sequence results for {args.set} ----------\n") 66 | print(table_seq.to_string(index=False)) 67 | total_time = time() - time_start 68 | sys.stdout.write('\nTotal time:' + str(total_time)) 69 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.sgmg import build 2 | 3 | def build_model(args): 4 | print("\n **** BUILD MODEL FOR SgMg. **** \n") 5 | return build(args) 6 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbone modules. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | from einops import rearrange 14 | 15 | from util.misc import NestedTensor, is_main_process 16 | 17 | from .position_encoding import build_position_encoding 18 | 19 | 20 | class FrozenBatchNorm2d(torch.nn.Module): 21 | """ 22 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 23 | 24 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 25 | without which any other models than torchvision.models.resnet[18,34,50,101] 26 | produce nans. 27 | """ 28 | 29 | def __init__(self, n): 30 | super(FrozenBatchNorm2d, self).__init__() 31 | self.register_buffer("weight", torch.ones(n)) 32 | self.register_buffer("bias", torch.zeros(n)) 33 | self.register_buffer("running_mean", torch.zeros(n)) 34 | self.register_buffer("running_var", torch.ones(n)) 35 | 36 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 37 | missing_keys, unexpected_keys, error_msgs): 38 | num_batches_tracked_key = prefix + 'num_batches_tracked' 39 | if num_batches_tracked_key in state_dict: 40 | del state_dict[num_batches_tracked_key] 41 | 42 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 43 | state_dict, prefix, local_metadata, strict, 44 | missing_keys, unexpected_keys, error_msgs) 45 | 46 | def forward(self, x): 47 | # move reshapes to the beginning 48 | # to make it fuser-friendly 49 | w = self.weight.reshape(1, -1, 1, 1) 50 | b = self.bias.reshape(1, -1, 1, 1) 51 | rv = self.running_var.reshape(1, -1, 1, 1) 52 | rm = self.running_mean.reshape(1, -1, 1, 1) 53 | eps = 1e-5 54 | scale = w * (rv + eps).rsqrt() 55 | bias = b - rm * scale 56 | return x * scale + bias 57 | 58 | 59 | class BackboneBase(nn.Module): 60 | 61 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 62 | super().__init__() 63 | for name, parameter in backbone.named_parameters(): 64 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 65 | parameter.requires_grad_(False) 66 | if return_interm_layers: 67 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 68 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} deformable detr 69 | self.strides = [4, 8, 16, 32] 70 | self.num_channels = [256, 512, 1024, 2048] 71 | else: 72 | return_layers = {'layer4': "0"} 73 | self.strides = [32] 74 | self.num_channels = [2048] 75 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 76 | 77 | def forward(self, tensor_list: NestedTensor): 78 | xs = self.body(tensor_list.tensors) 79 | out: Dict[str, NestedTensor] = {} 80 | for name, x in xs.items(): 81 | m = tensor_list.mask 82 | assert m is not None 83 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 84 | out[name] = NestedTensor(x, mask) 85 | return out 86 | 87 | 88 | class Backbone(BackboneBase): 89 | """ResNet backbone with frozen BatchNorm.""" 90 | def __init__(self, name: str, 91 | train_backbone: bool, 92 | return_interm_layers: bool, 93 | dilation: bool): # True 94 | backbone = getattr(torchvision.models, name)( 95 | replace_stride_with_dilation=[False, False, dilation], 96 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 97 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" 98 | super().__init__(backbone, train_backbone, return_interm_layers) 99 | if dilation: 100 | self.strides[-1] = self.strides[-1] // 2 101 | 102 | 103 | class Joiner(nn.Sequential): 104 | def __init__(self, backbone, position_embedding): 105 | super().__init__(backbone, position_embedding) 106 | self.strides = backbone.strides 107 | self.num_channels = backbone.num_channels 108 | 109 | 110 | def forward(self, tensor_list: NestedTensor): 111 | tensor_list.tensors = rearrange(tensor_list.tensors, 'b t c h w -> (b t) c h w') 112 | tensor_list.mask = rearrange(tensor_list.mask, 'b t h w -> (b t) h w') 113 | 114 | xs = self[0](tensor_list) 115 | out: List[NestedTensor] = [] 116 | pos = [] 117 | for name, x in xs.items(): 118 | out.append(x) 119 | # position encoding 120 | pos.append(self[1](x).to(x.tensors.dtype)) 121 | return out, pos 122 | 123 | 124 | def build_backbone(args): 125 | position_embedding = build_position_encoding(args) 126 | train_backbone = args.lr_backbone > 0 127 | return_interm_layers = args.masks or (args.num) 128 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 129 | model = Joiner(backbone, position_embedding) 130 | model.num_channels = backbone.num_channels 131 | return model 132 | 133 | -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | from typing import Optional, List 4 | import math 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn, Tensor 9 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 10 | from util.misc import inverse_sigmoid 11 | from einops import rearrange 12 | 13 | 14 | class MSO(nn.Module): 15 | def __init__(self, mask_dim=16, img_dim=[96, 192], out_dim=16): 16 | super().__init__() 17 | 18 | self.mask_dim = mask_dim 19 | self.img_dim = img_dim 20 | self.out_dim = out_dim 21 | 22 | self.conv1_1div8 = nn.Conv2d(mask_dim+img_dim[1], mask_dim, kernel_size=3, padding=1) 23 | self.conv2_1div8 = nn.Conv2d(mask_dim, mask_dim, kernel_size=3, padding=1) 24 | 25 | self.conv1_1div4 = nn.Conv2d(mask_dim + img_dim[0], mask_dim, kernel_size=3, padding=1) 26 | self.conv2_1div4 = nn.Conv2d(mask_dim, mask_dim, kernel_size=3, padding=1) 27 | 28 | # TODO: add image on channel. deconv to upsample 29 | def forward(self, pred_masks, image_features): 30 | image_features = [x.tensors for x in image_features] # 1/4 & 1/8 31 | 32 | # merge with 1/8 image 33 | assert pred_masks.shape[-1] == image_features[-1].shape[-1], "First size wrong." 34 | x = torch.cat([pred_masks, image_features[-1]], dim=1) 35 | pred_masks += self.conv2_1div8(F.relu(self.conv1_1div8(F.relu(x)))) 36 | 37 | # merge with 1/4 image 38 | pred_masks = F.interpolate(pred_masks, size=(image_features[-2].shape[-2], image_features[-2].shape[-1]), mode='bilinear', align_corners=False) 39 | assert pred_masks.shape[-1] == image_features[-2].shape[-1], "Second size wrong." 40 | x = torch.cat([pred_masks, image_features[-2]], dim=1) 41 | pred_masks += self.conv2_1div4(F.relu(self.conv1_1div4(F.relu(x)))) 42 | 43 | return pred_masks 44 | 45 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instance Sequence Matching 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, multi_iou 11 | from util.misc import nested_tensor_from_tensor_list 12 | 13 | INF = 100000000 14 | 15 | def dice_coef(inputs, targets): 16 | inputs = inputs.sigmoid() 17 | inputs = inputs.flatten(1).unsqueeze(1) # [N, 1, THW] 18 | targets = targets.flatten(1).unsqueeze(0) # [1, M, THW] 19 | numerator = 2 * (inputs * targets).sum(2) 20 | denominator = inputs.sum(-1) + targets.sum(-1) 21 | 22 | coef = (numerator + 1) / (denominator + 1) 23 | return coef 24 | 25 | def sigmoid_focal_coef(inputs, targets, alpha: float = 0.25, gamma: float = 2): 26 | N, M = len(inputs), len(targets) 27 | inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1) 28 | targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1) 29 | 30 | prob = inputs.sigmoid() 31 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 32 | p_t = prob * targets + (1 - prob) * (1 - targets) 33 | coef = ce_loss * ((1 - p_t) ** gamma) 34 | 35 | if alpha >= 0: 36 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 37 | coef = alpha_t * coef 38 | 39 | return coef.mean(2) 40 | 41 | 42 | class HungarianMatcher(nn.Module): 43 | """This class computes an assignment between the targets and the predictions of the network 44 | 45 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 46 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 47 | while the others are un-matched (and thus treated as non-objects). 48 | """ 49 | 50 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, 51 | cost_mask: float = 1, cost_dice: float = 1, cost_boundary: float = 1, num_classes: int = 1): 52 | """Creates the matcher 53 | 54 | Params: 55 | cost_class: This is the relative weight of the classification error in the matching cost 56 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 57 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 58 | cost_mask: This is the relative weight of the sigmoid focal loss of the mask in the matching cost 59 | cost_dice: This is the relative weight of the dice loss of the mask in the matching cost 60 | """ 61 | super().__init__() 62 | self.cost_class = cost_class 63 | self.cost_bbox = cost_bbox 64 | self.cost_giou = cost_giou 65 | self.cost_mask = cost_mask 66 | self.cost_dice = cost_dice 67 | self.cost_boundary = cost_boundary 68 | self.num_classes = num_classes 69 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0 \ 70 | or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" 71 | self.mask_out_stride = 2 72 | 73 | @torch.no_grad() 74 | def forward(self, outputs, targets): 75 | """ Performs the matching 76 | Params: 77 | outputs: This is a dict that contains at least these entries: 78 | "pred_logits": Tensor of dim [batch_size, num_queries_per_frame, num_frames, num_classes] with the classification logits 79 | "pred_boxes": Tensor of dim [batch_size, num_queries_per_frame, num_frames, 4] with the predicted box coordinates 80 | "pred_masks": Tensor of dim [batch_size, num_queries_per_frame, num_frames, h, w], h,w in 4x size 81 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 82 | NOTE: Since every frame has one object at most 83 | "labels": Tensor of dim [num_frames] (where num_target_boxes is the number of ground-truth 84 | objects in the target) containing the class labels 85 | "boxes": Tensor of dim [num_frames, 4] containing the target box coordinates 86 | "masks": Tensor of dim [num_frames, h, w], h,w in origin size 87 | Returns: 88 | A list of size batch_size, containing tuples of (index_i, index_j) where: 89 | - index_i is the indices of the selected predictions (in order) 90 | - index_j is the indices of the corresponding selected targets (in order) 91 | For each batch element, it holds: 92 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 93 | """ 94 | src_logits = outputs["pred_logits"] 95 | src_boxes = outputs["pred_boxes"] 96 | src_masks = outputs["pred_masks"] 97 | 98 | bs, nf, nq, h, w = src_masks.shape 99 | 100 | # handle mask padding issue 101 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], 102 | size_divisibility=32, 103 | split=False).decompose() 104 | target_masks = target_masks.to(src_masks) # [B, T, H, W] 105 | 106 | # downsample ground truth masks with ratio mask_out_stride 107 | start = int(self.mask_out_stride // 2) 108 | im_h, im_w = target_masks.shape[-2:] 109 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 110 | assert target_masks.size(2) * self.mask_out_stride == im_h 111 | assert target_masks.size(3) * self.mask_out_stride == im_w 112 | 113 | indices = [] 114 | for i in range(bs): 115 | out_prob = src_logits[i].sigmoid() 116 | out_bbox = src_boxes[i] 117 | out_mask = src_masks[i] 118 | 119 | tgt_ids = targets[i]["labels"] 120 | tgt_bbox = targets[i]["boxes"] 121 | tgt_mask = target_masks[i] 122 | tgt_valid = targets[i]["valid"] 123 | 124 | cost_class = [] 125 | for t in range(nf): 126 | # filter invalid frames 127 | if tgt_valid[t] == 0: 128 | continue 129 | 130 | out_prob_split = out_prob[t] 131 | tgt_ids_split = tgt_ids[t].unsqueeze(0) 132 | 133 | alpha = 0.25 134 | gamma = 2.0 135 | neg_cost_class = (1 - alpha) * (out_prob_split ** gamma) * (-(1 - out_prob_split + 1e-8).log()) 136 | pos_cost_class = alpha * ((1 - out_prob_split) ** gamma) * (-(out_prob_split + 1e-8).log()) 137 | if self.num_classes == 1: 138 | cost_class_split = pos_cost_class[:, [0]] - neg_cost_class[:, [0]] 139 | else: 140 | cost_class_split = pos_cost_class[:, tgt_ids_split] - neg_cost_class[:, tgt_ids_split] 141 | 142 | cost_class.append(cost_class_split) 143 | cost_class = torch.stack(cost_class, dim=0).mean(0) 144 | 145 | cost_bbox, cost_giou = [], [] 146 | for t in range(nf): 147 | out_bbox_split = out_bbox[t] 148 | tgt_bbox_split = tgt_bbox[t].unsqueeze(0) 149 | 150 | cost_bbox_split = torch.cdist(out_bbox_split, tgt_bbox_split, p=1) 151 | cost_giou_split = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox_split), 152 | box_cxcywh_to_xyxy(tgt_bbox_split)) 153 | 154 | cost_bbox.append(cost_bbox_split) 155 | cost_giou.append(cost_giou_split) 156 | cost_bbox = torch.stack(cost_bbox, dim=0).mean(0) 157 | cost_giou = torch.stack(cost_giou, dim=0).mean(0) 158 | 159 | cost_mask = sigmoid_focal_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0)) 160 | cost_dice = -dice_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0)) 161 | 162 | # Final cost matrix 163 | C = self.cost_class * cost_class + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou + \ 164 | self.cost_mask * cost_mask + self.cost_dice * cost_dice # [q, 1] 165 | 166 | _, src_ind = torch.min(C, dim=0) 167 | tgt_ind = torch.arange(1).to(src_ind) 168 | indices.append((src_ind.long(), tgt_ind.long())) 169 | 170 | # list[tuple], length is batch_size 171 | return indices 172 | 173 | 174 | def build_matcher(args): 175 | if args.binary: 176 | num_classes = 1 177 | else: 178 | if args.dataset_file == 'ytvos': 179 | num_classes = 65 180 | elif args.dataset_file == 'davis': 181 | num_classes = 78 182 | elif args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb': 183 | num_classes = 1 184 | else: 185 | num_classes = 91 # for coco 186 | 187 | return HungarianMatcher(cost_class=args.set_cost_class, 188 | cost_bbox=args.set_cost_bbox, 189 | cost_giou=args.set_cost_giou, 190 | cost_mask=args.set_cost_mask, 191 | cost_dice=args.set_cost_dice, 192 | cost_boundary=args.set_cost_boundary, 193 | num_classes=num_classes) 194 | 195 | 196 | 197 | -------------------------------------------------------------------------------- /models/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from typing import Optional, List 5 | from torch import Tensor 6 | from einops import rearrange 7 | 8 | 9 | class LFMResizeAdaptive(nn.Module): 10 | def __init__(self, num_channels, sigma): 11 | super(LFMResizeAdaptive, self).__init__() 12 | self.conv1 = nn.Conv2d(2 * num_channels, 2 * num_channels, kernel_size=1, stride=1, padding=0) 13 | self.conv2 = nn.Conv2d(2 * num_channels, 2 * num_channels, kernel_size=1, stride=1, padding=0) 14 | self.sigma = sigma 15 | 16 | self.laplace = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=0) 17 | self.pool = nn.AdaptiveAvgPool2d(1) 18 | self.fc = nn.Sequential( 19 | nn.Linear(num_channels, num_channels, bias=False), 20 | nn.ReLU(inplace=True), 21 | nn.Linear(num_channels, 1, bias=False), 22 | nn.Sigmoid() 23 | ) 24 | 25 | def make_gaussian(self, y_idx, x_idx, height, width, sigma=7, device='cpu'): 26 | yv, xv = torch.meshgrid([torch.arange(0, height), torch.arange(0, width)]) 27 | 28 | yv = yv.unsqueeze(0).float().to(device) 29 | xv = xv.unsqueeze(0).float().to(device) 30 | g = torch.exp(- ((yv - y_idx) ** 2 + (xv - x_idx) ** 2) / (2 * sigma ** 2)) 31 | return g.unsqueeze(0) 32 | 33 | def forward(self, x, gauss_map=None): 34 | b, c, h, w = x.shape 35 | x = x.float() 36 | 37 | # compute coef for gaussian 0~1 38 | coef = self.laplace(x) 39 | coef = self.fc(self.pool(coef).view(b, c)).view(b, 1, 1, 1) 40 | 41 | y = torch.fft.fft2(x) 42 | 43 | h_idx, w_idx = h // 2, w // 2 44 | if gauss_map is None: 45 | high_filter = self.make_gaussian(h_idx, w_idx, h, w, self.sigma, device=x.device) 46 | else: 47 | high_filter = F.interpolate(gauss_map, size=(h, w), mode='bilinear', align_corners=False) 48 | 49 | y = y * (1 - coef * high_filter) 50 | 51 | y_imag = y.imag 52 | y_real = y.real 53 | y_f = torch.cat([y_real, y_imag], dim=1) 54 | y = F.relu(self.conv1(y_f)) 55 | 56 | y = self.conv2(y).float() 57 | y_real, y_imag = torch.chunk(y, 2, dim=1) 58 | y = torch.complex(y_real, y_imag) 59 | 60 | y = torch.fft.ifft2(y, s=(h, w)).float() 61 | return x + y, high_filter 62 | 63 | 64 | if __name__ == "__main__": 65 | model = LFMResizeAdaptive(256, 3) 66 | data = torch.rand(2,256,8,8) 67 | res = model(data) 68 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # Modify for sample points visualization 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import print_function 12 | from __future__ import division 13 | 14 | import warnings 15 | import math 16 | 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | from torch.nn.init import xavier_uniform_, constant_ 21 | 22 | from ..functions import MSDeformAttnFunction 23 | 24 | 25 | def _is_power_of_2(n): 26 | if (not isinstance(n, int)) or (n < 0): 27 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 28 | return (n & (n-1) == 0) and n != 0 29 | 30 | 31 | class MSDeformAttn(nn.Module): 32 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 33 | """ 34 | Multi-Scale Deformable Attention Module 35 | :param d_model hidden dimension 36 | :param n_levels number of feature levels 37 | :param n_heads number of attention heads 38 | :param n_points number of sampling points per attention head per feature level 39 | """ 40 | super().__init__() 41 | if d_model % n_heads != 0: 42 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 43 | _d_per_head = d_model // n_heads 44 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 45 | if not _is_power_of_2(_d_per_head): 46 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 47 | "which is more efficient in our CUDA implementation.") 48 | 49 | self.im2col_step = 64 50 | 51 | self.d_model = d_model 52 | self.n_levels = n_levels 53 | self.n_heads = n_heads 54 | self.n_points = n_points 55 | 56 | # res = sum(attn * W*(delta p)) 57 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) # delta p 58 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) # attn 59 | self.value_proj = nn.Linear(d_model, d_model) 60 | self.output_proj = nn.Linear(d_model, d_model) 61 | 62 | self._reset_parameters() 63 | 64 | def _reset_parameters(self): 65 | constant_(self.sampling_offsets.weight.data, 0.) 66 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 67 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 68 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 69 | for i in range(self.n_points): 70 | grid_init[:, :, i, :] *= i + 1 71 | with torch.no_grad(): 72 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 73 | constant_(self.attention_weights.weight.data, 0.) 74 | constant_(self.attention_weights.bias.data, 0.) 75 | xavier_uniform_(self.value_proj.weight.data) 76 | constant_(self.value_proj.bias.data, 0.) 77 | xavier_uniform_(self.output_proj.weight.data) 78 | constant_(self.output_proj.bias.data, 0.) 79 | 80 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 81 | """ 82 | :param query (N, Length_{query}, C) 83 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 84 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 85 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 86 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 87 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 88 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 89 | 90 | :return output (N, Length_{query}, C) 91 | """ 92 | N, Len_q, _ = query.shape 93 | N, Len_in, _ = input_flatten.shape 94 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 95 | 96 | value = self.value_proj(input_flatten) 97 | if input_padding_mask is not None: 98 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 99 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 100 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 101 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 102 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 103 | # N, Len_q, n_heads, n_levels, n_points, 2 104 | if reference_points.shape[-1] == 2: 105 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 106 | sampling_locations = reference_points[:, :, None, :, None, :] \ 107 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 108 | elif reference_points.shape[-1] == 4: 109 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 110 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 111 | else: 112 | raise ValueError( 113 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 114 | output = MSDeformAttnFunction.apply( 115 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 116 | output = self.output_proj(output) 117 | 118 | return output, sampling_locations, attention_weights 119 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from util.misc import NestedTensor 10 | 11 | # dimension == 1 12 | class PositionEmbeddingSine1D(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=256, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors # [B, C, T] 30 | mask = tensor_list.mask # [B, T] 31 | assert mask is not None 32 | not_mask = ~mask 33 | x_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T] 34 | if self.normalize: 35 | eps = 1e-6 36 | x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, None] / dim_t # [B, T, C] 42 | # n,c,t 43 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 44 | pos = pos_x.permute(0, 2, 1) # [B, C, T] 45 | return pos 46 | 47 | # dimension == 2 48 | class PositionEmbeddingSine2D(nn.Module): 49 | """ 50 | This is a more standard version of the position embedding, very similar to the one 51 | used by the Attention is all you need paper, generalized to work on images. 52 | """ 53 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 54 | super().__init__() 55 | self.num_pos_feats = num_pos_feats 56 | self.temperature = temperature 57 | self.normalize = normalize 58 | if scale is not None and normalize is False: 59 | raise ValueError("normalize should be True if scale is passed") 60 | if scale is None: 61 | scale = 2 * math.pi 62 | self.scale = scale 63 | 64 | def forward(self, tensor_list: NestedTensor): 65 | x = tensor_list.tensors # [B, C, H, W] 66 | mask = tensor_list.mask # [B, H, W] 67 | assert mask is not None 68 | not_mask = ~mask 69 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 70 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 71 | if self.normalize: 72 | eps = 1e-6 73 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 74 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 75 | 76 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 77 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 78 | 79 | pos_x = x_embed[:, :, :, None] / dim_t 80 | pos_y = y_embed[:, :, :, None] / dim_t 81 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 82 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 83 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 84 | return pos # [B, C, H, W] 85 | 86 | 87 | # dimension == 3 88 | class PositionEmbeddingSine3D(nn.Module): 89 | """ 90 | This is a more standard version of the position embedding, very similar to the one 91 | used by the Attention is all you need paper, generalized to work on images. 92 | """ 93 | def __init__(self, num_pos_feats=64, num_frames=36, temperature=10000, normalize=False, scale=None): 94 | super().__init__() 95 | self.num_pos_feats = num_pos_feats 96 | self.temperature = temperature 97 | self.normalize = normalize 98 | self.frames = num_frames 99 | if scale is not None and normalize is False: 100 | raise ValueError("normalize should be True if scale is passed") 101 | if scale is None: 102 | scale = 2 * math.pi 103 | self.scale = scale 104 | 105 | def forward(self, tensor_list: NestedTensor): 106 | x = tensor_list.tensors # [B*T, C, H, W] 107 | mask = tensor_list.mask # [B*T, H, W] 108 | n,h,w = mask.shape 109 | mask = mask.reshape(n//self.frames, self.frames,h,w) # [B, T, H, W] 110 | assert mask is not None 111 | not_mask = ~mask 112 | z_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T, H, W] 113 | y_embed = not_mask.cumsum(2, dtype=torch.float32) # [B, T, H, W] 114 | x_embed = not_mask.cumsum(3, dtype=torch.float32) # [B, T, H, W] 115 | if self.normalize: 116 | eps = 1e-6 117 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale 118 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 119 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 120 | 121 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) # 122 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 123 | 124 | pos_x = x_embed[:, :, :, :, None] / dim_t # [B, T, H, W, c] 125 | pos_y = y_embed[:, :, :, :, None] / dim_t 126 | pos_z = z_embed[:, :, :, :, None] / dim_t 127 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) # [B, T, H, W, c] 128 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 129 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 130 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) # [B, T, C, H, W] 131 | return pos 132 | 133 | 134 | 135 | def build_position_encoding(args): 136 | # build 2D position encoding 137 | N_steps = args.hidden_dim // 2 # 256 / 2 = 128 138 | if args.position_embedding in ('v2', 'sine'): 139 | # TODO find a better way of exposing other arguments 140 | position_embedding = PositionEmbeddingSine2D(N_steps, normalize=True) 141 | else: 142 | raise ValueError(f"not supported {args.position_embedding}") 143 | 144 | return position_embedding 145 | -------------------------------------------------------------------------------- /models/postprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | """Postprocessors class to transform MDETR output according to the downstream task""" 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | import pycocotools.mask as mask_util 10 | 11 | from util import box_ops 12 | import os 13 | 14 | class A2DSentencesPostProcess(nn.Module): 15 | """ 16 | This module converts the model's output into the format expected by the coco api for the given task 17 | """ 18 | def __init__(self, threshold=0.5): 19 | super().__init__() 20 | self.threshold = threshold 21 | 22 | @torch.no_grad() 23 | def forward(self, outputs, orig_target_sizes, max_target_sizes): 24 | """ Perform the computation 25 | Parameters: 26 | outputs: raw outputs of the model 27 | orig_target_sizes: original size of the samples (no augmentations or padding) 28 | max_target_sizes: size of samples (input to model) after size augmentation. 29 | NOTE: the max_padding_size is 4x out_masks.shape[-2:] 30 | """ 31 | assert len(orig_target_sizes) == len(max_target_sizes) 32 | # there is only one valid frames, thus T=1 33 | out_logits = outputs['pred_logits'][:, 0, :, 0] 34 | out_masks = outputs['pred_masks'][:, 0, :, :, :] 35 | 36 | # TODO: rerank mask to get better results. 37 | scores = out_logits.sigmoid() 38 | pred_masks = out_masks 39 | processed_pred_masks, rle_masks = [], [] 40 | # for each batch 41 | for f_pred_masks, resized_size, orig_size in zip(pred_masks, max_target_sizes, orig_target_sizes): 42 | f_mask_h, f_mask_w = resized_size # resized shape without padding 43 | f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, :f_mask_w].unsqueeze(1) 44 | # resize the samples back to their original dataset (target) size for evaluation 45 | f_pred_masks_processed = F.interpolate(f_pred_masks_no_pad, size=tuple(orig_size.tolist()), mode="bilinear", align_corners=False) 46 | f_pred_masks_processed = (f_pred_masks_processed.sigmoid() > 0.5) # [B, N, H, W] 47 | f_pred_rle_masks = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 48 | for mask in f_pred_masks_processed.cpu()] 49 | processed_pred_masks.append(f_pred_masks_processed) 50 | rle_masks.append(f_pred_rle_masks) 51 | predictions = [{'scores': s, 'masks': m, 'rle_masks': rle} 52 | for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] 53 | return predictions 54 | 55 | 56 | # PostProcess for pretraining 57 | class PostProcess(nn.Module): 58 | """ This module converts the model's output into the format expected by the coco api""" 59 | 60 | @torch.no_grad() 61 | def forward(self, outputs, target_sizes): 62 | """Perform the computation 63 | Parameters: 64 | outputs: raw outputs of the model 65 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 66 | For evaluation, this must be the original image size (before any data augmentation) 67 | For visualization, this should be the image size after data augment, but before padding 68 | Returns: 69 | 70 | """ 71 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] # b t q c 72 | assert len(out_logits) == len(target_sizes) 73 | assert target_sizes.shape[1] == 2 74 | 75 | out_logits = outputs["pred_logits"].flatten(0,1) 76 | out_boxes = outputs["pred_boxes"].flatten(0,1) 77 | bs, num_queries = out_logits.shape[:2] 78 | 79 | prob = out_logits.sigmoid() 80 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True) 81 | scores = topk_values 82 | topk_boxes = topk_indexes // out_logits.shape[2] 83 | labels = topk_indexes % out_logits.shape[2] 84 | 85 | boxes = box_ops.box_cxcywh_to_xyxy(out_boxes) 86 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) 87 | 88 | img_h, img_w = target_sizes.unbind(1) 89 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 90 | boxes = boxes * scale_fct[:, None, :] 91 | assert len(scores) == len(labels) == len(boxes) 92 | results = [{"scores": s, "labels": torch.ones_like(l), "boxes": b} for s, l, b in zip(scores, labels, boxes)] 93 | return results 94 | 95 | 96 | # For Ref-COCO 97 | class PostProcessSegm(nn.Module): 98 | """Similar to PostProcess but for segmentation masks. 99 | This processor is to be called sequentially after PostProcess. 100 | Args: 101 | threshold: threshold that will be applied to binarize the segmentation masks. 102 | """ 103 | 104 | def __init__(self, threshold=0.5): 105 | super().__init__() 106 | self.threshold = threshold 107 | 108 | @torch.no_grad() 109 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 110 | """Perform the computation 111 | Parameters: 112 | results: already pre-processed boxes (output of PostProcess) NOTE here 113 | outputs: raw outputs of the model 114 | orig_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 115 | For evaluation, this must be the original image size (before any data augmentation) 116 | For visualization, this should be the image size after data augment, but before padding 117 | max_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 118 | after data augmentation. 119 | """ 120 | assert len(orig_target_sizes) == len(max_target_sizes) 121 | out_logits = outputs["pred_logits"].flatten(0, 1) # bt q 1 122 | out_masks = outputs["pred_masks"].flatten(0, 1) # bt q h w 123 | bs, num_queries = out_logits.shape[:2] 124 | 125 | # rerank based on score 126 | prob = out_logits.sigmoid() 127 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True) 128 | topk_boxes = topk_indexes // out_logits.shape[2] 129 | outputs_masks = [out_m[topk_boxes[i]].unsqueeze(0) for i, out_m, in enumerate(out_masks)] 130 | outputs_masks = torch.cat(outputs_masks, dim=0) 131 | 132 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): # for each b 133 | img_h, img_w = t[0], t[1] 134 | msk = cur_mask[:, :img_h, :img_w].unsqueeze(1).cpu() # q 1 h w unpad 135 | # resize to raw resolution 136 | msk = F.interpolate(msk, size=tuple(tt.tolist()), mode="bilinear", align_corners=False) # # resize to init resolution 137 | msk = (msk.sigmoid() > 0.5).cpu() # q 1 h w 138 | results[i]["masks"] = msk.byte() 139 | results[i]["rle_masks"] = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 140 | for mask in results[i]["masks"].cpu()] 141 | 142 | return results 143 | 144 | def build_postprocessors(args, dataset_name): 145 | print("\n **** BUILD POSTPROCESSOR FOR {}. **** \n".format(dataset_name)) 146 | if dataset_name == 'a2d' or dataset_name == 'jhmdb': 147 | postprocessors = A2DSentencesPostProcess(threshold=args.threshold) 148 | else: 149 | postprocessors: Dict[str, nn.Module] = {"bbox": PostProcess()} 150 | if args.masks: 151 | postprocessors["segm"] = PostProcessSegm(threshold=args.threshold) 152 | return postprocessors 153 | -------------------------------------------------------------------------------- /models/segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Segmentaion Part 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | from collections import defaultdict 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch import Tensor 12 | from PIL import Image 13 | 14 | from einops import rearrange, repeat 15 | 16 | try: 17 | from panopticapi.utils import id2rgb, rgb2id 18 | except ImportError: 19 | pass 20 | 21 | import fvcore.nn.weight_init as weight_init 22 | 23 | from .position_encoding import PositionEmbeddingSine1D 24 | 25 | BN_MOMENTUM = 0.1 26 | 27 | def get_norm(norm, out_channels): # only support GN or LN 28 | """ 29 | Args: 30 | norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; 31 | or a callable that takes a channel number and returns 32 | the normalization layer as a nn.Module. 33 | 34 | Returns: 35 | nn.Module or None: the normalization layer 36 | """ 37 | if norm is None: 38 | return None 39 | if isinstance(norm, str): 40 | if len(norm) == 0: 41 | return None 42 | norm = { 43 | "GN": lambda channels: nn.GroupNorm(8, channels), 44 | "LN": lambda channels: nn.LayerNorm(channels) 45 | }[norm] 46 | return norm(out_channels) 47 | 48 | class Conv2d(torch.nn.Conv2d): 49 | """ 50 | A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. 51 | """ 52 | 53 | def __init__(self, *args, **kwargs): 54 | """ 55 | Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: 56 | 57 | Args: 58 | norm (nn.Module, optional): a normalization layer 59 | activation (callable(Tensor) -> Tensor): a callable activation function 60 | 61 | It assumes that norm layer is used before activation. 62 | """ 63 | norm = kwargs.pop("norm", None) 64 | activation = kwargs.pop("activation", None) 65 | super().__init__(*args, **kwargs) 66 | 67 | self.norm = norm 68 | self.activation = activation 69 | 70 | def forward(self, x): 71 | # torchscript does not support SyncBatchNorm yet 72 | # https://github.com/pytorch/pytorch/issues/40507 73 | # and we skip these codes in torchscript since: 74 | # 1. currently we only support torchscript in evaluation mode 75 | # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or 76 | # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. 77 | if not torch.jit.is_scripting(): 78 | if x.numel() == 0 and self.training: 79 | # https://github.com/pytorch/pytorch/issues/12013 80 | assert not isinstance( 81 | self.norm, torch.nn.SyncBatchNorm 82 | ), "SyncBatchNorm does not support empty inputs!" 83 | 84 | x = F.conv2d( 85 | x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups 86 | ) 87 | if self.norm is not None: 88 | x = self.norm(x) 89 | if self.activation is not None: 90 | x = self.activation(x) 91 | return x 92 | 93 | 94 | class VisionLanguageFusionModule(nn.Module): 95 | def __init__(self, d_model, nhead, dropout=0.0): 96 | super().__init__() 97 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 98 | 99 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 100 | return tensor if pos is None else tensor + pos 101 | 102 | def forward(self, visual, text, 103 | text_key_padding_mask: Optional[Tensor] = None, 104 | text_pos: Optional[Tensor] = None, 105 | visual_pos: Optional[Tensor] = None): 106 | visual = rearrange(visual, 't h w b c -> (t h w) b c') 107 | visual2 = self.multihead_attn(query=self.with_pos_embed(visual, visual_pos), 108 | key=self.with_pos_embed(text, text_pos), 109 | value=text, attn_mask=None, 110 | key_padding_mask=text_key_padding_mask)[0] 111 | visual = visual * visual2 112 | return visual 113 | 114 | 115 | def dice_loss(inputs, targets, num_boxes): 116 | """ 117 | Compute the DICE loss, similar to generalized IOU for masks 118 | Args: 119 | inputs: A float tensor of arbitrary shape. 120 | The predictions for each example. 121 | targets: A float tensor with the same shape as inputs. Stores the binary 122 | classification label for each element in inputs 123 | (0 for the negative class and 1 for the positive class). 124 | """ 125 | inputs = inputs.sigmoid() 126 | inputs = inputs.flatten(1) 127 | numerator = 2 * (inputs * targets).sum(1) 128 | denominator = inputs.sum(-1) + targets.sum(-1) 129 | loss = 1 - (numerator + 1) / (denominator + 1) 130 | return loss.sum() / num_boxes 131 | 132 | 133 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 134 | """ 135 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 136 | Args: 137 | inputs: A float tensor of arbitrary shape. 138 | The predictions for each example. 139 | targets: A float tensor with the same shape as inputs. Stores the binary 140 | classification label for each element in inputs 141 | (0 for the negative class and 1 for the positive class). 142 | alpha: (optional) Weighting factor in range (0,1) to balance 143 | positive vs negative examples. Default = -1 (no weighting). 144 | gamma: Exponent of the modulating factor (1 - p_t) to 145 | balance easy vs hard examples. 146 | Returns: 147 | Loss tensor 148 | """ 149 | prob = inputs.sigmoid() 150 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 151 | p_t = prob * targets + (1 - prob) * (1 - targets) 152 | loss = ce_loss * ((1 - p_t) ** gamma) 153 | 154 | if alpha >= 0: 155 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 156 | loss = alpha_t * loss 157 | 158 | return loss.mean(1).sum() / num_boxes 159 | 160 | 161 | def _get_activation_fn(activation): 162 | """Return an activation function given a string""" 163 | if activation == "relu": 164 | return F.relu 165 | if activation == "gelu": 166 | return F.gelu 167 | if activation == "glu": 168 | return F.glu 169 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 170 | 171 | 172 | -------------------------------------------------------------------------------- /models/text_encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-miao/SgMg/90fd3c476858218b1ed0c21ec28e64e762ca2c84/models/text_encoder/__init__.py -------------------------------------------------------------------------------- /models/text_encoder/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-miao/SgMg/90fd3c476858218b1ed0c21ec28e64e762ca2c84/models/text_encoder/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/text_encoder/text_encoder.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains a wrapper for Video-Swin-Transformer so it can be properly used as a temporal encoder for MTTR. 3 | """ 4 | import torch 5 | import os 6 | from torch import nn, Tensor 7 | from einops import rearrange, repeat 8 | 9 | from transformers import RobertaModel, RobertaTokenizerFast 10 | from models.text_encoder.tokenizer import RobertaTokenizer 11 | 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | 15 | 16 | class FeatureResizer(nn.Module): 17 | def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): 18 | super().__init__() 19 | self.do_ln = do_ln 20 | self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) 21 | self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) 22 | self.dropout = nn.Dropout(dropout) 23 | 24 | def forward(self, encoder_features): 25 | x = self.fc(encoder_features) 26 | if self.do_ln: 27 | x = self.layer_norm(x) 28 | output = self.dropout(x) 29 | return output 30 | 31 | 32 | class TextEncoder(nn.Module): 33 | def __init__(self, args): 34 | super(TextEncoder, self).__init__() 35 | self.args = args 36 | self.hidden_dim = args.hidden_dim 37 | self.text_backbone_name = args.text_backbone 38 | self.token_size = 32 39 | if self.text_backbone_name == "Roberta": 40 | self.text_backbone = RobertaModel.from_pretrained("roberta-base") 41 | # self.text_backbone.pooler = None # this pooler is never used, this is a hack to avoid DDP problems... 42 | self.tokenizer = RobertaTokenizer() 43 | self.feat_dim = 768 44 | else: 45 | assert False, f'error: Text Encoder "{self.text_backbone_name}" is not supported' 46 | 47 | self.freeze_text_encoder = args.freeze_text_encoder 48 | if self.freeze_text_encoder: 49 | # self.text_backbone.eval() 50 | for p in self.text_backbone.parameters(): 51 | p.requires_grad_(False) 52 | for p in self.tokenizer.parameters(): 53 | p.requires_grad_(False) 54 | print("Use {} as text encoder. Freeze: {}".format(self.text_backbone_name, self.freeze_text_encoder)) 55 | 56 | self.target_len = None 57 | 58 | def forward(self, texts, device): 59 | if self.freeze_text_encoder: 60 | with torch.no_grad(): 61 | tokenized_queries = self.tokenizer(texts).to(device) 62 | if self.text_backbone_name == "Roberta": 63 | encoded_text = self.text_backbone(**tokenized_queries) 64 | text_pad_mask = tokenized_queries.attention_mask.ne(1).bool() 65 | text_features = encoded_text.last_hidden_state 66 | text_sentence_features = encoded_text.pooler_output 67 | else: 68 | raise NotImplementedError 69 | else: 70 | tokenized_queries = self.tokenizer(texts).to(device) 71 | if self.text_backbone_name == "Roberta": 72 | encoded_text = self.text_backbone(**tokenized_queries) 73 | text_pad_mask = tokenized_queries.attention_mask.ne(1).bool() 74 | text_features = encoded_text.last_hidden_state 75 | text_sentence_features = encoded_text.pooler_output 76 | else: 77 | raise NotImplementedError 78 | 79 | return text_features, text_sentence_features, text_pad_mask 80 | 81 | def num_parameters(self): 82 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 83 | 84 | -------------------------------------------------------------------------------- /models/text_encoder/tokenizer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Union 3 | import gzip 4 | import html 5 | import os 6 | from functools import lru_cache 7 | import ftfy 8 | import regex as re 9 | 10 | import torch 11 | import torch.nn as nn 12 | from transformers import RobertaTokenizerFast 13 | 14 | @lru_cache() 15 | def default_bpe(): 16 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 17 | 18 | 19 | @lru_cache() 20 | def bytes_to_unicode(): 21 | """ 22 | Returns list of utf-8 byte and a corresponding list of unicode strings. 23 | The reversible bpe codes work on unicode strings. 24 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 25 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 26 | This is a signficant percentage of your normal, say, 32K bpe vocab. 27 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 28 | And avoids mapping to whitespace/control characters the bpe code barfs on. 29 | """ 30 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 31 | cs = bs[:] 32 | n = 0 33 | for b in range(2**8): 34 | if b not in bs: 35 | bs.append(b) 36 | cs.append(2**8+n) 37 | n += 1 38 | cs = [chr(n) for n in cs] 39 | return dict(zip(bs, cs)) 40 | 41 | 42 | def get_pairs(word): 43 | """Return set of symbol pairs in a word. 44 | Word is represented as tuple of symbols (symbols being variable-length strings). 45 | """ 46 | pairs = set() 47 | prev_char = word[0] 48 | for char in word[1:]: 49 | pairs.add((prev_char, char)) 50 | prev_char = char 51 | return pairs 52 | 53 | 54 | def basic_clean(text): 55 | text = ftfy.fix_text(text) 56 | text = html.unescape(html.unescape(text)) 57 | return text.strip() 58 | 59 | 60 | def whitespace_clean(text): 61 | text = re.sub(r'\s+', ' ', text) 62 | text = text.strip() 63 | return text 64 | 65 | 66 | class SimpleTokenizer(object): 67 | def __init__(self, bpe_path: str = default_bpe()): 68 | self.byte_encoder = bytes_to_unicode() 69 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 70 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 71 | merges = merges[1:49152-256-2+1] 72 | merges = [tuple(merge.split()) for merge in merges] 73 | vocab = list(bytes_to_unicode().values()) 74 | vocab = vocab + [v+'' for v in vocab] 75 | for merge in merges: 76 | vocab.append(''.join(merge)) 77 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 82 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 83 | 84 | def bpe(self, token): 85 | if token in self.cache: 86 | return self.cache[token] 87 | word = tuple(token[:-1]) + ( token[-1] + '',) 88 | pairs = get_pairs(word) 89 | 90 | if not pairs: 91 | return token+'' 92 | 93 | while True: 94 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 95 | if bigram not in self.bpe_ranks: 96 | break 97 | first, second = bigram 98 | new_word = [] 99 | i = 0 100 | while i < len(word): 101 | try: 102 | j = word.index(first, i) 103 | new_word.extend(word[i:j]) 104 | i = j 105 | except: 106 | new_word.extend(word[i:]) 107 | break 108 | 109 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 110 | new_word.append(first+second) 111 | i += 2 112 | else: 113 | new_word.append(word[i]) 114 | i += 1 115 | new_word = tuple(new_word) 116 | word = new_word 117 | if len(word) == 1: 118 | break 119 | else: 120 | pairs = get_pairs(word) 121 | word = ' '.join(word) 122 | self.cache[token] = word 123 | return word 124 | 125 | def encode(self, text): 126 | bpe_tokens = [] 127 | text = whitespace_clean(basic_clean(text)).lower() 128 | for token in re.findall(self.pat, text): 129 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 130 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 131 | return bpe_tokens 132 | 133 | def decode(self, tokens): 134 | text = ''.join([self.decoder[token] for token in tokens]) 135 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 136 | return text 137 | 138 | 139 | # ***************************** Tokenize function ************************** 140 | class RobertaTokenizer(nn.Module): 141 | def __init__(self, ): 142 | super(RobertaTokenizer, self).__init__() 143 | self.tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") 144 | 145 | def forward(self, texts): 146 | return self.tokenizer.batch_encode_plus(texts, padding='longest', return_tensors='pt') 147 | 148 | 149 | _tokenizer = SimpleTokenizer() 150 | # example: word_vec = tokenize(["a photo of cat", "a dog"], self.word_length, True).squeeze(0) 151 | def tokenize(texts: Union[str, List[str]], 152 | context_length: int = 77, 153 | truncate: bool = False) -> torch.LongTensor: 154 | """ 155 | Returns the tokenized representation of given input string(s) 156 | 157 | Parameters 158 | ---------- 159 | texts : Union[str, List[str]] 160 | An input string or a list of input strings to tokenize 161 | 162 | context_length : int 163 | The context length to use; all CLIP models use 77 as the context length 164 | 165 | truncate: bool 166 | Whether to truncate the text in case its encoding is longer than the context length 167 | 168 | Returns 169 | ------- 170 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 171 | """ 172 | if isinstance(texts, str): 173 | texts = [texts] 174 | 175 | sot_token = _tokenizer.encoder["<|startoftext|>"] 176 | eot_token = _tokenizer.encoder["<|endoftext|>"] 177 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] 178 | for text in texts] 179 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 180 | 181 | for i, tokens in enumerate(all_tokens): 182 | if len(tokens) > context_length: 183 | if truncate: 184 | tokens = tokens[:context_length] 185 | tokens[-1] = eot_token 186 | else: 187 | raise RuntimeError( 188 | f"Input {texts[i]} is too long for context length {context_length}" 189 | ) 190 | result[i, :len(tokens)] = torch.tensor(tokens) 191 | 192 | return result 193 | 194 | 195 | def single_tokenize(texts: Union[str, List[str]], 196 | context_length: int = 77, 197 | truncate: bool = False) -> torch.LongTensor: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 215 | """ 216 | if isinstance(texts, str): 217 | texts = [texts] 218 | assert len(texts) == 1, "Only accept one text." 219 | sot_token = _tokenizer.encoder["<|startoftext|>"] 220 | eot_token = _tokenizer.encoder["<|endoftext|>"] 221 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] 222 | for text in texts] 223 | 224 | tokens = all_tokens[0] 225 | if truncate: 226 | tokens = tokens[:context_length] 227 | else: 228 | raise RuntimeError( 229 | f"Input {texts[0]} is too long for context length {context_length}" 230 | ) 231 | 232 | result = torch.tensor(tokens, dtype=torch.long).unsqueeze(0) 233 | return result 234 | 235 | 236 | if __name__ == "__main__": 237 | text = ["a big dog on the table has a big nose", 'a big dog'] 238 | model = RobertaTokenizer() 239 | 240 | res = model(text) 241 | msk = res.attention_mask.ne(1).bool() 242 | print(res, msk, res['input_ids'].shape, msk.shape) -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser('SgMg training and inference scripts.', add_help=False) 5 | parser.add_argument('--lr', default=1e-4, type=float) # MTTR 1e-4 1e-5 1e-5 wd 1e-4 6 | parser.add_argument('--lr_backbone', default=5e-5, type=float) 7 | parser.add_argument('--lr_backbone_names', default=['backbone.0'], type=str, nargs='+') 8 | parser.add_argument('--lr_text_encoder', default=1e-5, type=float) 9 | parser.add_argument('--lr_text_encoder_names', default=['text_encoder'], type=str, nargs='+') 10 | parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+') 11 | parser.add_argument('--lr_linear_proj_mult', default=1.0, type=float) 12 | parser.add_argument('--batch_size', default=1, type=int) 13 | parser.add_argument('--weight_decay', default=5e-4, type=float) 14 | parser.add_argument('--epochs', default=10, type=int) 15 | parser.add_argument('--lr_drop', default=[6, 8], type=int, nargs='+') 16 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 17 | help='gradient clipping max norm') 18 | 19 | # New Parameters 20 | parser.add_argument('--amp', default=False, action='store_true') 21 | parser.add_argument('--exp_name', default='main', type=str) 22 | parser.add_argument('--current_epoch', default=0, type=int) 23 | 24 | # load the pretrained weights 25 | parser.add_argument('--pretrained_weights', type=str, default=None, 26 | help="Path to the pretrained model.") 27 | 28 | # Variants of Deformable DETR 29 | parser.add_argument('--with_box_refine', default=False, action='store_true') 30 | parser.add_argument('--two_stage', default=False, action='store_true') # NOTE: must be false 31 | 32 | # * Backbone 33 | # ["resnet50", "resnet101", "swin_t_p4w7", "swin_s_p4w7", "swin_b_p4w7", "swin_l_p4w7"] 34 | # ["video_swin_t_p4w7", "video_swin_s_p4w7", "video_swin_b_p4w7"] 35 | parser.add_argument('--backbone', default='resnet50', type=str, 36 | help="Name of the convolutional backbone to use") 37 | parser.add_argument('--text_backbone', default='Roberta', type=str, 38 | help="Name of the convolutional backbone to use") 39 | parser.add_argument('--backbone_pretrained', default=None, type=str, 40 | help="if use swin backbone and train from scratch, the path to the pretrained weights") 41 | parser.add_argument('--use_checkpoint', action='store_true', help='whether use checkpoint for swin/video swin backbone') 42 | parser.add_argument('--dilation', action='store_true', # DC5 43 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 44 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 45 | help="Type of positional embedding to use on top of the image features") 46 | parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels') 47 | parser.add_argument('--output_levels', default=4, type=int, help='number of feature levels') 48 | 49 | # * Transformer 50 | parser.add_argument('--enc_layers', default=4, type=int, 51 | help="Number of encoding layers in the transformer") 52 | parser.add_argument('--dec_layers', default=4, type=int, 53 | help="Number of decoding layers in the transformer") 54 | parser.add_argument('--dim_feedforward', default=2048, type=int, 55 | help="Intermediate size of the feedforward layers in the transformer blocks") 56 | parser.add_argument('--hidden_dim', default=256, type=int, 57 | help="Size of the embeddings (dimension of the transformer)") 58 | parser.add_argument('--dropout', default=0.1, type=float, 59 | help="Dropout applied in the transformer") 60 | parser.add_argument('--nheads', default=8, type=int, 61 | help="Number of attention heads inside the transformer's attentions") 62 | parser.add_argument('--num_frames', default=5, type=int, 63 | help="Number of clip frames for training") 64 | parser.add_argument('--num_queries', default=5, type=int, 65 | help="Number of query slots, all frames share the same queries") 66 | parser.add_argument('--dec_n_points', default=4, type=int) 67 | parser.add_argument('--enc_n_points', default=4, type=int) 68 | parser.add_argument('--pre_norm', action='store_true') 69 | parser.add_argument('--freeze_text_encoder', action='store_true') # default: False 70 | 71 | # * Segmentation 72 | parser.add_argument('--masks', action='store_true', 73 | help="Train segmentation head if the flag is provided") 74 | parser.add_argument('--mask_dim', default=256, type=int, 75 | help="Size of the mask embeddings (dimension of the dynamic mask conv)") 76 | parser.add_argument('--controller_layers', default=2, type=int, 77 | help="Dynamic conv layer number") 78 | parser.add_argument('--dynamic_mask_channels', default=16, type=int, 79 | help="Dynamic conv final channel number") 80 | parser.add_argument('--no_rel_coord', dest='rel_coord', action='store_false', 81 | help="Disables relative coordinates") 82 | 83 | # Loss 84 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 85 | help="Disables auxiliary decoding losses (loss at each layer)") 86 | # * Matcher 87 | parser.add_argument('--set_cost_class', default=2, type=float, 88 | help="Class coefficient in the matching cost") 89 | parser.add_argument('--set_cost_bbox', default=5, type=float, 90 | help="L1 box coefficient in the matching cost") 91 | parser.add_argument('--set_cost_giou', default=2, type=float, 92 | help="giou box coefficient in the matching cost") 93 | parser.add_argument('--set_cost_mask', default=2, type=float, 94 | help="mask coefficient in the matching cost") 95 | parser.add_argument('--set_cost_boundary', default=2, type=float, 96 | help="mask coefficient in the matching cost") 97 | parser.add_argument('--set_cost_dice', default=5, type=float, 98 | help="mask coefficient in the matching cost") 99 | # * Loss coefficients 100 | parser.add_argument('--mask_loss_coef', default=2, type=float) 101 | parser.add_argument('--boundary_loss_coef', default=2, type=float) 102 | parser.add_argument('--dice_loss_coef', default=5, type=float) 103 | parser.add_argument('--cls_loss_coef', default=2, type=float) 104 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 105 | parser.add_argument('--giou_loss_coef', default=2, type=float) 106 | parser.add_argument('--eos_coef', default=0.1, type=float, 107 | help="Relative classification weight of the no-object class") 108 | parser.add_argument('--focal_alpha', default=0.25, type=float) 109 | 110 | # dataset parameters 111 | # ['ytvos', 'davis', 'a2d', 'jhmdb', 'refcoco', 'refcoco+', 'refcocog', 'all'] 112 | # 'all': using the three ref datasets for pretraining 113 | parser.add_argument('--dataset_file', default='ytvos', help='Dataset name') 114 | parser.add_argument('--coco_path', type=str, default='../datasets/coco') 115 | parser.add_argument('--ytvos_path', type=str, default='../datasets/refer_youtube_vos') 116 | parser.add_argument('--davis_path', type=str, default='../datasets/refer_davis') 117 | parser.add_argument('--a2d_path', type=str, default='../datasets/a2d_sentences') 118 | parser.add_argument('--jhmdb_path', type=str, default='../datasets/jhmdb_sentences') 119 | parser.add_argument('--max_skip', default=3, type=int, help="max skip frame number") 120 | parser.add_argument('--max_size', default=640, type=int, help="max size for the frame") 121 | parser.add_argument('--binary', action='store_true') 122 | parser.add_argument('--remove_difficult', action='store_true') 123 | 124 | parser.add_argument('--output_dir', default='output', 125 | help='path where to save, empty for no saving') 126 | parser.add_argument('--device', default='cuda', 127 | help='device to use for training / testing') 128 | parser.add_argument('--seed', default=42, type=int) 129 | parser.add_argument('--resume', default='', help='resume from checkpoint') 130 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 131 | help='start epoch') 132 | parser.add_argument('--eval', default=False, action='store_true') 133 | parser.add_argument('--num_workers', default=4, type=int) 134 | 135 | # test setting 136 | parser.add_argument('--threshold', default=0.5, type=float) # binary threshold for mask 137 | parser.add_argument('--ngpu', default=8, type=int, help='gpu number when inference for ref-ytvos and refer_davis') 138 | parser.add_argument('--split', default='valid', type=str, choices=['valid', 'test']) 139 | parser.add_argument('--visualize', action='store_true', help='whether visualize the masks during inference') 140 | 141 | # distributed training parameters 142 | parser.add_argument('--world_size', default=1, type=int, 143 | help='number of distributed processes') 144 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 145 | parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory') 146 | return parser 147 | 148 | 149 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.12.5 2 | cython 3 | scipy 4 | opencv-python 5 | pillow 6 | scikit-image 7 | timm 8 | einops 9 | pandas 10 | imgaug 11 | h5py 12 | av -------------------------------------------------------------------------------- /scripts/dist_test_davis_videoswinb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | cd .. 4 | 5 | GPUS='0,1' 6 | GPUS_PER_NODE=2 7 | CPUS_PER_TASK=6 8 | PORT=29500 9 | export CUDA_VISIBLE_DEVICES=${GPUS} 10 | echo "using gpus ${GPUS}, master port ${PORT}." 11 | now=$(date +"%T") 12 | echo "Current time : $now" 13 | echo "Current path : $PWD" 14 | 15 | BACKBONE="video_swin_b_p4w7" 16 | BACKBONE_PRETRAINED="./checkpoints/backbones/swin_base_patch244_window877_kinetics600_22k.pth" 17 | OUTPUT_DIR="./checkpoints/results/SgMg_${BACKBONE}_eval" 18 | CHECKPOINT="./checkpoints/sgmg_videosiwnb_ytvos.pth" 19 | python inference_davis.py --with_box_refine --binary --freeze_text_encoder \ 20 | --eval \ 21 | --ngpu=${GPUS_PER_NODE} \ 22 | --output_dir=${OUTPUT_DIR} \ 23 | --resume=${CHECKPOINT} \ 24 | --backbone=${BACKBONE} \ 25 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 26 | --amp \ 27 | 28 | 29 | # evaluation 30 | ANNO0_DIR=${OUTPUT_DIR}/"DVS_Annotations"/"anno_0" 31 | ANNO1_DIR=${OUTPUT_DIR}/"DVS_Annotations"/"anno_1" 32 | ANNO2_DIR=${OUTPUT_DIR}/"DVS_Annotations"/"anno_2" 33 | ANNO3_DIR=${OUTPUT_DIR}/"DVS_Annotations"/"anno_3" 34 | echo "Annotations store at : ${ANNO0_DIR}" 35 | rm ${ANNO0_DIR}"/global_results-val.csv" 36 | rm ${ANNO0_DIR}"/per-sequence_results-val.csv" 37 | rm ${ANNO1_DIR}"/global_results-val.csv" 38 | rm ${ANNO1_DIR}"/per-sequence_results-val.csv" 39 | rm ${ANNO2_DIR}"/global_results-val.csv" 40 | rm ${ANNO2_DIR}"/per-sequence_results-val.csv" 41 | rm ${ANNO3_DIR}"/global_results-val.csv" 42 | rm ${ANNO3_DIR}"/per-sequence_results-val.csv" 43 | 44 | python3 eval_davis.py --results_path=${ANNO0_DIR} 45 | python3 eval_davis.py --results_path=${ANNO1_DIR} 46 | python3 eval_davis.py --results_path=${ANNO2_DIR} 47 | python3 eval_davis.py --results_path=${ANNO3_DIR} 48 | 49 | echo "Working path is: ${OUTPUT_DIR}" 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /scripts/dist_test_ytvos_videoswinb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | cd .. 4 | 5 | GPUS='0,1' 6 | GPUS_PER_NODE=2 7 | CPUS_PER_TASK=6 8 | PORT=29500 9 | export CUDA_VISIBLE_DEVICES=${GPUS} 10 | echo "using gpus ${GPUS}, master port ${PORT}." 11 | now=$(date +"%T") 12 | echo "Current time : $now" 13 | echo "Current path : $PWD" 14 | 15 | BACKBONE="video_swin_b_p4w7" 16 | BACKBONE_PRETRAINED="./checkpoints/backbones/swin_base_patch244_window877_kinetics600_22k.pth" 17 | OUTPUT_DIR="./checkpoints/results/SgMg_${BACKBONE}_eval" 18 | CHECKPOINT="./checkpoints/sgmg_videosiwnb_ytvos.pth" 19 | python inference_ytvos.py --with_box_refine --binary --freeze_text_encoder \ 20 | --eval \ 21 | --ngpu=${GPUS_PER_NODE} \ 22 | --output_dir=${OUTPUT_DIR} \ 23 | --resume=${CHECKPOINT} \ 24 | --backbone=${BACKBONE} \ 25 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 26 | --amp \ 27 | 28 | 29 | -------------------------------------------------------------------------------- /scripts/dist_train_a2d_videoswinb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | cd .. 4 | 5 | GPUS='0,1' 6 | PORT=25503 7 | GPUS_PER_NODE=2 8 | CPUS_PER_TASK=6 9 | export CUDA_VISIBLE_DEVICES=${GPUS} 10 | echo "using gpus ${GPUS}, master port ${PORT}." 11 | now=$(date +"%T") 12 | echo "Current time : $now" 13 | echo "Current path : $PWD" 14 | 15 | BACKBONE="video_swin_b_p4w7" 16 | BACKBONE_PRETRAINED="./checkpoints/backbones/swin_base_patch244_window877_kinetics600_22k.pth" 17 | OUTPUT_DIR="./checkpoints/results/SgMg_${BACKBONE}_finetune_a2d" 18 | EXP_NAME="SgMg_${BACKBONE}_finetune_a2d" 19 | PRETRAINED_WEIGHTS="xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" 20 | CUDA_VISIBLE_DEVICES=${GPUS} OMP_NUM_THREADS=${CPUS_PER_TASK} torchrun --master_port ${PORT} --nproc_per_node=${GPUS_PER_NODE} main.py \ 21 | --with_box_refine --binary --freeze_text_encoder \ 22 | --exp_name=${EXP_NAME} \ 23 | --output_dir=${OUTPUT_DIR} \ 24 | --backbone=${BACKBONE} \ 25 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 26 | --dataset_file a2d \ 27 | --batch_size 2 \ 28 | --epochs 6 --lr_drop 3 5 \ 29 | --pretrained_weights=${PRETRAINED_WEIGHTS} \ 30 | --use_checkpoint \ 31 | 32 | 33 | -------------------------------------------------------------------------------- /scripts/dist_train_scratch_ytvos_videoswin.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | cd .. 4 | 5 | GPUS='0,1' 6 | PORT=25500 7 | GPUS_PER_NODE=2 8 | CPUS_PER_TASK=6 9 | export CUDA_VISIBLE_DEVICES=${GPUS} 10 | echo "using gpus ${GPUS}, master port ${PORT}." 11 | now=$(date +"%T") 12 | echo "Current time : $now" 13 | echo "Current path : $PWD" 14 | 15 | BACKBONE="video_swin_t_p4w7" 16 | BACKBONE_PRETRAINED="./checkpoints/backbones/swin_tiny_patch244_window877_kinetics400_1k.pth" 17 | OUTPUT_DIR="./checkpoints/results/SgMg_${BACKBONE}_scratch" 18 | EXP_NAME="SgMg_${BACKBONE}_scratch" 19 | CUDA_VISIBLE_DEVICES=${GPUS} OMP_NUM_THREADS=${CPUS_PER_TASK} torchrun --master_port ${PORT} --nproc_per_node=${GPUS_PER_NODE} main.py \ 20 | --with_box_refine --binary --freeze_text_encoder \ 21 | --output_dir=${OUTPUT_DIR} \ 22 | --exp_name=${EXP_NAME} \ 23 | --backbone=${BACKBONE} \ 24 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 25 | --dataset_file ytvos \ 26 | --amp 27 | -------------------------------------------------------------------------------- /scripts/dist_train_ytvos_videoswin.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | cd .. 4 | 5 | GPUS='0,1' 6 | PORT=25500 7 | GPUS_PER_NODE=2 8 | CPUS_PER_TASK=6 9 | export CUDA_VISIBLE_DEVICES=${GPUS} 10 | echo "using gpus ${GPUS}, master port ${PORT}." 11 | now=$(date +"%T") 12 | echo "Current time : $now" 13 | echo "Current path : $PWD" 14 | 15 | BACKBONE="video_swin_t_p4w7" 16 | BACKBONE_PRETRAINED="./checkpoints/backbones/swin_tiny_patch244_window877_kinetics400_1k.pth" 17 | OUTPUT_DIR1="./checkpoints/results/SgMg_${BACKBONE}_pretrain" 18 | EXP_NAME1="SgMg_${BACKBONE}_pretrain" 19 | CUDA_VISIBLE_DEVICES=${GPUS} OMP_NUM_THREADS=${CPUS_PER_TASK} torchrun --master_port ${PORT} --nproc_per_node=${GPUS_PER_NODE} main_pretrain.py \ 20 | --dataset_file all \ 21 | --with_box_refine --binary \ 22 | --output_dir=${OUTPUT_DIR1} \ 23 | --exp_name=${EXP_NAME1} \ 24 | --backbone=${BACKBONE} \ 25 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 26 | --batch_size 2 \ 27 | --num_frames 1 \ 28 | --epochs 11 --lr_drop 8 10 \ 29 | 30 | 31 | OUTPUT_DIR2="./checkpoints/results/SgMg_${BACKBONE}_finetune" 32 | EXP_NAME2="SgMg_${BACKBONE}_finetune" 33 | CUDA_VISIBLE_DEVICES=${GPUS} OMP_NUM_THREADS=${CPUS_PER_TASK} torchrun --master_port ${PORT} --nproc_per_node=${GPUS_PER_NODE} main.py \ 34 | --with_box_refine --binary --freeze_text_encoder \ 35 | --output_dir=${OUTPUT_DIR2} \ 36 | --exp_name=${EXP_NAME2} \ 37 | --backbone=${BACKBONE} \ 38 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 39 | --epochs 6 --lr_drop 3 5 \ 40 | --dataset_file ytvos \ 41 | --pretrained_weights ${OUTPUT_DIR1}"/checkpoint0010.pth" \ 42 | -------------------------------------------------------------------------------- /scripts/dist_train_ytvos_videoswinb.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | cd .. 4 | 5 | GPUS='0,1' 6 | PORT=25501 7 | GPUS_PER_NODE=2 8 | CPUS_PER_TASK=6 9 | export CUDA_VISIBLE_DEVICES=${GPUS} 10 | echo "using gpus ${GPUS}, master port ${PORT}." 11 | now=$(date +"%T") 12 | echo "Current time : $now" 13 | echo "Current path : $PWD" 14 | 15 | BACKBONE="video_swin_b_p4w7" 16 | BACKBONE_PRETRAINED="./checkpoints/backbones/swin_base_patch244_window877_kinetics600_22k.pth" 17 | OUTPUT_DIR1="./checkpoints/results/SgMg_${BACKBONE}_pretrain" 18 | EXP_NAME1="SgMg_${BACKBONE}_pretrain" 19 | CUDA_VISIBLE_DEVICES=${GPUS} OMP_NUM_THREADS=${CPUS_PER_TASK} torchrun --master_port ${PORT} --nproc_per_node=${GPUS_PER_NODE} main_pretrain.py \ 20 | --dataset_file all \ 21 | --with_box_refine --binary \ 22 | --output_dir=${OUTPUT_DIR1} \ 23 | --exp_name=${EXP_NAME1} \ 24 | --backbone=${BACKBONE} \ 25 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 26 | --batch_size 2 \ 27 | --num_frames 1 \ 28 | --epochs 11 --lr_drop 8 10 \ 29 | 30 | 31 | OUTPUT_DIR2="./checkpoints/results/SgMg_${BACKBONE}_finetune" 32 | EXP_NAME2="SgMg_${BACKBONE}_finetune" 33 | CUDA_VISIBLE_DEVICES=${GPUS} OMP_NUM_THREADS=${CPUS_PER_TASK} torchrun --master_port ${PORT} --nproc_per_node=${GPUS_PER_NODE} main.py \ 34 | --with_box_refine --binary --freeze_text_encoder \ 35 | --output_dir=${OUTPUT_DIR2} \ 36 | --exp_name=${EXP_NAME2} \ 37 | --backbone=${BACKBONE} \ 38 | --backbone_pretrained=${BACKBONE_PRETRAINED} \ 39 | --epochs 6 --lr_drop 3 5 \ 40 | --dataset_file ytvos \ 41 | --pretrained_weights ${OUTPUT_DIR1}"/checkpoint0010.pth" \ 42 | 43 | -------------------------------------------------------------------------------- /tools/colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | -------------------------------------------------------------------------------- /tools/data/convert_refexp_to_coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from datasets.refer import REFER 4 | import cv2 5 | from tqdm import tqdm 6 | import json 7 | import pickle 8 | import json 9 | 10 | 11 | def convert_to_coco(data_root='data/coco', output_root='data/coco', dataset='refcoco', dataset_split='unc'): 12 | dataset_dir = os.path.join(data_root, dataset) 13 | output_dir = os.path.join(output_root, dataset) # .json save path 14 | if not os.path.exists(output_dir): 15 | os.makedirs(output_dir) 16 | 17 | # read REFER 18 | refer = REFER(data_root, dataset, dataset_split) 19 | refs = refer.Refs 20 | anns = refer.Anns 21 | imgs = refer.Imgs 22 | cats = refer.Cats 23 | sents = refer.Sents 24 | """ 25 | # create sets of mapping 26 | # 1) Refs: {ref_id: ref} 27 | # 2) Anns: {ann_id: ann} 28 | # 3) Imgs: {image_id: image} 29 | # 4) Cats: {category_id: category_name} 30 | # 5) Sents: {sent_id: sent} 31 | # 6) imgToRefs: {image_id: refs} 32 | # 7) imgToAnns: {image_id: anns} 33 | # 8) refToAnn: {ref_id: ann} 34 | # 9) annToRef: {ann_id: ref} 35 | # 10) catToRefs: {category_id: refs} 36 | # 11) sentToRef: {sent_id: ref} 37 | # 12) sentToTokens: {sent_id: tokens} 38 | 39 | Refs: List[Dict], "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences" 40 | "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent" 41 | Anns: List[Dict], "segmentation", "area", "iscrowd", "image_id", "bbox", "category_id", "id" 42 | Imgs: List[Dict], "license", "file_name", "coco_url", "height", "width", "date_captured", "flickr_url", "id" 43 | Cats: List[Dict], "supercategory", "name", "id" 44 | Sents: List[Dict], "tokens"(List), "raw", "sent_id", "sent", here the "sent_id" is consistent 45 | """ 46 | print('Dataset [%s_%s] contains: ' % (dataset, dataset_split)) 47 | ref_ids = refer.getRefIds() 48 | image_ids = refer.getImgIds() 49 | print('There are %s expressions for %s refereed objects in %s images.' % (len(refer.Sents), len(ref_ids), len(image_ids))) 50 | 51 | print('\nAmong them:') 52 | if dataset == 'refcoco': 53 | splits = ['train', 'val', 'testA', 'testB'] 54 | elif dataset == 'refcoco+': 55 | splits = ['train', 'val', 'testA', 'testB'] 56 | elif dataset == 'refcocog': 57 | splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now. 58 | 59 | for split in splits: 60 | ref_ids = refer.getRefIds(split=split) 61 | print(' %s referred objects are in split [%s].' % (len(ref_ids), split)) 62 | 63 | with open(os.path.join(dataset_dir, "instances.json"), "r") as f: 64 | ann_json = json.load(f) 65 | 66 | 67 | # 1. for each split: train, val... 68 | for split in splits: 69 | max_length = 0 # max length of a sentence 70 | 71 | coco_ann = { 72 | "info": "", 73 | "licenses": "", 74 | "images": [], # each caption is a image sample 75 | "annotations": [], 76 | "categories": [] 77 | } 78 | coco_ann['info'], coco_ann['licenses'], coco_ann['categories'] = \ 79 | ann_json['info'], ann_json['licenses'], ann_json['categories'] 80 | 81 | num_images = 0 # each caption is a sample, create a "images" and a "annotations", since each image has one box 82 | ref_ids = refer.getRefIds(split=split) 83 | # 2. for each referred object 84 | for i in tqdm(ref_ids): 85 | ref = refs[i] 86 | # "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences" 87 | # "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent" 88 | img = imgs[ref["image_id"]] 89 | ann = anns[ref["ann_id"]] 90 | 91 | # 3. for each sentence, which is a sample 92 | for sentence in ref["sentences"]: 93 | num_images += 1 94 | # append image info 95 | image_info = { 96 | "file_name": img["file_name"], 97 | "height": img["height"], 98 | "width": img["width"], 99 | "original_id": img["id"], 100 | "id": num_images, 101 | "caption": sentence["sent"], 102 | "dataset_name": dataset 103 | } 104 | coco_ann["images"].append(image_info) 105 | 106 | # append annotation info 107 | ann_info = { 108 | "segmentation": ann["segmentation"], 109 | "area": ann["area"], 110 | "iscrowd": ann["iscrowd"], 111 | "bbox": ann["bbox"], 112 | "image_id": num_images, 113 | "category_id": ann["category_id"], 114 | "id": num_images, 115 | "original_id": ann["id"] 116 | } 117 | coco_ann["annotations"].append(ann_info) 118 | 119 | max_length = max(max_length, len(sentence["tokens"])) 120 | 121 | print("Total expression: {} in split {}".format(num_images, split)) 122 | print("Max sentence length of the split: ", max_length) 123 | # save the json file 124 | save_file = "instances_{}_{}.json".format(dataset, split) 125 | with open(os.path.join(output_dir, save_file), 'w') as f: 126 | json.dump(coco_ann, f, indent=4) 127 | 128 | if __name__ == '__main__': 129 | datasets = ["refcoco", "refcoco+", "refcocog"] 130 | datasets_split = ["unc", "unc", "umd"] 131 | for (dataset, dataset_split) in zip(datasets, datasets_split): 132 | convert_to_coco(data_root='/home2/useradmin/Documents/dataset/coco', output_root='/home2/useradmin/Documents/dataset/coco2', dataset=dataset, dataset_split=dataset_split) 133 | print("Done.") 134 | 135 | """ 136 | # original mapping 137 | {'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 'airplane': 5, 'bus': 6, 'train': 7, 'truck': 8, 'boat': 9, 138 | 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13, 'parking meter': 14, 'bench': 15, 'bird': 16, 'cat': 17, 139 | 'dog': 18, 'horse': 19, 'sheep': 20, 'cow': 21, 'elephant': 22, 'bear': 23, 'zebra': 24, 'giraffe': 25, 'backpack': 27, 140 | 'umbrella': 28, 'handbag': 31, 'tie': 32, 'suitcase': 33, 'frisbee': 34, 'skis': 35, 'snowboard': 36, 'sports ball': 37, 141 | 'kite': 38, 'baseball bat': 39, 'baseball glove': 40, 'skateboard': 41, 'surfboard': 42, 'tennis racket': 43, 'bottle': 44, 142 | 'wine glass': 46, 'cup': 47, 'fork': 48, 'knife': 49, 'spoon': 50, 'bowl': 51, 'banana': 52, 'apple': 53, 'sandwich': 54, 143 | 'orange': 55, 'broccoli': 56, 'carrot': 57, 'hot dog': 58, 'pizza': 59, 'donut': 60, 'cake': 61, 'chair': 62, 'couch': 63, 144 | 'potted plant': 64, 'bed': 65, 'dining table': 67, 'toilet': 70, 'tv': 72, 'laptop': 73, 'mouse': 74, 'remote': 75, 145 | 'keyboard': 76, 'cell phone': 77, 'microwave': 78, 'oven': 79, 'toaster': 80, 'sink': 81, 'refrigerator': 82, 'book': 84, 146 | 'clock': 85, 'vase': 86, 'scissors': 87, 'teddy bear': 88, 'hair drier': 89, 'toothbrush': 90} 147 | 148 | """ 149 | -------------------------------------------------------------------------------- /tools/load_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def pre_trained_model_to_finetune(checkpoint, args): 4 | checkpoint = checkpoint['model'] 5 | # only delete the class_embed since the finetuned dataset has different num_classes 6 | num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers 7 | for l in range(num_layers): 8 | del checkpoint["class_embed.{}.weight".format(l)] 9 | del checkpoint["class_embed.{}.bias".format(l)] 10 | 11 | return checkpoint 12 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bo-miao/SgMg/90fd3c476858218b1ed0c21ec28e64e762ca2c84/util/__init__.py -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | def clip_iou(boxes1,boxes2): 8 | area1 = box_area(boxes1) 9 | area2 = box_area(boxes2) 10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) 11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) 12 | wh = (rb - lt).clamp(min=0) 13 | inter = wh[:,0] * wh[:,1] 14 | union = area1 + area2 - inter 15 | iou = (inter + 1e-6) / (union+1e-6) 16 | return iou 17 | 18 | def multi_iou(boxes1, boxes2): 19 | lt = torch.max(boxes1[...,:2], boxes2[...,:2]) 20 | rb = torch.min(boxes1[...,2:], boxes2[...,2:]) 21 | wh = (rb - lt).clamp(min=0) 22 | wh_1 = boxes1[...,2:] - boxes1[...,:2] 23 | wh_2 = boxes2[...,2:] - boxes2[...,:2] 24 | inter = wh[...,0] * wh[...,1] 25 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter 26 | iou = (inter + 1e-6) / (union + 1e-6) 27 | return iou 28 | 29 | def box_cxcywh_to_xyxy(x): 30 | x_c, y_c, w, h = x.unbind(-1) 31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 32 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 33 | return torch.stack(b, dim=-1) 34 | 35 | 36 | def box_xyxy_to_cxcywh(x): 37 | x0, y0, x1, y1 = x.unbind(-1) 38 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 39 | (x1 - x0), (y1 - y0)] 40 | return torch.stack(b, dim=-1) 41 | 42 | 43 | # modified from torchvision to also return the union 44 | def box_iou(boxes1, boxes2): 45 | area1 = box_area(boxes1) 46 | area2 = box_area(boxes2) 47 | 48 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 49 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 50 | 51 | wh = (rb - lt).clamp(min=0) # [N,M,2] 52 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 53 | 54 | union = area1[:, None] + area2 - inter 55 | 56 | iou = (inter+1e-6) / (union+1e-6) 57 | return iou, union 58 | 59 | 60 | def generalized_box_iou(boxes1, boxes2): 61 | """ 62 | Generalized IoU from https://giou.stanford.edu/ 63 | 64 | The boxes should be in [x0, y0, x1, y1] format 65 | 66 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 67 | and M = len(boxes2) 68 | """ 69 | # degenerate boxes gives inf / nan results 70 | # so do an early check 71 | # if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): 72 | # for i in range(boxes1.shape[0]): 73 | # if not (boxes1[i, 2:] >= boxes1[i, :2]).all(): 74 | # boxes1[i] = torch.zeros_like(boxes1[i]) 75 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all(), "error boxes: {} vs {}.".format(boxes1, boxes2) 76 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all(), "error boxes: {} vs {}.".format(boxes1, boxes2) 77 | iou, union = box_iou(boxes1, boxes2) 78 | 79 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 80 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 81 | 82 | wh = (rb - lt).clamp(min=0) # [N,M,2] 83 | area = wh[:, :, 0] * wh[:, :, 1] 84 | 85 | return iou - ((area - union) + 1e-6) / (area + 1e-6) 86 | 87 | 88 | def masks_to_boxes(masks): 89 | """Compute the bounding boxes around the provided masks 90 | 91 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 92 | 93 | Returns a [N, 4] tensors, with the boxes in xyxy format 94 | """ 95 | if masks.numel() == 0: 96 | return torch.zeros((0, 4), device=masks.device) 97 | 98 | h, w = masks.shape[-2:] 99 | 100 | y = torch.arange(0, h, dtype=torch.float) 101 | x = torch.arange(0, w, dtype=torch.float) 102 | y, x = torch.meshgrid(y, x) 103 | 104 | x_mask = (masks * x.unsqueeze(0)) 105 | x_max = x_mask.flatten(1).max(-1)[0] 106 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 107 | 108 | y_mask = (masks * y.unsqueeze(0)) 109 | y_max = y_mask.flatten(1).max(-1)[0] 110 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 111 | 112 | return torch.stack([x_min, y_min, x_max, y_max], 1) 113 | -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dumps things to tensorboard and console 3 | """ 4 | 5 | import os 6 | import warnings 7 | import git 8 | 9 | import torchvision.transforms as transforms 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | 13 | def tensor_to_numpy(image): 14 | image_np = (image.numpy() * 255).astype('uint8') 15 | return image_np 16 | 17 | def detach_to_cpu(x): 18 | return x.detach().cpu() 19 | 20 | def fix_width_trunc(x): 21 | return ('{:.9s}'.format('{:0.9f}'.format(x))) 22 | 23 | class TensorboardLogger: 24 | def __init__(self, short_id, id, local_rank): 25 | self.short_id = short_id 26 | if self.short_id == 'NULL': 27 | self.short_id = 'DEBUG' 28 | 29 | if id is None: 30 | self.no_log = True 31 | warnings.warn('Logging has been disbaled.') 32 | else: 33 | self.no_log = False 34 | 35 | self.inv_im_trans = transforms.Normalize( 36 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 37 | std=[1/0.229, 1/0.224, 1/0.225]) 38 | 39 | self.inv_seg_trans = transforms.Normalize( 40 | mean=[-0.5/0.5], 41 | std=[1/0.5]) 42 | 43 | log_path = os.path.join('..', 'log', '%s' % id) 44 | os.makedirs(log_path, exist_ok=True) 45 | self.logger = SummaryWriter(log_path) 46 | 47 | self.local_rank = local_rank 48 | self.values = {} 49 | self.counts = {} 50 | 51 | def log_scalar(self, tag, x, step): 52 | if self.no_log: 53 | warnings.warn('Logging has been disabled.') 54 | return 55 | self.logger.add_scalar(tag, x, step) 56 | 57 | def log_metrics(self, l1_tag, l2_tag, val, step, f=None): 58 | tag = l1_tag + '/' + l2_tag 59 | text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) 60 | if f is not None: 61 | f.write(text + '\n') 62 | f.flush() 63 | self.log_scalar(tag, val, step) 64 | 65 | def log_im(self, tag, x, step): 66 | if self.no_log: 67 | warnings.warn('Logging has been disabled.') 68 | return 69 | x = detach_to_cpu(x) 70 | x = self.inv_im_trans(x) 71 | x = tensor_to_numpy(x) 72 | self.logger.add_image(tag, x, step) 73 | 74 | def log_cv2(self, tag, x, step): 75 | if self.no_log: 76 | warnings.warn('Logging has been disabled.') 77 | return 78 | x = x.transpose((2, 0, 1)) 79 | self.logger.add_image(tag, x, step) 80 | 81 | def log_seg(self, tag, x, step): 82 | if self.no_log: 83 | warnings.warn('Logging has been disabled.') 84 | return 85 | x = detach_to_cpu(x) 86 | x = self.inv_seg_trans(x) 87 | x = tensor_to_numpy(x) 88 | self.logger.add_image(tag, x, step) 89 | 90 | def log_gray(self, tag, x, step): 91 | if self.no_log: 92 | warnings.warn('Logging has been disabled.') 93 | return 94 | x = detach_to_cpu(x) 95 | x = tensor_to_numpy(x) 96 | self.logger.add_image(tag, x, step) 97 | 98 | def log_string(self, tag, x): 99 | print(tag, x) 100 | if self.no_log: 101 | warnings.warn('Logging has been disabled.') 102 | return 103 | self.logger.add_text(tag, x) 104 | 105 | def add_dict(self, tensor_dict, itr): 106 | for k, v in tensor_dict.items(): 107 | self.add_tensor(k, v, itr) 108 | 109 | def add_tensor(self, key, tensor, itr): 110 | if len(key.split("_")) == 3: 111 | self.log_scalar("sublayer_loss/" + key, tensor, itr) 112 | else: 113 | self.log_scalar("main_loss/" + key, tensor, itr) 114 | 115 | 116 | # def add_tensor(self, key, tensor, itr): 117 | # if key not in self.values: 118 | # self.counts[key] = 1 119 | # if type(tensor) == float or type(tensor) == int: 120 | # self.values[key] = tensor 121 | # else: 122 | # self.values[key] = tensor.mean().item() 123 | # else: 124 | # self.counts[key] += 1 125 | # if type(tensor) == float or type(tensor) == int: 126 | # self.values[key] += tensor 127 | # else: 128 | # self.values[key] += tensor.mean().item() 129 | # 130 | # for k, v in self.values.items(): 131 | # if len(k.split("_")) == 3: 132 | # self.log_scalar("sublayer_loss/" + k, v, itr) 133 | # else: 134 | # self.log_scalar("main_loss/"+k, v, itr) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def pre_trained_model_to_finetune(checkpoint, args): 6 | checkpoint = checkpoint['model'] 7 | # only delete the class_embed since the finetuned dataset has different num_classes 8 | num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers 9 | for l in range(num_layers): 10 | del checkpoint["class_embed.{}.weight".format(l)] 11 | del checkpoint["class_embed.{}.bias".format(l)] 12 | 13 | return checkpoint 14 | 15 | 16 | 17 | def colormap(rgb=False): 18 | color_list = np.array( 19 | [ 20 | 0.000, 0.447, 0.741, 21 | 0.850, 0.325, 0.098, 22 | 0.929, 0.694, 0.125, 23 | 0.494, 0.184, 0.556, 24 | 0.466, 0.674, 0.188, 25 | 0.301, 0.745, 0.933, 26 | 0.635, 0.078, 0.184, 27 | 0.300, 0.300, 0.300, 28 | 0.600, 0.600, 0.600, 29 | 1.000, 0.000, 0.000, 30 | 1.000, 0.500, 0.000, 31 | 0.749, 0.749, 0.000, 32 | 0.000, 1.000, 0.000, 33 | 0.000, 0.000, 1.000, 34 | 0.667, 0.000, 1.000, 35 | 0.333, 0.333, 0.000, 36 | 0.333, 0.667, 0.000, 37 | 0.333, 1.000, 0.000, 38 | 0.667, 0.333, 0.000, 39 | 0.667, 0.667, 0.000, 40 | 0.667, 1.000, 0.000, 41 | 1.000, 0.333, 0.000, 42 | 1.000, 0.667, 0.000, 43 | 1.000, 1.000, 0.000, 44 | 0.000, 0.333, 0.500, 45 | 0.000, 0.667, 0.500, 46 | 0.000, 1.000, 0.500, 47 | 0.333, 0.000, 0.500, 48 | 0.333, 0.333, 0.500, 49 | 0.333, 0.667, 0.500, 50 | 0.333, 1.000, 0.500, 51 | 0.667, 0.000, 0.500, 52 | 0.667, 0.333, 0.500, 53 | 0.667, 0.667, 0.500, 54 | 0.667, 1.000, 0.500, 55 | 1.000, 0.000, 0.500, 56 | 1.000, 0.333, 0.500, 57 | 1.000, 0.667, 0.500, 58 | 1.000, 1.000, 0.500, 59 | 0.000, 0.333, 1.000, 60 | 0.000, 0.667, 1.000, 61 | 0.000, 1.000, 1.000, 62 | 0.333, 0.000, 1.000, 63 | 0.333, 0.333, 1.000, 64 | 0.333, 0.667, 1.000, 65 | 0.333, 1.000, 1.000, 66 | 0.667, 0.000, 1.000, 67 | 0.667, 0.333, 1.000, 68 | 0.667, 0.667, 1.000, 69 | 0.667, 1.000, 1.000, 70 | 1.000, 0.000, 1.000, 71 | 1.000, 0.333, 1.000, 72 | 1.000, 0.667, 1.000, 73 | 0.167, 0.000, 0.000, 74 | 0.333, 0.000, 0.000, 75 | 0.500, 0.000, 0.000, 76 | 0.667, 0.000, 0.000, 77 | 0.833, 0.000, 0.000, 78 | 1.000, 0.000, 0.000, 79 | 0.000, 0.167, 0.000, 80 | 0.000, 0.333, 0.000, 81 | 0.000, 0.500, 0.000, 82 | 0.000, 0.667, 0.000, 83 | 0.000, 0.833, 0.000, 84 | 0.000, 1.000, 0.000, 85 | 0.000, 0.000, 0.167, 86 | 0.000, 0.000, 0.333, 87 | 0.000, 0.000, 0.500, 88 | 0.000, 0.000, 0.667, 89 | 0.000, 0.000, 0.833, 90 | 0.000, 0.000, 1.000, 91 | 0.000, 0.000, 0.000, 92 | 0.143, 0.143, 0.143, 93 | 0.286, 0.286, 0.286, 94 | 0.429, 0.429, 0.429, 95 | 0.571, 0.571, 0.571, 96 | 0.714, 0.714, 0.714, 97 | 0.857, 0.857, 0.857, 98 | 1.000, 1.000, 1.000 99 | ] 100 | ).astype(np.float32) 101 | color_list = color_list.reshape((-1, 3)) * 255 102 | if not rgb: 103 | color_list = color_list[:, ::-1] 104 | return color_list --------------------------------------------------------------------------------