├── .gitignore
├── LICENSE.md
├── README.md
├── demo
├── __init__.py
├── demo_automatic.py
├── demo_gradio.py
└── demo_with_text.py
├── deva
├── __init__.py
├── dataset
│ ├── __init__.py
│ ├── static_dataset.py
│ ├── tps.py
│ ├── utils.py
│ └── vos_dataset.py
├── ext
│ ├── LightHQSAM
│ │ ├── __init__.py
│ │ ├── setup_light_hqsam.py
│ │ └── tiny_vit_sam.py
│ ├── MobileSAM
│ │ ├── __init__.py
│ │ ├── setup_mobile_sam.py
│ │ └── tiny_vit_sam.py
│ ├── SAM
│ │ ├── __init__.py
│ │ └── automatic_mask_generator.py
│ ├── __init__.py
│ ├── automatic_processor.py
│ ├── automatic_sam.py
│ ├── ext_eval_args.py
│ ├── grounding_dino.py
│ └── with_text_processor.py
├── inference
│ ├── __init__.py
│ ├── consensus_associated.py
│ ├── consensus_automatic.py
│ ├── data
│ │ ├── __init__.py
│ │ ├── detection_video_reader.py
│ │ ├── referring_test_datasets.py
│ │ ├── saliency_test_datasets.py
│ │ ├── simple_video_reader.py
│ │ ├── video_reader.py
│ │ ├── vos_test_datasets.py
│ │ └── vps_test_datasets.py
│ ├── demo_utils.py
│ ├── eval_args.py
│ ├── frame_utils.py
│ ├── image_feature_store.py
│ ├── inference_core.py
│ ├── kv_memory_store.py
│ ├── memory_manager.py
│ ├── object_info.py
│ ├── object_manager.py
│ ├── object_utils.py
│ ├── postprocess_unsup_davis17.py
│ ├── result_utils.py
│ └── segment_merging.py
├── model
│ ├── __init__.py
│ ├── big_modules.py
│ ├── cbam.py
│ ├── group_modules.py
│ ├── losses.py
│ ├── memory_utils.py
│ ├── modules.py
│ ├── network.py
│ ├── resnet.py
│ └── trainer.py
├── train.py
├── utils
│ ├── __init__.py
│ ├── burst_test.txt
│ ├── burst_val.txt
│ ├── configuration.py
│ ├── davis_subset.txt
│ ├── image_saver.py
│ ├── load_subset.py
│ ├── log_integrator.py
│ ├── logger.py
│ ├── palette.py
│ ├── pano_utils.py
│ ├── referring-youtubevos-val.txt
│ ├── tensor_utils.py
│ ├── vipseg_categories.py
│ └── yv_subset.txt
└── vps_metrics
│ ├── README.md
│ ├── __init__.py
│ ├── eval_stq_vipseg.py
│ ├── eval_vpq_vipseg.py
│ ├── segmentation_and_tracking_quality.py
│ └── stuff_merging.py
├── docs
├── CUSTOM.md
├── DEMO.md
├── EVALUATION.md
├── TRAINING.md
├── index.html
└── style.css
├── evaluation
├── __init__.py
├── eval_ref_davis.py
├── eval_ref_youtubevos.py
├── eval_saliency.py
├── eval_vos.py
└── eval_with_detections.py
├── example
├── vipseg
│ ├── images
│ │ └── 12_1mWNahzcsAc
│ │ │ ├── 00001255.jpg
│ │ │ ├── 00001258.jpg
│ │ │ ├── 00001261.jpg
│ │ │ └── 00001264.jpg
│ └── source
│ │ └── 12_1mWNahzcsAc
│ │ ├── 00001255.json
│ │ ├── 00001255.png
│ │ ├── 00001258.json
│ │ ├── 00001258.png
│ │ ├── 00001261.json
│ │ ├── 00001261.png
│ │ ├── 00001264.json
│ │ └── 00001264.png
└── vos
│ ├── Annotations
│ └── bmx-trees
│ │ └── 00000.png
│ └── JPEGImages
│ └── bmx-trees
│ ├── 00000.jpg
│ ├── 00001.jpg
│ ├── 00002.jpg
│ └── 00003.jpg
├── pyproject.toml
└── scripts
├── __init__.py
├── download_datasets.py
├── download_models.sh
├── merge_burst_json.py
├── merge_multi_scale.py
└── vipseg
├── change2_720p.py
└── create_panoptic_video_labels.py
/.gitignore:
--------------------------------------------------------------------------------
1 | run_*.sh
2 | log/
3 | saves
4 | saves/
5 | output/
6 | .vscode/
7 | example/*/output
8 | example/*/postprocessed
9 | example/videos
10 | example/output
11 | gradio_cached_examples/
12 | flagged/
13 | *.pth
14 | *.pt
15 |
16 | # Byte-compiled / optimized / DLL files
17 | __pycache__/
18 | *.py[cod]
19 | *$py.class
20 |
21 | # C extensions
22 | *.so
23 |
24 | # Distribution / packaging
25 | .Python
26 | build/
27 | develop-eggs/
28 | dist/
29 | downloads/
30 | eggs/
31 | .eggs/
32 | lib/
33 | lib64/
34 | parts/
35 | sdist/
36 | var/
37 | wheels/
38 | pip-wheel-metadata/
39 | share/python-wheels/
40 | *.egg-info/
41 | .installed.cfg
42 | *.egg
43 | MANIFEST
44 |
45 | # PyInstaller
46 | # Usually these files are written by a python script from a template
47 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
48 | *.manifest
49 | *.spec
50 |
51 | # Installer logs
52 | pip-log.txt
53 | pip-delete-this-directory.txt
54 |
55 | # Unit test / coverage reports
56 | htmlcov/
57 | .tox/
58 | .nox/
59 | .coverage
60 | .coverage.*
61 | .cache
62 | nosetests.xml
63 | coverage.xml
64 | *.cover
65 | *.py,cover
66 | .hypothesis/
67 | .pytest_cache/
68 |
69 | # Translations
70 | *.mo
71 | *.pot
72 |
73 | # Django stuff:
74 | *.log
75 | local_settings.py
76 | db.sqlite3
77 | db.sqlite3-journal
78 |
79 | # Flask stuff:
80 | instance/
81 | .webassets-cache
82 |
83 | # Scrapy stuff:
84 | .scrapy
85 |
86 | # Sphinx documentation
87 | docs/_build/
88 |
89 | # PyBuilder
90 | target/
91 |
92 | # Jupyter Notebook
93 | .ipynb_checkpoints
94 |
95 | # IPython
96 | profile_default/
97 | ipython_config.py
98 |
99 | # pyenv
100 | .python-version
101 |
102 | # pipenv
103 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
104 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
105 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
106 | # install all needed dependencies.
107 | #Pipfile.lock
108 |
109 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
110 | __pypackages__/
111 |
112 | # Celery stuff
113 | celerybeat-schedule
114 | celerybeat.pid
115 |
116 | # SageMath parsed files
117 | *.sage.py
118 |
119 | # Environments
120 | .env
121 | .venv
122 | env/
123 | venv/
124 | ENV/
125 | env.bak/
126 | venv.bak/
127 |
128 | # Spyder project settings
129 | .spyderproject
130 | .spyproject
131 |
132 | # Rope project settings
133 | .ropeproject
134 |
135 | # mkdocs documentation
136 | /site
137 |
138 | # mypy
139 | .mypy_cache/
140 | .dmypy.json
141 | dmypy.json
142 |
143 | # Pyre type checker
144 | .pyre/
145 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | 
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
2 |
--------------------------------------------------------------------------------
/demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/demo/__init__.py
--------------------------------------------------------------------------------
/demo/demo_automatic.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from argparse import ArgumentParser
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 |
9 | from deva.inference.inference_core import DEVAInferenceCore
10 | from deva.inference.data.simple_video_reader import SimpleVideoReader, no_collate
11 | from deva.inference.result_utils import ResultSaver
12 | from deva.inference.eval_args import add_common_eval_args, get_model_and_config
13 | from deva.inference.demo_utils import flush_buffer
14 | from deva.ext.ext_eval_args import add_ext_eval_args, add_auto_default_args
15 | from deva.ext.automatic_sam import get_sam_model
16 | from deva.ext.automatic_processor import process_frame_automatic as process_frame
17 |
18 | from tqdm import tqdm
19 | import json
20 |
21 | if __name__ == '__main__':
22 | torch.autograd.set_grad_enabled(False)
23 |
24 | # for id2rgb
25 | np.random.seed(42)
26 | """
27 | Arguments loading
28 | """
29 | parser = ArgumentParser()
30 |
31 | add_common_eval_args(parser)
32 | add_ext_eval_args(parser)
33 | add_auto_default_args(parser)
34 | deva_model, cfg, args = get_model_and_config(parser)
35 | sam_model = get_sam_model(cfg, 'cuda')
36 | """
37 | Temporal setting
38 | """
39 | cfg['temporal_setting'] = args.temporal_setting.lower()
40 | assert cfg['temporal_setting'] in ['semionline', 'online']
41 |
42 | # get data
43 | video_reader = SimpleVideoReader(cfg['img_path'])
44 | loader = DataLoader(video_reader, batch_size=None, collate_fn=no_collate, num_workers=8)
45 | out_path = cfg['output']
46 |
47 | # Start eval
48 | vid_length = len(loader)
49 | # no need to count usage for LT if the video is not that long anyway
50 | cfg['enable_long_term_count_usage'] = (
51 | cfg['enable_long_term']
52 | and (vid_length / (cfg['max_mid_term_frames'] - cfg['min_mid_term_frames']) *
53 | cfg['num_prototypes']) >= cfg['max_long_term_elements'])
54 |
55 | print('Configuration:', cfg)
56 |
57 | deva = DEVAInferenceCore(deva_model, config=cfg)
58 | deva.next_voting_frame = args.num_voting_frames - 1
59 | deva.enabled_long_id()
60 | result_saver = ResultSaver(out_path, None, dataset='demo', object_manager=deva.object_manager)
61 |
62 | with torch.cuda.amp.autocast(enabled=args.amp):
63 | for ti, (frame, im_path) in enumerate(tqdm(loader)):
64 | process_frame(deva, sam_model, im_path, result_saver, ti, image_np=frame)
65 | flush_buffer(deva, result_saver)
66 | result_saver.end()
67 |
68 | # save this as a video-level json
69 | with open(path.join(out_path, 'pred.json'), 'w') as f:
70 | json.dump(result_saver.video_json, f, indent=4) # prettier json
71 |
--------------------------------------------------------------------------------
/demo/demo_with_text.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from argparse import ArgumentParser
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | import numpy as np
8 |
9 | from deva.inference.inference_core import DEVAInferenceCore
10 | from deva.inference.data.simple_video_reader import SimpleVideoReader, no_collate
11 | from deva.inference.result_utils import ResultSaver
12 | from deva.inference.eval_args import add_common_eval_args, get_model_and_config
13 | from deva.inference.demo_utils import flush_buffer
14 | from deva.ext.ext_eval_args import add_ext_eval_args, add_text_default_args
15 | from deva.ext.grounding_dino import get_grounding_dino_model
16 | from deva.ext.with_text_processor import process_frame_with_text as process_frame
17 |
18 | from tqdm import tqdm
19 | import json
20 |
21 | if __name__ == '__main__':
22 | torch.autograd.set_grad_enabled(False)
23 |
24 | # for id2rgb
25 | np.random.seed(42)
26 | """
27 | Arguments loading
28 | """
29 | parser = ArgumentParser()
30 |
31 | add_common_eval_args(parser)
32 | add_ext_eval_args(parser)
33 | add_text_default_args(parser)
34 | deva_model, cfg, args = get_model_and_config(parser)
35 | gd_model, sam_model = get_grounding_dino_model(cfg, 'cuda')
36 | """
37 | Temporal setting
38 | """
39 | cfg['temporal_setting'] = args.temporal_setting.lower()
40 | assert cfg['temporal_setting'] in ['semionline', 'online']
41 |
42 | # get data
43 | video_reader = SimpleVideoReader(cfg['img_path'])
44 | loader = DataLoader(video_reader, batch_size=None, collate_fn=no_collate, num_workers=8)
45 | out_path = cfg['output']
46 |
47 | # Start eval
48 | vid_length = len(loader)
49 | # no need to count usage for LT if the video is not that long anyway
50 | cfg['enable_long_term_count_usage'] = (
51 | cfg['enable_long_term']
52 | and (vid_length / (cfg['max_mid_term_frames'] - cfg['min_mid_term_frames']) *
53 | cfg['num_prototypes']) >= cfg['max_long_term_elements'])
54 |
55 | print('Configuration:', cfg)
56 |
57 | deva = DEVAInferenceCore(deva_model, config=cfg)
58 | deva.next_voting_frame = cfg['num_voting_frames'] - 1
59 | deva.enabled_long_id()
60 | result_saver = ResultSaver(out_path, None, dataset='demo', object_manager=deva.object_manager)
61 |
62 | with torch.cuda.amp.autocast(enabled=cfg['amp']):
63 | for ti, (frame, im_path) in enumerate(tqdm(loader)):
64 | process_frame(deva, gd_model, sam_model, im_path, result_saver, ti, image_np=frame)
65 | flush_buffer(deva, result_saver)
66 | result_saver.end()
67 |
68 | # save this as a video-level json
69 | with open(path.join(out_path, 'pred.json'), 'w') as f:
70 | json.dump(result_saver.video_json, f, indent=4) # prettier json
71 |
--------------------------------------------------------------------------------
/deva/__init__.py:
--------------------------------------------------------------------------------
1 | from deva.inference.inference_core import DEVAInferenceCore
2 | from deva.model.network import DEVA
--------------------------------------------------------------------------------
/deva/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/dataset/__init__.py
--------------------------------------------------------------------------------
/deva/dataset/static_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import torch
5 | from torch.utils.data.dataset import Dataset
6 | from torchvision import transforms
7 | from torchvision.transforms import InterpolationMode
8 | from PIL import Image
9 | import numpy as np
10 |
11 | from deva.dataset.utils import im_normalization, im_mean, reseed
12 | from deva.dataset.tps import random_tps_warp
13 |
14 |
15 | class StaticTransformDataset(Dataset):
16 | """
17 | Generate pseudo VOS data by applying random transforms on static images.
18 | Single-object only.
19 |
20 | parameters is a list of tuples (data_root, how data is structured (0 or 1), and sample multiplier)
21 |
22 | Method 0 - FSS style (class/1.jpg class/1.png)
23 | Method 1 - Others style (XXX.jpg XXX.png)
24 | """
25 | def __init__(self, parameters, *, size=384, num_frames=3, max_num_obj=1):
26 | self.num_frames = num_frames
27 | self.max_num_obj = max_num_obj
28 | self.size = size
29 |
30 | self.im_list = []
31 | for parameter in parameters:
32 | root, method, multiplier = parameter
33 | if method == 0:
34 | # Get images
35 | classes = os.listdir(root)
36 | for c in classes:
37 | imgs = os.listdir(path.join(root, c))
38 | jpg_list = [im for im in imgs if 'jpg' in im[-3:].lower()]
39 |
40 | joint_list = [path.join(root, c, im) for im in jpg_list]
41 | self.im_list.extend(joint_list * multiplier)
42 |
43 | elif method == 1:
44 | self.im_list.extend(
45 | [path.join(root, im) for im in os.listdir(root) if '.jpg' in im] * multiplier)
46 |
47 | print(f'{len(self.im_list)} images found.')
48 |
49 | # These set of transform is the same for im/gt pairs, but different among the 3 sampled frames
50 | self.pair_im_lone_transform = transforms.Compose([
51 | transforms.ColorJitter(0.1, 0.05, 0.05, 0),
52 | ])
53 |
54 | self.pair_im_dual_transform = transforms.Compose([
55 | transforms.RandomAffine(degrees=20,
56 | scale=(0.5, 2.0),
57 | shear=10,
58 | interpolation=InterpolationMode.BICUBIC,
59 | fill=im_mean),
60 | transforms.Resize(self.size, InterpolationMode.BICUBIC, antialias=True),
61 | transforms.RandomCrop((self.size, self.size), pad_if_needed=True, fill=im_mean),
62 | ])
63 |
64 | self.pair_gt_dual_transform = transforms.Compose([
65 | transforms.RandomAffine(
66 | degrees=20,
67 | scale=(0.5, 2.0),
68 | shear=10,
69 | # don't know why I used bicubic here.
70 | # Since GT is binary it shouldn't matter much
71 | interpolation=InterpolationMode.BICUBIC,
72 | fill=0),
73 | transforms.Resize(self.size, InterpolationMode.NEAREST),
74 | transforms.RandomCrop((self.size, self.size), pad_if_needed=True, fill=0),
75 | ])
76 |
77 | # These transform are the same for all pairs in the sampled sequence
78 | self.all_im_lone_transform = transforms.Compose([
79 | transforms.ColorJitter(0.1, 0.05, 0.05, 0.05),
80 | transforms.RandomGrayscale(0.05),
81 | ])
82 |
83 | self.all_im_dual_transform = transforms.Compose([
84 | transforms.RandomAffine(degrees=0, scale=(0.5, 2.0), fill=im_mean),
85 | transforms.RandomHorizontalFlip(),
86 | ])
87 |
88 | self.all_gt_dual_transform = transforms.Compose([
89 | transforms.RandomAffine(degrees=0, scale=(0.5, 2.0), fill=0),
90 | transforms.RandomHorizontalFlip(),
91 | ])
92 |
93 | # Final transform without randomness
94 | self.final_im_transform = transforms.Compose([
95 | transforms.ToTensor(),
96 | im_normalization,
97 | ])
98 |
99 | self.final_gt_transform = transforms.Compose([
100 | transforms.ToTensor(),
101 | ])
102 |
103 | def _get_sample(self, idx):
104 | im = Image.open(self.im_list[idx]).convert('RGB')
105 | gt = Image.open(self.im_list[idx][:-3] + 'png').convert('L')
106 |
107 | sequence_seed = np.random.randint(2147483647)
108 |
109 | images = []
110 | masks = []
111 | for _ in range(self.num_frames):
112 | reseed(sequence_seed)
113 | this_im = self.all_im_dual_transform(im)
114 | this_im = self.all_im_lone_transform(this_im)
115 | reseed(sequence_seed)
116 | this_gt = self.all_gt_dual_transform(gt)
117 |
118 | pairwise_seed = np.random.randint(2147483647)
119 | reseed(pairwise_seed)
120 | this_im = self.pair_im_dual_transform(this_im)
121 | this_im = self.pair_im_lone_transform(this_im)
122 | reseed(pairwise_seed)
123 | this_gt = self.pair_gt_dual_transform(this_gt)
124 |
125 | # Use TPS only some of the times
126 | # Not because TPS is bad -- just that it is too slow and I need to speed up data loading
127 | if np.random.rand() < 0.33:
128 | this_im, this_gt = random_tps_warp(this_im, this_gt, scale=0.02)
129 |
130 | this_im = self.final_im_transform(this_im)
131 | this_gt = self.final_gt_transform(this_gt)
132 |
133 | images.append(this_im)
134 | masks.append(this_gt)
135 |
136 | images = torch.stack(images, 0)
137 | masks = torch.stack(masks, 0)
138 |
139 | return images, masks.numpy()
140 |
141 | def __getitem__(self, idx):
142 | additional_objects = np.random.randint(self.max_num_obj)
143 | indices = [idx, *np.random.randint(self.__len__(), size=additional_objects)]
144 |
145 | merged_images = None
146 | merged_masks = np.zeros((self.num_frames, self.size, self.size), dtype=np.int64)
147 |
148 | for i, list_id in enumerate(indices):
149 | images, masks = self._get_sample(list_id)
150 | if merged_images is None:
151 | merged_images = images
152 | else:
153 | merged_images = merged_images * (1 - masks) + images * masks
154 | merged_masks[masks[:, 0] > 0.5] = (i + 1)
155 |
156 | masks = merged_masks
157 |
158 | labels = np.unique(masks[0])
159 | # Remove background
160 | labels = labels[labels != 0]
161 | target_objects = labels.tolist()
162 |
163 | # Generate one-hot ground-truth
164 | cls_gt = np.zeros((self.num_frames, self.size, self.size), dtype=np.int64)
165 | first_frame_gt = np.zeros((1, self.max_num_obj, self.size, self.size), dtype=np.int64)
166 | for i, l in enumerate(target_objects):
167 | this_mask = (masks == l)
168 | cls_gt[this_mask] = i + 1
169 | first_frame_gt[0, i] = (this_mask[0])
170 | cls_gt = np.expand_dims(cls_gt, 1)
171 |
172 | info = {}
173 | info['name'] = self.im_list[idx]
174 | info['num_objects'] = max(1, len(target_objects))
175 |
176 | # 1 if object exist, 0 otherwise
177 | selector = [1 if i < info['num_objects'] else 0 for i in range(self.max_num_obj)]
178 | selector = torch.FloatTensor(selector)
179 |
180 | data = {
181 | 'rgb': merged_images,
182 | 'first_frame_gt': first_frame_gt,
183 | 'cls_gt': cls_gt,
184 | 'selector': selector,
185 | 'info': info
186 | }
187 |
188 | return data
189 |
190 | def __len__(self):
191 | return len(self.im_list)
192 |
--------------------------------------------------------------------------------
/deva/dataset/tps.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | import cv2
4 | import thinplate as tps
5 |
6 | cv2.setNumThreads(0)
7 |
8 | def pick_random_points(h, w, n_samples):
9 | y_idx = np.random.choice(np.arange(h), size=n_samples, replace=False)
10 | x_idx = np.random.choice(np.arange(w), size=n_samples, replace=False)
11 | return y_idx/h, x_idx/w
12 |
13 |
14 | def warp_dual_cv(img, mask, c_src, c_dst):
15 | dshape = img.shape
16 | theta = tps.tps_theta_from_points(c_src, c_dst, reduced=True)
17 | grid = tps.tps_grid(theta, c_dst, dshape)
18 | mapx, mapy = tps.tps_grid_to_remap(grid, img.shape)
19 | return cv2.remap(img, mapx, mapy, cv2.INTER_LINEAR), cv2.remap(mask, mapx, mapy, cv2.INTER_NEAREST)
20 |
21 |
22 | def random_tps_warp(img, mask, scale, n_ctrl_pts=12):
23 | """
24 | Apply a random TPS warp of the input image and mask
25 | Uses randomness from numpy
26 | """
27 | img = np.asarray(img)
28 | mask = np.asarray(mask)
29 |
30 | h, w = mask.shape
31 | points = pick_random_points(h, w, n_ctrl_pts)
32 | c_src = np.stack(points, 1)
33 | c_dst = c_src + np.random.normal(scale=scale, size=c_src.shape)
34 | warp_im, warp_gt = warp_dual_cv(img, mask, c_src, c_dst)
35 |
36 | return Image.fromarray(warp_im), Image.fromarray(warp_gt)
37 |
38 |
--------------------------------------------------------------------------------
/deva/dataset/utils.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | import torchvision.transforms as transforms
5 |
6 | im_mean = (124, 116, 104)
7 |
8 | im_normalization = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
9 |
10 | inv_im_trans = transforms.Normalize(mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
11 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
12 |
13 |
14 | def reseed(seed):
15 | random.seed(seed)
16 | torch.manual_seed(seed)
17 |
18 |
19 | def all_to_onehot(masks, labels):
20 | if len(masks.shape) == 3:
21 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1], masks.shape[2]), dtype=np.uint8)
22 | else:
23 | Ms = np.zeros((len(labels), masks.shape[0], masks.shape[1]), dtype=np.uint8)
24 |
25 | for ni, l in enumerate(labels):
26 | Ms[ni] = (masks == l).astype(np.uint8)
27 |
28 | return Ms
29 |
--------------------------------------------------------------------------------
/deva/ext/LightHQSAM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/LightHQSAM/__init__.py
--------------------------------------------------------------------------------
/deva/ext/LightHQSAM/setup_light_hqsam.py:
--------------------------------------------------------------------------------
1 | from deva.ext.LightHQSAM.tiny_vit_sam import TinyViT
2 | from segment_anything.modeling import MaskDecoderHQ, PromptEncoder, Sam, TwoWayTransformer
3 |
4 | def setup_model():
5 | prompt_embed_dim = 256
6 | image_size = 1024
7 | vit_patch_size = 16
8 | image_embedding_size = image_size // vit_patch_size
9 | mobile_sam = Sam(
10 | image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
11 | embed_dims=[64, 128, 160, 320],
12 | depths=[2, 2, 6, 2],
13 | num_heads=[2, 4, 5, 10],
14 | window_sizes=[7, 7, 14, 7],
15 | mlp_ratio=4.,
16 | drop_rate=0.,
17 | drop_path_rate=0.0,
18 | use_checkpoint=False,
19 | mbconv_expand_ratio=4.0,
20 | local_conv_size=3,
21 | layer_lr_decay=0.8
22 | ),
23 | prompt_encoder=PromptEncoder(
24 | embed_dim=prompt_embed_dim,
25 | image_embedding_size=(image_embedding_size, image_embedding_size),
26 | input_image_size=(image_size, image_size),
27 | mask_in_chans=16,
28 | ),
29 | mask_decoder=MaskDecoderHQ(
30 | num_multimask_outputs=3,
31 | transformer=TwoWayTransformer(
32 | depth=2,
33 | embedding_dim=prompt_embed_dim,
34 | mlp_dim=2048,
35 | num_heads=8,
36 | ),
37 | transformer_dim=prompt_embed_dim,
38 | iou_head_depth=3,
39 | iou_head_hidden_dim=256,
40 | vit_dim=160,
41 | ),
42 | pixel_mean=[123.675, 116.28, 103.53],
43 | pixel_std=[58.395, 57.12, 57.375],
44 | )
45 | return mobile_sam
--------------------------------------------------------------------------------
/deva/ext/MobileSAM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/MobileSAM/__init__.py
--------------------------------------------------------------------------------
/deva/ext/MobileSAM/setup_mobile_sam.py:
--------------------------------------------------------------------------------
1 | # https://github.com/IDEA-Research/Grounded-Segment-Anything/blob/main/EfficientSAM/MobileSAM/setup_mobile_sam.py
2 |
3 | from deva.ext.MobileSAM.tiny_vit_sam import TinyViT
4 | from segment_anything.modeling import MaskDecoder, PromptEncoder, Sam, TwoWayTransformer
5 |
6 |
7 | def setup_model():
8 | prompt_embed_dim = 256
9 | image_size = 1024
10 | vit_patch_size = 16
11 | image_embedding_size = image_size // vit_patch_size
12 | mobile_sam = Sam(
13 | image_encoder=TinyViT(img_size=1024,
14 | in_chans=3,
15 | num_classes=1000,
16 | embed_dims=[64, 128, 160, 320],
17 | depths=[2, 2, 6, 2],
18 | num_heads=[2, 4, 5, 10],
19 | window_sizes=[7, 7, 14, 7],
20 | mlp_ratio=4.,
21 | drop_rate=0.,
22 | drop_path_rate=0.0,
23 | use_checkpoint=False,
24 | mbconv_expand_ratio=4.0,
25 | local_conv_size=3,
26 | layer_lr_decay=0.8),
27 | prompt_encoder=PromptEncoder(
28 | embed_dim=prompt_embed_dim,
29 | image_embedding_size=(image_embedding_size, image_embedding_size),
30 | input_image_size=(image_size, image_size),
31 | mask_in_chans=16,
32 | ),
33 | mask_decoder=MaskDecoder(
34 | num_multimask_outputs=3,
35 | transformer=TwoWayTransformer(
36 | depth=2,
37 | embedding_dim=prompt_embed_dim,
38 | mlp_dim=2048,
39 | num_heads=8,
40 | ),
41 | transformer_dim=prompt_embed_dim,
42 | iou_head_depth=3,
43 | iou_head_hidden_dim=256,
44 | ),
45 | pixel_mean=[123.675, 116.28, 103.53],
46 | pixel_std=[58.395, 57.12, 57.375],
47 | )
48 | return mobile_sam
--------------------------------------------------------------------------------
/deva/ext/SAM/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/SAM/__init__.py
--------------------------------------------------------------------------------
/deva/ext/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/ext/__init__.py
--------------------------------------------------------------------------------
/deva/ext/automatic_processor.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | from typing import Dict, List, Optional
3 |
4 | import cv2
5 | import torch
6 | import numpy as np
7 |
8 | from deva.inference.object_info import ObjectInfo
9 | from deva.inference.inference_core import DEVAInferenceCore
10 | from deva.inference.frame_utils import FrameInfo
11 | from deva.inference.result_utils import ResultSaver
12 | from deva.inference.demo_utils import get_input_frame_for_deva
13 | from deva.ext.automatic_sam import auto_segment
14 | from deva.utils.tensor_utils import pad_divide_by, unpad
15 |
16 | from segment_anything import SamAutomaticMaskGenerator
17 |
18 |
19 | def make_segmentation(cfg: Dict, image_np: np.ndarray, forward_mask: Optional[torch.Tensor],
20 | sam_model: SamAutomaticMaskGenerator, min_side: int,
21 | suppress_small_mask: bool) -> (torch.Tensor, List[ObjectInfo]):
22 | mask, segments_info = auto_segment(cfg, sam_model, image_np, forward_mask, min_side,
23 | suppress_small_mask)
24 | return mask, segments_info
25 |
26 |
27 | @torch.inference_mode()
28 | def process_frame_automatic(deva: DEVAInferenceCore,
29 | sam_model: SamAutomaticMaskGenerator,
30 | frame_path: str,
31 | result_saver: ResultSaver,
32 | ti: int,
33 | image_np: np.ndarray = None) -> None:
34 | # image_np, if given, should be in RGB
35 | if image_np is None:
36 | image_np = cv2.imread(frame_path)
37 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
38 | cfg = deva.config
39 |
40 | h, w = image_np.shape[:2]
41 | new_min_side = cfg['size']
42 | suppress_small_mask = cfg['suppress_small_objects']
43 | need_resize = new_min_side > 0
44 | image = get_input_frame_for_deva(image_np, new_min_side)
45 |
46 | frame_name = path.basename(frame_path)
47 | frame_info = FrameInfo(image, None, None, ti, {
48 | 'frame': [frame_name],
49 | 'shape': [h, w],
50 | })
51 |
52 | if cfg['temporal_setting'] == 'semionline':
53 | if ti + cfg['num_voting_frames'] > deva.next_voting_frame:
54 | # getting a forward mask
55 | if deva.memory.engaged:
56 | forward_mask = estimate_forward_mask(deva, image)
57 | else:
58 | forward_mask = None
59 |
60 | mask, segments_info = make_segmentation(cfg, image_np, forward_mask, sam_model,
61 | new_min_side, suppress_small_mask)
62 | frame_info.mask = mask
63 | frame_info.segments_info = segments_info
64 | frame_info.image_np = image_np # for visualization only
65 | # wait for more frames before proceeding
66 | deva.add_to_temporary_buffer(frame_info)
67 |
68 | if ti == deva.next_voting_frame:
69 | # process this clip
70 | this_image = deva.frame_buffer[0].image
71 | this_frame_name = deva.frame_buffer[0].name
72 | this_image_np = deva.frame_buffer[0].image_np
73 |
74 | _, mask, new_segments_info = deva.vote_in_temporary_buffer(
75 | keyframe_selection='first')
76 | prob = deva.incorporate_detection(this_image,
77 | mask,
78 | new_segments_info,
79 | incremental=True)
80 | deva.next_voting_frame += cfg['detection_every']
81 |
82 | result_saver.save_mask(prob,
83 | this_frame_name,
84 | need_resize=need_resize,
85 | shape=(h, w),
86 | image_np=this_image_np)
87 |
88 | for frame_info in deva.frame_buffer[1:]:
89 | this_image = frame_info.image
90 | this_frame_name = frame_info.name
91 | this_image_np = frame_info.image_np
92 | prob = deva.step(this_image, None, None)
93 | result_saver.save_mask(prob,
94 | this_frame_name,
95 | need_resize,
96 | shape=(h, w),
97 | image_np=this_image_np)
98 |
99 | deva.clear_buffer()
100 | else:
101 | # standard propagation
102 | prob = deva.step(image, None, None)
103 | result_saver.save_mask(prob,
104 | frame_name,
105 | need_resize=need_resize,
106 | shape=(h, w),
107 | image_np=image_np)
108 |
109 | elif cfg['temporal_setting'] == 'online':
110 | if ti % cfg['detection_every'] == 0:
111 | # incorporate new detections
112 | if deva.memory.engaged:
113 | forward_mask = estimate_forward_mask(deva, image)
114 | else:
115 | forward_mask = None
116 |
117 | mask, segments_info = make_segmentation(cfg, image_np, forward_mask, sam_model,
118 | new_min_side, suppress_small_mask)
119 | frame_info.segments_info = segments_info
120 | prob = deva.incorporate_detection(image, mask, segments_info, incremental=True)
121 | else:
122 | # Run the model on this frame
123 | prob = deva.step(image, None, None)
124 | result_saver.save_mask(prob,
125 | frame_name,
126 | need_resize=need_resize,
127 | shape=(h, w),
128 | image_np=image_np)
129 |
130 |
131 | def estimate_forward_mask(deva: DEVAInferenceCore, image: torch.Tensor):
132 | image, pad = pad_divide_by(image, 16)
133 | image = image.unsqueeze(0) # add the batch dimension
134 |
135 | ms_features = deva.image_feature_store.get_ms_features(deva.curr_ti + 1, image)
136 | key, _, selection = deva.image_feature_store.get_key(deva.curr_ti + 1, image)
137 | prob = deva._segment(key, selection, ms_features)
138 | forward_mask = torch.argmax(prob, dim=0)
139 | forward_mask = unpad(forward_mask, pad)
140 | return forward_mask
141 |
--------------------------------------------------------------------------------
/deva/ext/automatic_sam.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/IDEA-Research/Grounded-Segment-Anything
2 |
3 | from typing import Dict, List, Optional
4 | import numpy as np
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from segment_anything import sam_model_registry
11 | from deva.ext.SAM.automatic_mask_generator import SamAutomaticMaskGenerator
12 | from deva.ext.MobileSAM.setup_mobile_sam import setup_model as setup_mobile_sam
13 | from deva.inference.object_info import ObjectInfo
14 |
15 |
16 | def get_sam_model(config: Dict, device: str) -> SamAutomaticMaskGenerator:
17 | variant = config['sam_variant'].lower()
18 | if variant == 'mobile':
19 | MOBILE_SAM_CHECKPOINT_PATH = config['MOBILE_SAM_CHECKPOINT_PATH']
20 |
21 | # Building Mobile SAM model
22 | checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
23 | mobile_sam = setup_mobile_sam()
24 | mobile_sam.load_state_dict(checkpoint, strict=True)
25 | mobile_sam.to(device=device)
26 | auto_sam = SamAutomaticMaskGenerator(mobile_sam,
27 | points_per_side=config['SAM_NUM_POINTS_PER_SIDE'],
28 | points_per_batch=config['SAM_NUM_POINTS_PER_BATCH'],
29 | pred_iou_thresh=config['SAM_PRED_IOU_THRESHOLD'])
30 | elif variant == 'original':
31 | SAM_ENCODER_VERSION = config['SAM_ENCODER_VERSION']
32 | SAM_CHECKPOINT_PATH = config['SAM_CHECKPOINT_PATH']
33 |
34 | # Building SAM Model and SAM Predictor
35 | sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(
36 | device=device)
37 | auto_sam = SamAutomaticMaskGenerator(sam,
38 | points_per_side=config['SAM_NUM_POINTS_PER_SIDE'],
39 | points_per_batch=config['SAM_NUM_POINTS_PER_BATCH'],
40 | pred_iou_thresh=config['SAM_PRED_IOU_THRESHOLD'])
41 | else:
42 | raise ValueError(f'Unknown SAM variant: {config["SAM_VARIANT"]}')
43 |
44 | return auto_sam
45 |
46 |
47 | def auto_segment(config: Dict, auto_sam: SamAutomaticMaskGenerator, image: np.ndarray,
48 | forward_mask: Optional[torch.Tensor], min_side: int,
49 | suppress_small_mask: bool) -> (torch.Tensor, List[ObjectInfo]):
50 | """
51 | config: the global configuration dictionary
52 | image: the image to segment; should be a numpy array; H*W*3; unnormalized (0~255)
53 | forward_mask: the mask used to determine positive/negative points; H*W
54 |
55 | Returns: a torch index mask of the same size as image; H*W
56 | a list of segment info, see object_utils.py for definition
57 | """
58 | device = auto_sam.predictor.device
59 |
60 | h, w = image.shape[:2]
61 | if min_side > 0:
62 | scale = min_side / min(h, w)
63 | new_h, new_w = int(h * scale), int(w * scale)
64 | else:
65 | new_h, new_w = h, w
66 |
67 | if forward_mask is not None:
68 | # compute positive and negative points
69 | foreground_mask = (forward_mask > 0).float().unsqueeze(0).unsqueeze(0)
70 | foreground_mask = F.interpolate(foreground_mask,
71 | scale_factor=1 / 16,
72 | mode='bilinear',
73 | antialias=True) # blurring
74 | n_per_side = config['SAM_NUM_POINTS_PER_SIDE']
75 | offset = 1 / (2 * n_per_side)
76 | points_one_side = torch.linspace(offset, 1 - offset, n_per_side, device=device)
77 | points_x = points_one_side.unsqueeze(0).repeat(n_per_side, 1)
78 | points_y = points_one_side.unsqueeze(1).repeat(1, n_per_side)
79 | points = torch.stack([points_x, points_y], dim=-1).unsqueeze(0)
80 | points_label = F.grid_sample(foreground_mask, points * 2 - 1, align_corners=False).view(-1)
81 | points = points.view(-1, 2)
82 | positive_points = points[points_label < 0.01].cpu().numpy()
83 | if len(positive_points) == 0:
84 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=device)
85 | segments_info = []
86 | return output_mask, segments_info
87 | # negative_points = points[points_label >= 0.5].cpu().numpy()
88 | negative_points = None # no negative points
89 | mask_data = auto_sam.generate(image, positive_points, negative_points)
90 | else:
91 | mask_data = auto_sam.generate(image)
92 |
93 | curr_id = 1
94 | segments_info = []
95 |
96 | pred_masks = mask_data['masks'].float() # num masks * H * W
97 | predicted_iou = mask_data["iou_preds"]
98 |
99 | # score mask by their areas
100 | if pred_masks.shape[0] == 0:
101 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=device)
102 | else:
103 | pred_masks = F.interpolate(pred_masks.unsqueeze(0), (new_h, new_w), mode='bilinear')[0]
104 |
105 | curr_id = 1
106 | if suppress_small_mask:
107 | areas = pred_masks.flatten(-2).sum(-1)
108 | scores = areas.unsqueeze(-1).unsqueeze(-1)
109 |
110 | scored_masks = pred_masks * scores
111 | scored_masks_with_bg = torch.cat(
112 | [torch.zeros((1, *pred_masks.shape[1:]), device=device) + 0.1, scored_masks], dim=0)
113 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=device)
114 |
115 | # let large mask eats small masks (small/tiny/incomplete masks are too common in SAM)
116 | hard_mask = torch.argmax(scored_masks_with_bg, dim=0)
117 | for k in range(scores.shape[0]):
118 | mask_area = (hard_mask == (k + 1)).sum()
119 | original_area = (pred_masks[k] > 0.5).sum()
120 | mask = (hard_mask == (k + 1)) & (pred_masks[k] >= 0.5)
121 |
122 | if mask_area > 0 and original_area > 0 and mask.sum() > 0:
123 | if mask_area / original_area < config['SAM_OVERLAP_THRESHOLD']:
124 | continue
125 | output_mask[mask] = curr_id
126 | segments_info.append(ObjectInfo(id=curr_id, score=predicted_iou[k].item()))
127 | curr_id += 1
128 | else:
129 | # prefer smaller objects
130 | areas = pred_masks.flatten(-2).sum(-1)
131 | scores = (areas.max() * 2 - areas).unsqueeze(-1).unsqueeze(-1)
132 | scored_masks = pred_masks * scores
133 |
134 | # add background channel
135 | scored_masks_with_bg = torch.cat(
136 | [torch.zeros((1, *scored_masks.shape[1:]), device=device) + 0.1, scored_masks],
137 | dim=0)
138 | output_mask = torch.argmax(scored_masks_with_bg, dim=0)
139 | for k in range(scored_masks.shape[0]):
140 | mask = (output_mask == (k + 1))
141 | if mask.sum() > 0:
142 | segments_info.append(ObjectInfo(id=curr_id, score=predicted_iou[k].item()))
143 | curr_id += 1
144 |
145 | return output_mask, segments_info
146 |
--------------------------------------------------------------------------------
/deva/ext/ext_eval_args.py:
--------------------------------------------------------------------------------
1 | # Evaluation arguments for extensions
2 | from argparse import ArgumentParser
3 |
4 |
5 | def add_ext_eval_args(parser: ArgumentParser):
6 |
7 | # Grounded Segment Anything
8 | parser.add_argument('--GROUNDING_DINO_CONFIG_PATH',
9 | default='./saves/GroundingDINO_SwinT_OGC.py')
10 |
11 | parser.add_argument('--GROUNDING_DINO_CHECKPOINT_PATH',
12 | default='./saves/groundingdino_swint_ogc.pth')
13 |
14 | parser.add_argument('--DINO_THRESHOLD', default=0.35, type=float)
15 | parser.add_argument('--DINO_NMS_THRESHOLD', default=0.8, type=float)
16 |
17 | # Segment Anything (SAM) models
18 | parser.add_argument('--SAM_ENCODER_VERSION', default='vit_h')
19 | parser.add_argument('--SAM_CHECKPOINT_PATH', default='./saves/sam_vit_h_4b8939.pth')
20 |
21 | # HQ-SAM
22 | parser.add_argument('--HQ_SAM_CHECKPOINT_PATH', default='./saves/sam_hq_vit_h.pth')
23 |
24 | # Light HQ-SAM
25 | parser.add_argument('--LIGHT_HQ_SAM_CHECKPOINT_PATH', default='./saves/sam_hq_vit_tiny.pth')
26 |
27 | # Mobile SAM
28 | parser.add_argument('--MOBILE_SAM_CHECKPOINT_PATH', default='./saves/mobile_sam.pt')
29 |
30 | # Segment Anything (SAM) parameters
31 | parser.add_argument('--SAM_NUM_POINTS_PER_SIDE',
32 | type=int,
33 | help='Number of points per side for prompting SAM',
34 | default=64)
35 | parser.add_argument('--SAM_NUM_POINTS_PER_BATCH',
36 | type=int,
37 | help='Number of points computed per batch',
38 | default=64)
39 | parser.add_argument('--SAM_PRED_IOU_THRESHOLD',
40 | type=float,
41 | help='(Predicted) IoU threshold for SAM',
42 | default=0.88)
43 | parser.add_argument('--SAM_OVERLAP_THRESHOLD',
44 | type=float,
45 | help='Overlap threshold for overlapped mask suppression in SAM',
46 | default=0.8)
47 |
48 |
49 | def add_text_default_args(parser):
50 | parser.add_argument('--img_path', default='./example/vipseg')
51 | parser.add_argument('--detection_every', type=int, default=5)
52 | parser.add_argument('--num_voting_frames',
53 | default=3,
54 | type=int,
55 | help='Number of frames selected for voting. only valid in semionline')
56 |
57 | parser.add_argument('--temporal_setting', default='semionline', help='semionline/online')
58 | parser.add_argument('--max_missed_detection_count', type=int, default=10)
59 | parser.add_argument('--max_num_objects',
60 | default=-1,
61 | type=int,
62 | help='Max. num of objects to keep in memory. -1 for no limit')
63 | parser.add_argument('--prompt', type=str, help='Separate classes with a single fullstop')
64 | parser.add_argument('--sam_variant', default='original', help='mobile/original')
65 | return parser
66 |
67 |
68 | def add_auto_default_args(parser):
69 | parser.add_argument('--img_path', default='./example/vipseg')
70 | parser.add_argument('--detection_every', type=int, default=5)
71 | parser.add_argument('--num_voting_frames',
72 | default=3,
73 | type=int,
74 | help='Number of frames selected for voting. only valid in semionline')
75 |
76 | parser.add_argument('--temporal_setting', default='semionline', help='semionline/online')
77 | parser.add_argument('--max_missed_detection_count', type=int, default=5)
78 | parser.add_argument('--max_num_objects',
79 | default=200,
80 | type=int,
81 | help='Max. num of objects to keep in memory. -1 for no limit')
82 |
83 | parser.add_argument('--sam_variant', default='original', help='mobile/original')
84 | parser.add_argument('--suppress_small_objects', action='store_true')
85 |
86 | return parser
--------------------------------------------------------------------------------
/deva/ext/grounding_dino.py:
--------------------------------------------------------------------------------
1 | # Reference: https://github.com/IDEA-Research/Grounded-Segment-Anything
2 |
3 | from typing import Dict, List
4 | import numpy as np
5 | import cv2
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 |
11 | try:
12 | from groundingdino.util.inference import Model as GroundingDINOModel
13 | except ImportError:
14 | # not sure why this happens sometimes
15 | from GroundingDINO.groundingdino.util.inference import Model as GroundingDINOModel
16 | from segment_anything import sam_model_registry, SamPredictor
17 | try:
18 | from segment_anything import sam_hq_model_registry
19 | except ImportError:
20 | print("HQ-SAM not found, please install it from https://github.com/SysCV/sam-hq")
21 | from deva.ext.MobileSAM.setup_mobile_sam import setup_model as setup_mobile_sam
22 | try:
23 | from deva.ext.LightHQSAM.setup_light_hqsam import setup_model as setup_light_hqsam
24 | except ImportError:
25 | print("Light HQ-SAM not found, please install it from https://github.com/SysCV/sam-hq")
26 | import numpy as np
27 | import torch
28 |
29 | from deva.inference.object_info import ObjectInfo
30 |
31 |
32 | def get_grounding_dino_model(config: Dict, device: str) -> (GroundingDINOModel, SamPredictor):
33 | GROUNDING_DINO_CONFIG_PATH = config['GROUNDING_DINO_CONFIG_PATH']
34 | GROUNDING_DINO_CHECKPOINT_PATH = config['GROUNDING_DINO_CHECKPOINT_PATH']
35 |
36 | gd_model = GroundingDINOModel(model_config_path=GROUNDING_DINO_CONFIG_PATH,
37 | model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH,
38 | device=device)
39 |
40 | # Building SAM Model and SAM Predictor
41 | variant = config['sam_variant'].lower()
42 | if variant == 'mobile':
43 | MOBILE_SAM_CHECKPOINT_PATH = config['MOBILE_SAM_CHECKPOINT_PATH']
44 |
45 | # Building Mobile SAM model
46 | checkpoint = torch.load(MOBILE_SAM_CHECKPOINT_PATH)
47 | mobile_sam = setup_mobile_sam()
48 | mobile_sam.load_state_dict(checkpoint, strict=True)
49 | mobile_sam.to(device=device)
50 | sam = SamPredictor(mobile_sam)
51 | elif variant == 'original':
52 | SAM_ENCODER_VERSION = config['SAM_ENCODER_VERSION']
53 | SAM_CHECKPOINT_PATH = config['SAM_CHECKPOINT_PATH']
54 |
55 | sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH).to(
56 | device=device)
57 | sam = SamPredictor(sam)
58 | elif variant == 'sam_hq':
59 | # Building HQ-SAM model with better Mask quality
60 | SAM_ENCODER_VERSION = config['SAM_ENCODER_VERSION']
61 | HQ_SAM_CHECKPOINT_PATH = config['HQ_SAM_CHECKPOINT_PATH']
62 | sam_hq = sam_hq_model_registry[SAM_ENCODER_VERSION](checkpoint=HQ_SAM_CHECKPOINT_PATH).to(
63 | device=device)
64 | sam = SamPredictor(sam_hq)
65 | elif variant == 'sam_hq_light':
66 | LIGHT_HQ_SAM_CHECKPOINT_PATH = config['LIGHT_HQ_SAM_CHECKPOINT_PATH']
67 |
68 | # Building Light HQ-SAM model with good Mask quality and efficiency
69 | checkpoint = torch.load(LIGHT_HQ_SAM_CHECKPOINT_PATH)
70 | light_hq_sam = setup_light_hqsam()
71 | light_hq_sam.load_state_dict(checkpoint, strict=True)
72 | light_hq_sam.to(device=device)
73 | sam = SamPredictor(light_hq_sam)
74 |
75 | return gd_model, sam
76 |
77 |
78 | def segment_with_text(config: Dict, gd_model: GroundingDINOModel, sam: SamPredictor,
79 | image: np.ndarray, prompts: List[str],
80 | min_side: int) -> (torch.Tensor, List[ObjectInfo]):
81 | """
82 | config: the global configuration dictionary
83 | image: the image to segment; should be a numpy array; H*W*3; unnormalized (0~255)
84 | prompts: list of class names
85 |
86 | Returns: a torch index mask of the same size as image; H*W
87 | a list of segment info, see object_utils.py for definition
88 | """
89 |
90 | BOX_THRESHOLD = TEXT_THRESHOLD = config['DINO_THRESHOLD']
91 | NMS_THRESHOLD = config['DINO_NMS_THRESHOLD']
92 |
93 | sam.set_image(image, image_format='RGB')
94 |
95 | # detect objects
96 | # GroundingDINO uses BGR
97 | detections = gd_model.predict_with_classes(image=cv2.cvtColor(image, cv2.COLOR_RGB2BGR),
98 | classes=prompts,
99 | box_threshold=BOX_THRESHOLD,
100 | text_threshold=TEXT_THRESHOLD)
101 | nms_idx = torchvision.ops.nms(torch.from_numpy(detections.xyxy),
102 | torch.from_numpy(detections.confidence),
103 | NMS_THRESHOLD).numpy().tolist()
104 |
105 | detections.xyxy = detections.xyxy[nms_idx]
106 | detections.confidence = detections.confidence[nms_idx]
107 | detections.class_id = detections.class_id[nms_idx]
108 |
109 | result_masks = []
110 | for box in detections.xyxy:
111 | masks, scores, _ = sam.predict(box=box, multimask_output=True)
112 | index = np.argmax(scores)
113 | result_masks.append(masks[index])
114 |
115 | detections.mask = np.array(result_masks)
116 |
117 | h, w = image.shape[:2]
118 | if min_side > 0:
119 | scale = min_side / min(h, w)
120 | new_h, new_w = int(h * scale), int(w * scale)
121 | else:
122 | new_h, new_w = h, w
123 |
124 | output_mask = torch.zeros((new_h, new_w), dtype=torch.int64, device=gd_model.device)
125 | curr_id = 1
126 | segments_info = []
127 |
128 | # sort by descending area to preserve the smallest object
129 | for i in np.flip(np.argsort(detections.area)):
130 | mask = detections.mask[i]
131 | confidence = detections.confidence[i]
132 | class_id = detections.class_id[i]
133 | mask = torch.from_numpy(mask.astype(np.float32))
134 | mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), (new_h, new_w), mode='bilinear')[0, 0]
135 | mask = (mask > 0.5).float()
136 |
137 | if mask.sum() > 0:
138 | output_mask[mask > 0] = curr_id
139 | segments_info.append(ObjectInfo(id=curr_id, category_id=class_id, score=confidence))
140 | curr_id += 1
141 |
142 | return output_mask, segments_info
143 |
--------------------------------------------------------------------------------
/deva/ext/with_text_processor.py:
--------------------------------------------------------------------------------
1 | from os import path
2 | from typing import Dict, List
3 |
4 | import cv2
5 | import torch
6 | import numpy as np
7 |
8 | from deva.inference.object_info import ObjectInfo
9 | from deva.inference.inference_core import DEVAInferenceCore
10 | from deva.inference.frame_utils import FrameInfo
11 | from deva.inference.result_utils import ResultSaver
12 | from deva.inference.demo_utils import get_input_frame_for_deva
13 | from deva.ext.grounding_dino import segment_with_text
14 | try:
15 | from groundingdino.util.inference import Model as GroundingDINOModel
16 | except ImportError:
17 | # not sure why this happens sometimes
18 | from GroundingDINO.groundingdino.util.inference import Model as GroundingDINOModel
19 | from segment_anything import SamPredictor
20 |
21 |
22 | def make_segmentation_with_text(cfg: Dict, image_np: np.ndarray, gd_model: GroundingDINOModel,
23 | sam_model: SamPredictor, prompts: List[str],
24 | min_side: int) -> (torch.Tensor, List[ObjectInfo]):
25 | mask, segments_info = segment_with_text(cfg, gd_model, sam_model, image_np, prompts, min_side)
26 | return mask, segments_info
27 |
28 |
29 | @torch.inference_mode()
30 | def process_frame_with_text(deva: DEVAInferenceCore,
31 | gd_model: GroundingDINOModel,
32 | sam_model: SamPredictor,
33 | frame_path: str,
34 | result_saver: ResultSaver,
35 | ti: int,
36 | image_np: np.ndarray = None) -> None:
37 | # image_np, if given, should be in RGB
38 | if image_np is None:
39 | image_np = cv2.imread(frame_path)
40 | image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
41 | cfg = deva.config
42 | raw_prompt = cfg['prompt']
43 | prompts = raw_prompt.split('.')
44 |
45 | h, w = image_np.shape[:2]
46 | new_min_side = cfg['size']
47 | need_resize = new_min_side > 0
48 | image = get_input_frame_for_deva(image_np, new_min_side)
49 |
50 | frame_name = path.basename(frame_path)
51 | frame_info = FrameInfo(image, None, None, ti, {
52 | 'frame': [frame_name],
53 | 'shape': [h, w],
54 | })
55 |
56 | if cfg['temporal_setting'] == 'semionline':
57 | if ti + cfg['num_voting_frames'] > deva.next_voting_frame:
58 | mask, segments_info = make_segmentation_with_text(cfg, image_np, gd_model, sam_model,
59 | prompts, new_min_side)
60 | frame_info.mask = mask
61 | frame_info.segments_info = segments_info
62 | frame_info.image_np = image_np # for visualization only
63 | # wait for more frames before proceeding
64 | deva.add_to_temporary_buffer(frame_info)
65 |
66 | if ti == deva.next_voting_frame:
67 | # process this clip
68 | this_image = deva.frame_buffer[0].image
69 | this_frame_name = deva.frame_buffer[0].name
70 | this_image_np = deva.frame_buffer[0].image_np
71 |
72 | _, mask, new_segments_info = deva.vote_in_temporary_buffer(
73 | keyframe_selection='first')
74 | prob = deva.incorporate_detection(this_image, mask, new_segments_info)
75 | deva.next_voting_frame += cfg['detection_every']
76 |
77 | result_saver.save_mask(prob,
78 | this_frame_name,
79 | need_resize=need_resize,
80 | shape=(h, w),
81 | image_np=this_image_np,
82 | prompts=prompts)
83 |
84 | for frame_info in deva.frame_buffer[1:]:
85 | this_image = frame_info.image
86 | this_frame_name = frame_info.name
87 | this_image_np = frame_info.image_np
88 | prob = deva.step(this_image, None, None)
89 | result_saver.save_mask(prob,
90 | this_frame_name,
91 | need_resize,
92 | shape=(h, w),
93 | image_np=this_image_np,
94 | prompts=prompts)
95 |
96 | deva.clear_buffer()
97 | else:
98 | # standard propagation
99 | prob = deva.step(image, None, None)
100 | result_saver.save_mask(prob,
101 | frame_name,
102 | need_resize=need_resize,
103 | shape=(h, w),
104 | image_np=image_np,
105 | prompts=prompts)
106 |
107 | elif cfg['temporal_setting'] == 'online':
108 | if ti % cfg['detection_every'] == 0:
109 | # incorporate new detections
110 | mask, segments_info = make_segmentation_with_text(cfg, image_np, gd_model, sam_model,
111 | prompts, new_min_side)
112 | frame_info.segments_info = segments_info
113 | prob = deva.incorporate_detection(image, mask, segments_info)
114 | else:
115 | # Run the model on this frame
116 | prob = deva.step(image, None, None)
117 | result_saver.save_mask(prob,
118 | frame_name,
119 | need_resize=need_resize,
120 | shape=(h, w),
121 | image_np=image_np,
122 | prompts=prompts)
123 |
--------------------------------------------------------------------------------
/deva/inference/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/inference/__init__.py
--------------------------------------------------------------------------------
/deva/inference/consensus_associated.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the implementation of the consensus when the association is already established.
3 | E.g., when we know which mask in frame 1 corresponds to which mask in frame 2.
4 | There is no need to use integer programming for matching.
5 | """
6 |
7 | from typing import List, Dict
8 | import torch
9 |
10 | from deva.model.memory_utils import *
11 | from deva.model.network import DEVA
12 | from deva.inference.image_feature_store import ImageFeatureStore
13 | from deva.utils.tensor_utils import pad_divide_by, unpad
14 |
15 |
16 | def spatial_alignment(src_ti: int, src_image: torch.Tensor, src_mask: torch.Tensor, tar_ti: int,
17 | tar_image: torch.Tensor, network: DEVA, store: ImageFeatureStore,
18 | config: Dict) -> torch.Tensor:
19 | """
20 | src_image/tar_image: 3*H*W
21 | src_mask: num_objects*H*W
22 |
23 | returns: a segmentation mask of the target image: num_objects*H*W
24 | """
25 | num_objects, h, w = src_mask.shape
26 | src_image = src_image.unsqueeze(0)
27 | tar_image = tar_image.unsqueeze(0)
28 | src_mask = src_mask.unsqueeze(0)
29 |
30 | # get source features
31 | src_ms_features = store.get_ms_features(src_ti, src_image)
32 | src_key, src_shrinkage, _ = store.get_key(src_ti, src_image)
33 | # get target features
34 | tar_ms_features = store.get_ms_features(tar_ti, tar_image)
35 | tar_key, _, tar_selection = store.get_key(tar_ti, tar_image)
36 |
37 | # encode memory from the source frame
38 | sensory = torch.zeros((1, num_objects, config['value_dim'], h // 16, w // 16),
39 | device=src_key.device)
40 | value, sensory = network.encode_mask(src_image,
41 | src_ms_features,
42 | sensory,
43 | src_mask,
44 | is_deep_update=True,
45 | chunk_size=config['chunk_size'])
46 |
47 | # key matching
48 | src_key = src_key.flatten(start_dim=2)
49 | src_shrinkage = src_shrinkage.flatten(start_dim=2)
50 | tar_key = tar_key.flatten(start_dim=2)
51 | tar_selection = tar_selection.flatten(start_dim=2)
52 | # 1*num_objects*C*H*W -> 1*(num_objects*C)*(H*W)
53 | value = value.flatten(start_dim=1, end_dim=2).flatten(start_dim=2)
54 |
55 | similarity = get_similarity(src_key, src_shrinkage, tar_key, tar_selection)
56 | affinity = do_softmax(similarity, top_k=config['top_k'])
57 | # readout
58 | memory_readout = value @ affinity
59 | memory_readout = memory_readout.view(1, num_objects, config['value_dim'], h // 16, w // 16)
60 |
61 | # segmentation
62 | _, _, tar_mask = network.segment(tar_ms_features,
63 | memory_readout,
64 | sensory,
65 | src_mask,
66 | chunk_size=config['chunk_size'],
67 | update_sensory=False)
68 |
69 | return tar_mask
70 |
71 |
72 | def _keyframe_objective_from_mask(mask, score, method='high_foreground') -> float:
73 | # compute a good-to-be-keyframe score of a mask
74 | if method == 'high_foreground':
75 | return (mask > 0.8).float().mean()
76 | elif method == 'score':
77 | return score
78 | else:
79 | raise NotImplementedError
80 |
81 |
82 | def find_consensus_with_established_association(time_indices: List[int],
83 | images: List[torch.Tensor],
84 | masks: List[torch.Tensor],
85 | network: DEVA,
86 | store: ImageFeatureStore,
87 | config: Dict,
88 | scores: List[float] = None) -> (int, torch.Tensor):
89 |
90 | # apply padding to all images and masks
91 | for i, (image, mask) in enumerate(zip(images, masks)):
92 | images[i], pads = pad_divide_by(image, 16)
93 | masks[i], _ = pad_divide_by(mask, 16)
94 |
95 | # if scores is None, assume uniform (for averaging later on)
96 | if scores is None:
97 | scores = [1 for _ in time_indices]
98 | use_score = False
99 | else:
100 | use_score = True
101 | scores = torch.softmax(torch.Tensor(scores) * 2, dim=0).tolist()
102 |
103 | # first, find a keyframe
104 | keyframe_objective = float('-inf')
105 | keyframe_ti = None
106 | keyframe_image = None
107 | keyframe_mask = None
108 | keyframe_score = None
109 |
110 | if use_score:
111 | # ranking with score
112 | for ti, image, mask, score in zip(time_indices, images, masks, scores):
113 | objective = _keyframe_objective_from_mask(mask, score, method='score')
114 | if objective > keyframe_objective:
115 | keyframe_objective = objective
116 | keyframe_ti = ti
117 | keyframe_image = image
118 | keyframe_mask = mask
119 | keyframe_score = score
120 | else:
121 | # score-less ranking
122 | score = None
123 | for ti, image, mask in zip(time_indices, images, masks):
124 | objective = _keyframe_objective_from_mask(mask, score, method='high_foreground')
125 | if objective > keyframe_objective:
126 | keyframe_objective = objective
127 | keyframe_ti = ti
128 | keyframe_image = image
129 | keyframe_mask = mask
130 | keyframe_score = score
131 |
132 | if keyframe_score is None:
133 | keyframe_score = scores[0]
134 |
135 | # then, project all frames onto the keyframe
136 | # we also project the keyframe onto the keyframe itself for mask refinement
137 | total_projected_mask = keyframe_mask * keyframe_score
138 | for ti, image, mask, score in zip(time_indices, images, masks, scores):
139 | # the keyframe is already added
140 | if ti == keyframe_ti:
141 | continue
142 | projected_mask = spatial_alignment(ti, image, mask, keyframe_ti, keyframe_image, network,
143 | store, config)
144 | total_projected_mask += projected_mask[0, 1:] * score
145 |
146 | total_projected_mask = unpad(total_projected_mask, pads)
147 | return keyframe_ti, total_projected_mask
148 |
--------------------------------------------------------------------------------
/deva/inference/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/inference/data/__init__.py
--------------------------------------------------------------------------------
/deva/inference/data/detection_video_reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | from torch.utils.data.dataset import Dataset
5 | from torchvision import transforms
6 | from torchvision.transforms import InterpolationMode
7 | import torch.nn.functional as F
8 | from PIL import Image
9 | import numpy as np
10 |
11 | from deva.dataset.utils import im_normalization
12 |
13 |
14 | class DetectionVideoReader(Dataset):
15 | """
16 | This class is used to read a video, one frame at a time
17 | """
18 | def __init__(self,
19 | vid_name,
20 | image_dir,
21 | mask_dir,
22 | size=-1,
23 | to_save=None,
24 | size_dir=None,
25 | start=-1,
26 | end=-1,
27 | reverse=False):
28 | """
29 | image_dir - points to a directory of jpg images
30 | mask_dir - points to a directory of png masks
31 | size - resize min. side to size. Does nothing if <0.
32 | to_save - optionally contains a list of file names without extensions
33 | where the segmentation mask is required
34 | """
35 | # TODO: determine if_rgb automatically
36 | self.vid_name = vid_name
37 | self.image_dir = image_dir
38 | self.mask_dir = mask_dir
39 | self.to_save = to_save
40 | if size_dir is None:
41 | self.size_dir = self.image_dir
42 | else:
43 | self.size_dir = size_dir
44 |
45 | self.frames = sorted(os.listdir(self.image_dir))
46 | if start > 0:
47 | self.frames = self.frames[start:]
48 | if end > 0:
49 | self.frames = self.frames[:end]
50 | if reverse:
51 | self.frames = reversed(self.frames)
52 |
53 | self.palette = Image.open(path.join(mask_dir, self.frames[0].replace('.jpg',
54 | '.png'))).getpalette()
55 | self.first_gt_path = path.join(self.mask_dir, self.frames[0].replace('.jpg', '.png'))
56 |
57 | if size < 0:
58 | self.im_transform = transforms.Compose([
59 | transforms.ToTensor(),
60 | im_normalization,
61 | ])
62 | self.mask_transform = transforms.Compose([])
63 | else:
64 | self.im_transform = transforms.Compose([
65 | transforms.ToTensor(),
66 | im_normalization,
67 | transforms.Resize(size, interpolation=InterpolationMode.BILINEAR, antialias=True),
68 | ])
69 | self.mask_transform = transforms.Compose([
70 | transforms.Resize(size, interpolation=InterpolationMode.NEAREST),
71 | ])
72 | self.size = size
73 | self.is_rgb = None
74 |
75 | def __getitem__(self, idx):
76 | frame = self.frames[idx]
77 | info = {}
78 | data = {}
79 | info['frame'] = frame
80 | info['save'] = (self.to_save is None) or (frame[:-4] in self.to_save)
81 |
82 | im_path = path.join(self.image_dir, frame)
83 | img = Image.open(im_path).convert('RGB')
84 |
85 | if self.image_dir == self.size_dir:
86 | shape = np.array(img).shape[:2]
87 | else:
88 | size_path = path.join(self.size_dir, frame)
89 | size_im = Image.open(size_path).convert('RGB')
90 | shape = np.array(size_im).shape[:2]
91 |
92 | mask_path = path.join(self.mask_dir, frame[:-4] + '.png')
93 | img = self.im_transform(img)
94 |
95 | if path.exists(mask_path):
96 | mask = Image.open(mask_path)
97 | mask = self.mask_transform(mask)
98 | if mask.mode == 'RGB':
99 | mask = np.array(mask, dtype=np.int32)
100 | mask = mask[:, :, 0] + mask[:, :, 1] * 256 + mask[:, :, 2] * 256 * 256
101 | self.is_rgb = True
102 | else:
103 | mask = mask.convert('P')
104 | mask = np.array(mask, dtype=np.int32)
105 | self.is_rgb = False
106 | data['mask'] = mask
107 |
108 | # defer json loading to the model
109 | json_path = path.join(self.mask_dir, frame[:-4] + '.json')
110 | if path.exists(json_path):
111 | info['json'] = json_path
112 |
113 | info['is_rgb'] = self.is_rgb
114 | info['shape'] = shape
115 | info['need_resize'] = not (self.size < 0)
116 | info['path_to_image'] = im_path
117 | data['rgb'] = img
118 | data['info'] = info
119 |
120 | return data
121 |
122 | def get_palette(self):
123 | return self.palette
124 |
125 | def __len__(self):
126 | return len(self.frames)
--------------------------------------------------------------------------------
/deva/inference/data/referring_test_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | import json
4 | from collections import defaultdict
5 | import numpy as np
6 |
7 | from deva.inference.data.video_reader import VideoReader
8 |
9 |
10 | class ReferringDAVISTestDataset:
11 | def __init__(self, image_dir, mask_dir, size=-1):
12 | self.image_dir = image_dir
13 | self.mask_dir = mask_dir
14 | self.size = size
15 |
16 | self.vid_list = sorted(os.listdir(self.mask_dir))
17 |
18 | def get_videos(self):
19 | return self.vid_list
20 |
21 | def get_offline_sampled_frames(self, video, num_sampled_frames):
22 | return VideoReader(
23 | video,
24 | path.join(self.image_dir, video),
25 | path.join(self.mask_dir, video),
26 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
27 | size=self.size,
28 | soft_mask=True,
29 | num_sampled_frames=num_sampled_frames,
30 | use_all_masks=True,
31 | )
32 |
33 | def get_partial_video_loader(self, video, *, start, end, reverse):
34 | return VideoReader(
35 | video,
36 | path.join(self.image_dir, video),
37 | path.join(self.mask_dir, video),
38 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
39 | size=self.size,
40 | soft_mask=True,
41 | start=start,
42 | end=end,
43 | reverse=reverse,
44 | )
45 |
46 | def get_scores(self, video):
47 | with open(path.join(self.mask_dir, video, 'scores.csv')) as f:
48 | lines = f.read().splitlines()
49 | scores = defaultdict(dict)
50 | for l in lines:
51 | frame, obj, score = l.split(',')
52 | scores[frame[:-4]][obj] = float(score)
53 |
54 | average_scores = {}
55 | for frame, all_objects in scores.items():
56 | average_scores[frame] = np.array(list(all_objects.values())).mean()
57 |
58 | return average_scores
59 |
60 | def __len__(self):
61 | return len(self.vid_list)
62 |
63 |
64 | class ReferringYouTubeVOSTestDataset:
65 | def __init__(self, image_dir, mask_dir, json_dir, size=-1):
66 | self.image_dir = image_dir
67 | self.mask_dir = mask_dir
68 | self.size = size
69 |
70 | self.vid_list = sorted(os.listdir(self.mask_dir))
71 | self.req_frame_list = {}
72 |
73 | with open(json_dir) as f:
74 | # read meta.json to know which frame is required for evaluation
75 | meta = json.load(f)['videos']
76 |
77 | for vid in self.vid_list:
78 | req_frames = []
79 | req_frames.extend(meta[vid]['frames'])
80 |
81 | req_frames = list(set(req_frames))
82 | self.req_frame_list[vid] = req_frames
83 |
84 | def get_videos(self):
85 | return self.vid_list
86 |
87 | def get_objects(self, video):
88 | return [
89 | obj for obj in sorted(os.listdir(path.join(self.mask_dir, video))) if '.csv' not in obj
90 | ]
91 |
92 | def _get_to_save_list(self, video, object_name):
93 | return self.req_frame_list[video]
94 |
95 | def get_offline_sampled_frames(self, video, object_name, num_sampled_frames):
96 | return VideoReader(
97 | video,
98 | path.join(self.image_dir, video),
99 | path.join(self.mask_dir, video),
100 | size=self.size,
101 | soft_mask=True,
102 | num_sampled_frames=num_sampled_frames,
103 | use_all_masks=True,
104 | to_save=self._get_to_save_list(video, object_name),
105 | object_name=object_name,
106 | enabled_frame_list=self._get_enabled_frame_list(video, object_name),
107 | )
108 |
109 | def get_partial_video_loader(self, video, object_name, *, start, end, reverse):
110 | return VideoReader(
111 | video,
112 | path.join(self.image_dir, video),
113 | path.join(self.mask_dir, video),
114 | size=self.size,
115 | soft_mask=True,
116 | start=start,
117 | end=end,
118 | reverse=reverse,
119 | to_save=self._get_to_save_list(video, object_name),
120 | object_name=object_name,
121 | enabled_frame_list=self._get_enabled_frame_list(video, object_name),
122 | )
123 |
124 | def get_scores(self, video):
125 | with open(path.join(self.mask_dir, video, 'scores.csv')) as f:
126 | lines = f.read().splitlines()
127 | scores = defaultdict(dict)
128 | enabled_frame_list = self._get_enabled_frame_list(video, None)
129 | for l in lines:
130 | frame, obj, score = l.split(',')
131 | if enabled_frame_list is not None and frame[:-4] not in enabled_frame_list:
132 | continue
133 | scores[obj][frame[:-4]] = float(score)
134 | return scores
135 |
136 | def _get_enabled_frame_list(self, video, object_name):
137 | # None -> enable all
138 | return None
139 |
140 | def __len__(self):
141 | return len(self.vid_list)
142 |
--------------------------------------------------------------------------------
/deva/inference/data/saliency_test_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | from deva.inference.data.video_reader import VideoReader
5 |
6 |
7 | class DAVISSaliencyTestDataset:
8 | def __init__(self, image_dir, mask_dir, imset=None, size=-1):
9 | self.image_dir = image_dir
10 | self.mask_dir = mask_dir
11 | self.size = size
12 |
13 | if imset is None:
14 | self.vid_list = sorted(os.listdir(self.mask_dir))
15 | else:
16 | with open(imset) as f:
17 | self.vid_list = sorted([line.strip() for line in f])
18 |
19 | def get_datasets(self):
20 | for video in self.vid_list:
21 | yield VideoReader(
22 | video,
23 | path.join(self.image_dir, video),
24 | path.join(self.mask_dir, video),
25 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
26 | size=self.size,
27 | soft_mask=True,
28 | use_all_masks=True,
29 | multi_object=False,
30 | )
31 |
32 | def get_videos(self):
33 | return self.vid_list
34 |
35 | def get_offline_sampled_frames(self, video, num_sampled_frames):
36 | return VideoReader(
37 | video,
38 | path.join(self.image_dir, video),
39 | path.join(self.mask_dir, video),
40 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
41 | size=self.size,
42 | soft_mask=True,
43 | num_sampled_frames=num_sampled_frames,
44 | use_all_masks=True,
45 | multi_object=False,
46 | )
47 |
48 | def get_partial_video_loader(self, video, *, start, end, reverse):
49 | return VideoReader(
50 | video,
51 | path.join(self.image_dir, video),
52 | path.join(self.mask_dir, video),
53 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
54 | size=self.size,
55 | soft_mask=True,
56 | start=start,
57 | end=end,
58 | reverse=reverse,
59 | multi_object=False,
60 | )
61 |
62 | def __len__(self):
63 | return len(self.vid_list)
--------------------------------------------------------------------------------
/deva/inference/data/simple_video_reader.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from torch.utils.data.dataset import Dataset
4 | from PIL import Image
5 | import numpy as np
6 |
7 |
8 | class SimpleVideoReader(Dataset):
9 | """
10 | This class is used to read a video, one frame at a time
11 | This simple version:
12 | 1. Does not load the mask/json
13 | 2. Does not normalize the input
14 | 3. Does not resize
15 | """
16 | def __init__(
17 | self,
18 | image_dir,
19 | ):
20 | """
21 | image_dir - points to a directory of jpg images
22 | """
23 | self.image_dir = image_dir
24 | self.frames = sorted(os.listdir(self.image_dir))
25 |
26 | def __getitem__(self, idx):
27 | frame = self.frames[idx]
28 |
29 | im_path = path.join(self.image_dir, frame)
30 | img = Image.open(im_path).convert('RGB')
31 | img = np.array(img)
32 |
33 | return img, im_path
34 |
35 | def __len__(self):
36 | return len(self.frames)
37 |
38 |
39 | def no_collate(x):
40 | return x
--------------------------------------------------------------------------------
/deva/inference/data/vos_test_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | import json
4 |
5 | from deva.inference.data.video_reader import VideoReader
6 |
7 |
8 | class GeneralVOSTestDataset:
9 | def __init__(self, data_root, size=-1, use_all_masks=False):
10 | self.image_dir = path.join(data_root, 'JPEGImages')
11 | self.mask_dir = path.join(data_root, 'Annotations')
12 | self.size = size
13 | self.use_all_masks = use_all_masks
14 |
15 | self.vid_list = sorted(os.listdir(self.mask_dir))
16 |
17 | def get_datasets(self):
18 | for video in self.vid_list:
19 | yield VideoReader(
20 | video,
21 | path.join(self.image_dir, video),
22 | path.join(self.mask_dir, video),
23 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
24 | size=self.size,
25 | use_all_masks=self.use_all_masks,
26 | )
27 |
28 | def __len__(self):
29 | return len(self.vid_list)
30 |
31 |
32 | class DAVISTestDataset:
33 | def __init__(self, data_root, imset='2017/val.txt', size=-1):
34 | if size != 480:
35 | self.image_dir = path.join(data_root, 'JPEGImages', 'Full-Resolution')
36 | self.mask_dir = path.join(data_root, 'Annotations', 'Full-Resolution')
37 | if not path.exists(self.image_dir):
38 | print(f'{self.image_dir} not found. Looking at .../1080p instead')
39 | self.image_dir = path.join(data_root, 'JPEGImages', '1080p')
40 | self.mask_dir = path.join(data_root, 'Annotations', '1080p')
41 | assert path.exists(self.image_dir), 'path not found'
42 | else:
43 | self.image_dir = path.join(data_root, 'JPEGImages', '480p')
44 | self.mask_dir = path.join(data_root, 'Annotations', '480p')
45 | self.size_dir = path.join(data_root, 'JPEGImages', '480p')
46 | self.size = size
47 |
48 | with open(path.join(data_root, 'ImageSets', imset)) as f:
49 | self.vid_list = sorted([line.strip() for line in f])
50 |
51 | def get_datasets(self):
52 | for video in self.vid_list:
53 | yield VideoReader(
54 | video,
55 | path.join(self.image_dir, video),
56 | path.join(self.mask_dir, video),
57 | size=self.size,
58 | size_dir=path.join(self.size_dir, video),
59 | )
60 |
61 | def __len__(self):
62 | return len(self.vid_list)
63 |
64 |
65 | class YouTubeVOSTestDataset:
66 | def __init__(self, data_root, split, size=480):
67 | self.image_dir = path.join(data_root, 'all_frames', split + '_all_frames', 'JPEGImages')
68 | self.mask_dir = path.join(data_root, split, 'Annotations')
69 | self.size = size
70 |
71 | self.vid_list = sorted(os.listdir(self.image_dir))
72 | self.req_frame_list = {}
73 |
74 | with open(path.join(data_root, split, 'meta.json')) as f:
75 | # read meta.json to know which frame is required for evaluation
76 | meta = json.load(f)['videos']
77 |
78 | for vid in self.vid_list:
79 | req_frames = []
80 | objects = meta[vid]['objects']
81 | for value in objects.values():
82 | req_frames.extend(value['frames'])
83 |
84 | req_frames = list(set(req_frames))
85 | self.req_frame_list[vid] = req_frames
86 |
87 | def get_datasets(self):
88 | for video in self.vid_list:
89 | yield VideoReader(video,
90 | path.join(self.image_dir, video),
91 | path.join(self.mask_dir, video),
92 | size=self.size,
93 | to_save=self.req_frame_list[video],
94 | use_all_masks=True)
95 |
96 | def __len__(self):
97 | return len(self.vid_list)
98 |
--------------------------------------------------------------------------------
/deva/inference/data/vps_test_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | import json
4 |
5 | from deva.inference.data.detection_video_reader import DetectionVideoReader
6 |
7 |
8 | class VIPSegDetectionTestDataset:
9 | def __init__(self, image_dir, mask_dir, size=-1):
10 | self.image_dir = image_dir
11 | self.mask_dir = mask_dir
12 | self.size = size
13 | self.vid_list = sorted(os.listdir(self.mask_dir))
14 | self.vid_list = [v for v in self.vid_list if not v.endswith('.json')]
15 |
16 | def get_datasets(self):
17 | for video in self.vid_list:
18 | yield DetectionVideoReader(
19 | video,
20 | path.join(self.image_dir, video),
21 | path.join(self.mask_dir, video),
22 | to_save=[name[:-4] for name in os.listdir(path.join(self.mask_dir, video))],
23 | size=self.size,
24 | )
25 |
26 | def __len__(self):
27 | return len(self.vid_list)
28 |
29 |
30 | class BURSTDetectionTestDataset:
31 | def __init__(self, image_dir, mask_dir, gt_json_dir, size=-1, *, start=None, count=None):
32 | self.image_dir = image_dir
33 | self.mask_dir = mask_dir
34 | self.size = size
35 |
36 | # read the json file to get a list of videos and frames to save
37 | with open(gt_json_dir, 'r') as f:
38 | json_file = json.load(f)
39 | sequences = json_file['sequences']
40 | split = json_file['split']
41 |
42 | assert split == 'test' or split == 'val'
43 |
44 | # load a randomized ordering of BURST videos for a balanced load
45 | with open(f'./deva/utils/burst_{split}.txt', mode='r') as f:
46 | randomized_videos = list(f.read().splitlines())
47 |
48 | # subsample a list of videos for processing
49 | if start is not None and count is not None:
50 | randomized_videos = randomized_videos[start:start + count]
51 | print(f'Start: {start}, Count: {count}, End: {start+count}')
52 |
53 | self.vid_list = []
54 | self.frames_to_save = {}
55 | for sequence in sequences:
56 | dataset = sequence['dataset']
57 | seq_name = sequence['seq_name']
58 | video_name = path.join(dataset, seq_name)
59 | if video_name not in randomized_videos:
60 | continue
61 | self.vid_list.append(video_name)
62 |
63 | annotated_image_paths = sequence['annotated_image_paths']
64 | self.frames_to_save[video_name] = [p[:-4] for p in annotated_image_paths]
65 | assert path.exists(path.join(image_dir, video_name))
66 | assert path.exists(path.join(mask_dir, video_name))
67 |
68 | assert len(self.vid_list) == len(randomized_videos)
69 | # to use the random ordering
70 | self.vid_list = randomized_videos
71 |
72 | print(f'Actual total: {len(self.vid_list)}')
73 |
74 | def get_datasets(self):
75 | for video in self.vid_list:
76 | yield DetectionVideoReader(
77 | video,
78 | path.join(self.image_dir, video),
79 | path.join(self.mask_dir, video),
80 | to_save=self.frames_to_save[video],
81 | size=self.size,
82 | )
83 |
84 | def __len__(self):
85 | return len(self.vid_list)
--------------------------------------------------------------------------------
/deva/inference/demo_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from deva.dataset.utils import im_normalization
6 | from deva.inference.inference_core import DEVAInferenceCore
7 | from deva.inference.result_utils import ResultSaver
8 |
9 |
10 | def get_input_frame_for_deva(image_np: np.ndarray, min_side: int) -> torch.Tensor:
11 | image = torch.from_numpy(image_np).permute(2, 0, 1).float() / 255
12 | image = im_normalization(image)
13 | if min_side > 0:
14 | h, w = image_np.shape[:2]
15 | scale = min_side / min(h, w)
16 | new_h, new_w = int(h * scale), int(w * scale)
17 | image = image.unsqueeze(0)
18 | image = F.interpolate(image, (new_h, new_w), mode='bilinear', align_corners=False)[0]
19 | return image.cuda()
20 |
21 |
22 | @torch.inference_mode()
23 | def flush_buffer(deva: DEVAInferenceCore, result_saver: ResultSaver) -> None:
24 | # process all the remaining frames in the buffer
25 | cfg = deva.config
26 | new_min_side = cfg['size']
27 | need_resize = new_min_side > 0
28 |
29 | if 'prompt' in cfg:
30 | raw_prompt = cfg['prompt']
31 | prompts = raw_prompt.split('.')
32 | else:
33 | prompts = None
34 |
35 | for frame_info in deva.frame_buffer:
36 | this_image = frame_info.image
37 | this_frame_name = frame_info.name
38 | this_image_np = frame_info.image_np
39 | h, w = this_image_np.shape[:2]
40 | prob = deva.step(this_image, None, None)
41 | result_saver.save_mask(prob,
42 | this_frame_name,
43 | need_resize=need_resize,
44 | shape=(h, w),
45 | image_np=this_image_np,
46 | prompts=prompts)
47 |
--------------------------------------------------------------------------------
/deva/inference/eval_args.py:
--------------------------------------------------------------------------------
1 | # Common evaluation arguments
2 | from argparse import ArgumentParser
3 | import torch
4 | from deva.model.network import DEVA
5 |
6 |
7 | def add_common_eval_args(parser: ArgumentParser):
8 | parser.add_argument('--model', default='./saves/DEVA-propagation.pth')
9 |
10 | parser.add_argument('--output', default=None)
11 | parser.add_argument(
12 | '--save_all',
13 | action='store_true',
14 | help='Save all frames',
15 | )
16 |
17 | parser.add_argument('--amp', action='store_true')
18 |
19 | # Model parameters
20 | parser.add_argument('--key_dim', type=int, default=64)
21 | parser.add_argument('--value_dim', type=int, default=512)
22 | parser.add_argument('--pix_feat_dim', type=int, default=512)
23 |
24 | # Long-term memory options
25 | parser.add_argument('--disable_long_term', action='store_true')
26 | parser.add_argument('--max_mid_term_frames',
27 | help='T_max in XMem, decrease to save memory',
28 | type=int,
29 | default=10)
30 | parser.add_argument('--min_mid_term_frames',
31 | help='T_min in XMem, decrease to save memory',
32 | type=int,
33 | default=5)
34 | parser.add_argument('--max_long_term_elements',
35 | help='LT_max in XMem, increase if objects disappear for a long time',
36 | type=int,
37 | default=10000)
38 | parser.add_argument('--num_prototypes', help='P in XMem', type=int, default=128)
39 |
40 | parser.add_argument('--top_k', type=int, default=30)
41 | parser.add_argument('--mem_every',
42 | help='r in XMem. Increase to improve running speed.',
43 | type=int,
44 | default=5)
45 | parser.add_argument(
46 | '--chunk_size',
47 | default=-1,
48 | type=int,
49 | help='''Number of objects to process in parallel as a batch; -1 for unlimited.
50 | Set to a small number to save memory.''')
51 |
52 | parser.add_argument(
53 | '--size',
54 | default=480,
55 | type=int,
56 | help='Resize the shorter side to this size. -1 to use original resolution. ')
57 |
58 |
59 | def get_model_and_config(parser: ArgumentParser):
60 | args = parser.parse_args()
61 | config = vars(args)
62 | config['enable_long_term'] = not config['disable_long_term']
63 |
64 | # Load our checkpoint
65 | network = DEVA(config).cuda().eval()
66 | if args.model is not None:
67 | model_weights = torch.load(args.model)
68 | network.load_weights(model_weights)
69 | else:
70 | print('No model loaded.')
71 |
72 | return network, config, args
--------------------------------------------------------------------------------
/deva/inference/frame_utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List
2 | import torch
3 |
4 | from deva.inference.object_info import ObjectInfo
5 |
6 |
7 | class FrameInfo:
8 | def __init__(self, image: torch.Tensor, mask: torch.Tensor, segments_info: List[ObjectInfo],
9 | ti: int, info: Dict):
10 | self.image = image
11 | self.mask = mask
12 | self.segments_info = segments_info
13 | self.ti = ti
14 | self.info = info
15 |
16 | @property
17 | def name(self):
18 | return self.info['frame'][0]
19 |
20 | @property
21 | def shape(self):
22 | return self.info['shape']
23 |
24 | @property
25 | def save_needed(self):
26 | return self.info['save'][0]
27 |
28 | @property
29 | def path_to_image(self):
30 | return self.info['path_to_image'][0]
31 |
--------------------------------------------------------------------------------
/deva/inference/image_feature_store.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable
2 | import warnings
3 | import torch
4 | from deva.model.network import DEVA
5 |
6 |
7 | class ImageFeatureStore:
8 | """
9 | A cache for image features.
10 | These features might be reused at different parts of the inference pipeline.
11 | This class provide a easy interface for reusing these features.
12 | It is the user's responsibility to delete features.
13 |
14 | Feature of a frame should be associated with a unique index -- typically the frame id.
15 | """
16 | def __init__(self, network: DEVA, no_warning: bool = False):
17 | self.network = network
18 | self._store = {}
19 | self.no_warning = no_warning
20 |
21 | def _encode_feature(self, index: int, image: torch.Tensor) -> None:
22 | ms_features, feat = self.network.encode_image(image)
23 | key, shrinkage, selection = self.network.transform_key(feat)
24 |
25 | self._store[index] = (ms_features, feat, key, shrinkage, selection)
26 |
27 | def get_ms_features(self, index, image) -> Iterable[torch.Tensor]:
28 | if index not in self._store:
29 | self._encode_feature(index, image)
30 |
31 | return self._store[index][0]
32 |
33 | def get_key(self, index, image) -> (torch.Tensor, torch.Tensor, torch.Tensor):
34 | if index not in self._store:
35 | self._encode_feature(index, image)
36 |
37 | return self._store[index][2:]
38 |
39 | def delete(self, index) -> None:
40 | if index in self._store:
41 | del self._store[index]
42 |
43 | def __len__(self):
44 | return len(self._store)
45 |
46 | def __del__(self):
47 | if len(self._store) > 0 and not self.no_warning:
48 | warnings.warn(f'Leaking {self._store.keys()} in the image feature store')
49 |
--------------------------------------------------------------------------------
/deva/inference/object_info.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import numpy as np
3 | from scipy import stats
4 | from deva.utils.pano_utils import id_to_rgb
5 |
6 |
7 | class ObjectInfo:
8 | """
9 | Stores meta information for an object
10 | """
11 | def __init__(self,
12 | id: int,
13 | category_id: Optional[int] = None,
14 | isthing: Optional[bool] = None,
15 | score: Optional[float] = None):
16 | self.id = id
17 | self.category_ids = [category_id]
18 | self.scores = [score]
19 | self.isthing = isthing
20 | self.poke_count = 0 # number of detections since last this object was last seen
21 |
22 | def poke(self) -> None:
23 | self.poke_count += 1
24 |
25 | def unpoke(self) -> None:
26 | self.poke_count = 0
27 |
28 | def merge(self, other) -> None:
29 | self.category_ids.extend(other.category_ids)
30 | self.scores.extend(other.scores)
31 |
32 | def vote_category_id(self) -> Optional[int]:
33 | category_ids = [c for c in self.category_ids if c is not None]
34 | if len(category_ids) == 0:
35 | return None
36 | else:
37 | return int(stats.mode(category_ids, keepdims=False)[0])
38 |
39 | def vote_score(self) -> Optional[float]:
40 | scores = [c for c in self.scores if c is not None]
41 | if len(scores) == 0:
42 | return None
43 | else:
44 | return float(np.mean(scores))
45 |
46 | def get_rgb(self) -> np.ndarray:
47 | # this is valid for panoptic segmentation-style id only (0~255**3)
48 | return id_to_rgb(self.id)
49 |
50 | def copy_meta_info(self, other) -> None:
51 | self.category_ids = other.category_ids
52 | self.scores = other.scores
53 | self.isthing = other.isthing
54 |
55 | def __hash__(self):
56 | return hash(self.id)
57 |
58 | def __eq__(self, other):
59 | return self.id == other.id
60 |
61 | def __repr__(self):
62 | return f'(ID: {self.id}, cat: {self.category_ids}, isthing: {self.isthing}, score: {self.scores})'
63 |
--------------------------------------------------------------------------------
/deva/inference/object_manager.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Dict, Set
2 |
3 | import torch
4 | import numpy as np
5 | from deva.inference.object_info import ObjectInfo
6 |
7 |
8 | class ObjectManager:
9 | """
10 | Object (real) IDs are immutable. The same ID always represent the same object.
11 | Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
12 | """
13 | def __init__(self):
14 | self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
15 | self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
16 | self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
17 |
18 | # We keep track of all historical object IDs to avoid collision
19 | # even removed object IDs stay in this set
20 | self.all_historical_object_ids: Set[int] = set()
21 | self.use_long_id = False
22 |
23 | def _recompute_obj_id_to_obj_mapping(self) -> None:
24 | self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
25 |
26 | def add_new_objects(
27 | self, objects: Union[List[ObjectInfo], ObjectInfo,
28 | List[int]]) -> (List[int], List[int]):
29 | if not isinstance(objects, list):
30 | objects = [objects]
31 |
32 | corresponding_tmp_ids = []
33 | corresponding_obj_ids = []
34 | for obj in objects:
35 | if isinstance(obj, int):
36 | obj = ObjectInfo(id=obj)
37 |
38 | # create new id if id collides with existing ones
39 | # simple heuristic to determine RGB/0~255; works most of the time except when it doesn't
40 | count = 0
41 | new_obj = ObjectInfo(id=obj.id)
42 | while (new_obj.id in self.all_historical_object_ids) or (self.use_long_id
43 | and new_obj.id < 256):
44 | if self.use_long_id:
45 | new_id = np.random.randint(256, 256**3)
46 | else:
47 | new_id = np.random.randint(1, 256)
48 | new_obj = ObjectInfo(id=new_id)
49 | count += 1
50 |
51 | if count > 5000:
52 | raise ValueError("We cannot find a new ID for this object." +
53 | "Perhaps you should use long ID?")
54 | new_obj.copy_meta_info(obj)
55 |
56 | # new object
57 | new_tmp_id = len(self.obj_to_tmp_id) + 1
58 | self.obj_to_tmp_id[new_obj] = new_tmp_id
59 | self.tmp_id_to_obj[new_tmp_id] = new_obj
60 | self.all_historical_object_ids.add(new_obj.id)
61 | corresponding_tmp_ids.append(new_tmp_id)
62 | corresponding_obj_ids.append(new_obj.id)
63 |
64 | self._recompute_obj_id_to_obj_mapping()
65 | assert corresponding_tmp_ids == sorted(corresponding_tmp_ids), 'tmp id assignment bugged'
66 | return corresponding_tmp_ids, corresponding_obj_ids
67 |
68 | def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
69 | # delete an object or a list of objects
70 | # re-sort the tmp ids
71 | if isinstance(obj_ids_to_remove, int):
72 | obj_ids_to_remove = [obj_ids_to_remove]
73 |
74 | new_tmp_id = 1
75 | total_num_id = len(self.obj_to_tmp_id)
76 |
77 | local_obj_to_tmp_id = {}
78 | local_tmp_to_obj_id = {}
79 |
80 | for tmp_iter in range(1, total_num_id + 1):
81 | obj = self.tmp_id_to_obj[tmp_iter]
82 | if obj.id not in obj_ids_to_remove:
83 | local_obj_to_tmp_id[obj] = new_tmp_id
84 | local_tmp_to_obj_id[new_tmp_id] = obj
85 | new_tmp_id += 1
86 |
87 | self.obj_to_tmp_id = local_obj_to_tmp_id
88 | self.tmp_id_to_obj = local_tmp_to_obj_id
89 | self._recompute_obj_id_to_obj_mapping()
90 |
91 | def purge_inactive_objects(self,
92 | max_missed_detection_count: int) -> (bool, List[int], List[int]):
93 | # remove tmp ids of objects that are removed
94 | obj_id_to_be_deleted = []
95 | tmp_id_to_be_deleted = []
96 | tmp_id_to_keep = []
97 | obj_id_to_keep = []
98 |
99 | for obj in self.obj_to_tmp_id:
100 | if obj.poke_count > max_missed_detection_count:
101 | obj_id_to_be_deleted.append(obj.id)
102 | tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
103 | else:
104 | tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
105 | obj_id_to_keep.append(obj.id)
106 |
107 | purge_activated = len(obj_id_to_be_deleted) > 0
108 | if purge_activated:
109 | self.delete_object(obj_id_to_be_deleted)
110 | return purge_activated, tmp_id_to_keep, obj_id_to_keep
111 |
112 | def tmp_to_obj_cls(self, mask) -> torch.Tensor:
113 | # remap tmp id cls representation to the true object id representation
114 | new_mask = torch.zeros_like(mask)
115 | for tmp_id, obj in self.tmp_id_to_obj.items():
116 | new_mask[mask == tmp_id] = obj.id
117 | return new_mask
118 |
119 | def get_tmp_to_obj_mapping(self) -> Dict[int, int]:
120 | # returns the mapping in a dict format for saving it with pickle
121 | return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
122 |
123 | def realize_dict(self, obj_dict: Dict[int, torch.Tensor]) -> torch.Tensor:
124 | # turns a dict indexed by obj id into a tensor, ordered by tmp IDs
125 | output = []
126 | for _, obj in self.tmp_id_to_obj.items():
127 | if obj.id not in obj_dict:
128 | raise NotImplementedError
129 | output.append(obj_dict[obj.id])
130 | output = torch.stack(output, dim=0)
131 | return output
132 |
133 | def make_one_hot(self, cls_mask: torch.Tensor) -> torch.Tensor:
134 | output = []
135 | for _, obj in self.tmp_id_to_obj.items():
136 | output.append(cls_mask == obj.id)
137 | if len(output) == 0:
138 | output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
139 | else:
140 | output = torch.stack(output, dim=0)
141 | return output
142 |
143 | def get_current_segments_info(self) -> List[Dict]:
144 | segments_info = []
145 | for obj in self.obj_to_tmp_id:
146 | segments_info.append({
147 | 'category_id': obj.vote_category_id(),
148 | 'id': int(obj.id),
149 | 'score': obj.vote_score(),
150 | })
151 | return segments_info
152 |
153 | @property
154 | def all_obj_ids(self) -> List[int]:
155 | return [k.id for k in self.obj_to_tmp_id]
156 |
157 | @property
158 | def num_obj(self) -> int:
159 | return len(self.obj_to_tmp_id)
160 |
161 | def has_all(self, objects: List[int]) -> bool:
162 | for obj in objects:
163 | if obj not in self.obj_to_tmp_id:
164 | return False
165 | return True
166 |
167 | def find_object_by_id(self, obj_id) -> ObjectInfo:
168 | return self.obj_id_to_obj[obj_id]
169 |
--------------------------------------------------------------------------------
/deva/inference/object_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional
2 |
3 | import torch
4 | from deva.inference.object_info import ObjectInfo
5 | from deva.utils.pano_utils import vipseg_cat_to_isthing
6 |
7 |
8 | def convert_json_dict_to_objects_info(mask: torch.Tensor,
9 | segments_info: Optional[List],
10 | dataset: str = None) -> List[ObjectInfo]:
11 | """
12 | Convert a json dict to a list of object info
13 | If segments_info is None, we use the unique elements in mask to construct the list
14 | Otherwise mask is ignored
15 | """
16 | if segments_info is not None:
17 | output = [
18 | ObjectInfo(
19 | id=segment['id'],
20 | category_id=segment.get('category_id'),
21 | isthing=vipseg_cat_to_isthing[segment.get('category_id')]
22 | if dataset == 'vipseg' else None,
23 | score=float(segment['score']) if
24 | ((dataset == 'burst' or dataset == 'demo') and 'score' in segment) else None)
25 | for segment in segments_info
26 | ]
27 | else:
28 | # use the mask
29 | labels = torch.unique(mask)
30 | labels = labels[labels != 0]
31 | output = [ObjectInfo(l.item()) for l in labels]
32 |
33 | return output
34 |
--------------------------------------------------------------------------------
/deva/inference/postprocess_unsup_davis17.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import os
3 | from os import path
4 | import sys
5 | import numpy as np
6 | import tqdm
7 |
8 | from deva.utils.palette import davis_palette
9 |
10 |
11 | def limit_max_id(input_path, output_path, max_num_objects=20):
12 | videos = sorted(os.listdir(input_path))
13 | for video in tqdm.tqdm(videos):
14 | existing_objects = []
15 |
16 | video_path = path.join(input_path, video)
17 | frames = sorted(os.listdir(video_path))
18 |
19 | # determine the objects to keep
20 | for frame in frames:
21 | mask = Image.open(path.join(video_path, frame))
22 | mask = np.array(mask).astype(np.int32)
23 | if len(mask.shape) == 3:
24 | mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]
25 | labels = np.unique(mask)
26 | labels = labels[labels != 0]
27 | labels_area = [np.sum(mask == label) for label in labels]
28 |
29 | labels_sorted_by_area = [x for _, x in sorted(zip(labels_area, labels), reverse=True)]
30 | if len(labels_sorted_by_area) + len(existing_objects) <= max_num_objects:
31 | existing_objects += labels_sorted_by_area
32 | else:
33 | existing_objects += labels_sorted_by_area[:max_num_objects - len(existing_objects)]
34 |
35 | if len(existing_objects) == max_num_objects:
36 | break
37 |
38 | assert len(existing_objects) <= max_num_objects
39 |
40 | # remove the objects that are not in the existing_objects list
41 | for frame in frames:
42 | mask = Image.open(path.join(video_path, frame))
43 | mask = np.array(mask).astype(np.int32)
44 | if len(mask.shape) == 3:
45 | mask = mask[:, :, 0] * 256 * 256 + mask[:, :, 1] * 256 + mask[:, :, 2]
46 | labels = np.unique(mask)
47 | labels = labels[labels != 0]
48 |
49 | new_mask = np.zeros_like(mask, dtype=np.uint8)
50 | for new_idx, label in enumerate(existing_objects):
51 | new_mask[mask == label] = new_idx + 1
52 |
53 | mask = Image.fromarray(new_mask)
54 | mask.putpalette(davis_palette)
55 | os.makedirs(path.join(output_path, video), exist_ok=True)
56 | mask.save(path.join(output_path, video, frame))
57 |
58 |
59 | if __name__ == '__main__':
60 | input_path = sys.argv[1]
61 | output_path = sys.argv[2]
62 | limit_max_id(input_path, output_path)
--------------------------------------------------------------------------------
/deva/inference/segment_merging.py:
--------------------------------------------------------------------------------
1 | """
2 | This file contains the implementation of segment matching and merging (Section 3.2.2).
3 |
4 | Match & merge the objects as discussed in the paper
5 | (Section 3.2.2 Merging Propagation and Consensus)
6 | Also update the object manager
7 | """
8 |
9 | import warnings
10 | from typing import List, Literal, Dict, Optional
11 |
12 | import torch
13 | from deva.inference.object_info import ObjectInfo
14 | from deva.inference.object_manager import ObjectManager
15 |
16 |
17 | def _get_iou(m1, m2, m1_sum, m2_sum) -> (float, float, float):
18 | intersection = (m1 * m2).sum()
19 | if intersection < 1e-3:
20 | return 0, None, None
21 | union = (m1_sum + m2_sum - intersection)
22 | return intersection / union, intersection, union
23 |
24 |
25 | def merge_by_iou(our_masks: Dict[ObjectInfo, torch.Tensor], new_masks: Dict[ObjectInfo,
26 | torch.Tensor],
27 | our_sums: Dict[ObjectInfo, torch.Tensor], new_sums: Dict[ObjectInfo, torch.Tensor],
28 | merged_mask: torch.Tensor, object_manager: ObjectManager,
29 | new_segments_info: List[ObjectInfo], isthing_status: Optional[bool],
30 | incremental_mode: bool) -> torch.Tensor:
31 | # meged_mask is edited in-place
32 | our_to_new_matching = {}
33 | matched_area = {}
34 | new_objects = []
35 |
36 | for new_obj in new_segments_info:
37 | if new_obj.isthing != isthing_status:
38 | continue
39 | for our_obj in object_manager.obj_to_tmp_id:
40 | if (our_obj.isthing != isthing_status) or (our_obj in our_to_new_matching):
41 | continue
42 | iou, _, union = _get_iou(new_masks[new_obj], our_masks[our_obj], new_sums[new_obj],
43 | our_sums[our_obj])
44 | matched = (iou > 0.5)
45 | if matched:
46 | our_to_new_matching[our_obj] = new_obj
47 | matched_area[(our_obj, False)] = union
48 | break
49 | else:
50 | new_objects.append(new_obj)
51 | matched_area[(new_obj, True)] = new_sums[new_obj]
52 |
53 | # for all unmatched our segment
54 | for our_obj in object_manager.obj_to_tmp_id:
55 | if (our_obj.isthing != isthing_status) or (our_obj in our_to_new_matching):
56 | continue
57 | matched_area[(our_obj, False)] = our_sums[our_obj]
58 |
59 | # rendering by reversed order of areas
60 | sorted_by_area = sorted(matched_area.items(), key=lambda x: x[1], reverse=True)
61 | for (obj, is_new), _ in sorted_by_area:
62 | if is_new:
63 | # obj is a new object
64 | _, corresponding_obj_ids = object_manager.add_new_objects(obj)
65 | merged_mask[new_masks[obj]] = corresponding_obj_ids[0]
66 | else:
67 | # obj is not a new object
68 | if obj in our_to_new_matching:
69 | # merge
70 | new_obj = our_to_new_matching[obj]
71 | merged_mask[our_masks[obj]] = obj.id
72 | merged_mask[new_masks[new_obj]] = obj.id
73 | obj.merge(new_obj)
74 | obj.unpoke()
75 | else:
76 | # copy from our forward mask
77 | merged_mask[our_masks[obj]] = obj.id
78 | if incremental_mode:
79 | if our_sums[obj] < 1:
80 | obj.poke()
81 | else:
82 | obj.unpoke()
83 | else:
84 | obj.poke()
85 |
86 | return merged_mask
87 |
88 |
89 | def match_and_merge(our_mask: torch.Tensor,
90 | new_mask: torch.Tensor,
91 | object_manager: ObjectManager,
92 | new_segments_info: List[ObjectInfo],
93 | mode: Literal['iou'] = 'iou',
94 | max_num_objects: int = -1,
95 | incremental_mode: bool = False) -> torch.Tensor:
96 | """
97 | our_mask is in temporary ids (consecutive)
98 | new_mask is in object ids (real ids from json)
99 |
100 | Updates the object manager as a side effect
101 | mode: 'iou' only
102 | max_num_objects: maximum number of objects allowed in memory (-1 for no limit)
103 | incremental_mode: existing masks are not expected to be supported by new masks,
104 | thus we only delete masks when they are not visible for too long,
105 | not when they are unsupported for too long
106 | """
107 | mode = mode.lower()
108 |
109 | # separate the masks into one-hot format
110 | our_mask = our_mask.long()
111 | new_mask = new_mask.long()
112 | our_masks = {obj: (our_mask == tmp) for obj, tmp in object_manager.obj_to_tmp_id.items()}
113 | new_masks = {obj: (new_mask == obj.id) for obj in new_segments_info}
114 |
115 | if max_num_objects > 0 and len(
116 | object_manager.obj_to_tmp_id) + len(new_segments_info) > max_num_objects:
117 | # too many objects; forcibly deny all new objects
118 | warnings.warn(
119 | 'Number of objects exceeded maximum (--max_num_objects); discarding new objects')
120 | new_masks = {}
121 | new_segments_info = []
122 |
123 | # pre-compute mask sums for IoU computation
124 | our_sums = {obj: m.sum() for m in our_masks for obj, m in our_masks.items()}
125 | new_sums = {obj: m.sum() for m in new_masks for obj, m in new_masks.items()}
126 |
127 | # matching
128 | merged_mask = torch.zeros_like(our_mask)
129 | match_isthing = [None, False, True] # for isthing
130 | # we merge stuff/things/others separately
131 | for isthing_status in match_isthing:
132 | if mode == 'iou':
133 | merged_mask = merge_by_iou(our_masks, new_masks, our_sums, new_sums, merged_mask,
134 | object_manager, new_segments_info, isthing_status,
135 | incremental_mode)
136 | elif mode == 'engulf':
137 | raise NotImplementedError('Engulf mode is deprecated')
138 | merged_mask = merge_by_engulf(our_masks, new_masks, our_sums, new_sums, merged_mask,
139 | object_manager, new_segments_info, isthing_status,
140 | engulf_threshold)
141 |
142 | merged_mask = object_manager.make_one_hot(merged_mask)
143 | return merged_mask
144 |
--------------------------------------------------------------------------------
/deva/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/model/__init__.py
--------------------------------------------------------------------------------
/deva/model/cbam.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class BasicConv(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
9 | super(BasicConv, self).__init__()
10 | self.out_channels = out_planes
11 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
12 |
13 | def forward(self, x):
14 | x = self.conv(x)
15 | return x
16 |
17 | class Flatten(nn.Module):
18 | def forward(self, x):
19 | return x.view(x.size(0), -1)
20 |
21 | class ChannelGate(nn.Module):
22 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
23 | super(ChannelGate, self).__init__()
24 | self.gate_channels = gate_channels
25 | self.mlp = nn.Sequential(
26 | Flatten(),
27 | nn.Linear(gate_channels, gate_channels // reduction_ratio),
28 | nn.ReLU(),
29 | nn.Linear(gate_channels // reduction_ratio, gate_channels)
30 | )
31 | self.pool_types = pool_types
32 | def forward(self, x):
33 | channel_att_sum = None
34 | for pool_type in self.pool_types:
35 | if pool_type=='avg':
36 | avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
37 | channel_att_raw = self.mlp( avg_pool )
38 | elif pool_type=='max':
39 | max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
40 | channel_att_raw = self.mlp( max_pool )
41 |
42 | if channel_att_sum is None:
43 | channel_att_sum = channel_att_raw
44 | else:
45 | channel_att_sum = channel_att_sum + channel_att_raw
46 |
47 | scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
48 | return x * scale
49 |
50 | class ChannelPool(nn.Module):
51 | def forward(self, x):
52 | return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
53 |
54 | class SpatialGate(nn.Module):
55 | def __init__(self):
56 | super(SpatialGate, self).__init__()
57 | kernel_size = 7
58 | self.compress = ChannelPool()
59 | self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
60 | def forward(self, x):
61 | x_compress = self.compress(x)
62 | x_out = self.spatial(x_compress)
63 | scale = torch.sigmoid(x_out) # broadcasting
64 | return x * scale
65 |
66 | class CBAM(nn.Module):
67 | def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
68 | super(CBAM, self).__init__()
69 | self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
70 | self.no_spatial=no_spatial
71 | if not no_spatial:
72 | self.SpatialGate = SpatialGate()
73 | def forward(self, x):
74 | x_out = self.ChannelGate(x)
75 | if not self.no_spatial:
76 | x_out = self.SpatialGate(x_out)
77 | return x_out
78 |
--------------------------------------------------------------------------------
/deva/model/group_modules.py:
--------------------------------------------------------------------------------
1 | """
2 | Group-specific modules
3 | They handle features that also depends on the mask.
4 | Features are typically of shape
5 | batch_size * num_objects * num_channels * H * W
6 |
7 | All of them are permutation equivariant w.r.t. to the num_objects dimension
8 | """
9 | from typing import Optional
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from deva.model.cbam import CBAM
15 |
16 |
17 | def interpolate_groups(g: torch.Tensor, ratio: float, mode: str, align_corners: bool) -> torch.Tensor:
18 | batch_size, num_objects = g.shape[:2]
19 | g = F.interpolate(g.flatten(start_dim=0, end_dim=1),
20 | scale_factor=ratio,
21 | mode=mode,
22 | align_corners=align_corners)
23 | g = g.view(batch_size, num_objects, *g.shape[1:])
24 | return g
25 |
26 |
27 | def upsample_groups(g: torch.Tensor,
28 | ratio: float = 2,
29 | mode: str = 'bilinear',
30 | align_corners: bool = False) -> torch.Tensor:
31 | return interpolate_groups(g, ratio, mode, align_corners)
32 |
33 |
34 | def downsample_groups(g: torch.Tensor,
35 | ratio: float = 1 / 2,
36 | mode: str = 'area',
37 | align_corners: bool = None) -> torch.Tensor:
38 | return interpolate_groups(g, ratio, mode, align_corners)
39 |
40 |
41 | class GConv2D(nn.Conv2d):
42 | def forward(self, g: torch.Tensor) -> torch.Tensor:
43 | batch_size, num_objects = g.shape[:2]
44 | g = super().forward(g.flatten(start_dim=0, end_dim=1))
45 | return g.view(batch_size, num_objects, *g.shape[1:])
46 |
47 |
48 | class GroupResBlock(nn.Module):
49 | def __init__(self, in_dim: int, out_dim: int):
50 | super().__init__()
51 |
52 | if in_dim == out_dim:
53 | self.downsample = None
54 | else:
55 | self.downsample = GConv2D(in_dim, out_dim, kernel_size=1)
56 |
57 | self.conv1 = GConv2D(in_dim, out_dim, kernel_size=3, padding=1)
58 | self.conv2 = GConv2D(out_dim, out_dim, kernel_size=3, padding=1)
59 |
60 | def forward(self, g: torch.Tensor) -> torch.Tensor:
61 | out_g = self.conv1(F.relu(g))
62 | out_g = self.conv2(F.relu(out_g))
63 |
64 | if self.downsample is not None:
65 | g = self.downsample(g)
66 |
67 | return out_g + g
68 |
69 |
70 | class ResMLP(nn.Module):
71 | def __init__(self, in_dim: int, out_dim: int):
72 | super().__init__()
73 |
74 | if in_dim == out_dim:
75 | self.downsample = None
76 | else:
77 | self.downsample = nn.Linear(in_dim, out_dim)
78 |
79 | self.conv1 = nn.Linear(in_dim, out_dim)
80 | self.conv2 = nn.Linear(out_dim, out_dim)
81 |
82 | def forward(self, g: torch.Tensor) -> torch.Tensor:
83 | out_g = self.conv1(F.relu(g))
84 | out_g = self.conv2(F.relu(out_g))
85 |
86 | if self.downsample is not None:
87 | g = self.downsample(g)
88 |
89 | return out_g + g
90 |
91 |
92 | class MainToGroupDistributor(nn.Module):
93 | def __init__(self,
94 | x_transform: Optional[nn.Module] = None,
95 | g_transform: Optional[nn.Module] = None,
96 | method: str = 'cat',
97 | reverse_order: bool = False):
98 | super().__init__()
99 |
100 | self.x_transform = x_transform
101 | self.g_transform = g_transform
102 | self.method = method
103 | self.reverse_order = reverse_order
104 |
105 | def forward(self, x: torch.Tensor, g: torch.Tensor, skip_expand: bool = False) -> torch.Tensor:
106 | num_objects = g.shape[1]
107 |
108 | if self.x_transform is not None:
109 | x = self.x_transform(x)
110 |
111 | if self.g_transform is not None:
112 | g = self.g_transform(g)
113 |
114 | if not skip_expand:
115 | x = x.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
116 | if self.method == 'cat':
117 | if self.reverse_order:
118 | g = torch.cat([g, x], 2)
119 | else:
120 | g = torch.cat([x, g], 2)
121 | elif self.method == 'add':
122 | g = x + g
123 | elif self.method == 'mulcat':
124 | g = torch.cat([x * g, g], dim=2)
125 | elif self.method == 'muladd':
126 | g = x * g + g
127 | else:
128 | raise NotImplementedError
129 |
130 | return g
131 |
132 |
133 | class GroupFeatureFusionBlock(nn.Module):
134 | def __init__(self, x_in_dim: int, g_in_dim: int, g_mid_dim: int, g_out_dim: int):
135 | super().__init__()
136 |
137 | self.distributor = MainToGroupDistributor()
138 | self.block1 = GroupResBlock(x_in_dim + g_in_dim, g_mid_dim)
139 | self.attention = CBAM(g_mid_dim)
140 | self.block2 = GroupResBlock(g_mid_dim, g_out_dim)
141 |
142 | def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
143 | batch_size, num_objects = g.shape[:2]
144 |
145 | g = self.distributor(x, g)
146 | g = self.block1(g)
147 | r = self.attention(g.flatten(start_dim=0, end_dim=1))
148 | r = r.view(batch_size, num_objects, *r.shape[1:])
149 |
150 | g = self.block2(g + r)
151 |
152 | return g
--------------------------------------------------------------------------------
/deva/model/losses.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from collections import defaultdict
7 |
8 |
9 | def dice_loss(input_mask, cls_gt) -> torch.Tensor:
10 | num_objects = input_mask.shape[1]
11 | losses = []
12 | for i in range(num_objects):
13 | mask = input_mask[:, i].flatten(start_dim=1)
14 | # background not in mask, so we add one to cls_gt
15 | gt = (cls_gt == (i + 1)).float().flatten(start_dim=1)
16 | numerator = 2 * (mask * gt).sum(-1)
17 | denominator = mask.sum(-1) + gt.sum(-1)
18 | loss = 1 - (numerator + 1) / (denominator + 1)
19 | losses.append(loss)
20 | return torch.cat(losses).mean()
21 |
22 |
23 | # https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch
24 | class BootstrappedCE(nn.Module):
25 | def __init__(self, start_warm, end_warm, top_p=0.3):
26 | super().__init__()
27 |
28 | self.start_warm = start_warm
29 | self.end_warm = end_warm
30 | self.top_p = top_p
31 |
32 | def forward(self, input, target, it) -> (torch.Tensor, float):
33 | if it < self.start_warm:
34 | return F.cross_entropy(input, target), 1.0
35 |
36 | raw_loss = F.cross_entropy(input, target, reduction='none').view(-1)
37 | num_pixels = raw_loss.numel()
38 |
39 | if it > self.end_warm:
40 | this_p = self.top_p
41 | else:
42 | this_p = self.top_p + (1 - self.top_p) * ((self.end_warm - it) /
43 | (self.end_warm - self.start_warm))
44 | loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False)
45 | return loss.mean(), this_p
46 |
47 |
48 | class LossComputer:
49 | def __init__(self, config):
50 | super().__init__()
51 | self.config = config
52 | self.bce = BootstrappedCE(config['start_warm'], config['end_warm'])
53 |
54 | def compute(self, data, num_objects, it) -> Dict[str, torch.Tensor]:
55 | losses = defaultdict(int)
56 |
57 | b, t = data['rgb'].shape[:2]
58 |
59 | losses['total_loss'] = 0
60 | for ti in range(1, t):
61 | for bi in range(b):
62 | loss, p = self.bce(data[f'logits_{ti}'][bi:bi + 1, :num_objects[bi] + 1],
63 | data['cls_gt'][bi:bi + 1, ti, 0], it)
64 |
65 | aux_loss = F.cross_entropy(
66 | data[f'aux_logits_{ti}'][bi:bi + 1, :num_objects[bi] + 1, 0],
67 | data['cls_gt'][bi:bi + 1, ti, 0])
68 |
69 | losses['p'] += p / b / (t - 1)
70 | losses[f'ce_loss_{ti}'] += loss / b
71 | losses[f'aux_loss_{ti}'] += aux_loss / b
72 |
73 | losses['total_loss'] += losses['ce_loss_%d' % ti]
74 | losses['total_loss'] += losses['aux_loss_%d' % ti] * 0.1
75 | losses[f'dice_loss_{ti}'] = dice_loss(data[f'masks_{ti}'], data['cls_gt'][:, ti, 0])
76 | losses['total_loss'] += losses[f'dice_loss_{ti}']
77 |
78 | return losses
79 |
--------------------------------------------------------------------------------
/deva/model/memory_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from typing import Optional, Union, Tuple
4 |
5 |
6 | def get_similarity(mk: torch.Tensor,
7 | ms: torch.Tensor,
8 | qk: torch.Tensor,
9 | qe: torch.Tensor,
10 | add_batch_dim=False) -> torch.Tensor:
11 | # used for training/inference and memory reading/memory potentiation
12 | # mk: B x CK x [N] - Memory keys
13 | # ms: B x 1 x [N] - Memory shrinkage
14 | # qk: B x CK x [HW/P] - Query keys
15 | # qe: B x CK x [HW/P] - Query selection
16 | # Dimensions in [] are flattened
17 | if add_batch_dim:
18 | mk, ms = mk.unsqueeze(0), ms.unsqueeze(0)
19 | qk, qe = qk.unsqueeze(0), qe.unsqueeze(0)
20 |
21 | CK = mk.shape[1]
22 | mk = mk.flatten(start_dim=2)
23 | ms = ms.flatten(start_dim=1).unsqueeze(2) if ms is not None else None
24 | qk = qk.flatten(start_dim=2)
25 | qe = qe.flatten(start_dim=2) if qe is not None else None
26 |
27 | if qe is not None:
28 | # See XMem's appendix for derivation
29 | mk = mk.transpose(1, 2)
30 | a_sq = (mk.pow(2) @ qe)
31 | two_ab = 2 * (mk @ (qk * qe))
32 | b_sq = (qe * qk.pow(2)).sum(1, keepdim=True)
33 | similarity = (-a_sq + two_ab - b_sq)
34 | else:
35 | # similar to STCN if we don't have the selection term
36 | a_sq = mk.pow(2).sum(1).unsqueeze(2)
37 | two_ab = 2 * (mk.transpose(1, 2) @ qk)
38 | similarity = (-a_sq + two_ab)
39 |
40 | if ms is not None:
41 | similarity = similarity * ms / math.sqrt(CK) # B*N*HW
42 | else:
43 | similarity = similarity / math.sqrt(CK) # B*N*HW
44 |
45 | return similarity
46 |
47 |
48 | def do_softmax(
49 | similarity: torch.Tensor,
50 | top_k: Optional[int] = None,
51 | inplace: bool = False,
52 | return_usage: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
53 | # normalize similarity with top-k softmax
54 | # similarity: B x N x [HW/P]
55 | # use inplace with care
56 | if top_k is not None:
57 | values, indices = torch.topk(similarity, k=top_k, dim=1)
58 |
59 | x_exp = values.exp_()
60 | x_exp /= torch.sum(x_exp, dim=1, keepdim=True)
61 | if inplace:
62 | similarity.zero_().scatter_(1, indices, x_exp) # B*N*HW
63 | affinity = similarity
64 | else:
65 | affinity = torch.zeros_like(similarity).scatter_(1, indices, x_exp) # B*N*HW
66 | else:
67 | maxes = torch.max(similarity, dim=1, keepdim=True)[0]
68 | x_exp = torch.exp(similarity - maxes)
69 | x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True)
70 | affinity = x_exp / x_exp_sum
71 | indices = None
72 |
73 | if return_usage:
74 | return affinity, affinity.sum(dim=2)
75 |
76 | return affinity
77 |
78 |
79 | def get_affinity(mk: torch.Tensor, ms: torch.Tensor, qk: torch.Tensor,
80 | qe: torch.Tensor) -> torch.Tensor:
81 | # shorthand used in training with no top-k
82 | similarity = get_similarity(mk, ms, qk, qe)
83 | affinity = do_softmax(similarity)
84 | return affinity
85 |
86 |
87 | def readout(affinity: torch.Tensor, mv: torch.Tensor) -> torch.Tensor:
88 | B, CV, T, H, W = mv.shape
89 |
90 | mo = mv.view(B, CV, T * H * W)
91 | mem = torch.bmm(mo, affinity)
92 | mem = mem.view(B, CV, H, W)
93 |
94 | return mem
95 |
--------------------------------------------------------------------------------
/deva/model/modules.py:
--------------------------------------------------------------------------------
1 | """
2 | modules.py - This file stores low-level network blocks.
3 |
4 | x - usually means features that only depends on the image
5 | g - usually means features that also depends on the mask.
6 | They might have an extra "group" or "num_objects" dimension, hence
7 | batch_size * num_objects * num_channels * H * W
8 |
9 | The trailing number of a variable usually denote the stride
10 |
11 | """
12 |
13 | from typing import List, Iterable
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 |
18 | from deva.model.group_modules import *
19 | from deva.model.cbam import CBAM
20 |
21 |
22 | class ResBlock(nn.Module):
23 | def __init__(self, in_dim: int, out_dim: int):
24 | super().__init__()
25 |
26 | if in_dim == out_dim:
27 | self.downsample = None
28 | else:
29 | self.downsample = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
30 |
31 | self.conv1 = nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1)
32 | self.conv2 = nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1)
33 |
34 | def forward(self, f: torch.Tensor) -> torch.Tensor:
35 | out_f = self.conv1(F.relu(f))
36 | out_f = self.conv2(F.relu(out_f))
37 |
38 | if self.downsample is not None:
39 | f = self.downsample(f)
40 |
41 | return out_f + f
42 |
43 |
44 | class FeatureFusionBlock(nn.Module):
45 | def __init__(self, in_dim: int, mid_dim: int, out_dim: int):
46 | super().__init__()
47 |
48 | self.block1 = ResBlock(in_dim, mid_dim)
49 | self.attention = CBAM(mid_dim)
50 | self.block2 = ResBlock(mid_dim, out_dim)
51 |
52 | def forward(self, x: torch.Tensor) -> torch.Tensor:
53 | x = self.block1(x)
54 | r = self.attention(x)
55 | x = self.block2(x + r)
56 |
57 | return x
58 |
59 |
60 | class KeyProjection(nn.Module):
61 | def __init__(self, in_dim: int, keydim: int):
62 | super().__init__()
63 |
64 | self.key_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
65 | # shrinkage
66 | self.d_proj = nn.Conv2d(in_dim, 1, kernel_size=3, padding=1)
67 | # selection
68 | self.e_proj = nn.Conv2d(in_dim, keydim, kernel_size=3, padding=1)
69 |
70 | nn.init.orthogonal_(self.key_proj.weight.data)
71 | nn.init.zeros_(self.key_proj.bias.data)
72 |
73 | def forward(self, x: torch.Tensor, *, need_s: bool,
74 | need_e: bool) -> (torch.Tensor, torch.Tensor, torch.Tensor):
75 | shrinkage = self.d_proj(x)**2 + 1 if (need_s) else None
76 | selection = torch.sigmoid(self.e_proj(x)) if (need_e) else None
77 |
78 | return self.key_proj(x), shrinkage, selection
79 |
80 |
81 | class MaskUpsampleBlock(nn.Module):
82 | def __init__(self, up_dim: int, out_dim: int, scale_factor: int = 2):
83 | super().__init__()
84 | self.distributor = MainToGroupDistributor(method='add')
85 | self.out_conv = GroupResBlock(up_dim, out_dim)
86 | self.scale_factor = scale_factor
87 |
88 | def forward(self, skip_f: torch.Tensor, up_g: torch.Tensor) -> torch.Tensor:
89 | g = upsample_groups(up_g, ratio=self.scale_factor)
90 | g = self.distributor(skip_f, g)
91 | g = self.out_conv(g)
92 | return g
93 |
94 |
95 | class DecoderFeatureProcessor(nn.Module):
96 | def __init__(self, decoder_dims: List[int], out_dims: List[int]):
97 | super().__init__()
98 | self.transforms = nn.ModuleList([
99 | nn.Conv2d(d_dim, p_dim, kernel_size=1) for d_dim, p_dim in zip(decoder_dims, out_dims)
100 | ])
101 |
102 | def forward(self, multi_scale_features: Iterable[torch.Tensor]) -> List[torch.Tensor]:
103 | outputs = [func(x) for x, func in zip(multi_scale_features, self.transforms)]
104 | return outputs
105 |
106 |
107 | class LinearPredictor(nn.Module):
108 | def __init__(self, in_dim: int, pred_dim: int):
109 | super().__init__()
110 | self.projection = GConv2D(in_dim, pred_dim + 1, kernel_size=1)
111 |
112 | def forward(self, im_feat: torch.Tensor, pred_feat: torch.Tensor) -> torch.Tensor:
113 | num_objects = pred_feat.shape[1]
114 | parameters = self.projection(pred_feat)
115 |
116 | im_feat = im_feat.unsqueeze(1).expand(-1, num_objects, -1, -1, -1)
117 | x = (im_feat * parameters[:, :, :-1]).sum(dim=2, keepdim=True) + parameters[:, :, -1:]
118 | return x
119 |
120 |
121 | class SensoryUpdater(nn.Module):
122 | # Used in the decoder, multi-scale feature + GRU
123 | def __init__(self, g_dims: List[int], mid_dim: int, sensory_dim: int):
124 | super().__init__()
125 | self.sensory_dim = sensory_dim
126 |
127 | self.g16_conv = GConv2D(g_dims[0], mid_dim, kernel_size=1)
128 | self.g8_conv = GConv2D(g_dims[1], mid_dim, kernel_size=1)
129 | self.g4_conv = GConv2D(g_dims[2], mid_dim, kernel_size=1)
130 |
131 | self.transform = GConv2D(mid_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
132 |
133 | nn.init.xavier_normal_(self.transform.weight)
134 |
135 | def forward(self, g: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
136 | g = self.g16_conv(g[0]) + self.g8_conv(downsample_groups(g[1], ratio=1/2)) + \
137 | self.g4_conv(downsample_groups(g[2], ratio=1/4))
138 |
139 | g = torch.cat([g, h], 2)
140 |
141 | # defined slightly differently than standard GRU,
142 | # namely the new value is generated before the forget gate.
143 | # might provide better gradient but frankly it was initially just an
144 | # implementation error that I never bothered fixing
145 | values = self.transform(g)
146 | forget_gate = torch.sigmoid(values[:, :, :self.sensory_dim])
147 | update_gate = torch.sigmoid(values[:, :, self.sensory_dim:self.sensory_dim * 2])
148 | new_value = torch.tanh(values[:, :, self.sensory_dim * 2:])
149 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
150 |
151 | return new_h
152 |
153 |
154 | class SensoryDeepUpdater(nn.Module):
155 | def __init__(self, f_dim: int, sensory_dim: int):
156 | super().__init__()
157 | self.sensory_dim = sensory_dim
158 | self.transform = GConv2D(f_dim + sensory_dim, sensory_dim * 3, kernel_size=3, padding=1)
159 |
160 | nn.init.xavier_normal_(self.transform.weight)
161 |
162 | def forward(self, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
163 | values = self.transform(torch.cat([f, h], dim=2))
164 | forget_gate = torch.sigmoid(values[:, :, :self.sensory_dim])
165 | update_gate = torch.sigmoid(values[:, :, self.sensory_dim:self.sensory_dim * 2])
166 | new_value = torch.tanh(values[:, :, self.sensory_dim * 2:])
167 | new_h = forget_gate * h * (1 - update_gate) + update_gate * new_value
168 |
169 | return new_h
170 |
--------------------------------------------------------------------------------
/deva/model/resnet.py:
--------------------------------------------------------------------------------
1 | """
2 | resnet.py - A modified ResNet structure
3 | We append extra channels to the first conv by some network surgery
4 | """
5 |
6 | from collections import OrderedDict
7 | import math
8 |
9 | import torch
10 | import torch.nn as nn
11 | from torch.utils import model_zoo
12 |
13 |
14 | def load_weights_add_extra_dim(target, source_state, extra_dim=1):
15 | new_dict = OrderedDict()
16 |
17 | for k1, v1 in target.state_dict().items():
18 | if not 'num_batches_tracked' in k1:
19 | if k1 in source_state:
20 | tar_v = source_state[k1]
21 |
22 | if v1.shape != tar_v.shape:
23 | # Init the new segmentation channel with zeros
24 | # print(v1.shape, tar_v.shape)
25 | c, _, w, h = v1.shape
26 | pads = torch.zeros((c,extra_dim,w,h), device=tar_v.device)
27 | nn.init.orthogonal_(pads)
28 | tar_v = torch.cat([tar_v, pads], 1)
29 |
30 | new_dict[k1] = tar_v
31 |
32 | target.load_state_dict(new_dict)
33 |
34 |
35 | model_urls = {
36 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
37 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
38 | }
39 |
40 |
41 | def conv3x3(in_planes, out_planes, stride=1, dilation=1):
42 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
43 | padding=dilation, dilation=dilation, bias=False)
44 |
45 |
46 | class BasicBlock(nn.Module):
47 | expansion = 1
48 |
49 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
50 | super(BasicBlock, self).__init__()
51 | self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation)
52 | self.bn1 = nn.BatchNorm2d(planes)
53 | self.relu = nn.ReLU(inplace=True)
54 | self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation)
55 | self.bn2 = nn.BatchNorm2d(planes)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 |
69 | if self.downsample is not None:
70 | residual = self.downsample(x)
71 |
72 | out += residual
73 | out = self.relu(out)
74 |
75 | return out
76 |
77 |
78 | class Bottleneck(nn.Module):
79 | expansion = 4
80 |
81 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
82 | super(Bottleneck, self).__init__()
83 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
84 | self.bn1 = nn.BatchNorm2d(planes)
85 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation,
86 | padding=dilation, bias=False)
87 | self.bn2 = nn.BatchNorm2d(planes)
88 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
89 | self.bn3 = nn.BatchNorm2d(planes * 4)
90 | self.relu = nn.ReLU(inplace=True)
91 | self.downsample = downsample
92 | self.stride = stride
93 |
94 | def forward(self, x):
95 | residual = x
96 |
97 | out = self.conv1(x)
98 | out = self.bn1(out)
99 | out = self.relu(out)
100 |
101 | out = self.conv2(out)
102 | out = self.bn2(out)
103 | out = self.relu(out)
104 |
105 | out = self.conv3(out)
106 | out = self.bn3(out)
107 |
108 | if self.downsample is not None:
109 | residual = self.downsample(x)
110 |
111 | out += residual
112 | out = self.relu(out)
113 |
114 | return out
115 |
116 |
117 | class ResNet(nn.Module):
118 | def __init__(self, block, layers=(3, 4, 23, 3), extra_dim=0):
119 | self.inplanes = 64
120 | super(ResNet, self).__init__()
121 | self.conv1 = nn.Conv2d(3+extra_dim, 64, kernel_size=7, stride=2, padding=3, bias=False)
122 | self.bn1 = nn.BatchNorm2d(64)
123 | self.relu = nn.ReLU(inplace=True)
124 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
125 | self.layer1 = self._make_layer(block, 64, layers[0])
126 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
127 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
128 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
129 |
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
133 | m.weight.data.normal_(0, math.sqrt(2. / n))
134 | elif isinstance(m, nn.BatchNorm2d):
135 | m.weight.data.fill_(1)
136 | m.bias.data.zero_()
137 |
138 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
139 | downsample = None
140 | if stride != 1 or self.inplanes != planes * block.expansion:
141 | downsample = nn.Sequential(
142 | nn.Conv2d(self.inplanes, planes * block.expansion,
143 | kernel_size=1, stride=stride, bias=False),
144 | nn.BatchNorm2d(planes * block.expansion),
145 | )
146 |
147 | layers = [block(self.inplanes, planes, stride, downsample)]
148 | self.inplanes = planes * block.expansion
149 | for i in range(1, blocks):
150 | layers.append(block(self.inplanes, planes, dilation=dilation))
151 |
152 | return nn.Sequential(*layers)
153 |
154 | def resnet18(pretrained=True, extra_dim=0):
155 | model = ResNet(BasicBlock, [2, 2, 2, 2], extra_dim)
156 | if pretrained:
157 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet18']), extra_dim)
158 | return model
159 |
160 | def resnet50(pretrained=True, extra_dim=0):
161 | model = ResNet(Bottleneck, [3, 4, 6, 3], extra_dim)
162 | if pretrained:
163 | load_weights_add_extra_dim(model, model_zoo.load_url(model_urls['resnet50']), extra_dim)
164 | return model
165 |
166 |
--------------------------------------------------------------------------------
/deva/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/utils/__init__.py
--------------------------------------------------------------------------------
/deva/utils/configuration.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 |
3 |
4 | class Configuration():
5 | def parse(self, unknown_arg_ok=False):
6 | parser = ArgumentParser()
7 |
8 | # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment
9 | parser.add_argument('--benchmark', action='store_true')
10 | # AMP causes problems in training. Not recommended.
11 | parser.add_argument('--amp', action='store_true')
12 |
13 | # Data parameters
14 | parser.add_argument('--static_root', help='Static training data root', default='../static')
15 | parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K')
16 | parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube')
17 | parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS')
18 | parser.add_argument('--ovis_root', help='OVIS data root', default='../OVIS-VOS-train')
19 | parser.add_argument('--num_workers',
20 | help='Total number of dataloader workers across all GPUs processes',
21 | type=int,
22 | default=16)
23 | parser.add_argument('--video_data_ratio', default=1.0, type=float)
24 |
25 | parser.add_argument('--pix_feat_dim', default=512, type=int)
26 | parser.add_argument('--key_dim', default=64, type=int)
27 | parser.add_argument('--value_dim', default=512, type=int)
28 |
29 | parser.add_argument('--deep_update_prob', default=0.2, type=float)
30 |
31 | # stage 1 and 2 are used previously in MiVOS/STCN/XMem, and not used here
32 | # we do not train on BL30K for this project, but we keep the naming (s03) for consistency
33 | parser.add_argument('--stages',
34 | help='Training stage (0-static images, 3-DAVIS+YouTubeVOS+OVIS)',
35 | default='03')
36 | parser.add_argument('--clip_grad_norm',
37 | help='Clip norm of global gradient at',
38 | default=3.0,
39 | type=float)
40 | """
41 | Stage-specific learning parameters
42 | Batch sizes are effective -- you don't have to scale them when you scale the number processes
43 | """
44 | # Stage 0, static images
45 | parser.add_argument('--s0_batch_size', default=16, type=int)
46 | parser.add_argument('--s0_iterations', default=80000, type=int)
47 | parser.add_argument('--s0_steps', nargs="*", default=[], type=int)
48 | parser.add_argument('--s0_lr', help='Initial learning rate', default=2e-5, type=float)
49 | parser.add_argument('--s0_num_ref_frames', default=2, type=int)
50 | parser.add_argument('--s0_num_frames', default=3, type=int)
51 | parser.add_argument('--s0_start_warm', default=10000, type=int)
52 | parser.add_argument('--s0_end_warm', default=35000, type=int)
53 | parser.add_argument('--s0_schedule', default='constant')
54 |
55 | # Stage 3, DAVIS+YoutubeVOS+OVIS
56 | parser.add_argument('--s3_batch_size', default=16, type=int)
57 | parser.add_argument('--s3_iterations', default=150000, type=int)
58 | parser.add_argument('--s3_steps', nargs="*", default=[120000, 140000], type=int)
59 | parser.add_argument('--s3_lr', help='Initial learning rate', default=1e-5, type=float)
60 | parser.add_argument('--s3_num_ref_frames', default=3, type=int)
61 | parser.add_argument('--s3_num_frames', default=8, type=int)
62 | parser.add_argument('--s3_start_warm', default=10000, type=int)
63 | parser.add_argument('--s3_end_warm', default=35000, type=int)
64 | parser.add_argument('--s3_schedule', default='step')
65 |
66 | parser.add_argument('--gamma',
67 | help='LR := LR*gamma at every decay step',
68 | default=0.1,
69 | type=float)
70 | parser.add_argument('--weight_decay', default=0.001, type=float)
71 |
72 | # Loading
73 | parser.add_argument('--load_network', help='Path to a pretrained network weights file. ')
74 | parser.add_argument(
75 | '--load_checkpoint',
76 | help='Path to the checkpoint file which contains network weights, optimizer, scheduler, '
77 | 'and the number of iterations. This is used for resuming interrupted training.')
78 |
79 | # Logging information
80 | parser.add_argument('--log_text_interval', default=100, type=int)
81 | parser.add_argument('--log_image_interval', default=1500, type=int)
82 | parser.add_argument('--save_network_interval', default=50000, type=int)
83 | parser.add_argument('--save_checkpoint_interval', default=50000, type=int)
84 | parser.add_argument('--exp_id',
85 | help='UNIQUE for a training run, set to NULL to disable logging',
86 | default='NULL')
87 | parser.add_argument('--debug',
88 | help='Debug mode; logs info at every iteration',
89 | action='store_true')
90 |
91 | if unknown_arg_ok:
92 | args, _ = parser.parse_known_args()
93 | self.args = vars(args)
94 | else:
95 | self.args = vars(parser.parse_args())
96 |
97 | # check if the stages are valid
98 | stage_to_perform = list(self.args['stages'])
99 | for s in stage_to_perform:
100 | if s not in ['0', '3']:
101 | raise NotImplementedError
102 |
103 | def get_stage_parameters(self, stage):
104 | parameters = {
105 | 'batch_size': self.args['s%s_batch_size' % stage],
106 | 'iterations': self.args['s%s_iterations' % stage],
107 | 'steps': self.args['s%s_steps' % stage],
108 | 'schedule': self.args['s%s_schedule' % stage],
109 | 'lr': self.args['s%s_lr' % stage],
110 | 'num_ref_frames': self.args['s%s_num_ref_frames' % stage],
111 | 'num_frames': self.args['s%s_num_frames' % stage],
112 | 'start_warm': self.args['s%s_start_warm' % stage],
113 | 'end_warm': self.args['s%s_end_warm' % stage],
114 | }
115 |
116 | return parameters
117 |
118 | def __getitem__(self, key):
119 | return self.args[key]
120 |
121 | def __setitem__(self, key, value):
122 | self.args[key] = value
123 |
124 | def __str__(self):
125 | return str(self.args)
126 |
--------------------------------------------------------------------------------
/deva/utils/davis_subset.txt:
--------------------------------------------------------------------------------
1 | bear
2 | bmx-bumps
3 | boat
4 | boxing-fisheye
5 | breakdance-flare
6 | bus
7 | car-turn
8 | cat-girl
9 | classic-car
10 | color-run
11 | crossing
12 | dance-jump
13 | dancing
14 | disc-jockey
15 | dog-agility
16 | dog-gooses
17 | dogs-scale
18 | drift-turn
19 | drone
20 | elephant
21 | flamingo
22 | hike
23 | hockey
24 | horsejump-low
25 | kid-football
26 | kite-walk
27 | koala
28 | lady-running
29 | lindy-hop
30 | longboard
31 | lucia
32 | mallard-fly
33 | mallard-water
34 | miami-surf
35 | motocross-bumps
36 | motorbike
37 | night-race
38 | paragliding
39 | planes-water
40 | rallye
41 | rhino
42 | rollerblade
43 | schoolgirls
44 | scooter-board
45 | scooter-gray
46 | sheep
47 | skate-park
48 | snowboard
49 | soccerball
50 | stroller
51 | stunt
52 | surf
53 | swing
54 | tennis
55 | tractor-sand
56 | train
57 | tuk-tuk
58 | upside-down
59 | varanus-cage
60 | walking
--------------------------------------------------------------------------------
/deva/utils/image_saver.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import torch
5 | from deva.dataset.utils import inv_im_trans
6 | from collections import defaultdict
7 |
8 |
9 | def tensor_to_numpy(image):
10 | image_np = (image.numpy() * 255).astype('uint8')
11 | return image_np
12 |
13 |
14 | def tensor_to_np_float(image):
15 | image_np = image.numpy().astype('float32')
16 | return image_np
17 |
18 |
19 | def detach_to_cpu(x):
20 | return x.detach().cpu()
21 |
22 |
23 | def transpose_np(x):
24 | return np.transpose(x, [1, 2, 0])
25 |
26 |
27 | def tensor_to_gray_im(x):
28 | x = detach_to_cpu(x)
29 | x = tensor_to_numpy(x)
30 | x = transpose_np(x)
31 | return x
32 |
33 |
34 | def tensor_to_im(x):
35 | x = detach_to_cpu(x)
36 | x = inv_im_trans(x).clamp(0, 1)
37 | x = tensor_to_numpy(x)
38 | x = transpose_np(x)
39 | return x
40 |
41 |
42 | # Predefined key <-> caption dict
43 | key_captions = {
44 | 'im': 'Image',
45 | 'gt': 'GT',
46 | }
47 | """
48 | Return an image array with captions
49 | keys in dictionary will be used as caption if not provided
50 | values should contain lists of cv2 images
51 | """
52 |
53 |
54 | def get_image_array(images, grid_shape, captions={}):
55 | h, w = grid_shape
56 | cate_counts = len(images)
57 | rows_counts = len(next(iter(images.values())))
58 |
59 | font = cv2.FONT_HERSHEY_SIMPLEX
60 |
61 | output_image = np.zeros([w * cate_counts, h * (rows_counts + 1), 3], dtype=np.uint8)
62 | col_cnt = 0
63 | for k, v in images.items():
64 |
65 | # Default as key value itself
66 | caption = captions.get(k, k)
67 |
68 | # Handles new line character
69 | dy = 40
70 | for i, line in enumerate(caption.split('\n')):
71 | cv2.putText(output_image, line, (10, col_cnt * w + 100 + i * dy), font, 0.8,
72 | (255, 255, 255), 2, cv2.LINE_AA)
73 |
74 | # Put images
75 | for row_cnt, img in enumerate(v):
76 | im_shape = img.shape
77 | if len(im_shape) == 2:
78 | img = img[..., np.newaxis]
79 |
80 | img = (img * 255).astype('uint8')
81 |
82 | output_image[(col_cnt + 0) * w:(col_cnt + 1) * w,
83 | (row_cnt + 1) * h:(row_cnt + 2) * h, :] = img
84 |
85 | col_cnt += 1
86 |
87 | return output_image
88 |
89 |
90 | def base_transform(im, size):
91 | im = tensor_to_np_float(im)
92 | if len(im.shape) == 3:
93 | im = im.transpose((1, 2, 0))
94 | else:
95 | im = im[:, :, None]
96 |
97 | # Resize
98 | if im.shape[1] != size:
99 | im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST)
100 |
101 | return im.clip(0, 1)
102 |
103 |
104 | def im_transform(im, size):
105 | return base_transform(inv_im_trans(detach_to_cpu(im)), size=size)
106 |
107 |
108 | def mask_transform(mask, size):
109 | return base_transform(detach_to_cpu(mask), size=size)
110 |
111 |
112 | def logits_transform(mask, size):
113 | return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size)
114 |
115 |
116 | def pool_pairs(images, size, num_objects):
117 | req_images = defaultdict(list)
118 |
119 | b, t = images['rgb'].shape[:2]
120 |
121 | # limit the number of images saved
122 | b = min(2, b)
123 |
124 | # find max num objects
125 | max_num_objects = max(num_objects[:b])
126 |
127 | GT_suffix = ''
128 | for bi in range(b):
129 | GT_suffix += ' \n%s' % images['info']['name'][bi][-25:-4]
130 |
131 | for bi in range(b):
132 | for ti in range(t):
133 | req_images['RGB'].append(im_transform(images['rgb'][bi, ti], size))
134 | for oi in range(max_num_objects):
135 | if ti == 0 or oi >= num_objects[bi]:
136 | req_images[f'Mask_{oi}'].append(
137 | mask_transform(images['first_frame_gt'][bi][0, oi], size))
138 | req_images[f'Aux Mask_{oi}'].append(
139 | mask_transform(images['first_frame_gt'][bi][0, oi], size))
140 | else:
141 | req_images[f'Mask_{oi}'].append(
142 | mask_transform(images[f'masks_{ti}'][bi][oi], size))
143 | req_images[f'Aux Mask_{oi}'].append(
144 | mask_transform(images[f'aux_masks_{ti}'][bi][oi][0], size))
145 | req_images[f'GT_{oi}_{GT_suffix}'].append(
146 | mask_transform(images['cls_gt'][bi, ti, 0] == (oi + 1), size))
147 |
148 | return get_image_array(req_images, size, key_captions)
--------------------------------------------------------------------------------
/deva/utils/load_subset.py:
--------------------------------------------------------------------------------
1 | """
2 | load_subset.py - Presents a subset of data
3 | DAVIS - only the training set
4 | YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all
5 | """
6 |
7 |
8 | def load_sub_davis(path='deva/utils/davis_subset.txt'):
9 | with open(path, mode='r') as f:
10 | subset = set(f.read().splitlines())
11 | return subset
12 |
13 |
14 | def load_sub_yv(path='deva/utils/yv_subset.txt'):
15 | with open(path, mode='r') as f:
16 | subset = set(f.read().splitlines())
17 | return subset
18 |
19 |
20 | def load_referring_yv_val(path='deva/utils/referring-youtubevos-val.txt'):
21 | with open(path, mode='r') as f:
22 | subset = set(f.read().splitlines())
23 | return subset
24 |
--------------------------------------------------------------------------------
/deva/utils/log_integrator.py:
--------------------------------------------------------------------------------
1 | """
2 | Integrate numerical values for some iterations
3 | Typically used for loss computation / logging to tensorboard
4 | Call finalize and create a new Integrator when you want to display/log
5 | """
6 | from typing import Dict, Callable, Tuple
7 | import torch
8 | from deva.utils.logger import TensorboardLogger
9 |
10 |
11 | class Integrator:
12 | def __init__(self, logger: TensorboardLogger, distributed: bool = True):
13 | self.values = {}
14 | self.counts = {}
15 | self.hooks = [] # List is used here to maintain insertion order
16 |
17 | self.logger = logger
18 |
19 | self.distributed = distributed
20 | self.local_rank = torch.distributed.get_rank()
21 | self.world_size = torch.distributed.get_world_size()
22 |
23 | def add_tensor(self, key: str, tensor: torch.Tensor):
24 | if key not in self.values:
25 | self.counts[key] = 1
26 | if type(tensor) == float or type(tensor) == int:
27 | self.values[key] = tensor
28 | else:
29 | self.values[key] = tensor.mean().item()
30 | else:
31 | self.counts[key] += 1
32 | if type(tensor) == float or type(tensor) == int:
33 | self.values[key] += tensor
34 | else:
35 | self.values[key] += tensor.mean().item()
36 |
37 | def add_dict(self, tensor_dict: Dict[str, torch.Tensor]):
38 | for k, v in tensor_dict.items():
39 | self.add_tensor(k, v)
40 |
41 | def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]):
42 | """
43 | Adds a custom hook, i.e. compute new metrics using values in the dict
44 | The hook takes the dict as argument, and returns a (k, v) tuple
45 | e.g. for computing IoU
46 | """
47 | if type(hook) == list:
48 | self.hooks.extend(hook)
49 | else:
50 | self.hooks.append(hook)
51 |
52 | def reset_except_hooks(self):
53 | self.values = {}
54 | self.counts = {}
55 |
56 | # Average and output the metrics
57 | def finalize(self, prefix: str, it: int, f=None) -> None:
58 |
59 | for hook in self.hooks:
60 | k, v = hook(self.values)
61 | self.add_tensor(k, v)
62 |
63 | for k, v in self.values.items():
64 |
65 | if k[:4] == 'hide':
66 | continue
67 |
68 | avg = v / self.counts[k]
69 |
70 | if self.distributed:
71 | # Inplace operation
72 | avg = torch.tensor(avg).cuda()
73 | torch.distributed.reduce(avg, dst=0)
74 |
75 | if self.local_rank == 0:
76 | avg = (avg / self.world_size).cpu().item()
77 | self.logger.log_metrics(prefix, k, avg, it, f)
78 | else:
79 | # Simple does it
80 | self.logger.log_metrics(prefix, k, avg, it, f)
--------------------------------------------------------------------------------
/deva/utils/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Dumps things to tensorboard and console
3 | """
4 |
5 | import os
6 | import warnings
7 |
8 | import torchvision.transforms as transforms
9 | from torch.utils.tensorboard import SummaryWriter
10 |
11 |
12 | def tensor_to_numpy(image):
13 | image_np = (image.numpy() * 255).astype('uint8')
14 | return image_np
15 |
16 |
17 | def detach_to_cpu(x):
18 | return x.detach().cpu()
19 |
20 |
21 | def fix_width_trunc(x):
22 | return ('{:.9s}'.format('{:0.9f}'.format(x)))
23 |
24 |
25 | class TensorboardLogger:
26 | def __init__(self, short_id, id):
27 | self.short_id = short_id
28 | if self.short_id == 'NULL':
29 | self.short_id = 'DEBUG'
30 |
31 | if id is None:
32 | self.no_log = True
33 | warnings.warn('Logging has been disabled.')
34 | else:
35 | self.no_log = False
36 |
37 | self.inv_im_trans = transforms.Normalize(
38 | mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
39 | std=[1 / 0.229, 1 / 0.224, 1 / 0.225])
40 |
41 | self.inv_seg_trans = transforms.Normalize(mean=[-0.5 / 0.5], std=[1 / 0.5])
42 |
43 | log_path = os.path.join('.', 'saves', '%s' % id)
44 | self.logger = SummaryWriter(log_path)
45 |
46 | # Get current git info for logging
47 | try:
48 | import git
49 | repo = git.Repo(".")
50 | git_info = str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)
51 | except (ImportError, RuntimeError) as e:
52 | print('Failed to fetch git info. Defaulting to None')
53 | git_info = 'None'
54 |
55 | self.log_string('git', git_info)
56 |
57 | def log_scalar(self, tag, x, step):
58 | if self.no_log:
59 | warnings.warn('Logging has been disabled.')
60 | return
61 | self.logger.add_scalar(tag, x, step)
62 |
63 | def log_metrics(self, l1_tag, l2_tag, val, step, f=None):
64 | tag = l1_tag + '/' + l2_tag
65 | text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(),
66 | l2_tag, fix_width_trunc(val))
67 | print(text)
68 | if f is not None:
69 | f.write(text + '\n')
70 | f.flush()
71 | self.log_scalar(tag, val, step)
72 |
73 | def log_im(self, tag, x, step):
74 | if self.no_log:
75 | warnings.warn('Logging has been disabled.')
76 | return
77 | x = detach_to_cpu(x)
78 | x = self.inv_im_trans(x)
79 | x = tensor_to_numpy(x)
80 | self.logger.add_image(tag, x, step)
81 |
82 | def log_cv2(self, tag, x, step):
83 | if self.no_log:
84 | warnings.warn('Logging has been disabled.')
85 | return
86 | x = x.transpose((2, 0, 1))
87 | self.logger.add_image(tag, x, step)
88 |
89 | def log_seg(self, tag, x, step):
90 | if self.no_log:
91 | warnings.warn('Logging has been disabled.')
92 | return
93 | x = detach_to_cpu(x)
94 | x = self.inv_seg_trans(x)
95 | x = tensor_to_numpy(x)
96 | self.logger.add_image(tag, x, step)
97 |
98 | def log_gray(self, tag, x, step):
99 | if self.no_log:
100 | warnings.warn('Logging has been disabled.')
101 | return
102 | x = detach_to_cpu(x)
103 | x = tensor_to_numpy(x)
104 | self.logger.add_image(tag, x, step)
105 |
106 | def log_string(self, tag, x):
107 | print(tag, x)
108 | if self.no_log:
109 | warnings.warn('Logging has been disabled.')
110 | return
111 | self.logger.add_text(tag, x)
112 |
--------------------------------------------------------------------------------
/deva/utils/palette.py:
--------------------------------------------------------------------------------
1 | davis_palette = b'\x00\x00\x00\x80\x00\x00\x00\x80\x00\x80\x80\x00\x00\x00\x80\x80\x00\x80\x00\x80\x80\x80\x80\x80@\x00\x00\xc0\x00\x00@\x80\x00\xc0\x80\x00@\x00\x80\xc0\x00\x80@\x80\x80\xc0\x80\x80\x00@\x00\x80@\x00\x00\xc0\x00\x80\xc0\x00\x00@\x80\x80@\x80\x00\xc0\x80\x80\xc0\x80@@\x00\xc0@\x00@\xc0\x00\xc0\xc0\x00@@\x80\xc0@\x80@\xc0\x80\xc0\xc0\x80\x00\x00@\x80\x00@\x00\x80@\x80\x80@\x00\x00\xc0\x80\x00\xc0\x00\x80\xc0\x80\x80\xc0@\x00@\xc0\x00@@\x80@\xc0\x80@@\x00\xc0\xc0\x00\xc0@\x80\xc0\xc0\x80\xc0\x00@@\x80@@\x00\xc0@\x80\xc0@\x00@\xc0\x80@\xc0\x00\xc0\xc0\x80\xc0\xc0@@@\xc0@@@\xc0@\xc0\xc0@@@\xc0\xc0@\xc0@\xc0\xc0\xc0\xc0\xc0 \x00\x00\xa0\x00\x00 \x80\x00\xa0\x80\x00 \x00\x80\xa0\x00\x80 \x80\x80\xa0\x80\x80`\x00\x00\xe0\x00\x00`\x80\x00\xe0\x80\x00`\x00\x80\xe0\x00\x80`\x80\x80\xe0\x80\x80 @\x00\xa0@\x00 \xc0\x00\xa0\xc0\x00 @\x80\xa0@\x80 \xc0\x80\xa0\xc0\x80`@\x00\xe0@\x00`\xc0\x00\xe0\xc0\x00`@\x80\xe0@\x80`\xc0\x80\xe0\xc0\x80 \x00@\xa0\x00@ \x80@\xa0\x80@ \x00\xc0\xa0\x00\xc0 \x80\xc0\xa0\x80\xc0`\x00@\xe0\x00@`\x80@\xe0\x80@`\x00\xc0\xe0\x00\xc0`\x80\xc0\xe0\x80\xc0 @@\xa0@@ \xc0@\xa0\xc0@ @\xc0\xa0@\xc0 \xc0\xc0\xa0\xc0\xc0`@@\xe0@@`\xc0@\xe0\xc0@`@\xc0\xe0@\xc0`\xc0\xc0\xe0\xc0\xc0\x00 \x00\x80 \x00\x00\xa0\x00\x80\xa0\x00\x00 \x80\x80 \x80\x00\xa0\x80\x80\xa0\x80@ \x00\xc0 \x00@\xa0\x00\xc0\xa0\x00@ \x80\xc0 \x80@\xa0\x80\xc0\xa0\x80\x00`\x00\x80`\x00\x00\xe0\x00\x80\xe0\x00\x00`\x80\x80`\x80\x00\xe0\x80\x80\xe0\x80@`\x00\xc0`\x00@\xe0\x00\xc0\xe0\x00@`\x80\xc0`\x80@\xe0\x80\xc0\xe0\x80\x00 @\x80 @\x00\xa0@\x80\xa0@\x00 \xc0\x80 \xc0\x00\xa0\xc0\x80\xa0\xc0@ @\xc0 @@\xa0@\xc0\xa0@@ \xc0\xc0 \xc0@\xa0\xc0\xc0\xa0\xc0\x00`@\x80`@\x00\xe0@\x80\xe0@\x00`\xc0\x80`\xc0\x00\xe0\xc0\x80\xe0\xc0@`@\xc0`@@\xe0@\xc0\xe0@@`\xc0\xc0`\xc0@\xe0\xc0\xc0\xe0\xc0 \x00\xa0 \x00 \xa0\x00\xa0\xa0\x00 \x80\xa0 \x80 \xa0\x80\xa0\xa0\x80` \x00\xe0 \x00`\xa0\x00\xe0\xa0\x00` \x80\xe0 \x80`\xa0\x80\xe0\xa0\x80 `\x00\xa0`\x00 \xe0\x00\xa0\xe0\x00 `\x80\xa0`\x80 \xe0\x80\xa0\xe0\x80``\x00\xe0`\x00`\xe0\x00\xe0\xe0\x00``\x80\xe0`\x80`\xe0\x80\xe0\xe0\x80 @\xa0 @ \xa0@\xa0\xa0@ \xc0\xa0 \xc0 \xa0\xc0\xa0\xa0\xc0` @\xe0 @`\xa0@\xe0\xa0@` \xc0\xe0 \xc0`\xa0\xc0\xe0\xa0\xc0 `@\xa0`@ \xe0@\xa0\xe0@ `\xc0\xa0`\xc0 \xe0\xc0\xa0\xe0\xc0``@\xe0`@`\xe0@\xe0\xe0@``\xc0\xe0`\xc0`\xe0\xc0\xe0\xe0\xc0'
2 |
3 | youtube_palette = b'\x00\x00\x00\xec_g\xf9\x91W\xfa\xc8c\x99\xc7\x94b\xb3\xb2f\x99\xcc\xc5\x94\xc5\xabyg\xff\xff\xffes~\x0b\x0b\x0b\x0c\x0c\x0c\r\r\r\x0e\x0e\x0e\x0f\x0f\x0f'
4 |
--------------------------------------------------------------------------------
/deva/utils/pano_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from deva.utils.vipseg_categories import VIPSEG_CATEGORIES
3 |
4 | vipseg_cat_to_isthing = {d['id']: d['isthing'] == 1 for d in VIPSEG_CATEGORIES}
5 |
6 |
7 | def id_to_rgb(id: np.ndarray) -> np.ndarray:
8 | h, w = id.shape
9 | rgb = np.zeros((h, w, 3), dtype=np.uint8)
10 |
11 | for i in range(3):
12 | rgb[:, :, i] = id % 256
13 | id = id // 256
14 |
15 | return rgb
16 |
17 |
18 | class ID2RGBConverter:
19 | def __init__(self):
20 | self.all_id = []
21 | self.obj_to_id = {}
22 |
23 | def _id_to_rgb(self, id: int):
24 | rgb = np.zeros((3, ), dtype=np.uint8)
25 | for i in range(3):
26 | rgb[i] = id % 256
27 | id = id // 256
28 | return rgb
29 |
30 | def convert(self, obj: int):
31 | if obj in self.obj_to_id:
32 | id = self.obj_to_id[obj]
33 | else:
34 | while True:
35 | id = np.random.randint(255, 256**3)
36 | if id not in self.all_id:
37 | break
38 | self.obj_to_id[obj] = id
39 | self.all_id.append(id)
40 |
41 | return id, self._id_to_rgb(id)
42 |
43 |
44 | class IDPostprocessor:
45 | def __init__(self):
46 | self.all_id = []
47 | self.thing_obj_to_id = {}
48 | self.stuff_to_id = {}
49 |
50 | def id_to_rgb(self, id):
51 | rgb = np.zeros((3, ), dtype=np.uint8)
52 | for i in range(3):
53 | rgb[i] = id % 256
54 | id = id // 256
55 | return rgb
56 |
57 | def _find_new_id(self, default_id):
58 | id = default_id
59 | while True:
60 | if id not in self.all_id:
61 | return id
62 | id = np.random.randint(256, 256**3)
63 |
64 | def convert(self, obj, category_id, isthing):
65 | if isthing:
66 | # is thing
67 | if (obj, category_id) in self.thing_obj_to_id:
68 | id = self.thing_obj_to_id[(obj, category_id)]
69 | else:
70 | id = self._find_new_id(obj)
71 | self.thing_obj_to_id[(obj, category_id)] = id
72 | self.all_id.append(id)
73 | else:
74 | # is stuff
75 | if category_id in self.stuff_to_id:
76 | id = self.stuff_to_id[category_id]
77 | else:
78 | id = self._find_new_id(obj)
79 | self.stuff_to_id[category_id] = id
80 | self.all_id.append(id)
81 |
82 | return id
83 |
--------------------------------------------------------------------------------
/deva/utils/referring-youtubevos-val.txt:
--------------------------------------------------------------------------------
1 | 0062f687f1
2 | 01c88b5b60
3 | 0390fabe58
4 | 03fe6115d4
5 | 0620b43a31
6 | 06a5dfb511
7 | 06cd94d38d
8 | 0723d7d4fe
9 | 0782a6df7e
10 | 0788b4033d
11 | 0a598e18a8
12 | 0b0c90e21a
13 | 0c04834d61
14 | 0daaddc9da
15 | 0e8a6b63bb
16 | 0f3f8b2b2f
17 | 1335b16cf9
18 | 13c3cea202
19 | 13ca7bbcfd
20 | 152fe4902a
21 | 17cba76927
22 | 182dbfd6ba
23 | 188cb4e03d
24 | 19cde15c4b
25 | 1a1dbe153e
26 | 1a609fa7ee
27 | 1a894a8f98
28 | 1ab5f4bbc5
29 | 1e0257109e
30 | 1e20ceafae
31 | 1f390d22ea
32 | 20a93b4c54
33 | 218ac81c2d
34 | 226f1e10f7
35 | 246e38963b
36 | 257f7fd5b8
37 | 29c06df0f2
38 | 2b904b76c9
39 | 30fe0ed0ce
40 | 31d3a7d2ee
41 | 31e0beaf99
42 | 332dabe378
43 | 335fc10235
44 | 33c8dcbe09
45 | 33e8066265
46 | 34564d26d8
47 | 352ad66724
48 | 35948a7fca
49 | 35d5e5149d
50 | 3674b2c70a
51 | 369919ef49
52 | 37b4ec2e1a
53 | 39b7491321
54 | 39bce09d8d
55 | 3b72dc1941
56 | 3be852ed44
57 | 3dd327ab4e
58 | 3e03f623bb
59 | 3f4bacb16a
60 | 4037d8305d
61 | 411774e9ff
62 | 4307020e0f
63 | 43115c42b2
64 | 44e5d1a969
65 | 450bd2e238
66 | 45dc90f558
67 | 466734bc5c
68 | 47d01d34c8
69 | 48d2909d9e
70 | 4b783f1fc5
71 | 4ee0105885
72 | 4f5b3310e3
73 | 4f6662e4e0
74 | 4fe6619a47
75 | 541ccb0844
76 | 54526e3c66
77 | 5460cc540a
78 | 547416bda1
79 | 559a611d86
80 | 5d2020eff8
81 | 6031809500
82 | 60362df585
83 | 61fca8cbf1
84 | 621487be65
85 | 623d24ce2b
86 | 62bf7630b3
87 | 63883da4f5
88 | 64c6f2ed76
89 | 65350fd60a
90 | 65e0640a2a
91 | 68dab8f80c
92 | 696e01387c
93 | 69c0f7494e
94 | 6a75316e99
95 | 6cb5b08d93
96 | 6cc8bce61a
97 | 72d613f21a
98 | 749f1abdf9
99 | 7741a0fbce
100 | 7775043b5e
101 | 77df215672
102 | 7836afc0c2
103 | 7a19a80b19
104 | 7a72130f21
105 | 7daa6343e6
106 | 7f26b553ae
107 | 822c31928a
108 | 8273b59141
109 | 853ca85618
110 | 85968ae408
111 | 8939473ea7
112 | 8b7b57b94d
113 | 8c60938d92
114 | 8d803e87f7
115 | 8dea7458de
116 | 8e2e5af6a8
117 | 92fde455eb
118 | 94fa9bd3b5
119 | 975be70866
120 | 9787f452bf
121 | 97b38cabcc
122 | 9a38b8e463
123 | 9c0b55cae5
124 | 9ce299a510
125 | 9da2156a73
126 | 9f16d17e42
127 | 9f21474aca
128 | 9f429af409
129 | 9fd2d2782b
130 | a00c3fa88e
131 | a0fc95d8fc
132 | a1251195e7
133 | a2948d4116
134 | a46012c642
135 | a4bce691c6
136 | a7462d6aaf
137 | a806e58451
138 | a9f23c9150
139 | ab9a7583f1
140 | abae1ce57d
141 | aceb34fcbe
142 | b00ff71889
143 | b05faf54f7
144 | b205d868e6
145 | b2256e265c
146 | b3b92781d9
147 | b5514f75d8
148 | b58a97176b
149 | b772ac822a
150 | b7928ea5c0
151 | b7b7e52e02
152 | b83923fd72
153 | b90f8c11db
154 | ba8823f2d2
155 | bc9ba8917e
156 | bf2d38aefe
157 | bf4cc89b18
158 | c16d9a4ade
159 | c280d21988
160 | c2bbd6d121
161 | c42fdedcdd
162 | c74fc37224
163 | c9ef04fe59
164 | cb06f84b6e
165 | cbea8f6bea
166 | cc1a82ac2a
167 | cc7c3138ff
168 | cd69993923
169 | cd896a9bee
170 | cdcfd9f93a
171 | d1ac0d8b81
172 | d1dd586cfd
173 | d59c093632
174 | d69812339e
175 | d7a38bf258
176 | d7ff44ea97
177 | d975e5f4a9
178 | dab44991de
179 | dc197289ef
180 | dce363032d
181 | dea0160a12
182 | deed0ab4fc
183 | e027ebc228
184 | e10236eb37
185 | e11254d3b9
186 | e633eec195
187 | eb263ef128
188 | eb49ce8027
189 | ebe7138e58
190 | ee9415c553
191 | eea1a45e49
192 | eeb18f9d47
193 | f054e28786
194 | f143fede6f
195 | f2a45acf1c
196 | f3678388a7
197 | f39c805b54
198 | f7255a57d0
199 | f7d7fb16d0
200 | fb104c286f
201 | fd8cf868b2
202 | fef7e84268
203 |
--------------------------------------------------------------------------------
/deva/utils/tensor_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List, Iterable
2 | import torch
3 | import torch.nn.functional as F
4 |
5 |
6 | # STM
7 | def pad_divide_by(in_img: torch.Tensor, d: int) -> (torch.Tensor, Iterable[int]):
8 | h, w = in_img.shape[-2:]
9 |
10 | if h % d > 0:
11 | new_h = h + d - h % d
12 | else:
13 | new_h = h
14 | if w % d > 0:
15 | new_w = w + d - w % d
16 | else:
17 | new_w = w
18 | lh, uh = int((new_h - h) / 2), int(new_h - h) - int((new_h - h) / 2)
19 | lw, uw = int((new_w - w) / 2), int(new_w - w) - int((new_w - w) / 2)
20 | pad_array = (int(lw), int(uw), int(lh), int(uh))
21 | out = F.pad(in_img, pad_array)
22 | return out, pad_array
23 |
24 |
25 | def unpad(img: torch.Tensor, pad: Iterable[int]) -> torch.Tensor:
26 | if len(img.shape) == 4:
27 | if pad[2] + pad[3] > 0:
28 | img = img[:, :, pad[2]:-pad[3], :]
29 | if pad[0] + pad[1] > 0:
30 | img = img[:, :, :, pad[0]:-pad[1]]
31 | elif len(img.shape) == 3:
32 | if pad[2] + pad[3] > 0:
33 | img = img[:, pad[2]:-pad[3], :]
34 | if pad[0] + pad[1] > 0:
35 | img = img[:, :, pad[0]:-pad[1]]
36 | elif len(img.shape) == 5:
37 | if pad[2] + pad[3] > 0:
38 | img = img[:, :, :, pad[2]:-pad[3], :]
39 | if pad[0] + pad[1] > 0:
40 | img = img[:, :, :, :, pad[0]:-pad[1]]
41 | elif len(img.shape) == 2:
42 | if pad[2] + pad[3] > 0:
43 | img = img[pad[2]:-pad[3], :]
44 | if pad[0] + pad[1] > 0:
45 | img = img[:, pad[0]:-pad[1]]
46 | else:
47 | raise NotImplementedError
48 | return img
49 |
--------------------------------------------------------------------------------
/deva/vps_metrics/README.md:
--------------------------------------------------------------------------------
1 | `eval_stq_vspw.py` and `eval_vpq_vspw.py` are modified from the original evaluation scripts in https://github.com/VIPSeg-Dataset/VIPSeg-Dataset to generate the quantitative results in the GMP paper.
2 |
3 | There are a few main modifications:
4 | 1. Multiprocessing is implemented to speed up evaluation (by a lot).
5 | 2. Masks are loaded on-the-fly to reduce RAM usage.
6 | 3. There is no longer a `pan_pred` directory that stores the masks separately. The current script expects a single folder that contains all the video folders with the json file.
7 | 4. The print message is modified to be consistent with the notations in the paper and the output text file. `0-frame vpq_stat` becomes `1-frame vpq_stat`;`5-frame vpq_stat` becomes `2-frame vpq_stat`, etc.
8 |
9 |
10 | I noticed that there has been an update to the origin VIPSeg evaluation script that covers some of these modifications (i.e., point 1 and 2). I developed this version independently with that update and I keep the current version for documentation. They generate the same results in my testing.
11 |
--------------------------------------------------------------------------------
/deva/vps_metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/deva/vps_metrics/__init__.py
--------------------------------------------------------------------------------
/deva/vps_metrics/eval_stq_vipseg.py:
--------------------------------------------------------------------------------
1 | # -------------------------------------------------------------------
2 | # Video Panoptic Segmentation
3 | #
4 | # VPQ evaluation code by tube (video segment) matching
5 | # Inference on every frames and evaluation on every 5 frames.
6 | # ------------------------------------------------------------------
7 |
8 | # Modified by Rex Cheng Oct 2022 - save results as a .txt file
9 | # Feb 2023 - Ported the update from the official code patch to here
10 | # - added a functional interface
11 |
12 | import argparse
13 | import os
14 | import os.path
15 | import numpy as np
16 | from PIL import Image
17 | import json
18 | from tqdm import tqdm
19 | import deva.vps_metrics.segmentation_and_tracking_quality as numpy_stq
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description='VPSNet eval')
24 | parser.add_argument('--submit_dir', '-i', type=str, help='test output directory')
25 |
26 | parser.add_argument(
27 | '--truth_dir',
28 | type=str,
29 | default='../VIPSeg/VIPSeg_720P/panomasksRGB',
30 | help='ground truth directory. Point this to /VIPSeg/VIPSeg_720P/panomasksRGB '
31 | 'after running the conversion script')
32 |
33 | parser.add_argument(
34 | '--pan_gt_json_file',
35 | type=str,
36 | default='../VIPSeg/VIPSeg_720P/panoptic_gt_VIPSeg_val.json',
37 | help='ground truth JSON file. Point this to /VIPSeg/VIPSeg_720P/panoptic_gt_'
38 | 'VIPSeg_val.json after running the conversion script')
39 |
40 | args = parser.parse_args()
41 | return args
42 |
43 |
44 | # constants
45 | n_classes = 124
46 | ignore_label = 255
47 | bit_shift = 16
48 |
49 |
50 | def eval_stq(submit_dir, truth_dir, pan_gt_json_file):
51 | output_dir = submit_dir
52 | if not os.path.isdir(submit_dir):
53 | print("%s doesn't exist" % submit_dir)
54 | if os.path.isdir(submit_dir) and os.path.isdir(truth_dir):
55 | if not os.path.exists(output_dir):
56 | os.makedirs(output_dir)
57 |
58 | pan_pred_json_file = os.path.join(submit_dir, 'pred.json')
59 | with open(pan_pred_json_file, 'r') as f:
60 | pred_jsons = json.load(f)
61 | with open(pan_gt_json_file, 'r') as f:
62 | gt_jsons = json.load(f)
63 |
64 | categories = gt_jsons['categories']
65 |
66 | thing_list_ = []
67 | for cate_ in categories:
68 | cat_id = cate_['id']
69 | isthing = cate_['isthing']
70 | if isthing:
71 | thing_list_.append(cat_id)
72 |
73 | stq_metric = numpy_stq.STQuality(n_classes, thing_list_, ignore_label, bit_shift, 2**24)
74 |
75 | pred_annos = pred_jsons['annotations']
76 | pred_j = {}
77 | for p_a in pred_annos:
78 | pred_j[p_a['video_id']] = p_a['annotations']
79 | gt_annos = gt_jsons['annotations']
80 | gt_j = {}
81 | for g_a in gt_annos:
82 | gt_j[g_a['video_id']] = g_a['annotations']
83 |
84 | pbar = tqdm(gt_jsons['videos'])
85 | for seq_id, video_images in enumerate(pbar):
86 | video_id = video_images['video_id']
87 | pbar.set_description(video_id)
88 |
89 | # print('processing video:{}'.format(video_id))
90 | gt_image_jsons = video_images['images']
91 | gt_js = gt_j[video_id]
92 | pred_js = pred_j[video_id]
93 | assert len(gt_js) == len(pred_js)
94 |
95 | gt_pans = []
96 | pred_pans = []
97 | for imgname_j in gt_image_jsons:
98 | imgname = imgname_j['file_name']
99 | image = np.array(Image.open(os.path.join(submit_dir, 'pan_pred', video_id, imgname)))
100 | pred_pans.append(image)
101 | image = np.array(Image.open(os.path.join(truth_dir, video_id, imgname)))
102 | gt_pans.append(image)
103 | gt_id_to_ins_num_dic = {}
104 | list_tmp = []
105 | for segm in gt_js:
106 | for img_info in segm['segments_info']:
107 | id_tmp_ = img_info['id']
108 | if id_tmp_ not in list_tmp:
109 | list_tmp.append(id_tmp_)
110 | for ii, id_tmp_ in enumerate(list_tmp):
111 | gt_id_to_ins_num_dic[id_tmp_] = ii
112 |
113 | pred_id_to_ins_num_dic = {}
114 | list_tmp = []
115 | for segm in pred_js:
116 | for img_info in segm['segments_info']:
117 | id_tmp_ = img_info['id']
118 | if id_tmp_ not in list_tmp:
119 | list_tmp.append(id_tmp_)
120 | for ii, id_tmp_ in enumerate(list_tmp):
121 | pred_id_to_ins_num_dic[id_tmp_] = ii
122 |
123 | for i, (gt_json, pred_json, gt_pan, pred_pan, gt_image_json) in enumerate(
124 | list(zip(gt_js, pred_js, gt_pans, pred_pans, gt_image_jsons))):
125 | #### Step1. Collect frame-level pan_gt, pan_pred, etc.
126 | gt_pan, pred_pan = np.uint32(gt_pan), np.uint32(pred_pan)
127 | pan_gt = gt_pan[:, :, 0] + gt_pan[:, :, 1] * 256 + gt_pan[:, :, 2] * 256 * 256
128 | pan_pred = pred_pan[:, :, 0] + pred_pan[:, :, 1] * 256 + pred_pan[:, :, 2] * 256 * 256
129 |
130 | ground_truth_instance = np.ones_like(pan_gt) * 255
131 | ground_truth_semantic = np.ones_like(pan_gt) * 255
132 | for el in gt_json['segments_info']:
133 | id_ = el['id']
134 | cate_id = el['category_id']
135 | ground_truth_semantic[pan_gt == id_] = cate_id
136 | ground_truth_instance[pan_gt == id_] = gt_id_to_ins_num_dic[id_]
137 |
138 | ground_truth = ((ground_truth_semantic << bit_shift) + ground_truth_instance)
139 |
140 | prediction_instance = np.ones_like(pan_pred) * 255
141 | prediction_semantic = np.ones_like(pan_pred) * 255
142 |
143 | for el in pred_json['segments_info']:
144 | id_ = el['id']
145 | cate_id = el['category_id']
146 | prediction_semantic[pan_pred == id_] = cate_id
147 | prediction_instance[pan_pred == id_] = pred_id_to_ins_num_dic[id_]
148 | prediction = ((prediction_semantic << bit_shift) + prediction_instance)
149 |
150 | stq_metric.update_state(ground_truth.astype(dtype=np.int32),
151 | prediction.astype(dtype=np.int32), seq_id)
152 | result = stq_metric.result()
153 | print('*' * 100)
154 | print('STQ : {}'.format(result['STQ']))
155 | print('AQ :{}'.format(result['AQ']))
156 | print('IoU:{}'.format(result['IoU']))
157 | print('STQ_per_seq')
158 | print(result['STQ_per_seq'])
159 | print('AQ_per_seq')
160 | print(result['AQ_per_seq'])
161 | print('ID_per_seq')
162 | print(result['ID_per_seq'])
163 | print('Length_per_seq')
164 | print(result['Length_per_seq'])
165 | print('*' * 100)
166 |
167 | with open(os.path.join(submit_dir, 'stq.txt'), 'w') as f:
168 | f.write(f'{result["STQ"]*100:.1f},{result["AQ"]*100:.1f},{result["IoU"]*100:.1f}\n')
169 |
170 |
171 | if __name__ == "__main__":
172 | args = parse_args()
173 | submit_dir = args.submit_dir
174 | truth_dir = args.truth_dir
175 | pan_gt_json_file = args.pan_gt_json_file
176 | eval_stq(submit_dir, truth_dir, pan_gt_json_file)
177 |
--------------------------------------------------------------------------------
/deva/vps_metrics/stuff_merging.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 |
4 | import json
5 | from argparse import ArgumentParser
6 | import numpy as np
7 | from PIL import Image
8 | from functools import partial
9 | from tqdm import tqdm
10 |
11 | from deva.utils.vipseg_categories import VIPSEG_CATEGORIES
12 | from deva.utils.pano_utils import IDPostprocessor, id_to_rgb
13 | from multiprocessing import Pool
14 | """
15 | Post-processing is done online, so technically it can be run with the original evaluation.
16 | But that introduces additional complexities that are just tailored for VPS/VIPSeg.
17 | So I make this part into a post-processing script.
18 |
19 | Specifically, it does the following:
20 | 1. For every "thing" segment, whenever its category changes, we give it a new object id.
21 | This is because the evaluation script assumes that objects with the same object id
22 | have the same category. This change aligns with the original formula in
23 | Video Panoptic Segmentation, CVPR 2020.
24 | 2. It stores a mapping table from "stuff" category to object id. Whenever we encounter a segment
25 | with a "stuff" category, we apply this mapping.
26 | """
27 |
28 |
29 | def process_single_video(vid_ann, input_path, output_path):
30 | video_id = vid_ann['video_id']
31 | video_output_annotation = []
32 | video_output = {'video_id': video_id, 'annotations': video_output_annotation}
33 | output_path = path.join(output_path, 'pan_pred', video_id)
34 | os.makedirs(output_path, exist_ok=True)
35 |
36 | converter = IDPostprocessor()
37 |
38 | for ann in vid_ann['annotations']:
39 | file_name = ann['file_name']
40 | segments_info = ann['segments_info']
41 | output_segments_info = []
42 | output_annotation = {'file_name': ann['file_name'], 'segments_info': output_segments_info}
43 | video_output_annotation.append(output_annotation)
44 |
45 | mask = np.array(
46 | Image.open(
47 | path.join(input_path, 'pan_pred', video_id,
48 | file_name.replace('.jpg', '.png')))).astype(np.int32)
49 | mask = mask[:, :, 0] + mask[:, :, 1] * 256 + mask[:, :, 2] * 256 * 256
50 | output_mask = np.zeros_like(mask)
51 |
52 | for segment in segments_info:
53 | id = segment['id']
54 | category_id = segment['category_id']
55 | isthing = vipseg_cat_to_isthing[category_id]
56 | new_id = converter.convert(id, category_id, isthing)
57 | output_mask[mask == id] = new_id
58 |
59 | if isthing:
60 | # is not stuff
61 | output_segment = {
62 | 'id': new_id,
63 | 'category_id': segment['category_id'],
64 | 'isthing': 1,
65 | }
66 | output_segments_info.append(output_segment)
67 |
68 | # a pass for the merged stuff objects
69 | for cat, new_id in converter.stuff_to_id.items():
70 | area = int((output_mask == new_id).sum())
71 | assert not vipseg_cat_to_isthing[cat]
72 | if area > 0:
73 | output_segment = {
74 | 'id': new_id,
75 | 'category_id': cat,
76 | 'isthing': 0,
77 | }
78 | output_segments_info.append(output_segment)
79 |
80 | # save the new output mask
81 | output_mask = id_to_rgb(output_mask)
82 | output_mask = Image.fromarray(output_mask)
83 | output_mask.save(path.join(output_path, file_name.replace('.jpg', '.png')))
84 |
85 | return video_output
86 |
87 |
88 | vipseg_cat_to_isthing = {d['id']: d['isthing'] == 1 for d in VIPSEG_CATEGORIES}
89 |
90 |
91 | def merge_stuff(input_path, output_path):
92 | with open(path.join(input_path, 'pred.json')) as f:
93 | annotations = json.load(f)['annotations']
94 |
95 | output_annotations = []
96 | pool = Pool(16)
97 | for out_vid_ann in tqdm(pool.imap(
98 | partial(process_single_video, input_path=input_path, output_path=output_path),
99 | annotations),
100 | max_value=len(annotations)):
101 | output_annotations.append(out_vid_ann)
102 |
103 | output_json = {'annotations': output_annotations}
104 | with open(path.join(output_path, 'pred.json'), 'w') as f:
105 | json.dump(output_json, f)
106 |
107 |
108 | if __name__ == '__main__':
109 | parser = ArgumentParser()
110 | parser.add_argument('--input_path')
111 | parser.add_argument('--output_path')
112 | args = parser.parse_args()
113 |
114 | input_path = args.input_path
115 | output_path = args.output_path
116 |
117 | merge_stuff(input_path, output_path)
118 |
--------------------------------------------------------------------------------
/docs/CUSTOM.md:
--------------------------------------------------------------------------------
1 | # Using DEVA with Your Custom Detection Models
2 |
3 | There are two main ways to use your own detection models:
4 | 1. Online integration. Use a single script that queries the image detection model when needed. This is how the demo works (by querying Grounded Segment Anything or Segment Anything).
5 | 2. Offline integration. Run the image detection model on all frames beforehand and save the results. Then, run DEVA on the saved results. This is how we evaluate DEVA with the benchmark detection models.
6 |
7 | From an algorithm point-of-view, both approaches are online/semi-online. There is only an implementation difference.
8 |
9 | For (1), look at `demo/demo_automatic` and `demo/demo_with_text`.
10 |
11 | For (2), generate the detections following the data format in `example/vipseg`. There is a json file associated with every segmentation that contains object IDs and meta information. "category_id" and "score" can be optionally included in the json. Then follow the "demo" command listed in [EVALUATION.md](EVALUATION.md).
12 |
13 | You can also "mock" your data as BURST/VIPSeg and run the models as if you are evaluating on BURST/VIPSeg.
14 |
--------------------------------------------------------------------------------
/docs/DEMO.md:
--------------------------------------------------------------------------------
1 | # Details on the Demo
2 |
3 | ## Pipeline
4 |
5 | The high-level description below is for the online setting. In the semi-online setting, the detections are first merged across a small clip.
6 | The first frame is always initialized with detection without propagation.
7 |
8 | ### Text-prompted mode (recommended)
9 | 1. DEVA propagates masks from memory to the current frame
10 | 2. If this is a detection frame, go to the next step. Otherwise, no further processing is needed for this frame.
11 | 3. Grounding DINO takes the text prompt and generates some bounding boxes
12 | 4. Segment Anything takes the bounding boxes and generates corresponding segmentation masks
13 | 5. The propagated masks are compared to and merged with the segmentation from Segment Anything
14 |
15 | ### Automatic mode
16 | 1. DEVA propagates masks from memory to the current frame.
17 | 2. If this is a detection frame, go to the next step. Otherwise, no further processing is needed for this frame.
18 | 3. We generate a grid of points on the unsegmented regions.
19 | 4. Segment Anything takes the points and generates corresponding segmentation masks.
20 | 5. The propagated masks are compared to and merged with the segmentation from Segment Anything.
21 |
22 | ## Tips on Speeding up Inference
23 |
24 | **General Tips:**
25 |
26 | - Though innocently looking, reading frames from disk, visualizing the output, and encoding the output as videos can be slow, especially at high resolutions. The script version runs faster than the gradio version because it uses threaded I/O.
27 | - Specifying `--amp` (automatic fixed precision) makes things run faster on most modern GPUs.
28 | - In general, text-prompted inference is faster and more robust than "automatic" inference.
29 | - To speed up the actual processing, we need to speed up either the image model or the propagation model.
30 |
31 | **Speed up the image model:**
32 |
33 | - The most efficient way is to use the image model less often. This can be achieved by:
34 | - Using `online` instead of `semionline`, or,
35 | - Increasing `detection_every`.
36 | - Use a faster image model. For example, Mobile-SAM is faster than SAM. Grounded-Segment-Anything (text-prompt) is faster than automatic SAM. In automatic mode, you can reduce the number of prompting points (`SAM_NUM_POINTS_PER_SIDE`) to reduce the number of queries to SAM.
37 | - In automatic mode, increasing `SAM_NUM_POINTS_PER_BATCH` improves parallelism.
38 |
39 | **Speeding up the propagation model:**
40 |
41 | - In general, the running time of the propagation model scales linearly with the number of objects (not to be confused with direct proportionality). The best play is thus to reduce the number of objects:
42 | - Using text-prompt typically generates more relevant objects and fewer overall number of objects.
43 | - Increasing the thresholds `SAM_PRED_IOU_THRESHOLD` or `DINO_THRESHOLD` reduces the number of detected objects.
44 | - Reduce `max_missed_detection_count` to delete objects more readily.
45 | - In automatic mode, enable `suppress_small_objects` to get larger and fewer segments. Note this option has its own overhead.
46 | - Reduce the internal processing resolution `size`. Note this does not affect the image model.
47 | - Increasing `chunk_size` improves parallelism.
48 |
49 | ## Explanation of arguments
50 |
51 | General:
52 | - `detection_every`: number of frames between two consecutive detections; a higher number means faster inference but slower responses to new objects
53 | - `amp`: enable mixed precision; is faster and has a lower memory usage
54 | - `chunk_size`: number of objects to be processed in parallel; a higher number means faster inference but higher memory usage
55 | - `size`: internal processing resolution for the propagation module; defaults to 480
56 | - `max_missed_detection_count`: maximum number of consecutive detections that can be missed before an object is deleted from memory.
57 | - `max_num_objects`: maximum number of objects that can be tracked at the same time; new objects are ignored if this is exceeded
58 |
59 | Text-prompted mode only:
60 | - `DINO_THRESHOLD`: threshold for DINO to consider a detection as valid
61 | - `prompt`: text prompt to use, separate by a full stop; e.g. "people.trees". The wording of the prompt and minor details like pluralization might affect the results.
62 |
63 | Automatic mode only:
64 | - `SAM_NUM_POINTS_PER_SIDE`: number of points per side to use for automatic grid-based prompting in SAM
65 | - `SAM_NUM_POINTS_PER_BATCH`: number of points prompts to process in parallel in SAM
66 | - `SAM_PRED_IOU_THRESHOLD`: threshold of predicted IoU to be considered as a valid segmentation for SAM
67 | - `suppress_small_objects`: if enabled, small objects that overlap with large objects are suppressed during the automatic mode; does not matter in the text-prompted mode
68 | - `SAM_OVERLAP_THRESHOLD`: if suppress_small_objects are enabled, this is the IoU threshold for the suppression. A lower threshold means more segmentation masks (less suppression)
69 |
70 | ## Source videos
71 |
72 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/4a00cd0d-f712-447f-82c4-6152addffd6b
73 |
74 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/c556d398-44dd-423b-9ff3-49763eaecd94
75 |
76 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/72e7495c-d5f9-4a8b-b7e8-8714b269e98d
77 |
78 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/337dd073-07eb-4392-9610-c5f6c6b94832
79 |
80 | https://github.com/hkchengrex/Tracking-Anything-with-DEVA/assets/7107196/e5f6df87-9fd0-4178-8490-00c4b8dc613b
81 |
--------------------------------------------------------------------------------
/docs/TRAINING.md:
--------------------------------------------------------------------------------
1 | # Training DEVA
2 |
3 | Note that this repository only supports the training of the temporal propagation module. For the image module, please refer to the individual projects.
4 |
5 | ## Setting Up Data
6 |
7 | We put datasets out-of-source, as in XMem. You do not need BL30K. The directory structure should look like this:
8 | ```bash
9 | ├── Tracking-Anything-with-DEVA
10 | ├── DAVIS
11 | │ ├── 2016
12 | │ │ ├── Annotations
13 | │ │ └── ...
14 | │ └── 2017
15 | │ ├── test-dev
16 | │ │ ├── Annotations
17 | │ │ └── ...
18 | │ └── trainval
19 | │ ├── Annotations
20 | │ └── ...
21 | ├── static
22 | │ ├── BIG_small
23 | │ └── ...
24 | └── YouTube
25 | │ ├── all_frames
26 | │ │ └── valid_all_frames
27 | │ ├── train
28 | │ └── valid
29 | └── OVIS-VOS-train
30 | ├── JPEGImages
31 | └── Annotations
32 | ```
33 |
34 | You can try our script `python -m scripts.download_dataset` which might not work 100% of the time due to Google Drive's blocking. If it fails, please download the datasets manually. The links can be found in the script.
35 |
36 | To generate OVIS-VOS-train, use something like https://github.com/youtubevos/vis2vos or download our preprocessed version from https://drive.google.com/uc?id=1AZPyyqVqOl6j8THgZ1UdNJY9R1VGEFrX.
37 |
38 | ## Training Command
39 | The training command is the same as in XMem. We tried training with 4/8 GPUs.
40 | With 8 GPUs,
41 | ```
42 | python -m torch.distributed.run --master_port 25763 --nproc_per_node=8 deva/train.py --exp_id deva_retrain --stage 03
43 | ```
44 | - Change `nproc_per_node` to change the number of GPUs.
45 | - Prepend `CUDA_VISIBLE_DEVICES=...` if you want to use specific GPUs.
46 | - Change `master_port` if you encounter port collision.
47 | - `exp_id` is a unique experiment identifier that does not affect how the training is done.
48 | - Models will be saved in `./saves/`.
49 | - We simply use the last trained model without model selection.
50 |
--------------------------------------------------------------------------------
/docs/style.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: 'Source Sans 3', sans-serif;
3 | font-size: 18px;
4 | margin-left: auto;
5 | margin-right: auto;
6 | font-weight: 300;
7 | height: 100%;
8 | max-width: 1000px;
9 | }
10 |
11 | .light {
12 | font-weight: 100;
13 | }
14 |
15 | .heavy {
16 | font-weight: 400;
17 | }
18 |
19 | .column {
20 | float: left;
21 | }
22 |
23 | .metric_table {
24 | border-collapse: collapse;
25 | margin-left: 15px;
26 | margin-right: auto;
27 | }
28 |
29 | .metric_table th {
30 | border-bottom: 1px solid #555;
31 | padding-left: 15px;
32 | padding-right: 15px;
33 | }
34 |
35 | .metric_table td {
36 | padding-left: 15px;
37 | padding-right: 15px;
38 | }
39 |
40 | .metric_table .left_align {
41 | text-align: left;
42 | }
43 |
44 | a:link,
45 | a:visited {
46 | color: #05538f;
47 | text-decoration: none;
48 | }
49 |
50 | a:hover {
51 | color: #63cbdd;
52 | }
53 |
54 | hr {
55 | border: 0;
56 | height: 1px;
57 | background-image: linear-gradient(to right, rgba(0, 0, 0, 0), rgba(0, 0, 0, 0.75), rgba(0, 0, 0, 0));
58 | }
--------------------------------------------------------------------------------
/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/evaluation/__init__.py
--------------------------------------------------------------------------------
/evaluation/eval_ref_davis.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from argparse import ArgumentParser
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 | from tqdm import tqdm
8 |
9 | from deva.inference.data.referring_test_datasets import ReferringDAVISTestDataset
10 | from deva.inference.image_feature_store import ImageFeatureStore
11 | from deva.inference.inference_core import DEVAInferenceCore
12 | from deva.inference.consensus_associated import find_consensus_with_established_association
13 | from deva.utils.palette import davis_palette
14 | from deva.inference.result_utils import ResultSaver
15 | from deva.inference.eval_args import add_common_eval_args, get_model_and_config
16 |
17 |
18 | def main():
19 | """
20 | Arguments loading
21 | """
22 | parser = ArgumentParser()
23 | parser.add_argument('--img_path', default='../DAVIS/2017/trainval/JPEGImages/480p')
24 | parser.add_argument('--mask_path')
25 | parser.add_argument('--num_voting_frames',
26 | default=5,
27 | type=int,
28 | help='Number of frames selected for the initial consensus voting')
29 | add_common_eval_args(parser)
30 | network, config, args = get_model_and_config(parser)
31 | """
32 | Data preparation
33 | """
34 | out_path = args.output
35 | meta_dataset = ReferringDAVISTestDataset(args.img_path, args.mask_path)
36 | torch.autograd.set_grad_enabled(False)
37 |
38 | videos = meta_dataset.get_videos()
39 |
40 | total_process_time = 0
41 | total_frames = 0
42 |
43 | # Start eval
44 | pbar = tqdm(videos, total=len(videos))
45 | for vid_name in pbar:
46 | pbar.set_description(vid_name)
47 | video_scores = meta_dataset.get_scores(vid_name)
48 | try:
49 | """
50 | initial pass, perform consensus voting and get a keyframe
51 | """
52 | image_feature_store = ImageFeatureStore(network)
53 | vid_reader = meta_dataset.get_offline_sampled_frames(vid_name,
54 | config['num_voting_frames'])
55 | loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
56 |
57 | time_indices = []
58 | images = []
59 | masks = []
60 | scores = []
61 | for ti, data in enumerate(loader):
62 | time_indices.append(data['info']['time_index'][0].item())
63 | image = data['rgb'].cuda()[0]
64 | mask = data['mask'].cuda()[0]
65 | images.append(image)
66 | masks.append(mask)
67 |
68 | frame_name = data['info']['frame'][0][:-4]
69 | scores.append(video_scores[frame_name])
70 |
71 | torch.cuda.synchronize()
72 | start = torch.cuda.Event(enable_timing=True)
73 | end = torch.cuda.Event(enable_timing=True)
74 | start.record()
75 | keyframe_ti, projected_mask = find_consensus_with_established_association(
76 | time_indices,
77 | images,
78 | masks,
79 | scores=scores,
80 | network=network,
81 | store=image_feature_store,
82 | config=config)
83 | end.record()
84 | torch.cuda.synchronize()
85 | total_process_time += (start.elapsed_time(end) / 1000)
86 | """
87 | Backward pass video reader
88 | """
89 | backward_vid_reader = meta_dataset.get_partial_video_loader(vid_name,
90 | start=-1,
91 | end=keyframe_ti + 1,
92 | reverse=True)
93 | """
94 | Forward pass video reader
95 | """
96 | forward_vid_reader = meta_dataset.get_partial_video_loader(vid_name,
97 | start=keyframe_ti,
98 | end=-1,
99 | reverse=False)
100 | """
101 | Running them in combination
102 | """
103 | vid_readers = [backward_vid_reader, forward_vid_reader]
104 | for vid_reader in vid_readers:
105 |
106 | loader = DataLoader(vid_reader, batch_size=1, shuffle=False, num_workers=2)
107 | vid_length = len(loader)
108 | # no need to count usage for LT if the video is not that long anyway
109 | config['enable_long_term_count_usage'] = (
110 | config['enable_long_term'] and
111 | (vid_length / (config['max_mid_term_frames'] - config['min_mid_term_frames']) *
112 | config['num_prototypes']) >= config['max_long_term_elements'])
113 |
114 | processor = DEVAInferenceCore(network,
115 | config=config,
116 | image_feature_store=image_feature_store)
117 | result_saver = ResultSaver(out_path,
118 | vid_name,
119 | dataset='ref_davis',
120 | palette=davis_palette,
121 | object_manager=processor.object_manager)
122 |
123 | for ti, data in enumerate(loader):
124 | with torch.cuda.amp.autocast(enabled=args.amp):
125 | image = data['rgb'].cuda()[0]
126 | info = data['info']
127 | frame = info['frame'][0]
128 | shape = info['shape']
129 | need_resize = info['need_resize'][0]
130 | image_ti = info['time_index'][0].item()
131 |
132 | if image_ti == keyframe_ti:
133 | mask = projected_mask
134 | else:
135 | mask = None
136 |
137 | start = torch.cuda.Event(enable_timing=True)
138 | end = torch.cuda.Event(enable_timing=True)
139 | start.record()
140 |
141 | # Run the model on this frame
142 | prob = processor.step(image,
143 | mask,
144 | end=(ti == vid_length - 1),
145 | hard_mask=False,
146 | image_ti_override=image_ti)
147 |
148 | end.record()
149 | torch.cuda.synchronize()
150 | total_process_time += (start.elapsed_time(end) / 1000)
151 | total_frames += 1
152 |
153 | result_saver.save_mask(prob, frame, need_resize=need_resize, shape=shape)
154 |
155 | result_saver.end()
156 | with open(path.join(out_path, vid_name, 'key.txt'), 'w') as f:
157 | f.write(f'options: {time_indices}; keyframe: {keyframe_ti}')
158 |
159 | except Exception as e:
160 | print(f'Runtime error at {vid_name}')
161 | print(e)
162 | raise e
163 |
164 | print(f'Total processing time: {total_process_time}')
165 | print(f'Total processed frames: {total_frames}')
166 | print(f'FPS: {total_frames / total_process_time}')
167 | print(f'Max allocated memory (MB): {torch.cuda.max_memory_allocated() / (2**20)}')
168 |
169 |
170 |
171 | if __name__ == '__main__':
172 | main()
173 |
--------------------------------------------------------------------------------
/example/vipseg/images/12_1mWNahzcsAc/00001255.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001255.jpg
--------------------------------------------------------------------------------
/example/vipseg/images/12_1mWNahzcsAc/00001258.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001258.jpg
--------------------------------------------------------------------------------
/example/vipseg/images/12_1mWNahzcsAc/00001261.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001261.jpg
--------------------------------------------------------------------------------
/example/vipseg/images/12_1mWNahzcsAc/00001264.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/images/12_1mWNahzcsAc/00001264.jpg
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001255.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "id": 1,
4 | "isthing": false,
5 | "category_id": 7
6 | },
7 | {
8 | "id": 2,
9 | "isthing": true,
10 | "category_id": 63
11 | },
12 | {
13 | "id": 3,
14 | "isthing": false,
15 | "category_id": 66
16 | },
17 | {
18 | "id": 4,
19 | "isthing": true,
20 | "category_id": 60
21 | },
22 | {
23 | "id": 5,
24 | "isthing": false,
25 | "category_id": 14
26 | },
27 | {
28 | "id": 6,
29 | "isthing": true,
30 | "category_id": 63
31 | },
32 | {
33 | "id": 7,
34 | "isthing": true,
35 | "category_id": 60
36 | }
37 | ]
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001255.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001255.png
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001258.json:
--------------------------------------------------------------------------------
1 | [{"id": 1, "isthing": true, "category_id": 63}, {"id": 2, "isthing": false, "category_id": 66}, {"id": 3, "isthing": false, "category_id": 0}, {"id": 4, "isthing": true, "category_id": 63}, {"id": 5, "isthing": true, "category_id": 60}, {"id": 6, "isthing": false, "category_id": 14}, {"id": 7, "isthing": false, "category_id": 7}, {"id": 8, "isthing": true, "category_id": 60}]
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001258.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001258.png
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001261.json:
--------------------------------------------------------------------------------
1 | [{"id": 1, "isthing": false, "category_id": 7}, {"id": 2, "isthing": true, "category_id": 63}, {"id": 3, "isthing": true, "category_id": 63}, {"id": 4, "isthing": false, "category_id": 66}, {"id": 5, "isthing": true, "category_id": 60}, {"id": 6, "isthing": true, "category_id": 60}, {"id": 7, "isthing": false, "category_id": 14}]
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001261.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001261.png
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001264.json:
--------------------------------------------------------------------------------
1 | [{"id": 1, "isthing": false, "category_id": 7}, {"id": 2, "isthing": true, "category_id": 63}, {"id": 3, "isthing": true, "category_id": 63}, {"id": 4, "isthing": false, "category_id": 66}, {"id": 5, "isthing": false, "category_id": 0}, {"id": 6, "isthing": true, "category_id": 60}, {"id": 7, "isthing": true, "category_id": 60}, {"id": 8, "isthing": false, "category_id": 14}]
--------------------------------------------------------------------------------
/example/vipseg/source/12_1mWNahzcsAc/00001264.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vipseg/source/12_1mWNahzcsAc/00001264.png
--------------------------------------------------------------------------------
/example/vos/Annotations/bmx-trees/00000.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/Annotations/bmx-trees/00000.png
--------------------------------------------------------------------------------
/example/vos/JPEGImages/bmx-trees/00000.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00000.jpg
--------------------------------------------------------------------------------
/example/vos/JPEGImages/bmx-trees/00001.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00001.jpg
--------------------------------------------------------------------------------
/example/vos/JPEGImages/bmx-trees/00002.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00002.jpg
--------------------------------------------------------------------------------
/example/vos/JPEGImages/bmx-trees/00003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/example/vos/JPEGImages/bmx-trees/00003.jpg
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = ["hatchling"]
3 | build-backend = "hatchling.build"
4 |
5 | [tool.hatch.metadata]
6 | allow-direct-references = true
7 |
8 | [tool.yapf]
9 | based_on_style = "pep8"
10 | indent_width = 4
11 | column_limit = 100
12 |
13 | [project]
14 | name = "deva"
15 | version = "1.0.0"
16 | authors = [{ name = "Rex Cheng", email = "hkchengrex@gmail.com" }]
17 | description = "Tracking Anything with Decoupled Video Segmenttation (DEVA), ICCV 2023"
18 | readme = "README.md"
19 | requires-python = ">=3.7"
20 | classifiers = [
21 | "Programming Language :: Python :: 3",
22 | "Operating System :: OS Independent",
23 | ]
24 | dependencies = [
25 | 'gitpython >= 3.1',
26 | 'thinplate@git+https://github.com/cheind/py-thin-plate-spline',
27 | 'hickle >= 5.0',
28 | 'tensorboard >= 2.12',
29 | 'numpy >= 1.22',
30 | 'Pillow >= 9.5',
31 | 'opencv-python >= 4.8',
32 | 'scipy >= 1.11.2',
33 | 'pycocotools >= 2.0.7',
34 | 'supervision >= 0.18',
35 | 'tqdm >= 4.66.1',
36 | 'gurobipy >= 10.0.3',
37 | 'PuLP >= 2.7',
38 | 'gradio >= 3.44',
39 | 'gdown >= 4.7.1',
40 | ]
41 |
42 | [tool.hatch.build.targets.wheel]
43 | packages = ["deva"]
44 |
45 | [project.urls]
46 | "Homepage" = "https://github.com/hkchengrex/Tracking-Anything-with-DEVA"
47 | "Bug Tracker" = "https://github.com/hkchengrex/Tracking-Anything-with-DEVA/issues"
48 |
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hkchengrex/Tracking-Anything-with-DEVA/404a112df77f9644d5c7211811329ccd8174b8c3/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/download_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gdown
3 | import zipfile
4 |
5 | LICENSE = """
6 | These are either re-distribution of the original datasets or derivatives (through simple processing) of the original datasets.
7 | Please read and respect their licenses and terms before use.
8 | You should cite the original papers if you use any of the datasets.
9 |
10 | Links:
11 | DUTS: http://saliencydetection.net/duts
12 | HRSOD: https://github.com/yi94code/HRSOD
13 | FSS: https://github.com/HKUSTCV/FSS-1000
14 | ECSSD: https://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html
15 | BIG: https://github.com/hkchengrex/CascadePSP
16 |
17 | YouTubeVOS: https://youtube-vos.org
18 | DAVIS: https://davischallenge.org/
19 | """
20 |
21 | print(LICENSE)
22 | print('Datasets will be downloaded and extracted to ../YouTube, ../static, ../DAVIS')
23 | reply = input('[y] to confirm, others to exit: ')
24 | if reply != 'y':
25 | exit()
26 | """
27 | Static image data
28 | """
29 | os.makedirs('../static', exist_ok=True)
30 | print('Downloading static datasets...')
31 | gdown.download('https://drive.google.com/uc?id=1wUJq3HcLdN-z1t4CsUhjeZ9BVDb9YKLd',
32 | output='../static/static_data.zip',
33 | quiet=False)
34 | print('Extracting static datasets...')
35 | with zipfile.ZipFile('../static/static_data.zip', 'r') as zip_file:
36 | zip_file.extractall('../static/')
37 | print('Cleaning up static datasets...')
38 | os.remove('../static/static_data.zip')
39 | """
40 | DAVIS dataset
41 | """
42 | # Google drive mirror: https://drive.google.com/drive/folders/1hEczGHw7qcMScbCJukZsoOW4Q9byx16A?usp=sharing
43 | os.makedirs('../DAVIS/2017', exist_ok=True)
44 |
45 | print('Downloading DAVIS 2016...')
46 | gdown.download('https://drive.google.com/uc?id=198aRlh5CpAoFz0hfRgYbiNenn_K8DxWD',
47 | output='../DAVIS/DAVIS-data.zip',
48 | quiet=False)
49 |
50 | print('Downloading DAVIS 2017 trainval...')
51 | gdown.download('https://drive.google.com/uc?id=1kiaxrX_4GuW6NmiVuKGSGVoKGWjOdp6d',
52 | output='../DAVIS/2017/DAVIS-2017-trainval-480p.zip',
53 | quiet=False)
54 |
55 | print('Downloading DAVIS 2017 testdev...')
56 | gdown.download('https://drive.google.com/uc?id=1fmkxU2v9cQwyb62Tj1xFDdh2p4kDsUzD',
57 | output='../DAVIS/2017/DAVIS-2017-test-dev-480p.zip',
58 | quiet=False)
59 |
60 | print('Extracting DAVIS datasets...')
61 | with zipfile.ZipFile('../DAVIS/DAVIS-data.zip', 'r') as zip_file:
62 | zip_file.extractall('../DAVIS/')
63 | os.rename('../DAVIS/DAVIS', '../DAVIS/2016')
64 |
65 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-trainval-480p.zip', 'r') as zip_file:
66 | zip_file.extractall('../DAVIS/2017/')
67 | os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/trainval')
68 |
69 | with zipfile.ZipFile('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip', 'r') as zip_file:
70 | zip_file.extractall('../DAVIS/2017/')
71 | os.rename('../DAVIS/2017/DAVIS', '../DAVIS/2017/test-dev')
72 |
73 | print('Cleaning up DAVIS datasets...')
74 | os.remove('../DAVIS/2017/DAVIS-2017-trainval-480p.zip')
75 | os.remove('../DAVIS/2017/DAVIS-2017-test-dev-480p.zip')
76 | os.remove('../DAVIS/DAVIS-data.zip')
77 | """
78 | YouTubeVOS dataset
79 | """
80 | os.makedirs('../YouTube', exist_ok=True)
81 | os.makedirs('../YouTube/all_frames', exist_ok=True)
82 |
83 | print('Downloading YouTubeVOS train...')
84 | gdown.download('https://drive.google.com/uc?id=13Eqw0gVK-AO5B-cqvJ203mZ2vzWck9s4',
85 | output='../YouTube/train.zip',
86 | quiet=False)
87 | print('Downloading YouTubeVOS val...')
88 | gdown.download('https://drive.google.com/uc?id=1o586Wjya-f2ohxYf9C1RlRH-gkrzGS8t',
89 | output='../YouTube/valid.zip',
90 | quiet=False)
91 | print('Downloading YouTubeVOS all frames valid...')
92 | gdown.download('https://drive.google.com/uc?id=1rWQzZcMskgpEQOZdJPJ7eTmLCBEIIpEN',
93 | output='../YouTube/all_frames/valid.zip',
94 | quiet=False)
95 |
96 | print('Extracting YouTube datasets...')
97 | with zipfile.ZipFile('../YouTube/train.zip', 'r') as zip_file:
98 | zip_file.extractall('../YouTube/')
99 | with zipfile.ZipFile('../YouTube/valid.zip', 'r') as zip_file:
100 | zip_file.extractall('../YouTube/')
101 | with zipfile.ZipFile('../YouTube/all_frames/valid.zip', 'r') as zip_file:
102 | zip_file.extractall('../YouTube/all_frames')
103 |
104 | print('Cleaning up YouTubeVOS datasets...')
105 | os.remove('../YouTube/train.zip')
106 | os.remove('../YouTube/valid.zip')
107 | os.remove('../YouTube/all_frames/valid.zip')
108 | """
109 | OVIS dataset
110 | """
111 | os.makedirs('../OVIS-VOS-train', exist_ok=True)
112 | print('Downloading OVIS-VOS train...')
113 | gdown.download('https://drive.google.com/uc?id=1AZPyyqVqOl6j8THgZ1UdNJY9R1VGEFrX',
114 | output='../OVIS-VOS-train/OVIS-VOS-train.zip',
115 | quiet=False)
116 | print('Extracting OVIS...')
117 | with zipfile.ZipFile('../OVIS-VOS-train/OVIS-VOS-train.zip', 'r') as zip_file:
118 | zip_file.extractall('../OVIS-VOS-train/')
119 | print('Cleaning up OVIS...')
120 | os.remove('../OVIS-VOS-train/OVIS-VOS-train.zip')
121 |
122 | print('Done.')
--------------------------------------------------------------------------------
/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | wget -P ./saves/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/DEVA-propagation.pth
2 | wget -P ./saves/ https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
3 | wget -P ./saves/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
4 | wget -O ./saves/sam_hq_vit_h.pth https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_h.pth?download=true
5 | wget -O ./saves/sam_hq_vit_tiny.pth https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth?download=true
6 | wget -P ./saves/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/mobile_sam.pt
7 | wget -P ./saves/ https://github.com/hkchengrex/Tracking-Anything-with-DEVA/releases/download/v1.0/GroundingDINO_SwinT_OGC.py
--------------------------------------------------------------------------------
/scripts/merge_burst_json.py:
--------------------------------------------------------------------------------
1 | import json
2 | import sys
3 | import os
4 | from os import path
5 | import tqdm
6 |
7 | gt_json_path = sys.argv[1]
8 | pred_path = sys.argv[2]
9 | out_path = sys.argv[3]
10 |
11 | with open(gt_json_path) as f:
12 | json_file = json.load(f)
13 |
14 | for sequence in tqdm.tqdm(json_file['sequences']):
15 | dataset = sequence['dataset']
16 | seq_name = sequence['seq_name']
17 |
18 | sequence['segmentations'] = []
19 |
20 | with open(path.join(pred_path, dataset, seq_name, 'pred.json')) as f:
21 | pred_json = json.load(f)
22 | track_category_id = {}
23 | for frame_segmentation in pred_json['segmentations']:
24 | this_frame_segmentation = {}
25 |
26 | for segmentation_dict in frame_segmentation['segmentations']:
27 | this_frame_segmentation[segmentation_dict['id']] = {
28 | 'rle': segmentation_dict['rle']['counts']
29 | }
30 | track_category_id[segmentation_dict['id']] = 0
31 | sequence['segmentations'].append(this_frame_segmentation)
32 |
33 | sequence['track_category_ids'] = track_category_id
34 |
35 |
36 | with open(out_path, 'w') as f:
37 | json.dump(json_file, f)
38 |
--------------------------------------------------------------------------------
/scripts/merge_multi_scale.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os import path
3 | from argparse import ArgumentParser
4 | import glob
5 | from collections import defaultdict
6 |
7 | import numpy as np
8 | import hickle as hkl
9 | from PIL import Image, ImagePalette
10 |
11 | from tqdm import tqdm
12 | from multiprocessing import Pool
13 | from deva.utils import palette
14 |
15 | from deva.utils.palette import davis_palette, youtube_palette
16 | import shutil
17 |
18 |
19 | def search_options(options, name):
20 | for option in options:
21 | if path.exists(path.join(option, name)):
22 | return path.join(option, name)
23 | else:
24 | return None
25 |
26 |
27 | def process_vid(vid):
28 | vid_path = search_options(all_options, vid)
29 | if vid_path is not None:
30 | backward_mapping = hkl.load(path.join(vid_path, 'backward.hkl'))
31 | else:
32 | backward_mapping = None
33 |
34 | frames = os.listdir(path.join(all_options[0], vid))
35 | frames = [f for f in frames if 'backward' not in f]
36 |
37 | print(vid)
38 | if 'Y' in args.dataset:
39 | this_out_path = path.join(out_path, 'Annotations', vid)
40 | else:
41 | this_out_path = path.join(out_path, vid)
42 | os.makedirs(this_out_path, exist_ok=True)
43 |
44 | for f in frames:
45 | result_sum = None
46 |
47 | for option in all_options:
48 | if not path.exists(path.join(option, vid, f)):
49 | continue
50 |
51 | result = hkl.load(path.join(option, vid, f))
52 | if result_sum is None:
53 | result_sum = result.astype(np.float32)
54 | else:
55 | result_sum += result
56 |
57 | # argmax and to idx
58 | result_sum = np.argmax(result_sum, axis=0)
59 |
60 | # Remap the indices to the original domain
61 | if backward_mapping is not None:
62 | idx_mask = np.zeros_like(result_sum, dtype=np.uint8)
63 | for l, i in backward_mapping.items():
64 | idx_mask[result_sum == i] = l
65 | else:
66 | idx_mask = result_sum.astype(np.uint8)
67 |
68 | # Save the results
69 | img_E = Image.fromarray(idx_mask)
70 | img_E.putpalette(palette)
71 | img_E.save(path.join(this_out_path, f[:-4] + '.png'))
72 |
73 |
74 | if __name__ == '__main__':
75 | """
76 | Arguments loading
77 | """
78 | parser = ArgumentParser()
79 | parser.add_argument('--dataset', default='Y', help='D/Y, D for DAVIS; Y for YouTubeVOS')
80 | parser.add_argument('--list', nargs="+")
81 | parser.add_argument('--pattern',
82 | default=None,
83 | help='Glob pattern. Can be used in place of list.')
84 | parser.add_argument('--output')
85 | parser.add_argument('--num_proc', default=4, type=int)
86 | args = parser.parse_args()
87 |
88 | out_path = args.output
89 |
90 | # Find the input candidates
91 | if args.pattern is None:
92 | all_options = args.list
93 | else:
94 | assert args.list is None, 'cannot specify both list and pattern'
95 | all_options = glob.glob(args.pattern)
96 |
97 | # Get the correct palette
98 | if 'D' in args.dataset:
99 | palette = ImagePalette.ImagePalette(mode='P', palette=davis_palette)
100 | elif 'Y' in args.dataset:
101 | palette = ImagePalette.ImagePalette(mode='P', palette=youtube_palette)
102 | else:
103 | raise NotImplementedError
104 |
105 | # Count of the number of videos in each candidate
106 | all_options = [path.join(o, 'Scores') for o in all_options]
107 | vid_count = defaultdict(int)
108 | for option in all_options:
109 | vid_in_here = sorted(os.listdir(option))
110 | for vid in vid_in_here:
111 | vid_count[vid] += 1
112 |
113 | all_vid = []
114 | count_to_vid = defaultdict(int)
115 | for k, v in vid_count.items():
116 | count_to_vid[v] += 1
117 | all_vid.append(k)
118 |
119 | for k, v in count_to_vid.items():
120 | print('Videos with count %d: %d' % (k, v))
121 |
122 | all_vid = sorted(all_vid)
123 | print('Total number of videos: ', len(all_vid))
124 |
125 | pool = Pool(processes=args.num_proc)
126 | for _ in tqdm(pool.imap_unordered(process_vid, all_vid), max_value=len(all_vid)):
127 | pass
128 |
129 | pool.close()
130 | pool.join()
131 |
132 | if 'D' in args.dataset:
133 | print('Making zip for DAVIS test-dev...')
134 | shutil.make_archive(args.output, 'zip', args.output)
135 |
136 | if 'Y' in args.dataset:
137 | print('Making zip for YouTubeVOS...')
138 | shutil.make_archive(path.join(args.output, path.basename(args.output)), 'zip', args.output,
139 | 'Annotations')
140 |
--------------------------------------------------------------------------------
/scripts/vipseg/change2_720p.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from multiprocessing import Pool
4 |
5 |
6 | DIR='imgs'
7 | DIR2='panomasks'
8 |
9 | Target_Dir = 'VIPSeg_720P'
10 |
11 |
12 | def change(DIR,video,image):
13 | if os.path.isfile(os.path.join(Target_Dir,'images',video,image)) and os.path.isfile(os.path.join(Target_Dir,'panomasks',video,image.split('.')[0]+'.png')):
14 | return
15 |
16 | img = Image.open(os.path.join(DIR,video,image))
17 | w,h = img.size
18 | img = img.resize((int(720*w/h),720),Image.BILINEAR)
19 |
20 | if not os.path.isfile(os.path.join(DIR2,video,image.split('.')[0]+'.png')):
21 | # print('this is the test set')
22 | # print(os.path.join(DIR2,video,image.split('.')[0]+'.png'))
23 | return
24 |
25 |
26 | mask = Image.open(os.path.join(DIR2,video,image.split('.')[0]+'.png'))
27 | mask = mask.resize((int(720*w/h),720),Image.NEAREST)
28 |
29 | if not os.path.exists(os.path.join(Target_Dir,'images',video)):
30 | os.makedirs(os.path.join(Target_Dir,'images',video))
31 | if not os.path.exists(os.path.join(Target_Dir,'panomasks',video)):
32 | os.makedirs(os.path.join(Target_Dir,'panomasks',video))
33 |
34 | img.save(os.path.join(Target_Dir,'images',video,image))
35 | mask.save(os.path.join(Target_Dir,'panomasks',video,image.split('.')[0]+'.png'))
36 | # print('Processing video {} image {}'.format(video,image))
37 |
38 | p = Pool(16)
39 | for video in sorted(os.listdir(DIR)):
40 | print(video)
41 | if video[0]=='.':
42 | continue
43 | for image in sorted(os.listdir(os.path.join(DIR,video))):
44 | if image[0]=='.':
45 | continue
46 | p.apply_async(change,args=(DIR,video,image))
47 | #change(DIR,video,image)
48 | p.close()
49 | p.join()
50 | print('finish')
51 |
52 |
--------------------------------------------------------------------------------
/scripts/vipseg/create_panoptic_video_labels.py:
--------------------------------------------------------------------------------
1 |
2 | import os, sys
3 | import json
4 | import glob
5 | import numpy as np
6 | import PIL.Image as Image
7 | from tqdm import trange
8 | from panopticapi.utils import IdGenerator, save_json
9 | import argparse
10 |
11 | from multiprocessing import Pool
12 |
13 | ROOT_DIR='VIPSeg_720P/panomasks'
14 | TARGET_DIR='VIPSeg_720P/panomasksRGB'
15 | with open('VIPSeg_720P/panoVIPSeg_categories.json','r') as f:
16 | CATEGORIES = json.load(f)
17 |
18 |
19 | if __name__ == "__main__":
20 |
21 | def conversion_worker(video):
22 | # print('processing video:{}'.format(video))
23 | videos_dic={}
24 | video_id = video
25 | videos_dic['video_id'] = video_id
26 |
27 | images = []
28 | annotations = []
29 | id_generator = IdGenerator(categories_dict)
30 | instid2color = {}
31 |
32 |
33 | for imgname in sorted(os.listdir(os.path.join(original_format_folder,video))):
34 |
35 | original_format = np.array(Image.open(os.path.join(original_format_folder,video,imgname)))
36 | image_id = imgname.split('.')[0]
37 | image_filename = imgname
38 | images.append({"id": image_id,
39 | "width": original_format.shape[1],
40 | "height": original_format.shape[0],
41 | "file_name": image_filename})
42 | pan_format = np.zeros((original_format.shape[0], original_format.shape[1], 3), dtype=np.uint8)
43 |
44 | l = np.unique(original_format)
45 |
46 | segm_info = {}
47 |
48 | for el in l:
49 | if el==0:
50 | continue
51 | if el < 125:
52 | semantic_id = el
53 | is_crowd = 0
54 | else:
55 | semantic_id = el // 100
56 | is_crowd = 0
57 | semantic_id = semantic_id -1
58 | if semantic_id not in categories_dict:
59 | print(semantic_id, video, l, imgname, list(categories_dict.keys()))
60 | if semantic_id > 255:
61 | print(semantic_id, video, l, imgname)
62 | if categories_dict[semantic_id]['isthing'] == 0:
63 | is_crowd = 0
64 | mask = (original_format == el)
65 |
66 | if el not in instid2color:
67 | segment_id, color = id_generator.get_id_and_color(semantic_id)
68 | instid2color[el] = (segment_id, color)
69 | else:
70 | segment_id, color = instid2color[el]
71 |
72 | pan_format[mask] = color
73 | segm_info[int(segment_id)] = \
74 | {"id": int(segment_id),
75 | "category_id": int(semantic_id),
76 | # "area": int(area),
77 | "iscrowd": is_crowd}
78 | if not os.path.exists(os.path.join(out_folder,video)):
79 | os.makedirs(os.path.join(out_folder,video))
80 |
81 | Image.fromarray(pan_format).save(os.path.join(out_folder,video, image_filename))
82 | # print('image saved {}'.format(os.path.join(out_folder,video, image_filename)))
83 |
84 | gt_pan = np.uint32(pan_format)
85 | pan_gt = gt_pan[:, :, 0] + gt_pan[:, :, 1] * 256 + gt_pan[:, :, 2] * 256 * 256
86 | # print(np.unique(pan_gt))
87 | # exit()
88 | labels, labels_cnt = np.unique(pan_gt, return_counts=True)
89 | gt_labels = [_ for _ in segm_info.keys()]
90 | gt_labels_set = set(gt_labels)
91 | for label, area in zip(labels, labels_cnt):
92 | if label == 0:
93 | continue
94 | if label not in gt_labels and label > 0:
95 | print('png label not in json labels.', label, video, gt_labels, l)
96 | segm_info[label]["area"] = int(area)
97 | gt_labels_set.remove(label)
98 | if len(gt_labels_set) != 0:
99 | raise KeyError('remaining gt_labels json')
100 |
101 | segm_info = [v for k,v in segm_info.items()]
102 | annotations.append({'image_id': image_id,
103 | 'file_name': image_filename,
104 | "segments_info": segm_info})
105 | v_anno ={'video_id':video_id,'annotations':annotations}
106 | videos_dic['images'] = images
107 | return v_anno, videos_dic
108 | # return None
109 |
110 |
111 | original_format_folder = ROOT_DIR
112 | # folder to store panoptic PNGs
113 | out_folder = os.path.join(TARGET_DIR)
114 | out_file = os.path.join('VIPSeg_720P/', 'panoptic_gt_VIPSeg.json')
115 | if not os.path.isdir(out_folder):
116 | os.makedirs(out_folder)
117 |
118 | categories = CATEGORIES
119 | categories_dict = {el['id']: el for el in CATEGORIES}
120 |
121 | v_videos=[]
122 | v_annotations=[]
123 | pool = Pool(16)
124 |
125 | results = pool.map(conversion_worker, sorted(os.listdir(original_format_folder)), chunksize=8)
126 | # results = [conversion_worker(p) for p in sorted(os.listdir(original_format_folder))]
127 |
128 | for v_anno, videos_dic in results:
129 | v_videos.append(videos_dic)
130 | v_annotations.append(v_anno)
131 |
132 | d = {'videos': v_videos,
133 | 'annotations': v_annotations,
134 | 'categories': categories,
135 | }
136 |
137 | save_json(d, out_file)
138 |
139 |
140 | print('==> Saved json file at %s'%(out_file))
141 |
--------------------------------------------------------------------------------