├── tools ├── __init__.py ├── load_pretrained_weights.py ├── colormap.py └── data │ └── convert_refexp_to_coco.py ├── util ├── __init__.py └── box_ops.py ├── docs ├── network.png ├── install.md ├── Ref-DAVIS2017.md ├── Ref-YouTube-VOS.md └── data.md ├── davis2017 ├── __init__.py ├── results.py ├── davis.py ├── evaluation.py ├── utils.py └── metrics.py ├── models ├── __init__.py ├── ops │ ├── make.sh │ ├── modules │ │ ├── __init__.py │ │ └── ms_deform_attn.py │ ├── functions │ │ ├── __init__.py │ │ └── ms_deform_attn_func.py │ ├── src │ │ ├── vision.cpp │ │ ├── cuda │ │ │ ├── ms_deform_attn_cuda.h │ │ │ └── ms_deform_attn_cuda.cu │ │ ├── cpu │ │ │ ├── ms_deform_attn_cpu.h │ │ │ └── ms_deform_attn_cpu.cpp │ │ └── ms_deform_attn.h │ ├── setup.py │ └── test.py ├── backbone.py ├── position_encoding.py ├── postprocessors.py ├── convnext.py ├── matcher.py └── criterion.py ├── requirements.txt ├── scripts ├── dist_test_davis.sh ├── dist_test_ytvos.sh └── dist_train.sh ├── LICENSE.txt ├── datasets ├── concat_dataset.py ├── __init__.py ├── refexp_eval.py ├── categories.py ├── image_to_seq_augmenter.py ├── a2d_eval.py ├── coco.py ├── samplers.py ├── refexp.py ├── coco_eval.py ├── refexp2seq.py └── davis.py ├── README.md └── opts.py /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenGVLab/MUTR/HEAD/docs/network.png -------------------------------------------------------------------------------- /davis2017/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | __version__ = '0.1.0' 4 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .mutr import build 2 | 3 | 4 | def build_model(args): 5 | return build(args) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | cython 3 | scipy 4 | opencv-python 5 | pillow 6 | scikit-image 7 | timm 8 | einops 9 | pandas 10 | imgaug 11 | h5py 12 | av -------------------------------------------------------------------------------- /scripts/dist_test_davis.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | 4 | GPUS=1 5 | 6 | python3 inference_davis.py --with_box_refine --binary --freeze_text_encoder --ngpu ${GPUS} --backbone $1 7 | 8 | -------------------------------------------------------------------------------- /scripts/dist_test_ytvos.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | 4 | GPUS=1 5 | 6 | python3 inference_ytvos.py --with_box_refine --binary --freeze_text_encoder --ngpu ${GPUS} --backbone $1 7 | 8 | -------------------------------------------------------------------------------- /tools/load_pretrained_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def pre_trained_model_to_finetune(checkpoint, args): 4 | checkpoint = checkpoint['model'] 5 | # only delete the class_embed since the finetuned dataset has different num_classes 6 | num_layers = args.dec_layers + 1 if args.two_stage else args.dec_layers 7 | for l in range(num_layers): 8 | del checkpoint["class_embed.{}.weight".format(l)] 9 | del checkpoint["class_embed.{}.bias".format(l)] 10 | 11 | return checkpoint 12 | -------------------------------------------------------------------------------- /scripts/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | set -x 3 | 4 | GPUS=${GPUS:-8} 5 | PORT=${PORT:-29500} 6 | if [ $GPUS -lt 8 ]; then 7 | GPUS_PER_NODE=${GPUS_PER_NODE:-$GPUS} 8 | else 9 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 10 | fi 11 | 12 | PY_ARGS=${@:1} # Any other arguments 13 | python3 -m torch.distributed.launch --nproc_per_node=${GPUS_PER_NODE} --master_port=${PORT} --use_env \ 14 | train.py \ 15 | --with_box_refine \ 16 | --dataset_file all \ 17 | --binary \ 18 | --batch_size 2 \ 19 | --epochs 12 \ 20 | --lr_drop 8 10 \ 21 | ${PY_ARGS} 22 | -------------------------------------------------------------------------------- /models/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /models/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /models/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /models/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /docs/install.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | We provide the instructions to install the dependency packages. 4 | 5 | ## Requirements 6 | 7 | We test the code in the following environments, other versions may also be compatible: 8 | 9 | - CUDA 11.1 10 | - Python 3.7 11 | - Pytorch 1.8.1 12 | 13 | 14 | 15 | ## Setup 16 | 17 | First, clone the repository locally. 18 | 19 | ``` 20 | https://github.com/OpenGVLab/MUTR.git 21 | ``` 22 | 23 | Then, install Pytorch 1.8.1 using the conda environment. 24 | ``` 25 | conda install pytorch==1.8.1 torchvision==0.9.1 torchaudio==0.8.1 -c pytorch 26 | ``` 27 | 28 | Install the necessary packages and pycocotools. 29 | 30 | ``` 31 | pip install -r requirements.txt 32 | pip install 'git+https://github.com/facebookresearch/fvcore' 33 | pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI' 34 | ``` 35 | 36 | Finally, compile CUDA operators. 37 | 38 | ``` 39 | cd models/ops 40 | python setup.py build install 41 | cd ../.. 42 | ``` 43 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Renrui Zhang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /datasets/concat_dataset.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # ------------------------------------------------------------------------ 5 | 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | 11 | from torch.utils.data import Dataset, ConcatDataset 12 | from .refexp2seq import build as build_seq_refexp 13 | from .ytvos import build as build_ytvs 14 | from datasets import ytvos 15 | 16 | 17 | 18 | def build(image_set, args): 19 | concat_data = [] 20 | 21 | print('preparing coco2seq dataset ....') 22 | coco_names = ["refcoco", "refcoco+", "refcocog"] 23 | for name in coco_names: 24 | coco_seq = build_seq_refexp(name, image_set, args) 25 | concat_data.append(coco_seq) 26 | 27 | print('preparing ytvos dataset .... ') 28 | ytvos_dataset = build_ytvs(image_set, args) 29 | concat_data.append(ytvos_dataset) 30 | 31 | concat_data = ConcatDataset(concat_data) 32 | 33 | return concat_data 34 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .ytvos import build as build_ytvos 5 | from .davis import build as build_davis 6 | from .refexp import build as build_refexp 7 | from .refexp2seq import build as build_seq_refexp 8 | from .concat_dataset import build as build_joint 9 | 10 | 11 | def get_coco_api_from_dataset(dataset): 12 | for _ in range(10): 13 | # if isinstance(dataset, torchvision.datasets.CocoDetection): 14 | # break 15 | if isinstance(dataset, torch.utils.data.Subset): 16 | dataset = dataset.dataset 17 | if isinstance(dataset, torchvision.datasets.CocoDetection): 18 | return dataset.coco 19 | 20 | 21 | def build_dataset(dataset_file: str, image_set: str, args): 22 | if dataset_file == 'ytvos': 23 | return build_ytvos(image_set, args) 24 | if dataset_file == 'davis': 25 | return build_davis(image_set, args) 26 | if dataset_file == "refcoco" or dataset_file == "refcoco+" or dataset_file == "refcocog": 27 | return build_seq_refexp(dataset_file, image_set, args) 28 | # for joint training of refcoco and ytvos 29 | if dataset_file == 'joint': 30 | return build_joint(image_set, args) 31 | raise ValueError(f'dataset {dataset_file} not supported') 32 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /davis2017/results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import sys 5 | 6 | 7 | class Results(object): 8 | def __init__(self, root_dir): 9 | self.root_dir = root_dir 10 | 11 | def _read_mask(self, sequence, frame_id): 12 | try: 13 | mask_path = os.path.join(self.root_dir, sequence, f'{frame_id}.png') 14 | return np.array(Image.open(mask_path)) 15 | except IOError as err: 16 | sys.stdout.write(sequence + " frame %s not found!\n" % frame_id) 17 | sys.stdout.write("The frames have to be indexed PNG files placed inside the corespondent sequence " 18 | "folder.\nThe indexes have to match with the initial frame.\n") 19 | sys.stderr.write("IOError: " + err.strerror + "\n") 20 | sys.exit() 21 | 22 | def read_masks(self, sequence, masks_id): 23 | mask_0 = self._read_mask(sequence, masks_id[0]) 24 | masks = np.zeros((len(masks_id), *mask_0.shape)) 25 | for ii, m in enumerate(masks_id): 26 | masks[ii, ...] = self._read_mask(sequence, m) 27 | num_objects = int(np.max(masks)) 28 | tmp = np.ones((num_objects, *masks.shape)) 29 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 30 | masks = (tmp == masks[None, ...]) > 0 31 | return masks 32 | -------------------------------------------------------------------------------- /models/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /docs/Ref-DAVIS2017.md: -------------------------------------------------------------------------------- 1 | ## Ref-DAVIS 2017 2 | 3 | ### Model Zoo 4 | 5 | As described in the paper, we report the results using the model trained on Ref-Youtube-VOS without finetune. 6 | 7 | | Backbone| J&F | J | F | Model | 8 | | :----: | :----: | :----: | :----: | :----: | 9 | | ResNet-50 | 65.3 | 62.4 | 68.2 | [model](https://drive.google.com/file/d/1bNkR4n7be3hYwtaYp75c2WNrbmyh-Ik2/view?usp=sharing) | 10 | | ResNet-101 | 65.3 | 61.9 | 68.6 | [model](https://drive.google.com/file/d/1ZOev9AZM_GRpnsKjg0_gpRFKGrJcm0S5/view?usp=sharing) | 11 | | Swin-L | 68.0 | 64.8 | 71.3 | [model](https://drive.google.com/file/d/1e2-BXV3HGxPxWFKO-z34PZDBShCzEmz9/view?usp=sharing) | 12 | | Video-Swin-T | 66.5 | 63.0 | 70.0 | [model](https://drive.google.com/file/d/1-TkdQksTrmB253ao99NgnmsrsQkous2V/view?usp=sharing) | 13 | | Video-Swin-S | 66.1 | 62.6 | 69.8 | [model](https://drive.google.com/file/d/1gVeOE20nmZzONTQSdBhPHBg_hBZnXgxI/view?usp=sharing) | 14 | | Video-Swin-B | 66.4 | 62.8 | 70.0 | [model](https://drive.google.com/file/d/11poAYPbJDB2R_DlsDhRrSYvgOzaihpTN/view?usp=sharing) | 15 | | ConvNext-L | 69.0 | 65.6 | 72.4 | [model](https://drive.google.com/file/d/1d6C73EmSpQZBIuhBDu1gnzibDXYxCYDz/view?usp=sharing) | 16 | | ConvMAE-B | 69.2 | 65.6 | 72.8 | [model](https://drive.google.com/file/d/1kM9VLjdzl_YKN09WD6iSzmvtxVYU_NiE/view?usp=sharing) | 17 | 18 | 19 | ### Inference & Evaluation 20 | 21 | ``` 22 | ./scripts/dist_test_davis.sh --backbone [backbone] 23 | ``` 24 | 25 | For example, evaluating the Swin-Large model, run the following command: 26 | 27 | ``` 28 | ./scripts/dist_test_davis.sh --backbone swin_l_p4w7 29 | ``` 30 | Note that, if you use the weights we provide, you should put the weights in the corresponding path. ./results/[backbone]/ckpt/backbone_weight.pth 31 | -------------------------------------------------------------------------------- /models/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /docs/Ref-YouTube-VOS.md: -------------------------------------------------------------------------------- 1 | ### Ref-YouTube-VOS 2 | 3 | To evaluate the results, please upload the zip file to the [competition server](https://codalab.lisn.upsaclay.fr/competitions/3282#participate-submit_results). 4 | 5 | 6 | | Backbone| J&F | J | F | Model | Submission | 7 | | :----: | :----: | :----: | :----: | :----: | :----: | 8 | | ResNet-50 | 61.9 | 60.4 | 63.4 | [model](https://drive.google.com/file/d/1bNkR4n7be3hYwtaYp75c2WNrbmyh-Ik2/view?usp=sharing) | [link]() | 9 | | ResNet-101 | 63.6 | 61.8 | 65.4 | [model](https://drive.google.com/file/d/1ZOev9AZM_GRpnsKjg0_gpRFKGrJcm0S5/view?usp=sharing) | [link]() | 10 | | Swin-L | 68.4 | 66.4 | 70.4 | [model](https://drive.google.com/file/d/1e2-BXV3HGxPxWFKO-z34PZDBShCzEmz9/view?usp=sharing) | [link]() | 11 | | Video-Swin-T | 64.0 | 62.2 | 65.8 | [model](https://drive.google.com/file/d/1-TkdQksTrmB253ao99NgnmsrsQkous2V/view?usp=sharing) | [link]() | 12 | | Video-Swin-S | 65.1 | 63.0 | 67.1 | [model](https://drive.google.com/file/d/1gVeOE20nmZzONTQSdBhPHBg_hBZnXgxI/view?usp=sharing) | [link]() | 13 | | Video-Swin-B | 67.5 | 65.4 | 69.6 | [model](https://drive.google.com/file/d/11poAYPbJDB2R_DlsDhRrSYvgOzaihpTN/view?usp=sharing) | [link]() | 14 | | ConvNext-L | 66.7 | 64.8 | 68.7 | [model](https://drive.google.com/file/d/1d6C73EmSpQZBIuhBDu1gnzibDXYxCYDz/view?usp=sharing) | [link]() | 15 | | ConvMAE-B | 66.9 | 64.7 | 69.1 | [model](https://drive.google.com/file/d/1kM9VLjdzl_YKN09WD6iSzmvtxVYU_NiE/view?usp=sharing) | [link]() | 16 | 17 | ### Training 18 | 19 | ``` 20 | ./scripts/dist_train.sh --backbone [backbone] --backbone_pretrained [/path/to/backbone_pretrained_weight] [other args] 21 | ``` 22 | 23 | For example, training the Video-Swin-Tiny model, run the following command: 24 | 25 | ``` 26 | ./scripts/dist_train.sh --backbone video_swin_t_p4w7 --backbone_pretrained video_swin_pretrained/swin_tiny_patch244_window877_kinetics400_1k.pth 27 | ``` 28 | 29 | ### Inference & Evaluation 30 | 31 | Inference using the trained model. 32 | ``` 33 | ./scripts/dist_test_ytvos.sh [backbone] 34 | ``` 35 | 36 | For example, evaluating the Swin-Large model, run the following command: 37 | 38 | ``` 39 | ./scripts/dist_test_ytvos.sh swin_l_p4w7 40 | ``` 41 | 42 | To evaluate the results, please upload the zip file to the [competition server](https://codalab.lisn.upsaclay.fr/competitions/3282#participate-submit_results). 43 | 44 | Note that, if you use the weights we provide, you should put the weights in the corresponding path. ./results/[backbone]/ckpt/backbone_weight.pth 45 | 46 | 47 | -------------------------------------------------------------------------------- /models/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /docs/data.md: -------------------------------------------------------------------------------- 1 | # Data Preparation 2 | 3 | Create a new directory `data` to store all the datasets. 4 | 5 | ## Ref-COCO 6 | 7 | Download the dataset from the official website [COCO](https://cocodataset.org/#download). 8 | RefCOCO/+/g use the COCO2014 train split. 9 | Download the annotation files from [github](https://github.com/lichengunc/refer). 10 | 11 | Convert the annotation files: 12 | 13 | ``` 14 | python3 tools/data/convert_refexp_to_coco.py 15 | ``` 16 | 17 | Finally, we expect the directory structure to be the following: 18 | 19 | ``` 20 | MUTR 21 | ├── data 22 | │ ├── coco 23 | │ │ ├── train2014 24 | │ │ ├── refcoco 25 | │ │ │ ├── instances_refcoco_train.json 26 | │ │ │ ├── instances_refcoco_val.json 27 | │ │ ├── refcoco+ 28 | │ │ │ ├── instances_refcoco+_train.json 29 | │ │ │ ├── instances_refcoco+_val.json 30 | │ │ ├── refcocog 31 | │ │ │ ├── instances_refcocog_train.json 32 | │ │ │ ├── instances_refcocog_val.json 33 | ``` 34 | 35 | 36 | ## Ref-YouTube-VOS 37 | 38 | Download the dataset from the competition's website [here](https://competitions.codalab.org/competitions/29139#participate-get_data). 39 | Then, extract and organize the file. We expect the directory structure to be the following: 40 | 41 | ``` 42 | MUTR 43 | ├── data 44 | │ ├── ref-youtube-vos 45 | │ │ ├── meta_expressions 46 | │ │ ├── train 47 | │ │ │ ├── JPEGImages 48 | │ │ │ ├── Annotations 49 | │ │ │ ├── meta.json 50 | │ │ ├── valid 51 | │ │ │ ├── JPEGImages 52 | ``` 53 | 54 | ## Ref-DAVIS 2017 55 | 56 | Downlaod the DAVIS2017 dataset from the [website](https://davischallenge.org/davis2017/code.html). Note that you only need to download the two zip files `DAVIS-2017-Unsupervised-trainval-480p.zip` and `DAVIS-2017_semantics-480p.zip`. 57 | Download the text annotations from the [website](https://www.mpi-inf.mpg.de/departments/computer-vision-and-machine-learning/research/video-segmentation/video-object-segmentation-with-language-referring-expressions). 58 | Then, put the zip files in the directory as follows. 59 | 60 | 61 | ``` 62 | MUTR 63 | ├── data 64 | │ ├── ref-davis 65 | │ │ ├── DAVIS-2017_semantics-480p.zip 66 | │ │ ├── DAVIS-2017-Unsupervised-trainval-480p.zip 67 | │ │ ├── davis_text_annotations.zip 68 | ``` 69 | 70 | Unzip these zip files. 71 | ``` 72 | unzip -o davis_text_annotations.zip 73 | unzip -o DAVIS-2017_semantics-480p.zip 74 | unzip -o DAVIS-2017-Unsupervised-trainval-480p.zip 75 | ``` 76 | 77 | Preprocess the dataset to Ref-Youtube-VOS format. (Make sure you are in the main directory) 78 | 79 | ``` 80 | python tools/data/convert_davis_to_ytvos.py 81 | ``` 82 | 83 | Finally, unzip the file `DAVIS-2017-Unsupervised-trainval-480p.zip` again (since we use `mv` in preprocess for efficiency). 84 | 85 | ``` 86 | unzip -o DAVIS-2017-Unsupervised-trainval-480p.zip 87 | ``` 88 | -------------------------------------------------------------------------------- /tools/colormap.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def colormap(rgb=False): 5 | color_list = np.array( 6 | [ 7 | 0.000, 0.447, 0.741, 8 | 0.850, 0.325, 0.098, 9 | 0.929, 0.694, 0.125, 10 | 0.494, 0.184, 0.556, 11 | 0.466, 0.674, 0.188, 12 | 0.301, 0.745, 0.933, 13 | 0.635, 0.078, 0.184, 14 | 0.300, 0.300, 0.300, 15 | 0.600, 0.600, 0.600, 16 | 1.000, 0.000, 0.000, 17 | 1.000, 0.500, 0.000, 18 | 0.749, 0.749, 0.000, 19 | 0.000, 1.000, 0.000, 20 | 0.000, 0.000, 1.000, 21 | 0.667, 0.000, 1.000, 22 | 0.333, 0.333, 0.000, 23 | 0.333, 0.667, 0.000, 24 | 0.333, 1.000, 0.000, 25 | 0.667, 0.333, 0.000, 26 | 0.667, 0.667, 0.000, 27 | 0.667, 1.000, 0.000, 28 | 1.000, 0.333, 0.000, 29 | 1.000, 0.667, 0.000, 30 | 1.000, 1.000, 0.000, 31 | 0.000, 0.333, 0.500, 32 | 0.000, 0.667, 0.500, 33 | 0.000, 1.000, 0.500, 34 | 0.333, 0.000, 0.500, 35 | 0.333, 0.333, 0.500, 36 | 0.333, 0.667, 0.500, 37 | 0.333, 1.000, 0.500, 38 | 0.667, 0.000, 0.500, 39 | 0.667, 0.333, 0.500, 40 | 0.667, 0.667, 0.500, 41 | 0.667, 1.000, 0.500, 42 | 1.000, 0.000, 0.500, 43 | 1.000, 0.333, 0.500, 44 | 1.000, 0.667, 0.500, 45 | 1.000, 1.000, 0.500, 46 | 0.000, 0.333, 1.000, 47 | 0.000, 0.667, 1.000, 48 | 0.000, 1.000, 1.000, 49 | 0.333, 0.000, 1.000, 50 | 0.333, 0.333, 1.000, 51 | 0.333, 0.667, 1.000, 52 | 0.333, 1.000, 1.000, 53 | 0.667, 0.000, 1.000, 54 | 0.667, 0.333, 1.000, 55 | 0.667, 0.667, 1.000, 56 | 0.667, 1.000, 1.000, 57 | 1.000, 0.000, 1.000, 58 | 1.000, 0.333, 1.000, 59 | 1.000, 0.667, 1.000, 60 | 0.167, 0.000, 0.000, 61 | 0.333, 0.000, 0.000, 62 | 0.500, 0.000, 0.000, 63 | 0.667, 0.000, 0.000, 64 | 0.833, 0.000, 0.000, 65 | 1.000, 0.000, 0.000, 66 | 0.000, 0.167, 0.000, 67 | 0.000, 0.333, 0.000, 68 | 0.000, 0.500, 0.000, 69 | 0.000, 0.667, 0.000, 70 | 0.000, 0.833, 0.000, 71 | 0.000, 1.000, 0.000, 72 | 0.000, 0.000, 0.167, 73 | 0.000, 0.000, 0.333, 74 | 0.000, 0.000, 0.500, 75 | 0.000, 0.000, 0.667, 76 | 0.000, 0.000, 0.833, 77 | 0.000, 0.000, 1.000, 78 | 0.000, 0.000, 0.000, 79 | 0.143, 0.143, 0.143, 80 | 0.286, 0.286, 0.286, 81 | 0.429, 0.429, 0.429, 82 | 0.571, 0.571, 0.571, 83 | 0.714, 0.714, 0.714, 84 | 0.857, 0.857, 0.857, 85 | 1.000, 1.000, 1.000 86 | ] 87 | ).astype(np.float32) 88 | color_list = color_list.reshape((-1, 3)) * 255 89 | if not rgb: 90 | color_list = color_list[:, ::-1] 91 | return color_list -------------------------------------------------------------------------------- /models/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | ctx.im2col_step = im2col_step 25 | output = MSDA.ms_deform_attn_forward( 26 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 27 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 28 | return output 29 | 30 | @staticmethod 31 | @once_differentiable 32 | def backward(ctx, grad_output): 33 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 34 | grad_value, grad_sampling_loc, grad_attn_weight = \ 35 | MSDA.ms_deform_attn_backward( 36 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 37 | 38 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 39 | 40 | 41 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 42 | # for debug and test only, 43 | # need to use cuda version instead 44 | N_, S_, M_, D_ = value.shape 45 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 46 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 47 | sampling_grids = 2 * sampling_locations - 1 48 | sampling_value_list = [] 49 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 50 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 51 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 52 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 53 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 54 | # N_*M_, D_, Lq_, P_ 55 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 56 | mode='bilinear', padding_mode='zeros', align_corners=False) 57 | sampling_value_list.append(sampling_value_l_) 58 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 59 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 60 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 61 | return output.transpose(1, 2).contiguous() 62 | -------------------------------------------------------------------------------- /datasets/refexp_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | import copy 3 | from collections import defaultdict 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.utils.data 8 | 9 | import util.misc as utils 10 | from util.box_ops import generalized_box_iou 11 | 12 | 13 | class RefExpEvaluator(object): 14 | def __init__(self, refexp_gt, iou_types, k=(1, 5, 10), thresh_iou=0.5): 15 | assert isinstance(k, (list, tuple)) 16 | refexp_gt = copy.deepcopy(refexp_gt) 17 | self.refexp_gt = refexp_gt 18 | self.iou_types = iou_types 19 | self.img_ids = self.refexp_gt.imgs.keys() 20 | self.predictions = {} 21 | self.k = k 22 | self.thresh_iou = thresh_iou 23 | 24 | def accumulate(self): 25 | pass 26 | 27 | def update(self, predictions): 28 | self.predictions.update(predictions) 29 | 30 | def synchronize_between_processes(self): 31 | all_predictions = utils.all_gather(self.predictions) 32 | merged_predictions = {} 33 | for p in all_predictions: 34 | merged_predictions.update(p) 35 | self.predictions = merged_predictions 36 | 37 | def summarize(self): 38 | if utils.is_main_process(): 39 | dataset2score = { 40 | "refcoco": {k: 0.0 for k in self.k}, 41 | "refcoco+": {k: 0.0 for k in self.k}, 42 | "refcocog": {k: 0.0 for k in self.k}, 43 | } 44 | dataset2count = {"refcoco": 0.0, "refcoco+": 0.0, "refcocog": 0.0} 45 | for image_id in self.img_ids: 46 | ann_ids = self.refexp_gt.getAnnIds(imgIds=image_id) 47 | assert len(ann_ids) == 1 48 | img_info = self.refexp_gt.loadImgs(image_id)[0] 49 | 50 | target = self.refexp_gt.loadAnns(ann_ids[0]) 51 | prediction = self.predictions[image_id] 52 | assert prediction is not None 53 | sorted_scores_boxes = sorted( 54 | zip(prediction["scores"].tolist(), prediction["boxes"].tolist()), reverse=True 55 | ) 56 | sorted_scores, sorted_boxes = zip(*sorted_scores_boxes) 57 | sorted_boxes = torch.cat([torch.as_tensor(x).view(1, 4) for x in sorted_boxes]) 58 | target_bbox = target[0]["bbox"] 59 | converted_bbox = [ 60 | target_bbox[0], 61 | target_bbox[1], 62 | target_bbox[2] + target_bbox[0], 63 | target_bbox[3] + target_bbox[1], 64 | ] 65 | giou = generalized_box_iou(sorted_boxes, torch.as_tensor(converted_bbox).view(-1, 4)) 66 | for k in self.k: 67 | if max(giou[:k]) >= self.thresh_iou: 68 | dataset2score[img_info["dataset_name"]][k] += 1.0 69 | dataset2count[img_info["dataset_name"]] += 1.0 70 | 71 | for key, value in dataset2score.items(): 72 | for k in self.k: 73 | try: 74 | value[k] /= dataset2count[key] 75 | except: 76 | pass 77 | results = {} 78 | for key, value in dataset2score.items(): 79 | results[key] = sorted([v for k, v in value.items()]) 80 | print(f" Dataset: {key} - Precision @ 1, 5, 10: {results[key]} \n") 81 | 82 | return results 83 | return None 84 | 85 | 86 | -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | def clip_iou(boxes1,boxes2): 8 | area1 = box_area(boxes1) 9 | area2 = box_area(boxes2) 10 | lt = torch.max(boxes1[:, :2], boxes2[:, :2]) 11 | rb = torch.min(boxes1[:, 2:], boxes2[:, 2:]) 12 | wh = (rb - lt).clamp(min=0) 13 | inter = wh[:,0] * wh[:,1] 14 | union = area1 + area2 - inter 15 | iou = (inter + 1e-6) / (union+1e-6) 16 | return iou 17 | 18 | def multi_iou(boxes1, boxes2): 19 | lt = torch.max(boxes1[...,:2], boxes2[...,:2]) 20 | rb = torch.min(boxes1[...,2:], boxes2[...,2:]) 21 | wh = (rb - lt).clamp(min=0) 22 | wh_1 = boxes1[...,2:] - boxes1[...,:2] 23 | wh_2 = boxes2[...,2:] - boxes2[...,:2] 24 | inter = wh[...,0] * wh[...,1] 25 | union = wh_1[...,0] * wh_1[...,1] + wh_2[...,0] * wh_2[...,1] - inter 26 | iou = (inter + 1e-6) / (union + 1e-6) 27 | return iou 28 | 29 | def box_cxcywh_to_xyxy(x): 30 | x_c, y_c, w, h = x.unbind(-1) 31 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 32 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 33 | return torch.stack(b, dim=-1) 34 | 35 | 36 | def box_xyxy_to_cxcywh(x): 37 | x0, y0, x1, y1 = x.unbind(-1) 38 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 39 | (x1 - x0), (y1 - y0)] 40 | return torch.stack(b, dim=-1) 41 | 42 | 43 | # modified from torchvision to also return the union 44 | def box_iou(boxes1, boxes2): 45 | area1 = box_area(boxes1) 46 | area2 = box_area(boxes2) 47 | 48 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 49 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 50 | 51 | wh = (rb - lt).clamp(min=0) # [N,M,2] 52 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 53 | 54 | union = area1[:, None] + area2 - inter 55 | 56 | iou = (inter+1e-6) / (union+1e-6) 57 | return iou, union 58 | 59 | 60 | def generalized_box_iou(boxes1, boxes2): 61 | """ 62 | Generalized IoU from https://giou.stanford.edu/ 63 | 64 | The boxes should be in [x0, y0, x1, y1] format 65 | 66 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 67 | and M = len(boxes2) 68 | """ 69 | # degenerate boxes gives inf / nan results 70 | # so do an early check 71 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 72 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 73 | iou, union = box_iou(boxes1, boxes2) 74 | 75 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 76 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 77 | 78 | wh = (rb - lt).clamp(min=0) # [N,M,2] 79 | area = wh[:, :, 0] * wh[:, :, 1] 80 | 81 | return iou - ((area - union) + 1e-6) / (area + 1e-6) 82 | 83 | 84 | def masks_to_boxes(masks): 85 | """Compute the bounding boxes around the provided masks 86 | 87 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 88 | 89 | Returns a [N, 4] tensors, with the boxes in xyxy format 90 | """ 91 | if masks.numel() == 0: 92 | return torch.zeros((0, 4), device=masks.device) 93 | 94 | h, w = masks.shape[-2:] 95 | 96 | y = torch.arange(0, h, dtype=torch.float) 97 | x = torch.arange(0, w, dtype=torch.float) 98 | y, x = torch.meshgrid(y, x) 99 | 100 | x_mask = (masks * x.unsqueeze(0)) 101 | x_max = x_mask.flatten(1).max(-1)[0] 102 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 103 | 104 | y_mask = (masks * y.unsqueeze(0)) 105 | y_max = y_mask.flatten(1).max(-1)[0] 106 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 107 | 108 | return torch.stack([x_min, y_min, x_max, y_max], 1) 109 | -------------------------------------------------------------------------------- /datasets/categories.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------------------------- 2 | # 1. Ref-Youtube-VOS 3 | ytvos_category_dict = { 4 | 'airplane': 0, 'ape': 1, 'bear': 2, 'bike': 3, 'bird': 4, 'boat': 5, 'bucket': 6, 'bus': 7, 'camel': 8, 'cat': 9, 5 | 'cow': 10, 'crocodile': 11, 'deer': 12, 'dog': 13, 'dolphin': 14, 'duck': 15, 'eagle': 16, 'earless_seal': 17, 6 | 'elephant': 18, 'fish': 19, 'fox': 20, 'frisbee': 21, 'frog': 22, 'giant_panda': 23, 'giraffe': 24, 'hand': 25, 7 | 'hat': 26, 'hedgehog': 27, 'horse': 28, 'knife': 29, 'leopard': 30, 'lion': 31, 'lizard': 32, 'monkey': 33, 8 | 'motorbike': 34, 'mouse': 35, 'others': 36, 'owl': 37, 'paddle': 38, 'parachute': 39, 'parrot': 40, 'penguin': 41, 9 | 'person': 42, 'plant': 43, 'rabbit': 44, 'raccoon': 45, 'sedan': 46, 'shark': 47, 'sheep': 48, 'sign': 49, 10 | 'skateboard': 50, 'snail': 51, 'snake': 52, 'snowboard': 53, 'squirrel': 54, 'surfboard': 55, 'tennis_racket': 56, 11 | 'tiger': 57, 'toilet': 58, 'train': 59, 'truck': 60, 'turtle': 61, 'umbrella': 62, 'whale': 63, 'zebra': 64 12 | } 13 | 14 | ytvos_category_list = [ 15 | 'airplane', 'ape', 'bear', 'bike', 'bird', 'boat', 'bucket', 'bus', 'camel', 'cat', 'cow', 'crocodile', 16 | 'deer', 'dog', 'dolphin', 'duck', 'eagle', 'earless_seal', 'elephant', 'fish', 'fox', 'frisbee', 'frog', 17 | 'giant_panda', 'giraffe', 'hand', 'hat', 'hedgehog', 'horse', 'knife', 'leopard', 'lion', 'lizard', 18 | 'monkey', 'motorbike', 'mouse', 'others', 'owl', 'paddle', 'parachute', 'parrot', 'penguin', 'person', 19 | 'plant', 'rabbit', 'raccoon', 'sedan', 'shark', 'sheep', 'sign', 'skateboard', 'snail', 'snake', 'snowboard', 20 | 'squirrel', 'surfboard', 'tennis_racket', 'tiger', 'toilet', 'train', 'truck', 'turtle', 'umbrella', 'whale', 'zebra' 21 | ] 22 | 23 | # ------------------------------------------------------------------------------------------------------------------- 24 | # 2. Ref-DAVIS17 25 | davis_category_dict = { 26 | 'airplane': 0, 'backpack': 1, 'ball': 2, 'bear': 3, 'bicycle': 4, 'bird': 5, 'boat': 6, 'bottle': 7, 'box': 8, 'bus': 9, 27 | 'camel': 10, 'car': 11, 'carriage': 12, 'cat': 13, 'cellphone': 14, 'chamaleon': 15, 'cow': 16, 'deer': 17, 'dog': 18, 28 | 'dolphin': 19, 'drone': 20, 'elephant': 21, 'excavator': 22, 'fish': 23, 'goat': 24, 'golf cart': 25, 'golf club': 26, 29 | 'grass': 27, 'guitar': 28, 'gun': 29, 'helicopter': 30, 'horse': 31, 'hoverboard': 32, 'kart': 33, 'key': 34, 'kite': 35, 30 | 'koala': 36, 'leash': 37, 'lion': 38, 'lock': 39, 'mask': 40, 'microphone': 41, 'monkey': 42, 'motorcycle': 43, 'oar': 44, 31 | 'paper': 45, 'paraglide': 46, 'person': 47, 'pig': 48, 'pole': 49, 'potted plant': 50, 'puck': 51, 'rack': 52, 'rhino': 53, 32 | 'rope': 54, 'sail': 55, 'scale': 56, 'scooter': 57, 'selfie stick': 58, 'sheep': 59, 'skateboard': 60, 'ski': 61, 'ski poles': 62, 33 | 'snake': 63, 'snowboard': 64, 'stick': 65, 'stroller': 66, 'surfboard': 67, 'swing': 68, 'tennis racket': 69, 'tractor': 70, 34 | 'trailer': 71, 'train': 72, 'truck': 73, 'turtle': 74, 'varanus': 75, 'violin': 76, 'wheelchair': 77 35 | } 36 | 37 | davis_category_list = [ 38 | 'airplane', 'backpack', 'ball', 'bear', 'bicycle', 'bird', 'boat', 'bottle', 'box', 'bus', 'camel', 'car', 'carriage', 39 | 'cat', 'cellphone', 'chamaleon', 'cow', 'deer', 'dog', 'dolphin', 'drone', 'elephant', 'excavator', 'fish', 'goat', 40 | 'golf cart', 'golf club', 'grass', 'guitar', 'gun', 'helicopter', 'horse', 'hoverboard', 'kart', 'key', 'kite', 'koala', 41 | 'leash', 'lion', 'lock', 'mask', 'microphone', 'monkey', 'motorcycle', 'oar', 'paper', 'paraglide', 'person', 'pig', 42 | 'pole', 'potted plant', 'puck', 'rack', 'rhino', 'rope', 'sail', 'scale', 'scooter', 'selfie stick', 'sheep', 'skateboard', 43 | 'ski', 'ski poles', 'snake', 'snowboard', 'stick', 'stroller', 'surfboard', 'swing', 'tennis racket', 'tractor', 'trailer', 44 | 'train', 'truck', 'turtle', 'varanus', 'violin', 'wheelchair' 45 | ] -------------------------------------------------------------------------------- /models/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /datasets/image_to_seq_augmenter.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from SeqFormer (https://github.com/wjf5203/SeqFormer) 3 | # ------------------------------------------------------------------------ 4 | # Modified from STEm-Seg (https://github.com/sabarim/STEm-Seg) 5 | # ------------------------------------------------------------------------ 6 | 7 | 8 | import imgaug 9 | import imgaug.augmenters as iaa 10 | import numpy as np 11 | 12 | from datetime import datetime 13 | 14 | from imgaug.augmentables.segmaps import SegmentationMapsOnImage 15 | from imgaug.augmentables.bbs import BoundingBox, BoundingBoxesOnImage 16 | 17 | 18 | class ImageToSeqAugmenter(object): 19 | def __init__(self, perspective=True, affine=True, motion_blur=True, 20 | brightness_range=(-50, 50), hue_saturation_range=(-15, 15), perspective_magnitude=0.12, 21 | scale_range=1.0, translate_range={"x": (-0.15, 0.15), "y": (-0.15, 0.15)}, rotation_range=(-20, 20), 22 | motion_blur_kernel_sizes=(7, 9), motion_blur_prob=0.5): 23 | 24 | self.basic_augmenter = iaa.SomeOf((1, None), [ 25 | iaa.Add(brightness_range), 26 | iaa.AddToHueAndSaturation(hue_saturation_range) 27 | ] 28 | ) 29 | 30 | transforms = [] 31 | if perspective: 32 | transforms.append(iaa.PerspectiveTransform(perspective_magnitude)) 33 | if affine: 34 | transforms.append(iaa.Affine(scale=scale_range, 35 | translate_percent=translate_range, 36 | rotate=rotation_range, 37 | order=1, # cv2.INTER_LINEAR 38 | backend='auto')) 39 | transforms = iaa.Sequential(transforms) 40 | transforms = [transforms] 41 | 42 | if motion_blur: 43 | blur = iaa.Sometimes(motion_blur_prob, iaa.OneOf( 44 | [ 45 | iaa.MotionBlur(ksize) 46 | for ksize in motion_blur_kernel_sizes 47 | ] 48 | )) 49 | transforms.append(blur) 50 | 51 | self.frame_shift_augmenter = iaa.Sequential(transforms) 52 | 53 | @staticmethod 54 | def condense_masks(instance_masks): 55 | condensed_mask = np.zeros_like(instance_masks[0], dtype=np.int8) 56 | for instance_id, mask in enumerate(instance_masks, 1): 57 | condensed_mask = np.where(mask, instance_id, condensed_mask) 58 | 59 | return condensed_mask 60 | 61 | @staticmethod 62 | def expand_masks(condensed_mask, num_instances): 63 | return [(condensed_mask == instance_id).astype(np.uint8) for instance_id in range(1, num_instances + 1)] 64 | 65 | def __call__(self, image, masks=None, boxes=None): 66 | det_augmenter = self.frame_shift_augmenter.to_deterministic() 67 | 68 | 69 | if masks is not None: 70 | masks_np, is_binary_mask = [], [] 71 | boxs_np = [] 72 | 73 | for mask in masks: 74 | 75 | if isinstance(mask, np.ndarray): 76 | masks_np.append(mask.astype(np.bool)) 77 | is_binary_mask.append(False) 78 | else: 79 | raise ValueError("Invalid mask type: {}".format(type(mask))) 80 | 81 | num_instances = len(masks_np) 82 | masks_np = SegmentationMapsOnImage(self.condense_masks(masks_np), shape=image.shape[:2]) 83 | # boxs_np = BoundingBoxesOnImage(boxs_np, shape=image.shape[:2]) 84 | 85 | seed = int(datetime.now().strftime('%M%S%f')[-8:]) 86 | imgaug.seed(seed) 87 | aug_image, aug_masks = det_augmenter(image=self.basic_augmenter(image=image) , segmentation_maps=masks_np) 88 | imgaug.seed(seed) 89 | invalid_pts_mask = det_augmenter(image=np.ones(image.shape[:2] + (1,), np.uint8)).squeeze(2) 90 | aug_masks = self.expand_masks(aug_masks.get_arr(), num_instances) 91 | # aug_boxes = aug_boxes.remove_out_of_image().clip_out_of_image() 92 | aug_masks = [mask for mask, is_bm in zip(aug_masks, is_binary_mask)] 93 | return aug_image, aug_masks #, aug_boxes.to_xyxy_array() 94 | 95 | else: 96 | masks = [SegmentationMapsOnImage(np.ones(image.shape[:2], np.bool), shape=image.shape[:2])] 97 | aug_image, invalid_pts_mask = det_augmenter(image=image, segmentation_maps=masks) 98 | return aug_image, invalid_pts_mask.get_arr() == 0 99 | -------------------------------------------------------------------------------- /datasets/a2d_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains implementations for the precision@k and IoU (mean, overall) evaluation metrics. 3 | copy-paste from https://github.com/mttr2021/MTTR/blob/main/metrics.py 4 | """ 5 | import torch 6 | from tqdm import tqdm 7 | from pycocotools.coco import COCO 8 | from pycocotools.mask import decode 9 | import numpy as np 10 | 11 | from torchvision.ops.boxes import box_area 12 | 13 | def compute_bbox_iou(boxes1: torch.Tensor, boxes2: torch.Tensor): 14 | # both boxes: xyxy 15 | area1 = box_area(boxes1) 16 | area2 = box_area(boxes2) 17 | 18 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 19 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 20 | 21 | wh = (rb - lt).clamp(min=0) # [N,M,2] 22 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 23 | 24 | union = area1[:, None] + area2 - inter 25 | 26 | iou = (inter+1e-6) / (union+1e-6) 27 | return iou, inter, union 28 | 29 | def compute_mask_iou(outputs: torch.Tensor, labels: torch.Tensor, EPS=1e-6): 30 | outputs = outputs.int() 31 | intersection = (outputs & labels).float().sum((1, 2)) # Will be zero if Truth=0 or Prediction=0 32 | union = (outputs | labels).float().sum((1, 2)) # Will be zero if both are 0 33 | iou = (intersection + EPS) / (union + EPS) # EPS is used to avoid division by zero 34 | return iou, intersection, union 35 | 36 | # mask 37 | def calculate_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): 38 | print('evaluating mask precision@k & iou metrics...') 39 | counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} 40 | total_intersection_area = 0 41 | total_union_area = 0 42 | ious_list = [] 43 | for instance in tqdm(coco_gt.imgs.keys()): # each image_id contains exactly one instance 44 | gt_annot = coco_gt.imgToAnns[instance][0] 45 | gt_mask = decode(gt_annot['segmentation']) 46 | pred_annots = coco_pred.imgToAnns[instance] 47 | pred_annot = sorted(pred_annots, key=lambda a: a['score'])[-1] # choose pred with highest score 48 | pred_mask = decode(pred_annot['segmentation']) 49 | iou, intersection, union = compute_mask_iou(torch.tensor(pred_mask).unsqueeze(0), 50 | torch.tensor(gt_mask).unsqueeze(0)) 51 | iou, intersection, union = iou.item(), intersection.item(), union.item() 52 | for iou_threshold in counters_by_iou.keys(): 53 | if iou > iou_threshold: 54 | counters_by_iou[iou_threshold] += 1 55 | total_intersection_area += intersection 56 | total_union_area += union 57 | ious_list.append(iou) 58 | num_samples = len(ious_list) 59 | precision_at_k = np.array(list(counters_by_iou.values())) / num_samples 60 | overall_iou = total_intersection_area / total_union_area 61 | mean_iou = np.mean(ious_list) 62 | return precision_at_k, overall_iou, mean_iou 63 | 64 | # bbox 65 | def calculate_bbox_precision_at_k_and_iou_metrics(coco_gt: COCO, coco_pred: COCO): 66 | print('evaluating bbox precision@k & iou metrics...') 67 | counters_by_iou = {iou: 0 for iou in [0.5, 0.6, 0.7, 0.8, 0.9]} 68 | total_intersection_area = 0 69 | total_union_area = 0 70 | ious_list = [] 71 | for instance in tqdm(coco_gt.imgs.keys()): # each image_id contains exactly one instance 72 | gt_annot = coco_gt.imgToAnns[instance][0] 73 | gt_bbox = gt_annot['bbox'] # xywh 74 | gt_bbox = [ 75 | gt_bbox[0], 76 | gt_bbox[1], 77 | gt_bbox[2] + gt_bbox[0], 78 | gt_bbox[3] + gt_bbox[1], 79 | ] 80 | pred_annots = coco_pred.imgToAnns[instance] 81 | pred_annot = sorted(pred_annots, key=lambda a: a['score'])[-1] # choose pred with highest score 82 | pred_bbox = pred_annot['bbox'] # xyxy 83 | iou, intersection, union = compute_bbox_iou(torch.tensor(pred_bbox).unsqueeze(0), 84 | torch.tensor(gt_bbox).unsqueeze(0)) 85 | iou, intersection, union = iou.item(), intersection.item(), union.item() 86 | for iou_threshold in counters_by_iou.keys(): 87 | if iou > iou_threshold: 88 | counters_by_iou[iou_threshold] += 1 89 | total_intersection_area += intersection 90 | total_union_area += union 91 | ious_list.append(iou) 92 | num_samples = len(ious_list) 93 | precision_at_k = np.array(list(counters_by_iou.values())) / num_samples 94 | overall_iou = total_intersection_area / total_union_area 95 | mean_iou = np.mean(ious_list) 96 | return precision_at_k, overall_iou, mean_iou 97 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backbone modules. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | from einops import rearrange 14 | 15 | from util.misc import NestedTensor, is_main_process 16 | 17 | from .position_encoding import build_position_encoding 18 | 19 | 20 | class FrozenBatchNorm2d(torch.nn.Module): 21 | """ 22 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 23 | 24 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 25 | without which any other models than torchvision.models.resnet[18,34,50,101] 26 | produce nans. 27 | """ 28 | 29 | def __init__(self, n): 30 | super(FrozenBatchNorm2d, self).__init__() 31 | self.register_buffer("weight", torch.ones(n)) 32 | self.register_buffer("bias", torch.zeros(n)) 33 | self.register_buffer("running_mean", torch.zeros(n)) 34 | self.register_buffer("running_var", torch.ones(n)) 35 | 36 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 37 | missing_keys, unexpected_keys, error_msgs): 38 | num_batches_tracked_key = prefix + 'num_batches_tracked' 39 | if num_batches_tracked_key in state_dict: 40 | del state_dict[num_batches_tracked_key] 41 | 42 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 43 | state_dict, prefix, local_metadata, strict, 44 | missing_keys, unexpected_keys, error_msgs) 45 | 46 | def forward(self, x): 47 | # move reshapes to the beginning 48 | # to make it fuser-friendly 49 | w = self.weight.reshape(1, -1, 1, 1) 50 | b = self.bias.reshape(1, -1, 1, 1) 51 | rv = self.running_var.reshape(1, -1, 1, 1) 52 | rm = self.running_mean.reshape(1, -1, 1, 1) 53 | eps = 1e-5 54 | scale = w * (rv + eps).rsqrt() 55 | bias = b - rm * scale 56 | return x * scale + bias 57 | 58 | 59 | class BackboneBase(nn.Module): 60 | 61 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 62 | super().__init__() 63 | for name, parameter in backbone.named_parameters(): 64 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 65 | parameter.requires_grad_(False) 66 | if return_interm_layers: 67 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 68 | # return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} deformable detr 69 | self.strides = [4, 8, 16, 32] 70 | self.num_channels = [256, 512, 1024, 2048] 71 | else: 72 | return_layers = {'layer4': "0"} 73 | self.strides = [32] 74 | self.num_channels = [2048] 75 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 76 | 77 | def forward(self, tensor_list: NestedTensor): 78 | xs = self.body(tensor_list.tensors) 79 | out: Dict[str, NestedTensor] = {} 80 | for name, x in xs.items(): 81 | m = tensor_list.mask 82 | assert m is not None 83 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 84 | out[name] = NestedTensor(x, mask) 85 | return out 86 | 87 | 88 | class Backbone(BackboneBase): 89 | """ResNet backbone with frozen BatchNorm.""" 90 | def __init__(self, name: str, 91 | train_backbone: bool, 92 | return_interm_layers: bool, 93 | dilation: bool): 94 | backbone = getattr(torchvision.models, name)( 95 | replace_stride_with_dilation=[False, False, dilation], 96 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 97 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" 98 | super().__init__(backbone, train_backbone, return_interm_layers) 99 | if dilation: 100 | self.strides[-1] = self.strides[-1] // 2 101 | 102 | 103 | class Joiner(nn.Sequential): 104 | def __init__(self, backbone, position_embedding): 105 | super().__init__(backbone, position_embedding) 106 | self.strides = backbone.strides 107 | self.num_channels = backbone.num_channels 108 | 109 | 110 | def forward(self, tensor_list: NestedTensor): 111 | tensor_list.tensors = rearrange(tensor_list.tensors, 'b t c h w -> (b t) c h w') 112 | tensor_list.mask = rearrange(tensor_list.mask, 'b t h w -> (b t) h w') 113 | 114 | xs = self[0](tensor_list) 115 | out: List[NestedTensor] = [] 116 | pos = [] 117 | for name, x in xs.items(): 118 | out.append(x) 119 | # position encoding 120 | pos.append(self[1](x).to(x.tensors.dtype)) 121 | return out, pos 122 | 123 | 124 | def build_backbone(args): 125 | position_embedding = build_position_encoding(args) 126 | train_backbone = args.lr_backbone > 0 127 | return_interm_layers = args.masks or (args.num_feature_levels > 1) 128 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 129 | model = Joiner(backbone, position_embedding) 130 | model.num_channels = backbone.num_channels 131 | return model 132 | 133 | -------------------------------------------------------------------------------- /datasets/coco.py: -------------------------------------------------------------------------------- 1 | """ 2 | COCO dataset which returns image_id for evaluation. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 5 | """ 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | from pycocotools import mask as coco_mask 12 | 13 | import datasets.transforms as T 14 | 15 | 16 | class CocoDetection(torchvision.datasets.CocoDetection): 17 | def __init__(self, img_folder, ann_file, transforms, return_masks): 18 | super(CocoDetection, self).__init__(img_folder, ann_file) 19 | self._transforms = transforms 20 | self.prepare = ConvertCocoPolysToMask(return_masks) 21 | 22 | def __getitem__(self, idx): 23 | img, target = super(CocoDetection, self).__getitem__(idx) 24 | image_id = self.ids[idx] 25 | target = {'image_id': image_id, 'annotations': target} 26 | 27 | img, target = self.prepare(img, target) 28 | if self._transforms is not None: 29 | img, target = self._transforms(img, target) 30 | return img, target 31 | 32 | 33 | def convert_coco_poly_to_mask(segmentations, height, width): 34 | masks = [] 35 | for polygons in segmentations: 36 | rles = coco_mask.frPyObjects(polygons, height, width) 37 | mask = coco_mask.decode(rles) 38 | if len(mask.shape) < 3: 39 | mask = mask[..., None] 40 | mask = torch.as_tensor(mask, dtype=torch.uint8) 41 | mask = mask.any(dim=2) 42 | masks.append(mask) 43 | if masks: 44 | masks = torch.stack(masks, dim=0) 45 | else: 46 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 47 | return masks 48 | 49 | 50 | class ConvertCocoPolysToMask(object): 51 | def __init__(self, return_masks=False): 52 | self.return_masks = return_masks 53 | 54 | def __call__(self, image, target): 55 | w, h = image.size 56 | 57 | image_id = target["image_id"] 58 | image_id = torch.tensor([image_id]) 59 | 60 | anno = target["annotations"] 61 | 62 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 63 | 64 | boxes = [obj["bbox"] for obj in anno] 65 | # guard against no boxes via resizing 66 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 67 | boxes[:, 2:] += boxes[:, :2] 68 | boxes[:, 0::2].clamp_(min=0, max=w) 69 | boxes[:, 1::2].clamp_(min=0, max=h) 70 | 71 | classes = [obj["category_id"] for obj in anno] 72 | classes = torch.tensor(classes, dtype=torch.int64) 73 | 74 | if self.return_masks: 75 | segmentations = [obj["segmentation"] for obj in anno] 76 | masks = convert_coco_poly_to_mask(segmentations, h, w) 77 | 78 | keypoints = None 79 | if anno and "keypoints" in anno[0]: 80 | keypoints = [obj["keypoints"] for obj in anno] 81 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 82 | num_keypoints = keypoints.shape[0] 83 | if num_keypoints: 84 | keypoints = keypoints.view(num_keypoints, -1, 3) 85 | 86 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 87 | boxes = boxes[keep] 88 | classes = classes[keep] 89 | if self.return_masks: 90 | masks = masks[keep] 91 | if keypoints is not None: 92 | keypoints = keypoints[keep] 93 | 94 | target = {} 95 | target["boxes"] = boxes 96 | target["labels"] = classes 97 | if self.return_masks: 98 | target["masks"] = masks 99 | target["image_id"] = image_id 100 | if keypoints is not None: 101 | target["keypoints"] = keypoints 102 | 103 | # for conversion to coco api 104 | area = torch.tensor([obj["area"] for obj in anno]) 105 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 106 | target["area"] = area[keep] 107 | target["iscrowd"] = iscrowd[keep] 108 | 109 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 110 | target["size"] = torch.as_tensor([int(h), int(w)]) 111 | 112 | return image, target 113 | 114 | 115 | def make_coco_transforms(image_set): 116 | 117 | normalize = T.Compose([ 118 | T.ToTensor(), 119 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 120 | ]) 121 | 122 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 123 | 124 | if image_set == 'train': 125 | return T.Compose([ 126 | T.RandomHorizontalFlip(), 127 | T.RandomSelect( 128 | T.RandomResize(scales, max_size=1333), 129 | T.Compose([ 130 | T.RandomResize([400, 500, 600]), 131 | T.RandomSizeCrop(384, 600), 132 | T.RandomResize(scales, max_size=1333), 133 | ]) 134 | ), 135 | normalize, 136 | ]) 137 | 138 | if image_set == 'val': 139 | return T.Compose([ 140 | T.RandomResize([800], max_size=1333), 141 | normalize, 142 | ]) 143 | 144 | raise ValueError(f'unknown {image_set}') 145 | 146 | 147 | def build(image_set, args): 148 | root = Path(args.coco_path) 149 | assert root.exists(), f'provided COCO path {root} does not exist' 150 | mode = 'instances' 151 | PATHS = { 152 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), 153 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), 154 | } 155 | img_folder, ann_file = PATHS[image_set] 156 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) 157 | return dataset 158 | -------------------------------------------------------------------------------- /davis2017/davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from glob import glob 3 | from collections import defaultdict 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | class DAVIS(object): 9 | SUBSET_OPTIONS = ['train', 'val', 'test-dev', 'test-challenge'] 10 | TASKS = ['semi-supervised', 'unsupervised'] 11 | DATASET_WEB = 'https://davischallenge.org/davis2017/code.html' 12 | VOID_LABEL = 255 13 | 14 | def __init__(self, root, task='unsupervised', subset='val', sequences='all', resolution='480p', codalab=False): 15 | """ 16 | Class to read the DAVIS dataset 17 | :param root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 18 | :param task: Task to load the annotations, choose between semi-supervised or unsupervised. 19 | :param subset: Set to load the annotations 20 | :param sequences: Sequences to consider, 'all' to use all the sequences in a set. 21 | :param resolution: Specify the resolution to use the dataset, choose between '480' and 'Full-Resolution' 22 | """ 23 | if subset not in self.SUBSET_OPTIONS: 24 | raise ValueError(f'Subset should be in {self.SUBSET_OPTIONS}') 25 | if task not in self.TASKS: 26 | raise ValueError(f'The only tasks that are supported are {self.TASKS}') 27 | 28 | self.task = task 29 | self.subset = subset 30 | self.root = root 31 | self.img_path = os.path.join(self.root, 'JPEGImages', resolution) 32 | annotations_folder = 'Annotations' if task == 'semi-supervised' else 'Annotations_unsupervised' 33 | self.mask_path = os.path.join(self.root, annotations_folder, resolution) 34 | year = '2019' if task == 'unsupervised' and (subset == 'test-dev' or subset == 'test-challenge') else '2017' 35 | self.imagesets_path = os.path.join(self.root, 'ImageSets', year) 36 | 37 | self._check_directories() 38 | 39 | if sequences == 'all': 40 | with open(os.path.join(self.imagesets_path, f'{self.subset}.txt'), 'r') as f: 41 | tmp = f.readlines() 42 | sequences_names = [x.strip() for x in tmp] 43 | else: 44 | sequences_names = sequences if isinstance(sequences, list) else [sequences] 45 | self.sequences = defaultdict(dict) 46 | 47 | for seq in sequences_names: 48 | images = np.sort(glob(os.path.join(self.img_path, seq, '*.jpg'))).tolist() 49 | if len(images) == 0 and not codalab: 50 | raise FileNotFoundError(f'Images for sequence {seq} not found.') 51 | self.sequences[seq]['images'] = images 52 | masks = np.sort(glob(os.path.join(self.mask_path, seq, '*.png'))).tolist() 53 | masks.extend([-1] * (len(images) - len(masks))) 54 | self.sequences[seq]['masks'] = masks 55 | 56 | def _check_directories(self): 57 | if not os.path.exists(self.root): 58 | raise FileNotFoundError(f'DAVIS not found in the specified directory, download it from {self.DATASET_WEB}') 59 | if not os.path.exists(os.path.join(self.imagesets_path, f'{self.subset}.txt')): 60 | raise FileNotFoundError(f'Subset sequences list for {self.subset} not found, download the missing subset ' 61 | f'for the {self.task} task from {self.DATASET_WEB}') 62 | if self.subset in ['train', 'val'] and not os.path.exists(self.mask_path): 63 | raise FileNotFoundError(f'Annotations folder for the {self.task} task not found, download it from {self.DATASET_WEB}') 64 | 65 | def get_frames(self, sequence): 66 | for img, msk in zip(self.sequences[sequence]['images'], self.sequences[sequence]['masks']): 67 | image = np.array(Image.open(img)) 68 | mask = None if msk is None else np.array(Image.open(msk)) 69 | yield image, mask 70 | 71 | def _get_all_elements(self, sequence, obj_type): 72 | obj = np.array(Image.open(self.sequences[sequence][obj_type][0])) 73 | all_objs = np.zeros((len(self.sequences[sequence][obj_type]), *obj.shape)) 74 | obj_id = [] 75 | for i, obj in enumerate(self.sequences[sequence][obj_type]): 76 | all_objs[i, ...] = np.array(Image.open(obj)) 77 | obj_id.append(''.join(obj.split('/')[-1].split('.')[:-1])) 78 | return all_objs, obj_id 79 | 80 | def get_all_images(self, sequence): 81 | return self._get_all_elements(sequence, 'images') 82 | 83 | def get_all_masks(self, sequence, separate_objects_masks=False): 84 | masks, masks_id = self._get_all_elements(sequence, 'masks') 85 | masks_void = np.zeros_like(masks) 86 | 87 | # Separate void and object masks 88 | for i in range(masks.shape[0]): 89 | masks_void[i, ...] = masks[i, ...] == 255 90 | masks[i, masks[i, ...] == 255] = 0 91 | 92 | if separate_objects_masks: 93 | num_objects = int(np.max(masks[0, ...])) 94 | tmp = np.ones((num_objects, *masks.shape)) 95 | tmp = tmp * np.arange(1, num_objects + 1)[:, None, None, None] 96 | masks = (tmp == masks[None, ...]) 97 | masks = masks > 0 98 | return masks, masks_void, masks_id 99 | 100 | def get_sequences(self): 101 | for seq in self.sequences: 102 | yield seq 103 | 104 | 105 | if __name__ == '__main__': 106 | from matplotlib import pyplot as plt 107 | 108 | only_first_frame = True 109 | subsets = ['train', 'val'] 110 | 111 | for s in subsets: 112 | dataset = DAVIS(root='/home/csergi/scratch2/Databases/DAVIS2017_private', subset=s) 113 | for seq in dataset.get_sequences(): 114 | g = dataset.get_frames(seq) 115 | img, mask = next(g) 116 | plt.subplot(2, 1, 1) 117 | plt.title(seq) 118 | plt.imshow(img) 119 | plt.subplot(2, 1, 2) 120 | plt.imshow(mask) 121 | plt.show(block=True) 122 | 123 | -------------------------------------------------------------------------------- /datasets/samplers.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from codes in torch.utils.data.distributed 7 | # ------------------------------------------------------------------------ 8 | 9 | import os 10 | import math 11 | import torch 12 | import torch.distributed as dist 13 | from torch.utils.data.sampler import Sampler 14 | 15 | 16 | class DistributedSampler(Sampler): 17 | """Sampler that restricts data loading to a subset of the dataset. 18 | It is especially useful in conjunction with 19 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 20 | process can pass a DistributedSampler instance as a DataLoader sampler, 21 | and load a subset of the original dataset that is exclusive to it. 22 | .. note:: 23 | Dataset is assumed to be of constant size. 24 | Arguments: 25 | dataset: Dataset used for sampling. 26 | num_replicas (optional): Number of processes participating in 27 | distributed training. 28 | rank (optional): Rank of the current process within num_replicas. 29 | """ 30 | 31 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 32 | if num_replicas is None: 33 | if not dist.is_available(): 34 | raise RuntimeError("Requires distributed package to be available") 35 | num_replicas = dist.get_world_size() 36 | if rank is None: 37 | if not dist.is_available(): 38 | raise RuntimeError("Requires distributed package to be available") 39 | rank = dist.get_rank() 40 | self.dataset = dataset 41 | self.num_replicas = num_replicas 42 | self.rank = rank 43 | self.epoch = 0 44 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 45 | self.total_size = self.num_samples * self.num_replicas 46 | self.shuffle = shuffle 47 | 48 | def __iter__(self): 49 | if self.shuffle: 50 | # deterministically shuffle based on epoch 51 | g = torch.Generator() 52 | g.manual_seed(self.epoch) 53 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 54 | else: 55 | indices = torch.arange(len(self.dataset)).tolist() 56 | 57 | # add extra samples to make it evenly divisible 58 | indices += indices[: (self.total_size - len(indices))] 59 | assert len(indices) == self.total_size 60 | 61 | # subsample 62 | offset = self.num_samples * self.rank 63 | indices = indices[offset : offset + self.num_samples] 64 | assert len(indices) == self.num_samples 65 | 66 | return iter(indices) 67 | 68 | def __len__(self): 69 | return self.num_samples 70 | 71 | def set_epoch(self, epoch): 72 | self.epoch = epoch 73 | 74 | 75 | class NodeDistributedSampler(Sampler): 76 | """Sampler that restricts data loading to a subset of the dataset. 77 | It is especially useful in conjunction with 78 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 79 | process can pass a DistributedSampler instance as a DataLoader sampler, 80 | and load a subset of the original dataset that is exclusive to it. 81 | .. note:: 82 | Dataset is assumed to be of constant size. 83 | Arguments: 84 | dataset: Dataset used for sampling. 85 | num_replicas (optional): Number of processes participating in 86 | distributed training. 87 | rank (optional): Rank of the current process within num_replicas. 88 | """ 89 | 90 | def __init__(self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True): 91 | if num_replicas is None: 92 | if not dist.is_available(): 93 | raise RuntimeError("Requires distributed package to be available") 94 | num_replicas = dist.get_world_size() 95 | if rank is None: 96 | if not dist.is_available(): 97 | raise RuntimeError("Requires distributed package to be available") 98 | rank = dist.get_rank() 99 | if local_rank is None: 100 | local_rank = int(os.environ.get('LOCAL_RANK', 0)) 101 | if local_size is None: 102 | local_size = int(os.environ.get('LOCAL_SIZE', 1)) 103 | self.dataset = dataset 104 | self.shuffle = shuffle 105 | self.num_replicas = num_replicas 106 | self.num_parts = local_size 107 | self.rank = rank 108 | self.local_rank = local_rank 109 | self.epoch = 0 110 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 111 | self.total_size = self.num_samples * self.num_replicas 112 | 113 | self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts 114 | 115 | def __iter__(self): 116 | if self.shuffle: 117 | # deterministically shuffle based on epoch 118 | g = torch.Generator() 119 | g.manual_seed(self.epoch) 120 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 121 | else: 122 | indices = torch.arange(len(self.dataset)).tolist() 123 | indices = [i for i in indices if i % self.num_parts == self.local_rank] 124 | 125 | # add extra samples to make it evenly divisible 126 | indices += indices[:(self.total_size_parts - len(indices))] 127 | assert len(indices) == self.total_size_parts 128 | 129 | # subsample 130 | indices = indices[self.rank // self.num_parts:self.total_size_parts:self.num_replicas // self.num_parts] 131 | assert len(indices) == self.num_samples 132 | 133 | return iter(indices) 134 | 135 | def __len__(self): 136 | return self.num_samples 137 | 138 | def set_epoch(self, epoch): 139 | self.epoch = epoch 140 | -------------------------------------------------------------------------------- /davis2017/evaluation.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from tqdm import tqdm 3 | import warnings 4 | warnings.filterwarnings("ignore", category=RuntimeWarning) 5 | 6 | import numpy as np 7 | from davis2017.davis import DAVIS 8 | from davis2017.metrics import db_eval_boundary, db_eval_iou 9 | from davis2017 import utils 10 | from davis2017.results import Results 11 | from scipy.optimize import linear_sum_assignment 12 | 13 | 14 | class DAVISEvaluation(object): 15 | def __init__(self, davis_root, task, gt_set, sequences='all', codalab=False): 16 | """ 17 | Class to evaluate DAVIS sequences from a certain set and for a certain task 18 | :param davis_root: Path to the DAVIS folder that contains JPEGImages, Annotations, etc. folders. 19 | :param task: Task to compute the evaluation, chose between semi-supervised or unsupervised. 20 | :param gt_set: Set to compute the evaluation 21 | :param sequences: Sequences to consider for the evaluation, 'all' to use all the sequences in a set. 22 | """ 23 | self.davis_root = davis_root 24 | self.task = task 25 | self.dataset = DAVIS(root=davis_root, task=task, subset=gt_set, sequences=sequences, codalab=codalab) 26 | 27 | @staticmethod 28 | def _evaluate_semisupervised(all_gt_masks, all_res_masks, all_void_masks, metric): 29 | if all_res_masks.shape[0] > all_gt_masks.shape[0]: 30 | sys.stdout.write("\nIn your PNG files there is an index higher than the number of objects in the sequence!") 31 | sys.exit() 32 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 33 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 34 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 35 | j_metrics_res, f_metrics_res = np.zeros(all_gt_masks.shape[:2]), np.zeros(all_gt_masks.shape[:2]) 36 | for ii in range(all_gt_masks.shape[0]): 37 | if 'J' in metric: 38 | j_metrics_res[ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 39 | if 'F' in metric: 40 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...], all_void_masks) 41 | return j_metrics_res, f_metrics_res 42 | 43 | @staticmethod 44 | def _evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric, max_n_proposals=20): 45 | if all_res_masks.shape[0] > max_n_proposals: 46 | sys.stdout.write(f"\nIn your PNG files there is an index higher than the maximum number ({max_n_proposals}) of proposals allowed!") 47 | sys.exit() 48 | elif all_res_masks.shape[0] < all_gt_masks.shape[0]: 49 | zero_padding = np.zeros((all_gt_masks.shape[0] - all_res_masks.shape[0], *all_res_masks.shape[1:])) 50 | all_res_masks = np.concatenate([all_res_masks, zero_padding], axis=0) 51 | j_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) 52 | f_metrics_res = np.zeros((all_res_masks.shape[0], all_gt_masks.shape[0], all_gt_masks.shape[1])) 53 | for ii in range(all_gt_masks.shape[0]): 54 | for jj in range(all_res_masks.shape[0]): 55 | if 'J' in metric: 56 | j_metrics_res[jj, ii, :] = db_eval_iou(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) 57 | if 'F' in metric: 58 | f_metrics_res[jj, ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[jj, ...], all_void_masks) 59 | if 'J' in metric and 'F' in metric: 60 | all_metrics = (np.mean(j_metrics_res, axis=2) + np.mean(f_metrics_res, axis=2)) / 2 61 | else: 62 | all_metrics = np.mean(j_metrics_res, axis=2) if 'J' in metric else np.mean(f_metrics_res, axis=2) 63 | row_ind, col_ind = linear_sum_assignment(-all_metrics) 64 | return j_metrics_res[row_ind, col_ind, :], f_metrics_res[row_ind, col_ind, :] 65 | 66 | def evaluate(self, res_path, metric=('J', 'F'), debug=False): 67 | metric = metric if isinstance(metric, tuple) or isinstance(metric, list) else [metric] 68 | if 'T' in metric: 69 | raise ValueError('Temporal metric not supported!') 70 | if 'J' not in metric and 'F' not in metric: 71 | raise ValueError('Metric possible values are J for IoU or F for Boundary') 72 | 73 | # Containers 74 | metrics_res = {} 75 | if 'J' in metric: 76 | metrics_res['J'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 77 | if 'F' in metric: 78 | metrics_res['F'] = {"M": [], "R": [], "D": [], "M_per_object": {}} 79 | 80 | # Sweep all sequences 81 | results = Results(root_dir=res_path) 82 | for seq in tqdm(list(self.dataset.get_sequences())): 83 | all_gt_masks, all_void_masks, all_masks_id = self.dataset.get_all_masks(seq, True) 84 | if self.task == 'semi-supervised': 85 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 86 | all_res_masks = results.read_masks(seq, all_masks_id) 87 | if self.task == 'unsupervised': 88 | j_metrics_res, f_metrics_res = self._evaluate_unsupervised(all_gt_masks, all_res_masks, all_void_masks, metric) 89 | elif self.task == 'semi-supervised': 90 | j_metrics_res, f_metrics_res = self._evaluate_semisupervised(all_gt_masks, all_res_masks, None, metric) 91 | for ii in range(all_gt_masks.shape[0]): 92 | seq_name = f'{seq}_{ii+1}' 93 | if 'J' in metric: 94 | [JM, JR, JD] = utils.db_statistics(j_metrics_res[ii]) 95 | metrics_res['J']["M"].append(JM) 96 | metrics_res['J']["R"].append(JR) 97 | metrics_res['J']["D"].append(JD) 98 | metrics_res['J']["M_per_object"][seq_name] = JM 99 | if 'F' in metric: 100 | [FM, FR, FD] = utils.db_statistics(f_metrics_res[ii]) 101 | metrics_res['F']["M"].append(FM) 102 | metrics_res['F']["R"].append(FR) 103 | metrics_res['F']["D"].append(FD) 104 | metrics_res['F']["M_per_object"][seq_name] = FM 105 | 106 | # Show progress 107 | if debug: 108 | sys.stdout.write(seq + '\n') 109 | sys.stdout.flush() 110 | return metrics_res 111 | -------------------------------------------------------------------------------- /davis2017/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import errno 3 | import numpy as np 4 | from PIL import Image 5 | import warnings 6 | from davis2017.davis import DAVIS 7 | 8 | 9 | def _pascal_color_map(N=256, normalized=False): 10 | """ 11 | Python implementation of the color map function for the PASCAL VOC data set. 12 | Official Matlab version can be found in the PASCAL VOC devkit 13 | http://host.robots.ox.ac.uk/pascal/VOC/voc2012/index.html#devkit 14 | """ 15 | 16 | def bitget(byteval, idx): 17 | return (byteval & (1 << idx)) != 0 18 | 19 | dtype = 'float32' if normalized else 'uint8' 20 | cmap = np.zeros((N, 3), dtype=dtype) 21 | for i in range(N): 22 | r = g = b = 0 23 | c = i 24 | for j in range(8): 25 | r = r | (bitget(c, 0) << 7 - j) 26 | g = g | (bitget(c, 1) << 7 - j) 27 | b = b | (bitget(c, 2) << 7 - j) 28 | c = c >> 3 29 | 30 | cmap[i] = np.array([r, g, b]) 31 | 32 | cmap = cmap / 255 if normalized else cmap 33 | return cmap 34 | 35 | 36 | def overlay_semantic_mask(im, ann, alpha=0.5, colors=None, contour_thickness=None): 37 | im, ann = np.asarray(im, dtype=np.uint8), np.asarray(ann, dtype=np.int) 38 | if im.shape[:-1] != ann.shape: 39 | raise ValueError('First two dimensions of `im` and `ann` must match') 40 | if im.shape[-1] != 3: 41 | raise ValueError('im must have three channels at the 3 dimension') 42 | 43 | colors = colors or _pascal_color_map() 44 | colors = np.asarray(colors, dtype=np.uint8) 45 | 46 | mask = colors[ann] 47 | fg = im * alpha + (1 - alpha) * mask 48 | 49 | img = im.copy() 50 | img[ann > 0] = fg[ann > 0] 51 | 52 | if contour_thickness: # pragma: no cover 53 | import cv2 54 | for obj_id in np.unique(ann[ann > 0]): 55 | contours = cv2.findContours((ann == obj_id).astype( 56 | np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[-2:] 57 | cv2.drawContours(img, contours[0], -1, colors[obj_id].tolist(), 58 | contour_thickness) 59 | return img 60 | 61 | 62 | def generate_obj_proposals(davis_root, subset, num_proposals, save_path): 63 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 64 | for seq in dataset.get_sequences(): 65 | save_dir = os.path.join(save_path, seq) 66 | if os.path.exists(save_dir): 67 | continue 68 | all_gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 69 | img_size = all_gt_masks.shape[2:] 70 | num_rows = int(np.ceil(np.sqrt(num_proposals))) 71 | proposals = np.zeros((num_proposals, len(all_masks_id), *img_size)) 72 | height_slices = np.floor(np.arange(0, img_size[0] + 1, img_size[0]/num_rows)).astype(np.uint).tolist() 73 | width_slices = np.floor(np.arange(0, img_size[1] + 1, img_size[1]/num_rows)).astype(np.uint).tolist() 74 | ii = 0 75 | prev_h, prev_w = 0, 0 76 | for h in height_slices[1:]: 77 | for w in width_slices[1:]: 78 | proposals[ii, :, prev_h:h, prev_w:w] = 1 79 | prev_w = w 80 | ii += 1 81 | if ii == num_proposals: 82 | break 83 | prev_h, prev_w = h, 0 84 | if ii == num_proposals: 85 | break 86 | 87 | os.makedirs(save_dir, exist_ok=True) 88 | for i, mask_id in enumerate(all_masks_id): 89 | mask = np.sum(proposals[:, i, ...] * np.arange(1, proposals.shape[0] + 1)[:, None, None], axis=0) 90 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 91 | 92 | 93 | def generate_random_permutation_gt_obj_proposals(davis_root, subset, save_path): 94 | dataset = DAVIS(davis_root, subset=subset, codalab=True) 95 | for seq in dataset.get_sequences(): 96 | gt_masks, all_masks_id = dataset.get_all_masks(seq, True) 97 | obj_swap = np.random.permutation(np.arange(gt_masks.shape[0])) 98 | gt_masks = gt_masks[obj_swap, ...] 99 | save_dir = os.path.join(save_path, seq) 100 | os.makedirs(save_dir, exist_ok=True) 101 | for i, mask_id in enumerate(all_masks_id): 102 | mask = np.sum(gt_masks[:, i, ...] * np.arange(1, gt_masks.shape[0] + 1)[:, None, None], axis=0) 103 | save_mask(mask, os.path.join(save_dir, f'{mask_id}.png')) 104 | 105 | 106 | def color_map(N=256, normalized=False): 107 | def bitget(byteval, idx): 108 | return ((byteval & (1 << idx)) != 0) 109 | 110 | dtype = 'float32' if normalized else 'uint8' 111 | cmap = np.zeros((N, 3), dtype=dtype) 112 | for i in range(N): 113 | r = g = b = 0 114 | c = i 115 | for j in range(8): 116 | r = r | (bitget(c, 0) << 7-j) 117 | g = g | (bitget(c, 1) << 7-j) 118 | b = b | (bitget(c, 2) << 7-j) 119 | c = c >> 3 120 | 121 | cmap[i] = np.array([r, g, b]) 122 | 123 | cmap = cmap/255 if normalized else cmap 124 | return cmap 125 | 126 | 127 | def save_mask(mask, img_path): 128 | if np.max(mask) > 255: 129 | raise ValueError('Maximum id pixel value is 255') 130 | mask_img = Image.fromarray(mask.astype(np.uint8)) 131 | mask_img.putpalette(color_map().flatten().tolist()) 132 | mask_img.save(img_path) 133 | 134 | 135 | def db_statistics(per_frame_values): 136 | """ Compute mean,recall and decay from per-frame evaluation. 137 | Arguments: 138 | per_frame_values (ndarray): per-frame evaluation 139 | 140 | Returns: 141 | M,O,D (float,float,float): 142 | return evaluation statistics: mean,recall,decay. 143 | """ 144 | 145 | # strip off nan values 146 | with warnings.catch_warnings(): 147 | warnings.simplefilter("ignore", category=RuntimeWarning) 148 | M = np.nanmean(per_frame_values) 149 | O = np.nanmean(per_frame_values > 0.5) 150 | 151 | N_bins = 4 152 | ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 153 | ids = ids.astype(np.uint8) 154 | 155 | D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] 156 | 157 | with warnings.catch_warnings(): 158 | warnings.simplefilter("ignore", category=RuntimeWarning) 159 | D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) 160 | 161 | return M, O, D 162 | 163 | 164 | def list_files(dir, extension=".png"): 165 | return [os.path.splitext(file_)[0] for file_ in os.listdir(dir) if file_.endswith(extension)] 166 | 167 | 168 | def force_symlink(file1, file2): 169 | try: 170 | os.symlink(file1, file2) 171 | except OSError as e: 172 | if e.errno == errno.EEXIST: 173 | os.remove(file2) 174 | os.symlink(file1, file2) 175 | -------------------------------------------------------------------------------- /models/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # Modify for sample points visualization 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | from __future__ import absolute_import 11 | from __future__ import print_function 12 | from __future__ import division 13 | 14 | import warnings 15 | import math 16 | 17 | import torch 18 | from torch import nn 19 | import torch.nn.functional as F 20 | from torch.nn.init import xavier_uniform_, constant_ 21 | 22 | from ..functions import MSDeformAttnFunction 23 | 24 | 25 | def _is_power_of_2(n): 26 | if (not isinstance(n, int)) or (n < 0): 27 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 28 | return (n & (n-1) == 0) and n != 0 29 | 30 | 31 | class MSDeformAttn(nn.Module): 32 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 33 | """ 34 | Multi-Scale Deformable Attention Module 35 | :param d_model hidden dimension 36 | :param n_levels number of feature levels 37 | :param n_heads number of attention heads 38 | :param n_points number of sampling points per attention head per feature level 39 | """ 40 | super().__init__() 41 | if d_model % n_heads != 0: 42 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 43 | _d_per_head = d_model // n_heads 44 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 45 | if not _is_power_of_2(_d_per_head): 46 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 47 | "which is more efficient in our CUDA implementation.") 48 | 49 | self.im2col_step = 64 50 | 51 | self.d_model = d_model 52 | self.n_levels = n_levels 53 | self.n_heads = n_heads 54 | self.n_points = n_points 55 | 56 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 57 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 58 | self.value_proj = nn.Linear(d_model, d_model) 59 | self.output_proj = nn.Linear(d_model, d_model) 60 | 61 | self._reset_parameters() 62 | 63 | def _reset_parameters(self): 64 | constant_(self.sampling_offsets.weight.data, 0.) 65 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 66 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 67 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 68 | for i in range(self.n_points): 69 | grid_init[:, :, i, :] *= i + 1 70 | with torch.no_grad(): 71 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 72 | constant_(self.attention_weights.weight.data, 0.) 73 | constant_(self.attention_weights.bias.data, 0.) 74 | xavier_uniform_(self.value_proj.weight.data) 75 | constant_(self.value_proj.bias.data, 0.) 76 | xavier_uniform_(self.output_proj.weight.data) 77 | constant_(self.output_proj.bias.data, 0.) 78 | 79 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 80 | """ 81 | :param query (N, Length_{query}, C) 82 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 83 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 84 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 85 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 86 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 87 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 88 | 89 | :return output (N, Length_{query}, C) 90 | """ 91 | N, Len_q, _ = query.shape 92 | N, Len_in, _ = input_flatten.shape 93 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 94 | 95 | value = self.value_proj(input_flatten) 96 | if input_padding_mask is not None: 97 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 98 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 99 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 100 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 101 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 102 | # N, Len_q, n_heads, n_levels, n_points, 2 103 | if reference_points.shape[-1] == 2: 104 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 105 | sampling_locations = reference_points[:, :, None, :, None, :] \ 106 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 107 | elif reference_points.shape[-1] == 4: 108 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 109 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 110 | else: 111 | raise ValueError( 112 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 113 | output = MSDeformAttnFunction.apply( 114 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 115 | output = self.output_proj(output) 116 | 117 | return output, sampling_locations, attention_weights 118 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from util.misc import NestedTensor 10 | 11 | # dimension == 1 12 | class PositionEmbeddingSine1D(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=256, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors # [B, C, T] 30 | mask = tensor_list.mask # [B, T] 31 | assert mask is not None 32 | not_mask = ~mask 33 | x_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T] 34 | if self.normalize: 35 | eps = 1e-6 36 | x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, None] / dim_t # [B, T, C] 42 | # n,c,t 43 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 44 | pos = pos_x.permute(0, 2, 1) # [B, C, T] 45 | return pos 46 | 47 | # dimension == 2 48 | class PositionEmbeddingSine2D(nn.Module): 49 | """ 50 | This is a more standard version of the position embedding, very similar to the one 51 | used by the Attention is all you need paper, generalized to work on images. 52 | """ 53 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 54 | super().__init__() 55 | self.num_pos_feats = num_pos_feats 56 | self.temperature = temperature 57 | self.normalize = normalize 58 | if scale is not None and normalize is False: 59 | raise ValueError("normalize should be True if scale is passed") 60 | if scale is None: 61 | scale = 2 * math.pi 62 | self.scale = scale 63 | 64 | def forward(self, tensor_list: NestedTensor): 65 | x = tensor_list.tensors # [B, C, H, W] 66 | mask = tensor_list.mask # [B, H, W] 67 | assert mask is not None 68 | not_mask = ~mask 69 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 70 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 71 | if self.normalize: 72 | eps = 1e-6 73 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 74 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 75 | 76 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 77 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 78 | 79 | pos_x = x_embed[:, :, :, None] / dim_t 80 | pos_y = y_embed[:, :, :, None] / dim_t 81 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 82 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 83 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 84 | return pos # [B, C, H, W] 85 | 86 | 87 | # dimension == 3 88 | class PositionEmbeddingSine3D(nn.Module): 89 | """ 90 | This is a more standard version of the position embedding, very similar to the one 91 | used by the Attention is all you need paper, generalized to work on images. 92 | """ 93 | def __init__(self, num_pos_feats=64, num_frames=36, temperature=10000, normalize=False, scale=None): 94 | super().__init__() 95 | self.num_pos_feats = num_pos_feats 96 | self.temperature = temperature 97 | self.normalize = normalize 98 | self.frames = num_frames 99 | if scale is not None and normalize is False: 100 | raise ValueError("normalize should be True if scale is passed") 101 | if scale is None: 102 | scale = 2 * math.pi 103 | self.scale = scale 104 | 105 | def forward(self, tensor_list: NestedTensor): 106 | x = tensor_list.tensors # [B*T, C, H, W] 107 | mask = tensor_list.mask # [B*T, H, W] 108 | n,h,w = mask.shape 109 | mask = mask.reshape(n//self.frames, self.frames,h,w) # [B, T, H, W] 110 | assert mask is not None 111 | not_mask = ~mask 112 | z_embed = not_mask.cumsum(1, dtype=torch.float32) # [B, T, H, W] 113 | y_embed = not_mask.cumsum(2, dtype=torch.float32) # [B, T, H, W] 114 | x_embed = not_mask.cumsum(3, dtype=torch.float32) # [B, T, H, W] 115 | if self.normalize: 116 | eps = 1e-6 117 | z_embed = z_embed / (z_embed[:, -1:, :, :] + eps) * self.scale 118 | y_embed = y_embed / (y_embed[:, :, -1:, :] + eps) * self.scale 119 | x_embed = x_embed / (x_embed[:, :, :, -1:] + eps) * self.scale 120 | 121 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) # 122 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 123 | 124 | pos_x = x_embed[:, :, :, :, None] / dim_t # [B, T, H, W, c] 125 | pos_y = y_embed[:, :, :, :, None] / dim_t 126 | pos_z = z_embed[:, :, :, :, None] / dim_t 127 | pos_x = torch.stack((pos_x[:, :, :, :, 0::2].sin(), pos_x[:, :, :, :, 1::2].cos()), dim=5).flatten(4) # [B, T, H, W, c] 128 | pos_y = torch.stack((pos_y[:, :, :, :, 0::2].sin(), pos_y[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 129 | pos_z = torch.stack((pos_z[:, :, :, :, 0::2].sin(), pos_z[:, :, :, :, 1::2].cos()), dim=5).flatten(4) 130 | pos = torch.cat((pos_z, pos_y, pos_x), dim=4).permute(0, 1, 4, 2, 3) # [B, T, C, H, W] 131 | return pos 132 | 133 | 134 | 135 | def build_position_encoding(args): 136 | # build 2D position encoding 137 | N_steps = args.hidden_dim // 2 # 256 / 2 = 128 138 | if args.position_embedding in ('v2', 'sine'): 139 | # TODO find a better way of exposing other arguments 140 | position_embedding = PositionEmbeddingSine2D(N_steps, normalize=True) 141 | else: 142 | raise ValueError(f"not supported {args.position_embedding}") 143 | 144 | return position_embedding 145 | 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MUTR: A Unified Temporal Transformer for Multi-Modal Video Object Segmentation 2 | 3 | Official implementation of ['Referred by Multi-Modality: A Unified Temporal Transformer for Video Object Segmentation'](https://arxiv.org/abs/2305.16318). 4 | 5 | The paper has been accepted by **AAAI 2024** 🔥. 6 | 13 | 14 | ## Introduction 15 | We propose **MUTR**, a **M**ulti-modal **U**nified **T**emporal transformer for **R**eferring video object segmentation. With a unified framework for the first time, MUTR adopts a DETR-style transformer and is capable of segmenting video objects designated by either text or audio reference. Specifically, we introduce two strategies to fully explore the temporal relations between videos and multi-modal signals, which are low-level temporal aggregation (MTA) and high-level temporal interaction (MTI). 16 | On Ref-YouTube-VOS and AVSBench with respective text and audio references, MUTR achieves **+4.2\%** and **+4.2\%** J&F improvements to *state-of-the-art* methods, demonstrating our significance for unified multi-modal VOS. 17 | 18 |

