├── README.md ├── gradio_demo ├── cartoon │ ├── 0.mp4 │ ├── 1.mp4 │ ├── 2.mp4 │ ├── 3.mp4 │ └── 4.mp4 ├── image_demo.py ├── normal_videos │ ├── 0.mp4 │ ├── 1.mp4 │ ├── 2.mp4 │ ├── 3.mp4 │ ├── 4.mp4 │ ├── 5.mp4 │ └── 6.mp4 ├── pipeline_minimax_remover.py ├── sam2 │ ├── SAM2-Video-Predictor │ │ └── checkpoints │ │ │ ├── README.md │ │ │ └── sam2_hiera_l.yaml │ ├── __init__.py │ ├── automatic_mask_generator.py │ ├── build_sam.py │ ├── configs │ │ ├── __init__.py │ │ ├── sam2_hiera_b+.yaml │ │ ├── sam2_hiera_l.yaml │ │ ├── sam2_hiera_s.yaml │ │ └── sam2_hiera_t.yaml │ ├── modeling │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── hieradet.py │ │ │ ├── image_encoder.py │ │ │ └── utils.py │ │ ├── memory_attention.py │ │ ├── memory_encoder.py │ │ ├── position_encoding.py │ │ ├── sam │ │ │ ├── __init__.py │ │ │ ├── mask_decoder.py │ │ │ ├── prompt_encoder.py │ │ │ └── transformer.py │ │ ├── sam2_base.py │ │ └── sam2_utils.py │ ├── sam2_image_predictor.py │ ├── sam2_video_predictor.py │ └── utils │ │ ├── __init__.py │ │ ├── amg.py │ │ ├── download.py │ │ ├── misc.py │ │ ├── transforms.py │ │ └── visualization.py ├── test.py └── transformer_minimax_remover.py ├── imgs └── gradio_demo.gif ├── pipeline_minimax_remover.py ├── requirements.txt ├── test_minimax_remover.py └── transformer_minimax_remover.py /README.md: -------------------------------------------------------------------------------- 1 |

2 | MiniMax-Remover: Taming Bad Noise Helps Video Object Removal 3 |

4 | 5 |

6 | Bojia Zi*, 7 | Weixuan Peng*, 8 | Xianbiao Qi, 9 | Jianan Wang, Shihao Zhao, Rong Xiao, Kam-Fai Wong
10 | * Equal contribution. Corresponding author. 11 |

12 | 13 |

14 | Huggingface Model 15 | Github 16 | Huggingface Space 17 | arXiv 18 | YouTube 19 | Demo Page 20 | Replicate 21 |

22 | 23 | --- 24 | 25 | ## 🚀 Overview 26 | 27 | **MiniMax-Remover** is a fast and effective video object remover based on minimax optimization. It operates in two stages: the first stage trains a remover using a simplified DiT architecture, while the second stage distills a robust remover with CFG removal and fewer inference steps. 28 | 29 | --- 30 | 31 | ## ✨ Features: 32 | 33 | * **Fast:** Requires only 6 inference steps and does not use CFG, making it highly efficient. 34 | 35 | * **Effective:** Seamlessly removes objects from videos and generates high-quality visual content. 36 | 37 | * **Robust:** Maintains robustness by preventing the regeneration of undesired objects or artifacts within the masked region, even under varying noise conditions. 38 | 39 | --- 40 | 41 | ## 🛠️ Installation 42 | 43 | All dependencies are listed in `requirements.txt`. 44 | 45 | ```bash 46 | pip install -r requirements.txt 47 | ``` 48 | 49 | --- 50 | 51 | ## 🏃‍♂️ Gradio Demo 52 | 53 |

54 | 55 | firstpage 56 | 57 |

58 | 59 | You can use this gradio demo to remove objects. Note that you don't need to compile the sam2. 60 | ```bash 61 | cd gradio_demo 62 | python3 test.py 63 | ``` 64 | 65 | --- 66 | 67 | ## 📂 Download 68 | 69 | ```shell 70 | huggingface-cli download zibojia/minimax-remover --include vae transformer scheduler --local-dir . 71 | ``` 72 | 73 | --- 74 | 75 | ## ⚡ Quick Start 76 | 77 | ### Minimal Example 78 | 79 | ```python 80 | import torch 81 | from diffusers.utils import export_to_video 82 | from decord import VideoReader 83 | from diffusers.models import AutoencoderKLWan 84 | from transformer_minimax_remover import Transformer3DModel 85 | from diffusers.schedulers import UniPCMultistepScheduler 86 | from pipeline_minimax_remover import Minimax_Remover_Pipeline 87 | 88 | random_seed = 42 89 | video_length = 81 90 | device = torch.device("cuda:0") 91 | 92 | # Load model weights separately 93 | vae = AutoencoderKLWan.from_pretrained("./vae", torch_dtype=torch.float16) 94 | transformer = Transformer3DModel.from_pretrained("./transformer", torch_dtype=torch.float16) 95 | scheduler = UniPCMultistepScheduler.from_pretrained("./scheduler") 96 | 97 | images = # images in range [-1, 1] 98 | masks = # masks in range [0, 1] 99 | 100 | # Initialize the pipeline (pass the loaded weights as objects) 101 | pipe = Minimax_Remover_Pipeline(vae=vae, transformer=transformer, \ 102 | scheduler=scheduler, torch_dtype=torch.float16 103 | ).to(device) 104 | 105 | result = pipe(images=images, masks=masks, num_frames=video_length, height=480, width=832, \ 106 | num_inference_steps=12, generator=torch.Generator(device=device).manual_seed(random_seed), iterations=6 \ 107 | ).frames[0] 108 | export_to_video(result, "./output.mp4") 109 | ``` 110 | --- 111 | 112 | ## 📧 Contact 113 | 114 | Feel free to send an email to [19210240030@fudan.edu.cn](mailto:19210240030@fudan.edu.cn) if you have any questions or suggestions. 115 | -------------------------------------------------------------------------------- /gradio_demo/cartoon/0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/0.mp4 -------------------------------------------------------------------------------- /gradio_demo/cartoon/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/1.mp4 -------------------------------------------------------------------------------- /gradio_demo/cartoon/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/2.mp4 -------------------------------------------------------------------------------- /gradio_demo/cartoon/3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/3.mp4 -------------------------------------------------------------------------------- /gradio_demo/cartoon/4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/cartoon/4.mp4 -------------------------------------------------------------------------------- /gradio_demo/image_demo.py: -------------------------------------------------------------------------------- 1 | # create gradio demo to input one image and output one image 2 | import os 3 | import gradio as gr 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | import torch 8 | import time 9 | import random 10 | from huggingface_hub import snapshot_download 11 | from diffusers.models import AutoencoderKLWan 12 | from transformer_minimax_remover import Transformer3DModel 13 | from diffusers.schedulers import UniPCMultistepScheduler 14 | from pipeline_minimax_remover import Minimax_Remover_Pipeline 15 | from sam2.build_sam import build_sam2 16 | from sam2.sam2_image_predictor import SAM2ImagePredictor 17 | 18 | # Create directories for models 19 | os.makedirs("./SAM2-Video-Predictor/checkpoints/", exist_ok=True) 20 | os.makedirs("./model/", exist_ok=True) 21 | 22 | # Download models from Hugging Face Hub 23 | def download_sam2(): 24 | snapshot_download(repo_id="facebook/sam2-hiera-large", local_dir="./SAM2-Video-Predictor/checkpoints/") 25 | print("Download sam2 completed") 26 | 27 | def download_remover(): 28 | snapshot_download(repo_id="zibojia/minimax-remover", local_dir="./model/") 29 | print("Download minimax remover completed") 30 | 31 | download_sam2() 32 | download_remover() 33 | 34 | COLOR_PALETTE = [ 35 | (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), 36 | (0, 255, 255), (255, 128, 0), (128, 0, 255), (0, 128, 255), (128, 255, 0) 37 | ] 38 | 39 | random_seed = 42 40 | W = 1024 41 | H = W 42 | device = "cuda" if torch.cuda.is_available() else "cpu" 43 | 44 | def get_pipe_and_predictor(): 45 | vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16) 46 | transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16) 47 | scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler") 48 | 49 | pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler) 50 | pipe.to(device) 51 | 52 | sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" 53 | config = "sam2_hiera_l.yaml" 54 | 55 | model = build_sam2(config, sam2_checkpoint, device=device) 56 | model.image_size = 1024 57 | image_predictor = SAM2ImagePredictor(sam_model=model) 58 | 59 | return pipe, image_predictor 60 | 61 | def get_image_info(image_pil, state): 62 | state["input_points"] = [] 63 | state["scaled_points"] = [] 64 | state["input_labels"] = [] 65 | 66 | image_np = np.array(image_pil) 67 | 68 | if image_np.shape[0] > image_np.shape[1]: 69 | W_ = W 70 | H_ = int(W_ * image_np.shape[0] / image_np.shape[1]) 71 | else: 72 | H_ = H 73 | W_ = int(H_ * image_np.shape[1] / image_np.shape[0]) 74 | 75 | image_np = cv2.resize(image_np, (W_, H_)) 76 | state["origin_image"] = image_np 77 | state["mask"] = None 78 | state["painted_image"] = None 79 | return Image.fromarray(image_np) 80 | 81 | def segment_frame(evt: gr.SelectData, label, state): 82 | if state["origin_image"] is None: 83 | return None 84 | x, y = evt.index 85 | new_point = [x, y] 86 | label_value = 1 if label == "Positive" else 0 87 | 88 | state["input_points"].append(new_point) 89 | state["input_labels"].append(label_value) 90 | height, width = state["origin_image"].shape[0:2] 91 | scaled_points = [] 92 | for pt in state["input_points"]: 93 | sx = pt[0] 94 | sy = pt[1] 95 | scaled_points.append([sx, sy]) 96 | 97 | state["scaled_points"] = scaled_points 98 | 99 | image_predictor.set_image(state["origin_image"]) 100 | mask, _, _ = image_predictor.predict( 101 | point_coords=np.array(state["scaled_points"]), 102 | point_labels=np.array(state["input_labels"]), 103 | multimask_output=False, 104 | ) 105 | 106 | mask = np.squeeze(mask) 107 | mask = cv2.resize(mask, (width, height)) 108 | mask = mask[:,:,None] 109 | 110 | color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 111 | color = color[None, None, :] 112 | org_image = state["origin_image"].astype(np.float32) / 255.0 113 | painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color 114 | painted_image = np.uint8(np.clip(painted_image * 255, 0, 255)) 115 | state["painted_image"] = painted_image 116 | state["mask"] = mask[:,:,0] 117 | 118 | for i in range(len(state["input_points"])): 119 | point = state["input_points"][i] 120 | if state["input_labels"][i] == 0: 121 | cv2.circle(painted_image, tuple(point), radius=5, color=(0, 0, 255), thickness=-1) 122 | else: 123 | cv2.circle(painted_image, tuple(point), radius=5, color=(255, 0, 0), thickness=-1) 124 | 125 | return Image.fromarray(painted_image) 126 | 127 | def clear_clicks(state): 128 | state["input_points"] = [] 129 | state["input_labels"] = [] 130 | state["scaled_points"] = [] 131 | state["mask"] = None 132 | state["painted_image"] = None 133 | return Image.fromarray(state["origin_image"]) if state["origin_image"] is not None else None 134 | 135 | def preprocess_for_removal(image, mask): 136 | if image.shape[0] > image.shape[1]: 137 | img_resized = cv2.resize(image, (480, 832), interpolation=cv2.INTER_LINEAR) 138 | else: 139 | img_resized = cv2.resize(image, (832, 480), interpolation=cv2.INTER_LINEAR) 140 | img_resized = img_resized.astype(np.float32) / 127.5 - 1.0 # [-1, 1] 141 | 142 | if mask.shape[0] > mask.shape[1]: 143 | msk_resized = cv2.resize(mask, (480, 832), interpolation=cv2.INTER_NEAREST) 144 | else: 145 | msk_resized = cv2.resize(mask, (832, 480), interpolation=cv2.INTER_NEAREST) 146 | msk_resized = msk_resized.astype(np.float32) 147 | msk_resized = (msk_resized > 0.5).astype(np.float32) 148 | 149 | return torch.from_numpy(img_resized).half().to(device), torch.from_numpy(msk_resized).half().to(device) 150 | 151 | def inference_and_return_image(dilation_iterations, num_inference_steps, state=None): 152 | if state["origin_image"] is None or state["mask"] is None: 153 | return None 154 | image = state["origin_image"] 155 | mask = state["mask"] 156 | 157 | img_tensor, mask_tensor = preprocess_for_removal(image, mask) 158 | img_tensor = img_tensor.unsqueeze(0) 159 | mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(-1) 160 | 161 | if mask_tensor.shape[1] < mask_tensor.shape[2]: 162 | height = 480 163 | width = 832 164 | else: 165 | height = 832 166 | width = 480 167 | 168 | with torch.no_grad(): 169 | out = pipe( 170 | images=img_tensor, 171 | masks=mask_tensor, 172 | num_frames=1, 173 | height=height, 174 | width=width, 175 | num_inference_steps=int(num_inference_steps), 176 | generator=torch.Generator(device=device).manual_seed(random_seed), 177 | iterations=int(dilation_iterations) 178 | ).frames[0] 179 | 180 | out = np.uint8(out * 255) 181 | 182 | return Image.fromarray(out[0]) 183 | 184 | pipe, image_predictor = get_pipe_and_predictor() 185 | 186 | with gr.Blocks() as demo: 187 | state = gr.State({ 188 | "origin_image": None, 189 | "mask": None, 190 | "painted_image": None, 191 | "input_points": [], 192 | "scaled_points": [], 193 | "input_labels": [], 194 | }) 195 | gr.Markdown("

Minimax-Remover: Image Object Removal

