├── .gitignore
├── LICENSE-Caption-Anything.txt
├── LICENSE-Qwen2-VL.txt
├── LICENSE-SAMURAI.txt
├── LICENSE-VideoLLaMA2.txt
├── README.md
├── assets
├── ball.txt
├── cat-v-framework.png
├── cover.png
├── demo.mp4
├── demo.txt
├── demo_short.mp4
└── jump.mp4
├── checkpoints
└── download_ckpts.sh
├── environment.yml
├── eval_utils.py
├── gradio_app.py
├── inference.sh
├── init.sh
├── internvl
├── get_acc.py
├── test-batch.sh
└── test.py
├── requirements.txt
├── sam2
├── __init__.py
├── automatic_mask_generator.py
├── build_sam.py
├── configs
│ ├── sam2.1
│ │ ├── sam2.1_hiera_b+.yaml
│ │ ├── sam2.1_hiera_l.yaml
│ │ ├── sam2.1_hiera_s.yaml
│ │ └── sam2.1_hiera_t.yaml
│ ├── sam2.1_training
│ │ └── sam2.1_hiera_b+_MOSE_finetune.yaml
│ ├── sam2
│ │ ├── sam2_hiera_b+.yaml
│ │ ├── sam2_hiera_l.yaml
│ │ ├── sam2_hiera_s.yaml
│ │ └── sam2_hiera_t.yaml
│ └── samurai
│ │ ├── sam2.1_hiera_b+.yaml
│ │ ├── sam2.1_hiera_l.yaml
│ │ ├── sam2.1_hiera_s.yaml
│ │ └── sam2.1_hiera_t.yaml
├── csrc
│ └── connected_components.cu
├── modeling
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── hieradet.py
│ │ ├── image_encoder.py
│ │ └── utils.py
│ ├── memory_attention.py
│ ├── memory_encoder.py
│ ├── position_encoding.py
│ ├── sam
│ │ ├── __init__.py
│ │ ├── mask_decoder.py
│ │ ├── prompt_encoder.py
│ │ └── transformer.py
│ ├── sam2_base.py
│ └── sam2_utils.py
├── sam2_hiera_b+.yaml
├── sam2_hiera_l.yaml
├── sam2_hiera_s.yaml
├── sam2_hiera_t.yaml
├── sam2_image_predictor.py
├── sam2_video_predictor.py
└── utils
│ ├── __init__.py
│ ├── amg.py
│ ├── kalman_filter.py
│ ├── misc.py
│ └── transforms.py
├── scripts
├── dog.txt
├── get_boundary.py
├── get_caption.py
├── get_masks.py
├── get_vis.py
├── inference
│ └── inference.py
└── main_inference.py
├── setup.py
└── trace
├── __init__.py
├── constants.py
├── conversation.py
├── eval
├── eval.sh
├── evaluate.py
├── mvbench
│ ├── eval.sh
│ └── evaluate.py
├── reformat_dvc.py
├── reformat_tvg.py
├── reformat_vhd.py
└── videomme
│ ├── eval.sh
│ └── evaluate.py
├── metrics
├── README.md
├── dvc
│ ├── SODA
│ │ ├── LICENSE
│ │ ├── README.md
│ │ ├── dataset.py
│ │ ├── nlpeval
│ │ │ ├── bert_f_score.py
│ │ │ ├── bert_r_score.py
│ │ │ └── mover.py
│ │ ├── requirements.txt
│ │ ├── soda.py
│ │ └── utils.py
│ ├── __init__.py
│ ├── eval_dvc.py
│ ├── eval_dvc_anet.py
│ ├── eval_soda.py
│ └── metrics
│ │ ├── README.md
│ │ ├── cider.py
│ │ ├── cider_scorer.py
│ │ ├── eval_soda.py
│ │ ├── meteor-1.5.jar
│ │ ├── meteor.py
│ │ ├── ptbtokenizer.py
│ │ └── stanford-corenlp-3.4.1.jar
├── tvg
│ ├── eval_tvg.py
│ └── eval_tvg.sh
└── vhd
│ ├── eval_highlights.sh
│ ├── eval_vhd.py
│ └── utils.py
├── mm_utils.py
├── model
├── __init__.py
├── builder.py
├── language_model
│ └── trace_mistral.py
├── multimodal_encoder
│ ├── builder.py
│ ├── clip_encoder.py
│ ├── score_encoder.py
│ ├── sync_encoder.py
│ └── time_encoder.py
├── multimodal_projector
│ ├── __init__.py
│ └── builder.py
└── trace_arch.py
├── prompts
├── dvc-anet-ft.txt
├── dvc-anet.txt
├── dvc.txt
├── mr.txt
└── vhd.txt
├── trace_trainer.py
├── train_mt.py
├── train_mt_npu.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .vscode/
2 | .DS_Store
3 | __pycache__/
4 | *-checkpoint.ipynb
5 | .venv
6 | *.egg*
7 | build/*
8 | _C.*
9 | outputs/*
10 | checkpoints/*.pt
11 |
12 |
13 | results/*.mp4
14 |
15 | # Python
16 | __pycache__
17 | *.pyc
18 | *.egg-info
19 | dist
20 |
21 | # Log
22 | *.log
23 | *.log.*
24 | *.json
25 | *.jsonl
26 | log_dir*/
27 |
28 | # Data
29 | !**/alpaca-data-conversation.json
30 |
31 | # Editor
32 | .idea
33 | *.swp
34 |
35 | # Other
36 | .DS_Store
37 |
38 | # jupyter
39 | .ipynb_checkpoints
40 | *.ipynb
41 |
42 | # DevContainer
43 | !.devcontainer/*
44 |
45 | # Demo
46 | serve_images/
47 |
48 | # data folder
49 | data/
50 | dataset/
51 | datasets/
52 | results/
53 |
54 | # training folder
55 | wandb
56 | ckpts*
57 | output
58 | output/
59 | # checkpoints
60 | # checkpoints/
61 | work_dirs*/
62 |
63 | # evaluation folder
64 | eval_results/
65 |
66 | # pretrained weights
67 | pretrained/
68 | publish_models/
69 |
70 |
--------------------------------------------------------------------------------
/LICENSE-Caption-Anything.txt:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2023, Teng Wang
4 |
5 | Redistribution and use in source and binary forms, with or without
6 | modification, are permitted provided that the following conditions are met:
7 |
8 | 1. Redistributions of source code must retain the above copyright notice, this
9 | list of conditions and the following disclaimer.
10 |
11 | 2. Redistributions in binary form must reproduce the above copyright notice,
12 | this list of conditions and the following disclaimer in the documentation
13 | and/or other materials provided with the distribution.
14 |
15 | 3. Neither the name of the copyright holder nor the names of its
16 | contributors may be used to endorse or promote products derived from
17 | this software without specific prior written permission.
18 |
19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [AAAI-26 Demo] Caption Anything in Video: Object-centric Dense Video Captioning with Spatiotemporal Multimodal Prompting
2 | 
3 |
4 |
5 |
6 | Official PyTorch implementation of [Caption Anything in Video: Fine-grained Object-centric Captioning via Spatiotemporal Multimodal Prompting](https://arxiv.org/abs/2504.05541)
7 |
8 | 
9 |
10 | ## 🚀 Updates
11 |
12 | ## 🕹️ Demo
13 |
14 | YouTube: [https://youtu.be/2eiPVKXEoxw](https://youtu.be/2eiPVKXEoxw)
15 |
16 | ## 🛠️ Getting Started
17 |
18 | 1. Set up a conda environment (python>= 3.10) using:
19 |
20 | ```bash
21 | conda create -n cat2 python=3.10 -y
22 | conda activate cat2
23 | ```
24 |
25 | 2. Install the requirements:
26 |
27 | ```bash
28 | pip install -e .
29 | ```
30 |
31 | 3. Download checkpoints:
32 |
33 | ```bash
34 | cd checkpoints && \
35 | ./download_ckpts.sh && \
36 | cd ..
37 | ```
38 |
39 | ## 🏃 RUN
40 |
41 | ```
42 | bash inference.sh
43 | ```
44 |
45 |
46 | ## 📖 Citation
47 | If you find this work useful for your research or applications, please cite using this BibTeX:
48 |
49 | ```bibtex
50 | @inproceedings{tang2025cat-v,
51 | title={Caption Anything in Video: Fine-grained Object-centric Captioning via Spatiotemporal Multimodal Prompting},
52 | author={Tang, Yunlong and Bi, Jing and Hua, Hang and Xiao, Yunzhong and Song, Yizhi and Liu, Pinxin and Huang, Chao and Feng, Mingqian and Guo, Junjia and Liu, Zhuo and Song, Luchuan and Liang, Susan and Shimada, Daiki and Vosoughi, Ali and He, Jinxi and He, Liu and Zhang, Zeliang and Luo, Jiebo and Xu, Chenliang},
53 | journel={arXiv},
54 | year={2025}
55 | }
56 | ```
57 |
58 |
59 | ## 🙏 Acknowledgements
60 | This work was supported by Sony Group Corporation. We would like to thank Sayaka Nakamura and Jerry Jun Yokono for their insightful discussion.
61 |
62 | We are also grateful for the following awesome projects our CAT-V arising from:
63 |
64 | - [Caption Anything](https://github.com/ttengwang/Caption-Anything)
65 | - [SAM 2](https://github.com/facebookresearch/sam2)
66 | - [SAMURAI](https://github.com/yangchris11/samurai)
67 | - [TRACE-uni](https://github.com/gyxxyg/TRACE)
68 | - [VideoLLaMA2](https://github.com/DAMO-NLP-SG/VideoLLaMA2)
69 | - [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL)
70 | - [InternVL-2.5](https://internvl.github.io/blog/2024-12-05-InternVL-2.5/)
71 |
72 |
73 | ## 👩💻 Contributors
74 | Our project wouldn't be possible without the contributions of these amazing people! Thank you all for making this project better.
75 |
76 | - [Yunlong Tang](https://yunlong10.github.io/) @ University of Rochester
77 | - [Jing Bi](https://scholar.google.com/citations?user=ZyCYhUkAAAAJ) @ University of Rochester
78 | - [Chao Huang](https://wikichao.github.io/) @ University of Rochester
79 | - [Susan Liang](https://liangsusan-git.github.io/) @ University of Rochester
80 | - [Daiki Shimada](https://scholar.google.co.jp/citations?user=1uAwouQAAAAJ) @ Sony Group Corporation
81 | - [Hang Hua](https://hanghuacs.notion.site/Hang-Hua-151c5b68f62980e8884febf1b5c1d4a9) @ University of Rochester
82 | - [Yunzhong Xiao](https://shawn-yzxiao.github.io/) @ Carnegie Mellon University
83 | - [Yizhi Song](https://song630.github.io/yizhisong.github.io/) @ Purdue University
84 | - [Pinxin Liu](https://andypinxinliu.github.io/) @ University of Rochester
85 | - [Mingqian Feng](https://fmmarkmq.github.io/) @ University of Rochester
86 | - [Junjia Guo](https://doujiangter.github.io/JunjiaGuo.github.io/) @ University of Rochester
87 | - [Zhuo Liu](https://joeliuz6.github.io/) @ University of Rochester
88 | - [Luchuan Song](https://songluchuan.github.io/) @ University of Rochester
89 | - [Ali Vosoughi](https://alivosoughi.com/) @ University of Rochester
90 | - [Jinxi He](https://gingin520.github.io/) @ University of Rochester
91 | - [Liu He](https://arking1995.github.io/) @ Purdue University
92 | - [Zeliang Zhang](https://zhangaipi.github.io/) @ University of Rochester
93 | - [Jiebo Luo](https://www.cs.rochester.edu/u/jluo/) @ University of Rochester
94 | - [Chenliang Xu](https://www.cs.rochester.edu/~cxu22/index.html) @ University of Rochester
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 | ### 🌟 Star History
103 |
104 | [](https://star-history.com/#yunlong10/CAT-V&Date)
105 |
--------------------------------------------------------------------------------
/assets/ball.txt:
--------------------------------------------------------------------------------
1 | 390, 435, 60, 60
--------------------------------------------------------------------------------
/assets/cat-v-framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/cat-v-framework.png
--------------------------------------------------------------------------------
/assets/cover.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/cover.png
--------------------------------------------------------------------------------
/assets/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/demo.mp4
--------------------------------------------------------------------------------
/assets/demo.txt:
--------------------------------------------------------------------------------
1 | 720, 250, 750, 300
--------------------------------------------------------------------------------
/assets/demo_short.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/demo_short.mp4
--------------------------------------------------------------------------------
/assets/jump.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/assets/jump.mp4
--------------------------------------------------------------------------------
/checkpoints/download_ckpts.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Copyright (c) Meta Platforms, Inc. and affiliates.
4 | # All rights reserved.
5 |
6 | # This source code is licensed under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 |
9 | # Use either wget or curl to download the checkpoints
10 | if command -v wget &> /dev/null; then
11 | CMD="wget"
12 | elif command -v curl &> /dev/null; then
13 | CMD="curl -L -O"
14 | else
15 | echo "Please install wget or curl to download the checkpoints."
16 | exit 1
17 | fi
18 |
19 | # Define the URLs for SAM 2 checkpoints
20 | # SAM2_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/072824"
21 | # sam2_hiera_t_url="${SAM2_BASE_URL}/sam2_hiera_tiny.pt"
22 | # sam2_hiera_s_url="${SAM2_BASE_URL}/sam2_hiera_small.pt"
23 | # sam2_hiera_b_plus_url="${SAM2_BASE_URL}/sam2_hiera_base_plus.pt"
24 | # sam2_hiera_l_url="${SAM2_BASE_URL}/sam2_hiera_large.pt"
25 |
26 | # Download each of the four checkpoints using wget
27 | # echo "Downloading sam2_hiera_tiny.pt checkpoint..."
28 | # $CMD $sam2_hiera_t_url || { echo "Failed to download checkpoint from $sam2_hiera_t_url"; exit 1; }
29 |
30 | # echo "Downloading sam2_hiera_small.pt checkpoint..."
31 | # $CMD $sam2_hiera_s_url || { echo "Failed to download checkpoint from $sam2_hiera_s_url"; exit 1; }
32 |
33 | # echo "Downloading sam2_hiera_base_plus.pt checkpoint..."
34 | # $CMD $sam2_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2_hiera_b_plus_url"; exit 1; }
35 |
36 | # echo "Downloading sam2_hiera_large.pt checkpoint..."
37 | # $CMD $sam2_hiera_l_url || { echo "Failed to download checkpoint from $sam2_hiera_l_url"; exit 1; }
38 |
39 | # Define the URLs for SAM 2.1 checkpoints
40 | SAM2p1_BASE_URL="https://dl.fbaipublicfiles.com/segment_anything_2/092824"
41 | sam2p1_hiera_t_url="${SAM2p1_BASE_URL}/sam2.1_hiera_tiny.pt"
42 | sam2p1_hiera_s_url="${SAM2p1_BASE_URL}/sam2.1_hiera_small.pt"
43 | sam2p1_hiera_b_plus_url="${SAM2p1_BASE_URL}/sam2.1_hiera_base_plus.pt"
44 | sam2p1_hiera_l_url="${SAM2p1_BASE_URL}/sam2.1_hiera_large.pt"
45 |
46 | # SAM 2.1 checkpoints
47 | echo "Downloading sam2.1_hiera_tiny.pt checkpoint..."
48 | $CMD $sam2p1_hiera_t_url || { echo "Failed to download checkpoint from $sam2p1_hiera_t_url"; exit 1; }
49 |
50 | echo "Downloading sam2.1_hiera_small.pt checkpoint..."
51 | $CMD $sam2p1_hiera_s_url || { echo "Failed to download checkpoint from $sam2p1_hiera_s_url"; exit 1; }
52 |
53 | echo "Downloading sam2.1_hiera_base_plus.pt checkpoint..."
54 | $CMD $sam2p1_hiera_b_plus_url || { echo "Failed to download checkpoint from $sam2p1_hiera_b_plus_url"; exit 1; }
55 |
56 | echo "Downloading sam2.1_hiera_large.pt checkpoint..."
57 | $CMD $sam2p1_hiera_l_url || { echo "Failed to download checkpoint from $sam2p1_hiera_l_url"; exit 1; }
58 |
59 | echo "All checkpoints are downloaded successfully."
60 |
--------------------------------------------------------------------------------
/gradio_app.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import gradio as gr
4 | import subprocess
5 | import json
6 | import cv2
7 | import tempfile
8 | from PIL import Image
9 | import numpy as np
10 |
11 | CONFIG = {
12 | "model_path": "OpenGVLab/InternVL2-8B",
13 | "get_boundary_model_path": "Yongxin-Guo/trace-uni",
14 | "get_mask_model_path": "./checkpoints/sam2.1_hiera_base_plus.pt",
15 | "output_folder": "./results/",
16 | "frame_count": 16,
17 | }
18 |
19 | def extract_first_frame(video_path):
20 | cap = cv2.VideoCapture(video_path)
21 | ret, image = cap.read()
22 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
23 | cap.release()
24 | height, width = image.shape[:2]
25 | if height > 750:
26 | scale = 750 / height
27 | new_width = int(width * scale)
28 | new_height = 750
29 | image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_AREA)
30 | return image
31 |
32 | def run_inference_pipeline(video_path, bbox):
33 | """
34 | Run the entire inference pipeline for video processing
35 | """
36 | # Ensure output folder exists
37 | os.makedirs(CONFIG["output_folder"], exist_ok=True)
38 |
39 | # Prepare file paths
40 | video_name = os.path.basename(video_path)
41 | qa_file_path = os.path.join(
42 | CONFIG["output_folder"], f"{os.path.splitext(video_name)[0]}_boundary.json"
43 | )
44 | final_json_path = os.path.join(
45 | CONFIG["output_folder"],
46 | f"{os.path.splitext(video_name)[0]}_boundary_caption.json",
47 | )
48 | final_video_path = os.path.join(
49 | CONFIG["output_folder"],
50 | f"{os.path.splitext(video_name)[0]}_boundary_caption.mp4",
51 | )
52 | masked_video_path = os.path.join(
53 | CONFIG["output_folder"], f"{os.path.splitext(video_name)[0]}_mask.mp4"
54 | )
55 | print(f"Final JSON Path: {final_json_path}")
56 | video = cv2.VideoCapture(video_path)
57 | ret, frame = video.read()
58 | h,w = frame.shape[:2]
59 | print(h,w)
60 | video.release()
61 | bbox = [int(bbox[0]*w), int(bbox[1]*h), int(bbox[2]*w), int(bbox[3]*h)]
62 | print(bbox)
63 | object_bbox_path = Path(CONFIG['output_folder'])/f"{os.path.splitext(video_name)[0]}_bbox.txt"
64 | with open(object_bbox_path, "w") as f:
65 | f.write(','.join(map(str, bbox)))
66 | commands = [
67 | # Step 1: Parsing/Boundary Detection
68 | f"python -m scripts.get_boundary "
69 | f"--video_paths {video_path} "
70 | f"--questions 'Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences.' "
71 | f"--model_path {CONFIG['get_boundary_model_path']}",
72 | ]
73 |
74 |
75 | commands.append(
76 | f"python scripts/get_masks.py "
77 | f"--video_path {video_path} "
78 | f"--txt_path {object_bbox_path} "
79 | f"--model_path {CONFIG['get_mask_model_path']} "
80 | f"--video_output_path {CONFIG['output_folder']} "
81 | f"--save_to_video True"
82 | )
83 |
84 | # Step 2: Captioning
85 | commands.append(
86 | f"python scripts/get_caption.py "
87 | f"--model_path {CONFIG['model_path']} "
88 | f"--QA_file_path {qa_file_path} "
89 | f"--video_folder {CONFIG['output_folder']} "
90 | f"--answers_output_folder {CONFIG['output_folder']} "
91 | f"--extract_frames_method max_frames_num "
92 | f"--max_frames_num {CONFIG['frame_count']} "
93 | f"--frames_from video "
94 | f"--final_json_path {final_json_path} "
95 | f"--provide_boundaries"
96 | )
97 |
98 | # Step 3: Generate Visualization
99 | commands.append(
100 | f"python scripts/get_vis.py {masked_video_path if object_bbox_path else video_path} {final_json_path} {final_video_path}"
101 | )
102 |
103 | # Execute commands
104 | for cmd in commands:
105 | try:
106 | subprocess.run(cmd, shell=True, check=True)
107 | except subprocess.CalledProcessError as e:
108 | print(f"Error in command: {cmd}")
109 | print(f"Error details: {e}")
110 | return None
111 |
112 | try:
113 | with open(final_json_path, "r") as f:
114 | results = json.load(f)
115 | return {"captions": results, "final_video": final_video_path}
116 | except Exception as e:
117 | print(f"Error reading results: {e}")
118 | return None
119 |
120 | def get_bounding_box(image):
121 | alpha_channel = image[:, :, 3]
122 | y_coords, x_coords = np.where(alpha_channel > 0)
123 |
124 | if y_coords.size == 0 or x_coords.size == 0:
125 | return None
126 | x_min, x_max = x_coords.min(), x_coords.max()
127 | y_min, y_max = y_coords.min(), y_coords.max()
128 | x_min_ratio = x_min / image.shape[1]
129 | x_max_ratio = x_max / image.shape[1]
130 | y_min_ratio = y_min / image.shape[0]
131 | y_max_ratio = y_max / image.shape[0]
132 | return x_min_ratio, y_min_ratio, x_max_ratio, y_max_ratio
133 | def caption_video(video, edited_image):
134 | """
135 | Gradio-friendly wrapper for inference pipeline
136 | video: path to the uploaded video
137 | bbox_file: path to the uploaded bounding box file (optional)
138 | edited_image: the edited first frame image returned by ImageEditor (PIL Image)
139 | """
140 | layer_0 = edited_image['layers'][0]
141 | bbox = get_bounding_box(layer_0)
142 |
143 | if video is None:
144 | return "Please upload a video first.", None
145 | results = run_inference_pipeline(video, bbox)
146 |
147 | if results is None:
148 | return "Processing failed. Please check the logs.", None
149 |
150 | # Format captions nicely
151 | captions_text = "\n\n".join(
152 | [
153 | f"Event {i+1} (Time: {event.get('timestamp', 'N/A')}):\n{event.get('model_answer', 'No caption')}"
154 | for i, event in enumerate(results.get("captions", []))
155 | ]
156 | )
157 |
158 | return captions_text, results.get("final_video")
159 |
160 |
161 |
162 |
163 | def create_demo():
164 | """
165 | Create Gradio interface
166 | """
167 |
168 | DESCRIPTION = """# CAT2:
169 | This is a demo for our 'CAT2' [paper](https://github.com/yunlong10/CAT-2).
170 | Code is available [here](https://github.com/yunlong10/CAT-2).
171 | This demo performs captioning with optional object bounding box annotation.
172 | """
173 |
174 | with gr.Blocks() as demo:
175 | gr.Markdown("# Caption Anything Demo")
176 | gr.Markdown(DESCRIPTION)
177 | gr.Markdown(
178 | "Upload a video and optionally a bounding box file. Or draw a rectangle on the first frame of the video to provide a bounding box. (Note: The ImageEditor does not return bounding box coordinates directly. Further processing may be required.)"
179 | )
180 |
181 | with gr.Row():
182 | video_input = gr.Video(label="Upload Video",height=800)
183 | first_frame_editor = gr.ImageEditor(label="Draw a rectangle on the First Frame",height=800)
184 |
185 | video_input.change(fn=extract_first_frame, inputs=video_input, outputs=first_frame_editor)
186 |
187 | caption_button = gr.Button("Generate Captions")
188 |
189 | output_text = gr.Textbox(label="Video Captions")
190 | output_video = gr.Video(label="Processed Video")
191 |
192 | caption_button.click(
193 | fn=caption_video,
194 | inputs=[video_input, first_frame_editor],
195 | outputs=[output_text, output_video],
196 | )
197 |
198 | return demo
199 |
200 |
201 | if __name__ == "__main__":
202 | demo = create_demo()
203 | demo.launch(
204 | server_name="0.0.0.0", # Make accessible from other machines
205 | server_port=8889,
206 | debug=True,
207 | )
--------------------------------------------------------------------------------
/inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # conda activate /home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/env/cat-2
4 | export TRANSFORMERS_CACHE=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/transformers_cache
5 | export TORCH_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/torch_home
6 | export HF_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/hf_home
7 | export PIP_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/pip
8 | export OPENAI_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/openai
9 |
10 | set -e
11 |
12 | GREEN="\033[32m"
13 | RESET="\033[0m"
14 | FRAME_COUNT=32
15 | OUTPUT_FOLDER="./results"
16 | mkdir -p $OUTPUT_FOLDER
17 | MODEL_PATH="OpenGVLab/InternVL2_5-8B-MPO" # "OpenGVLab/InternVL2-8B"
18 | GET_BOUNDARY_MODEL_PATH="Yongxin-Guo/trace-uni"
19 | GET_MASK_MODEL_PATH="./checkpoints/sam2.1_hiera_base_plus.pt"
20 |
21 | ############################################################################################################
22 | VIDEO_NAME="demo.mp4"
23 | VIDEO_FOLDER="./assets/"
24 | OBJECT_BBOX="demo.txt"
25 | QA_FILE_PATH="$OUTPUT_FOLDER/demo_boundary.json"
26 | FINAL_JSON_PATH="$OUTPUT_FOLDER/demo_boundary_caption.json"
27 | FINAL_VIDEO_PATH="$OUTPUT_FOLDER/demo_boundary_caption.mp4"
28 | MASKED_VIDEO_PATH="$OUTPUT_FOLDER/demo_mask.mp4"
29 | ############################################################################################################
30 |
31 | VIDEO_PATH="$VIDEO_FOLDER$VIDEO_NAME"
32 | OBJECT_BBOX_PATH="$VIDEO_FOLDER$OBJECT_BBOX"
33 |
34 | START_TIME=$(date +%s)
35 |
36 | echo -e "${GREEN}Step 1: Parsing...${RESET}"
37 |
38 | # python -m scripts.get_boundary \
39 | # --video_paths $VIDEO_PATH \
40 | # --questions "Localize a series of activity events in the video, output the start and end timestamp for each event, and describe each event with sentences." \
41 | # --model_path $GET_BOUNDARY_MODEL_PATH
42 |
43 | echo -e "${GREEN}Step 2: Segmentation...${RESET}"
44 |
45 |
46 | # python scripts/get_masks.py \
47 | # --video_path "$VIDEO_PATH" \
48 | # --txt_path "$OBJECT_BBOX_PATH" \
49 | # --model_path "$GET_MASK_MODEL_PATH" \
50 | # --video_output_path "$OUTPUT_FOLDER" \
51 | # --save_to_video True
52 |
53 | echo -e "${GREEN}Step 3: Captioning...${RESET}"
54 |
55 | python scripts/get_caption.py \
56 | --model_path "$MODEL_PATH" \
57 | --QA_file_path "$QA_FILE_PATH" \
58 | --video_folder "$OUTPUT_FOLDER" \
59 | --answers_output_folder "$OUTPUT_FOLDER" \
60 | --extract_frames_method "max_frames_num" \
61 | --max_frames_num "$FRAME_COUNT" \
62 | --frames_from "video" \
63 | --final_json_path "$FINAL_JSON_PATH" \
64 | --provide_boundaries
65 |
66 | echo -e "${GREEN}Step 3: Generate visualizations...${RESET}"
67 |
68 | # python scripts/get_vis.py "$MASKED_VIDEO_PATH" "$FINAL_JSON_PATH" "$FINAL_VIDEO_PATH"
69 |
70 | echo -e "${GREEN}Completed in $(($(date +%s) - START_TIME)) seconds.${RESET}"
--------------------------------------------------------------------------------
/init.sh:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yunlong10/CAT-V/85ca51152e364256533032954b3592df66e134cd/init.sh
--------------------------------------------------------------------------------
/internvl/get_acc.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 | from pycocoevalcap.bleu.bleu import Bleu
4 | from pycocoevalcap.meteor.meteor import Meteor
5 | from pycocoevalcap.cider.cider import Cider
6 | import os
7 |
8 | # Load the JSON file
9 |
10 | results = "./results"
11 | # read all json files in the directory
12 | for file in os.listdir(results):
13 | if file.endswith(".json"):
14 | with open(f"{results}/{file}", "r", encoding="utf-8") as file:
15 | data = json.load(file)
16 |
17 | # Prepare references and hypotheses as dictionaries
18 | gts = {} # Ground truth (references)
19 | res = {} # Results (hypotheses)a
20 |
21 | for i, item in enumerate(tqdm(data)):
22 | gts[i] = [item["correct_answer"]] # Reference list for ID i
23 | res[i] = [item["model_answer"]] # Hypothesis for ID i
24 |
25 | # BLEU Score
26 | def compute_bleu(gts, res):
27 | bleu_scorer = Bleu(4) # Compute BLEU-1 to BLEU-4
28 | score, _ = bleu_scorer.compute_score(gts, res)
29 | return score
30 |
31 | # METEOR Score
32 | def compute_meteor(gts, res):
33 | meteor_scorer = Meteor()
34 | score, _ = meteor_scorer.compute_score(gts, res)
35 | return score
36 |
37 | # CIDEr Score
38 | def compute_cider(gts, res):
39 | cider_scorer = Cider()
40 | score, _ = cider_scorer.compute_score(gts, res)
41 | return score
42 |
43 | # Calculate scores
44 | bleu_score = compute_bleu(gts, res)
45 | meteor_score = compute_meteor(gts, res)
46 | cider_score = compute_cider(gts, res)
47 |
48 | # Print results
49 | print(f"Results for {file}")
50 | print(f"BLEU Scores: {[round(i*100, 2) for i in bleu_score]}")
51 | print(f"METEOR Score: {round(meteor_score*100, 2)}")
52 | print(f"CIDEr Score: {round(cider_score*100, 2)}")
53 |
--------------------------------------------------------------------------------
/internvl/test-batch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | export TRANSFORMERS_CACHE=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/transformers_cache
3 | export TORCH_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/torch_home
4 | export HF_HOME=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/hf_home
5 | export PIP_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/pip
6 | export OPENAI_CACHE_DIR=/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/cache/openai
7 |
8 | # Exit immediately if a command exits with a non-zero status.
9 | set -e
10 |
11 | SCRIPT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
12 | PYTHON_FILE_PATH="${SCRIPT_DIR}/test.py"
13 | ANSWERS_OUTPUT_FOLDER="${SCRIPT_DIR}/results"
14 |
15 | if [ -e "$ANSWERS_OUTPUT_FOLDER" ]; then
16 | echo "File $ANSWERS_OUTPUT_FOLDER already exists."
17 | else
18 | echo "File $ANSWERS_OUTPUT_FOLDER does not exist. Creating it now..."
19 | mkdir "$ANSWERS_OUTPUT_FOLDER"
20 | echo "File $ANSWERS_OUTPUT_FOLDER created."
21 | fi
22 |
23 | # Variables - Please update these paths according to your setup
24 | MODEL_PATH="OpenGVLab/InternVL2_5-8B-MPO-AWQ" #"OpenGVLab/InternVL2-8B"
25 | VIDEO_FOLDER="/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/samurai/samed_videos"
26 |
27 | # Define task files and frame numbers
28 | TASK_FILES=("example.json")
29 | FRAME_COUNTS=("16")
30 |
31 | # Loop through each task file and each frame count
32 | for TASK in "${TASK_FILES[@]}"; do
33 | for FRAME_COUNT in "${FRAME_COUNTS[@]}"; do
34 | QA_FILE_PATH="/home/cxu-serve/p62/ytang37/projects/Caption-Anything-2/samurai/QAs/$TASK"
35 |
36 | # Execute the Python script with the provided arguments
37 | python "$PYTHON_FILE_PATH" \
38 | --model_path "$MODEL_PATH" \
39 | --QA_file_path "$QA_FILE_PATH" \
40 | --video_folder "$VIDEO_FOLDER" \
41 | --answers_output_folder "$ANSWERS_OUTPUT_FOLDER" \
42 | --extract_frames_method "max_frames_num" \
43 | --max_frames_num "$FRAME_COUNT" \
44 | --frames_from "video" \
45 | # --provide_boundaries
46 | done
47 | done
48 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | matplotlib==3.7
2 | moviepy==1.0.3
3 | accelerate>=0.26.0
4 | numpy==1.26.1
5 | tikzplotlib
6 | jpeg4py
7 | opencv-python
8 | lmdb
9 | pandas
10 | scipy
11 | loguru
12 | einops
13 | transformers==4.40.1
14 | timm
15 | decord
16 | imageio
17 | scenedetect
18 | SentencePiece
19 | gradio
--------------------------------------------------------------------------------
/sam2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from hydra import initialize_config_module
8 | from hydra.core.global_hydra import GlobalHydra
9 |
10 | if not GlobalHydra.instance().is_initialized():
11 | initialize_config_module("sam2", version_base="1.2")
12 |
--------------------------------------------------------------------------------
/sam2/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | import os
9 |
10 | import torch
11 | from hydra import compose
12 | from hydra.utils import instantiate
13 | from omegaconf import OmegaConf
14 |
15 | import sam2
16 |
17 | # Check if the user is running Python from the parent directory of the sam2 repo
18 | # (i.e. the directory where this repo is cloned into) -- this is not supported since
19 | # it could shadow the sam2 package and cause issues.
20 | if os.path.isdir(os.path.join(sam2.__path__[0], "sam2")):
21 | # If the user has "sam2/sam2" in their path, they are likey importing the repo itself
22 | # as "sam2" rather than importing the "sam2" python package (i.e. "sam2/sam2" directory).
23 | # This typically happens because the user is running Python from the parent directory
24 | # that contains the sam2 repo they cloned.
25 | raise RuntimeError(
26 | "You're likely running Python from the parent directory of the sam2 repository "
27 | "(i.e. the directory where https://github.com/facebookresearch/sam2 is cloned into). "
28 | "This is not supported since the `sam2` Python package could be shadowed by the "
29 | "repository name (the repository is also named `sam2` and contains the Python package "
30 | "in `sam2/sam2`). Please run Python from another directory (e.g. from the repo dir "
31 | "rather than its parent dir, or from your home directory) after installing SAM 2."
32 | )
33 |
34 |
35 | HF_MODEL_ID_TO_FILENAMES = {
36 | "facebook/sam2-hiera-tiny": (
37 | "configs/sam2/sam2_hiera_t.yaml",
38 | "sam2_hiera_tiny.pt",
39 | ),
40 | "facebook/sam2-hiera-small": (
41 | "configs/sam2/sam2_hiera_s.yaml",
42 | "sam2_hiera_small.pt",
43 | ),
44 | "facebook/sam2-hiera-base-plus": (
45 | "configs/sam2/sam2_hiera_b+.yaml",
46 | "sam2_hiera_base_plus.pt",
47 | ),
48 | "facebook/sam2-hiera-large": (
49 | "configs/sam2/sam2_hiera_l.yaml",
50 | "sam2_hiera_large.pt",
51 | ),
52 | "facebook/sam2.1-hiera-tiny": (
53 | "configs/sam2.1/sam2.1_hiera_t.yaml",
54 | "sam2.1_hiera_tiny.pt",
55 | ),
56 | "facebook/sam2.1-hiera-small": (
57 | "configs/sam2.1/sam2.1_hiera_s.yaml",
58 | "sam2.1_hiera_small.pt",
59 | ),
60 | "facebook/sam2.1-hiera-base-plus": (
61 | "configs/sam2.1/sam2.1_hiera_b+.yaml",
62 | "sam2.1_hiera_base_plus.pt",
63 | ),
64 | "facebook/sam2.1-hiera-large": (
65 | "configs/sam2.1/sam2.1_hiera_l.yaml",
66 | "sam2.1_hiera_large.pt",
67 | ),
68 | }
69 |
70 |
71 | def build_sam2(
72 | config_file,
73 | ckpt_path=None,
74 | device="cuda",
75 | mode="eval",
76 | hydra_overrides_extra=[],
77 | apply_postprocessing=True,
78 | **kwargs,
79 | ):
80 |
81 | if apply_postprocessing:
82 | hydra_overrides_extra = hydra_overrides_extra.copy()
83 | hydra_overrides_extra += [
84 | # dynamically fall back to multi-mask if the single mask is not stable
85 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
86 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
87 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
88 | ]
89 | # Read config and init model
90 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
91 | OmegaConf.resolve(cfg)
92 | model = instantiate(cfg.model, _recursive_=True)
93 | _load_checkpoint(model, ckpt_path)
94 | model = model.to(device)
95 | if mode == "eval":
96 | model.eval()
97 | return model
98 |
99 |
100 | def build_sam2_video_predictor(
101 | config_file,
102 | ckpt_path=None,
103 | device="cuda",
104 | mode="eval",
105 | hydra_overrides_extra=[],
106 | apply_postprocessing=True,
107 | **kwargs,
108 | ):
109 | hydra_overrides = [
110 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
111 | ]
112 | if apply_postprocessing:
113 | hydra_overrides_extra = hydra_overrides_extra.copy()
114 | hydra_overrides_extra += [
115 | # dynamically fall back to multi-mask if the single mask is not stable
116 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
117 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
118 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
119 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
120 | "++model.binarize_mask_from_pts_for_mem_enc=true",
121 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
122 | "++model.fill_hole_area=8",
123 | ]
124 | hydra_overrides.extend(hydra_overrides_extra)
125 |
126 | # Read config and init model
127 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
128 | OmegaConf.resolve(cfg)
129 | model = instantiate(cfg.model, _recursive_=True)
130 | _load_checkpoint(model, ckpt_path)
131 | model = model.to(device)
132 | if mode == "eval":
133 | model.eval()
134 | return model
135 |
136 |
137 | def _hf_download(model_id):
138 | from huggingface_hub import hf_hub_download
139 |
140 | config_name, checkpoint_name = HF_MODEL_ID_TO_FILENAMES[model_id]
141 | ckpt_path = hf_hub_download(repo_id=model_id, filename=checkpoint_name)
142 | return config_name, ckpt_path
143 |
144 |
145 | def build_sam2_hf(model_id, **kwargs):
146 | config_name, ckpt_path = _hf_download(model_id)
147 | return build_sam2(config_file=config_name, ckpt_path=ckpt_path, **kwargs)
148 |
149 |
150 | def build_sam2_video_predictor_hf(model_id, **kwargs):
151 | config_name, ckpt_path = _hf_download(model_id)
152 | return build_sam2_video_predictor(
153 | config_file=config_name, ckpt_path=ckpt_path, **kwargs
154 | )
155 |
156 |
157 | def _load_checkpoint(model, ckpt_path):
158 | if ckpt_path is not None:
159 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
160 | missing_keys, unexpected_keys = model.load_state_dict(sd)
161 | if missing_keys:
162 | logging.error(missing_keys)
163 | raise RuntimeError()
164 | if unexpected_keys:
165 | logging.error(unexpected_keys)
166 | raise RuntimeError()
167 | logging.info("Loaded checkpoint sucessfully")
168 |
--------------------------------------------------------------------------------
/sam2/configs/sam2.1/sam2.1_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | no_obj_embed_spatial: true
93 | # use high-resolution feature map in the SAM mask decoder
94 | use_high_res_features_in_sam: true
95 | # output 3 masks on the first click on initial conditioning frames
96 | multimask_output_in_sam: true
97 | # SAM heads
98 | iou_prediction_use_sigmoid: True
99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100 | use_obj_ptrs_in_encoder: true
101 | add_tpos_enc_to_obj_ptrs: true
102 | proj_tpos_enc_in_obj_ptrs: true
103 | use_signed_tpos_enc_to_obj_ptrs: true
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/sam2/configs/sam2.1/sam2.1_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | no_obj_embed_spatial: true
97 | # use high-resolution feature map in the SAM mask decoder
98 | use_high_res_features_in_sam: true
99 | # output 3 masks on the first click on initial conditioning frames
100 | multimask_output_in_sam: true
101 | # SAM heads
102 | iou_prediction_use_sigmoid: True
103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104 | use_obj_ptrs_in_encoder: true
105 | add_tpos_enc_to_obj_ptrs: true
106 | proj_tpos_enc_in_obj_ptrs: true
107 | use_signed_tpos_enc_to_obj_ptrs: true
108 | only_obj_ptrs_in_the_past_for_eval: true
109 | # object occlusion prediction
110 | pred_obj_scores: true
111 | pred_obj_scores_mlp: true
112 | fixed_no_obj_ptr: true
113 | # multimask tracking settings
114 | multimask_output_for_tracking: true
115 | use_multimask_token_for_obj_ptr: true
116 | multimask_min_pt_num: 0
117 | multimask_max_pt_num: 1
118 | use_mlp_for_obj_ptr_proj: true
119 | # Compilation flag
120 | compile_image_encoder: False
121 |
--------------------------------------------------------------------------------
/sam2/configs/sam2.1/sam2.1_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | no_obj_embed_spatial: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: true
105 | proj_tpos_enc_in_obj_ptrs: true
106 | use_signed_tpos_enc_to_obj_ptrs: true
107 | only_obj_ptrs_in_the_past_for_eval: true
108 | # object occlusion prediction
109 | pred_obj_scores: true
110 | pred_obj_scores_mlp: true
111 | fixed_no_obj_ptr: true
112 | # multimask tracking settings
113 | multimask_output_for_tracking: true
114 | use_multimask_token_for_obj_ptr: true
115 | multimask_min_pt_num: 0
116 | multimask_max_pt_num: 1
117 | use_mlp_for_obj_ptr_proj: true
118 | # Compilation flag
119 | compile_image_encoder: False
120 |
--------------------------------------------------------------------------------
/sam2/configs/sam2.1/sam2.1_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | no_obj_embed_spatial: true
97 | # use high-resolution feature map in the SAM mask decoder
98 | use_high_res_features_in_sam: true
99 | # output 3 masks on the first click on initial conditioning frames
100 | multimask_output_in_sam: true
101 | # SAM heads
102 | iou_prediction_use_sigmoid: True
103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104 | use_obj_ptrs_in_encoder: true
105 | add_tpos_enc_to_obj_ptrs: true
106 | proj_tpos_enc_in_obj_ptrs: true
107 | use_signed_tpos_enc_to_obj_ptrs: true
108 | only_obj_ptrs_in_the_past_for_eval: true
109 | # object occlusion prediction
110 | pred_obj_scores: true
111 | pred_obj_scores_mlp: true
112 | fixed_no_obj_ptr: true
113 | # multimask tracking settings
114 | multimask_output_for_tracking: true
115 | use_multimask_token_for_obj_ptr: true
116 | multimask_min_pt_num: 0
117 | multimask_max_pt_num: 1
118 | use_mlp_for_obj_ptr_proj: true
119 | # Compilation flag
120 | # HieraT does not currently support compilation, should always be set to False
121 | compile_image_encoder: False
122 |
--------------------------------------------------------------------------------
/sam2/configs/sam2/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | # use high-resolution feature map in the SAM mask decoder
93 | use_high_res_features_in_sam: true
94 | # output 3 masks on the first click on initial conditioning frames
95 | multimask_output_in_sam: true
96 | # SAM heads
97 | iou_prediction_use_sigmoid: True
98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99 | use_obj_ptrs_in_encoder: true
100 | add_tpos_enc_to_obj_ptrs: false
101 | only_obj_ptrs_in_the_past_for_eval: true
102 | # object occlusion prediction
103 | pred_obj_scores: true
104 | pred_obj_scores_mlp: true
105 | fixed_no_obj_ptr: true
106 | # multimask tracking settings
107 | multimask_output_for_tracking: true
108 | use_multimask_token_for_obj_ptr: true
109 | multimask_min_pt_num: 0
110 | multimask_max_pt_num: 1
111 | use_mlp_for_obj_ptr_proj: true
112 | # Compilation flag
113 | compile_image_encoder: False
114 |
--------------------------------------------------------------------------------
/sam2/configs/sam2/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
118 |
--------------------------------------------------------------------------------
/sam2/configs/sam2/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | # use high-resolution feature map in the SAM mask decoder
96 | use_high_res_features_in_sam: true
97 | # output 3 masks on the first click on initial conditioning frames
98 | multimask_output_in_sam: true
99 | # SAM heads
100 | iou_prediction_use_sigmoid: True
101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102 | use_obj_ptrs_in_encoder: true
103 | add_tpos_enc_to_obj_ptrs: false
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/sam2/configs/sam2/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | # HieraT does not currently support compilation, should always be set to False
118 | compile_image_encoder: False
119 |
--------------------------------------------------------------------------------
/sam2/configs/samurai/sam2.1_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | no_obj_embed_spatial: true
93 | # use high-resolution feature map in the SAM mask decoder
94 | use_high_res_features_in_sam: true
95 | # output 3 masks on the first click on initial conditioning frames
96 | multimask_output_in_sam: true
97 | # SAM heads
98 | iou_prediction_use_sigmoid: True
99 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
100 | use_obj_ptrs_in_encoder: true
101 | add_tpos_enc_to_obj_ptrs: true
102 | proj_tpos_enc_in_obj_ptrs: true
103 | use_signed_tpos_enc_to_obj_ptrs: true
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 | # SAMURAI
118 | samurai_mode: true
119 | stable_frames_threshold: 15
120 | stable_ious_threshold: 0.3
121 | min_obj_score_logits: -1
122 | kf_score_weight: 0.25
123 | memory_bank_iou_threshold: 0.5
124 | memory_bank_obj_score_threshold: 0.0
125 | memory_bank_kf_score_threshold: 0.0
126 |
--------------------------------------------------------------------------------
/sam2/configs/samurai/sam2.1_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | no_obj_embed_spatial: true
97 | # use high-resolution feature map in the SAM mask decoder
98 | use_high_res_features_in_sam: true
99 | # output 3 masks on the first click on initial conditioning frames
100 | multimask_output_in_sam: true
101 | # SAM heads
102 | iou_prediction_use_sigmoid: True
103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104 | use_obj_ptrs_in_encoder: true
105 | add_tpos_enc_to_obj_ptrs: true
106 | proj_tpos_enc_in_obj_ptrs: true
107 | use_signed_tpos_enc_to_obj_ptrs: true
108 | only_obj_ptrs_in_the_past_for_eval: true
109 | # object occlusion prediction
110 | pred_obj_scores: true
111 | pred_obj_scores_mlp: true
112 | fixed_no_obj_ptr: true
113 | # multimask tracking settings
114 | multimask_output_for_tracking: true
115 | use_multimask_token_for_obj_ptr: true
116 | multimask_min_pt_num: 0
117 | multimask_max_pt_num: 1
118 | use_mlp_for_obj_ptr_proj: true
119 | # Compilation flag
120 | compile_image_encoder: False
121 | # SAMURAI
122 | samurai_mode: true
123 | stable_frames_threshold: 15
124 | stable_ious_threshold: 0.3
125 | min_obj_score_logits: -1
126 | kf_score_weight: 0.15
127 | memory_bank_iou_threshold: 0.5
128 | memory_bank_obj_score_threshold: 0.0
129 | memory_bank_kf_score_threshold: 0.0
--------------------------------------------------------------------------------
/sam2/configs/samurai/sam2.1_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | no_obj_embed_spatial: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: true
105 | proj_tpos_enc_in_obj_ptrs: true
106 | use_signed_tpos_enc_to_obj_ptrs: true
107 | only_obj_ptrs_in_the_past_for_eval: true
108 | # object occlusion prediction
109 | pred_obj_scores: true
110 | pred_obj_scores_mlp: true
111 | fixed_no_obj_ptr: true
112 | # multimask tracking settings
113 | multimask_output_for_tracking: true
114 | use_multimask_token_for_obj_ptr: true
115 | multimask_min_pt_num: 0
116 | multimask_max_pt_num: 1
117 | use_mlp_for_obj_ptr_proj: true
118 | # Compilation flag
119 | compile_image_encoder: False
120 | # SAMURAI
121 | samurai_mode: true
122 | stable_frames_threshold: 15
123 | stable_ious_threshold: 0.3
124 | min_obj_score_logits: -1
125 | kf_score_weight: 0.25
126 | memory_bank_iou_threshold: 0.5
127 | memory_bank_obj_score_threshold: 0.0
128 | memory_bank_kf_score_threshold: 0.0
--------------------------------------------------------------------------------
/sam2/configs/samurai/sam2.1_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | no_obj_embed_spatial: true
97 | # use high-resolution feature map in the SAM mask decoder
98 | use_high_res_features_in_sam: true
99 | # output 3 masks on the first click on initial conditioning frames
100 | multimask_output_in_sam: true
101 | # SAM heads
102 | iou_prediction_use_sigmoid: True
103 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
104 | use_obj_ptrs_in_encoder: true
105 | add_tpos_enc_to_obj_ptrs: true
106 | proj_tpos_enc_in_obj_ptrs: true
107 | use_signed_tpos_enc_to_obj_ptrs: true
108 | only_obj_ptrs_in_the_past_for_eval: true
109 | # object occlusion prediction
110 | pred_obj_scores: true
111 | pred_obj_scores_mlp: true
112 | fixed_no_obj_ptr: true
113 | # multimask tracking settings
114 | multimask_output_for_tracking: true
115 | use_multimask_token_for_obj_ptr: true
116 | multimask_min_pt_num: 0
117 | multimask_max_pt_num: 1
118 | use_mlp_for_obj_ptr_proj: true
119 | # Compilation flag
120 | # HieraT does not currently support compilation, should always be set to False
121 | compile_image_encoder: False
122 | # SAMURAI
123 | samurai_mode: true
124 | stable_frames_threshold: 15
125 | stable_ious_threshold: 0.3
126 | min_obj_score_logits: -1
127 | kf_score_weight: 0.25
128 | memory_bank_iou_threshold: 0.5
129 | memory_bank_obj_score_threshold: 0.0
130 | memory_bank_kf_score_threshold: 0.0
--------------------------------------------------------------------------------
/sam2/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class ImageEncoder(nn.Module):
15 | def __init__(
16 | self,
17 | trunk: nn.Module,
18 | neck: nn.Module,
19 | scalp: int = 0,
20 | ):
21 | super().__init__()
22 | self.trunk = trunk
23 | self.neck = neck
24 | self.scalp = scalp
25 | assert (
26 | self.trunk.channel_list == self.neck.backbone_channel_list
27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28 |
29 | def forward(self, sample: torch.Tensor):
30 | # Forward through backbone
31 | features, pos = self.neck(self.trunk(sample))
32 | if self.scalp > 0:
33 | # Discard the lowest resolution features
34 | features, pos = features[: -self.scalp], pos[: -self.scalp]
35 |
36 | src = features[-1]
37 | output = {
38 | "vision_features": src,
39 | "vision_pos_enc": pos,
40 | "backbone_fpn": features,
41 | }
42 | return output
43 |
44 |
45 | class FpnNeck(nn.Module):
46 | """
47 | A modified variant of Feature Pyramid Network (FPN) neck
48 | (we remove output conv and also do bicubic interpolation similar to ViT
49 | pos embed interpolation)
50 | """
51 |
52 | def __init__(
53 | self,
54 | position_encoding: nn.Module,
55 | d_model: int,
56 | backbone_channel_list: List[int],
57 | kernel_size: int = 1,
58 | stride: int = 1,
59 | padding: int = 0,
60 | fpn_interp_model: str = "bilinear",
61 | fuse_type: str = "sum",
62 | fpn_top_down_levels: Optional[List[int]] = None,
63 | ):
64 | """Initialize the neck
65 | :param trunk: the backbone
66 | :param position_encoding: the positional encoding to use
67 | :param d_model: the dimension of the model
68 | :param neck_norm: the normalization to use
69 | """
70 | super().__init__()
71 | self.position_encoding = position_encoding
72 | self.convs = nn.ModuleList()
73 | self.backbone_channel_list = backbone_channel_list
74 | self.d_model = d_model
75 | for dim in backbone_channel_list:
76 | current = nn.Sequential()
77 | current.add_module(
78 | "conv",
79 | nn.Conv2d(
80 | in_channels=dim,
81 | out_channels=d_model,
82 | kernel_size=kernel_size,
83 | stride=stride,
84 | padding=padding,
85 | ),
86 | )
87 |
88 | self.convs.append(current)
89 | self.fpn_interp_model = fpn_interp_model
90 | assert fuse_type in ["sum", "avg"]
91 | self.fuse_type = fuse_type
92 |
93 | # levels to have top-down features in its outputs
94 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
95 | # have top-down propagation, while outputs of level 0 and level 1 have only
96 | # lateral features from the same backbone level.
97 | if fpn_top_down_levels is None:
98 | # default is to have top-down features on all levels
99 | fpn_top_down_levels = range(len(self.convs))
100 | self.fpn_top_down_levels = list(fpn_top_down_levels)
101 |
102 | def forward(self, xs: List[torch.Tensor]):
103 |
104 | out = [None] * len(self.convs)
105 | pos = [None] * len(self.convs)
106 | assert len(xs) == len(self.convs)
107 | # fpn forward pass
108 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
109 | prev_features = None
110 | # forward in top-down order (from low to high resolution)
111 | n = len(self.convs) - 1
112 | for i in range(n, -1, -1):
113 | x = xs[i]
114 | lateral_features = self.convs[n - i](x)
115 | if i in self.fpn_top_down_levels and prev_features is not None:
116 | top_down_features = F.interpolate(
117 | prev_features.to(dtype=torch.float32),
118 | scale_factor=2.0,
119 | mode=self.fpn_interp_model,
120 | align_corners=(
121 | None if self.fpn_interp_model == "nearest" else False
122 | ),
123 | antialias=False,
124 | )
125 | prev_features = lateral_features + top_down_features
126 | if self.fuse_type == "avg":
127 | prev_features /= 2
128 | else:
129 | prev_features = lateral_features
130 | x_out = prev_features
131 | out[i] = x_out
132 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
133 |
134 | return out, pos
135 |
--------------------------------------------------------------------------------
/sam2/modeling/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some utilities for backbones, in particular for windowing"""
8 |
9 | from typing import Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def window_partition(x, window_size):
17 | """
18 | Partition into non-overlapping windows with padding if needed.
19 | Args:
20 | x (tensor): input tokens with [B, H, W, C].
21 | window_size (int): window size.
22 | Returns:
23 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
24 | (Hp, Wp): padded height and width before partition
25 | """
26 | B, H, W, C = x.shape
27 |
28 | pad_h = (window_size - H % window_size) % window_size
29 | pad_w = (window_size - W % window_size) % window_size
30 | if pad_h > 0 or pad_w > 0:
31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32 | Hp, Wp = H + pad_h, W + pad_w
33 |
34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35 | windows = (
36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | )
38 | return windows, (Hp, Wp)
39 |
40 |
41 | def window_unpartition(windows, window_size, pad_hw, hw):
42 | """
43 | Window unpartition into original sequences and removing padding.
44 | Args:
45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46 | window_size (int): window size.
47 | pad_hw (Tuple): padded height and width (Hp, Wp).
48 | hw (Tuple): original height and width (H, W) before padding.
49 | Returns:
50 | x: unpartitioned sequences with [B, H, W, C].
51 | """
52 | Hp, Wp = pad_hw
53 | H, W = hw
54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55 | x = windows.view(
56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57 | )
58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59 |
60 | if Hp > H or Wp > W:
61 | x = x[:, :H, :W, :].contiguous()
62 | return x
63 |
64 |
65 | class PatchEmbed(nn.Module):
66 | """
67 | Image to Patch Embedding.
68 | """
69 |
70 | def __init__(
71 | self,
72 | kernel_size: Tuple[int, ...] = (7, 7),
73 | stride: Tuple[int, ...] = (4, 4),
74 | padding: Tuple[int, ...] = (3, 3),
75 | in_chans: int = 3,
76 | embed_dim: int = 768,
77 | ):
78 | """
79 | Args:
80 | kernel_size (Tuple): kernel size of the projection layer.
81 | stride (Tuple): stride of the projection layer.
82 | padding (Tuple): padding size of the projection layer.
83 | in_chans (int): Number of input image channels.
84 | embed_dim (int): embed_dim (int): Patch embedding dimension.
85 | """
86 | super().__init__()
87 | self.proj = nn.Conv2d(
88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89 | )
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | x = self.proj(x)
93 | # B C H W -> B H W C
94 | x = x.permute(0, 2, 3, 1)
95 | return x
96 |
--------------------------------------------------------------------------------
/sam2/modeling/memory_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import nn, Tensor
11 |
12 | from sam2.modeling.sam.transformer import RoPEAttention
13 |
14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones
15 |
16 |
17 | class MemoryAttentionLayer(nn.Module):
18 |
19 | def __init__(
20 | self,
21 | activation: str,
22 | cross_attention: nn.Module,
23 | d_model: int,
24 | dim_feedforward: int,
25 | dropout: float,
26 | pos_enc_at_attn: bool,
27 | pos_enc_at_cross_attn_keys: bool,
28 | pos_enc_at_cross_attn_queries: bool,
29 | self_attention: nn.Module,
30 | ):
31 | super().__init__()
32 | self.d_model = d_model
33 | self.dim_feedforward = dim_feedforward
34 | self.dropout_value = dropout
35 | self.self_attn = self_attention
36 | self.cross_attn_image = cross_attention
37 |
38 | # Implementation of Feedforward model
39 | self.linear1 = nn.Linear(d_model, dim_feedforward)
40 | self.dropout = nn.Dropout(dropout)
41 | self.linear2 = nn.Linear(dim_feedforward, d_model)
42 |
43 | self.norm1 = nn.LayerNorm(d_model)
44 | self.norm2 = nn.LayerNorm(d_model)
45 | self.norm3 = nn.LayerNorm(d_model)
46 | self.dropout1 = nn.Dropout(dropout)
47 | self.dropout2 = nn.Dropout(dropout)
48 | self.dropout3 = nn.Dropout(dropout)
49 |
50 | self.activation_str = activation
51 | self.activation = get_activation_fn(activation)
52 |
53 | # Where to add pos enc
54 | self.pos_enc_at_attn = pos_enc_at_attn
55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
57 |
58 | def _forward_sa(self, tgt, query_pos):
59 | # Self-Attention
60 | tgt2 = self.norm1(tgt)
61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
62 | tgt2 = self.self_attn(q, k, v=tgt2)
63 | tgt = tgt + self.dropout1(tgt2)
64 | return tgt
65 |
66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
67 | kwds = {}
68 | if num_k_exclude_rope > 0:
69 | assert isinstance(self.cross_attn_image, RoPEAttention)
70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope}
71 |
72 | # Cross-Attention
73 | tgt2 = self.norm2(tgt)
74 | tgt2 = self.cross_attn_image(
75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
77 | v=memory,
78 | **kwds,
79 | )
80 | tgt = tgt + self.dropout2(tgt2)
81 | return tgt
82 |
83 | def forward(
84 | self,
85 | tgt,
86 | memory,
87 | pos: Optional[Tensor] = None,
88 | query_pos: Optional[Tensor] = None,
89 | num_k_exclude_rope: int = 0,
90 | ) -> torch.Tensor:
91 |
92 | # Self-Attn, Cross-Attn
93 | tgt = self._forward_sa(tgt, query_pos)
94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
95 | # MLP
96 | tgt2 = self.norm3(tgt)
97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
98 | tgt = tgt + self.dropout3(tgt2)
99 | return tgt
100 |
101 |
102 | class MemoryAttention(nn.Module):
103 | def __init__(
104 | self,
105 | d_model: int,
106 | pos_enc_at_input: bool,
107 | layer: nn.Module,
108 | num_layers: int,
109 | batch_first: bool = True, # Do layers expect batch first input?
110 | ):
111 | super().__init__()
112 | self.d_model = d_model
113 | self.layers = get_clones(layer, num_layers)
114 | self.num_layers = num_layers
115 | self.norm = nn.LayerNorm(d_model)
116 | self.pos_enc_at_input = pos_enc_at_input
117 | self.batch_first = batch_first
118 |
119 | def forward(
120 | self,
121 | curr: torch.Tensor, # self-attention inputs
122 | memory: torch.Tensor, # cross-attention inputs
123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
126 | ):
127 | if isinstance(curr, list):
128 | assert isinstance(curr_pos, list)
129 | assert len(curr) == len(curr_pos) == 1
130 | curr, curr_pos = (
131 | curr[0],
132 | curr_pos[0],
133 | )
134 |
135 | assert (
136 | curr.shape[1] == memory.shape[1]
137 | ), "Batch size must be the same for curr and memory"
138 |
139 | output = curr
140 | if self.pos_enc_at_input and curr_pos is not None:
141 | output = output + 0.1 * curr_pos
142 |
143 | if self.batch_first:
144 | # Convert to batch first
145 | output = output.transpose(0, 1)
146 | curr_pos = curr_pos.transpose(0, 1)
147 | memory = memory.transpose(0, 1)
148 | memory_pos = memory_pos.transpose(0, 1)
149 |
150 | for layer in self.layers:
151 | kwds = {}
152 | if isinstance(layer.cross_attn_image, RoPEAttention):
153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
154 |
155 | output = layer(
156 | tgt=output,
157 | memory=memory,
158 | pos=memory_pos,
159 | query_pos=curr_pos,
160 | **kwds,
161 | )
162 | normed_output = self.norm(output)
163 |
164 | if self.batch_first:
165 | # Convert back to seq first
166 | normed_output = normed_output.transpose(0, 1)
167 | curr_pos = curr_pos.transpose(0, 1)
168 |
169 | return normed_output
170 |
--------------------------------------------------------------------------------
/sam2/modeling/memory_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Tuple
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d
15 |
16 |
17 | class MaskDownSampler(nn.Module):
18 | """
19 | Progressively downsample a mask by total_stride, each time by stride.
20 | Note that LayerNorm is applied per *token*, like in ViT.
21 |
22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23 | In the end, we linearly project to embed_dim channels.
24 | """
25 |
26 | def __init__(
27 | self,
28 | embed_dim=256,
29 | kernel_size=4,
30 | stride=4,
31 | padding=0,
32 | total_stride=16,
33 | activation=nn.GELU,
34 | ):
35 | super().__init__()
36 | num_layers = int(math.log2(total_stride) // math.log2(stride))
37 | assert stride**num_layers == total_stride
38 | self.encoder = nn.Sequential()
39 | mask_in_chans, mask_out_chans = 1, 1
40 | for _ in range(num_layers):
41 | mask_out_chans = mask_in_chans * (stride**2)
42 | self.encoder.append(
43 | nn.Conv2d(
44 | mask_in_chans,
45 | mask_out_chans,
46 | kernel_size=kernel_size,
47 | stride=stride,
48 | padding=padding,
49 | )
50 | )
51 | self.encoder.append(LayerNorm2d(mask_out_chans))
52 | self.encoder.append(activation())
53 | mask_in_chans = mask_out_chans
54 |
55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56 |
57 | def forward(self, x):
58 | return self.encoder(x)
59 |
60 |
61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62 | class CXBlock(nn.Module):
63 | r"""ConvNeXt Block. There are two equivalent implementations:
64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66 | We use (2) as we find it slightly faster in PyTorch
67 |
68 | Args:
69 | dim (int): Number of input channels.
70 | drop_path (float): Stochastic depth rate. Default: 0.0
71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72 | """
73 |
74 | def __init__(
75 | self,
76 | dim,
77 | kernel_size=7,
78 | padding=3,
79 | drop_path=0.0,
80 | layer_scale_init_value=1e-6,
81 | use_dwconv=True,
82 | ):
83 | super().__init__()
84 | self.dwconv = nn.Conv2d(
85 | dim,
86 | dim,
87 | kernel_size=kernel_size,
88 | padding=padding,
89 | groups=dim if use_dwconv else 1,
90 | ) # depthwise conv
91 | self.norm = LayerNorm2d(dim, eps=1e-6)
92 | self.pwconv1 = nn.Linear(
93 | dim, 4 * dim
94 | ) # pointwise/1x1 convs, implemented with linear layers
95 | self.act = nn.GELU()
96 | self.pwconv2 = nn.Linear(4 * dim, dim)
97 | self.gamma = (
98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99 | if layer_scale_init_value > 0
100 | else None
101 | )
102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103 |
104 | def forward(self, x):
105 | input = x
106 | x = self.dwconv(x)
107 | x = self.norm(x)
108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109 | x = self.pwconv1(x)
110 | x = self.act(x)
111 | x = self.pwconv2(x)
112 | if self.gamma is not None:
113 | x = self.gamma * x
114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115 |
116 | x = input + self.drop_path(x)
117 | return x
118 |
119 |
120 | class Fuser(nn.Module):
121 | def __init__(self, layer, num_layers, dim=None, input_projection=False):
122 | super().__init__()
123 | self.proj = nn.Identity()
124 | self.layers = get_clones(layer, num_layers)
125 |
126 | if input_projection:
127 | assert dim is not None
128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129 |
130 | def forward(self, x):
131 | # normally x: (N, C, H, W)
132 | x = self.proj(x)
133 | for layer in self.layers:
134 | x = layer(x)
135 | return x
136 |
137 |
138 | class MemoryEncoder(nn.Module):
139 | def __init__(
140 | self,
141 | out_dim,
142 | mask_downsampler,
143 | fuser,
144 | position_encoding,
145 | in_dim=256, # in_dim of pix_feats
146 | ):
147 | super().__init__()
148 |
149 | self.mask_downsampler = mask_downsampler
150 |
151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152 | self.fuser = fuser
153 | self.position_encoding = position_encoding
154 | self.out_proj = nn.Identity()
155 | if out_dim != in_dim:
156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157 |
158 | def forward(
159 | self,
160 | pix_feat: torch.Tensor,
161 | masks: torch.Tensor,
162 | skip_mask_sigmoid: bool = False,
163 | ) -> Tuple[torch.Tensor, torch.Tensor]:
164 | ## Process masks
165 | # sigmoid, so that less domain shift from gt masks which are bool
166 | if not skip_mask_sigmoid:
167 | masks = F.sigmoid(masks)
168 | masks = self.mask_downsampler(masks)
169 |
170 | ## Fuse pix_feats and downsampled masks
171 | # in case the visual features are on CPU, cast them to CUDA
172 | pix_feat = pix_feat.to(masks.device)
173 |
174 | x = self.pix_feat_proj(pix_feat)
175 | x = x + masks
176 | x = self.fuser(x)
177 | x = self.out_proj(x)
178 |
179 | pos = self.position_encoding(x).to(x.dtype)
180 |
181 | return {"vision_features": x, "vision_pos_enc": [pos]}
182 |
--------------------------------------------------------------------------------
/sam2/modeling/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/modeling/sam/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom
13 |
14 | from sam2.modeling.sam2_utils import LayerNorm2d
15 |
16 |
17 | class PromptEncoder(nn.Module):
18 | def __init__(
19 | self,
20 | embed_dim: int,
21 | image_embedding_size: Tuple[int, int],
22 | input_image_size: Tuple[int, int],
23 | mask_in_chans: int,
24 | activation: Type[nn.Module] = nn.GELU,
25 | ) -> None:
26 | """
27 | Encodes prompts for input to SAM's mask decoder.
28 |
29 | Arguments:
30 | embed_dim (int): The prompts' embedding dimension
31 | image_embedding_size (tuple(int, int)): The spatial size of the
32 | image embedding, as (H, W).
33 | input_image_size (int): The padded size of the image as input
34 | to the image encoder, as (H, W).
35 | mask_in_chans (int): The number of hidden channels used for
36 | encoding input masks.
37 | activation (nn.Module): The activation to use when encoding
38 | input masks.
39 | """
40 | super().__init__()
41 | self.embed_dim = embed_dim
42 | self.input_image_size = input_image_size
43 | self.image_embedding_size = image_embedding_size
44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
45 |
46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
47 | point_embeddings = [
48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
49 | ]
50 | self.point_embeddings = nn.ModuleList(point_embeddings)
51 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
52 |
53 | self.mask_input_size = (
54 | 4 * image_embedding_size[0],
55 | 4 * image_embedding_size[1],
56 | )
57 | self.mask_downscaling = nn.Sequential(
58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
59 | LayerNorm2d(mask_in_chans // 4),
60 | activation(),
61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
62 | LayerNorm2d(mask_in_chans),
63 | activation(),
64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
65 | )
66 | self.no_mask_embed = nn.Embedding(1, embed_dim)
67 |
68 | def get_dense_pe(self) -> torch.Tensor:
69 | """
70 | Returns the positional encoding used to encode point prompts,
71 | applied to a dense set of points the shape of the image encoding.
72 |
73 | Returns:
74 | torch.Tensor: Positional encoding with shape
75 | 1x(embed_dim)x(embedding_h)x(embedding_w)
76 | """
77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
78 |
79 | def _embed_points(
80 | self,
81 | points: torch.Tensor,
82 | labels: torch.Tensor,
83 | pad: bool,
84 | ) -> torch.Tensor:
85 | """Embeds point prompts."""
86 | points = points + 0.5 # Shift to center of pixel
87 | if pad:
88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
90 | points = torch.cat([points, padding_point], dim=1)
91 | labels = torch.cat([labels, padding_label], dim=1)
92 | point_embedding = self.pe_layer.forward_with_coords(
93 | points, self.input_image_size
94 | )
95 | point_embedding[labels == -1] = 0.0
96 | point_embedding[labels == -1] += self.not_a_point_embed.weight
97 | point_embedding[labels == 0] += self.point_embeddings[0].weight
98 | point_embedding[labels == 1] += self.point_embeddings[1].weight
99 | point_embedding[labels == 2] += self.point_embeddings[2].weight
100 | point_embedding[labels == 3] += self.point_embeddings[3].weight
101 | return point_embedding
102 |
103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
104 | """Embeds box prompts."""
105 | boxes = boxes + 0.5 # Shift to center of pixel
106 | coords = boxes.reshape(-1, 2, 2)
107 | corner_embedding = self.pe_layer.forward_with_coords(
108 | coords, self.input_image_size
109 | )
110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
112 | return corner_embedding
113 |
114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
115 | """Embeds mask inputs."""
116 | mask_embedding = self.mask_downscaling(masks)
117 | return mask_embedding
118 |
119 | def _get_batch_size(
120 | self,
121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
122 | boxes: Optional[torch.Tensor],
123 | masks: Optional[torch.Tensor],
124 | ) -> int:
125 | """
126 | Gets the batch size of the output given the batch size of the input prompts.
127 | """
128 | if points is not None:
129 | return points[0].shape[0]
130 | elif boxes is not None:
131 | return boxes.shape[0]
132 | elif masks is not None:
133 | return masks.shape[0]
134 | else:
135 | return 1
136 |
137 | def _get_device(self) -> torch.device:
138 | return self.point_embeddings[0].weight.device
139 |
140 | def forward(
141 | self,
142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
143 | boxes: Optional[torch.Tensor],
144 | masks: Optional[torch.Tensor],
145 | ) -> Tuple[torch.Tensor, torch.Tensor]:
146 | """
147 | Embeds different types of prompts, returning both sparse and dense
148 | embeddings.
149 |
150 | Arguments:
151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
152 | and labels to embed.
153 | boxes (torch.Tensor or none): boxes to embed
154 | masks (torch.Tensor or none): masks to embed
155 |
156 | Returns:
157 | torch.Tensor: sparse embeddings for the points and boxes, with shape
158 | BxNx(embed_dim), where N is determined by the number of input points
159 | and boxes.
160 | torch.Tensor: dense embeddings for the masks, in the shape
161 | Bx(embed_dim)x(embed_H)x(embed_W)
162 | """
163 | bs = self._get_batch_size(points, boxes, masks)
164 | sparse_embeddings = torch.empty(
165 | (bs, 0, self.embed_dim), device=self._get_device()
166 | )
167 | if points is not None:
168 | coords, labels = points
169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
171 | if boxes is not None:
172 | box_embeddings = self._embed_boxes(boxes)
173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
174 |
175 | if masks is not None:
176 | dense_embeddings = self._embed_masks(masks)
177 | else:
178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
180 | )
181 |
182 | return sparse_embeddings, dense_embeddings
183 |
--------------------------------------------------------------------------------
/sam2/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | configs/sam2/sam2_hiera_b+.yaml
--------------------------------------------------------------------------------
/sam2/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | configs/sam2/sam2_hiera_l.yaml
--------------------------------------------------------------------------------
/sam2/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | configs/sam2/sam2_hiera_s.yaml
--------------------------------------------------------------------------------
/sam2/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | configs/sam2/sam2_hiera_t.yaml
--------------------------------------------------------------------------------
/sam2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/sam2/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import warnings
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | from torchvision.transforms import Normalize, Resize, ToTensor
13 |
14 |
15 | class SAM2Transforms(nn.Module):
16 | def __init__(
17 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
18 | ):
19 | """
20 | Transforms for SAM2.
21 | """
22 | super().__init__()
23 | self.resolution = resolution
24 | self.mask_threshold = mask_threshold
25 | self.max_hole_area = max_hole_area
26 | self.max_sprinkle_area = max_sprinkle_area
27 | self.mean = [0.485, 0.456, 0.406]
28 | self.std = [0.229, 0.224, 0.225]
29 | self.to_tensor = ToTensor()
30 | self.transforms = torch.jit.script(
31 | nn.Sequential(
32 | Resize((self.resolution, self.resolution)),
33 | Normalize(self.mean, self.std),
34 | )
35 | )
36 |
37 | def __call__(self, x):
38 | x = self.to_tensor(x)
39 | return self.transforms(x)
40 |
41 | def forward_batch(self, img_list):
42 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
43 | img_batch = torch.stack(img_batch, dim=0)
44 | return img_batch
45 |
46 | def transform_coords(
47 | self, coords: torch.Tensor, normalize=False, orig_hw=None
48 | ) -> torch.Tensor:
49 | """
50 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
51 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
52 |
53 | Returns
54 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
55 | """
56 | if normalize:
57 | assert orig_hw is not None
58 | h, w = orig_hw
59 | coords = coords.clone()
60 | coords[..., 0] = coords[..., 0] / w
61 | coords[..., 1] = coords[..., 1] / h
62 |
63 | coords = coords * self.resolution # unnormalize coords
64 | return coords
65 |
66 | def transform_boxes(
67 | self, boxes: torch.Tensor, normalize=False, orig_hw=None
68 | ) -> torch.Tensor:
69 | """
70 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
71 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
72 | """
73 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
74 | return boxes
75 |
76 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
77 | """
78 | Perform PostProcessing on output masks.
79 | """
80 | from sam2.utils.misc import get_connected_components
81 |
82 | masks = masks.float()
83 | input_masks = masks
84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
85 | try:
86 | if self.max_hole_area > 0:
87 | # Holes are those connected components in background with area <= self.fill_hole_area
88 | # (background regions are those with mask scores <= self.mask_threshold)
89 | labels, areas = get_connected_components(
90 | mask_flat <= self.mask_threshold
91 | )
92 | is_hole = (labels > 0) & (areas <= self.max_hole_area)
93 | is_hole = is_hole.reshape_as(masks)
94 | # We fill holes with a small positive mask score (10.0) to change them to foreground.
95 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
96 |
97 | if self.max_sprinkle_area > 0:
98 | labels, areas = get_connected_components(
99 | mask_flat > self.mask_threshold
100 | )
101 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
102 | is_hole = is_hole.reshape_as(masks)
103 | # We fill holes with negative mask score (-10.0) to change them to background.
104 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
105 | except Exception as e:
106 | # Skip the post-processing step if the CUDA kernel fails
107 | warnings.warn(
108 | f"{e}\n\nSkipping the post-processing step due to the error above. You can "
109 | "still use SAM 2 and it's OK to ignore the error above, although some post-processing "
110 | "functionality may be limited (which doesn't affect the results in most cases; see "
111 | "https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).",
112 | category=UserWarning,
113 | stacklevel=2,
114 | )
115 | masks = input_masks
116 |
117 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
118 | return masks
119 |
--------------------------------------------------------------------------------
/scripts/dog.txt:
--------------------------------------------------------------------------------
1 | 450, 350, 250, 200
2 | 200, 100, 250, 400
--------------------------------------------------------------------------------
/scripts/get_boundary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 | import json
4 | import sys
5 | import os
6 | import argparse
7 |
8 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
9 | from trace.conversation import conv_templates, SeparatorStyle
10 | from trace.constants import DEFAULT_MMODAL_TOKEN, MMODAL_TOKEN_INDEX
11 | from trace.mm_utils import get_model_name_from_path, tokenizer_MMODAL_token_all, process_video, process_image, KeywordsStoppingCriteria
12 | from trace.model.builder import load_pretrained_model
13 |
14 |
15 | def inference(args):
16 | # Video Inference
17 | paths = args.video_paths
18 | questions = args.questions
19 | modal_list = ['video']
20 |
21 | # 1. Initialize the model.
22 | model_path = args.model_path
23 | model_name = get_model_name_from_path(model_path)
24 | tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name)
25 | model = model.to('cuda')
26 | conv_mode = 'llama_2'
27 |
28 | # 2. Visual preprocess (load & transform image or video).
29 | if modal_list[0] == 'video':
30 | tensor, video_timestamps = process_video(paths[0], processor, model.config.image_aspect_ratio, num_frames=64)
31 | tensor = tensor.to(dtype=torch.float16, device='cuda', non_blocking=True)
32 | default_mm_token = DEFAULT_MMODAL_TOKEN["VIDEO"]
33 | modal_token_index = MMODAL_TOKEN_INDEX["VIDEO"]
34 | else:
35 | tensor = process_image(paths[0], processor, model.config.image_aspect_ratio)[0].to(dtype=torch.float16, device='cuda', non_blocking=True)
36 | default_mm_token = DEFAULT_MMODAL_TOKEN["IMAGE"]
37 | modal_token_index = MMODAL_TOKEN_INDEX["IMAGE"]
38 |
39 | tensor = [tensor]
40 | video_timestamps = [video_timestamps]
41 | heads = [1]
42 |
43 | # 3. Text preprocess (tag process & generate prompt).
44 | question = default_mm_token + "\n" + questions[0]
45 | conv = conv_templates[conv_mode].copy()
46 | conv.append_message(conv.roles[0], question)
47 | conv.append_message(conv.roles[1], None)
48 | prompt = conv.get_prompt()
49 | prompt += ''
50 | print(prompt)
51 | input_ids = tokenizer_MMODAL_token_all(prompt, tokenizer, return_tensors='pt').unsqueeze(0).to('cuda')
52 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
53 | stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE] else conv.sep2
54 | keywords = [stop_str]
55 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
56 | do_sample = True
57 |
58 | with torch.inference_mode():
59 | output_ids = model.generate(
60 | input_ids,
61 | attention_mask=attention_masks,
62 | images_or_videos=tensor,
63 | modal_list=modal_list,
64 | do_sample=do_sample,
65 | temperature=0.2 if do_sample else 0.0,
66 | max_new_tokens=1024,
67 | use_cache=True,
68 | pad_token_id=tokenizer.eos_token_id,
69 | video_timestamps=video_timestamps,
70 | heads=heads
71 | )
72 |
73 | outputs = {
74 | 'timestamps': [],
75 | 'scores': [],
76 | 'captions': [],
77 | }
78 | cur_timestamps = []
79 | cur_timestamp = []
80 | cur_scores = []
81 | cur_score = []
82 | cur_caption = []
83 | for idx in output_ids[0]:
84 | if idx <= 32000:
85 | if idx == 32000:
86 | new_caption = tokenizer.decode(cur_caption, skip_special_tokens=True)
87 | outputs['captions'].append(new_caption)
88 | cur_caption = []
89 | else:
90 | cur_caption.append(idx)
91 | elif idx <= 32013: # 32001 ; 32002
92 | if idx == 32001:
93 | if len(cur_timestamp) > 0:
94 | cur_timestamps.append(float(''.join(cur_timestamp)))
95 | outputs['timestamps'].append(cur_timestamps)
96 | cur_timestamps = []
97 | cur_timestamp = []
98 | elif idx == 32002:
99 | if len(cur_timestamp) > 0:
100 | cur_timestamps.append(float(''.join(cur_timestamp)))
101 | cur_timestamp = []
102 | else:
103 | cur_timestamp.append(model.get_model().time_tokenizer.decode(idx - 32001))
104 | else: # 32014 ; 32015
105 | if idx == 32014:
106 | if len(cur_score) > 0:
107 | cur_scores.append(float(''.join(cur_score)))
108 | outputs['scores'].append(cur_scores)
109 | cur_scores = []
110 | cur_score = []
111 | elif idx == 32015:
112 | if len(cur_score) > 0:
113 | cur_scores.append(float(''.join(cur_score)))
114 | cur_score = []
115 | else:
116 | cur_score.append(model.get_model().score_tokenizer.decode(idx - 32014))
117 | if len(cur_caption):
118 | outputs['captions'].append(tokenizer.decode(cur_caption, skip_special_tokens=True))
119 |
120 | try:
121 | results = []
122 | for i in range(len(outputs['timestamps'])):
123 | output = {
124 | 'video': paths[0].split("/")[-1][:-4] + "_mask.mp4",
125 | 'segment': f"{outputs['timestamps'][i][0]}_{outputs['timestamps'][i][1]}",
126 | 'question': "",
127 | 'answer': outputs['captions'][i],
128 | }
129 | results.append(output)
130 |
131 | with open(f'./results/{paths[0].split("/")[-1].split(".")[0]}_boundary.json', 'w') as f:
132 | json.dump(results, f)
133 |
134 | except Exception as e:
135 | print(e)
136 | print("Failed to save the output to a json file.")
137 | with open(f'./results/{paths[0].split("/")[-1].split(".")[0]}_boundary.json', 'w') as f:
138 | json.dump([{"video": paths[0].split("/")[-1], "segment": f"0.0_{video_timestamps[0][1]}", "question": "", "answer": ""}], f)
139 |
140 |
141 | if __name__ == "__main__":
142 | parser = argparse.ArgumentParser(description="Inference script for boundary detection.")
143 | parser.add_argument("--video_paths", nargs='+', required=True, help="Paths to the input video files.")
144 | parser.add_argument("--questions", nargs='+', required=True, help="Questions for video inference.")
145 | parser.add_argument("--model_path", required=True, help="Path to the pretrained model.")
146 | args = parser.parse_args()
147 |
148 | inference(args)
149 |
--------------------------------------------------------------------------------
/scripts/get_masks.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import cv2
6 | import torch
7 | import gc
8 | from tqdm import tqdm
9 | import sys
10 | sys.path.append("./")
11 | from sam2.build_sam import build_sam2_video_predictor
12 |
13 | color = [(255, 0, 0)]
14 |
15 | def load_txt(gt_path):
16 | with open(gt_path, 'r') as f:
17 | gt = f.readlines()
18 | prompts = {}
19 | for fid, line in enumerate(gt):
20 | x_min, y_min, x_max, y_max = line.strip().split(",")
21 | # x, y, w, h = int(x), int(y), int(w), int(h)
22 | x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
23 | prompts[fid] = ((x_min, y_min, x_max, y_max), 0)
24 | return prompts
25 |
26 | def determine_model_cfg(model_path):
27 | if "large" in model_path:
28 | return "configs/samurai/sam2.1_hiera_l.yaml"
29 | elif "base_plus" in model_path:
30 | return "configs/samurai/sam2.1_hiera_b+.yaml"
31 | elif "small" in model_path:
32 | return "configs/samurai/sam2.1_hiera_s.yaml"
33 | elif "tiny" in model_path:
34 | return "configs/samurai/sam2.1_hiera_t.yaml"
35 | else:
36 | raise ValueError("Unknown model size in path!")
37 |
38 | def prepare_frames_or_path(video_path):
39 | if video_path.endswith(".mp4") or osp.isdir(video_path):
40 | return video_path
41 | else:
42 | raise ValueError("Invalid video_path format. Should be .mp4 or a directory of jpg frames.")
43 |
44 | def main(args):
45 | model_cfg = determine_model_cfg(args.model_path)
46 | predictor = build_sam2_video_predictor(model_cfg, args.model_path, device="cuda:0")
47 | frames_or_path = prepare_frames_or_path(args.video_path)
48 | prompts = load_txt(args.txt_path)
49 | print(prompts)
50 |
51 | if args.save_to_video:
52 | if osp.isdir(args.video_path):
53 | frames = sorted([osp.join(args.video_path, f) for f in os.listdir(args.video_path) if f.endswith(".jpg")])
54 | loaded_frames = [cv2.imread(frame_path) for frame_path in frames]
55 | height, width = loaded_frames[0].shape[:2]
56 | else:
57 | cap = cv2.VideoCapture(args.video_path)
58 | loaded_frames = []
59 | while True:
60 | ret, frame = cap.read()
61 | if not ret:
62 | break
63 | loaded_frames.append(frame)
64 | cap.release()
65 | height, width = loaded_frames[0].shape[:2]
66 | if len(loaded_frames) == 0:
67 | raise ValueError("No frames were loaded from the video.")
68 |
69 |
70 |
71 |
72 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
73 | out = cv2.VideoWriter(args.video_output_path+f"/{osp.basename(args.video_path).split('.')[0]}_mask.mp4", fourcc, 30, (width, height))
74 |
75 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
76 | state = predictor.init_state(frames_or_path, offload_video_to_cpu=True)
77 | bbox, track_label = prompts[0]
78 | _, _, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
79 |
80 |
81 | for frame_idx, object_ids, masks in tqdm(predictor.propagate_in_video(state)):
82 | # if frame_idx >= len(loaded_frames):
83 | # print(f"Frame index {frame_idx} out of range. Skipping.")
84 | # continue
85 |
86 | # img = loaded_frames[frame_idx]
87 | mask_to_vis = {}
88 | bbox_to_vis = {}
89 |
90 | for obj_id, mask in zip(object_ids, masks):
91 | mask = mask[0].cpu().numpy()
92 | mask = mask > 0.0
93 | non_zero_indices = np.argwhere(mask)
94 | if len(non_zero_indices) == 0:
95 | bbox = [0, 0, 0, 0]
96 | else:
97 | y_min, x_min = non_zero_indices.min(axis=0).tolist()
98 | y_max, x_max = non_zero_indices.max(axis=0).tolist()
99 | bbox = [x_min, y_min, x_max - x_min, y_max - y_min]
100 | bbox_to_vis[obj_id] = bbox
101 | mask_to_vis[obj_id] = mask
102 |
103 | if args.save_to_video:
104 | img = loaded_frames[frame_idx]
105 | for obj_id, mask in mask_to_vis.items():
106 | mask_img = np.zeros((height, width, 3), np.uint8)
107 | mask_img[mask] = color[(obj_id + 1) % len(color)]
108 | img = cv2.addWeighted(img, 1, mask_img, 0.2, 0)
109 |
110 |
111 | for obj_id, bbox in bbox_to_vis.items():
112 | cv2.rectangle(img, (bbox[0], bbox[1]), (bbox[0] + bbox[2], bbox[1] + bbox[3]), color[obj_id % len(color)], 2)
113 |
114 | # add the text to the bottom of EACH frame
115 | # The text depends on the current frame's position in the video, with one decimal place reserved for seconds
116 | # The font color is red, and the text is centered (so you should calculate the len of text) on the bottom like subtitle, occupying 1/5 of the frame height.
117 | # The text is displayed in the format "103.5s"
118 | # time = frame_idx / 30
119 | # time_text = f"{time:.1f}s"
120 | # font = cv2.FONT_HERSHEY_SIMPLEX
121 | # font_scale = 4
122 | # font_thickness = 8
123 | # font_color = (0, 0, 255)
124 | # text_size = cv2.getTextSize(time_text, font, font_scale, font_thickness)[0]
125 | # text_x = (width - text_size[0]) // 2
126 | # text_y = height - 5
127 | # cv2.putText(img, time_text, (text_x, text_y), font, font_scale, font_color, font_thickness)
128 |
129 | out.write(img)
130 |
131 | if args.save_to_video:
132 | out.release()
133 |
134 | del predictor, state
135 | gc.collect()
136 | torch.clear_autocast_cache()
137 | torch.cuda.empty_cache()
138 |
139 | if __name__ == "__main__":
140 | parser = argparse.ArgumentParser()
141 | parser.add_argument("--video_path", default="./assets/demo.mp4", help="Input video path or directory of frames.")
142 | parser.add_argument("--txt_path", default="./assets/demo.txt", help="Path to ground truth text file.")
143 | parser.add_argument("--model_path", default="./checkpoints/sam2.1_hiera_base_plus.pt", help="Path to the model checkpoint.")
144 | parser.add_argument("--video_output_path", default="./results/", help="Path to save the output video.")
145 | parser.add_argument("--save_to_video", default=True, help="Save results to a video.")
146 | args = parser.parse_args()
147 | main(args)
148 |
--------------------------------------------------------------------------------
/scripts/get_vis.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import json
3 | import sys
4 | from tqdm import tqdm
5 |
6 | def add_captions_to_video(video_input_path, json_path, video_output_path):
7 | # Load JSON data
8 | with open(json_path, 'r') as f:
9 | captions_data = json.load(f)
10 |
11 | # Open the input video
12 | cap = cv2.VideoCapture(video_input_path)
13 | fps = int(cap.get(cv2.CAP_PROP_FPS))
14 | width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
15 | height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
16 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
17 |
18 | # Initialize video writer
19 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
20 | out = cv2.VideoWriter(video_output_path, fourcc, fps, (width, height))
21 |
22 | # Helper function: Check if a frame is within a time segment
23 | def is_frame_in_segment(frame_idx, start, end, fps):
24 | timestamp = frame_idx / fps
25 | return start <= timestamp <= end
26 |
27 | # Helper function: Wrap text to fit within the video width
28 | def wrap_text(text, font, font_scale, thickness, max_width):
29 | words = text.split()
30 | lines = []
31 | current_line = ""
32 | for word in words:
33 | test_line = current_line + " " + word if current_line else word
34 | text_size = cv2.getTextSize(test_line, font, font_scale, thickness)[0]
35 | if text_size[0] <= max_width:
36 | current_line = test_line
37 | else:
38 | lines.append(current_line)
39 | current_line = word
40 | if current_line:
41 | lines.append(current_line)
42 | return lines
43 |
44 | # Dynamic scaling based on video resolution
45 | def get_font_scale_and_thickness(width, height):
46 | base_width = 1280.0 # Reference width for scaling
47 | scale_factor = width / base_width
48 | font_scale = max(0.5 * scale_factor, 0.4) # Reduced default font size
49 | thickness = max(int(1.5 * scale_factor), 1) # Slightly thinner font
50 | return font_scale, thickness
51 |
52 | # Process video frames
53 | frame_idx = 0
54 | with tqdm(total=frame_count, desc="Processing video") as pbar:
55 | while cap.isOpened():
56 | ret, frame = cap.read()
57 | if not ret:
58 | break
59 |
60 | # Find captions for the current frame
61 | for caption in captions_data:
62 | start_time = float(caption['segment'][0])
63 | end_time = float(caption['segment'][1])
64 | text = caption['model_answer']
65 |
66 | if is_frame_in_segment(frame_idx, start_time, end_time, fps):
67 | # Get font scale and thickness dynamically
68 | font = cv2.FONT_HERSHEY_SIMPLEX
69 | font_scale, font_thickness = get_font_scale_and_thickness(width, height)
70 | text_color = (255, 255, 255) # White
71 | bg_color = (0, 0, 0, 150) # Black with alpha for transparency
72 | margin = int(10 * (height / 720)) # Adjust margin proportionally
73 | max_width = int(width * 0.95) # Wrap text at 85% of video width
74 |
75 | # Wrap text into multiple lines
76 | lines = wrap_text(text, font, font_scale, font_thickness, max_width)
77 | line_height = cv2.getTextSize("Test", font, font_scale, font_thickness)[0][1] + margin
78 |
79 | # Determine the text box position
80 | total_text_height = len(lines) * line_height
81 | text_x = (width - max_width) // 2
82 | text_y = height - margin - total_text_height
83 |
84 | # Create a transparent overlay
85 | overlay = frame.copy()
86 | cv2.rectangle(
87 | overlay,
88 | (text_x - margin, text_y - margin),
89 | (text_x + max_width + margin, text_y + total_text_height + margin),
90 | (0, 0, 0), # Black background
91 | -1
92 | )
93 |
94 | # Add the transparent overlay to the frame
95 | alpha = 0.6 # Transparency factor for the background
96 | frame = cv2.addWeighted(overlay, alpha, frame, 1 - alpha, 0)
97 |
98 | # Draw each line of text
99 | for i, line in enumerate(lines):
100 | line_y = text_y + (i * line_height) + line_height
101 | cv2.putText(
102 | frame,
103 | line,
104 | (text_x, line_y),
105 | font,
106 | font_scale,
107 | text_color,
108 | font_thickness,
109 | lineType=cv2.LINE_AA
110 | )
111 |
112 | out.write(frame)
113 | frame_idx += 1
114 | pbar.update(1)
115 |
116 | # Release resources
117 | cap.release()
118 | out.release()
119 | print(f"Captioned video saved to: {video_output_path}")
120 |
121 |
122 | if __name__ == "__main__":
123 | # Read arguments
124 | video_input_path = sys.argv[1]
125 | json_path = sys.argv[2]
126 | video_output_path = sys.argv[3]
127 |
128 | add_captions_to_video(video_input_path, json_path, video_output_path)
129 |
--------------------------------------------------------------------------------
/scripts/main_inference.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import gc
3 | import numpy as np
4 | import os
5 | import os.path as osp
6 | import pdb
7 | import torch
8 | from sam2.build_sam import build_sam2_video_predictor
9 | from tqdm import tqdm
10 |
11 |
12 | def load_lasot_gt(gt_path):
13 | with open(gt_path, 'r') as f:
14 | gt = f.readlines()
15 |
16 | # bbox in first frame are prompts
17 | prompts = {}
18 | fid = 0
19 | for line in gt:
20 | x, y, w, h = map(int, line.split(','))
21 | prompts[fid] = ((x, y, x+w, y+h), 0)
22 | fid += 1
23 |
24 | return prompts
25 |
26 | color = [
27 | (255, 0, 0),
28 | ]
29 |
30 | testing_set = "data/LaSOT/testing_set.txt"
31 | with open(testing_set, 'r') as f:
32 | test_videos = f.readlines()
33 |
34 | exp_name = "samurai"
35 | model_name = "base_plus"
36 |
37 | checkpoint = f"sam2/checkpoints/sam2.1_hiera_{model_name}.pt"
38 | if model_name == "base_plus":
39 | model_cfg = "configs/samurai/sam2.1_hiera_b+.yaml"
40 | else:
41 | model_cfg = f"configs/samurai/sam2.1_hiera_{model_name[0]}.yaml"
42 |
43 | video_folder= "data/LaSOT"
44 | pred_folder = f"results/{exp_name}/{exp_name}_{model_name}"
45 |
46 | save_to_video = True
47 | if save_to_video:
48 | vis_folder = f"visualization/{exp_name}/{model_name}"
49 | os.makedirs(vis_folder, exist_ok=True)
50 | vis_mask = {}
51 | vis_bbox = {}
52 |
53 | test_videos = sorted(test_videos)
54 | for vid, video in enumerate(test_videos):
55 |
56 | cat_name = video.split('-')[0]
57 | cid_name = video.split('-')[1]
58 | video_basename = video.strip()
59 | frame_folder = osp.join(video_folder, cat_name, video.strip(), "img")
60 |
61 | num_frames = len(os.listdir(osp.join(video_folder, cat_name, video.strip(), "img")))
62 |
63 | print(f"\033[91mRunning video [{vid+1}/{len(test_videos)}]: {video} with {num_frames} frames\033[0m")
64 |
65 | height, width = cv2.imread(osp.join(frame_folder, "00000001.jpg")).shape[:2]
66 |
67 | predictor = build_sam2_video_predictor(model_cfg, checkpoint, device="cuda:0")
68 |
69 | predictions = []
70 |
71 | if save_to_video:
72 | fourcc = cv2.VideoWriter_fourcc(*'mp4v')
73 | out = cv2.VideoWriter(osp.join(vis_folder, f'{video_basename}.mp4'), fourcc, 30, (width, height))
74 |
75 | # Start processing frames
76 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
77 | state = predictor.init_state(frame_folder, offload_video_to_cpu=True, offload_state_to_cpu=True, async_loading_frames=True)
78 |
79 | prompts = load_lasot_gt(osp.join(video_folder, cat_name, video.strip(), "groundtruth.txt"))
80 |
81 | bbox, track_label = prompts[0]
82 | frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, box=bbox, frame_idx=0, obj_id=0)
83 |
84 | for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
85 | mask_to_vis = {}
86 | bbox_to_vis = {}
87 |
88 | assert len(masks) == 1 and len(object_ids) == 1, "Only one object is supported right now"
89 | for obj_id, mask in zip(object_ids, masks):
90 | mask = mask[0].cpu().numpy()
91 | mask = mask > 0.0
92 | non_zero_indices = np.argwhere(mask)
93 | if len(non_zero_indices) == 0:
94 | bbox = [0, 0, 0, 0]
95 | else:
96 | y_min, x_min = non_zero_indices.min(axis=0).tolist()
97 | y_max, x_max = non_zero_indices.max(axis=0).tolist()
98 | bbox = [x_min, y_min, x_max-x_min, y_max-y_min]
99 | bbox_to_vis[obj_id] = bbox
100 | mask_to_vis[obj_id] = mask
101 |
102 | if save_to_video:
103 |
104 | img = cv2.imread(f'{frame_folder}/{frame_idx+1:08d}.jpg')
105 | if img is None:
106 | break
107 |
108 | for obj_id in mask_to_vis.keys():
109 | mask_img = np.zeros((height, width, 3), np.uint8)
110 | mask_img[mask_to_vis[obj_id]] = color[(obj_id+1)%len(color)]
111 | img = cv2.addWeighted(img, 1, mask_img, 0.75, 0)
112 |
113 | for obj_id in bbox_to_vis.keys():
114 | cv2.rectangle(img, (bbox_to_vis[obj_id][0], bbox_to_vis[obj_id][1]), (bbox_to_vis[obj_id][0]+bbox_to_vis[obj_id][2], bbox_to_vis[obj_id][1]+bbox_to_vis[obj_id][3]), color[(obj_id)%len(color)], 2)
115 |
116 | x1, y1, x2, y2 = prompts[frame_idx][0]
117 | cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
118 | out.write(img)
119 |
120 | predictions.append(bbox_to_vis)
121 |
122 | os.makedirs(pred_folder, exist_ok=True)
123 | with open(osp.join(pred_folder, f'{video_basename}.txt'), 'w') as f:
124 | for pred in predictions:
125 | x, y, w, h = pred[0]
126 | f.write(f"{x},{y},{w},{h}\n")
127 |
128 | if save_to_video:
129 | out.release()
130 |
131 | del predictor
132 | del state
133 | gc.collect()
134 | torch.clear_autocast_cache()
135 | torch.cuda.empty_cache()
136 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 |
8 | from setuptools import find_packages, setup
9 |
10 | # Package metadata
11 | NAME = "CAT-2"
12 | VERSION = "1.0"
13 | DESCRIPTION = "Caption Anything in Video: Object-centric Dense Video Captioning with Multimodal Controls"
14 | URL = "https://github.com/yunlong10/CAT-2"
15 | AUTHOR = "Tang, Yunlong and Bi, Jing and Hua, Hang and Xiao, Yunzhong and Song, Yizhi and Wang, Teng and Huang, Chao and Feng, Mingqian and Guo, Junjia and Liu, Zhuo and Song, Luchuan and Liang, Susan and Wang, Bingjie and Shimada, Daiki and Vosoughi, Ali and Zhang, Zeliang and Luo, Jiebo and Xu, Chenliang"
16 | AUTHOR_EMAIL = "yunlong.tang@rochester.edu"
17 | LICENSE = "Apache 2.0"
18 |
19 | # Read the contents of README file
20 | with open("README.md", "r", encoding="utf-8") as f:
21 | LONG_DESCRIPTION = f.read()
22 |
23 | # Required dependencies
24 | REQUIRED_PACKAGES = [
25 | "torch>=2.3.1",
26 | "torchvision>=0.18.1",
27 | "numpy>=1.24.4",
28 | "tqdm>=4.66.1",
29 | "hydra-core>=1.3.2",
30 | "iopath>=0.1.10",
31 | "pillow>=9.4.0",
32 | "matplotlib>=3.9.1",
33 | "moviepy==1.0.3",
34 | "accelerate>=0.26.0",
35 | "numpy==1.26.1",
36 | "tikzplotlib",
37 | "jpeg4py",
38 | "opencv-python",
39 | "lmdb",
40 | "pandas",
41 | "scipy",
42 | "loguru",
43 | "einops",
44 | "transformers==4.40.1",
45 | "timm",
46 | "decord",
47 | "imageio",
48 | "scenedetect",
49 | "SentencePiece",
50 | "gradio",
51 | ]
52 |
53 | EXTRA_PACKAGES = {
54 | "notebooks": [
55 | "matplotlib>=3.9.1",
56 | "jupyter>=1.0.0",
57 | "opencv-python>=4.7.0",
58 | "eva-decord>=0.6.1",
59 | ],
60 | "interactive-demo": [
61 | "Flask>=3.0.3",
62 | "Flask-Cors>=5.0.0",
63 | "av>=13.0.0",
64 | "dataclasses-json>=0.6.7",
65 | "eva-decord>=0.6.1",
66 | "gunicorn>=23.0.0",
67 | "imagesize>=1.4.1",
68 | "pycocotools>=2.0.8",
69 | "strawberry-graphql>=0.243.0",
70 | ],
71 | "dev": [
72 | "black==24.2.0",
73 | "usort==1.0.2",
74 | "ufmt==2.0.0b2",
75 | "fvcore>=0.1.5.post20221221",
76 | "pandas>=2.2.2",
77 | "scikit-image>=0.24.0",
78 | "tensorboard>=2.17.0",
79 | "pycocotools>=2.0.8",
80 | "tensordict>=0.5.0",
81 | "opencv-python>=4.7.0",
82 | "submitit>=1.5.1",
83 | ],
84 | }
85 |
86 | # By default, we also build the SAM 2 CUDA extension.
87 | # You may turn off CUDA build with `export SAM2_BUILD_CUDA=0`.
88 | BUILD_CUDA = os.getenv("SAM2_BUILD_CUDA", "1") == "1"
89 | # By default, we allow SAM 2 installation to proceed even with build errors.
90 | # You may force stopping on errors with `export SAM2_BUILD_ALLOW_ERRORS=0`.
91 | BUILD_ALLOW_ERRORS = os.getenv("SAM2_BUILD_ALLOW_ERRORS", "1") == "1"
92 |
93 | # Catch and skip errors during extension building and print a warning message
94 | # (note that this message only shows up under verbose build mode
95 | # "pip install -v -e ." or "python setup.py build_ext -v")
96 | CUDA_ERROR_MSG = (
97 | "{}\n\n"
98 | "Failed to build the SAM 2 CUDA extension due to the error above. "
99 | "You can still use SAM 2 and it's OK to ignore the error above, although some "
100 | "post-processing functionality may be limited (which doesn't affect the results in most cases; "
101 | "(see https://github.com/facebookresearch/sam2/blob/main/INSTALL.md).\n"
102 | )
103 |
104 |
105 | def get_extensions():
106 | if not BUILD_CUDA:
107 | return []
108 |
109 | try:
110 | from torch.utils.cpp_extension import CUDAExtension
111 |
112 | srcs = ["sam2/csrc/connected_components.cu"]
113 | compile_args = {
114 | "cxx": [],
115 | "nvcc": [
116 | "-DCUDA_HAS_FP16=1",
117 | "-D__CUDA_NO_HALF_OPERATORS__",
118 | "-D__CUDA_NO_HALF_CONVERSIONS__",
119 | "-D__CUDA_NO_HALF2_OPERATORS__",
120 | ],
121 | }
122 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)]
123 | except Exception as e:
124 | if BUILD_ALLOW_ERRORS:
125 | print(CUDA_ERROR_MSG.format(e))
126 | ext_modules = []
127 | else:
128 | raise e
129 |
130 | return ext_modules
131 |
132 |
133 | try:
134 | from torch.utils.cpp_extension import BuildExtension
135 |
136 | class BuildExtensionIgnoreErrors(BuildExtension):
137 |
138 | def finalize_options(self):
139 | try:
140 | super().finalize_options()
141 | except Exception as e:
142 | print(CUDA_ERROR_MSG.format(e))
143 | self.extensions = []
144 |
145 | def build_extensions(self):
146 | try:
147 | super().build_extensions()
148 | except Exception as e:
149 | print(CUDA_ERROR_MSG.format(e))
150 | self.extensions = []
151 |
152 | def get_ext_filename(self, ext_name):
153 | try:
154 | return super().get_ext_filename(ext_name)
155 | except Exception as e:
156 | print(CUDA_ERROR_MSG.format(e))
157 | self.extensions = []
158 | return "_C.so"
159 |
160 | cmdclass = {
161 | "build_ext": (
162 | BuildExtensionIgnoreErrors.with_options(no_python_abi_suffix=True)
163 | if BUILD_ALLOW_ERRORS
164 | else BuildExtension.with_options(no_python_abi_suffix=True)
165 | )
166 | }
167 | except Exception as e:
168 | cmdclass = {}
169 | if BUILD_ALLOW_ERRORS:
170 | print(CUDA_ERROR_MSG.format(e))
171 | else:
172 | raise e
173 |
174 |
175 | # Setup configuration
176 | setup(
177 | name=NAME,
178 | version=VERSION,
179 | description=DESCRIPTION,
180 | long_description=LONG_DESCRIPTION,
181 | long_description_content_type="text/markdown",
182 | url=URL,
183 | author=AUTHOR,
184 | author_email=AUTHOR_EMAIL,
185 | license=LICENSE,
186 | packages=find_packages(exclude="notebooks"),
187 | include_package_data=True,
188 | install_requires=REQUIRED_PACKAGES,
189 | extras_require=EXTRA_PACKAGES,
190 | python_requires=">=3.10.0",
191 | ext_modules=get_extensions(),
192 | cmdclass=cmdclass,
193 | )
194 |
--------------------------------------------------------------------------------
/trace/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | import copy
3 | from functools import partial
4 |
5 | import torch
6 |
7 | from .model import TraceMistralForCausalLM
8 | from .model.builder import load_pretrained_model
9 | from .conversation import conv_templates, SeparatorStyle
10 | from .mm_utils import process_video, tokenizer_MMODAL_token, get_model_name_from_path, KeywordsStoppingCriteria
11 | from .constants import NUM_FRAMES, DEFAULT_MMODAL_TOKEN, DEFAULT_MMODAL_START_TOKEN, DEFAULT_MMODAL_END_TOKEN, MMODAL_TOKEN_INDEX
12 |
13 | def model_init(model_path=None):
14 | model_path = "DAMO-NLP-SG/Trace-7B" if model_path is None else model_path
15 | model_name = get_model_name_from_path(model_path)
16 | tokenizer, model, processor, context_len = load_pretrained_model(model_path, None, model_name)
17 |
18 | num_frames = model.config.num_frames if hasattr(model.config, "num_frames") else NUM_FRAMES
19 |
20 | return model, partial(process_video, aspect_ratio=None, processor=processor, num_frames=num_frames), tokenizer
21 |
22 |
23 | def infer(model, video, instruct, tokenizer, do_sample=False):
24 | """inference api of Trace for video understanding.
25 |
26 | Args:
27 | model: Trace model.
28 | video (torch.Tensor): video tensor (T, C, H, W).
29 | instruct (str): text instruction for understanding video.
30 | tokenizer: tokenizer.
31 | do_sample (bool): whether to sample.
32 | Returns:
33 | str: response of the model.
34 | """
35 |
36 | # 1. vision preprocess (load & transform image or video).
37 | tensor = [video.half().cuda()]
38 | modals = ["video"]
39 |
40 | # 2. text preprocess (tag process & generate prompt).
41 | modal_token = DEFAULT_MMODAL_TOKEN['VIDEO']
42 | modal_index = MMODAL_TOKEN_INDEX["VIDEO"]
43 | instruct = modal_token + '\n' + instruct
44 |
45 | conv = conv_templates["llama_2"].copy()
46 | conv.append_message(conv.roles[0], instruct)
47 | conv.append_message(conv.roles[1], None)
48 | prompt = conv.get_prompt()
49 |
50 | input_ids = tokenizer_MMODAL_token(prompt, tokenizer, modal_index, return_tensors='pt').unsqueeze(0).cuda()
51 | attention_masks = input_ids.ne(tokenizer.pad_token_id).long().cuda()
52 |
53 | # 3. generate response according to visual signals and prompts.
54 | stop_str = conv.sep if conv.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.QWEN] else conv.sep2
55 | # keywords = ["", ""]
56 | keywords = [stop_str]
57 | stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
58 |
59 | with torch.inference_mode():
60 | output_ids = model.generate(
61 | input_ids,
62 | attention_mask=attention_masks,
63 | images_or_videos=tensor,
64 | modal_list=modals,
65 | do_sample=do_sample,
66 | temperature=0.2 if do_sample else 0.0,
67 | max_new_tokens=1024,
68 | use_cache=True,
69 | stopping_criteria=[stopping_criteria],
70 | pad_token_id=tokenizer.eos_token_id,
71 | )
72 |
73 | outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
74 |
75 | return outputs
76 |
77 |
78 | def x_infer(video, question, model, tokenizer, mode='vanilla', do_sample=False):
79 | if mode == 'mcqa':
80 | instruction = f'{question}\nAnswer with the option\'s letter from the given choices directly and only give the best option.'
81 | return infer(model=model, tokenizer=tokenizer, video=video, instruct=instruction, do_sample=do_sample)
82 | elif mode == 'openend':
83 | instruction = f'{question}\nAnswer the question using a single word or a short phrase with multiple words.'
84 | return infer(model=model, tokenizer=tokenizer, video=video, instruct=instruction, do_sample=do_sample)
85 | elif mode == 'vanilla':
86 | instruction = question
87 | return infer(model=model, tokenizer=tokenizer, video=video, instruct=instruction, do_sample=do_sample)
--------------------------------------------------------------------------------
/trace/constants.py:
--------------------------------------------------------------------------------
1 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
2 | WORKER_HEART_BEAT_INTERVAL = 15
3 |
4 | LOGDIR = "./log_dir"
5 |
6 | NUM_FRAMES = 8
7 | MAX_FRAMES = 128
8 | NUM_FRAMES_PER_SECOND = 1
9 | Grids = [(2, 2), (1, 2), (1, 3), (1, 4), (2, 1), (3, 1), (4, 1)]
10 |
11 | # Model Constants
12 | IGNORE_INDEX = -100
13 | IMAGE_TOKEN_INDEX = -200
14 | DEFAULT_IMAGE_TOKEN = ""
15 | DEFAULT_VIDEO_TOKEN = "