19 | 20 | ## Update 21 | * **TODO**: Release the code and checkpoints on AV-VOS with audio reference 📌. 22 | * We release the code and checkpoints of MUTR on RVOS with language reference 🔥. 23 | 24 | ## Requirements 25 | 26 | We test the codes in the following environments, other versions may also be compatible: 27 | 28 | - CUDA 11.1 29 | - Python 3.7 30 | - Pytorch 1.8.1 31 | 32 | 33 | ## Installation 34 | 35 | Please refer to [install.md](docs/install.md) for installation. 36 | 37 | 38 | 39 | ## Data Preparation 40 | 41 | Please refer to [data.md](docs/data.md) for data preparation. 42 | 43 | After the organization, we expect the directory struture to be the following: 44 | 45 | ``` 46 | MUTR/ 47 | ├── data/ 48 | │ ├── ref-youtube-vos/ 49 | │ ├── ref-davis/ 50 | ├── davis2017/ 51 | ├── datasets/ 52 | ├── models/ 53 | ├── scipts/ 54 | ├── tools/ 55 | ├── util/ 56 | ├── train.py 57 | ├── engine.py 58 | ├── inference_ytvos.py 59 | ├── inference_davis.py 60 | ├── opts.py 61 | ... 62 | ``` 63 | 64 | ## Get Started 65 | 66 | Please see [Ref-YouTube-VOS](docs/Ref-YouTube-VOS.md) and [Ref-DAVIS 2017](docs/Ref-DAVIS2017.md) for details. 67 | 68 | 69 | ## Model Zoo and Results 70 | 71 | **Note:** 72 | 73 | `--backbone` denotes the different backbones (see [here](https://github.com/OpenGVLab/MUTR/blob/c4d8901e0fca1da667922d453a004259ffb1a5cd/opts.py#L31)). 74 | 75 | `--backbone_pretrained` denotes the path of the backbone's pretrained weight (see [here](https://github.com/OpenGVLab/MUTR/blob/c4d8901e0fca1da667922d453a004259ffb1a5cd/opts.py#L33)). 76 | 77 | 78 | 79 | 80 | ### Ref-YouTube-VOS 81 | 82 | To evaluate the results, please upload the zip file to the [competition server](https://codalab.lisn.upsaclay.fr/competitions/3282#participate-submit_results). 83 | 84 | 85 | | Backbone| J&F | J | F | Model | Submission | 86 | | :----: | :----: | :----: | :----: | :----: | :----: | 87 | | ResNet-50 | 61.9 | 60.4 | 63.4 | [model](https://drive.google.com/file/d/1W1hSYd1DDFdhl46rpE1Y1OgsG1N5Zh7B/view?usp=sharing) | [link](https://drive.google.com/file/d/1ORmyM8cNgnjnXSy6SBC27wKRsORAc8Wu/view?usp=sharing) | 88 | | ResNet-101 | 63.6 | 61.8 | 65.4 | [model](https://drive.google.com/file/d/1tIX6jmM9MjCxbMDh89e2LugY2ul12GD6/view?usp=sharing) | [link](https://drive.google.com/file/d/1JAG6u_U5c5w0K0z3D5_r3UseN2Fmk9_y/view?usp=sharing) | 89 | | Swin-L | 68.4 | 66.4 | 70.4 | [model](https://drive.google.com/file/d/1PrWZjppjxEvJe2wQ7a3augG4iRQX1pLJ/view?usp=sharing) | [link](https://drive.google.com/file/d/1EYh82Ij30IJTO4Kn1-jvbbpARJybJzdj/view?usp=sharing) | 90 | | Video-Swin-T | 64.0 | 62.2 | 65.8 | [model](https://drive.google.com/file/d/1-TkdQksTrmB253ao99NgnmsrsQkous2V/view?usp=sharing) | [link](https://drive.google.com/file/d/14bNF3WsPResaUrB0NWmJ8GQ1eaE-Fw_7/view?usp=sharing) | 91 | | Video-Swin-S | 65.1 | 63.0 | 67.1 | [model](https://drive.google.com/file/d/1Z4ENlWAKIEp44HC0OH4CjsZXgQTMTvDK/view?usp=sharing) | [link](https://drive.google.com/file/d/19kWvu1fc-5hhkI1Ibzzps3pYQA4N42JU/view?usp=sharing) | 92 | | Video-Swin-B | 67.5 | 65.4 | 69.6 | [model](https://drive.google.com/file/d/1-ezn8H2GPTc7o6cUGN1r3DI6sDLF2J5s/view?usp=sharing) | [link](https://drive.google.com/file/d/1aYFs_DDsEFHo7Dd8pOG24O2rwyjjpEMN/view?usp=sharing) | 93 | | ConvNext-L | 66.7 | 64.8 | 68.7 | [model](https://drive.google.com/file/d/1w4o392nrKEDd2JqlBcu5r1ZqoSAMtFwD/view?usp=sharing) | [link](https://drive.google.com/file/d/1jASGNhitDozzN9trIlAVWsmjio7GjsA0/view?usp=sharing) | 94 | | ConvMAE-B | 66.9 | 64.7 | 69.1 | [model](https://drive.google.com/file/d/1_hHPVici-RIcn7ocvn6RPDSMG4gJA5Pj/view?usp=sharing) | [link](https://drive.google.com/file/d/1CORTnxJo4hWRCR4eSTcgPjxTi_5ZOlPV/view?usp=sharing) | 95 | 96 | 97 | 98 | 99 | ### Ref-DAVIS17 100 | 101 | As described in the paper, we report the results using the model trained on Ref-Youtube-VOS without finetune. 102 | 103 | | Backbone| J&F | J | F | Model | 104 | | :----: | :----: | :----: | :----: | :----: | 105 | | ResNet-50 | 65.3 | 62.4 | 68.2 | [model](https://drive.google.com/file/d/1W1hSYd1DDFdhl46rpE1Y1OgsG1N5Zh7B/view?usp=sharing) | 106 | | ResNet-101 | 65.3 | 61.9 | 68.6 | [model](https://drive.google.com/file/d/1tIX6jmM9MjCxbMDh89e2LugY2ul12GD6/view?usp=sharing) | 107 | | Swin-L | 68.0 | 64.8 | 71.3 | [model](https://drive.google.com/file/d/1PrWZjppjxEvJe2wQ7a3augG4iRQX1pLJ/view?usp=sharing) | 108 | | Video-Swin-T | 66.5 | 63.0 | 70.0 | [model](https://drive.google.com/file/d/1-TkdQksTrmB253ao99NgnmsrsQkous2V/view?usp=sharing) | 109 | | Video-Swin-S | 66.1 | 62.6 | 69.8 | [model](https://drive.google.com/file/d/1Z4ENlWAKIEp44HC0OH4CjsZXgQTMTvDK/view?usp=sharing) | 110 | | Video-Swin-B | 66.4 | 62.8 | 70.0 | [model](https://drive.google.com/file/d/1-ezn8H2GPTc7o6cUGN1r3DI6sDLF2J5s/view?usp=sharing) | 111 | | ConvNext-L | 69.0 | 65.6 | 72.4 | [model](https://drive.google.com/file/d/1w4o392nrKEDd2JqlBcu5r1ZqoSAMtFwD/view?usp=sharing) | 112 | | ConvMAE-B | 69.2 | 65.6 | 72.8 | [model](https://drive.google.com/file/d/1_hHPVici-RIcn7ocvn6RPDSMG4gJA5Pj/view?usp=sharing) | 113 | 114 | 115 | ## Acknowledgement 116 | 117 | This repo is based on [ReferFormer](https://github.com/wjn922/ReferFormer/tree/main). We also refer to the repositories [Deformable DETR](https://github.com/ashkamath/mdetr) and [MTTR](https://github.com/fundamentalvision/Deformable-DETR). Thanks for their wonderful works. 118 | 119 | 120 | ## Citation 121 | 122 | ``` 123 | @inproceedings{yan2024referred, 124 | title={Referred by multi-modality: A unified temporal transformer for video object segmentation}, 125 | author={Yan, Shilin and Zhang, Renrui and Guo, Ziyu and Chen, Wenchao and Zhang, Wei and Li, Hongyang and Qiao, Yu and Dong, Hao and He, Zhongjiang and Gao, Peng}, 126 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 127 | volume={38}, 128 | number={6}, 129 | pages={6449--6457}, 130 | year={2024} 131 | } 132 | ``` 133 | 134 | ## Contact 135 | If you have any question about this project, please feel free to contact tattoo.ysl@gmail.com. 136 | -------------------------------------------------------------------------------- /tools/data/convert_refexp_to_coco.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from datasets.refer import REFER 4 | import cv2 5 | from tqdm import tqdm 6 | import json 7 | import pickle 8 | import json 9 | 10 | 11 | def convert_to_coco(data_root='/mnt/petrelfs/yanshilin/DATASETS/coco', output_root='/mnt/petrelfs/yanshilin/DATASETS/coco', dataset='refcoco', dataset_split='unc'): 12 | dataset_dir = os.path.join(data_root, dataset) 13 | output_dir = os.path.join(output_root, dataset) # .json save path 14 | if not os.path.exists(output_dir): 15 | os.makedirs(output_dir) 16 | 17 | # read REFER 18 | refer = REFER(data_root, dataset, dataset_split) 19 | refs = refer.Refs 20 | anns = refer.Anns 21 | imgs = refer.Imgs 22 | cats = refer.Cats 23 | sents = refer.Sents 24 | """ 25 | # create sets of mapping 26 | # 1) Refs: {ref_id: ref} 27 | # 2) Anns: {ann_id: ann} 28 | # 3) Imgs: {image_id: image} 29 | # 4) Cats: {category_id: category_name} 30 | # 5) Sents: {sent_id: sent} 31 | # 6) imgToRefs: {image_id: refs} 32 | # 7) imgToAnns: {image_id: anns} 33 | # 8) refToAnn: {ref_id: ann} 34 | # 9) annToRef: {ann_id: ref} 35 | # 10) catToRefs: {category_id: refs} 36 | # 11) sentToRef: {sent_id: ref} 37 | # 12) sentToTokens: {sent_id: tokens} 38 | 39 | Refs: List[Dict], "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences" 40 | "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent" 41 | Anns: List[Dict], "segmentation", "area", "iscrowd", "image_id", "bbox", "category_id", "id" 42 | Imgs: List[Dict], "license", "file_name", "coco_url", "height", "width", "date_captured", "flickr_url", "id" 43 | Cats: List[Dict], "supercategory", "name", "id" 44 | Sents: List[Dict], "tokens"(List), "raw", "sent_id", "sent", here the "sent_id" is consistent 45 | """ 46 | print('Dataset [%s_%s] contains: ' % (dataset, dataset_split)) 47 | ref_ids = refer.getRefIds() 48 | image_ids = refer.getImgIds() 49 | print('There are %s expressions for %s refereed objects in %s images.' % (len(refer.Sents), len(ref_ids), len(image_ids))) 50 | 51 | print('\nAmong them:') 52 | if dataset == 'refcoco': 53 | splits = ['train', 'val', 'testA', 'testB'] 54 | elif dataset == 'refcoco+': 55 | splits = ['train', 'val', 'testA', 'testB'] 56 | elif dataset == 'refcocog': 57 | splits = ['train', 'val', 'test'] # we don't have test split for refcocog right now. 58 | 59 | for split in splits: 60 | ref_ids = refer.getRefIds(split=split) 61 | print(' %s referred objects are in split [%s].' % (len(ref_ids), split)) 62 | 63 | with open(os.path.join(dataset_dir, "instances.json"), "r") as f: 64 | ann_json = json.load(f) 65 | 66 | 67 | # 1. for each split: train, val... 68 | for split in splits: 69 | max_length = 0 # max length of a sentence 70 | 71 | coco_ann = { 72 | "info": "", 73 | "licenses": "", 74 | "images": [], # each caption is a image sample 75 | "annotations": [], 76 | "categories": [] 77 | } 78 | coco_ann['info'], coco_ann['licenses'], coco_ann['categories'] = \ 79 | ann_json['info'], ann_json['licenses'], ann_json['categories'] 80 | 81 | num_images = 0 # each caption is a sample, create a "images" and a "annotations", since each image has one box 82 | ref_ids = refer.getRefIds(split=split) 83 | # 2. for each referred object 84 | for i in tqdm(ref_ids): 85 | ref = refs[i] 86 | # "sent_ids", "file_name", "ann_id", "ref_id", "image_id", "category_id", "split", "sentences" 87 | # "sentences": List[Dict], "tokens"(List), "raw", "sent_id", "sent" 88 | img = imgs[ref["image_id"]] 89 | ann = anns[ref["ann_id"]] 90 | 91 | # 3. for each sentence, which is a sample 92 | for sentence in ref["sentences"]: 93 | num_images += 1 94 | # append image info 95 | image_info = { 96 | "file_name": img["file_name"], 97 | "height": img["height"], 98 | "width": img["width"], 99 | "original_id": img["id"], 100 | "id": num_images, 101 | "caption": sentence["sent"], 102 | "dataset_name": dataset 103 | } 104 | coco_ann["images"].append(image_info) 105 | 106 | # append annotation info 107 | ann_info = { 108 | "segmentation": ann["segmentation"], 109 | "area": ann["area"], 110 | "iscrowd": ann["iscrowd"], 111 | "bbox": ann["bbox"], 112 | "image_id": num_images, 113 | "category_id": ann["category_id"], 114 | "id": num_images, 115 | "original_id": ann["id"] 116 | } 117 | coco_ann["annotations"].append(ann_info) 118 | 119 | max_length = max(max_length, len(sentence["tokens"])) 120 | 121 | print("Total expression: {} in split {}".format(num_images, split)) 122 | print("Max sentence length of the split: ", max_length) 123 | # save the json file 124 | save_file = "instances_{}_{}.json".format(dataset, split) 125 | with open(os.path.join(output_dir, save_file), 'w') as f: 126 | json.dump(coco_ann, f) 127 | 128 | if __name__ == '__main__': 129 | datasets = ["refcoco", "refcoco+", "refcocog"] 130 | datasets_split = ["unc", "unc", "umd"] 131 | for (dataset, dataset_split) in zip(datasets, datasets_split): 132 | convert_to_coco(dataset=dataset, dataset_split=dataset_split) 133 | print("") 134 | 135 | 136 | """ 137 | # original mapping 138 | {'person': 1, 'bicycle': 2, 'car': 3, 'motorcycle': 4, 'airplane': 5, 'bus': 6, 'train': 7, 'truck': 8, 'boat': 9, 139 | 'traffic light': 10, 'fire hydrant': 11, 'stop sign': 13, 'parking meter': 14, 'bench': 15, 'bird': 16, 'cat': 17, 140 | 'dog': 18, 'horse': 19, 'sheep': 20, 'cow': 21, 'elephant': 22, 'bear': 23, 'zebra': 24, 'giraffe': 25, 'backpack': 27, 141 | 'umbrella': 28, 'handbag': 31, 'tie': 32, 'suitcase': 33, 'frisbee': 34, 'skis': 35, 'snowboard': 36, 'sports ball': 37, 142 | 'kite': 38, 'baseball bat': 39, 'baseball glove': 40, 'skateboard': 41, 'surfboard': 42, 'tennis racket': 43, 'bottle': 44, 143 | 'wine glass': 46, 'cup': 47, 'fork': 48, 'knife': 49, 'spoon': 50, 'bowl': 51, 'banana': 52, 'apple': 53, 'sandwich': 54, 144 | 'orange': 55, 'broccoli': 56, 'carrot': 57, 'hot dog': 58, 'pizza': 59, 'donut': 60, 'cake': 61, 'chair': 62, 'couch': 63, 145 | 'potted plant': 64, 'bed': 65, 'dining table': 67, 'toilet': 70, 'tv': 72, 'laptop': 73, 'mouse': 74, 'remote': 75, 146 | 'keyboard': 76, 'cell phone': 77, 'microwave': 78, 'oven': 79, 'toaster': 80, 'sink': 81, 'refrigerator': 82, 'book': 84, 147 | 'clock': 85, 'vase': 86, 'scissors': 87, 'teddy bear': 88, 'hair drier': 89, 'toothbrush': 90} 148 | 149 | """ 150 | -------------------------------------------------------------------------------- /datasets/refexp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | """ 4 | COCO dataset which returns image_id for evaluation. 5 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 6 | """ 7 | from pathlib import Path 8 | 9 | import torch 10 | import torch.utils.data 11 | import torchvision 12 | from pycocotools import mask as coco_mask 13 | 14 | import datasets.transforms_image as T 15 | 16 | 17 | class ModulatedDetection(torchvision.datasets.CocoDetection): 18 | def __init__(self, img_folder, ann_file, transforms, return_masks): 19 | super(ModulatedDetection, self).__init__(img_folder, ann_file) 20 | self._transforms = transforms 21 | self.prepare = ConvertCocoPolysToMask(return_masks) 22 | 23 | def __getitem__(self, idx): 24 | instance_check = False 25 | while not instance_check: 26 | img, target = super(ModulatedDetection, self).__getitem__(idx) 27 | image_id = self.ids[idx] 28 | coco_img = self.coco.loadImgs(image_id)[0] 29 | caption = coco_img["caption"] 30 | dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None 31 | target = {"image_id": image_id, "annotations": target, "caption": caption} 32 | img, target = self.prepare(img, target) 33 | if self._transforms is not None: 34 | img, target = self._transforms(img, target) 35 | target["dataset_name"] = dataset_name 36 | for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: 37 | if extra_key in coco_img: 38 | target[extra_key] = coco_img[extra_key] # box xyxy -> cxcywh 39 | # FIXME: handle "valid", since some box may be removed due to random crop 40 | target["valid"] = torch.tensor([1]) if len(target["area"]) != 0 else torch.tensor([0]) 41 | 42 | if torch.any(target['valid'] == 1): # at leatst one instance 43 | instance_check = True 44 | else: 45 | import random 46 | idx = random.randint(0, self.__len__() - 1) 47 | return img.unsqueeze(0), target 48 | # return img: [1, 3, H, W], the first dimension means T = 1. 49 | 50 | 51 | def convert_coco_poly_to_mask(segmentations, height, width): 52 | masks = [] 53 | for polygons in segmentations: 54 | rles = coco_mask.frPyObjects(polygons, height, width) 55 | mask = coco_mask.decode(rles) 56 | if len(mask.shape) < 3: 57 | mask = mask[..., None] 58 | mask = torch.as_tensor(mask, dtype=torch.uint8) 59 | mask = mask.any(dim=2) 60 | masks.append(mask) 61 | if masks: 62 | masks = torch.stack(masks, dim=0) 63 | else: 64 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 65 | return masks 66 | 67 | 68 | class ConvertCocoPolysToMask(object): 69 | def __init__(self, return_masks=False): 70 | self.return_masks = return_masks 71 | 72 | def __call__(self, image, target): 73 | w, h = image.size 74 | 75 | image_id = target["image_id"] 76 | image_id = torch.tensor([image_id]) 77 | 78 | anno = target["annotations"] 79 | caption = target["caption"] if "caption" in target else None 80 | 81 | anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] 82 | 83 | boxes = [obj["bbox"] for obj in anno] 84 | # guard against no boxes via resizing 85 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 86 | boxes[:, 2:] += boxes[:, :2] # xminyminwh -> xyxy 87 | boxes[:, 0::2].clamp_(min=0, max=w) 88 | boxes[:, 1::2].clamp_(min=0, max=h) 89 | 90 | classes = [obj["category_id"] for obj in anno] 91 | classes = torch.tensor(classes, dtype=torch.int64) 92 | 93 | if self.return_masks: 94 | segmentations = [obj["segmentation"] for obj in anno] 95 | masks = convert_coco_poly_to_mask(segmentations, h, w) 96 | 97 | # keep the valid boxes 98 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 99 | boxes = boxes[keep] 100 | classes = classes[keep] 101 | if self.return_masks: 102 | masks = masks[keep] 103 | 104 | target = {} 105 | target["boxes"] = boxes 106 | target["labels"] = classes 107 | if caption is not None: 108 | target["caption"] = caption 109 | if self.return_masks: 110 | target["masks"] = masks 111 | target["image_id"] = image_id 112 | 113 | # for conversion to coco api 114 | area = torch.tensor([obj["area"] for obj in anno]) 115 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 116 | target["area"] = area[keep] 117 | target["iscrowd"] = iscrowd[keep] 118 | target["valid"] = torch.tensor([1]) 119 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 120 | target["size"] = torch.as_tensor([int(h), int(w)]) 121 | return image, target 122 | 123 | 124 | def make_coco_transforms(image_set, cautious): 125 | 126 | normalize = T.Compose([T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 127 | 128 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768] 129 | final_scales = [296, 328, 360, 392, 416, 448, 480, 512] 130 | 131 | max_size = 800 132 | if image_set == "train": 133 | horizontal = [] if cautious else [T.RandomHorizontalFlip()] 134 | return T.Compose( 135 | horizontal 136 | + [ 137 | T.RandomSelect( 138 | T.RandomResize(scales, max_size=max_size), 139 | T.Compose( 140 | [ 141 | T.RandomResize([400, 500, 600]), 142 | T.RandomSizeCrop(384, 600, respect_boxes=cautious), 143 | T.RandomResize(final_scales, max_size=640), 144 | ] 145 | ), 146 | ), 147 | normalize, 148 | ] 149 | ) 150 | 151 | if image_set == "val": 152 | return T.Compose( 153 | [ 154 | T.RandomResize([360], max_size=640), 155 | normalize, 156 | ] 157 | ) 158 | 159 | raise ValueError(f"unknown {image_set}") 160 | 161 | 162 | def build(dataset_file, image_set, args): 163 | root = Path(args.coco_path) 164 | assert root.exists(), f"provided COCO path {root} does not exist" 165 | mode = "instances" 166 | dataset = dataset_file 167 | PATHS = { 168 | "train": (root / "train2014", root / dataset / f"{mode}_{dataset}_train.json"), 169 | "val": (root / "train2014", root / dataset / f"{mode}_{dataset}_val.json"), 170 | } 171 | 172 | img_folder, ann_file = PATHS[image_set] 173 | dataset = ModulatedDetection( 174 | img_folder, 175 | ann_file, 176 | transforms=make_coco_transforms(image_set, False), 177 | return_masks=args.masks, 178 | ) 179 | return dataset -------------------------------------------------------------------------------- /davis2017/metrics.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import cv2 4 | 5 | 6 | def db_eval_iou(annotation, segmentation, void_pixels=None): 7 | """ Compute region similarity as the Jaccard Index. 8 | Arguments: 9 | annotation (ndarray): binary annotation map. 10 | segmentation (ndarray): binary segmentation map. 11 | void_pixels (ndarray): optional mask with void pixels 12 | 13 | Return: 14 | jaccard (float): region similarity 15 | """ 16 | assert annotation.shape == segmentation.shape, \ 17 | f'Annotation({annotation.shape}) and segmentation:{segmentation.shape} dimensions do not match.' 18 | annotation = annotation.astype(np.bool) 19 | segmentation = segmentation.astype(np.bool) 20 | 21 | if void_pixels is not None: 22 | assert annotation.shape == void_pixels.shape, \ 23 | f'Annotation({annotation.shape}) and void pixels:{void_pixels.shape} dimensions do not match.' 24 | void_pixels = void_pixels.astype(np.bool) 25 | else: 26 | void_pixels = np.zeros_like(segmentation) 27 | 28 | # Intersection between all sets 29 | inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 30 | union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 31 | 32 | j = inters / union 33 | if j.ndim == 0: 34 | j = 1 if np.isclose(union, 0) else j 35 | else: 36 | j[np.isclose(union, 0)] = 1 37 | return j 38 | 39 | 40 | def db_eval_boundary(annotation, segmentation, void_pixels=None, bound_th=0.008): 41 | assert annotation.shape == segmentation.shape 42 | if void_pixels is not None: 43 | assert annotation.shape == void_pixels.shape 44 | if annotation.ndim == 3: 45 | n_frames = annotation.shape[0] 46 | f_res = np.zeros(n_frames) 47 | for frame_id in range(n_frames): 48 | void_pixels_frame = None if void_pixels is None else void_pixels[frame_id, :, :, ] 49 | f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], void_pixels_frame, bound_th=bound_th) 50 | elif annotation.ndim == 2: 51 | f_res = f_measure(segmentation, annotation, void_pixels, bound_th=bound_th) 52 | else: 53 | raise ValueError(f'db_eval_boundary does not support tensors with {annotation.ndim} dimensions') 54 | return f_res 55 | 56 | 57 | def f_measure(foreground_mask, gt_mask, void_pixels=None, bound_th=0.008): 58 | """ 59 | Compute mean,recall and decay from per-frame evaluation. 60 | Calculates precision/recall for boundaries between foreground_mask and 61 | gt_mask using morphological operators to speed it up. 62 | 63 | Arguments: 64 | foreground_mask (ndarray): binary segmentation image. 65 | gt_mask (ndarray): binary annotated image. 66 | void_pixels (ndarray): optional mask with void pixels 67 | 68 | Returns: 69 | F (float): boundaries F-measure 70 | """ 71 | assert np.atleast_3d(foreground_mask).shape[2] == 1 72 | if void_pixels is not None: 73 | void_pixels = void_pixels.astype(np.bool) 74 | else: 75 | void_pixels = np.zeros_like(foreground_mask).astype(np.bool) 76 | 77 | bound_pix = bound_th if bound_th >= 1 else \ 78 | np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 79 | 80 | # Get the pixel boundaries of both masks 81 | fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) 82 | gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) 83 | 84 | from skimage.morphology import disk 85 | 86 | # fg_dil = binary_dilation(fg_boundary, disk(bound_pix)) 87 | fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 88 | # gt_dil = binary_dilation(gt_boundary, disk(bound_pix)) 89 | gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 90 | 91 | # Get the intersection 92 | gt_match = gt_boundary * fg_dil 93 | fg_match = fg_boundary * gt_dil 94 | 95 | # Area of the intersection 96 | n_fg = np.sum(fg_boundary) 97 | n_gt = np.sum(gt_boundary) 98 | 99 | # % Compute precision and recall 100 | if n_fg == 0 and n_gt > 0: 101 | precision = 1 102 | recall = 0 103 | elif n_fg > 0 and n_gt == 0: 104 | precision = 0 105 | recall = 1 106 | elif n_fg == 0 and n_gt == 0: 107 | precision = 1 108 | recall = 1 109 | else: 110 | precision = np.sum(fg_match) / float(n_fg) 111 | recall = np.sum(gt_match) / float(n_gt) 112 | 113 | # Compute F measure 114 | if precision + recall == 0: 115 | F = 0 116 | else: 117 | F = 2 * precision * recall / (precision + recall) 118 | 119 | return F 120 | 121 | 122 | def _seg2bmap(seg, width=None, height=None): 123 | """ 124 | From a segmentation, compute a binary boundary map with 1 pixel wide 125 | boundaries. The boundary pixels are offset by 1/2 pixel towards the 126 | origin from the actual segment boundary. 127 | Arguments: 128 | seg : Segments labeled from 1..k. 129 | width : Width of desired bmap <= seg.shape[1] 130 | height : Height of desired bmap <= seg.shape[0] 131 | Returns: 132 | bmap (ndarray): Binary boundary map. 133 | David Martin 134 | January 2003 135 | """ 136 | 137 | seg = seg.astype(np.bool) 138 | seg[seg > 0] = 1 139 | 140 | assert np.atleast_3d(seg).shape[2] == 1 141 | 142 | width = seg.shape[1] if width is None else width 143 | height = seg.shape[0] if height is None else height 144 | 145 | h, w = seg.shape[:2] 146 | 147 | ar1 = float(width) / float(height) 148 | ar2 = float(w) / float(h) 149 | 150 | assert not ( 151 | width > w | height > h | abs(ar1 - ar2) > 0.01 152 | ), "Can" "t convert %dx%d seg to %dx%d bmap." % (w, h, width, height) 153 | 154 | e = np.zeros_like(seg) 155 | s = np.zeros_like(seg) 156 | se = np.zeros_like(seg) 157 | 158 | e[:, :-1] = seg[:, 1:] 159 | s[:-1, :] = seg[1:, :] 160 | se[:-1, :-1] = seg[1:, 1:] 161 | 162 | b = seg ^ e | seg ^ s | seg ^ se 163 | b[-1, :] = seg[-1, :] ^ e[-1, :] 164 | b[:, -1] = seg[:, -1] ^ s[:, -1] 165 | b[-1, -1] = 0 166 | 167 | if w == width and h == height: 168 | bmap = b 169 | else: 170 | bmap = np.zeros((height, width)) 171 | for x in range(w): 172 | for y in range(h): 173 | if b[y, x]: 174 | j = 1 + math.floor((y - 1) + height / h) 175 | i = 1 + math.floor((x - 1) + width / h) 176 | bmap[j, i] = 1 177 | 178 | return bmap 179 | 180 | 181 | if __name__ == '__main__': 182 | from davis2017.davis import DAVIS 183 | from davis2017.results import Results 184 | 185 | dataset = DAVIS(root='input_dir/ref', subset='val', sequences='aerobatics') 186 | results = Results(root_dir='examples/osvos') 187 | # Test timing F measure 188 | for seq in dataset.get_sequences(): 189 | all_gt_masks, _, all_masks_id = dataset.get_all_masks(seq, True) 190 | all_gt_masks, all_masks_id = all_gt_masks[:, 1:-1, :, :], all_masks_id[1:-1] 191 | all_res_masks = results.read_masks(seq, all_masks_id) 192 | f_metrics_res = np.zeros(all_gt_masks.shape[:2]) 193 | for ii in range(all_gt_masks.shape[0]): 194 | f_metrics_res[ii, :] = db_eval_boundary(all_gt_masks[ii, ...], all_res_masks[ii, ...]) 195 | 196 | # Run using to profile code: python -m cProfile -o f_measure.prof metrics.py 197 | # snakeviz f_measure.prof 198 | -------------------------------------------------------------------------------- /models/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 66 | value.data() + n * im2col_step_ * per_value_size, 67 | spatial_shapes.data(), 68 | level_start_index.data(), 69 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 70 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 71 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 72 | columns.data()); 73 | 74 | })); 75 | } 76 | 77 | output = output.view({batch, num_query, num_heads*channels}); 78 | 79 | return output; 80 | } 81 | 82 | 83 | std::vector ms_deform_attn_cuda_backward( 84 | const at::Tensor &value, 85 | const at::Tensor &spatial_shapes, 86 | const at::Tensor &level_start_index, 87 | const at::Tensor &sampling_loc, 88 | const at::Tensor &attn_weight, 89 | const at::Tensor &grad_output, 90 | const int im2col_step) 91 | { 92 | 93 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 94 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 95 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 96 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 97 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 98 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 99 | 100 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 101 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 102 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 103 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 104 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 105 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 106 | 107 | const int batch = value.size(0); 108 | const int spatial_size = value.size(1); 109 | const int num_heads = value.size(2); 110 | const int channels = value.size(3); 111 | 112 | const int num_levels = spatial_shapes.size(0); 113 | 114 | const int num_query = sampling_loc.size(1); 115 | const int num_point = sampling_loc.size(4); 116 | 117 | const int im2col_step_ = std::min(batch, im2col_step); 118 | 119 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 120 | 121 | auto grad_value = at::zeros_like(value); 122 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 123 | auto grad_attn_weight = at::zeros_like(attn_weight); 124 | 125 | const int batch_n = im2col_step_; 126 | auto per_value_size = spatial_size * num_heads * channels; 127 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 128 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 129 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 130 | 131 | for (int n = 0; n < batch/im2col_step_; ++n) 132 | { 133 | auto grad_output_g = grad_output_n.select(0, n); 134 | AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 135 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 136 | grad_output_g.data(), 137 | value.data() + n * im2col_step_ * per_value_size, 138 | spatial_shapes.data(), 139 | level_start_index.data(), 140 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 141 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 142 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 143 | grad_value.data() + n * im2col_step_ * per_value_size, 144 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 145 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 146 | 147 | })); 148 | } 149 | 150 | return { 151 | grad_value, grad_sampling_loc, grad_attn_weight 152 | }; 153 | } -------------------------------------------------------------------------------- /models/postprocessors.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | """Postprocessors class to transform MDETR output according to the downstream task""" 3 | from typing import Dict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn 9 | import pycocotools.mask as mask_util 10 | 11 | from util import box_ops 12 | 13 | 14 | class A2DSentencesPostProcess(nn.Module): 15 | """ 16 | This module converts the model's output into the format expected by the coco api for the given task 17 | """ 18 | def __init__(self, threshold=0.5): 19 | super().__init__() 20 | self.threshold = threshold 21 | 22 | @torch.no_grad() 23 | def forward(self, outputs, orig_target_sizes, max_target_sizes): 24 | """ Perform the computation 25 | Parameters: 26 | outputs: raw outputs of the model 27 | orig_target_sizes: original size of the samples (no augmentations or padding) 28 | max_target_sizes: size of samples (input to model) after size augmentation. 29 | NOTE: the max_padding_size is 4x out_masks.shape[-2:] 30 | """ 31 | assert len(orig_target_sizes) == len(max_target_sizes) 32 | 33 | # there is only one valid frames, thus T=1 34 | out_logits = outputs['pred_logits'][:, 0, :, 0] # [B, T, N, 1] -> [B, N] 35 | out_masks = outputs['pred_masks'][:, 0, :, :, :] # [B, T, N, out_h, out_w] -> [B, N, out_h, out_w] 36 | out_h, out_w = out_masks.shape[-2:] 37 | 38 | scores = out_logits.sigmoid() 39 | pred_masks = F.interpolate(out_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False) # [B, N, H, W] 40 | pred_masks = (pred_masks.sigmoid() > 0.5) # [B, N, H, W] 41 | processed_pred_masks, rle_masks = [], [] 42 | # for each batch 43 | for f_pred_masks, resized_size, orig_size in zip(pred_masks, max_target_sizes, orig_target_sizes): 44 | f_mask_h, f_mask_w = resized_size # resized shape without padding 45 | f_pred_masks_no_pad = f_pred_masks[:, :f_mask_h, :f_mask_w].unsqueeze(1) # remove the samples' padding, [:, 1, h, w] 46 | # resize the samples back to their original dataset (target) size for evaluation 47 | f_pred_masks_processed = F.interpolate(f_pred_masks_no_pad.float(), size=tuple(orig_size.tolist()), mode="nearest") # origin size, [:, 1, h, w] 48 | f_pred_rle_masks = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 49 | for mask in f_pred_masks_processed.cpu()] 50 | processed_pred_masks.append(f_pred_masks_processed) 51 | rle_masks.append(f_pred_rle_masks) 52 | predictions = [{'scores': s, 'masks': m, 'rle_masks': rle} 53 | for s, m, rle in zip(scores, processed_pred_masks, rle_masks)] 54 | return predictions 55 | 56 | 57 | # PostProcess for pretraining 58 | class PostProcess(nn.Module): 59 | """ This module converts the model's output into the format expected by the coco api""" 60 | 61 | @torch.no_grad() 62 | def forward(self, outputs, target_sizes): 63 | """Perform the computation 64 | Parameters: 65 | outputs: raw outputs of the model 66 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 67 | For evaluation, this must be the original image size (before any data augmentation) 68 | For visualization, this should be the image size after data augment, but before padding 69 | Returns: 70 | 71 | """ 72 | out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] 73 | 74 | assert len(out_logits) == len(target_sizes) 75 | assert target_sizes.shape[1] == 2 76 | 77 | # coco, num_frames=1 78 | out_logits = outputs["pred_logits"].flatten(1, 2) 79 | out_boxes = outputs["pred_boxes"].flatten(1, 2) 80 | bs, num_queries = out_logits.shape[:2] 81 | 82 | prob = out_logits.sigmoid() # [bs, num_queries, num_classes] 83 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True) 84 | scores = topk_values # [bs, num_queries] 85 | topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries] 86 | labels = topk_indexes % out_logits.shape[2] # [bs, num_queries] 87 | 88 | boxes = box_ops.box_cxcywh_to_xyxy(out_boxes) # [bs, num_queries, 4] 89 | boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1,1,4)) 90 | 91 | # and from relative [0, 1] to absolute [0, height] coordinates 92 | img_h, img_w = target_sizes.unbind(1) 93 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 94 | boxes = boxes * scale_fct[:, None, :] # [bs, num_queries, 4] 95 | 96 | assert len(scores) == len(labels) == len(boxes) 97 | # binary for the pretraining 98 | results = [{"scores": s, "labels": torch.ones_like(l), "boxes": b} for s, l, b in zip(scores, labels, boxes)] 99 | 100 | return results 101 | 102 | 103 | class PostProcessSegm(nn.Module): 104 | """Similar to PostProcess but for segmentation masks. 105 | This processor is to be called sequentially after PostProcess. 106 | Args: 107 | threshold: threshold that will be applied to binarize the segmentation masks. 108 | """ 109 | 110 | def __init__(self, threshold=0.5): 111 | super().__init__() 112 | self.threshold = threshold 113 | 114 | @torch.no_grad() 115 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 116 | """Perform the computation 117 | Parameters: 118 | results: already pre-processed boxes (output of PostProcess) NOTE here 119 | outputs: raw outputs of the model 120 | orig_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 121 | For evaluation, this must be the original image size (before any data augmentation) 122 | For visualization, this should be the image size after data augment, but before padding 123 | max_target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 124 | after data augmentation. 125 | """ 126 | assert len(orig_target_sizes) == len(max_target_sizes) 127 | 128 | out_logits = outputs["pred_logits"].flatten(1, 2) 129 | out_masks = outputs["pred_masks"].flatten(1, 2) 130 | bs, num_queries = out_logits.shape[:2] 131 | 132 | prob = out_logits.sigmoid() 133 | topk_values, topk_indexes = torch.topk(prob.view(out_logits.shape[0], -1), k=num_queries, dim=1, sorted=True) 134 | scores = topk_values # [bs, num_queries] 135 | topk_boxes = topk_indexes // out_logits.shape[2] # [bs, num_queries] 136 | labels = topk_indexes % out_logits.shape[2] # [bs, num_queries] 137 | 138 | outputs_masks = [out_m[topk_boxes[i]].unsqueeze(0) for i, out_m, in enumerate(out_masks)] # list[Tensor] 139 | outputs_masks = torch.cat(outputs_masks, dim=0) # [bs, num_queries, H, W] 140 | out_h, out_w = outputs_masks.shape[-2:] 141 | 142 | # max_h, max_w = max_target_sizes.max(0)[0].tolist() 143 | # outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) 144 | outputs_masks = F.interpolate(outputs_masks, size=(out_h*4, out_w*4), mode="bilinear", align_corners=False) 145 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() 146 | 147 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): 148 | img_h, img_w = t[0], t[1] 149 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) # [:, 1, h, w] 150 | results[i]["masks"] = F.interpolate( 151 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 152 | ).byte() 153 | results[i]["rle_masks"] = [mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 154 | for mask in results[i]["masks"].cpu()] 155 | 156 | return results 157 | 158 | 159 | 160 | def build_postprocessors(args, dataset_name): 161 | if dataset_name == 'a2d' or dataset_name == 'jhmdb': 162 | postprocessors = A2DSentencesPostProcess(threshold=args.threshold) 163 | else: 164 | # for coco pretrain postprocessor 165 | postprocessors: Dict[str, nn.Module] = {"bbox": PostProcess()} 166 | if args.masks: 167 | postprocessors["segm"] = PostProcessSegm(threshold=args.threshold) 168 | return postprocessors 169 | -------------------------------------------------------------------------------- /models/convnext.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # H-DETR 3 | # Copyright (c) 2022 Peking University & Microsoft Research Asia. All Rights Reserved. 4 | # Licensed under the MIT-style license found in the LICENSE file in the root directory 5 | # ------------------------------------------------------------------------ 6 | # Deformable DETR 7 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 8 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 9 | # ------------------------------------------------------------------------ 10 | # Modified from DETR (https://github.com/facebookresearch/detr) 11 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 12 | # ------------------------------------------------------------------------ 13 | 14 | """ 15 | Backbone modules. 16 | """ 17 | from collections import OrderedDict 18 | 19 | import torch 20 | from torch import nn 21 | from typing import Dict, List 22 | import torch.nn.functional as F 23 | 24 | from util.misc import NestedTensor 25 | from torch import Tensor 26 | import math 27 | 28 | from .position_encoding import build_position_encoding 29 | 30 | from timm.models import convnext 31 | from einops import rearrange 32 | 33 | def checkpoint_filter_fn(state_dict, model): 34 | """ Remap FB checkpoints -> timm """ 35 | if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: 36 | return state_dict # non-FB checkpoint 37 | if 'model' in state_dict: 38 | state_dict = state_dict['model'] 39 | 40 | out_dict = {} 41 | if 'visual.trunk.stem.0.weight' in state_dict: 42 | out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')} 43 | if 'visual.head.proj.weight' in state_dict: 44 | out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight'] 45 | out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) 46 | elif 'visual.head.mlp.fc1.weight' in state_dict: 47 | out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight'] 48 | out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias'] 49 | out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight'] 50 | out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0]) 51 | return out_dict 52 | 53 | import re 54 | for k, v in state_dict.items(): 55 | k = k.replace('downsample_layers.0.', 'stem.') 56 | k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) 57 | k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) 58 | k = k.replace('dwconv', 'conv_dw') 59 | k = k.replace('pwconv', 'mlp.fc') 60 | if 'grn' in k: 61 | k = k.replace('grn.beta', 'mlp.grn.bias') 62 | k = k.replace('grn.gamma', 'mlp.grn.weight') 63 | v = v.reshape(v.shape[-1]) 64 | k = k.replace('head.', 'head.fc.') 65 | if k.startswith('norm.'): 66 | k = k.replace('norm', 'head.norm') 67 | if v.ndim == 2 and 'head' not in k: 68 | model_shape = model.state_dict()[k].shape 69 | v = v.reshape(model_shape) 70 | out_dict[k] = v 71 | 72 | return out_dict 73 | 74 | 75 | class ConvNeXt(convnext.ConvNeXt): 76 | def __init__(self, out_indices, **kwargs): 77 | super(ConvNeXt, self).__init__(**kwargs) 78 | self.out_indices = out_indices 79 | del self.norm_pre 80 | del self.head 81 | 82 | def forward_features(self, x): 83 | x = self.stem(x) 84 | outputs = {} 85 | for stage_idx in range(len(self.stages)): 86 | x = self.stages[stage_idx](x) 87 | if stage_idx in self.out_indices: 88 | outputs[stage_idx] = x 89 | return outputs 90 | 91 | 92 | def convnext_large(pretrained=False, **kwargs): 93 | model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) 94 | if pretrained: 95 | ckpt = torch.load(pretrained, map_location='cpu') 96 | ckpt = checkpoint_filter_fn(ckpt, model) 97 | load_logs = model.load_state_dict(ckpt, strict=False) 98 | print(load_logs) 99 | return model 100 | 101 | 102 | def convnext_xxlarge(pretrained=False, **kwargs): 103 | model = ConvNeXt(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], norm_eps=kwargs.pop('norm_eps', 1e-5), 104 | **kwargs) 105 | if pretrained: 106 | ckpt = torch.load(pretrained, map_location='cpu') 107 | ckpt = checkpoint_filter_fn(ckpt, model) 108 | load_logs = model.load_state_dict(ckpt, strict=False) 109 | print(load_logs) 110 | return model 111 | 112 | 113 | class ConvnextBackbone(nn.Module): 114 | def __init__( 115 | self, backbone: str, train_backbone: bool, return_interm_layers: bool, args 116 | ): 117 | super().__init__() 118 | if args.num_feature_levels == 4: 119 | out_indices = (0, 1, 2, 3) 120 | else: 121 | out_indices = (1, 2, 3) 122 | 123 | if backbone == 'convnext_large': 124 | backbone = convnext_large(args.backbone_pretrained, out_indices=out_indices, 125 | drop_path_rate=0.5) 126 | embed_dim = 192 127 | elif backbone == 'convnext_xxlarge': 128 | backbone = convnext_xxlarge(args.backbone_pretrained, out_indices=out_indices, 129 | drop_path_rate=0.5) 130 | embed_dim = 384 131 | else: 132 | raise NotImplementedError 133 | 134 | self.train_backbone = train_backbone 135 | for name, parameter in backbone.named_parameters(): 136 | if not train_backbone: 137 | parameter.requires_grad_(False) 138 | 139 | if return_interm_layers: 140 | 141 | if args.num_feature_levels == 4: 142 | self.strides = [4, 8, 16, 32] 143 | self.num_channels = [ 144 | embed_dim, 145 | embed_dim * 2, 146 | embed_dim * 4, 147 | embed_dim * 8, 148 | ] 149 | else: 150 | self.strides = [8, 16, 32] 151 | self.num_channels = [ 152 | embed_dim * 2, 153 | embed_dim * 4, 154 | embed_dim * 8, 155 | ] 156 | else: 157 | self.strides = [32] 158 | self.num_channels = [embed_dim * 8] 159 | 160 | self.norm_layers = nn.ModuleList([nn.LayerNorm(ndim) for ndim in self.num_channels]) 161 | 162 | self.body = backbone 163 | 164 | def forward(self, tensor_list: NestedTensor): 165 | 166 | if self.train_backbone: 167 | xs = self.body.forward_features(tensor_list.tensors) 168 | else: 169 | with torch.no_grad(): 170 | xs = self.body.forward_features(tensor_list.tensors) 171 | 172 | out: Dict[str, NestedTensor] = {} 173 | for layer_idx, (name, x) in enumerate(xs.items()): 174 | m = tensor_list.mask 175 | assert m is not None 176 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 177 | b, c, h, w = x.shape 178 | x = self.norm_layers[layer_idx](x.view(b, c, -1).transpose(1, 2)) 179 | x = x.transpose(1, 2).view(b, c, h, w) 180 | out[name] = NestedTensor(x, mask) 181 | return out 182 | 183 | 184 | 185 | class Joiner(nn.Sequential): 186 | def __init__(self, backbone, position_embedding): 187 | super().__init__(backbone, position_embedding) 188 | self.strides = backbone.strides 189 | self.num_channels = backbone.num_channels 190 | 191 | def forward(self, tensor_list: NestedTensor, no_norm=None): 192 | 193 | tensor_list.tensors = rearrange(tensor_list.tensors, 'b t c h w -> (b t) c h w') 194 | tensor_list.mask = rearrange(tensor_list.mask, 'b t h w -> (b t) h w') 195 | 196 | if no_norm is None: 197 | xs = self[0](tensor_list) 198 | else: 199 | xs = self[0](tensor_list, no_norm) 200 | out: List[NestedTensor] = [] 201 | pos = [] 202 | for name, x in sorted(xs.items()): 203 | out.append(x) 204 | 205 | # position encoding 206 | for x in out: 207 | pos.append(self[1](x).to(x.tensors.dtype)) 208 | 209 | return out, pos 210 | 211 | 212 | def build_convnext_backbone(args): 213 | position_embedding = build_position_encoding(args) 214 | train_backbone = args.lr_backbone > 0 215 | 216 | return_interm_layers = args.masks or (args.num_feature_levels > 1) 217 | backbone = ConvnextBackbone(args.backbone, train_backbone, return_interm_layers, args) 218 | model = Joiner(backbone, position_embedding) 219 | return model 220 | -------------------------------------------------------------------------------- /opts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def get_args_parser(): 4 | parser = argparse.ArgumentParser('ReferFormer training and inference scripts.', add_help=False) 5 | parser.add_argument('--lr', default=1e-4, type=float) 6 | parser.add_argument('--lr_backbone', default=6e-6, type=float) 7 | parser.add_argument('--lr_backbone_names', default=['backbone.0'], type=str, nargs='+') 8 | parser.add_argument('--lr_text_encoder', default=1e-5, type=float) 9 | parser.add_argument('--lr_text_encoder_names', default=['text_encoder'], type=str, nargs='+') 10 | parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+') 11 | parser.add_argument('--lr_linear_proj_mult', default=1.0, type=float) 12 | parser.add_argument('--batch_size', default=1, type=int) 13 | parser.add_argument('--weight_decay', default=5e-4, type=float) 14 | parser.add_argument('--epochs', default=10, type=int) 15 | parser.add_argument('--lr_drop', default=[6, 8], type=int, nargs='+') 16 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 17 | help='gradient clipping max norm') 18 | 19 | # Model parameters 20 | # load the pretrained weights 21 | parser.add_argument('--pretrained_weights', type=str, default=None, 22 | help="Path to the pretrained model.") 23 | 24 | # Variants of Deformable DETR 25 | parser.add_argument('--with_box_refine', default=False, action='store_true') 26 | parser.add_argument('--two_stage', default=False, action='store_true') # NOTE: must be false 27 | 28 | # * Backbone 29 | # ["resnet50", "resnet101", "swin_t_p4w7", "swin_s_p4w7", "swin_b_p4w7", "swin_l_p4w7"] 30 | # ["video_swin_t_p4w7", "video_swin_s_p4w7", "video_swin_b_p4w7"] 31 | parser.add_argument('--backbone', default='resnet50', type=str, 32 | help="Name of the convolutional backbone to use") 33 | parser.add_argument('--backbone_pretrained', default=None, type=str, 34 | help="if use swin backbone and train from scratch, the path to the pretrained weights") 35 | parser.add_argument('--use_checkpoint', action='store_true', help='whether use checkpoint for swin/video swin backbone') 36 | parser.add_argument('--dilation', action='store_true', # DC5 37 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 38 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 39 | help="Type of positional embedding to use on top of the image features") 40 | parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels') 41 | 42 | # * Transformer 43 | parser.add_argument('--enc_layers', default=4, type=int, 44 | help="Number of encoding layers in the transformer") 45 | parser.add_argument('--dec_layers', default=4, type=int, 46 | help="Number of decoding layers in the transformer") 47 | parser.add_argument('--dim_feedforward', default=2048, type=int, 48 | help="Intermediate size of the feedforward layers in the transformer blocks") 49 | parser.add_argument('--hidden_dim', default=256, type=int, 50 | help="Size of the embeddings (dimension of the transformer)") 51 | parser.add_argument('--dropout', default=0.1, type=float, 52 | help="Dropout applied in the transformer") 53 | parser.add_argument('--nheads', default=8, type=int, 54 | help="Number of attention heads inside the transformer's attentions") 55 | parser.add_argument('--num_frames', default=5, type=int, 56 | help="Number of clip frames for training") 57 | parser.add_argument('--num_queries', default=5, type=int, 58 | help="Number of query slots, all frames share the same queries") 59 | parser.add_argument('--dec_n_points', default=4, type=int) 60 | parser.add_argument('--enc_n_points', default=4, type=int) 61 | parser.add_argument('--pre_norm', action='store_true') 62 | # for text 63 | parser.add_argument('--freeze_text_encoder', action='store_true') # default: False 64 | 65 | # * Segmentation 66 | parser.add_argument('--masks', action='store_true', 67 | help="Train segmentation head if the flag is provided") 68 | parser.add_argument('--mask_dim', default=256, type=int, 69 | help="Size of the mask embeddings (dimension of the dynamic mask conv)") 70 | parser.add_argument('--controller_layers', default=3, type=int, 71 | help="Dynamic conv layer number") 72 | parser.add_argument('--dynamic_mask_channels', default=8, type=int, 73 | help="Dynamic conv final channel number") 74 | parser.add_argument('--no_rel_coord', dest='rel_coord', action='store_false', 75 | help="Disables relative coordinates") 76 | 77 | # Loss 78 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 79 | help="Disables auxiliary decoding losses (loss at each layer)") 80 | # * Matcher 81 | parser.add_argument('--set_cost_class', default=2, type=float, 82 | help="Class coefficient in the matching cost") 83 | parser.add_argument('--set_cost_bbox', default=5, type=float, 84 | help="L1 box coefficient in the matching cost") 85 | parser.add_argument('--set_cost_giou', default=2, type=float, 86 | help="giou box coefficient in the matching cost") 87 | parser.add_argument('--set_cost_mask', default=2, type=float, 88 | help="mask coefficient in the matching cost") 89 | parser.add_argument('--set_cost_dice', default=5, type=float, 90 | help="mask coefficient in the matching cost") 91 | # * Loss coefficients 92 | parser.add_argument('--mask_loss_coef', default=2, type=float) 93 | parser.add_argument('--dice_loss_coef', default=5, type=float) 94 | parser.add_argument('--cls_loss_coef', default=2, type=float) 95 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 96 | parser.add_argument('--giou_loss_coef', default=2, type=float) 97 | parser.add_argument('--eos_coef', default=0.1, type=float, 98 | help="Relative classification weight of the no-object class") 99 | parser.add_argument('--focal_alpha', default=0.25, type=float) 100 | 101 | # dataset parameters 102 | # ['ytvos', 'davis', 'refcoco', 'refcoco+', 'refcocog', 'all'] 103 | # 'all': using the three ref datasets for pretraining 104 | parser.add_argument('--dataset_file', default='ytvos', help='Dataset name') 105 | parser.add_argument('--coco_path', type=str, default='data/coco') 106 | parser.add_argument('--ytvos_path', type=str, default='data/ref-youtube-vos') 107 | parser.add_argument('--davis_path', type=str, default='data/ref-davis') 108 | parser.add_argument('--max_skip', default=3, type=int, help="max skip frame number") 109 | parser.add_argument('--max_size', default=640, type=int, help="max size for the frame") 110 | parser.add_argument('--binary', action='store_true') 111 | parser.add_argument('--remove_difficult', action='store_true') 112 | 113 | parser.add_argument('--output_dir', default='output', 114 | help='path where to save, empty for no saving') 115 | parser.add_argument('--device', default='cuda', 116 | help='device to use for training / testing') 117 | parser.add_argument('--seed', default=42, type=int) 118 | parser.add_argument('--resume', default='', help='resume from checkpoint') 119 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 120 | help='start epoch') 121 | parser.add_argument('--eval', action='store_true') 122 | parser.add_argument('--num_workers', default=4, type=int) 123 | 124 | # test setting 125 | parser.add_argument('--threshold', default=0.5, type=float) # binary threshold for mask 126 | parser.add_argument('--ngpu', default=1, type=int, help='gpu number when inference for ref-ytvos and ref-davis') 127 | parser.add_argument('--split', default='valid', type=str, choices=['valid', 'test']) 128 | parser.add_argument('--visualize', action='store_true', help='whether visualize the masks during inference') 129 | 130 | # distributed training parameters 131 | parser.add_argument('--world_size', default=1, type=int, 132 | help='number of distributed processes') 133 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 134 | parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory') 135 | 136 | parser.add_argument('--model_ckpt', type=str, default='checkpoint.pth') 137 | parser.add_argument('--test_num_ckpt', default=-2, type=int) 138 | 139 | # training technologies 140 | parser.add_argument("--use_fp16", default=False, action="store_true") 141 | 142 | return parser -------------------------------------------------------------------------------- /datasets/coco_eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | COCO evaluator that works in distributed mode. 3 | 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 5 | The difference is that there is less copy-pasting from pycocotools 6 | in the end of the file, as python3 can suppress prints with contextlib 7 | """ 8 | import os 9 | import contextlib 10 | import copy 11 | import numpy as np 12 | import torch 13 | 14 | from pycocotools.cocoeval import COCOeval 15 | from pycocotools.coco import COCO 16 | import pycocotools.mask as mask_util 17 | 18 | from util.misc import all_gather 19 | 20 | 21 | class CocoEvaluator(object): 22 | def __init__(self, coco_gt, iou_types, useCats=False): 23 | assert isinstance(iou_types, (list, tuple)) 24 | coco_gt = copy.deepcopy(coco_gt) 25 | self.coco_gt = coco_gt 26 | 27 | self.iou_types = iou_types 28 | self.coco_eval = {} 29 | for iou_type in iou_types: 30 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 31 | self.coco_eval[iou_type].params.useCats = useCats 32 | 33 | self.img_ids = [] 34 | self.eval_imgs = {k: [] for k in iou_types} 35 | self.useCats = useCats 36 | 37 | def update(self, predictions): 38 | img_ids = list(np.unique(list(predictions.keys()))) 39 | self.img_ids.extend(img_ids) 40 | 41 | for iou_type in self.iou_types: 42 | results = self.prepare(predictions, iou_type) 43 | 44 | # suppress pycocotools prints 45 | with open(os.devnull, 'w') as devnull: 46 | with contextlib.redirect_stdout(devnull): 47 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 48 | coco_eval = self.coco_eval[iou_type] 49 | 50 | coco_eval.cocoDt = coco_dt 51 | coco_eval.params.imgIds = list(img_ids) 52 | coco_eval.params.useCats = self.useCats 53 | img_ids, eval_imgs = evaluate(coco_eval) 54 | 55 | self.eval_imgs[iou_type].append(eval_imgs) 56 | 57 | def synchronize_between_processes(self): 58 | for iou_type in self.iou_types: 59 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 60 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 61 | 62 | def accumulate(self): 63 | for coco_eval in self.coco_eval.values(): 64 | coco_eval.accumulate() 65 | 66 | def summarize(self): 67 | for iou_type, coco_eval in self.coco_eval.items(): 68 | print("IoU metric: {}".format(iou_type)) 69 | coco_eval.summarize() 70 | 71 | def prepare(self, predictions, iou_type): 72 | if iou_type == "bbox": 73 | return self.prepare_for_coco_detection(predictions) 74 | elif iou_type == "segm": 75 | return self.prepare_for_coco_segmentation(predictions) 76 | elif iou_type == "keypoints": 77 | return self.prepare_for_coco_keypoint(predictions) 78 | else: 79 | raise ValueError("Unknown iou type {}".format(iou_type)) 80 | 81 | def prepare_for_coco_detection(self, predictions): 82 | coco_results = [] 83 | for original_id, prediction in predictions.items(): 84 | if len(prediction) == 0: 85 | continue 86 | 87 | boxes = prediction["boxes"] 88 | boxes = convert_to_xywh(boxes).tolist() 89 | scores = prediction["scores"].tolist() 90 | labels = prediction["labels"].tolist() 91 | 92 | coco_results.extend( 93 | [ 94 | { 95 | "image_id": original_id, 96 | "category_id": labels[k], 97 | "bbox": box, 98 | "score": scores[k], 99 | } 100 | for k, box in enumerate(boxes) 101 | ] 102 | ) 103 | return coco_results 104 | 105 | def prepare_for_coco_segmentation(self, predictions): 106 | coco_results = [] 107 | for original_id, prediction in predictions.items(): 108 | if len(prediction) == 0: 109 | continue 110 | 111 | scores = prediction["scores"] 112 | labels = prediction["labels"] 113 | masks = prediction["masks"] 114 | 115 | masks = masks > 0.5 116 | 117 | scores = prediction["scores"].tolist() 118 | labels = prediction["labels"].tolist() 119 | 120 | rles = [ 121 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 122 | for mask in masks 123 | ] 124 | for rle in rles: 125 | rle["counts"] = rle["counts"].decode("utf-8") 126 | 127 | coco_results.extend( 128 | [ 129 | { 130 | "image_id": original_id, 131 | "category_id": labels[k], 132 | "segmentation": rle, 133 | "score": scores[k], 134 | } 135 | for k, rle in enumerate(rles) 136 | ] 137 | ) 138 | return coco_results 139 | 140 | def prepare_for_coco_keypoint(self, predictions): 141 | coco_results = [] 142 | for original_id, prediction in predictions.items(): 143 | if len(prediction) == 0: 144 | continue 145 | 146 | boxes = prediction["boxes"] 147 | boxes = convert_to_xywh(boxes).tolist() 148 | scores = prediction["scores"].tolist() 149 | labels = prediction["labels"].tolist() 150 | keypoints = prediction["keypoints"] 151 | keypoints = keypoints.flatten(start_dim=1).tolist() 152 | 153 | coco_results.extend( 154 | [ 155 | { 156 | "image_id": original_id, 157 | "category_id": labels[k], 158 | 'keypoints': keypoint, 159 | "score": scores[k], 160 | } 161 | for k, keypoint in enumerate(keypoints) 162 | ] 163 | ) 164 | return coco_results 165 | 166 | 167 | def convert_to_xywh(boxes): 168 | xmin, ymin, xmax, ymax = boxes.unbind(1) 169 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 170 | 171 | 172 | def merge(img_ids, eval_imgs): 173 | all_img_ids = all_gather(img_ids) 174 | all_eval_imgs = all_gather(eval_imgs) 175 | 176 | merged_img_ids = [] 177 | for p in all_img_ids: 178 | merged_img_ids.extend(p) 179 | 180 | merged_eval_imgs = [] 181 | for p in all_eval_imgs: 182 | merged_eval_imgs.append(p) 183 | 184 | merged_img_ids = np.array(merged_img_ids) 185 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 186 | 187 | # keep only unique (and in sorted order) images 188 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 189 | merged_eval_imgs = merged_eval_imgs[..., idx] 190 | 191 | return merged_img_ids, merged_eval_imgs 192 | 193 | 194 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 195 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 196 | img_ids = list(img_ids) 197 | eval_imgs = list(eval_imgs.flatten()) 198 | 199 | coco_eval.evalImgs = eval_imgs 200 | coco_eval.params.imgIds = img_ids 201 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 202 | 203 | 204 | ################################################################# 205 | # From pycocotools, just removed the prints and fixed 206 | # a Python3 bug about unicode not defined 207 | ################################################################# 208 | 209 | 210 | def evaluate(self): 211 | ''' 212 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 213 | :return: None 214 | ''' 215 | # tic = time.time() 216 | # print('Running per image evaluation...') 217 | p = self.params 218 | # add backward compatibility if useSegm is specified in params 219 | if p.useSegm is not None: 220 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 221 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 222 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 223 | p.imgIds = list(np.unique(p.imgIds)) 224 | if p.useCats: 225 | p.catIds = list(np.unique(p.catIds)) 226 | p.maxDets = sorted(p.maxDets) 227 | self.params = p 228 | 229 | self._prepare() 230 | # loop through images, area range, max detection number 231 | catIds = p.catIds if p.useCats else [-1] 232 | 233 | if p.iouType == 'segm' or p.iouType == 'bbox': 234 | computeIoU = self.computeIoU 235 | elif p.iouType == 'keypoints': 236 | computeIoU = self.computeOks 237 | self.ious = { 238 | (imgId, catId): computeIoU(imgId, catId) 239 | for imgId in p.imgIds 240 | for catId in catIds} 241 | 242 | evaluateImg = self.evaluateImg 243 | maxDet = p.maxDets[-1] 244 | evalImgs = [ 245 | evaluateImg(imgId, catId, areaRng, maxDet) 246 | for catId in catIds 247 | for areaRng in p.areaRng 248 | for imgId in p.imgIds 249 | ] 250 | # this is NOT in the pycocotools code, but could be done outside 251 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 252 | self._paramsEval = copy.deepcopy(self.params) 253 | # toc = time.time() 254 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 255 | return p.imgIds, evalImgs 256 | 257 | ################################################################# 258 | # end of straight copy from pycocotools, just removing the prints 259 | ################################################################# 260 | -------------------------------------------------------------------------------- /datasets/refexp2seq.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Aishwarya Kamath & Nicolas Carion. Licensed under the Apache License 2.0. All Rights Reserved 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | # For building refcoco, refcoco+, refcocog datasets 4 | """ 5 | COCO dataset which returns image_id for evaluation. 6 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 7 | """ 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.utils.data 12 | import torchvision 13 | from pycocotools import mask as coco_mask 14 | 15 | import random 16 | import numpy as np 17 | from PIL import Image 18 | 19 | import datasets.transforms_video as T 20 | from datasets.image_to_seq_augmenter import ImageToSeqAugmenter 21 | 22 | from util.box_ops import masks_to_boxes 23 | 24 | 25 | class ModulatedDetection(torchvision.datasets.CocoDetection): 26 | def __init__(self, img_folder, ann_file, num_frames, transforms, return_masks): 27 | super(ModulatedDetection, self).__init__(img_folder, ann_file) 28 | self._transforms = transforms 29 | self.prepare = ConvertCocoPolysToMask(return_masks) 30 | self.num_frames = num_frames 31 | self.augmenter = ImageToSeqAugmenter(perspective=True, affine=True, motion_blur=True, 32 | rotation_range=(-20, 20), perspective_magnitude=0.08, 33 | hue_saturation_range=(-5, 5), brightness_range=(-40, 40), 34 | motion_blur_prob=0.25, motion_blur_kernel_sizes=(9, 11), 35 | translate_range=(-0.1, 0.1)) 36 | 37 | def apply_random_sequence_shuffle(self, images, instance_masks): 38 | perm = list(range(self.num_frames)) 39 | random.shuffle(perm) 40 | images = [images[i] for i in perm] 41 | instance_masks = [instance_masks[i] for i in perm] 42 | return images, instance_masks 43 | 44 | def __getitem__(self, idx): 45 | instance_check = False 46 | while not instance_check: 47 | img, target = super(ModulatedDetection, self).__getitem__(idx) 48 | image_id = self.ids[idx] 49 | coco_img = self.coco.loadImgs(image_id)[0] 50 | caption = coco_img["caption"] 51 | dataset_name = coco_img["dataset_name"] if "dataset_name" in coco_img else None 52 | target = {"image_id": image_id, "annotations": target, "caption": caption} 53 | img, target = self.prepare(img, target) 54 | 55 | # for a image, we rotate it to form a clip 56 | seq_images, seq_instance_masks = [img], [target['masks'].numpy()] 57 | numpy_masks = target['masks'].numpy() # [1, H, W] 58 | 59 | numinst = len(numpy_masks) 60 | assert numinst == 1 61 | for t in range(self.num_frames - 1): 62 | im_trafo, instance_masks_trafo = self.augmenter(np.asarray(img), numpy_masks) 63 | im_trafo = Image.fromarray(np.uint8(im_trafo)) 64 | seq_images.append(im_trafo) 65 | seq_instance_masks.append(np.stack(instance_masks_trafo, axis=0)) 66 | seq_images, seq_instance_masks = self.apply_random_sequence_shuffle(seq_images, seq_instance_masks) 67 | output_inst_masks = [] 68 | for inst_i in range(numinst): 69 | inst_i_mask = [] 70 | for f_i in range(self.num_frames): 71 | inst_i_mask.append(seq_instance_masks[f_i][inst_i]) 72 | output_inst_masks.append( np.stack(inst_i_mask, axis=0) ) 73 | 74 | output_inst_masks = torch.from_numpy( np.stack(output_inst_masks, axis=0) ) 75 | target['masks'] = output_inst_masks.flatten(0,1) # [t, h, w] 76 | target['boxes'] = masks_to_boxes(target['masks']) # [t, 4] 77 | target['labels'] = target['labels'].repeat(self.num_frames) # [t,] 78 | 79 | if self._transforms is not None: 80 | img, target = self._transforms(seq_images, target) 81 | target["dataset_name"] = dataset_name 82 | for extra_key in ["sentence_id", "original_img_id", "original_id", "task_id"]: 83 | if extra_key in coco_img: 84 | target[extra_key] = coco_img[extra_key] # box xyxy -> cxcywh 85 | # FIXME: handle "valid", since some box may be removed due to random crop 86 | if torch.any(target['valid'] == 1): # at leatst one instance 87 | instance_check = True 88 | else: 89 | idx = random.randint(0, self.__len__() - 1) 90 | 91 | # set the gt box of empty mask to [0, 0, 0, 0] 92 | for inst_id in range(len(target['boxes'])): 93 | if target['masks'][inst_id].max()<1: 94 | target['boxes'][inst_id] = torch.zeros(4).to(target['boxes'][inst_id]) 95 | 96 | target['boxes']=target['boxes'].clamp(1e-6) 97 | return torch.stack(img,dim=0), target 98 | 99 | 100 | def convert_coco_poly_to_mask(segmentations, height, width): 101 | masks = [] 102 | for polygons in segmentations: 103 | rles = coco_mask.frPyObjects(polygons, height, width) 104 | mask = coco_mask.decode(rles) 105 | if len(mask.shape) < 3: 106 | mask = mask[..., None] 107 | mask = torch.as_tensor(mask, dtype=torch.uint8) 108 | mask = mask.any(dim=2) 109 | masks.append(mask) 110 | if masks: 111 | masks = torch.stack(masks, dim=0) 112 | else: 113 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 114 | return masks 115 | 116 | 117 | class ConvertCocoPolysToMask(object): 118 | def __init__(self, return_masks=False): 119 | self.return_masks = return_masks 120 | 121 | def __call__(self, image, target): 122 | w, h = image.size 123 | 124 | image_id = target["image_id"] 125 | image_id = torch.tensor([image_id]) 126 | 127 | anno = target["annotations"] 128 | caption = target["caption"] if "caption" in target else None 129 | 130 | anno = [obj for obj in anno if "iscrowd" not in obj or obj["iscrowd"] == 0] 131 | 132 | boxes = [obj["bbox"] for obj in anno] 133 | # guard against no boxes via resizing 134 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 135 | boxes[:, 2:] += boxes[:, :2] # xminyminwh -> xyxy 136 | boxes[:, 0::2].clamp_(min=0, max=w) 137 | boxes[:, 1::2].clamp_(min=0, max=h) 138 | 139 | classes = [obj["category_id"] for obj in anno] 140 | classes = torch.tensor(classes, dtype=torch.int64) 141 | 142 | if self.return_masks: 143 | segmentations = [obj["segmentation"] for obj in anno] 144 | masks = convert_coco_poly_to_mask(segmentations, h, w) 145 | 146 | # keep the valid boxes 147 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 148 | boxes = boxes[keep] 149 | classes = classes[keep] 150 | if self.return_masks: 151 | masks = masks[keep] 152 | 153 | target = {} 154 | target["boxes"] = boxes 155 | target["labels"] = classes 156 | if caption is not None: 157 | target["caption"] = caption 158 | if self.return_masks: 159 | target["masks"] = masks 160 | target["image_id"] = image_id 161 | 162 | # for conversion to coco api 163 | area = torch.tensor([obj["area"] for obj in anno]) 164 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 165 | target["area"] = area[keep] 166 | target["iscrowd"] = iscrowd[keep] 167 | target["valid"] = torch.tensor([1]) 168 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 169 | target["size"] = torch.as_tensor([int(h), int(w)]) 170 | return image, target 171 | 172 | 173 | def make_coco_transforms(image_set, max_size): 174 | normalize = T.Compose([ 175 | T.ToTensor(), 176 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 177 | ]) 178 | 179 | scales = [288, 320, 352, 392, 416, 448, 480, 512] 180 | 181 | if image_set == 'train': 182 | return T.Compose([ 183 | T.RandomHorizontalFlip(), 184 | T.PhotometricDistort(), 185 | T.RandomSelect( 186 | T.Compose([ 187 | T.RandomResize(scales, max_size=max_size), 188 | T.Check(), 189 | ]), 190 | T.Compose([ 191 | T.RandomResize([400, 500, 600]), 192 | T.RandomSizeCrop(384, 600), 193 | T.RandomResize(scales, max_size=max_size), 194 | T.Check(), 195 | ]) 196 | ), 197 | normalize, 198 | ]) 199 | 200 | if image_set == "val": 201 | return T.Compose( 202 | [ 203 | T.RandomResize([360], max_size=640), 204 | normalize, 205 | ] 206 | ) 207 | 208 | raise ValueError(f"unknown {image_set}") 209 | 210 | 211 | def build(dataset_file, image_set, args): 212 | root = Path(args.coco_path) 213 | assert root.exists(), f"provided COCO path {root} does not exist" 214 | mode = "instances" 215 | dataset = dataset_file 216 | PATHS = { 217 | "train": (root / "train2014", root / dataset / f"{mode}_{dataset}_train.json"), 218 | "val": (root / "train2014", root / dataset / f"{mode}_{dataset}_val.json"), 219 | } 220 | 221 | img_folder, ann_file = PATHS[image_set] 222 | dataset = ModulatedDetection( 223 | img_folder, 224 | ann_file, 225 | num_frames=args.num_frames, 226 | transforms=make_coco_transforms(image_set, args.max_size), 227 | return_masks=args.masks, 228 | ) 229 | return dataset 230 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | """ 2 | Instance Sequence Matching 3 | Modified from DETR (https://github.com/facebookresearch/detr) 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou, multi_iou 11 | from util.misc import nested_tensor_from_tensor_list 12 | 13 | INF = 100000000 14 | 15 | def dice_coef(inputs, targets): 16 | inputs = inputs.sigmoid() 17 | inputs = inputs.flatten(1).unsqueeze(1) # [N, 1, THW] 18 | targets = targets.flatten(1).unsqueeze(0) # [1, M, THW] 19 | numerator = 2 * (inputs * targets).sum(2) 20 | denominator = inputs.sum(-1) + targets.sum(-1) 21 | 22 | # NOTE coef doesn't be subtracted to 1 as it is not necessary for computing costs 23 | coef = (numerator + 1) / (denominator + 1) 24 | return coef 25 | 26 | def sigmoid_focal_coef(inputs, targets, alpha: float = 0.25, gamma: float = 2): 27 | N, M = len(inputs), len(targets) 28 | inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1) # [N, M, THW] 29 | targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1) # [N, M, THW] 30 | 31 | prob = inputs.sigmoid() 32 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 33 | p_t = prob * targets + (1 - prob) * (1 - targets) 34 | coef = ce_loss * ((1 - p_t) ** gamma) 35 | 36 | if alpha >= 0: 37 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 38 | coef = alpha_t * coef 39 | 40 | return coef.mean(2) # [N, M] 41 | 42 | 43 | class HungarianMatcher(nn.Module): 44 | """This class computes an assignment between the targets and the predictions of the network 45 | 46 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 47 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 48 | while the others are un-matched (and thus treated as non-objects). 49 | """ 50 | 51 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1, 52 | cost_mask: float = 1, cost_dice: float = 1, num_classes: int = 1): 53 | """Creates the matcher 54 | 55 | Params: 56 | cost_class: This is the relative weight of the classification error in the matching cost 57 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 58 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 59 | cost_mask: This is the relative weight of the sigmoid focal loss of the mask in the matching cost 60 | cost_dice: This is the relative weight of the dice loss of the mask in the matching cost 61 | """ 62 | super().__init__() 63 | self.cost_class = cost_class 64 | self.cost_bbox = cost_bbox 65 | self.cost_giou = cost_giou 66 | self.cost_mask = cost_mask 67 | self.cost_dice = cost_dice 68 | self.num_classes = num_classes 69 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0 \ 70 | or cost_mask != 0 or cost_dice != 0, "all costs cant be 0" 71 | self.mask_out_stride = 4 72 | 73 | @torch.no_grad() 74 | def forward(self, outputs, targets): 75 | """ Performs the matching 76 | Params: 77 | outputs: This is a dict that contains at least these entries: 78 | "pred_logits": Tensor of dim [batch_size, num_queries_per_frame, num_frames, num_classes] with the classification logits 79 | "pred_boxes": Tensor of dim [batch_size, num_queries_per_frame, num_frames, 4] with the predicted box coordinates 80 | "pred_masks": Tensor of dim [batch_size, num_queries_per_frame, num_frames, h, w], h,w in 4x size 81 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 82 | NOTE: Since every frame has one object at most 83 | "labels": Tensor of dim [num_frames] (where num_target_boxes is the number of ground-truth 84 | objects in the target) containing the class labels 85 | "boxes": Tensor of dim [num_frames, 4] containing the target box coordinates 86 | "masks": Tensor of dim [num_frames, h, w], h,w in origin size 87 | Returns: 88 | A list of size batch_size, containing tuples of (index_i, index_j) where: 89 | - index_i is the indices of the selected predictions (in order) 90 | - index_j is the indices of the corresponding selected targets (in order) 91 | For each batch element, it holds: 92 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 93 | """ 94 | src_logits = outputs["pred_logits"] 95 | src_boxes = outputs["pred_boxes"] 96 | src_masks = outputs["pred_masks"] 97 | 98 | bs, nf, nq, h, w = src_masks.shape 99 | 100 | # handle mask padding issue 101 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], 102 | size_divisibility=32, 103 | split=False).decompose() 104 | target_masks = target_masks.to(src_masks) # [B, T, H, W] 105 | 106 | # downsample ground truth masks with ratio mask_out_stride 107 | start = int(self.mask_out_stride // 2) 108 | im_h, im_w = target_masks.shape[-2:] 109 | 110 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 111 | assert target_masks.size(2) * self.mask_out_stride == im_h 112 | assert target_masks.size(3) * self.mask_out_stride == im_w 113 | 114 | indices = [] 115 | for i in range(bs): 116 | out_prob = src_logits[i].sigmoid() 117 | out_bbox = src_boxes[i] 118 | out_mask = src_masks[i] 119 | 120 | tgt_ids = targets[i]["labels"] 121 | tgt_bbox = targets[i]["boxes"] 122 | tgt_mask = target_masks[i] 123 | tgt_valid = targets[i]["valid"] 124 | 125 | # class cost 126 | # we average the cost on valid frames 127 | cost_class = [] 128 | for t in range(nf): 129 | if tgt_valid[t] == 0: 130 | continue 131 | 132 | out_prob_split = out_prob[t] 133 | tgt_ids_split = tgt_ids[t].unsqueeze(0) 134 | 135 | # Compute the classification cost. 136 | alpha = 0.25 137 | gamma = 2.0 138 | neg_cost_class = (1 - alpha) * (out_prob_split ** gamma) * (-(1 - out_prob_split + 1e-8).log()) 139 | pos_cost_class = alpha * ((1 - out_prob_split) ** gamma) * (-(out_prob_split + 1e-8).log()) 140 | if self.num_classes == 1: # binary referred 141 | cost_class_split = pos_cost_class[:, [0]] - neg_cost_class[:, [0]] 142 | else: 143 | cost_class_split = pos_cost_class[:, tgt_ids_split] - neg_cost_class[:, tgt_ids_split] 144 | 145 | cost_class.append(cost_class_split) 146 | cost_class = torch.stack(cost_class, dim=0).mean(0) # [q, 1] 147 | 148 | # box cost 149 | # we average the cost on every frame 150 | cost_bbox, cost_giou = [], [] 151 | for t in range(nf): 152 | out_bbox_split = out_bbox[t] 153 | tgt_bbox_split = tgt_bbox[t].unsqueeze(0) 154 | 155 | # Compute the L1 cost between boxes 156 | cost_bbox_split = torch.cdist(out_bbox_split, tgt_bbox_split, p=1) 157 | 158 | # Compute the giou cost betwen boxes 159 | cost_giou_split = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox_split), 160 | box_cxcywh_to_xyxy(tgt_bbox_split)) 161 | 162 | cost_bbox.append(cost_bbox_split) 163 | cost_giou.append(cost_giou_split) 164 | cost_bbox = torch.stack(cost_bbox, dim=0).mean(0) 165 | cost_giou = torch.stack(cost_giou, dim=0).mean(0) 166 | 167 | # mask cost 168 | # Compute the focal loss between masks 169 | cost_mask = sigmoid_focal_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0)) 170 | 171 | # Compute the dice loss betwen masks 172 | cost_dice = -dice_coef(out_mask.transpose(0, 1), tgt_mask.unsqueeze(0)) 173 | 174 | # Final cost matrix 175 | C = self.cost_class * cost_class + self.cost_bbox * cost_bbox + self.cost_giou * cost_giou + \ 176 | self.cost_mask * cost_mask + self.cost_dice * cost_dice # [q, 1] 177 | 178 | # Only has one tgt, MinCost Matcher 179 | _, src_ind = torch.min(C, dim=0) 180 | tgt_ind = torch.arange(1).to(src_ind) 181 | indices.append((src_ind.long(), tgt_ind.long())) 182 | 183 | # list[tuple], length is batch_size 184 | return indices 185 | 186 | 187 | def build_matcher(args): 188 | if args.binary: 189 | num_classes = 1 190 | else: 191 | if args.dataset_file == 'ytvos': 192 | num_classes = 65 193 | elif args.dataset_file == 'davis': 194 | num_classes = 78 195 | elif args.dataset_file == 'a2d' or args.dataset_file == 'jhmdb': 196 | num_classes = 1 197 | else: 198 | num_classes = 91 # for coco 199 | return HungarianMatcher(cost_class=args.set_cost_class, 200 | cost_bbox=args.set_cost_bbox, 201 | cost_giou=args.set_cost_giou, 202 | cost_mask=args.set_cost_mask, 203 | cost_dice=args.set_cost_dice, 204 | num_classes=num_classes) 205 | 206 | 207 | -------------------------------------------------------------------------------- /datasets/davis.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ref-Davis17 data loader 3 | """ 4 | from pathlib import Path 5 | 6 | import torch 7 | from torch.autograd.grad_mode import F 8 | from torch.utils.data import Dataset 9 | import datasets.transforms_video as T 10 | import os 11 | from PIL import Image 12 | 13 | import json 14 | import numpy as np 15 | import random 16 | 17 | from datasets.categories import davis_category_dict as category_dict 18 | 19 | 20 | class DAVIS17Dataset(Dataset): 21 | """ 22 | A dataset class for the Refer-DAVIS17 dataset which was first introduced in the paper: 23 | "Video Object Segmentation with Language Referring Expressions" 24 | (see https://arxiv.org/pdf/1803.08006.pdf). 25 | There are 60/30 videos in train/validation set, respectively. 26 | """ 27 | def __init__(self, img_folder: Path, ann_file: Path, transforms, return_masks: bool, 28 | num_frames: int, max_skip: int): 29 | self.img_folder = img_folder 30 | self.ann_file = ann_file 31 | self._transforms = transforms 32 | self.return_masks = return_masks # not used 33 | self.num_frames = num_frames 34 | self.max_skip = max_skip 35 | # create video meta data 36 | self.prepare_metas() 37 | 38 | print('\n video num: ', len(self.videos), ' clip num: ', len(self.metas)) 39 | print('\n') 40 | 41 | 42 | def prepare_metas(self): 43 | # read object information 44 | with open(os.path.join(str(self.img_folder), 'meta.json'), 'r') as f: 45 | subset_metas_by_video = json.load(f)['videos'] 46 | 47 | # read expression data 48 | with open(str(self.ann_file), 'r') as f: 49 | subset_expressions_by_video = json.load(f)['videos'] 50 | self.videos = list(subset_expressions_by_video.keys()) 51 | 52 | self.metas = [] 53 | for vid in self.videos: 54 | vid_meta = subset_metas_by_video[vid] 55 | vid_data = subset_expressions_by_video[vid] 56 | vid_frames = sorted(vid_data['frames']) 57 | vid_len = len(vid_frames) 58 | for exp_id, exp_dict in vid_data['expressions'].items(): 59 | for frame_id in range(0, vid_len, self.num_frames): 60 | meta = {} 61 | meta['video'] = vid 62 | meta['exp'] = exp_dict['exp'] 63 | meta['obj_id'] = int(exp_dict['obj_id']) 64 | meta['frames'] = vid_frames 65 | meta['frame_id'] = frame_id 66 | # get object category 67 | obj_id = exp_dict['obj_id'] 68 | meta['category'] = vid_meta['objects'][obj_id]['category'] 69 | self.metas.append(meta) 70 | 71 | @staticmethod 72 | def bounding_box(img): 73 | rows = np.any(img, axis=1) 74 | cols = np.any(img, axis=0) 75 | rmin, rmax = np.where(rows)[0][[0, -1]] 76 | cmin, cmax = np.where(cols)[0][[0, -1]] 77 | return rmin, rmax, cmin, cmax # y1, y2, x1, x2 78 | 79 | 80 | def __len__(self): 81 | return len(self.metas) 82 | 83 | 84 | def __getitem__(self, idx): 85 | instance_check = False 86 | while not instance_check: 87 | meta = self.metas[idx] # dict 88 | 89 | video, exp, obj_id, category, frames, frame_id = \ 90 | meta['video'], meta['exp'], meta['obj_id'], meta['category'], meta['frames'], meta['frame_id'] 91 | # clean up the caption 92 | exp = " ".join(exp.lower().split()) 93 | category_id = category_dict[category] 94 | vid_len = len(frames) 95 | 96 | num_frames = self.num_frames 97 | # random sparse sample 98 | sample_indx = [frame_id] 99 | # local sample 100 | sample_id_before = random.randint(1, 3) 101 | sample_id_after = random.randint(1, 3) 102 | local_indx = [max(0, frame_id - sample_id_before), min(vid_len - 1, frame_id + sample_id_after)] 103 | sample_indx.extend(local_indx) 104 | 105 | # global sampling 106 | if num_frames > 3: 107 | all_inds = list(range(vid_len)) 108 | global_inds = all_inds[:min(sample_indx)] + all_inds[max(sample_indx):] 109 | global_n = num_frames - len(sample_indx) 110 | if len(global_inds) > global_n: 111 | select_id = random.sample(range(len(global_inds)), global_n) 112 | for s_id in select_id: 113 | sample_indx.append(global_inds[s_id]) 114 | elif vid_len >=global_n: # sample long range global frames 115 | select_id = random.sample(range(vid_len), global_n) 116 | for s_id in select_id: 117 | sample_indx.append(all_inds[s_id]) 118 | else: 119 | select_id = random.sample(range(vid_len), global_n - vid_len) + list(range(vid_len)) 120 | for s_id in select_id: 121 | sample_indx.append(all_inds[s_id]) 122 | sample_indx.sort() 123 | 124 | # read frames and masks 125 | imgs, labels, boxes, masks, valid = [], [], [], [], [] 126 | for j in range(self.num_frames): 127 | frame_indx = sample_indx[j] 128 | frame_name = frames[frame_indx] 129 | img_path = os.path.join(str(self.img_folder), 'JPEGImages', video, frame_name + '.jpg') 130 | mask_path = os.path.join(str(self.img_folder), 'Annotations', video, frame_name + '.png') 131 | img = Image.open(img_path).convert('RGB') 132 | mask = Image.open(mask_path).convert('P') 133 | 134 | # create the target 135 | label = torch.tensor(category_id) 136 | mask = np.array(mask) 137 | mask = (mask==obj_id).astype(np.float32) # 0,1 binary 138 | if (mask > 0).any(): 139 | y1, y2, x1, x2 = self.bounding_box(mask) 140 | box = torch.tensor([x1, y1, x2, y2]).to(torch.float) 141 | valid.append(1) 142 | else: # some frame didn't contain the instance 143 | box = torch.tensor([0, 0, 0, 0]).to(torch.float) 144 | valid.append(0) 145 | mask = torch.from_numpy(mask) 146 | 147 | # append 148 | imgs.append(img) 149 | labels.append(label) 150 | masks.append(mask) 151 | boxes.append(box) 152 | 153 | # transform 154 | w, h = img.size 155 | labels = torch.stack(labels, dim=0) 156 | boxes = torch.stack(boxes, dim=0) 157 | boxes[:, 0::2].clamp_(min=0, max=w) 158 | boxes[:, 1::2].clamp_(min=0, max=h) 159 | masks = torch.stack(masks, dim=0) 160 | target = { 161 | 'frames_idx': torch.tensor(sample_indx), # [T,] 162 | 'labels': labels, # [T,] 163 | 'boxes': boxes, # [T, 4], xyxy 164 | 'masks': masks, # [T, H, W] 165 | 'valid': torch.tensor(valid), # [T,] 166 | 'caption': exp, 167 | 'orig_size': torch.as_tensor([int(h), int(w)]), 168 | 'size': torch.as_tensor([int(h), int(w)]) 169 | } 170 | 171 | # "boxes" normalize to [0, 1] and transform from xyxy to cxcywh in self._transform 172 | imgs, target = self._transforms(imgs, target) 173 | imgs = torch.stack(imgs, dim=0) # [T, 3, H, W] 174 | 175 | # FIXME: handle "valid", since some box may be removed due to random crop 176 | if torch.any(target['valid'] == 1): # at leatst one instance 177 | instance_check = True 178 | else: 179 | idx = random.randint(0, self.__len__() - 1) 180 | 181 | return imgs, target 182 | 183 | 184 | 185 | def make_coco_transforms(image_set, max_size=640): 186 | normalize = T.Compose([ 187 | T.ToTensor(), 188 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 189 | ]) 190 | 191 | scales = [288, 320, 352, 392, 416, 448, 480, 512] 192 | 193 | if image_set == 'train': 194 | return T.Compose([ 195 | T.RandomHorizontalFlip(), 196 | T.PhotometricDistort(), 197 | T.RandomSelect( 198 | T.Compose([ 199 | T.RandomResize(scales, max_size=max_size), 200 | T.Check(), 201 | ]), 202 | T.Compose([ 203 | T.RandomResize([400, 500, 600]), 204 | T.RandomSizeCrop(384, 600), 205 | T.RandomResize(scales, max_size=max_size), 206 | T.Check(), 207 | ]) 208 | ), 209 | normalize, 210 | ]) 211 | 212 | # we do not use the 'val' set since the annotations are inaccessible 213 | if image_set == 'val': 214 | return T.Compose([ 215 | T.RandomResize([360], max_size=640), 216 | normalize, 217 | ]) 218 | 219 | raise ValueError(f'unknown {image_set}') 220 | 221 | 222 | def build(image_set, args): 223 | root = Path(args.davis_path) 224 | assert root.exists(), f'provided DAVIS path {root} does not exist' 225 | PATHS = { 226 | "train": (root / "train", root / "meta_expressions" / "train" / "meta_expressions.json"), 227 | "val": (root / "valid", root / "meta_expressions" / "val" / "meta_expressions.json"), # not used actually 228 | } 229 | img_folder, ann_file = PATHS[image_set] 230 | dataset = DAVIS17Dataset(img_folder, ann_file, transforms=make_coco_transforms(image_set, max_size=args.max_size), 231 | return_masks=args.masks, num_frames=args.num_frames, max_skip=args.max_skip) 232 | return dataset 233 | 234 | 235 | -------------------------------------------------------------------------------- /models/criterion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from util import box_ops 6 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 7 | accuracy, get_world_size, interpolate, 8 | is_dist_avail_and_initialized, inverse_sigmoid) 9 | 10 | from .segmentation import (dice_loss, sigmoid_focal_loss) 11 | 12 | from einops import rearrange 13 | 14 | class SetCriterion(nn.Module): 15 | """ This class computes the loss for ReferFormer. 16 | The process happens in two steps: 17 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 18 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 19 | """ 20 | def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, focal_alpha=0.25): 21 | """ Create the criterion. 22 | Parameters: 23 | num_classes: number of object categories, omitting the special no-object category 24 | matcher: module able to compute a matching between targets and proposals 25 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 26 | eos_coef: relative classification weight applied to the no-object category 27 | losses: list of all the losses to be applied. See get_loss for list of available losses. 28 | """ 29 | super().__init__() 30 | self.num_classes = num_classes 31 | self.matcher = matcher 32 | self.weight_dict = weight_dict 33 | self.eos_coef = eos_coef 34 | self.losses = losses 35 | empty_weight = torch.ones(self.num_classes + 1) 36 | empty_weight[-1] = self.eos_coef 37 | self.register_buffer('empty_weight', empty_weight) 38 | self.focal_alpha = focal_alpha 39 | self.mask_out_stride = 4 40 | 41 | def loss_labels(self, outputs, targets, indices, num_boxes, log=True): 42 | """Classification loss (NLL) 43 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 44 | """ 45 | assert 'pred_logits' in outputs 46 | src_logits = outputs['pred_logits'] 47 | _, nf, nq = src_logits.shape[:3] 48 | src_logits = rearrange(src_logits, 'b t q k -> b (t q) k') 49 | 50 | # judge the valid frames 51 | valid_indices = [] 52 | valids = [target['valid'] for target in targets] 53 | for valid, (indice_i, indice_j) in zip(valids, indices): 54 | valid_ind = valid.nonzero().flatten() 55 | valid_i = valid_ind * nq + indice_i 56 | valid_j = valid_ind + indice_j * nf 57 | valid_indices.append((valid_i, valid_j)) 58 | 59 | idx = self._get_src_permutation_idx(valid_indices) # NOTE: use valid indices 60 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, valid_indices)]) 61 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, 62 | dtype=torch.int64, device=src_logits.device) 63 | if self.num_classes == 1: # binary referred 64 | target_classes[idx] = 0 65 | else: 66 | target_classes[idx] = target_classes_o 67 | 68 | target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1], 69 | dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device) 70 | target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1) 71 | 72 | target_classes_onehot = target_classes_onehot[:,:,:-1] 73 | loss_ce = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * src_logits.shape[1] 74 | losses = {'loss_ce': loss_ce} 75 | 76 | if log: 77 | # TODO this should probably be a separate loss, not hacked in this one here 78 | pass 79 | return losses 80 | 81 | 82 | def loss_boxes(self, outputs, targets, indices, num_boxes): 83 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 84 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 85 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 86 | """ 87 | assert 'pred_boxes' in outputs 88 | src_boxes = outputs['pred_boxes'] 89 | bs, nf, nq = src_boxes.shape[:3] 90 | src_boxes = src_boxes.transpose(1, 2) 91 | 92 | idx = self._get_src_permutation_idx(indices) 93 | src_boxes = src_boxes[idx] 94 | src_boxes = src_boxes.flatten(0, 1) # [b*t, 4] 95 | 96 | target_boxes = torch.cat([t['boxes'] for t in targets], dim=0) # [b*t, 4] 97 | 98 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 99 | 100 | losses = {} 101 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes 102 | 103 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( 104 | box_ops.box_cxcywh_to_xyxy(src_boxes), 105 | box_ops.box_cxcywh_to_xyxy(target_boxes))) 106 | losses['loss_giou'] = loss_giou.sum() / num_boxes 107 | return losses 108 | 109 | 110 | def loss_masks(self, outputs, targets, indices, num_boxes): 111 | """Compute the losses related to the masks: the focal loss and the dice loss. 112 | targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] 113 | """ 114 | assert "pred_masks" in outputs 115 | 116 | src_idx = self._get_src_permutation_idx(indices) 117 | # tgt_idx = self._get_tgt_permutation_idx(indices) 118 | 119 | src_masks = outputs["pred_masks"] 120 | src_masks = src_masks.transpose(1, 2) 121 | 122 | # TODO use valid to mask invalid areas due to padding in loss 123 | target_masks, valid = nested_tensor_from_tensor_list([t["masks"] for t in targets], 124 | size_divisibility=32, split=False).decompose() 125 | target_masks = target_masks.to(src_masks) 126 | 127 | # downsample ground truth masks with ratio mask_out_stride 128 | start = int(self.mask_out_stride // 2) 129 | im_h, im_w = target_masks.shape[-2:] 130 | 131 | target_masks = target_masks[:, :, start::self.mask_out_stride, start::self.mask_out_stride] 132 | assert target_masks.size(2) * self.mask_out_stride == im_h 133 | assert target_masks.size(3) * self.mask_out_stride == im_w 134 | 135 | src_masks = src_masks[src_idx] 136 | # upsample predictions to the target size 137 | # src_masks = interpolate(src_masks, size=target_masks.shape[-2:], mode="bilinear", align_corners=False) 138 | src_masks = src_masks.flatten(1) # [b, thw] 139 | 140 | target_masks = target_masks.flatten(1) # [b, thw] 141 | 142 | losses = { 143 | "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_boxes), 144 | "loss_dice": dice_loss(src_masks, target_masks, num_boxes), 145 | } 146 | return losses 147 | 148 | def _get_src_permutation_idx(self, indices): 149 | # permute predictions following indices 150 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 151 | src_idx = torch.cat([src for (src, _) in indices]) 152 | return batch_idx, src_idx 153 | 154 | def _get_tgt_permutation_idx(self, indices): 155 | # permute targets following indices 156 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 157 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 158 | return batch_idx, tgt_idx 159 | 160 | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): 161 | loss_map = { 162 | 'labels': self.loss_labels, 163 | 'boxes': self.loss_boxes, 164 | 'masks': self.loss_masks 165 | } 166 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 167 | return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) 168 | 169 | def forward(self, outputs, targets): 170 | """ This performs the loss computation. 171 | Parameters: 172 | outputs: dict of tensors, see the output specification of the model for the format 173 | targets: list of dicts, such that len(targets) == batch_size. 174 | The expected keys in each dict depends on the losses applied, see each loss' doc 175 | """ 176 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} 177 | # Retrieve the matching between the outputs of the last layer and the targets 178 | indices = self.matcher(outputs_without_aux, targets) 179 | 180 | # Compute the average number of target boxes accross all nodes, for normalization purposes 181 | target_valid = torch.stack([t["valid"] for t in targets], dim=0).reshape(-1) # [B, T] -> [B*T] 182 | num_boxes = target_valid.sum().item() 183 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) 184 | if is_dist_avail_and_initialized(): 185 | torch.distributed.all_reduce(num_boxes) 186 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 187 | 188 | # Compute all the requested losses 189 | losses = {} 190 | for loss in self.losses: 191 | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) 192 | 193 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 194 | if 'aux_outputs' in outputs: 195 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 196 | indices = self.matcher(aux_outputs, targets) 197 | for loss in self.losses: 198 | kwargs = {} 199 | if loss == 'labels': 200 | # Logging is enabled only for the last layer 201 | kwargs = {'log': False} 202 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs) 203 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 204 | losses.update(l_dict) 205 | 206 | return losses 207 | 208 | 209 | --------------------------------------------------------------------------------