") 196 | 197 | with gr.Row(): 198 | with gr.Column(): 199 | image_input = gr.Image(label="Upload Image", type="pil") 200 | get_info_btn = gr.Button("Load Image") 201 | 202 | point_prompt = gr.Radio(["Positive", "Negative"], label="Click Type", value="Positive") 203 | clear_btn = gr.Button("Clear All Clicks") 204 | 205 | dilation_slider = gr.Slider(minimum=1, maximum=20, value=6, step=1, label="Mask Dilation") 206 | inference_steps_slider = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Num Inference Steps") 207 | 208 | remove_btn = gr.Button("Remove Object") 209 | 210 | with gr.Column(): 211 | image_output_segmentation = gr.Image(label="Segmentation", interactive=True) 212 | image_output_removed = gr.Image(label="Object Removed") 213 | 214 | get_info_btn.click(get_image_info, inputs=[image_input, state], outputs=image_output_segmentation) 215 | image_output_segmentation.select(fn=segment_frame, inputs=[point_prompt, state], outputs=image_output_segmentation) 216 | clear_btn.click(clear_clicks, inputs=state, outputs=image_output_segmentation) 217 | remove_btn.click( 218 | inference_and_return_image, 219 | inputs=[dilation_slider, inference_steps_slider, state], 220 | outputs=image_output_removed 221 | ) 222 | 223 | demo.launch(server_name="0.0.0.0", server_port=8000) -------------------------------------------------------------------------------- /gradio_demo/normal_videos/0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/0.mp4 -------------------------------------------------------------------------------- /gradio_demo/normal_videos/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/1.mp4 -------------------------------------------------------------------------------- /gradio_demo/normal_videos/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/2.mp4 -------------------------------------------------------------------------------- /gradio_demo/normal_videos/3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/3.mp4 -------------------------------------------------------------------------------- /gradio_demo/normal_videos/4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/4.mp4 -------------------------------------------------------------------------------- /gradio_demo/normal_videos/5.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/5.mp4 -------------------------------------------------------------------------------- /gradio_demo/normal_videos/6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/gradio_demo/normal_videos/6.mp4 -------------------------------------------------------------------------------- /gradio_demo/pipeline_minimax_remover.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Union 2 | 3 | import torch 4 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 5 | from diffusers.models import AutoencoderKLWan 6 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 7 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from diffusers.video_processor import VideoProcessor 10 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 11 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput 12 | 13 | import scipy 14 | import numpy as np 15 | import torch.nn.functional as F 16 | from transformer_minimax_remover import Transformer3DModel 17 | from einops import rearrange 18 | 19 | if is_torch_xla_available(): 20 | import torch_xla.core.xla_model as xm 21 | 22 | XLA_AVAILABLE = True 23 | else: 24 | XLA_AVAILABLE = False 25 | 26 | class Minimax_Remover_Pipeline(DiffusionPipeline): 27 | 28 | model_cpu_offload_seq = "transformer->vae" 29 | _callback_tensor_inputs = ["latents"] 30 | 31 | def __init__( 32 | self, 33 | transformer: Transformer3DModel, 34 | vae: AutoencoderKLWan, 35 | scheduler: FlowMatchEulerDiscreteScheduler 36 | ): 37 | super().__init__() 38 | 39 | self.register_modules( 40 | vae=vae, 41 | transformer=transformer, 42 | scheduler=scheduler, 43 | ) 44 | 45 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 46 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 47 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 48 | 49 | def prepare_latents( 50 | self, 51 | batch_size: int, 52 | num_channels_latents: 16, 53 | height: int = 720, 54 | width: int = 1280, 55 | num_latent_frames: int = 21, 56 | dtype: Optional[torch.dtype] = None, 57 | device: Optional[torch.device] = None, 58 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 59 | latents: Optional[torch.Tensor] = None, 60 | ) -> torch.Tensor: 61 | if latents is not None: 62 | return latents.to(device=device, dtype=dtype) 63 | 64 | shape = ( 65 | batch_size, 66 | num_channels_latents, 67 | num_latent_frames, 68 | int(height) // self.vae_scale_factor_spatial, 69 | int(width) // self.vae_scale_factor_spatial, 70 | ) 71 | 72 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 73 | return latents 74 | 75 | def expand_masks(self, masks, iterations): 76 | masks = masks.cpu().detach().numpy() 77 | # numpy array, masks [0,1], f h w c 78 | masks2 = [] 79 | for i in range(len(masks)): 80 | mask = masks[i] 81 | mask = mask > 0 82 | mask = scipy.ndimage.binary_dilation(mask, iterations=iterations) 83 | masks2.append(mask) 84 | masks = np.array(masks2).astype(np.float32) 85 | masks = torch.from_numpy(masks) 86 | masks = masks.repeat(1,1,1,3) 87 | masks = rearrange(masks, "f h w c -> c f h w") 88 | masks = masks[None,...] 89 | return masks 90 | 91 | def resize(self, images, w, h): 92 | bsz,_,_,_,_ = images.shape 93 | images = rearrange(images, "b c f w h -> (b f) c w h") 94 | images = F.interpolate(images, (w,h), mode='bilinear') 95 | images = rearrange(images, "(b f) c w h -> b c f w h", b=bsz) 96 | return images 97 | 98 | @property 99 | def num_timesteps(self): 100 | return self._num_timesteps 101 | 102 | @property 103 | def current_timestep(self): 104 | return self._current_timestep 105 | 106 | @property 107 | def interrupt(self): 108 | return self._interrupt 109 | 110 | @torch.no_grad() 111 | def __call__( 112 | self, 113 | height: int = 720, 114 | width: int = 1280, 115 | num_frames: int = 81, 116 | num_inference_steps: int = 50, 117 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 118 | images: Optional[torch.Tensor] = None, 119 | masks: Optional[torch.Tensor] = None, 120 | latents: Optional[torch.Tensor] = None, 121 | output_type: Optional[str] = "np", 122 | iterations: int = 16 123 | ): 124 | 125 | self._current_timestep = None 126 | self._interrupt = False 127 | device = self._execution_device 128 | batch_size = 1 129 | transformer_dtype = torch.float16 130 | 131 | self.scheduler.set_timesteps(num_inference_steps, device=device) 132 | timesteps = self.scheduler.timesteps 133 | 134 | num_channels_latents = 16 135 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 136 | 137 | latents = self.prepare_latents( 138 | batch_size, 139 | num_channels_latents, 140 | height, 141 | width, 142 | num_latent_frames, 143 | torch.float16, 144 | device, 145 | generator, 146 | latents, 147 | ) 148 | 149 | masks = self.expand_masks(masks, iterations) 150 | masks = self.resize(masks, height, width).to("cuda:0").half() 151 | masks[masks>0] = 1 152 | images = rearrange(images, "f h w c -> c f h w") 153 | images = self.resize(images[None,...], height, width).to("cuda:0").half() 154 | 155 | masked_images = images * (1-masks) 156 | 157 | latents_mean = ( 158 | torch.tensor(self.vae.config.latents_mean) 159 | .view(1, self.vae.config.z_dim, 1, 1, 1) 160 | .to(self.vae.device, torch.float16) 161 | ) 162 | 163 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 164 | self.vae.device, torch.float16 165 | ) 166 | 167 | with torch.no_grad(): 168 | masked_latents = self.vae.encode(masked_images.half()).latent_dist.mode() 169 | masks_latents = self.vae.encode(2*masks.half()-1.0).latent_dist.mode() 170 | 171 | masked_latents = (masked_latents - latents_mean) * latents_std 172 | masks_latents = (masks_latents - latents_mean) * latents_std 173 | 174 | self._num_timesteps = len(timesteps) 175 | 176 | with self.progress_bar(total=num_inference_steps) as progress_bar: 177 | for i, t in enumerate(timesteps): 178 | 179 | latent_model_input = latents.to(transformer_dtype) 180 | 181 | #print("latent_model_input, masked_latents, masks_latents", latent_model_input.shape, masked_latents.shape, masks_latents.shape) 182 | latent_model_input = torch.cat([latent_model_input, masked_latents, masks_latents], dim=1) 183 | timestep = t.expand(latents.shape[0]) 184 | 185 | noise_pred = self.transformer( 186 | hidden_states=latent_model_input.half(), 187 | timestep=timestep 188 | )[0] 189 | 190 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 191 | 192 | progress_bar.update() 193 | 194 | latents = latents.half() / latents_std + latents_mean 195 | video = self.vae.decode(latents, return_dict=False)[0] 196 | video = self.video_processor.postprocess_video(video, output_type=output_type) 197 | 198 | return WanPipelineOutput(frames=video) 199 | -------------------------------------------------------------------------------- /gradio_demo/sam2/SAM2-Video-Predictor/checkpoints/README.md: -------------------------------------------------------------------------------- 1 | --- 2 | license: apache-2.0 3 | pipeline_tag: mask-generation 4 | library_name: sam2 5 | --- 6 | 7 | Repository for SAM 2: Segment Anything in Images and Videos, a foundation model towards solving promptable visual segmentation in images and videos from FAIR. See the [SAM 2 paper](https://arxiv.org/abs/2408.00714) for more information. 8 | 9 | The official code is publicly release in this [repo](https://github.com/facebookresearch/segment-anything-2/). 10 | 11 | ## Usage 12 | 13 | For image prediction: 14 | 15 | ```python 16 | import torch 17 | from sam2.sam2_image_predictor import SAM2ImagePredictor 18 | 19 | predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") 20 | 21 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): 22 | predictor.set_image() 23 | masks, _, _ = predictor.predict() 24 | ``` 25 | 26 | For video prediction: 27 | 28 | ```python 29 | import torch 30 | from sam2.sam2_video_predictor import SAM2VideoPredictor 31 | 32 | predictor = SAM2VideoPredictor.from_pretrained("facebook/sam2-hiera-large") 33 | 34 | with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16): 35 | state = predictor.init_state() 36 | 37 | # add new prompts and instantly get the output on the same frame 38 | frame_idx, object_ids, masks = predictor.add_new_points_or_box(state, ): 39 | 40 | # propagate the prompts to get masklets throughout the video 41 | for frame_idx, object_ids, masks in predictor.propagate_in_video(state): 42 | ... 43 | ``` 44 | 45 | Refer to the [demo notebooks](https://github.com/facebookresearch/segment-anything-2/tree/main/notebooks) for details. 46 | 47 | ### Citation 48 | 49 | To cite the paper, model, or software, please use the below: 50 | ``` 51 | @article{ravi2024sam2, 52 | title={SAM 2: Segment Anything in Images and Videos}, 53 | author={Ravi, Nikhila and Gabeur, Valentin and Hu, Yuan-Ting and Hu, Ronghang and Ryali, Chaitanya and Ma, Tengyu and Khedr, Haitham and R{\"a}dle, Roman and Rolland, Chloe and Gustafson, Laura and Mintun, Eric and Pan, Junting and Alwala, Kalyan Vasudev and Carion, Nicolas and Wu, Chao-Yuan and Girshick, Ross and Doll{\'a}r, Piotr and Feichtenhofer, Christoph}, 54 | journal={arXiv preprint arXiv:2408.00714}, 55 | url={https://arxiv.org/abs/2408.00714}, 56 | year={2024} 57 | } 58 | ``` -------------------------------------------------------------------------------- /gradio_demo/sam2/SAM2-Video-Predictor/checkpoints/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False -------------------------------------------------------------------------------- /gradio_demo/sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize 8 | 9 | from .build_sam import load_model 10 | 11 | initialize("configs", version_base="1.2") 12 | -------------------------------------------------------------------------------- /gradio_demo/sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | from .utils.misc import VARIANTS, variant_to_config_mapping 15 | 16 | 17 | def load_model( 18 | variant: str, 19 | ckpt_path=None, 20 | device="cpu", 21 | mode="eval", 22 | hydra_overrides_extra=[], 23 | apply_postprocessing=True, 24 | ) -> torch.nn.Module: 25 | assert variant in VARIANTS, f"only accepted variants are {VARIANTS}" 26 | 27 | return build_sam2( 28 | config_file=variant_to_config_mapping[variant], 29 | ckpt_path=ckpt_path, 30 | device=device, 31 | mode=mode, 32 | hydra_overrides_extra=hydra_overrides_extra, 33 | apply_postprocessing=apply_postprocessing, 34 | ) 35 | 36 | 37 | def build_sam2( 38 | config_file, 39 | ckpt_path=None, 40 | device="cpu", 41 | mode="eval", 42 | hydra_overrides_extra=[], 43 | apply_postprocessing=True, 44 | ): 45 | 46 | if apply_postprocessing: 47 | hydra_overrides_extra = hydra_overrides_extra.copy() 48 | hydra_overrides_extra += [ 49 | # dynamically fall back to multi-mask if the single mask is not stable 50 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 51 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 52 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 53 | ] 54 | # Read config and init model 55 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 56 | OmegaConf.resolve(cfg) 57 | model = instantiate(cfg.model, _recursive_=True) 58 | _load_checkpoint(model, ckpt_path) 59 | model = model.to(device) 60 | if mode == "eval": 61 | model.eval() 62 | return model 63 | 64 | 65 | def build_sam2_video_predictor( 66 | config_file, 67 | ckpt_path=None, 68 | device="cpu", 69 | mode="eval", 70 | hydra_overrides_extra=[], 71 | apply_postprocessing=True, 72 | ): 73 | hydra_overrides = [ 74 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 75 | ] 76 | if apply_postprocessing: 77 | hydra_overrides_extra = hydra_overrides_extra.copy() 78 | hydra_overrides_extra += [ 79 | # dynamically fall back to multi-mask if the single mask is not stable 80 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 81 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 82 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 83 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 84 | "++model.binarize_mask_from_pts_for_mem_enc=true", 85 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 86 | # "++model.fill_hole_area=8", 87 | ] 88 | hydra_overrides.extend(hydra_overrides_extra) 89 | 90 | # Read config and init model 91 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 92 | OmegaConf.resolve(cfg) 93 | model = instantiate(cfg.model, _recursive_=True) 94 | _load_checkpoint(model, ckpt_path) 95 | model = model.to(device) 96 | if mode == "eval": 97 | model.eval() 98 | return model 99 | 100 | 101 | def _load_checkpoint(model, ckpt_path): 102 | if ckpt_path is not None: 103 | sd = torch.load(ckpt_path, map_location="cpu", weights_only=True)["model"] 104 | missing_keys, unexpected_keys = model.load_state_dict(sd) 105 | if missing_keys: 106 | logging.error(missing_keys) 107 | raise RuntimeError() 108 | if unexpected_keys: 109 | logging.error(unexpected_keys) 110 | raise RuntimeError() 111 | logging.info("Loaded checkpoint sucessfully") 112 | -------------------------------------------------------------------------------- /gradio_demo/sam2/configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /gradio_demo/sam2/configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /gradio_demo/sam2/configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /gradio_demo/sam2/configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /gradio_demo/sam2/configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from functools import partial 8 | from typing import List, Tuple, Union 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.backbones.utils import ( 15 | PatchEmbed, 16 | window_partition, 17 | window_unpartition, 18 | ) 19 | from sam2.modeling.sam2_utils import MLP, DropPath 20 | 21 | 22 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 23 | if pool is None: 24 | return x 25 | # (B, H, W, C) -> (B, C, H, W) 26 | x = x.permute(0, 3, 1, 2) 27 | x = pool(x) 28 | # (B, C, H', W') -> (B, H', W', C) 29 | x = x.permute(0, 2, 3, 1) 30 | if norm: 31 | x = norm(x) 32 | 33 | return x 34 | 35 | 36 | class MultiScaleAttention(nn.Module): 37 | def __init__( 38 | self, 39 | dim: int, 40 | dim_out: int, 41 | num_heads: int, 42 | q_pool: nn.Module = None, 43 | ): 44 | super().__init__() 45 | 46 | self.dim = dim 47 | self.dim_out = dim_out 48 | 49 | self.num_heads = num_heads 50 | head_dim = dim_out // num_heads 51 | self.scale = head_dim**-0.5 52 | 53 | self.q_pool = q_pool 54 | self.qkv = nn.Linear(dim, dim_out * 3) 55 | self.proj = nn.Linear(dim_out, dim_out) 56 | 57 | def forward(self, x: torch.Tensor) -> torch.Tensor: 58 | B, H, W, _ = x.shape 59 | # qkv with shape (B, H * W, 3, nHead, C) 60 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 61 | # q, k, v with shape (B, H * W, nheads, C) 62 | q, k, v = torch.unbind(qkv, 2) 63 | 64 | # Q pooling (for downsample at stage changes) 65 | if self.q_pool: 66 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 67 | H, W = q.shape[1:3] # downsampled shape 68 | q = q.reshape(B, H * W, self.num_heads, -1) 69 | 70 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 71 | x = F.scaled_dot_product_attention( 72 | q.transpose(1, 2), 73 | k.transpose(1, 2), 74 | v.transpose(1, 2), 75 | ) 76 | # Transpose back 77 | x = x.transpose(1, 2) 78 | x = x.reshape(B, H, W, -1) 79 | 80 | x = self.proj(x) 81 | 82 | return x 83 | 84 | 85 | class MultiScaleBlock(nn.Module): 86 | def __init__( 87 | self, 88 | dim: int, 89 | dim_out: int, 90 | num_heads: int, 91 | mlp_ratio: float = 4.0, 92 | drop_path: float = 0.0, 93 | norm_layer: Union[nn.Module, str] = "LayerNorm", 94 | q_stride: Tuple[int, int] = None, 95 | act_layer: nn.Module = nn.GELU, 96 | window_size: int = 0, 97 | ): 98 | super().__init__() 99 | 100 | if isinstance(norm_layer, str): 101 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 102 | 103 | self.dim = dim 104 | self.dim_out = dim_out 105 | self.norm1 = norm_layer(dim) 106 | 107 | self.window_size = window_size 108 | 109 | self.pool, self.q_stride = None, q_stride 110 | if self.q_stride: 111 | self.pool = nn.MaxPool2d( 112 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 113 | ) 114 | 115 | self.attn = MultiScaleAttention( 116 | dim, 117 | dim_out, 118 | num_heads=num_heads, 119 | q_pool=self.pool, 120 | ) 121 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 122 | 123 | self.norm2 = norm_layer(dim_out) 124 | self.mlp = MLP( 125 | dim_out, 126 | int(dim_out * mlp_ratio), 127 | dim_out, 128 | num_layers=2, 129 | activation=act_layer, 130 | ) 131 | 132 | if dim != dim_out: 133 | self.proj = nn.Linear(dim, dim_out) 134 | 135 | def forward(self, x: torch.Tensor) -> torch.Tensor: 136 | shortcut = x # B, H, W, C 137 | x = self.norm1(x) 138 | 139 | # Skip connection 140 | if self.dim != self.dim_out: 141 | shortcut = do_pool(self.proj(x), self.pool) 142 | 143 | # Window partition 144 | window_size = self.window_size 145 | if window_size > 0: 146 | H, W = x.shape[1], x.shape[2] 147 | x, pad_hw = window_partition(x, window_size) 148 | 149 | # Window Attention + Q Pooling (if stage change) 150 | x = self.attn(x) 151 | if self.q_stride: 152 | # Shapes have changed due to Q pooling 153 | window_size = self.window_size // self.q_stride[0] 154 | H, W = shortcut.shape[1:3] 155 | 156 | pad_h = (window_size - H % window_size) % window_size 157 | pad_w = (window_size - W % window_size) % window_size 158 | pad_hw = (H + pad_h, W + pad_w) 159 | 160 | # Reverse window partition 161 | if self.window_size > 0: 162 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 163 | 164 | x = shortcut + self.drop_path(x) 165 | # MLP 166 | x = x + self.drop_path(self.mlp(self.norm2(x))) 167 | return x 168 | 169 | 170 | class Hiera(nn.Module): 171 | """ 172 | Reference: https://arxiv.org/abs/2306.00989 173 | """ 174 | 175 | def __init__( 176 | self, 177 | embed_dim: int = 96, # initial embed dim 178 | num_heads: int = 1, # initial number of heads 179 | drop_path_rate: float = 0.0, # stochastic depth 180 | q_pool: int = 3, # number of q_pool stages 181 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 182 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 183 | dim_mul: float = 2.0, # dim_mul factor at stage shift 184 | head_mul: float = 2.0, # head_mul factor at stage shift 185 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 186 | # window size per stage, when not using global att. 187 | window_spec: Tuple[int, ...] = ( 188 | 8, 189 | 4, 190 | 14, 191 | 7, 192 | ), 193 | # global attn in these blocks 194 | global_att_blocks: Tuple[int, ...] = ( 195 | 12, 196 | 16, 197 | 20, 198 | ), 199 | return_interm_layers=True, # return feats from every stage 200 | ): 201 | super().__init__() 202 | 203 | assert len(stages) == len(window_spec) 204 | self.window_spec = window_spec 205 | 206 | depth = sum(stages) 207 | self.q_stride = q_stride 208 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 209 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 210 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 211 | self.return_interm_layers = return_interm_layers 212 | 213 | self.patch_embed = PatchEmbed( 214 | embed_dim=embed_dim, 215 | ) 216 | # Which blocks have global att? 217 | self.global_att_blocks = global_att_blocks 218 | 219 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 220 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 221 | self.pos_embed = nn.Parameter( 222 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 223 | ) 224 | self.pos_embed_window = nn.Parameter( 225 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 226 | ) 227 | 228 | dpr = [ 229 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 230 | ] # stochastic depth decay rule 231 | 232 | cur_stage = 1 233 | self.blocks = nn.ModuleList() 234 | 235 | for i in range(depth): 236 | dim_out = embed_dim 237 | # lags by a block, so first block of 238 | # next stage uses an initial window size 239 | # of previous stage and final window size of current stage 240 | window_size = self.window_spec[cur_stage - 1] 241 | 242 | if self.global_att_blocks is not None: 243 | window_size = 0 if i in self.global_att_blocks else window_size 244 | 245 | if i - 1 in self.stage_ends: 246 | dim_out = int(embed_dim * dim_mul) 247 | num_heads = int(num_heads * head_mul) 248 | cur_stage += 1 249 | 250 | block = MultiScaleBlock( 251 | dim=embed_dim, 252 | dim_out=dim_out, 253 | num_heads=num_heads, 254 | drop_path=dpr[i], 255 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 256 | window_size=window_size, 257 | ) 258 | 259 | embed_dim = dim_out 260 | self.blocks.append(block) 261 | 262 | self.channel_list = ( 263 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 264 | if return_interm_layers 265 | else [self.blocks[-1].dim_out] 266 | ) 267 | 268 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 269 | h, w = hw 270 | window_embed = self.pos_embed_window 271 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 272 | pos_embed = pos_embed + window_embed.tile( 273 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 274 | ) 275 | pos_embed = pos_embed.permute(0, 2, 3, 1) 276 | return pos_embed 277 | 278 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 279 | x = self.patch_embed(x) 280 | # x: (B, H, W, C) 281 | 282 | # Add pos embed 283 | x = x + self._get_pos_embed(x.shape[1:3]) 284 | 285 | outputs = [] 286 | for i, blk in enumerate(self.blocks): 287 | x = blk(x) 288 | if (i == self.stage_ends[-1]) or ( 289 | i in self.stage_ends and self.return_interm_layers 290 | ): 291 | feats = x.permute(0, 3, 1, 2) 292 | outputs.append(feats) 293 | 294 | return outputs 295 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | def forward(self, xs: List[torch.Tensor]): 102 | 103 | out = [None] * len(self.convs) 104 | pos = [None] * len(self.convs) 105 | assert len(xs) == len(self.convs) 106 | # fpn forward pass 107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 108 | prev_features = None 109 | # forward in top-down order (from low to high resolution) 110 | n = len(self.convs) - 1 111 | for i in range(n, -1, -1): 112 | x = xs[i] 113 | lateral_features = self.convs[n - i](x) 114 | if i in self.fpn_top_down_levels and prev_features is not None: 115 | top_down_features = F.interpolate( 116 | prev_features.to(dtype=torch.float32), 117 | scale_factor=2.0, 118 | mode=self.fpn_interp_model, 119 | align_corners=( 120 | None if self.fpn_interp_model == "nearest" else False 121 | ), 122 | antialias=False, 123 | ) 124 | prev_features = lateral_features + top_down_features 125 | if self.fuse_type == "avg": 126 | prev_features /= 2 127 | else: 128 | prev_features = lateral_features 129 | x_out = prev_features 130 | out[i] = x_out 131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 132 | 133 | return out, pos 134 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import Tensor, nn 11 | 12 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 13 | from sam2.modeling.sam.transformer import RoPEAttention 14 | 15 | 16 | class MemoryAttentionLayer(nn.Module): 17 | 18 | def __init__( 19 | self, 20 | activation: str, 21 | cross_attention: nn.Module, 22 | d_model: int, 23 | dim_feedforward: int, 24 | dropout: float, 25 | pos_enc_at_attn: bool, 26 | pos_enc_at_cross_attn_keys: bool, 27 | pos_enc_at_cross_attn_queries: bool, 28 | self_attention: nn.Module, 29 | ): 30 | super().__init__() 31 | self.d_model = d_model 32 | self.dim_feedforward = dim_feedforward 33 | self.dropout_value = dropout 34 | self.self_attn = self_attention 35 | self.cross_attn_image = cross_attention 36 | 37 | # Implementation of Feedforward model 38 | self.linear1 = nn.Linear(d_model, dim_feedforward) 39 | self.dropout = nn.Dropout(dropout) 40 | self.linear2 = nn.Linear(dim_feedforward, d_model) 41 | 42 | self.norm1 = nn.LayerNorm(d_model) 43 | self.norm2 = nn.LayerNorm(d_model) 44 | self.norm3 = nn.LayerNorm(d_model) 45 | self.dropout1 = nn.Dropout(dropout) 46 | self.dropout2 = nn.Dropout(dropout) 47 | self.dropout3 = nn.Dropout(dropout) 48 | 49 | self.activation_str = activation 50 | self.activation = get_activation_fn(activation) 51 | 52 | # Where to add pos enc 53 | self.pos_enc_at_attn = pos_enc_at_attn 54 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 55 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 56 | 57 | def _forward_sa(self, tgt, query_pos): 58 | # Self-Attention 59 | tgt2 = self.norm1(tgt) 60 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 61 | tgt2 = self.self_attn(q, k, v=tgt2) 62 | tgt = tgt + self.dropout1(tgt2) 63 | return tgt 64 | 65 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 66 | kwds = {} 67 | if num_k_exclude_rope > 0: 68 | assert isinstance(self.cross_attn_image, RoPEAttention) 69 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 70 | 71 | # Cross-Attention 72 | tgt2 = self.norm2(tgt) 73 | tgt2 = self.cross_attn_image( 74 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 75 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 76 | v=memory, 77 | **kwds, 78 | ) 79 | tgt = tgt + self.dropout2(tgt2) 80 | return tgt 81 | 82 | def forward( 83 | self, 84 | tgt, 85 | memory, 86 | pos: Optional[Tensor] = None, 87 | query_pos: Optional[Tensor] = None, 88 | num_k_exclude_rope: int = 0, 89 | ) -> torch.Tensor: 90 | 91 | # Self-Attn, Cross-Attn 92 | tgt = self._forward_sa(tgt, query_pos) 93 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 94 | # MLP 95 | tgt2 = self.norm3(tgt) 96 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 97 | tgt = tgt + self.dropout3(tgt2) 98 | return tgt 99 | 100 | 101 | class MemoryAttention(nn.Module): 102 | def __init__( 103 | self, 104 | d_model: int, 105 | pos_enc_at_input: bool, 106 | layer: nn.Module, 107 | num_layers: int, 108 | batch_first: bool = True, # Do layers expect batch first input? 109 | ): 110 | super().__init__() 111 | self.d_model = d_model 112 | self.layers = get_clones(layer, num_layers) 113 | self.num_layers = num_layers 114 | self.norm = nn.LayerNorm(d_model) 115 | self.pos_enc_at_input = pos_enc_at_input 116 | self.batch_first = batch_first 117 | 118 | def forward( 119 | self, 120 | curr: torch.Tensor, # self-attention inputs 121 | memory: torch.Tensor, # cross-attention inputs 122 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 123 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 124 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 125 | ): 126 | if isinstance(curr, list): 127 | assert isinstance(curr_pos, list) 128 | assert len(curr) == len(curr_pos) == 1 129 | curr, curr_pos = ( 130 | curr[0], 131 | curr_pos[0], 132 | ) 133 | 134 | assert ( 135 | curr.shape[1] == memory.shape[1] 136 | ), "Batch size must be the same for curr and memory" 137 | 138 | output = curr 139 | if self.pos_enc_at_input and curr_pos is not None: 140 | output = output + 0.1 * curr_pos 141 | 142 | if self.batch_first: 143 | # Convert to batch first 144 | output = output.transpose(0, 1) 145 | curr_pos = curr_pos.transpose(0, 1) 146 | memory = memory.transpose(0, 1) 147 | memory_pos = memory_pos.transpose(0, 1) 148 | 149 | for layer in self.layers: 150 | kwds = {} 151 | if isinstance(layer.cross_attn_image, RoPEAttention): 152 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 153 | 154 | output = layer( 155 | tgt=output, 156 | memory=memory, 157 | pos=memory_pos, 158 | query_pos=curr_pos, 159 | **kwds, 160 | ) 161 | normed_output = self.norm(output) 162 | 163 | if self.batch_first: 164 | # Convert back to seq first 165 | normed_output = normed_output.transpose(0, 1) 166 | curr_pos = curr_pos.transpose(0, 1) 167 | 168 | return normed_output 169 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, LayerNorm2d, get_clones 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | 14 | 15 | class PositionEmbeddingSine(nn.Module): 16 | """ 17 | This is a more standard version of the position embedding, very similar to the one 18 | used by the Attention is all you need paper, generalized to work on images. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | num_pos_feats, 24 | temperature: int = 10000, 25 | normalize: bool = True, 26 | scale: Optional[float] = None, 27 | ): 28 | super().__init__() 29 | assert num_pos_feats % 2 == 0, "Expecting even model width" 30 | self.num_pos_feats = num_pos_feats // 2 31 | self.temperature = temperature 32 | self.normalize = normalize 33 | if scale is not None and normalize is False: 34 | raise ValueError("normalize should be True if scale is passed") 35 | if scale is None: 36 | scale = 2 * math.pi 37 | self.scale = scale 38 | 39 | self.cache = {} 40 | 41 | def _encode_xy(self, x, y): 42 | # The positions are expected to be normalized 43 | assert len(x) == len(y) and x.ndim == y.ndim == 1 44 | x_embed = x * self.scale 45 | y_embed = y * self.scale 46 | 47 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 48 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 49 | 50 | pos_x = x_embed[:, None] / dim_t 51 | pos_y = y_embed[:, None] / dim_t 52 | pos_x = torch.stack( 53 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 54 | ).flatten(1) 55 | pos_y = torch.stack( 56 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 57 | ).flatten(1) 58 | return pos_x, pos_y 59 | 60 | @torch.no_grad() 61 | def encode_boxes(self, x, y, w, h): 62 | pos_x, pos_y = self._encode_xy(x, y) 63 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 64 | return pos 65 | 66 | encode = encode_boxes # Backwards compatibility 67 | 68 | @torch.no_grad() 69 | def encode_points(self, x, y, labels): 70 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 71 | assert bx == by and nx == ny and bx == bl and nx == nl 72 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 73 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 74 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 75 | return pos 76 | 77 | @torch.no_grad() 78 | def forward(self, x: torch.Tensor): 79 | cache_key = (x.shape[-2], x.shape[-1]) 80 | if cache_key in self.cache: 81 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) 82 | y_embed = ( 83 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) 84 | .view(1, -1, 1) 85 | .repeat(x.shape[0], 1, x.shape[-1]) 86 | ) 87 | x_embed = ( 88 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) 89 | .view(1, 1, -1) 90 | .repeat(x.shape[0], x.shape[-2], 1) 91 | ) 92 | 93 | if self.normalize: 94 | eps = 1e-6 95 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 96 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 97 | 98 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 99 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 100 | 101 | pos_x = x_embed[:, :, :, None] / dim_t 102 | pos_y = y_embed[:, :, :, None] / dim_t 103 | pos_x = torch.stack( 104 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 105 | ).flatten(3) 106 | pos_y = torch.stack( 107 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 108 | ).flatten(3) 109 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 110 | self.cache[cache_key] = pos[0] 111 | return pos 112 | 113 | 114 | class PositionEmbeddingRandom(nn.Module): 115 | """ 116 | Positional encoding using random spatial frequencies. 117 | """ 118 | 119 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 120 | super().__init__() 121 | if scale is None or scale <= 0.0: 122 | scale = 1.0 123 | self.register_buffer( 124 | "positional_encoding_gaussian_matrix", 125 | scale * torch.randn((2, num_pos_feats)), 126 | ) 127 | 128 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 129 | """Positionally encode points that are normalized to [0,1].""" 130 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 131 | coords = 2 * coords - 1 132 | coords = coords @ self.positional_encoding_gaussian_matrix 133 | coords = 2 * np.pi * coords 134 | # outputs d_1 x ... x d_n x C shape 135 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 136 | 137 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 138 | """Generate positional encoding for a grid of the specified size.""" 139 | h, w = size 140 | device: Any = self.positional_encoding_gaussian_matrix.device 141 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 142 | y_embed = grid.cumsum(dim=0) - 0.5 143 | x_embed = grid.cumsum(dim=1) - 0.5 144 | y_embed = y_embed / h 145 | x_embed = x_embed / w 146 | 147 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 148 | return pe.permute(2, 0, 1) # C x H x W 149 | 150 | def forward_with_coords( 151 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 152 | ) -> torch.Tensor: 153 | """Positionally encode points that are not normalized to [0,1].""" 154 | coords = coords_input.clone() 155 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 156 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 157 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 158 | 159 | 160 | # Rotary Positional Encoding, adapted from: 161 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 162 | # 2. https://github.com/naver-ai/rope-vit 163 | # 3. https://github.com/lucidrains/rotary-embedding-torch 164 | 165 | 166 | def init_t_xy(end_x: int, end_y: int): 167 | t = torch.arange(end_x * end_y, dtype=torch.float32) 168 | t_x = (t % end_x).float() 169 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 170 | return t_x, t_y 171 | 172 | 173 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 174 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 175 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 176 | 177 | t_x, t_y = init_t_xy(end_x, end_y) 178 | freqs_x = torch.outer(t_x, freqs_x) 179 | freqs_y = torch.outer(t_y, freqs_y) 180 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 181 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 182 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 183 | 184 | 185 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 186 | ndim = x.ndim 187 | assert 0 <= 1 < ndim 188 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 189 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 190 | return freqs_cis.view(*shape) 191 | 192 | 193 | def apply_rotary_enc( 194 | xq: torch.Tensor, 195 | xk: torch.Tensor, 196 | freqs_cis: torch.Tensor, 197 | repeat_freqs_k: bool = False, 198 | ): 199 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 200 | xk_ = ( 201 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 202 | if xk.shape[-2] != 0 203 | else None 204 | ) 205 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 206 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 207 | if xk_ is None: 208 | # no keys to rotate, due to dropout 209 | return xq_out.type_as(xq).to(xq.device), xk 210 | # repeat freqs along seq_len dim to match k seq_len 211 | if repeat_freqs_k: 212 | r = xk_.shape[-2] // xq_.shape[-2] 213 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 214 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 215 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 216 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/sam/mask_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.sam2_utils import MLP, LayerNorm2d 13 | 14 | 15 | class MaskDecoder(nn.Module): 16 | def __init__( 17 | self, 18 | *, 19 | transformer_dim: int, 20 | transformer: nn.Module, 21 | num_multimask_outputs: int = 3, 22 | activation: Type[nn.Module] = nn.GELU, 23 | iou_head_depth: int = 3, 24 | iou_head_hidden_dim: int = 256, 25 | use_high_res_features: bool = False, 26 | iou_prediction_use_sigmoid=False, 27 | dynamic_multimask_via_stability=False, 28 | dynamic_multimask_stability_delta=0.05, 29 | dynamic_multimask_stability_thresh=0.98, 30 | pred_obj_scores: bool = False, 31 | pred_obj_scores_mlp: bool = False, 32 | use_multimask_token_for_obj_ptr: bool = False, 33 | ) -> None: 34 | """ 35 | Predicts masks given an image and prompt embeddings, using a 36 | transformer architecture. 37 | 38 | Arguments: 39 | transformer_dim (int): the channel dimension of the transformer 40 | transformer (nn.Module): the transformer used to predict masks 41 | num_multimask_outputs (int): the number of masks to predict 42 | when disambiguating masks 43 | activation (nn.Module): the type of activation to use when 44 | upscaling masks 45 | iou_head_depth (int): the depth of the MLP used to predict 46 | mask quality 47 | iou_head_hidden_dim (int): the hidden dimension of the MLP 48 | used to predict mask quality 49 | """ 50 | super().__init__() 51 | self.transformer_dim = transformer_dim 52 | self.transformer = transformer 53 | 54 | self.num_multimask_outputs = num_multimask_outputs 55 | 56 | self.iou_token = nn.Embedding(1, transformer_dim) 57 | self.num_mask_tokens = num_multimask_outputs + 1 58 | self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) 59 | 60 | self.pred_obj_scores = pred_obj_scores 61 | if self.pred_obj_scores: 62 | self.obj_score_token = nn.Embedding(1, transformer_dim) 63 | self.use_multimask_token_for_obj_ptr = use_multimask_token_for_obj_ptr 64 | 65 | self.output_upscaling = nn.Sequential( 66 | nn.ConvTranspose2d( 67 | transformer_dim, transformer_dim // 4, kernel_size=2, stride=2 68 | ), 69 | LayerNorm2d(transformer_dim // 4), 70 | activation(), 71 | nn.ConvTranspose2d( 72 | transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2 73 | ), 74 | activation(), 75 | ) 76 | self.use_high_res_features = use_high_res_features 77 | if use_high_res_features: 78 | self.conv_s0 = nn.Conv2d( 79 | transformer_dim, transformer_dim // 8, kernel_size=1, stride=1 80 | ) 81 | self.conv_s1 = nn.Conv2d( 82 | transformer_dim, transformer_dim // 4, kernel_size=1, stride=1 83 | ) 84 | 85 | self.output_hypernetworks_mlps = nn.ModuleList( 86 | [ 87 | MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) 88 | for i in range(self.num_mask_tokens) 89 | ] 90 | ) 91 | 92 | self.iou_prediction_head = MLP( 93 | transformer_dim, 94 | iou_head_hidden_dim, 95 | self.num_mask_tokens, 96 | iou_head_depth, 97 | sigmoid_output=iou_prediction_use_sigmoid, 98 | ) 99 | if self.pred_obj_scores: 100 | self.pred_obj_score_head = nn.Linear(transformer_dim, 1) 101 | if pred_obj_scores_mlp: 102 | self.pred_obj_score_head = MLP(transformer_dim, transformer_dim, 1, 3) 103 | 104 | # When outputting a single mask, optionally we can dynamically fall back to the best 105 | # multimask output token if the single mask output token gives low stability scores. 106 | self.dynamic_multimask_via_stability = dynamic_multimask_via_stability 107 | self.dynamic_multimask_stability_delta = dynamic_multimask_stability_delta 108 | self.dynamic_multimask_stability_thresh = dynamic_multimask_stability_thresh 109 | 110 | def forward( 111 | self, 112 | image_embeddings: torch.Tensor, 113 | image_pe: torch.Tensor, 114 | sparse_prompt_embeddings: torch.Tensor, 115 | dense_prompt_embeddings: torch.Tensor, 116 | multimask_output: bool, 117 | repeat_image: bool, 118 | high_res_features: Optional[List[torch.Tensor]] = None, 119 | ) -> Tuple[torch.Tensor, torch.Tensor]: 120 | """ 121 | Predict masks given image and prompt embeddings. 122 | 123 | Arguments: 124 | image_embeddings (torch.Tensor): the embeddings from the image encoder 125 | image_pe (torch.Tensor): positional encoding with the shape of image_embeddings 126 | sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes 127 | dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs 128 | multimask_output (bool): Whether to return multiple masks or a single 129 | mask. 130 | 131 | Returns: 132 | torch.Tensor: batched predicted masks 133 | torch.Tensor: batched predictions of mask quality 134 | torch.Tensor: batched SAM token for mask output 135 | """ 136 | masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks( 137 | image_embeddings=image_embeddings, 138 | image_pe=image_pe, 139 | sparse_prompt_embeddings=sparse_prompt_embeddings, 140 | dense_prompt_embeddings=dense_prompt_embeddings, 141 | repeat_image=repeat_image, 142 | high_res_features=high_res_features, 143 | ) 144 | 145 | # Select the correct mask or masks for output 146 | if multimask_output: 147 | masks = masks[:, 1:, :, :] 148 | iou_pred = iou_pred[:, 1:] 149 | elif self.dynamic_multimask_via_stability and not self.training: 150 | masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred) 151 | else: 152 | masks = masks[:, 0:1, :, :] 153 | iou_pred = iou_pred[:, 0:1] 154 | 155 | if multimask_output and self.use_multimask_token_for_obj_ptr: 156 | sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape 157 | else: 158 | # Take the mask output token. Here we *always* use the token for single mask output. 159 | # At test time, even if we track after 1-click (and using multimask_output=True), 160 | # we still take the single mask token here. The rationale is that we always track 161 | # after multiple clicks during training, so the past tokens seen during training 162 | # are always the single mask token (and we'll let it be the object-memory token). 163 | sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape 164 | 165 | # Prepare output 166 | return masks, iou_pred, sam_tokens_out, object_score_logits 167 | 168 | def predict_masks( 169 | self, 170 | image_embeddings: torch.Tensor, 171 | image_pe: torch.Tensor, 172 | sparse_prompt_embeddings: torch.Tensor, 173 | dense_prompt_embeddings: torch.Tensor, 174 | repeat_image: bool, 175 | high_res_features: Optional[List[torch.Tensor]] = None, 176 | ) -> Tuple[torch.Tensor, torch.Tensor]: 177 | """Predicts masks. See 'forward' for more details.""" 178 | # Concatenate output tokens 179 | s = 0 180 | if self.pred_obj_scores: 181 | output_tokens = torch.cat( 182 | [ 183 | self.obj_score_token.weight, 184 | self.iou_token.weight, 185 | self.mask_tokens.weight, 186 | ], 187 | dim=0, 188 | ) 189 | s = 1 190 | else: 191 | output_tokens = torch.cat( 192 | [self.iou_token.weight, self.mask_tokens.weight], dim=0 193 | ) 194 | output_tokens = output_tokens.unsqueeze(0).expand( 195 | sparse_prompt_embeddings.size(0), -1, -1 196 | ) 197 | tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) 198 | 199 | # Expand per-image data in batch direction to be per-mask 200 | if repeat_image: 201 | src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) 202 | else: 203 | assert image_embeddings.shape[0] == tokens.shape[0] 204 | src = image_embeddings 205 | src = src + dense_prompt_embeddings 206 | assert ( 207 | image_pe.size(0) == 1 208 | ), "image_pe should have size 1 in batch dim (from `get_dense_pe()`)" 209 | pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) 210 | b, c, h, w = src.shape 211 | 212 | # Run the transformer 213 | hs, src = self.transformer(src, pos_src, tokens) 214 | iou_token_out = hs[:, s, :] 215 | mask_tokens_out = hs[:, s + 1 : (s + 1 + self.num_mask_tokens), :] 216 | 217 | # Upscale mask embeddings and predict masks using the mask tokens 218 | src = src.transpose(1, 2).view(b, c, h, w) 219 | if not self.use_high_res_features: 220 | upscaled_embedding = self.output_upscaling(src) 221 | else: 222 | dc1, ln1, act1, dc2, act2 = self.output_upscaling 223 | feat_s0, feat_s1 = high_res_features 224 | upscaled_embedding = act1(ln1(dc1(src) + feat_s1)) 225 | upscaled_embedding = act2(dc2(upscaled_embedding) + feat_s0) 226 | 227 | hyper_in_list: List[torch.Tensor] = [] 228 | for i in range(self.num_mask_tokens): 229 | hyper_in_list.append( 230 | self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]) 231 | ) 232 | hyper_in = torch.stack(hyper_in_list, dim=1) 233 | b, c, h, w = upscaled_embedding.shape 234 | masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) 235 | 236 | # Generate mask quality predictions 237 | iou_pred = self.iou_prediction_head(iou_token_out) 238 | if self.pred_obj_scores: 239 | assert s == 1 240 | object_score_logits = self.pred_obj_score_head(hs[:, 0, :]) 241 | else: 242 | # Obj scores logits - default to 10.0, i.e. assuming the object is present, sigmoid(10)=1 243 | object_score_logits = 10.0 * iou_pred.new_ones(iou_pred.shape[0], 1) 244 | 245 | return masks, iou_pred, mask_tokens_out, object_score_logits 246 | 247 | def _get_stability_scores(self, mask_logits): 248 | """ 249 | Compute stability scores of the mask logits based on the IoU between upper and 250 | lower thresholds, similar to https://github.com/fairinternal/onevision/pull/568. 251 | """ 252 | mask_logits = mask_logits.flatten(-2) 253 | stability_delta = self.dynamic_multimask_stability_delta 254 | area_i = torch.sum(mask_logits > stability_delta, dim=-1).float() 255 | area_u = torch.sum(mask_logits > -stability_delta, dim=-1).float() 256 | stability_scores = torch.where(area_u > 0, area_i / area_u, 1.0) 257 | return stability_scores 258 | 259 | def _dynamic_multimask_via_stability(self, all_mask_logits, all_iou_scores): 260 | """ 261 | When outputting a single mask, if the stability score from the current single-mask 262 | output (based on output token 0) falls below a threshold, we instead select from 263 | multi-mask outputs (based on output token 1~3) the mask with the highest predicted 264 | IoU score. This is intended to ensure a valid mask for both clicking and tracking. 265 | """ 266 | # The best mask from multimask output tokens (1~3) 267 | multimask_logits = all_mask_logits[:, 1:, :, :] 268 | multimask_iou_scores = all_iou_scores[:, 1:] 269 | best_scores_inds = torch.argmax(multimask_iou_scores, dim=-1) 270 | batch_inds = torch.arange( 271 | multimask_iou_scores.size(0), device=all_iou_scores.device 272 | ) 273 | best_multimask_logits = multimask_logits[batch_inds, best_scores_inds] 274 | best_multimask_logits = best_multimask_logits.unsqueeze(1) 275 | best_multimask_iou_scores = multimask_iou_scores[batch_inds, best_scores_inds] 276 | best_multimask_iou_scores = best_multimask_iou_scores.unsqueeze(1) 277 | 278 | # The mask from singlemask output token 0 and its stability score 279 | singlemask_logits = all_mask_logits[:, 0:1, :, :] 280 | singlemask_iou_scores = all_iou_scores[:, 0:1] 281 | stability_scores = self._get_stability_scores(singlemask_logits) 282 | is_stable = stability_scores >= self.dynamic_multimask_stability_thresh 283 | 284 | # Dynamically fall back to best multimask output upon low stability scores. 285 | mask_logits_out = torch.where( 286 | is_stable[..., None, None].expand_as(singlemask_logits), 287 | singlemask_logits, 288 | best_multimask_logits, 289 | ) 290 | iou_scores_out = torch.where( 291 | is_stable.expand_as(singlemask_iou_scores), 292 | singlemask_iou_scores, 293 | best_multimask_iou_scores, 294 | ) 295 | return mask_logits_out, iou_scores_out 296 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | from sam2.modeling.sam2_utils import LayerNorm2d 14 | 15 | 16 | class PromptEncoder(nn.Module): 17 | def __init__( 18 | self, 19 | embed_dim: int, 20 | image_embedding_size: Tuple[int, int], 21 | input_image_size: Tuple[int, int], 22 | mask_in_chans: int, 23 | activation: Type[nn.Module] = nn.GELU, 24 | ) -> None: 25 | """ 26 | Encodes prompts for input to SAM's mask decoder. 27 | 28 | Arguments: 29 | embed_dim (int): The prompts' embedding dimension 30 | image_embedding_size (tuple(int, int)): The spatial size of the 31 | image embedding, as (H, W). 32 | input_image_size (int): The padded size of the image as input 33 | to the image encoder, as (H, W). 34 | mask_in_chans (int): The number of hidden channels used for 35 | encoding input masks. 36 | activation (nn.Module): The activation to use when encoding 37 | input masks. 38 | """ 39 | super().__init__() 40 | self.embed_dim = embed_dim 41 | self.input_image_size = input_image_size 42 | self.image_embedding_size = image_embedding_size 43 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 44 | 45 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 46 | point_embeddings = [ 47 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 48 | ] 49 | self.point_embeddings = nn.ModuleList(point_embeddings) 50 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 51 | 52 | self.mask_input_size = ( 53 | 4 * image_embedding_size[0], 54 | 4 * image_embedding_size[1], 55 | ) 56 | self.mask_downscaling = nn.Sequential( 57 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 58 | LayerNorm2d(mask_in_chans // 4), 59 | activation(), 60 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 61 | LayerNorm2d(mask_in_chans), 62 | activation(), 63 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 64 | ) 65 | self.no_mask_embed = nn.Embedding(1, embed_dim) 66 | 67 | def get_dense_pe(self) -> torch.Tensor: 68 | """ 69 | Returns the positional encoding used to encode point prompts, 70 | applied to a dense set of points the shape of the image encoding. 71 | 72 | Returns: 73 | torch.Tensor: Positional encoding with shape 74 | 1x(embed_dim)x(embedding_h)x(embedding_w) 75 | """ 76 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 77 | 78 | def _embed_points( 79 | self, 80 | points: torch.Tensor, 81 | labels: torch.Tensor, 82 | pad: bool, 83 | ) -> torch.Tensor: 84 | """Embeds point prompts.""" 85 | points = points + 0.5 # Shift to center of pixel 86 | if pad: 87 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 88 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 89 | points = torch.cat([points, padding_point], dim=1) 90 | labels = torch.cat([labels, padding_label], dim=1) 91 | point_embedding = self.pe_layer.forward_with_coords( 92 | points, self.input_image_size 93 | ) 94 | point_embedding[labels == -1] = 0.0 95 | point_embedding[labels == -1] += self.not_a_point_embed.weight 96 | point_embedding[labels == 0] += self.point_embeddings[0].weight 97 | point_embedding[labels == 1] += self.point_embeddings[1].weight 98 | point_embedding[labels == 2] += self.point_embeddings[2].weight 99 | point_embedding[labels == 3] += self.point_embeddings[3].weight 100 | return point_embedding 101 | 102 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 103 | """Embeds box prompts.""" 104 | boxes = boxes + 0.5 # Shift to center of pixel 105 | coords = boxes.reshape(-1, 2, 2) 106 | corner_embedding = self.pe_layer.forward_with_coords( 107 | coords, self.input_image_size 108 | ) 109 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 110 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 111 | return corner_embedding 112 | 113 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 114 | """Embeds mask inputs.""" 115 | mask_embedding = self.mask_downscaling(masks) 116 | return mask_embedding 117 | 118 | def _get_batch_size( 119 | self, 120 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 121 | boxes: Optional[torch.Tensor], 122 | masks: Optional[torch.Tensor], 123 | ) -> int: 124 | """ 125 | Gets the batch size of the output given the batch size of the input prompts. 126 | """ 127 | if points is not None: 128 | return points[0].shape[0] 129 | elif boxes is not None: 130 | return boxes.shape[0] 131 | elif masks is not None: 132 | return masks.shape[0] 133 | else: 134 | return 1 135 | 136 | def _get_device(self) -> torch.device: 137 | return self.point_embeddings[0].weight.device 138 | 139 | def forward( 140 | self, 141 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 142 | boxes: Optional[torch.Tensor], 143 | masks: Optional[torch.Tensor], 144 | ) -> Tuple[torch.Tensor, torch.Tensor]: 145 | """ 146 | Embeds different types of prompts, returning both sparse and dense 147 | embeddings. 148 | 149 | Arguments: 150 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 151 | and labels to embed. 152 | boxes (torch.Tensor or none): boxes to embed 153 | masks (torch.Tensor or none): masks to embed 154 | 155 | Returns: 156 | torch.Tensor: sparse embeddings for the points and boxes, with shape 157 | BxNx(embed_dim), where N is determined by the number of input points 158 | and boxes. 159 | torch.Tensor: dense embeddings for the masks, in the shape 160 | Bx(embed_dim)x(embed_H)x(embed_W) 161 | """ 162 | bs = self._get_batch_size(points, boxes, masks) 163 | sparse_embeddings = torch.empty( 164 | (bs, 0, self.embed_dim), device=self._get_device() 165 | ) 166 | if points is not None: 167 | coords, labels = points 168 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 169 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 170 | if boxes is not None: 171 | box_embeddings = self._embed_boxes(boxes) 172 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 173 | 174 | if masks is not None: 175 | dense_embeddings = self._embed_masks(masks) 176 | else: 177 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 178 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 179 | ) 180 | 181 | return sparse_embeddings, dense_embeddings 182 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/sam/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import warnings 9 | from functools import partial 10 | from typing import Tuple, Type 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch import Tensor, nn 15 | 16 | from sam2.modeling.position_encoding import apply_rotary_enc, compute_axial_cis 17 | from sam2.modeling.sam2_utils import MLP 18 | from sam2.utils.misc import get_sdp_backends 19 | 20 | warnings.simplefilter(action="ignore", category=FutureWarning) 21 | # OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings() 22 | 23 | 24 | class TwoWayTransformer(nn.Module): 25 | def __init__( 26 | self, 27 | depth: int, 28 | embedding_dim: int, 29 | num_heads: int, 30 | mlp_dim: int, 31 | activation: Type[nn.Module] = nn.ReLU, 32 | attention_downsample_rate: int = 2, 33 | ) -> None: 34 | """ 35 | A transformer decoder that attends to an input image using 36 | queries whose positional embedding is supplied. 37 | 38 | Args: 39 | depth (int): number of layers in the transformer 40 | embedding_dim (int): the channel dimension for the input embeddings 41 | num_heads (int): the number of heads for multihead attention. Must 42 | divide embedding_dim 43 | mlp_dim (int): the channel dimension internal to the MLP block 44 | activation (nn.Module): the activation to use in the MLP block 45 | """ 46 | super().__init__() 47 | self.depth = depth 48 | self.embedding_dim = embedding_dim 49 | self.num_heads = num_heads 50 | self.mlp_dim = mlp_dim 51 | self.layers = nn.ModuleList() 52 | 53 | for i in range(depth): 54 | self.layers.append( 55 | TwoWayAttentionBlock( 56 | embedding_dim=embedding_dim, 57 | num_heads=num_heads, 58 | mlp_dim=mlp_dim, 59 | activation=activation, 60 | attention_downsample_rate=attention_downsample_rate, 61 | skip_first_layer_pe=(i == 0), 62 | ) 63 | ) 64 | 65 | self.final_attn_token_to_image = Attention( 66 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 67 | ) 68 | self.norm_final_attn = nn.LayerNorm(embedding_dim) 69 | 70 | def forward( 71 | self, 72 | image_embedding: Tensor, 73 | image_pe: Tensor, 74 | point_embedding: Tensor, 75 | ) -> Tuple[Tensor, Tensor]: 76 | """ 77 | Args: 78 | image_embedding (torch.Tensor): image to attend to. Should be shape 79 | B x embedding_dim x h x w for any h and w. 80 | image_pe (torch.Tensor): the positional encoding to add to the image. Must 81 | have the same shape as image_embedding. 82 | point_embedding (torch.Tensor): the embedding to add to the query points. 83 | Must have shape B x N_points x embedding_dim for any N_points. 84 | 85 | Returns: 86 | torch.Tensor: the processed point_embedding 87 | torch.Tensor: the processed image_embedding 88 | """ 89 | # BxCxHxW -> BxHWxC == B x N_image_tokens x C 90 | bs, c, h, w = image_embedding.shape 91 | image_embedding = image_embedding.flatten(2).permute(0, 2, 1) 92 | image_pe = image_pe.flatten(2).permute(0, 2, 1) 93 | 94 | # Prepare queries 95 | queries = point_embedding 96 | keys = image_embedding 97 | 98 | # Apply transformer blocks and final layernorm 99 | for layer in self.layers: 100 | queries, keys = layer( 101 | queries=queries, 102 | keys=keys, 103 | query_pe=point_embedding, 104 | key_pe=image_pe, 105 | ) 106 | 107 | # Apply the final attention layer from the points to the image 108 | q = queries + point_embedding 109 | k = keys + image_pe 110 | attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) 111 | queries = queries + attn_out 112 | queries = self.norm_final_attn(queries) 113 | 114 | return queries, keys 115 | 116 | 117 | class TwoWayAttentionBlock(nn.Module): 118 | def __init__( 119 | self, 120 | embedding_dim: int, 121 | num_heads: int, 122 | mlp_dim: int = 2048, 123 | activation: Type[nn.Module] = nn.ReLU, 124 | attention_downsample_rate: int = 2, 125 | skip_first_layer_pe: bool = False, 126 | ) -> None: 127 | """ 128 | A transformer block with four layers: (1) self-attention of sparse 129 | inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp 130 | block on sparse inputs, and (4) cross attention of dense inputs to sparse 131 | inputs. 132 | 133 | Arguments: 134 | embedding_dim (int): the channel dimension of the embeddings 135 | num_heads (int): the number of heads in the attention layers 136 | mlp_dim (int): the hidden dimension of the mlp block 137 | activation (nn.Module): the activation of the mlp block 138 | skip_first_layer_pe (bool): skip the PE on the first layer 139 | """ 140 | super().__init__() 141 | self.self_attn = Attention(embedding_dim, num_heads) 142 | self.norm1 = nn.LayerNorm(embedding_dim) 143 | 144 | self.cross_attn_token_to_image = Attention( 145 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 146 | ) 147 | self.norm2 = nn.LayerNorm(embedding_dim) 148 | 149 | self.mlp = MLP( 150 | embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation 151 | ) 152 | self.norm3 = nn.LayerNorm(embedding_dim) 153 | 154 | self.norm4 = nn.LayerNorm(embedding_dim) 155 | self.cross_attn_image_to_token = Attention( 156 | embedding_dim, num_heads, downsample_rate=attention_downsample_rate 157 | ) 158 | 159 | self.skip_first_layer_pe = skip_first_layer_pe 160 | 161 | def forward( 162 | self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor 163 | ) -> Tuple[Tensor, Tensor]: 164 | # Self attention block 165 | if self.skip_first_layer_pe: 166 | queries = self.self_attn(q=queries, k=queries, v=queries) 167 | else: 168 | q = queries + query_pe 169 | attn_out = self.self_attn(q=q, k=q, v=queries) 170 | queries = queries + attn_out 171 | queries = self.norm1(queries) 172 | 173 | # Cross attention block, tokens attending to image embedding 174 | q = queries + query_pe 175 | k = keys + key_pe 176 | attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) 177 | queries = queries + attn_out 178 | queries = self.norm2(queries) 179 | 180 | # MLP block 181 | mlp_out = self.mlp(queries) 182 | queries = queries + mlp_out 183 | queries = self.norm3(queries) 184 | 185 | # Cross attention block, image embedding attending to tokens 186 | q = queries + query_pe 187 | k = keys + key_pe 188 | attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) 189 | keys = keys + attn_out 190 | keys = self.norm4(keys) 191 | 192 | return queries, keys 193 | 194 | 195 | class Attention(nn.Module): 196 | """ 197 | An attention layer that allows for downscaling the size of the embedding 198 | after projection to queries, keys, and values. 199 | """ 200 | 201 | def __init__( 202 | self, 203 | embedding_dim: int, 204 | num_heads: int, 205 | downsample_rate: int = 1, 206 | dropout: float = 0.0, 207 | kv_in_dim: int = None, 208 | ) -> None: 209 | super().__init__() 210 | self.embedding_dim = embedding_dim 211 | self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim 212 | self.internal_dim = embedding_dim // downsample_rate 213 | self.num_heads = num_heads 214 | assert ( 215 | self.internal_dim % num_heads == 0 216 | ), "num_heads must divide embedding_dim." 217 | 218 | self.q_proj = nn.Linear(embedding_dim, self.internal_dim) 219 | self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 220 | self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim) 221 | self.out_proj = nn.Linear(self.internal_dim, embedding_dim) 222 | 223 | self.dropout_p = dropout 224 | 225 | def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: 226 | b, n, c = x.shape 227 | x = x.reshape(b, n, num_heads, c // num_heads) 228 | return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head 229 | 230 | def _recombine_heads(self, x: Tensor) -> Tensor: 231 | b, n_heads, n_tokens, c_per_head = x.shape 232 | x = x.transpose(1, 2) 233 | return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C 234 | 235 | def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: 236 | # Input projections 237 | q = self.q_proj(q) 238 | k = self.k_proj(k) 239 | v = self.v_proj(v) 240 | 241 | # Separate into heads 242 | q = self._separate_heads(q, self.num_heads) 243 | k = self._separate_heads(k, self.num_heads) 244 | v = self._separate_heads(v, self.num_heads) 245 | 246 | dropout_p = self.dropout_p if self.training else 0.0 247 | # Attention 248 | 249 | #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)): 250 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 251 | 252 | out = self._recombine_heads(out) 253 | out = self.out_proj(out) 254 | 255 | return out 256 | 257 | 258 | class RoPEAttention(Attention): 259 | """Attention with rotary position encoding.""" 260 | 261 | def __init__( 262 | self, 263 | *args, 264 | rope_theta=10000.0, 265 | # whether to repeat q rope to match k length 266 | # this is needed for cross-attention to memories 267 | rope_k_repeat=False, 268 | feat_sizes=(32, 32), # [w, h] for stride 16 feats at 512 resolution 269 | **kwargs, 270 | ): 271 | super().__init__(*args, **kwargs) 272 | 273 | self.compute_cis = partial( 274 | compute_axial_cis, dim=self.internal_dim // self.num_heads, theta=rope_theta 275 | ) 276 | freqs_cis = self.compute_cis(end_x=feat_sizes[0], end_y=feat_sizes[1]) 277 | self.freqs_cis = freqs_cis 278 | self.rope_k_repeat = rope_k_repeat 279 | 280 | def forward( 281 | self, q: Tensor, k: Tensor, v: Tensor, num_k_exclude_rope: int = 0 282 | ) -> Tensor: 283 | # Input projections 284 | q = self.q_proj(q) 285 | k = self.k_proj(k) 286 | v = self.v_proj(v) 287 | 288 | # Separate into heads 289 | q = self._separate_heads(q, self.num_heads) 290 | k = self._separate_heads(k, self.num_heads) 291 | v = self._separate_heads(v, self.num_heads) 292 | 293 | # Apply rotary position encoding 294 | w = h = math.sqrt(q.shape[-2]) 295 | self.freqs_cis = self.freqs_cis.to(q.device) 296 | if self.freqs_cis.shape[0] != q.shape[-2]: 297 | self.freqs_cis = self.compute_cis(end_x=w, end_y=h).to(q.device) 298 | if q.shape[-2] != k.shape[-2]: 299 | assert self.rope_k_repeat 300 | 301 | num_k_rope = k.size(-2) - num_k_exclude_rope 302 | q, k[:, :, :num_k_rope] = apply_rotary_enc( 303 | q, 304 | k[:, :, :num_k_rope], 305 | freqs_cis=self.freqs_cis, 306 | repeat_freqs_k=self.rope_k_repeat, 307 | ) 308 | 309 | dropout_p = self.dropout_p if self.training else 0.0 310 | 311 | #with torch.nn.attention.sdpa_kernel(get_sdp_backends(dropout_p)): 312 | out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 313 | 314 | out = self._recombine_heads(out) 315 | out = self.out_proj(out) 316 | 317 | return out 318 | -------------------------------------------------------------------------------- /gradio_demo/sam2/modeling/sam2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import copy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): 16 | """ 17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` 18 | that are temporally closest to the current frame at `frame_idx`. Here, we take 19 | - a) the closest conditioning frame before `frame_idx` (if any); 20 | - b) the closest conditioning frame after `frame_idx` (if any); 21 | - c) any other temporally closest conditioning frames until reaching a total 22 | of `max_cond_frame_num` conditioning frames. 23 | 24 | Outputs: 25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. 26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. 27 | """ 28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: 29 | selected_outputs = cond_frame_outputs 30 | unselected_outputs = {} 31 | else: 32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" 33 | selected_outputs = {} 34 | 35 | # the closest conditioning frame before `frame_idx` (if any) 36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) 37 | if idx_before is not None: 38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before] 39 | 40 | # the closest conditioning frame after `frame_idx` (if any) 41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) 42 | if idx_after is not None: 43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after] 44 | 45 | # add other temporally closest conditioning frames until reaching a total 46 | # of `max_cond_frame_num` conditioning frames. 47 | num_remain = max_cond_frame_num - len(selected_outputs) 48 | inds_remain = sorted( 49 | (t for t in cond_frame_outputs if t not in selected_outputs), 50 | key=lambda x: abs(x - frame_idx), 51 | )[:num_remain] 52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) 53 | unselected_outputs = { 54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs 55 | } 56 | 57 | return selected_outputs, unselected_outputs 58 | 59 | 60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000): 61 | """ 62 | Get 1D sine positional embedding as in the original Transformer paper. 63 | """ 64 | pe_dim = dim // 2 65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) 66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) 67 | 68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t 69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) 70 | return pos_embed 71 | 72 | 73 | def get_activation_fn(activation): 74 | """Return an activation function given a string""" 75 | if activation == "relu": 76 | return F.relu 77 | if activation == "gelu": 78 | return F.gelu 79 | if activation == "glu": 80 | return F.glu 81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 82 | 83 | 84 | def get_clones(module, N): 85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 86 | 87 | 88 | class DropPath(nn.Module): 89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py 90 | def __init__(self, drop_prob=0.0, scale_by_keep=True): 91 | super(DropPath, self).__init__() 92 | self.drop_prob = drop_prob 93 | self.scale_by_keep = scale_by_keep 94 | 95 | def forward(self, x): 96 | if self.drop_prob == 0.0 or not self.training: 97 | return x 98 | keep_prob = 1 - self.drop_prob 99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 101 | if keep_prob > 0.0 and self.scale_by_keep: 102 | random_tensor.div_(keep_prob) 103 | return x * random_tensor 104 | 105 | 106 | # Lightly adapted from 107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 108 | class MLP(nn.Module): 109 | def __init__( 110 | self, 111 | input_dim: int, 112 | hidden_dim: int, 113 | output_dim: int, 114 | num_layers: int, 115 | activation: nn.Module = nn.ReLU, 116 | sigmoid_output: bool = False, 117 | ) -> None: 118 | super().__init__() 119 | self.num_layers = num_layers 120 | h = [hidden_dim] * (num_layers - 1) 121 | self.layers = nn.ModuleList( 122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 123 | ) 124 | self.sigmoid_output = sigmoid_output 125 | self.act = activation() 126 | 127 | def forward(self, x): 128 | for i, layer in enumerate(self.layers): 129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) 130 | if self.sigmoid_output: 131 | x = F.sigmoid(x) 132 | return x 133 | 134 | 135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 137 | class LayerNorm2d(nn.Module): 138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 139 | super().__init__() 140 | self.weight = nn.Parameter(torch.ones(num_channels)) 141 | self.bias = nn.Parameter(torch.zeros(num_channels)) 142 | self.eps = eps 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | u = x.mean(1, keepdim=True) 146 | s = (x - u).pow(2).mean(1, keepdim=True) 147 | x = (x - u) / torch.sqrt(s + self.eps) 148 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 149 | return x 150 | -------------------------------------------------------------------------------- /gradio_demo/sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from .download import download_weights 8 | -------------------------------------------------------------------------------- /gradio_demo/sam2/utils/amg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from copy import deepcopy 9 | from itertools import product 10 | from typing import Any, Dict, Generator, ItemsView, List, Tuple 11 | 12 | import numpy as np 13 | import torch 14 | 15 | # Very lightly adapted from https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/amg.py 16 | 17 | 18 | class MaskData: 19 | """ 20 | A structure for storing masks and their related data in batched format. 21 | Implements basic filtering and concatenation. 22 | """ 23 | 24 | def __init__(self, **kwargs) -> None: 25 | for v in kwargs.values(): 26 | assert isinstance( 27 | v, (list, np.ndarray, torch.Tensor) 28 | ), "MaskData only supports list, numpy arrays, and torch tensors." 29 | self._stats = dict(**kwargs) 30 | 31 | def __setitem__(self, key: str, item: Any) -> None: 32 | assert isinstance( 33 | item, (list, np.ndarray, torch.Tensor) 34 | ), "MaskData only supports list, numpy arrays, and torch tensors." 35 | self._stats[key] = item 36 | 37 | def __delitem__(self, key: str) -> None: 38 | del self._stats[key] 39 | 40 | def __getitem__(self, key: str) -> Any: 41 | return self._stats[key] 42 | 43 | def items(self) -> ItemsView[str, Any]: 44 | return self._stats.items() 45 | 46 | def filter(self, keep: torch.Tensor) -> None: 47 | for k, v in self._stats.items(): 48 | if v is None: 49 | self._stats[k] = None 50 | elif isinstance(v, torch.Tensor): 51 | self._stats[k] = v[torch.as_tensor(keep, device=v.device)] 52 | elif isinstance(v, np.ndarray): 53 | self._stats[k] = v[keep.detach().cpu().numpy()] 54 | elif isinstance(v, list) and keep.dtype == torch.bool: 55 | self._stats[k] = [a for i, a in enumerate(v) if keep[i]] 56 | elif isinstance(v, list): 57 | self._stats[k] = [v[i] for i in keep] 58 | else: 59 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 60 | 61 | def cat(self, new_stats: "MaskData") -> None: 62 | for k, v in new_stats.items(): 63 | if k not in self._stats or self._stats[k] is None: 64 | self._stats[k] = deepcopy(v) 65 | elif isinstance(v, torch.Tensor): 66 | self._stats[k] = torch.cat([self._stats[k], v], dim=0) 67 | elif isinstance(v, np.ndarray): 68 | self._stats[k] = np.concatenate([self._stats[k], v], axis=0) 69 | elif isinstance(v, list): 70 | self._stats[k] = self._stats[k] + deepcopy(v) 71 | else: 72 | raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") 73 | 74 | def to_numpy(self) -> None: 75 | for k, v in self._stats.items(): 76 | if isinstance(v, torch.Tensor): 77 | self._stats[k] = v.float().detach().cpu().numpy() 78 | 79 | 80 | def is_box_near_crop_edge( 81 | boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 82 | ) -> torch.Tensor: 83 | """Filter masks at the edge of a crop, but not at the edge of the original image.""" 84 | crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) 85 | orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) 86 | boxes = uncrop_boxes_xyxy(boxes, crop_box).float() 87 | near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) 88 | near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) 89 | near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) 90 | return torch.any(near_crop_edge, dim=1) 91 | 92 | 93 | def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: 94 | box_xywh = deepcopy(box_xyxy) 95 | box_xywh[2] = box_xywh[2] - box_xywh[0] 96 | box_xywh[3] = box_xywh[3] - box_xywh[1] 97 | return box_xywh 98 | 99 | 100 | def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: 101 | assert len(args) > 0 and all( 102 | len(a) == len(args[0]) for a in args 103 | ), "Batched iteration must have inputs of all the same size." 104 | n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) 105 | for b in range(n_batches): 106 | yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] 107 | 108 | 109 | def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: 110 | """ 111 | Encodes masks to an uncompressed RLE, in the format expected by 112 | pycoco tools. 113 | """ 114 | # Put in fortran order and flatten h,w 115 | b, h, w = tensor.shape 116 | tensor = tensor.permute(0, 2, 1).flatten(1) 117 | 118 | # Compute change indices 119 | diff = tensor[:, 1:] ^ tensor[:, :-1] 120 | change_indices = diff.nonzero() 121 | 122 | # Encode run length 123 | out = [] 124 | for i in range(b): 125 | cur_idxs = change_indices[change_indices[:, 0] == i, 1] 126 | cur_idxs = torch.cat( 127 | [ 128 | torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), 129 | cur_idxs + 1, 130 | torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), 131 | ] 132 | ) 133 | btw_idxs = cur_idxs[1:] - cur_idxs[:-1] 134 | counts = [] if tensor[i, 0] == 0 else [0] 135 | counts.extend(btw_idxs.detach().cpu().tolist()) 136 | out.append({"size": [h, w], "counts": counts}) 137 | return out 138 | 139 | 140 | def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: 141 | """Compute a binary mask from an uncompressed RLE.""" 142 | h, w = rle["size"] 143 | mask = np.empty(h * w, dtype=bool) 144 | idx = 0 145 | parity = False 146 | for count in rle["counts"]: 147 | mask[idx : idx + count] = parity 148 | idx += count 149 | parity ^= True 150 | mask = mask.reshape(w, h) 151 | return mask.transpose() # Put in C order 152 | 153 | 154 | def area_from_rle(rle: Dict[str, Any]) -> int: 155 | return sum(rle["counts"][1::2]) 156 | 157 | 158 | def calculate_stability_score( 159 | masks: torch.Tensor, mask_threshold: float, threshold_offset: float 160 | ) -> torch.Tensor: 161 | """ 162 | Computes the stability score for a batch of masks. The stability 163 | score is the IoU between the binary masks obtained by thresholding 164 | the predicted mask logits at high and low values. 165 | """ 166 | # One mask is always contained inside the other. 167 | # Save memory by preventing unnecessary cast to torch.int64 168 | intersections = ( 169 | (masks > (mask_threshold + threshold_offset)) 170 | .sum(-1, dtype=torch.int16) 171 | .sum(-1, dtype=torch.int32) 172 | ) 173 | unions = ( 174 | (masks > (mask_threshold - threshold_offset)) 175 | .sum(-1, dtype=torch.int16) 176 | .sum(-1, dtype=torch.int32) 177 | ) 178 | return intersections / unions 179 | 180 | 181 | def build_point_grid(n_per_side: int) -> np.ndarray: 182 | """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" 183 | offset = 1 / (2 * n_per_side) 184 | points_one_side = np.linspace(offset, 1 - offset, n_per_side) 185 | points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) 186 | points_y = np.tile(points_one_side[:, None], (1, n_per_side)) 187 | points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) 188 | return points 189 | 190 | 191 | def build_all_layer_point_grids( 192 | n_per_side: int, n_layers: int, scale_per_layer: int 193 | ) -> List[np.ndarray]: 194 | """Generates point grids for all crop layers.""" 195 | points_by_layer = [] 196 | for i in range(n_layers + 1): 197 | n_points = int(n_per_side / (scale_per_layer**i)) 198 | points_by_layer.append(build_point_grid(n_points)) 199 | return points_by_layer 200 | 201 | 202 | def generate_crop_boxes( 203 | im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float 204 | ) -> Tuple[List[List[int]], List[int]]: 205 | """ 206 | Generates a list of crop boxes of different sizes. Each layer 207 | has (2**i)**2 boxes for the ith layer. 208 | """ 209 | crop_boxes, layer_idxs = [], [] 210 | im_h, im_w = im_size 211 | short_side = min(im_h, im_w) 212 | 213 | # Original image 214 | crop_boxes.append([0, 0, im_w, im_h]) 215 | layer_idxs.append(0) 216 | 217 | def crop_len(orig_len, n_crops, overlap): 218 | return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) 219 | 220 | for i_layer in range(n_layers): 221 | n_crops_per_side = 2 ** (i_layer + 1) 222 | overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) 223 | 224 | crop_w = crop_len(im_w, n_crops_per_side, overlap) 225 | crop_h = crop_len(im_h, n_crops_per_side, overlap) 226 | 227 | crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] 228 | crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] 229 | 230 | # Crops in XYWH format 231 | for x0, y0 in product(crop_box_x0, crop_box_y0): 232 | box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] 233 | crop_boxes.append(box) 234 | layer_idxs.append(i_layer + 1) 235 | 236 | return crop_boxes, layer_idxs 237 | 238 | 239 | def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 240 | x0, y0, _, _ = crop_box 241 | offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) 242 | # Check if boxes has a channel dimension 243 | if len(boxes.shape) == 3: 244 | offset = offset.unsqueeze(1) 245 | return boxes + offset 246 | 247 | 248 | def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: 249 | x0, y0, _, _ = crop_box 250 | offset = torch.tensor([[x0, y0]], device=points.device) 251 | # Check if points has a channel dimension 252 | if len(points.shape) == 3: 253 | offset = offset.unsqueeze(1) 254 | return points + offset 255 | 256 | 257 | def uncrop_masks( 258 | masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int 259 | ) -> torch.Tensor: 260 | x0, y0, x1, y1 = crop_box 261 | if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: 262 | return masks 263 | # Coordinate transform masks 264 | pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) 265 | pad = (x0, pad_x - x0, y0, pad_y - y0) 266 | return torch.nn.functional.pad(masks, pad, value=0) 267 | 268 | 269 | def remove_small_regions( 270 | mask: np.ndarray, area_thresh: float, mode: str 271 | ) -> Tuple[np.ndarray, bool]: 272 | """ 273 | Removes small disconnected regions and holes in a mask. Returns the 274 | mask and an indicator of if the mask has been modified. 275 | """ 276 | import cv2 # type: ignore 277 | 278 | assert mode in ["holes", "islands"] 279 | correct_holes = mode == "holes" 280 | working_mask = (correct_holes ^ mask).astype(np.uint8) 281 | n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) 282 | sizes = stats[:, -1][1:] # Row 0 is background label 283 | small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] 284 | if len(small_regions) == 0: 285 | return mask, False 286 | fill_labels = [0] + small_regions 287 | if not correct_holes: 288 | fill_labels = [i for i in range(n_labels) if i not in fill_labels] 289 | # If every region is below threshold, keep largest 290 | if len(fill_labels) == 0: 291 | fill_labels = [int(np.argmax(sizes)) + 1] 292 | mask = np.isin(regions, fill_labels) 293 | return mask, True 294 | 295 | 296 | def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: 297 | from pycocotools import mask as mask_utils # type: ignore 298 | 299 | h, w = uncompressed_rle["size"] 300 | rle = mask_utils.frPyObjects(uncompressed_rle, h, w) 301 | rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json 302 | return rle 303 | 304 | 305 | def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: 306 | """ 307 | Calculates boxes in XYXY format around masks. Return [0,0,0,0] for 308 | an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. 309 | """ 310 | # torch.max below raises an error on empty inputs, just skip in this case 311 | if torch.numel(masks) == 0: 312 | return torch.zeros(*masks.shape[:-2], 4, device=masks.device) 313 | 314 | # Normalize shape to CxHxW 315 | shape = masks.shape 316 | h, w = shape[-2:] 317 | if len(shape) > 2: 318 | masks = masks.flatten(0, -3) 319 | else: 320 | masks = masks.unsqueeze(0) 321 | 322 | # Get top and bottom edges 323 | in_height, _ = torch.max(masks, dim=-1) 324 | in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] 325 | bottom_edges, _ = torch.max(in_height_coords, dim=-1) 326 | in_height_coords = in_height_coords + h * (~in_height) 327 | top_edges, _ = torch.min(in_height_coords, dim=-1) 328 | 329 | # Get left and right edges 330 | in_width, _ = torch.max(masks, dim=-2) 331 | in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] 332 | right_edges, _ = torch.max(in_width_coords, dim=-1) 333 | in_width_coords = in_width_coords + w * (~in_width) 334 | left_edges, _ = torch.min(in_width_coords, dim=-1) 335 | 336 | # If the mask is empty the right edge will be to the left of the left edge. 337 | # Replace these boxes with [0, 0, 0, 0] 338 | empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) 339 | out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) 340 | out = out * (~empty_filter).unsqueeze(-1) 341 | 342 | # Return to original shape 343 | if len(shape) > 2: 344 | out = out.reshape(*shape[:-2], 4) 345 | else: 346 | out = out[0] 347 | 348 | return out 349 | -------------------------------------------------------------------------------- /gradio_demo/sam2/utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from typing import List 4 | 5 | import pytest 6 | 7 | 8 | @pytest.fixture 9 | def download_weights(output_directory: str = "artifacts") -> None: 10 | base_url: str = "https://dl.fbaipublicfiles.com/segment_anything_2/072824/" 11 | file_names: List[str] = [ 12 | "sam2_hiera_tiny.pt", 13 | "sam2_hiera_small.pt", 14 | "sam2_hiera_base_plus.pt", 15 | "sam2_hiera_large.pt", 16 | ] 17 | 18 | if not os.path.exists(output_directory): 19 | os.makedirs(output_directory) 20 | 21 | for file_name in file_names: 22 | file_path = os.path.join(output_directory, file_name) 23 | if not os.path.exists(file_path): 24 | url = f"{base_url}{file_name}" 25 | command = ["wget", url, "-P", output_directory] 26 | try: 27 | result = subprocess.run( 28 | command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE 29 | ) 30 | print(f"Download of {file_name} completed successfully.") 31 | print(result.stdout.decode()) 32 | except subprocess.CalledProcessError as e: 33 | print(f"An error occurred during the download of {file_name}.") 34 | print(e.stderr.decode()) 35 | else: 36 | print(f"{file_name} already exists. Skipping download.") 37 | -------------------------------------------------------------------------------- /gradio_demo/sam2/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | from threading import Thread 10 | from typing import Dict, List, Union 11 | 12 | import numpy as np 13 | import torch 14 | from PIL import Image 15 | from torch.nn.attention import SDPBackend 16 | from einops import rearrange 17 | from tqdm import tqdm 18 | import torch.nn.functional as F 19 | 20 | VARIANTS: List[str] = ["tiny", "small", "base_plus", "large"] 21 | 22 | variant_to_config_mapping: Dict[str, str] = { 23 | "tiny": "sam2_hiera_t.yaml", 24 | "small": "sam2_hiera_s.yaml", 25 | "base_plus": "sam2_hiera_b+.yaml", 26 | "large": "sam2_hiera_l.yaml", 27 | } 28 | 29 | 30 | def get_sdp_backends(dropout_p: float) -> Union[List[SDPBackend], SDPBackend]: 31 | backends = [] 32 | if torch.cuda.is_available(): 33 | use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 34 | pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) 35 | 36 | if torch.cuda.get_device_properties(0).major < 7: 37 | backends.append(SDPBackend.EFFICIENT_ATTENTION) 38 | 39 | if use_flash_attn: 40 | backends.append(SDPBackend.FLASH_ATTENTION) 41 | 42 | if pytorch_version < (2, 2) or not use_flash_attn: 43 | backends.append(SDPBackend.MATH) 44 | 45 | if ( 46 | SDPBackend.EFFICIENT_ATTENTION in backends and dropout_p > 0.0 47 | ) and SDPBackend.MATH not in backends: 48 | backends.append(SDPBackend.MATH) 49 | 50 | else: 51 | backends.extend([SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH]) 52 | 53 | return backends 54 | 55 | 56 | def get_connected_components(mask): 57 | """ 58 | Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). 59 | 60 | Inputs: 61 | - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is 62 | background. 63 | 64 | Outputs: 65 | - labels: A tensor of shape (N, 1, H, W) containing the connected component labels 66 | for foreground pixels and 0 for background pixels. 67 | - counts: A tensor of shape (N, 1, H, W) containing the area of the connected 68 | components for foreground pixels and 0 for background pixels. 69 | """ 70 | from sam2 import _C 71 | 72 | return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) 73 | 74 | 75 | def mask_to_box(masks: torch.Tensor): 76 | """ 77 | compute bounding box given an input mask 78 | 79 | Inputs: 80 | - masks: [B, 1, H, W] boxes, dtype=torch.Tensor 81 | 82 | Returns: 83 | - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor 84 | """ 85 | B, _, h, w = masks.shape 86 | device = masks.device 87 | xs = torch.arange(w, device=device, dtype=torch.int32) 88 | ys = torch.arange(h, device=device, dtype=torch.int32) 89 | grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") 90 | grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) 91 | grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) 92 | min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) 93 | max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) 94 | min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) 95 | max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) 96 | bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) 97 | 98 | return bbox_coords 99 | 100 | 101 | def _load_img_as_tensor(img_path, image_size): 102 | img_pil = Image.open(img_path) 103 | img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) 104 | if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images 105 | img_np = img_np / 255.0 106 | else: 107 | raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") 108 | img = torch.from_numpy(img_np).permute(2, 0, 1) 109 | video_width, video_height = img_pil.size # the original video size 110 | return img, video_height, video_width 111 | 112 | 113 | class AsyncVideoFrameLoader: 114 | """ 115 | A list of video frames to be load asynchronously without blocking session start. 116 | """ 117 | 118 | def __init__(self, img_paths, image_size, img_mean, img_std, device): 119 | self.img_paths = img_paths 120 | self.image_size = image_size 121 | self.img_mean = img_mean 122 | self.img_std = img_std 123 | self.device = device 124 | # items in `self._images` will be loaded asynchronously 125 | self.images = [None] * len(img_paths) 126 | # catch and raise any exceptions in the async loading thread 127 | self.exception = None 128 | # video_height and video_width be filled when loading the first image 129 | self.video_height = None 130 | self.video_width = None 131 | 132 | # load the first frame to fill video_height and video_width and also 133 | # to cache it (since it's most likely where the user will click) 134 | self.__getitem__(0) 135 | 136 | # load the rest of frames asynchronously without blocking the session start 137 | def _load_frames(): 138 | try: 139 | for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): 140 | self.__getitem__(n) 141 | except Exception as e: 142 | self.exception = e 143 | 144 | self.thread = Thread(target=_load_frames, daemon=True) 145 | self.thread.start() 146 | 147 | def __getitem__(self, index): 148 | if self.exception is not None: 149 | raise RuntimeError("Failure in frame loading thread") from self.exception 150 | 151 | img = self.images[index] 152 | if img is not None: 153 | return img 154 | 155 | img, video_height, video_width = _load_img_as_tensor( 156 | self.img_paths[index], self.image_size 157 | ) 158 | self.video_height = video_height 159 | self.video_width = video_width 160 | # normalize by mean and std 161 | img -= self.img_mean 162 | img /= self.img_std 163 | img = img.to(self.device) 164 | self.images[index] = img 165 | return img 166 | 167 | def __len__(self): 168 | return len(self.images) 169 | 170 | 171 | def load_video_frames( 172 | video_path, 173 | image_size, 174 | images=None, 175 | img_mean=(0.485, 0.456, 0.406), 176 | img_std=(0.229, 0.224, 0.225), 177 | async_loading_frames=False, 178 | device="cpu", 179 | ): 180 | """ 181 | Load the video frames from a directory of JPEG files (".jpg" format). 182 | 183 | You can load a frame asynchronously by setting `async_loading_frames` to `True`. 184 | """ 185 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] 186 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] 187 | 188 | if images is not None: 189 | images = torch.from_numpy(images).float() 190 | images = rearrange(images, "f h w c -> f c h w") 191 | images = F.interpolate(images, (image_size, image_size), mode="bilinear") 192 | video_height, video_width = images.shape[2:] 193 | else: 194 | if isinstance(video_path, str) and os.path.isdir(video_path): 195 | jpg_folder = video_path 196 | else: 197 | raise NotImplementedError("Only JPEG frames are supported at this moment") 198 | 199 | frame_names = [ 200 | p 201 | for p in os.listdir(jpg_folder) 202 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 203 | ] 204 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 205 | num_frames = len(frame_names) 206 | if num_frames == 0: 207 | raise RuntimeError(f"no images found in {jpg_folder}") 208 | img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] 209 | 210 | images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) 211 | for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): 212 | images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) 213 | images = images.to(device) 214 | img_mean = img_mean.to(device) 215 | img_std = img_std.to(device) 216 | # normalize by mean and std 217 | images -= img_mean 218 | images /= img_std 219 | return images, video_height, video_width 220 | 221 | 222 | def fill_holes_in_mask_scores(mask, max_area): 223 | """ 224 | A post processor to fill small holes in mask scores with area under `max_area`. 225 | """ 226 | # Holes are those connected components in background with area <= self.max_area 227 | # (background regions are those with mask scores <= 0) 228 | assert max_area > 0, "max_area must be positive" 229 | labels, areas = get_connected_components(mask <= 0) 230 | is_hole = (labels > 0) & (areas <= max_area) 231 | # We fill holes with a small positive mask score (0.1) to change them to foreground. 232 | mask = torch.where(is_hole, 0.1, mask) 233 | return mask 234 | 235 | 236 | def concat_points(old_point_inputs, new_points, new_labels): 237 | """Add new points and labels to previous point inputs (add at the end).""" 238 | if old_point_inputs is None: 239 | points, labels = new_points, new_labels 240 | else: 241 | points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) 242 | labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) 243 | 244 | return {"point_coords": points, "point_labels": labels} 245 | -------------------------------------------------------------------------------- /gradio_demo/sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Normalize, Resize, ToTensor 11 | 12 | 13 | class SAM2Transforms(nn.Module): 14 | def __init__( 15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 16 | ): 17 | """ 18 | Transforms for SAM2. 19 | """ 20 | super().__init__() 21 | self.resolution = resolution 22 | self.mask_threshold = mask_threshold 23 | self.max_hole_area = max_hole_area 24 | self.max_sprinkle_area = max_sprinkle_area 25 | self.mean = [0.485, 0.456, 0.406] 26 | self.std = [0.229, 0.224, 0.225] 27 | self.to_tensor = ToTensor() 28 | self.transforms = torch.jit.script( 29 | nn.Sequential( 30 | Resize((self.resolution, self.resolution)), 31 | Normalize(self.mean, self.std), 32 | ) 33 | ) 34 | 35 | def __call__(self, x): 36 | x = self.to_tensor(x) 37 | return self.transforms(x) 38 | 39 | def forward_batch(self, img_list): 40 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 41 | img_batch = torch.stack(img_batch, dim=0) 42 | return img_batch 43 | 44 | def transform_coords( 45 | self, coords: torch.Tensor, normalize=False, orig_hw=None 46 | ) -> torch.Tensor: 47 | """ 48 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 49 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 50 | 51 | Returns 52 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 53 | """ 54 | if normalize: 55 | assert orig_hw is not None 56 | h, w = orig_hw 57 | coords = coords.clone() 58 | coords[..., 0] = coords[..., 0] / w 59 | coords[..., 1] = coords[..., 1] / h 60 | 61 | coords = coords * self.resolution # unnormalize coords 62 | return coords 63 | 64 | def transform_boxes( 65 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 66 | ) -> torch.Tensor: 67 | """ 68 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 69 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 70 | """ 71 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 72 | return boxes 73 | 74 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 75 | """ 76 | Perform PostProcessing on output masks. 77 | """ 78 | from sam2.utils.misc import get_connected_components 79 | 80 | masks = masks.float() 81 | if self.max_hole_area > 0: 82 | # Holes are those connected components in background with area <= self.fill_hole_area 83 | # (background regions are those with mask scores <= self.mask_threshold) 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold) 86 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 87 | is_hole = is_hole.reshape_as(masks) 88 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 89 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 90 | 91 | if self.max_sprinkle_area > 0: 92 | labels, areas = get_connected_components(mask_flat > self.mask_threshold) 93 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 94 | is_hole = is_hole.reshape_as(masks) 95 | # We fill holes with negative mask score (-10.0) to change them to background. 96 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 97 | 98 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 99 | return masks 100 | -------------------------------------------------------------------------------- /gradio_demo/sam2/utils/visualization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import numpy as np 4 | from PIL import Image 5 | 6 | 7 | def show_masks( 8 | image: np.ndarray, 9 | masks: np.ndarray, 10 | scores: Optional[np.ndarray], 11 | alpha: Optional[float] = 0.5, 12 | display_image: Optional[bool] = False, 13 | only_best: Optional[bool] = True, 14 | autogenerated_mask: Optional[bool] = False, 15 | ) -> Image.Image: 16 | if scores is not None: 17 | # sort masks by their scores 18 | sorted_ind = np.argsort(scores)[::-1] 19 | masks = masks[sorted_ind] 20 | 21 | if autogenerated_mask: 22 | masks = sorted(masks, key=(lambda x: x["area"]), reverse=True) 23 | else: 24 | # get mask dimensions 25 | h, w = masks.shape[-2:] 26 | 27 | if display_image: 28 | output_image = Image.fromarray(image) 29 | else: 30 | # create a new blank image to superimpose masks 31 | if autogenerated_mask: 32 | output_image = Image.new( 33 | mode="RGBA", 34 | size=( 35 | masks[0]["segmentation"].shape[1], 36 | masks[0]["segmentation"].shape[0], 37 | ), 38 | color=(0, 0, 0), 39 | ) 40 | else: 41 | output_image = Image.new(mode="RGBA", size=(w, h), color=(0, 0, 0)) 42 | 43 | for i, mask in enumerate(masks): 44 | if not autogenerated_mask: 45 | if mask.ndim > 2: # type: ignore 46 | mask = mask.squeeze() # type: ignore 47 | else: 48 | mask = mask["segmentation"] 49 | # Generate a random color with specified alpha value 50 | color = np.concatenate( 51 | (np.random.randint(0, 256, size=3), [int(alpha * 255)]), axis=0 52 | ) 53 | 54 | # Create an RGBA image for the mask 55 | mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("L") 56 | mask_colored = Image.new("RGBA", mask_image.size, tuple(color)) 57 | mask_image = Image.composite( 58 | mask_colored, Image.new("RGBA", mask_image.size), mask_image 59 | ) 60 | 61 | # Overlay mask on the output image 62 | output_image = Image.alpha_composite(output_image, mask_image) 63 | 64 | # Exit if specified to only display the best mask 65 | if only_best: 66 | break 67 | 68 | return output_image 69 | -------------------------------------------------------------------------------- /gradio_demo/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | import cv2 4 | import numpy as np 5 | from PIL import Image 6 | 7 | os.makedirs("./SAM2-Video-Predictor/checkpoints/", exist_ok=True) 8 | os.makedirs("./model/", exist_ok=True) 9 | 10 | from huggingface_hub import snapshot_download 11 | 12 | def download_sam2(): 13 | snapshot_download(repo_id="facebook/sam2-hiera-large", local_dir="./SAM2-Video-Predictor/checkpoints/") 14 | print("Download sam2 completed") 15 | 16 | def download_remover(): 17 | snapshot_download(repo_id="zibojia/minimax-remover", local_dir="./model/") 18 | print("Download minimax remover completed") 19 | 20 | download_sam2() 21 | download_remover() 22 | 23 | import torch 24 | import argparse 25 | import random 26 | 27 | import torch.nn.functional as F 28 | import time 29 | import random 30 | from omegaconf import OmegaConf 31 | from einops import rearrange 32 | from diffusers.models import AutoencoderKLWan 33 | import scipy 34 | from transformer_minimax_remover import Transformer3DModel 35 | from einops import rearrange 36 | from diffusers.schedulers import UniPCMultistepScheduler 37 | from pipeline_minimax_remover import Minimax_Remover_Pipeline 38 | 39 | from diffusers.utils import export_to_video 40 | from decord import VideoReader, cpu 41 | from moviepy.editor import ImageSequenceClip 42 | 43 | from sam2 import load_model 44 | 45 | from sam2.build_sam import build_sam2, build_sam2_video_predictor 46 | from sam2.sam2_image_predictor import SAM2ImagePredictor 47 | 48 | COLOR_PALETTE = [ 49 | (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255), 50 | (0, 255, 255), (255, 128, 0), (128, 0, 255), (0, 128, 255), (128, 255, 0) 51 | ] 52 | 53 | random_seed = 42 54 | video_length = 201 55 | W = 1024 56 | H = W 57 | device = "cuda" if torch.cuda.is_available() else "cpu" 58 | 59 | def get_pipe_image_and_video_predictor(): 60 | vae = AutoencoderKLWan.from_pretrained("./model/vae", torch_dtype=torch.float16) 61 | transformer = Transformer3DModel.from_pretrained("./model/transformer", torch_dtype=torch.float16) 62 | scheduler = UniPCMultistepScheduler.from_pretrained("./model/scheduler") 63 | 64 | pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler) 65 | pipe.to(device) 66 | 67 | sam2_checkpoint = "./SAM2-Video-Predictor/checkpoints/sam2_hiera_large.pt" 68 | config = "sam2_hiera_l.yaml" 69 | 70 | video_predictor = build_sam2_video_predictor(config, sam2_checkpoint, device=device) 71 | model = build_sam2(config, sam2_checkpoint, device=device) 72 | model.image_size = 1024 73 | image_predictor = SAM2ImagePredictor(sam_model=model) 74 | 75 | return pipe, image_predictor, video_predictor 76 | 77 | def get_video_info(video_path, video_state): 78 | video_state["input_points"] = [] 79 | video_state["scaled_points"] = [] 80 | video_state["input_labels"] = [] 81 | video_state["frame_idx"] = 0 82 | vr = VideoReader(video_path, ctx=cpu(0)) 83 | first_frame = vr[0].asnumpy() 84 | del vr 85 | 86 | if first_frame.shape[0] > first_frame.shape[1]: 87 | W_ = W 88 | H_ = int(W_ * first_frame.shape[0] / first_frame.shape[1]) 89 | else: 90 | H_ = H 91 | W_ = int(H_ * first_frame.shape[1] / first_frame.shape[0]) 92 | 93 | first_frame = cv2.resize(first_frame, (W_, H_)) 94 | video_state["origin_images"] = np.expand_dims(first_frame, axis=0) 95 | video_state["inference_state"] = None 96 | video_state["video_path"] = video_path 97 | video_state["masks"] = None 98 | video_state["painted_images"] = None 99 | image = Image.fromarray(first_frame) 100 | return image 101 | 102 | def segment_frame(evt: gr.SelectData, label, video_state): 103 | if video_state["origin_images"] is None: 104 | gr.Warning("Please click \"Extract First Frame\" to extract the first frame first, then click the annotation") 105 | return None 106 | x, y = evt.index 107 | new_point = [x, y] 108 | label_value = 1 if label == "Positive" else 0 109 | 110 | video_state["input_points"].append(new_point) 111 | video_state["input_labels"].append(label_value) 112 | height, width = video_state["origin_images"][0].shape[0:2] 113 | scaled_points = [] 114 | for pt in video_state["input_points"]: 115 | sx = pt[0] / width 116 | sy = pt[1] / height 117 | scaled_points.append([sx, sy]) 118 | 119 | video_state["scaled_points"] = scaled_points 120 | 121 | image_predictor.set_image(video_state["origin_images"][0]) 122 | mask, _, _ = image_predictor.predict( 123 | point_coords=video_state["scaled_points"], 124 | point_labels=video_state["input_labels"], 125 | multimask_output=False, 126 | normalize_coords=False, 127 | ) 128 | 129 | mask = np.squeeze(mask) 130 | mask = cv2.resize(mask, (width, height)) 131 | mask = mask[:,:,None] 132 | 133 | color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 134 | color = color[None, None, :] 135 | org_image = video_state["origin_images"][0].astype(np.float32) / 255.0 136 | painted_image = (1 - mask * 0.5) * org_image + mask * 0.5 * color 137 | painted_image = np.uint8(np.clip(painted_image * 255, 0, 255)) 138 | video_state["painted_images"] = np.expand_dims(painted_image, axis=0) 139 | video_state["masks"] = np.expand_dims(mask[:,:,0], axis=0) 140 | 141 | for i in range(len(video_state["input_points"])): 142 | point = video_state["input_points"][i] 143 | if video_state["input_labels"][i] == 0: 144 | cv2.circle(painted_image, point, radius=3, color=(0, 0, 255), thickness=-1) # 红色点,半径为3 145 | else: 146 | cv2.circle(painted_image, point, radius=3, color=(255, 0, 0), thickness=-1) 147 | 148 | return Image.fromarray(painted_image) 149 | 150 | def clear_clicks(video_state): 151 | video_state["input_points"] = [] 152 | video_state["input_labels"] = [] 153 | video_state["scaled_points"] = [] 154 | video_state["inference_state"] = None 155 | video_state["masks"] = None 156 | video_state["painted_images"] = None 157 | return Image.fromarray(video_state["origin_images"][0]) if video_state["origin_images"] is not None else None 158 | 159 | 160 | def preprocess_for_removal(images, masks): 161 | out_images = [] 162 | out_masks = [] 163 | for img, msk in zip(images, masks): 164 | if img.shape[0] > img.shape[1]: 165 | img_resized = cv2.resize(img, (480, 832), interpolation=cv2.INTER_LINEAR) 166 | else: 167 | img_resized = cv2.resize(img, (832, 480), interpolation=cv2.INTER_LINEAR) 168 | img_resized = img_resized.astype(np.float32) / 127.5 - 1.0 # [-1, 1] 169 | out_images.append(img_resized) 170 | if msk.shape[0] > msk.shape[1]: 171 | msk_resized = cv2.resize(msk, (480, 832), interpolation=cv2.INTER_NEAREST) 172 | else: 173 | msk_resized = cv2.resize(msk, (832, 480), interpolation=cv2.INTER_NEAREST) 174 | msk_resized = msk_resized.astype(np.float32) 175 | msk_resized = (msk_resized > 0.5).astype(np.float32) 176 | out_masks.append(msk_resized) 177 | arr_images = np.stack(out_images) 178 | arr_masks = np.stack(out_masks) 179 | return torch.from_numpy(arr_images).half().to(device), torch.from_numpy(arr_masks).half().to(device) 180 | 181 | 182 | def inference_and_return_video(dilation_iterations, num_inference_steps, video_state=None): 183 | if video_state["origin_images"] is None or video_state["masks"] is None: 184 | return None 185 | images = video_state["origin_images"] 186 | masks = video_state["masks"] 187 | 188 | images = np.array(images) 189 | masks = np.array(masks) 190 | img_tensor, mask_tensor = preprocess_for_removal(images, masks) 191 | mask_tensor = mask_tensor[:,:,:,:1] 192 | 193 | if mask_tensor.shape[1] < mask_tensor.shape[2]: 194 | height = 480 195 | width = 832 196 | else: 197 | height = 832 198 | width = 480 199 | 200 | with torch.no_grad(): 201 | out = pipe( 202 | images=img_tensor, 203 | masks=mask_tensor, 204 | num_frames=mask_tensor.shape[0], 205 | height=height, 206 | width=width, 207 | num_inference_steps=int(num_inference_steps), 208 | generator=torch.Generator(device=device).manual_seed(random_seed), 209 | iterations=int(dilation_iterations) 210 | ).frames[0] 211 | 212 | out = np.uint8(out * 255) 213 | output_frames = [img for img in out] 214 | 215 | video_file = f"/tmp/{time.time()}-{random.random()}-removed_output.mp4" 216 | clip = ImageSequenceClip(output_frames, fps=15) 217 | clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None) 218 | return video_file 219 | 220 | 221 | def track_video(n_frames, video_state): 222 | if video_state["origin_images"] is None or video_state["masks"] is None: 223 | gr.Warning("Please complete target segmentation on the first frame first, then click Tracking") 224 | return None 225 | 226 | input_points = video_state["input_points"] 227 | input_labels = video_state["input_labels"] 228 | frame_idx = video_state["frame_idx"] 229 | obj_id = video_state["obj_id"] 230 | scaled_points = video_state["scaled_points"] 231 | 232 | vr = VideoReader(video_state["video_path"], ctx=cpu(0)) 233 | height, width = vr[0].shape[0:2] 234 | images = [vr[i].asnumpy() for i in range(min(len(vr), n_frames))] 235 | del vr 236 | 237 | if images[0].shape[0] > images[0].shape[1]: 238 | W_ = W 239 | H_ = int(W_ * images[0].shape[0] / images[0].shape[1]) 240 | else: 241 | H_ = H 242 | W_ = int(H_ * images[0].shape[1] / images[0].shape[0]) 243 | 244 | images = [cv2.resize(img, (W_, H_)) for img in images] 245 | video_state["origin_images"] = images 246 | images = np.array(images) 247 | inference_state = video_predictor.init_state(images=images/255, device=device) 248 | video_state["inference_state"] = inference_state 249 | 250 | if len(torch.from_numpy(video_state["masks"][0]).shape) == 3: 251 | mask = torch.from_numpy(video_state["masks"][0])[:,:,0] 252 | else: 253 | mask = torch.from_numpy(video_state["masks"][0]) 254 | 255 | video_predictor.add_new_mask( 256 | inference_state=inference_state, 257 | frame_idx=0, 258 | obj_id=obj_id, 259 | mask=mask 260 | ) 261 | 262 | output_frames = [] 263 | mask_frames = [] 264 | color = np.array(COLOR_PALETTE[int(time.time()) % len(COLOR_PALETTE)], dtype=np.float32) / 255.0 265 | color = color[None, None, :] 266 | for out_frame_idx, out_obj_ids, out_mask_logits in video_predictor.propagate_in_video(inference_state): 267 | frame = images[out_frame_idx].astype(np.float32) / 255.0 268 | mask = np.zeros((H, W, 3), dtype=np.float32) 269 | for i, logit in enumerate(out_mask_logits): 270 | out_mask = logit.cpu().squeeze().detach().numpy() 271 | out_mask = (out_mask[:,:,None] > 0).astype(np.float32) 272 | mask += out_mask 273 | mask = np.clip(mask, 0, 1) 274 | mask = cv2.resize(mask, (W_, H_)) 275 | mask_frames.append(mask) 276 | painted = (1 - mask * 0.5) * frame + mask * 0.5 * color 277 | painted = np.uint8(np.clip(painted * 255, 0, 255)) 278 | output_frames.append(painted) 279 | video_state["masks"] = mask_frames 280 | video_file = f"/tmp/{time.time()}-{random.random()}-tracked_output.mp4" 281 | clip = ImageSequenceClip(output_frames, fps=15) 282 | clip.write_videofile(video_file, codec='libx264', audio=False, verbose=False, logger=None) 283 | return video_file 284 | 285 | text = """ 286 |
287 | Minimax-Remover: Taming Bad Noise Helps Video Object Removal 288 |
289 |
290 | Huggingface Model 291 | Github 292 | Huggingface Space 293 | arXiv 294 | YouTube 295 | Demo Page 296 |
297 |
298 | Bojia Zi*, Weixuan Peng*, Xianbiao Qi, Jianan Wang, Shihao Zhao, Rong Xiao, Kam-Fai Wong 299 |
300 |
301 | * Equal contribution     Corresponding author 302 |
303 | """ 304 | 305 | pipe, image_predictor, video_predictor = get_pipe_image_and_video_predictor() 306 | 307 | with gr.Blocks() as demo: 308 | video_state = gr.State({ 309 | "origin_images": None, 310 | "inference_state": None, 311 | "masks": None, # Store user-generated masks 312 | "painted_images": None, 313 | "video_path": None, 314 | "input_points": [], 315 | "scaled_points": [], 316 | "input_labels": [], 317 | "frame_idx": 0, 318 | "obj_id": 1 319 | }) 320 | gr.Markdown(f"
{text}
") 321 | 322 | with gr.Column(): 323 | video_input = gr.Video(label="Upload Video", elem_id="my-video1") 324 | get_info_btn = gr.Button("Extract First Frame", elem_id="my-btn") 325 | 326 | gr.Examples( 327 | examples=[ 328 | ["./cartoon/0.mp4"], 329 | ["./cartoon/1.mp4"], 330 | ["./cartoon/2.mp4"], 331 | ["./cartoon/3.mp4"], 332 | ["./cartoon/4.mp4"], 333 | ["./normal_videos/0.mp4"], 334 | ["./normal_videos/1.mp4"], 335 | ["./normal_videos/3.mp4"], 336 | ["./normal_videos/4.mp4"], 337 | ["./normal_videos/5.mp4"], 338 | ], 339 | inputs=[video_input], 340 | label="Choose a video to remove.", 341 | elem_id="my-btn2" 342 | ) 343 | 344 | image_output = gr.Image(label="First Frame Segmentation", interactive=True, elem_id="my-video")#, height="35%", width="60%") 345 | demo.css = """ 346 | #my-btn { 347 | width: 60% !important; 348 | margin: 0 auto; 349 | } 350 | 351 | #my-video1 { 352 | width: 60% !important; 353 | height: 35% !important; 354 | margin: 0 auto; 355 | } 356 | #my-video { 357 | width: 60% !important; 358 | height: 35% !important; 359 | margin: 0 auto; 360 | } 361 | #my-md { 362 | margin: 0 auto; 363 | } 364 | #my-btn2 { 365 | width: 60% !important; 366 | margin: 0 auto; 367 | } 368 | #my-btn2 button { 369 | width: 120px !important; 370 | max-width: 120px !important; 371 | min-width: 120px !important; 372 | height: 70px !important; 373 | max-height: 70px !important; 374 | min-height: 70px !important; 375 | margin: 8px !important; 376 | border-radius: 8px !important; 377 | overflow: hidden !important; 378 | white-space: normal !important; 379 | } 380 | """ 381 | with gr.Row(elem_id="my-btn"): 382 | point_prompt = gr.Radio(["Positive", "Negative"], label="Click Type", value="Positive") 383 | clear_btn = gr.Button("Clear All Clicks") 384 | 385 | with gr.Row(elem_id="my-btn"): 386 | n_frames_slider = gr.Slider(minimum=1, maximum=201, value=81, step=1, label="Tracking Frames N") 387 | track_btn = gr.Button("Tracking") 388 | video_output = gr.Video(label="Tracking Result", elem_id="my-video") 389 | 390 | with gr.Column(elem_id="my-btn"): 391 | dilation_slider = gr.Slider(minimum=1, maximum=20, value=6, step=1, label="Mask Dilation") 392 | inference_steps_slider = gr.Slider(minimum=1, maximum=100, value=6, step=1, label="Num Inference Steps") 393 | 394 | remove_btn = gr.Button("Remove", elem_id="my-btn") 395 | remove_video = gr.Video(label="Remove Results", elem_id="my-video") 396 | remove_btn.click( 397 | inference_and_return_video, 398 | inputs=[dilation_slider, inference_steps_slider, video_state], 399 | outputs=remove_video 400 | ) 401 | get_info_btn.click(get_video_info, inputs=[video_input, video_state], \ 402 | outputs=image_output) 403 | image_output.select(fn=segment_frame, inputs=[point_prompt, video_state], outputs=image_output) 404 | clear_btn.click(clear_clicks, inputs=video_state, outputs=image_output) 405 | track_btn.click(track_video, inputs=[n_frames_slider, video_state], outputs=video_output) 406 | 407 | demo.launch(server_name="0.0.0.0", server_port=8000) 408 | -------------------------------------------------------------------------------- /gradio_demo/transformer_minimax_remover.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.utils import logging 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import Attention 12 | from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed 13 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 14 | from diffusers.models.modeling_utils import ModelMixin 15 | from diffusers.models.normalization import FP32LayerNorm 16 | 17 | class AttnProcessor2_0: 18 | def __init__(self): 19 | if not hasattr(F, "scaled_dot_product_attention"): 20 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 21 | 22 | def __call__( 23 | self, 24 | attn: Attention, 25 | hidden_states: torch.Tensor, 26 | rotary_emb: Optional[torch.Tensor] = None, 27 | attention_mask: Optional[torch.Tensor] = None, 28 | encoder_hidden_states: Optional[torch.Tensor] = None 29 | ) -> torch.Tensor: 30 | 31 | encoder_hidden_states = hidden_states 32 | query = attn.to_q(hidden_states) 33 | key = attn.to_k(encoder_hidden_states) 34 | value = attn.to_v(encoder_hidden_states) 35 | 36 | if attn.norm_q is not None: 37 | query = attn.norm_q(query) 38 | if attn.norm_k is not None: 39 | key = attn.norm_k(key) 40 | 41 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 42 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 43 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 44 | 45 | if rotary_emb is not None: 46 | 47 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): 48 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) 49 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) 50 | return x_out.type_as(hidden_states) 51 | 52 | query = apply_rotary_emb(query, rotary_emb) 53 | key = apply_rotary_emb(key, rotary_emb) 54 | 55 | hidden_states = F.scaled_dot_product_attention( 56 | query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False 57 | ) 58 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 59 | hidden_states = hidden_states.type_as(query) 60 | 61 | hidden_states = attn.to_out[0](hidden_states) 62 | hidden_states = attn.to_out[1](hidden_states) 63 | return hidden_states 64 | 65 | class TimeEmbedding(nn.Module): 66 | def __init__( 67 | self, 68 | dim: int, 69 | time_freq_dim: int, 70 | time_proj_dim: int 71 | ): 72 | super().__init__() 73 | 74 | self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) 75 | self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) 76 | 77 | self.act_fn = nn.SiLU() 78 | self.time_proj = nn.Linear(dim, time_proj_dim) 79 | 80 | def forward( 81 | self, 82 | timestep: torch.Tensor, 83 | ): 84 | timestep = self.timesteps_proj(timestep) 85 | 86 | time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype 87 | if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: 88 | timestep = timestep.to(time_embedder_dtype) 89 | temb = self.time_embedder(timestep).type_as(self.time_proj.weight.data) 90 | timestep_proj = self.time_proj(self.act_fn(temb)) 91 | 92 | return temb, timestep_proj 93 | 94 | 95 | class RotaryPosEmbed(nn.Module): 96 | def __init__( 97 | self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 98 | ): 99 | super().__init__() 100 | 101 | self.attention_head_dim = attention_head_dim 102 | self.patch_size = patch_size 103 | self.max_seq_len = max_seq_len 104 | 105 | h_dim = w_dim = 2 * (attention_head_dim // 6) 106 | t_dim = attention_head_dim - h_dim - w_dim 107 | 108 | freqs = [] 109 | for dim in [t_dim, h_dim, w_dim]: 110 | freq = get_1d_rotary_pos_embed( 111 | dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 112 | ) 113 | freqs.append(freq) 114 | self.freqs = torch.cat(freqs, dim=1) 115 | 116 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 117 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 118 | p_t, p_h, p_w = self.patch_size 119 | ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w 120 | 121 | self.freqs = self.freqs.to(hidden_states.device) 122 | freqs = self.freqs.split_with_sizes( 123 | [ 124 | self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), 125 | self.attention_head_dim // 6, 126 | self.attention_head_dim // 6, 127 | ], 128 | dim=1, 129 | ) 130 | 131 | freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) 132 | freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) 133 | freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) 134 | freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) 135 | return freqs 136 | 137 | 138 | class TransformerBlock(nn.Module): 139 | def __init__( 140 | self, 141 | dim: int, 142 | ffn_dim: int, 143 | num_heads: int, 144 | qk_norm: str = "rms_norm_across_heads", 145 | cross_attn_norm: bool = False, 146 | eps: float = 1e-6, 147 | added_kv_proj_dim: Optional[int] = None, 148 | ): 149 | super().__init__() 150 | 151 | self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) 152 | self.attn1 = Attention( 153 | query_dim=dim, 154 | heads=num_heads, 155 | kv_heads=num_heads, 156 | dim_head=dim // num_heads, 157 | qk_norm=qk_norm, 158 | eps=eps, 159 | bias=True, 160 | cross_attention_dim=None, 161 | out_bias=True, 162 | processor=AttnProcessor2_0(), 163 | ) 164 | 165 | self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") 166 | self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False) 167 | 168 | self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 169 | 170 | def forward( 171 | self, 172 | hidden_states: torch.Tensor, 173 | temb: torch.Tensor, 174 | rotary_emb: torch.Tensor, 175 | ) -> torch.Tensor: 176 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( 177 | self.scale_shift_table + temb.float() 178 | ).chunk(6, dim=1) 179 | 180 | norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) 181 | attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) 182 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) 183 | 184 | norm_hidden_states = (self.norm2(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( 185 | hidden_states 186 | ) 187 | ff_output = self.ffn(norm_hidden_states) 188 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) 189 | 190 | return hidden_states 191 | 192 | 193 | class Transformer3DModel(ModelMixin, ConfigMixin): 194 | 195 | _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] 196 | _no_split_modules = ["TransformerBlock"] 197 | _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2"] 198 | 199 | @register_to_config 200 | def __init__( 201 | self, 202 | patch_size: Tuple[int] = (1, 2, 2), 203 | num_attention_heads: int = 40, 204 | attention_head_dim: int = 128, 205 | in_channels: int = 16, 206 | out_channels: int = 16, 207 | freq_dim: int = 256, 208 | ffn_dim: int = 13824, 209 | num_layers: int = 40, 210 | cross_attn_norm: bool = True, 211 | qk_norm: Optional[str] = "rms_norm_across_heads", 212 | eps: float = 1e-6, 213 | added_kv_proj_dim: Optional[int] = None, 214 | rope_max_seq_len: int = 1024 215 | ) -> None: 216 | super().__init__() 217 | 218 | inner_dim = num_attention_heads * attention_head_dim 219 | out_channels = out_channels or in_channels 220 | 221 | # 1. Patch & position embedding 222 | self.rope = RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) 223 | self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) 224 | 225 | # 2. Condition embeddings 226 | self.condition_embedder = TimeEmbedding( 227 | dim=inner_dim, 228 | time_freq_dim=freq_dim, 229 | time_proj_dim=inner_dim * 6, 230 | ) 231 | 232 | # 3. Transformer blocks 233 | self.blocks = nn.ModuleList( 234 | [ 235 | TransformerBlock( 236 | inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim 237 | ) 238 | for _ in range(num_layers) 239 | ] 240 | ) 241 | 242 | # 4. Output norm & projection 243 | self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) 244 | self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) 245 | self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) 246 | 247 | def forward( 248 | self, 249 | hidden_states: torch.Tensor, 250 | timestep: torch.LongTensor 251 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 252 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 253 | p_t, p_h, p_w = self.config.patch_size 254 | post_patch_num_frames = num_frames // p_t 255 | post_patch_height = height // p_h 256 | post_patch_width = width // p_w 257 | 258 | rotary_emb = self.rope(hidden_states) 259 | 260 | hidden_states = self.patch_embedding(hidden_states) 261 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 262 | 263 | temb, timestep_proj = self.condition_embedder( 264 | timestep 265 | ) 266 | timestep_proj = timestep_proj.unflatten(1, (6, -1)) 267 | 268 | for block in self.blocks: 269 | hidden_states = block(hidden_states, timestep_proj, rotary_emb) 270 | 271 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) 272 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) 273 | hidden_states = self.proj_out(hidden_states) 274 | 275 | hidden_states = hidden_states.reshape( 276 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 277 | ) 278 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) 279 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) 280 | 281 | return Transformer2DModelOutput(sample=output) 282 | -------------------------------------------------------------------------------- /imgs/gradio_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/MiniMax-Remover/28e12b450d8a72a7547b86940a4985e6ad90d75b/imgs/gradio_demo.gif -------------------------------------------------------------------------------- /pipeline_minimax_remover.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Dict, List, Optional, Union 2 | 3 | import torch 4 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 5 | from diffusers.models import AutoencoderKLWan 6 | from diffusers.schedulers import FlowMatchEulerDiscreteScheduler 7 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 8 | from diffusers.utils.torch_utils import randn_tensor 9 | from diffusers.video_processor import VideoProcessor 10 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 11 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput 12 | 13 | import scipy 14 | import numpy as np 15 | import torch.nn.functional as F 16 | from transformer_minimax_remover import Transformer3DModel 17 | from einops import rearrange 18 | 19 | if is_torch_xla_available(): 20 | import torch_xla.core.xla_model as xm 21 | 22 | XLA_AVAILABLE = True 23 | else: 24 | XLA_AVAILABLE = False 25 | 26 | class Minimax_Remover_Pipeline(DiffusionPipeline): 27 | 28 | model_cpu_offload_seq = "transformer->vae" 29 | _callback_tensor_inputs = ["latents"] 30 | 31 | def __init__( 32 | self, 33 | transformer: Transformer3DModel, 34 | vae: AutoencoderKLWan, 35 | scheduler: FlowMatchEulerDiscreteScheduler 36 | ): 37 | super().__init__() 38 | 39 | self.register_modules( 40 | vae=vae, 41 | transformer=transformer, 42 | scheduler=scheduler, 43 | ) 44 | 45 | self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4 46 | self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8 47 | self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) 48 | 49 | def prepare_latents( 50 | self, 51 | batch_size: int, 52 | num_channels_latents: 16, 53 | height: int = 720, 54 | width: int = 1280, 55 | num_latent_frames: int = 21, 56 | dtype: Optional[torch.dtype] = None, 57 | device: Optional[torch.device] = None, 58 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 59 | latents: Optional[torch.Tensor] = None, 60 | ) -> torch.Tensor: 61 | if latents is not None: 62 | return latents.to(device=device, dtype=dtype) 63 | 64 | shape = ( 65 | batch_size, 66 | num_channels_latents, 67 | num_latent_frames, 68 | int(height) // self.vae_scale_factor_spatial, 69 | int(width) // self.vae_scale_factor_spatial, 70 | ) 71 | 72 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 73 | return latents 74 | 75 | def expand_masks(self, masks, iterations): 76 | masks = masks.cpu().detach().numpy() 77 | # numpy array, masks [0,1], f h w c 78 | masks2 = [] 79 | for i in range(len(masks)): 80 | mask = masks[i] 81 | mask = mask > 0 82 | mask = scipy.ndimage.binary_dilation(mask, iterations=iterations) 83 | masks2.append(mask) 84 | masks = np.array(masks2).astype(np.float32) 85 | masks = torch.from_numpy(masks) 86 | masks = masks.repeat(1,1,1,3) 87 | masks = rearrange(masks, "f h w c -> c f h w") 88 | masks = masks[None,...] 89 | return masks 90 | 91 | def resize(self, images, w, h): 92 | bsz,_,_,_,_ = images.shape 93 | images = rearrange(images, "b c f w h -> (b f) c w h") 94 | images = F.interpolate(images, (w,h), mode='bilinear') 95 | images = rearrange(images, "(b f) c w h -> b c f w h", b=bsz) 96 | return images 97 | 98 | @property 99 | def num_timesteps(self): 100 | return self._num_timesteps 101 | 102 | @property 103 | def current_timestep(self): 104 | return self._current_timestep 105 | 106 | @property 107 | def interrupt(self): 108 | return self._interrupt 109 | 110 | @torch.no_grad() 111 | def __call__( 112 | self, 113 | height: int = 720, 114 | width: int = 1280, 115 | num_frames: int = 81, 116 | num_inference_steps: int = 50, 117 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 118 | images: Optional[torch.Tensor] = None, 119 | masks: Optional[torch.Tensor] = None, 120 | latents: Optional[torch.Tensor] = None, 121 | output_type: Optional[str] = "np", 122 | iterations: int = 16 123 | ): 124 | 125 | self._current_timestep = None 126 | self._interrupt = False 127 | device = self._execution_device 128 | batch_size = 1 129 | transformer_dtype = torch.float16 130 | 131 | self.scheduler.set_timesteps(num_inference_steps, device=device) 132 | timesteps = self.scheduler.timesteps 133 | 134 | num_channels_latents = 16 135 | num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 136 | 137 | latents = self.prepare_latents( 138 | batch_size, 139 | num_channels_latents, 140 | height, 141 | width, 142 | num_latent_frames, 143 | torch.float16, 144 | device, 145 | generator, 146 | latents, 147 | ) 148 | 149 | masks = self.expand_masks(masks, iterations) 150 | masks = self.resize(masks, height, width).to("cuda:0").half() 151 | masks[masks>0] = 1 152 | images = rearrange(images, "f h w c -> c f h w") 153 | images = self.resize(images[None,...], height, width).to("cuda:0").half() 154 | 155 | masked_images = images * (1-masks) 156 | 157 | latents_mean = ( 158 | torch.tensor(self.vae.config.latents_mean) 159 | .view(1, self.vae.config.z_dim, 1, 1, 1) 160 | .to(self.vae.device, torch.float16) 161 | ) 162 | 163 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 164 | self.vae.device, torch.float16 165 | ) 166 | 167 | with torch.no_grad(): 168 | masked_latents = self.vae.encode(masked_images.half()).latent_dist.mode() 169 | masks_latents = self.vae.encode(2*masks.half()-1.0).latent_dist.mode() 170 | 171 | masked_latents = (masked_latents - latents_mean) * latents_std 172 | masks_latents = (masks_latents - latents_mean) * latents_std 173 | 174 | self._num_timesteps = len(timesteps) 175 | 176 | with self.progress_bar(total=num_inference_steps) as progress_bar: 177 | for i, t in enumerate(timesteps): 178 | 179 | latent_model_input = latents.to(transformer_dtype) 180 | 181 | latent_model_input = torch.cat([latent_model_input, masked_latents, masks_latents], dim=1) 182 | timestep = t.expand(latents.shape[0]) 183 | 184 | noise_pred = self.transformer( 185 | hidden_states=latent_model_input.half(), 186 | timestep=timestep 187 | )[0] 188 | 189 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 190 | 191 | progress_bar.update() 192 | 193 | latents = latents.half() / latents_std + latents_mean 194 | video = self.vae.decode(latents, return_dict=False)[0] 195 | video = self.video_processor.postprocess_video(video, output_type=output_type) 196 | 197 | return WanPipelineOutput(frames=video) 198 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.7.1 2 | torchvision==0.22.1 3 | decord==0.6 4 | diffusers==0.33.1 5 | Pillow==9.2 6 | einops==0.8.0 7 | scipy==1.14.0 8 | numpy==1.26.4 9 | opencv-python==4.10.0.84 10 | gradio_client==0.7.0 11 | huggingface_hub==0.32.4 12 | omegaconf 13 | einops 14 | accelerate==0.30.1 15 | gradio==3.40.0 16 | moviepy==1.0.3 17 | hydra-core 18 | pytest 19 | -------------------------------------------------------------------------------- /test_minimax_remover.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils import export_to_video 3 | from decord import VideoReader 4 | from diffusers.models import AutoencoderKLWan 5 | from transformer_minimax_remover import Transformer3DModel 6 | from diffusers.schedulers import UniPCMultistepScheduler 7 | from pipeline_minimax_remover import Minimax_Remover_Pipeline 8 | 9 | random_seed = 42 10 | video_length = 81 11 | device = torch.device("cuda:0") 12 | 13 | vae = AutoencoderKLWan.from_pretrained("./vae", torch_dtype=torch.float16) 14 | transformer = Transformer3DModel.from_pretrained("./transformer", torch_dtype=torch.float16) 15 | scheduler = UniPCMultistepScheduler.from_pretrained("./scheduler") 16 | 17 | pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler) 18 | pipe.to(device) 19 | 20 | # the iterations is the hyperparameter for mask dilation 21 | def inference(pixel_values, masks, iterations=6): 22 | video = pipe( 23 | images=pixel_values, 24 | masks=masks, 25 | num_frames=video_length, 26 | height=480, 27 | width=832, 28 | num_inference_steps=12, 29 | generator=torch.Generator(device=device).manual_seed(random_seed), 30 | iterations=iterations 31 | ).frames[0] 32 | export_to_video(video, "./output.mp4") 33 | 34 | def load_video(video_path): 35 | vr = VideoReader(video_path) 36 | images = vr.get_batch(list(range(video_length))).asnumpy() 37 | images = torch.from_numpy(images)/127.5 - 1.0 38 | return images 39 | 40 | def load_mask(mask_path): 41 | vr = VideoReader(mask_path) 42 | masks = vr.get_batch(list(range(video_length))).asnumpy() 43 | masks = torch.from_numpy(masks) 44 | masks = masks[:, :, :, :1] 45 | masks[masks > 20] = 255 46 | masks[masks < 255] = 0 47 | masks = masks / 255.0 48 | return masks 49 | 50 | video_path = "./video.mp4" 51 | mask_path = "./mask.mp4" 52 | 53 | images = load_video(video_path) 54 | masks = load_mask(mask_path) 55 | 56 | inference(images, masks) -------------------------------------------------------------------------------- /transformer_minimax_remover.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from diffusers.configuration_utils import ConfigMixin, register_to_config 9 | from diffusers.utils import logging 10 | from diffusers.models.attention import FeedForward 11 | from diffusers.models.attention_processor import Attention 12 | from diffusers.models.embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed 13 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 14 | from diffusers.models.modeling_utils import ModelMixin 15 | from diffusers.models.normalization import FP32LayerNorm 16 | 17 | class AttnProcessor2_0: 18 | def __init__(self): 19 | if not hasattr(F, "scaled_dot_product_attention"): 20 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 21 | 22 | def __call__( 23 | self, 24 | attn: Attention, 25 | hidden_states: torch.Tensor, 26 | rotary_emb: Optional[torch.Tensor] = None, 27 | attention_mask: Optional[torch.Tensor] = None, 28 | encoder_hidden_states: Optional[torch.Tensor] = None 29 | ) -> torch.Tensor: 30 | 31 | encoder_hidden_states = hidden_states 32 | query = attn.to_q(hidden_states) 33 | key = attn.to_k(encoder_hidden_states) 34 | value = attn.to_v(encoder_hidden_states) 35 | 36 | if attn.norm_q is not None: 37 | query = attn.norm_q(query) 38 | if attn.norm_k is not None: 39 | key = attn.norm_k(key) 40 | 41 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 42 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 43 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 44 | 45 | if rotary_emb is not None: 46 | 47 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): 48 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) 49 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) 50 | return x_out.type_as(hidden_states) 51 | 52 | query = apply_rotary_emb(query, rotary_emb) 53 | key = apply_rotary_emb(key, rotary_emb) 54 | 55 | hidden_states = F.scaled_dot_product_attention( 56 | query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False 57 | ) 58 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 59 | hidden_states = hidden_states.type_as(query) 60 | 61 | hidden_states = attn.to_out[0](hidden_states) 62 | hidden_states = attn.to_out[1](hidden_states) 63 | return hidden_states 64 | 65 | class TimeEmbedding(nn.Module): 66 | def __init__( 67 | self, 68 | dim: int, 69 | time_freq_dim: int, 70 | time_proj_dim: int 71 | ): 72 | super().__init__() 73 | 74 | self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0) 75 | self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim) 76 | 77 | self.act_fn = nn.SiLU() 78 | self.time_proj = nn.Linear(dim, time_proj_dim) 79 | 80 | def forward( 81 | self, 82 | timestep: torch.Tensor, 83 | ): 84 | timestep = self.timesteps_proj(timestep) 85 | 86 | time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype 87 | if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8: 88 | timestep = timestep.to(time_embedder_dtype) 89 | temb = self.time_embedder(timestep).type_as(self.time_proj.weight.data) 90 | timestep_proj = self.time_proj(self.act_fn(temb)) 91 | 92 | return temb, timestep_proj 93 | 94 | 95 | class RotaryPosEmbed(nn.Module): 96 | def __init__( 97 | self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0 98 | ): 99 | super().__init__() 100 | 101 | self.attention_head_dim = attention_head_dim 102 | self.patch_size = patch_size 103 | self.max_seq_len = max_seq_len 104 | 105 | h_dim = w_dim = 2 * (attention_head_dim // 6) 106 | t_dim = attention_head_dim - h_dim - w_dim 107 | 108 | freqs = [] 109 | for dim in [t_dim, h_dim, w_dim]: 110 | freq = get_1d_rotary_pos_embed( 111 | dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64 112 | ) 113 | freqs.append(freq) 114 | self.freqs = torch.cat(freqs, dim=1) 115 | 116 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 117 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 118 | p_t, p_h, p_w = self.patch_size 119 | ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w 120 | 121 | self.freqs = self.freqs.to(hidden_states.device) 122 | freqs = self.freqs.split_with_sizes( 123 | [ 124 | self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6), 125 | self.attention_head_dim // 6, 126 | self.attention_head_dim // 6, 127 | ], 128 | dim=1, 129 | ) 130 | 131 | freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1) 132 | freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1) 133 | freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1) 134 | freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1) 135 | return freqs 136 | 137 | 138 | class TransformerBlock(nn.Module): 139 | def __init__( 140 | self, 141 | dim: int, 142 | ffn_dim: int, 143 | num_heads: int, 144 | qk_norm: str = "rms_norm_across_heads", 145 | cross_attn_norm: bool = False, 146 | eps: float = 1e-6, 147 | added_kv_proj_dim: Optional[int] = None, 148 | ): 149 | super().__init__() 150 | 151 | self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False) 152 | self.attn1 = Attention( 153 | query_dim=dim, 154 | heads=num_heads, 155 | kv_heads=num_heads, 156 | dim_head=dim // num_heads, 157 | qk_norm=qk_norm, 158 | eps=eps, 159 | bias=True, 160 | cross_attention_dim=None, 161 | out_bias=True, 162 | processor=AttnProcessor2_0(), 163 | ) 164 | 165 | self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate") 166 | self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=False) 167 | 168 | self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) 169 | 170 | def forward( 171 | self, 172 | hidden_states: torch.Tensor, 173 | temb: torch.Tensor, 174 | rotary_emb: torch.Tensor, 175 | ) -> torch.Tensor: 176 | shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = ( 177 | self.scale_shift_table + temb.float() 178 | ).chunk(6, dim=1) 179 | 180 | norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states) 181 | attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb) 182 | hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states) 183 | 184 | norm_hidden_states = (self.norm2(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as( 185 | hidden_states 186 | ) 187 | ff_output = self.ffn(norm_hidden_states) 188 | hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states) 189 | 190 | return hidden_states 191 | 192 | 193 | class Transformer3DModel(ModelMixin, ConfigMixin): 194 | 195 | _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"] 196 | _no_split_modules = ["TransformerBlock"] 197 | _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2"] 198 | 199 | @register_to_config 200 | def __init__( 201 | self, 202 | patch_size: Tuple[int] = (1, 2, 2), 203 | num_attention_heads: int = 40, 204 | attention_head_dim: int = 128, 205 | in_channels: int = 16, 206 | out_channels: int = 16, 207 | freq_dim: int = 256, 208 | ffn_dim: int = 13824, 209 | num_layers: int = 40, 210 | cross_attn_norm: bool = True, 211 | qk_norm: Optional[str] = "rms_norm_across_heads", 212 | eps: float = 1e-6, 213 | added_kv_proj_dim: Optional[int] = None, 214 | rope_max_seq_len: int = 1024 215 | ) -> None: 216 | super().__init__() 217 | 218 | inner_dim = num_attention_heads * attention_head_dim 219 | out_channels = out_channels or in_channels 220 | 221 | # 1. Patch & position embedding 222 | self.rope = RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) 223 | self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size) 224 | 225 | # 2. Condition embeddings 226 | self.condition_embedder = TimeEmbedding( 227 | dim=inner_dim, 228 | time_freq_dim=freq_dim, 229 | time_proj_dim=inner_dim * 6, 230 | ) 231 | 232 | # 3. Transformer blocks 233 | self.blocks = nn.ModuleList( 234 | [ 235 | TransformerBlock( 236 | inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim 237 | ) 238 | for _ in range(num_layers) 239 | ] 240 | ) 241 | 242 | # 4. Output norm & projection 243 | self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False) 244 | self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size)) 245 | self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5) 246 | 247 | def forward( 248 | self, 249 | hidden_states: torch.Tensor, 250 | timestep: torch.LongTensor 251 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 252 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 253 | p_t, p_h, p_w = self.config.patch_size 254 | post_patch_num_frames = num_frames // p_t 255 | post_patch_height = height // p_h 256 | post_patch_width = width // p_w 257 | 258 | rotary_emb = self.rope(hidden_states) 259 | 260 | hidden_states = self.patch_embedding(hidden_states) 261 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 262 | 263 | temb, timestep_proj = self.condition_embedder( 264 | timestep 265 | ) 266 | timestep_proj = timestep_proj.unflatten(1, (6, -1)) 267 | 268 | for block in self.blocks: 269 | hidden_states = block(hidden_states, timestep_proj, rotary_emb) 270 | 271 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) 272 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) 273 | hidden_states = self.proj_out(hidden_states) 274 | 275 | hidden_states = hidden_states.reshape( 276 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 277 | ) 278 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) 279 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) 280 | 281 | return Transformer2DModelOutput(sample=output) 282 | --------------------------------------------------------------------------------