├── .gitignore ├── LICENSE.md ├── README.md ├── demo ├── __init__.py ├── demo_automatic.py ├── demo_gradio.py └── demo_with_text.py ├── deva ├── __init__.py ├── dataset │ ├── __init__.py │ ├── static_dataset.py │ ├── tps.py │ ├── utils.py │ └── vos_dataset.py ├── ext │ ├── LightHQSAM │ │ ├── __init__.py │ │ ├── setup_light_hqsam.py │ │ └── tiny_vit_sam.py │ ├── MobileSAM │ │ ├── __init__.py │ │ ├── setup_mobile_sam.py │ │ └── tiny_vit_sam.py │ ├── SAM │ │ ├── __init__.py │ │ └── automatic_mask_generator.py │ ├── __init__.py │ ├── automatic_processor.py │ ├── automatic_sam.py │ ├── ext_eval_args.py │ ├── grounding_dino.py │ └── with_text_processor.py ├── inference │ ├── __init__.py │ ├── consensus_associated.py │ ├── consensus_automatic.py │ ├── data │ │ ├── __init__.py │ │ ├── detection_video_reader.py │ │ ├── referring_test_datasets.py │ │ ├── saliency_test_datasets.py │ │ ├── simple_video_reader.py │ │ ├── video_reader.py │ │ ├── vos_test_datasets.py │ │ └── vps_test_datasets.py │ ├── demo_utils.py │ ├── eval_args.py │ ├── frame_utils.py │ ├── image_feature_store.py │ ├── inference_core.py │ ├── kv_memory_store.py │ ├── memory_manager.py │ ├── object_info.py │ ├── object_manager.py │ ├── object_utils.py │ ├── postprocess_unsup_davis17.py │ ├── result_utils.py │ └── segment_merging.py ├── model │ ├── __init__.py │ ├── big_modules.py │ ├── cbam.py │ ├── group_modules.py │ ├── losses.py │ ├── memory_utils.py │ ├── modules.py │ ├── network.py │ ├── resnet.py │ └── trainer.py ├── train.py ├── utils │ ├── __init__.py │ ├── burst_test.txt │ ├── burst_val.txt │ ├── configuration.py │ ├── davis_subset.txt │ ├── image_saver.py │ ├── load_subset.py │ ├── log_integrator.py │ ├── logger.py │ ├── palette.py │ ├── pano_utils.py │ ├── referring-youtubevos-val.txt │ ├── tensor_utils.py │ ├── vipseg_categories.py │ └── yv_subset.txt └── vps_metrics │ ├── README.md │ ├── __init__.py │ ├── eval_stq_vipseg.py │ ├── eval_vpq_vipseg.py │ ├── segmentation_and_tracking_quality.py │ └── stuff_merging.py ├── docs ├── CUSTOM.md ├── DEMO.md ├── EVALUATION.md ├── TRAINING.md ├── index.html └── style.css ├── evaluation ├── __init__.py ├── eval_ref_davis.py ├── eval_ref_youtubevos.py ├── eval_saliency.py ├── eval_vos.py └── eval_with_detections.py ├── example ├── vipseg │ ├── images │ │ └── 12_1mWNahzcsAc │ │ │ ├── 00001255.jpg │ │ │ ├── 00001258.jpg │ │ │ ├── 00001261.jpg │ │ │ └── 00001264.jpg │ └── source │ │ └── 12_1mWNahzcsAc │ │ ├── 00001255.json │ │ ├── 00001255.png │ │ ├── 00001258.json │ │ ├── 00001258.png │ │ ├── 00001261.json │ │ ├── 00001261.png │ │ ├── 00001264.json │ │ └── 00001264.png └── vos │ ├── Annotations │ └── bmx-trees │ │ └── 00000.png │ └── JPEGImages │ └── bmx-trees │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ └── 00003.jpg ├── pyproject.toml └── scripts ├── __init__.py ├── download_datasets.py ├── download_models.sh ├── merge_burst_json.py ├── merge_multi_scale.py └── vipseg ├── change2_720p.py └── create_panoptic_video_labels.py /.gitignore: -------------------------------------------------------------------------------- 1 | run_*.sh 2 | log/ 3 | saves 4 | saves/ 5 | output/ 6 | .vscode/ 7 | example/*/output 8 | example/*/postprocessed 9 | example/videos 10 | example/output 11 | gradio_cached_examples/ 12 | flagged/ 13 | *.pth 14 | *.pt 15 | 16 | # Byte-compiled / optimized / DLL files 17 | __pycache__/ 18 | *.py[cod] 19 | *$py.class 20 | 21 | # C extensions 22 | *.so 23 | 24 | # Distribution / packaging 25 | .Python 26 | build/ 27 | develop-eggs/ 28 | dist/ 29 | downloads/ 30 | eggs/ 31 | .eggs/ 32 | lib/ 33 | lib64/ 34 | parts/ 35 | sdist/ 36 | var/ 37 | wheels/ 38 | pip-wheel-metadata/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # PyInstaller 46 | # Usually these files are written by a python script from a template 47 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 48 | *.manifest 49 | *.spec 50 | 51 | # Installer logs 52 | pip-log.txt 53 | pip-delete-this-directory.txt 54 | 55 | # Unit test / coverage reports 56 | htmlcov/ 57 | .tox/ 58 | .nox/ 59 | .coverage 60 | .coverage.* 61 | .cache 62 | nosetests.xml 63 | coverage.xml 64 | *.cover 65 | *.py,cover 66 | .hypothesis/ 67 | .pytest_cache/ 68 | 69 | # Translations 70 | *.mo 71 | *.pot 72 | 73 | # Django stuff: 74 | *.log 75 | local_settings.py 76 | db.sqlite3 77 | db.sqlite3-journal 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # IPython 96 | profile_default/ 97 | ipython_config.py 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # pipenv 103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 106 | # install all needed dependencies. 107 | #Pipfile.lock 108 | 109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 110 | __pypackages__/ 111 | 112 | # Celery stuff 113 | celerybeat-schedule 114 | celerybeat.pid 115 | 116 | # SageMath parsed files 117 | *.sage.py 118 | 119 | # Environments 120 | .env 121 | .venv 122 | env/ 123 | venv/ 124 | ENV/ 125 | env.bak/ 126 | venv.bak/ 127 | 128 | # Spyder project settings 129 | .spyderproject 130 | .spyproject 131 | 132 | # Rope project settings 133 | .ropeproject 134 | 135 | # mkdocs documentation 136 | /site 137 | 138 | # mypy 139 | .mypy_cache/ 140 | .dmypy.json 141 | dmypy.json 142 | 143 | # Pyre type checker 144 | .pyre/ 145 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License. 2 | -------------------------------------------------------------------------------- /demo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/demo/__init__.py -------------------------------------------------------------------------------- /demo/demo_automatic.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | 9 | from deva.inference.inference_core import DEVAInferenceCore 10 | from deva.inference.data.simple_video_reader import SimpleVideoReader, no_collate 11 | from deva.inference.result_utils import ResultSaver 12 | from deva.inference.eval_args import add_common_eval_args, get_model_and_config 13 | from deva.inference.demo_utils import flush_buffer 14 | from deva.ext.ext_eval_args import add_ext_eval_args, add_auto_default_args 15 | from deva.ext.automatic_sam import get_sam_model 16 | from deva.ext.automatic_processor import process_frame_automatic as process_frame 17 | 18 | from tqdm import tqdm 19 | import json 20 | 21 | if __name__ == '__main__': 22 | torch.autograd.set_grad_enabled(False) 23 | 24 | # for id2rgb 25 | np.random.seed(42) 26 | """ 27 | Arguments loading 28 | """ 29 | parser = ArgumentParser() 30 | 31 | add_common_eval_args(parser) 32 | add_ext_eval_args(parser) 33 | add_auto_default_args(parser) 34 | deva_model, cfg, args = get_model_and_config(parser) 35 | sam_model = get_sam_model(cfg, 'cuda') 36 | """ 37 | Temporal setting 38 | """ 39 | cfg['temporal_setting'] = args.temporal_setting.lower() 40 | assert cfg['temporal_setting'] in ['semionline', 'online'] 41 | 42 | # get data 43 | video_reader = SimpleVideoReader(cfg['img_path']) 44 | loader = DataLoader(video_reader, batch_size=None, collate_fn=no_collate, num_workers=8) 45 | out_path = cfg['output'] 46 | 47 | # Start eval 48 | vid_length = len(loader) 49 | # no need to count usage for LT if the video is not that long anyway 50 | cfg['enable_long_term_count_usage'] = ( 51 | cfg['enable_long_term'] 52 | and (vid_length / (cfg['max_mid_term_frames'] - cfg['min_mid_term_frames']) * 53 | cfg['num_prototypes']) >= cfg['max_long_term_elements']) 54 | 55 | print('Configuration:', cfg) 56 | 57 | deva = DEVAInferenceCore(deva_model, config=cfg) 58 | deva.next_voting_frame = args.num_voting_frames - 1 59 | deva.enabled_long_id() 60 | result_saver = ResultSaver(out_path, None, dataset='demo', object_manager=deva.object_manager) 61 | 62 | with torch.cuda.amp.autocast(enabled=args.amp): 63 | for ti, (frame, im_path) in enumerate(tqdm(loader)): 64 | process_frame(deva, sam_model, im_path, result_saver, ti, image_np=frame) 65 | flush_buffer(deva, result_saver) 66 | result_saver.end() 67 | 68 | # save this as a video-level json 69 | with open(path.join(out_path, 'pred.json'), 'w') as f: 70 | json.dump(result_saver.video_json, f, indent=4) # prettier json 71 | -------------------------------------------------------------------------------- /demo/demo_with_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import numpy as np 8 | 9 | from deva.inference.inference_core import DEVAInferenceCore 10 | from deva.inference.data.simple_video_reader import SimpleVideoReader, no_collate 11 | from deva.inference.result_utils import ResultSaver 12 | from deva.inference.eval_args import add_common_eval_args, get_model_and_config 13 | from deva.inference.demo_utils import flush_buffer 14 | from deva.ext.ext_eval_args import add_ext_eval_args, add_text_default_args 15 | from deva.ext.grounding_dino import get_grounding_dino_model 16 | from deva.ext.with_text_processor import process_frame_with_text as process_frame 17 | 18 | from tqdm import tqdm 19 | import json 20 | 21 | if __name__ == '__main__': 22 | torch.autograd.set_grad_enabled(False) 23 | 24 | # for id2rgb 25 | np.random.seed(42) 26 | """ 27 | Arguments loading 28 | """ 29 | parser = ArgumentParser() 30 | 31 | add_common_eval_args(parser) 32 | add_ext_eval_args(parser) 33 | add_text_default_args(parser) 34 | deva_model, cfg, args = get_model_and_config(parser) 35 | gd_model, sam_model = get_grounding_dino_model(cfg, 'cuda') 36 | """ 37 | Temporal setting 38 | """ 39 | cfg['temporal_setting'] = args.temporal_setting.lower() 40 | assert cfg['temporal_setting'] in ['semionline', 'online'] 41 | 42 | # get data 43 | video_reader = SimpleVideoReader(cfg['img_path']) 44 | loader = DataLoader(video_reader, batch_size=None, collate_fn=no_collate, num_workers=8) 45 | out_path = cfg['output'] 46 | 47 | # Start eval 48 | vid_length = len(loader) 49 | # no need to count usage for LT if the video is not that long anyway 50 | cfg['enable_long_term_count_usage'] = ( 51 | cfg['enable_long_term'] 52 | and (vid_length / (cfg['max_mid_term_frames'] - cfg['min_mid_term_frames']) * 53 | cfg['num_prototypes']) >= cfg['max_long_term_elements']) 54 | 55 | print('Configuration:', cfg) 56 | 57 | deva = DEVAInferenceCore(deva_model, config=cfg) 58 | deva.next_voting_frame = cfg['num_voting_frames'] - 1 59 | deva.enabled_long_id() 60 | result_saver = ResultSaver(out_path, None, dataset='demo', object_manager=deva.object_manager) 61 | 62 | with torch.cuda.amp.autocast(enabled=cfg['amp']): 63 | for ti, (frame, im_path) in enumerate(tqdm(loader)): 64 | process_frame(deva, gd_model, sam_model, im_path, result_saver, ti, image_np=frame) 65 | flush_buffer(deva, result_saver) 66 | result_saver.end() 67 | 68 | # save this as a video-level json 69 | with open(path.join(out_path, 'pred.json'), 'w') as f: 70 | json.dump(result_saver.video_json, f, indent=4) # prettier json 71 | -------------------------------------------------------------------------------- /deva/__init__.py: -------------------------------------------------------------------------------- 1 | from deva.inference.inference_core import DEVAInferenceCore 2 | from deva.model.network import DEVA -------------------------------------------------------------------------------- /deva/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/dataset/__init__.py -------------------------------------------------------------------------------- /deva/dataset/static_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import torch 5 | from torch.utils.data.dataset import Dataset 6 | from torchvision import transforms 7 | from torchvision.transforms import InterpolationMode 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from deva.dataset.utils import im_normalization, im_mean, reseed 12 | from deva.dataset.tps import random_tps_warp 13 | 14 | 15 | class StaticTransformDataset(Dataset): 16 | """ 17 | Generate pseudo VOS data by applying random transforms on static images. 18 | Single-object only. 19 | 20 | parameters is a list of tuples (data_root, how data is structured (0 or 1), and sample multiplier) 21 | 22 | Method 0 - FSS style (class/1.jpg class/1.png) 23 | Method 1 - Others style (XXX.jpg XXX.png) 24 | """ 25 | def __init__(self, parameters, *, size=384, num_frames=3, max_num_obj=1): 26 | self.num_frames = num_frames 27 | self.max_num_obj = max_num_obj 28 | self.size = size 29 | 30 | self.im_list = [] 31 | for parameter in parameters: 32 | root, method, multiplier = parameter 33 | if method == 0: 34 | # Get images 35 | classes = os.listdir(root) 36 | for c in classes: 37 | imgs = os.listdir(path.join(root, c)) 38 | jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()] 39 | 40 | joint_list = [path.join(root, c, im) for im in jpg_list] 41 | self.im_list.extend(joint_list * multiplier) 42 | 43 | elif method == 1: 44 | self.im_list.extend( 45 | [path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier) 46 | 47 | print(f'{len(self.im_list)} images found.') 48 | 49 | # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames 50 | self.pair_im_lone_transform = transforms.Compose([ 51 | transforms.ColorJitter(0.1, 0.05, 0.05, 0), 52 | ]) 53 | 54 | self.pair_im_dual_transform = transforms.Compose([ 55 | transforms.RandomAffine(degrees=20, 56 | scale=(0.5, 2.0), 57 | shear=10, 58 | interpolation=InterpolationMode.BICUBIC, 59 | fill=im_mean), 60 | transforms.Resize(self.size, InterpolationMode.BICUBIC, antialias=True), 61 | transforms.RandomCrop((self.size, self.size), pad_if_needed=True, fill=im_mean), 62 | ]) 63 | 64 | self.pair_gt_dual_transform = transforms.Compose([ 65 | transforms.RandomAffine( 66 | degrees=20, 67 | scale=(0.5, 2.0), 68 | shear=10, 69 | # don't know why I used bicubic here. 70 | # Since GT is binary it shouldn't matter much 71 | interpolation=InterpolationMode.BICUBIC, 72 | fill=0), 73 | transforms.Resize(self.size, InterpolationMode.NEAREST), 74 | transforms.RandomCrop((self.size, self.size), pad_if_needed=True, fill=0), 75 | ]) 76 | 77 | # These transform are the same for all pairs in the sampled sequence 78 | self.all_im_lone_transform = transforms.Compose([ 79 | transforms.ColorJitter(0.1, 0.05, 0.05, 0.05), 80 | transforms.RandomGrayscale(0.05), 81 | ]) 82 | 83 | self.all_im_dual_transform = transforms.Compose([ 84 | transforms.RandomAffine(degrees=0, scale=(0.5, 2.0), fill=im_mean), 85 | transforms.RandomHorizontalFlip(), 86 | ]) 87 | 88 | self.all_gt_dual_transform = transforms.Compose([ 89 | transforms.RandomAffine(degrees=0, scale=(0.5, 2.0), fill=0), 90 | transforms.RandomHorizontalFlip(), 91 | ]) 92 | 93 | # Final transform without randomness 94 | self.final_im_transform = transforms.Compose([ 95 | transforms.ToTensor(), 96 | im_normalization, 97 | ]) 98 | 99 | self.final_gt_transform = transforms.Compose([ 100 | transforms.ToTensor(), 101 | ]) 102 | 103 | def _get_sample(self, idx): 104 | im = Image.open(self.im_list[idx]).convert('RGB') 105 | gt = Image.open(self.im_list[idx][:-3] + 'png').convert('L') 106 | 107 | sequence_seed = np.random.randint(2147483647) 108 | 109 | images = [] 110 | masks = [] 111 | for _ in range(self.num_frames): 112 | reseed(sequence_seed) 113 | this_im = self.all_im_dual_transform(im) 114 | this_im = self.all_im_lone_transform(this_im) 115 | reseed(sequence_seed) 116 | this_gt = self.all_gt_dual_transform(gt) 117 | 118 | pairwise_seed = np.random.randint(2147483647) 119 | reseed(pairwise_seed) 120 | this_im = self.pair_im_dual_transform(this_im) 121 | this_im = self.pair_im_lone_transform(this_im) 122 | reseed(pairwise_seed) 123 | this_gt = self.pair_gt_dual_transform(this_gt) 124 | 125 | # Use TPS only some of the times 126 | # Not because TPS is bad -- just that it is too slow and I need to speed up data loading 127 | if np.random.rand() < 0.33: 128 | this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02) 129 | 130 | this_im = self.final_im_transform(this_im) 131 | this_gt = self.final_gt_transform(this_gt) 132 | 133 | images.append(this_im) 134 | masks.append(this_gt) 135 | 136 | images = torch.stack(images, 0) 137 | masks = torch.stack(masks, 0) 138 | 139 | return images, masks.numpy() 140 | 141 | def __getitem__(self, idx): 142 | additional_objects = np.random.randint(self.max_num_obj) 143 | indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)] 144 | 145 | merged_images = None 146 | merged_masks = np.zeros((self.num_frames, self.size, self.size), dtype=np.int64) 147 | 148 | for i, list_id in enumerate(indices): 149 | images, masks = self._get_sample(list_id) 150 | if merged_images is None: 151 | merged_images = images 152 | else: 153 | merged_images = merged_images * (1 - masks) + images * masks 154 | merged_masks[masks[:, 0] > 0.5] = (i + 1) 155 | 156 | masks = merged_masks 157 | 158 | labels = np.unique(masks[0]) 159 | # Remove background 160 | labels = labels[labels != 0] 161 | target_objects = labels.tolist() 162 | 163 | # Generate one-hot ground-truth 164 | cls_gt = np.zeros((self.num_frames, self.size, self.size), dtype=np.int64) 165 | first_frame_gt = np.zeros((1, self.max_num_obj, self.size, self.size), dtype=np.int64) 166 | for i, l in enumerate(target_objects): 167 | this_mask = (masks == l) 168 | cls_gt[this_mask] = i + 1 169 | first_frame_gt[0, i] = (this_mask[0]) 170 | cls_gt = np.expand_dims(cls_gt, 1) 171 | 172 | info = {} 173 | info['name'] = self.im_list[idx] 174 | info['num_objects'] = max(1, len(target_objects)) 175 | 176 | # 1 if object exist, 0 otherwise 177 | selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)] 178 | selector = torch.FloatTensor(selector) 179 | 180 | data = { 181 | 'rgb': merged_images, 182 | 'first_frame_gt': first_frame_gt, 183 | 'cls_gt': cls_gt, 184 | 'selector': selector, 185 | 'info': info 186 | } 187 | 188 | return data 189 | 190 | def __len__(self): 191 | return len(self.im_list) 192 | -------------------------------------------------------------------------------- /deva/dataset/tps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | import thinplate as tps 5 | 6 | cv2.setNumThreads(0) 7 | 8 | def pick_random_points(h, w, n_samples): 9 | y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False) 10 | x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False) 11 | return y_idx/h, x_idx/w 12 | 13 | 14 | def warp_dual_cv(img, mask, c_src, c_dst): 15 | dshape = img.shape 16 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True) 17 | grid = tps.tps_grid(theta, c_dst, dshape) 18 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape) 19 | return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST) 20 | 21 | 22 | def random_tps_warp(img, mask, scale, n_ctrl_pts=12): 23 | """ 24 | Apply a random TPS warp of the input image and mask 25 | Uses randomness from numpy 26 | """ 27 | img = np.asarray(img) 28 | mask = np.asarray(mask) 29 | 30 | h, w = mask.shape 31 | points = pick_random_points(h, w, n_ctrl_pts) 32 | c_src = np.stack(points, 1) 33 | c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape) 34 | warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst) 35 | 36 | return Image.fromarray(warp_im), Image.fromarray(warp_gt) 37 | 38 | -------------------------------------------------------------------------------- /deva/dataset/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | 6 | im_mean = (124, 116, 104) 7 | 8 | im_normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 9 | 10 | inv_im_trans = transforms.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 11 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) 12 | 13 | 14 | def reseed(seed): 15 | random.seed(seed) 16 | torch.manual_seed(seed) 17 | 18 | 19 | def all_to_onehot(masks, labels): 20 | if len(masks.shape) == 3: 21 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8) 22 | else: 23 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8) 24 | 25 | for ni, l in enumerate(labels): 26 | Ms[ni] = (masks == l).astype(np.uint8) 27 | 28 | return Ms 29 | -------------------------------------------------------------------------------- /deva/ext/LightHQSAM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/LightHQSAM/__init__.py -------------------------------------------------------------------------------- /deva/ext/LightHQSAM/setup_light_hqsam.py: -------------------------------------------------------------------------------- 1 | from deva.ext.LightHQSAM.tiny_vit_sam import TinyViT 2 | from segment_anything.modeling import MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer 3 | 4 | def setup_model(): 5 | prompt_embed_dim = 256 6 | image_size = 1024 7 | vit_patch_size = 16 8 | image_embedding_size = image_size // vit_patch_size 9 | mobile_sam = Sam( 10 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, 11 | embed_dims=[64, 128, 160, 320], 12 | depths=[2, 2, 6, 2], 13 | num_heads=[2, 4, 5, 10], 14 | window_sizes=[7, 7, 14, 7], 15 | mlp_ratio=4., 16 | drop_rate=0., 17 | drop_path_rate=0.0, 18 | use_checkpoint=False, 19 | mbconv_expand_ratio=4.0, 20 | local_conv_size=3, 21 | layer_lr_decay=0.8 22 | ), 23 | prompt_encoder=PromptEncoder( 24 | embed_dim=prompt_embed_dim, 25 | image_embedding_size=(image_embedding_size, image_embedding_size), 26 | input_image_size=(image_size, image_size), 27 | mask_in_chans=16, 28 | ), 29 | mask_decoder=MaskDecoderHQ( 30 | num_multimask_outputs=3, 31 | transformer=TwoWayTransformer( 32 | depth=2, 33 | embedding_dim=prompt_embed_dim, 34 | mlp_dim=2048, 35 | num_heads=8, 36 | ), 37 | transformer_dim=prompt_embed_dim, 38 | iou_head_depth=3, 39 | iou_head_hidden_dim=256, 40 | vit_dim=160, 41 | ), 42 | pixel_mean=[123.675, 116.28, 103.53], 43 | pixel_std=[58.395, 57.12, 57.375], 44 | ) 45 | return mobile_sam -------------------------------------------------------------------------------- /deva/ext/MobileSAM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/MobileSAM/__init__.py -------------------------------------------------------------------------------- /deva/ext/MobileSAM/setup_mobile_sam.py: -------------------------------------------------------------------------------- 1 | # https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/EfficientSAM/MobileSAM/setup_mobile_sam.py 2 | 3 | from deva.ext.MobileSAM.tiny_vit_sam import TinyViT 4 | from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer 5 | 6 | 7 | def setup_model(): 8 | prompt_embed_dim = 256 9 | image_size = 1024 10 | vit_patch_size = 16 11 | image_embedding_size = image_size // vit_patch_size 12 | mobile_sam = Sam( 13 | image_encoder=TinyViT(img_size=1024, 14 | in_chans=3, 15 | num_classes=1000, 16 | embed_dims=[64, 128, 160, 320], 17 | depths=[2, 2, 6, 2], 18 | num_heads=[2, 4, 5, 10], 19 | window_sizes=[7, 7, 14, 7], 20 | mlp_ratio=4., 21 | drop_rate=0., 22 | drop_path_rate=0.0, 23 | use_checkpoint=False, 24 | mbconv_expand_ratio=4.0, 25 | local_conv_size=3, 26 | layer_lr_decay=0.8), 27 | prompt_encoder=PromptEncoder( 28 | embed_dim=prompt_embed_dim, 29 | image_embedding_size=(image_embedding_size, image_embedding_size), 30 | input_image_size=(image_size, image_size), 31 | mask_in_chans=16, 32 | ), 33 | mask_decoder=MaskDecoder( 34 | num_multimask_outputs=3, 35 | transformer=TwoWayTransformer( 36 | depth=2, 37 | embedding_dim=prompt_embed_dim, 38 | mlp_dim=2048, 39 | num_heads=8, 40 | ), 41 | transformer_dim=prompt_embed_dim, 42 | iou_head_depth=3, 43 | iou_head_hidden_dim=256, 44 | ), 45 | pixel_mean=[123.675, 116.28, 103.53], 46 | pixel_std=[58.395, 57.12, 57.375], 47 | ) 48 | return mobile_sam -------------------------------------------------------------------------------- /deva/ext/SAM/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/SAM/__init__.py -------------------------------------------------------------------------------- /deva/ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/__init__.py -------------------------------------------------------------------------------- /deva/ext/automatic_processor.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from typing import Dict, List, Optional 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | 8 | from deva.inference.object_info import ObjectInfo 9 | from deva.inference.inference_core import DEVAInferenceCore 10 | from deva.inference.frame_utils import FrameInfo 11 | from deva.inference.result_utils import ResultSaver 12 | from deva.inference.demo_utils import get_input_frame_for_deva 13 | from deva.ext.automatic_sam import auto_segment 14 | from deva.utils.tensor_utils import pad_divide_by, unpad 15 | 16 | from segment_anything import SamAutomaticMaskGenerator 17 | 18 | 19 | def make_segmentation(cfg: Dict, image_np: np.ndarray, forward_mask: Optional[torch.Tensor], 20 | sam_model: SamAutomaticMaskGenerator, min_side: int, 21 | suppress_small_mask: bool) -> (torch.Tensor, List[ObjectInfo]): 22 | mask, segments_info = auto_segment(cfg, sam_model, image_np, forward_mask, min_side, 23 | suppress_small_mask) 24 | return mask, segments_info 25 | 26 | 27 | @torch.inference_mode() 28 | def process_frame_automatic(deva: DEVAInferenceCore, 29 | sam_model: SamAutomaticMaskGenerator, 30 | frame_path: str, 31 | result_saver: ResultSaver, 32 | ti: int, 33 | image_np: np.ndarray = None) -> None: 34 | # image_np, if given, should be in RGB 35 | if image_np is None: 36 | image_np = cv2.imread(frame_path) 37 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 38 | cfg = deva.config 39 | 40 | h, w = image_np.shape[:2] 41 | new_min_side = cfg['size'] 42 | suppress_small_mask = cfg['suppress_small_objects'] 43 | need_resize = new_min_side > 0 44 | image = get_input_frame_for_deva(image_np, new_min_side) 45 | 46 | frame_name = path.basename(frame_path) 47 | frame_info = FrameInfo(image, None, None, ti, { 48 | 'frame': [frame_name], 49 | 'shape': [h, w], 50 | }) 51 | 52 | if cfg['temporal_setting'] == 'semionline': 53 | if ti + cfg['num_voting_frames'] > deva.next_voting_frame: 54 | # getting a forward mask 55 | if deva.memory.engaged: 56 | forward_mask = estimate_forward_mask(deva, image) 57 | else: 58 | forward_mask = None 59 | 60 | mask, segments_info = make_segmentation(cfg, image_np, forward_mask, sam_model, 61 | new_min_side, suppress_small_mask) 62 | frame_info.mask = mask 63 | frame_info.segments_info = segments_info 64 | frame_info.image_np = image_np # for visualization only 65 | # wait for more frames before proceeding 66 | deva.add_to_temporary_buffer(frame_info) 67 | 68 | if ti == deva.next_voting_frame: 69 | # process this clip 70 | this_image = deva.frame_buffer[0].image 71 | this_frame_name = deva.frame_buffer[0].name 72 | this_image_np = deva.frame_buffer[0].image_np 73 | 74 | _, mask, new_segments_info = deva.vote_in_temporary_buffer( 75 | keyframe_selection='first') 76 | prob = deva.incorporate_detection(this_image, 77 | mask, 78 | new_segments_info, 79 | incremental=True) 80 | deva.next_voting_frame += cfg['detection_every'] 81 | 82 | result_saver.save_mask(prob, 83 | this_frame_name, 84 | need_resize=need_resize, 85 | shape=(h, w), 86 | image_np=this_image_np) 87 | 88 | for frame_info in deva.frame_buffer[1:]: 89 | this_image = frame_info.image 90 | this_frame_name = frame_info.name 91 | this_image_np = frame_info.image_np 92 | prob = deva.step(this_image, None, None) 93 | result_saver.save_mask(prob, 94 | this_frame_name, 95 | need_resize, 96 | shape=(h, w), 97 | image_np=this_image_np) 98 | 99 | deva.clear_buffer() 100 | else: 101 | # standard propagation 102 | prob = deva.step(image, None, None) 103 | result_saver.save_mask(prob, 104 | frame_name, 105 | need_resize=need_resize, 106 | shape=(h, w), 107 | image_np=image_np) 108 | 109 | elif cfg['temporal_setting'] == 'online': 110 | if ti % cfg['detection_every'] == 0: 111 | # incorporate new detections 112 | if deva.memory.engaged: 113 | forward_mask = estimate_forward_mask(deva, image) 114 | else: 115 | forward_mask = None 116 | 117 | mask, segments_info = make_segmentation(cfg, image_np, forward_mask, sam_model, 118 | new_min_side, suppress_small_mask) 119 | frame_info.segments_info = segments_info 120 | prob = deva.incorporate_detection(image, mask, segments_info, incremental=True) 121 | else: 122 | # Run the model on this frame 123 | prob = deva.step(image, None, None) 124 | result_saver.save_mask(prob, 125 | frame_name, 126 | need_resize=need_resize, 127 | shape=(h, w), 128 | image_np=image_np) 129 | 130 | 131 | def estimate_forward_mask(deva: DEVAInferenceCore, image: torch.Tensor): 132 | image, pad = pad_divide_by(image, 16) 133 | image = image.unsqueeze(0) # add the batch dimension 134 | 135 | ms_features = deva.image_feature_store.get_ms_features(deva.curr_ti + 1, image) 136 | key, _, selection = deva.image_feature_store.get_key(deva.curr_ti + 1, image) 137 | prob = deva._segment(key, selection, ms_features) 138 | forward_mask = torch.argmax(prob, dim=0) 139 | forward_mask = unpad(forward_mask, pad) 140 | return forward_mask 141 | -------------------------------------------------------------------------------- /deva/ext/automatic_sam.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/IDEA-Research/Grounded-Segment-Anything 2 | 3 | from typing import Dict, List, Optional 4 | import numpy as np 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from segment_anything import sam_model_registry 11 | from deva.ext.SAM.automatic_mask_generator import SamAutomaticMaskGenerator 12 | from deva.ext.MobileSAM.setup_mobile_sam import setup_model as setup_mobile_sam 13 | from deva.inference.object_info import ObjectInfo 14 | 15 | 16 | def get_sam_model(config: Dict, device: str) -> SamAutomaticMaskGenerator: 17 | variant = config['sam_variant'].lower() 18 | if variant == 'mobile': 19 | MOBILE_SAM_CHECKPOINT_PATH = config['MOBILE_SAM_CHECKPOINT_PATH'] 20 | 21 | # Building Mobile SAM model 22 | checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH) 23 | mobile_sam = setup_mobile_sam() 24 | mobile_sam.load_state_dict(checkpoint, strict=True) 25 | mobile_sam.to(device=device) 26 | auto_sam = SamAutomaticMaskGenerator(mobile_sam, 27 | points_per_side=config['SAM_NUM_POINTS_PER_SIDE'], 28 | points_per_batch=config['SAM_NUM_POINTS_PER_BATCH'], 29 | pred_iou_thresh=config['SAM_PRED_IOU_THRESHOLD']) 30 | elif variant == 'original': 31 | SAM_ENCODER_VERSION = config['SAM_ENCODER_VERSION'] 32 | SAM_CHECKPOINT_PATH = config['SAM_CHECKPOINT_PATH'] 33 | 34 | # Building SAM Model and SAM Predictor 35 | sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to( 36 | device=device) 37 | auto_sam = SamAutomaticMaskGenerator(sam, 38 | points_per_side=config['SAM_NUM_POINTS_PER_SIDE'], 39 | points_per_batch=config['SAM_NUM_POINTS_PER_BATCH'], 40 | pred_iou_thresh=config['SAM_PRED_IOU_THRESHOLD']) 41 | else: 42 | raise ValueError(f'Unknown SAM variant: {config["SAM_VARIANT"]}') 43 | 44 | return auto_sam 45 | 46 | 47 | def auto_segment(config: Dict, auto_sam: SamAutomaticMaskGenerator, image: np.ndarray, 48 | forward_mask: Optional[torch.Tensor], min_side: int, 49 | suppress_small_mask: bool) -> (torch.Tensor, List[ObjectInfo]): 50 | """ 51 | config: the global configuration dictionary 52 | image: the image to segment; should be a numpy array; H*W*3; unnormalized (0~255) 53 | forward_mask: the mask used to determine positive/negative points; H*W 54 | 55 | Returns: a torch index mask of the same size as image; H*W 56 | a list of segment info, see object_utils.py for definition 57 | """ 58 | device = auto_sam.predictor.device 59 | 60 | h, w = image.shape[:2] 61 | if min_side > 0: 62 | scale = min_side / min(h, w) 63 | new_h, new_w = int(h * scale), int(w * scale) 64 | else: 65 | new_h, new_w = h, w 66 | 67 | if forward_mask is not None: 68 | # compute positive and negative points 69 | foreground_mask = (forward_mask > 0).float().unsqueeze(0).unsqueeze(0) 70 | foreground_mask = F.interpolate(foreground_mask, 71 | scale_factor=1 / 16, 72 | mode='bilinear', 73 | antialias=True) # blurring 74 | n_per_side = config['SAM_NUM_POINTS_PER_SIDE'] 75 | offset = 1 / (2 * n_per_side) 76 | points_one_side = torch.linspace(offset, 1 - offset, n_per_side, device=device) 77 | points_x = points_one_side.unsqueeze(0).repeat(n_per_side, 1) 78 | points_y = points_one_side.unsqueeze(1).repeat(1, n_per_side) 79 | points = torch.stack([points_x, points_y], dim=-1).unsqueeze(0) 80 | points_label = F.grid_sample(foreground_mask, points * 2 - 1, align_corners=False).view(-1) 81 | points = points.view(-1, 2) 82 | positive_points = points[points_label < 0.01].cpu().numpy() 83 | if len(positive_points) == 0: 84 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=device) 85 | segments_info = [] 86 | return output_mask, segments_info 87 | # negative_points = points[points_label >= 0.5].cpu().numpy() 88 | negative_points = None # no negative points 89 | mask_data = auto_sam.generate(image, positive_points, negative_points) 90 | else: 91 | mask_data = auto_sam.generate(image) 92 | 93 | curr_id = 1 94 | segments_info = [] 95 | 96 | pred_masks = mask_data['masks'].float() # num masks * H * W 97 | predicted_iou = mask_data["iou_preds"] 98 | 99 | # score mask by their areas 100 | if pred_masks.shape[0] == 0: 101 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=device) 102 | else: 103 | pred_masks = F.interpolate(pred_masks.unsqueeze(0), (new_h, new_w), mode='bilinear')[0] 104 | 105 | curr_id = 1 106 | if suppress_small_mask: 107 | areas = pred_masks.flatten(-2).sum(-1) 108 | scores = areas.unsqueeze(-1).unsqueeze(-1) 109 | 110 | scored_masks = pred_masks * scores 111 | scored_masks_with_bg = torch.cat( 112 | [torch.zeros((1, *pred_masks.shape[1:]), device=device) + 0.1, scored_masks], dim=0) 113 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=device) 114 | 115 | # let large mask eats small masks (small/tiny/incomplete masks are too common in SAM) 116 | hard_mask = torch.argmax(scored_masks_with_bg, dim=0) 117 | for k in range(scores.shape[0]): 118 | mask_area = (hard_mask == (k + 1)).sum() 119 | original_area = (pred_masks[k] > 0.5).sum() 120 | mask = (hard_mask == (k + 1)) & (pred_masks[k] >= 0.5) 121 | 122 | if mask_area > 0 and original_area > 0 and mask.sum() > 0: 123 | if mask_area / original_area < config['SAM_OVERLAP_THRESHOLD']: 124 | continue 125 | output_mask[mask] = curr_id 126 | segments_info.append(ObjectInfo(id=curr_id, score=predicted_iou[k].item())) 127 | curr_id += 1 128 | else: 129 | # prefer smaller objects 130 | areas = pred_masks.flatten(-2).sum(-1) 131 | scores = (areas.max() * 2 - areas).unsqueeze(-1).unsqueeze(-1) 132 | scored_masks = pred_masks * scores 133 | 134 | # add background channel 135 | scored_masks_with_bg = torch.cat( 136 | [torch.zeros((1, *scored_masks.shape[1:]), device=device) + 0.1, scored_masks], 137 | dim=0) 138 | output_mask = torch.argmax(scored_masks_with_bg, dim=0) 139 | for k in range(scored_masks.shape[0]): 140 | mask = (output_mask == (k + 1)) 141 | if mask.sum() > 0: 142 | segments_info.append(ObjectInfo(id=curr_id, score=predicted_iou[k].item())) 143 | curr_id += 1 144 | 145 | return output_mask, segments_info 146 | -------------------------------------------------------------------------------- /deva/ext/ext_eval_args.py: -------------------------------------------------------------------------------- 1 | # Evaluation arguments for extensions 2 | from argparse import ArgumentParser 3 | 4 | 5 | def add_ext_eval_args(parser: ArgumentParser): 6 | 7 | # Grounded Segment Anything 8 | parser.add_argument('--GROUNDING_DINO_CONFIG_PATH', 9 | default='./saves/GroundingDINO_SwinT_OGC.py') 10 | 11 | parser.add_argument('--GROUNDING_DINO_CHECKPOINT_PATH', 12 | default='./saves/groundingdino_swint_ogc.pth') 13 | 14 | parser.add_argument('--DINO_THRESHOLD', default=0.35, type=float) 15 | parser.add_argument('--DINO_NMS_THRESHOLD', default=0.8, type=float) 16 | 17 | # Segment Anything (SAM) models 18 | parser.add_argument('--SAM_ENCODER_VERSION', default='vit_h') 19 | parser.add_argument('--SAM_CHECKPOINT_PATH', default='./saves/sam_vit_h_4b8939.pth') 20 | 21 | # HQ-SAM 22 | parser.add_argument('--HQ_SAM_CHECKPOINT_PATH', default='./saves/sam_hq_vit_h.pth') 23 | 24 | # Light HQ-SAM 25 | parser.add_argument('--LIGHT_HQ_SAM_CHECKPOINT_PATH', default='./saves/sam_hq_vit_tiny.pth') 26 | 27 | # Mobile SAM 28 | parser.add_argument('--MOBILE_SAM_CHECKPOINT_PATH', default='./saves/mobile_sam.pt') 29 | 30 | # Segment Anything (SAM) parameters 31 | parser.add_argument('--SAM_NUM_POINTS_PER_SIDE', 32 | type=int, 33 | help='Number of points per side for prompting SAM', 34 | default=64) 35 | parser.add_argument('--SAM_NUM_POINTS_PER_BATCH', 36 | type=int, 37 | help='Number of points computed per batch', 38 | default=64) 39 | parser.add_argument('--SAM_PRED_IOU_THRESHOLD', 40 | type=float, 41 | help='(Predicted) IoU threshold for SAM', 42 | default=0.88) 43 | parser.add_argument('--SAM_OVERLAP_THRESHOLD', 44 | type=float, 45 | help='Overlap threshold for overlapped mask suppression in SAM', 46 | default=0.8) 47 | 48 | 49 | def add_text_default_args(parser): 50 | parser.add_argument('--img_path', default='./example/vipseg') 51 | parser.add_argument('--detection_every', type=int, default=5) 52 | parser.add_argument('--num_voting_frames', 53 | default=3, 54 | type=int, 55 | help='Number of frames selected for voting. only valid in semionline') 56 | 57 | parser.add_argument('--temporal_setting', default='semionline', help='semionline/online') 58 | parser.add_argument('--max_missed_detection_count', type=int, default=10) 59 | parser.add_argument('--max_num_objects', 60 | default=-1, 61 | type=int, 62 | help='Max. num of objects to keep in memory. -1 for no limit') 63 | parser.add_argument('--prompt', type=str, help='Separate classes with a single fullstop') 64 | parser.add_argument('--sam_variant', default='original', help='mobile/original') 65 | return parser 66 | 67 | 68 | def add_auto_default_args(parser): 69 | parser.add_argument('--img_path', default='./example/vipseg') 70 | parser.add_argument('--detection_every', type=int, default=5) 71 | parser.add_argument('--num_voting_frames', 72 | default=3, 73 | type=int, 74 | help='Number of frames selected for voting. only valid in semionline') 75 | 76 | parser.add_argument('--temporal_setting', default='semionline', help='semionline/online') 77 | parser.add_argument('--max_missed_detection_count', type=int, default=5) 78 | parser.add_argument('--max_num_objects', 79 | default=200, 80 | type=int, 81 | help='Max. num of objects to keep in memory. -1 for no limit') 82 | 83 | parser.add_argument('--sam_variant', default='original', help='mobile/original') 84 | parser.add_argument('--suppress_small_objects', action='store_true') 85 | 86 | return parser -------------------------------------------------------------------------------- /deva/ext/grounding_dino.py: -------------------------------------------------------------------------------- 1 | # Reference: https://github.com/IDEA-Research/Grounded-Segment-Anything 2 | 3 | from typing import Dict, List 4 | import numpy as np 5 | import cv2 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | 11 | try: 12 | from groundingdino.util.inference import Model as GroundingDINOModel 13 | except ImportError: 14 | # not sure why this happens sometimes 15 | from GroundingDINO.groundingdino.util.inference import Model as GroundingDINOModel 16 | from segment_anything import sam_model_registry, SamPredictor 17 | try: 18 | from segment_anything import sam_hq_model_registry 19 | except ImportError: 20 | print("HQ-SAM not found, please install it from https://github.com/SysCV/sam-hq") 21 | from deva.ext.MobileSAM.setup_mobile_sam import setup_model as setup_mobile_sam 22 | try: 23 | from deva.ext.LightHQSAM.setup_light_hqsam import setup_model as setup_light_hqsam 24 | except ImportError: 25 | print("Light HQ-SAM not found, please install it from https://github.com/SysCV/sam-hq") 26 | import numpy as np 27 | import torch 28 | 29 | from deva.inference.object_info import ObjectInfo 30 | 31 | 32 | def get_grounding_dino_model(config: Dict, device: str) -> (GroundingDINOModel, SamPredictor): 33 | GROUNDING_DINO_CONFIG_PATH = config['GROUNDING_DINO_CONFIG_PATH'] 34 | GROUNDING_DINO_CHECKPOINT_PATH = config['GROUNDING_DINO_CHECKPOINT_PATH'] 35 | 36 | gd_model = GroundingDINOModel(model_config_path=GROUNDING_DINO_CONFIG_PATH, 37 | model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, 38 | device=device) 39 | 40 | # Building SAM Model and SAM Predictor 41 | variant = config['sam_variant'].lower() 42 | if variant == 'mobile': 43 | MOBILE_SAM_CHECKPOINT_PATH = config['MOBILE_SAM_CHECKPOINT_PATH'] 44 | 45 | # Building Mobile SAM model 46 | checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH) 47 | mobile_sam = setup_mobile_sam() 48 | mobile_sam.load_state_dict(checkpoint, strict=True) 49 | mobile_sam.to(device=device) 50 | sam = SamPredictor(mobile_sam) 51 | elif variant == 'original': 52 | SAM_ENCODER_VERSION = config['SAM_ENCODER_VERSION'] 53 | SAM_CHECKPOINT_PATH = config['SAM_CHECKPOINT_PATH'] 54 | 55 | sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to( 56 | device=device) 57 | sam = SamPredictor(sam) 58 | elif variant == 'sam_hq': 59 | # Building HQ-SAM model with better Mask quality 60 | SAM_ENCODER_VERSION = config['SAM_ENCODER_VERSION'] 61 | HQ_SAM_CHECKPOINT_PATH = config['HQ_SAM_CHECKPOINT_PATH'] 62 | sam_hq = sam_hq_model_registry[SAM_ENCODER_VERSION](checkpoint=HQ_SAM_CHECKPOINT_PATH).to( 63 | device=device) 64 | sam = SamPredictor(sam_hq) 65 | elif variant == 'sam_hq_light': 66 | LIGHT_HQ_SAM_CHECKPOINT_PATH = config['LIGHT_HQ_SAM_CHECKPOINT_PATH'] 67 | 68 | # Building Light HQ-SAM model with good Mask quality and efficiency 69 | checkpoint = torch.load(LIGHT_HQ_SAM_CHECKPOINT_PATH) 70 | light_hq_sam = setup_light_hqsam() 71 | light_hq_sam.load_state_dict(checkpoint, strict=True) 72 | light_hq_sam.to(device=device) 73 | sam = SamPredictor(light_hq_sam) 74 | 75 | return gd_model, sam 76 | 77 | 78 | def segment_with_text(config: Dict, gd_model: GroundingDINOModel, sam: SamPredictor, 79 | image: np.ndarray, prompts: List[str], 80 | min_side: int) -> (torch.Tensor, List[ObjectInfo]): 81 | """ 82 | config: the global configuration dictionary 83 | image: the image to segment; should be a numpy array; H*W*3; unnormalized (0~255) 84 | prompts: list of class names 85 | 86 | Returns: a torch index mask of the same size as image; H*W 87 | a list of segment info, see object_utils.py for definition 88 | """ 89 | 90 | BOX_THRESHOLD = TEXT_THRESHOLD = config['DINO_THRESHOLD'] 91 | NMS_THRESHOLD = config['DINO_NMS_THRESHOLD'] 92 | 93 | sam.set_image(image, image_format='RGB') 94 | 95 | # detect objects 96 | # GroundingDINO uses BGR 97 | detections = gd_model.predict_with_classes(image=cv2.cvtColor(image, cv2.COLOR_RGB2BGR), 98 | classes=prompts, 99 | box_threshold=BOX_THRESHOLD, 100 | text_threshold=TEXT_THRESHOLD) 101 | nms_idx = torchvision.ops.nms(torch.from_numpy(detections.xyxy), 102 | torch.from_numpy(detections.confidence), 103 | NMS_THRESHOLD).numpy().tolist() 104 | 105 | detections.xyxy = detections.xyxy[nms_idx] 106 | detections.confidence = detections.confidence[nms_idx] 107 | detections.class_id = detections.class_id[nms_idx] 108 | 109 | result_masks = [] 110 | for box in detections.xyxy: 111 | masks, scores, _ = sam.predict(box=box, multimask_output=True) 112 | index = np.argmax(scores) 113 | result_masks.append(masks[index]) 114 | 115 | detections.mask = np.array(result_masks) 116 | 117 | h, w = image.shape[:2] 118 | if min_side > 0: 119 | scale = min_side / min(h, w) 120 | new_h, new_w = int(h * scale), int(w * scale) 121 | else: 122 | new_h, new_w = h, w 123 | 124 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=gd_model.device) 125 | curr_id = 1 126 | segments_info = [] 127 | 128 | # sort by descending area to preserve the smallest object 129 | for i in np.flip(np.argsort(detections.area)): 130 | mask = detections.mask[i] 131 | confidence = detections.confidence[i] 132 | class_id = detections.class_id[i] 133 | mask = torch.from_numpy(mask.astype(np.float32)) 134 | mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), (new_h, new_w), mode='bilinear')[0, 0] 135 | mask = (mask > 0.5).float() 136 | 137 | if mask.sum() > 0: 138 | output_mask[mask > 0] = curr_id 139 | segments_info.append(ObjectInfo(id=curr_id, category_id=class_id, score=confidence)) 140 | curr_id += 1 141 | 142 | return output_mask, segments_info 143 | -------------------------------------------------------------------------------- /deva/ext/with_text_processor.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from typing import Dict, List 3 | 4 | import cv2 5 | import torch 6 | import numpy as np 7 | 8 | from deva.inference.object_info import ObjectInfo 9 | from deva.inference.inference_core import DEVAInferenceCore 10 | from deva.inference.frame_utils import FrameInfo 11 | from deva.inference.result_utils import ResultSaver 12 | from deva.inference.demo_utils import get_input_frame_for_deva 13 | from deva.ext.grounding_dino import segment_with_text 14 | try: 15 | from groundingdino.util.inference import Model as GroundingDINOModel 16 | except ImportError: 17 | # not sure why this happens sometimes 18 | from GroundingDINO.groundingdino.util.inference import Model as GroundingDINOModel 19 | from segment_anything import SamPredictor 20 | 21 | 22 | def make_segmentation_with_text(cfg: Dict, image_np: np.ndarray, gd_model: GroundingDINOModel, 23 | sam_model: SamPredictor, prompts: List[str], 24 | min_side: int) -> (torch.Tensor, List[ObjectInfo]): 25 | mask, segments_info = segment_with_text(cfg, gd_model, sam_model, image_np, prompts, min_side) 26 | return mask, segments_info 27 | 28 | 29 | @torch.inference_mode() 30 | def process_frame_with_text(deva: DEVAInferenceCore, 31 | gd_model: GroundingDINOModel, 32 | sam_model: SamPredictor, 33 | frame_path: str, 34 | result_saver: ResultSaver, 35 | ti: int, 36 | image_np: np.ndarray = None) -> None: 37 | # image_np, if given, should be in RGB 38 | if image_np is None: 39 | image_np = cv2.imread(frame_path) 40 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB) 41 | cfg = deva.config 42 | raw_prompt = cfg['prompt'] 43 | prompts = raw_prompt.split('.') 44 | 45 | h, w = image_np.shape[:2] 46 | new_min_side = cfg['size'] 47 | need_resize = new_min_side > 0 48 | image = get_input_frame_for_deva(image_np, new_min_side) 49 | 50 | frame_name = path.basename(frame_path) 51 | frame_info = FrameInfo(image, None, None, ti, { 52 | 'frame': [frame_name], 53 | 'shape': [h, w], 54 | }) 55 | 56 | if cfg['temporal_setting'] == 'semionline': 57 | if ti + cfg['num_voting_frames'] > deva.next_voting_frame: 58 | mask, segments_info = make_segmentation_with_text(cfg, image_np, gd_model, sam_model, 59 | prompts, new_min_side) 60 | frame_info.mask = mask 61 | frame_info.segments_info = segments_info 62 | frame_info.image_np = image_np # for visualization only 63 | # wait for more frames before proceeding 64 | deva.add_to_temporary_buffer(frame_info) 65 | 66 | if ti == deva.next_voting_frame: 67 | # process this clip 68 | this_image = deva.frame_buffer[0].image 69 | this_frame_name = deva.frame_buffer[0].name 70 | this_image_np = deva.frame_buffer[0].image_np 71 | 72 | _, mask, new_segments_info = deva.vote_in_temporary_buffer( 73 | keyframe_selection='first') 74 | prob = deva.incorporate_detection(this_image, mask, new_segments_info) 75 | deva.next_voting_frame += cfg['detection_every'] 76 | 77 | result_saver.save_mask(prob, 78 | this_frame_name, 79 | need_resize=need_resize, 80 | shape=(h, w), 81 | image_np=this_image_np, 82 | prompts=prompts) 83 | 84 | for frame_info in deva.frame_buffer[1:]: 85 | this_image = frame_info.image 86 | this_frame_name = frame_info.name 87 | this_image_np = frame_info.image_np 88 | prob = deva.step(this_image, None, None) 89 | result_saver.save_mask(prob, 90 | this_frame_name, 91 | need_resize, 92 | shape=(h, w), 93 | image_np=this_image_np, 94 | prompts=prompts) 95 | 96 | deva.clear_buffer() 97 | else: 98 | # standard propagation 99 | prob = deva.step(image, None, None) 100 | result_saver.save_mask(prob, 101 | frame_name, 102 | need_resize=need_resize, 103 | shape=(h, w), 104 | image_np=image_np, 105 | prompts=prompts) 106 | 107 | elif cfg['temporal_setting'] == 'online': 108 | if ti % cfg['detection_every'] == 0: 109 | # incorporate new detections 110 | mask, segments_info = make_segmentation_with_text(cfg, image_np, gd_model, sam_model, 111 | prompts, new_min_side) 112 | frame_info.segments_info = segments_info 113 | prob = deva.incorporate_detection(image, mask, segments_info) 114 | else: 115 | # Run the model on this frame 116 | prob = deva.step(image, None, None) 117 | result_saver.save_mask(prob, 118 | frame_name, 119 | need_resize=need_resize, 120 | shape=(h, w), 121 | image_np=image_np, 122 | prompts=prompts) 123 | -------------------------------------------------------------------------------- /deva/inference/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/inference/__init__.py -------------------------------------------------------------------------------- /deva/inference/consensus_associated.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the implementation of the consensus when the association is already established. 3 | E.g., when we know which mask in frame 1 corresponds to which mask in frame 2. 4 | There is no need to use integer programming for matching. 5 | """ 6 | 7 | from typing import List, Dict 8 | import torch 9 | 10 | from deva.model.memory_utils import * 11 | from deva.model.network import DEVA 12 | from deva.inference.image_feature_store import ImageFeatureStore 13 | from deva.utils.tensor_utils import pad_divide_by, unpad 14 | 15 | 16 | def spatial_alignment(src_ti: int, src_image: torch.Tensor, src_mask: torch.Tensor, tar_ti: int, 17 | tar_image: torch.Tensor, network: DEVA, store: ImageFeatureStore, 18 | config: Dict) -> torch.Tensor: 19 | """ 20 | src_image/tar_image: 3*H*W 21 | src_mask: num_objects*H*W 22 | 23 | returns: a segmentation mask of the target image: num_objects*H*W 24 | """ 25 | num_objects, h, w = src_mask.shape 26 | src_image = src_image.unsqueeze(0) 27 | tar_image = tar_image.unsqueeze(0) 28 | src_mask = src_mask.unsqueeze(0) 29 | 30 | # get source features 31 | src_ms_features = store.get_ms_features(src_ti, src_image) 32 | src_key, src_shrinkage, _ = store.get_key(src_ti, src_image) 33 | # get target features 34 | tar_ms_features = store.get_ms_features(tar_ti, tar_image) 35 | tar_key, _, tar_selection = store.get_key(tar_ti, tar_image) 36 | 37 | # encode memory from the source frame 38 | sensory = torch.zeros((1, num_objects, config['value_dim'], h // 16, w // 16), 39 | device=src_key.device) 40 | value, sensory = network.encode_mask(src_image, 41 | src_ms_features, 42 | sensory, 43 | src_mask, 44 | is_deep_update=True, 45 | chunk_size=config['chunk_size']) 46 | 47 | # key matching 48 | src_key = src_key.flatten(start_dim=2) 49 | src_shrinkage = src_shrinkage.flatten(start_dim=2) 50 | tar_key = tar_key.flatten(start_dim=2) 51 | tar_selection = tar_selection.flatten(start_dim=2) 52 | # 1*num_objects*C*H*W -> 1*(num_objects*C)*(H*W) 53 | value = value.flatten(start_dim=1, end_dim=2).flatten(start_dim=2) 54 | 55 | similarity = get_similarity(src_key, src_shrinkage, tar_key, tar_selection) 56 | affinity = do_softmax(similarity, top_k=config['top_k']) 57 | # readout 58 | memory_readout = value @ affinity 59 | memory_readout = memory_readout.view(1, num_objects, config['value_dim'], h // 16, w // 16) 60 | 61 | # segmentation 62 | _, _, tar_mask = network.segment(tar_ms_features, 63 | memory_readout, 64 | sensory, 65 | src_mask, 66 | chunk_size=config['chunk_size'], 67 | update_sensory=False) 68 | 69 | return tar_mask 70 | 71 | 72 | def _keyframe_objective_from_mask(mask, score, method='high_foreground') -> float: 73 | # compute a good-to-be-keyframe score of a mask 74 | if method == 'high_foreground': 75 | return (mask > 0.8).float().mean() 76 | elif method == 'score': 77 | return score 78 | else: 79 | raise NotImplementedError 80 | 81 | 82 | def find_consensus_with_established_association(time_indices: List[int], 83 | images: List[torch.Tensor], 84 | masks: List[torch.Tensor], 85 | network: DEVA, 86 | store: ImageFeatureStore, 87 | config: Dict, 88 | scores: List[float] = None) -> (int, torch.Tensor): 89 | 90 | # apply padding to all images and masks 91 | for i, (image, mask) in enumerate(zip(images, masks)): 92 | images[i], pads = pad_divide_by(image, 16) 93 | masks[i], _ = pad_divide_by(mask, 16) 94 | 95 | # if scores is None, assume uniform (for averaging later on) 96 | if scores is None: 97 | scores = [1 for _ in time_indices] 98 | use_score = False 99 | else: 100 | use_score = True 101 | scores = torch.softmax(torch.Tensor(scores) * 2, dim=0).tolist() 102 | 103 | # first, find a keyframe 104 | keyframe_objective = float('-inf') 105 | keyframe_ti = None 106 | keyframe_image = None 107 | keyframe_mask = None 108 | keyframe_score = None 109 | 110 | if use_score: 111 | # ranking with score 112 | for ti, image, mask, score in zip(time_indices, images, masks, scores): 113 | objective = _keyframe_objective_from_mask(mask, score, method='score') 114 | if objective > keyframe_objective: 115 | keyframe_objective = objective 116 | keyframe_ti = ti 117 | keyframe_image = image 118 | keyframe_mask = mask 119 | keyframe_score = score 120 | else: 121 | # score-less ranking 122 | score = None 123 | for ti, image, mask in zip(time_indices, images, masks): 124 | objective = _keyframe_objective_from_mask(mask, score, method='high_foreground') 125 | if objective > keyframe_objective: 126 | keyframe_objective = objective 127 | keyframe_ti = ti 128 | keyframe_image = image 129 | keyframe_mask = mask 130 | keyframe_score = score 131 | 132 | if keyframe_score is None: 133 | keyframe_score = scores[0] 134 | 135 | # then, project all frames onto the keyframe 136 | # we also project the keyframe onto the keyframe itself for mask refinement 137 | total_projected_mask = keyframe_mask * keyframe_score 138 | for ti, image, mask, score in zip(time_indices, images, masks, scores): 139 | # the keyframe is already added 140 | if ti == keyframe_ti: 141 | continue 142 | projected_mask = spatial_alignment(ti, image, mask, keyframe_ti, keyframe_image, network, 143 | store, config) 144 | total_projected_mask += projected_mask[0, 1:] * score 145 | 146 | total_projected_mask = unpad(total_projected_mask, pads) 147 | return keyframe_ti, total_projected_mask 148 | -------------------------------------------------------------------------------- /deva/inference/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/inference/data/__init__.py -------------------------------------------------------------------------------- /deva/inference/data/detection_video_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | from torch.utils.data.dataset import Dataset 5 | from torchvision import transforms 6 | from torchvision.transforms import InterpolationMode 7 | import torch.nn.functional as F 8 | from PIL import Image 9 | import numpy as np 10 | 11 | from deva.dataset.utils import im_normalization 12 | 13 | 14 | class DetectionVideoReader(Dataset): 15 | """ 16 | This class is used to read a video, one frame at a time 17 | """ 18 | def __init__(self, 19 | vid_name, 20 | image_dir, 21 | mask_dir, 22 | size=-1, 23 | to_save=None, 24 | size_dir=None, 25 | start=-1, 26 | end=-1, 27 | reverse=False): 28 | """ 29 | image_dir - points to a directory of jpg images 30 | mask_dir - points to a directory of png masks 31 | size - resize min. side to size. Does nothing if <0. 32 | to_save - optionally contains a list of file names without extensions 33 | where the segmentation mask is required 34 | """ 35 | # TODO: determine if_rgb automatically 36 | self.vid_name = vid_name 37 | self.image_dir = image_dir 38 | self.mask_dir = mask_dir 39 | self.to_save = to_save 40 | if size_dir is None: 41 | self.size_dir = self.image_dir 42 | else: 43 | self.size_dir = size_dir 44 | 45 | self.frames = sorted(os.listdir(self.image_dir)) 46 | if start > 0: 47 | self.frames = self.frames[start:] 48 | if end > 0: 49 | self.frames = self.frames[:end] 50 | if reverse: 51 | self.frames = reversed(self.frames) 52 | 53 | self.palette = Image.open(path.join(mask_dir, self.frames[0].replace('.jpg', 54 | '.png'))).getpalette() 55 | self.first_gt_path = path.join(self.mask_dir, self.frames[0].replace('.jpg', '.png')) 56 | 57 | if size < 0: 58 | self.im_transform = transforms.Compose([ 59 | transforms.ToTensor(), 60 | im_normalization, 61 | ]) 62 | self.mask_transform = transforms.Compose([]) 63 | else: 64 | self.im_transform = transforms.Compose([ 65 | transforms.ToTensor(), 66 | im_normalization, 67 | transforms.Resize(size, interpolation=InterpolationMode.BILINEAR, antialias=True), 68 | ]) 69 | self.mask_transform = transforms.Compose([ 70 | transforms.Resize(size, interpolation=InterpolationMode.NEAREST), 71 | ]) 72 | self.size = size 73 | self.is_rgb = None 74 | 75 | def __getitem__(self, idx): 76 | frame = self.frames[idx] 77 | info = {} 78 | data = {} 79 | info['frame'] = frame 80 | info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save) 81 | 82 | im_path = path.join(self.image_dir, frame) 83 | img = Image.open(im_path).convert('RGB') 84 | 85 | if self.image_dir == self.size_dir: 86 | shape = np.array(img).shape[:2] 87 | else: 88 | size_path = path.join(self.size_dir, frame) 89 | size_im = Image.open(size_path).convert('RGB') 90 | shape = np.array(size_im).shape[:2] 91 | 92 | mask_path = path.join(self.mask_dir, frame[:-4] + '.png') 93 | img = self.im_transform(img) 94 | 95 | if path.exists(mask_path): 96 | mask = Image.open(mask_path) 97 | mask = self.mask_transform(mask) 98 | if mask.mode == 'RGB': 99 | mask = np.array(mask, dtype=np.int32) 100 | mask = mask[:, :, 0] + mask[:, :, 1] * 256 + mask[:, :, 2] * 256 * 256 101 | self.is_rgb = True 102 | else: 103 | mask = mask.convert('P') 104 | mask = np.array(mask, dtype=np.int32) 105 | self.is_rgb = False 106 | data['mask'] = mask 107 | 108 | # defer json loading to the model 109 | json_path = path.join(self.mask_dir, frame[:-4] + '.json') 110 | if path.exists(json_path): 111 | info['json'] = json_path 112 | 113 | info['is_rgb'] = self.is_rgb 114 | info['shape'] = shape 115 | info['need_resize'] = not (self.size < 0) 116 | info['path_to_image'] = im_path 117 | data['rgb'] = img 118 | data['info'] = info 119 | 120 | return data 121 | 122 | def get_palette(self): 123 | return self.palette 124 | 125 | def __len__(self): 126 | return len(self.frames) -------------------------------------------------------------------------------- /deva/inference/data/referring_test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import json 4 | from collections import defaultdict 5 | import numpy as np 6 | 7 | from deva.inference.data.video_reader import VideoReader 8 | 9 | 10 | class ReferringDAVISTestDataset: 11 | def __init__(self, image_dir, mask_dir, size=-1): 12 | self.image_dir = image_dir 13 | self.mask_dir = mask_dir 14 | self.size = size 15 | 16 | self.vid_list = sorted(os.listdir(self.mask_dir)) 17 | 18 | def get_videos(self): 19 | return self.vid_list 20 | 21 | def get_offline_sampled_frames(self, video, num_sampled_frames): 22 | return VideoReader( 23 | video, 24 | path.join(self.image_dir, video), 25 | path.join(self.mask_dir, video), 26 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 27 | size=self.size, 28 | soft_mask=True, 29 | num_sampled_frames=num_sampled_frames, 30 | use_all_masks=True, 31 | ) 32 | 33 | def get_partial_video_loader(self, video, *, start, end, reverse): 34 | return VideoReader( 35 | video, 36 | path.join(self.image_dir, video), 37 | path.join(self.mask_dir, video), 38 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 39 | size=self.size, 40 | soft_mask=True, 41 | start=start, 42 | end=end, 43 | reverse=reverse, 44 | ) 45 | 46 | def get_scores(self, video): 47 | with open(path.join(self.mask_dir, video, 'scores.csv')) as f: 48 | lines = f.read().splitlines() 49 | scores = defaultdict(dict) 50 | for l in lines: 51 | frame, obj, score = l.split(',') 52 | scores[frame[:-4]][obj] = float(score) 53 | 54 | average_scores = {} 55 | for frame, all_objects in scores.items(): 56 | average_scores[frame] = np.array(list(all_objects.values())).mean() 57 | 58 | return average_scores 59 | 60 | def __len__(self): 61 | return len(self.vid_list) 62 | 63 | 64 | class ReferringYouTubeVOSTestDataset: 65 | def __init__(self, image_dir, mask_dir, json_dir, size=-1): 66 | self.image_dir = image_dir 67 | self.mask_dir = mask_dir 68 | self.size = size 69 | 70 | self.vid_list = sorted(os.listdir(self.mask_dir)) 71 | self.req_frame_list = {} 72 | 73 | with open(json_dir) as f: 74 | # read meta.json to know which frame is required for evaluation 75 | meta = json.load(f)['videos'] 76 | 77 | for vid in self.vid_list: 78 | req_frames = [] 79 | req_frames.extend(meta[vid]['frames']) 80 | 81 | req_frames = list(set(req_frames)) 82 | self.req_frame_list[vid] = req_frames 83 | 84 | def get_videos(self): 85 | return self.vid_list 86 | 87 | def get_objects(self, video): 88 | return [ 89 | obj for obj in sorted(os.listdir(path.join(self.mask_dir, video))) if '.csv' not in obj 90 | ] 91 | 92 | def _get_to_save_list(self, video, object_name): 93 | return self.req_frame_list[video] 94 | 95 | def get_offline_sampled_frames(self, video, object_name, num_sampled_frames): 96 | return VideoReader( 97 | video, 98 | path.join(self.image_dir, video), 99 | path.join(self.mask_dir, video), 100 | size=self.size, 101 | soft_mask=True, 102 | num_sampled_frames=num_sampled_frames, 103 | use_all_masks=True, 104 | to_save=self._get_to_save_list(video, object_name), 105 | object_name=object_name, 106 | enabled_frame_list=self._get_enabled_frame_list(video, object_name), 107 | ) 108 | 109 | def get_partial_video_loader(self, video, object_name, *, start, end, reverse): 110 | return VideoReader( 111 | video, 112 | path.join(self.image_dir, video), 113 | path.join(self.mask_dir, video), 114 | size=self.size, 115 | soft_mask=True, 116 | start=start, 117 | end=end, 118 | reverse=reverse, 119 | to_save=self._get_to_save_list(video, object_name), 120 | object_name=object_name, 121 | enabled_frame_list=self._get_enabled_frame_list(video, object_name), 122 | ) 123 | 124 | def get_scores(self, video): 125 | with open(path.join(self.mask_dir, video, 'scores.csv')) as f: 126 | lines = f.read().splitlines() 127 | scores = defaultdict(dict) 128 | enabled_frame_list = self._get_enabled_frame_list(video, None) 129 | for l in lines: 130 | frame, obj, score = l.split(',') 131 | if enabled_frame_list is not None and frame[:-4] not in enabled_frame_list: 132 | continue 133 | scores[obj][frame[:-4]] = float(score) 134 | return scores 135 | 136 | def _get_enabled_frame_list(self, video, object_name): 137 | # None -> enable all 138 | return None 139 | 140 | def __len__(self): 141 | return len(self.vid_list) 142 | -------------------------------------------------------------------------------- /deva/inference/data/saliency_test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | from deva.inference.data.video_reader import VideoReader 5 | 6 | 7 | class DAVISSaliencyTestDataset: 8 | def __init__(self, image_dir, mask_dir, imset=None, size=-1): 9 | self.image_dir = image_dir 10 | self.mask_dir = mask_dir 11 | self.size = size 12 | 13 | if imset is None: 14 | self.vid_list = sorted(os.listdir(self.mask_dir)) 15 | else: 16 | with open(imset) as f: 17 | self.vid_list = sorted([line.strip() for line in f]) 18 | 19 | def get_datasets(self): 20 | for video in self.vid_list: 21 | yield VideoReader( 22 | video, 23 | path.join(self.image_dir, video), 24 | path.join(self.mask_dir, video), 25 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 26 | size=self.size, 27 | soft_mask=True, 28 | use_all_masks=True, 29 | multi_object=False, 30 | ) 31 | 32 | def get_videos(self): 33 | return self.vid_list 34 | 35 | def get_offline_sampled_frames(self, video, num_sampled_frames): 36 | return VideoReader( 37 | video, 38 | path.join(self.image_dir, video), 39 | path.join(self.mask_dir, video), 40 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 41 | size=self.size, 42 | soft_mask=True, 43 | num_sampled_frames=num_sampled_frames, 44 | use_all_masks=True, 45 | multi_object=False, 46 | ) 47 | 48 | def get_partial_video_loader(self, video, *, start, end, reverse): 49 | return VideoReader( 50 | video, 51 | path.join(self.image_dir, video), 52 | path.join(self.mask_dir, video), 53 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 54 | size=self.size, 55 | soft_mask=True, 56 | start=start, 57 | end=end, 58 | reverse=reverse, 59 | multi_object=False, 60 | ) 61 | 62 | def __len__(self): 63 | return len(self.vid_list) -------------------------------------------------------------------------------- /deva/inference/data/simple_video_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from torch.utils.data.dataset import Dataset 4 | from PIL import Image 5 | import numpy as np 6 | 7 | 8 | class SimpleVideoReader(Dataset): 9 | """ 10 | This class is used to read a video, one frame at a time 11 | This simple version: 12 | 1. Does not load the mask/json 13 | 2. Does not normalize the input 14 | 3. Does not resize 15 | """ 16 | def __init__( 17 | self, 18 | image_dir, 19 | ): 20 | """ 21 | image_dir - points to a directory of jpg images 22 | """ 23 | self.image_dir = image_dir 24 | self.frames = sorted(os.listdir(self.image_dir)) 25 | 26 | def __getitem__(self, idx): 27 | frame = self.frames[idx] 28 | 29 | im_path = path.join(self.image_dir, frame) 30 | img = Image.open(im_path).convert('RGB') 31 | img = np.array(img) 32 | 33 | return img, im_path 34 | 35 | def __len__(self): 36 | return len(self.frames) 37 | 38 | 39 | def no_collate(x): 40 | return x -------------------------------------------------------------------------------- /deva/inference/data/vos_test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import json 4 | 5 | from deva.inference.data.video_reader import VideoReader 6 | 7 | 8 | class GeneralVOSTestDataset: 9 | def __init__(self, data_root, size=-1, use_all_masks=False): 10 | self.image_dir = path.join(data_root, 'JPEGImages') 11 | self.mask_dir = path.join(data_root, 'Annotations') 12 | self.size = size 13 | self.use_all_masks = use_all_masks 14 | 15 | self.vid_list = sorted(os.listdir(self.mask_dir)) 16 | 17 | def get_datasets(self): 18 | for video in self.vid_list: 19 | yield VideoReader( 20 | video, 21 | path.join(self.image_dir, video), 22 | path.join(self.mask_dir, video), 23 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 24 | size=self.size, 25 | use_all_masks=self.use_all_masks, 26 | ) 27 | 28 | def __len__(self): 29 | return len(self.vid_list) 30 | 31 | 32 | class DAVISTestDataset: 33 | def __init__(self, data_root, imset='2017/val.txt', size=-1): 34 | if size != 480: 35 | self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution') 36 | self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution') 37 | if not path.exists(self.image_dir): 38 | print(f'{self.image_dir} not found. Looking at .../1080p instead') 39 | self.image_dir = path.join(data_root, 'JPEGImages', '1080p') 40 | self.mask_dir = path.join(data_root, 'Annotations', '1080p') 41 | assert path.exists(self.image_dir), 'path not found' 42 | else: 43 | self.image_dir = path.join(data_root, 'JPEGImages', '480p') 44 | self.mask_dir = path.join(data_root, 'Annotations', '480p') 45 | self.size_dir = path.join(data_root, 'JPEGImages', '480p') 46 | self.size = size 47 | 48 | with open(path.join(data_root, 'ImageSets', imset)) as f: 49 | self.vid_list = sorted([line.strip() for line in f]) 50 | 51 | def get_datasets(self): 52 | for video in self.vid_list: 53 | yield VideoReader( 54 | video, 55 | path.join(self.image_dir, video), 56 | path.join(self.mask_dir, video), 57 | size=self.size, 58 | size_dir=path.join(self.size_dir, video), 59 | ) 60 | 61 | def __len__(self): 62 | return len(self.vid_list) 63 | 64 | 65 | class YouTubeVOSTestDataset: 66 | def __init__(self, data_root, split, size=480): 67 | self.image_dir = path.join(data_root, 'all_frames', split + '_all_frames', 'JPEGImages') 68 | self.mask_dir = path.join(data_root, split, 'Annotations') 69 | self.size = size 70 | 71 | self.vid_list = sorted(os.listdir(self.image_dir)) 72 | self.req_frame_list = {} 73 | 74 | with open(path.join(data_root, split, 'meta.json')) as f: 75 | # read meta.json to know which frame is required for evaluation 76 | meta = json.load(f)['videos'] 77 | 78 | for vid in self.vid_list: 79 | req_frames = [] 80 | objects = meta[vid]['objects'] 81 | for value in objects.values(): 82 | req_frames.extend(value['frames']) 83 | 84 | req_frames = list(set(req_frames)) 85 | self.req_frame_list[vid] = req_frames 86 | 87 | def get_datasets(self): 88 | for video in self.vid_list: 89 | yield VideoReader(video, 90 | path.join(self.image_dir, video), 91 | path.join(self.mask_dir, video), 92 | size=self.size, 93 | to_save=self.req_frame_list[video], 94 | use_all_masks=True) 95 | 96 | def __len__(self): 97 | return len(self.vid_list) 98 | -------------------------------------------------------------------------------- /deva/inference/data/vps_test_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | import json 4 | 5 | from deva.inference.data.detection_video_reader import DetectionVideoReader 6 | 7 | 8 | class VIPSegDetectionTestDataset: 9 | def __init__(self, image_dir, mask_dir, size=-1): 10 | self.image_dir = image_dir 11 | self.mask_dir = mask_dir 12 | self.size = size 13 | self.vid_list = sorted(os.listdir(self.mask_dir)) 14 | self.vid_list = [v for v in self.vid_list if not v.endswith('.json')] 15 | 16 | def get_datasets(self): 17 | for video in self.vid_list: 18 | yield DetectionVideoReader( 19 | video, 20 | path.join(self.image_dir, video), 21 | path.join(self.mask_dir, video), 22 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))], 23 | size=self.size, 24 | ) 25 | 26 | def __len__(self): 27 | return len(self.vid_list) 28 | 29 | 30 | class BURSTDetectionTestDataset: 31 | def __init__(self, image_dir, mask_dir, gt_json_dir, size=-1, *, start=None, count=None): 32 | self.image_dir = image_dir 33 | self.mask_dir = mask_dir 34 | self.size = size 35 | 36 | # read the json file to get a list of videos and frames to save 37 | with open(gt_json_dir, 'r') as f: 38 | json_file = json.load(f) 39 | sequences = json_file['sequences'] 40 | split = json_file['split'] 41 | 42 | assert split == 'test' or split == 'val' 43 | 44 | # load a randomized ordering of BURST videos for a balanced load 45 | with open(f'./deva/utils/burst_{split}.txt', mode='r') as f: 46 | randomized_videos = list(f.read().splitlines()) 47 | 48 | # subsample a list of videos for processing 49 | if start is not None and count is not None: 50 | randomized_videos = randomized_videos[start:start + count] 51 | print(f'Start: {start}, Count: {count}, End: {start+count}') 52 | 53 | self.vid_list = [] 54 | self.frames_to_save = {} 55 | for sequence in sequences: 56 | dataset = sequence['dataset'] 57 | seq_name = sequence['seq_name'] 58 | video_name = path.join(dataset, seq_name) 59 | if video_name not in randomized_videos: 60 | continue 61 | self.vid_list.append(video_name) 62 | 63 | annotated_image_paths = sequence['annotated_image_paths'] 64 | self.frames_to_save[video_name] = [p[:-4] for p in annotated_image_paths] 65 | assert path.exists(path.join(image_dir, video_name)) 66 | assert path.exists(path.join(mask_dir, video_name)) 67 | 68 | assert len(self.vid_list) == len(randomized_videos) 69 | # to use the random ordering 70 | self.vid_list = randomized_videos 71 | 72 | print(f'Actual total: {len(self.vid_list)}') 73 | 74 | def get_datasets(self): 75 | for video in self.vid_list: 76 | yield DetectionVideoReader( 77 | video, 78 | path.join(self.image_dir, video), 79 | path.join(self.mask_dir, video), 80 | to_save=self.frames_to_save[video], 81 | size=self.size, 82 | ) 83 | 84 | def __len__(self): 85 | return len(self.vid_list) -------------------------------------------------------------------------------- /deva/inference/demo_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from deva.dataset.utils import im_normalization 6 | from deva.inference.inference_core import DEVAInferenceCore 7 | from deva.inference.result_utils import ResultSaver 8 | 9 | 10 | def get_input_frame_for_deva(image_np: np.ndarray, min_side: int) -> torch.Tensor: 11 | image = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255 12 | image = im_normalization(image) 13 | if min_side > 0: 14 | h, w = image_np.shape[:2] 15 | scale = min_side / min(h, w) 16 | new_h, new_w = int(h * scale), int(w * scale) 17 | image = image.unsqueeze(0) 18 | image = F.interpolate(image, (new_h, new_w), mode='bilinear', align_corners=False)[0] 19 | return image.cuda() 20 | 21 | 22 | @torch.inference_mode() 23 | def flush_buffer(deva: DEVAInferenceCore, result_saver: ResultSaver) -> None: 24 | # process all the remaining frames in the buffer 25 | cfg = deva.config 26 | new_min_side = cfg['size'] 27 | need_resize = new_min_side > 0 28 | 29 | if 'prompt' in cfg: 30 | raw_prompt = cfg['prompt'] 31 | prompts = raw_prompt.split('.') 32 | else: 33 | prompts = None 34 | 35 | for frame_info in deva.frame_buffer: 36 | this_image = frame_info.image 37 | this_frame_name = frame_info.name 38 | this_image_np = frame_info.image_np 39 | h, w = this_image_np.shape[:2] 40 | prob = deva.step(this_image, None, None) 41 | result_saver.save_mask(prob, 42 | this_frame_name, 43 | need_resize=need_resize, 44 | shape=(h, w), 45 | image_np=this_image_np, 46 | prompts=prompts) 47 | -------------------------------------------------------------------------------- /deva/inference/eval_args.py: -------------------------------------------------------------------------------- 1 | # Common evaluation arguments 2 | from argparse import ArgumentParser 3 | import torch 4 | from deva.model.network import DEVA 5 | 6 | 7 | def add_common_eval_args(parser: ArgumentParser): 8 | parser.add_argument('--model', default='./saves/DEVA-propagation.pth') 9 | 10 | parser.add_argument('--output', default=None) 11 | parser.add_argument( 12 | '--save_all', 13 | action='store_true', 14 | help='Save all frames', 15 | ) 16 | 17 | parser.add_argument('--amp', action='store_true') 18 | 19 | # Model parameters 20 | parser.add_argument('--key_dim', type=int, default=64) 21 | parser.add_argument('--value_dim', type=int, default=512) 22 | parser.add_argument('--pix_feat_dim', type=int, default=512) 23 | 24 | # Long-term memory options 25 | parser.add_argument('--disable_long_term', action='store_true') 26 | parser.add_argument('--max_mid_term_frames', 27 | help='T_max in XMem, decrease to save memory', 28 | type=int, 29 | default=10) 30 | parser.add_argument('--min_mid_term_frames', 31 | help='T_min in XMem, decrease to save memory', 32 | type=int, 33 | default=5) 34 | parser.add_argument('--max_long_term_elements', 35 | help='LT_max in XMem, increase if objects disappear for a long time', 36 | type=int, 37 | default=10000) 38 | parser.add_argument('--num_prototypes', help='P in XMem', type=int, default=128) 39 | 40 | parser.add_argument('--top_k', type=int, default=30) 41 | parser.add_argument('--mem_every', 42 | help='r in XMem. Increase to improve running speed.', 43 | type=int, 44 | default=5) 45 | parser.add_argument( 46 | '--chunk_size', 47 | default=-1, 48 | type=int, 49 | help='''Number of objects to process in parallel as a batch; -1 for unlimited. 50 | Set to a small number to save memory.''') 51 | 52 | parser.add_argument( 53 | '--size', 54 | default=480, 55 | type=int, 56 | help='Resize the shorter side to this size. -1 to use original resolution. ') 57 | 58 | 59 | def get_model_and_config(parser: ArgumentParser): 60 | args = parser.parse_args() 61 | config = vars(args) 62 | config['enable_long_term'] = not config['disable_long_term'] 63 | 64 | # Load our checkpoint 65 | network = DEVA(config).cuda().eval() 66 | if args.model is not None: 67 | model_weights = torch.load(args.model) 68 | network.load_weights(model_weights) 69 | else: 70 | print('No model loaded.') 71 | 72 | return network, config, args -------------------------------------------------------------------------------- /deva/inference/frame_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | import torch 3 | 4 | from deva.inference.object_info import ObjectInfo 5 | 6 | 7 | class FrameInfo: 8 | def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo], 9 | ti: int, info: Dict): 10 | self.image = image 11 | self.mask = mask 12 | self.segments_info = segments_info 13 | self.ti = ti 14 | self.info = info 15 | 16 | @property 17 | def name(self): 18 | return self.info['frame'][0] 19 | 20 | @property 21 | def shape(self): 22 | return self.info['shape'] 23 | 24 | @property 25 | def save_needed(self): 26 | return self.info['save'][0] 27 | 28 | @property 29 | def path_to_image(self): 30 | return self.info['path_to_image'][0] 31 | -------------------------------------------------------------------------------- /deva/inference/image_feature_store.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable 2 | import warnings 3 | import torch 4 | from deva.model.network import DEVA 5 | 6 | 7 | class ImageFeatureStore: 8 | """ 9 | A cache for image features. 10 | These features might be reused at different parts of the inference pipeline. 11 | This class provide a easy interface for reusing these features. 12 | It is the user's responsibility to delete features. 13 | 14 | Feature of a frame should be associated with a unique index -- typically the frame id. 15 | """ 16 | def __init__(self, network: DEVA, no_warning: bool = False): 17 | self.network = network 18 | self._store = {} 19 | self.no_warning = no_warning 20 | 21 | def _encode_feature(self, index: int, image: torch.Tensor) -> None: 22 | ms_features, feat = self.network.encode_image(image) 23 | key, shrinkage, selection = self.network.transform_key(feat) 24 | 25 | self._store[index] = (ms_features, feat, key, shrinkage, selection) 26 | 27 | def get_ms_features(self, index, image) -> Iterable[torch.Tensor]: 28 | if index not in self._store: 29 | self._encode_feature(index, image) 30 | 31 | return self._store[index][0] 32 | 33 | def get_key(self, index, image) -> (torch.Tensor, torch.Tensor, torch.Tensor): 34 | if index not in self._store: 35 | self._encode_feature(index, image) 36 | 37 | return self._store[index][2:] 38 | 39 | def delete(self, index) -> None: 40 | if index in self._store: 41 | del self._store[index] 42 | 43 | def __len__(self): 44 | return len(self._store) 45 | 46 | def __del__(self): 47 | if len(self._store) > 0 and not self.no_warning: 48 | warnings.warn(f'Leaking {self._store.keys()} in the image feature store') 49 | -------------------------------------------------------------------------------- /deva/inference/object_info.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import numpy as np 3 | from scipy import stats 4 | from deva.utils.pano_utils import id_to_rgb 5 | 6 | 7 | class ObjectInfo: 8 | """ 9 | Stores meta information for an object 10 | """ 11 | def __init__(self, 12 | id: int, 13 | category_id: Optional[int] = None, 14 | isthing: Optional[bool] = None, 15 | score: Optional[float] = None): 16 | self.id = id 17 | self.category_ids = [category_id] 18 | self.scores = [score] 19 | self.isthing = isthing 20 | self.poke_count = 0 # number of detections since last this object was last seen 21 | 22 | def poke(self) -> None: 23 | self.poke_count += 1 24 | 25 | def unpoke(self) -> None: 26 | self.poke_count = 0 27 | 28 | def merge(self, other) -> None: 29 | self.category_ids.extend(other.category_ids) 30 | self.scores.extend(other.scores) 31 | 32 | def vote_category_id(self) -> Optional[int]: 33 | category_ids = [c for c in self.category_ids if c is not None] 34 | if len(category_ids) == 0: 35 | return None 36 | else: 37 | return int(stats.mode(category_ids, keepdims=False)[0]) 38 | 39 | def vote_score(self) -> Optional[float]: 40 | scores = [c for c in self.scores if c is not None] 41 | if len(scores) == 0: 42 | return None 43 | else: 44 | return float(np.mean(scores)) 45 | 46 | def get_rgb(self) -> np.ndarray: 47 | # this is valid for panoptic segmentation-style id only (0~255**3) 48 | return id_to_rgb(self.id) 49 | 50 | def copy_meta_info(self, other) -> None: 51 | self.category_ids = other.category_ids 52 | self.scores = other.scores 53 | self.isthing = other.isthing 54 | 55 | def __hash__(self): 56 | return hash(self.id) 57 | 58 | def __eq__(self, other): 59 | return self.id == other.id 60 | 61 | def __repr__(self): 62 | return f'(ID: {self.id}, cat: {self.category_ids}, isthing: {self.isthing}, score: {self.scores})' 63 | -------------------------------------------------------------------------------- /deva/inference/object_manager.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Dict, Set 2 | 3 | import torch 4 | import numpy as np 5 | from deva.inference.object_info import ObjectInfo 6 | 7 | 8 | class ObjectManager: 9 | """ 10 | Object (real) IDs are immutable. The same ID always represent the same object. 11 | Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. 12 | """ 13 | def __init__(self): 14 | self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} 15 | self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} 16 | self.obj_id_to_obj: Dict[int, ObjectInfo] = {} 17 | 18 | # We keep track of all historical object IDs to avoid collision 19 | # even removed object IDs stay in this set 20 | self.all_historical_object_ids: Set[int] = set() 21 | self.use_long_id = False 22 | 23 | def _recompute_obj_id_to_obj_mapping(self) -> None: 24 | self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} 25 | 26 | def add_new_objects( 27 | self, objects: Union[List[ObjectInfo], ObjectInfo, 28 | List[int]]) -> (List[int], List[int]): 29 | if not isinstance(objects, list): 30 | objects = [objects] 31 | 32 | corresponding_tmp_ids = [] 33 | corresponding_obj_ids = [] 34 | for obj in objects: 35 | if isinstance(obj, int): 36 | obj = ObjectInfo(id=obj) 37 | 38 | # create new id if id collides with existing ones 39 | # simple heuristic to determine RGB/0~255; works most of the time except when it doesn't 40 | count = 0 41 | new_obj = ObjectInfo(id=obj.id) 42 | while (new_obj.id in self.all_historical_object_ids) or (self.use_long_id 43 | and new_obj.id < 256): 44 | if self.use_long_id: 45 | new_id = np.random.randint(256, 256**3) 46 | else: 47 | new_id = np.random.randint(1, 256) 48 | new_obj = ObjectInfo(id=new_id) 49 | count += 1 50 | 51 | if count > 5000: 52 | raise ValueError("We cannot find a new ID for this object." + 53 | "Perhaps you should use long ID?") 54 | new_obj.copy_meta_info(obj) 55 | 56 | # new object 57 | new_tmp_id = len(self.obj_to_tmp_id) + 1 58 | self.obj_to_tmp_id[new_obj] = new_tmp_id 59 | self.tmp_id_to_obj[new_tmp_id] = new_obj 60 | self.all_historical_object_ids.add(new_obj.id) 61 | corresponding_tmp_ids.append(new_tmp_id) 62 | corresponding_obj_ids.append(new_obj.id) 63 | 64 | self._recompute_obj_id_to_obj_mapping() 65 | assert corresponding_tmp_ids == sorted(corresponding_tmp_ids), 'tmp id assignment bugged' 66 | return corresponding_tmp_ids, corresponding_obj_ids 67 | 68 | def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: 69 | # delete an object or a list of objects 70 | # re-sort the tmp ids 71 | if isinstance(obj_ids_to_remove, int): 72 | obj_ids_to_remove = [obj_ids_to_remove] 73 | 74 | new_tmp_id = 1 75 | total_num_id = len(self.obj_to_tmp_id) 76 | 77 | local_obj_to_tmp_id = {} 78 | local_tmp_to_obj_id = {} 79 | 80 | for tmp_iter in range(1, total_num_id + 1): 81 | obj = self.tmp_id_to_obj[tmp_iter] 82 | if obj.id not in obj_ids_to_remove: 83 | local_obj_to_tmp_id[obj] = new_tmp_id 84 | local_tmp_to_obj_id[new_tmp_id] = obj 85 | new_tmp_id += 1 86 | 87 | self.obj_to_tmp_id = local_obj_to_tmp_id 88 | self.tmp_id_to_obj = local_tmp_to_obj_id 89 | self._recompute_obj_id_to_obj_mapping() 90 | 91 | def purge_inactive_objects(self, 92 | max_missed_detection_count: int) -> (bool, List[int], List[int]): 93 | # remove tmp ids of objects that are removed 94 | obj_id_to_be_deleted = [] 95 | tmp_id_to_be_deleted = [] 96 | tmp_id_to_keep = [] 97 | obj_id_to_keep = [] 98 | 99 | for obj in self.obj_to_tmp_id: 100 | if obj.poke_count > max_missed_detection_count: 101 | obj_id_to_be_deleted.append(obj.id) 102 | tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) 103 | else: 104 | tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) 105 | obj_id_to_keep.append(obj.id) 106 | 107 | purge_activated = len(obj_id_to_be_deleted) > 0 108 | if purge_activated: 109 | self.delete_object(obj_id_to_be_deleted) 110 | return purge_activated, tmp_id_to_keep, obj_id_to_keep 111 | 112 | def tmp_to_obj_cls(self, mask) -> torch.Tensor: 113 | # remap tmp id cls representation to the true object id representation 114 | new_mask = torch.zeros_like(mask) 115 | for tmp_id, obj in self.tmp_id_to_obj.items(): 116 | new_mask[mask == tmp_id] = obj.id 117 | return new_mask 118 | 119 | def get_tmp_to_obj_mapping(self) -> Dict[int, int]: 120 | # returns the mapping in a dict format for saving it with pickle 121 | return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} 122 | 123 | def realize_dict(self, obj_dict: Dict[int, torch.Tensor]) -> torch.Tensor: 124 | # turns a dict indexed by obj id into a tensor, ordered by tmp IDs 125 | output = [] 126 | for _, obj in self.tmp_id_to_obj.items(): 127 | if obj.id not in obj_dict: 128 | raise NotImplementedError 129 | output.append(obj_dict[obj.id]) 130 | output = torch.stack(output, dim=0) 131 | return output 132 | 133 | def make_one_hot(self, cls_mask: torch.Tensor) -> torch.Tensor: 134 | output = [] 135 | for _, obj in self.tmp_id_to_obj.items(): 136 | output.append(cls_mask == obj.id) 137 | if len(output) == 0: 138 | output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) 139 | else: 140 | output = torch.stack(output, dim=0) 141 | return output 142 | 143 | def get_current_segments_info(self) -> List[Dict]: 144 | segments_info = [] 145 | for obj in self.obj_to_tmp_id: 146 | segments_info.append({ 147 | 'category_id': obj.vote_category_id(), 148 | 'id': int(obj.id), 149 | 'score': obj.vote_score(), 150 | }) 151 | return segments_info 152 | 153 | @property 154 | def all_obj_ids(self) -> List[int]: 155 | return [k.id for k in self.obj_to_tmp_id] 156 | 157 | @property 158 | def num_obj(self) -> int: 159 | return len(self.obj_to_tmp_id) 160 | 161 | def has_all(self, objects: List[int]) -> bool: 162 | for obj in objects: 163 | if obj not in self.obj_to_tmp_id: 164 | return False 165 | return True 166 | 167 | def find_object_by_id(self, obj_id) -> ObjectInfo: 168 | return self.obj_id_to_obj[obj_id] 169 | -------------------------------------------------------------------------------- /deva/inference/object_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | from deva.inference.object_info import ObjectInfo 5 | from deva.utils.pano_utils import vipseg_cat_to_isthing 6 | 7 | 8 | def convert_json_dict_to_objects_info(mask: torch.Tensor, 9 | segments_info: Optional[List], 10 | dataset: str = None) -> List[ObjectInfo]: 11 | """ 12 | Convert a json dict to a list of object info 13 | If segments_info is None, we use the unique elements in mask to construct the list 14 | Otherwise mask is ignored 15 | """ 16 | if segments_info is not None: 17 | output = [ 18 | ObjectInfo( 19 | id=segment['id'], 20 | category_id=segment.get('category_id'), 21 | isthing=vipseg_cat_to_isthing[segment.get('category_id')] 22 | if dataset == 'vipseg' else None, 23 | score=float(segment['score']) if 24 | ((dataset == 'burst' or dataset == 'demo') and 'score' in segment) else None) 25 | for segment in segments_info 26 | ] 27 | else: 28 | # use the mask 29 | labels = torch.unique(mask) 30 | labels = labels[labels != 0] 31 | output = [ObjectInfo(l.item()) for l in labels] 32 | 33 | return output 34 | -------------------------------------------------------------------------------- /deva/inference/postprocess_unsup_davis17.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import os 3 | from os import path 4 | import sys 5 | import numpy as np 6 | import tqdm 7 | 8 | from deva.utils.palette import davis_palette 9 | 10 | 11 | def limit_max_id(input_path, output_path, max_num_objects=20): 12 | videos = sorted(os.listdir(input_path)) 13 | for video in tqdm.tqdm(videos): 14 | existing_objects = [] 15 | 16 | video_path = path.join(input_path, video) 17 | frames = sorted(os.listdir(video_path)) 18 | 19 | # determine the objects to keep 20 | for frame in frames: 21 | mask = Image.open(path.join(video_path, frame)) 22 | mask = np.array(mask).astype(np.int32) 23 | if len(mask.shape) == 3: 24 | mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2] 25 | labels = np.unique(mask) 26 | labels = labels[labels != 0] 27 | labels_area = [np.sum(mask == label) for label in labels] 28 | 29 | labels_sorted_by_area = [x for _, x in sorted(zip(labels_area, labels), reverse=True)] 30 | if len(labels_sorted_by_area) + len(existing_objects) <= max_num_objects: 31 | existing_objects += labels_sorted_by_area 32 | else: 33 | existing_objects += labels_sorted_by_area[:max_num_objects - len(existing_objects)] 34 | 35 | if len(existing_objects) == max_num_objects: 36 | break 37 | 38 | assert len(existing_objects) <= max_num_objects 39 | 40 | # remove the objects that are not in the existing_objects list 41 | for frame in frames: 42 | mask = Image.open(path.join(video_path, frame)) 43 | mask = np.array(mask).astype(np.int32) 44 | if len(mask.shape) == 3: 45 | mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2] 46 | labels = np.unique(mask) 47 | labels = labels[labels != 0] 48 | 49 | new_mask = np.zeros_like(mask, dtype=np.uint8) 50 | for new_idx, label in enumerate(existing_objects): 51 | new_mask[mask == label] = new_idx + 1 52 | 53 | mask = Image.fromarray(new_mask) 54 | mask.putpalette(davis_palette) 55 | os.makedirs(path.join(output_path, video), exist_ok=True) 56 | mask.save(path.join(output_path, video, frame)) 57 | 58 | 59 | if __name__ == '__main__': 60 | input_path = sys.argv[1] 61 | output_path = sys.argv[2] 62 | limit_max_id(input_path, output_path) -------------------------------------------------------------------------------- /deva/inference/segment_merging.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains the implementation of segment matching and merging (Section 3.2.2). 3 | 4 | Match & merge the objects as discussed in the paper 5 | (Section 3.2.2 Merging Propagation and Consensus) 6 | Also update the object manager 7 | """ 8 | 9 | import warnings 10 | from typing import List, Literal, Dict, Optional 11 | 12 | import torch 13 | from deva.inference.object_info import ObjectInfo 14 | from deva.inference.object_manager import ObjectManager 15 | 16 | 17 | def _get_iou(m1, m2, m1_sum, m2_sum) -> (float, float, float): 18 | intersection = (m1 * m2).sum() 19 | if intersection < 1e-3: 20 | return 0, None, None 21 | union = (m1_sum + m2_sum - intersection) 22 | return intersection / union, intersection, union 23 | 24 | 25 | def merge_by_iou(our_masks: Dict[ObjectInfo, torch.Tensor], new_masks: Dict[ObjectInfo, 26 | torch.Tensor], 27 | our_sums: Dict[ObjectInfo, torch.Tensor], new_sums: Dict[ObjectInfo, torch.Tensor], 28 | merged_mask: torch.Tensor, object_manager: ObjectManager, 29 | new_segments_info: List[ObjectInfo], isthing_status: Optional[bool], 30 | incremental_mode: bool) -> torch.Tensor: 31 | # meged_mask is edited in-place 32 | our_to_new_matching = {} 33 | matched_area = {} 34 | new_objects = [] 35 | 36 | for new_obj in new_segments_info: 37 | if new_obj.isthing != isthing_status: 38 | continue 39 | for our_obj in object_manager.obj_to_tmp_id: 40 | if (our_obj.isthing != isthing_status) or (our_obj in our_to_new_matching): 41 | continue 42 | iou, _, union = _get_iou(new_masks[new_obj], our_masks[our_obj], new_sums[new_obj], 43 | our_sums[our_obj]) 44 | matched = (iou > 0.5) 45 | if matched: 46 | our_to_new_matching[our_obj] = new_obj 47 | matched_area[(our_obj, False)] = union 48 | break 49 | else: 50 | new_objects.append(new_obj) 51 | matched_area[(new_obj, True)] = new_sums[new_obj] 52 | 53 | # for all unmatched our segment 54 | for our_obj in object_manager.obj_to_tmp_id: 55 | if (our_obj.isthing != isthing_status) or (our_obj in our_to_new_matching): 56 | continue 57 | matched_area[(our_obj, False)] = our_sums[our_obj] 58 | 59 | # rendering by reversed order of areas 60 | sorted_by_area = sorted(matched_area.items(), key=lambda x: x[1], reverse=True) 61 | for (obj, is_new), _ in sorted_by_area: 62 | if is_new: 63 | # obj is a new object 64 | _, corresponding_obj_ids = object_manager.add_new_objects(obj) 65 | merged_mask[new_masks[obj]] = corresponding_obj_ids[0] 66 | else: 67 | # obj is not a new object 68 | if obj in our_to_new_matching: 69 | # merge 70 | new_obj = our_to_new_matching[obj] 71 | merged_mask[our_masks[obj]] = obj.id 72 | merged_mask[new_masks[new_obj]] = obj.id 73 | obj.merge(new_obj) 74 | obj.unpoke() 75 | else: 76 | # copy from our forward mask 77 | merged_mask[our_masks[obj]] = obj.id 78 | if incremental_mode: 79 | if our_sums[obj] < 1: 80 | obj.poke() 81 | else: 82 | obj.unpoke() 83 | else: 84 | obj.poke() 85 | 86 | return merged_mask 87 | 88 | 89 | def match_and_merge(our_mask: torch.Tensor, 90 | new_mask: torch.Tensor, 91 | object_manager: ObjectManager, 92 | new_segments_info: List[ObjectInfo], 93 | mode: Literal['iou'] = 'iou', 94 | max_num_objects: int = -1, 95 | incremental_mode: bool = False) -> torch.Tensor: 96 | """ 97 | our_mask is in temporary ids (consecutive) 98 | new_mask is in object ids (real ids from json) 99 | 100 | Updates the object manager as a side effect 101 | mode: 'iou' only 102 | max_num_objects: maximum number of objects allowed in memory (-1 for no limit) 103 | incremental_mode: existing masks are not expected to be supported by new masks, 104 | thus we only delete masks when they are not visible for too long, 105 | not when they are unsupported for too long 106 | """ 107 | mode = mode.lower() 108 | 109 | # separate the masks into one-hot format 110 | our_mask = our_mask.long() 111 | new_mask = new_mask.long() 112 | our_masks = {obj: (our_mask == tmp) for obj, tmp in object_manager.obj_to_tmp_id.items()} 113 | new_masks = {obj: (new_mask == obj.id) for obj in new_segments_info} 114 | 115 | if max_num_objects > 0 and len( 116 | object_manager.obj_to_tmp_id) + len(new_segments_info) > max_num_objects: 117 | # too many objects; forcibly deny all new objects 118 | warnings.warn( 119 | 'Number of objects exceeded maximum (--max_num_objects); discarding new objects') 120 | new_masks = {} 121 | new_segments_info = [] 122 | 123 | # pre-compute mask sums for IoU computation 124 | our_sums = {obj: m.sum() for m in our_masks for obj, m in our_masks.items()} 125 | new_sums = {obj: m.sum() for m in new_masks for obj, m in new_masks.items()} 126 | 127 | # matching 128 | merged_mask = torch.zeros_like(our_mask) 129 | match_isthing = [None, False, True] # for isthing 130 | # we merge stuff/things/others separately 131 | for isthing_status in match_isthing: 132 | if mode == 'iou': 133 | merged_mask = merge_by_iou(our_masks, new_masks, our_sums, new_sums, merged_mask, 134 | object_manager, new_segments_info, isthing_status, 135 | incremental_mode) 136 | elif mode == 'engulf': 137 | raise NotImplementedError('Engulf mode is deprecated') 138 | merged_mask = merge_by_engulf(our_masks, new_masks, our_sums, new_sums, merged_mask, 139 | object_manager, new_segments_info, isthing_status, 140 | engulf_threshold) 141 | 142 | merged_mask = object_manager.make_one_hot(merged_mask) 143 | return merged_mask 144 | -------------------------------------------------------------------------------- /deva/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/model/__init__.py -------------------------------------------------------------------------------- /deva/model/cbam.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class BasicConv(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): 9 | super(BasicConv, self).__init__() 10 | self.out_channels = out_planes 11 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) 12 | 13 | def forward(self, x): 14 | x = self.conv(x) 15 | return x 16 | 17 | class Flatten(nn.Module): 18 | def forward(self, x): 19 | return x.view(x.size(0), -1) 20 | 21 | class ChannelGate(nn.Module): 22 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): 23 | super(ChannelGate, self).__init__() 24 | self.gate_channels = gate_channels 25 | self.mlp = nn.Sequential( 26 | Flatten(), 27 | nn.Linear(gate_channels, gate_channels // reduction_ratio), 28 | nn.ReLU(), 29 | nn.Linear(gate_channels // reduction_ratio, gate_channels) 30 | ) 31 | self.pool_types = pool_types 32 | def forward(self, x): 33 | channel_att_sum = None 34 | for pool_type in self.pool_types: 35 | if pool_type=='avg': 36 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 37 | channel_att_raw = self.mlp( avg_pool ) 38 | elif pool_type=='max': 39 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) 40 | channel_att_raw = self.mlp( max_pool ) 41 | 42 | if channel_att_sum is None: 43 | channel_att_sum = channel_att_raw 44 | else: 45 | channel_att_sum = channel_att_sum + channel_att_raw 46 | 47 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) 48 | return x * scale 49 | 50 | class ChannelPool(nn.Module): 51 | def forward(self, x): 52 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) 53 | 54 | class SpatialGate(nn.Module): 55 | def __init__(self): 56 | super(SpatialGate, self).__init__() 57 | kernel_size = 7 58 | self.compress = ChannelPool() 59 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) 60 | def forward(self, x): 61 | x_compress = self.compress(x) 62 | x_out = self.spatial(x_compress) 63 | scale = torch.sigmoid(x_out) # broadcasting 64 | return x * scale 65 | 66 | class CBAM(nn.Module): 67 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): 68 | super(CBAM, self).__init__() 69 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) 70 | self.no_spatial=no_spatial 71 | if not no_spatial: 72 | self.SpatialGate = SpatialGate() 73 | def forward(self, x): 74 | x_out = self.ChannelGate(x) 75 | if not self.no_spatial: 76 | x_out = self.SpatialGate(x_out) 77 | return x_out 78 | -------------------------------------------------------------------------------- /deva/model/group_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Group-specific modules 3 | They handle features that also depends on the mask. 4 | Features are typically of shape 5 | batch_size * num_objects * num_channels * H * W 6 | 7 | All of them are permutation equivariant w.r.t. to the num_objects dimension 8 | """ 9 | from typing import Optional 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from deva.model.cbam import CBAM 15 | 16 | 17 | def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, align_corners: bool) -> torch.Tensor: 18 | batch_size, num_objects = g.shape[:2] 19 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1), 20 | scale_factor=ratio, 21 | mode=mode, 22 | align_corners=align_corners) 23 | g = g.view(batch_size, num_objects, *g.shape[1:]) 24 | return g 25 | 26 | 27 | def upsample_groups(g: torch.Tensor, 28 | ratio: float = 2, 29 | mode: str = 'bilinear', 30 | align_corners: bool = False) -> torch.Tensor: 31 | return interpolate_groups(g, ratio, mode, align_corners) 32 | 33 | 34 | def downsample_groups(g: torch.Tensor, 35 | ratio: float = 1 / 2, 36 | mode: str = 'area', 37 | align_corners: bool = None) -> torch.Tensor: 38 | return interpolate_groups(g, ratio, mode, align_corners) 39 | 40 | 41 | class GConv2D(nn.Conv2d): 42 | def forward(self, g: torch.Tensor) -> torch.Tensor: 43 | batch_size, num_objects = g.shape[:2] 44 | g = super().forward(g.flatten(start_dim=0, end_dim=1)) 45 | return g.view(batch_size, num_objects, *g.shape[1:]) 46 | 47 | 48 | class GroupResBlock(nn.Module): 49 | def __init__(self, in_dim: int, out_dim: int): 50 | super().__init__() 51 | 52 | if in_dim == out_dim: 53 | self.downsample = None 54 | else: 55 | self.downsample = GConv2D(in_dim, out_dim, kernel_size=1) 56 | 57 | self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1) 58 | self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1) 59 | 60 | def forward(self, g: torch.Tensor) -> torch.Tensor: 61 | out_g = self.conv1(F.relu(g)) 62 | out_g = self.conv2(F.relu(out_g)) 63 | 64 | if self.downsample is not None: 65 | g = self.downsample(g) 66 | 67 | return out_g + g 68 | 69 | 70 | class ResMLP(nn.Module): 71 | def __init__(self, in_dim: int, out_dim: int): 72 | super().__init__() 73 | 74 | if in_dim == out_dim: 75 | self.downsample = None 76 | else: 77 | self.downsample = nn.Linear(in_dim, out_dim) 78 | 79 | self.conv1 = nn.Linear(in_dim, out_dim) 80 | self.conv2 = nn.Linear(out_dim, out_dim) 81 | 82 | def forward(self, g: torch.Tensor) -> torch.Tensor: 83 | out_g = self.conv1(F.relu(g)) 84 | out_g = self.conv2(F.relu(out_g)) 85 | 86 | if self.downsample is not None: 87 | g = self.downsample(g) 88 | 89 | return out_g + g 90 | 91 | 92 | class MainToGroupDistributor(nn.Module): 93 | def __init__(self, 94 | x_transform: Optional[nn.Module] = None, 95 | g_transform: Optional[nn.Module] = None, 96 | method: str = 'cat', 97 | reverse_order: bool = False): 98 | super().__init__() 99 | 100 | self.x_transform = x_transform 101 | self.g_transform = g_transform 102 | self.method = method 103 | self.reverse_order = reverse_order 104 | 105 | def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor: 106 | num_objects = g.shape[1] 107 | 108 | if self.x_transform is not None: 109 | x = self.x_transform(x) 110 | 111 | if self.g_transform is not None: 112 | g = self.g_transform(g) 113 | 114 | if not skip_expand: 115 | x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) 116 | if self.method == 'cat': 117 | if self.reverse_order: 118 | g = torch.cat([g, x], 2) 119 | else: 120 | g = torch.cat([x, g], 2) 121 | elif self.method == 'add': 122 | g = x + g 123 | elif self.method == 'mulcat': 124 | g = torch.cat([x * g, g], dim=2) 125 | elif self.method == 'muladd': 126 | g = x * g + g 127 | else: 128 | raise NotImplementedError 129 | 130 | return g 131 | 132 | 133 | class GroupFeatureFusionBlock(nn.Module): 134 | def __init__(self, x_in_dim: int, g_in_dim: int, g_mid_dim: int, g_out_dim: int): 135 | super().__init__() 136 | 137 | self.distributor = MainToGroupDistributor() 138 | self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim) 139 | self.attention = CBAM(g_mid_dim) 140 | self.block2 = GroupResBlock(g_mid_dim, g_out_dim) 141 | 142 | def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: 143 | batch_size, num_objects = g.shape[:2] 144 | 145 | g = self.distributor(x, g) 146 | g = self.block1(g) 147 | r = self.attention(g.flatten(start_dim=0, end_dim=1)) 148 | r = r.view(batch_size, num_objects, *r.shape[1:]) 149 | 150 | g = self.block2(g + r) 151 | 152 | return g -------------------------------------------------------------------------------- /deva/model/losses.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from collections import defaultdict 7 | 8 | 9 | def dice_loss(input_mask, cls_gt) -> torch.Tensor: 10 | num_objects = input_mask.shape[1] 11 | losses = [] 12 | for i in range(num_objects): 13 | mask = input_mask[:, i].flatten(start_dim=1) 14 | # background not in mask, so we add one to cls_gt 15 | gt = (cls_gt == (i + 1)).float().flatten(start_dim=1) 16 | numerator = 2 * (mask * gt).sum(-1) 17 | denominator = mask.sum(-1) + gt.sum(-1) 18 | loss = 1 - (numerator + 1) / (denominator + 1) 19 | losses.append(loss) 20 | return torch.cat(losses).mean() 21 | 22 | 23 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch 24 | class BootstrappedCE(nn.Module): 25 | def __init__(self, start_warm, end_warm, top_p=0.3): 26 | super().__init__() 27 | 28 | self.start_warm = start_warm 29 | self.end_warm = end_warm 30 | self.top_p = top_p 31 | 32 | def forward(self, input, target, it) -> (torch.Tensor, float): 33 | if it < self.start_warm: 34 | return F.cross_entropy(input, target), 1.0 35 | 36 | raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) 37 | num_pixels = raw_loss.numel() 38 | 39 | if it > self.end_warm: 40 | this_p = self.top_p 41 | else: 42 | this_p = self.top_p + (1 - self.top_p) * ((self.end_warm - it) / 43 | (self.end_warm - self.start_warm)) 44 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) 45 | return loss.mean(), this_p 46 | 47 | 48 | class LossComputer: 49 | def __init__(self, config): 50 | super().__init__() 51 | self.config = config 52 | self.bce = BootstrappedCE(config['start_warm'], config['end_warm']) 53 | 54 | def compute(self, data, num_objects, it) -> Dict[str, torch.Tensor]: 55 | losses = defaultdict(int) 56 | 57 | b, t = data['rgb'].shape[:2] 58 | 59 | losses['total_loss'] = 0 60 | for ti in range(1, t): 61 | for bi in range(b): 62 | loss, p = self.bce(data[f'logits_{ti}'][bi:bi + 1, :num_objects[bi] + 1], 63 | data['cls_gt'][bi:bi + 1, ti, 0], it) 64 | 65 | aux_loss = F.cross_entropy( 66 | data[f'aux_logits_{ti}'][bi:bi + 1, :num_objects[bi] + 1, 0], 67 | data['cls_gt'][bi:bi + 1, ti, 0]) 68 | 69 | losses['p'] += p / b / (t - 1) 70 | losses[f'ce_loss_{ti}'] += loss / b 71 | losses[f'aux_loss_{ti}'] += aux_loss / b 72 | 73 | losses['total_loss'] += losses['ce_loss_%d' % ti] 74 | losses['total_loss'] += losses['aux_loss_%d' % ti] * 0.1 75 | losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:, ti, 0]) 76 | losses['total_loss'] += losses[f'dice_loss_{ti}'] 77 | 78 | return losses 79 | -------------------------------------------------------------------------------- /deva/model/memory_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from typing import Optional, Union, Tuple 4 | 5 | 6 | def get_similarity(mk: torch.Tensor, 7 | ms: torch.Tensor, 8 | qk: torch.Tensor, 9 | qe: torch.Tensor, 10 | add_batch_dim=False) -> torch.Tensor: 11 | # used for training/inference and memory reading/memory potentiation 12 | # mk: B x CK x [N] - Memory keys 13 | # ms: B x 1 x [N] - Memory shrinkage 14 | # qk: B x CK x [HW/P] - Query keys 15 | # qe: B x CK x [HW/P] - Query selection 16 | # Dimensions in [] are flattened 17 | if add_batch_dim: 18 | mk, ms = mk.unsqueeze(0), ms.unsqueeze(0) 19 | qk, qe = qk.unsqueeze(0), qe.unsqueeze(0) 20 | 21 | CK = mk.shape[1] 22 | mk = mk.flatten(start_dim=2) 23 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None 24 | qk = qk.flatten(start_dim=2) 25 | qe = qe.flatten(start_dim=2) if qe is not None else None 26 | 27 | if qe is not None: 28 | # See XMem's appendix for derivation 29 | mk = mk.transpose(1, 2) 30 | a_sq = (mk.pow(2) @ qe) 31 | two_ab = 2 * (mk @ (qk * qe)) 32 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True) 33 | similarity = (-a_sq + two_ab - b_sq) 34 | else: 35 | # similar to STCN if we don't have the selection term 36 | a_sq = mk.pow(2).sum(1).unsqueeze(2) 37 | two_ab = 2 * (mk.transpose(1, 2) @ qk) 38 | similarity = (-a_sq + two_ab) 39 | 40 | if ms is not None: 41 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW 42 | else: 43 | similarity = similarity / math.sqrt(CK) # B*N*HW 44 | 45 | return similarity 46 | 47 | 48 | def do_softmax( 49 | similarity: torch.Tensor, 50 | top_k: Optional[int] = None, 51 | inplace: bool = False, 52 | return_usage: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: 53 | # normalize similarity with top-k softmax 54 | # similarity: B x N x [HW/P] 55 | # use inplace with care 56 | if top_k is not None: 57 | values, indices = torch.topk(similarity, k=top_k, dim=1) 58 | 59 | x_exp = values.exp_() 60 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True) 61 | if inplace: 62 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW 63 | affinity = similarity 64 | else: 65 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW 66 | else: 67 | maxes = torch.max(similarity, dim=1, keepdim=True)[0] 68 | x_exp = torch.exp(similarity - maxes) 69 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) 70 | affinity = x_exp / x_exp_sum 71 | indices = None 72 | 73 | if return_usage: 74 | return affinity, affinity.sum(dim=2) 75 | 76 | return affinity 77 | 78 | 79 | def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor, 80 | qe: torch.Tensor) -> torch.Tensor: 81 | # shorthand used in training with no top-k 82 | similarity = get_similarity(mk, ms, qk, qe) 83 | affinity = do_softmax(similarity) 84 | return affinity 85 | 86 | 87 | def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor: 88 | B, CV, T, H, W = mv.shape 89 | 90 | mo = mv.view(B, CV, T * H * W) 91 | mem = torch.bmm(mo, affinity) 92 | mem = mem.view(B, CV, H, W) 93 | 94 | return mem 95 | -------------------------------------------------------------------------------- /deva/model/modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | modules.py - This file stores low-level network blocks. 3 | 4 | x - usually means features that only depends on the image 5 | g - usually means features that also depends on the mask. 6 | They might have an extra "group" or "num_objects" dimension, hence 7 | batch_size * num_objects * num_channels * H * W 8 | 9 | The trailing number of a variable usually denote the stride 10 | 11 | """ 12 | 13 | from typing import List, Iterable 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | 18 | from deva.model.group_modules import * 19 | from deva.model.cbam import CBAM 20 | 21 | 22 | class ResBlock(nn.Module): 23 | def __init__(self, in_dim: int, out_dim: int): 24 | super().__init__() 25 | 26 | if in_dim == out_dim: 27 | self.downsample = None 28 | else: 29 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 30 | 31 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1) 32 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1) 33 | 34 | def forward(self, f: torch.Tensor) -> torch.Tensor: 35 | out_f = self.conv1(F.relu(f)) 36 | out_f = self.conv2(F.relu(out_f)) 37 | 38 | if self.downsample is not None: 39 | f = self.downsample(f) 40 | 41 | return out_f + f 42 | 43 | 44 | class FeatureFusionBlock(nn.Module): 45 | def __init__(self, in_dim: int, mid_dim: int, out_dim: int): 46 | super().__init__() 47 | 48 | self.block1 = ResBlock(in_dim, mid_dim) 49 | self.attention = CBAM(mid_dim) 50 | self.block2 = ResBlock(mid_dim, out_dim) 51 | 52 | def forward(self, x: torch.Tensor) -> torch.Tensor: 53 | x = self.block1(x) 54 | r = self.attention(x) 55 | x = self.block2(x + r) 56 | 57 | return x 58 | 59 | 60 | class KeyProjection(nn.Module): 61 | def __init__(self, in_dim: int, keydim: int): 62 | super().__init__() 63 | 64 | self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) 65 | # shrinkage 66 | self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1) 67 | # selection 68 | self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1) 69 | 70 | nn.init.orthogonal_(self.key_proj.weight.data) 71 | nn.init.zeros_(self.key_proj.bias.data) 72 | 73 | def forward(self, x: torch.Tensor, *, need_s: bool, 74 | need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor): 75 | shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None 76 | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None 77 | 78 | return self.key_proj(x), shrinkage, selection 79 | 80 | 81 | class MaskUpsampleBlock(nn.Module): 82 | def __init__(self, up_dim: int, out_dim: int, scale_factor: int = 2): 83 | super().__init__() 84 | self.distributor = MainToGroupDistributor(method='add') 85 | self.out_conv = GroupResBlock(up_dim, out_dim) 86 | self.scale_factor = scale_factor 87 | 88 | def forward(self, skip_f: torch.Tensor, up_g: torch.Tensor) -> torch.Tensor: 89 | g = upsample_groups(up_g, ratio=self.scale_factor) 90 | g = self.distributor(skip_f, g) 91 | g = self.out_conv(g) 92 | return g 93 | 94 | 95 | class DecoderFeatureProcessor(nn.Module): 96 | def __init__(self, decoder_dims: List[int], out_dims: List[int]): 97 | super().__init__() 98 | self.transforms = nn.ModuleList([ 99 | nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims) 100 | ]) 101 | 102 | def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]: 103 | outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)] 104 | return outputs 105 | 106 | 107 | class LinearPredictor(nn.Module): 108 | def __init__(self, in_dim: int, pred_dim: int): 109 | super().__init__() 110 | self.projection = GConv2D(in_dim, pred_dim + 1, kernel_size=1) 111 | 112 | def forward(self, im_feat: torch.Tensor, pred_feat: torch.Tensor) -> torch.Tensor: 113 | num_objects = pred_feat.shape[1] 114 | parameters = self.projection(pred_feat) 115 | 116 | im_feat = im_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1) 117 | x = (im_feat * parameters[:, :, :-1]).sum(dim=2, keepdim=True) + parameters[:, :, -1:] 118 | return x 119 | 120 | 121 | class SensoryUpdater(nn.Module): 122 | # Used in the decoder, multi-scale feature + GRU 123 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int): 124 | super().__init__() 125 | self.sensory_dim = sensory_dim 126 | 127 | self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1) 128 | self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1) 129 | self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1) 130 | 131 | self.transform = GConv2D(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 132 | 133 | nn.init.xavier_normal_(self.transform.weight) 134 | 135 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 136 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \ 137 | self.g4_conv(downsample_groups(g[2], ratio=1/4)) 138 | 139 | g = torch.cat([g, h], 2) 140 | 141 | # defined slightly differently than standard GRU, 142 | # namely the new value is generated before the forget gate. 143 | # might provide better gradient but frankly it was initially just an 144 | # implementation error that I never bothered fixing 145 | values = self.transform(g) 146 | forget_gate = torch.sigmoid(values[:, :, :self.sensory_dim]) 147 | update_gate = torch.sigmoid(values[:, :, self.sensory_dim:self.sensory_dim * 2]) 148 | new_value = torch.tanh(values[:, :, self.sensory_dim * 2:]) 149 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value 150 | 151 | return new_h 152 | 153 | 154 | class SensoryDeepUpdater(nn.Module): 155 | def __init__(self, f_dim: int, sensory_dim: int): 156 | super().__init__() 157 | self.sensory_dim = sensory_dim 158 | self.transform = GConv2D(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1) 159 | 160 | nn.init.xavier_normal_(self.transform.weight) 161 | 162 | def forward(self, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor: 163 | values = self.transform(torch.cat([f, h], dim=2)) 164 | forget_gate = torch.sigmoid(values[:, :, :self.sensory_dim]) 165 | update_gate = torch.sigmoid(values[:, :, self.sensory_dim:self.sensory_dim * 2]) 166 | new_value = torch.tanh(values[:, :, self.sensory_dim * 2:]) 167 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value 168 | 169 | return new_h 170 | -------------------------------------------------------------------------------- /deva/model/resnet.py: -------------------------------------------------------------------------------- 1 | """ 2 | resnet.py - A modified ResNet structure 3 | We append extra channels to the first conv by some network surgery 4 | """ 5 | 6 | from collections import OrderedDict 7 | import math 8 | 9 | import torch 10 | import torch.nn as nn 11 | from torch.utils import model_zoo 12 | 13 | 14 | def load_weights_add_extra_dim(target, source_state, extra_dim=1): 15 | new_dict = OrderedDict() 16 | 17 | for k1, v1 in target.state_dict().items(): 18 | if not 'num_batches_tracked' in k1: 19 | if k1 in source_state: 20 | tar_v = source_state[k1] 21 | 22 | if v1.shape != tar_v.shape: 23 | # Init the new segmentation channel with zeros 24 | # print(v1.shape, tar_v.shape) 25 | c, _, w, h = v1.shape 26 | pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device) 27 | nn.init.orthogonal_(pads) 28 | tar_v = torch.cat([tar_v, pads], 1) 29 | 30 | new_dict[k1] = tar_v 31 | 32 | target.load_state_dict(new_dict) 33 | 34 | 35 | model_urls = { 36 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 37 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 38 | } 39 | 40 | 41 | def conv3x3(in_planes, out_planes, stride=1, dilation=1): 42 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 43 | padding=dilation, dilation=dilation, bias=False) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | expansion = 1 48 | 49 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 50 | super(BasicBlock, self).__init__() 51 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) 52 | self.bn1 = nn.BatchNorm2d(planes) 53 | self.relu = nn.ReLU(inplace=True) 54 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) 55 | self.bn2 = nn.BatchNorm2d(planes) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | 69 | if self.downsample is not None: 70 | residual = self.downsample(x) 71 | 72 | out += residual 73 | out = self.relu(out) 74 | 75 | return out 76 | 77 | 78 | class Bottleneck(nn.Module): 79 | expansion = 4 80 | 81 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 82 | super(Bottleneck, self).__init__() 83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 84 | self.bn1 = nn.BatchNorm2d(planes) 85 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, 86 | padding=dilation, bias=False) 87 | self.bn2 = nn.BatchNorm2d(planes) 88 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 89 | self.bn3 = nn.BatchNorm2d(planes * 4) 90 | self.relu = nn.ReLU(inplace=True) 91 | self.downsample = downsample 92 | self.stride = stride 93 | 94 | def forward(self, x): 95 | residual = x 96 | 97 | out = self.conv1(x) 98 | out = self.bn1(out) 99 | out = self.relu(out) 100 | 101 | out = self.conv2(out) 102 | out = self.bn2(out) 103 | out = self.relu(out) 104 | 105 | out = self.conv3(out) 106 | out = self.bn3(out) 107 | 108 | if self.downsample is not None: 109 | residual = self.downsample(x) 110 | 111 | out += residual 112 | out = self.relu(out) 113 | 114 | return out 115 | 116 | 117 | class ResNet(nn.Module): 118 | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0): 119 | self.inplanes = 64 120 | super(ResNet, self).__init__() 121 | self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False) 122 | self.bn1 = nn.BatchNorm2d(64) 123 | self.relu = nn.ReLU(inplace=True) 124 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 125 | self.layer1 = self._make_layer(block, 64, layers[0]) 126 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 127 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 128 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 133 | m.weight.data.normal_(0, math.sqrt(2. / n)) 134 | elif isinstance(m, nn.BatchNorm2d): 135 | m.weight.data.fill_(1) 136 | m.bias.data.zero_() 137 | 138 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 139 | downsample = None 140 | if stride != 1 or self.inplanes != planes * block.expansion: 141 | downsample = nn.Sequential( 142 | nn.Conv2d(self.inplanes, planes * block.expansion, 143 | kernel_size=1, stride=stride, bias=False), 144 | nn.BatchNorm2d(planes * block.expansion), 145 | ) 146 | 147 | layers = [block(self.inplanes, planes, stride, downsample)] 148 | self.inplanes = planes * block.expansion 149 | for i in range(1, blocks): 150 | layers.append(block(self.inplanes, planes, dilation=dilation)) 151 | 152 | return nn.Sequential(*layers) 153 | 154 | def resnet18(pretrained=True, extra_dim=0): 155 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim) 156 | if pretrained: 157 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim) 158 | return model 159 | 160 | def resnet50(pretrained=True, extra_dim=0): 161 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim) 162 | if pretrained: 163 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim) 164 | return model 165 | 166 | -------------------------------------------------------------------------------- /deva/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/utils/__init__.py -------------------------------------------------------------------------------- /deva/utils/configuration.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | class Configuration(): 5 | def parse(self, unknown_arg_ok=False): 6 | parser = ArgumentParser() 7 | 8 | # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment 9 | parser.add_argument('--benchmark', action='store_true') 10 | # AMP causes problems in training. Not recommended. 11 | parser.add_argument('--amp', action='store_true') 12 | 13 | # Data parameters 14 | parser.add_argument('--static_root', help='Static training data root', default='../static') 15 | parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K') 16 | parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube') 17 | parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS') 18 | parser.add_argument('--ovis_root', help='OVIS data root', default='../OVIS-VOS-train') 19 | parser.add_argument('--num_workers', 20 | help='Total number of dataloader workers across all GPUs processes', 21 | type=int, 22 | default=16) 23 | parser.add_argument('--video_data_ratio', default=1.0, type=float) 24 | 25 | parser.add_argument('--pix_feat_dim', default=512, type=int) 26 | parser.add_argument('--key_dim', default=64, type=int) 27 | parser.add_argument('--value_dim', default=512, type=int) 28 | 29 | parser.add_argument('--deep_update_prob', default=0.2, type=float) 30 | 31 | # stage 1 and 2 are used previously in MiVOS/STCN/XMem, and not used here 32 | # we do not train on BL30K for this project, but we keep the naming (s03) for consistency 33 | parser.add_argument('--stages', 34 | help='Training stage (0-static images, 3-DAVIS+YouTubeVOS+OVIS)', 35 | default='03') 36 | parser.add_argument('--clip_grad_norm', 37 | help='Clip norm of global gradient at', 38 | default=3.0, 39 | type=float) 40 | """ 41 | Stage-specific learning parameters 42 | Batch sizes are effective -- you don't have to scale them when you scale the number processes 43 | """ 44 | # Stage 0, static images 45 | parser.add_argument('--s0_batch_size', default=16, type=int) 46 | parser.add_argument('--s0_iterations', default=80000, type=int) 47 | parser.add_argument('--s0_steps', nargs="*", default=[], type=int) 48 | parser.add_argument('--s0_lr', help='Initial learning rate', default=2e-5, type=float) 49 | parser.add_argument('--s0_num_ref_frames', default=2, type=int) 50 | parser.add_argument('--s0_num_frames', default=3, type=int) 51 | parser.add_argument('--s0_start_warm', default=10000, type=int) 52 | parser.add_argument('--s0_end_warm', default=35000, type=int) 53 | parser.add_argument('--s0_schedule', default='constant') 54 | 55 | # Stage 3, DAVIS+YoutubeVOS+OVIS 56 | parser.add_argument('--s3_batch_size', default=16, type=int) 57 | parser.add_argument('--s3_iterations', default=150000, type=int) 58 | parser.add_argument('--s3_steps', nargs="*", default=[120000, 140000], type=int) 59 | parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float) 60 | parser.add_argument('--s3_num_ref_frames', default=3, type=int) 61 | parser.add_argument('--s3_num_frames', default=8, type=int) 62 | parser.add_argument('--s3_start_warm', default=10000, type=int) 63 | parser.add_argument('--s3_end_warm', default=35000, type=int) 64 | parser.add_argument('--s3_schedule', default='step') 65 | 66 | parser.add_argument('--gamma', 67 | help='LR := LR*gamma at every decay step', 68 | default=0.1, 69 | type=float) 70 | parser.add_argument('--weight_decay', default=0.001, type=float) 71 | 72 | # Loading 73 | parser.add_argument('--load_network', help='Path to a pretrained network weights file. ') 74 | parser.add_argument( 75 | '--load_checkpoint', 76 | help='Path to the checkpoint file which contains network weights, optimizer, scheduler, ' 77 | 'and the number of iterations. This is used for resuming interrupted training.') 78 | 79 | # Logging information 80 | parser.add_argument('--log_text_interval', default=100, type=int) 81 | parser.add_argument('--log_image_interval', default=1500, type=int) 82 | parser.add_argument('--save_network_interval', default=50000, type=int) 83 | parser.add_argument('--save_checkpoint_interval', default=50000, type=int) 84 | parser.add_argument('--exp_id', 85 | help='UNIQUE for a training run, set to NULL to disable logging', 86 | default='NULL') 87 | parser.add_argument('--debug', 88 | help='Debug mode; logs info at every iteration', 89 | action='store_true') 90 | 91 | if unknown_arg_ok: 92 | args, _ = parser.parse_known_args() 93 | self.args = vars(args) 94 | else: 95 | self.args = vars(parser.parse_args()) 96 | 97 | # check if the stages are valid 98 | stage_to_perform = list(self.args['stages']) 99 | for s in stage_to_perform: 100 | if s not in ['0', '3']: 101 | raise NotImplementedError 102 | 103 | def get_stage_parameters(self, stage): 104 | parameters = { 105 | 'batch_size': self.args['s%s_batch_size' % stage], 106 | 'iterations': self.args['s%s_iterations' % stage], 107 | 'steps': self.args['s%s_steps' % stage], 108 | 'schedule': self.args['s%s_schedule' % stage], 109 | 'lr': self.args['s%s_lr' % stage], 110 | 'num_ref_frames': self.args['s%s_num_ref_frames' % stage], 111 | 'num_frames': self.args['s%s_num_frames' % stage], 112 | 'start_warm': self.args['s%s_start_warm' % stage], 113 | 'end_warm': self.args['s%s_end_warm' % stage], 114 | } 115 | 116 | return parameters 117 | 118 | def __getitem__(self, key): 119 | return self.args[key] 120 | 121 | def __setitem__(self, key, value): 122 | self.args[key] = value 123 | 124 | def __str__(self): 125 | return str(self.args) 126 | -------------------------------------------------------------------------------- /deva/utils/davis_subset.txt: -------------------------------------------------------------------------------- 1 | bear 2 | bmx-bumps 3 | boat 4 | boxing-fisheye 5 | breakdance-flare 6 | bus 7 | car-turn 8 | cat-girl 9 | classic-car 10 | color-run 11 | crossing 12 | dance-jump 13 | dancing 14 | disc-jockey 15 | dog-agility 16 | dog-gooses 17 | dogs-scale 18 | drift-turn 19 | drone 20 | elephant 21 | flamingo 22 | hike 23 | hockey 24 | horsejump-low 25 | kid-football 26 | kite-walk 27 | koala 28 | lady-running 29 | lindy-hop 30 | longboard 31 | lucia 32 | mallard-fly 33 | mallard-water 34 | miami-surf 35 | motocross-bumps 36 | motorbike 37 | night-race 38 | paragliding 39 | planes-water 40 | rallye 41 | rhino 42 | rollerblade 43 | schoolgirls 44 | scooter-board 45 | scooter-gray 46 | sheep 47 | skate-park 48 | snowboard 49 | soccerball 50 | stroller 51 | stunt 52 | surf 53 | swing 54 | tennis 55 | tractor-sand 56 | train 57 | tuk-tuk 58 | upside-down 59 | varanus-cage 60 | walking -------------------------------------------------------------------------------- /deva/utils/image_saver.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from deva.dataset.utils import inv_im_trans 6 | from collections import defaultdict 7 | 8 | 9 | def tensor_to_numpy(image): 10 | image_np = (image.numpy() * 255).astype('uint8') 11 | return image_np 12 | 13 | 14 | def tensor_to_np_float(image): 15 | image_np = image.numpy().astype('float32') 16 | return image_np 17 | 18 | 19 | def detach_to_cpu(x): 20 | return x.detach().cpu() 21 | 22 | 23 | def transpose_np(x): 24 | return np.transpose(x, [1, 2, 0]) 25 | 26 | 27 | def tensor_to_gray_im(x): 28 | x = detach_to_cpu(x) 29 | x = tensor_to_numpy(x) 30 | x = transpose_np(x) 31 | return x 32 | 33 | 34 | def tensor_to_im(x): 35 | x = detach_to_cpu(x) 36 | x = inv_im_trans(x).clamp(0, 1) 37 | x = tensor_to_numpy(x) 38 | x = transpose_np(x) 39 | return x 40 | 41 | 42 | # Predefined key <-> caption dict 43 | key_captions = { 44 | 'im': 'Image', 45 | 'gt': 'GT', 46 | } 47 | """ 48 | Return an image array with captions 49 | keys in dictionary will be used as caption if not provided 50 | values should contain lists of cv2 images 51 | """ 52 | 53 | 54 | def get_image_array(images, grid_shape, captions={}): 55 | h, w = grid_shape 56 | cate_counts = len(images) 57 | rows_counts = len(next(iter(images.values()))) 58 | 59 | font = cv2.FONT_HERSHEY_SIMPLEX 60 | 61 | output_image = np.zeros([w * cate_counts, h * (rows_counts + 1), 3], dtype=np.uint8) 62 | col_cnt = 0 63 | for k, v in images.items(): 64 | 65 | # Default as key value itself 66 | caption = captions.get(k, k) 67 | 68 | # Handles new line character 69 | dy = 40 70 | for i, line in enumerate(caption.split('\n')): 71 | cv2.putText(output_image, line, (10, col_cnt * w + 100 + i * dy), font, 0.8, 72 | (255, 255, 255), 2, cv2.LINE_AA) 73 | 74 | # Put images 75 | for row_cnt, img in enumerate(v): 76 | im_shape = img.shape 77 | if len(im_shape) == 2: 78 | img = img[..., np.newaxis] 79 | 80 | img = (img * 255).astype('uint8') 81 | 82 | output_image[(col_cnt + 0) * w:(col_cnt + 1) * w, 83 | (row_cnt + 1) * h:(row_cnt + 2) * h, :] = img 84 | 85 | col_cnt += 1 86 | 87 | return output_image 88 | 89 | 90 | def base_transform(im, size): 91 | im = tensor_to_np_float(im) 92 | if len(im.shape) == 3: 93 | im = im.transpose((1, 2, 0)) 94 | else: 95 | im = im[:, :, None] 96 | 97 | # Resize 98 | if im.shape[1] != size: 99 | im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) 100 | 101 | return im.clip(0, 1) 102 | 103 | 104 | def im_transform(im, size): 105 | return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) 106 | 107 | 108 | def mask_transform(mask, size): 109 | return base_transform(detach_to_cpu(mask), size=size) 110 | 111 | 112 | def logits_transform(mask, size): 113 | return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) 114 | 115 | 116 | def pool_pairs(images, size, num_objects): 117 | req_images = defaultdict(list) 118 | 119 | b, t = images['rgb'].shape[:2] 120 | 121 | # limit the number of images saved 122 | b = min(2, b) 123 | 124 | # find max num objects 125 | max_num_objects = max(num_objects[:b]) 126 | 127 | GT_suffix = '' 128 | for bi in range(b): 129 | GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4] 130 | 131 | for bi in range(b): 132 | for ti in range(t): 133 | req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size)) 134 | for oi in range(max_num_objects): 135 | if ti == 0 or oi >= num_objects[bi]: 136 | req_images[f'Mask_{oi}'].append( 137 | mask_transform(images['first_frame_gt'][bi][0, oi], size)) 138 | req_images[f'Aux Mask_{oi}'].append( 139 | mask_transform(images['first_frame_gt'][bi][0, oi], size)) 140 | else: 141 | req_images[f'Mask_{oi}'].append( 142 | mask_transform(images[f'masks_{ti}'][bi][oi], size)) 143 | req_images[f'Aux Mask_{oi}'].append( 144 | mask_transform(images[f'aux_masks_{ti}'][bi][oi][0], size)) 145 | req_images[f'GT_{oi}_{GT_suffix}'].append( 146 | mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size)) 147 | 148 | return get_image_array(req_images, size, key_captions) -------------------------------------------------------------------------------- /deva/utils/load_subset.py: -------------------------------------------------------------------------------- 1 | """ 2 | load_subset.py - Presents a subset of data 3 | DAVIS - only the training set 4 | YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all 5 | """ 6 | 7 | 8 | def load_sub_davis(path='deva/utils/davis_subset.txt'): 9 | with open(path, mode='r') as f: 10 | subset = set(f.read().splitlines()) 11 | return subset 12 | 13 | 14 | def load_sub_yv(path='deva/utils/yv_subset.txt'): 15 | with open(path, mode='r') as f: 16 | subset = set(f.read().splitlines()) 17 | return subset 18 | 19 | 20 | def load_referring_yv_val(path='deva/utils/referring-youtubevos-val.txt'): 21 | with open(path, mode='r') as f: 22 | subset = set(f.read().splitlines()) 23 | return subset 24 | -------------------------------------------------------------------------------- /deva/utils/log_integrator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Integrate numerical values for some iterations 3 | Typically used for loss computation / logging to tensorboard 4 | Call finalize and create a new Integrator when you want to display/log 5 | """ 6 | from typing import Dict, Callable, Tuple 7 | import torch 8 | from deva.utils.logger import TensorboardLogger 9 | 10 | 11 | class Integrator: 12 | def __init__(self, logger: TensorboardLogger, distributed: bool = True): 13 | self.values = {} 14 | self.counts = {} 15 | self.hooks = [] # List is used here to maintain insertion order 16 | 17 | self.logger = logger 18 | 19 | self.distributed = distributed 20 | self.local_rank = torch.distributed.get_rank() 21 | self.world_size = torch.distributed.get_world_size() 22 | 23 | def add_tensor(self, key: str, tensor: torch.Tensor): 24 | if key not in self.values: 25 | self.counts[key] = 1 26 | if type(tensor) == float or type(tensor) == int: 27 | self.values[key] = tensor 28 | else: 29 | self.values[key] = tensor.mean().item() 30 | else: 31 | self.counts[key] += 1 32 | if type(tensor) == float or type(tensor) == int: 33 | self.values[key] += tensor 34 | else: 35 | self.values[key] += tensor.mean().item() 36 | 37 | def add_dict(self, tensor_dict: Dict[str, torch.Tensor]): 38 | for k, v in tensor_dict.items(): 39 | self.add_tensor(k, v) 40 | 41 | def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]): 42 | """ 43 | Adds a custom hook, i.e. compute new metrics using values in the dict 44 | The hook takes the dict as argument, and returns a (k, v) tuple 45 | e.g. for computing IoU 46 | """ 47 | if type(hook) == list: 48 | self.hooks.extend(hook) 49 | else: 50 | self.hooks.append(hook) 51 | 52 | def reset_except_hooks(self): 53 | self.values = {} 54 | self.counts = {} 55 | 56 | # Average and output the metrics 57 | def finalize(self, prefix: str, it: int, f=None) -> None: 58 | 59 | for hook in self.hooks: 60 | k, v = hook(self.values) 61 | self.add_tensor(k, v) 62 | 63 | for k, v in self.values.items(): 64 | 65 | if k[:4] == 'hide': 66 | continue 67 | 68 | avg = v / self.counts[k] 69 | 70 | if self.distributed: 71 | # Inplace operation 72 | avg = torch.tensor(avg).cuda() 73 | torch.distributed.reduce(avg, dst=0) 74 | 75 | if self.local_rank == 0: 76 | avg = (avg / self.world_size).cpu().item() 77 | self.logger.log_metrics(prefix, k, avg, it, f) 78 | else: 79 | # Simple does it 80 | self.logger.log_metrics(prefix, k, avg, it, f) -------------------------------------------------------------------------------- /deva/utils/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dumps things to tensorboard and console 3 | """ 4 | 5 | import os 6 | import warnings 7 | 8 | import torchvision.transforms as transforms 9 | from torch.utils.tensorboard import SummaryWriter 10 | 11 | 12 | def tensor_to_numpy(image): 13 | image_np = (image.numpy() * 255).astype('uint8') 14 | return image_np 15 | 16 | 17 | def detach_to_cpu(x): 18 | return x.detach().cpu() 19 | 20 | 21 | def fix_width_trunc(x): 22 | return ('{:.9s}'.format('{:0.9f}'.format(x))) 23 | 24 | 25 | class TensorboardLogger: 26 | def __init__(self, short_id, id): 27 | self.short_id = short_id 28 | if self.short_id == 'NULL': 29 | self.short_id = 'DEBUG' 30 | 31 | if id is None: 32 | self.no_log = True 33 | warnings.warn('Logging has been disabled.') 34 | else: 35 | self.no_log = False 36 | 37 | self.inv_im_trans = transforms.Normalize( 38 | mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225], 39 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225]) 40 | 41 | self.inv_seg_trans = transforms.Normalize(mean=[-0.5 / 0.5], std=[1 / 0.5]) 42 | 43 | log_path = os.path.join('.', 'saves', '%s' % id) 44 | self.logger = SummaryWriter(log_path) 45 | 46 | # Get current git info for logging 47 | try: 48 | import git 49 | repo = git.Repo(".") 50 | git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha) 51 | except (ImportError, RuntimeError) as e: 52 | print('Failed to fetch git info. Defaulting to None') 53 | git_info = 'None' 54 | 55 | self.log_string('git', git_info) 56 | 57 | def log_scalar(self, tag, x, step): 58 | if self.no_log: 59 | warnings.warn('Logging has been disabled.') 60 | return 61 | self.logger.add_scalar(tag, x, step) 62 | 63 | def log_metrics(self, l1_tag, l2_tag, val, step, f=None): 64 | tag = l1_tag + '/' + l2_tag 65 | text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), 66 | l2_tag, fix_width_trunc(val)) 67 | print(text) 68 | if f is not None: 69 | f.write(text + '\n') 70 | f.flush() 71 | self.log_scalar(tag, val, step) 72 | 73 | def log_im(self, tag, x, step): 74 | if self.no_log: 75 | warnings.warn('Logging has been disabled.') 76 | return 77 | x = detach_to_cpu(x) 78 | x = self.inv_im_trans(x) 79 | x = tensor_to_numpy(x) 80 | self.logger.add_image(tag, x, step) 81 | 82 | def log_cv2(self, tag, x, step): 83 | if self.no_log: 84 | warnings.warn('Logging has been disabled.') 85 | return 86 | x = x.transpose((2, 0, 1)) 87 | self.logger.add_image(tag, x, step) 88 | 89 | def log_seg(self, tag, x, step): 90 | if self.no_log: 91 | warnings.warn('Logging has been disabled.') 92 | return 93 | x = detach_to_cpu(x) 94 | x = self.inv_seg_trans(x) 95 | x = tensor_to_numpy(x) 96 | self.logger.add_image(tag, x, step) 97 | 98 | def log_gray(self, tag, x, step): 99 | if self.no_log: 100 | warnings.warn('Logging has been disabled.') 101 | return 102 | x = detach_to_cpu(x) 103 | x = tensor_to_numpy(x) 104 | self.logger.add_image(tag, x, step) 105 | 106 | def log_string(self, tag, x): 107 | print(tag, x) 108 | if self.no_log: 109 | warnings.warn('Logging has been disabled.') 110 | return 111 | self.logger.add_text(tag, x) 112 | -------------------------------------------------------------------------------- /deva/utils/palette.py: -------------------------------------------------------------------------------- 1 | davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0' 2 | 3 | youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f' 4 | -------------------------------------------------------------------------------- /deva/utils/pano_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from deva.utils.vipseg_categories import VIPSEG_CATEGORIES 3 | 4 | vipseg_cat_to_isthing = {d['id']: d['isthing'] == 1 for d in VIPSEG_CATEGORIES} 5 | 6 | 7 | def id_to_rgb(id: np.ndarray) -> np.ndarray: 8 | h, w = id.shape 9 | rgb = np.zeros((h, w, 3), dtype=np.uint8) 10 | 11 | for i in range(3): 12 | rgb[:, :, i] = id % 256 13 | id = id // 256 14 | 15 | return rgb 16 | 17 | 18 | class ID2RGBConverter: 19 | def __init__(self): 20 | self.all_id = [] 21 | self.obj_to_id = {} 22 | 23 | def _id_to_rgb(self, id: int): 24 | rgb = np.zeros((3, ), dtype=np.uint8) 25 | for i in range(3): 26 | rgb[i] = id % 256 27 | id = id // 256 28 | return rgb 29 | 30 | def convert(self, obj: int): 31 | if obj in self.obj_to_id: 32 | id = self.obj_to_id[obj] 33 | else: 34 | while True: 35 | id = np.random.randint(255, 256**3) 36 | if id not in self.all_id: 37 | break 38 | self.obj_to_id[obj] = id 39 | self.all_id.append(id) 40 | 41 | return id, self._id_to_rgb(id) 42 | 43 | 44 | class IDPostprocessor: 45 | def __init__(self): 46 | self.all_id = [] 47 | self.thing_obj_to_id = {} 48 | self.stuff_to_id = {} 49 | 50 | def id_to_rgb(self, id): 51 | rgb = np.zeros((3, ), dtype=np.uint8) 52 | for i in range(3): 53 | rgb[i] = id % 256 54 | id = id // 256 55 | return rgb 56 | 57 | def _find_new_id(self, default_id): 58 | id = default_id 59 | while True: 60 | if id not in self.all_id: 61 | return id 62 | id = np.random.randint(256, 256**3) 63 | 64 | def convert(self, obj, category_id, isthing): 65 | if isthing: 66 | # is thing 67 | if (obj, category_id) in self.thing_obj_to_id: 68 | id = self.thing_obj_to_id[(obj, category_id)] 69 | else: 70 | id = self._find_new_id(obj) 71 | self.thing_obj_to_id[(obj, category_id)] = id 72 | self.all_id.append(id) 73 | else: 74 | # is stuff 75 | if category_id in self.stuff_to_id: 76 | id = self.stuff_to_id[category_id] 77 | else: 78 | id = self._find_new_id(obj) 79 | self.stuff_to_id[category_id] = id 80 | self.all_id.append(id) 81 | 82 | return id 83 | -------------------------------------------------------------------------------- /deva/utils/referring-youtubevos-val.txt: -------------------------------------------------------------------------------- 1 | 0062f687f1 2 | 01c88b5b60 3 | 0390fabe58 4 | 03fe6115d4 5 | 0620b43a31 6 | 06a5dfb511 7 | 06cd94d38d 8 | 0723d7d4fe 9 | 0782a6df7e 10 | 0788b4033d 11 | 0a598e18a8 12 | 0b0c90e21a 13 | 0c04834d61 14 | 0daaddc9da 15 | 0e8a6b63bb 16 | 0f3f8b2b2f 17 | 1335b16cf9 18 | 13c3cea202 19 | 13ca7bbcfd 20 | 152fe4902a 21 | 17cba76927 22 | 182dbfd6ba 23 | 188cb4e03d 24 | 19cde15c4b 25 | 1a1dbe153e 26 | 1a609fa7ee 27 | 1a894a8f98 28 | 1ab5f4bbc5 29 | 1e0257109e 30 | 1e20ceafae 31 | 1f390d22ea 32 | 20a93b4c54 33 | 218ac81c2d 34 | 226f1e10f7 35 | 246e38963b 36 | 257f7fd5b8 37 | 29c06df0f2 38 | 2b904b76c9 39 | 30fe0ed0ce 40 | 31d3a7d2ee 41 | 31e0beaf99 42 | 332dabe378 43 | 335fc10235 44 | 33c8dcbe09 45 | 33e8066265 46 | 34564d26d8 47 | 352ad66724 48 | 35948a7fca 49 | 35d5e5149d 50 | 3674b2c70a 51 | 369919ef49 52 | 37b4ec2e1a 53 | 39b7491321 54 | 39bce09d8d 55 | 3b72dc1941 56 | 3be852ed44 57 | 3dd327ab4e 58 | 3e03f623bb 59 | 3f4bacb16a 60 | 4037d8305d 61 | 411774e9ff 62 | 4307020e0f 63 | 43115c42b2 64 | 44e5d1a969 65 | 450bd2e238 66 | 45dc90f558 67 | 466734bc5c 68 | 47d01d34c8 69 | 48d2909d9e 70 | 4b783f1fc5 71 | 4ee0105885 72 | 4f5b3310e3 73 | 4f6662e4e0 74 | 4fe6619a47 75 | 541ccb0844 76 | 54526e3c66 77 | 5460cc540a 78 | 547416bda1 79 | 559a611d86 80 | 5d2020eff8 81 | 6031809500 82 | 60362df585 83 | 61fca8cbf1 84 | 621487be65 85 | 623d24ce2b 86 | 62bf7630b3 87 | 63883da4f5 88 | 64c6f2ed76 89 | 65350fd60a 90 | 65e0640a2a 91 | 68dab8f80c 92 | 696e01387c 93 | 69c0f7494e 94 | 6a75316e99 95 | 6cb5b08d93 96 | 6cc8bce61a 97 | 72d613f21a 98 | 749f1abdf9 99 | 7741a0fbce 100 | 7775043b5e 101 | 77df215672 102 | 7836afc0c2 103 | 7a19a80b19 104 | 7a72130f21 105 | 7daa6343e6 106 | 7f26b553ae 107 | 822c31928a 108 | 8273b59141 109 | 853ca85618 110 | 85968ae408 111 | 8939473ea7 112 | 8b7b57b94d 113 | 8c60938d92 114 | 8d803e87f7 115 | 8dea7458de 116 | 8e2e5af6a8 117 | 92fde455eb 118 | 94fa9bd3b5 119 | 975be70866 120 | 9787f452bf 121 | 97b38cabcc 122 | 9a38b8e463 123 | 9c0b55cae5 124 | 9ce299a510 125 | 9da2156a73 126 | 9f16d17e42 127 | 9f21474aca 128 | 9f429af409 129 | 9fd2d2782b 130 | a00c3fa88e 131 | a0fc95d8fc 132 | a1251195e7 133 | a2948d4116 134 | a46012c642 135 | a4bce691c6 136 | a7462d6aaf 137 | a806e58451 138 | a9f23c9150 139 | ab9a7583f1 140 | abae1ce57d 141 | aceb34fcbe 142 | b00ff71889 143 | b05faf54f7 144 | b205d868e6 145 | b2256e265c 146 | b3b92781d9 147 | b5514f75d8 148 | b58a97176b 149 | b772ac822a 150 | b7928ea5c0 151 | b7b7e52e02 152 | b83923fd72 153 | b90f8c11db 154 | ba8823f2d2 155 | bc9ba8917e 156 | bf2d38aefe 157 | bf4cc89b18 158 | c16d9a4ade 159 | c280d21988 160 | c2bbd6d121 161 | c42fdedcdd 162 | c74fc37224 163 | c9ef04fe59 164 | cb06f84b6e 165 | cbea8f6bea 166 | cc1a82ac2a 167 | cc7c3138ff 168 | cd69993923 169 | cd896a9bee 170 | cdcfd9f93a 171 | d1ac0d8b81 172 | d1dd586cfd 173 | d59c093632 174 | d69812339e 175 | d7a38bf258 176 | d7ff44ea97 177 | d975e5f4a9 178 | dab44991de 179 | dc197289ef 180 | dce363032d 181 | dea0160a12 182 | deed0ab4fc 183 | e027ebc228 184 | e10236eb37 185 | e11254d3b9 186 | e633eec195 187 | eb263ef128 188 | eb49ce8027 189 | ebe7138e58 190 | ee9415c553 191 | eea1a45e49 192 | eeb18f9d47 193 | f054e28786 194 | f143fede6f 195 | f2a45acf1c 196 | f3678388a7 197 | f39c805b54 198 | f7255a57d0 199 | f7d7fb16d0 200 | fb104c286f 201 | fd8cf868b2 202 | fef7e84268 203 | -------------------------------------------------------------------------------- /deva/utils/tensor_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List, Iterable 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | # STM 7 | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]): 8 | h, w = in_img.shape[-2:] 9 | 10 | if h % d > 0: 11 | new_h = h + d - h % d 12 | else: 13 | new_h = h 14 | if w % d > 0: 15 | new_w = w + d - w % d 16 | else: 17 | new_w = w 18 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2) 19 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2) 20 | pad_array = (int(lw), int(uw), int(lh), int(uh)) 21 | out = F.pad(in_img, pad_array) 22 | return out, pad_array 23 | 24 | 25 | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor: 26 | if len(img.shape) == 4: 27 | if pad[2] + pad[3] > 0: 28 | img = img[:, :, pad[2]:-pad[3], :] 29 | if pad[0] + pad[1] > 0: 30 | img = img[:, :, :, pad[0]:-pad[1]] 31 | elif len(img.shape) == 3: 32 | if pad[2] + pad[3] > 0: 33 | img = img[:, pad[2]:-pad[3], :] 34 | if pad[0] + pad[1] > 0: 35 | img = img[:, :, pad[0]:-pad[1]] 36 | elif len(img.shape) == 5: 37 | if pad[2] + pad[3] > 0: 38 | img = img[:, :, :, pad[2]:-pad[3], :] 39 | if pad[0] + pad[1] > 0: 40 | img = img[:, :, :, :, pad[0]:-pad[1]] 41 | elif len(img.shape) == 2: 42 | if pad[2] + pad[3] > 0: 43 | img = img[pad[2]:-pad[3], :] 44 | if pad[0] + pad[1] > 0: 45 | img = img[:, pad[0]:-pad[1]] 46 | else: 47 | raise NotImplementedError 48 | return img 49 | -------------------------------------------------------------------------------- /deva/vps_metrics/README.md: -------------------------------------------------------------------------------- 1 | `eval_stq_vspw.py` and `eval_vpq_vspw.py` are modified from the original evaluation scripts in https://github.com/VIPSeg-Dataset/VIPSeg-Dataset to generate the quantitative results in the GMP paper. 2 | 3 | There are a few main modifications: 4 | 1. Multiprocessing is implemented to speed up evaluation (by a lot). 5 | 2. Masks are loaded on-the-fly to reduce RAM usage. 6 | 3. There is no longer a `pan_pred` directory that stores the masks separately. The current script expects a single folder that contains all the video folders with the json file. 7 | 4. The print message is modified to be consistent with the notations in the paper and the output text file. `0-frame vpq_stat` becomes `1-frame vpq_stat`;`5-frame vpq_stat` becomes `2-frame vpq_stat`, etc. 8 | 9 | 10 | I noticed that there has been an update to the origin VIPSeg evaluation script that covers some of these modifications (i.e., point 1 and 2). I developed this version independently with that update and I keep the current version for documentation. They generate the same results in my testing. 11 | -------------------------------------------------------------------------------- /deva/vps_metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/vps_metrics/__init__.py -------------------------------------------------------------------------------- /deva/vps_metrics/eval_stq_vipseg.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------- 2 | # Video Panoptic Segmentation 3 | # 4 | # VPQ evaluation code by tube (video segment) matching 5 | # Inference on every frames and evaluation on every 5 frames. 6 | # ------------------------------------------------------------------ 7 | 8 | # Modified by Rex Cheng Oct 2022 - save results as a .txt file 9 | # Feb 2023 - Ported the update from the official code patch to here 10 | # - added a functional interface 11 | 12 | import argparse 13 | import os 14 | import os.path 15 | import numpy as np 16 | from PIL import Image 17 | import json 18 | from tqdm import tqdm 19 | import deva.vps_metrics.segmentation_and_tracking_quality as numpy_stq 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description='VPSNet eval') 24 | parser.add_argument('--submit_dir', '-i', type=str, help='test output directory') 25 | 26 | parser.add_argument( 27 | '--truth_dir', 28 | type=str, 29 | default='../VIPSeg/VIPSeg_720P/panomasksRGB', 30 | help='ground truth directory. Point this to /VIPSeg/VIPSeg_720P/panomasksRGB ' 31 | 'after running the conversion script') 32 | 33 | parser.add_argument( 34 | '--pan_gt_json_file', 35 | type=str, 36 | default='../VIPSeg/VIPSeg_720P/panoptic_gt_VIPSeg_val.json', 37 | help='ground truth JSON file. Point this to /VIPSeg/VIPSeg_720P/panoptic_gt_' 38 | 'VIPSeg_val.json after running the conversion script') 39 | 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | # constants 45 | n_classes = 124 46 | ignore_label = 255 47 | bit_shift = 16 48 | 49 | 50 | def eval_stq(submit_dir, truth_dir, pan_gt_json_file): 51 | output_dir = submit_dir 52 | if not os.path.isdir(submit_dir): 53 | print("%s doesn't exist" % submit_dir) 54 | if os.path.isdir(submit_dir) and os.path.isdir(truth_dir): 55 | if not os.path.exists(output_dir): 56 | os.makedirs(output_dir) 57 | 58 | pan_pred_json_file = os.path.join(submit_dir, 'pred.json') 59 | with open(pan_pred_json_file, 'r') as f: 60 | pred_jsons = json.load(f) 61 | with open(pan_gt_json_file, 'r') as f: 62 | gt_jsons = json.load(f) 63 | 64 | categories = gt_jsons['categories'] 65 | 66 | thing_list_ = [] 67 | for cate_ in categories: 68 | cat_id = cate_['id'] 69 | isthing = cate_['isthing'] 70 | if isthing: 71 | thing_list_.append(cat_id) 72 | 73 | stq_metric = numpy_stq.STQuality(n_classes, thing_list_, ignore_label, bit_shift, 2**24) 74 | 75 | pred_annos = pred_jsons['annotations'] 76 | pred_j = {} 77 | for p_a in pred_annos: 78 | pred_j[p_a['video_id']] = p_a['annotations'] 79 | gt_annos = gt_jsons['annotations'] 80 | gt_j = {} 81 | for g_a in gt_annos: 82 | gt_j[g_a['video_id']] = g_a['annotations'] 83 | 84 | pbar = tqdm(gt_jsons['videos']) 85 | for seq_id, video_images in enumerate(pbar): 86 | video_id = video_images['video_id'] 87 | pbar.set_description(video_id) 88 | 89 | # print('processing video:{}'.format(video_id)) 90 | gt_image_jsons = video_images['images'] 91 | gt_js = gt_j[video_id] 92 | pred_js = pred_j[video_id] 93 | assert len(gt_js) == len(pred_js) 94 | 95 | gt_pans = [] 96 | pred_pans = [] 97 | for imgname_j in gt_image_jsons: 98 | imgname = imgname_j['file_name'] 99 | image = np.array(Image.open(os.path.join(submit_dir, 'pan_pred', video_id, imgname))) 100 | pred_pans.append(image) 101 | image = np.array(Image.open(os.path.join(truth_dir, video_id, imgname))) 102 | gt_pans.append(image) 103 | gt_id_to_ins_num_dic = {} 104 | list_tmp = [] 105 | for segm in gt_js: 106 | for img_info in segm['segments_info']: 107 | id_tmp_ = img_info['id'] 108 | if id_tmp_ not in list_tmp: 109 | list_tmp.append(id_tmp_) 110 | for ii, id_tmp_ in enumerate(list_tmp): 111 | gt_id_to_ins_num_dic[id_tmp_] = ii 112 | 113 | pred_id_to_ins_num_dic = {} 114 | list_tmp = [] 115 | for segm in pred_js: 116 | for img_info in segm['segments_info']: 117 | id_tmp_ = img_info['id'] 118 | if id_tmp_ not in list_tmp: 119 | list_tmp.append(id_tmp_) 120 | for ii, id_tmp_ in enumerate(list_tmp): 121 | pred_id_to_ins_num_dic[id_tmp_] = ii 122 | 123 | for i, (gt_json, pred_json, gt_pan, pred_pan, gt_image_json) in enumerate( 124 | list(zip(gt_js, pred_js, gt_pans, pred_pans, gt_image_jsons))): 125 | #### Step1. Collect frame-level pan_gt, pan_pred, etc. 126 | gt_pan, pred_pan = np.uint32(gt_pan), np.uint32(pred_pan) 127 | pan_gt = gt_pan[:, :, 0] + gt_pan[:, :, 1] * 256 + gt_pan[:, :, 2] * 256 * 256 128 | pan_pred = pred_pan[:, :, 0] + pred_pan[:, :, 1] * 256 + pred_pan[:, :, 2] * 256 * 256 129 | 130 | ground_truth_instance = np.ones_like(pan_gt) * 255 131 | ground_truth_semantic = np.ones_like(pan_gt) * 255 132 | for el in gt_json['segments_info']: 133 | id_ = el['id'] 134 | cate_id = el['category_id'] 135 | ground_truth_semantic[pan_gt == id_] = cate_id 136 | ground_truth_instance[pan_gt == id_] = gt_id_to_ins_num_dic[id_] 137 | 138 | ground_truth = ((ground_truth_semantic << bit_shift) + ground_truth_instance) 139 | 140 | prediction_instance = np.ones_like(pan_pred) * 255 141 | prediction_semantic = np.ones_like(pan_pred) * 255 142 | 143 | for el in pred_json['segments_info']: 144 | id_ = el['id'] 145 | cate_id = el['category_id'] 146 | prediction_semantic[pan_pred == id_] = cate_id 147 | prediction_instance[pan_pred == id_] = pred_id_to_ins_num_dic[id_] 148 | prediction = ((prediction_semantic << bit_shift) + prediction_instance) 149 | 150 | stq_metric.update_state(ground_truth.astype(dtype=np.int32), 151 | prediction.astype(dtype=np.int32), seq_id) 152 | result = stq_metric.result() 153 | print('*' * 100) 154 | print('STQ : {}'.format(result['STQ'])) 155 | print('AQ :{}'.format(result['AQ'])) 156 | print('IoU:{}'.format(result['IoU'])) 157 | print('STQ_per_seq') 158 | print(result['STQ_per_seq']) 159 | print('AQ_per_seq') 160 | print(result['AQ_per_seq']) 161 | print('ID_per_seq') 162 | print(result['ID_per_seq']) 163 | print('Length_per_seq') 164 | print(result['Length_per_seq']) 165 | print('*' * 100) 166 | 167 | with open(os.path.join(submit_dir, 'stq.txt'), 'w') as f: 168 | f.write(f'{result["STQ"]*100:.1f},{result["AQ"]*100:.1f},{result["IoU"]*100:.1f}\n') 169 | 170 | 171 | if __name__ == "__main__": 172 | args = parse_args() 173 | submit_dir = args.submit_dir 174 | truth_dir = args.truth_dir 175 | pan_gt_json_file = args.pan_gt_json_file 176 | eval_stq(submit_dir, truth_dir, pan_gt_json_file) 177 | -------------------------------------------------------------------------------- /deva/vps_metrics/stuff_merging.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | 4 | import json 5 | from argparse import ArgumentParser 6 | import numpy as np 7 | from PIL import Image 8 | from functools import partial 9 | from tqdm import tqdm 10 | 11 | from deva.utils.vipseg_categories import VIPSEG_CATEGORIES 12 | from deva.utils.pano_utils import IDPostprocessor, id_to_rgb 13 | from multiprocessing import Pool 14 | """ 15 | Post-processing is done online, so technically it can be run with the original evaluation. 16 | But that introduces additional complexities that are just tailored for VPS/VIPSeg. 17 | So I make this part into a post-processing script. 18 | 19 | Specifically, it does the following: 20 | 1. For every "thing" segment, whenever its category changes, we give it a new object id. 21 | This is because the evaluation script assumes that objects with the same object id 22 | have the same category. This change aligns with the original formula in 23 | Video Panoptic Segmentation, CVPR 2020. 24 | 2. It stores a mapping table from "stuff" category to object id. Whenever we encounter a segment 25 | with a "stuff" category, we apply this mapping. 26 | """ 27 | 28 | 29 | def process_single_video(vid_ann, input_path, output_path): 30 | video_id = vid_ann['video_id'] 31 | video_output_annotation = [] 32 | video_output = {'video_id': video_id, 'annotations': video_output_annotation} 33 | output_path = path.join(output_path, 'pan_pred', video_id) 34 | os.makedirs(output_path, exist_ok=True) 35 | 36 | converter = IDPostprocessor() 37 | 38 | for ann in vid_ann['annotations']: 39 | file_name = ann['file_name'] 40 | segments_info = ann['segments_info'] 41 | output_segments_info = [] 42 | output_annotation = {'file_name': ann['file_name'], 'segments_info': output_segments_info} 43 | video_output_annotation.append(output_annotation) 44 | 45 | mask = np.array( 46 | Image.open( 47 | path.join(input_path, 'pan_pred', video_id, 48 | file_name.replace('.jpg', '.png')))).astype(np.int32) 49 | mask = mask[:, :, 0] + mask[:, :, 1] * 256 + mask[:, :, 2] * 256 * 256 50 | output_mask = np.zeros_like(mask) 51 | 52 | for segment in segments_info: 53 | id = segment['id'] 54 | category_id = segment['category_id'] 55 | isthing = vipseg_cat_to_isthing[category_id] 56 | new_id = converter.convert(id, category_id, isthing) 57 | output_mask[mask == id] = new_id 58 | 59 | if isthing: 60 | # is not stuff 61 | output_segment = { 62 | 'id': new_id, 63 | 'category_id': segment['category_id'], 64 | 'isthing': 1, 65 | } 66 | output_segments_info.append(output_segment) 67 | 68 | # a pass for the merged stuff objects 69 | for cat, new_id in converter.stuff_to_id.items(): 70 | area = int((output_mask == new_id).sum()) 71 | assert not vipseg_cat_to_isthing[cat] 72 | if area > 0: 73 | output_segment = { 74 | 'id': new_id, 75 | 'category_id': cat, 76 | 'isthing': 0, 77 | } 78 | output_segments_info.append(output_segment) 79 | 80 | # save the new output mask 81 | output_mask = id_to_rgb(output_mask) 82 | output_mask = Image.fromarray(output_mask) 83 | output_mask.save(path.join(output_path, file_name.replace('.jpg', '.png'))) 84 | 85 | return video_output 86 | 87 | 88 | vipseg_cat_to_isthing = {d['id']: d['isthing'] == 1 for d in VIPSEG_CATEGORIES} 89 | 90 | 91 | def merge_stuff(input_path, output_path): 92 | with open(path.join(input_path, 'pred.json')) as f: 93 | annotations = json.load(f)['annotations'] 94 | 95 | output_annotations = [] 96 | pool = Pool(16) 97 | for out_vid_ann in tqdm(pool.imap( 98 | partial(process_single_video, input_path=input_path, output_path=output_path), 99 | annotations), 100 | max_value=len(annotations)): 101 | output_annotations.append(out_vid_ann) 102 | 103 | output_json = {'annotations': output_annotations} 104 | with open(path.join(output_path, 'pred.json'), 'w') as f: 105 | json.dump(output_json, f) 106 | 107 | 108 | if __name__ == '__main__': 109 | parser = ArgumentParser() 110 | parser.add_argument('--input_path') 111 | parser.add_argument('--output_path') 112 | args = parser.parse_args() 113 | 114 | input_path = args.input_path 115 | output_path = args.output_path 116 | 117 | merge_stuff(input_path, output_path) 118 | -------------------------------------------------------------------------------- /docs/CUSTOM.md: -------------------------------------------------------------------------------- 1 | # Using DEVA with Your Custom Detection Models 2 | 3 | There are two main ways to use your own detection models: 4 | 1. Online integration. Use a single script that queries the image detection model when needed. This is how the demo works (by querying Grounded Segment Anything or Segment Anything). 5 | 2. Offline integration. Run the image detection model on all frames beforehand and save the results. Then, run DEVA on the saved results. This is how we evaluate DEVA with the benchmark detection models. 6 | 7 | From an algorithm point-of-view, both approaches are online/semi-online. There is only an implementation difference. 8 | 9 | For (1), look at `demo/demo_automatic` and `demo/demo_with_text`. 10 | 11 | For (2), generate the detections following the data format in `example/vipseg`. There is a json file associated with every segmentation that contains object IDs and meta information. "category_id" and "score" can be optionally included in the json. Then follow the "demo" command listed in [EVALUATION.md](EVALUATION.md). 12 | 13 | You can also "mock" your data as BURST/VIPSeg and run the models as if you are evaluating on BURST/VIPSeg. 14 | -------------------------------------------------------------------------------- /docs/DEMO.md: -------------------------------------------------------------------------------- 1 | # Details on the Demo 2 | 3 | ## Pipeline 4 | 5 | The high-level description below is for the online setting. In the semi-online setting, the detections are first merged across a small clip. 6 | The first frame is always initialized with detection without propagation. 7 | 8 | ### Text-prompted mode (recommended) 9 | 1. DEVA propagates masks from memory to the current frame 10 | 2. If this is a detection frame, go to the next step. Otherwise, no further processing is needed for this frame. 11 | 3. Grounding DINO takes the text prompt and generates some bounding boxes 12 | 4. Segment Anything takes the bounding boxes and generates corresponding segmentation masks 13 | 5. The propagated masks are compared to and merged with the segmentation from Segment Anything 14 | 15 | ### Automatic mode 16 | 1. DEVA propagates masks from memory to the current frame. 17 | 2. If this is a detection frame, go to the next step. Otherwise, no further processing is needed for this frame. 18 | 3. We generate a grid of points on the unsegmented regions. 19 | 4. Segment Anything takes the points and generates corresponding segmentation masks. 20 | 5. The propagated masks are compared to and merged with the segmentation from Segment Anything. 21 | 22 | ## Tips on Speeding up Inference 23 | 24 | **General Tips:** 25 | 26 | - Though innocently looking, reading frames from disk, visualizing the output, and encoding the output as videos can be slow, especially at high resolutions. The script version runs faster than the gradio version because it uses threaded I/O. 27 | - Specifying `--amp` (automatic fixed precision) makes things run faster on most modern GPUs. 28 | - In general, text-prompted inference is faster and more robust than "automatic" inference. 29 | - To speed up the actual processing, we need to speed up either the image model or the propagation model. 30 | 31 | **Speed up the image model:** 32 | 33 | - The most efficient way is to use the image model less often. This can be achieved by: 34 | - Using `online` instead of `semionline`, or, 35 | - Increasing `detection_every`. 36 | - Use a faster image model. For example, Mobile-SAM is faster than SAM. Grounded-Segment-Anything (text-prompt) is faster than automatic SAM. In automatic mode, you can reduce the number of prompting points (`SAM_NUM_POINTS_PER_SIDE`) to reduce the number of queries to SAM. 37 | - In automatic mode, increasing `SAM_NUM_POINTS_PER_BATCH` improves parallelism. 38 | 39 | **Speeding up the propagation model:** 40 | 41 | - In general, the running time of the propagation model scales linearly with the number of objects (not to be confused with direct proportionality). The best play is thus to reduce the number of objects: 42 | - Using text-prompt typically generates more relevant objects and fewer overall number of objects. 43 | - Increasing the thresholds `SAM_PRED_IOU_THRESHOLD` or `DINO_THRESHOLD` reduces the number of detected objects. 44 | - Reduce `max_missed_detection_count` to delete objects more readily. 45 | - In automatic mode, enable `suppress_small_objects` to get larger and fewer segments. Note this option has its own overhead. 46 | - Reduce the internal processing resolution `size`. Note this does not affect the image model. 47 | - Increasing `chunk_size` improves parallelism. 48 | 49 | ## Explanation of arguments 50 | 51 | General: 52 | - `detection_every`: number of frames between two consecutive detections; a higher number means faster inference but slower responses to new objects 53 | - `amp`: enable mixed precision; is faster and has a lower memory usage 54 | - `chunk_size`: number of objects to be processed in parallel; a higher number means faster inference but higher memory usage 55 | - `size`: internal processing resolution for the propagation module; defaults to 480 56 | - `max_missed_detection_count`: maximum number of consecutive detections that can be missed before an object is deleted from memory. 57 | - `max_num_objects`: maximum number of objects that can be tracked at the same time; new objects are ignored if this is exceeded 58 | 59 | Text-prompted mode only: 60 | - `DINO_THRESHOLD`: threshold for DINO to consider a detection as valid 61 | - `prompt`: text prompt to use, separate by a full stop; e.g. "people.trees". The wording of the prompt and minor details like pluralization might affect the results. 62 | 63 | Automatic mode only: 64 | - `SAM_NUM_POINTS_PER_SIDE`: number of points per side to use for automatic grid-based prompting in SAM 65 | - `SAM_NUM_POINTS_PER_BATCH`: number of points prompts to process in parallel in SAM 66 | - `SAM_PRED_IOU_THRESHOLD`: threshold of predicted IoU to be considered as a valid segmentation for SAM 67 | - `suppress_small_objects`: if enabled, small objects that overlap with large objects are suppressed during the automatic mode; does not matter in the text-prompted mode 68 | - `SAM_OVERLAP_THRESHOLD`: if suppress_small_objects are enabled, this is the IoU threshold for the suppression. A lower threshold means more segmentation masks (less suppression) 69 | 70 | ## Source videos 71 | 72 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/4a00cd0d-f712-447f-82c4-6152addffd6b 73 | 74 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/c556d398-44dd-423b-9ff3-49763eaecd94 75 | 76 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/72e7495c-d5f9-4a8b-b7e8-8714b269e98d 77 | 78 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/337dd073-07eb-4392-9610-c5f6c6b94832 79 | 80 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/e5f6df87-9fd0-4178-8490-00c4b8dc613b 81 | -------------------------------------------------------------------------------- /docs/TRAINING.md: -------------------------------------------------------------------------------- 1 | # Training DEVA 2 | 3 | Note that this repository only supports the training of the temporal propagation module. For the image module, please refer to the individual projects. 4 | 5 | ## Setting Up Data 6 | 7 | We put datasets out-of-source, as in XMem. You do not need BL30K. The directory structure should look like this: 8 | ```bash 9 | ├── Tracking-Anything-with-DEVA 10 | ├── DAVIS 11 | │ ├── 2016 12 | │ │ ├── Annotations 13 | │ │ └── ... 14 | │ └── 2017 15 | │ ├── test-dev 16 | │ │ ├── Annotations 17 | │ │ └── ... 18 | │ └── trainval 19 | │ ├── Annotations 20 | │ └── ... 21 | ├── static 22 | │ ├── BIG_small 23 | │ └── ... 24 | └── YouTube 25 | │ ├── all_frames 26 | │ │ └── valid_all_frames 27 | │ ├── train 28 | │ └── valid 29 | └── OVIS-VOS-train 30 | ├── JPEGImages 31 | └── Annotations 32 | ``` 33 | 34 | You can try our script `python -m scripts.download_dataset` which might not work 100% of the time due to Google Drive's blocking. If it fails, please download the datasets manually. The links can be found in the script. 35 | 36 | To generate OVIS-VOS-train, use something like https://github.com/youtubevos/vis2vos or download our preprocessed version from https://drive.google.com/uc?id=1AZPyyqVqOl6j8THgZ1UdNJY9R1VGEFrX. 37 | 38 | ## Training Command 39 | The training command is the same as in XMem. We tried training with 4/8 GPUs. 40 | With 8 GPUs, 41 | ``` 42 | python -m torch.distributed.run --master_port 25763 --nproc_per_node=8 deva/train.py --exp_id deva_retrain --stage 03 43 | ``` 44 | - Change `nproc_per_node` to change the number of GPUs. 45 | - Prepend `CUDA_VISIBLE_DEVICES=...` if you want to use specific GPUs. 46 | - Change `master_port` if you encounter port collision. 47 | - `exp_id` is a unique experiment identifier that does not affect how the training is done. 48 | - Models will be saved in `./saves/`. 49 | - We simply use the last trained model without model selection. 50 | -------------------------------------------------------------------------------- /docs/style.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: 'Source Sans 3', sans-serif; 3 | font-size: 18px; 4 | margin-left: auto; 5 | margin-right: auto; 6 | font-weight: 300; 7 | height: 100%; 8 | max-width: 1000px; 9 | } 10 | 11 | .light { 12 | font-weight: 100; 13 | } 14 | 15 | .heavy { 16 | font-weight: 400; 17 | } 18 | 19 | .column { 20 | float: left; 21 | } 22 | 23 | .metric_table { 24 | border-collapse: collapse; 25 | margin-left: 15px; 26 | margin-right: auto; 27 | } 28 | 29 | .metric_table th { 30 | border-bottom: 1px solid #555; 31 | padding-left: 15px; 32 | padding-right: 15px; 33 | } 34 | 35 | .metric_table td { 36 | padding-left: 15px; 37 | padding-right: 15px; 38 | } 39 | 40 | .metric_table .left_align { 41 | text-align: left; 42 | } 43 | 44 | a:link, 45 | a:visited { 46 | color: #05538f; 47 | text-decoration: none; 48 | } 49 | 50 | a:hover { 51 | color: #63cbdd; 52 | } 53 | 54 | hr { 55 | border: 0; 56 | height: 1px; 57 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0)); 58 | } -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/eval_ref_davis.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from argparse import ArgumentParser 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | from tqdm import tqdm 8 | 9 | from deva.inference.data.referring_test_datasets import ReferringDAVISTestDataset 10 | from deva.inference.image_feature_store import ImageFeatureStore 11 | from deva.inference.inference_core import DEVAInferenceCore 12 | from deva.inference.consensus_associated import find_consensus_with_established_association 13 | from deva.utils.palette import davis_palette 14 | from deva.inference.result_utils import ResultSaver 15 | from deva.inference.eval_args import add_common_eval_args, get_model_and_config 16 | 17 | 18 | def main(): 19 | """ 20 | Arguments loading 21 | """ 22 | parser = ArgumentParser() 23 | parser.add_argument('--img_path', default='../DAVIS/2017/trainval/JPEGImages/480p') 24 | parser.add_argument('--mask_path') 25 | parser.add_argument('--num_voting_frames', 26 | default=5, 27 | type=int, 28 | help='Number of frames selected for the initial consensus voting') 29 | add_common_eval_args(parser) 30 | network, config, args = get_model_and_config(parser) 31 | """ 32 | Data preparation 33 | """ 34 | out_path = args.output 35 | meta_dataset = ReferringDAVISTestDataset(args.img_path, args.mask_path) 36 | torch.autograd.set_grad_enabled(False) 37 | 38 | videos = meta_dataset.get_videos() 39 | 40 | total_process_time = 0 41 | total_frames = 0 42 | 43 | # Start eval 44 | pbar = tqdm(videos, total=len(videos)) 45 | for vid_name in pbar: 46 | pbar.set_description(vid_name) 47 | video_scores = meta_dataset.get_scores(vid_name) 48 | try: 49 | """ 50 | initial pass, perform consensus voting and get a keyframe 51 | """ 52 | image_feature_store = ImageFeatureStore(network) 53 | vid_reader = meta_dataset.get_offline_sampled_frames(vid_name, 54 | config['num_voting_frames']) 55 | loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2) 56 | 57 | time_indices = [] 58 | images = [] 59 | masks = [] 60 | scores = [] 61 | for ti, data in enumerate(loader): 62 | time_indices.append(data['info']['time_index'][0].item()) 63 | image = data['rgb'].cuda()[0] 64 | mask = data['mask'].cuda()[0] 65 | images.append(image) 66 | masks.append(mask) 67 | 68 | frame_name = data['info']['frame'][0][:-4] 69 | scores.append(video_scores[frame_name]) 70 | 71 | torch.cuda.synchronize() 72 | start = torch.cuda.Event(enable_timing=True) 73 | end = torch.cuda.Event(enable_timing=True) 74 | start.record() 75 | keyframe_ti, projected_mask = find_consensus_with_established_association( 76 | time_indices, 77 | images, 78 | masks, 79 | scores=scores, 80 | network=network, 81 | store=image_feature_store, 82 | config=config) 83 | end.record() 84 | torch.cuda.synchronize() 85 | total_process_time += (start.elapsed_time(end) / 1000) 86 | """ 87 | Backward pass video reader 88 | """ 89 | backward_vid_reader = meta_dataset.get_partial_video_loader(vid_name, 90 | start=-1, 91 | end=keyframe_ti + 1, 92 | reverse=True) 93 | """ 94 | Forward pass video reader 95 | """ 96 | forward_vid_reader = meta_dataset.get_partial_video_loader(vid_name, 97 | start=keyframe_ti, 98 | end=-1, 99 | reverse=False) 100 | """ 101 | Running them in combination 102 | """ 103 | vid_readers = [backward_vid_reader, forward_vid_reader] 104 | for vid_reader in vid_readers: 105 | 106 | loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2) 107 | vid_length = len(loader) 108 | # no need to count usage for LT if the video is not that long anyway 109 | config['enable_long_term_count_usage'] = ( 110 | config['enable_long_term'] and 111 | (vid_length / (config['max_mid_term_frames'] - config['min_mid_term_frames']) * 112 | config['num_prototypes']) >= config['max_long_term_elements']) 113 | 114 | processor = DEVAInferenceCore(network, 115 | config=config, 116 | image_feature_store=image_feature_store) 117 | result_saver = ResultSaver(out_path, 118 | vid_name, 119 | dataset='ref_davis', 120 | palette=davis_palette, 121 | object_manager=processor.object_manager) 122 | 123 | for ti, data in enumerate(loader): 124 | with torch.cuda.amp.autocast(enabled=args.amp): 125 | image = data['rgb'].cuda()[0] 126 | info = data['info'] 127 | frame = info['frame'][0] 128 | shape = info['shape'] 129 | need_resize = info['need_resize'][0] 130 | image_ti = info['time_index'][0].item() 131 | 132 | if image_ti == keyframe_ti: 133 | mask = projected_mask 134 | else: 135 | mask = None 136 | 137 | start = torch.cuda.Event(enable_timing=True) 138 | end = torch.cuda.Event(enable_timing=True) 139 | start.record() 140 | 141 | # Run the model on this frame 142 | prob = processor.step(image, 143 | mask, 144 | end=(ti == vid_length - 1), 145 | hard_mask=False, 146 | image_ti_override=image_ti) 147 | 148 | end.record() 149 | torch.cuda.synchronize() 150 | total_process_time += (start.elapsed_time(end) / 1000) 151 | total_frames += 1 152 | 153 | result_saver.save_mask(prob, frame, need_resize=need_resize, shape=shape) 154 | 155 | result_saver.end() 156 | with open(path.join(out_path, vid_name, 'key.txt'), 'w') as f: 157 | f.write(f'options: {time_indices}; keyframe: {keyframe_ti}') 158 | 159 | except Exception as e: 160 | print(f'Runtime error at {vid_name}') 161 | print(e) 162 | raise e 163 | 164 | print(f'Total processing time: {total_process_time}') 165 | print(f'Total processed frames: {total_frames}') 166 | print(f'FPS: {total_frames / total_process_time}') 167 | print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}') 168 | 169 | 170 | 171 | if __name__ == '__main__': 172 | main() 173 | -------------------------------------------------------------------------------- /example/vipseg/images/12_1mWNahzcsAc/00001255.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001255.jpg -------------------------------------------------------------------------------- /example/vipseg/images/12_1mWNahzcsAc/00001258.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001258.jpg -------------------------------------------------------------------------------- /example/vipseg/images/12_1mWNahzcsAc/00001261.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001261.jpg -------------------------------------------------------------------------------- /example/vipseg/images/12_1mWNahzcsAc/00001264.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001264.jpg -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001255.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "id": 1, 4 | "isthing": false, 5 | "category_id": 7 6 | }, 7 | { 8 | "id": 2, 9 | "isthing": true, 10 | "category_id": 63 11 | }, 12 | { 13 | "id": 3, 14 | "isthing": false, 15 | "category_id": 66 16 | }, 17 | { 18 | "id": 4, 19 | "isthing": true, 20 | "category_id": 60 21 | }, 22 | { 23 | "id": 5, 24 | "isthing": false, 25 | "category_id": 14 26 | }, 27 | { 28 | "id": 6, 29 | "isthing": true, 30 | "category_id": 63 31 | }, 32 | { 33 | "id": 7, 34 | "isthing": true, 35 | "category_id": 60 36 | } 37 | ] -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001255.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001255.png -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001258.json: -------------------------------------------------------------------------------- 1 | [{"id": 1, "isthing": true, "category_id": 63}, {"id": 2, "isthing": false, "category_id": 66}, {"id": 3, "isthing": false, "category_id": 0}, {"id": 4, "isthing": true, "category_id": 63}, {"id": 5, "isthing": true, "category_id": 60}, {"id": 6, "isthing": false, "category_id": 14}, {"id": 7, "isthing": false, "category_id": 7}, {"id": 8, "isthing": true, "category_id": 60}] -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001258.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001258.png -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001261.json: -------------------------------------------------------------------------------- 1 | [{"id": 1, "isthing": false, "category_id": 7}, {"id": 2, "isthing": true, "category_id": 63}, {"id": 3, "isthing": true, "category_id": 63}, {"id": 4, "isthing": false, "category_id": 66}, {"id": 5, "isthing": true, "category_id": 60}, {"id": 6, "isthing": true, "category_id": 60}, {"id": 7, "isthing": false, "category_id": 14}] -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001261.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001261.png -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001264.json: -------------------------------------------------------------------------------- 1 | [{"id": 1, "isthing": false, "category_id": 7}, {"id": 2, "isthing": true, "category_id": 63}, {"id": 3, "isthing": true, "category_id": 63}, {"id": 4, "isthing": false, "category_id": 66}, {"id": 5, "isthing": false, "category_id": 0}, {"id": 6, "isthing": true, "category_id": 60}, {"id": 7, "isthing": true, "category_id": 60}, {"id": 8, "isthing": false, "category_id": 14}] -------------------------------------------------------------------------------- /example/vipseg/source/12_1mWNahzcsAc/00001264.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001264.png -------------------------------------------------------------------------------- /example/vos/Annotations/bmx-trees/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/Annotations/bmx-trees/00000.png -------------------------------------------------------------------------------- /example/vos/JPEGImages/bmx-trees/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00000.jpg -------------------------------------------------------------------------------- /example/vos/JPEGImages/bmx-trees/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00001.jpg -------------------------------------------------------------------------------- /example/vos/JPEGImages/bmx-trees/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00002.jpg -------------------------------------------------------------------------------- /example/vos/JPEGImages/bmx-trees/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00003.jpg -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["hatchling"] 3 | build-backend = "hatchling.build" 4 | 5 | [tool.hatch.metadata] 6 | allow-direct-references = true 7 | 8 | [tool.yapf] 9 | based_on_style = "pep8" 10 | indent_width = 4 11 | column_limit = 100 12 | 13 | [project] 14 | name = "deva" 15 | version = "1.0.0" 16 | authors = [{ name = "Rex Cheng", email = "hkchengrex@gmail.com" }] 17 | description = "Tracking Anything with Decoupled Video Segmenttation (DEVA), ICCV 2023" 18 | readme = "README.md" 19 | requires-python = ">=3.7" 20 | classifiers = [ 21 | "Programming Language :: Python :: 3", 22 | "Operating System :: OS Independent", 23 | ] 24 | dependencies = [ 25 | 'gitpython >= 3.1', 26 | 'thinplate@git+https://github.com/cheind/py-thin-plate-spline', 27 | 'hickle >= 5.0', 28 | 'tensorboard >= 2.12', 29 | 'numpy >= 1.22', 30 | 'Pillow >= 9.5', 31 | 'opencv-python >= 4.8', 32 | 'scipy >= 1.11.2', 33 | 'pycocotools >= 2.0.7', 34 | 'supervision >= 0.18', 35 | 'tqdm >= 4.66.1', 36 | 'gurobipy >= 10.0.3', 37 | 'PuLP >= 2.7', 38 | 'gradio >= 3.44', 39 | 'gdown >= 4.7.1', 40 | ] 41 | 42 | [tool.hatch.build.targets.wheel] 43 | packages = ["deva"] 44 | 45 | [project.urls] 46 | "Homepage" = "https://github.com/hkchengrex/Tracking-Anything-with-DEVA" 47 | "Bug Tracker" = "https://github.com/hkchengrex/Tracking-Anything-with-DEVA/issues" 48 | -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/download_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | import zipfile 4 | 5 | LICENSE = """ 6 | These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets. 7 | Please read and respect their licenses and terms before use. 8 | You should cite the original papers if you use any of the datasets. 9 | 10 | Links: 11 | DUTS: http://saliencydetection.net/duts 12 | HRSOD: https://github.com/yi94code/HRSOD 13 | FSS: https://github.com/HKUSTCV/FSS-1000 14 | ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html 15 | BIG: https://github.com/hkchengrex/CascadePSP 16 | 17 | YouTubeVOS: https://youtube-vos.org 18 | DAVIS: https://davischallenge.org/ 19 | """ 20 | 21 | print(LICENSE) 22 | print('Datasets will be downloaded and extracted to ../YouTube, ../static, ../DAVIS') 23 | reply = input('[y] to confirm, others to exit: ') 24 | if reply != 'y': 25 | exit() 26 | """ 27 | Static image data 28 | """ 29 | os.makedirs('../static', exist_ok=True) 30 | print('Downloading static datasets...') 31 | gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd', 32 | output='../static/static_data.zip', 33 | quiet=False) 34 | print('Extracting static datasets...') 35 | with zipfile.ZipFile('../static/static_data.zip', 'r') as zip_file: 36 | zip_file.extractall('../static/') 37 | print('Cleaning up static datasets...') 38 | os.remove('../static/static_data.zip') 39 | """ 40 | DAVIS dataset 41 | """ 42 | # Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing 43 | os.makedirs('../DAVIS/2017', exist_ok=True) 44 | 45 | print('Downloading DAVIS 2016...') 46 | gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD', 47 | output='../DAVIS/DAVIS-data.zip', 48 | quiet=False) 49 | 50 | print('Downloading DAVIS 2017 trainval...') 51 | gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d', 52 | output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 53 | quiet=False) 54 | 55 | print('Downloading DAVIS 2017 testdev...') 56 | gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD', 57 | output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 58 | quiet=False) 59 | 60 | print('Extracting DAVIS datasets...') 61 | with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file: 62 | zip_file.extractall('../DAVIS/') 63 | os.rename('../DAVIS/DAVIS', '../DAVIS/2016') 64 | 65 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file: 66 | zip_file.extractall('../DAVIS/2017/') 67 | os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval') 68 | 69 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file: 70 | zip_file.extractall('../DAVIS/2017/') 71 | os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev') 72 | 73 | print('Cleaning up DAVIS datasets...') 74 | os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip') 75 | os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip') 76 | os.remove('../DAVIS/DAVIS-data.zip') 77 | """ 78 | YouTubeVOS dataset 79 | """ 80 | os.makedirs('../YouTube', exist_ok=True) 81 | os.makedirs('../YouTube/all_frames', exist_ok=True) 82 | 83 | print('Downloading YouTubeVOS train...') 84 | gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4', 85 | output='../YouTube/train.zip', 86 | quiet=False) 87 | print('Downloading YouTubeVOS val...') 88 | gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t', 89 | output='../YouTube/valid.zip', 90 | quiet=False) 91 | print('Downloading YouTubeVOS all frames valid...') 92 | gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN', 93 | output='../YouTube/all_frames/valid.zip', 94 | quiet=False) 95 | 96 | print('Extracting YouTube datasets...') 97 | with zipfile.ZipFile('../YouTube/train.zip', 'r') as zip_file: 98 | zip_file.extractall('../YouTube/') 99 | with zipfile.ZipFile('../YouTube/valid.zip', 'r') as zip_file: 100 | zip_file.extractall('../YouTube/') 101 | with zipfile.ZipFile('../YouTube/all_frames/valid.zip', 'r') as zip_file: 102 | zip_file.extractall('../YouTube/all_frames') 103 | 104 | print('Cleaning up YouTubeVOS datasets...') 105 | os.remove('../YouTube/train.zip') 106 | os.remove('../YouTube/valid.zip') 107 | os.remove('../YouTube/all_frames/valid.zip') 108 | """ 109 | OVIS dataset 110 | """ 111 | os.makedirs('../OVIS-VOS-train', exist_ok=True) 112 | print('Downloading OVIS-VOS train...') 113 | gdown.download('https://drive.google.com/uc?id=1AZPyyqVqOl6j8THgZ1UdNJY9R1VGEFrX', 114 | output='../OVIS-VOS-train/OVIS-VOS-train.zip', 115 | quiet=False) 116 | print('Extracting OVIS...') 117 | with zipfile.ZipFile('../OVIS-VOS-train/OVIS-VOS-train.zip', 'r') as zip_file: 118 | zip_file.extractall('../OVIS-VOS-train/') 119 | print('Cleaning up OVIS...') 120 | os.remove('../OVIS-VOS-train/OVIS-VOS-train.zip') 121 | 122 | print('Done.') -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | wget -P ./saves/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/DEVA-propagation.pth 2 | wget -P ./saves/ https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth 3 | wget -P ./saves/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 4 | wget -O ./saves/sam_hq_vit_h.pth https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth?download=true 5 | wget -O ./saves/sam_hq_vit_tiny.pth https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth?download=true 6 | wget -P ./saves/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/mobile_sam.pt 7 | wget -P ./saves/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/GroundingDINO_SwinT_OGC.py -------------------------------------------------------------------------------- /scripts/merge_burst_json.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import os 4 | from os import path 5 | import tqdm 6 | 7 | gt_json_path = sys.argv[1] 8 | pred_path = sys.argv[2] 9 | out_path = sys.argv[3] 10 | 11 | with open(gt_json_path) as f: 12 | json_file = json.load(f) 13 | 14 | for sequence in tqdm.tqdm(json_file['sequences']): 15 | dataset = sequence['dataset'] 16 | seq_name = sequence['seq_name'] 17 | 18 | sequence['segmentations'] = [] 19 | 20 | with open(path.join(pred_path, dataset, seq_name, 'pred.json')) as f: 21 | pred_json = json.load(f) 22 | track_category_id = {} 23 | for frame_segmentation in pred_json['segmentations']: 24 | this_frame_segmentation = {} 25 | 26 | for segmentation_dict in frame_segmentation['segmentations']: 27 | this_frame_segmentation[segmentation_dict['id']] = { 28 | 'rle': segmentation_dict['rle']['counts'] 29 | } 30 | track_category_id[segmentation_dict['id']] = 0 31 | sequence['segmentations'].append(this_frame_segmentation) 32 | 33 | sequence['track_category_ids'] = track_category_id 34 | 35 | 36 | with open(out_path, 'w') as f: 37 | json.dump(json_file, f) 38 | -------------------------------------------------------------------------------- /scripts/merge_multi_scale.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import path 3 | from argparse import ArgumentParser 4 | import glob 5 | from collections import defaultdict 6 | 7 | import numpy as np 8 | import hickle as hkl 9 | from PIL import Image, ImagePalette 10 | 11 | from tqdm import tqdm 12 | from multiprocessing import Pool 13 | from deva.utils import palette 14 | 15 | from deva.utils.palette import davis_palette, youtube_palette 16 | import shutil 17 | 18 | 19 | def search_options(options, name): 20 | for option in options: 21 | if path.exists(path.join(option, name)): 22 | return path.join(option, name) 23 | else: 24 | return None 25 | 26 | 27 | def process_vid(vid): 28 | vid_path = search_options(all_options, vid) 29 | if vid_path is not None: 30 | backward_mapping = hkl.load(path.join(vid_path, 'backward.hkl')) 31 | else: 32 | backward_mapping = None 33 | 34 | frames = os.listdir(path.join(all_options[0], vid)) 35 | frames = [f for f in frames if 'backward' not in f] 36 | 37 | print(vid) 38 | if 'Y' in args.dataset: 39 | this_out_path = path.join(out_path, 'Annotations', vid) 40 | else: 41 | this_out_path = path.join(out_path, vid) 42 | os.makedirs(this_out_path, exist_ok=True) 43 | 44 | for f in frames: 45 | result_sum = None 46 | 47 | for option in all_options: 48 | if not path.exists(path.join(option, vid, f)): 49 | continue 50 | 51 | result = hkl.load(path.join(option, vid, f)) 52 | if result_sum is None: 53 | result_sum = result.astype(np.float32) 54 | else: 55 | result_sum += result 56 | 57 | # argmax and to idx 58 | result_sum = np.argmax(result_sum, axis=0) 59 | 60 | # Remap the indices to the original domain 61 | if backward_mapping is not None: 62 | idx_mask = np.zeros_like(result_sum, dtype=np.uint8) 63 | for l, i in backward_mapping.items(): 64 | idx_mask[result_sum == i] = l 65 | else: 66 | idx_mask = result_sum.astype(np.uint8) 67 | 68 | # Save the results 69 | img_E = Image.fromarray(idx_mask) 70 | img_E.putpalette(palette) 71 | img_E.save(path.join(this_out_path, f[:-4] + '.png')) 72 | 73 | 74 | if __name__ == '__main__': 75 | """ 76 | Arguments loading 77 | """ 78 | parser = ArgumentParser() 79 | parser.add_argument('--dataset', default='Y', help='D/Y, D for DAVIS; Y for YouTubeVOS') 80 | parser.add_argument('--list', nargs="+") 81 | parser.add_argument('--pattern', 82 | default=None, 83 | help='Glob pattern. Can be used in place of list.') 84 | parser.add_argument('--output') 85 | parser.add_argument('--num_proc', default=4, type=int) 86 | args = parser.parse_args() 87 | 88 | out_path = args.output 89 | 90 | # Find the input candidates 91 | if args.pattern is None: 92 | all_options = args.list 93 | else: 94 | assert args.list is None, 'cannot specify both list and pattern' 95 | all_options = glob.glob(args.pattern) 96 | 97 | # Get the correct palette 98 | if 'D' in args.dataset: 99 | palette = ImagePalette.ImagePalette(mode='P', palette=davis_palette) 100 | elif 'Y' in args.dataset: 101 | palette = ImagePalette.ImagePalette(mode='P', palette=youtube_palette) 102 | else: 103 | raise NotImplementedError 104 | 105 | # Count of the number of videos in each candidate 106 | all_options = [path.join(o, 'Scores') for o in all_options] 107 | vid_count = defaultdict(int) 108 | for option in all_options: 109 | vid_in_here = sorted(os.listdir(option)) 110 | for vid in vid_in_here: 111 | vid_count[vid] += 1 112 | 113 | all_vid = [] 114 | count_to_vid = defaultdict(int) 115 | for k, v in vid_count.items(): 116 | count_to_vid[v] += 1 117 | all_vid.append(k) 118 | 119 | for k, v in count_to_vid.items(): 120 | print('Videos with count %d: %d' % (k, v)) 121 | 122 | all_vid = sorted(all_vid) 123 | print('Total number of videos: ', len(all_vid)) 124 | 125 | pool = Pool(processes=args.num_proc) 126 | for _ in tqdm(pool.imap_unordered(process_vid, all_vid), max_value=len(all_vid)): 127 | pass 128 | 129 | pool.close() 130 | pool.join() 131 | 132 | if 'D' in args.dataset: 133 | print('Making zip for DAVIS test-dev...') 134 | shutil.make_archive(args.output, 'zip', args.output) 135 | 136 | if 'Y' in args.dataset: 137 | print('Making zip for YouTubeVOS...') 138 | shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output, 139 | 'Annotations') 140 | -------------------------------------------------------------------------------- /scripts/vipseg/change2_720p.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from multiprocessing import Pool 4 | 5 | 6 | DIR='imgs' 7 | DIR2='panomasks' 8 | 9 | Target_Dir = 'VIPSeg_720P' 10 | 11 | 12 | def change(DIR,video,image): 13 | if os.path.isfile(os.path.join(Target_Dir,'images',video,image)) and os.path.isfile(os.path.join(Target_Dir,'panomasks',video,image.split('.')[0]+'.png')): 14 | return 15 | 16 | img = Image.open(os.path.join(DIR,video,image)) 17 | w,h = img.size 18 | img = img.resize((int(720*w/h),720),Image.BILINEAR) 19 | 20 | if not os.path.isfile(os.path.join(DIR2,video,image.split('.')[0]+'.png')): 21 | # print('this is the test set') 22 | # print(os.path.join(DIR2,video,image.split('.')[0]+'.png')) 23 | return 24 | 25 | 26 | mask = Image.open(os.path.join(DIR2,video,image.split('.')[0]+'.png')) 27 | mask = mask.resize((int(720*w/h),720),Image.NEAREST) 28 | 29 | if not os.path.exists(os.path.join(Target_Dir,'images',video)): 30 | os.makedirs(os.path.join(Target_Dir,'images',video)) 31 | if not os.path.exists(os.path.join(Target_Dir,'panomasks',video)): 32 | os.makedirs(os.path.join(Target_Dir,'panomasks',video)) 33 | 34 | img.save(os.path.join(Target_Dir,'images',video,image)) 35 | mask.save(os.path.join(Target_Dir,'panomasks',video,image.split('.')[0]+'.png')) 36 | # print('Processing video {} image {}'.format(video,image)) 37 | 38 | p = Pool(16) 39 | for video in sorted(os.listdir(DIR)): 40 | print(video) 41 | if video[0]=='.': 42 | continue 43 | for image in sorted(os.listdir(os.path.join(DIR,video))): 44 | if image[0]=='.': 45 | continue 46 | p.apply_async(change,args=(DIR,video,image)) 47 | #change(DIR,video,image) 48 | p.close() 49 | p.join() 50 | print('finish') 51 | 52 | -------------------------------------------------------------------------------- /scripts/vipseg/create_panoptic_video_labels.py: -------------------------------------------------------------------------------- 1 | 2 | import os, sys 3 | import json 4 | import glob 5 | import numpy as np 6 | import PIL.Image as Image 7 | from tqdm import trange 8 | from panopticapi.utils import IdGenerator, save_json 9 | import argparse 10 | 11 | from multiprocessing import Pool 12 | 13 | ROOT_DIR='VIPSeg_720P/panomasks' 14 | TARGET_DIR='VIPSeg_720P/panomasksRGB' 15 | with open('VIPSeg_720P/panoVIPSeg_categories.json','r') as f: 16 | CATEGORIES = json.load(f) 17 | 18 | 19 | if __name__ == "__main__": 20 | 21 | def conversion_worker(video): 22 | # print('processing video:{}'.format(video)) 23 | videos_dic={} 24 | video_id = video 25 | videos_dic['video_id'] = video_id 26 | 27 | images = [] 28 | annotations = [] 29 | id_generator = IdGenerator(categories_dict) 30 | instid2color = {} 31 | 32 | 33 | for imgname in sorted(os.listdir(os.path.join(original_format_folder,video))): 34 | 35 | original_format = np.array(Image.open(os.path.join(original_format_folder,video,imgname))) 36 | image_id = imgname.split('.')[0] 37 | image_filename = imgname 38 | images.append({"id": image_id, 39 | "width": original_format.shape[1], 40 | "height": original_format.shape[0], 41 | "file_name": image_filename}) 42 | pan_format = np.zeros((original_format.shape[0], original_format.shape[1], 3), dtype=np.uint8) 43 | 44 | l = np.unique(original_format) 45 | 46 | segm_info = {} 47 | 48 | for el in l: 49 | if el==0: 50 | continue 51 | if el < 125: 52 | semantic_id = el 53 | is_crowd = 0 54 | else: 55 | semantic_id = el // 100 56 | is_crowd = 0 57 | semantic_id = semantic_id -1 58 | if semantic_id not in categories_dict: 59 | print(semantic_id, video, l, imgname, list(categories_dict.keys())) 60 | if semantic_id > 255: 61 | print(semantic_id, video, l, imgname) 62 | if categories_dict[semantic_id]['isthing'] == 0: 63 | is_crowd = 0 64 | mask = (original_format == el) 65 | 66 | if el not in instid2color: 67 | segment_id, color = id_generator.get_id_and_color(semantic_id) 68 | instid2color[el] = (segment_id, color) 69 | else: 70 | segment_id, color = instid2color[el] 71 | 72 | pan_format[mask] = color 73 | segm_info[int(segment_id)] = \ 74 | {"id": int(segment_id), 75 | "category_id": int(semantic_id), 76 | # "area": int(area), 77 | "iscrowd": is_crowd} 78 | if not os.path.exists(os.path.join(out_folder,video)): 79 | os.makedirs(os.path.join(out_folder,video)) 80 | 81 | Image.fromarray(pan_format).save(os.path.join(out_folder,video, image_filename)) 82 | # print('image saved {}'.format(os.path.join(out_folder,video, image_filename))) 83 | 84 | gt_pan = np.uint32(pan_format) 85 | pan_gt = gt_pan[:, :, 0] + gt_pan[:, :, 1] * 256 + gt_pan[:, :, 2] * 256 * 256 86 | # print(np.unique(pan_gt)) 87 | # exit() 88 | labels, labels_cnt = np.unique(pan_gt, return_counts=True) 89 | gt_labels = [_ for _ in segm_info.keys()] 90 | gt_labels_set = set(gt_labels) 91 | for label, area in zip(labels, labels_cnt): 92 | if label == 0: 93 | continue 94 | if label not in gt_labels and label > 0: 95 | print('png label not in json labels.', label, video, gt_labels, l) 96 | segm_info[label]["area"] = int(area) 97 | gt_labels_set.remove(label) 98 | if len(gt_labels_set) != 0: 99 | raise KeyError('remaining gt_labels json') 100 | 101 | segm_info = [v for k,v in segm_info.items()] 102 | annotations.append({'image_id': image_id, 103 | 'file_name': image_filename, 104 | "segments_info": segm_info}) 105 | v_anno ={'video_id':video_id,'annotations':annotations} 106 | videos_dic['images'] = images 107 | return v_anno, videos_dic 108 | # return None 109 | 110 | 111 | original_format_folder = ROOT_DIR 112 | # folder to store panoptic PNGs 113 | out_folder = os.path.join(TARGET_DIR) 114 | out_file = os.path.join('VIPSeg_720P/', 'panoptic_gt_VIPSeg.json') 115 | if not os.path.isdir(out_folder): 116 | os.makedirs(out_folder) 117 | 118 | categories = CATEGORIES 119 | categories_dict = {el['id']: el for el in CATEGORIES} 120 | 121 | v_videos=[] 122 | v_annotations=[] 123 | pool = Pool(16) 124 | 125 | results = pool.map(conversion_worker, sorted(os.listdir(original_format_folder)), chunksize=8) 126 | # results = [conversion_worker(p) for p in sorted(os.listdir(original_format_folder))] 127 | 128 | for v_anno, videos_dic in results: 129 | v_videos.append(videos_dic) 130 | v_annotations.append(v_anno) 131 | 132 | d = {'videos': v_videos, 133 | 'annotations': v_annotations, 134 | 'categories': categories, 135 | } 136 | 137 | save_json(d, out_file) 138 | 139 | 140 | print('==> Saved json file at %s'%(out_file)) 141 | --------------------------------------------------------------------------------