├── .gitignore ├── .gitmodules ├── INSTALL.md ├── MODEL_PREP.md ├── README.md ├── aot_example.ipynb ├── aot_timing.py ├── demo.py ├── grounding_SAM_example.ipynb ├── grounding_dino_example.ipynb ├── sam_example.ipynb ├── sam_timing.py ├── sample_data └── DAVIS_bear │ ├── images │ ├── 00000.jpg │ ├── 00001.jpg │ ├── 00002.jpg │ ├── 00003.jpg │ ├── 00004.jpg │ ├── 00005.jpg │ ├── 00006.jpg │ ├── 00007.jpg │ ├── 00008.jpg │ ├── 00009.jpg │ ├── 00010.jpg │ ├── 00011.jpg │ ├── 00012.jpg │ ├── 00013.jpg │ ├── 00014.jpg │ ├── 00015.jpg │ ├── 00016.jpg │ ├── 00017.jpg │ ├── 00018.jpg │ ├── 00019.jpg │ ├── 00020.jpg │ ├── 00021.jpg │ ├── 00022.jpg │ ├── 00023.jpg │ ├── 00024.jpg │ ├── 00025.jpg │ ├── 00026.jpg │ ├── 00027.jpg │ ├── 00028.jpg │ ├── 00029.jpg │ ├── 00030.jpg │ ├── 00031.jpg │ ├── 00032.jpg │ ├── 00033.jpg │ ├── 00034.jpg │ ├── 00035.jpg │ ├── 00036.jpg │ ├── 00037.jpg │ ├── 00038.jpg │ ├── 00039.jpg │ ├── 00040.jpg │ ├── 00041.jpg │ ├── 00042.jpg │ ├── 00043.jpg │ ├── 00044.jpg │ ├── 00045.jpg │ ├── 00046.jpg │ ├── 00047.jpg │ ├── 00048.jpg │ ├── 00049.jpg │ ├── 00050.jpg │ ├── 00051.jpg │ ├── 00052.jpg │ ├── 00053.jpg │ ├── 00054.jpg │ ├── 00055.jpg │ ├── 00056.jpg │ ├── 00057.jpg │ ├── 00058.jpg │ ├── 00059.jpg │ ├── 00060.jpg │ ├── 00061.jpg │ ├── 00062.jpg │ ├── 00063.jpg │ ├── 00064.jpg │ ├── 00065.jpg │ ├── 00066.jpg │ ├── 00067.jpg │ ├── 00068.jpg │ ├── 00069.jpg │ ├── 00070.jpg │ ├── 00071.jpg │ ├── 00072.jpg │ ├── 00073.jpg │ ├── 00074.jpg │ ├── 00075.jpg │ ├── 00076.jpg │ ├── 00077.jpg │ ├── 00078.jpg │ ├── 00079.jpg │ ├── 00080.jpg │ └── 00081.jpg │ └── init_mask │ └── 00000.png └── tracking_SAM ├── __init__.py ├── aott.py ├── plt_clicker.py └── tracking_SAM.py /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_weights/ 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # poetry 100 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 104 | #poetry.lock 105 | 106 | # pdm 107 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 108 | #pdm.lock 109 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 110 | # in version control. 111 | # https://pdm.fming.dev/#use-with-ide 112 | .pdm.toml 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/Pytorch-Correlation-extension"] 2 | path = tracking_SAM/third_party/Pytorch-Correlation-extension 3 | url = https://github.com/ClementPinard/Pytorch-Correlation-extension.git 4 | [submodule "third_party/aot-benchmark"] 5 | path = tracking_SAM/third_party/aot_benchmark 6 | url = https://github.com/yoxu515/aot-benchmark 7 | [submodule "tracking_SAM/third_party/GroundingDINO"] 8 | path = tracking_SAM/third_party/GroundingDINO 9 | url = https://github.com/IDEA-Research/GroundingDINO 10 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | # Prepare env 2 | 3 | ## Clone codebase 4 | 5 | The repo contains submodule. So clone needs to be recursive. 6 | 7 | ```bash 8 | git clone --recursive https://github.com/RogerQi/Tracking_SAM 9 | ``` 10 | 11 | ## Install dependencies 12 | 13 | Tested on Anaconda with mamba solver. 14 | 15 | **It's known that sometimes default solver behaves different than mamba. So be cautious.** 16 | 17 | ### General packages 18 | 19 | ```bash 20 | conda create -y -n tracking_SAM python=3.8 && conda activate tracking_SAM 21 | # may need these two lines if run into compile-time issues 22 | # conda install -c conda-forge gcc=10.3.0 --strict-channel-priority 23 | # conda install -c conda-forge cxx-compiler --strict-channel-priority 24 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.7 cudatoolkit=11.7 cudatoolkit-dev=11.7 -c pytorch -c conda-forge -c nvidia 25 | pip install opencv-python Pillow tqdm matplotlib 26 | ``` 27 | 28 | ### Compile PyTorch Correlation (for efficient VOS inference) 29 | 30 | ```bash 31 | cd tracking_SAM/third_party/Pytorch-Correlation-extension 32 | pip install . 33 | ``` 34 | 35 | ### Install SAM 36 | 37 | ```bash 38 | pip install git+https://github.com/facebookresearch/segment-anything.git 39 | ``` 40 | 41 | ### Install GroundingDINO 42 | 43 | ``` 44 | cd tracking_SAM/third_party/GroundingDINO 45 | pip install . 46 | ``` 47 | -------------------------------------------------------------------------------- /MODEL_PREP.md: -------------------------------------------------------------------------------- 1 | # Pre-trained model weights 2 | 3 | Create a folder. All the pre-trained weights will be saved here. 4 | 5 | ```bash 6 | mkdir pretrained_weights 7 | cd pretrained_weights 8 | ``` 9 | 10 | ## Getting SAM weights 11 | 12 | ```bash 13 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 14 | ``` 15 | 16 | ## Getting GroundingDINO weights 17 | 18 | ```bash 19 | wget https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth 20 | ``` 21 | 22 | ## Getting AOT weights 23 | 24 | Get pre-trained AOT VOS weights from [here](https://drive.google.com/file/d/1owPmwV4owd_ll6GuilzklqTyAd0ZvbCu/view?usp=sharing) 25 | and save to `./pretrained_weights`. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tracking_SAM 2 | Language/Clicking grounded SAM + VOS for real-time video object tracking 3 | 4 | ## Prepare env 5 | 6 | See [INSTALL.md](INSTALL.md) to prepare environment. 7 | 8 | Follow [MODEL_PREP.md](MODEL_PREP.md) to prepare pre-trained model weight. 9 | 10 | ## Run main demo 11 | 12 | You can run the main tracking_SAM example on the sample data by running: 13 | 14 | ```bash 15 | python demo.py 16 | ``` 17 | 18 | **NOTE: if you do nothing after running python. It will simply play a video.** 19 | 20 | For the tracking to happen, an initial frame needs to be annotated via clicking or language. 21 | 22 | ### Interactive clicking annotation 23 | 24 | Press 'a' on the keyboard to go in an annotator. Inside the annotator, you can use your mouse to left click and add points. 25 | Press 'enter' after a satisfactory mask is generated and the mask tracking will automatically start. 26 | 27 | If you are unhappy with your clicks, you can also do `ctrl+r` to reset all annotating progress. 28 | 29 | ### Language-based detection 30 | 31 | Press 'd' on the keyboard to send a pre-defined language query (beat) to the model. GroundingDINO will generate a bbox 32 | for SAM to refine. The mask tracking will automatically start. 33 | 34 | ## Check intermediate processes 35 | 36 | See all the iPython notebooks in the root project directory. 37 | 38 | ## TODOs 39 | 40 | - [x] Add VOS 41 | - [x] Add SAM 42 | - [x] Add VOS+SAM 43 | - [x] Add Clicking 44 | - [x] Add Grounding DINO for languaged-conditioned mask generation 45 | - [ ] Support multiple objects 46 | - [ ] Save memory by loading models only when necessary and offloading when not used 47 | - [ ] Serialize used models to ONNX for deployment (easy access and potential TRT optimization) 48 | - [ ] Switch to FastSAM for faster inference 49 | - [ ] Add support for SAM2 50 | 51 | ## Citation 52 | 53 | If you find our tool useful in your research, please consider citing VBC (for which TrackingSAM was originally developed) 54 | 55 | ``` 56 | @inproceedings{liu2024-vbc, 57 | title={Visual Whole-Body Control for Legged Loco-Manipulation}, 58 | author={Liu, Minghuan and Chen, Zixuan and Cheng, Xuxin and Ji, Yandong and Qiu, Ri-Zhao and Yang, Ruihan and Wang, Xiaolong}, 59 | booktitle={CoRL}, 60 | year={2024} 61 | } 62 | ``` 63 | -------------------------------------------------------------------------------- /aot_timing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import time 5 | import torch 6 | from tqdm import tqdm, trange 7 | 8 | # Ours 9 | import tracking_SAM.aott 10 | 11 | @torch.no_grad() 12 | def time_video_seq(my_vos_tracker, test_data_base_dir, num_trials=100): 13 | # Call this to clear VOS internal states 14 | my_vos_tracker.reset_engine() 15 | # Load images 16 | image_paths_list = sorted([os.path.join(test_data_base_dir, 'images', x) for x in os.listdir(os.path.join(test_data_base_dir, 'images'))]) 17 | 18 | image_np_list = [np.array(Image.open(x)) for x in image_paths_list] 19 | 20 | # Load the initial provided mask 21 | init_mask_path = os.path.join(test_data_base_dir, 'init_mask', '00000.png') 22 | init_mask = np.array(Image.open(init_mask_path)).astype(np.uint8) 23 | 24 | init_mask[init_mask > 0] = 1 25 | 26 | all_mask_list = [init_mask] # The first mask is the initial mask 27 | 28 | reference_frame_time = [] 29 | for _ in trange(num_trials): 30 | start_cp = time.time() 31 | my_vos_tracker.add_reference_frame(image_np_list[0], init_mask) 32 | end_cp = time.time() 33 | reference_frame_time.append(end_cp - start_cp) 34 | 35 | print("Using {} trials".format(num_trials)) 36 | print('Reference frame time: {:.4f}. Std: {:.4f}'.format(np.mean(reference_frame_time), np.std(reference_frame_time))) 37 | 38 | # Reset reference frames; we will add them again 39 | my_vos_tracker.reset_engine() 40 | my_vos_tracker.add_reference_frame(image_np_list[0], init_mask) 41 | 42 | propage_frame_time = [] 43 | for i in trange(1, len(image_np_list)): 44 | start_cp = time.time() 45 | cur_frame_np = image_np_list[i] 46 | cur_mask_np = my_vos_tracker.propagate_one_frame(cur_frame_np) 47 | all_mask_list.append(cur_mask_np) 48 | end_cp = time.time() 49 | propage_frame_time.append(end_cp - start_cp) 50 | 51 | print('Propagate frame time: {:.4f}. Std: {:.4f}'.format(np.mean(propage_frame_time), np.std(propage_frame_time))) 52 | 53 | def main(): 54 | vos_weight_path = './pretrained_weights/AOTT_PRE_YTB_DAV.pth' 55 | 56 | my_vos_tracker = tracking_SAM.aott.aot_segmenter(vos_weight_path) 57 | 58 | test_folder_list = ['./sample_data/DAVIS_bear'] 59 | 60 | for folder in test_folder_list: 61 | time_video_seq(my_vos_tracker, folder) 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import cv2 5 | import time 6 | import argparse 7 | 8 | import tracking_SAM 9 | 10 | def main(sam_checkpoint, aot_checkpoint, grounding_dino_checkpoint, play_delay): 11 | 12 | test_data_base_dir = './sample_data/DAVIS_bear' 13 | 14 | # Load images 15 | image_paths_list = sorted([os.path.join(test_data_base_dir, 'images', x) for x in os.listdir(os.path.join(test_data_base_dir, 'images'))]) 16 | 17 | image_np_list = [np.array(Image.open(x)) for x in image_paths_list] 18 | 19 | my_tracking_SAM = tracking_SAM.main_tracker(sam_checkpoint, aot_checkpoint, grounding_dino_checkpoint) 20 | 21 | for i in range(len(image_np_list)): 22 | image_np_rgb = image_np_list[i] 23 | image_np_bgr = cv2.cvtColor(image_np_rgb, cv2.COLOR_RGB2BGR) 24 | 25 | cv2.imshow('Video', image_np_bgr) 26 | 27 | if my_tracking_SAM.is_tracking(): 28 | start_cp = time.time() 29 | pred_np_hw = my_tracking_SAM.propagate_one_frame(image_np_rgb) 30 | time_elapsed = time.time() - start_cp 31 | pred_np_hw = pred_np_hw.astype(np.uint8) 32 | pred_np_hw[pred_np_hw > 0] = 255 33 | 34 | viz_img = image_np_bgr.copy() 35 | # Alpha blending to add red mask 36 | red_overlay = np.dstack((np.zeros_like(pred_np_hw), np.zeros_like(pred_np_hw), pred_np_hw)) 37 | viz_img = cv2.addWeighted(viz_img, 0.5, red_overlay, 0.5, 0) 38 | # Show time_elapsed on the screen 39 | str_to_show = f'VOS Latency {time_elapsed:.2f} s' 40 | cv2.putText(viz_img, str_to_show, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2, cv2.LINE_AA) 41 | cv2.imshow('Tracked', viz_img) 42 | 43 | # Press Q on keyboard to exit 44 | key_pressed = cv2.waitKey(play_delay) & 0xFF 45 | if key_pressed == ord('q'): 46 | break 47 | elif key_pressed == ord('a'): 48 | if my_tracking_SAM.is_tracking(): 49 | my_tracking_SAM.reset_engine() 50 | my_tracking_SAM.annotate_init_frame(image_np_rgb) 51 | elif key_pressed == ord('d'): 52 | if my_tracking_SAM.is_tracking(): 53 | my_tracking_SAM.reset_engine() 54 | my_tracking_SAM.annotate_init_frame(image_np_rgb, method='dino', category_name='bear') 55 | 56 | cv2.destroyAllWindows() 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument('--sam_checkpoint', type=str, default="./pretrained_weights/sam_vit_h_4b8939.pth") 62 | parser.add_argument('--aot_checkpoint', type=str, default="./pretrained_weights/AOTT_PRE_YTB_DAV.pth") 63 | parser.add_argument('--ground_dino_checkpoint', type=str, default="./pretrained_weights/groundingdino_swint_ogc.pth") 64 | 65 | # delay in ms for each image to stay on screen. Low values (e.g., 1) causes the video to pass by quickly. 66 | parser.add_argument('--play_delay', type=int, default=200) 67 | 68 | args = parser.parse_args() 69 | 70 | main(args.sam_checkpoint, args.aot_checkpoint, args.ground_dino_checkpoint, args.play_delay) 71 | -------------------------------------------------------------------------------- /grounding_SAM_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "6aa12104", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "from PIL import Image\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "import torch\n", 14 | "import cv2" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": null, 20 | "id": "bb696373-fae1-45a4-a05e-c93d4e8cae6c", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import groundingdino.datasets.transforms as T\n", 25 | "from groundingdino.models import build_model\n", 26 | "from groundingdino.util import box_ops\n", 27 | "from groundingdino.util.slconfig import SLConfig\n", 28 | "from groundingdino.util.utils import clean_state_dict\n", 29 | "from groundingdino.util.inference import predict" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "id": "2e2cea10", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "config_file_path = './tracking_SAM/third_party/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py'\n", 40 | "model_path = './pretrained_weights/groundingdino_swint_ogc.pth'\n", 41 | "\n", 42 | "args = SLConfig.fromfile(config_file_path) \n", 43 | "device = 'cpu'\n", 44 | "\n", 45 | "dino_model = build_model(args)\n", 46 | "\n", 47 | "checkpoint = torch.load(model_path, map_location='cpu')\n", 48 | "log = dino_model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)\n", 49 | "dino_model.eval()\n", 50 | "dino_model = dino_model.to(device)\n", 51 | "\n", 52 | "print(log)" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "1a15e128", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "test_img_path = \"./sample_data/DAVIS_bear/images/00000.jpg\"\n", 63 | "image_np = np.asarray(Image.open(test_img_path).convert(\"RGB\"))\n", 64 | "\n", 65 | "transform = T.Compose(\n", 66 | " [\n", 67 | " T.RandomResize([800], max_size=1333),\n", 68 | " T.ToTensor(),\n", 69 | " T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),\n", 70 | " ]\n", 71 | ")\n", 72 | "\n", 73 | "img_chw, _ = transform(Image.fromarray(image_np), None)" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "id": "af08a8e5", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "text_prompt = 'bear'\n", 84 | "\n", 85 | "BOX_TRESHOLD = 0.3\n", 86 | "TEXT_TRESHOLD = 0.25" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "89eabf0a", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "boxes, logits, phrases = predict(\n", 97 | " model=dino_model, \n", 98 | " image=img_chw, \n", 99 | " caption=text_prompt, \n", 100 | " box_threshold=BOX_TRESHOLD, \n", 101 | " text_threshold=TEXT_TRESHOLD,\n", 102 | " device=device\n", 103 | ")" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "3f4efed1", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "H, W, _ = image_np.shape\n", 114 | "\n", 115 | "boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": null, 121 | "id": "80dbd1b0", 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "viz_img = image_np.copy()\n", 126 | "\n", 127 | "for box, phrase in zip(boxes_xyxy, phrases):\n", 128 | " box = box.cpu().numpy().astype(np.int32)\n", 129 | " viz_img = cv2.rectangle(viz_img, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2)\n", 130 | " viz_img = cv2.putText(viz_img, phrase, (box[0], box[1]), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)\n", 131 | "\n", 132 | "plt.imshow(viz_img)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "ecb99a52-2643-4c4a-8d6c-95d65e34f71c", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "from segment_anything import sam_model_registry, SamPredictor\n", 143 | "\n", 144 | "sam_checkpoint = \"./pretrained_weights/sam_vit_h_4b8939.pth\" # default model\n", 145 | "\n", 146 | "model_type = \"vit_h\"\n", 147 | "\n", 148 | "device = \"cuda\"\n", 149 | "\n", 150 | "sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)\n", 151 | "sam.to(device=device)\n", 152 | "\n", 153 | "predictor = SamPredictor(sam)" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "3c2f7226-f4b0-4bf1-bf33-957c082d07ea", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "predictor.set_image(image_np)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": null, 169 | "id": "d93114d2-f464-4fcb-8ff5-b658e4a3271f", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "assert len(boxes_xyxy) == 1" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "5092a9de-5b8b-489b-b828-0b38d43ed21c", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "input_box = boxes_xyxy[0].cpu().numpy()" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": null, 189 | "id": "1536b2b6-ab79-43d3-a33b-705472afafe4", 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "masks, _, _ = predictor.predict(\n", 194 | " point_coords=None,\n", 195 | " point_labels=None,\n", 196 | " box=input_box[None, :],\n", 197 | " multimask_output=False,\n", 198 | ")" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": null, 204 | "id": "c21e1542-cfe8-4318-b0f5-1f5d6886c4ef", 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "plt.imshow(masks[0])" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": null, 214 | "id": "eca7b69a-6bec-4b55-9e75-4908b18429bd", 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [] 218 | } 219 | ], 220 | "metadata": { 221 | "kernelspec": { 222 | "display_name": "Python 3 (ipykernel)", 223 | "language": "python", 224 | "name": "python3" 225 | }, 226 | "language_info": { 227 | "codemirror_mode": { 228 | "name": "ipython", 229 | "version": 3 230 | }, 231 | "file_extension": ".py", 232 | "mimetype": "text/x-python", 233 | "name": "python", 234 | "nbconvert_exporter": "python", 235 | "pygments_lexer": "ipython3", 236 | "version": "3.8.18" 237 | } 238 | }, 239 | "nbformat": 4, 240 | "nbformat_minor": 5 241 | } 242 | -------------------------------------------------------------------------------- /sam_timing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import cv2 5 | import time 6 | from PIL import Image 7 | from tqdm import trange 8 | from segment_anything import sam_model_registry, SamPredictor 9 | 10 | def time_sam_clicking_time(test_img_path, predictor, num_trials=100): 11 | image_np = np.array(Image.open(test_img_path).convert("RGB")) 12 | 13 | input_x = np.random.randint(0, image_np.shape[1]) 14 | input_y = np.random.randint(0, image_np.shape[0]) 15 | 16 | input_point = np.array([[input_x, input_y]]) # (x, y); DIFFERENT FROM OPENCV COORDINATE 17 | input_label = np.ones((input_point.shape[0], )) 18 | 19 | sam_time_list = [] 20 | for _ in trange(num_trials): 21 | start_cp = time.time() 22 | 23 | predictor.set_image(image_np) 24 | 25 | masks, scores, logits = predictor.predict( 26 | point_coords=input_point, 27 | point_labels=input_label, 28 | multimask_output=False, 29 | ) 30 | 31 | end_cp = time.time() 32 | sam_time_list.append(end_cp - start_cp) 33 | 34 | print("Using {} trials".format(num_trials)) 35 | print(f"Average SAM clicking time: {np.mean(sam_time_list):.4f} sec") 36 | print(f"STD: {np.std(sam_time_list):.4f} sec") 37 | 38 | 39 | def main(): 40 | sam_checkpoint = "./pretrained_weights/sam_vit_h_4b8939.pth" # default model 41 | test_img_paths = [ 42 | "./sample_data/DAVIS_bear/images/00000.jpg" 43 | ] 44 | 45 | model_type = "vit_h" 46 | 47 | device = "cuda" 48 | 49 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 50 | sam.to(device=device) 51 | 52 | predictor = SamPredictor(sam) 53 | 54 | for test_img_path in test_img_paths: 55 | time_sam_clicking_time(test_img_path, predictor) 56 | 57 | if __name__ == "__main__": 58 | main() -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00000.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00001.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00002.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00003.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00004.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00005.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00006.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00007.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00008.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00008.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00009.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00010.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00011.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00012.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00012.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00013.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00013.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00014.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00015.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00016.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00017.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00017.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00018.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00019.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00020.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00020.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00021.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00022.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00022.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00023.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00023.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00024.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00024.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00025.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00025.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00026.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00026.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00027.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00027.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00028.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00028.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00029.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00029.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00030.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00031.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00032.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00033.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00033.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00034.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00034.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00035.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00035.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00036.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00036.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00037.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00037.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00038.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00038.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00039.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00039.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00040.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00040.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00041.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00042.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00042.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00043.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00043.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00044.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00044.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00045.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00045.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00046.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00046.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00047.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00047.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00048.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00048.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00049.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00049.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00050.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00050.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00051.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00052.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00052.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00053.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00053.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00054.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00054.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00055.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00055.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00056.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00056.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00057.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00057.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00058.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00058.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00059.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00059.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00060.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00060.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00061.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00062.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00062.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00063.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00063.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00064.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00064.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00065.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00065.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00066.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00066.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00067.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00067.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00068.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00069.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00069.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00070.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00070.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00071.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00072.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00072.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00073.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00073.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00074.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00074.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00075.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00075.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00076.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00076.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00077.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00077.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00078.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00078.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00079.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00079.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00080.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00080.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/images/00081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/images/00081.jpg -------------------------------------------------------------------------------- /sample_data/DAVIS_bear/init_mask/00000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RogerQi/Tracking_SAM/d5e5e81a361dce4ddbe3264a2b8027b2b43ea3fd/sample_data/DAVIS_bear/init_mask/00000.png -------------------------------------------------------------------------------- /tracking_SAM/__init__.py: -------------------------------------------------------------------------------- 1 | from .tracking_SAM import main_tracker -------------------------------------------------------------------------------- /tracking_SAM/aott.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import sys 3 | import os 4 | 5 | base_dir = os.path.join(os.path.dirname(__file__), 'third_party/aot_benchmark') 6 | 7 | sys.path.insert(0, base_dir) 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | from torchvision import transforms 13 | 14 | from utils.checkpoint import load_network 15 | from networks.models import build_vos_model 16 | from networks.engines import build_engine 17 | 18 | import dataloaders.video_transforms as tr 19 | 20 | class aot_segmenter: 21 | def __init__(self, ckpt_path, gpu_id=0): 22 | cfg = self.get_config(gpu_id, ckpt_path) 23 | # Pad AOT CFG 24 | self.gpu_id = gpu_id 25 | # Load pre-trained model 26 | self.model = build_vos_model(cfg.MODEL_VOS, cfg).cuda(gpu_id) 27 | self.model, _ = load_network(self.model, cfg.TEST_CKPT_PATH, gpu_id) 28 | self.engine = build_engine(cfg.MODEL_ENGINE, 29 | phase='eval', 30 | aot_model=self.model, 31 | gpu_id=gpu_id, 32 | long_term_mem_gap=cfg.TEST_LONG_TERM_MEM_GAP) 33 | 34 | # Prepare datasets for each sequence 35 | self.transform = transforms.Compose([ 36 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, 37 | cfg.TEST_FLIP, cfg.TEST_MULTISCALE, 38 | cfg.MODEL_ALIGN_CORNERS), 39 | tr.MultiToTensor() 40 | ]) 41 | self.cfg = cfg 42 | 43 | self.reset_engine() 44 | 45 | @staticmethod 46 | def get_config(gpu_id, ckpt_path): 47 | exp_name = 'AOT Tool' 48 | stage = 'pre_ytb_dav' 49 | model = 'aott' 50 | gpu_id = gpu_id 51 | data_path = os.path.join(base_dir, './datasets/Demo') 52 | output_path = os.path.join(base_dir, './demo_output') 53 | max_resolution = 480*1.3 54 | 55 | engine_config = importlib.import_module('configs.' + stage) 56 | cfg = engine_config.EngineConfig(exp_name, model) 57 | 58 | cfg.TEST_GPU_ID = gpu_id 59 | 60 | cfg.TEST_CKPT_PATH = ckpt_path 61 | cfg.TEST_DATA_PATH = data_path 62 | cfg.TEST_OUTPUT_PATH = output_path 63 | 64 | cfg.TEST_MIN_SIZE = None 65 | cfg.TEST_MAX_SIZE = max_resolution * 800. / 480. 66 | 67 | return cfg 68 | 69 | def preprocess_sample(self, img, label=None): 70 | """ 71 | Parameters 72 | - img: (H, W, 3) np.array (RGB ordering; 0-255) 73 | - label (optional): (H, W) np.array in np.int of same size as img 74 | 75 | Return 76 | - ret_dict: ret_dict ready to be fed into AOT engine 77 | """ 78 | assert len(img.shape) == 3 and img.shape[2] == 3 79 | H, W, _ = img.shape 80 | 81 | if label is not None: 82 | obj_idx_list = np.unique(label) 83 | valid_obj_idx_list = [i for i in obj_idx_list if i != 0] 84 | assert len(valid_obj_idx_list) > 0, "Not valid label provided" 85 | novel_obj_idx_list = [i for i in valid_obj_idx_list if i not in self.tracked_obj_idx_list] 86 | self.obj_num += len(novel_obj_idx_list) 87 | 88 | meta_dict = { 89 | 'obj_num': self.obj_num, 90 | 'height': H, 91 | 'width': W, 92 | 'flip': False # no flipping 93 | } 94 | 95 | ret_dict = { 96 | 'current_img': img, 97 | 'meta': meta_dict 98 | } 99 | 100 | if label is not None: 101 | ret_dict['current_label'] = label 102 | 103 | ret_dict = self.transform(ret_dict)[0] # return a list of length 1 104 | 105 | ret_dict['current_img'] = ret_dict['current_img'].cuda(self.gpu_id, 106 | non_blocking=True).float() 107 | ret_dict['current_img'] = ret_dict['current_img'].reshape((1,) + ret_dict['current_img'].shape) 108 | 109 | if 'current_label' in ret_dict: 110 | ret_dict['current_label'] = ret_dict['current_label'].cuda(self.gpu_id, 111 | non_blocking=True) 112 | ret_dict['current_label'] = ret_dict['current_label'].reshape((1,) + ret_dict['current_label'].shape) 113 | ret_dict['current_label'] = F.interpolate(ret_dict['current_label'].float(), 114 | ret_dict['current_img'].shape[2:], 115 | mode='nearest') 116 | 117 | return ret_dict 118 | 119 | def add_reference_frame(self, img, label): 120 | data_dict = self.preprocess_sample(img, label) 121 | with torch.no_grad(): 122 | self.engine.add_reference_frame(data_dict['current_img'], 123 | data_dict['current_label'], 124 | frame_step=self.frame_cnt, 125 | obj_nums=self.obj_num) 126 | self.frame_cnt += 1 127 | 128 | def propagate_one_frame(self, img): 129 | data_dict = self.preprocess_sample(img) 130 | with torch.no_grad(): 131 | # predict segmentation 132 | self.engine.match_propogate_one_frame(data_dict['current_img']) 133 | pred_logit = self.engine.decode_current_logits( 134 | (data_dict['meta']['height'], data_dict['meta']['width'])) 135 | pred_prob = torch.softmax(pred_logit, dim=1) 136 | pred_label = torch.argmax(pred_prob, dim=1, 137 | keepdim=True).float() 138 | _pred_label = F.interpolate(pred_label, 139 | size=self.engine.input_size_2d, 140 | mode="nearest") 141 | # update memory 142 | self.engine.update_memory(_pred_label) 143 | self.frame_cnt += 1 144 | 145 | return pred_label.squeeze(0).squeeze(0).cpu().numpy().astype(np.uint8) 146 | 147 | def reset_engine(self): 148 | self.tracked_obj_idx_list = [] 149 | self.obj_num = 0 150 | self.frame_cnt = 0 151 | self.model.eval() 152 | self.engine.restart_engine() -------------------------------------------------------------------------------- /tracking_SAM/plt_clicker.py: -------------------------------------------------------------------------------- 1 | 2 | import cv2 3 | import argparse 4 | import numpy as np 5 | from PIL import Image 6 | from pathlib import Path 7 | import matplotlib as mpl 8 | import matplotlib.pyplot as plt 9 | 10 | # Forbidden Key: QSFKL 11 | 12 | 13 | class Annotator(object): 14 | def __init__(self, img_np, sam_predictor, save_path=None): 15 | self.sam_predictor = sam_predictor 16 | self.save_path = save_path 17 | self.img = img_np.copy() 18 | self.sam_predictor.set_image(self.img) 19 | self.clicks = np.empty([0, 2], dtype=np.int64) 20 | self.pred = np.zeros(self.img.shape[:2], dtype=np.uint8) 21 | self.merge = self.__gene_merge(self.pred, self.img, self.clicks) 22 | 23 | def __gene_merge(self, pred, img, clicks, r=9, cb=2, b=2, if_first=True): 24 | pred_mask = cv2.merge([pred * 255, pred * 255, np.zeros_like(pred)]) 25 | result = np.uint8(np.clip(img * 0.7 + pred_mask * 0.3, 0, 255)) 26 | if b > 0: 27 | contours, _ = cv2.findContours(pred, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 28 | cv2.drawContours(result, contours, -1, (255, 255, 255), b) 29 | for pt in clicks: 30 | cv2.circle(result, tuple(pt), r, (255, 0, 0), -1) 31 | cv2.circle(result, tuple(pt), r, (255, 255, 255), cb) 32 | if if_first and len(clicks) != 0: 33 | cv2.circle(result, tuple(clicks[0, :2]), r, (0, 255, 0), cb) 34 | return result 35 | 36 | def __update(self): 37 | self.ax1.imshow(self.merge) 38 | self.fig.canvas.draw() 39 | 40 | def __reset(self): 41 | self.clicks = np.empty([0, 2], dtype=np.int64) 42 | self.pred = np.zeros(self.img.shape[:2], dtype=np.uint8) 43 | self.merge = self.__gene_merge(self.pred, self.img, self.clicks) 44 | self.__update() 45 | 46 | def __predict(self): 47 | # TODO(roger): support multiple instances and negative clicks 48 | input_label = np.ones((self.clicks.shape[0], )) 49 | masks, scores, logits = self.sam_predictor.predict( 50 | point_coords=self.clicks, 51 | point_labels=input_label, 52 | multimask_output=False, 53 | ) 54 | self.pred = masks[0].astype(np.uint8) 55 | self.merge = self.__gene_merge(self.pred, self.img, self.clicks) 56 | self.__update() 57 | 58 | def __on_key_press(self, event): 59 | if event.key == 'ctrl+z': 60 | self.clicks = self.clicks[:-1, :] 61 | if len(self.clicks) != 0: 62 | self.__predict() 63 | else: 64 | self.__reset() 65 | elif event.key == 'ctrl+r': 66 | self.__reset() 67 | elif event.key == 'escape': 68 | plt.close() 69 | elif event.key == 'enter': 70 | if self.save_path is not None: 71 | Image.fromarray(self.pred * 255).save(self.save_path) 72 | print('save mask in [{}]!'.format(self.save_path)) 73 | plt.close() 74 | 75 | def __on_button_press(self, event): 76 | if (event.xdata is None) or (event.ydata is None): 77 | return 78 | if event.button == 1: # 1 for left click; 3 for right click 79 | x, y = int(event.xdata + 0.5), int(event.ydata + 0.5) 80 | self.clicks = np.append(self.clicks, np.array( 81 | [[x, y]], dtype=np.int64), axis=0) 82 | self.__predict() 83 | 84 | def main(self): 85 | self.fig = plt.figure('Annotator', figsize=(10, 7)) 86 | self.fig.canvas.mpl_connect('key_press_event', self.__on_key_press) 87 | self.fig.canvas.mpl_connect("button_press_event", self.__on_button_press) 88 | self.fig.suptitle('[RESET]: ctrl+r; [REVOKE]: ctrl+z; [EXIT]: esc; [DONE]: enter', fontsize=14) 89 | self.ax1 = self.fig.add_subplot(1, 1, 1) 90 | self.ax1.axis('off') 91 | self.ax1.imshow(self.merge) 92 | plt.show() 93 | 94 | def get_mask(self): 95 | return self.pred 96 | 97 | if __name__ == "__main__": 98 | from segment_anything import sam_model_registry, SamPredictor 99 | sam_checkpoint = "../pretrained_weights/sam_vit_h_4b8939.pth" # default model 100 | 101 | model_type = "vit_h" 102 | 103 | device = "cuda" 104 | 105 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 106 | sam.to(device=device) 107 | 108 | predictor = SamPredictor(sam) 109 | img_path = "../sample_data/DAVIS_bear/images/00000.jpg" 110 | img_np = np.array(Image.open(img_path)) 111 | anno = Annotator(img_np, predictor, save_path="/tmp/00000.png") 112 | anno.main() 113 | 114 | print("Done!") 115 | print(anno.get_mask().shape) 116 | -------------------------------------------------------------------------------- /tracking_SAM/tracking_SAM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import torch 5 | import tracking_SAM.aott 6 | import tracking_SAM.plt_clicker 7 | from segment_anything import sam_model_registry, SamPredictor 8 | 9 | import groundingdino.datasets.transforms as T 10 | from groundingdino.models import build_model 11 | from groundingdino.util import box_ops 12 | from groundingdino.util.slconfig import SLConfig 13 | from groundingdino.util.utils import clean_state_dict 14 | from groundingdino.util.inference import predict 15 | 16 | class main_tracker: 17 | def __init__(self, sam_checkpoint, aot_checkpoint, ground_dino_checkpoint, 18 | sam_model_type="vit_h", device="cuda"): 19 | self.device = device 20 | self.tracking = False 21 | 22 | self.sam = sam_model_registry[sam_model_type](checkpoint=sam_checkpoint) 23 | self.sam.to(device=device) 24 | self.sam_predictor = SamPredictor(self.sam) 25 | 26 | # Custom wrapper for AOTT 27 | self.vos_tracker = tracking_SAM.aott.aot_segmenter(aot_checkpoint) 28 | 29 | self.reset_engine() 30 | 31 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 32 | config_file_path = os.path.join(cur_dir, 33 | 'third_party', 34 | 'GroundingDINO', 35 | 'groundingdino', 36 | 'config', 37 | 'GroundingDINO_SwinT_OGC.py') 38 | 39 | args = SLConfig.fromfile(config_file_path) 40 | 41 | self.dino_model = build_model(args) 42 | 43 | checkpoint = torch.load(ground_dino_checkpoint, map_location='cpu') 44 | self.dino_model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False) 45 | self.dino_model.eval() 46 | self.dino_model = self.dino_model.to(device) 47 | 48 | def annotate_init_frame(self, img, method='clicking', category_name='background'): 49 | """ 50 | Annotate the first frame of the video. 51 | 52 | Args: 53 | img: numpy array of shape (H, W, 3) and dtype uint8. in RGB format. 54 | method: 'clicking' or 'dino'. 'clicking' is the default method. 55 | """ 56 | if method == 'clicking': 57 | anno = tracking_SAM.plt_clicker.Annotator(img, self.sam_predictor) 58 | anno.main() # blocking call 59 | mask_np_hw = anno.get_mask() 60 | elif method == 'dino': 61 | assert category_name != 'background', "Category name must be specified!" 62 | transform = T.Compose( 63 | [ 64 | T.RandomResize([800], max_size=1333), # not acutally random. It selects from [800]. 65 | T.ToTensor(), 66 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 67 | ] 68 | ) 69 | 70 | img_chw, _ = transform(Image.fromarray(img), None) 71 | 72 | # From official groundingdino demo 73 | BOX_TRESHOLD = 0.3 74 | TEXT_TRESHOLD = 0.25 75 | 76 | boxes, logits, phrases = predict( 77 | model=self.dino_model, 78 | image=img_chw, 79 | caption=category_name, 80 | box_threshold=BOX_TRESHOLD, 81 | text_threshold=TEXT_TRESHOLD, 82 | device=self.device 83 | ) 84 | 85 | H, W, _ = img.shape 86 | 87 | boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H]) 88 | 89 | self.sam_predictor.set_image(img) 90 | assert len(boxes_xyxy) == 1 91 | input_box = boxes_xyxy[0].cpu().numpy() 92 | 93 | masks, _, _ = self.sam_predictor.predict( 94 | point_coords=None, 95 | point_labels=None, 96 | box=input_box[None, :], 97 | multimask_output=False, 98 | ) 99 | 100 | mask_np_hw = masks[0].astype(np.uint8) 101 | else: 102 | raise NotImplementedError(f"method {method} not implemented!") 103 | 104 | mask_np_hw = mask_np_hw.astype(np.uint8) 105 | mask_np_hw[mask_np_hw > 0] = 1 # TODO(roger): support multiple objects? 106 | 107 | self.vos_tracker.add_reference_frame(img, mask_np_hw) 108 | 109 | self.tracking = True 110 | 111 | def propagate_one_frame(self, img): 112 | assert self.tracking, "Please call annotate_init_frame() first!" 113 | pred_np_hw = self.vos_tracker.propagate_one_frame(img) 114 | return pred_np_hw 115 | 116 | def reset_engine(self): 117 | self.vos_tracker.reset_engine() 118 | self.tracking = False 119 | 120 | def is_tracking(self): 121 | return self.tracking 122 | --------------------------------------------------------------------------------