├── .gitignore ├── LICENSE-Caption-Anything.txt ├── LICENSE-Qwen2-VL.txt ├── LICENSE-SAMURAI.txt ├── LICENSE-VideoLLaMA2.txt ├── README.md ├── assets ├── ball.txt ├── cat-v-framework.png ├── cover.png ├── demo.mp4 ├── demo.txt ├── demo_short.mp4 └── jump.mp4 ├── checkpoints └── download_ckpts.sh ├── environment.yml ├── eval_utils.py ├── gradio_app.py ├── inference.sh ├── init.sh ├── internvl ├── get_acc.py ├── test-batch.sh └── test.py ├── requirements.txt ├── sam2 ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── configs │ ├── sam2.1 │ │ ├── sam2.1_hiera_b+.yaml │ │ ├── sam2.1_hiera_l.yaml │ │ ├── sam2.1_hiera_s.yaml │ │ └── sam2.1_hiera_t.yaml │ ├── sam2.1_training │ │ └── sam2.1_hiera_b+_MOSE_finetune.yaml │ ├── sam2 │ │ ├── sam2_hiera_b+.yaml │ │ ├── sam2_hiera_l.yaml │ │ ├── sam2_hiera_s.yaml │ │ └── sam2_hiera_t.yaml │ └── samurai │ │ ├── sam2.1_hiera_b+.yaml │ │ ├── sam2.1_hiera_l.yaml │ │ ├── sam2.1_hiera_s.yaml │ │ └── sam2.1_hiera_t.yaml ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml ├── sam2_hiera_t.yaml ├── sam2_image_predictor.py ├── sam2_video_predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── kalman_filter.py │ ├── misc.py │ └── transforms.py ├── scripts ├── dog.txt ├── get_boundary.py ├── get_caption.py ├── get_masks.py ├── get_vis.py ├── inference │ └── inference.py └── main_inference.py ├── setup.py └── trace ├── __init__.py ├── constants.py ├── conversation.py ├── eval ├── eval.sh ├── evaluate.py ├── mvbench │ ├── eval.sh │ └── evaluate.py ├── reformat_dvc.py ├── reformat_tvg.py ├── reformat_vhd.py └── videomme │ ├── eval.sh │ └── evaluate.py ├── metrics ├── README.md ├── dvc │ ├── SODA │ │ ├── LICENSE │ │ ├── README.md │ │ ├── dataset.py │ │ ├── nlpeval │ │ │ ├── bert_f_score.py │ │ │ ├── bert_r_score.py │ │ │ └── mover.py │ │ ├── requirements.txt │ │ ├── soda.py │ │ └── utils.py │ ├── __init__.py │ ├── eval_dvc.py │ ├── eval_dvc_anet.py │ ├── eval_soda.py │ └── metrics │ │ ├── README.md │ │ ├── cider.py │ │ ├── cider_scorer.py │ │ ├── eval_soda.py │ │ ├── meteor-1.5.jar │ │ ├── meteor.py │ │ ├── ptbtokenizer.py │ │ └── stanford-corenlp-3.4.1.jar ├── tvg │ ├── eval_tvg.py │ └── eval_tvg.sh └── vhd │ ├── eval_highlights.sh │ ├── eval_vhd.py │ └── utils.py ├── mm_utils.py ├── model ├── __init__.py ├── builder.py ├── language_model │ └── trace_mistral.py ├── multimodal_encoder │ ├── builder.py │ ├── clip_encoder.py │ ├── score_encoder.py │ ├── sync_encoder.py │ └── time_encoder.py ├── multimodal_projector │ ├── __init__.py │ └── builder.py └── trace_arch.py ├── prompts ├── dvc-anet-ft.txt ├── dvc-anet.txt ├── dvc.txt ├── mr.txt └── vhd.txt ├── trace_trainer.py ├── train_mt.py ├── train_mt_npu.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | .DS_Store 3 | __pycache__/ 4 | *-checkpoint.ipynb 5 | .venv 6 | *.egg* 7 | build/* 8 | _C.* 9 | outputs/* 10 | checkpoints/*.pt 11 | 12 | 13 | results/*.mp4 14 | 15 | # Python 16 | __pycache__ 17 | *.pyc 18 | *.egg-info 19 | dist 20 | 21 | # Log 22 | *.log 23 | *.log.* 24 | *.json 25 | *.jsonl 26 | log_dir*/ 27 | 28 | # Data 29 | !**/alpaca-data-conversation.json 30 | 31 | # Editor 32 | .idea 33 | *.swp 34 | 35 | # Other 36 | .DS_Store 37 | 38 | # jupyter 39 | .ipynb_checkpoints 40 | *.ipynb 41 | 42 | # DevContainer 43 | !.devcontainer/* 44 | 45 | # Demo 46 | serve_images/ 47 | 48 | # data folder 49 | data/ 50 | dataset/ 51 | datasets/ 52 | results/ 53 | 54 | # training folder 55 | wandb 56 | ckpts* 57 | output 58 | output/ 59 | # checkpoints 60 | # checkpoints/ 61 | work_dirs*/ 62 | 63 | # evaluation folder 64 | eval_results/ 65 | 66 | # pretrained weights 67 | pretrained/ 68 | publish_models/ 69 | 70 | -------------------------------------------------------------------------------- /LICENSE-Caption-Anything.txt: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Teng Wang 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [AAAI-26 Demo] Caption Anything in Video: Object-centric Dense Video Captioning with Spatiotemporal Multimodal Prompting 2 | ![Cover](assets/cover.png) 3 | 4 | 5 | 6 | Official PyTorch implementation of [Caption Anything in Video: Fine-grained Object-centric Captioning via Spatiotemporal Multimodal Prompting](https://arxiv.org/abs/2504.05541) 7 | 8 | ![cat-v-framework](assets/cat-v-framework.png) 9 | 10 | ## 🚀 Updates 11 | 12 | ## 🕹️ Demo 13 | 14 | YouTube: [https://youtu.be/2eiPVKXEoxw](https://youtu.be/2eiPVKXEoxw) 15 | 16 | ## 🛠️ Getting Started 17 | 18 | 1. Set up a conda environment (python>= 3.10) using: 19 | 20 | ```bash 21 | conda create -n cat2 python=3.10 -y 22 | conda activate cat2 23 | ``` 24 | 25 | 2. Install the requirements: 26 | 27 | ```bash 28 | pip install -e . 29 | ``` 30 | 31 | 3. Download checkpoints: 32 | 33 | ```bash 34 | cd checkpoints && \ 35 | ./download_ckpts.sh && \ 36 | cd .. 37 | ``` 38 | 39 | ## 🏃 RUN 40 | 41 | ``` 42 | bash inference.sh 43 | ``` 44 | 45 | 46 | ## 📖 Citation 47 | If you find this work useful for your research or applications, please cite using this BibTeX: 48 | 49 | ```bibtex 50 | @inproceedings{tang2025cat-v, 51 | title={Caption Anything in Video: Fine-grained Object-centric Captioning via Spatiotemporal Multimodal Prompting}, 52 | author={Tang, Yunlong and Bi, Jing and Hua, Hang and Xiao, Yunzhong and Song, Yizhi and Liu, Pinxin and Huang, Chao and Feng, Mingqian and Guo, Junjia and Liu, Zhuo and Song, Luchuan and Liang, Susan and Shimada, Daiki and Vosoughi, Ali and He, Jinxi and He, Liu and Zhang, Zeliang and Luo, Jiebo and Xu, Chenliang}, 53 | journel={arXiv}, 54 | year={2025} 55 | } 56 | ``` 57 | 58 | 59 | ## 🙏 Acknowledgements 60 | This work was supported by Sony Group Corporation. We would like to thank Sayaka Nakamura and Jerry Jun Yokono for their insightful discussion. 61 | 62 | We are also grateful for the following awesome projects our CAT-V arising from: 63 | 64 | - [Caption Anything](https://github.com/ttengwang/Caption-Anything) 65 | - [SAM 2](https://github.com/facebookresearch/sam2) 66 | - [SAMURAI](https://github.com/yangchris11/samurai) 67 | - [TRACE-uni](https://github.com/gyxxyg/TRACE) 68 | - [VideoLLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2) 69 | - [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL) 70 | - [InternVL-2.5](https://internvl.github.io/blog/2024-12-05-InternVL-2.5/) 71 | 72 | 73 | ## 👩‍💻 Contributors 74 | Our project wouldn't be possible without the contributions of these amazing people! Thank you all for making this project better. 75 | 76 | - [Yunlong Tang](https://yunlong10.github.io/) @ University of Rochester 77 | - [Jing Bi](https://scholar.google.com/citations?user=ZyCYhUkAAAAJ) @ University of Rochester 78 | - [Chao Huang](https://wikichao.github.io/) @ University of Rochester 79 | - [Susan Liang](https://liangsusan-git.github.io/) @ University of Rochester 80 | - [Daiki Shimada](https://scholar.google.co.jp/citations?user=1uAwouQAAAAJ) @ Sony Group Corporation 81 | - [Hang Hua](https://hanghuacs.notion.site/Hang-Hua-151c5b68f62980e8884febf1b5c1d4a9) @ University of Rochester 82 | - [Yunzhong Xiao](https://shawn-yzxiao.github.io/) @ Carnegie Mellon University 83 | - [Yizhi Song](https://song630.github.io/yizhisong.github.io/) @ Purdue University 84 | - [Pinxin Liu](https://andypinxinliu.github.io/) @ University of Rochester 85 | - [Mingqian Feng](https://fmmarkmq.github.io/) @ University of Rochester 86 | - [Junjia Guo](https://doujiangter.github.io/JunjiaGuo.github.io/) @ University of Rochester 87 | - [Zhuo Liu](https://joeliuz6.github.io/) @ University of Rochester 88 | - [Luchuan Song](https://songluchuan.github.io/) @ University of Rochester 89 | - [Ali Vosoughi](https://alivosoughi.com/) @ University of Rochester 90 | - [Jinxi He](https://gingin520.github.io/) @ University of Rochester 91 | - [Liu He](https://arking1995.github.io/) @ Purdue University 92 | - [Zeliang Zhang](https://zhangaipi.github.io/) @ University of Rochester 93 | - [Jiebo Luo](https://www.cs.rochester.edu/u/jluo/) @ University of Rochester 94 | - [Chenliang Xu](https://www.cs.rochester.edu/~cxu22/index.html) @ University of Rochester 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | ### 🌟 Star History 103 | 104 | [![Star History Chart](https://api.star-history.com/svg?repos=yunlong10/CAT-V&type=Date)](https://star-history.com/#yunlong10/CAT-V&Date) 105 | -------------------------------------------------------------------------------- /assets/ball.txt: -------------------------------------------------------------------------------- 1 | 390, 435, 60, 60 -------------------------------------------------------------------------------- /assets/cat-v-framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/cat-v-framework.png -------------------------------------------------------------------------------- /assets/cover.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/cover.png -------------------------------------------------------------------------------- /assets/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/demo.mp4 -------------------------------------------------------------------------------- /assets/demo.txt: -------------------------------------------------------------------------------- 1 | 720, 250, 750, 300 -------------------------------------------------------------------------------- /assets/demo_short.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/demo_short.mp4 -------------------------------------------------------------------------------- /assets/jump.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/jump.mp4 -------------------------------------------------------------------------------- /checkpoints/download_ckpts.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | # Use either wget or curl to download the checkpoints 10 | if command -v wget &> /dev/null; then 11 | CMD="wget" 12 | elif command -v curl &> /dev/null; then 13 | CMD="curl -L -O" 14 | else 15 | echo "Please install wget or curl to download the checkpoints." 16 | exit 1 17 | fi 18 | 19 | # Define the URLs for SAM 2 checkpoints 20 | # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824" 21 | # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt" 22 | # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt" 23 | # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt" 24 | # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt" 25 | 26 | # Download each of the four checkpoints using wget 27 | # echo "Downloading sam2_hiera_tiny.pt checkpoint..." 28 | # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; } 29 | 30 | # echo "Downloading sam2_hiera_small.pt checkpoint..." 31 | # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; } 32 | 33 | # echo "Downloading sam2_hiera_base_plus.pt checkpoint..." 34 | # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; } 35 | 36 | # echo "Downloading sam2_hiera_large.pt checkpoint..." 37 | # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; } 38 | 39 | # Define the URLs for SAM 2.1 checkpoints 40 | SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824" 41 | sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt" 42 | sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt" 43 | sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt" 44 | sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt" 45 | 46 | # SAM 2.1 checkpoints 47 | echo "Downloading sam2.1_hiera_tiny.pt checkpoint..." 48 | $CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; } 49 | 50 | echo "Downloading sam2.1_hiera_small.pt checkpoint..." 51 | $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; } 52 | 53 | echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..." 54 | $CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; } 55 | 56 | echo "Downloading sam2.1_hiera_large.pt checkpoint..." 57 | $CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; } 58 | 59 | echo "All checkpoints are downloaded successfully." 60 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import gradio as gr 4 | import subprocess 5 | import json 6 | import cv2 7 | import tempfile 8 | from PIL import Image 9 | import numpy as np 10 | 11 | CONFIG = { 12 | "model_path": "OpenGVLab/InternVL2-8B", 13 | "get_boundary_model_path": "Yongxin-Guo/trace-uni", 14 | "get_mask_model_path": "./checkpoints/sam2.1_hiera_base_plus.pt", 15 | "output_folder": "./results/", 16 | "frame_count": 16, 17 | } 18 | 19 | def extract_first_frame(video_path): 20 | cap = cv2.VideoCapture(video_path) 21 | ret, image = cap.read() 22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 23 | cap.release() 24 | height, width = image.shape[:2] 25 | if height > 750: 26 | scale = 750 / height 27 | new_width = int(width * scale) 28 | new_height = 750 29 | image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA) 30 | return image 31 | 32 | def run_inference_pipeline(video_path, bbox): 33 | """ 34 | Run the entire inference pipeline for video processing 35 | """ 36 | # Ensure output folder exists 37 | os.makedirs(CONFIG["output_folder"], exist_ok=True) 38 | 39 | # Prepare file paths 40 | video_name = os.path.basename(video_path) 41 | qa_file_path = os.path.join( 42 | CONFIG["output_folder"], f"{os.path.splitext(video_name)[0]}_boundary.json" 43 | ) 44 | final_json_path = os.path.join( 45 | CONFIG["output_folder"], 46 | f"{os.path.splitext(video_name)[0]}_boundary_caption.json", 47 | ) 48 | final_video_path = os.path.join( 49 | CONFIG["output_folder"], 50 | f"{os.path.splitext(video_name)[0]}_boundary_caption.mp4", 51 | ) 52 | masked_video_path = os.path.join( 53 | CONFIG["output_folder"], f"{os.path.splitext(video_name)[0]}_mask.mp4" 54 | ) 55 | print(f"Final JSON Path: {final_json_path}") 56 | video = cv2.VideoCapture(video_path) 57 | ret, frame = video.read() 58 | h,w = frame.shape[:2] 59 | print(h,w) 60 | video.release() 61 | bbox = [int(bbox[0]*w), int(bbox[1]*h), int(bbox[2]*w), int(bbox[3]*h)] 62 | print(bbox) 63 | object_bbox_path = Path(CONFIG['output_folder'])/f"{os.path.splitext(video_name)[0]}_bbox.txt" 64 | with open(object_bbox_path, "w") as f: 65 | f.write(','.join(map(str, bbox))) 66 | commands = [ 67 | # Step 1: Parsing/Boundary Detection 68 | f"python -m scripts.get_boundary " 69 | f"--video_paths {video_path} " 70 | f"--questions 'Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences.' " 71 | f"--model_path {CONFIG['get_boundary_model_path']}", 72 | ] 73 | 74 | 75 | commands.append( 76 | f"python scripts/get_masks.py " 77 | f"--video_path {video_path} " 78 | f"--txt_path {object_bbox_path} " 79 | f"--model_path {CONFIG['get_mask_model_path']} " 80 | f"--video_output_path {CONFIG['output_folder']} " 81 | f"--save_to_video True" 82 | ) 83 | 84 | # Step 2: Captioning 85 | commands.append( 86 | f"python scripts/get_caption.py " 87 | f"--model_path {CONFIG['model_path']} " 88 | f"--QA_file_path {qa_file_path} " 89 | f"--video_folder {CONFIG['output_folder']} " 90 | f"--answers_output_folder {CONFIG['output_folder']} " 91 | f"--extract_frames_method max_frames_num " 92 | f"--max_frames_num {CONFIG['frame_count']} " 93 | f"--frames_from video " 94 | f"--final_json_path {final_json_path} " 95 | f"--provide_boundaries" 96 | ) 97 | 98 | # Step 3: Generate Visualization 99 | commands.append( 100 | f"python scripts/get_vis.py {masked_video_path if object_bbox_path else video_path} {final_json_path} {final_video_path}" 101 | ) 102 | 103 | # Execute commands 104 | for cmd in commands: 105 | try: 106 | subprocess.run(cmd, shell=True, check=True) 107 | except subprocess.CalledProcessError as e: 108 | print(f"Error in command: {cmd}") 109 | print(f"Error details: {e}") 110 | return None 111 | 112 | try: 113 | with open(final_json_path, "r") as f: 114 | results = json.load(f) 115 | return {"captions": results, "final_video": final_video_path} 116 | except Exception as e: 117 | print(f"Error reading results: {e}") 118 | return None 119 | 120 | def get_bounding_box(image): 121 | alpha_channel = image[:, :, 3] 122 | y_coords, x_coords = np.where(alpha_channel > 0) 123 | 124 | if y_coords.size == 0 or x_coords.size == 0: 125 | return None 126 | x_min, x_max = x_coords.min(), x_coords.max() 127 | y_min, y_max = y_coords.min(), y_coords.max() 128 | x_min_ratio = x_min / image.shape[1] 129 | x_max_ratio = x_max / image.shape[1] 130 | y_min_ratio = y_min / image.shape[0] 131 | y_max_ratio = y_max / image.shape[0] 132 | return x_min_ratio, y_min_ratio, x_max_ratio, y_max_ratio 133 | def caption_video(video, edited_image): 134 | """ 135 | Gradio-friendly wrapper for inference pipeline 136 | video: path to the uploaded video 137 | bbox_file: path to the uploaded bounding box file (optional) 138 | edited_image: the edited first frame image returned by ImageEditor (PIL Image) 139 | """ 140 | layer_0 = edited_image['layers'][0] 141 | bbox = get_bounding_box(layer_0) 142 | 143 | if video is None: 144 | return "Please upload a video first.", None 145 | results = run_inference_pipeline(video, bbox) 146 | 147 | if results is None: 148 | return "Processing failed. Please check the logs.", None 149 | 150 | # Format captions nicely 151 | captions_text = "\n\n".join( 152 | [ 153 | f"Event {i+1} (Time: {event.get('timestamp', 'N/A')}):\n{event.get('model_answer', 'No caption')}" 154 | for i, event in enumerate(results.get("captions", [])) 155 | ] 156 | ) 157 | 158 | return captions_text, results.get("final_video") 159 | 160 | 161 | 162 | 163 | def create_demo(): 164 | """ 165 | Create Gradio interface 166 | """ 167 | 168 | DESCRIPTION = """# CAT2: 169 | This is a demo for our 'CAT2' [paper](https://github.com/yunlong10/CAT-2). 170 | Code is available [here](https://github.com/yunlong10/CAT-2). 171 | This demo performs captioning with optional object bounding box annotation. 172 | """ 173 | 174 | with gr.Blocks() as demo: 175 | gr.Markdown("# Caption Anything Demo") 176 | gr.Markdown(DESCRIPTION) 177 | gr.Markdown( 178 | "Upload a video and optionally a bounding box file. Or draw a rectangle on the first frame of the video to provide a bounding box. (Note: The ImageEditor does not return bounding box coordinates directly. Further processing may be required.)" 179 | ) 180 | 181 | with gr.Row(): 182 | video_input = gr.Video(label="Upload Video",height=800) 183 | first_frame_editor = gr.ImageEditor(label="Draw a rectangle on the First Frame",height=800) 184 | 185 | video_input.change(fn=extract_first_frame, inputs=video_input, outputs=first_frame_editor) 186 | 187 | caption_button = gr.Button("Generate Captions") 188 | 189 | output_text = gr.Textbox(label="Video Captions") 190 | output_video = gr.Video(label="Processed Video") 191 | 192 | caption_button.click( 193 | fn=caption_video, 194 | inputs=[video_input, first_frame_editor], 195 | outputs=[output_text, output_video], 196 | ) 197 | 198 | return demo 199 | 200 | 201 | if __name__ == "__main__": 202 | demo = create_demo() 203 | demo.launch( 204 | server_name="0.0.0.0", # Make accessible from other machines 205 | server_port=8889, 206 | debug=True, 207 | ) -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # conda activate /home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/env/cat-2 4 | export TRANSFORMERS_CACHE=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/transformers_cache 5 | export TORCH_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/torch_home 6 | export HF_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/hf_home 7 | export PIP_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/pip 8 | export OPENAI_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/openai 9 | 10 | set -e 11 | 12 | GREEN="\033[32m" 13 | RESET="\033[0m" 14 | FRAME_COUNT=32 15 | OUTPUT_FOLDER="./results" 16 | mkdir -p $OUTPUT_FOLDER 17 | MODEL_PATH="OpenGVLab/InternVL2_5-8B-MPO" # "OpenGVLab/InternVL2-8B" 18 | GET_BOUNDARY_MODEL_PATH="Yongxin-Guo/trace-uni" 19 | GET_MASK_MODEL_PATH="./checkpoints/sam2.1_hiera_base_plus.pt" 20 | 21 | ############################################################################################################ 22 | VIDEO_NAME="demo.mp4" 23 | VIDEO_FOLDER="./assets/" 24 | OBJECT_BBOX="demo.txt" 25 | QA_FILE_PATH="$OUTPUT_FOLDER/demo_boundary.json" 26 | FINAL_JSON_PATH="$OUTPUT_FOLDER/demo_boundary_caption.json" 27 | FINAL_VIDEO_PATH="$OUTPUT_FOLDER/demo_boundary_caption.mp4" 28 | MASKED_VIDEO_PATH="$OUTPUT_FOLDER/demo_mask.mp4" 29 | ############################################################################################################ 30 | 31 | VIDEO_PATH="$VIDEO_FOLDER$VIDEO_NAME" 32 | OBJECT_BBOX_PATH="$VIDEO_FOLDER$OBJECT_BBOX" 33 | 34 | START_TIME=$(date +%s) 35 | 36 | echo -e "${GREEN}Step 1: Parsing...${RESET}" 37 | 38 | # python -m scripts.get_boundary \ 39 | # --video_paths $VIDEO_PATH \ 40 | # --questions "Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences." \ 41 | # --model_path $GET_BOUNDARY_MODEL_PATH 42 | 43 | echo -e "${GREEN}Step 2: Segmentation...${RESET}" 44 | 45 | 46 | # python scripts/get_masks.py \ 47 | # --video_path "$VIDEO_PATH" \ 48 | # --txt_path "$OBJECT_BBOX_PATH" \ 49 | # --model_path "$GET_MASK_MODEL_PATH" \ 50 | # --video_output_path "$OUTPUT_FOLDER" \ 51 | # --save_to_video True 52 | 53 | echo -e "${GREEN}Step 3: Captioning...${RESET}" 54 | 55 | python scripts/get_caption.py \ 56 | --model_path "$MODEL_PATH" \ 57 | --QA_file_path "$QA_FILE_PATH" \ 58 | --video_folder "$OUTPUT_FOLDER" \ 59 | --answers_output_folder "$OUTPUT_FOLDER" \ 60 | --extract_frames_method "max_frames_num" \ 61 | --max_frames_num "$FRAME_COUNT" \ 62 | --frames_from "video" \ 63 | --final_json_path "$FINAL_JSON_PATH" \ 64 | --provide_boundaries 65 | 66 | echo -e "${GREEN}Step 3: Generate visualizations...${RESET}" 67 | 68 | # python scripts/get_vis.py "$MASKED_VIDEO_PATH" "$FINAL_JSON_PATH" "$FINAL_VIDEO_PATH" 69 | 70 | echo -e "${GREEN}Completed in $(($(date +%s) - START_TIME)) seconds.${RESET}" -------------------------------------------------------------------------------- /init.sh: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/init.sh -------------------------------------------------------------------------------- /internvl/get_acc.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from pycocoevalcap.bleu.bleu import Bleu 4 | from pycocoevalcap.meteor.meteor import Meteor 5 | from pycocoevalcap.cider.cider import Cider 6 | import os 7 | 8 | # Load the JSON file 9 | 10 | results = "./results" 11 | # read all json files in the directory 12 | for file in os.listdir(results): 13 | if file.endswith(".json"): 14 | with open(f"{results}/{file}", "r", encoding="utf-8") as file: 15 | data = json.load(file) 16 | 17 | # Prepare references and hypotheses as dictionaries 18 | gts = {} # Ground truth (references) 19 | res = {} # Results (hypotheses)a 20 | 21 | for i, item in enumerate(tqdm(data)): 22 | gts[i] = [item["correct_answer"]] # Reference list for ID i 23 | res[i] = [item["model_answer"]] # Hypothesis for ID i 24 | 25 | # BLEU Score 26 | def compute_bleu(gts, res): 27 | bleu_scorer = Bleu(4) # Compute BLEU-1 to BLEU-4 28 | score, _ = bleu_scorer.compute_score(gts, res) 29 | return score 30 | 31 | # METEOR Score 32 | def compute_meteor(gts, res): 33 | meteor_scorer = Meteor() 34 | score, _ = meteor_scorer.compute_score(gts, res) 35 | return score 36 | 37 | # CIDEr Score 38 | def compute_cider(gts, res): 39 | cider_scorer = Cider() 40 | score, _ = cider_scorer.compute_score(gts, res) 41 | return score 42 | 43 | # Calculate scores 44 | bleu_score = compute_bleu(gts, res) 45 | meteor_score = compute_meteor(gts, res) 46 | cider_score = compute_cider(gts, res) 47 | 48 | # Print results 49 | print(f"Results for {file}") 50 | print(f"BLEU Scores: {[round(i*100, 2) for i in bleu_score]}") 51 | print(f"METEOR Score: {round(meteor_score*100, 2)}") 52 | print(f"CIDEr Score: {round(cider_score*100, 2)}") 53 | -------------------------------------------------------------------------------- /internvl/test-batch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | export TRANSFORMERS_CACHE=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/transformers_cache 3 | export TORCH_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/torch_home 4 | export HF_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/hf_home 5 | export PIP_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/pip 6 | export OPENAI_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/openai 7 | 8 | # Exit immediately if a command exits with a non-zero status. 9 | set -e 10 | 11 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) 12 | PYTHON_FILE_PATH="${SCRIPT_DIR}/test.py" 13 | ANSWERS_OUTPUT_FOLDER="${SCRIPT_DIR}/results" 14 | 15 | if [ -e "$ANSWERS_OUTPUT_FOLDER" ]; then 16 | echo "File $ANSWERS_OUTPUT_FOLDER already exists." 17 | else 18 | echo "File $ANSWERS_OUTPUT_FOLDER does not exist. Creating it now..." 19 | mkdir "$ANSWERS_OUTPUT_FOLDER" 20 | echo "File $ANSWERS_OUTPUT_FOLDER created." 21 | fi 22 | 23 | # Variables - Please update these paths according to your setup 24 | MODEL_PATH="OpenGVLab/InternVL2_5-8B-MPO-AWQ" #"OpenGVLab/InternVL2-8B" 25 | VIDEO_FOLDER="/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/samurai/samed_videos" 26 | 27 | # Define task files and frame numbers 28 | TASK_FILES=("example.json") 29 | FRAME_COUNTS=("16") 30 | 31 | # Loop through each task file and each frame count 32 | for TASK in "${TASK_FILES[@]}"; do 33 | for FRAME_COUNT in "${FRAME_COUNTS[@]}"; do 34 | QA_FILE_PATH="/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/samurai/QAs/$TASK" 35 | 36 | # Execute the Python script with the provided arguments 37 | python "$PYTHON_FILE_PATH" \ 38 | --model_path "$MODEL_PATH" \ 39 | --QA_file_path "$QA_FILE_PATH" \ 40 | --video_folder "$VIDEO_FOLDER" \ 41 | --answers_output_folder "$ANSWERS_OUTPUT_FOLDER" \ 42 | --extract_frames_method "max_frames_num" \ 43 | --max_frames_num "$FRAME_COUNT" \ 44 | --frames_from "video" \ 45 | # --provide_boundaries 46 | done 47 | done 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib==3.7 2 | moviepy==1.0.3 3 | accelerate>=0.26.0 4 | numpy==1.26.1 5 | tikzplotlib 6 | jpeg4py 7 | opencv-python 8 | lmdb 9 | pandas 10 | scipy 11 | loguru 12 | einops 13 | transformers==4.40.1 14 | timm 15 | decord 16 | imageio 17 | scenedetect 18 | SentencePiece 19 | gradio -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | from hydra.core.global_hydra import GlobalHydra 9 | 10 | if not GlobalHydra.instance().is_initialized(): 11 | initialize_config_module("sam2", version_base="1.2") 12 | -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import os 9 | 10 | import torch 11 | from hydra import compose 12 | from hydra.utils import instantiate 13 | from omegaconf import OmegaConf 14 | 15 | import sam2 16 | 17 | # Check if the user is running Python from the parent directory of the sam2 repo 18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since 19 | # it could shadow the sam2 package and cause issues. 20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")): 21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself 22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory). 23 | # This typically happens because the user is running Python from the parent directory 24 | # that contains the sam2 repo they cloned. 25 | raise RuntimeError( 26 | "You're likely running Python from the parent directory of the sam2 repository " 27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). " 28 | "This is not supported since the `sam2` Python package could be shadowed by the " 29 | "repository name (the repository is also named `sam2` and contains the Python package " 30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir " 31 | "rather than its parent dir, or from your home directory) after installing SAM 2." 32 | ) 33 | 34 | 35 | HF_MODEL_ID_TO_FILENAMES = { 36 | "facebook/sam2-hiera-tiny": ( 37 | "configs/sam2/sam2_hiera_t.yaml", 38 | "sam2_hiera_tiny.pt", 39 | ), 40 | "facebook/sam2-hiera-small": ( 41 | "configs/sam2/sam2_hiera_s.yaml", 42 | "sam2_hiera_small.pt", 43 | ), 44 | "facebook/sam2-hiera-base-plus": ( 45 | "configs/sam2/sam2_hiera_b+.yaml", 46 | "sam2_hiera_base_plus.pt", 47 | ), 48 | "facebook/sam2-hiera-large": ( 49 | "configs/sam2/sam2_hiera_l.yaml", 50 | "sam2_hiera_large.pt", 51 | ), 52 | "facebook/sam2.1-hiera-tiny": ( 53 | "configs/sam2.1/sam2.1_hiera_t.yaml", 54 | "sam2.1_hiera_tiny.pt", 55 | ), 56 | "facebook/sam2.1-hiera-small": ( 57 | "configs/sam2.1/sam2.1_hiera_s.yaml", 58 | "sam2.1_hiera_small.pt", 59 | ), 60 | "facebook/sam2.1-hiera-base-plus": ( 61 | "configs/sam2.1/sam2.1_hiera_b+.yaml", 62 | "sam2.1_hiera_base_plus.pt", 63 | ), 64 | "facebook/sam2.1-hiera-large": ( 65 | "configs/sam2.1/sam2.1_hiera_l.yaml", 66 | "sam2.1_hiera_large.pt", 67 | ), 68 | } 69 | 70 | 71 | def build_sam2( 72 | config_file, 73 | ckpt_path=None, 74 | device="cuda", 75 | mode="eval", 76 | hydra_overrides_extra=[], 77 | apply_postprocessing=True, 78 | **kwargs, 79 | ): 80 | 81 | if apply_postprocessing: 82 | hydra_overrides_extra = hydra_overrides_extra.copy() 83 | hydra_overrides_extra += [ 84 | # dynamically fall back to multi-mask if the single mask is not stable 85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 88 | ] 89 | # Read config and init model 90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 91 | OmegaConf.resolve(cfg) 92 | model = instantiate(cfg.model, _recursive_=True) 93 | _load_checkpoint(model, ckpt_path) 94 | model = model.to(device) 95 | if mode == "eval": 96 | model.eval() 97 | return model 98 | 99 | 100 | def build_sam2_video_predictor( 101 | config_file, 102 | ckpt_path=None, 103 | device="cuda", 104 | mode="eval", 105 | hydra_overrides_extra=[], 106 | apply_postprocessing=True, 107 | **kwargs, 108 | ): 109 | hydra_overrides = [ 110 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 111 | ] 112 | if apply_postprocessing: 113 | hydra_overrides_extra = hydra_overrides_extra.copy() 114 | hydra_overrides_extra += [ 115 | # dynamically fall back to multi-mask if the single mask is not stable 116 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 117 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 118 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 119 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 120 | "++model.binarize_mask_from_pts_for_mem_enc=true", 121 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 122 | "++model.fill_hole_area=8", 123 | ] 124 | hydra_overrides.extend(hydra_overrides_extra) 125 | 126 | # Read config and init model 127 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 128 | OmegaConf.resolve(cfg) 129 | model = instantiate(cfg.model, _recursive_=True) 130 | _load_checkpoint(model, ckpt_path) 131 | model = model.to(device) 132 | if mode == "eval": 133 | model.eval() 134 | return model 135 | 136 | 137 | def _hf_download(model_id): 138 | from huggingface_hub import hf_hub_download 139 | 140 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id] 141 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name) 142 | return config_name, ckpt_path 143 | 144 | 145 | def build_sam2_hf(model_id, **kwargs): 146 | config_name, ckpt_path = _hf_download(model_id) 147 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs) 148 | 149 | 150 | def build_sam2_video_predictor_hf(model_id, **kwargs): 151 | config_name, ckpt_path = _hf_download(model_id) 152 | return build_sam2_video_predictor( 153 | config_file=config_name, ckpt_path=ckpt_path, **kwargs 154 | ) 155 | 156 | 157 | def _load_checkpoint(model, ckpt_path): 158 | if ckpt_path is not None: 159 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 160 | missing_keys, unexpected_keys = model.load_state_dict(sd) 161 | if missing_keys: 162 | logging.error(missing_keys) 163 | raise RuntimeError() 164 | if unexpected_keys: 165 | logging.error(unexpected_keys) 166 | raise RuntimeError() 167 | logging.info("Loaded checkpoint sucessfully") 168 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | no_obj_embed_spatial: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: true 102 | proj_tpos_enc_in_obj_ptrs: true 103 | use_signed_tpos_enc_to_obj_ptrs: true 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | no_obj_embed_spatial: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: true 105 | proj_tpos_enc_in_obj_ptrs: true 106 | use_signed_tpos_enc_to_obj_ptrs: true 107 | only_obj_ptrs_in_the_past_for_eval: true 108 | # object occlusion prediction 109 | pred_obj_scores: true 110 | pred_obj_scores_mlp: true 111 | fixed_no_obj_ptr: true 112 | # multimask tracking settings 113 | multimask_output_for_tracking: true 114 | use_multimask_token_for_obj_ptr: true 115 | multimask_min_pt_num: 0 116 | multimask_max_pt_num: 1 117 | use_mlp_for_obj_ptr_proj: true 118 | # Compilation flag 119 | compile_image_encoder: False 120 | -------------------------------------------------------------------------------- /sam2/configs/sam2.1/sam2.1_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2/configs/sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /sam2/configs/samurai/sam2.1_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | no_obj_embed_spatial: true 93 | # use high-resolution feature map in the SAM mask decoder 94 | use_high_res_features_in_sam: true 95 | # output 3 masks on the first click on initial conditioning frames 96 | multimask_output_in_sam: true 97 | # SAM heads 98 | iou_prediction_use_sigmoid: True 99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 100 | use_obj_ptrs_in_encoder: true 101 | add_tpos_enc_to_obj_ptrs: true 102 | proj_tpos_enc_in_obj_ptrs: true 103 | use_signed_tpos_enc_to_obj_ptrs: true 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | # SAMURAI 118 | samurai_mode: true 119 | stable_frames_threshold: 15 120 | stable_ious_threshold: 0.3 121 | min_obj_score_logits: -1 122 | kf_score_weight: 0.25 123 | memory_bank_iou_threshold: 0.5 124 | memory_bank_obj_score_threshold: 0.0 125 | memory_bank_kf_score_threshold: 0.0 126 | -------------------------------------------------------------------------------- /sam2/configs/samurai/sam2.1_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | compile_image_encoder: False 121 | # SAMURAI 122 | samurai_mode: true 123 | stable_frames_threshold: 15 124 | stable_ious_threshold: 0.3 125 | min_obj_score_logits: -1 126 | kf_score_weight: 0.15 127 | memory_bank_iou_threshold: 0.5 128 | memory_bank_obj_score_threshold: 0.0 129 | memory_bank_kf_score_threshold: 0.0 -------------------------------------------------------------------------------- /sam2/configs/samurai/sam2.1_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | no_obj_embed_spatial: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: true 105 | proj_tpos_enc_in_obj_ptrs: true 106 | use_signed_tpos_enc_to_obj_ptrs: true 107 | only_obj_ptrs_in_the_past_for_eval: true 108 | # object occlusion prediction 109 | pred_obj_scores: true 110 | pred_obj_scores_mlp: true 111 | fixed_no_obj_ptr: true 112 | # multimask tracking settings 113 | multimask_output_for_tracking: true 114 | use_multimask_token_for_obj_ptr: true 115 | multimask_min_pt_num: 0 116 | multimask_max_pt_num: 1 117 | use_mlp_for_obj_ptr_proj: true 118 | # Compilation flag 119 | compile_image_encoder: False 120 | # SAMURAI 121 | samurai_mode: true 122 | stable_frames_threshold: 15 123 | stable_ious_threshold: 0.3 124 | min_obj_score_logits: -1 125 | kf_score_weight: 0.25 126 | memory_bank_iou_threshold: 0.5 127 | memory_bank_obj_score_threshold: 0.0 128 | memory_bank_kf_score_threshold: 0.0 -------------------------------------------------------------------------------- /sam2/configs/samurai/sam2.1_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | no_obj_embed_spatial: true 97 | # use high-resolution feature map in the SAM mask decoder 98 | use_high_res_features_in_sam: true 99 | # output 3 masks on the first click on initial conditioning frames 100 | multimask_output_in_sam: true 101 | # SAM heads 102 | iou_prediction_use_sigmoid: True 103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 104 | use_obj_ptrs_in_encoder: true 105 | add_tpos_enc_to_obj_ptrs: true 106 | proj_tpos_enc_in_obj_ptrs: true 107 | use_signed_tpos_enc_to_obj_ptrs: true 108 | only_obj_ptrs_in_the_past_for_eval: true 109 | # object occlusion prediction 110 | pred_obj_scores: true 111 | pred_obj_scores_mlp: true 112 | fixed_no_obj_ptr: true 113 | # multimask tracking settings 114 | multimask_output_for_tracking: true 115 | use_multimask_token_for_obj_ptr: true 116 | multimask_min_pt_num: 0 117 | multimask_max_pt_num: 1 118 | use_mlp_for_obj_ptr_proj: true 119 | # Compilation flag 120 | # HieraT does not currently support compilation, should always be set to False 121 | compile_image_encoder: False 122 | # SAMURAI 123 | samurai_mode: true 124 | stable_frames_threshold: 15 125 | stable_ious_threshold: 0.3 126 | min_obj_score_logits: -1 127 | kf_score_weight: 0.25 128 | memory_bank_iou_threshold: 0.5 129 | memory_bank_obj_score_threshold: 0.0 130 | memory_bank_kf_score_threshold: 0.0 -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | self.d_model = d_model 75 | for dim in backbone_channel_list: 76 | current = nn.Sequential() 77 | current.add_module( 78 | "conv", 79 | nn.Conv2d( 80 | in_channels=dim, 81 | out_channels=d_model, 82 | kernel_size=kernel_size, 83 | stride=stride, 84 | padding=padding, 85 | ), 86 | ) 87 | 88 | self.convs.append(current) 89 | self.fpn_interp_model = fpn_interp_model 90 | assert fuse_type in ["sum", "avg"] 91 | self.fuse_type = fuse_type 92 | 93 | # levels to have top-down features in its outputs 94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 95 | # have top-down propagation, while outputs of level 0 and level 1 have only 96 | # lateral features from the same backbone level. 97 | if fpn_top_down_levels is None: 98 | # default is to have top-down features on all levels 99 | fpn_top_down_levels = range(len(self.convs)) 100 | self.fpn_top_down_levels = list(fpn_top_down_levels) 101 | 102 | def forward(self, xs: List[torch.Tensor]): 103 | 104 | out = [None] * len(self.convs) 105 | pos = [None] * len(self.convs) 106 | assert len(xs) == len(self.convs) 107 | # fpn forward pass 108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 109 | prev_features = None 110 | # forward in top-down order (from low to high resolution) 111 | n = len(self.convs) - 1 112 | for i in range(n, -1, -1): 113 | x = xs[i] 114 | lateral_features = self.convs[n - i](x) 115 | if i in self.fpn_top_down_levels and prev_features is not None: 116 | top_down_features = F.interpolate( 117 | prev_features.to(dtype=torch.float32), 118 | scale_factor=2.0, 119 | mode=self.fpn_interp_model, 120 | align_corners=( 121 | None if self.fpn_interp_model == "nearest" else False 122 | ), 123 | antialias=False, 124 | ) 125 | prev_features = lateral_features + top_down_features 126 | if self.fuse_type == "avg": 127 | prev_features /= 2 128 | else: 129 | prev_features = lateral_features 130 | x_out = prev_features 131 | out[i] = x_out 132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 133 | 134 | return out, pos 135 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | point_embedding[labels == 2] += self.point_embeddings[2].weight 100 | point_embedding[labels == 3] += self.point_embeddings[3].weight 101 | return point_embedding 102 | 103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 104 | """Embeds box prompts.""" 105 | boxes = boxes + 0.5 # Shift to center of pixel 106 | coords = boxes.reshape(-1, 2, 2) 107 | corner_embedding = self.pe_layer.forward_with_coords( 108 | coords, self.input_image_size 109 | ) 110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 112 | return corner_embedding 113 | 114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 115 | """Embeds mask inputs.""" 116 | mask_embedding = self.mask_downscaling(masks) 117 | return mask_embedding 118 | 119 | def _get_batch_size( 120 | self, 121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 122 | boxes: Optional[torch.Tensor], 123 | masks: Optional[torch.Tensor], 124 | ) -> int: 125 | """ 126 | Gets the batch size of the output given the batch size of the input prompts. 127 | """ 128 | if points is not None: 129 | return points[0].shape[0] 130 | elif boxes is not None: 131 | return boxes.shape[0] 132 | elif masks is not None: 133 | return masks.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | ) -> Tuple[torch.Tensor, torch.Tensor]: 146 | """ 147 | Embeds different types of prompts, returning both sparse and dense 148 | embeddings. 149 | 150 | Arguments: 151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 152 | and labels to embed. 153 | boxes (torch.Tensor or none): boxes to embed 154 | masks (torch.Tensor or none): masks to embed 155 | 156 | Returns: 157 | torch.Tensor: sparse embeddings for the points and boxes, with shape 158 | BxNx(embed_dim), where N is determined by the number of input points 159 | and boxes. 160 | torch.Tensor: dense embeddings for the masks, in the shape 161 | Bx(embed_dim)x(embed_H)x(embed_W) 162 | """ 163 | bs = self._get_batch_size(points, boxes, masks) 164 | sparse_embeddings = torch.empty( 165 | (bs, 0, self.embed_dim), device=self._get_device() 166 | ) 167 | if points is not None: 168 | coords, labels = points 169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 171 | if boxes is not None: 172 | box_embeddings = self._embed_boxes(boxes) 173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 174 | 175 | if masks is not None: 176 | dense_embeddings = self._embed_masks(masks) 177 | else: 178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 180 | ) 181 | 182 | return sparse_embeddings, dense_embeddings 183 | -------------------------------------------------------------------------------- /sam2/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_b+.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_l.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_s.yaml -------------------------------------------------------------------------------- /sam2/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | configs/sam2/sam2_hiera_t.yaml -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import warnings 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from torchvision.transforms import Normalize, Resize, ToTensor 13 | 14 | 15 | class SAM2Transforms(nn.Module): 16 | def __init__( 17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 18 | ): 19 | """ 20 | Transforms for SAM2. 21 | """ 22 | super().__init__() 23 | self.resolution = resolution 24 | self.mask_threshold = mask_threshold 25 | self.max_hole_area = max_hole_area 26 | self.max_sprinkle_area = max_sprinkle_area 27 | self.mean = [0.485, 0.456, 0.406] 28 | self.std = [0.229, 0.224, 0.225] 29 | self.to_tensor = ToTensor() 30 | self.transforms = torch.jit.script( 31 | nn.Sequential( 32 | Resize((self.resolution, self.resolution)), 33 | Normalize(self.mean, self.std), 34 | ) 35 | ) 36 | 37 | def __call__(self, x): 38 | x = self.to_tensor(x) 39 | return self.transforms(x) 40 | 41 | def forward_batch(self, img_list): 42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 43 | img_batch = torch.stack(img_batch, dim=0) 44 | return img_batch 45 | 46 | def transform_coords( 47 | self, coords: torch.Tensor, normalize=False, orig_hw=None 48 | ) -> torch.Tensor: 49 | """ 50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 52 | 53 | Returns 54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 55 | """ 56 | if normalize: 57 | assert orig_hw is not None 58 | h, w = orig_hw 59 | coords = coords.clone() 60 | coords[..., 0] = coords[..., 0] / w 61 | coords[..., 1] = coords[..., 1] / h 62 | 63 | coords = coords * self.resolution # unnormalize coords 64 | return coords 65 | 66 | def transform_boxes( 67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 68 | ) -> torch.Tensor: 69 | """ 70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 72 | """ 73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 74 | return boxes 75 | 76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 77 | """ 78 | Perform PostProcessing on output masks. 79 | """ 80 | from sam2.utils.misc import get_connected_components 81 | 82 | masks = masks.float() 83 | input_masks = masks 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | try: 86 | if self.max_hole_area > 0: 87 | # Holes are those connected components in background with area <= self.fill_hole_area 88 | # (background regions are those with mask scores <= self.mask_threshold) 89 | labels, areas = get_connected_components( 90 | mask_flat <= self.mask_threshold 91 | ) 92 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 93 | is_hole = is_hole.reshape_as(masks) 94 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 96 | 97 | if self.max_sprinkle_area > 0: 98 | labels, areas = get_connected_components( 99 | mask_flat > self.mask_threshold 100 | ) 101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 102 | is_hole = is_hole.reshape_as(masks) 103 | # We fill holes with negative mask score (-10.0) to change them to background. 104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 105 | except Exception as e: 106 | # Skip the post-processing step if the CUDA kernel fails 107 | warnings.warn( 108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can " 109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing " 110 | "functionality may be limited (which doesn't affect the results in most cases; see " 111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).", 112 | category=UserWarning, 113 | stacklevel=2, 114 | ) 115 | masks = input_masks 116 | 117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 118 | return masks 119 | -------------------------------------------------------------------------------- /scripts/dog.txt: -------------------------------------------------------------------------------- 1 | 450, 350, 250, 200 2 | 200, 100, 250, 400 -------------------------------------------------------------------------------- /scripts/get_boundary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | import json 4 | import sys 5 | import os 6 | import argparse 7 | 8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) 9 | from trace.conversation import conv_templates, SeparatorStyle 10 | from trace.constants import DEFAULT_MMODAL_TOKEN, MMODAL_TOKEN_INDEX 11 | from trace.mm_utils import get_model_name_from_path, tokenizer_MMODAL_token_all, process_video, process_image, KeywordsStoppingCriteria 12 | from trace.model.builder import load_pretrained_model 13 | 14 | 15 | def inference(args): 16 | # Video Inference 17 | paths = args.video_paths 18 | questions = args.questions 19 | modal_list = ['video'] 20 | 21 | # 1. Initialize the model. 22 | model_path = args.model_path 23 | model_name = get_model_name_from_path(model_path) 24 | tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name) 25 | model = model.to('cuda') 26 | conv_mode = 'llama_2' 27 | 28 | # 2. Visual preprocess (load & transform image or video). 29 | if modal_list[0] == 'video': 30 | tensor, video_timestamps = process_video(paths[0], processor, model.config.image_aspect_ratio, num_frames=64) 31 | tensor = tensor.to(dtype=torch.float16, device='cuda', non_blocking=True) 32 | default_mm_token = DEFAULT_MMODAL_TOKEN["VIDEO"] 33 | modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"] 34 | else: 35 | tensor = process_image(paths[0], processor, model.config.image_aspect_ratio)[0].to(dtype=torch.float16, device='cuda', non_blocking=True) 36 | default_mm_token = DEFAULT_MMODAL_TOKEN["IMAGE"] 37 | modal_token_index = MMODAL_TOKEN_INDEX["IMAGE"] 38 | 39 | tensor = [tensor] 40 | video_timestamps = [video_timestamps] 41 | heads = [1] 42 | 43 | # 3. Text preprocess (tag process & generate prompt). 44 | question = default_mm_token + "\n" + questions[0] 45 | conv = conv_templates[conv_mode].copy() 46 | conv.append_message(conv.roles[0], question) 47 | conv.append_message(conv.roles[1], None) 48 | prompt = conv.get_prompt() 49 | prompt += '' 50 | print(prompt) 51 | input_ids = tokenizer_MMODAL_token_all(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to('cuda') 52 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda() 53 | stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE] else conv.sep2 54 | keywords = [stop_str] 55 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 56 | do_sample = True 57 | 58 | with torch.inference_mode(): 59 | output_ids = model.generate( 60 | input_ids, 61 | attention_mask=attention_masks, 62 | images_or_videos=tensor, 63 | modal_list=modal_list, 64 | do_sample=do_sample, 65 | temperature=0.2 if do_sample else 0.0, 66 | max_new_tokens=1024, 67 | use_cache=True, 68 | pad_token_id=tokenizer.eos_token_id, 69 | video_timestamps=video_timestamps, 70 | heads=heads 71 | ) 72 | 73 | outputs = { 74 | 'timestamps': [], 75 | 'scores': [], 76 | 'captions': [], 77 | } 78 | cur_timestamps = [] 79 | cur_timestamp = [] 80 | cur_scores = [] 81 | cur_score = [] 82 | cur_caption = [] 83 | for idx in output_ids[0]: 84 | if idx <= 32000: 85 | if idx == 32000: 86 | new_caption = tokenizer.decode(cur_caption, skip_special_tokens=True) 87 | outputs['captions'].append(new_caption) 88 | cur_caption = [] 89 | else: 90 | cur_caption.append(idx) 91 | elif idx <= 32013: # 32001 ; 32002 92 | if idx == 32001: 93 | if len(cur_timestamp) > 0: 94 | cur_timestamps.append(float(''.join(cur_timestamp))) 95 | outputs['timestamps'].append(cur_timestamps) 96 | cur_timestamps = [] 97 | cur_timestamp = [] 98 | elif idx == 32002: 99 | if len(cur_timestamp) > 0: 100 | cur_timestamps.append(float(''.join(cur_timestamp))) 101 | cur_timestamp = [] 102 | else: 103 | cur_timestamp.append(model.get_model().time_tokenizer.decode(idx - 32001)) 104 | else: # 32014 ; 32015 105 | if idx == 32014: 106 | if len(cur_score) > 0: 107 | cur_scores.append(float(''.join(cur_score))) 108 | outputs['scores'].append(cur_scores) 109 | cur_scores = [] 110 | cur_score = [] 111 | elif idx == 32015: 112 | if len(cur_score) > 0: 113 | cur_scores.append(float(''.join(cur_score))) 114 | cur_score = [] 115 | else: 116 | cur_score.append(model.get_model().score_tokenizer.decode(idx - 32014)) 117 | if len(cur_caption): 118 | outputs['captions'].append(tokenizer.decode(cur_caption, skip_special_tokens=True)) 119 | 120 | try: 121 | results = [] 122 | for i in range(len(outputs['timestamps'])): 123 | output = { 124 | 'video': paths[0].split("/")[-1][:-4] + "_mask.mp4", 125 | 'segment': f"{outputs['timestamps'][i][0]}_{outputs['timestamps'][i][1]}", 126 | 'question': "", 127 | 'answer': outputs['captions'][i], 128 | } 129 | results.append(output) 130 | 131 | with open(f'./results/{paths[0].split("/")[-1].split(".")[0]}_boundary.json', 'w') as f: 132 | json.dump(results, f) 133 | 134 | except Exception as e: 135 | print(e) 136 | print("Failed to save the output to a json file.") 137 | with open(f'./results/{paths[0].split("/")[-1].split(".")[0]}_boundary.json', 'w') as f: 138 | json.dump([{"video": paths[0].split("/")[-1], "segment": f"0.0_{video_timestamps[0][1]}", "question": "", "answer": ""}], f) 139 | 140 | 141 | if __name__ == "__main__": 142 | parser = argparse.ArgumentParser(description="Inference script for boundary detection.") 143 | parser.add_argument("--video_paths", nargs='+', required=True, help="Paths to the input video files.") 144 | parser.add_argument("--questions", nargs='+', required=True, help="Questions for video inference.") 145 | parser.add_argument("--model_path", required=True, help="Path to the pretrained model.") 146 | args = parser.parse_args() 147 | 148 | inference(args) 149 | -------------------------------------------------------------------------------- /scripts/get_masks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import cv2 6 | import torch 7 | import gc 8 | from tqdm import tqdm 9 | import sys 10 | sys.path.append("./") 11 | from sam2.build_sam import build_sam2_video_predictor 12 | 13 | color = [(255, 0, 0)] 14 | 15 | def load_txt(gt_path): 16 | with open(gt_path, 'r') as f: 17 | gt = f.readlines() 18 | prompts = {} 19 | for fid, line in enumerate(gt): 20 | x_min, y_min, x_max, y_max = line.strip().split(",") 21 | # x, y, w, h = int(x), int(y), int(w), int(h) 22 | x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max) 23 | prompts[fid] = ((x_min, y_min, x_max, y_max), 0) 24 | return prompts 25 | 26 | def determine_model_cfg(model_path): 27 | if "large" in model_path: 28 | return "configs/samurai/sam2.1_hiera_l.yaml" 29 | elif "base_plus" in model_path: 30 | return "configs/samurai/sam2.1_hiera_b+.yaml" 31 | elif "small" in model_path: 32 | return "configs/samurai/sam2.1_hiera_s.yaml" 33 | elif "tiny" in model_path: 34 | return "configs/samurai/sam2.1_hiera_t.yaml" 35 | else: 36 | raise ValueError("Unknown model size in path!") 37 | 38 | def prepare_frames_or_path(video_path): 39 | if video_path.endswith(".mp4") or osp.isdir(video_path): 40 | return video_path 41 | else: 42 | raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.") 43 | 44 | def main(args): 45 | model_cfg = determine_model_cfg(args.model_path) 46 | predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0") 47 | frames_or_path = prepare_frames_or_path(args.video_path) 48 | prompts = load_txt(args.txt_path) 49 | print(prompts) 50 | 51 | if args.save_to_video: 52 | if osp.isdir(args.video_path): 53 | frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")]) 54 | loaded_frames = [cv2.imread(frame_path) for frame_path in frames] 55 | height, width = loaded_frames[0].shape[:2] 56 | else: 57 | cap = cv2.VideoCapture(args.video_path) 58 | loaded_frames = [] 59 | while True: 60 | ret, frame = cap.read() 61 | if not ret: 62 | break 63 | loaded_frames.append(frame) 64 | cap.release() 65 | height, width = loaded_frames[0].shape[:2] 66 | if len(loaded_frames) == 0: 67 | raise ValueError("No frames were loaded from the video.") 68 | 69 | 70 | 71 | 72 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 73 | out = cv2.VideoWriter(args.video_output_path+f"/{osp.basename(args.video_path).split('.')[0]}_mask.mp4", fourcc, 30, (width, height)) 74 | 75 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): 76 | state = predictor.init_state(frames_or_path, offload_video_to_cpu=True) 77 | bbox, track_label = prompts[0] 78 | _, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0) 79 | 80 | 81 | for frame_idx, object_ids, masks in tqdm(predictor.propagate_in_video(state)): 82 | # if frame_idx >= len(loaded_frames): 83 | # print(f"Frame index {frame_idx} out of range. Skipping.") 84 | # continue 85 | 86 | # img = loaded_frames[frame_idx] 87 | mask_to_vis = {} 88 | bbox_to_vis = {} 89 | 90 | for obj_id, mask in zip(object_ids, masks): 91 | mask = mask[0].cpu().numpy() 92 | mask = mask > 0.0 93 | non_zero_indices = np.argwhere(mask) 94 | if len(non_zero_indices) == 0: 95 | bbox = [0, 0, 0, 0] 96 | else: 97 | y_min, x_min = non_zero_indices.min(axis=0).tolist() 98 | y_max, x_max = non_zero_indices.max(axis=0).tolist() 99 | bbox = [x_min, y_min, x_max - x_min, y_max - y_min] 100 | bbox_to_vis[obj_id] = bbox 101 | mask_to_vis[obj_id] = mask 102 | 103 | if args.save_to_video: 104 | img = loaded_frames[frame_idx] 105 | for obj_id, mask in mask_to_vis.items(): 106 | mask_img = np.zeros((height, width, 3), np.uint8) 107 | mask_img[mask] = color[(obj_id + 1) % len(color)] 108 | img = cv2.addWeighted(img, 1, mask_img, 0.2, 0) 109 | 110 | 111 | for obj_id, bbox in bbox_to_vis.items(): 112 | cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color[obj_id % len(color)], 2) 113 | 114 | # add the text to the bottom of EACH frame 115 | # The text depends on the current frame's position in the video, with one decimal place reserved for seconds 116 | # The font color is red, and the text is centered (so you should calculate the len of text) on the bottom like subtitle, occupying 1/5 of the frame height. 117 | # The text is displayed in the format "103.5s" 118 | # time = frame_idx / 30 119 | # time_text = f"{time:.1f}s" 120 | # font = cv2.FONT_HERSHEY_SIMPLEX 121 | # font_scale = 4 122 | # font_thickness = 8 123 | # font_color = (0, 0, 255) 124 | # text_size = cv2.getTextSize(time_text, font, font_scale, font_thickness)[0] 125 | # text_x = (width - text_size[0]) // 2 126 | # text_y = height - 5 127 | # cv2.putText(img, time_text, (text_x, text_y), font, font_scale, font_color, font_thickness) 128 | 129 | out.write(img) 130 | 131 | if args.save_to_video: 132 | out.release() 133 | 134 | del predictor, state 135 | gc.collect() 136 | torch.clear_autocast_cache() 137 | torch.cuda.empty_cache() 138 | 139 | if __name__ == "__main__": 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument("--video_path", default="./assets/demo.mp4", help="Input video path or directory of frames.") 142 | parser.add_argument("--txt_path", default="./assets/demo.txt", help="Path to ground truth text file.") 143 | parser.add_argument("--model_path", default="./checkpoints/sam2.1_hiera_base_plus.pt", help="Path to the model checkpoint.") 144 | parser.add_argument("--video_output_path", default="./results/", help="Path to save the output video.") 145 | parser.add_argument("--save_to_video", default=True, help="Save results to a video.") 146 | args = parser.parse_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /scripts/get_vis.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import json 3 | import sys 4 | from tqdm import tqdm 5 | 6 | def add_captions_to_video(video_input_path, json_path, video_output_path): 7 | # Load JSON data 8 | with open(json_path, 'r') as f: 9 | captions_data = json.load(f) 10 | 11 | # Open the input video 12 | cap = cv2.VideoCapture(video_input_path) 13 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 14 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 15 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 16 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 17 | 18 | # Initialize video writer 19 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 20 | out = cv2.VideoWriter(video_output_path, fourcc, fps, (width, height)) 21 | 22 | # Helper function: Check if a frame is within a time segment 23 | def is_frame_in_segment(frame_idx, start, end, fps): 24 | timestamp = frame_idx / fps 25 | return start <= timestamp <= end 26 | 27 | # Helper function: Wrap text to fit within the video width 28 | def wrap_text(text, font, font_scale, thickness, max_width): 29 | words = text.split() 30 | lines = [] 31 | current_line = "" 32 | for word in words: 33 | test_line = current_line + " " + word if current_line else word 34 | text_size = cv2.getTextSize(test_line, font, font_scale, thickness)[0] 35 | if text_size[0] <= max_width: 36 | current_line = test_line 37 | else: 38 | lines.append(current_line) 39 | current_line = word 40 | if current_line: 41 | lines.append(current_line) 42 | return lines 43 | 44 | # Dynamic scaling based on video resolution 45 | def get_font_scale_and_thickness(width, height): 46 | base_width = 1280.0 # Reference width for scaling 47 | scale_factor = width / base_width 48 | font_scale = max(0.5 * scale_factor, 0.4) # Reduced default font size 49 | thickness = max(int(1.5 * scale_factor), 1) # Slightly thinner font 50 | return font_scale, thickness 51 | 52 | # Process video frames 53 | frame_idx = 0 54 | with tqdm(total=frame_count, desc="Processing video") as pbar: 55 | while cap.isOpened(): 56 | ret, frame = cap.read() 57 | if not ret: 58 | break 59 | 60 | # Find captions for the current frame 61 | for caption in captions_data: 62 | start_time = float(caption['segment'][0]) 63 | end_time = float(caption['segment'][1]) 64 | text = caption['model_answer'] 65 | 66 | if is_frame_in_segment(frame_idx, start_time, end_time, fps): 67 | # Get font scale and thickness dynamically 68 | font = cv2.FONT_HERSHEY_SIMPLEX 69 | font_scale, font_thickness = get_font_scale_and_thickness(width, height) 70 | text_color = (255, 255, 255) # White 71 | bg_color = (0, 0, 0, 150) # Black with alpha for transparency 72 | margin = int(10 * (height / 720)) # Adjust margin proportionally 73 | max_width = int(width * 0.95) # Wrap text at 85% of video width 74 | 75 | # Wrap text into multiple lines 76 | lines = wrap_text(text, font, font_scale, font_thickness, max_width) 77 | line_height = cv2.getTextSize("Test", font, font_scale, font_thickness)[0][1] + margin 78 | 79 | # Determine the text box position 80 | total_text_height = len(lines) * line_height 81 | text_x = (width - max_width) // 2 82 | text_y = height - margin - total_text_height 83 | 84 | # Create a transparent overlay 85 | overlay = frame.copy() 86 | cv2.rectangle( 87 | overlay, 88 | (text_x - margin, text_y - margin), 89 | (text_x + max_width + margin, text_y + total_text_height + margin), 90 | (0, 0, 0), # Black background 91 | -1 92 | ) 93 | 94 | # Add the transparent overlay to the frame 95 | alpha = 0.6 # Transparency factor for the background 96 | frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0) 97 | 98 | # Draw each line of text 99 | for i, line in enumerate(lines): 100 | line_y = text_y + (i * line_height) + line_height 101 | cv2.putText( 102 | frame, 103 | line, 104 | (text_x, line_y), 105 | font, 106 | font_scale, 107 | text_color, 108 | font_thickness, 109 | lineType=cv2.LINE_AA 110 | ) 111 | 112 | out.write(frame) 113 | frame_idx += 1 114 | pbar.update(1) 115 | 116 | # Release resources 117 | cap.release() 118 | out.release() 119 | print(f"Captioned video saved to: {video_output_path}") 120 | 121 | 122 | if __name__ == "__main__": 123 | # Read arguments 124 | video_input_path = sys.argv[1] 125 | json_path = sys.argv[2] 126 | video_output_path = sys.argv[3] 127 | 128 | add_captions_to_video(video_input_path, json_path, video_output_path) 129 | -------------------------------------------------------------------------------- /scripts/main_inference.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import gc 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import pdb 7 | import torch 8 | from sam2.build_sam import build_sam2_video_predictor 9 | from tqdm import tqdm 10 | 11 | 12 | def load_lasot_gt(gt_path): 13 | with open(gt_path, 'r') as f: 14 | gt = f.readlines() 15 | 16 | # bbox in first frame are prompts 17 | prompts = {} 18 | fid = 0 19 | for line in gt: 20 | x, y, w, h = map(int, line.split(',')) 21 | prompts[fid] = ((x, y, x+w, y+h), 0) 22 | fid += 1 23 | 24 | return prompts 25 | 26 | color = [ 27 | (255, 0, 0), 28 | ] 29 | 30 | testing_set = "data/LaSOT/testing_set.txt" 31 | with open(testing_set, 'r') as f: 32 | test_videos = f.readlines() 33 | 34 | exp_name = "samurai" 35 | model_name = "base_plus" 36 | 37 | checkpoint = f"sam2/checkpoints/sam2.1_hiera_{model_name}.pt" 38 | if model_name == "base_plus": 39 | model_cfg = "configs/samurai/sam2.1_hiera_b+.yaml" 40 | else: 41 | model_cfg = f"configs/samurai/sam2.1_hiera_{model_name[0]}.yaml" 42 | 43 | video_folder= "data/LaSOT" 44 | pred_folder = f"results/{exp_name}/{exp_name}_{model_name}" 45 | 46 | save_to_video = True 47 | if save_to_video: 48 | vis_folder = f"visualization/{exp_name}/{model_name}" 49 | os.makedirs(vis_folder, exist_ok=True) 50 | vis_mask = {} 51 | vis_bbox = {} 52 | 53 | test_videos = sorted(test_videos) 54 | for vid, video in enumerate(test_videos): 55 | 56 | cat_name = video.split('-')[0] 57 | cid_name = video.split('-')[1] 58 | video_basename = video.strip() 59 | frame_folder = osp.join(video_folder, cat_name, video.strip(), "img") 60 | 61 | num_frames = len(os.listdir(osp.join(video_folder, cat_name, video.strip(), "img"))) 62 | 63 | print(f"\033[91mRunning video [{vid+1}/{len(test_videos)}]: {video} with {num_frames} frames\033[0m") 64 | 65 | height, width = cv2.imread(osp.join(frame_folder, "00000001.jpg")).shape[:2] 66 | 67 | predictor = build_sam2_video_predictor(model_cfg, checkpoint, device="cuda:0") 68 | 69 | predictions = [] 70 | 71 | if save_to_video: 72 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 73 | out = cv2.VideoWriter(osp.join(vis_folder, f'{video_basename}.mp4'), fourcc, 30, (width, height)) 74 | 75 | # Start processing frames 76 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): 77 | state = predictor.init_state(frame_folder, offload_video_to_cpu=True, offload_state_to_cpu=True, async_loading_frames=True) 78 | 79 | prompts = load_lasot_gt(osp.join(video_folder, cat_name, video.strip(), "groundtruth.txt")) 80 | 81 | bbox, track_label = prompts[0] 82 | frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0) 83 | 84 | for frame_idx, object_ids, masks in predictor.propagate_in_video(state): 85 | mask_to_vis = {} 86 | bbox_to_vis = {} 87 | 88 | assert len(masks) == 1 and len(object_ids) == 1, "Only one object is supported right now" 89 | for obj_id, mask in zip(object_ids, masks): 90 | mask = mask[0].cpu().numpy() 91 | mask = mask > 0.0 92 | non_zero_indices = np.argwhere(mask) 93 | if len(non_zero_indices) == 0: 94 | bbox = [0, 0, 0, 0] 95 | else: 96 | y_min, x_min = non_zero_indices.min(axis=0).tolist() 97 | y_max, x_max = non_zero_indices.max(axis=0).tolist() 98 | bbox = [x_min, y_min, x_max-x_min, y_max-y_min] 99 | bbox_to_vis[obj_id] = bbox 100 | mask_to_vis[obj_id] = mask 101 | 102 | if save_to_video: 103 | 104 | img = cv2.imread(f'{frame_folder}/{frame_idx+1:08d}.jpg') 105 | if img is None: 106 | break 107 | 108 | for obj_id in mask_to_vis.keys(): 109 | mask_img = np.zeros((height, width, 3), np.uint8) 110 | mask_img[mask_to_vis[obj_id]] = color[(obj_id+1)%len(color)] 111 | img = cv2.addWeighted(img, 1, mask_img, 0.75, 0) 112 | 113 | for obj_id in bbox_to_vis.keys(): 114 | cv2.rectangle(img, (bbox_to_vis[obj_id][0], bbox_to_vis[obj_id][1]), (bbox_to_vis[obj_id][0]+bbox_to_vis[obj_id][2], bbox_to_vis[obj_id][1]+bbox_to_vis[obj_id][3]), color[(obj_id)%len(color)], 2) 115 | 116 | x1, y1, x2, y2 = prompts[frame_idx][0] 117 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2) 118 | out.write(img) 119 | 120 | predictions.append(bbox_to_vis) 121 | 122 | os.makedirs(pred_folder, exist_ok=True) 123 | with open(osp.join(pred_folder, f'{video_basename}.txt'), 'w') as f: 124 | for pred in predictions: 125 | x, y, w, h = pred[0] 126 | f.write(f"{x},{y},{w},{h}\n") 127 | 128 | if save_to_video: 129 | out.release() 130 | 131 | del predictor 132 | del state 133 | gc.collect() 134 | torch.clear_autocast_cache() 135 | torch.cuda.empty_cache() 136 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | 8 | from setuptools import find_packages, setup 9 | 10 | # Package metadata 11 | NAME = "CAT-2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "Caption Anything in Video: Object-centric Dense Video Captioning with Multimodal Controls" 14 | URL = "https://github.com/yunlong10/CAT-2" 15 | AUTHOR = "Tang, Yunlong and Bi, Jing and Hua, Hang and Xiao, Yunzhong and Song, Yizhi and Wang, Teng and Huang, Chao and Feng, Mingqian and Guo, Junjia and Liu, Zhuo and Song, Luchuan and Liang, Susan and Wang, Bingjie and Shimada, Daiki and Vosoughi, Ali and Zhang, Zeliang and Luo, Jiebo and Xu, Chenliang" 16 | AUTHOR_EMAIL = "yunlong.tang@rochester.edu" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | with open("README.md", "r", encoding="utf-8") as f: 21 | LONG_DESCRIPTION = f.read() 22 | 23 | # Required dependencies 24 | REQUIRED_PACKAGES = [ 25 | "torch>=2.3.1", 26 | "torchvision>=0.18.1", 27 | "numpy>=1.24.4", 28 | "tqdm>=4.66.1", 29 | "hydra-core>=1.3.2", 30 | "iopath>=0.1.10", 31 | "pillow>=9.4.0", 32 | "matplotlib>=3.9.1", 33 | "moviepy==1.0.3", 34 | "accelerate>=0.26.0", 35 | "numpy==1.26.1", 36 | "tikzplotlib", 37 | "jpeg4py", 38 | "opencv-python", 39 | "lmdb", 40 | "pandas", 41 | "scipy", 42 | "loguru", 43 | "einops", 44 | "transformers==4.40.1", 45 | "timm", 46 | "decord", 47 | "imageio", 48 | "scenedetect", 49 | "SentencePiece", 50 | "gradio", 51 | ] 52 | 53 | EXTRA_PACKAGES = { 54 | "notebooks": [ 55 | "matplotlib>=3.9.1", 56 | "jupyter>=1.0.0", 57 | "opencv-python>=4.7.0", 58 | "eva-decord>=0.6.1", 59 | ], 60 | "interactive-demo": [ 61 | "Flask>=3.0.3", 62 | "Flask-Cors>=5.0.0", 63 | "av>=13.0.0", 64 | "dataclasses-json>=0.6.7", 65 | "eva-decord>=0.6.1", 66 | "gunicorn>=23.0.0", 67 | "imagesize>=1.4.1", 68 | "pycocotools>=2.0.8", 69 | "strawberry-graphql>=0.243.0", 70 | ], 71 | "dev": [ 72 | "black==24.2.0", 73 | "usort==1.0.2", 74 | "ufmt==2.0.0b2", 75 | "fvcore>=0.1.5.post20221221", 76 | "pandas>=2.2.2", 77 | "scikit-image>=0.24.0", 78 | "tensorboard>=2.17.0", 79 | "pycocotools>=2.0.8", 80 | "tensordict>=0.5.0", 81 | "opencv-python>=4.7.0", 82 | "submitit>=1.5.1", 83 | ], 84 | } 85 | 86 | # By default, we also build the SAM 2 CUDA extension. 87 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`. 88 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1" 89 | # By default, we allow SAM 2 installation to proceed even with build errors. 90 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`. 91 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1" 92 | 93 | # Catch and skip errors during extension building and print a warning message 94 | # (note that this message only shows up under verbose build mode 95 | # "pip install -v -e ." or "python setup.py build_ext -v") 96 | CUDA_ERROR_MSG = ( 97 | "{}\n\n" 98 | "Failed to build the SAM 2 CUDA extension due to the error above. " 99 | "You can still use SAM 2 and it's OK to ignore the error above, although some " 100 | "post-processing functionality may be limited (which doesn't affect the results in most cases; " 101 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n" 102 | ) 103 | 104 | 105 | def get_extensions(): 106 | if not BUILD_CUDA: 107 | return [] 108 | 109 | try: 110 | from torch.utils.cpp_extension import CUDAExtension 111 | 112 | srcs = ["sam2/csrc/connected_components.cu"] 113 | compile_args = { 114 | "cxx": [], 115 | "nvcc": [ 116 | "-DCUDA_HAS_FP16=1", 117 | "-D__CUDA_NO_HALF_OPERATORS__", 118 | "-D__CUDA_NO_HALF_CONVERSIONS__", 119 | "-D__CUDA_NO_HALF2_OPERATORS__", 120 | ], 121 | } 122 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 123 | except Exception as e: 124 | if BUILD_ALLOW_ERRORS: 125 | print(CUDA_ERROR_MSG.format(e)) 126 | ext_modules = [] 127 | else: 128 | raise e 129 | 130 | return ext_modules 131 | 132 | 133 | try: 134 | from torch.utils.cpp_extension import BuildExtension 135 | 136 | class BuildExtensionIgnoreErrors(BuildExtension): 137 | 138 | def finalize_options(self): 139 | try: 140 | super().finalize_options() 141 | except Exception as e: 142 | print(CUDA_ERROR_MSG.format(e)) 143 | self.extensions = [] 144 | 145 | def build_extensions(self): 146 | try: 147 | super().build_extensions() 148 | except Exception as e: 149 | print(CUDA_ERROR_MSG.format(e)) 150 | self.extensions = [] 151 | 152 | def get_ext_filename(self, ext_name): 153 | try: 154 | return super().get_ext_filename(ext_name) 155 | except Exception as e: 156 | print(CUDA_ERROR_MSG.format(e)) 157 | self.extensions = [] 158 | return "_C.so" 159 | 160 | cmdclass = { 161 | "build_ext": ( 162 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True) 163 | if BUILD_ALLOW_ERRORS 164 | else BuildExtension.with_options(no_python_abi_suffix=True) 165 | ) 166 | } 167 | except Exception as e: 168 | cmdclass = {} 169 | if BUILD_ALLOW_ERRORS: 170 | print(CUDA_ERROR_MSG.format(e)) 171 | else: 172 | raise e 173 | 174 | 175 | # Setup configuration 176 | setup( 177 | name=NAME, 178 | version=VERSION, 179 | description=DESCRIPTION, 180 | long_description=LONG_DESCRIPTION, 181 | long_description_content_type="text/markdown", 182 | url=URL, 183 | author=AUTHOR, 184 | author_email=AUTHOR_EMAIL, 185 | license=LICENSE, 186 | packages=find_packages(exclude="notebooks"), 187 | include_package_data=True, 188 | install_requires=REQUIRED_PACKAGES, 189 | extras_require=EXTRA_PACKAGES, 190 | python_requires=">=3.10.0", 191 | ext_modules=get_extensions(), 192 | cmdclass=cmdclass, 193 | ) 194 | -------------------------------------------------------------------------------- /trace/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import copy 3 | from functools import partial 4 | 5 | import torch 6 | 7 | from .model import TraceMistralForCausalLM 8 | from .model.builder import load_pretrained_model 9 | from .conversation import conv_templates, SeparatorStyle 10 | from .mm_utils import process_video, tokenizer_MMODAL_token, get_model_name_from_path, KeywordsStoppingCriteria 11 | from .constants import NUM_FRAMES, DEFAULT_MMODAL_TOKEN, DEFAULT_MMODAL_START_TOKEN, DEFAULT_MMODAL_END_TOKEN, MMODAL_TOKEN_INDEX 12 | 13 | def model_init(model_path=None): 14 | model_path = "DAMO-NLP-SG/Trace-7B" if model_path is None else model_path 15 | model_name = get_model_name_from_path(model_path) 16 | tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name) 17 | 18 | num_frames = model.config.num_frames if hasattr(model.config, "num_frames") else NUM_FRAMES 19 | 20 | return model, partial(process_video, aspect_ratio=None, processor=processor, num_frames=num_frames), tokenizer 21 | 22 | 23 | def infer(model, video, instruct, tokenizer, do_sample=False): 24 | """inference api of Trace for video understanding. 25 | 26 | Args: 27 | model: Trace model. 28 | video (torch.Tensor): video tensor (T, C, H, W). 29 | instruct (str): text instruction for understanding video. 30 | tokenizer: tokenizer. 31 | do_sample (bool): whether to sample. 32 | Returns: 33 | str: response of the model. 34 | """ 35 | 36 | # 1. vision preprocess (load & transform image or video). 37 | tensor = [video.half().cuda()] 38 | modals = ["video"] 39 | 40 | # 2. text preprocess (tag process & generate prompt). 41 | modal_token = DEFAULT_MMODAL_TOKEN['VIDEO'] 42 | modal_index = MMODAL_TOKEN_INDEX["VIDEO"] 43 | instruct = modal_token + '\n' + instruct 44 | 45 | conv = conv_templates["llama_2"].copy() 46 | conv.append_message(conv.roles[0], instruct) 47 | conv.append_message(conv.roles[1], None) 48 | prompt = conv.get_prompt() 49 | 50 | input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_index, return_tensors='pt').unsqueeze(0).cuda() 51 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda() 52 | 53 | # 3. generate response according to visual signals and prompts. 54 | stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.QWEN] else conv.sep2 55 | # keywords = ["", ""] 56 | keywords = [stop_str] 57 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) 58 | 59 | with torch.inference_mode(): 60 | output_ids = model.generate( 61 | input_ids, 62 | attention_mask=attention_masks, 63 | images_or_videos=tensor, 64 | modal_list=modals, 65 | do_sample=do_sample, 66 | temperature=0.2 if do_sample else 0.0, 67 | max_new_tokens=1024, 68 | use_cache=True, 69 | stopping_criteria=[stopping_criteria], 70 | pad_token_id=tokenizer.eos_token_id, 71 | ) 72 | 73 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() 74 | 75 | return outputs 76 | 77 | 78 | def x_infer(video, question, model, tokenizer, mode='vanilla', do_sample=False): 79 | if mode == 'mcqa': 80 | instruction = f'{question}\nAnswer with the option\'s letter from the given choices directly and only give the best option.' 81 | return infer(model=model, tokenizer=tokenizer, video=video, instruct=instruction, do_sample=do_sample) 82 | elif mode == 'openend': 83 | instruction = f'{question}\nAnswer the question using a single word or a short phrase with multiple words.' 84 | return infer(model=model, tokenizer=tokenizer, video=video, instruct=instruction, do_sample=do_sample) 85 | elif mode == 'vanilla': 86 | instruction = question 87 | return infer(model=model, tokenizer=tokenizer, video=video, instruct=instruction, do_sample=do_sample) -------------------------------------------------------------------------------- /trace/constants.py: -------------------------------------------------------------------------------- 1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 2 | WORKER_HEART_BEAT_INTERVAL = 15 3 | 4 | LOGDIR = "./log_dir" 5 | 6 | NUM_FRAMES = 8 7 | MAX_FRAMES = 128 8 | NUM_FRAMES_PER_SECOND = 1 9 | Grids = [(2, 2), (1, 2), (1, 3), (1, 4), (2, 1), (3, 1), (4, 1)] 10 | 11 | # Model Constants 12 | IGNORE_INDEX = -100 13 | IMAGE_TOKEN_INDEX = -200 14 | DEFAULT_IMAGE_TOKEN = "" 15 | DEFAULT_VIDEO_TOKEN = "