├── README.md
├── gradio_demo
├── cartoon
│ ├── 0.mp4
│ ├── 1.mp4
│ ├── 2.mp4
│ ├── 3.mp4
│ └── 4.mp4
├── image_demo.py
├── normal_videos
│ ├── 0.mp4
│ ├── 1.mp4
│ ├── 2.mp4
│ ├── 3.mp4
│ ├── 4.mp4
│ ├── 5.mp4
│ └── 6.mp4
├── pipeline_minimax_remover.py
├── sam2
│ ├── SAM2-Video-Predictor
│ │ └── checkpoints
│ │ │ ├── README.md
│ │ │ └── sam2_hiera_l.yaml
│ ├── __init__.py
│ ├── automatic_mask_generator.py
│ ├── build_sam.py
│ ├── configs
│ │ ├── __init__.py
│ │ ├── sam2_hiera_b+.yaml
│ │ ├── sam2_hiera_l.yaml
│ │ ├── sam2_hiera_s.yaml
│ │ └── sam2_hiera_t.yaml
│ ├── modeling
│ │ ├── __init__.py
│ │ ├── backbones
│ │ │ ├── __init__.py
│ │ │ ├── hieradet.py
│ │ │ ├── image_encoder.py
│ │ │ └── utils.py
│ │ ├── memory_attention.py
│ │ ├── memory_encoder.py
│ │ ├── position_encoding.py
│ │ ├── sam
│ │ │ ├── __init__.py
│ │ │ ├── mask_decoder.py
│ │ │ ├── prompt_encoder.py
│ │ │ └── transformer.py
│ │ ├── sam2_base.py
│ │ └── sam2_utils.py
│ ├── sam2_image_predictor.py
│ ├── sam2_video_predictor.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── amg.py
│ │ ├── download.py
│ │ ├── misc.py
│ │ ├── transforms.py
│ │ └── visualization.py
├── test.py
└── transformer_minimax_remover.py
├── imgs
└── gradio_demo.gif
├── pipeline_minimax_remover.py
├── requirements.txt
├── test_minimax_remover.py
└── transformer_minimax_remover.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | MiniMax-Remover: Taming Bad Noise Helps Video Object Removal
3 |
4 |
5 |
6 | Bojia Zi*,
7 | Weixuan Peng*,
8 | Xianbiao Qi†,
9 | Jianan Wang, Shihao Zhao, Rong Xiao, Kam-Fai Wong
10 | * Equal contribution. † Corresponding author.
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | ---
24 |
25 | ## 🚀 Overview
26 |
27 | **MiniMax-Remover** is a fast and effective video object remover based on minimax optimization. It operates in two stages: the first stage trains a remover using a simplified DiT architecture, while the second stage distills a robust remover with CFG removal and fewer inference steps.
28 |
29 | ---
30 |
31 | ## ✨ Features:
32 |
33 | * **Fast:** Requires only 6 inference steps and does not use CFG, making it highly efficient.
34 |
35 | * **Effective:** Seamlessly removes objects from videos and generates high-quality visual content.
36 |
37 | * **Robust:** Maintains robustness by preventing the regeneration of undesired objects or artifacts within the masked region, even under varying noise conditions.
38 |
39 | ---
40 |
41 | ## 🛠️ Installation
42 |
43 | All dependencies are listed in `requirements.txt`.
44 |
45 | ```bash
46 | pip install -r requirements.txt
47 | ```
48 |
49 | ---
50 |
51 | ## 🏃♂️ Gradio Demo
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 | You can use this gradio demo to remove objects. Note that you don't need to compile the sam2.
60 | ```bash
61 | cd gradio_demo
62 | python3 test.py
63 | ```
64 |
65 | ---
66 |
67 | ## 📂 Download
68 |
69 | ```shell
70 | huggingface-cli download zibojia/minimax-remover --include vae transformer scheduler --local-dir .
71 | ```
72 |
73 | ---
74 |
75 | ## ⚡ Quick Start
76 |
77 | ### Minimal Example
78 |
79 | ```python
80 | import torch
81 | from diffusers.utils import export_to_video
82 | from decord import VideoReader
83 | from diffusers.models import AutoencoderKLWan
84 | from transformer_minimax_remover import Transformer3DModel
85 | from diffusers.schedulers import UniPCMultistepScheduler
86 | from pipeline_minimax_remover import Minimax_Remover_Pipeline
87 |
88 | random_seed = 42
89 | video_length = 81
90 | device = torch.device("cuda:0")
91 |
92 | # Load model weights separately
93 | vae = AutoencoderKLWan.from_pretrained("./vae", torch_dtype=torch.float16)
94 | transformer = Transformer3DModel.from_pretrained("./transformer", torch_dtype=torch.float16)
95 | scheduler = UniPCMultistepScheduler.from_pretrained("./scheduler")
96 |
97 | images = # images in range [-1, 1]
98 | masks = # masks in range [0, 1]
99 |
100 | # Initialize the pipeline (pass the loaded weights as objects)
101 | pipe = Minimax_Remover_Pipeline(vae=vae, transformer=transformer, \
102 | scheduler=scheduler, torch_dtype=torch.float16
103 | ).to(device)
104 |
105 | result = pipe(images=images, masks=masks, num_frames=video_length, height=480, width=832, \
106 | num_inference_steps=12, generator=torch.Generator(device=device).manual_seed(random_seed), iterations=6 \
107 | ).frames[0]
108 | export_to_video(result, "./output.mp4")
109 | ```
110 | ---
111 |
112 | ## 📧 Contact
113 |
114 | Feel free to send an email to [19210240030@fudan.edu.cn](mailto:19210240030@fudan.edu.cn) if you have any questions or suggestions.
115 |
--------------------------------------------------------------------------------
/gradio_demo/cartoon/0.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/0.mp4
--------------------------------------------------------------------------------
/gradio_demo/cartoon/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/1.mp4
--------------------------------------------------------------------------------
/gradio_demo/cartoon/2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/2.mp4
--------------------------------------------------------------------------------
/gradio_demo/cartoon/3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/3.mp4
--------------------------------------------------------------------------------
/gradio_demo/cartoon/4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/4.mp4
--------------------------------------------------------------------------------
/gradio_demo/image_demo.py:
--------------------------------------------------------------------------------
1 | # create gradio demo to input one image and output one image
2 | import os
3 | import gradio as gr
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 | import torch
8 | import time
9 | import random
10 | from huggingface_hub import snapshot_download
11 | from diffusers.models import AutoencoderKLWan
12 | from transformer_minimax_remover import Transformer3DModel
13 | from diffusers.schedulers import UniPCMultistepScheduler
14 | from pipeline_minimax_remover import Minimax_Remover_Pipeline
15 | from sam2.build_sam import build_sam2
16 | from sam2.sam2_image_predictor import SAM2ImagePredictor
17 |
18 | # Create directories for models
19 | os.makedirs("./SAM2-Video-Predictor/checkpoints/", exist_ok=True)
20 | os.makedirs("./model/", exist_ok=True)
21 |
22 | # Download models from Hugging Face Hub
23 | def download_sam2():
24 | snapshot_download(repo_id="facebook/sam2-hiera-large", local_dir="./SAM2-Video-Predictor/checkpoints/")
25 | print("Download sam2 completed")
26 |
27 | def download_remover():
28 | snapshot_download(repo_id="zibojia/minimax-remover", local_dir="./model/")
29 | print("Download minimax remover completed")
30 |
31 | download_sam2()
32 | download_remover()
33 |
34 | COLOR_PALETTE = [
35 | (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),
36 | (0, 255, 255), (255, 128, 0), (128, 0, 255), (0, 128, 255), (128, 255, 0)
37 | ]
38 |
39 | random_seed = 42
40 | W = 1024
41 | H = W
42 | device = "cuda" if torch.cuda.is_available() else "cpu"
43 |
44 | def get_pipe_and_predictor():
45 | vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16)
46 | transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16)
47 | scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler")
48 |
49 | pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler)
50 | pipe.to(device)
51 |
52 | sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
53 | config = "sam2_hiera_l.yaml"
54 |
55 | model = build_sam2(config, sam2_checkpoint, device=device)
56 | model.image_size = 1024
57 | image_predictor = SAM2ImagePredictor(sam_model=model)
58 |
59 | return pipe, image_predictor
60 |
61 | def get_image_info(image_pil, state):
62 | state["input_points"] = []
63 | state["scaled_points"] = []
64 | state["input_labels"] = []
65 |
66 | image_np = np.array(image_pil)
67 |
68 | if image_np.shape[0] > image_np.shape[1]:
69 | W_ = W
70 | H_ = int(W_ * image_np.shape[0] / image_np.shape[1])
71 | else:
72 | H_ = H
73 | W_ = int(H_ * image_np.shape[1] / image_np.shape[0])
74 |
75 | image_np = cv2.resize(image_np, (W_, H_))
76 | state["origin_image"] = image_np
77 | state["mask"] = None
78 | state["painted_image"] = None
79 | return Image.fromarray(image_np)
80 |
81 | def segment_frame(evt: gr.SelectData, label, state):
82 | if state["origin_image"] is None:
83 | return None
84 | x, y = evt.index
85 | new_point = [x, y]
86 | label_value = 1 if label == "Positive" else 0
87 |
88 | state["input_points"].append(new_point)
89 | state["input_labels"].append(label_value)
90 | height, width = state["origin_image"].shape[0:2]
91 | scaled_points = []
92 | for pt in state["input_points"]:
93 | sx = pt[0]
94 | sy = pt[1]
95 | scaled_points.append([sx, sy])
96 |
97 | state["scaled_points"] = scaled_points
98 |
99 | image_predictor.set_image(state["origin_image"])
100 | mask, _, _ = image_predictor.predict(
101 | point_coords=np.array(state["scaled_points"]),
102 | point_labels=np.array(state["input_labels"]),
103 | multimask_output=False,
104 | )
105 |
106 | mask = np.squeeze(mask)
107 | mask = cv2.resize(mask, (width, height))
108 | mask = mask[:,:,None]
109 |
110 | color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
111 | color = color[None, None, :]
112 | org_image = state["origin_image"].astype(np.float32) / 255.0
113 | painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color
114 | painted_image = np.uint8(np.clip(painted_image * 255, 0, 255))
115 | state["painted_image"] = painted_image
116 | state["mask"] = mask[:,:,0]
117 |
118 | for i in range(len(state["input_points"])):
119 | point = state["input_points"][i]
120 | if state["input_labels"][i] == 0:
121 | cv2.circle(painted_image, tuple(point), radius=5, color=(0, 0, 255), thickness=-1)
122 | else:
123 | cv2.circle(painted_image, tuple(point), radius=5, color=(255, 0, 0), thickness=-1)
124 |
125 | return Image.fromarray(painted_image)
126 |
127 | def clear_clicks(state):
128 | state["input_points"] = []
129 | state["input_labels"] = []
130 | state["scaled_points"] = []
131 | state["mask"] = None
132 | state["painted_image"] = None
133 | return Image.fromarray(state["origin_image"]) if state["origin_image"] is not None else None
134 |
135 | def preprocess_for_removal(image, mask):
136 | if image.shape[0] > image.shape[1]:
137 | img_resized = cv2.resize(image, (480, 832), interpolation=cv2.INTER_LINEAR)
138 | else:
139 | img_resized = cv2.resize(image, (832, 480), interpolation=cv2.INTER_LINEAR)
140 | img_resized = img_resized.astype(np.float32) / 127.5 - 1.0 # [-1, 1]
141 |
142 | if mask.shape[0] > mask.shape[1]:
143 | msk_resized = cv2.resize(mask, (480, 832), interpolation=cv2.INTER_NEAREST)
144 | else:
145 | msk_resized = cv2.resize(mask, (832, 480), interpolation=cv2.INTER_NEAREST)
146 | msk_resized = msk_resized.astype(np.float32)
147 | msk_resized = (msk_resized > 0.5).astype(np.float32)
148 |
149 | return torch.from_numpy(img_resized).half().to(device), torch.from_numpy(msk_resized).half().to(device)
150 |
151 | def inference_and_return_image(dilation_iterations, num_inference_steps, state=None):
152 | if state["origin_image"] is None or state["mask"] is None:
153 | return None
154 | image = state["origin_image"]
155 | mask = state["mask"]
156 |
157 | img_tensor, mask_tensor = preprocess_for_removal(image, mask)
158 | img_tensor = img_tensor.unsqueeze(0)
159 | mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(-1)
160 |
161 | if mask_tensor.shape[1] < mask_tensor.shape[2]:
162 | height = 480
163 | width = 832
164 | else:
165 | height = 832
166 | width = 480
167 |
168 | with torch.no_grad():
169 | out = pipe(
170 | images=img_tensor,
171 | masks=mask_tensor,
172 | num_frames=1,
173 | height=height,
174 | width=width,
175 | num_inference_steps=int(num_inference_steps),
176 | generator=torch.Generator(device=device).manual_seed(random_seed),
177 | iterations=int(dilation_iterations)
178 | ).frames[0]
179 |
180 | out = np.uint8(out * 255)
181 |
182 | return Image.fromarray(out[0])
183 |
184 | pipe, image_predictor = get_pipe_and_predictor()
185 |
186 | with gr.Blocks() as demo:
187 | state = gr.State({
188 | "origin_image": None,
189 | "mask": None,
190 | "painted_image": None,
191 | "input_points": [],
192 | "scaled_points": [],
193 | "input_labels": [],
194 | })
195 | gr.Markdown("Minimax-Remover: Image Object Removal
")
196 |
197 | with gr.Row():
198 | with gr.Column():
199 | image_input = gr.Image(label="Upload Image", type="pil")
200 | get_info_btn = gr.Button("Load Image")
201 |
202 | point_prompt = gr.Radio(["Positive", "Negative"], label="Click Type", value="Positive")
203 | clear_btn = gr.Button("Clear All Clicks")
204 |
205 | dilation_slider = gr.Slider(minimum=1, maximum=20, value=6, step=1, label="Mask Dilation")
206 | inference_steps_slider = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Inference Steps")
207 |
208 | remove_btn = gr.Button("Remove Object")
209 |
210 | with gr.Column():
211 | image_output_segmentation = gr.Image(label="Segmentation", interactive=True)
212 | image_output_removed = gr.Image(label="Object Removed")
213 |
214 | get_info_btn.click(get_image_info, inputs=[image_input, state], outputs=image_output_segmentation)
215 | image_output_segmentation.select(fn=segment_frame, inputs=[point_prompt, state], outputs=image_output_segmentation)
216 | clear_btn.click(clear_clicks, inputs=state, outputs=image_output_segmentation)
217 | remove_btn.click(
218 | inference_and_return_image,
219 | inputs=[dilation_slider, inference_steps_slider, state],
220 | outputs=image_output_removed
221 | )
222 |
223 | demo.launch(server_name="0.0.0.0", server_port=8000)
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/0.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/0.mp4
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/1.mp4
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/2.mp4
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/3.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/3.mp4
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/4.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/4.mp4
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/5.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/5.mp4
--------------------------------------------------------------------------------
/gradio_demo/normal_videos/6.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/6.mp4
--------------------------------------------------------------------------------
/gradio_demo/pipeline_minimax_remover.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Optional, Union
2 |
3 | import torch
4 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
5 | from diffusers.models import AutoencoderKLWan
6 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
7 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
8 | from diffusers.utils.torch_utils import randn_tensor
9 | from diffusers.video_processor import VideoProcessor
10 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
12 |
13 | import scipy
14 | import numpy as np
15 | import torch.nn.functional as F
16 | from transformer_minimax_remover import Transformer3DModel
17 | from einops import rearrange
18 |
19 | if is_torch_xla_available():
20 | import torch_xla.core.xla_model as xm
21 |
22 | XLA_AVAILABLE = True
23 | else:
24 | XLA_AVAILABLE = False
25 |
26 | class Minimax_Remover_Pipeline(DiffusionPipeline):
27 |
28 | model_cpu_offload_seq = "transformer->vae"
29 | _callback_tensor_inputs = ["latents"]
30 |
31 | def __init__(
32 | self,
33 | transformer: Transformer3DModel,
34 | vae: AutoencoderKLWan,
35 | scheduler: FlowMatchEulerDiscreteScheduler
36 | ):
37 | super().__init__()
38 |
39 | self.register_modules(
40 | vae=vae,
41 | transformer=transformer,
42 | scheduler=scheduler,
43 | )
44 |
45 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
46 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
47 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
48 |
49 | def prepare_latents(
50 | self,
51 | batch_size: int,
52 | num_channels_latents: 16,
53 | height: int = 720,
54 | width: int = 1280,
55 | num_latent_frames: int = 21,
56 | dtype: Optional[torch.dtype] = None,
57 | device: Optional[torch.device] = None,
58 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
59 | latents: Optional[torch.Tensor] = None,
60 | ) -> torch.Tensor:
61 | if latents is not None:
62 | return latents.to(device=device, dtype=dtype)
63 |
64 | shape = (
65 | batch_size,
66 | num_channels_latents,
67 | num_latent_frames,
68 | int(height) // self.vae_scale_factor_spatial,
69 | int(width) // self.vae_scale_factor_spatial,
70 | )
71 |
72 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
73 | return latents
74 |
75 | def expand_masks(self, masks, iterations):
76 | masks = masks.cpu().detach().numpy()
77 | # numpy array, masks [0,1], f h w c
78 | masks2 = []
79 | for i in range(len(masks)):
80 | mask = masks[i]
81 | mask = mask > 0
82 | mask = scipy.ndimage.binary_dilation(mask, iterations=iterations)
83 | masks2.append(mask)
84 | masks = np.array(masks2).astype(np.float32)
85 | masks = torch.from_numpy(masks)
86 | masks = masks.repeat(1,1,1,3)
87 | masks = rearrange(masks, "f h w c -> c f h w")
88 | masks = masks[None,...]
89 | return masks
90 |
91 | def resize(self, images, w, h):
92 | bsz,_,_,_,_ = images.shape
93 | images = rearrange(images, "b c f w h -> (b f) c w h")
94 | images = F.interpolate(images, (w,h), mode='bilinear')
95 | images = rearrange(images, "(b f) c w h -> b c f w h", b=bsz)
96 | return images
97 |
98 | @property
99 | def num_timesteps(self):
100 | return self._num_timesteps
101 |
102 | @property
103 | def current_timestep(self):
104 | return self._current_timestep
105 |
106 | @property
107 | def interrupt(self):
108 | return self._interrupt
109 |
110 | @torch.no_grad()
111 | def __call__(
112 | self,
113 | height: int = 720,
114 | width: int = 1280,
115 | num_frames: int = 81,
116 | num_inference_steps: int = 50,
117 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
118 | images: Optional[torch.Tensor] = None,
119 | masks: Optional[torch.Tensor] = None,
120 | latents: Optional[torch.Tensor] = None,
121 | output_type: Optional[str] = "np",
122 | iterations: int = 16
123 | ):
124 |
125 | self._current_timestep = None
126 | self._interrupt = False
127 | device = self._execution_device
128 | batch_size = 1
129 | transformer_dtype = torch.float16
130 |
131 | self.scheduler.set_timesteps(num_inference_steps, device=device)
132 | timesteps = self.scheduler.timesteps
133 |
134 | num_channels_latents = 16
135 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
136 |
137 | latents = self.prepare_latents(
138 | batch_size,
139 | num_channels_latents,
140 | height,
141 | width,
142 | num_latent_frames,
143 | torch.float16,
144 | device,
145 | generator,
146 | latents,
147 | )
148 |
149 | masks = self.expand_masks(masks, iterations)
150 | masks = self.resize(masks, height, width).to("cuda:0").half()
151 | masks[masks>0] = 1
152 | images = rearrange(images, "f h w c -> c f h w")
153 | images = self.resize(images[None,...], height, width).to("cuda:0").half()
154 |
155 | masked_images = images * (1-masks)
156 |
157 | latents_mean = (
158 | torch.tensor(self.vae.config.latents_mean)
159 | .view(1, self.vae.config.z_dim, 1, 1, 1)
160 | .to(self.vae.device, torch.float16)
161 | )
162 |
163 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
164 | self.vae.device, torch.float16
165 | )
166 |
167 | with torch.no_grad():
168 | masked_latents = self.vae.encode(masked_images.half()).latent_dist.mode()
169 | masks_latents = self.vae.encode(2*masks.half()-1.0).latent_dist.mode()
170 |
171 | masked_latents = (masked_latents - latents_mean) * latents_std
172 | masks_latents = (masks_latents - latents_mean) * latents_std
173 |
174 | self._num_timesteps = len(timesteps)
175 |
176 | with self.progress_bar(total=num_inference_steps) as progress_bar:
177 | for i, t in enumerate(timesteps):
178 |
179 | latent_model_input = latents.to(transformer_dtype)
180 |
181 | #print("latent_model_input, masked_latents, masks_latents", latent_model_input.shape, masked_latents.shape, masks_latents.shape)
182 | latent_model_input = torch.cat([latent_model_input, masked_latents, masks_latents], dim=1)
183 | timestep = t.expand(latents.shape[0])
184 |
185 | noise_pred = self.transformer(
186 | hidden_states=latent_model_input.half(),
187 | timestep=timestep
188 | )[0]
189 |
190 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
191 |
192 | progress_bar.update()
193 |
194 | latents = latents.half() / latents_std + latents_mean
195 | video = self.vae.decode(latents, return_dict=False)[0]
196 | video = self.video_processor.postprocess_video(video, output_type=output_type)
197 |
198 | return WanPipelineOutput(frames=video)
199 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/SAM2-Video-Predictor/checkpoints/README.md:
--------------------------------------------------------------------------------
1 | ---
2 | license: apache-2.0
3 | pipeline_tag: mask-generation
4 | library_name: sam2
5 | ---
6 |
7 | Repository for SAM 2: Segment Anything in Images and Videos, a foundation model towards solving promptable visual segmentation in images and videos from FAIR. See the [SAM 2 paper](https://arxiv.org/abs/2408.00714) for more information.
8 |
9 | The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/).
10 |
11 | ## Usage
12 |
13 | For image prediction:
14 |
15 | ```python
16 | import torch
17 | from sam2.sam2_image_predictor import SAM2ImagePredictor
18 |
19 | predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
20 |
21 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
22 | predictor.set_image()
23 | masks, _, _ = predictor.predict()
24 | ```
25 |
26 | For video prediction:
27 |
28 | ```python
29 | import torch
30 | from sam2.sam2_video_predictor import SAM2VideoPredictor
31 |
32 | predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large")
33 |
34 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
35 | state = predictor.init_state()
36 |
37 | # add new prompts and instantly get the output on the same frame
38 | frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ):
39 |
40 | # propagate the prompts to get masklets throughout the video
41 | for frame_idx, object_ids, masks in predictor.propagate_in_video(state):
42 | ...
43 | ```
44 |
45 | Refer to the [demo notebooks](https://github.com/facebookresearch/segment-anything-2/tree/main/notebooks) for details.
46 |
47 | ### Citation
48 |
49 | To cite the paper, model, or software, please use the below:
50 | ```
51 | @article{ravi2024sam2,
52 | title={SAM 2: Segment Anything in Images and Videos},
53 | author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph},
54 | journal={arXiv preprint arXiv:2408.00714},
55 | url={https://arxiv.org/abs/2408.00714},
56 | year={2024}
57 | }
58 | ```
--------------------------------------------------------------------------------
/gradio_demo/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
--------------------------------------------------------------------------------
/gradio_demo/sam2/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from hydra import initialize
8 |
9 | from .build_sam import load_model
10 |
11 | initialize("configs", version_base="1.2")
12 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/build_sam.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 |
9 | import torch
10 | from hydra import compose
11 | from hydra.utils import instantiate
12 | from omegaconf import OmegaConf
13 |
14 | from .utils.misc import VARIANTS, variant_to_config_mapping
15 |
16 |
17 | def load_model(
18 | variant: str,
19 | ckpt_path=None,
20 | device="cpu",
21 | mode="eval",
22 | hydra_overrides_extra=[],
23 | apply_postprocessing=True,
24 | ) -> torch.nn.Module:
25 | assert variant in VARIANTS, f"only accepted variants are {VARIANTS}"
26 |
27 | return build_sam2(
28 | config_file=variant_to_config_mapping[variant],
29 | ckpt_path=ckpt_path,
30 | device=device,
31 | mode=mode,
32 | hydra_overrides_extra=hydra_overrides_extra,
33 | apply_postprocessing=apply_postprocessing,
34 | )
35 |
36 |
37 | def build_sam2(
38 | config_file,
39 | ckpt_path=None,
40 | device="cpu",
41 | mode="eval",
42 | hydra_overrides_extra=[],
43 | apply_postprocessing=True,
44 | ):
45 |
46 | if apply_postprocessing:
47 | hydra_overrides_extra = hydra_overrides_extra.copy()
48 | hydra_overrides_extra += [
49 | # dynamically fall back to multi-mask if the single mask is not stable
50 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
51 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
52 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
53 | ]
54 | # Read config and init model
55 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra)
56 | OmegaConf.resolve(cfg)
57 | model = instantiate(cfg.model, _recursive_=True)
58 | _load_checkpoint(model, ckpt_path)
59 | model = model.to(device)
60 | if mode == "eval":
61 | model.eval()
62 | return model
63 |
64 |
65 | def build_sam2_video_predictor(
66 | config_file,
67 | ckpt_path=None,
68 | device="cpu",
69 | mode="eval",
70 | hydra_overrides_extra=[],
71 | apply_postprocessing=True,
72 | ):
73 | hydra_overrides = [
74 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor",
75 | ]
76 | if apply_postprocessing:
77 | hydra_overrides_extra = hydra_overrides_extra.copy()
78 | hydra_overrides_extra += [
79 | # dynamically fall back to multi-mask if the single mask is not stable
80 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true",
81 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05",
82 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98",
83 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking
84 | "++model.binarize_mask_from_pts_for_mem_enc=true",
85 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution)
86 | # "++model.fill_hole_area=8",
87 | ]
88 | hydra_overrides.extend(hydra_overrides_extra)
89 |
90 | # Read config and init model
91 | cfg = compose(config_name=config_file, overrides=hydra_overrides)
92 | OmegaConf.resolve(cfg)
93 | model = instantiate(cfg.model, _recursive_=True)
94 | _load_checkpoint(model, ckpt_path)
95 | model = model.to(device)
96 | if mode == "eval":
97 | model.eval()
98 | return model
99 |
100 |
101 | def _load_checkpoint(model, ckpt_path):
102 | if ckpt_path is not None:
103 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"]
104 | missing_keys, unexpected_keys = model.load_state_dict(sd)
105 | if missing_keys:
106 | logging.error(missing_keys)
107 | raise RuntimeError()
108 | if unexpected_keys:
109 | logging.error(unexpected_keys)
110 | raise RuntimeError()
111 | logging.info("Loaded checkpoint sucessfully")
112 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/configs/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/configs/sam2_hiera_b+.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 112
12 | num_heads: 2
13 | neck:
14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
15 | position_encoding:
16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
17 | num_pos_feats: 256
18 | normalize: true
19 | scale: null
20 | temperature: 10000
21 | d_model: 256
22 | backbone_channel_list: [896, 448, 224, 112]
23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
24 | fpn_interp_model: nearest
25 |
26 | memory_attention:
27 | _target_: sam2.modeling.memory_attention.MemoryAttention
28 | d_model: 256
29 | pos_enc_at_input: true
30 | layer:
31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
32 | activation: relu
33 | dim_feedforward: 2048
34 | dropout: 0.1
35 | pos_enc_at_attn: false
36 | self_attention:
37 | _target_: sam2.modeling.sam.transformer.RoPEAttention
38 | rope_theta: 10000.0
39 | feat_sizes: [32, 32]
40 | embedding_dim: 256
41 | num_heads: 1
42 | downsample_rate: 1
43 | dropout: 0.1
44 | d_model: 256
45 | pos_enc_at_cross_attn_keys: true
46 | pos_enc_at_cross_attn_queries: false
47 | cross_attention:
48 | _target_: sam2.modeling.sam.transformer.RoPEAttention
49 | rope_theta: 10000.0
50 | feat_sizes: [32, 32]
51 | rope_k_repeat: True
52 | embedding_dim: 256
53 | num_heads: 1
54 | downsample_rate: 1
55 | dropout: 0.1
56 | kv_in_dim: 64
57 | num_layers: 4
58 |
59 | memory_encoder:
60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
61 | out_dim: 64
62 | position_encoding:
63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
64 | num_pos_feats: 64
65 | normalize: true
66 | scale: null
67 | temperature: 10000
68 | mask_downsampler:
69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
70 | kernel_size: 3
71 | stride: 2
72 | padding: 1
73 | fuser:
74 | _target_: sam2.modeling.memory_encoder.Fuser
75 | layer:
76 | _target_: sam2.modeling.memory_encoder.CXBlock
77 | dim: 256
78 | kernel_size: 7
79 | padding: 3
80 | layer_scale_init_value: 1e-6
81 | use_dwconv: True # depth-wise convs
82 | num_layers: 2
83 |
84 | num_maskmem: 7
85 | image_size: 1024
86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
87 | sigmoid_scale_for_mem_enc: 20.0
88 | sigmoid_bias_for_mem_enc: -10.0
89 | use_mask_input_as_output_without_sam: true
90 | # Memory
91 | directly_add_no_mem_embed: true
92 | # use high-resolution feature map in the SAM mask decoder
93 | use_high_res_features_in_sam: true
94 | # output 3 masks on the first click on initial conditioning frames
95 | multimask_output_in_sam: true
96 | # SAM heads
97 | iou_prediction_use_sigmoid: True
98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
99 | use_obj_ptrs_in_encoder: true
100 | add_tpos_enc_to_obj_ptrs: false
101 | only_obj_ptrs_in_the_past_for_eval: true
102 | # object occlusion prediction
103 | pred_obj_scores: true
104 | pred_obj_scores_mlp: true
105 | fixed_no_obj_ptr: true
106 | # multimask tracking settings
107 | multimask_output_for_tracking: true
108 | use_multimask_token_for_obj_ptr: true
109 | multimask_min_pt_num: 0
110 | multimask_max_pt_num: 1
111 | use_mlp_for_obj_ptr_proj: true
112 | # Compilation flag
113 | compile_image_encoder: False
114 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/configs/sam2_hiera_l.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 144
12 | num_heads: 2
13 | stages: [2, 6, 36, 4]
14 | global_att_blocks: [23, 33, 43]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | window_spec: [8, 4, 16, 8]
17 | neck:
18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
19 | position_encoding:
20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
21 | num_pos_feats: 256
22 | normalize: true
23 | scale: null
24 | temperature: 10000
25 | d_model: 256
26 | backbone_channel_list: [1152, 576, 288, 144]
27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
28 | fpn_interp_model: nearest
29 |
30 | memory_attention:
31 | _target_: sam2.modeling.memory_attention.MemoryAttention
32 | d_model: 256
33 | pos_enc_at_input: true
34 | layer:
35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
36 | activation: relu
37 | dim_feedforward: 2048
38 | dropout: 0.1
39 | pos_enc_at_attn: false
40 | self_attention:
41 | _target_: sam2.modeling.sam.transformer.RoPEAttention
42 | rope_theta: 10000.0
43 | feat_sizes: [32, 32]
44 | embedding_dim: 256
45 | num_heads: 1
46 | downsample_rate: 1
47 | dropout: 0.1
48 | d_model: 256
49 | pos_enc_at_cross_attn_keys: true
50 | pos_enc_at_cross_attn_queries: false
51 | cross_attention:
52 | _target_: sam2.modeling.sam.transformer.RoPEAttention
53 | rope_theta: 10000.0
54 | feat_sizes: [32, 32]
55 | rope_k_repeat: True
56 | embedding_dim: 256
57 | num_heads: 1
58 | downsample_rate: 1
59 | dropout: 0.1
60 | kv_in_dim: 64
61 | num_layers: 4
62 |
63 | memory_encoder:
64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
65 | out_dim: 64
66 | position_encoding:
67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
68 | num_pos_feats: 64
69 | normalize: true
70 | scale: null
71 | temperature: 10000
72 | mask_downsampler:
73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
74 | kernel_size: 3
75 | stride: 2
76 | padding: 1
77 | fuser:
78 | _target_: sam2.modeling.memory_encoder.Fuser
79 | layer:
80 | _target_: sam2.modeling.memory_encoder.CXBlock
81 | dim: 256
82 | kernel_size: 7
83 | padding: 3
84 | layer_scale_init_value: 1e-6
85 | use_dwconv: True # depth-wise convs
86 | num_layers: 2
87 |
88 | num_maskmem: 7
89 | image_size: 1024
90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | compile_image_encoder: False
118 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/configs/sam2_hiera_s.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 11, 2]
14 | global_att_blocks: [7, 10, 13]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | sigmoid_scale_for_mem_enc: 20.0
91 | sigmoid_bias_for_mem_enc: -10.0
92 | use_mask_input_as_output_without_sam: true
93 | # Memory
94 | directly_add_no_mem_embed: true
95 | # use high-resolution feature map in the SAM mask decoder
96 | use_high_res_features_in_sam: true
97 | # output 3 masks on the first click on initial conditioning frames
98 | multimask_output_in_sam: true
99 | # SAM heads
100 | iou_prediction_use_sigmoid: True
101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
102 | use_obj_ptrs_in_encoder: true
103 | add_tpos_enc_to_obj_ptrs: false
104 | only_obj_ptrs_in_the_past_for_eval: true
105 | # object occlusion prediction
106 | pred_obj_scores: true
107 | pred_obj_scores_mlp: true
108 | fixed_no_obj_ptr: true
109 | # multimask tracking settings
110 | multimask_output_for_tracking: true
111 | use_multimask_token_for_obj_ptr: true
112 | multimask_min_pt_num: 0
113 | multimask_max_pt_num: 1
114 | use_mlp_for_obj_ptr_proj: true
115 | # Compilation flag
116 | compile_image_encoder: False
117 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/configs/sam2_hiera_t.yaml:
--------------------------------------------------------------------------------
1 | # @package _global_
2 |
3 | # Model
4 | model:
5 | _target_: sam2.modeling.sam2_base.SAM2Base
6 | image_encoder:
7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder
8 | scalp: 1
9 | trunk:
10 | _target_: sam2.modeling.backbones.hieradet.Hiera
11 | embed_dim: 96
12 | num_heads: 1
13 | stages: [1, 2, 7, 2]
14 | global_att_blocks: [5, 7, 9]
15 | window_pos_embed_bkg_spatial_size: [7, 7]
16 | neck:
17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck
18 | position_encoding:
19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
20 | num_pos_feats: 256
21 | normalize: true
22 | scale: null
23 | temperature: 10000
24 | d_model: 256
25 | backbone_channel_list: [768, 384, 192, 96]
26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features
27 | fpn_interp_model: nearest
28 |
29 | memory_attention:
30 | _target_: sam2.modeling.memory_attention.MemoryAttention
31 | d_model: 256
32 | pos_enc_at_input: true
33 | layer:
34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer
35 | activation: relu
36 | dim_feedforward: 2048
37 | dropout: 0.1
38 | pos_enc_at_attn: false
39 | self_attention:
40 | _target_: sam2.modeling.sam.transformer.RoPEAttention
41 | rope_theta: 10000.0
42 | feat_sizes: [32, 32]
43 | embedding_dim: 256
44 | num_heads: 1
45 | downsample_rate: 1
46 | dropout: 0.1
47 | d_model: 256
48 | pos_enc_at_cross_attn_keys: true
49 | pos_enc_at_cross_attn_queries: false
50 | cross_attention:
51 | _target_: sam2.modeling.sam.transformer.RoPEAttention
52 | rope_theta: 10000.0
53 | feat_sizes: [32, 32]
54 | rope_k_repeat: True
55 | embedding_dim: 256
56 | num_heads: 1
57 | downsample_rate: 1
58 | dropout: 0.1
59 | kv_in_dim: 64
60 | num_layers: 4
61 |
62 | memory_encoder:
63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder
64 | out_dim: 64
65 | position_encoding:
66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine
67 | num_pos_feats: 64
68 | normalize: true
69 | scale: null
70 | temperature: 10000
71 | mask_downsampler:
72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler
73 | kernel_size: 3
74 | stride: 2
75 | padding: 1
76 | fuser:
77 | _target_: sam2.modeling.memory_encoder.Fuser
78 | layer:
79 | _target_: sam2.modeling.memory_encoder.CXBlock
80 | dim: 256
81 | kernel_size: 7
82 | padding: 3
83 | layer_scale_init_value: 1e-6
84 | use_dwconv: True # depth-wise convs
85 | num_layers: 2
86 |
87 | num_maskmem: 7
88 | image_size: 1024
89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask
90 | # SAM decoder
91 | sigmoid_scale_for_mem_enc: 20.0
92 | sigmoid_bias_for_mem_enc: -10.0
93 | use_mask_input_as_output_without_sam: true
94 | # Memory
95 | directly_add_no_mem_embed: true
96 | # use high-resolution feature map in the SAM mask decoder
97 | use_high_res_features_in_sam: true
98 | # output 3 masks on the first click on initial conditioning frames
99 | multimask_output_in_sam: true
100 | # SAM heads
101 | iou_prediction_use_sigmoid: True
102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder
103 | use_obj_ptrs_in_encoder: true
104 | add_tpos_enc_to_obj_ptrs: false
105 | only_obj_ptrs_in_the_past_for_eval: true
106 | # object occlusion prediction
107 | pred_obj_scores: true
108 | pred_obj_scores_mlp: true
109 | fixed_no_obj_ptr: true
110 | # multimask tracking settings
111 | multimask_output_for_tracking: true
112 | use_multimask_token_for_obj_ptr: true
113 | multimask_min_pt_num: 0
114 | multimask_max_pt_num: 1
115 | use_mlp_for_obj_ptr_proj: true
116 | # Compilation flag
117 | # HieraT does not currently support compilation, should always be set to False
118 | compile_image_encoder: False
119 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/backbones/hieradet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from functools import partial
8 | from typing import List, Tuple, Union
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from sam2.modeling.backbones.utils import (
15 | PatchEmbed,
16 | window_partition,
17 | window_unpartition,
18 | )
19 | from sam2.modeling.sam2_utils import MLP, DropPath
20 |
21 |
22 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
23 | if pool is None:
24 | return x
25 | # (B, H, W, C) -> (B, C, H, W)
26 | x = x.permute(0, 3, 1, 2)
27 | x = pool(x)
28 | # (B, C, H', W') -> (B, H', W', C)
29 | x = x.permute(0, 2, 3, 1)
30 | if norm:
31 | x = norm(x)
32 |
33 | return x
34 |
35 |
36 | class MultiScaleAttention(nn.Module):
37 | def __init__(
38 | self,
39 | dim: int,
40 | dim_out: int,
41 | num_heads: int,
42 | q_pool: nn.Module = None,
43 | ):
44 | super().__init__()
45 |
46 | self.dim = dim
47 | self.dim_out = dim_out
48 |
49 | self.num_heads = num_heads
50 | head_dim = dim_out // num_heads
51 | self.scale = head_dim**-0.5
52 |
53 | self.q_pool = q_pool
54 | self.qkv = nn.Linear(dim, dim_out * 3)
55 | self.proj = nn.Linear(dim_out, dim_out)
56 |
57 | def forward(self, x: torch.Tensor) -> torch.Tensor:
58 | B, H, W, _ = x.shape
59 | # qkv with shape (B, H * W, 3, nHead, C)
60 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
61 | # q, k, v with shape (B, H * W, nheads, C)
62 | q, k, v = torch.unbind(qkv, 2)
63 |
64 | # Q pooling (for downsample at stage changes)
65 | if self.q_pool:
66 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
67 | H, W = q.shape[1:3] # downsampled shape
68 | q = q.reshape(B, H * W, self.num_heads, -1)
69 |
70 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
71 | x = F.scaled_dot_product_attention(
72 | q.transpose(1, 2),
73 | k.transpose(1, 2),
74 | v.transpose(1, 2),
75 | )
76 | # Transpose back
77 | x = x.transpose(1, 2)
78 | x = x.reshape(B, H, W, -1)
79 |
80 | x = self.proj(x)
81 |
82 | return x
83 |
84 |
85 | class MultiScaleBlock(nn.Module):
86 | def __init__(
87 | self,
88 | dim: int,
89 | dim_out: int,
90 | num_heads: int,
91 | mlp_ratio: float = 4.0,
92 | drop_path: float = 0.0,
93 | norm_layer: Union[nn.Module, str] = "LayerNorm",
94 | q_stride: Tuple[int, int] = None,
95 | act_layer: nn.Module = nn.GELU,
96 | window_size: int = 0,
97 | ):
98 | super().__init__()
99 |
100 | if isinstance(norm_layer, str):
101 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
102 |
103 | self.dim = dim
104 | self.dim_out = dim_out
105 | self.norm1 = norm_layer(dim)
106 |
107 | self.window_size = window_size
108 |
109 | self.pool, self.q_stride = None, q_stride
110 | if self.q_stride:
111 | self.pool = nn.MaxPool2d(
112 | kernel_size=q_stride, stride=q_stride, ceil_mode=False
113 | )
114 |
115 | self.attn = MultiScaleAttention(
116 | dim,
117 | dim_out,
118 | num_heads=num_heads,
119 | q_pool=self.pool,
120 | )
121 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
122 |
123 | self.norm2 = norm_layer(dim_out)
124 | self.mlp = MLP(
125 | dim_out,
126 | int(dim_out * mlp_ratio),
127 | dim_out,
128 | num_layers=2,
129 | activation=act_layer,
130 | )
131 |
132 | if dim != dim_out:
133 | self.proj = nn.Linear(dim, dim_out)
134 |
135 | def forward(self, x: torch.Tensor) -> torch.Tensor:
136 | shortcut = x # B, H, W, C
137 | x = self.norm1(x)
138 |
139 | # Skip connection
140 | if self.dim != self.dim_out:
141 | shortcut = do_pool(self.proj(x), self.pool)
142 |
143 | # Window partition
144 | window_size = self.window_size
145 | if window_size > 0:
146 | H, W = x.shape[1], x.shape[2]
147 | x, pad_hw = window_partition(x, window_size)
148 |
149 | # Window Attention + Q Pooling (if stage change)
150 | x = self.attn(x)
151 | if self.q_stride:
152 | # Shapes have changed due to Q pooling
153 | window_size = self.window_size // self.q_stride[0]
154 | H, W = shortcut.shape[1:3]
155 |
156 | pad_h = (window_size - H % window_size) % window_size
157 | pad_w = (window_size - W % window_size) % window_size
158 | pad_hw = (H + pad_h, W + pad_w)
159 |
160 | # Reverse window partition
161 | if self.window_size > 0:
162 | x = window_unpartition(x, window_size, pad_hw, (H, W))
163 |
164 | x = shortcut + self.drop_path(x)
165 | # MLP
166 | x = x + self.drop_path(self.mlp(self.norm2(x)))
167 | return x
168 |
169 |
170 | class Hiera(nn.Module):
171 | """
172 | Reference: https://arxiv.org/abs/2306.00989
173 | """
174 |
175 | def __init__(
176 | self,
177 | embed_dim: int = 96, # initial embed dim
178 | num_heads: int = 1, # initial number of heads
179 | drop_path_rate: float = 0.0, # stochastic depth
180 | q_pool: int = 3, # number of q_pool stages
181 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
182 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
183 | dim_mul: float = 2.0, # dim_mul factor at stage shift
184 | head_mul: float = 2.0, # head_mul factor at stage shift
185 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
186 | # window size per stage, when not using global att.
187 | window_spec: Tuple[int, ...] = (
188 | 8,
189 | 4,
190 | 14,
191 | 7,
192 | ),
193 | # global attn in these blocks
194 | global_att_blocks: Tuple[int, ...] = (
195 | 12,
196 | 16,
197 | 20,
198 | ),
199 | return_interm_layers=True, # return feats from every stage
200 | ):
201 | super().__init__()
202 |
203 | assert len(stages) == len(window_spec)
204 | self.window_spec = window_spec
205 |
206 | depth = sum(stages)
207 | self.q_stride = q_stride
208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
209 | assert 0 <= q_pool <= len(self.stage_ends[:-1])
210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
211 | self.return_interm_layers = return_interm_layers
212 |
213 | self.patch_embed = PatchEmbed(
214 | embed_dim=embed_dim,
215 | )
216 | # Which blocks have global att?
217 | self.global_att_blocks = global_att_blocks
218 |
219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
221 | self.pos_embed = nn.Parameter(
222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
223 | )
224 | self.pos_embed_window = nn.Parameter(
225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
226 | )
227 |
228 | dpr = [
229 | x.item() for x in torch.linspace(0, drop_path_rate, depth)
230 | ] # stochastic depth decay rule
231 |
232 | cur_stage = 1
233 | self.blocks = nn.ModuleList()
234 |
235 | for i in range(depth):
236 | dim_out = embed_dim
237 | # lags by a block, so first block of
238 | # next stage uses an initial window size
239 | # of previous stage and final window size of current stage
240 | window_size = self.window_spec[cur_stage - 1]
241 |
242 | if self.global_att_blocks is not None:
243 | window_size = 0 if i in self.global_att_blocks else window_size
244 |
245 | if i - 1 in self.stage_ends:
246 | dim_out = int(embed_dim * dim_mul)
247 | num_heads = int(num_heads * head_mul)
248 | cur_stage += 1
249 |
250 | block = MultiScaleBlock(
251 | dim=embed_dim,
252 | dim_out=dim_out,
253 | num_heads=num_heads,
254 | drop_path=dpr[i],
255 | q_stride=self.q_stride if i in self.q_pool_blocks else None,
256 | window_size=window_size,
257 | )
258 |
259 | embed_dim = dim_out
260 | self.blocks.append(block)
261 |
262 | self.channel_list = (
263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
264 | if return_interm_layers
265 | else [self.blocks[-1].dim_out]
266 | )
267 |
268 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
269 | h, w = hw
270 | window_embed = self.pos_embed_window
271 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
272 | pos_embed = pos_embed + window_embed.tile(
273 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
274 | )
275 | pos_embed = pos_embed.permute(0, 2, 3, 1)
276 | return pos_embed
277 |
278 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
279 | x = self.patch_embed(x)
280 | # x: (B, H, W, C)
281 |
282 | # Add pos embed
283 | x = x + self._get_pos_embed(x.shape[1:3])
284 |
285 | outputs = []
286 | for i, blk in enumerate(self.blocks):
287 | x = blk(x)
288 | if (i == self.stage_ends[-1]) or (
289 | i in self.stage_ends and self.return_interm_layers
290 | ):
291 | feats = x.permute(0, 3, 1, 2)
292 | outputs.append(feats)
293 |
294 | return outputs
295 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/backbones/image_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 |
14 | class ImageEncoder(nn.Module):
15 | def __init__(
16 | self,
17 | trunk: nn.Module,
18 | neck: nn.Module,
19 | scalp: int = 0,
20 | ):
21 | super().__init__()
22 | self.trunk = trunk
23 | self.neck = neck
24 | self.scalp = scalp
25 | assert (
26 | self.trunk.channel_list == self.neck.backbone_channel_list
27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}"
28 |
29 | def forward(self, sample: torch.Tensor):
30 | # Forward through backbone
31 | features, pos = self.neck(self.trunk(sample))
32 | if self.scalp > 0:
33 | # Discard the lowest resolution features
34 | features, pos = features[: -self.scalp], pos[: -self.scalp]
35 |
36 | src = features[-1]
37 | output = {
38 | "vision_features": src,
39 | "vision_pos_enc": pos,
40 | "backbone_fpn": features,
41 | }
42 | return output
43 |
44 |
45 | class FpnNeck(nn.Module):
46 | """
47 | A modified variant of Feature Pyramid Network (FPN) neck
48 | (we remove output conv and also do bicubic interpolation similar to ViT
49 | pos embed interpolation)
50 | """
51 |
52 | def __init__(
53 | self,
54 | position_encoding: nn.Module,
55 | d_model: int,
56 | backbone_channel_list: List[int],
57 | kernel_size: int = 1,
58 | stride: int = 1,
59 | padding: int = 0,
60 | fpn_interp_model: str = "bilinear",
61 | fuse_type: str = "sum",
62 | fpn_top_down_levels: Optional[List[int]] = None,
63 | ):
64 | """Initialize the neck
65 | :param trunk: the backbone
66 | :param position_encoding: the positional encoding to use
67 | :param d_model: the dimension of the model
68 | :param neck_norm: the normalization to use
69 | """
70 | super().__init__()
71 | self.position_encoding = position_encoding
72 | self.convs = nn.ModuleList()
73 | self.backbone_channel_list = backbone_channel_list
74 | for dim in backbone_channel_list:
75 | current = nn.Sequential()
76 | current.add_module(
77 | "conv",
78 | nn.Conv2d(
79 | in_channels=dim,
80 | out_channels=d_model,
81 | kernel_size=kernel_size,
82 | stride=stride,
83 | padding=padding,
84 | ),
85 | )
86 |
87 | self.convs.append(current)
88 | self.fpn_interp_model = fpn_interp_model
89 | assert fuse_type in ["sum", "avg"]
90 | self.fuse_type = fuse_type
91 |
92 | # levels to have top-down features in its outputs
93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3
94 | # have top-down propagation, while outputs of level 0 and level 1 have only
95 | # lateral features from the same backbone level.
96 | if fpn_top_down_levels is None:
97 | # default is to have top-down features on all levels
98 | fpn_top_down_levels = range(len(self.convs))
99 | self.fpn_top_down_levels = list(fpn_top_down_levels)
100 |
101 | def forward(self, xs: List[torch.Tensor]):
102 |
103 | out = [None] * len(self.convs)
104 | pos = [None] * len(self.convs)
105 | assert len(xs) == len(self.convs)
106 | # fpn forward pass
107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py
108 | prev_features = None
109 | # forward in top-down order (from low to high resolution)
110 | n = len(self.convs) - 1
111 | for i in range(n, -1, -1):
112 | x = xs[i]
113 | lateral_features = self.convs[n - i](x)
114 | if i in self.fpn_top_down_levels and prev_features is not None:
115 | top_down_features = F.interpolate(
116 | prev_features.to(dtype=torch.float32),
117 | scale_factor=2.0,
118 | mode=self.fpn_interp_model,
119 | align_corners=(
120 | None if self.fpn_interp_model == "nearest" else False
121 | ),
122 | antialias=False,
123 | )
124 | prev_features = lateral_features + top_down_features
125 | if self.fuse_type == "avg":
126 | prev_features /= 2
127 | else:
128 | prev_features = lateral_features
129 | x_out = prev_features
130 | out[i] = x_out
131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype)
132 |
133 | return out, pos
134 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | """Some utilities for backbones, in particular for windowing"""
8 |
9 | from typing import Tuple
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | def window_partition(x, window_size):
17 | """
18 | Partition into non-overlapping windows with padding if needed.
19 | Args:
20 | x (tensor): input tokens with [B, H, W, C].
21 | window_size (int): window size.
22 | Returns:
23 | windows: windows after partition with [B * num_windows, window_size, window_size, C].
24 | (Hp, Wp): padded height and width before partition
25 | """
26 | B, H, W, C = x.shape
27 |
28 | pad_h = (window_size - H % window_size) % window_size
29 | pad_w = (window_size - W % window_size) % window_size
30 | if pad_h > 0 or pad_w > 0:
31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
32 | Hp, Wp = H + pad_h, W + pad_w
33 |
34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C)
35 | windows = (
36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
37 | )
38 | return windows, (Hp, Wp)
39 |
40 |
41 | def window_unpartition(windows, window_size, pad_hw, hw):
42 | """
43 | Window unpartition into original sequences and removing padding.
44 | Args:
45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C].
46 | window_size (int): window size.
47 | pad_hw (Tuple): padded height and width (Hp, Wp).
48 | hw (Tuple): original height and width (H, W) before padding.
49 | Returns:
50 | x: unpartitioned sequences with [B, H, W, C].
51 | """
52 | Hp, Wp = pad_hw
53 | H, W = hw
54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size)
55 | x = windows.view(
56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1
57 | )
58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1)
59 |
60 | if Hp > H or Wp > W:
61 | x = x[:, :H, :W, :].contiguous()
62 | return x
63 |
64 |
65 | class PatchEmbed(nn.Module):
66 | """
67 | Image to Patch Embedding.
68 | """
69 |
70 | def __init__(
71 | self,
72 | kernel_size: Tuple[int, ...] = (7, 7),
73 | stride: Tuple[int, ...] = (4, 4),
74 | padding: Tuple[int, ...] = (3, 3),
75 | in_chans: int = 3,
76 | embed_dim: int = 768,
77 | ):
78 | """
79 | Args:
80 | kernel_size (Tuple): kernel size of the projection layer.
81 | stride (Tuple): stride of the projection layer.
82 | padding (Tuple): padding size of the projection layer.
83 | in_chans (int): Number of input image channels.
84 | embed_dim (int): embed_dim (int): Patch embedding dimension.
85 | """
86 | super().__init__()
87 | self.proj = nn.Conv2d(
88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
89 | )
90 |
91 | def forward(self, x: torch.Tensor) -> torch.Tensor:
92 | x = self.proj(x)
93 | # B C H W -> B H W C
94 | x = x.permute(0, 2, 3, 1)
95 | return x
96 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/memory_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch
10 | from torch import Tensor, nn
11 |
12 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones
13 | from sam2.modeling.sam.transformer import RoPEAttention
14 |
15 |
16 | class MemoryAttentionLayer(nn.Module):
17 |
18 | def __init__(
19 | self,
20 | activation: str,
21 | cross_attention: nn.Module,
22 | d_model: int,
23 | dim_feedforward: int,
24 | dropout: float,
25 | pos_enc_at_attn: bool,
26 | pos_enc_at_cross_attn_keys: bool,
27 | pos_enc_at_cross_attn_queries: bool,
28 | self_attention: nn.Module,
29 | ):
30 | super().__init__()
31 | self.d_model = d_model
32 | self.dim_feedforward = dim_feedforward
33 | self.dropout_value = dropout
34 | self.self_attn = self_attention
35 | self.cross_attn_image = cross_attention
36 |
37 | # Implementation of Feedforward model
38 | self.linear1 = nn.Linear(d_model, dim_feedforward)
39 | self.dropout = nn.Dropout(dropout)
40 | self.linear2 = nn.Linear(dim_feedforward, d_model)
41 |
42 | self.norm1 = nn.LayerNorm(d_model)
43 | self.norm2 = nn.LayerNorm(d_model)
44 | self.norm3 = nn.LayerNorm(d_model)
45 | self.dropout1 = nn.Dropout(dropout)
46 | self.dropout2 = nn.Dropout(dropout)
47 | self.dropout3 = nn.Dropout(dropout)
48 |
49 | self.activation_str = activation
50 | self.activation = get_activation_fn(activation)
51 |
52 | # Where to add pos enc
53 | self.pos_enc_at_attn = pos_enc_at_attn
54 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries
55 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys
56 |
57 | def _forward_sa(self, tgt, query_pos):
58 | # Self-Attention
59 | tgt2 = self.norm1(tgt)
60 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2
61 | tgt2 = self.self_attn(q, k, v=tgt2)
62 | tgt = tgt + self.dropout1(tgt2)
63 | return tgt
64 |
65 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0):
66 | kwds = {}
67 | if num_k_exclude_rope > 0:
68 | assert isinstance(self.cross_attn_image, RoPEAttention)
69 | kwds = {"num_k_exclude_rope": num_k_exclude_rope}
70 |
71 | # Cross-Attention
72 | tgt2 = self.norm2(tgt)
73 | tgt2 = self.cross_attn_image(
74 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2,
75 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory,
76 | v=memory,
77 | **kwds,
78 | )
79 | tgt = tgt + self.dropout2(tgt2)
80 | return tgt
81 |
82 | def forward(
83 | self,
84 | tgt,
85 | memory,
86 | pos: Optional[Tensor] = None,
87 | query_pos: Optional[Tensor] = None,
88 | num_k_exclude_rope: int = 0,
89 | ) -> torch.Tensor:
90 |
91 | # Self-Attn, Cross-Attn
92 | tgt = self._forward_sa(tgt, query_pos)
93 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope)
94 | # MLP
95 | tgt2 = self.norm3(tgt)
96 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
97 | tgt = tgt + self.dropout3(tgt2)
98 | return tgt
99 |
100 |
101 | class MemoryAttention(nn.Module):
102 | def __init__(
103 | self,
104 | d_model: int,
105 | pos_enc_at_input: bool,
106 | layer: nn.Module,
107 | num_layers: int,
108 | batch_first: bool = True, # Do layers expect batch first input?
109 | ):
110 | super().__init__()
111 | self.d_model = d_model
112 | self.layers = get_clones(layer, num_layers)
113 | self.num_layers = num_layers
114 | self.norm = nn.LayerNorm(d_model)
115 | self.pos_enc_at_input = pos_enc_at_input
116 | self.batch_first = batch_first
117 |
118 | def forward(
119 | self,
120 | curr: torch.Tensor, # self-attention inputs
121 | memory: torch.Tensor, # cross-attention inputs
122 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs
123 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs
124 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens*
125 | ):
126 | if isinstance(curr, list):
127 | assert isinstance(curr_pos, list)
128 | assert len(curr) == len(curr_pos) == 1
129 | curr, curr_pos = (
130 | curr[0],
131 | curr_pos[0],
132 | )
133 |
134 | assert (
135 | curr.shape[1] == memory.shape[1]
136 | ), "Batch size must be the same for curr and memory"
137 |
138 | output = curr
139 | if self.pos_enc_at_input and curr_pos is not None:
140 | output = output + 0.1 * curr_pos
141 |
142 | if self.batch_first:
143 | # Convert to batch first
144 | output = output.transpose(0, 1)
145 | curr_pos = curr_pos.transpose(0, 1)
146 | memory = memory.transpose(0, 1)
147 | memory_pos = memory_pos.transpose(0, 1)
148 |
149 | for layer in self.layers:
150 | kwds = {}
151 | if isinstance(layer.cross_attn_image, RoPEAttention):
152 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens}
153 |
154 | output = layer(
155 | tgt=output,
156 | memory=memory,
157 | pos=memory_pos,
158 | query_pos=curr_pos,
159 | **kwds,
160 | )
161 | normed_output = self.norm(output)
162 |
163 | if self.batch_first:
164 | # Convert back to seq first
165 | normed_output = normed_output.transpose(0, 1)
166 | curr_pos = curr_pos.transpose(0, 1)
167 |
168 | return normed_output
169 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/memory_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Tuple
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 | from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones
15 |
16 |
17 | class MaskDownSampler(nn.Module):
18 | """
19 | Progressively downsample a mask by total_stride, each time by stride.
20 | Note that LayerNorm is applied per *token*, like in ViT.
21 |
22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor.
23 | In the end, we linearly project to embed_dim channels.
24 | """
25 |
26 | def __init__(
27 | self,
28 | embed_dim=256,
29 | kernel_size=4,
30 | stride=4,
31 | padding=0,
32 | total_stride=16,
33 | activation=nn.GELU,
34 | ):
35 | super().__init__()
36 | num_layers = int(math.log2(total_stride) // math.log2(stride))
37 | assert stride**num_layers == total_stride
38 | self.encoder = nn.Sequential()
39 | mask_in_chans, mask_out_chans = 1, 1
40 | for _ in range(num_layers):
41 | mask_out_chans = mask_in_chans * (stride**2)
42 | self.encoder.append(
43 | nn.Conv2d(
44 | mask_in_chans,
45 | mask_out_chans,
46 | kernel_size=kernel_size,
47 | stride=stride,
48 | padding=padding,
49 | )
50 | )
51 | self.encoder.append(LayerNorm2d(mask_out_chans))
52 | self.encoder.append(activation())
53 | mask_in_chans = mask_out_chans
54 |
55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1))
56 |
57 | def forward(self, x):
58 | return self.encoder(x)
59 |
60 |
61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt)
62 | class CXBlock(nn.Module):
63 | r"""ConvNeXt Block. There are two equivalent implementations:
64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
66 | We use (2) as we find it slightly faster in PyTorch
67 |
68 | Args:
69 | dim (int): Number of input channels.
70 | drop_path (float): Stochastic depth rate. Default: 0.0
71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
72 | """
73 |
74 | def __init__(
75 | self,
76 | dim,
77 | kernel_size=7,
78 | padding=3,
79 | drop_path=0.0,
80 | layer_scale_init_value=1e-6,
81 | use_dwconv=True,
82 | ):
83 | super().__init__()
84 | self.dwconv = nn.Conv2d(
85 | dim,
86 | dim,
87 | kernel_size=kernel_size,
88 | padding=padding,
89 | groups=dim if use_dwconv else 1,
90 | ) # depthwise conv
91 | self.norm = LayerNorm2d(dim, eps=1e-6)
92 | self.pwconv1 = nn.Linear(
93 | dim, 4 * dim
94 | ) # pointwise/1x1 convs, implemented with linear layers
95 | self.act = nn.GELU()
96 | self.pwconv2 = nn.Linear(4 * dim, dim)
97 | self.gamma = (
98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
99 | if layer_scale_init_value > 0
100 | else None
101 | )
102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
103 |
104 | def forward(self, x):
105 | input = x
106 | x = self.dwconv(x)
107 | x = self.norm(x)
108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
109 | x = self.pwconv1(x)
110 | x = self.act(x)
111 | x = self.pwconv2(x)
112 | if self.gamma is not None:
113 | x = self.gamma * x
114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
115 |
116 | x = input + self.drop_path(x)
117 | return x
118 |
119 |
120 | class Fuser(nn.Module):
121 | def __init__(self, layer, num_layers, dim=None, input_projection=False):
122 | super().__init__()
123 | self.proj = nn.Identity()
124 | self.layers = get_clones(layer, num_layers)
125 |
126 | if input_projection:
127 | assert dim is not None
128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1)
129 |
130 | def forward(self, x):
131 | # normally x: (N, C, H, W)
132 | x = self.proj(x)
133 | for layer in self.layers:
134 | x = layer(x)
135 | return x
136 |
137 |
138 | class MemoryEncoder(nn.Module):
139 | def __init__(
140 | self,
141 | out_dim,
142 | mask_downsampler,
143 | fuser,
144 | position_encoding,
145 | in_dim=256, # in_dim of pix_feats
146 | ):
147 | super().__init__()
148 |
149 | self.mask_downsampler = mask_downsampler
150 |
151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1)
152 | self.fuser = fuser
153 | self.position_encoding = position_encoding
154 | self.out_proj = nn.Identity()
155 | if out_dim != in_dim:
156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1)
157 |
158 | def forward(
159 | self,
160 | pix_feat: torch.Tensor,
161 | masks: torch.Tensor,
162 | skip_mask_sigmoid: bool = False,
163 | ) -> Tuple[torch.Tensor, torch.Tensor]:
164 | ## Process masks
165 | # sigmoid, so that less domain shift from gt masks which are bool
166 | if not skip_mask_sigmoid:
167 | masks = F.sigmoid(masks)
168 | masks = self.mask_downsampler(masks)
169 |
170 | ## Fuse pix_feats and downsampled masks
171 | # in case the visual features are on CPU, cast them to CUDA
172 | pix_feat = pix_feat.to(masks.device)
173 |
174 | x = self.pix_feat_proj(pix_feat)
175 | x = x + masks
176 | x = self.fuser(x)
177 | x = self.out_proj(x)
178 |
179 | pos = self.position_encoding(x).to(x.dtype)
180 |
181 | return {"vision_features": x, "vision_pos_enc": [pos]}
182 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from typing import Any, Optional, Tuple
9 |
10 | import numpy as np
11 | import torch
12 | from torch import nn
13 |
14 |
15 | class PositionEmbeddingSine(nn.Module):
16 | """
17 | This is a more standard version of the position embedding, very similar to the one
18 | used by the Attention is all you need paper, generalized to work on images.
19 | """
20 |
21 | def __init__(
22 | self,
23 | num_pos_feats,
24 | temperature: int = 10000,
25 | normalize: bool = True,
26 | scale: Optional[float] = None,
27 | ):
28 | super().__init__()
29 | assert num_pos_feats % 2 == 0, "Expecting even model width"
30 | self.num_pos_feats = num_pos_feats // 2
31 | self.temperature = temperature
32 | self.normalize = normalize
33 | if scale is not None and normalize is False:
34 | raise ValueError("normalize should be True if scale is passed")
35 | if scale is None:
36 | scale = 2 * math.pi
37 | self.scale = scale
38 |
39 | self.cache = {}
40 |
41 | def _encode_xy(self, x, y):
42 | # The positions are expected to be normalized
43 | assert len(x) == len(y) and x.ndim == y.ndim == 1
44 | x_embed = x * self.scale
45 | y_embed = y * self.scale
46 |
47 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
48 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
49 |
50 | pos_x = x_embed[:, None] / dim_t
51 | pos_y = y_embed[:, None] / dim_t
52 | pos_x = torch.stack(
53 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2
54 | ).flatten(1)
55 | pos_y = torch.stack(
56 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2
57 | ).flatten(1)
58 | return pos_x, pos_y
59 |
60 | @torch.no_grad()
61 | def encode_boxes(self, x, y, w, h):
62 | pos_x, pos_y = self._encode_xy(x, y)
63 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1)
64 | return pos
65 |
66 | encode = encode_boxes # Backwards compatibility
67 |
68 | @torch.no_grad()
69 | def encode_points(self, x, y, labels):
70 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape
71 | assert bx == by and nx == ny and bx == bl and nx == nl
72 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten())
73 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1)
74 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2)
75 | return pos
76 |
77 | @torch.no_grad()
78 | def forward(self, x: torch.Tensor):
79 | cache_key = (x.shape[-2], x.shape[-1])
80 | if cache_key in self.cache:
81 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1)
82 | y_embed = (
83 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device)
84 | .view(1, -1, 1)
85 | .repeat(x.shape[0], 1, x.shape[-1])
86 | )
87 | x_embed = (
88 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
89 | .view(1, 1, -1)
90 | .repeat(x.shape[0], x.shape[-2], 1)
91 | )
92 |
93 | if self.normalize:
94 | eps = 1e-6
95 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
96 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
97 |
98 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
99 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
100 |
101 | pos_x = x_embed[:, :, :, None] / dim_t
102 | pos_y = y_embed[:, :, :, None] / dim_t
103 | pos_x = torch.stack(
104 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
105 | ).flatten(3)
106 | pos_y = torch.stack(
107 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
108 | ).flatten(3)
109 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
110 | self.cache[cache_key] = pos[0]
111 | return pos
112 |
113 |
114 | class PositionEmbeddingRandom(nn.Module):
115 | """
116 | Positional encoding using random spatial frequencies.
117 | """
118 |
119 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None:
120 | super().__init__()
121 | if scale is None or scale <= 0.0:
122 | scale = 1.0
123 | self.register_buffer(
124 | "positional_encoding_gaussian_matrix",
125 | scale * torch.randn((2, num_pos_feats)),
126 | )
127 |
128 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
129 | """Positionally encode points that are normalized to [0,1]."""
130 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
131 | coords = 2 * coords - 1
132 | coords = coords @ self.positional_encoding_gaussian_matrix
133 | coords = 2 * np.pi * coords
134 | # outputs d_1 x ... x d_n x C shape
135 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
136 |
137 | def forward(self, size: Tuple[int, int]) -> torch.Tensor:
138 | """Generate positional encoding for a grid of the specified size."""
139 | h, w = size
140 | device: Any = self.positional_encoding_gaussian_matrix.device
141 | grid = torch.ones((h, w), device=device, dtype=torch.float32)
142 | y_embed = grid.cumsum(dim=0) - 0.5
143 | x_embed = grid.cumsum(dim=1) - 0.5
144 | y_embed = y_embed / h
145 | x_embed = x_embed / w
146 |
147 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1))
148 | return pe.permute(2, 0, 1) # C x H x W
149 |
150 | def forward_with_coords(
151 | self, coords_input: torch.Tensor, image_size: Tuple[int, int]
152 | ) -> torch.Tensor:
153 | """Positionally encode points that are not normalized to [0,1]."""
154 | coords = coords_input.clone()
155 | coords[:, :, 0] = coords[:, :, 0] / image_size[1]
156 | coords[:, :, 1] = coords[:, :, 1] / image_size[0]
157 | return self._pe_encoding(coords.to(torch.float)) # B x N x C
158 |
159 |
160 | # Rotary Positional Encoding, adapted from:
161 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py
162 | # 2. https://github.com/naver-ai/rope-vit
163 | # 3. https://github.com/lucidrains/rotary-embedding-torch
164 |
165 |
166 | def init_t_xy(end_x: int, end_y: int):
167 | t = torch.arange(end_x * end_y, dtype=torch.float32)
168 | t_x = (t % end_x).float()
169 | t_y = torch.div(t, end_x, rounding_mode="floor").float()
170 | return t_x, t_y
171 |
172 |
173 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0):
174 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
175 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
176 |
177 | t_x, t_y = init_t_xy(end_x, end_y)
178 | freqs_x = torch.outer(t_x, freqs_x)
179 | freqs_y = torch.outer(t_y, freqs_y)
180 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
181 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
182 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)
183 |
184 |
185 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
186 | ndim = x.ndim
187 | assert 0 <= 1 < ndim
188 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1])
189 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)]
190 | return freqs_cis.view(*shape)
191 |
192 |
193 | def apply_rotary_enc(
194 | xq: torch.Tensor,
195 | xk: torch.Tensor,
196 | freqs_cis: torch.Tensor,
197 | repeat_freqs_k: bool = False,
198 | ):
199 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
200 | xk_ = (
201 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
202 | if xk.shape[-2] != 0
203 | else None
204 | )
205 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
206 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
207 | if xk_ is None:
208 | # no keys to rotate, due to dropout
209 | return xq_out.type_as(xq).to(xq.device), xk
210 | # repeat freqs along seq_len dim to match k seq_len
211 | if repeat_freqs_k:
212 | r = xk_.shape[-2] // xq_.shape[-2]
213 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1)
214 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
215 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)
216 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/sam/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/sam/mask_decoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import List, Optional, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from sam2.modeling.sam2_utils import MLP, LayerNorm2d
13 |
14 |
15 | class MaskDecoder(nn.Module):
16 | def __init__(
17 | self,
18 | *,
19 | transformer_dim: int,
20 | transformer: nn.Module,
21 | num_multimask_outputs: int = 3,
22 | activation: Type[nn.Module] = nn.GELU,
23 | iou_head_depth: int = 3,
24 | iou_head_hidden_dim: int = 256,
25 | use_high_res_features: bool = False,
26 | iou_prediction_use_sigmoid=False,
27 | dynamic_multimask_via_stability=False,
28 | dynamic_multimask_stability_delta=0.05,
29 | dynamic_multimask_stability_thresh=0.98,
30 | pred_obj_scores: bool = False,
31 | pred_obj_scores_mlp: bool = False,
32 | use_multimask_token_for_obj_ptr: bool = False,
33 | ) -> None:
34 | """
35 | Predicts masks given an image and prompt embeddings, using a
36 | transformer architecture.
37 |
38 | Arguments:
39 | transformer_dim (int): the channel dimension of the transformer
40 | transformer (nn.Module): the transformer used to predict masks
41 | num_multimask_outputs (int): the number of masks to predict
42 | when disambiguating masks
43 | activation (nn.Module): the type of activation to use when
44 | upscaling masks
45 | iou_head_depth (int): the depth of the MLP used to predict
46 | mask quality
47 | iou_head_hidden_dim (int): the hidden dimension of the MLP
48 | used to predict mask quality
49 | """
50 | super().__init__()
51 | self.transformer_dim = transformer_dim
52 | self.transformer = transformer
53 |
54 | self.num_multimask_outputs = num_multimask_outputs
55 |
56 | self.iou_token = nn.Embedding(1, transformer_dim)
57 | self.num_mask_tokens = num_multimask_outputs + 1
58 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
59 |
60 | self.pred_obj_scores = pred_obj_scores
61 | if self.pred_obj_scores:
62 | self.obj_score_token = nn.Embedding(1, transformer_dim)
63 | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr
64 |
65 | self.output_upscaling = nn.Sequential(
66 | nn.ConvTranspose2d(
67 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2
68 | ),
69 | LayerNorm2d(transformer_dim // 4),
70 | activation(),
71 | nn.ConvTranspose2d(
72 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2
73 | ),
74 | activation(),
75 | )
76 | self.use_high_res_features = use_high_res_features
77 | if use_high_res_features:
78 | self.conv_s0 = nn.Conv2d(
79 | transformer_dim, transformer_dim // 8, kernel_size=1, stride=1
80 | )
81 | self.conv_s1 = nn.Conv2d(
82 | transformer_dim, transformer_dim // 4, kernel_size=1, stride=1
83 | )
84 |
85 | self.output_hypernetworks_mlps = nn.ModuleList(
86 | [
87 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3)
88 | for i in range(self.num_mask_tokens)
89 | ]
90 | )
91 |
92 | self.iou_prediction_head = MLP(
93 | transformer_dim,
94 | iou_head_hidden_dim,
95 | self.num_mask_tokens,
96 | iou_head_depth,
97 | sigmoid_output=iou_prediction_use_sigmoid,
98 | )
99 | if self.pred_obj_scores:
100 | self.pred_obj_score_head = nn.Linear(transformer_dim, 1)
101 | if pred_obj_scores_mlp:
102 | self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3)
103 |
104 | # When outputting a single mask, optionally we can dynamically fall back to the best
105 | # multimask output token if the single mask output token gives low stability scores.
106 | self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
107 | self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta
108 | self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh
109 |
110 | def forward(
111 | self,
112 | image_embeddings: torch.Tensor,
113 | image_pe: torch.Tensor,
114 | sparse_prompt_embeddings: torch.Tensor,
115 | dense_prompt_embeddings: torch.Tensor,
116 | multimask_output: bool,
117 | repeat_image: bool,
118 | high_res_features: Optional[List[torch.Tensor]] = None,
119 | ) -> Tuple[torch.Tensor, torch.Tensor]:
120 | """
121 | Predict masks given image and prompt embeddings.
122 |
123 | Arguments:
124 | image_embeddings (torch.Tensor): the embeddings from the image encoder
125 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
126 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
127 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
128 | multimask_output (bool): Whether to return multiple masks or a single
129 | mask.
130 |
131 | Returns:
132 | torch.Tensor: batched predicted masks
133 | torch.Tensor: batched predictions of mask quality
134 | torch.Tensor: batched SAM token for mask output
135 | """
136 | masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
137 | image_embeddings=image_embeddings,
138 | image_pe=image_pe,
139 | sparse_prompt_embeddings=sparse_prompt_embeddings,
140 | dense_prompt_embeddings=dense_prompt_embeddings,
141 | repeat_image=repeat_image,
142 | high_res_features=high_res_features,
143 | )
144 |
145 | # Select the correct mask or masks for output
146 | if multimask_output:
147 | masks = masks[:, 1:, :, :]
148 | iou_pred = iou_pred[:, 1:]
149 | elif self.dynamic_multimask_via_stability and not self.training:
150 | masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
151 | else:
152 | masks = masks[:, 0:1, :, :]
153 | iou_pred = iou_pred[:, 0:1]
154 |
155 | if multimask_output and self.use_multimask_token_for_obj_ptr:
156 | sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
157 | else:
158 | # Take the mask output token. Here we *always* use the token for single mask output.
159 | # At test time, even if we track after 1-click (and using multimask_output=True),
160 | # we still take the single mask token here. The rationale is that we always track
161 | # after multiple clicks during training, so the past tokens seen during training
162 | # are always the single mask token (and we'll let it be the object-memory token).
163 | sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
164 |
165 | # Prepare output
166 | return masks, iou_pred, sam_tokens_out, object_score_logits
167 |
168 | def predict_masks(
169 | self,
170 | image_embeddings: torch.Tensor,
171 | image_pe: torch.Tensor,
172 | sparse_prompt_embeddings: torch.Tensor,
173 | dense_prompt_embeddings: torch.Tensor,
174 | repeat_image: bool,
175 | high_res_features: Optional[List[torch.Tensor]] = None,
176 | ) -> Tuple[torch.Tensor, torch.Tensor]:
177 | """Predicts masks. See 'forward' for more details."""
178 | # Concatenate output tokens
179 | s = 0
180 | if self.pred_obj_scores:
181 | output_tokens = torch.cat(
182 | [
183 | self.obj_score_token.weight,
184 | self.iou_token.weight,
185 | self.mask_tokens.weight,
186 | ],
187 | dim=0,
188 | )
189 | s = 1
190 | else:
191 | output_tokens = torch.cat(
192 | [self.iou_token.weight, self.mask_tokens.weight], dim=0
193 | )
194 | output_tokens = output_tokens.unsqueeze(0).expand(
195 | sparse_prompt_embeddings.size(0), -1, -1
196 | )
197 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
198 |
199 | # Expand per-image data in batch direction to be per-mask
200 | if repeat_image:
201 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
202 | else:
203 | assert image_embeddings.shape[0] == tokens.shape[0]
204 | src = image_embeddings
205 | src = src + dense_prompt_embeddings
206 | assert (
207 | image_pe.size(0) == 1
208 | ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)"
209 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
210 | b, c, h, w = src.shape
211 |
212 | # Run the transformer
213 | hs, src = self.transformer(src, pos_src, tokens)
214 | iou_token_out = hs[:, s, :]
215 | mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :]
216 |
217 | # Upscale mask embeddings and predict masks using the mask tokens
218 | src = src.transpose(1, 2).view(b, c, h, w)
219 | if not self.use_high_res_features:
220 | upscaled_embedding = self.output_upscaling(src)
221 | else:
222 | dc1, ln1, act1, dc2, act2 = self.output_upscaling
223 | feat_s0, feat_s1 = high_res_features
224 | upscaled_embedding = act1(ln1(dc1(src) + feat_s1))
225 | upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0)
226 |
227 | hyper_in_list: List[torch.Tensor] = []
228 | for i in range(self.num_mask_tokens):
229 | hyper_in_list.append(
230 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])
231 | )
232 | hyper_in = torch.stack(hyper_in_list, dim=1)
233 | b, c, h, w = upscaled_embedding.shape
234 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)
235 |
236 | # Generate mask quality predictions
237 | iou_pred = self.iou_prediction_head(iou_token_out)
238 | if self.pred_obj_scores:
239 | assert s == 1
240 | object_score_logits = self.pred_obj_score_head(hs[:, 0, :])
241 | else:
242 | # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1
243 | object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1)
244 |
245 | return masks, iou_pred, mask_tokens_out, object_score_logits
246 |
247 | def _get_stability_scores(self, mask_logits):
248 | """
249 | Compute stability scores of the mask logits based on the IoU between upper and
250 | lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568.
251 | """
252 | mask_logits = mask_logits.flatten(-2)
253 | stability_delta = self.dynamic_multimask_stability_delta
254 | area_i = torch.sum(mask_logits > stability_delta, dim=-1).float()
255 | area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float()
256 | stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0)
257 | return stability_scores
258 |
259 | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores):
260 | """
261 | When outputting a single mask, if the stability score from the current single-mask
262 | output (based on output token 0) falls below a threshold, we instead select from
263 | multi-mask outputs (based on output token 1~3) the mask with the highest predicted
264 | IoU score. This is intended to ensure a valid mask for both clicking and tracking.
265 | """
266 | # The best mask from multimask output tokens (1~3)
267 | multimask_logits = all_mask_logits[:, 1:, :, :]
268 | multimask_iou_scores = all_iou_scores[:, 1:]
269 | best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1)
270 | batch_inds = torch.arange(
271 | multimask_iou_scores.size(0), device=all_iou_scores.device
272 | )
273 | best_multimask_logits = multimask_logits[batch_inds, best_scores_inds]
274 | best_multimask_logits = best_multimask_logits.unsqueeze(1)
275 | best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds]
276 | best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1)
277 |
278 | # The mask from singlemask output token 0 and its stability score
279 | singlemask_logits = all_mask_logits[:, 0:1, :, :]
280 | singlemask_iou_scores = all_iou_scores[:, 0:1]
281 | stability_scores = self._get_stability_scores(singlemask_logits)
282 | is_stable = stability_scores >= self.dynamic_multimask_stability_thresh
283 |
284 | # Dynamically fall back to best multimask output upon low stability scores.
285 | mask_logits_out = torch.where(
286 | is_stable[..., None, None].expand_as(singlemask_logits),
287 | singlemask_logits,
288 | best_multimask_logits,
289 | )
290 | iou_scores_out = torch.where(
291 | is_stable.expand_as(singlemask_iou_scores),
292 | singlemask_iou_scores,
293 | best_multimask_iou_scores,
294 | )
295 | return mask_logits_out, iou_scores_out
296 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/sam/prompt_encoder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional, Tuple, Type
8 |
9 | import torch
10 | from torch import nn
11 |
12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom
13 | from sam2.modeling.sam2_utils import LayerNorm2d
14 |
15 |
16 | class PromptEncoder(nn.Module):
17 | def __init__(
18 | self,
19 | embed_dim: int,
20 | image_embedding_size: Tuple[int, int],
21 | input_image_size: Tuple[int, int],
22 | mask_in_chans: int,
23 | activation: Type[nn.Module] = nn.GELU,
24 | ) -> None:
25 | """
26 | Encodes prompts for input to SAM's mask decoder.
27 |
28 | Arguments:
29 | embed_dim (int): The prompts' embedding dimension
30 | image_embedding_size (tuple(int, int)): The spatial size of the
31 | image embedding, as (H, W).
32 | input_image_size (int): The padded size of the image as input
33 | to the image encoder, as (H, W).
34 | mask_in_chans (int): The number of hidden channels used for
35 | encoding input masks.
36 | activation (nn.Module): The activation to use when encoding
37 | input masks.
38 | """
39 | super().__init__()
40 | self.embed_dim = embed_dim
41 | self.input_image_size = input_image_size
42 | self.image_embedding_size = image_embedding_size
43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2)
44 |
45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners
46 | point_embeddings = [
47 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)
48 | ]
49 | self.point_embeddings = nn.ModuleList(point_embeddings)
50 | self.not_a_point_embed = nn.Embedding(1, embed_dim)
51 |
52 | self.mask_input_size = (
53 | 4 * image_embedding_size[0],
54 | 4 * image_embedding_size[1],
55 | )
56 | self.mask_downscaling = nn.Sequential(
57 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
58 | LayerNorm2d(mask_in_chans // 4),
59 | activation(),
60 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
61 | LayerNorm2d(mask_in_chans),
62 | activation(),
63 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
64 | )
65 | self.no_mask_embed = nn.Embedding(1, embed_dim)
66 |
67 | def get_dense_pe(self) -> torch.Tensor:
68 | """
69 | Returns the positional encoding used to encode point prompts,
70 | applied to a dense set of points the shape of the image encoding.
71 |
72 | Returns:
73 | torch.Tensor: Positional encoding with shape
74 | 1x(embed_dim)x(embedding_h)x(embedding_w)
75 | """
76 | return self.pe_layer(self.image_embedding_size).unsqueeze(0)
77 |
78 | def _embed_points(
79 | self,
80 | points: torch.Tensor,
81 | labels: torch.Tensor,
82 | pad: bool,
83 | ) -> torch.Tensor:
84 | """Embeds point prompts."""
85 | points = points + 0.5 # Shift to center of pixel
86 | if pad:
87 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
88 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
89 | points = torch.cat([points, padding_point], dim=1)
90 | labels = torch.cat([labels, padding_label], dim=1)
91 | point_embedding = self.pe_layer.forward_with_coords(
92 | points, self.input_image_size
93 | )
94 | point_embedding[labels == -1] = 0.0
95 | point_embedding[labels == -1] += self.not_a_point_embed.weight
96 | point_embedding[labels == 0] += self.point_embeddings[0].weight
97 | point_embedding[labels == 1] += self.point_embeddings[1].weight
98 | point_embedding[labels == 2] += self.point_embeddings[2].weight
99 | point_embedding[labels == 3] += self.point_embeddings[3].weight
100 | return point_embedding
101 |
102 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
103 | """Embeds box prompts."""
104 | boxes = boxes + 0.5 # Shift to center of pixel
105 | coords = boxes.reshape(-1, 2, 2)
106 | corner_embedding = self.pe_layer.forward_with_coords(
107 | coords, self.input_image_size
108 | )
109 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight
110 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight
111 | return corner_embedding
112 |
113 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
114 | """Embeds mask inputs."""
115 | mask_embedding = self.mask_downscaling(masks)
116 | return mask_embedding
117 |
118 | def _get_batch_size(
119 | self,
120 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
121 | boxes: Optional[torch.Tensor],
122 | masks: Optional[torch.Tensor],
123 | ) -> int:
124 | """
125 | Gets the batch size of the output given the batch size of the input prompts.
126 | """
127 | if points is not None:
128 | return points[0].shape[0]
129 | elif boxes is not None:
130 | return boxes.shape[0]
131 | elif masks is not None:
132 | return masks.shape[0]
133 | else:
134 | return 1
135 |
136 | def _get_device(self) -> torch.device:
137 | return self.point_embeddings[0].weight.device
138 |
139 | def forward(
140 | self,
141 | points: Optional[Tuple[torch.Tensor, torch.Tensor]],
142 | boxes: Optional[torch.Tensor],
143 | masks: Optional[torch.Tensor],
144 | ) -> Tuple[torch.Tensor, torch.Tensor]:
145 | """
146 | Embeds different types of prompts, returning both sparse and dense
147 | embeddings.
148 |
149 | Arguments:
150 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
151 | and labels to embed.
152 | boxes (torch.Tensor or none): boxes to embed
153 | masks (torch.Tensor or none): masks to embed
154 |
155 | Returns:
156 | torch.Tensor: sparse embeddings for the points and boxes, with shape
157 | BxNx(embed_dim), where N is determined by the number of input points
158 | and boxes.
159 | torch.Tensor: dense embeddings for the masks, in the shape
160 | Bx(embed_dim)x(embed_H)x(embed_W)
161 | """
162 | bs = self._get_batch_size(points, boxes, masks)
163 | sparse_embeddings = torch.empty(
164 | (bs, 0, self.embed_dim), device=self._get_device()
165 | )
166 | if points is not None:
167 | coords, labels = points
168 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
169 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
170 | if boxes is not None:
171 | box_embeddings = self._embed_boxes(boxes)
172 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)
173 |
174 | if masks is not None:
175 | dense_embeddings = self._embed_masks(masks)
176 | else:
177 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
178 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
179 | )
180 |
181 | return sparse_embeddings, dense_embeddings
182 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/sam/transformer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | import warnings
9 | from functools import partial
10 | from typing import Tuple, Type
11 |
12 | import torch
13 | import torch.nn.functional as F
14 | from torch import Tensor, nn
15 |
16 | from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis
17 | from sam2.modeling.sam2_utils import MLP
18 | from sam2.utils.misc import get_sdp_backends
19 |
20 | warnings.simplefilter(action="ignore", category=FutureWarning)
21 | # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
22 |
23 |
24 | class TwoWayTransformer(nn.Module):
25 | def __init__(
26 | self,
27 | depth: int,
28 | embedding_dim: int,
29 | num_heads: int,
30 | mlp_dim: int,
31 | activation: Type[nn.Module] = nn.ReLU,
32 | attention_downsample_rate: int = 2,
33 | ) -> None:
34 | """
35 | A transformer decoder that attends to an input image using
36 | queries whose positional embedding is supplied.
37 |
38 | Args:
39 | depth (int): number of layers in the transformer
40 | embedding_dim (int): the channel dimension for the input embeddings
41 | num_heads (int): the number of heads for multihead attention. Must
42 | divide embedding_dim
43 | mlp_dim (int): the channel dimension internal to the MLP block
44 | activation (nn.Module): the activation to use in the MLP block
45 | """
46 | super().__init__()
47 | self.depth = depth
48 | self.embedding_dim = embedding_dim
49 | self.num_heads = num_heads
50 | self.mlp_dim = mlp_dim
51 | self.layers = nn.ModuleList()
52 |
53 | for i in range(depth):
54 | self.layers.append(
55 | TwoWayAttentionBlock(
56 | embedding_dim=embedding_dim,
57 | num_heads=num_heads,
58 | mlp_dim=mlp_dim,
59 | activation=activation,
60 | attention_downsample_rate=attention_downsample_rate,
61 | skip_first_layer_pe=(i == 0),
62 | )
63 | )
64 |
65 | self.final_attn_token_to_image = Attention(
66 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
67 | )
68 | self.norm_final_attn = nn.LayerNorm(embedding_dim)
69 |
70 | def forward(
71 | self,
72 | image_embedding: Tensor,
73 | image_pe: Tensor,
74 | point_embedding: Tensor,
75 | ) -> Tuple[Tensor, Tensor]:
76 | """
77 | Args:
78 | image_embedding (torch.Tensor): image to attend to. Should be shape
79 | B x embedding_dim x h x w for any h and w.
80 | image_pe (torch.Tensor): the positional encoding to add to the image. Must
81 | have the same shape as image_embedding.
82 | point_embedding (torch.Tensor): the embedding to add to the query points.
83 | Must have shape B x N_points x embedding_dim for any N_points.
84 |
85 | Returns:
86 | torch.Tensor: the processed point_embedding
87 | torch.Tensor: the processed image_embedding
88 | """
89 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C
90 | bs, c, h, w = image_embedding.shape
91 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1)
92 | image_pe = image_pe.flatten(2).permute(0, 2, 1)
93 |
94 | # Prepare queries
95 | queries = point_embedding
96 | keys = image_embedding
97 |
98 | # Apply transformer blocks and final layernorm
99 | for layer in self.layers:
100 | queries, keys = layer(
101 | queries=queries,
102 | keys=keys,
103 | query_pe=point_embedding,
104 | key_pe=image_pe,
105 | )
106 |
107 | # Apply the final attention layer from the points to the image
108 | q = queries + point_embedding
109 | k = keys + image_pe
110 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
111 | queries = queries + attn_out
112 | queries = self.norm_final_attn(queries)
113 |
114 | return queries, keys
115 |
116 |
117 | class TwoWayAttentionBlock(nn.Module):
118 | def __init__(
119 | self,
120 | embedding_dim: int,
121 | num_heads: int,
122 | mlp_dim: int = 2048,
123 | activation: Type[nn.Module] = nn.ReLU,
124 | attention_downsample_rate: int = 2,
125 | skip_first_layer_pe: bool = False,
126 | ) -> None:
127 | """
128 | A transformer block with four layers: (1) self-attention of sparse
129 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
130 | block on sparse inputs, and (4) cross attention of dense inputs to sparse
131 | inputs.
132 |
133 | Arguments:
134 | embedding_dim (int): the channel dimension of the embeddings
135 | num_heads (int): the number of heads in the attention layers
136 | mlp_dim (int): the hidden dimension of the mlp block
137 | activation (nn.Module): the activation of the mlp block
138 | skip_first_layer_pe (bool): skip the PE on the first layer
139 | """
140 | super().__init__()
141 | self.self_attn = Attention(embedding_dim, num_heads)
142 | self.norm1 = nn.LayerNorm(embedding_dim)
143 |
144 | self.cross_attn_token_to_image = Attention(
145 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
146 | )
147 | self.norm2 = nn.LayerNorm(embedding_dim)
148 |
149 | self.mlp = MLP(
150 | embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
151 | )
152 | self.norm3 = nn.LayerNorm(embedding_dim)
153 |
154 | self.norm4 = nn.LayerNorm(embedding_dim)
155 | self.cross_attn_image_to_token = Attention(
156 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate
157 | )
158 |
159 | self.skip_first_layer_pe = skip_first_layer_pe
160 |
161 | def forward(
162 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
163 | ) -> Tuple[Tensor, Tensor]:
164 | # Self attention block
165 | if self.skip_first_layer_pe:
166 | queries = self.self_attn(q=queries, k=queries, v=queries)
167 | else:
168 | q = queries + query_pe
169 | attn_out = self.self_attn(q=q, k=q, v=queries)
170 | queries = queries + attn_out
171 | queries = self.norm1(queries)
172 |
173 | # Cross attention block, tokens attending to image embedding
174 | q = queries + query_pe
175 | k = keys + key_pe
176 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
177 | queries = queries + attn_out
178 | queries = self.norm2(queries)
179 |
180 | # MLP block
181 | mlp_out = self.mlp(queries)
182 | queries = queries + mlp_out
183 | queries = self.norm3(queries)
184 |
185 | # Cross attention block, image embedding attending to tokens
186 | q = queries + query_pe
187 | k = keys + key_pe
188 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
189 | keys = keys + attn_out
190 | keys = self.norm4(keys)
191 |
192 | return queries, keys
193 |
194 |
195 | class Attention(nn.Module):
196 | """
197 | An attention layer that allows for downscaling the size of the embedding
198 | after projection to queries, keys, and values.
199 | """
200 |
201 | def __init__(
202 | self,
203 | embedding_dim: int,
204 | num_heads: int,
205 | downsample_rate: int = 1,
206 | dropout: float = 0.0,
207 | kv_in_dim: int = None,
208 | ) -> None:
209 | super().__init__()
210 | self.embedding_dim = embedding_dim
211 | self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
212 | self.internal_dim = embedding_dim // downsample_rate
213 | self.num_heads = num_heads
214 | assert (
215 | self.internal_dim % num_heads == 0
216 | ), "num_heads must divide embedding_dim."
217 |
218 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
219 | self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
220 | self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
221 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
222 |
223 | self.dropout_p = dropout
224 |
225 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
226 | b, n, c = x.shape
227 | x = x.reshape(b, n, num_heads, c // num_heads)
228 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head
229 |
230 | def _recombine_heads(self, x: Tensor) -> Tensor:
231 | b, n_heads, n_tokens, c_per_head = x.shape
232 | x = x.transpose(1, 2)
233 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C
234 |
235 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
236 | # Input projections
237 | q = self.q_proj(q)
238 | k = self.k_proj(k)
239 | v = self.v_proj(v)
240 |
241 | # Separate into heads
242 | q = self._separate_heads(q, self.num_heads)
243 | k = self._separate_heads(k, self.num_heads)
244 | v = self._separate_heads(v, self.num_heads)
245 |
246 | dropout_p = self.dropout_p if self.training else 0.0
247 | # Attention
248 |
249 | #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
250 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
251 |
252 | out = self._recombine_heads(out)
253 | out = self.out_proj(out)
254 |
255 | return out
256 |
257 |
258 | class RoPEAttention(Attention):
259 | """Attention with rotary position encoding."""
260 |
261 | def __init__(
262 | self,
263 | *args,
264 | rope_theta=10000.0,
265 | # whether to repeat q rope to match k length
266 | # this is needed for cross-attention to memories
267 | rope_k_repeat=False,
268 | feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution
269 | **kwargs,
270 | ):
271 | super().__init__(*args, **kwargs)
272 |
273 | self.compute_cis = partial(
274 | compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta
275 | )
276 | freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1])
277 | self.freqs_cis = freqs_cis
278 | self.rope_k_repeat = rope_k_repeat
279 |
280 | def forward(
281 | self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0
282 | ) -> Tensor:
283 | # Input projections
284 | q = self.q_proj(q)
285 | k = self.k_proj(k)
286 | v = self.v_proj(v)
287 |
288 | # Separate into heads
289 | q = self._separate_heads(q, self.num_heads)
290 | k = self._separate_heads(k, self.num_heads)
291 | v = self._separate_heads(v, self.num_heads)
292 |
293 | # Apply rotary position encoding
294 | w = h = math.sqrt(q.shape[-2])
295 | self.freqs_cis = self.freqs_cis.to(q.device)
296 | if self.freqs_cis.shape[0] != q.shape[-2]:
297 | self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device)
298 | if q.shape[-2] != k.shape[-2]:
299 | assert self.rope_k_repeat
300 |
301 | num_k_rope = k.size(-2) - num_k_exclude_rope
302 | q, k[:, :, :num_k_rope] = apply_rotary_enc(
303 | q,
304 | k[:, :, :num_k_rope],
305 | freqs_cis=self.freqs_cis,
306 | repeat_freqs_k=self.rope_k_repeat,
307 | )
308 |
309 | dropout_p = self.dropout_p if self.training else 0.0
310 |
311 | #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)):
312 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
313 |
314 | out = self._recombine_heads(out)
315 | out = self.out_proj(out)
316 |
317 | return out
318 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/modeling/sam2_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 |
8 | import copy
9 |
10 | import torch
11 | import torch.nn as nn
12 | import torch.nn.functional as F
13 |
14 |
15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num):
16 | """
17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs`
18 | that are temporally closest to the current frame at `frame_idx`. Here, we take
19 | - a) the closest conditioning frame before `frame_idx` (if any);
20 | - b) the closest conditioning frame after `frame_idx` (if any);
21 | - c) any other temporally closest conditioning frames until reaching a total
22 | of `max_cond_frame_num` conditioning frames.
23 |
24 | Outputs:
25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`.
26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`.
27 | """
28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num:
29 | selected_outputs = cond_frame_outputs
30 | unselected_outputs = {}
31 | else:
32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames"
33 | selected_outputs = {}
34 |
35 | # the closest conditioning frame before `frame_idx` (if any)
36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None)
37 | if idx_before is not None:
38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before]
39 |
40 | # the closest conditioning frame after `frame_idx` (if any)
41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None)
42 | if idx_after is not None:
43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after]
44 |
45 | # add other temporally closest conditioning frames until reaching a total
46 | # of `max_cond_frame_num` conditioning frames.
47 | num_remain = max_cond_frame_num - len(selected_outputs)
48 | inds_remain = sorted(
49 | (t for t in cond_frame_outputs if t not in selected_outputs),
50 | key=lambda x: abs(x - frame_idx),
51 | )[:num_remain]
52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain)
53 | unselected_outputs = {
54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs
55 | }
56 |
57 | return selected_outputs, unselected_outputs
58 |
59 |
60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000):
61 | """
62 | Get 1D sine positional embedding as in the original Transformer paper.
63 | """
64 | pe_dim = dim // 2
65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device)
66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim)
67 |
68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t
69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1)
70 | return pos_embed
71 |
72 |
73 | def get_activation_fn(activation):
74 | """Return an activation function given a string"""
75 | if activation == "relu":
76 | return F.relu
77 | if activation == "gelu":
78 | return F.gelu
79 | if activation == "glu":
80 | return F.glu
81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
82 |
83 |
84 | def get_clones(module, N):
85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
86 |
87 |
88 | class DropPath(nn.Module):
89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
90 | def __init__(self, drop_prob=0.0, scale_by_keep=True):
91 | super(DropPath, self).__init__()
92 | self.drop_prob = drop_prob
93 | self.scale_by_keep = scale_by_keep
94 |
95 | def forward(self, x):
96 | if self.drop_prob == 0.0 or not self.training:
97 | return x
98 | keep_prob = 1 - self.drop_prob
99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1)
100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
101 | if keep_prob > 0.0 and self.scale_by_keep:
102 | random_tensor.div_(keep_prob)
103 | return x * random_tensor
104 |
105 |
106 | # Lightly adapted from
107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa
108 | class MLP(nn.Module):
109 | def __init__(
110 | self,
111 | input_dim: int,
112 | hidden_dim: int,
113 | output_dim: int,
114 | num_layers: int,
115 | activation: nn.Module = nn.ReLU,
116 | sigmoid_output: bool = False,
117 | ) -> None:
118 | super().__init__()
119 | self.num_layers = num_layers
120 | h = [hidden_dim] * (num_layers - 1)
121 | self.layers = nn.ModuleList(
122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
123 | )
124 | self.sigmoid_output = sigmoid_output
125 | self.act = activation()
126 |
127 | def forward(self, x):
128 | for i, layer in enumerate(self.layers):
129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
130 | if self.sigmoid_output:
131 | x = F.sigmoid(x)
132 | return x
133 |
134 |
135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa
136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa
137 | class LayerNorm2d(nn.Module):
138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
139 | super().__init__()
140 | self.weight = nn.Parameter(torch.ones(num_channels))
141 | self.bias = nn.Parameter(torch.zeros(num_channels))
142 | self.eps = eps
143 |
144 | def forward(self, x: torch.Tensor) -> torch.Tensor:
145 | u = x.mean(1, keepdim=True)
146 | s = (x - u).pow(2).mean(1, keepdim=True)
147 | x = (x - u) / torch.sqrt(s + self.eps)
148 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
149 | return x
150 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from .download import download_weights
8 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/utils/amg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import math
8 | from copy import deepcopy
9 | from itertools import product
10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple
11 |
12 | import numpy as np
13 | import torch
14 |
15 | # Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py
16 |
17 |
18 | class MaskData:
19 | """
20 | A structure for storing masks and their related data in batched format.
21 | Implements basic filtering and concatenation.
22 | """
23 |
24 | def __init__(self, **kwargs) -> None:
25 | for v in kwargs.values():
26 | assert isinstance(
27 | v, (list, np.ndarray, torch.Tensor)
28 | ), "MaskData only supports list, numpy arrays, and torch tensors."
29 | self._stats = dict(**kwargs)
30 |
31 | def __setitem__(self, key: str, item: Any) -> None:
32 | assert isinstance(
33 | item, (list, np.ndarray, torch.Tensor)
34 | ), "MaskData only supports list, numpy arrays, and torch tensors."
35 | self._stats[key] = item
36 |
37 | def __delitem__(self, key: str) -> None:
38 | del self._stats[key]
39 |
40 | def __getitem__(self, key: str) -> Any:
41 | return self._stats[key]
42 |
43 | def items(self) -> ItemsView[str, Any]:
44 | return self._stats.items()
45 |
46 | def filter(self, keep: torch.Tensor) -> None:
47 | for k, v in self._stats.items():
48 | if v is None:
49 | self._stats[k] = None
50 | elif isinstance(v, torch.Tensor):
51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
52 | elif isinstance(v, np.ndarray):
53 | self._stats[k] = v[keep.detach().cpu().numpy()]
54 | elif isinstance(v, list) and keep.dtype == torch.bool:
55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
56 | elif isinstance(v, list):
57 | self._stats[k] = [v[i] for i in keep]
58 | else:
59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
60 |
61 | def cat(self, new_stats: "MaskData") -> None:
62 | for k, v in new_stats.items():
63 | if k not in self._stats or self._stats[k] is None:
64 | self._stats[k] = deepcopy(v)
65 | elif isinstance(v, torch.Tensor):
66 | self._stats[k] = torch.cat([self._stats[k], v], dim=0)
67 | elif isinstance(v, np.ndarray):
68 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
69 | elif isinstance(v, list):
70 | self._stats[k] = self._stats[k] + deepcopy(v)
71 | else:
72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
73 |
74 | def to_numpy(self) -> None:
75 | for k, v in self._stats.items():
76 | if isinstance(v, torch.Tensor):
77 | self._stats[k] = v.float().detach().cpu().numpy()
78 |
79 |
80 | def is_box_near_crop_edge(
81 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
82 | ) -> torch.Tensor:
83 | """Filter masks at the edge of a crop, but not at the edge of the original image."""
84 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
85 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
86 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
87 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
88 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
89 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
90 | return torch.any(near_crop_edge, dim=1)
91 |
92 |
93 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
94 | box_xywh = deepcopy(box_xyxy)
95 | box_xywh[2] = box_xywh[2] - box_xywh[0]
96 | box_xywh[3] = box_xywh[3] - box_xywh[1]
97 | return box_xywh
98 |
99 |
100 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
101 | assert len(args) > 0 and all(
102 | len(a) == len(args[0]) for a in args
103 | ), "Batched iteration must have inputs of all the same size."
104 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
105 | for b in range(n_batches):
106 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
107 |
108 |
109 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
110 | """
111 | Encodes masks to an uncompressed RLE, in the format expected by
112 | pycoco tools.
113 | """
114 | # Put in fortran order and flatten h,w
115 | b, h, w = tensor.shape
116 | tensor = tensor.permute(0, 2, 1).flatten(1)
117 |
118 | # Compute change indices
119 | diff = tensor[:, 1:] ^ tensor[:, :-1]
120 | change_indices = diff.nonzero()
121 |
122 | # Encode run length
123 | out = []
124 | for i in range(b):
125 | cur_idxs = change_indices[change_indices[:, 0] == i, 1]
126 | cur_idxs = torch.cat(
127 | [
128 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
129 | cur_idxs + 1,
130 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
131 | ]
132 | )
133 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
134 | counts = [] if tensor[i, 0] == 0 else [0]
135 | counts.extend(btw_idxs.detach().cpu().tolist())
136 | out.append({"size": [h, w], "counts": counts})
137 | return out
138 |
139 |
140 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
141 | """Compute a binary mask from an uncompressed RLE."""
142 | h, w = rle["size"]
143 | mask = np.empty(h * w, dtype=bool)
144 | idx = 0
145 | parity = False
146 | for count in rle["counts"]:
147 | mask[idx : idx + count] = parity
148 | idx += count
149 | parity ^= True
150 | mask = mask.reshape(w, h)
151 | return mask.transpose() # Put in C order
152 |
153 |
154 | def area_from_rle(rle: Dict[str, Any]) -> int:
155 | return sum(rle["counts"][1::2])
156 |
157 |
158 | def calculate_stability_score(
159 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float
160 | ) -> torch.Tensor:
161 | """
162 | Computes the stability score for a batch of masks. The stability
163 | score is the IoU between the binary masks obtained by thresholding
164 | the predicted mask logits at high and low values.
165 | """
166 | # One mask is always contained inside the other.
167 | # Save memory by preventing unnecessary cast to torch.int64
168 | intersections = (
169 | (masks > (mask_threshold + threshold_offset))
170 | .sum(-1, dtype=torch.int16)
171 | .sum(-1, dtype=torch.int32)
172 | )
173 | unions = (
174 | (masks > (mask_threshold - threshold_offset))
175 | .sum(-1, dtype=torch.int16)
176 | .sum(-1, dtype=torch.int32)
177 | )
178 | return intersections / unions
179 |
180 |
181 | def build_point_grid(n_per_side: int) -> np.ndarray:
182 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
183 | offset = 1 / (2 * n_per_side)
184 | points_one_side = np.linspace(offset, 1 - offset, n_per_side)
185 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
186 | points_y = np.tile(points_one_side[:, None], (1, n_per_side))
187 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
188 | return points
189 |
190 |
191 | def build_all_layer_point_grids(
192 | n_per_side: int, n_layers: int, scale_per_layer: int
193 | ) -> List[np.ndarray]:
194 | """Generates point grids for all crop layers."""
195 | points_by_layer = []
196 | for i in range(n_layers + 1):
197 | n_points = int(n_per_side / (scale_per_layer**i))
198 | points_by_layer.append(build_point_grid(n_points))
199 | return points_by_layer
200 |
201 |
202 | def generate_crop_boxes(
203 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
204 | ) -> Tuple[List[List[int]], List[int]]:
205 | """
206 | Generates a list of crop boxes of different sizes. Each layer
207 | has (2**i)**2 boxes for the ith layer.
208 | """
209 | crop_boxes, layer_idxs = [], []
210 | im_h, im_w = im_size
211 | short_side = min(im_h, im_w)
212 |
213 | # Original image
214 | crop_boxes.append([0, 0, im_w, im_h])
215 | layer_idxs.append(0)
216 |
217 | def crop_len(orig_len, n_crops, overlap):
218 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
219 |
220 | for i_layer in range(n_layers):
221 | n_crops_per_side = 2 ** (i_layer + 1)
222 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
223 |
224 | crop_w = crop_len(im_w, n_crops_per_side, overlap)
225 | crop_h = crop_len(im_h, n_crops_per_side, overlap)
226 |
227 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
228 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
229 |
230 | # Crops in XYWH format
231 | for x0, y0 in product(crop_box_x0, crop_box_y0):
232 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
233 | crop_boxes.append(box)
234 | layer_idxs.append(i_layer + 1)
235 |
236 | return crop_boxes, layer_idxs
237 |
238 |
239 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
240 | x0, y0, _, _ = crop_box
241 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
242 | # Check if boxes has a channel dimension
243 | if len(boxes.shape) == 3:
244 | offset = offset.unsqueeze(1)
245 | return boxes + offset
246 |
247 |
248 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
249 | x0, y0, _, _ = crop_box
250 | offset = torch.tensor([[x0, y0]], device=points.device)
251 | # Check if points has a channel dimension
252 | if len(points.shape) == 3:
253 | offset = offset.unsqueeze(1)
254 | return points + offset
255 |
256 |
257 | def uncrop_masks(
258 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
259 | ) -> torch.Tensor:
260 | x0, y0, x1, y1 = crop_box
261 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
262 | return masks
263 | # Coordinate transform masks
264 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
265 | pad = (x0, pad_x - x0, y0, pad_y - y0)
266 | return torch.nn.functional.pad(masks, pad, value=0)
267 |
268 |
269 | def remove_small_regions(
270 | mask: np.ndarray, area_thresh: float, mode: str
271 | ) -> Tuple[np.ndarray, bool]:
272 | """
273 | Removes small disconnected regions and holes in a mask. Returns the
274 | mask and an indicator of if the mask has been modified.
275 | """
276 | import cv2 # type: ignore
277 |
278 | assert mode in ["holes", "islands"]
279 | correct_holes = mode == "holes"
280 | working_mask = (correct_holes ^ mask).astype(np.uint8)
281 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
282 | sizes = stats[:, -1][1:] # Row 0 is background label
283 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
284 | if len(small_regions) == 0:
285 | return mask, False
286 | fill_labels = [0] + small_regions
287 | if not correct_holes:
288 | fill_labels = [i for i in range(n_labels) if i not in fill_labels]
289 | # If every region is below threshold, keep largest
290 | if len(fill_labels) == 0:
291 | fill_labels = [int(np.argmax(sizes)) + 1]
292 | mask = np.isin(regions, fill_labels)
293 | return mask, True
294 |
295 |
296 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
297 | from pycocotools import mask as mask_utils # type: ignore
298 |
299 | h, w = uncompressed_rle["size"]
300 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
301 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
302 | return rle
303 |
304 |
305 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
306 | """
307 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
308 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
309 | """
310 | # torch.max below raises an error on empty inputs, just skip in this case
311 | if torch.numel(masks) == 0:
312 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
313 |
314 | # Normalize shape to CxHxW
315 | shape = masks.shape
316 | h, w = shape[-2:]
317 | if len(shape) > 2:
318 | masks = masks.flatten(0, -3)
319 | else:
320 | masks = masks.unsqueeze(0)
321 |
322 | # Get top and bottom edges
323 | in_height, _ = torch.max(masks, dim=-1)
324 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
325 | bottom_edges, _ = torch.max(in_height_coords, dim=-1)
326 | in_height_coords = in_height_coords + h * (~in_height)
327 | top_edges, _ = torch.min(in_height_coords, dim=-1)
328 |
329 | # Get left and right edges
330 | in_width, _ = torch.max(masks, dim=-2)
331 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
332 | right_edges, _ = torch.max(in_width_coords, dim=-1)
333 | in_width_coords = in_width_coords + w * (~in_width)
334 | left_edges, _ = torch.min(in_width_coords, dim=-1)
335 |
336 | # If the mask is empty the right edge will be to the left of the left edge.
337 | # Replace these boxes with [0, 0, 0, 0]
338 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
339 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
340 | out = out * (~empty_filter).unsqueeze(-1)
341 |
342 | # Return to original shape
343 | if len(shape) > 2:
344 | out = out.reshape(*shape[:-2], 4)
345 | else:
346 | out = out[0]
347 |
348 | return out
349 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/utils/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | import subprocess
3 | from typing import List
4 |
5 | import pytest
6 |
7 |
8 | @pytest.fixture
9 | def download_weights(output_directory: str = "artifacts") -> None:
10 | base_url: str = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/"
11 | file_names: List[str] = [
12 | "sam2_hiera_tiny.pt",
13 | "sam2_hiera_small.pt",
14 | "sam2_hiera_base_plus.pt",
15 | "sam2_hiera_large.pt",
16 | ]
17 |
18 | if not os.path.exists(output_directory):
19 | os.makedirs(output_directory)
20 |
21 | for file_name in file_names:
22 | file_path = os.path.join(output_directory, file_name)
23 | if not os.path.exists(file_path):
24 | url = f"{base_url}{file_name}"
25 | command = ["wget", url, "-P", output_directory]
26 | try:
27 | result = subprocess.run(
28 | command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
29 | )
30 | print(f"Download of {file_name} completed successfully.")
31 | print(result.stdout.decode())
32 | except subprocess.CalledProcessError as e:
33 | print(f"An error occurred during the download of {file_name}.")
34 | print(e.stderr.decode())
35 | else:
36 | print(f"{file_name} already exists. Skipping download.")
37 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import warnings
9 | from threading import Thread
10 | from typing import Dict, List, Union
11 |
12 | import numpy as np
13 | import torch
14 | from PIL import Image
15 | from torch.nn.attention import SDPBackend
16 | from einops import rearrange
17 | from tqdm import tqdm
18 | import torch.nn.functional as F
19 |
20 | VARIANTS: List[str] = ["tiny", "small", "base_plus", "large"]
21 |
22 | variant_to_config_mapping: Dict[str, str] = {
23 | "tiny": "sam2_hiera_t.yaml",
24 | "small": "sam2_hiera_s.yaml",
25 | "base_plus": "sam2_hiera_b+.yaml",
26 | "large": "sam2_hiera_l.yaml",
27 | }
28 |
29 |
30 | def get_sdp_backends(dropout_p: float) -> Union[List[SDPBackend], SDPBackend]:
31 | backends = []
32 | if torch.cuda.is_available():
33 | use_flash_attn = torch.cuda.get_device_properties(0).major >= 8
34 | pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2])
35 |
36 | if torch.cuda.get_device_properties(0).major < 7:
37 | backends.append(SDPBackend.EFFICIENT_ATTENTION)
38 |
39 | if use_flash_attn:
40 | backends.append(SDPBackend.FLASH_ATTENTION)
41 |
42 | if pytorch_version < (2, 2) or not use_flash_attn:
43 | backends.append(SDPBackend.MATH)
44 |
45 | if (
46 | SDPBackend.EFFICIENT_ATTENTION in backends and dropout_p > 0.0
47 | ) and SDPBackend.MATH not in backends:
48 | backends.append(SDPBackend.MATH)
49 |
50 | else:
51 | backends.extend([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH])
52 |
53 | return backends
54 |
55 |
56 | def get_connected_components(mask):
57 | """
58 | Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W).
59 |
60 | Inputs:
61 | - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is
62 | background.
63 |
64 | Outputs:
65 | - labels: A tensor of shape (N, 1, H, W) containing the connected component labels
66 | for foreground pixels and 0 for background pixels.
67 | - counts: A tensor of shape (N, 1, H, W) containing the area of the connected
68 | components for foreground pixels and 0 for background pixels.
69 | """
70 | from sam2 import _C
71 |
72 | return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())
73 |
74 |
75 | def mask_to_box(masks: torch.Tensor):
76 | """
77 | compute bounding box given an input mask
78 |
79 | Inputs:
80 | - masks: [B, 1, H, W] boxes, dtype=torch.Tensor
81 |
82 | Returns:
83 | - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor
84 | """
85 | B, _, h, w = masks.shape
86 | device = masks.device
87 | xs = torch.arange(w, device=device, dtype=torch.int32)
88 | ys = torch.arange(h, device=device, dtype=torch.int32)
89 | grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy")
90 | grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w)
91 | grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w)
92 | min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1)
93 | max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1)
94 | min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1)
95 | max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1)
96 | bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1)
97 |
98 | return bbox_coords
99 |
100 |
101 | def _load_img_as_tensor(img_path, image_size):
102 | img_pil = Image.open(img_path)
103 | img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size)))
104 | if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images
105 | img_np = img_np / 255.0
106 | else:
107 | raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}")
108 | img = torch.from_numpy(img_np).permute(2, 0, 1)
109 | video_width, video_height = img_pil.size # the original video size
110 | return img, video_height, video_width
111 |
112 |
113 | class AsyncVideoFrameLoader:
114 | """
115 | A list of video frames to be load asynchronously without blocking session start.
116 | """
117 |
118 | def __init__(self, img_paths, image_size, img_mean, img_std, device):
119 | self.img_paths = img_paths
120 | self.image_size = image_size
121 | self.img_mean = img_mean
122 | self.img_std = img_std
123 | self.device = device
124 | # items in `self._images` will be loaded asynchronously
125 | self.images = [None] * len(img_paths)
126 | # catch and raise any exceptions in the async loading thread
127 | self.exception = None
128 | # video_height and video_width be filled when loading the first image
129 | self.video_height = None
130 | self.video_width = None
131 |
132 | # load the first frame to fill video_height and video_width and also
133 | # to cache it (since it's most likely where the user will click)
134 | self.__getitem__(0)
135 |
136 | # load the rest of frames asynchronously without blocking the session start
137 | def _load_frames():
138 | try:
139 | for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"):
140 | self.__getitem__(n)
141 | except Exception as e:
142 | self.exception = e
143 |
144 | self.thread = Thread(target=_load_frames, daemon=True)
145 | self.thread.start()
146 |
147 | def __getitem__(self, index):
148 | if self.exception is not None:
149 | raise RuntimeError("Failure in frame loading thread") from self.exception
150 |
151 | img = self.images[index]
152 | if img is not None:
153 | return img
154 |
155 | img, video_height, video_width = _load_img_as_tensor(
156 | self.img_paths[index], self.image_size
157 | )
158 | self.video_height = video_height
159 | self.video_width = video_width
160 | # normalize by mean and std
161 | img -= self.img_mean
162 | img /= self.img_std
163 | img = img.to(self.device)
164 | self.images[index] = img
165 | return img
166 |
167 | def __len__(self):
168 | return len(self.images)
169 |
170 |
171 | def load_video_frames(
172 | video_path,
173 | image_size,
174 | images=None,
175 | img_mean=(0.485, 0.456, 0.406),
176 | img_std=(0.229, 0.224, 0.225),
177 | async_loading_frames=False,
178 | device="cpu",
179 | ):
180 | """
181 | Load the video frames from a directory of JPEG files (".jpg" format).
182 |
183 | You can load a frame asynchronously by setting `async_loading_frames` to `True`.
184 | """
185 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None]
186 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None]
187 |
188 | if images is not None:
189 | images = torch.from_numpy(images).float()
190 | images = rearrange(images, "f h w c -> f c h w")
191 | images = F.interpolate(images, (image_size, image_size), mode="bilinear")
192 | video_height, video_width = images.shape[2:]
193 | else:
194 | if isinstance(video_path, str) and os.path.isdir(video_path):
195 | jpg_folder = video_path
196 | else:
197 | raise NotImplementedError("Only JPEG frames are supported at this moment")
198 |
199 | frame_names = [
200 | p
201 | for p in os.listdir(jpg_folder)
202 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"]
203 | ]
204 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0]))
205 | num_frames = len(frame_names)
206 | if num_frames == 0:
207 | raise RuntimeError(f"no images found in {jpg_folder}")
208 | img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names]
209 |
210 | images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32)
211 | for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")):
212 | images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size)
213 | images = images.to(device)
214 | img_mean = img_mean.to(device)
215 | img_std = img_std.to(device)
216 | # normalize by mean and std
217 | images -= img_mean
218 | images /= img_std
219 | return images, video_height, video_width
220 |
221 |
222 | def fill_holes_in_mask_scores(mask, max_area):
223 | """
224 | A post processor to fill small holes in mask scores with area under `max_area`.
225 | """
226 | # Holes are those connected components in background with area <= self.max_area
227 | # (background regions are those with mask scores <= 0)
228 | assert max_area > 0, "max_area must be positive"
229 | labels, areas = get_connected_components(mask <= 0)
230 | is_hole = (labels > 0) & (areas <= max_area)
231 | # We fill holes with a small positive mask score (0.1) to change them to foreground.
232 | mask = torch.where(is_hole, 0.1, mask)
233 | return mask
234 |
235 |
236 | def concat_points(old_point_inputs, new_points, new_labels):
237 | """Add new points and labels to previous point inputs (add at the end)."""
238 | if old_point_inputs is None:
239 | points, labels = new_points, new_labels
240 | else:
241 | points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1)
242 | labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1)
243 |
244 | return {"point_coords": points, "point_labels": labels}
245 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/utils/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | from torchvision.transforms import Normalize, Resize, ToTensor
11 |
12 |
13 | class SAM2Transforms(nn.Module):
14 | def __init__(
15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0
16 | ):
17 | """
18 | Transforms for SAM2.
19 | """
20 | super().__init__()
21 | self.resolution = resolution
22 | self.mask_threshold = mask_threshold
23 | self.max_hole_area = max_hole_area
24 | self.max_sprinkle_area = max_sprinkle_area
25 | self.mean = [0.485, 0.456, 0.406]
26 | self.std = [0.229, 0.224, 0.225]
27 | self.to_tensor = ToTensor()
28 | self.transforms = torch.jit.script(
29 | nn.Sequential(
30 | Resize((self.resolution, self.resolution)),
31 | Normalize(self.mean, self.std),
32 | )
33 | )
34 |
35 | def __call__(self, x):
36 | x = self.to_tensor(x)
37 | return self.transforms(x)
38 |
39 | def forward_batch(self, img_list):
40 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list]
41 | img_batch = torch.stack(img_batch, dim=0)
42 | return img_batch
43 |
44 | def transform_coords(
45 | self, coords: torch.Tensor, normalize=False, orig_hw=None
46 | ) -> torch.Tensor:
47 | """
48 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
49 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
50 |
51 | Returns
52 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model.
53 | """
54 | if normalize:
55 | assert orig_hw is not None
56 | h, w = orig_hw
57 | coords = coords.clone()
58 | coords[..., 0] = coords[..., 0] / w
59 | coords[..., 1] = coords[..., 1] / h
60 |
61 | coords = coords * self.resolution # unnormalize coords
62 | return coords
63 |
64 | def transform_boxes(
65 | self, boxes: torch.Tensor, normalize=False, orig_hw=None
66 | ) -> torch.Tensor:
67 | """
68 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates,
69 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required.
70 | """
71 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw)
72 | return boxes
73 |
74 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor:
75 | """
76 | Perform PostProcessing on output masks.
77 | """
78 | from sam2.utils.misc import get_connected_components
79 |
80 | masks = masks.float()
81 | if self.max_hole_area > 0:
82 | # Holes are those connected components in background with area <= self.fill_hole_area
83 | # (background regions are those with mask scores <= self.mask_threshold)
84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image
85 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold)
86 | is_hole = (labels > 0) & (areas <= self.max_hole_area)
87 | is_hole = is_hole.reshape_as(masks)
88 | # We fill holes with a small positive mask score (10.0) to change them to foreground.
89 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks)
90 |
91 | if self.max_sprinkle_area > 0:
92 | labels, areas = get_connected_components(mask_flat > self.mask_threshold)
93 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area)
94 | is_hole = is_hole.reshape_as(masks)
95 | # We fill holes with negative mask score (-10.0) to change them to background.
96 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks)
97 |
98 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False)
99 | return masks
100 |
--------------------------------------------------------------------------------
/gradio_demo/sam2/utils/visualization.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import numpy as np
4 | from PIL import Image
5 |
6 |
7 | def show_masks(
8 | image: np.ndarray,
9 | masks: np.ndarray,
10 | scores: Optional[np.ndarray],
11 | alpha: Optional[float] = 0.5,
12 | display_image: Optional[bool] = False,
13 | only_best: Optional[bool] = True,
14 | autogenerated_mask: Optional[bool] = False,
15 | ) -> Image.Image:
16 | if scores is not None:
17 | # sort masks by their scores
18 | sorted_ind = np.argsort(scores)[::-1]
19 | masks = masks[sorted_ind]
20 |
21 | if autogenerated_mask:
22 | masks = sorted(masks, key=(lambda x: x["area"]), reverse=True)
23 | else:
24 | # get mask dimensions
25 | h, w = masks.shape[-2:]
26 |
27 | if display_image:
28 | output_image = Image.fromarray(image)
29 | else:
30 | # create a new blank image to superimpose masks
31 | if autogenerated_mask:
32 | output_image = Image.new(
33 | mode="RGBA",
34 | size=(
35 | masks[0]["segmentation"].shape[1],
36 | masks[0]["segmentation"].shape[0],
37 | ),
38 | color=(0, 0, 0),
39 | )
40 | else:
41 | output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0))
42 |
43 | for i, mask in enumerate(masks):
44 | if not autogenerated_mask:
45 | if mask.ndim > 2: # type: ignore
46 | mask = mask.squeeze() # type: ignore
47 | else:
48 | mask = mask["segmentation"]
49 | # Generate a random color with specified alpha value
50 | color = np.concatenate(
51 | (np.random.randint(0, 256, size=3), [int(alpha * 255)]), axis=0
52 | )
53 |
54 | # Create an RGBA image for the mask
55 | mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("L")
56 | mask_colored = Image.new("RGBA", mask_image.size, tuple(color))
57 | mask_image = Image.composite(
58 | mask_colored, Image.new("RGBA", mask_image.size), mask_image
59 | )
60 |
61 | # Overlay mask on the output image
62 | output_image = Image.alpha_composite(output_image, mask_image)
63 |
64 | # Exit if specified to only display the best mask
65 | if only_best:
66 | break
67 |
68 | return output_image
69 |
--------------------------------------------------------------------------------
/gradio_demo/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gradio as gr
3 | import cv2
4 | import numpy as np
5 | from PIL import Image
6 |
7 | os.makedirs("./SAM2-Video-Predictor/checkpoints/", exist_ok=True)
8 | os.makedirs("./model/", exist_ok=True)
9 |
10 | from huggingface_hub import snapshot_download
11 |
12 | def download_sam2():
13 | snapshot_download(repo_id="facebook/sam2-hiera-large", local_dir="./SAM2-Video-Predictor/checkpoints/")
14 | print("Download sam2 completed")
15 |
16 | def download_remover():
17 | snapshot_download(repo_id="zibojia/minimax-remover", local_dir="./model/")
18 | print("Download minimax remover completed")
19 |
20 | download_sam2()
21 | download_remover()
22 |
23 | import torch
24 | import argparse
25 | import random
26 |
27 | import torch.nn.functional as F
28 | import time
29 | import random
30 | from omegaconf import OmegaConf
31 | from einops import rearrange
32 | from diffusers.models import AutoencoderKLWan
33 | import scipy
34 | from transformer_minimax_remover import Transformer3DModel
35 | from einops import rearrange
36 | from diffusers.schedulers import UniPCMultistepScheduler
37 | from pipeline_minimax_remover import Minimax_Remover_Pipeline
38 |
39 | from diffusers.utils import export_to_video
40 | from decord import VideoReader, cpu
41 | from moviepy.editor import ImageSequenceClip
42 |
43 | from sam2 import load_model
44 |
45 | from sam2.build_sam import build_sam2, build_sam2_video_predictor
46 | from sam2.sam2_image_predictor import SAM2ImagePredictor
47 |
48 | COLOR_PALETTE = [
49 | (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255),
50 | (0, 255, 255), (255, 128, 0), (128, 0, 255), (0, 128, 255), (128, 255, 0)
51 | ]
52 |
53 | random_seed = 42
54 | video_length = 201
55 | W = 1024
56 | H = W
57 | device = "cuda" if torch.cuda.is_available() else "cpu"
58 |
59 | def get_pipe_image_and_video_predictor():
60 | vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16)
61 | transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16)
62 | scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler")
63 |
64 | pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler)
65 | pipe.to(device)
66 |
67 | sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt"
68 | config = "sam2_hiera_l.yaml"
69 |
70 | video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device=device)
71 | model = build_sam2(config, sam2_checkpoint, device=device)
72 | model.image_size = 1024
73 | image_predictor = SAM2ImagePredictor(sam_model=model)
74 |
75 | return pipe, image_predictor, video_predictor
76 |
77 | def get_video_info(video_path, video_state):
78 | video_state["input_points"] = []
79 | video_state["scaled_points"] = []
80 | video_state["input_labels"] = []
81 | video_state["frame_idx"] = 0
82 | vr = VideoReader(video_path, ctx=cpu(0))
83 | first_frame = vr[0].asnumpy()
84 | del vr
85 |
86 | if first_frame.shape[0] > first_frame.shape[1]:
87 | W_ = W
88 | H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1])
89 | else:
90 | H_ = H
91 | W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0])
92 |
93 | first_frame = cv2.resize(first_frame, (W_, H_))
94 | video_state["origin_images"] = np.expand_dims(first_frame, axis=0)
95 | video_state["inference_state"] = None
96 | video_state["video_path"] = video_path
97 | video_state["masks"] = None
98 | video_state["painted_images"] = None
99 | image = Image.fromarray(first_frame)
100 | return image
101 |
102 | def segment_frame(evt: gr.SelectData, label, video_state):
103 | if video_state["origin_images"] is None:
104 | gr.Warning("Please click \"Extract First Frame\" to extract the first frame first, then click the annotation")
105 | return None
106 | x, y = evt.index
107 | new_point = [x, y]
108 | label_value = 1 if label == "Positive" else 0
109 |
110 | video_state["input_points"].append(new_point)
111 | video_state["input_labels"].append(label_value)
112 | height, width = video_state["origin_images"][0].shape[0:2]
113 | scaled_points = []
114 | for pt in video_state["input_points"]:
115 | sx = pt[0] / width
116 | sy = pt[1] / height
117 | scaled_points.append([sx, sy])
118 |
119 | video_state["scaled_points"] = scaled_points
120 |
121 | image_predictor.set_image(video_state["origin_images"][0])
122 | mask, _, _ = image_predictor.predict(
123 | point_coords=video_state["scaled_points"],
124 | point_labels=video_state["input_labels"],
125 | multimask_output=False,
126 | normalize_coords=False,
127 | )
128 |
129 | mask = np.squeeze(mask)
130 | mask = cv2.resize(mask, (width, height))
131 | mask = mask[:,:,None]
132 |
133 | color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
134 | color = color[None, None, :]
135 | org_image = video_state["origin_images"][0].astype(np.float32) / 255.0
136 | painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color
137 | painted_image = np.uint8(np.clip(painted_image * 255, 0, 255))
138 | video_state["painted_images"] = np.expand_dims(painted_image, axis=0)
139 | video_state["masks"] = np.expand_dims(mask[:,:,0], axis=0)
140 |
141 | for i in range(len(video_state["input_points"])):
142 | point = video_state["input_points"][i]
143 | if video_state["input_labels"][i] == 0:
144 | cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1) # 红色点,半径为3
145 | else:
146 | cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1)
147 |
148 | return Image.fromarray(painted_image)
149 |
150 | def clear_clicks(video_state):
151 | video_state["input_points"] = []
152 | video_state["input_labels"] = []
153 | video_state["scaled_points"] = []
154 | video_state["inference_state"] = None
155 | video_state["masks"] = None
156 | video_state["painted_images"] = None
157 | return Image.fromarray(video_state["origin_images"][0]) if video_state["origin_images"] is not None else None
158 |
159 |
160 | def preprocess_for_removal(images, masks):
161 | out_images = []
162 | out_masks = []
163 | for img, msk in zip(images, masks):
164 | if img.shape[0] > img.shape[1]:
165 | img_resized = cv2.resize(img, (480, 832), interpolation=cv2.INTER_LINEAR)
166 | else:
167 | img_resized = cv2.resize(img, (832, 480), interpolation=cv2.INTER_LINEAR)
168 | img_resized = img_resized.astype(np.float32) / 127.5 - 1.0 # [-1, 1]
169 | out_images.append(img_resized)
170 | if msk.shape[0] > msk.shape[1]:
171 | msk_resized = cv2.resize(msk, (480, 832), interpolation=cv2.INTER_NEAREST)
172 | else:
173 | msk_resized = cv2.resize(msk, (832, 480), interpolation=cv2.INTER_NEAREST)
174 | msk_resized = msk_resized.astype(np.float32)
175 | msk_resized = (msk_resized > 0.5).astype(np.float32)
176 | out_masks.append(msk_resized)
177 | arr_images = np.stack(out_images)
178 | arr_masks = np.stack(out_masks)
179 | return torch.from_numpy(arr_images).half().to(device), torch.from_numpy(arr_masks).half().to(device)
180 |
181 |
182 | def inference_and_return_video(dilation_iterations, num_inference_steps, video_state=None):
183 | if video_state["origin_images"] is None or video_state["masks"] is None:
184 | return None
185 | images = video_state["origin_images"]
186 | masks = video_state["masks"]
187 |
188 | images = np.array(images)
189 | masks = np.array(masks)
190 | img_tensor, mask_tensor = preprocess_for_removal(images, masks)
191 | mask_tensor = mask_tensor[:,:,:,:1]
192 |
193 | if mask_tensor.shape[1] < mask_tensor.shape[2]:
194 | height = 480
195 | width = 832
196 | else:
197 | height = 832
198 | width = 480
199 |
200 | with torch.no_grad():
201 | out = pipe(
202 | images=img_tensor,
203 | masks=mask_tensor,
204 | num_frames=mask_tensor.shape[0],
205 | height=height,
206 | width=width,
207 | num_inference_steps=int(num_inference_steps),
208 | generator=torch.Generator(device=device).manual_seed(random_seed),
209 | iterations=int(dilation_iterations)
210 | ).frames[0]
211 |
212 | out = np.uint8(out * 255)
213 | output_frames = [img for img in out]
214 |
215 | video_file = f"/tmp/{time.time()}-{random.random()}-removed_output.mp4"
216 | clip = ImageSequenceClip(output_frames, fps=15)
217 | clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None)
218 | return video_file
219 |
220 |
221 | def track_video(n_frames, video_state):
222 | if video_state["origin_images"] is None or video_state["masks"] is None:
223 | gr.Warning("Please complete target segmentation on the first frame first, then click Tracking")
224 | return None
225 |
226 | input_points = video_state["input_points"]
227 | input_labels = video_state["input_labels"]
228 | frame_idx = video_state["frame_idx"]
229 | obj_id = video_state["obj_id"]
230 | scaled_points = video_state["scaled_points"]
231 |
232 | vr = VideoReader(video_state["video_path"], ctx=cpu(0))
233 | height, width = vr[0].shape[0:2]
234 | images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))]
235 | del vr
236 |
237 | if images[0].shape[0] > images[0].shape[1]:
238 | W_ = W
239 | H_ = int(W_ * images[0].shape[0] / images[0].shape[1])
240 | else:
241 | H_ = H
242 | W_ = int(H_ * images[0].shape[1] / images[0].shape[0])
243 |
244 | images = [cv2.resize(img, (W_, H_)) for img in images]
245 | video_state["origin_images"] = images
246 | images = np.array(images)
247 | inference_state = video_predictor.init_state(images=images/255, device=device)
248 | video_state["inference_state"] = inference_state
249 |
250 | if len(torch.from_numpy(video_state["masks"][0]).shape) == 3:
251 | mask = torch.from_numpy(video_state["masks"][0])[:,:,0]
252 | else:
253 | mask = torch.from_numpy(video_state["masks"][0])
254 |
255 | video_predictor.add_new_mask(
256 | inference_state=inference_state,
257 | frame_idx=0,
258 | obj_id=obj_id,
259 | mask=mask
260 | )
261 |
262 | output_frames = []
263 | mask_frames = []
264 | color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0
265 | color = color[None, None, :]
266 | for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state):
267 | frame = images[out_frame_idx].astype(np.float32) / 255.0
268 | mask = np.zeros((H, W, 3), dtype=np.float32)
269 | for i, logit in enumerate(out_mask_logits):
270 | out_mask = logit.cpu().squeeze().detach().numpy()
271 | out_mask = (out_mask[:,:,None] > 0).astype(np.float32)
272 | mask += out_mask
273 | mask = np.clip(mask, 0, 1)
274 | mask = cv2.resize(mask, (W_, H_))
275 | mask_frames.append(mask)
276 | painted = (1 - mask * 0.5) * frame + mask * 0.5 * color
277 | painted = np.uint8(np.clip(painted * 255, 0, 255))
278 | output_frames.append(painted)
279 | video_state["masks"] = mask_frames
280 | video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4"
281 | clip = ImageSequenceClip(output_frames, fps=15)
282 | clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None)
283 | return video_file
284 |
285 | text = """
286 |
287 | Minimax-Remover: Taming Bad Noise Helps Video Object Removal
288 |
289 |
290 |

291 |

292 |

293 |

294 |

295 |

296 |
297 |
298 | Bojia Zi*, Weixuan Peng*, Xianbiao Qi†, Jianan Wang, Shihao Zhao, Rong Xiao, Kam-Fai Wong
299 |
300 |
301 | * Equal contribution † Corresponding author
302 |
303 | """
304 |
305 | pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor()
306 |
307 | with gr.Blocks() as demo:
308 | video_state = gr.State({
309 | "origin_images": None,
310 | "inference_state": None,
311 | "masks": None, # Store user-generated masks
312 | "painted_images": None,
313 | "video_path": None,
314 | "input_points": [],
315 | "scaled_points": [],
316 | "input_labels": [],
317 | "frame_idx": 0,
318 | "obj_id": 1
319 | })
320 | gr.Markdown(f"{text}
")
321 |
322 | with gr.Column():
323 | video_input = gr.Video(label="Upload Video", elem_id="my-video1")
324 | get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn")
325 |
326 | gr.Examples(
327 | examples=[
328 | ["./cartoon/0.mp4"],
329 | ["./cartoon/1.mp4"],
330 | ["./cartoon/2.mp4"],
331 | ["./cartoon/3.mp4"],
332 | ["./cartoon/4.mp4"],
333 | ["./normal_videos/0.mp4"],
334 | ["./normal_videos/1.mp4"],
335 | ["./normal_videos/3.mp4"],
336 | ["./normal_videos/4.mp4"],
337 | ["./normal_videos/5.mp4"],
338 | ],
339 | inputs=[video_input],
340 | label="Choose a video to remove.",
341 | elem_id="my-btn2"
342 | )
343 |
344 | image_output = gr.Image(label="First Frame Segmentation", interactive=True, elem_id="my-video")#, height="35%", width="60%")
345 | demo.css = """
346 | #my-btn {
347 | width: 60% !important;
348 | margin: 0 auto;
349 | }
350 |
351 | #my-video1 {
352 | width: 60% !important;
353 | height: 35% !important;
354 | margin: 0 auto;
355 | }
356 | #my-video {
357 | width: 60% !important;
358 | height: 35% !important;
359 | margin: 0 auto;
360 | }
361 | #my-md {
362 | margin: 0 auto;
363 | }
364 | #my-btn2 {
365 | width: 60% !important;
366 | margin: 0 auto;
367 | }
368 | #my-btn2 button {
369 | width: 120px !important;
370 | max-width: 120px !important;
371 | min-width: 120px !important;
372 | height: 70px !important;
373 | max-height: 70px !important;
374 | min-height: 70px !important;
375 | margin: 8px !important;
376 | border-radius: 8px !important;
377 | overflow: hidden !important;
378 | white-space: normal !important;
379 | }
380 | """
381 | with gr.Row(elem_id="my-btn"):
382 | point_prompt = gr.Radio(["Positive", "Negative"], label="Click Type", value="Positive")
383 | clear_btn = gr.Button("Clear All Clicks")
384 |
385 | with gr.Row(elem_id="my-btn"):
386 | n_frames_slider = gr.Slider(minimum=1, maximum=201, value=81, step=1, label="Tracking Frames N")
387 | track_btn = gr.Button("Tracking")
388 | video_output = gr.Video(label="Tracking Result", elem_id="my-video")
389 |
390 | with gr.Column(elem_id="my-btn"):
391 | dilation_slider = gr.Slider(minimum=1, maximum=20, value=6, step=1, label="Mask Dilation")
392 | inference_steps_slider = gr.Slider(minimum=1, maximum=100, value=6, step=1, label="Num Inference Steps")
393 |
394 | remove_btn = gr.Button("Remove", elem_id="my-btn")
395 | remove_video = gr.Video(label="Remove Results", elem_id="my-video")
396 | remove_btn.click(
397 | inference_and_return_video,
398 | inputs=[dilation_slider, inference_steps_slider, video_state],
399 | outputs=remove_video
400 | )
401 | get_info_btn.click(get_video_info, inputs=[video_input, video_state], \
402 | outputs=image_output)
403 | image_output.select(fn=segment_frame, inputs=[point_prompt, video_state], outputs=image_output)
404 | clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output)
405 | track_btn.click(track_video, inputs=[n_frames_slider, video_state], outputs=video_output)
406 |
407 | demo.launch(server_name="0.0.0.0", server_port=8000)
408 |
--------------------------------------------------------------------------------
/gradio_demo/transformer_minimax_remover.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.utils import logging
10 | from diffusers.models.attention import FeedForward
11 | from diffusers.models.attention_processor import Attention
12 | from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
13 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
14 | from diffusers.models.modeling_utils import ModelMixin
15 | from diffusers.models.normalization import FP32LayerNorm
16 |
17 | class AttnProcessor2_0:
18 | def __init__(self):
19 | if not hasattr(F, "scaled_dot_product_attention"):
20 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
21 |
22 | def __call__(
23 | self,
24 | attn: Attention,
25 | hidden_states: torch.Tensor,
26 | rotary_emb: Optional[torch.Tensor] = None,
27 | attention_mask: Optional[torch.Tensor] = None,
28 | encoder_hidden_states: Optional[torch.Tensor] = None
29 | ) -> torch.Tensor:
30 |
31 | encoder_hidden_states = hidden_states
32 | query = attn.to_q(hidden_states)
33 | key = attn.to_k(encoder_hidden_states)
34 | value = attn.to_v(encoder_hidden_states)
35 |
36 | if attn.norm_q is not None:
37 | query = attn.norm_q(query)
38 | if attn.norm_k is not None:
39 | key = attn.norm_k(key)
40 |
41 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
42 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
43 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
44 |
45 | if rotary_emb is not None:
46 |
47 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
48 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
49 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
50 | return x_out.type_as(hidden_states)
51 |
52 | query = apply_rotary_emb(query, rotary_emb)
53 | key = apply_rotary_emb(key, rotary_emb)
54 |
55 | hidden_states = F.scaled_dot_product_attention(
56 | query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
57 | )
58 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
59 | hidden_states = hidden_states.type_as(query)
60 |
61 | hidden_states = attn.to_out[0](hidden_states)
62 | hidden_states = attn.to_out[1](hidden_states)
63 | return hidden_states
64 |
65 | class TimeEmbedding(nn.Module):
66 | def __init__(
67 | self,
68 | dim: int,
69 | time_freq_dim: int,
70 | time_proj_dim: int
71 | ):
72 | super().__init__()
73 |
74 | self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
75 | self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
76 |
77 | self.act_fn = nn.SiLU()
78 | self.time_proj = nn.Linear(dim, time_proj_dim)
79 |
80 | def forward(
81 | self,
82 | timestep: torch.Tensor,
83 | ):
84 | timestep = self.timesteps_proj(timestep)
85 |
86 | time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
87 | if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
88 | timestep = timestep.to(time_embedder_dtype)
89 | temb = self.time_embedder(timestep).type_as(self.time_proj.weight.data)
90 | timestep_proj = self.time_proj(self.act_fn(temb))
91 |
92 | return temb, timestep_proj
93 |
94 |
95 | class RotaryPosEmbed(nn.Module):
96 | def __init__(
97 | self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
98 | ):
99 | super().__init__()
100 |
101 | self.attention_head_dim = attention_head_dim
102 | self.patch_size = patch_size
103 | self.max_seq_len = max_seq_len
104 |
105 | h_dim = w_dim = 2 * (attention_head_dim // 6)
106 | t_dim = attention_head_dim - h_dim - w_dim
107 |
108 | freqs = []
109 | for dim in [t_dim, h_dim, w_dim]:
110 | freq = get_1d_rotary_pos_embed(
111 | dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
112 | )
113 | freqs.append(freq)
114 | self.freqs = torch.cat(freqs, dim=1)
115 |
116 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
118 | p_t, p_h, p_w = self.patch_size
119 | ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
120 |
121 | self.freqs = self.freqs.to(hidden_states.device)
122 | freqs = self.freqs.split_with_sizes(
123 | [
124 | self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
125 | self.attention_head_dim // 6,
126 | self.attention_head_dim // 6,
127 | ],
128 | dim=1,
129 | )
130 |
131 | freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
132 | freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
133 | freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
134 | freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
135 | return freqs
136 |
137 |
138 | class TransformerBlock(nn.Module):
139 | def __init__(
140 | self,
141 | dim: int,
142 | ffn_dim: int,
143 | num_heads: int,
144 | qk_norm: str = "rms_norm_across_heads",
145 | cross_attn_norm: bool = False,
146 | eps: float = 1e-6,
147 | added_kv_proj_dim: Optional[int] = None,
148 | ):
149 | super().__init__()
150 |
151 | self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
152 | self.attn1 = Attention(
153 | query_dim=dim,
154 | heads=num_heads,
155 | kv_heads=num_heads,
156 | dim_head=dim // num_heads,
157 | qk_norm=qk_norm,
158 | eps=eps,
159 | bias=True,
160 | cross_attention_dim=None,
161 | out_bias=True,
162 | processor=AttnProcessor2_0(),
163 | )
164 |
165 | self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
166 | self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False)
167 |
168 | self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
169 |
170 | def forward(
171 | self,
172 | hidden_states: torch.Tensor,
173 | temb: torch.Tensor,
174 | rotary_emb: torch.Tensor,
175 | ) -> torch.Tensor:
176 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
177 | self.scale_shift_table + temb.float()
178 | ).chunk(6, dim=1)
179 |
180 | norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
181 | attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
182 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
183 |
184 | norm_hidden_states = (self.norm2(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
185 | hidden_states
186 | )
187 | ff_output = self.ffn(norm_hidden_states)
188 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
189 |
190 | return hidden_states
191 |
192 |
193 | class Transformer3DModel(ModelMixin, ConfigMixin):
194 |
195 | _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
196 | _no_split_modules = ["TransformerBlock"]
197 | _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2"]
198 |
199 | @register_to_config
200 | def __init__(
201 | self,
202 | patch_size: Tuple[int] = (1, 2, 2),
203 | num_attention_heads: int = 40,
204 | attention_head_dim: int = 128,
205 | in_channels: int = 16,
206 | out_channels: int = 16,
207 | freq_dim: int = 256,
208 | ffn_dim: int = 13824,
209 | num_layers: int = 40,
210 | cross_attn_norm: bool = True,
211 | qk_norm: Optional[str] = "rms_norm_across_heads",
212 | eps: float = 1e-6,
213 | added_kv_proj_dim: Optional[int] = None,
214 | rope_max_seq_len: int = 1024
215 | ) -> None:
216 | super().__init__()
217 |
218 | inner_dim = num_attention_heads * attention_head_dim
219 | out_channels = out_channels or in_channels
220 |
221 | # 1. Patch & position embedding
222 | self.rope = RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
223 | self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
224 |
225 | # 2. Condition embeddings
226 | self.condition_embedder = TimeEmbedding(
227 | dim=inner_dim,
228 | time_freq_dim=freq_dim,
229 | time_proj_dim=inner_dim * 6,
230 | )
231 |
232 | # 3. Transformer blocks
233 | self.blocks = nn.ModuleList(
234 | [
235 | TransformerBlock(
236 | inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
237 | )
238 | for _ in range(num_layers)
239 | ]
240 | )
241 |
242 | # 4. Output norm & projection
243 | self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
244 | self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
245 | self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
246 |
247 | def forward(
248 | self,
249 | hidden_states: torch.Tensor,
250 | timestep: torch.LongTensor
251 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
252 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
253 | p_t, p_h, p_w = self.config.patch_size
254 | post_patch_num_frames = num_frames // p_t
255 | post_patch_height = height // p_h
256 | post_patch_width = width // p_w
257 |
258 | rotary_emb = self.rope(hidden_states)
259 |
260 | hidden_states = self.patch_embedding(hidden_states)
261 | hidden_states = hidden_states.flatten(2).transpose(1, 2)
262 |
263 | temb, timestep_proj = self.condition_embedder(
264 | timestep
265 | )
266 | timestep_proj = timestep_proj.unflatten(1, (6, -1))
267 |
268 | for block in self.blocks:
269 | hidden_states = block(hidden_states, timestep_proj, rotary_emb)
270 |
271 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
272 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
273 | hidden_states = self.proj_out(hidden_states)
274 |
275 | hidden_states = hidden_states.reshape(
276 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
277 | )
278 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
279 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
280 |
281 | return Transformer2DModelOutput(sample=output)
282 |
--------------------------------------------------------------------------------
/imgs/gradio_demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/imgs/gradio_demo.gif
--------------------------------------------------------------------------------
/pipeline_minimax_remover.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Dict, List, Optional, Union
2 |
3 | import torch
4 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
5 | from diffusers.models import AutoencoderKLWan
6 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
7 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
8 | from diffusers.utils.torch_utils import randn_tensor
9 | from diffusers.video_processor import VideoProcessor
10 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
12 |
13 | import scipy
14 | import numpy as np
15 | import torch.nn.functional as F
16 | from transformer_minimax_remover import Transformer3DModel
17 | from einops import rearrange
18 |
19 | if is_torch_xla_available():
20 | import torch_xla.core.xla_model as xm
21 |
22 | XLA_AVAILABLE = True
23 | else:
24 | XLA_AVAILABLE = False
25 |
26 | class Minimax_Remover_Pipeline(DiffusionPipeline):
27 |
28 | model_cpu_offload_seq = "transformer->vae"
29 | _callback_tensor_inputs = ["latents"]
30 |
31 | def __init__(
32 | self,
33 | transformer: Transformer3DModel,
34 | vae: AutoencoderKLWan,
35 | scheduler: FlowMatchEulerDiscreteScheduler
36 | ):
37 | super().__init__()
38 |
39 | self.register_modules(
40 | vae=vae,
41 | transformer=transformer,
42 | scheduler=scheduler,
43 | )
44 |
45 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
46 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
47 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
48 |
49 | def prepare_latents(
50 | self,
51 | batch_size: int,
52 | num_channels_latents: 16,
53 | height: int = 720,
54 | width: int = 1280,
55 | num_latent_frames: int = 21,
56 | dtype: Optional[torch.dtype] = None,
57 | device: Optional[torch.device] = None,
58 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
59 | latents: Optional[torch.Tensor] = None,
60 | ) -> torch.Tensor:
61 | if latents is not None:
62 | return latents.to(device=device, dtype=dtype)
63 |
64 | shape = (
65 | batch_size,
66 | num_channels_latents,
67 | num_latent_frames,
68 | int(height) // self.vae_scale_factor_spatial,
69 | int(width) // self.vae_scale_factor_spatial,
70 | )
71 |
72 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
73 | return latents
74 |
75 | def expand_masks(self, masks, iterations):
76 | masks = masks.cpu().detach().numpy()
77 | # numpy array, masks [0,1], f h w c
78 | masks2 = []
79 | for i in range(len(masks)):
80 | mask = masks[i]
81 | mask = mask > 0
82 | mask = scipy.ndimage.binary_dilation(mask, iterations=iterations)
83 | masks2.append(mask)
84 | masks = np.array(masks2).astype(np.float32)
85 | masks = torch.from_numpy(masks)
86 | masks = masks.repeat(1,1,1,3)
87 | masks = rearrange(masks, "f h w c -> c f h w")
88 | masks = masks[None,...]
89 | return masks
90 |
91 | def resize(self, images, w, h):
92 | bsz,_,_,_,_ = images.shape
93 | images = rearrange(images, "b c f w h -> (b f) c w h")
94 | images = F.interpolate(images, (w,h), mode='bilinear')
95 | images = rearrange(images, "(b f) c w h -> b c f w h", b=bsz)
96 | return images
97 |
98 | @property
99 | def num_timesteps(self):
100 | return self._num_timesteps
101 |
102 | @property
103 | def current_timestep(self):
104 | return self._current_timestep
105 |
106 | @property
107 | def interrupt(self):
108 | return self._interrupt
109 |
110 | @torch.no_grad()
111 | def __call__(
112 | self,
113 | height: int = 720,
114 | width: int = 1280,
115 | num_frames: int = 81,
116 | num_inference_steps: int = 50,
117 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
118 | images: Optional[torch.Tensor] = None,
119 | masks: Optional[torch.Tensor] = None,
120 | latents: Optional[torch.Tensor] = None,
121 | output_type: Optional[str] = "np",
122 | iterations: int = 16
123 | ):
124 |
125 | self._current_timestep = None
126 | self._interrupt = False
127 | device = self._execution_device
128 | batch_size = 1
129 | transformer_dtype = torch.float16
130 |
131 | self.scheduler.set_timesteps(num_inference_steps, device=device)
132 | timesteps = self.scheduler.timesteps
133 |
134 | num_channels_latents = 16
135 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
136 |
137 | latents = self.prepare_latents(
138 | batch_size,
139 | num_channels_latents,
140 | height,
141 | width,
142 | num_latent_frames,
143 | torch.float16,
144 | device,
145 | generator,
146 | latents,
147 | )
148 |
149 | masks = self.expand_masks(masks, iterations)
150 | masks = self.resize(masks, height, width).to("cuda:0").half()
151 | masks[masks>0] = 1
152 | images = rearrange(images, "f h w c -> c f h w")
153 | images = self.resize(images[None,...], height, width).to("cuda:0").half()
154 |
155 | masked_images = images * (1-masks)
156 |
157 | latents_mean = (
158 | torch.tensor(self.vae.config.latents_mean)
159 | .view(1, self.vae.config.z_dim, 1, 1, 1)
160 | .to(self.vae.device, torch.float16)
161 | )
162 |
163 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
164 | self.vae.device, torch.float16
165 | )
166 |
167 | with torch.no_grad():
168 | masked_latents = self.vae.encode(masked_images.half()).latent_dist.mode()
169 | masks_latents = self.vae.encode(2*masks.half()-1.0).latent_dist.mode()
170 |
171 | masked_latents = (masked_latents - latents_mean) * latents_std
172 | masks_latents = (masks_latents - latents_mean) * latents_std
173 |
174 | self._num_timesteps = len(timesteps)
175 |
176 | with self.progress_bar(total=num_inference_steps) as progress_bar:
177 | for i, t in enumerate(timesteps):
178 |
179 | latent_model_input = latents.to(transformer_dtype)
180 |
181 | latent_model_input = torch.cat([latent_model_input, masked_latents, masks_latents], dim=1)
182 | timestep = t.expand(latents.shape[0])
183 |
184 | noise_pred = self.transformer(
185 | hidden_states=latent_model_input.half(),
186 | timestep=timestep
187 | )[0]
188 |
189 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
190 |
191 | progress_bar.update()
192 |
193 | latents = latents.half() / latents_std + latents_mean
194 | video = self.vae.decode(latents, return_dict=False)[0]
195 | video = self.video_processor.postprocess_video(video, output_type=output_type)
196 |
197 | return WanPipelineOutput(frames=video)
198 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.7.1
2 | torchvision==0.22.1
3 | decord==0.6
4 | diffusers==0.33.1
5 | Pillow==9.2
6 | einops==0.8.0
7 | scipy==1.14.0
8 | numpy==1.26.4
9 | opencv-python==4.10.0.84
10 | gradio_client==0.7.0
11 | huggingface_hub==0.32.4
12 | omegaconf
13 | einops
14 | accelerate==0.30.1
15 | gradio==3.40.0
16 | moviepy==1.0.3
17 | hydra-core
18 | pytest
19 |
--------------------------------------------------------------------------------
/test_minimax_remover.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.utils import export_to_video
3 | from decord import VideoReader
4 | from diffusers.models import AutoencoderKLWan
5 | from transformer_minimax_remover import Transformer3DModel
6 | from diffusers.schedulers import UniPCMultistepScheduler
7 | from pipeline_minimax_remover import Minimax_Remover_Pipeline
8 |
9 | random_seed = 42
10 | video_length = 81
11 | device = torch.device("cuda:0")
12 |
13 | vae = AutoencoderKLWan.from_pretrained("./vae", torch_dtype=torch.float16)
14 | transformer = Transformer3DModel.from_pretrained("./transformer", torch_dtype=torch.float16)
15 | scheduler = UniPCMultistepScheduler.from_pretrained("./scheduler")
16 |
17 | pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler)
18 | pipe.to(device)
19 |
20 | # the iterations is the hyperparameter for mask dilation
21 | def inference(pixel_values, masks, iterations=6):
22 | video = pipe(
23 | images=pixel_values,
24 | masks=masks,
25 | num_frames=video_length,
26 | height=480,
27 | width=832,
28 | num_inference_steps=12,
29 | generator=torch.Generator(device=device).manual_seed(random_seed),
30 | iterations=iterations
31 | ).frames[0]
32 | export_to_video(video, "./output.mp4")
33 |
34 | def load_video(video_path):
35 | vr = VideoReader(video_path)
36 | images = vr.get_batch(list(range(video_length))).asnumpy()
37 | images = torch.from_numpy(images)/127.5 - 1.0
38 | return images
39 |
40 | def load_mask(mask_path):
41 | vr = VideoReader(mask_path)
42 | masks = vr.get_batch(list(range(video_length))).asnumpy()
43 | masks = torch.from_numpy(masks)
44 | masks = masks[:, :, :, :1]
45 | masks[masks > 20] = 255
46 | masks[masks < 255] = 0
47 | masks = masks / 255.0
48 | return masks
49 |
50 | video_path = "./video.mp4"
51 | mask_path = "./mask.mp4"
52 |
53 | images = load_video(video_path)
54 | masks = load_mask(mask_path)
55 |
56 | inference(images, masks)
--------------------------------------------------------------------------------
/transformer_minimax_remover.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Dict, Optional, Tuple, Union
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from diffusers.configuration_utils import ConfigMixin, register_to_config
9 | from diffusers.utils import logging
10 | from diffusers.models.attention import FeedForward
11 | from diffusers.models.attention_processor import Attention
12 | from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
13 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
14 | from diffusers.models.modeling_utils import ModelMixin
15 | from diffusers.models.normalization import FP32LayerNorm
16 |
17 | class AttnProcessor2_0:
18 | def __init__(self):
19 | if not hasattr(F, "scaled_dot_product_attention"):
20 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
21 |
22 | def __call__(
23 | self,
24 | attn: Attention,
25 | hidden_states: torch.Tensor,
26 | rotary_emb: Optional[torch.Tensor] = None,
27 | attention_mask: Optional[torch.Tensor] = None,
28 | encoder_hidden_states: Optional[torch.Tensor] = None
29 | ) -> torch.Tensor:
30 |
31 | encoder_hidden_states = hidden_states
32 | query = attn.to_q(hidden_states)
33 | key = attn.to_k(encoder_hidden_states)
34 | value = attn.to_v(encoder_hidden_states)
35 |
36 | if attn.norm_q is not None:
37 | query = attn.norm_q(query)
38 | if attn.norm_k is not None:
39 | key = attn.norm_k(key)
40 |
41 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
42 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
43 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
44 |
45 | if rotary_emb is not None:
46 |
47 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
48 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
49 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
50 | return x_out.type_as(hidden_states)
51 |
52 | query = apply_rotary_emb(query, rotary_emb)
53 | key = apply_rotary_emb(key, rotary_emb)
54 |
55 | hidden_states = F.scaled_dot_product_attention(
56 | query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
57 | )
58 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
59 | hidden_states = hidden_states.type_as(query)
60 |
61 | hidden_states = attn.to_out[0](hidden_states)
62 | hidden_states = attn.to_out[1](hidden_states)
63 | return hidden_states
64 |
65 | class TimeEmbedding(nn.Module):
66 | def __init__(
67 | self,
68 | dim: int,
69 | time_freq_dim: int,
70 | time_proj_dim: int
71 | ):
72 | super().__init__()
73 |
74 | self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
75 | self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
76 |
77 | self.act_fn = nn.SiLU()
78 | self.time_proj = nn.Linear(dim, time_proj_dim)
79 |
80 | def forward(
81 | self,
82 | timestep: torch.Tensor,
83 | ):
84 | timestep = self.timesteps_proj(timestep)
85 |
86 | time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
87 | if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
88 | timestep = timestep.to(time_embedder_dtype)
89 | temb = self.time_embedder(timestep).type_as(self.time_proj.weight.data)
90 | timestep_proj = self.time_proj(self.act_fn(temb))
91 |
92 | return temb, timestep_proj
93 |
94 |
95 | class RotaryPosEmbed(nn.Module):
96 | def __init__(
97 | self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
98 | ):
99 | super().__init__()
100 |
101 | self.attention_head_dim = attention_head_dim
102 | self.patch_size = patch_size
103 | self.max_seq_len = max_seq_len
104 |
105 | h_dim = w_dim = 2 * (attention_head_dim // 6)
106 | t_dim = attention_head_dim - h_dim - w_dim
107 |
108 | freqs = []
109 | for dim in [t_dim, h_dim, w_dim]:
110 | freq = get_1d_rotary_pos_embed(
111 | dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
112 | )
113 | freqs.append(freq)
114 | self.freqs = torch.cat(freqs, dim=1)
115 |
116 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
117 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
118 | p_t, p_h, p_w = self.patch_size
119 | ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
120 |
121 | self.freqs = self.freqs.to(hidden_states.device)
122 | freqs = self.freqs.split_with_sizes(
123 | [
124 | self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
125 | self.attention_head_dim // 6,
126 | self.attention_head_dim // 6,
127 | ],
128 | dim=1,
129 | )
130 |
131 | freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
132 | freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
133 | freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
134 | freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
135 | return freqs
136 |
137 |
138 | class TransformerBlock(nn.Module):
139 | def __init__(
140 | self,
141 | dim: int,
142 | ffn_dim: int,
143 | num_heads: int,
144 | qk_norm: str = "rms_norm_across_heads",
145 | cross_attn_norm: bool = False,
146 | eps: float = 1e-6,
147 | added_kv_proj_dim: Optional[int] = None,
148 | ):
149 | super().__init__()
150 |
151 | self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
152 | self.attn1 = Attention(
153 | query_dim=dim,
154 | heads=num_heads,
155 | kv_heads=num_heads,
156 | dim_head=dim // num_heads,
157 | qk_norm=qk_norm,
158 | eps=eps,
159 | bias=True,
160 | cross_attention_dim=None,
161 | out_bias=True,
162 | processor=AttnProcessor2_0(),
163 | )
164 |
165 | self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
166 | self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False)
167 |
168 | self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
169 |
170 | def forward(
171 | self,
172 | hidden_states: torch.Tensor,
173 | temb: torch.Tensor,
174 | rotary_emb: torch.Tensor,
175 | ) -> torch.Tensor:
176 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
177 | self.scale_shift_table + temb.float()
178 | ).chunk(6, dim=1)
179 |
180 | norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
181 | attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
182 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
183 |
184 | norm_hidden_states = (self.norm2(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
185 | hidden_states
186 | )
187 | ff_output = self.ffn(norm_hidden_states)
188 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
189 |
190 | return hidden_states
191 |
192 |
193 | class Transformer3DModel(ModelMixin, ConfigMixin):
194 |
195 | _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
196 | _no_split_modules = ["TransformerBlock"]
197 | _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2"]
198 |
199 | @register_to_config
200 | def __init__(
201 | self,
202 | patch_size: Tuple[int] = (1, 2, 2),
203 | num_attention_heads: int = 40,
204 | attention_head_dim: int = 128,
205 | in_channels: int = 16,
206 | out_channels: int = 16,
207 | freq_dim: int = 256,
208 | ffn_dim: int = 13824,
209 | num_layers: int = 40,
210 | cross_attn_norm: bool = True,
211 | qk_norm: Optional[str] = "rms_norm_across_heads",
212 | eps: float = 1e-6,
213 | added_kv_proj_dim: Optional[int] = None,
214 | rope_max_seq_len: int = 1024
215 | ) -> None:
216 | super().__init__()
217 |
218 | inner_dim = num_attention_heads * attention_head_dim
219 | out_channels = out_channels or in_channels
220 |
221 | # 1. Patch & position embedding
222 | self.rope = RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
223 | self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
224 |
225 | # 2. Condition embeddings
226 | self.condition_embedder = TimeEmbedding(
227 | dim=inner_dim,
228 | time_freq_dim=freq_dim,
229 | time_proj_dim=inner_dim * 6,
230 | )
231 |
232 | # 3. Transformer blocks
233 | self.blocks = nn.ModuleList(
234 | [
235 | TransformerBlock(
236 | inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
237 | )
238 | for _ in range(num_layers)
239 | ]
240 | )
241 |
242 | # 4. Output norm & projection
243 | self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
244 | self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
245 | self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
246 |
247 | def forward(
248 | self,
249 | hidden_states: torch.Tensor,
250 | timestep: torch.LongTensor
251 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
252 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
253 | p_t, p_h, p_w = self.config.patch_size
254 | post_patch_num_frames = num_frames // p_t
255 | post_patch_height = height // p_h
256 | post_patch_width = width // p_w
257 |
258 | rotary_emb = self.rope(hidden_states)
259 |
260 | hidden_states = self.patch_embedding(hidden_states)
261 | hidden_states = hidden_states.flatten(2).transpose(1, 2)
262 |
263 | temb, timestep_proj = self.condition_embedder(
264 | timestep
265 | )
266 | timestep_proj = timestep_proj.unflatten(1, (6, -1))
267 |
268 | for block in self.blocks:
269 | hidden_states = block(hidden_states, timestep_proj, rotary_emb)
270 |
271 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
272 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
273 | hidden_states = self.proj_out(hidden_states)
274 |
275 | hidden_states = hidden_states.reshape(
276 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
277 | )
278 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
279 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
280 |
281 | return Transformer2DModelOutput(sample=output)
282 |
--------------------------------------------------------------------------------