├── README.md ├── __asset__ ├── COCOCO.PNG ├── DEMO.PNG ├── bocchi1.gif ├── bocchi2.gif ├── gibuli.gif ├── gibuli_lora_org.gif ├── gibuli_merged1.gif ├── gibuli_merged2.gif ├── mountain1.gif ├── mountain2.gif ├── mountain_org.gif ├── play-in-background.gif ├── river1.gif ├── river2.gif ├── river_org.gif ├── sea1.gif ├── sea2.gif ├── sea_org.gif ├── sky1.gif ├── sky2.gif ├── sky_org.gif ├── task.PNG ├── unmbrella1.gif ├── unmbrella2.gif └── unmbrella_org.gif ├── app.py ├── app_with_T2I_LoRA.py ├── cococo ├── models │ ├── attention.py │ ├── motion_module.py │ ├── resnet.py │ ├── unet.py │ └── unet_blocks.py ├── pipelines │ └── pipeline_animation_inpainting_cross_attention_vae.py └── utils │ ├── convert_from_ckpt.py │ ├── convert_lora_safetensor_to_diffusers.py │ └── util.py ├── configs └── code_release.yaml ├── dist └── .gitattributes ├── images ├── images.npy └── masks.npy ├── outputs └── .gitattributes ├── requirements.txt ├── sam2 ├── __init__.py ├── automatic_mask_generator.py ├── build_sam.py ├── csrc │ └── connected_components.cu ├── modeling │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── hieradet.py │ │ ├── image_encoder.py │ │ └── utils.py │ ├── memory_attention.py │ ├── memory_encoder.py │ ├── position_encoding.py │ ├── sam │ │ ├── __init__.py │ │ ├── mask_decoder.py │ │ ├── prompt_encoder.py │ │ └── transformer.py │ ├── sam2_base.py │ └── sam2_utils.py ├── sam2_image_predictor.py ├── sam2_video_predictor.py └── utils │ ├── __init__.py │ ├── amg.py │ ├── misc.py │ ├── test.txt │ └── transforms.py ├── sam2_configs ├── __init__.py ├── sam2_hiera_b+.yaml ├── sam2_hiera_l.yaml ├── sam2_hiera_s.yaml └── sam2_hiera_t.yaml ├── sav_dataset ├── LICENSE ├── LICENSE_DAVIS ├── LICENSE_VOS_BENCHMARK ├── README.md ├── requirements.txt ├── sav_evaluator.py └── utils │ ├── sav_benchmark.py │ └── sav_utils.py ├── setup.py ├── task_vector ├── convert.py ├── convert_lora.py ├── lora.json └── resources │ ├── source.txt │ ├── target.txt │ ├── text_source.txt │ ├── text_target.txt │ ├── vae_source.txt │ └── vae_target.txt ├── tools ├── README.md └── vos_inference.py ├── utils.py ├── utils_with_T2I_LoRA.py ├── valid_code_release.py └── valid_code_release_with_T2I_LoRA.py /__asset__/COCOCO.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/COCOCO.PNG -------------------------------------------------------------------------------- /__asset__/DEMO.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/DEMO.PNG -------------------------------------------------------------------------------- /__asset__/bocchi1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/bocchi1.gif -------------------------------------------------------------------------------- /__asset__/bocchi2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/bocchi2.gif -------------------------------------------------------------------------------- /__asset__/gibuli.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/gibuli.gif -------------------------------------------------------------------------------- /__asset__/gibuli_lora_org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/gibuli_lora_org.gif -------------------------------------------------------------------------------- /__asset__/gibuli_merged1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/gibuli_merged1.gif -------------------------------------------------------------------------------- /__asset__/gibuli_merged2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/gibuli_merged2.gif -------------------------------------------------------------------------------- /__asset__/mountain1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/mountain1.gif -------------------------------------------------------------------------------- /__asset__/mountain2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/mountain2.gif -------------------------------------------------------------------------------- /__asset__/mountain_org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/mountain_org.gif -------------------------------------------------------------------------------- /__asset__/play-in-background.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/play-in-background.gif -------------------------------------------------------------------------------- /__asset__/river1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/river1.gif -------------------------------------------------------------------------------- /__asset__/river2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/river2.gif -------------------------------------------------------------------------------- /__asset__/river_org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/river_org.gif -------------------------------------------------------------------------------- /__asset__/sea1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/sea1.gif -------------------------------------------------------------------------------- /__asset__/sea2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/sea2.gif -------------------------------------------------------------------------------- /__asset__/sea_org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/sea_org.gif -------------------------------------------------------------------------------- /__asset__/sky1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/sky1.gif -------------------------------------------------------------------------------- /__asset__/sky2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/sky2.gif -------------------------------------------------------------------------------- /__asset__/sky_org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/sky_org.gif -------------------------------------------------------------------------------- /__asset__/task.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/task.PNG -------------------------------------------------------------------------------- /__asset__/unmbrella1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/unmbrella1.gif -------------------------------------------------------------------------------- /__asset__/unmbrella2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/unmbrella2.gif -------------------------------------------------------------------------------- /__asset__/unmbrella_org.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/__asset__/unmbrella_org.gif -------------------------------------------------------------------------------- /cococo/models/resnet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from einops import rearrange 8 | 9 | 10 | class InflatedConv3d(nn.Conv2d): 11 | def forward(self, x): 12 | video_length = x.shape[2] 13 | 14 | x = rearrange(x, "b c f h w -> (b f) c h w") 15 | x = super().forward(x) 16 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 17 | 18 | return x 19 | 20 | 21 | class InflatedGroupNorm(nn.GroupNorm): 22 | def forward(self, x): 23 | video_length = x.shape[2] 24 | 25 | x = rearrange(x, "b c f h w -> (b f) c h w") 26 | x = super().forward(x) 27 | x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) 28 | 29 | return x 30 | 31 | 32 | class Upsample3D(nn.Module): 33 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 34 | super().__init__() 35 | self.channels = channels 36 | self.out_channels = out_channels or channels 37 | self.use_conv = use_conv 38 | self.use_conv_transpose = use_conv_transpose 39 | self.name = name 40 | 41 | conv = None 42 | if use_conv_transpose: 43 | raise NotImplementedError 44 | elif use_conv: 45 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) 46 | 47 | def forward(self, hidden_states, output_size=None): 48 | assert hidden_states.shape[1] == self.channels 49 | 50 | if self.use_conv_transpose: 51 | raise NotImplementedError 52 | 53 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 54 | dtype = hidden_states.dtype 55 | if dtype == torch.bfloat16: 56 | hidden_states = hidden_states.to(torch.float32) 57 | 58 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 59 | if hidden_states.shape[0] >= 64: 60 | hidden_states = hidden_states.contiguous() 61 | 62 | # if `output_size` is passed we force the interpolation output 63 | # size and do not make use of `scale_factor=2` 64 | if output_size is None: 65 | hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") 66 | else: 67 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 68 | 69 | # If the input is bfloat16, we cast back to bfloat16 70 | if dtype == torch.bfloat16: 71 | hidden_states = hidden_states.to(dtype) 72 | 73 | # if self.use_conv: 74 | # if self.name == "conv": 75 | # hidden_states = self.conv(hidden_states) 76 | # else: 77 | # hidden_states = self.Conv2d_0(hidden_states) 78 | hidden_states = self.conv(hidden_states) 79 | 80 | return hidden_states 81 | 82 | 83 | class Downsample3D(nn.Module): 84 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 85 | super().__init__() 86 | self.channels = channels 87 | self.out_channels = out_channels or channels 88 | self.use_conv = use_conv 89 | self.padding = padding 90 | stride = 2 91 | self.name = name 92 | 93 | if use_conv: 94 | self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 95 | else: 96 | raise NotImplementedError 97 | 98 | def forward(self, hidden_states): 99 | assert hidden_states.shape[1] == self.channels 100 | if self.use_conv and self.padding == 0: 101 | raise NotImplementedError 102 | 103 | assert hidden_states.shape[1] == self.channels 104 | hidden_states = self.conv(hidden_states) 105 | 106 | return hidden_states 107 | 108 | 109 | class ResnetBlock3D(nn.Module): 110 | def __init__( 111 | self, 112 | *, 113 | in_channels, 114 | out_channels=None, 115 | conv_shortcut=False, 116 | dropout=0.0, 117 | temb_channels=512, 118 | groups=32, 119 | groups_out=None, 120 | pre_norm=True, 121 | eps=1e-6, 122 | non_linearity="swish", 123 | time_embedding_norm="default", 124 | output_scale_factor=1.0, 125 | use_in_shortcut=None, 126 | use_inflated_groupnorm=False, 127 | ): 128 | super().__init__() 129 | self.pre_norm = pre_norm 130 | self.pre_norm = True 131 | self.in_channels = in_channels 132 | out_channels = in_channels if out_channels is None else out_channels 133 | self.out_channels = out_channels 134 | self.use_conv_shortcut = conv_shortcut 135 | self.time_embedding_norm = time_embedding_norm 136 | self.output_scale_factor = output_scale_factor 137 | 138 | if groups_out is None: 139 | groups_out = groups 140 | 141 | assert use_inflated_groupnorm != None 142 | if use_inflated_groupnorm: 143 | self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 144 | else: 145 | self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 146 | 147 | self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 148 | 149 | if temb_channels is not None: 150 | if self.time_embedding_norm == "default": 151 | time_emb_proj_out_channels = out_channels 152 | elif self.time_embedding_norm == "scale_shift": 153 | time_emb_proj_out_channels = out_channels * 2 154 | else: 155 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 156 | 157 | self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) 158 | else: 159 | self.time_emb_proj = None 160 | 161 | if use_inflated_groupnorm: 162 | self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 163 | else: 164 | self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 165 | 166 | self.dropout = torch.nn.Dropout(dropout) 167 | self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 168 | 169 | if non_linearity == "swish": 170 | self.nonlinearity = lambda x: F.silu(x) 171 | elif non_linearity == "mish": 172 | self.nonlinearity = Mish() 173 | elif non_linearity == "silu": 174 | self.nonlinearity = nn.SiLU() 175 | 176 | self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut 177 | 178 | self.conv_shortcut = None 179 | if self.use_in_shortcut: 180 | self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) 181 | 182 | def forward(self, input_tensor, temb): 183 | hidden_states = input_tensor 184 | 185 | hidden_states = self.norm1(hidden_states) 186 | hidden_states = self.nonlinearity(hidden_states) 187 | 188 | hidden_states = self.conv1(hidden_states) 189 | 190 | if temb is not None: 191 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] 192 | 193 | if temb is not None and self.time_embedding_norm == "default": 194 | hidden_states = hidden_states + temb 195 | 196 | hidden_states = self.norm2(hidden_states) 197 | 198 | if temb is not None and self.time_embedding_norm == "scale_shift": 199 | scale, shift = torch.chunk(temb, 2, dim=1) 200 | hidden_states = hidden_states * (1 + scale) + shift 201 | 202 | hidden_states = self.nonlinearity(hidden_states) 203 | 204 | hidden_states = self.dropout(hidden_states) 205 | hidden_states = self.conv2(hidden_states) 206 | 207 | if self.conv_shortcut is not None: 208 | input_tensor = self.conv_shortcut(input_tensor) 209 | 210 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 211 | 212 | return output_tensor 213 | 214 | 215 | class Mish(torch.nn.Module): 216 | def forward(self, hidden_states): 217 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) -------------------------------------------------------------------------------- /cococo/utils/convert_lora_safetensor_to_diffusers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # 16 | # Changes were made to this source code by Yuwei Guo. 17 | """ Conversion script for the LoRA's safetensors checkpoints. """ 18 | 19 | import argparse 20 | 21 | import torch 22 | from safetensors.torch import load_file 23 | 24 | from diffusers import StableDiffusionPipeline 25 | 26 | 27 | def load_diffusers_lora(pipeline, state_dict, alpha=1.0): 28 | # directly update weight in diffusers model 29 | for key in state_dict: 30 | # only process lora down key 31 | if "up." in key: continue 32 | 33 | up_key = key.replace(".down.", ".up.") 34 | model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") 35 | model_key = model_key.replace("to_out.", "to_out.0.") 36 | layer_infos = model_key.split(".")[:-1] 37 | 38 | curr_layer = pipeline.unet 39 | while len(layer_infos) > 0: 40 | temp_name = layer_infos.pop(0) 41 | curr_layer = curr_layer.__getattr__(temp_name) 42 | 43 | weight_down = state_dict[key] 44 | weight_up = state_dict[up_key] 45 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 46 | 47 | return pipeline 48 | 49 | 50 | def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): 51 | # load base model 52 | # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 53 | 54 | # load LoRA weight from .safetensors 55 | # state_dict = load_file(checkpoint_path) 56 | 57 | visited = [] 58 | 59 | # directly update weight in diffusers model 60 | for key in state_dict: 61 | # it is suggested to print out the key, it usually will be something like below 62 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 63 | 64 | # as we have set the alpha beforehand, so just skip 65 | if ".alpha" in key or key in visited: 66 | continue 67 | 68 | if "text" in key: 69 | layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 70 | curr_layer = pipeline.text_encoder 71 | else: 72 | layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") 73 | curr_layer = pipeline.unet 74 | 75 | # find the target layer 76 | temp_name = layer_infos.pop(0) 77 | while len(layer_infos) > -1: 78 | try: 79 | curr_layer = curr_layer.__getattr__(temp_name) 80 | if len(layer_infos) > 0: 81 | temp_name = layer_infos.pop(0) 82 | elif len(layer_infos) == 0: 83 | break 84 | except Exception: 85 | if len(temp_name) > 0: 86 | temp_name += "_" + layer_infos.pop(0) 87 | else: 88 | temp_name = layer_infos.pop(0) 89 | 90 | pair_keys = [] 91 | if "lora_down" in key: 92 | pair_keys.append(key.replace("lora_down", "lora_up")) 93 | pair_keys.append(key) 94 | else: 95 | pair_keys.append(key) 96 | pair_keys.append(key.replace("lora_up", "lora_down")) 97 | 98 | # update weight 99 | if len(state_dict[pair_keys[0]].shape) == 4: 100 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 101 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 102 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) 103 | else: 104 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 105 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 106 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) 107 | 108 | # update visited list 109 | for item in pair_keys: 110 | visited.append(item) 111 | 112 | return pipeline 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | 118 | parser.add_argument( 119 | "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." 120 | ) 121 | parser.add_argument( 122 | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." 123 | ) 124 | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") 125 | parser.add_argument( 126 | "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" 127 | ) 128 | parser.add_argument( 129 | "--lora_prefix_text_encoder", 130 | default="lora_te", 131 | type=str, 132 | help="The prefix of text encoder weight in safetensors", 133 | ) 134 | parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") 135 | parser.add_argument( 136 | "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." 137 | ) 138 | parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") 139 | 140 | args = parser.parse_args() 141 | 142 | base_model_path = args.base_model_path 143 | checkpoint_path = args.checkpoint_path 144 | dump_path = args.dump_path 145 | lora_prefix_unet = args.lora_prefix_unet 146 | lora_prefix_text_encoder = args.lora_prefix_text_encoder 147 | alpha = args.alpha 148 | 149 | pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) 150 | 151 | pipe = pipe.to(args.device) 152 | pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) 153 | -------------------------------------------------------------------------------- /cococo/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | from cococo.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 14 | from cococo.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora 15 | 16 | 17 | def zero_rank_print(s): 18 | if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) 19 | 20 | def save_videos_grid2(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 21 | #videos = rearrange(videos, "b c t h w -> t b c h w") 22 | videos = rearrange(videos, "b t c h w -> t b c h w") 23 | outputs = [] 24 | for x in videos: 25 | x = torchvision.utils.make_grid(x, nrow=n_rows) 26 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 27 | if rescale: 28 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 29 | x = torch.clamp(x, 0, 1) 30 | x = (x * 255).numpy().astype(np.uint8) 31 | #print(x.shape) 32 | #x = x[:,:,::-1] 33 | outputs.append(x) 34 | 35 | os.makedirs(os.path.dirname(path), exist_ok=True) 36 | imageio.mimsave(path, outputs, fps=fps) 37 | 38 | def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): 39 | videos = rearrange(videos, "b c t h w -> t b c h w") 40 | #videos = rearrange(videos, "b t c h w -> t b c h w") 41 | outputs = [] 42 | for x in videos: 43 | x = torchvision.utils.make_grid(x, nrow=n_rows) 44 | x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) 45 | if rescale: 46 | x = (x + 1.0) / 2.0 # -1,1 -> 0,1 47 | x = torch.clamp(x, 0, 1) 48 | x = (x * 255).numpy().astype(np.uint8) 49 | outputs.append(x) 50 | 51 | os.makedirs(os.path.dirname(path), exist_ok=True) 52 | imageio.mimsave(path, outputs, fps=fps) 53 | 54 | 55 | # DDIM Inversion 56 | @torch.no_grad() 57 | def init_prompt(prompt, pipeline): 58 | uncond_input = pipeline.tokenizer( 59 | [""], padding="max_length", max_length=pipeline.tokenizer.model_max_length, 60 | return_tensors="pt" 61 | ) 62 | uncond_embeddings = pipeline.text_encoder(uncond_input.input_ids.to(pipeline.device))[0] 63 | text_input = pipeline.tokenizer( 64 | [prompt], 65 | padding="max_length", 66 | max_length=pipeline.tokenizer.model_max_length, 67 | truncation=True, 68 | return_tensors="pt", 69 | ) 70 | text_embeddings = pipeline.text_encoder(text_input.input_ids.to(pipeline.device))[0] 71 | context = torch.cat([uncond_embeddings, text_embeddings]) 72 | 73 | return context 74 | 75 | 76 | def next_step(model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 77 | sample: Union[torch.FloatTensor, np.ndarray], ddim_scheduler): 78 | timestep, next_timestep = min( 79 | timestep - ddim_scheduler.config.num_train_timesteps // ddim_scheduler.num_inference_steps, 999), timestep 80 | alpha_prod_t = ddim_scheduler.alphas_cumprod[timestep] if timestep >= 0 else ddim_scheduler.final_alpha_cumprod 81 | alpha_prod_t_next = ddim_scheduler.alphas_cumprod[next_timestep] 82 | beta_prod_t = 1 - alpha_prod_t 83 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 84 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 85 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 86 | return next_sample 87 | 88 | 89 | def get_noise_pred_single(latents, t, context, unet): 90 | noise_pred = unet(latents, t, encoder_hidden_states=context)["sample"] 91 | return noise_pred 92 | 93 | 94 | @torch.no_grad() 95 | def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): 96 | context = init_prompt(prompt, pipeline) 97 | uncond_embeddings, cond_embeddings = context.chunk(2) 98 | all_latent = [latent] 99 | latent = latent.clone().detach() 100 | for i in tqdm(range(num_inv_steps)): 101 | t = ddim_scheduler.timesteps[len(ddim_scheduler.timesteps) - i - 1] 102 | noise_pred = get_noise_pred_single(latent, t, cond_embeddings, pipeline.unet) 103 | latent = next_step(noise_pred, t, latent, ddim_scheduler) 104 | all_latent.append(latent) 105 | return all_latent 106 | 107 | 108 | @torch.no_grad() 109 | def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): 110 | ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) 111 | return ddim_latents 112 | 113 | def load_weights( 114 | animation_pipeline, 115 | # motion module 116 | motion_module_path = "", 117 | motion_module_lora_configs = [], 118 | # domain adapter 119 | adapter_lora_path = "", 120 | adapter_lora_scale = 1.0, 121 | # image layers 122 | dreambooth_model_path = "", 123 | lora_model_path = "", 124 | lora_alpha = 0.8, 125 | ): 126 | # motion module 127 | unet_state_dict = {} 128 | if motion_module_path != "": 129 | print(f"load motion module from {motion_module_path}") 130 | motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") 131 | motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict 132 | unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) 133 | unet_state_dict.pop("animatediff_config", "") 134 | 135 | missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) 136 | assert len(unexpected) == 0 137 | del unet_state_dict 138 | 139 | # base model 140 | if dreambooth_model_path != "": 141 | print(f"load dreambooth model from {dreambooth_model_path}") 142 | if dreambooth_model_path.endswith(".safetensors"): 143 | dreambooth_state_dict = {} 144 | with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: 145 | for key in f.keys(): 146 | dreambooth_state_dict[key] = f.get_tensor(key) 147 | elif dreambooth_model_path.endswith(".ckpt"): 148 | dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") 149 | 150 | # 1. vae 151 | converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) 152 | animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) 153 | # 2. unet 154 | converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) 155 | animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) 156 | # 3. text_model 157 | animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) 158 | del dreambooth_state_dict 159 | 160 | # lora layers 161 | if lora_model_path != "": 162 | print(f"load lora model from {lora_model_path}") 163 | assert lora_model_path.endswith(".safetensors") 164 | lora_state_dict = {} 165 | with safe_open(lora_model_path, framework="pt", device="cpu") as f: 166 | for key in f.keys(): 167 | lora_state_dict[key] = f.get_tensor(key) 168 | 169 | animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) 170 | del lora_state_dict 171 | 172 | # domain adapter lora 173 | if adapter_lora_path != "": 174 | print(f"load domain lora from {adapter_lora_path}") 175 | domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") 176 | domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict 177 | domain_lora_state_dict.pop("animatediff_config", "") 178 | 179 | animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) 180 | 181 | # motion module lora 182 | for motion_module_lora_config in motion_module_lora_configs: 183 | path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] 184 | print(f"load motion LoRA from {path}") 185 | motion_lora_state_dict = torch.load(path, map_location="cpu") 186 | motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict 187 | motion_lora_state_dict.pop("animatediff_config", "") 188 | 189 | animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) 190 | 191 | return animation_pipeline 192 | -------------------------------------------------------------------------------- /configs/code_release.yaml: -------------------------------------------------------------------------------- 1 | unet_additional_kwargs: 2 | use_inflated_groupnorm: true 3 | unet_use_cross_frame_attention: false 4 | unet_use_temporal_attention: false 5 | use_motion_module: true 6 | motion_module_resolutions: [1,2,4,8] 7 | motion_module_mid_block: true 8 | motion_module_decoder_only: false 9 | 10 | motion_module_type: Vanilla 11 | motion_module_kwargs: 12 | num_attention_heads: 8 13 | num_transformer_block: 1 14 | attention_block_types: [ "Temporal_Self", "Temporal_Self", "Temporal_Light_down_resize", "Temporal_Text_Cross" ] 15 | temporal_position_encoding: true 16 | temporal_position_encoding_max_len: 64 17 | temporal_attention_dim_div: 1 18 | text_cross_attention_dim: 768 19 | vision_cross_attention_dim: 768 20 | 21 | noise_scheduler_kwargs: 22 | num_train_timesteps: 1000 23 | beta_start: 0.00085 24 | beta_end: 0.012 25 | beta_schedule: "linear" 26 | steps_offset: 1 27 | clip_sample: false 28 | -------------------------------------------------------------------------------- /dist/.gitattributes: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /images/images.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/images/images.npy -------------------------------------------------------------------------------- /images/masks.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zibojia/COCOCO/9ebe984cf1148470ec67b26e2781e05285f1512b/images/masks.npy -------------------------------------------------------------------------------- /outputs/.gitattributes: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.4.0 2 | torchvision 3 | torchaudio 4 | diffusers==0.11.1 5 | transformers==4.25.1 6 | imageio==2.34.0 7 | decord==0.6.0 8 | gdown 9 | omegaconf 10 | gradio==3.40 11 | xformers==0.0.28.dev864 12 | imageio-ffmpeg==0.4.9 13 | decord==0.6.0 14 | omegaconf==2.3.0 15 | safetensors 16 | einops 17 | wandb 18 | -------------------------------------------------------------------------------- /sam2/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from hydra import initialize_config_module 8 | 9 | initialize_config_module("sam2_configs", version_base="1.2") 10 | -------------------------------------------------------------------------------- /sam2/build_sam.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | 9 | import torch 10 | from hydra import compose 11 | from hydra.utils import instantiate 12 | from omegaconf import OmegaConf 13 | 14 | 15 | def build_sam2( 16 | config_file, 17 | ckpt_path=None, 18 | device="cuda", 19 | mode="eval", 20 | hydra_overrides_extra=[], 21 | apply_postprocessing=True, 22 | ): 23 | 24 | if apply_postprocessing: 25 | hydra_overrides_extra = hydra_overrides_extra.copy() 26 | hydra_overrides_extra += [ 27 | # dynamically fall back to multi-mask if the single mask is not stable 28 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 29 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 30 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 31 | ] 32 | # Read config and init model 33 | cfg = compose(config_name=config_file, overrides=hydra_overrides_extra) 34 | OmegaConf.resolve(cfg) 35 | model = instantiate(cfg.model, _recursive_=True) 36 | _load_checkpoint(model, ckpt_path) 37 | model = model.to(device) 38 | if mode == "eval": 39 | model.eval() 40 | return model 41 | 42 | 43 | def build_sam2_video_predictor( 44 | config_file, 45 | ckpt_path=None, 46 | device="cuda", 47 | mode="eval", 48 | hydra_overrides_extra=[], 49 | apply_postprocessing=True, 50 | ): 51 | hydra_overrides = [ 52 | "++model._target_=sam2.sam2_video_predictor.SAM2VideoPredictor", 53 | ] 54 | if apply_postprocessing: 55 | hydra_overrides_extra = hydra_overrides_extra.copy() 56 | hydra_overrides_extra += [ 57 | # dynamically fall back to multi-mask if the single mask is not stable 58 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_via_stability=true", 59 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_delta=0.05", 60 | "++model.sam_mask_decoder_extra_args.dynamic_multimask_stability_thresh=0.98", 61 | # the sigmoid mask logits on interacted frames with clicks in the memory encoder so that the encoded masks are exactly as what users see from clicking 62 | "++model.binarize_mask_from_pts_for_mem_enc=true", 63 | # fill small holes in the low-res masks up to `fill_hole_area` (before resizing them to the original video resolution) 64 | "++model.fill_hole_area=8", 65 | ] 66 | hydra_overrides.extend(hydra_overrides_extra) 67 | 68 | # Read config and init model 69 | cfg = compose(config_name=config_file, overrides=hydra_overrides) 70 | OmegaConf.resolve(cfg) 71 | model = instantiate(cfg.model, _recursive_=True) 72 | _load_checkpoint(model, ckpt_path) 73 | model = model.to(device) 74 | if mode == "eval": 75 | model.eval() 76 | return model 77 | 78 | 79 | def _load_checkpoint(model, ckpt_path): 80 | if ckpt_path is not None: 81 | sd = torch.load(ckpt_path, map_location="cpu")["model"] 82 | missing_keys, unexpected_keys = model.load_state_dict(sd) 83 | if missing_keys: 84 | logging.error(missing_keys) 85 | raise RuntimeError() 86 | if unexpected_keys: 87 | logging.error(unexpected_keys) 88 | raise RuntimeError() 89 | logging.info("Loaded checkpoint sucessfully") 90 | -------------------------------------------------------------------------------- /sam2/csrc/connected_components.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Meta Platforms, Inc. and affiliates. 2 | // All rights reserved. 3 | 4 | // This source code is licensed under the license found in the 5 | // LICENSE file in the root directory of this source tree. 6 | 7 | // adapted from https://github.com/zsef123/Connected_components_PyTorch 8 | // with license found in the LICENSE_cctorch file in the root directory. 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | #include 15 | 16 | // 2d 17 | #define BLOCK_ROWS 16 18 | #define BLOCK_COLS 16 19 | 20 | namespace cc2d { 21 | 22 | template 23 | __device__ __forceinline__ unsigned char hasBit(T bitmap, unsigned char pos) { 24 | return (bitmap >> pos) & 1; 25 | } 26 | 27 | __device__ int32_t find(const int32_t* s_buf, int32_t n) { 28 | while (s_buf[n] != n) 29 | n = s_buf[n]; 30 | return n; 31 | } 32 | 33 | __device__ int32_t find_n_compress(int32_t* s_buf, int32_t n) { 34 | const int32_t id = n; 35 | while (s_buf[n] != n) { 36 | n = s_buf[n]; 37 | s_buf[id] = n; 38 | } 39 | return n; 40 | } 41 | 42 | __device__ void union_(int32_t* s_buf, int32_t a, int32_t b) { 43 | bool done; 44 | do { 45 | a = find(s_buf, a); 46 | b = find(s_buf, b); 47 | 48 | if (a < b) { 49 | int32_t old = atomicMin(s_buf + b, a); 50 | done = (old == b); 51 | b = old; 52 | } else if (b < a) { 53 | int32_t old = atomicMin(s_buf + a, b); 54 | done = (old == a); 55 | a = old; 56 | } else 57 | done = true; 58 | 59 | } while (!done); 60 | } 61 | 62 | __global__ void 63 | init_labeling(int32_t* label, const uint32_t W, const uint32_t H) { 64 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 65 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 66 | const uint32_t idx = row * W + col; 67 | 68 | if (row < H && col < W) 69 | label[idx] = idx; 70 | } 71 | 72 | __global__ void 73 | merge(uint8_t* img, int32_t* label, const uint32_t W, const uint32_t H) { 74 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 75 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 76 | const uint32_t idx = row * W + col; 77 | 78 | if (row >= H || col >= W) 79 | return; 80 | 81 | uint32_t P = 0; 82 | 83 | if (img[idx]) 84 | P |= 0x777; 85 | if (row + 1 < H && img[idx + W]) 86 | P |= 0x777 << 4; 87 | if (col + 1 < W && img[idx + 1]) 88 | P |= 0x777 << 1; 89 | 90 | if (col == 0) 91 | P &= 0xEEEE; 92 | if (col + 1 >= W) 93 | P &= 0x3333; 94 | else if (col + 2 >= W) 95 | P &= 0x7777; 96 | 97 | if (row == 0) 98 | P &= 0xFFF0; 99 | if (row + 1 >= H) 100 | P &= 0xFF; 101 | 102 | if (P > 0) { 103 | // If need check about top-left pixel(if flag the first bit) and hit the 104 | // top-left pixel 105 | if (hasBit(P, 0) && img[idx - W - 1]) { 106 | union_(label, idx, idx - 2 * W - 2); // top left block 107 | } 108 | 109 | if ((hasBit(P, 1) && img[idx - W]) || (hasBit(P, 2) && img[idx - W + 1])) 110 | union_(label, idx, idx - 2 * W); // top bottom block 111 | 112 | if (hasBit(P, 3) && img[idx + 2 - W]) 113 | union_(label, idx, idx - 2 * W + 2); // top right block 114 | 115 | if ((hasBit(P, 4) && img[idx - 1]) || (hasBit(P, 8) && img[idx + W - 1])) 116 | union_(label, idx, idx - 2); // just left block 117 | } 118 | } 119 | 120 | __global__ void compression(int32_t* label, const int32_t W, const int32_t H) { 121 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 122 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 123 | const uint32_t idx = row * W + col; 124 | 125 | if (row < H && col < W) 126 | find_n_compress(label, idx); 127 | } 128 | 129 | __global__ void final_labeling( 130 | const uint8_t* img, 131 | int32_t* label, 132 | const int32_t W, 133 | const int32_t H) { 134 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y) * 2; 135 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x) * 2; 136 | const uint32_t idx = row * W + col; 137 | 138 | if (row >= H || col >= W) 139 | return; 140 | 141 | int32_t y = label[idx] + 1; 142 | 143 | if (img[idx]) 144 | label[idx] = y; 145 | else 146 | label[idx] = 0; 147 | 148 | if (col + 1 < W) { 149 | if (img[idx + 1]) 150 | label[idx + 1] = y; 151 | else 152 | label[idx + 1] = 0; 153 | 154 | if (row + 1 < H) { 155 | if (img[idx + W + 1]) 156 | label[idx + W + 1] = y; 157 | else 158 | label[idx + W + 1] = 0; 159 | } 160 | } 161 | 162 | if (row + 1 < H) { 163 | if (img[idx + W]) 164 | label[idx + W] = y; 165 | else 166 | label[idx + W] = 0; 167 | } 168 | } 169 | 170 | __global__ void init_counting( 171 | const int32_t* label, 172 | int32_t* count_init, 173 | const int32_t W, 174 | const int32_t H) { 175 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 176 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 177 | const uint32_t idx = row * W + col; 178 | 179 | if (row >= H || col >= W) 180 | return; 181 | 182 | int32_t y = label[idx]; 183 | if (y > 0) { 184 | int32_t count_idx = y - 1; 185 | atomicAdd(count_init + count_idx, 1); 186 | } 187 | } 188 | 189 | __global__ void final_counting( 190 | const int32_t* label, 191 | const int32_t* count_init, 192 | int32_t* count_final, 193 | const int32_t W, 194 | const int32_t H) { 195 | const uint32_t row = (blockIdx.y * blockDim.y + threadIdx.y); 196 | const uint32_t col = (blockIdx.x * blockDim.x + threadIdx.x); 197 | const uint32_t idx = row * W + col; 198 | 199 | if (row >= H || col >= W) 200 | return; 201 | 202 | int32_t y = label[idx]; 203 | if (y > 0) { 204 | int32_t count_idx = y - 1; 205 | count_final[idx] = count_init[count_idx]; 206 | } else { 207 | count_final[idx] = 0; 208 | } 209 | } 210 | 211 | } // namespace cc2d 212 | 213 | std::vector get_connected_componnets( 214 | const torch::Tensor& inputs) { 215 | AT_ASSERTM(inputs.is_cuda(), "inputs must be a CUDA tensor"); 216 | AT_ASSERTM(inputs.ndimension() == 4, "inputs must be [N, 1, H, W] shape"); 217 | AT_ASSERTM( 218 | inputs.scalar_type() == torch::kUInt8, "inputs must be a uint8 type"); 219 | 220 | const uint32_t N = inputs.size(0); 221 | const uint32_t C = inputs.size(1); 222 | const uint32_t H = inputs.size(2); 223 | const uint32_t W = inputs.size(3); 224 | 225 | AT_ASSERTM(C == 1, "inputs must be [N, 1, H, W] shape"); 226 | AT_ASSERTM((H % 2) == 0, "height must be a even number"); 227 | AT_ASSERTM((W % 2) == 0, "width must be a even number"); 228 | 229 | // label must be uint32_t 230 | auto label_options = 231 | torch::TensorOptions().dtype(torch::kInt32).device(inputs.device()); 232 | torch::Tensor labels = torch::zeros({N, C, H, W}, label_options); 233 | torch::Tensor counts_init = torch::zeros({N, C, H, W}, label_options); 234 | torch::Tensor counts_final = torch::zeros({N, C, H, W}, label_options); 235 | 236 | dim3 grid = dim3( 237 | ((W + 1) / 2 + BLOCK_COLS - 1) / BLOCK_COLS, 238 | ((H + 1) / 2 + BLOCK_ROWS - 1) / BLOCK_ROWS); 239 | dim3 block = dim3(BLOCK_COLS, BLOCK_ROWS); 240 | dim3 grid_count = 241 | dim3((W + BLOCK_COLS) / BLOCK_COLS, (H + BLOCK_ROWS) / BLOCK_ROWS); 242 | dim3 block_count = dim3(BLOCK_COLS, BLOCK_ROWS); 243 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(); 244 | 245 | for (int n = 0; n < N; n++) { 246 | uint32_t offset = n * H * W; 247 | 248 | cc2d::init_labeling<<>>( 249 | labels.data_ptr() + offset, W, H); 250 | cc2d::merge<<>>( 251 | inputs.data_ptr() + offset, 252 | labels.data_ptr() + offset, 253 | W, 254 | H); 255 | cc2d::compression<<>>( 256 | labels.data_ptr() + offset, W, H); 257 | cc2d::final_labeling<<>>( 258 | inputs.data_ptr() + offset, 259 | labels.data_ptr() + offset, 260 | W, 261 | H); 262 | 263 | // get the counting of each pixel 264 | cc2d::init_counting<<>>( 265 | labels.data_ptr() + offset, 266 | counts_init.data_ptr() + offset, 267 | W, 268 | H); 269 | cc2d::final_counting<<>>( 270 | labels.data_ptr() + offset, 271 | counts_init.data_ptr() + offset, 272 | counts_final.data_ptr() + offset, 273 | W, 274 | H); 275 | } 276 | 277 | // returned values are [labels, counts] 278 | std::vector outputs; 279 | outputs.push_back(labels); 280 | outputs.push_back(counts_final); 281 | return outputs; 282 | } 283 | 284 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 285 | m.def( 286 | "get_connected_componnets", 287 | &get_connected_componnets, 288 | "get_connected_componnets"); 289 | } 290 | -------------------------------------------------------------------------------- /sam2/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/hieradet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | 9 | from functools import partial 10 | from typing import List, Tuple, Union 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | from sam2.modeling.backbones.utils import ( 17 | PatchEmbed, 18 | window_partition, 19 | window_unpartition, 20 | ) 21 | 22 | from sam2.modeling.sam2_utils import DropPath, MLP 23 | 24 | 25 | def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor: 26 | if pool is None: 27 | return x 28 | # (B, H, W, C) -> (B, C, H, W) 29 | x = x.permute(0, 3, 1, 2) 30 | x = pool(x) 31 | # (B, C, H', W') -> (B, H', W', C) 32 | x = x.permute(0, 2, 3, 1) 33 | if norm: 34 | x = norm(x) 35 | 36 | return x 37 | 38 | def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor: 39 | L, S = query.size(-2), key.size(-2) 40 | scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale 41 | attn_bias = torch.zeros(L, S, dtype=query.dtype) 42 | if is_causal: 43 | assert attn_mask is None 44 | temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) 45 | attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) 46 | attn_bias.to(query.dtype) 47 | 48 | if attn_mask is not None: 49 | if attn_mask.dtype == torch.bool: 50 | attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) 51 | else: 52 | attn_bias += attn_mask 53 | attn_weight = query @ key.transpose(-2, -1) * scale_factor 54 | attn_weight += attn_bias.to(device=attn_weight.device, dtype=attn_weight.dtype) 55 | attn_weight = torch.softmax(attn_weight, dim=-1) 56 | attn_weight = torch.dropout(attn_weight, dropout_p, train=True) 57 | return attn_weight @ value 58 | 59 | class MultiScaleAttention(nn.Module): 60 | def __init__( 61 | self, 62 | dim: int, 63 | dim_out: int, 64 | num_heads: int, 65 | q_pool: nn.Module = None, 66 | ): 67 | super().__init__() 68 | 69 | self.dim = dim 70 | self.dim_out = dim_out 71 | 72 | self.num_heads = num_heads 73 | head_dim = dim_out // num_heads 74 | self.scale = head_dim**-0.5 75 | 76 | self.q_pool = q_pool 77 | self.qkv = nn.Linear(dim, dim_out * 3) 78 | self.proj = nn.Linear(dim_out, dim_out) 79 | 80 | def forward(self, x: torch.Tensor) -> torch.Tensor: 81 | B, H, W, _ = x.shape 82 | # qkv with shape (B, H * W, 3, nHead, C) 83 | qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1) 84 | # q, k, v with shape (B, H * W, nheads, C) 85 | q, k, v = torch.unbind(qkv, 2) 86 | 87 | # Q pooling (for downsample at stage changes) 88 | if self.q_pool: 89 | q = do_pool(q.reshape(B, H, W, -1), self.q_pool) 90 | H, W = q.shape[1:3] # downsampled shape 91 | q = q.reshape(B, H * W, self.num_heads, -1) 92 | 93 | # Torch's SDPA expects [B, nheads, H*W, C] so we transpose 94 | x = scaled_dot_product_attention( 95 | q.transpose(1, 2), 96 | k.transpose(1, 2), 97 | v.transpose(1, 2), 98 | ) 99 | # Transpose back 100 | x = x.transpose(1, 2) 101 | x = x.reshape(B, H, W, -1) 102 | 103 | x = self.proj(x) 104 | 105 | return x 106 | 107 | 108 | class MultiScaleBlock(nn.Module): 109 | def __init__( 110 | self, 111 | dim: int, 112 | dim_out: int, 113 | num_heads: int, 114 | mlp_ratio: float = 4.0, 115 | drop_path: float = 0.0, 116 | norm_layer: Union[nn.Module, str] = "LayerNorm", 117 | q_stride: Tuple[int, int] = None, 118 | act_layer: nn.Module = nn.GELU, 119 | window_size: int = 0, 120 | ): 121 | super().__init__() 122 | 123 | if isinstance(norm_layer, str): 124 | norm_layer = partial(getattr(nn, norm_layer), eps=1e-6) 125 | 126 | self.dim = dim 127 | self.dim_out = dim_out 128 | self.norm1 = norm_layer(dim) 129 | 130 | self.window_size = window_size 131 | 132 | self.pool, self.q_stride = None, q_stride 133 | if self.q_stride: 134 | self.pool = nn.MaxPool2d( 135 | kernel_size=q_stride, stride=q_stride, ceil_mode=False 136 | ) 137 | 138 | self.attn = MultiScaleAttention( 139 | dim, 140 | dim_out, 141 | num_heads=num_heads, 142 | q_pool=self.pool, 143 | ) 144 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 145 | 146 | self.norm2 = norm_layer(dim_out) 147 | self.mlp = MLP( 148 | dim_out, 149 | int(dim_out * mlp_ratio), 150 | dim_out, 151 | num_layers=2, 152 | activation=act_layer, 153 | ) 154 | 155 | if dim != dim_out: 156 | self.proj = nn.Linear(dim, dim_out) 157 | 158 | def forward(self, x: torch.Tensor) -> torch.Tensor: 159 | shortcut = x # B, H, W, C 160 | x = self.norm1(x) 161 | 162 | # Skip connection 163 | if self.dim != self.dim_out: 164 | shortcut = do_pool(self.proj(x), self.pool) 165 | 166 | # Window partition 167 | window_size = self.window_size 168 | if window_size > 0: 169 | H, W = x.shape[1], x.shape[2] 170 | x, pad_hw = window_partition(x, window_size) 171 | 172 | # Window Attention + Q Pooling (if stage change) 173 | x = self.attn(x) 174 | if self.q_stride: 175 | # Shapes have changed due to Q pooling 176 | window_size = self.window_size // self.q_stride[0] 177 | H, W = shortcut.shape[1:3] 178 | 179 | pad_h = (window_size - H % window_size) % window_size 180 | pad_w = (window_size - W % window_size) % window_size 181 | pad_hw = (H + pad_h, W + pad_w) 182 | 183 | # Reverse window partition 184 | if self.window_size > 0: 185 | x = window_unpartition(x, window_size, pad_hw, (H, W)) 186 | 187 | x = shortcut + self.drop_path(x) 188 | # MLP 189 | x = x + self.drop_path(self.mlp(self.norm2(x))) 190 | return x 191 | 192 | 193 | class Hiera(nn.Module): 194 | """ 195 | Reference: https://arxiv.org/abs/2306.00989 196 | """ 197 | 198 | def __init__( 199 | self, 200 | embed_dim: int = 96, # initial embed dim 201 | num_heads: int = 1, # initial number of heads 202 | drop_path_rate: float = 0.0, # stochastic depth 203 | q_pool: int = 3, # number of q_pool stages 204 | q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages 205 | stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage 206 | dim_mul: float = 2.0, # dim_mul factor at stage shift 207 | head_mul: float = 2.0, # head_mul factor at stage shift 208 | window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14), 209 | # window size per stage, when not using global att. 210 | window_spec: Tuple[int, ...] = ( 211 | 8, 212 | 4, 213 | 14, 214 | 7, 215 | ), 216 | # global attn in these blocks 217 | global_att_blocks: Tuple[int, ...] = ( 218 | 12, 219 | 16, 220 | 20, 221 | ), 222 | return_interm_layers=True, # return feats from every stage 223 | ): 224 | super().__init__() 225 | 226 | assert len(stages) == len(window_spec) 227 | self.window_spec = window_spec 228 | 229 | depth = sum(stages) 230 | self.q_stride = q_stride 231 | self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)] 232 | assert 0 <= q_pool <= len(self.stage_ends[:-1]) 233 | self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool] 234 | self.return_interm_layers = return_interm_layers 235 | 236 | self.patch_embed = PatchEmbed( 237 | embed_dim=embed_dim, 238 | ) 239 | # Which blocks have global att? 240 | self.global_att_blocks = global_att_blocks 241 | 242 | # Windowed positional embedding (https://arxiv.org/abs/2311.05613) 243 | self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size 244 | self.pos_embed = nn.Parameter( 245 | torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size) 246 | ) 247 | self.pos_embed_window = nn.Parameter( 248 | torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0]) 249 | ) 250 | 251 | dpr = [ 252 | x.item() for x in torch.linspace(0, drop_path_rate, depth) 253 | ] # stochastic depth decay rule 254 | 255 | cur_stage = 1 256 | self.blocks = nn.ModuleList() 257 | 258 | for i in range(depth): 259 | dim_out = embed_dim 260 | # lags by a block, so first block of 261 | # next stage uses an initial window size 262 | # of previous stage and final window size of current stage 263 | window_size = self.window_spec[cur_stage - 1] 264 | 265 | if self.global_att_blocks is not None: 266 | window_size = 0 if i in self.global_att_blocks else window_size 267 | 268 | if i - 1 in self.stage_ends: 269 | dim_out = int(embed_dim * dim_mul) 270 | num_heads = int(num_heads * head_mul) 271 | cur_stage += 1 272 | 273 | block = MultiScaleBlock( 274 | dim=embed_dim, 275 | dim_out=dim_out, 276 | num_heads=num_heads, 277 | drop_path=dpr[i], 278 | q_stride=self.q_stride if i in self.q_pool_blocks else None, 279 | window_size=window_size, 280 | ) 281 | 282 | embed_dim = dim_out 283 | self.blocks.append(block) 284 | 285 | self.channel_list = ( 286 | [self.blocks[i].dim_out for i in self.stage_ends[::-1]] 287 | if return_interm_layers 288 | else [self.blocks[-1].dim_out] 289 | ) 290 | 291 | def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor: 292 | h, w = hw 293 | window_embed = self.pos_embed_window 294 | pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic") 295 | pos_embed = pos_embed + window_embed.tile( 296 | [x // y for x, y in zip(pos_embed.shape, window_embed.shape)] 297 | ) 298 | pos_embed = pos_embed.permute(0, 2, 3, 1) 299 | return pos_embed 300 | 301 | def forward(self, x: torch.Tensor) -> List[torch.Tensor]: 302 | x = self.patch_embed(x) 303 | # x: (B, H, W, C) 304 | 305 | # Add pos embed 306 | x = x + self._get_pos_embed(x.shape[1:3]) 307 | 308 | outputs = [] 309 | for i, blk in enumerate(self.blocks): 310 | x = blk(x) 311 | if (i == self.stage_ends[-1]) or ( 312 | i in self.stage_ends and self.return_interm_layers 313 | ): 314 | feats = x.permute(0, 3, 1, 2) 315 | outputs.append(feats) 316 | 317 | return outputs 318 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/image_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List, Optional 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | 14 | class ImageEncoder(nn.Module): 15 | def __init__( 16 | self, 17 | trunk: nn.Module, 18 | neck: nn.Module, 19 | scalp: int = 0, 20 | ): 21 | super().__init__() 22 | self.trunk = trunk 23 | self.neck = neck 24 | self.scalp = scalp 25 | assert ( 26 | self.trunk.channel_list == self.neck.backbone_channel_list 27 | ), f"Channel dims of trunk and neck do not match. Trunk: {self.trunk.channel_list}, neck: {self.neck.backbone_channel_list}" 28 | 29 | def forward(self, sample: torch.Tensor): 30 | # Forward through backbone 31 | features, pos = self.neck(self.trunk(sample)) 32 | if self.scalp > 0: 33 | # Discard the lowest resolution features 34 | features, pos = features[: -self.scalp], pos[: -self.scalp] 35 | 36 | src = features[-1] 37 | output = { 38 | "vision_features": src, 39 | "vision_pos_enc": pos, 40 | "backbone_fpn": features, 41 | } 42 | return output 43 | 44 | 45 | class FpnNeck(nn.Module): 46 | """ 47 | A modified variant of Feature Pyramid Network (FPN) neck 48 | (we remove output conv and also do bicubic interpolation similar to ViT 49 | pos embed interpolation) 50 | """ 51 | 52 | def __init__( 53 | self, 54 | position_encoding: nn.Module, 55 | d_model: int, 56 | backbone_channel_list: List[int], 57 | kernel_size: int = 1, 58 | stride: int = 1, 59 | padding: int = 0, 60 | fpn_interp_model: str = "bilinear", 61 | fuse_type: str = "sum", 62 | fpn_top_down_levels: Optional[List[int]] = None, 63 | ): 64 | """Initialize the neck 65 | :param trunk: the backbone 66 | :param position_encoding: the positional encoding to use 67 | :param d_model: the dimension of the model 68 | :param neck_norm: the normalization to use 69 | """ 70 | super().__init__() 71 | self.position_encoding = position_encoding 72 | self.convs = nn.ModuleList() 73 | self.backbone_channel_list = backbone_channel_list 74 | for dim in backbone_channel_list: 75 | current = nn.Sequential() 76 | current.add_module( 77 | "conv", 78 | nn.Conv2d( 79 | in_channels=dim, 80 | out_channels=d_model, 81 | kernel_size=kernel_size, 82 | stride=stride, 83 | padding=padding, 84 | ), 85 | ) 86 | 87 | self.convs.append(current) 88 | self.fpn_interp_model = fpn_interp_model 89 | assert fuse_type in ["sum", "avg"] 90 | self.fuse_type = fuse_type 91 | 92 | # levels to have top-down features in its outputs 93 | # e.g. if fpn_top_down_levels is [2, 3], then only outputs of level 2 and 3 94 | # have top-down propagation, while outputs of level 0 and level 1 have only 95 | # lateral features from the same backbone level. 96 | if fpn_top_down_levels is None: 97 | # default is to have top-down features on all levels 98 | fpn_top_down_levels = range(len(self.convs)) 99 | self.fpn_top_down_levels = list(fpn_top_down_levels) 100 | 101 | def forward(self, xs: List[torch.Tensor]): 102 | 103 | out = [None] * len(self.convs) 104 | pos = [None] * len(self.convs) 105 | assert len(xs) == len(self.convs) 106 | # fpn forward pass 107 | # see https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/fpn.py 108 | prev_features = None 109 | # forward in top-down order (from low to high resolution) 110 | n = len(self.convs) - 1 111 | for i in range(n, -1, -1): 112 | x = xs[i] 113 | lateral_features = self.convs[n - i](x) 114 | if i in self.fpn_top_down_levels and prev_features is not None: 115 | top_down_features = F.interpolate( 116 | prev_features.to(dtype=torch.float32), 117 | scale_factor=2.0, 118 | mode=self.fpn_interp_model, 119 | align_corners=( 120 | None if self.fpn_interp_model == "nearest" else False 121 | ), 122 | antialias=False, 123 | ) 124 | prev_features = lateral_features + top_down_features 125 | if self.fuse_type == "avg": 126 | prev_features /= 2 127 | else: 128 | prev_features = lateral_features 129 | x_out = prev_features 130 | out[i] = x_out 131 | pos[i] = self.position_encoding(x_out).to(x_out.dtype) 132 | 133 | return out, pos 134 | -------------------------------------------------------------------------------- /sam2/modeling/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Some utilities for backbones, in particular for windowing""" 8 | 9 | from typing import Tuple 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | 16 | def window_partition(x, window_size): 17 | """ 18 | Partition into non-overlapping windows with padding if needed. 19 | Args: 20 | x (tensor): input tokens with [B, H, W, C]. 21 | window_size (int): window size. 22 | Returns: 23 | windows: windows after partition with [B * num_windows, window_size, window_size, C]. 24 | (Hp, Wp): padded height and width before partition 25 | """ 26 | B, H, W, C = x.shape 27 | 28 | pad_h = (window_size - H % window_size) % window_size 29 | pad_w = (window_size - W % window_size) % window_size 30 | if pad_h > 0 or pad_w > 0: 31 | x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) 32 | Hp, Wp = H + pad_h, W + pad_w 33 | 34 | x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) 35 | windows = ( 36 | x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 37 | ) 38 | return windows, (Hp, Wp) 39 | 40 | 41 | def window_unpartition(windows, window_size, pad_hw, hw): 42 | """ 43 | Window unpartition into original sequences and removing padding. 44 | Args: 45 | x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. 46 | window_size (int): window size. 47 | pad_hw (Tuple): padded height and width (Hp, Wp). 48 | hw (Tuple): original height and width (H, W) before padding. 49 | Returns: 50 | x: unpartitioned sequences with [B, H, W, C]. 51 | """ 52 | Hp, Wp = pad_hw 53 | H, W = hw 54 | B = windows.shape[0] // (Hp * Wp // window_size // window_size) 55 | x = windows.view( 56 | B, Hp // window_size, Wp // window_size, window_size, window_size, -1 57 | ) 58 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) 59 | 60 | if Hp > H or Wp > W: 61 | x = x[:, :H, :W, :].contiguous() 62 | return x 63 | 64 | 65 | class PatchEmbed(nn.Module): 66 | """ 67 | Image to Patch Embedding. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | kernel_size: Tuple[int, ...] = (7, 7), 73 | stride: Tuple[int, ...] = (4, 4), 74 | padding: Tuple[int, ...] = (3, 3), 75 | in_chans: int = 3, 76 | embed_dim: int = 768, 77 | ): 78 | """ 79 | Args: 80 | kernel_size (Tuple): kernel size of the projection layer. 81 | stride (Tuple): stride of the projection layer. 82 | padding (Tuple): padding size of the projection layer. 83 | in_chans (int): Number of input image channels. 84 | embed_dim (int): embed_dim (int): Patch embedding dimension. 85 | """ 86 | super().__init__() 87 | self.proj = nn.Conv2d( 88 | in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding 89 | ) 90 | 91 | def forward(self, x: torch.Tensor) -> torch.Tensor: 92 | x = self.proj(x) 93 | # B C H W -> B H W C 94 | x = x.permute(0, 2, 3, 1) 95 | return x 96 | -------------------------------------------------------------------------------- /sam2/modeling/memory_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch 10 | from torch import nn, Tensor 11 | 12 | from sam2.modeling.sam.transformer import RoPEAttention 13 | 14 | from sam2.modeling.sam2_utils import get_activation_fn, get_clones 15 | 16 | 17 | class MemoryAttentionLayer(nn.Module): 18 | 19 | def __init__( 20 | self, 21 | activation: str, 22 | cross_attention: nn.Module, 23 | d_model: int, 24 | dim_feedforward: int, 25 | dropout: float, 26 | pos_enc_at_attn: bool, 27 | pos_enc_at_cross_attn_keys: bool, 28 | pos_enc_at_cross_attn_queries: bool, 29 | self_attention: nn.Module, 30 | ): 31 | super().__init__() 32 | self.d_model = d_model 33 | self.dim_feedforward = dim_feedforward 34 | self.dropout_value = dropout 35 | self.self_attn = self_attention 36 | self.cross_attn_image = cross_attention 37 | 38 | # Implementation of Feedforward model 39 | self.linear1 = nn.Linear(d_model, dim_feedforward) 40 | self.dropout = nn.Dropout(dropout) 41 | self.linear2 = nn.Linear(dim_feedforward, d_model) 42 | 43 | self.norm1 = nn.LayerNorm(d_model) 44 | self.norm2 = nn.LayerNorm(d_model) 45 | self.norm3 = nn.LayerNorm(d_model) 46 | self.dropout1 = nn.Dropout(dropout) 47 | self.dropout2 = nn.Dropout(dropout) 48 | self.dropout3 = nn.Dropout(dropout) 49 | 50 | self.activation_str = activation 51 | self.activation = get_activation_fn(activation) 52 | 53 | # Where to add pos enc 54 | self.pos_enc_at_attn = pos_enc_at_attn 55 | self.pos_enc_at_cross_attn_queries = pos_enc_at_cross_attn_queries 56 | self.pos_enc_at_cross_attn_keys = pos_enc_at_cross_attn_keys 57 | 58 | def _forward_sa(self, tgt, query_pos): 59 | # Self-Attention 60 | tgt2 = self.norm1(tgt) 61 | q = k = tgt2 + query_pos if self.pos_enc_at_attn else tgt2 62 | tgt2 = self.self_attn(q, k, v=tgt2) 63 | tgt = tgt + self.dropout1(tgt2) 64 | return tgt 65 | 66 | def _forward_ca(self, tgt, memory, query_pos, pos, num_k_exclude_rope=0): 67 | kwds = {} 68 | if num_k_exclude_rope > 0: 69 | assert isinstance(self.cross_attn_image, RoPEAttention) 70 | kwds = {"num_k_exclude_rope": num_k_exclude_rope} 71 | 72 | # Cross-Attention 73 | tgt2 = self.norm2(tgt) 74 | tgt2 = self.cross_attn_image( 75 | q=tgt2 + query_pos if self.pos_enc_at_cross_attn_queries else tgt2, 76 | k=memory + pos if self.pos_enc_at_cross_attn_keys else memory, 77 | v=memory, 78 | **kwds, 79 | ) 80 | tgt = tgt + self.dropout2(tgt2) 81 | return tgt 82 | 83 | def forward( 84 | self, 85 | tgt, 86 | memory, 87 | pos: Optional[Tensor] = None, 88 | query_pos: Optional[Tensor] = None, 89 | num_k_exclude_rope: int = 0, 90 | ) -> torch.Tensor: 91 | 92 | # Self-Attn, Cross-Attn 93 | tgt = self._forward_sa(tgt, query_pos) 94 | tgt = self._forward_ca(tgt, memory, query_pos, pos, num_k_exclude_rope) 95 | # MLP 96 | tgt2 = self.norm3(tgt) 97 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 98 | tgt = tgt + self.dropout3(tgt2) 99 | return tgt 100 | 101 | 102 | class MemoryAttention(nn.Module): 103 | def __init__( 104 | self, 105 | d_model: int, 106 | pos_enc_at_input: bool, 107 | layer: nn.Module, 108 | num_layers: int, 109 | batch_first: bool = True, # Do layers expect batch first input? 110 | ): 111 | super().__init__() 112 | self.d_model = d_model 113 | self.layers = get_clones(layer, num_layers) 114 | self.num_layers = num_layers 115 | self.norm = nn.LayerNorm(d_model) 116 | self.pos_enc_at_input = pos_enc_at_input 117 | self.batch_first = batch_first 118 | 119 | def forward( 120 | self, 121 | curr: torch.Tensor, # self-attention inputs 122 | memory: torch.Tensor, # cross-attention inputs 123 | curr_pos: Optional[Tensor] = None, # pos_enc for self-attention inputs 124 | memory_pos: Optional[Tensor] = None, # pos_enc for cross-attention inputs 125 | num_obj_ptr_tokens: int = 0, # number of object pointer *tokens* 126 | ): 127 | if isinstance(curr, list): 128 | assert isinstance(curr_pos, list) 129 | assert len(curr) == len(curr_pos) == 1 130 | curr, curr_pos = ( 131 | curr[0], 132 | curr_pos[0], 133 | ) 134 | 135 | assert ( 136 | curr.shape[1] == memory.shape[1] 137 | ), "Batch size must be the same for curr and memory" 138 | 139 | output = curr 140 | if self.pos_enc_at_input and curr_pos is not None: 141 | output = output + 0.1 * curr_pos 142 | 143 | if self.batch_first: 144 | # Convert to batch first 145 | output = output.transpose(0, 1) 146 | curr_pos = curr_pos.transpose(0, 1) 147 | memory = memory.transpose(0, 1) 148 | memory_pos = memory_pos.transpose(0, 1) 149 | 150 | for layer in self.layers: 151 | kwds = {} 152 | if isinstance(layer.cross_attn_image, RoPEAttention): 153 | kwds = {"num_k_exclude_rope": num_obj_ptr_tokens} 154 | 155 | output = layer( 156 | tgt=output, 157 | memory=memory, 158 | pos=memory_pos, 159 | query_pos=curr_pos, 160 | **kwds, 161 | ) 162 | normed_output = self.norm(output) 163 | 164 | if self.batch_first: 165 | # Convert back to seq first 166 | normed_output = normed_output.transpose(0, 1) 167 | curr_pos = curr_pos.transpose(0, 1) 168 | 169 | return normed_output 170 | -------------------------------------------------------------------------------- /sam2/modeling/memory_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Tuple 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | from sam2.modeling.sam2_utils import DropPath, get_clones, LayerNorm2d 15 | 16 | 17 | class MaskDownSampler(nn.Module): 18 | """ 19 | Progressively downsample a mask by total_stride, each time by stride. 20 | Note that LayerNorm is applied per *token*, like in ViT. 21 | 22 | With each downsample (by a factor stride**2), channel capacity increases by the same factor. 23 | In the end, we linearly project to embed_dim channels. 24 | """ 25 | 26 | def __init__( 27 | self, 28 | embed_dim=256, 29 | kernel_size=4, 30 | stride=4, 31 | padding=0, 32 | total_stride=16, 33 | activation=nn.GELU, 34 | ): 35 | super().__init__() 36 | num_layers = int(math.log2(total_stride) // math.log2(stride)) 37 | assert stride**num_layers == total_stride 38 | self.encoder = nn.Sequential() 39 | mask_in_chans, mask_out_chans = 1, 1 40 | for _ in range(num_layers): 41 | mask_out_chans = mask_in_chans * (stride**2) 42 | self.encoder.append( 43 | nn.Conv2d( 44 | mask_in_chans, 45 | mask_out_chans, 46 | kernel_size=kernel_size, 47 | stride=stride, 48 | padding=padding, 49 | ) 50 | ) 51 | self.encoder.append(LayerNorm2d(mask_out_chans)) 52 | self.encoder.append(activation()) 53 | mask_in_chans = mask_out_chans 54 | 55 | self.encoder.append(nn.Conv2d(mask_out_chans, embed_dim, kernel_size=1)) 56 | 57 | def forward(self, x): 58 | return self.encoder(x) 59 | 60 | 61 | # Lightly adapted from ConvNext (https://github.com/facebookresearch/ConvNeXt) 62 | class CXBlock(nn.Module): 63 | r"""ConvNeXt Block. There are two equivalent implementations: 64 | (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) 65 | (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back 66 | We use (2) as we find it slightly faster in PyTorch 67 | 68 | Args: 69 | dim (int): Number of input channels. 70 | drop_path (float): Stochastic depth rate. Default: 0.0 71 | layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. 72 | """ 73 | 74 | def __init__( 75 | self, 76 | dim, 77 | kernel_size=7, 78 | padding=3, 79 | drop_path=0.0, 80 | layer_scale_init_value=1e-6, 81 | use_dwconv=True, 82 | ): 83 | super().__init__() 84 | self.dwconv = nn.Conv2d( 85 | dim, 86 | dim, 87 | kernel_size=kernel_size, 88 | padding=padding, 89 | groups=dim if use_dwconv else 1, 90 | ) # depthwise conv 91 | self.norm = LayerNorm2d(dim, eps=1e-6) 92 | self.pwconv1 = nn.Linear( 93 | dim, 4 * dim 94 | ) # pointwise/1x1 convs, implemented with linear layers 95 | self.act = nn.GELU() 96 | self.pwconv2 = nn.Linear(4 * dim, dim) 97 | self.gamma = ( 98 | nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) 99 | if layer_scale_init_value > 0 100 | else None 101 | ) 102 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 103 | 104 | def forward(self, x): 105 | input = x 106 | x = self.dwconv(x) 107 | x = self.norm(x) 108 | x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 109 | x = self.pwconv1(x) 110 | x = self.act(x) 111 | x = self.pwconv2(x) 112 | if self.gamma is not None: 113 | x = self.gamma * x 114 | x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 115 | 116 | x = input + self.drop_path(x) 117 | return x 118 | 119 | 120 | class Fuser(nn.Module): 121 | def __init__(self, layer, num_layers, dim=None, input_projection=False): 122 | super().__init__() 123 | self.proj = nn.Identity() 124 | self.layers = get_clones(layer, num_layers) 125 | 126 | if input_projection: 127 | assert dim is not None 128 | self.proj = nn.Conv2d(dim, dim, kernel_size=1) 129 | 130 | def forward(self, x): 131 | # normally x: (N, C, H, W) 132 | x = self.proj(x) 133 | for layer in self.layers: 134 | x = layer(x) 135 | return x 136 | 137 | 138 | class MemoryEncoder(nn.Module): 139 | def __init__( 140 | self, 141 | out_dim, 142 | mask_downsampler, 143 | fuser, 144 | position_encoding, 145 | in_dim=256, # in_dim of pix_feats 146 | ): 147 | super().__init__() 148 | 149 | self.mask_downsampler = mask_downsampler 150 | 151 | self.pix_feat_proj = nn.Conv2d(in_dim, in_dim, kernel_size=1) 152 | self.fuser = fuser 153 | self.position_encoding = position_encoding 154 | self.out_proj = nn.Identity() 155 | if out_dim != in_dim: 156 | self.out_proj = nn.Conv2d(in_dim, out_dim, kernel_size=1) 157 | 158 | def forward( 159 | self, 160 | pix_feat: torch.Tensor, 161 | masks: torch.Tensor, 162 | skip_mask_sigmoid: bool = False, 163 | ) -> Tuple[torch.Tensor, torch.Tensor]: 164 | ## Process masks 165 | # sigmoid, so that less domain shift from gt masks which are bool 166 | if not skip_mask_sigmoid: 167 | masks = F.sigmoid(masks) 168 | masks = self.mask_downsampler(masks) 169 | 170 | ## Fuse pix_feats and downsampled masks 171 | # in case the visual features are on CPU, cast them to CUDA 172 | pix_feat = pix_feat.to(masks.device) 173 | 174 | x = self.pix_feat_proj(pix_feat) 175 | x = x + masks 176 | x = self.fuser(x) 177 | x = self.out_proj(x) 178 | 179 | pos = self.position_encoding(x).to(x.dtype) 180 | 181 | return {"vision_features": x, "vision_pos_enc": [pos]} 182 | -------------------------------------------------------------------------------- /sam2/modeling/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | from typing import Any, Optional, Tuple 9 | 10 | import numpy as np 11 | 12 | import torch 13 | from torch import nn 14 | 15 | 16 | class PositionEmbeddingSine(nn.Module): 17 | """ 18 | This is a more standard version of the position embedding, very similar to the one 19 | used by the Attention is all you need paper, generalized to work on images. 20 | """ 21 | 22 | def __init__( 23 | self, 24 | num_pos_feats, 25 | temperature: int = 10000, 26 | normalize: bool = True, 27 | scale: Optional[float] = None, 28 | ): 29 | super().__init__() 30 | assert num_pos_feats % 2 == 0, "Expecting even model width" 31 | self.num_pos_feats = num_pos_feats // 2 32 | self.temperature = temperature 33 | self.normalize = normalize 34 | if scale is not None and normalize is False: 35 | raise ValueError("normalize should be True if scale is passed") 36 | if scale is None: 37 | scale = 2 * math.pi 38 | self.scale = scale 39 | 40 | self.cache = {} 41 | 42 | def _encode_xy(self, x, y): 43 | # The positions are expected to be normalized 44 | assert len(x) == len(y) and x.ndim == y.ndim == 1 45 | x_embed = x * self.scale 46 | y_embed = y * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, None] / dim_t 52 | pos_y = y_embed[:, None] / dim_t 53 | pos_x = torch.stack( 54 | (pos_x[:, 0::2].sin(), pos_x[:, 1::2].cos()), dim=2 55 | ).flatten(1) 56 | pos_y = torch.stack( 57 | (pos_y[:, 0::2].sin(), pos_y[:, 1::2].cos()), dim=2 58 | ).flatten(1) 59 | return pos_x, pos_y 60 | 61 | @torch.no_grad() 62 | def encode_boxes(self, x, y, w, h): 63 | pos_x, pos_y = self._encode_xy(x, y) 64 | pos = torch.cat((pos_y, pos_x, h[:, None], w[:, None]), dim=1) 65 | return pos 66 | 67 | encode = encode_boxes # Backwards compatibility 68 | 69 | @torch.no_grad() 70 | def encode_points(self, x, y, labels): 71 | (bx, nx), (by, ny), (bl, nl) = x.shape, y.shape, labels.shape 72 | assert bx == by and nx == ny and bx == bl and nx == nl 73 | pos_x, pos_y = self._encode_xy(x.flatten(), y.flatten()) 74 | pos_x, pos_y = pos_x.reshape(bx, nx, -1), pos_y.reshape(by, ny, -1) 75 | pos = torch.cat((pos_y, pos_x, labels[:, :, None]), dim=2) 76 | return pos 77 | 78 | @torch.no_grad() 79 | def forward(self, x: torch.Tensor): 80 | cache_key = (x.shape[-2], x.shape[-1]) 81 | if cache_key in self.cache: 82 | return self.cache[cache_key][None].repeat(x.shape[0], 1, 1, 1) 83 | y_embed = ( 84 | torch.arange(1, x.shape[-2] + 1, dtype=torch.float32, device=x.device) 85 | .view(1, -1, 1) 86 | .repeat(x.shape[0], 1, x.shape[-1]) 87 | ) 88 | x_embed = ( 89 | torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device) 90 | .view(1, 1, -1) 91 | .repeat(x.shape[0], x.shape[-2], 1) 92 | ) 93 | 94 | if self.normalize: 95 | eps = 1e-6 96 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 97 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 98 | 99 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 100 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 101 | 102 | pos_x = x_embed[:, :, :, None] / dim_t 103 | pos_y = y_embed[:, :, :, None] / dim_t 104 | pos_x = torch.stack( 105 | (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 106 | ).flatten(3) 107 | pos_y = torch.stack( 108 | (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 109 | ).flatten(3) 110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 111 | self.cache[cache_key] = pos[0] 112 | return pos 113 | 114 | 115 | class PositionEmbeddingRandom(nn.Module): 116 | """ 117 | Positional encoding using random spatial frequencies. 118 | """ 119 | 120 | def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: 121 | super().__init__() 122 | if scale is None or scale <= 0.0: 123 | scale = 1.0 124 | self.register_buffer( 125 | "positional_encoding_gaussian_matrix", 126 | scale * torch.randn((2, num_pos_feats)), 127 | ) 128 | 129 | def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: 130 | """Positionally encode points that are normalized to [0,1].""" 131 | # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape 132 | coords = 2 * coords - 1 133 | coords = coords @ self.positional_encoding_gaussian_matrix 134 | coords = 2 * np.pi * coords 135 | # outputs d_1 x ... x d_n x C shape 136 | return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) 137 | 138 | def forward(self, size: Tuple[int, int]) -> torch.Tensor: 139 | """Generate positional encoding for a grid of the specified size.""" 140 | h, w = size 141 | device: Any = self.positional_encoding_gaussian_matrix.device 142 | grid = torch.ones((h, w), device=device, dtype=torch.float32) 143 | y_embed = grid.cumsum(dim=0) - 0.5 144 | x_embed = grid.cumsum(dim=1) - 0.5 145 | y_embed = y_embed / h 146 | x_embed = x_embed / w 147 | 148 | pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) 149 | return pe.permute(2, 0, 1) # C x H x W 150 | 151 | def forward_with_coords( 152 | self, coords_input: torch.Tensor, image_size: Tuple[int, int] 153 | ) -> torch.Tensor: 154 | """Positionally encode points that are not normalized to [0,1].""" 155 | coords = coords_input.clone() 156 | coords[:, :, 0] = coords[:, :, 0] / image_size[1] 157 | coords[:, :, 1] = coords[:, :, 1] / image_size[0] 158 | return self._pe_encoding(coords.to(torch.float)) # B x N x C 159 | 160 | 161 | # Rotary Positional Encoding, adapted from: 162 | # 1. https://github.com/meta-llama/codellama/blob/main/llama/model.py 163 | # 2. https://github.com/naver-ai/rope-vit 164 | # 3. https://github.com/lucidrains/rotary-embedding-torch 165 | 166 | 167 | def init_t_xy(end_x: int, end_y: int): 168 | t = torch.arange(end_x * end_y, dtype=torch.float32) 169 | t_x = (t % end_x).float() 170 | t_y = torch.div(t, end_x, rounding_mode="floor").float() 171 | return t_x, t_y 172 | 173 | 174 | def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 10000.0): 175 | freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 176 | freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim)) 177 | 178 | t_x, t_y = init_t_xy(end_x, end_y) 179 | freqs_x = torch.outer(t_x, freqs_x) 180 | freqs_y = torch.outer(t_y, freqs_y) 181 | freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x) 182 | freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y) 183 | return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1) 184 | 185 | 186 | def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): 187 | ndim = x.ndim 188 | assert 0 <= 1 < ndim 189 | assert freqs_cis.shape == (x.shape[-2], x.shape[-1]) 190 | shape = [d if i >= ndim - 2 else 1 for i, d in enumerate(x.shape)] 191 | return freqs_cis.view(*shape) 192 | 193 | 194 | def apply_rotary_enc( 195 | xq: torch.Tensor, 196 | xk: torch.Tensor, 197 | freqs_cis: torch.Tensor, 198 | repeat_freqs_k: bool = False, 199 | ): 200 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 201 | xk_ = ( 202 | torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 203 | if xk.shape[-2] != 0 204 | else None 205 | ) 206 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 207 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 208 | if xk_ is None: 209 | # no keys to rotate, due to dropout 210 | return xq_out.type_as(xq).to(xq.device), xk 211 | # repeat freqs along seq_len dim to match k seq_len 212 | if repeat_freqs_k: 213 | r = xk_.shape[-2] // xq_.shape[-2] 214 | freqs_cis = freqs_cis.repeat(*([1] * (freqs_cis.ndim - 2)), r, 1) 215 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 216 | return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device) 217 | -------------------------------------------------------------------------------- /sam2/modeling/sam/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/modeling/sam/prompt_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional, Tuple, Type 8 | 9 | import torch 10 | from torch import nn 11 | 12 | from sam2.modeling.position_encoding import PositionEmbeddingRandom 13 | 14 | from sam2.modeling.sam2_utils import LayerNorm2d 15 | 16 | 17 | class PromptEncoder(nn.Module): 18 | def __init__( 19 | self, 20 | embed_dim: int, 21 | image_embedding_size: Tuple[int, int], 22 | input_image_size: Tuple[int, int], 23 | mask_in_chans: int, 24 | activation: Type[nn.Module] = nn.GELU, 25 | ) -> None: 26 | """ 27 | Encodes prompts for input to SAM's mask decoder. 28 | 29 | Arguments: 30 | embed_dim (int): The prompts' embedding dimension 31 | image_embedding_size (tuple(int, int)): The spatial size of the 32 | image embedding, as (H, W). 33 | input_image_size (int): The padded size of the image as input 34 | to the image encoder, as (H, W). 35 | mask_in_chans (int): The number of hidden channels used for 36 | encoding input masks. 37 | activation (nn.Module): The activation to use when encoding 38 | input masks. 39 | """ 40 | super().__init__() 41 | self.embed_dim = embed_dim 42 | self.input_image_size = input_image_size 43 | self.image_embedding_size = image_embedding_size 44 | self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) 45 | 46 | self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners 47 | point_embeddings = [ 48 | nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings) 49 | ] 50 | self.point_embeddings = nn.ModuleList(point_embeddings) 51 | self.not_a_point_embed = nn.Embedding(1, embed_dim) 52 | 53 | self.mask_input_size = ( 54 | 4 * image_embedding_size[0], 55 | 4 * image_embedding_size[1], 56 | ) 57 | self.mask_downscaling = nn.Sequential( 58 | nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), 59 | LayerNorm2d(mask_in_chans // 4), 60 | activation(), 61 | nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), 62 | LayerNorm2d(mask_in_chans), 63 | activation(), 64 | nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), 65 | ) 66 | self.no_mask_embed = nn.Embedding(1, embed_dim) 67 | 68 | def get_dense_pe(self) -> torch.Tensor: 69 | """ 70 | Returns the positional encoding used to encode point prompts, 71 | applied to a dense set of points the shape of the image encoding. 72 | 73 | Returns: 74 | torch.Tensor: Positional encoding with shape 75 | 1x(embed_dim)x(embedding_h)x(embedding_w) 76 | """ 77 | return self.pe_layer(self.image_embedding_size).unsqueeze(0) 78 | 79 | def _embed_points( 80 | self, 81 | points: torch.Tensor, 82 | labels: torch.Tensor, 83 | pad: bool, 84 | ) -> torch.Tensor: 85 | """Embeds point prompts.""" 86 | points = points + 0.5 # Shift to center of pixel 87 | if pad: 88 | padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) 89 | padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) 90 | points = torch.cat([points, padding_point], dim=1) 91 | labels = torch.cat([labels, padding_label], dim=1) 92 | point_embedding = self.pe_layer.forward_with_coords( 93 | points, self.input_image_size 94 | ) 95 | point_embedding[labels == -1] = 0.0 96 | point_embedding[labels == -1] += self.not_a_point_embed.weight 97 | point_embedding[labels == 0] += self.point_embeddings[0].weight 98 | point_embedding[labels == 1] += self.point_embeddings[1].weight 99 | point_embedding[labels == 2] += self.point_embeddings[2].weight 100 | point_embedding[labels == 3] += self.point_embeddings[3].weight 101 | return point_embedding 102 | 103 | def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: 104 | """Embeds box prompts.""" 105 | boxes = boxes + 0.5 # Shift to center of pixel 106 | coords = boxes.reshape(-1, 2, 2) 107 | corner_embedding = self.pe_layer.forward_with_coords( 108 | coords, self.input_image_size 109 | ) 110 | corner_embedding[:, 0, :] += self.point_embeddings[2].weight 111 | corner_embedding[:, 1, :] += self.point_embeddings[3].weight 112 | return corner_embedding 113 | 114 | def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: 115 | """Embeds mask inputs.""" 116 | mask_embedding = self.mask_downscaling(masks) 117 | return mask_embedding 118 | 119 | def _get_batch_size( 120 | self, 121 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 122 | boxes: Optional[torch.Tensor], 123 | masks: Optional[torch.Tensor], 124 | ) -> int: 125 | """ 126 | Gets the batch size of the output given the batch size of the input prompts. 127 | """ 128 | if points is not None: 129 | return points[0].shape[0] 130 | elif boxes is not None: 131 | return boxes.shape[0] 132 | elif masks is not None: 133 | return masks.shape[0] 134 | else: 135 | return 1 136 | 137 | def _get_device(self) -> torch.device: 138 | return self.point_embeddings[0].weight.device 139 | 140 | def forward( 141 | self, 142 | points: Optional[Tuple[torch.Tensor, torch.Tensor]], 143 | boxes: Optional[torch.Tensor], 144 | masks: Optional[torch.Tensor], 145 | ) -> Tuple[torch.Tensor, torch.Tensor]: 146 | """ 147 | Embeds different types of prompts, returning both sparse and dense 148 | embeddings. 149 | 150 | Arguments: 151 | points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates 152 | and labels to embed. 153 | boxes (torch.Tensor or none): boxes to embed 154 | masks (torch.Tensor or none): masks to embed 155 | 156 | Returns: 157 | torch.Tensor: sparse embeddings for the points and boxes, with shape 158 | BxNx(embed_dim), where N is determined by the number of input points 159 | and boxes. 160 | torch.Tensor: dense embeddings for the masks, in the shape 161 | Bx(embed_dim)x(embed_H)x(embed_W) 162 | """ 163 | bs = self._get_batch_size(points, boxes, masks) 164 | sparse_embeddings = torch.empty( 165 | (bs, 0, self.embed_dim), device=self._get_device() 166 | ) 167 | if points is not None: 168 | coords, labels = points 169 | point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) 170 | sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) 171 | if boxes is not None: 172 | box_embeddings = self._embed_boxes(boxes) 173 | sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) 174 | 175 | if masks is not None: 176 | dense_embeddings = self._embed_masks(masks) 177 | else: 178 | dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( 179 | bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] 180 | ) 181 | 182 | return sparse_embeddings, dense_embeddings 183 | -------------------------------------------------------------------------------- /sam2/modeling/sam2_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import copy 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | 15 | def select_closest_cond_frames(frame_idx, cond_frame_outputs, max_cond_frame_num): 16 | """ 17 | Select up to `max_cond_frame_num` conditioning frames from `cond_frame_outputs` 18 | that are temporally closest to the current frame at `frame_idx`. Here, we take 19 | - a) the closest conditioning frame before `frame_idx` (if any); 20 | - b) the closest conditioning frame after `frame_idx` (if any); 21 | - c) any other temporally closest conditioning frames until reaching a total 22 | of `max_cond_frame_num` conditioning frames. 23 | 24 | Outputs: 25 | - selected_outputs: selected items (keys & values) from `cond_frame_outputs`. 26 | - unselected_outputs: items (keys & values) not selected in `cond_frame_outputs`. 27 | """ 28 | if max_cond_frame_num == -1 or len(cond_frame_outputs) <= max_cond_frame_num: 29 | selected_outputs = cond_frame_outputs 30 | unselected_outputs = {} 31 | else: 32 | assert max_cond_frame_num >= 2, "we should allow using 2+ conditioning frames" 33 | selected_outputs = {} 34 | 35 | # the closest conditioning frame before `frame_idx` (if any) 36 | idx_before = max((t for t in cond_frame_outputs if t < frame_idx), default=None) 37 | if idx_before is not None: 38 | selected_outputs[idx_before] = cond_frame_outputs[idx_before] 39 | 40 | # the closest conditioning frame after `frame_idx` (if any) 41 | idx_after = min((t for t in cond_frame_outputs if t >= frame_idx), default=None) 42 | if idx_after is not None: 43 | selected_outputs[idx_after] = cond_frame_outputs[idx_after] 44 | 45 | # add other temporally closest conditioning frames until reaching a total 46 | # of `max_cond_frame_num` conditioning frames. 47 | num_remain = max_cond_frame_num - len(selected_outputs) 48 | inds_remain = sorted( 49 | (t for t in cond_frame_outputs if t not in selected_outputs), 50 | key=lambda x: abs(x - frame_idx), 51 | )[:num_remain] 52 | selected_outputs.update((t, cond_frame_outputs[t]) for t in inds_remain) 53 | unselected_outputs = { 54 | t: v for t, v in cond_frame_outputs.items() if t not in selected_outputs 55 | } 56 | 57 | return selected_outputs, unselected_outputs 58 | 59 | 60 | def get_1d_sine_pe(pos_inds, dim, temperature=10000): 61 | """ 62 | Get 1D sine positional embedding as in the original Transformer paper. 63 | """ 64 | pe_dim = dim // 2 65 | dim_t = torch.arange(pe_dim, dtype=torch.float32, device=pos_inds.device) 66 | dim_t = temperature ** (2 * (dim_t // 2) / pe_dim) 67 | 68 | pos_embed = pos_inds.unsqueeze(-1) / dim_t 69 | pos_embed = torch.cat([pos_embed.sin(), pos_embed.cos()], dim=-1) 70 | return pos_embed 71 | 72 | 73 | def get_activation_fn(activation): 74 | """Return an activation function given a string""" 75 | if activation == "relu": 76 | return F.relu 77 | if activation == "gelu": 78 | return F.gelu 79 | if activation == "glu": 80 | return F.glu 81 | raise RuntimeError(f"activation should be relu/gelu, not {activation}.") 82 | 83 | 84 | def get_clones(module, N): 85 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 86 | 87 | 88 | class DropPath(nn.Module): 89 | # adapted from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py 90 | def __init__(self, drop_prob=0.0, scale_by_keep=True): 91 | super(DropPath, self).__init__() 92 | self.drop_prob = drop_prob 93 | self.scale_by_keep = scale_by_keep 94 | 95 | def forward(self, x): 96 | if self.drop_prob == 0.0 or not self.training: 97 | return x 98 | keep_prob = 1 - self.drop_prob 99 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 100 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 101 | if keep_prob > 0.0 and self.scale_by_keep: 102 | random_tensor.div_(keep_prob) 103 | return x * random_tensor 104 | 105 | 106 | # Lightly adapted from 107 | # https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa 108 | class MLP(nn.Module): 109 | def __init__( 110 | self, 111 | input_dim: int, 112 | hidden_dim: int, 113 | output_dim: int, 114 | num_layers: int, 115 | activation: nn.Module = nn.ReLU, 116 | sigmoid_output: bool = False, 117 | ) -> None: 118 | super().__init__() 119 | self.num_layers = num_layers 120 | h = [hidden_dim] * (num_layers - 1) 121 | self.layers = nn.ModuleList( 122 | nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) 123 | ) 124 | self.sigmoid_output = sigmoid_output 125 | self.act = activation() 126 | 127 | def forward(self, x): 128 | for i, layer in enumerate(self.layers): 129 | x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x) 130 | if self.sigmoid_output: 131 | x = F.sigmoid(x) 132 | return x 133 | 134 | 135 | # From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa 136 | # Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa 137 | class LayerNorm2d(nn.Module): 138 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 139 | super().__init__() 140 | self.weight = nn.Parameter(torch.ones(num_channels)) 141 | self.bias = nn.Parameter(torch.zeros(num_channels)) 142 | self.eps = eps 143 | 144 | def forward(self, x: torch.Tensor) -> torch.Tensor: 145 | u = x.mean(1, keepdim=True) 146 | s = (x - u).pow(2).mean(1, keepdim=True) 147 | x = (x - u) / torch.sqrt(s + self.eps) 148 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 149 | return x 150 | -------------------------------------------------------------------------------- /sam2/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import warnings 9 | from threading import Thread 10 | 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | 17 | def get_sdpa_settings(): 18 | if torch.cuda.is_available(): 19 | old_gpu = torch.cuda.get_device_properties(0).major < 7 20 | # only use Flash Attention on Ampere (8.0) or newer GPUs 21 | use_flash_attn = torch.cuda.get_device_properties(0).major >= 8 22 | if not use_flash_attn: 23 | warnings.warn( 24 | "Flash Attention is disabled as it requires a GPU with Ampere (8.0) CUDA capability.", 25 | category=UserWarning, 26 | stacklevel=2, 27 | ) 28 | # keep math kernel for PyTorch versions before 2.2 (Flash Attention v2 is only 29 | # available on PyTorch 2.2+, while Flash Attention v1 cannot handle all cases) 30 | pytorch_version = tuple(int(v) for v in torch.__version__.split(".")[:2]) 31 | if pytorch_version < (2, 2): 32 | warnings.warn( 33 | f"You are using PyTorch {torch.__version__} without Flash Attention v2 support. " 34 | "Consider upgrading to PyTorch 2.2+ for Flash Attention v2 (which could be faster).", 35 | category=UserWarning, 36 | stacklevel=2, 37 | ) 38 | math_kernel_on = pytorch_version < (2, 2) or not use_flash_attn 39 | else: 40 | old_gpu = True 41 | use_flash_attn = False 42 | math_kernel_on = True 43 | 44 | return old_gpu, use_flash_attn, math_kernel_on 45 | 46 | 47 | def get_connected_components(mask): 48 | """ 49 | Get the connected components (8-connectivity) of binary masks of shape (N, 1, H, W). 50 | 51 | Inputs: 52 | - mask: A binary mask tensor of shape (N, 1, H, W), where 1 is foreground and 0 is 53 | background. 54 | 55 | Outputs: 56 | - labels: A tensor of shape (N, 1, H, W) containing the connected component labels 57 | for foreground pixels and 0 for background pixels. 58 | - counts: A tensor of shape (N, 1, H, W) containing the area of the connected 59 | components for foreground pixels and 0 for background pixels. 60 | """ 61 | from sam2 import _C 62 | 63 | return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) 64 | 65 | 66 | def mask_to_box(masks: torch.Tensor): 67 | """ 68 | compute bounding box given an input mask 69 | 70 | Inputs: 71 | - masks: [B, 1, H, W] boxes, dtype=torch.Tensor 72 | 73 | Returns: 74 | - box_coords: [B, 1, 4], contains (x, y) coordinates of top left and bottom right box corners, dtype=torch.Tensor 75 | """ 76 | B, _, h, w = masks.shape 77 | device = masks.device 78 | xs = torch.arange(w, device=device, dtype=torch.int32) 79 | ys = torch.arange(h, device=device, dtype=torch.int32) 80 | grid_xs, grid_ys = torch.meshgrid(xs, ys, indexing="xy") 81 | grid_xs = grid_xs[None, None, ...].expand(B, 1, h, w) 82 | grid_ys = grid_ys[None, None, ...].expand(B, 1, h, w) 83 | min_xs, _ = torch.min(torch.where(masks, grid_xs, w).flatten(-2), dim=-1) 84 | max_xs, _ = torch.max(torch.where(masks, grid_xs, -1).flatten(-2), dim=-1) 85 | min_ys, _ = torch.min(torch.where(masks, grid_ys, h).flatten(-2), dim=-1) 86 | max_ys, _ = torch.max(torch.where(masks, grid_ys, -1).flatten(-2), dim=-1) 87 | bbox_coords = torch.stack((min_xs, min_ys, max_xs, max_ys), dim=-1) 88 | 89 | return bbox_coords 90 | 91 | 92 | def _load_img_as_tensor(img_path, image_size): 93 | img_pil = Image.open(img_path) 94 | img_np = np.array(img_pil.convert("RGB").resize((image_size, image_size))) 95 | if img_np.dtype == np.uint8: # np.uint8 is expected for JPEG images 96 | img_np = img_np / 255.0 97 | else: 98 | raise RuntimeError(f"Unknown image dtype: {img_np.dtype} on {img_path}") 99 | img = torch.from_numpy(img_np).permute(2, 0, 1) 100 | video_width, video_height = img_pil.size # the original video size 101 | return img, video_height, video_width 102 | 103 | 104 | class AsyncVideoFrameLoader: 105 | """ 106 | A list of video frames to be load asynchronously without blocking session start. 107 | """ 108 | 109 | def __init__(self, img_paths, image_size, offload_video_to_cpu, img_mean, img_std): 110 | self.img_paths = img_paths 111 | self.image_size = image_size 112 | self.offload_video_to_cpu = offload_video_to_cpu 113 | self.img_mean = img_mean 114 | self.img_std = img_std 115 | # items in `self._images` will be loaded asynchronously 116 | self.images = [None] * len(img_paths) 117 | # catch and raise any exceptions in the async loading thread 118 | self.exception = None 119 | # video_height and video_width be filled when loading the first image 120 | self.video_height = None 121 | self.video_width = None 122 | 123 | # load the first frame to fill video_height and video_width and also 124 | # to cache it (since it's most likely where the user will click) 125 | self.__getitem__(0) 126 | 127 | # load the rest of frames asynchronously without blocking the session start 128 | def _load_frames(): 129 | try: 130 | for n in tqdm(range(len(self.images)), desc="frame loading (JPEG)"): 131 | self.__getitem__(n) 132 | except Exception as e: 133 | self.exception = e 134 | 135 | self.thread = Thread(target=_load_frames, daemon=True) 136 | self.thread.start() 137 | 138 | def __getitem__(self, index): 139 | if self.exception is not None: 140 | raise RuntimeError("Failure in frame loading thread") from self.exception 141 | 142 | img = self.images[index] 143 | if img is not None: 144 | return img 145 | 146 | img, video_height, video_width = _load_img_as_tensor( 147 | self.img_paths[index], self.image_size 148 | ) 149 | self.video_height = video_height 150 | self.video_width = video_width 151 | # normalize by mean and std 152 | img -= self.img_mean 153 | img /= self.img_std 154 | if not self.offload_video_to_cpu: 155 | img = img.cuda(non_blocking=True) 156 | self.images[index] = img 157 | return img 158 | 159 | def __len__(self): 160 | return len(self.images) 161 | 162 | 163 | def load_video_frames( 164 | video_path, 165 | image_size, 166 | offload_video_to_cpu, 167 | img_mean=(0.485, 0.456, 0.406), 168 | img_std=(0.229, 0.224, 0.225), 169 | async_loading_frames=False, 170 | ): 171 | """ 172 | Load the video frames from a directory of JPEG files (".jpg" format). 173 | 174 | The frames are resized to image_size x image_size and are loaded to GPU if 175 | `offload_video_to_cpu` is `False` and to CPU if `offload_video_to_cpu` is `True`. 176 | 177 | You can load a frame asynchronously by setting `async_loading_frames` to `True`. 178 | """ 179 | if isinstance(video_path, str) and os.path.isdir(video_path): 180 | jpg_folder = video_path 181 | else: 182 | raise NotImplementedError("Only JPEG frames are supported at this moment") 183 | 184 | frame_names = [ 185 | p 186 | for p in os.listdir(jpg_folder) 187 | if os.path.splitext(p)[-1] in [".jpg", ".jpeg", ".JPG", ".JPEG"] 188 | ] 189 | frame_names.sort(key=lambda p: int(os.path.splitext(p)[0])) 190 | num_frames = len(frame_names) 191 | if num_frames == 0: 192 | raise RuntimeError(f"no images found in {jpg_folder}") 193 | img_paths = [os.path.join(jpg_folder, frame_name) for frame_name in frame_names] 194 | img_mean = torch.tensor(img_mean, dtype=torch.float32)[:, None, None] 195 | img_std = torch.tensor(img_std, dtype=torch.float32)[:, None, None] 196 | 197 | if async_loading_frames: 198 | lazy_images = AsyncVideoFrameLoader( 199 | img_paths, image_size, offload_video_to_cpu, img_mean, img_std 200 | ) 201 | return lazy_images, lazy_images.video_height, lazy_images.video_width 202 | 203 | images = torch.zeros(num_frames, 3, image_size, image_size, dtype=torch.float32) 204 | for n, img_path in enumerate(tqdm(img_paths, desc="frame loading (JPEG)")): 205 | images[n], video_height, video_width = _load_img_as_tensor(img_path, image_size) 206 | if not offload_video_to_cpu: 207 | images = images.cuda() 208 | img_mean = img_mean.cuda() 209 | img_std = img_std.cuda() 210 | # normalize by mean and std 211 | images -= img_mean 212 | images /= img_std 213 | return images, video_height, video_width 214 | 215 | 216 | def fill_holes_in_mask_scores(mask, max_area): 217 | """ 218 | A post processor to fill small holes in mask scores with area under `max_area`. 219 | """ 220 | # Holes are those connected components in background with area <= self.max_area 221 | # (background regions are those with mask scores <= 0) 222 | assert max_area > 0, "max_area must be positive" 223 | labels, areas = get_connected_components(mask <= 0) 224 | is_hole = (labels > 0) & (areas <= max_area) 225 | # We fill holes with a small positive mask score (0.1) to change them to foreground. 226 | mask = torch.where(is_hole, 0.1, mask) 227 | return mask 228 | 229 | 230 | def concat_points(old_point_inputs, new_points, new_labels): 231 | """Add new points and labels to previous point inputs (add at the end).""" 232 | if old_point_inputs is None: 233 | points, labels = new_points, new_labels 234 | else: 235 | points = torch.cat([old_point_inputs["point_coords"], new_points], dim=1) 236 | labels = torch.cat([old_point_inputs["point_labels"], new_labels], dim=1) 237 | 238 | return {"point_coords": points, "point_labels": labels} 239 | -------------------------------------------------------------------------------- /sam2/utils/test.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /sam2/utils/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from torchvision.transforms import Normalize, Resize, ToTensor 11 | 12 | 13 | class SAM2Transforms(nn.Module): 14 | def __init__( 15 | self, resolution, mask_threshold, max_hole_area=0.0, max_sprinkle_area=0.0 16 | ): 17 | """ 18 | Transforms for SAM2. 19 | """ 20 | super().__init__() 21 | self.resolution = resolution 22 | self.mask_threshold = mask_threshold 23 | self.max_hole_area = max_hole_area 24 | self.max_sprinkle_area = max_sprinkle_area 25 | self.mean = [0.485, 0.456, 0.406] 26 | self.std = [0.229, 0.224, 0.225] 27 | self.to_tensor = ToTensor() 28 | self.transforms = torch.jit.script( 29 | nn.Sequential( 30 | Resize((self.resolution, self.resolution)), 31 | Normalize(self.mean, self.std), 32 | ) 33 | ) 34 | 35 | def __call__(self, x): 36 | x = self.to_tensor(x) 37 | return self.transforms(x) 38 | 39 | def forward_batch(self, img_list): 40 | img_batch = [self.transforms(self.to_tensor(img)) for img in img_list] 41 | img_batch = torch.stack(img_batch, dim=0) 42 | return img_batch 43 | 44 | def transform_coords( 45 | self, coords: torch.Tensor, normalize=False, orig_hw=None 46 | ) -> torch.Tensor: 47 | """ 48 | Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates, 49 | If the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 50 | 51 | Returns 52 | Un-normalized coordinates in the range of [0, 1] which is expected by the SAM2 model. 53 | """ 54 | if normalize: 55 | assert orig_hw is not None 56 | h, w = orig_hw 57 | coords = coords.clone() 58 | coords[..., 0] = coords[..., 0] / w 59 | coords[..., 1] = coords[..., 1] / h 60 | 61 | coords = coords * self.resolution # unnormalize coords 62 | return coords 63 | 64 | def transform_boxes( 65 | self, boxes: torch.Tensor, normalize=False, orig_hw=None 66 | ) -> torch.Tensor: 67 | """ 68 | Expects a tensor of shape Bx4. The coordinates can be in absolute image or normalized coordinates, 69 | if the coords are in absolute image coordinates, normalize should be set to True and original image size is required. 70 | """ 71 | boxes = self.transform_coords(boxes.reshape(-1, 2, 2), normalize, orig_hw) 72 | return boxes 73 | 74 | def postprocess_masks(self, masks: torch.Tensor, orig_hw) -> torch.Tensor: 75 | """ 76 | Perform PostProcessing on output masks. 77 | """ 78 | from sam2.utils.misc import get_connected_components 79 | 80 | masks = masks.float() 81 | if self.max_hole_area > 0: 82 | # Holes are those connected components in background with area <= self.fill_hole_area 83 | # (background regions are those with mask scores <= self.mask_threshold) 84 | mask_flat = masks.flatten(0, 1).unsqueeze(1) # flatten as 1-channel image 85 | labels, areas = get_connected_components(mask_flat <= self.mask_threshold) 86 | is_hole = (labels > 0) & (areas <= self.max_hole_area) 87 | is_hole = is_hole.reshape_as(masks) 88 | # We fill holes with a small positive mask score (10.0) to change them to foreground. 89 | masks = torch.where(is_hole, self.mask_threshold + 10.0, masks) 90 | 91 | if self.max_sprinkle_area > 0: 92 | labels, areas = get_connected_components(mask_flat > self.mask_threshold) 93 | is_hole = (labels > 0) & (areas <= self.max_sprinkle_area) 94 | is_hole = is_hole.reshape_as(masks) 95 | # We fill holes with negative mask score (-10.0) to change them to background. 96 | masks = torch.where(is_hole, self.mask_threshold - 10.0, masks) 97 | 98 | masks = F.interpolate(masks, orig_hw, mode="bilinear", align_corners=False) 99 | return masks 100 | -------------------------------------------------------------------------------- /sam2_configs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_b+.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 112 12 | num_heads: 2 13 | neck: 14 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 15 | position_encoding: 16 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 17 | num_pos_feats: 256 18 | normalize: true 19 | scale: null 20 | temperature: 10000 21 | d_model: 256 22 | backbone_channel_list: [896, 448, 224, 112] 23 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 24 | fpn_interp_model: nearest 25 | 26 | memory_attention: 27 | _target_: sam2.modeling.memory_attention.MemoryAttention 28 | d_model: 256 29 | pos_enc_at_input: true 30 | layer: 31 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 32 | activation: relu 33 | dim_feedforward: 2048 34 | dropout: 0.1 35 | pos_enc_at_attn: false 36 | self_attention: 37 | _target_: sam2.modeling.sam.transformer.RoPEAttention 38 | rope_theta: 10000.0 39 | feat_sizes: [32, 32] 40 | embedding_dim: 256 41 | num_heads: 1 42 | downsample_rate: 1 43 | dropout: 0.1 44 | d_model: 256 45 | pos_enc_at_cross_attn_keys: true 46 | pos_enc_at_cross_attn_queries: false 47 | cross_attention: 48 | _target_: sam2.modeling.sam.transformer.RoPEAttention 49 | rope_theta: 10000.0 50 | feat_sizes: [32, 32] 51 | rope_k_repeat: True 52 | embedding_dim: 256 53 | num_heads: 1 54 | downsample_rate: 1 55 | dropout: 0.1 56 | kv_in_dim: 64 57 | num_layers: 4 58 | 59 | memory_encoder: 60 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 61 | out_dim: 64 62 | position_encoding: 63 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 64 | num_pos_feats: 64 65 | normalize: true 66 | scale: null 67 | temperature: 10000 68 | mask_downsampler: 69 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 70 | kernel_size: 3 71 | stride: 2 72 | padding: 1 73 | fuser: 74 | _target_: sam2.modeling.memory_encoder.Fuser 75 | layer: 76 | _target_: sam2.modeling.memory_encoder.CXBlock 77 | dim: 256 78 | kernel_size: 7 79 | padding: 3 80 | layer_scale_init_value: 1e-6 81 | use_dwconv: True # depth-wise convs 82 | num_layers: 2 83 | 84 | num_maskmem: 7 85 | image_size: 1024 86 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 87 | sigmoid_scale_for_mem_enc: 20.0 88 | sigmoid_bias_for_mem_enc: -10.0 89 | use_mask_input_as_output_without_sam: true 90 | # Memory 91 | directly_add_no_mem_embed: true 92 | # use high-resolution feature map in the SAM mask decoder 93 | use_high_res_features_in_sam: true 94 | # output 3 masks on the first click on initial conditioning frames 95 | multimask_output_in_sam: true 96 | # SAM heads 97 | iou_prediction_use_sigmoid: True 98 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 99 | use_obj_ptrs_in_encoder: true 100 | add_tpos_enc_to_obj_ptrs: false 101 | only_obj_ptrs_in_the_past_for_eval: true 102 | # object occlusion prediction 103 | pred_obj_scores: true 104 | pred_obj_scores_mlp: true 105 | fixed_no_obj_ptr: true 106 | # multimask tracking settings 107 | multimask_output_for_tracking: true 108 | use_multimask_token_for_obj_ptr: true 109 | multimask_min_pt_num: 0 110 | multimask_max_pt_num: 1 111 | use_mlp_for_obj_ptr_proj: true 112 | # Compilation flag 113 | compile_image_encoder: False 114 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_l.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 144 12 | num_heads: 2 13 | stages: [2, 6, 36, 4] 14 | global_att_blocks: [23, 33, 43] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | window_spec: [8, 4, 16, 8] 17 | neck: 18 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 19 | position_encoding: 20 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 21 | num_pos_feats: 256 22 | normalize: true 23 | scale: null 24 | temperature: 10000 25 | d_model: 256 26 | backbone_channel_list: [1152, 576, 288, 144] 27 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 28 | fpn_interp_model: nearest 29 | 30 | memory_attention: 31 | _target_: sam2.modeling.memory_attention.MemoryAttention 32 | d_model: 256 33 | pos_enc_at_input: true 34 | layer: 35 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 36 | activation: relu 37 | dim_feedforward: 2048 38 | dropout: 0.1 39 | pos_enc_at_attn: false 40 | self_attention: 41 | _target_: sam2.modeling.sam.transformer.RoPEAttention 42 | rope_theta: 10000.0 43 | feat_sizes: [32, 32] 44 | embedding_dim: 256 45 | num_heads: 1 46 | downsample_rate: 1 47 | dropout: 0.1 48 | d_model: 256 49 | pos_enc_at_cross_attn_keys: true 50 | pos_enc_at_cross_attn_queries: false 51 | cross_attention: 52 | _target_: sam2.modeling.sam.transformer.RoPEAttention 53 | rope_theta: 10000.0 54 | feat_sizes: [32, 32] 55 | rope_k_repeat: True 56 | embedding_dim: 256 57 | num_heads: 1 58 | downsample_rate: 1 59 | dropout: 0.1 60 | kv_in_dim: 64 61 | num_layers: 4 62 | 63 | memory_encoder: 64 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 65 | out_dim: 64 66 | position_encoding: 67 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 68 | num_pos_feats: 64 69 | normalize: true 70 | scale: null 71 | temperature: 10000 72 | mask_downsampler: 73 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 74 | kernel_size: 3 75 | stride: 2 76 | padding: 1 77 | fuser: 78 | _target_: sam2.modeling.memory_encoder.Fuser 79 | layer: 80 | _target_: sam2.modeling.memory_encoder.CXBlock 81 | dim: 256 82 | kernel_size: 7 83 | padding: 3 84 | layer_scale_init_value: 1e-6 85 | use_dwconv: True # depth-wise convs 86 | num_layers: 2 87 | 88 | num_maskmem: 7 89 | image_size: 1024 90 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | compile_image_encoder: False 118 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_s.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 11, 2] 14 | global_att_blocks: [7, 10, 13] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | sigmoid_scale_for_mem_enc: 20.0 91 | sigmoid_bias_for_mem_enc: -10.0 92 | use_mask_input_as_output_without_sam: true 93 | # Memory 94 | directly_add_no_mem_embed: true 95 | # use high-resolution feature map in the SAM mask decoder 96 | use_high_res_features_in_sam: true 97 | # output 3 masks on the first click on initial conditioning frames 98 | multimask_output_in_sam: true 99 | # SAM heads 100 | iou_prediction_use_sigmoid: True 101 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 102 | use_obj_ptrs_in_encoder: true 103 | add_tpos_enc_to_obj_ptrs: false 104 | only_obj_ptrs_in_the_past_for_eval: true 105 | # object occlusion prediction 106 | pred_obj_scores: true 107 | pred_obj_scores_mlp: true 108 | fixed_no_obj_ptr: true 109 | # multimask tracking settings 110 | multimask_output_for_tracking: true 111 | use_multimask_token_for_obj_ptr: true 112 | multimask_min_pt_num: 0 113 | multimask_max_pt_num: 1 114 | use_mlp_for_obj_ptr_proj: true 115 | # Compilation flag 116 | compile_image_encoder: False 117 | -------------------------------------------------------------------------------- /sam2_configs/sam2_hiera_t.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | # Model 4 | model: 5 | _target_: sam2.modeling.sam2_base.SAM2Base 6 | image_encoder: 7 | _target_: sam2.modeling.backbones.image_encoder.ImageEncoder 8 | scalp: 1 9 | trunk: 10 | _target_: sam2.modeling.backbones.hieradet.Hiera 11 | embed_dim: 96 12 | num_heads: 1 13 | stages: [1, 2, 7, 2] 14 | global_att_blocks: [5, 7, 9] 15 | window_pos_embed_bkg_spatial_size: [7, 7] 16 | neck: 17 | _target_: sam2.modeling.backbones.image_encoder.FpnNeck 18 | position_encoding: 19 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 20 | num_pos_feats: 256 21 | normalize: true 22 | scale: null 23 | temperature: 10000 24 | d_model: 256 25 | backbone_channel_list: [768, 384, 192, 96] 26 | fpn_top_down_levels: [2, 3] # output level 0 and 1 directly use the backbone features 27 | fpn_interp_model: nearest 28 | 29 | memory_attention: 30 | _target_: sam2.modeling.memory_attention.MemoryAttention 31 | d_model: 256 32 | pos_enc_at_input: true 33 | layer: 34 | _target_: sam2.modeling.memory_attention.MemoryAttentionLayer 35 | activation: relu 36 | dim_feedforward: 2048 37 | dropout: 0.1 38 | pos_enc_at_attn: false 39 | self_attention: 40 | _target_: sam2.modeling.sam.transformer.RoPEAttention 41 | rope_theta: 10000.0 42 | feat_sizes: [32, 32] 43 | embedding_dim: 256 44 | num_heads: 1 45 | downsample_rate: 1 46 | dropout: 0.1 47 | d_model: 256 48 | pos_enc_at_cross_attn_keys: true 49 | pos_enc_at_cross_attn_queries: false 50 | cross_attention: 51 | _target_: sam2.modeling.sam.transformer.RoPEAttention 52 | rope_theta: 10000.0 53 | feat_sizes: [32, 32] 54 | rope_k_repeat: True 55 | embedding_dim: 256 56 | num_heads: 1 57 | downsample_rate: 1 58 | dropout: 0.1 59 | kv_in_dim: 64 60 | num_layers: 4 61 | 62 | memory_encoder: 63 | _target_: sam2.modeling.memory_encoder.MemoryEncoder 64 | out_dim: 64 65 | position_encoding: 66 | _target_: sam2.modeling.position_encoding.PositionEmbeddingSine 67 | num_pos_feats: 64 68 | normalize: true 69 | scale: null 70 | temperature: 10000 71 | mask_downsampler: 72 | _target_: sam2.modeling.memory_encoder.MaskDownSampler 73 | kernel_size: 3 74 | stride: 2 75 | padding: 1 76 | fuser: 77 | _target_: sam2.modeling.memory_encoder.Fuser 78 | layer: 79 | _target_: sam2.modeling.memory_encoder.CXBlock 80 | dim: 256 81 | kernel_size: 7 82 | padding: 3 83 | layer_scale_init_value: 1e-6 84 | use_dwconv: True # depth-wise convs 85 | num_layers: 2 86 | 87 | num_maskmem: 7 88 | image_size: 1024 89 | # apply scaled sigmoid on mask logits for memory encoder, and directly feed input mask as output mask 90 | # SAM decoder 91 | sigmoid_scale_for_mem_enc: 20.0 92 | sigmoid_bias_for_mem_enc: -10.0 93 | use_mask_input_as_output_without_sam: true 94 | # Memory 95 | directly_add_no_mem_embed: true 96 | # use high-resolution feature map in the SAM mask decoder 97 | use_high_res_features_in_sam: true 98 | # output 3 masks on the first click on initial conditioning frames 99 | multimask_output_in_sam: true 100 | # SAM heads 101 | iou_prediction_use_sigmoid: True 102 | # cross-attend to object pointers from other frames (based on SAM output tokens) in the encoder 103 | use_obj_ptrs_in_encoder: true 104 | add_tpos_enc_to_obj_ptrs: false 105 | only_obj_ptrs_in_the_past_for_eval: true 106 | # object occlusion prediction 107 | pred_obj_scores: true 108 | pred_obj_scores_mlp: true 109 | fixed_no_obj_ptr: true 110 | # multimask tracking settings 111 | multimask_output_for_tracking: true 112 | use_multimask_token_for_obj_ptr: true 113 | multimask_min_pt_num: 0 114 | multimask_max_pt_num: 1 115 | use_mlp_for_obj_ptr_proj: true 116 | # Compilation flag 117 | # HieraT does not currently support compilation, should always be set to False 118 | compile_image_encoder: False 119 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE: -------------------------------------------------------------------------------- 1 | BSD License 2 | 3 | For SAM 2 Eval software 4 | 5 | Copyright (c) Meta Platforms, Inc. and affiliates. 6 | 7 | Redistribution and use in source and binary forms, with or without modification, 8 | are permitted provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, this 11 | list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright notice, 14 | this list of conditions and the following disclaimer in the documentation 15 | and/or other materials provided with the distribution. 16 | 17 | * Neither the name Meta nor the names of its contributors may be used to 18 | endorse or promote products derived from this software without specific 19 | prior written permission. 20 | 21 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 22 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 23 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 24 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 25 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 26 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 27 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 28 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 29 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 30 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 31 | -------------------------------------------------------------------------------- /sav_dataset/LICENSE_DAVIS: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, DAVIS: Densely Annotated VIdeo Segmentation 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /sav_dataset/LICENSE_VOS_BENCHMARK: -------------------------------------------------------------------------------- 1 | Copyright 2023 Rex Cheng 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -------------------------------------------------------------------------------- /sav_dataset/README.md: -------------------------------------------------------------------------------- 1 | # Segment Anything Video (SA-V) Dataset 2 | 3 | ## Overview 4 | 5 | [Segment Anything Video (SA-V)](https://ai.meta.com/datasets/segment-anything-video/), consists of 51K diverse videos and 643K high-quality spatio-temporal segmentation masks (i.e., masklets). The dataset is released under the CC by 4.0 license. Browse the dataset [here](https://sam2.metademolab.com/dataset). 6 | 7 | ![SA-V dataset](../assets/sa_v_dataset.jpg?raw=true) 8 | 9 | ## Getting Started 10 | 11 | ### Download the dataset 12 | 13 | Visit [here](https://ai.meta.com/datasets/segment-anything-video-downloads/) to download SA-V including the training, val and test sets. 14 | 15 | ### Dataset Stats 16 | 17 | | | Num Videos | Num Masklets | 18 | | ---------- | ---------- | ----------------------------------------- | 19 | | SA-V train | 50,583 | 642,036 (auto 451,720 and manual 190,316) | 20 | | SA-V val | 155 | 293 | 21 | | SA-V test | 150 | 278 | 22 | 23 | ### Notebooks 24 | 25 | To load and visualize the SA-V training set annotations, refer to the example [sav_visualization_example.ipynb](./sav_visualization_example.ipynb) notebook. 26 | 27 | ### SA-V train 28 | 29 | For SA-V training set we release the mp4 videos and store the masklet annotations per video as json files . Automatic masklets and manual masklets are stored separately as two json files: `{video_id}_auto.json` and `{video_id}_manual.json`. They can be loaded as dictionaries in python in the format below. 30 | 31 | ``` 32 | { 33 | "video_id" : str; video id 34 | "video_duration" : float64; the duration in seconds of this video 35 | "video_frame_count" : float64; the number of frames in the video 36 | "video_height" : float64; the height of the video 37 | "video_width" : float64; the width of the video 38 | "video_resolution" : float64; video_height $\times$ video_width 39 | "video_environment" : List[str]; "Indoor" or "Outdoor" 40 | "video_split" : str; "train" for training set 41 | "masklet" : List[List[Dict]]; masklet annotations in list of list of RLEs. 42 | The outer list is over frames in the video and the inner list 43 | is over objects in the video. 44 | "masklet_id" : List[int]; the masklet ids 45 | "masklet_size_rel" : List[float]; the average mask area normalized by resolution 46 | across all the frames where the object is visible 47 | "masklet_size_abs" : List[float]; the average mask area (in pixels) 48 | across all the frames where the object is visible 49 | "masklet_size_bucket" : List[str]; "small": $1$ <= masklet_size_abs < $32^2$, 50 | "medium": $32^2$ <= masklet_size_abs < $96^2$, 51 | and "large": masklet_size_abs > $96^2$ 52 | "masklet_visibility_changes" : List[int]; the number of times where the visibility changes 53 | after the first appearance (e.g., invisible -> visible 54 | or visible -> invisible) 55 | "masklet_first_appeared_frame" : List[int]; the index of the frame where the object appears 56 | the first time in the video. Always 0 for auto masklets. 57 | "masklet_frame_count" : List[int]; the number of frames being annotated. Note that 58 | videos are annotated at 6 fps (annotated every 4 frames) 59 | while the videos are at 24 fps. 60 | "masklet_edited_frame_count" : List[int]; the number of frames being edited by human annotators. 61 | Always 0 for auto masklets. 62 | "masklet_type" : List[str]; "auto" or "manual" 63 | "masklet_stability_score" : Optional[List[List[float]]]; per-mask stability scores. Auto annotation only. 64 | "masklet_num" : int; the number of manual/auto masklets in the video 65 | 66 | } 67 | ``` 68 | 69 | Note that in SA-V train, there are in total 50,583 videos where all of them have manual annotations. Among the 50,583 videos there are 48,436 videos that also have automatic annotations. 70 | 71 | ### SA-V val and test 72 | 73 | For SA-V val and test sets, we release the extracted frames as jpeg files, and the masks as png files with the following directory structure: 74 | 75 | ``` 76 | sav_val(sav_test) 77 | ├── sav_val.txt (sav_test.txt): a list of video ids in the split 78 | ├── JPEGImages_24fps # videos are extracted at 24 fps 79 | │ ├── {video_id} 80 | │ │ ├── 00000.jpg # video frame 81 | │ │ ├── 00001.jpg # video frame 82 | │ │ ├── 00002.jpg # video frame 83 | │ │ ├── 00003.jpg # video frame 84 | │ │ └── ... 85 | │ ├── {video_id} 86 | │ ├── {video_id} 87 | │ └── ... 88 | └── Annotations_6fps # videos are annotated at 6 fps 89 | ├── {video_id} 90 | │ ├── 000 # obj 000 91 | │ │ ├── 00000.png # mask for object 000 in 00000.jpg 92 | │ │ ├── 00004.png # mask for object 000 in 00004.jpg 93 | │ │ ├── 00008.png # mask for object 000 in 00008.jpg 94 | │ │ ├── 00012.png # mask for object 000 in 00012.jpg 95 | │ │ └── ... 96 | │ ├── 001 # obj 001 97 | │ ├── 002 # obj 002 98 | │ └── ... 99 | ├── {video_id} 100 | ├── {video_id} 101 | └── ... 102 | ``` 103 | 104 | All masklets in val and test sets are manually annotated in every frame by annotators. For each annotated object in a video, we store the annotated masks in a single png. This is because the annotated objects may overlap, e.g., it is possible in our SA-V dataset for there to be a mask for the whole person as well as a separate mask for their hands. 105 | 106 | ## SA-V Val and Test Evaluation 107 | 108 | We provide an evaluator to compute the common J and F metrics on SA-V val and test sets. To run the evaluation, we need to first install a few dependencies as follows: 109 | 110 | ``` 111 | pip install -r requirements.txt 112 | ``` 113 | 114 | Then we can evaluate the predictions as follows: 115 | 116 | ``` 117 | python sav_evaluator.py --gt_root {GT_ROOT} --pred_root {PRED_ROOT} 118 | ``` 119 | 120 | or run 121 | 122 | ``` 123 | python sav_evaluator.py --help 124 | ``` 125 | 126 | to print a complete help message. 127 | 128 | The evaluator expects the `GT_ROOT` to be one of the following folder structures, and `GT_ROOT` and `PRED_ROOT` to have the same structure. 129 | 130 | - Same as SA-V val and test directory structure 131 | 132 | ``` 133 | {GT_ROOT} # gt root folder 134 | ├── {video_id} 135 | │ ├── 000 # all masks associated with obj 000 136 | │ │ ├── 00000.png # mask for object 000 in frame 00000 (binary mask) 137 | │ │ └── ... 138 | │ ├── 001 # all masks associated with obj 001 139 | │ ├── 002 # all masks associated with obj 002 140 | │ └── ... 141 | ├── {video_id} 142 | ├── {video_id} 143 | └── ... 144 | ``` 145 | 146 | In the paper for the experiments on SA-V val and test, we run inference on the 24 fps videos, and evaluate on the subset of frames where we have ground truth annotations (first and last annotated frames dropped). The evaluator will ignore the masks in frames where we don't have ground truth annotations. 147 | 148 | - Same as [DAVIS](https://github.com/davisvideochallenge/davis2017-evaluation) directory structure 149 | 150 | ``` 151 | {GT_ROOT} # gt root folder 152 | ├── {video_id} 153 | │ ├── 00000.png # annotations in frame 00000 (may contain multiple objects) 154 | │ └── ... 155 | ├── {video_id} 156 | ├── {video_id} 157 | └── ... 158 | ``` 159 | 160 | ## License 161 | 162 | The evaluation code is licensed under the [BSD 3 license](./LICENSE). Please refer to the paper for more details on the models. The videos and annotations in SA-V Dataset are released under CC BY 4.0. 163 | 164 | Third-party code: the evaluation software is heavily adapted from [`VOS-Benchmark`](https://github.com/hkchengrex/vos-benchmark) and [`DAVIS`](https://github.com/davisvideochallenge/davis2017-evaluation) (with their licenses in [`LICENSE_DAVIS`](./LICENSE_DAVIS) and [`LICENSE_VOS_BENCHMARK`](./LICENSE_VOS_BENCHMARK)). 165 | -------------------------------------------------------------------------------- /sav_dataset/requirements.txt: -------------------------------------------------------------------------------- 1 | pycocoevalcap 2 | scikit-image 3 | opencv-python 4 | tqdm 5 | pillow 6 | numpy 7 | matplotlib -------------------------------------------------------------------------------- /sav_dataset/sav_evaluator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the sav_dataset directory of this source tree. 6 | 7 | # adapted from https://github.com/hkchengrex/vos-benchmark 8 | # and https://github.com/davisvideochallenge/davis2017-evaluation 9 | # with their licenses found in the LICENSE_VOS_BENCHMARK and LICENSE_DAVIS files 10 | # in the sav_dataset directory. 11 | from argparse import ArgumentParser 12 | 13 | from utils.sav_benchmark import benchmark 14 | 15 | """ 16 | The structure of the {GT_ROOT} can be either of the follow two structures. 17 | {GT_ROOT} and {PRED_ROOT} should be of the same format 18 | 19 | 1. SA-V val/test structure 20 | {GT_ROOT} # gt root folder 21 | ├── {video_id} 22 | │ ├── 000 # all masks associated with obj 000 23 | │ │ ├── {frame_id}.png # mask for object 000 in {frame_id} (binary mask) 24 | │ │ └── ... 25 | │ ├── 001 # all masks associated with obj 001 26 | │ ├── 002 # all masks associated with obj 002 27 | │ └── ... 28 | ├── {video_id} 29 | ├── {video_id} 30 | └── ... 31 | 32 | 2. Similar to DAVIS structure: 33 | 34 | {GT_ROOT} # gt root folder 35 | ├── {video_id} 36 | │ ├── {frame_id}.png # annotation in {frame_id} (may contain multiple objects) 37 | │ └── ... 38 | ├── {video_id} 39 | ├── {video_id} 40 | └── ... 41 | """ 42 | 43 | 44 | parser = ArgumentParser() 45 | parser.add_argument( 46 | "--gt_root", 47 | required=True, 48 | help="Path to the GT folder. For SA-V, it's sav_val/Annotations_6fps or sav_test/Annotations_6fps", 49 | ) 50 | parser.add_argument( 51 | "--pred_root", 52 | required=True, 53 | help="Path to a folder containing folders of masks to be evaluated, with exactly the same structure as gt_root", 54 | ) 55 | parser.add_argument( 56 | "-n", "--num_processes", default=16, type=int, help="Number of concurrent processes" 57 | ) 58 | parser.add_argument( 59 | "-s", 60 | "--strict", 61 | help="Make sure every video in the gt_root folder has a corresponding video in the prediction", 62 | action="store_true", 63 | ) 64 | parser.add_argument( 65 | "-q", 66 | "--quiet", 67 | help="Quietly run evaluation without printing the information out", 68 | action="store_true", 69 | ) 70 | 71 | # https://github.com/davisvideochallenge/davis2017-evaluation/blob/d34fdef71ce3cb24c1a167d860b707e575b3034c/davis2017/evaluation.py#L85 72 | parser.add_argument( 73 | "--do_not_skip_first_and_last_frame", 74 | help="In SA-V val and test, we skip the first and the last annotated frames in evaluation. " 75 | "Set this to true for evaluation on settings that doen't skip first and last frames", 76 | action="store_true", 77 | ) 78 | 79 | 80 | if __name__ == "__main__": 81 | args = parser.parse_args() 82 | benchmark( 83 | [args.gt_root], 84 | [args.pred_root], 85 | args.strict, 86 | args.num_processes, 87 | verbose=not args.quiet, 88 | skip_first_and_last=not args.do_not_skip_first_and_last_frame, 89 | ) 90 | -------------------------------------------------------------------------------- /sav_dataset/utils/sav_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the sav_dataset directory of this source tree. 6 | import json 7 | import os 8 | from typing import Dict, List, Optional, Tuple 9 | 10 | import cv2 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | import pycocotools.mask as mask_util 14 | 15 | 16 | def decode_video(video_path: str) -> List[np.ndarray]: 17 | """ 18 | Decode the video and return the RGB frames 19 | """ 20 | video = cv2.VideoCapture(video_path) 21 | video_frames = [] 22 | while video.isOpened(): 23 | ret, frame = video.read() 24 | if ret: 25 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 26 | video_frames.append(frame) 27 | else: 28 | break 29 | return video_frames 30 | 31 | 32 | def show_anns(masks, colors: List, borders=True) -> None: 33 | """ 34 | show the annotations 35 | """ 36 | # return if no masks 37 | if len(masks) == 0: 38 | return 39 | 40 | # sort masks by size 41 | sorted_annot_and_color = sorted( 42 | zip(masks, colors), key=(lambda x: x[0].sum()), reverse=True 43 | ) 44 | H, W = sorted_annot_and_color[0][0].shape[0], sorted_annot_and_color[0][0].shape[1] 45 | 46 | canvas = np.ones((H, W, 4)) 47 | canvas[:, :, 3] = 0 # set the alpha channel 48 | contour_thickness = max(1, int(min(5, 0.01 * min(H, W)))) 49 | for mask, color in sorted_annot_and_color: 50 | canvas[mask] = np.concatenate([color, [0.55]]) 51 | if borders: 52 | contours, _ = cv2.findContours( 53 | np.array(mask, dtype=np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE 54 | ) 55 | cv2.drawContours( 56 | canvas, contours, -1, (0.05, 0.05, 0.05, 1), thickness=contour_thickness 57 | ) 58 | 59 | ax = plt.gca() 60 | ax.imshow(canvas) 61 | 62 | 63 | class SAVDataset: 64 | """ 65 | SAVDataset is a class to load the SAV dataset and visualize the annotations. 66 | """ 67 | 68 | def __init__(self, sav_dir, annot_sample_rate=4): 69 | """ 70 | Args: 71 | sav_dir: the directory of the SAV dataset 72 | annot_sample_rate: the sampling rate of the annotations. 73 | The annotations are aligned with the videos at 6 fps. 74 | """ 75 | self.sav_dir = sav_dir 76 | self.annot_sample_rate = annot_sample_rate 77 | self.manual_mask_colors = np.random.random((256, 3)) 78 | self.auto_mask_colors = np.random.random((256, 3)) 79 | 80 | def read_frames(self, mp4_path: str) -> None: 81 | """ 82 | Read the frames and downsample them to align with the annotations. 83 | """ 84 | if not os.path.exists(mp4_path): 85 | print(f"{mp4_path} doesn't exist.") 86 | return None 87 | else: 88 | # decode the video 89 | frames = decode_video(mp4_path) 90 | print(f"There are {len(frames)} frames decoded from {mp4_path} (24fps).") 91 | 92 | # downsample the frames to align with the annotations 93 | frames = frames[:: self.annot_sample_rate] 94 | print( 95 | f"Videos are annotated every {self.annot_sample_rate} frames. " 96 | "To align with the annotations, " 97 | f"downsample the video to {len(frames)} frames." 98 | ) 99 | return frames 100 | 101 | def get_frames_and_annotations( 102 | self, video_id: str 103 | ) -> Tuple[List | None, Dict | None, Dict | None]: 104 | """ 105 | Get the frames and annotations for video. 106 | """ 107 | # load the video 108 | mp4_path = os.path.join(self.sav_dir, video_id + ".mp4") 109 | frames = self.read_frames(mp4_path) 110 | if frames is None: 111 | return None, None, None 112 | 113 | # load the manual annotations 114 | manual_annot_path = os.path.join(self.sav_dir, video_id + "_manual.json") 115 | if not os.path.exists(manual_annot_path): 116 | print(f"{manual_annot_path} doesn't exist. Something might be wrong.") 117 | manual_annot = None 118 | else: 119 | manual_annot = json.load(open(manual_annot_path)) 120 | 121 | # load the manual annotations 122 | auto_annot_path = os.path.join(self.sav_dir, video_id + "_auto.json") 123 | if not os.path.exists(auto_annot_path): 124 | print(f"{auto_annot_path} doesn't exist.") 125 | auto_annot = None 126 | else: 127 | auto_annot = json.load(open(auto_annot_path)) 128 | 129 | return frames, manual_annot, auto_annot 130 | 131 | def visualize_annotation( 132 | self, 133 | frames: List[np.ndarray], 134 | auto_annot: Optional[Dict], 135 | manual_annot: Optional[Dict], 136 | annotated_frame_id: int, 137 | show_auto=True, 138 | show_manual=True, 139 | ) -> None: 140 | """ 141 | Visualize the annotations on the annotated_frame_id. 142 | If show_manual is True, show the manual annotations. 143 | If show_auto is True, show the auto annotations. 144 | By default, show both auto and manual annotations. 145 | """ 146 | 147 | if annotated_frame_id >= len(frames): 148 | print("invalid annotated_frame_id") 149 | return 150 | 151 | rles = [] 152 | colors = [] 153 | if show_manual and manual_annot is not None: 154 | rles.extend(manual_annot["masklet"][annotated_frame_id]) 155 | colors.extend( 156 | self.manual_mask_colors[ 157 | : len(manual_annot["masklet"][annotated_frame_id]) 158 | ] 159 | ) 160 | if show_auto and auto_annot is not None: 161 | rles.extend(auto_annot["masklet"][annotated_frame_id]) 162 | colors.extend( 163 | self.auto_mask_colors[: len(auto_annot["masklet"][annotated_frame_id])] 164 | ) 165 | 166 | plt.imshow(frames[annotated_frame_id]) 167 | 168 | if len(rles) > 0: 169 | masks = [mask_util.decode(rle) > 0 for rle in rles] 170 | show_anns(masks, colors) 171 | else: 172 | print("No annotation will be shown") 173 | 174 | plt.axis("off") 175 | plt.show() 176 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from setuptools import find_packages, setup 8 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 9 | 10 | # Package metadata 11 | NAME = "SAM 2" 12 | VERSION = "1.0" 13 | DESCRIPTION = "SAM 2: Segment Anything in Images and Videos" 14 | URL = "https://github.com/facebookresearch/segment-anything-2" 15 | AUTHOR = "Meta AI" 16 | AUTHOR_EMAIL = "segment-anything@meta.com" 17 | LICENSE = "Apache 2.0" 18 | 19 | # Read the contents of README file 20 | with open("README.md", "r") as f: 21 | LONG_DESCRIPTION = f.read() 22 | 23 | # Required dependencies 24 | REQUIRED_PACKAGES = [ 25 | "torch>=2.3.1", 26 | "torchvision>=0.18.1", 27 | "numpy>=1.24.4", 28 | "tqdm>=4.66.1", 29 | "hydra-core>=1.3.2", 30 | "iopath>=0.1.10", 31 | "pillow>=9.4.0", 32 | ] 33 | 34 | EXTRA_PACKAGES = { 35 | "demo": ["matplotlib>=3.9.1", "jupyter>=1.0.0", "opencv-python>=4.7.0"], 36 | "dev": ["black==24.2.0", "usort==1.0.2", "ufmt==2.0.0b2"], 37 | } 38 | 39 | 40 | def get_extensions(): 41 | srcs = ["sam2/csrc/connected_components.cu"] 42 | compile_args = { 43 | "cxx": [], 44 | "nvcc": [ 45 | "-DCUDA_HAS_FP16=1", 46 | "-D__CUDA_NO_HALF_OPERATORS__", 47 | "-D__CUDA_NO_HALF_CONVERSIONS__", 48 | "-D__CUDA_NO_HALF2_OPERATORS__", 49 | ], 50 | } 51 | ext_modules = [CUDAExtension("sam2._C", srcs, extra_compile_args=compile_args)] 52 | return ext_modules 53 | 54 | 55 | # Setup configuration 56 | setup( 57 | name=NAME, 58 | version=VERSION, 59 | description=DESCRIPTION, 60 | long_description=LONG_DESCRIPTION, 61 | long_description_content_type="text/markdown", 62 | url=URL, 63 | author=AUTHOR, 64 | author_email=AUTHOR_EMAIL, 65 | license=LICENSE, 66 | packages=find_packages(exclude="notebooks"), 67 | install_requires=REQUIRED_PACKAGES, 68 | extras_require=EXTRA_PACKAGES, 69 | python_requires=">=3.10.0", 70 | ext_modules=get_extensions(), 71 | cmdclass={"build_ext": BuildExtension.with_options(no_python_abi_suffix=True)}, 72 | ) 73 | -------------------------------------------------------------------------------- /task_vector/convert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | from safetensors.torch import load_file 4 | import argparse 5 | 6 | # Create ArgumentParser object 7 | parser = argparse.ArgumentParser(description='Process some integers.') 8 | 9 | # Add arguments 10 | parser.add_argument('--tensor_path', type=str, default='', help='') 11 | parser.add_argument('--unet_path', type=str, default='', help='') 12 | parser.add_argument('--text_encoder_path', type=str, default='', help='') 13 | parser.add_argument('--vae_path', type=str, default='', help='') 14 | 15 | parser.add_argument('--source_path', type=str, default='', help='') 16 | parser.add_argument('--target_path', type=str, default='', help='') 17 | 18 | parser.add_argument('--target_prefix', type=str, default='', help='') 19 | 20 | # Parse the arguments 21 | args = parser.parse_args() 22 | 23 | unet_source_path = args.source_path + "/source.txt" 24 | unet_target_path = args.target_path + "/target.txt" 25 | 26 | text_encoder_source_path = args.source_path + "/text_source.txt" 27 | text_encoder_target_path = args.target_path + "/text_target.txt" 28 | 29 | vae_source_path = args.source_path + "/vae_source.txt" 30 | vae_target_path = args.target_path + "/vae_target.txt" 31 | 32 | 33 | tensor_dict = load_file(args.tensor_path) 34 | 35 | state_dict = torch.load(args.unet_path, map_location='cpu') 36 | text_state_dict = torch.load(args.text_encoder_path, map_location='cpu') 37 | vae_state_dict = torch.load(args.vae_path, map_location='cpu') 38 | state_dict = {**vae_state_dict, **state_dict, **text_state_dict} 39 | 40 | # Convert diffusion model 41 | f = open(unet_source_path,'r') 42 | source = f.readlines() 43 | f.close() 44 | 45 | f = open(unet_target_path,'r') 46 | target = f.readlines() 47 | f.close() 48 | 49 | state_dict2 = {} 50 | for source_key, target_key in zip(source, target): 51 | source_key = source_key.strip() 52 | target_key = target_key.strip() 53 | 54 | if tensor_dict[source_key].shape == state_dict[target_key].shape and source_key != 'model.diffusion_model.input_blocks.0.0.weight': 55 | state_dict2[target_key] = tensor_dict[source_key] - state_dict[target_key] 56 | elif source_key == 'model.diffusion_model.input_blocks.0.0.weight': 57 | delta_weight = torch.cat([tensor_dict[source_key] - state_dict[target_key], torch.zeros([320,5,3,3])], dim=1) 58 | state_dict2[target_key] = delta_weight 59 | 60 | torch.save(state_dict2, f'{args.target_prefix}_unet_delta.pth') 61 | 62 | # Convert text encoder model 63 | f = open(text_encoder_source_path,'r') 64 | source = f.readlines() 65 | f.close() 66 | 67 | f = open(text_encoder_target_path,'r') 68 | target = f.readlines() 69 | f.close() 70 | 71 | state_dict2 = {} 72 | for source_key, target_key in zip(source, target): 73 | source_key = source_key.strip() 74 | target_key = target_key.strip() 75 | 76 | if tensor_dict[source_key].shape == state_dict[target_key].shape: 77 | state_dict2[target_key] = tensor_dict[source_key] - state_dict[target_key] 78 | else: 79 | print(source_key, tensor_dict[source_key].shape, state_dict[target_key].shape) 80 | 81 | torch.save(state_dict2, f'{args.target_prefix}_text_encoder_delta.pth') 82 | 83 | # Convert vae model 84 | f = open(vae_source_path,'r') 85 | source = f.readlines() 86 | f.close() 87 | 88 | f = open(vae_target_path,'r') 89 | target = f.readlines() 90 | f.close() 91 | 92 | state_dict2 = {} 93 | for source_key, target_key in zip(source, target): 94 | source_key = source_key.strip() 95 | target_key = target_key.strip() 96 | 97 | if tensor_dict[source_key].shape == state_dict[target_key].shape: 98 | state_dict2[target_key] = tensor_dict[source_key] - state_dict[target_key] 99 | else: 100 | state_dict2[target_key] = tensor_dict[source_key].squeeze() - state_dict[target_key] 101 | print(source_key, tensor_dict[source_key].shape, state_dict[target_key].shape) 102 | 103 | torch.save(state_dict2, f'{args.target_prefix}_vae_delta.pth') 104 | -------------------------------------------------------------------------------- /task_vector/convert_lora.py: -------------------------------------------------------------------------------- 1 | import re 2 | import json 3 | import torch 4 | from safetensors.torch import load_file 5 | import argparse 6 | 7 | # Create ArgumentParser object 8 | parser = argparse.ArgumentParser(description='Process some integers.') 9 | 10 | # Add arguments 11 | parser.add_argument('--tensor_path', type=str, default='', help='') 12 | parser.add_argument('--unet_path', type=str, default='', help='') 13 | parser.add_argument('--text_encoder_path', type=str, default='', help='') 14 | parser.add_argument('--vae_path', type=str, default='', help='') 15 | 16 | parser.add_argument('--regulation_path', type=str, default='', help='') 17 | parser.add_argument('--target_prefix', type=str, default='', help='') 18 | 19 | # Parse the arguments 20 | args = parser.parse_args() 21 | 22 | tensor_dict = load_file(args.tensor_path) 23 | 24 | state_dict = torch.load(args.unet_path, map_location='cpu') 25 | text_state_dict = torch.load(args.text_encoder_path, map_location='cpu') 26 | vae_state_dict = torch.load(args.vae_path, map_location='cpu') 27 | state_dict = {**vae_state_dict, **state_dict, **text_state_dict} 28 | 29 | json_list = json.load(open(f"{args.regulation_path}","r")) 30 | for item in json_list: 31 | key = item["start"] 32 | value = item["end"] 33 | 34 | state_dict2 = {} 35 | for key in tensor_dict: 36 | print(key) 37 | org_key = key 38 | for it in json_list: 39 | reg = it["regression"] 40 | key2 = it["start"] 41 | value = it["end"] 42 | key = key.replace(key2, value) 43 | if reg: 44 | key = re.sub(key2, value, key) 45 | 46 | key = key.replace(".lora_up","") 47 | key = key.replace(".lora_down", "") 48 | 49 | if 'lora_up' in org_key: 50 | alpha_ = tensor_dict[org_key.split('.lora_up')[0]+'.alpha'] 51 | lora_down = org_key.split('.lora_up')[0]+'.lora_down.weight' 52 | rank = tensor_dict[org_key].shape[1] 53 | if len(tensor_dict[org_key].shape) == 4: 54 | lora1 = tensor_dict[org_key].float() 55 | w1,_,w3,w4 = lora1.shape 56 | lora1 = lora1.permute(0,2,3,1).contiguous() 57 | lora1 = lora1.view(-1, rank) 58 | lora2 = tensor_dict[lora_down].float() 59 | _,w2,w3,w4 = lora2.shape 60 | lora2 = lora2.view(rank, -1) 61 | weight = alpha_*lora1@lora2/rank 62 | weight = weight.view(w1,w2,w3,w4) 63 | else: 64 | lora1 = tensor_dict[org_key].float() 65 | lora2 = tensor_dict[lora_down].float() 66 | weight = lora1@lora2 67 | state_dict2[key] = weight 68 | 69 | if key.endswith('.alpha'): 70 | continue 71 | if key in state_dict: 72 | continue 73 | else: 74 | print('### The key doesn\'t match!') 75 | exit(0) 76 | 77 | text_state_dict = {} 78 | vae_state_dict = {} 79 | unet_state_dict = {} 80 | for key in state_dict2: 81 | if 'text' in key: 82 | text_state_dict[key] = state_dict2[key] 83 | elif key.startswith('encoder.') or key.startswith('decoder.'): 84 | vae_state_dict[key] = state_dict2[key] 85 | else: 86 | unet_state_dict[key] = state_dict2[key] 87 | 88 | if len(text_state_dict) > 0: 89 | torch.save(text_state_dict, f'./text_{args.target_prefix}_delta.pth') 90 | if len(vae_state_dict) > 0: 91 | torch.save(vae_state_dict, f'./vae_{args.target_prefix}_delta.pth') 92 | if len(unet_state_dict): 93 | torch.save(unet_state_dict, f'./unet_{args.target_prefix}_delta.pth') 94 | 95 | -------------------------------------------------------------------------------- /task_vector/lora.json: -------------------------------------------------------------------------------- 1 | [{"regression": false, "start": "lora_unet_", "end": ""}, {"regression": false, "start": "up_blocks_", "end": "up_blocks."}, {"regression": false, "start": "down_blocks_", "end": "down_blocks."}, {"regression": false, "start": "_attentions_", "end": ".attentions."}, {"regression": false, "start": "_transformer_blocks_", "end": ".transformer_blocks."}, {"regression": false, "start": "_attn", "end": ".attn"}, {"regression": false, "start": "_ff", "end": ".ff"}, {"regression": false, "start": "_to", "end": ".to"}, {"regression": false, "start": "ff_", "end": "ff."}, {"regression": false, "start": "net_", "end": "net."}, {"regression": false, "start": "_proj", "end": ".proj"}, {"regression": false, "start": "to_out_", "end": "to_out."}, {"regression": false, "start": "lora_te_", "end": ""}, {"regression": false, "start": "text_model_", "end": "text_model."}, {"regression": false, "start": "text_model.encoder_", "end": "text_model.encoder."}, {"regression": false, "start": "encoder.layers_", "end": "encoder.layers."}, {"regression": false, "start": "_self.attn_", "end": ".self_attn."}, {"regression": false, "start": "attn.out.proj", "end": "attn.out_proj"}, {"regression": false, "start": ".q.proj.", "end": ".q_proj."}, {"regression": false, "start": ".v.proj.", "end": ".v_proj."}, {"regression": false, "start": "_mlp_", "end": ".mlp."}, {"regression": false, "start": ".k.proj.", "end": ".k_proj."}, {"regression": false, "start": "_resnets_", "end": ".resnets."}, {"regression": false, "start": "_conv", "end": ".conv"}, {"regression": false, "start": "_time_emb.proj", "end": ".time_emb_proj"}, {"regression": false, "start": "_upsamplers_", "end": ".upsamplers."}, {"regression": false, "start": "_downsamplers_", "end": ".downsamplers."}, {"regression": true, "start": "(\\d+)\\.(\\d+)\\.transformer_blocks", "end": "\\1.attentions.\\2.transformer_blocks"}, {"regression": true, "start": "(\\d+)\\.(\\d+)\\.", "end": "\\1.resnets.\\2."}] 2 | -------------------------------------------------------------------------------- /task_vector/resources/text_target.txt: -------------------------------------------------------------------------------- 1 | text_model.embeddings.position_embedding.weight 2 | text_model.embeddings.position_ids 3 | text_model.embeddings.token_embedding.weight 4 | text_model.encoder.layers.0.self_attn.k_proj.bias 5 | text_model.encoder.layers.0.self_attn.k_proj.weight 6 | text_model.encoder.layers.0.self_attn.v_proj.bias 7 | text_model.encoder.layers.0.self_attn.v_proj.weight 8 | text_model.encoder.layers.0.self_attn.q_proj.bias 9 | text_model.encoder.layers.0.self_attn.q_proj.weight 10 | text_model.encoder.layers.0.self_attn.out_proj.bias 11 | text_model.encoder.layers.0.self_attn.out_proj.weight 12 | text_model.encoder.layers.0.layer_norm1.bias 13 | text_model.encoder.layers.0.layer_norm1.weight 14 | text_model.encoder.layers.0.layer_norm2.bias 15 | text_model.encoder.layers.0.layer_norm2.weight 16 | text_model.encoder.layers.0.mlp.fc1.bias 17 | text_model.encoder.layers.0.mlp.fc1.weight 18 | text_model.encoder.layers.0.mlp.fc2.bias 19 | text_model.encoder.layers.0.mlp.fc2.weight 20 | text_model.encoder.layers.1.self_attn.k_proj.bias 21 | text_model.encoder.layers.1.self_attn.k_proj.weight 22 | text_model.encoder.layers.1.self_attn.v_proj.bias 23 | text_model.encoder.layers.1.self_attn.v_proj.weight 24 | text_model.encoder.layers.1.self_attn.q_proj.bias 25 | text_model.encoder.layers.1.self_attn.q_proj.weight 26 | text_model.encoder.layers.1.self_attn.out_proj.bias 27 | text_model.encoder.layers.1.self_attn.out_proj.weight 28 | text_model.encoder.layers.1.layer_norm1.bias 29 | text_model.encoder.layers.1.layer_norm1.weight 30 | text_model.encoder.layers.1.layer_norm2.bias 31 | text_model.encoder.layers.1.layer_norm2.weight 32 | text_model.encoder.layers.1.mlp.fc1.bias 33 | text_model.encoder.layers.1.mlp.fc1.weight 34 | text_model.encoder.layers.1.mlp.fc2.bias 35 | text_model.encoder.layers.1.mlp.fc2.weight 36 | text_model.encoder.layers.2.self_attn.k_proj.bias 37 | text_model.encoder.layers.2.self_attn.k_proj.weight 38 | text_model.encoder.layers.2.self_attn.v_proj.bias 39 | text_model.encoder.layers.2.self_attn.v_proj.weight 40 | text_model.encoder.layers.2.self_attn.q_proj.bias 41 | text_model.encoder.layers.2.self_attn.q_proj.weight 42 | text_model.encoder.layers.2.self_attn.out_proj.bias 43 | text_model.encoder.layers.2.self_attn.out_proj.weight 44 | text_model.encoder.layers.2.layer_norm1.bias 45 | text_model.encoder.layers.2.layer_norm1.weight 46 | text_model.encoder.layers.2.layer_norm2.bias 47 | text_model.encoder.layers.2.layer_norm2.weight 48 | text_model.encoder.layers.2.mlp.fc1.bias 49 | text_model.encoder.layers.2.mlp.fc1.weight 50 | text_model.encoder.layers.2.mlp.fc2.bias 51 | text_model.encoder.layers.2.mlp.fc2.weight 52 | text_model.encoder.layers.3.self_attn.k_proj.bias 53 | text_model.encoder.layers.3.self_attn.k_proj.weight 54 | text_model.encoder.layers.3.self_attn.v_proj.bias 55 | text_model.encoder.layers.3.self_attn.v_proj.weight 56 | text_model.encoder.layers.3.self_attn.q_proj.bias 57 | text_model.encoder.layers.3.self_attn.q_proj.weight 58 | text_model.encoder.layers.3.self_attn.out_proj.bias 59 | text_model.encoder.layers.3.self_attn.out_proj.weight 60 | text_model.encoder.layers.3.layer_norm1.bias 61 | text_model.encoder.layers.3.layer_norm1.weight 62 | text_model.encoder.layers.3.layer_norm2.bias 63 | text_model.encoder.layers.3.layer_norm2.weight 64 | text_model.encoder.layers.3.mlp.fc1.bias 65 | text_model.encoder.layers.3.mlp.fc1.weight 66 | text_model.encoder.layers.3.mlp.fc2.bias 67 | text_model.encoder.layers.3.mlp.fc2.weight 68 | text_model.encoder.layers.4.self_attn.k_proj.bias 69 | text_model.encoder.layers.4.self_attn.k_proj.weight 70 | text_model.encoder.layers.4.self_attn.v_proj.bias 71 | text_model.encoder.layers.4.self_attn.v_proj.weight 72 | text_model.encoder.layers.4.self_attn.q_proj.bias 73 | text_model.encoder.layers.4.self_attn.q_proj.weight 74 | text_model.encoder.layers.4.self_attn.out_proj.bias 75 | text_model.encoder.layers.4.self_attn.out_proj.weight 76 | text_model.encoder.layers.4.layer_norm1.bias 77 | text_model.encoder.layers.4.layer_norm1.weight 78 | text_model.encoder.layers.4.layer_norm2.bias 79 | text_model.encoder.layers.4.layer_norm2.weight 80 | text_model.encoder.layers.4.mlp.fc1.bias 81 | text_model.encoder.layers.4.mlp.fc1.weight 82 | text_model.encoder.layers.4.mlp.fc2.bias 83 | text_model.encoder.layers.4.mlp.fc2.weight 84 | text_model.encoder.layers.5.self_attn.k_proj.bias 85 | text_model.encoder.layers.5.self_attn.k_proj.weight 86 | text_model.encoder.layers.5.self_attn.v_proj.bias 87 | text_model.encoder.layers.5.self_attn.v_proj.weight 88 | text_model.encoder.layers.5.self_attn.q_proj.bias 89 | text_model.encoder.layers.5.self_attn.q_proj.weight 90 | text_model.encoder.layers.5.self_attn.out_proj.bias 91 | text_model.encoder.layers.5.self_attn.out_proj.weight 92 | text_model.encoder.layers.5.layer_norm1.bias 93 | text_model.encoder.layers.5.layer_norm1.weight 94 | text_model.encoder.layers.5.layer_norm2.bias 95 | text_model.encoder.layers.5.layer_norm2.weight 96 | text_model.encoder.layers.5.mlp.fc1.bias 97 | text_model.encoder.layers.5.mlp.fc1.weight 98 | text_model.encoder.layers.5.mlp.fc2.bias 99 | text_model.encoder.layers.5.mlp.fc2.weight 100 | text_model.encoder.layers.6.self_attn.k_proj.bias 101 | text_model.encoder.layers.6.self_attn.k_proj.weight 102 | text_model.encoder.layers.6.self_attn.v_proj.bias 103 | text_model.encoder.layers.6.self_attn.v_proj.weight 104 | text_model.encoder.layers.6.self_attn.q_proj.bias 105 | text_model.encoder.layers.6.self_attn.q_proj.weight 106 | text_model.encoder.layers.6.self_attn.out_proj.bias 107 | text_model.encoder.layers.6.self_attn.out_proj.weight 108 | text_model.encoder.layers.6.layer_norm1.bias 109 | text_model.encoder.layers.6.layer_norm1.weight 110 | text_model.encoder.layers.6.layer_norm2.bias 111 | text_model.encoder.layers.6.layer_norm2.weight 112 | text_model.encoder.layers.6.mlp.fc1.bias 113 | text_model.encoder.layers.6.mlp.fc1.weight 114 | text_model.encoder.layers.6.mlp.fc2.bias 115 | text_model.encoder.layers.6.mlp.fc2.weight 116 | text_model.encoder.layers.7.self_attn.k_proj.bias 117 | text_model.encoder.layers.7.self_attn.k_proj.weight 118 | text_model.encoder.layers.7.self_attn.v_proj.bias 119 | text_model.encoder.layers.7.self_attn.v_proj.weight 120 | text_model.encoder.layers.7.self_attn.q_proj.bias 121 | text_model.encoder.layers.7.self_attn.q_proj.weight 122 | text_model.encoder.layers.7.self_attn.out_proj.bias 123 | text_model.encoder.layers.7.self_attn.out_proj.weight 124 | text_model.encoder.layers.7.layer_norm1.bias 125 | text_model.encoder.layers.7.layer_norm1.weight 126 | text_model.encoder.layers.7.layer_norm2.bias 127 | text_model.encoder.layers.7.layer_norm2.weight 128 | text_model.encoder.layers.7.mlp.fc1.bias 129 | text_model.encoder.layers.7.mlp.fc1.weight 130 | text_model.encoder.layers.7.mlp.fc2.bias 131 | text_model.encoder.layers.7.mlp.fc2.weight 132 | text_model.encoder.layers.8.self_attn.k_proj.bias 133 | text_model.encoder.layers.8.self_attn.k_proj.weight 134 | text_model.encoder.layers.8.self_attn.v_proj.bias 135 | text_model.encoder.layers.8.self_attn.v_proj.weight 136 | text_model.encoder.layers.8.self_attn.q_proj.bias 137 | text_model.encoder.layers.8.self_attn.q_proj.weight 138 | text_model.encoder.layers.8.self_attn.out_proj.bias 139 | text_model.encoder.layers.8.self_attn.out_proj.weight 140 | text_model.encoder.layers.8.layer_norm1.bias 141 | text_model.encoder.layers.8.layer_norm1.weight 142 | text_model.encoder.layers.8.layer_norm2.bias 143 | text_model.encoder.layers.8.layer_norm2.weight 144 | text_model.encoder.layers.8.mlp.fc1.bias 145 | text_model.encoder.layers.8.mlp.fc1.weight 146 | text_model.encoder.layers.8.mlp.fc2.bias 147 | text_model.encoder.layers.8.mlp.fc2.weight 148 | text_model.encoder.layers.9.self_attn.k_proj.bias 149 | text_model.encoder.layers.9.self_attn.k_proj.weight 150 | text_model.encoder.layers.9.self_attn.v_proj.bias 151 | text_model.encoder.layers.9.self_attn.v_proj.weight 152 | text_model.encoder.layers.9.self_attn.q_proj.bias 153 | text_model.encoder.layers.9.self_attn.q_proj.weight 154 | text_model.encoder.layers.9.self_attn.out_proj.bias 155 | text_model.encoder.layers.9.self_attn.out_proj.weight 156 | text_model.encoder.layers.9.layer_norm1.bias 157 | text_model.encoder.layers.9.layer_norm1.weight 158 | text_model.encoder.layers.9.layer_norm2.bias 159 | text_model.encoder.layers.9.layer_norm2.weight 160 | text_model.encoder.layers.9.mlp.fc1.bias 161 | text_model.encoder.layers.9.mlp.fc1.weight 162 | text_model.encoder.layers.9.mlp.fc2.bias 163 | text_model.encoder.layers.9.mlp.fc2.weight 164 | text_model.encoder.layers.10.self_attn.k_proj.bias 165 | text_model.encoder.layers.10.self_attn.k_proj.weight 166 | text_model.encoder.layers.10.self_attn.v_proj.bias 167 | text_model.encoder.layers.10.self_attn.v_proj.weight 168 | text_model.encoder.layers.10.self_attn.q_proj.bias 169 | text_model.encoder.layers.10.self_attn.q_proj.weight 170 | text_model.encoder.layers.10.self_attn.out_proj.bias 171 | text_model.encoder.layers.10.self_attn.out_proj.weight 172 | text_model.encoder.layers.10.layer_norm1.bias 173 | text_model.encoder.layers.10.layer_norm1.weight 174 | text_model.encoder.layers.10.layer_norm2.bias 175 | text_model.encoder.layers.10.layer_norm2.weight 176 | text_model.encoder.layers.10.mlp.fc1.bias 177 | text_model.encoder.layers.10.mlp.fc1.weight 178 | text_model.encoder.layers.10.mlp.fc2.bias 179 | text_model.encoder.layers.10.mlp.fc2.weight 180 | text_model.encoder.layers.11.self_attn.k_proj.bias 181 | text_model.encoder.layers.11.self_attn.k_proj.weight 182 | text_model.encoder.layers.11.self_attn.v_proj.bias 183 | text_model.encoder.layers.11.self_attn.v_proj.weight 184 | text_model.encoder.layers.11.self_attn.q_proj.bias 185 | text_model.encoder.layers.11.self_attn.q_proj.weight 186 | text_model.encoder.layers.11.self_attn.out_proj.bias 187 | text_model.encoder.layers.11.self_attn.out_proj.weight 188 | text_model.encoder.layers.11.layer_norm1.bias 189 | text_model.encoder.layers.11.layer_norm1.weight 190 | text_model.encoder.layers.11.layer_norm2.bias 191 | text_model.encoder.layers.11.layer_norm2.weight 192 | text_model.encoder.layers.11.mlp.fc1.bias 193 | text_model.encoder.layers.11.mlp.fc1.weight 194 | text_model.encoder.layers.11.mlp.fc2.bias 195 | text_model.encoder.layers.11.mlp.fc2.weight 196 | text_model.final_layer_norm.bias 197 | text_model.final_layer_norm.weight 198 | -------------------------------------------------------------------------------- /task_vector/resources/vae_target.txt: -------------------------------------------------------------------------------- 1 | encoder.conv_in.bias 2 | encoder.conv_in.weight 3 | encoder.down_blocks.0.resnets.0.norm1.bias 4 | encoder.down_blocks.0.resnets.0.norm1.weight 5 | encoder.down_blocks.0.resnets.0.norm2.bias 6 | encoder.down_blocks.0.resnets.0.norm2.weight 7 | encoder.down_blocks.0.resnets.1.norm1.bias 8 | encoder.down_blocks.0.resnets.1.norm1.weight 9 | encoder.down_blocks.0.resnets.1.norm2.bias 10 | encoder.down_blocks.0.resnets.1.norm2.weight 11 | encoder.down_blocks.1.resnets.0.norm1.bias 12 | encoder.down_blocks.1.resnets.0.norm1.weight 13 | encoder.down_blocks.1.resnets.0.norm2.bias 14 | encoder.down_blocks.1.resnets.0.norm2.weight 15 | encoder.down_blocks.1.resnets.1.norm1.bias 16 | encoder.down_blocks.1.resnets.1.norm1.weight 17 | encoder.down_blocks.1.resnets.1.norm2.bias 18 | encoder.down_blocks.1.resnets.1.norm2.weight 19 | encoder.down_blocks.2.resnets.0.norm1.bias 20 | encoder.down_blocks.2.resnets.0.norm1.weight 21 | encoder.down_blocks.2.resnets.0.norm2.bias 22 | encoder.down_blocks.2.resnets.0.norm2.weight 23 | encoder.down_blocks.2.resnets.1.norm1.bias 24 | encoder.down_blocks.2.resnets.1.norm1.weight 25 | encoder.down_blocks.2.resnets.1.norm2.bias 26 | encoder.down_blocks.2.resnets.1.norm2.weight 27 | encoder.down_blocks.3.resnets.0.norm1.bias 28 | encoder.down_blocks.3.resnets.0.norm1.weight 29 | encoder.down_blocks.3.resnets.0.norm2.bias 30 | encoder.down_blocks.3.resnets.0.norm2.weight 31 | encoder.down_blocks.3.resnets.1.norm1.bias 32 | encoder.down_blocks.3.resnets.1.norm1.weight 33 | encoder.down_blocks.3.resnets.1.norm2.bias 34 | encoder.down_blocks.3.resnets.1.norm2.weight 35 | encoder.down_blocks.0.downsamplers.0.conv.bias 36 | encoder.down_blocks.0.downsamplers.0.conv.weight 37 | encoder.down_blocks.1.downsamplers.0.conv.bias 38 | encoder.down_blocks.1.downsamplers.0.conv.weight 39 | encoder.down_blocks.1.resnets.0.conv_shortcut.bias 40 | encoder.down_blocks.1.resnets.0.conv_shortcut.weight 41 | encoder.down_blocks.2.resnets.0.conv_shortcut.bias 42 | encoder.down_blocks.2.resnets.0.conv_shortcut.weight 43 | encoder.down_blocks.2.downsamplers.0.conv.bias 44 | encoder.down_blocks.2.downsamplers.0.conv.weight 45 | encoder.down_blocks.0.resnets.0.conv1.bias 46 | encoder.down_blocks.0.resnets.0.conv1.weight 47 | encoder.down_blocks.0.resnets.0.conv2.bias 48 | encoder.down_blocks.0.resnets.0.conv2.weight 49 | encoder.down_blocks.0.resnets.1.conv1.bias 50 | encoder.down_blocks.0.resnets.1.conv1.weight 51 | encoder.down_blocks.0.resnets.1.conv2.bias 52 | encoder.down_blocks.0.resnets.1.conv2.weight 53 | encoder.down_blocks.1.resnets.0.conv1.bias 54 | encoder.down_blocks.1.resnets.0.conv1.weight 55 | encoder.down_blocks.1.resnets.0.conv2.bias 56 | encoder.down_blocks.1.resnets.0.conv2.weight 57 | encoder.down_blocks.1.resnets.1.conv1.bias 58 | encoder.down_blocks.1.resnets.1.conv1.weight 59 | encoder.down_blocks.1.resnets.1.conv2.bias 60 | encoder.down_blocks.1.resnets.1.conv2.weight 61 | encoder.down_blocks.2.resnets.0.conv1.bias 62 | encoder.down_blocks.2.resnets.0.conv1.weight 63 | encoder.down_blocks.2.resnets.0.conv2.bias 64 | encoder.down_blocks.2.resnets.0.conv2.weight 65 | encoder.down_blocks.2.resnets.1.conv1.bias 66 | encoder.down_blocks.2.resnets.1.conv1.weight 67 | encoder.down_blocks.2.resnets.1.conv2.bias 68 | encoder.down_blocks.2.resnets.1.conv2.weight 69 | encoder.down_blocks.3.resnets.0.conv1.bias 70 | encoder.down_blocks.3.resnets.0.conv1.weight 71 | encoder.down_blocks.3.resnets.0.conv2.bias 72 | encoder.down_blocks.3.resnets.0.conv2.weight 73 | encoder.down_blocks.3.resnets.1.conv1.bias 74 | encoder.down_blocks.3.resnets.1.conv1.weight 75 | encoder.down_blocks.3.resnets.1.conv2.bias 76 | encoder.down_blocks.3.resnets.1.conv2.weight 77 | encoder.mid_block.attentions.0.group_norm.bias 78 | encoder.mid_block.attentions.0.group_norm.weight 79 | encoder.mid_block.attentions.0.key.bias 80 | encoder.mid_block.attentions.0.key.weight 81 | encoder.mid_block.attentions.0.query.bias 82 | encoder.mid_block.attentions.0.query.weight 83 | encoder.mid_block.attentions.0.value.bias 84 | encoder.mid_block.attentions.0.value.weight 85 | encoder.mid_block.attentions.0.proj_attn.bias 86 | encoder.mid_block.attentions.0.proj_attn.weight 87 | encoder.mid_block.resnets.0.norm1.bias 88 | encoder.mid_block.resnets.0.norm1.weight 89 | encoder.mid_block.resnets.0.norm2.bias 90 | encoder.mid_block.resnets.0.norm2.weight 91 | encoder.mid_block.resnets.1.norm1.bias 92 | encoder.mid_block.resnets.1.norm1.weight 93 | encoder.mid_block.resnets.1.norm2.bias 94 | encoder.mid_block.resnets.1.norm2.weight 95 | encoder.mid_block.resnets.0.conv1.bias 96 | encoder.mid_block.resnets.0.conv1.weight 97 | encoder.mid_block.resnets.0.conv2.bias 98 | encoder.mid_block.resnets.0.conv2.weight 99 | encoder.mid_block.resnets.1.conv1.bias 100 | encoder.mid_block.resnets.1.conv1.weight 101 | encoder.mid_block.resnets.1.conv2.bias 102 | encoder.mid_block.resnets.1.conv2.weight 103 | encoder.conv_out.bias 104 | encoder.conv_out.weight 105 | encoder.conv_norm_out.bias 106 | encoder.conv_norm_out.weight 107 | decoder.conv_in.bias 108 | decoder.conv_in.weight 109 | decoder.conv_out.bias 110 | decoder.conv_out.weight 111 | decoder.up_blocks.0.resnets.0.norm1.bias 112 | decoder.up_blocks.0.resnets.0.norm1.weight 113 | decoder.up_blocks.0.resnets.0.norm2.bias 114 | decoder.up_blocks.0.resnets.0.norm2.weight 115 | decoder.up_blocks.0.resnets.1.norm1.bias 116 | decoder.up_blocks.0.resnets.1.norm1.weight 117 | decoder.up_blocks.0.resnets.1.norm2.bias 118 | decoder.up_blocks.0.resnets.1.norm2.weight 119 | decoder.up_blocks.0.resnets.2.norm1.bias 120 | decoder.up_blocks.0.resnets.2.norm1.weight 121 | decoder.up_blocks.0.resnets.2.norm2.bias 122 | decoder.up_blocks.0.resnets.2.norm2.weight 123 | decoder.up_blocks.1.resnets.0.norm1.bias 124 | decoder.up_blocks.1.resnets.0.norm1.weight 125 | decoder.up_blocks.1.resnets.0.norm2.bias 126 | decoder.up_blocks.1.resnets.0.norm2.weight 127 | decoder.up_blocks.1.resnets.1.norm1.bias 128 | decoder.up_blocks.1.resnets.1.norm1.weight 129 | decoder.up_blocks.1.resnets.1.norm2.bias 130 | decoder.up_blocks.1.resnets.1.norm2.weight 131 | decoder.up_blocks.1.resnets.2.norm1.bias 132 | decoder.up_blocks.1.resnets.2.norm1.weight 133 | decoder.up_blocks.1.resnets.2.norm2.bias 134 | decoder.up_blocks.1.resnets.2.norm2.weight 135 | decoder.up_blocks.2.resnets.0.norm1.bias 136 | decoder.up_blocks.2.resnets.0.norm1.weight 137 | decoder.up_blocks.2.resnets.0.norm2.bias 138 | decoder.up_blocks.2.resnets.0.norm2.weight 139 | decoder.up_blocks.2.resnets.1.norm1.bias 140 | decoder.up_blocks.2.resnets.1.norm1.weight 141 | decoder.up_blocks.2.resnets.1.norm2.bias 142 | decoder.up_blocks.2.resnets.1.norm2.weight 143 | decoder.up_blocks.2.resnets.2.norm1.bias 144 | decoder.up_blocks.2.resnets.2.norm1.weight 145 | decoder.up_blocks.2.resnets.2.norm2.bias 146 | decoder.up_blocks.2.resnets.2.norm2.weight 147 | decoder.up_blocks.3.resnets.0.norm1.bias 148 | decoder.up_blocks.3.resnets.0.norm1.weight 149 | decoder.up_blocks.3.resnets.0.norm2.bias 150 | decoder.up_blocks.3.resnets.0.norm2.weight 151 | decoder.up_blocks.3.resnets.1.norm1.bias 152 | decoder.up_blocks.3.resnets.1.norm1.weight 153 | decoder.up_blocks.3.resnets.1.norm2.bias 154 | decoder.up_blocks.3.resnets.1.norm2.weight 155 | decoder.up_blocks.3.resnets.2.norm1.bias 156 | decoder.up_blocks.3.resnets.2.norm1.weight 157 | decoder.up_blocks.3.resnets.2.norm2.bias 158 | decoder.up_blocks.3.resnets.2.norm2.weight 159 | decoder.up_blocks.0.resnets.0.conv1.bias 160 | decoder.up_blocks.0.resnets.0.conv1.weight 161 | decoder.up_blocks.0.resnets.0.conv2.bias 162 | decoder.up_blocks.0.resnets.0.conv2.weight 163 | decoder.up_blocks.0.resnets.1.conv1.bias 164 | decoder.up_blocks.0.resnets.1.conv1.weight 165 | decoder.up_blocks.0.resnets.1.conv2.bias 166 | decoder.up_blocks.0.resnets.1.conv2.weight 167 | decoder.up_blocks.0.resnets.2.conv1.bias 168 | decoder.up_blocks.0.resnets.2.conv1.weight 169 | decoder.up_blocks.0.resnets.2.conv2.bias 170 | decoder.up_blocks.0.resnets.2.conv2.weight 171 | decoder.up_blocks.1.resnets.0.conv1.bias 172 | decoder.up_blocks.1.resnets.0.conv1.weight 173 | decoder.up_blocks.1.resnets.0.conv2.bias 174 | decoder.up_blocks.1.resnets.0.conv2.weight 175 | decoder.up_blocks.1.resnets.1.conv1.bias 176 | decoder.up_blocks.1.resnets.1.conv1.weight 177 | decoder.up_blocks.1.resnets.1.conv2.bias 178 | decoder.up_blocks.1.resnets.1.conv2.weight 179 | decoder.up_blocks.1.resnets.2.conv1.bias 180 | decoder.up_blocks.1.resnets.2.conv1.weight 181 | decoder.up_blocks.1.resnets.2.conv2.bias 182 | decoder.up_blocks.1.resnets.2.conv2.weight 183 | decoder.up_blocks.2.resnets.0.conv1.bias 184 | decoder.up_blocks.2.resnets.0.conv1.weight 185 | decoder.up_blocks.2.resnets.0.conv2.bias 186 | decoder.up_blocks.2.resnets.0.conv2.weight 187 | decoder.up_blocks.2.resnets.1.conv1.bias 188 | decoder.up_blocks.2.resnets.1.conv1.weight 189 | decoder.up_blocks.2.resnets.1.conv2.bias 190 | decoder.up_blocks.2.resnets.1.conv2.weight 191 | decoder.up_blocks.2.resnets.2.conv1.bias 192 | decoder.up_blocks.2.resnets.2.conv1.weight 193 | decoder.up_blocks.2.resnets.2.conv2.bias 194 | decoder.up_blocks.2.resnets.2.conv2.weight 195 | decoder.up_blocks.3.resnets.0.conv1.bias 196 | decoder.up_blocks.3.resnets.0.conv1.weight 197 | decoder.up_blocks.3.resnets.0.conv2.bias 198 | decoder.up_blocks.3.resnets.0.conv2.weight 199 | decoder.up_blocks.3.resnets.1.conv1.bias 200 | decoder.up_blocks.3.resnets.1.conv1.weight 201 | decoder.up_blocks.3.resnets.1.conv2.bias 202 | decoder.up_blocks.3.resnets.1.conv2.weight 203 | decoder.up_blocks.3.resnets.2.conv1.bias 204 | decoder.up_blocks.3.resnets.2.conv1.weight 205 | decoder.up_blocks.3.resnets.2.conv2.bias 206 | decoder.up_blocks.3.resnets.2.conv2.weight 207 | decoder.up_blocks.0.upsamplers.0.conv.bias 208 | decoder.up_blocks.0.upsamplers.0.conv.weight 209 | decoder.up_blocks.1.upsamplers.0.conv.bias 210 | decoder.up_blocks.1.upsamplers.0.conv.weight 211 | decoder.up_blocks.2.upsamplers.0.conv.bias 212 | decoder.up_blocks.2.upsamplers.0.conv.weight 213 | decoder.up_blocks.2.resnets.0.conv_shortcut.bias 214 | decoder.up_blocks.2.resnets.0.conv_shortcut.weight 215 | decoder.up_blocks.3.resnets.0.conv_shortcut.bias 216 | decoder.up_blocks.3.resnets.0.conv_shortcut.weight 217 | decoder.mid_block.attentions.0.query.bias 218 | decoder.mid_block.attentions.0.query.weight 219 | decoder.mid_block.attentions.0.key.bias 220 | decoder.mid_block.attentions.0.key.weight 221 | decoder.mid_block.attentions.0.value.bias 222 | decoder.mid_block.attentions.0.value.weight 223 | decoder.mid_block.attentions.0.group_norm.bias 224 | decoder.mid_block.attentions.0.group_norm.weight 225 | decoder.mid_block.attentions.0.proj_attn.bias 226 | decoder.mid_block.attentions.0.proj_attn.weight 227 | decoder.mid_block.resnets.0.conv1.bias 228 | decoder.mid_block.resnets.0.conv1.weight 229 | decoder.mid_block.resnets.0.conv2.bias 230 | decoder.mid_block.resnets.0.conv2.weight 231 | decoder.mid_block.resnets.1.conv1.bias 232 | decoder.mid_block.resnets.1.conv1.weight 233 | decoder.mid_block.resnets.1.conv2.bias 234 | decoder.mid_block.resnets.1.conv2.weight 235 | decoder.mid_block.resnets.0.norm1.bias 236 | decoder.mid_block.resnets.0.norm1.weight 237 | decoder.mid_block.resnets.0.norm2.bias 238 | decoder.mid_block.resnets.0.norm2.weight 239 | decoder.mid_block.resnets.1.norm1.bias 240 | decoder.mid_block.resnets.1.norm1.weight 241 | decoder.mid_block.resnets.1.norm2.bias 242 | decoder.mid_block.resnets.1.norm2.weight 243 | decoder.conv_norm_out.bias 244 | decoder.conv_norm_out.weight 245 | quant_conv.bias 246 | quant_conv.weight 247 | post_quant_conv.bias 248 | post_quant_conv.weight 249 | -------------------------------------------------------------------------------- /tools/README.md: -------------------------------------------------------------------------------- 1 | ## SAM 2 toolkits 2 | 3 | This directory provides toolkits for additional SAM 2 use cases. 4 | 5 | ### Semi-supervised VOS inference 6 | 7 | The `vos_inference.py` script can be used to generate predictions for semi-supervised video object segmentation (VOS) evaluation on datasets such as [DAVIS](https://davischallenge.org/index.html), [MOSE](https://henghuiding.github.io/MOSE/) or the SA-V dataset. 8 | 9 | After installing SAM 2 and its dependencies, it can be used as follows ([DAVIS 2017 dataset](https://davischallenge.org/davis2017/code.html) as an example). This script saves the prediction PNG files to the `--output_mask_dir`. 10 | ```bash 11 | python ./tools/vos_inference.py \ 12 | --sam2_cfg sam2_hiera_b+.yaml \ 13 | --sam2_checkpoint ./checkpoints/sam2_hiera_base_plus.pt \ 14 | --base_video_dir /path-to-davis-2017/JPEGImages/480p \ 15 | --input_mask_dir /path-to-davis-2017/Annotations/480p \ 16 | --video_list_file /path-to-davis-2017/ImageSets/2017/val.txt \ 17 | --output_mask_dir ./outputs/davis_2017_pred_pngs 18 | ``` 19 | (replace `/path-to-davis-2017` with the path to DAVIS 2017 dataset) 20 | 21 | To evaluate on the SA-V dataset with per-object PNG files for the object masks, we need to **add the `--per_obj_png_file` flag** as follows (using SA-V val as an example). This script will also save per-object PNG files for the output masks under the `--per_obj_png_file` flag. 22 | ```bash 23 | python ./tools/vos_inference.py \ 24 | --sam2_cfg sam2_hiera_b+.yaml \ 25 | --sam2_checkpoint ./checkpoints/sam2_hiera_base_plus.pt \ 26 | --base_video_dir /path-to-sav-val/JPEGImages_24fps \ 27 | --input_mask_dir /path-to-sav-val/Annotations_6fps \ 28 | --video_list_file /path-to-sav-val/sav_val.txt \ 29 | --per_obj_png_file \ 30 | --output_mask_dir ./outputs/sav_val_pred_pngs 31 | ``` 32 | (replace `/path-to-sav-val` with the path to SA-V val) 33 | 34 | Then, we can use the evaluation tools or servers for each dataset to get the performance of the prediction PNG files above. 35 | 36 | **Note: a limitation of the `vos_inference.py` script above is that currently it only supports VOS datasets where all objects to track already appear on frame 0 in each video** (and therefore it doesn't apply to some datasets such as [LVOS](https://lingyihongfd.github.io/lvos.github.io/) that have objects only appearing in the middle of a video). 37 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import wandb 5 | import random 6 | import logging 7 | import inspect 8 | import argparse 9 | import datetime 10 | import subprocess 11 | 12 | import cv2 13 | from pathlib import Path 14 | import numpy as np 15 | from tqdm.auto import tqdm 16 | from einops import rearrange 17 | from omegaconf import OmegaConf 18 | from safetensors import safe_open 19 | from typing import Dict, Optional, Tuple 20 | 21 | import torch 22 | import torchvision 23 | import torch.nn.functional as F 24 | import torch.distributed as dist 25 | from torch.optim.swa_utils import AveragedModel 26 | from torch.utils.data.distributed import DistributedSampler 27 | from torch.nn.parallel import DistributedDataParallel as DDP 28 | 29 | import diffusers 30 | from diffusers import AutoencoderKL, DDIMScheduler 31 | from diffusers.models import UNet2DConditionModel 32 | from diffusers.pipelines import StableDiffusionInpaintPipeline 33 | from diffusers.optimization import get_scheduler 34 | from diffusers.utils import check_min_version 35 | from diffusers.utils.import_utils import is_xformers_available 36 | 37 | import transformers 38 | from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer 39 | 40 | from cococo.models.unet import UNet3DConditionModel 41 | from cococo.pipelines.pipeline_animation_inpainting_cross_attention_vae import AnimationInpaintPipeline 42 | from cococo.utils.util import save_videos_grid, zero_rank_print 43 | 44 | def load_model( 45 | model_path: str, 46 | 47 | pretrained_model_path: str, 48 | sub_folder: str = "unet", 49 | 50 | text_device: str = "cuda:0", 51 | unet_device: str = "cuda:1", 52 | 53 | unet_checkpoint_path: str = "", 54 | unet_additional_kwargs: Dict = {}, 55 | noise_scheduler_kwargs = None, 56 | 57 | global_seed: int = 42 58 | ): 59 | 60 | # Load scheduler, tokenizer and models. 61 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 62 | 63 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 64 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 65 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 66 | 67 | unet = UNet3DConditionModel.from_pretrained_2d( 68 | pretrained_model_path, subfolder=sub_folder, 69 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) 70 | ) 71 | state_dict = {} 72 | for i in range(4): 73 | state_dict2 = torch.load(f'{model_path}/model_{i}.pth', map_location='cpu') 74 | state_dict = {**state_dict, **state_dict2} 75 | 76 | state_dict2 = {} 77 | for key in state_dict: 78 | if 'pe' in key: 79 | continue 80 | state_dict2[key.split('module.')[1]] = state_dict[key] 81 | 82 | m, u = unet.load_state_dict(state_dict2, strict=False) 83 | 84 | vae = vae.to(text_device).half().eval() 85 | text_encoder = text_encoder.to(text_device).half().eval() 86 | unet = unet.to(unet_device).half().eval() 87 | 88 | validation_pipeline = AnimationInpaintPipeline( 89 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, 90 | ) 91 | validation_pipeline.enable_vae_slicing() 92 | return validation_pipeline 93 | 94 | 95 | def generate_frames(images, masks, output_dir, validation_pipeline, vae, prompt, negative_prompt, guidance_scale, text_device="cuda:0", unet_device="cuda:1"): 96 | pixel_values = torch.tensor(images).to(device=vae.device, dtype=torch.float16) 97 | test_masks = torch.tensor(masks).to(device=vae.device, dtype=torch.float16) 98 | 99 | height,width = images.shape[-3:-1] 100 | 101 | prefix = 'test_release_'+str(guidance_scale)+'_guidance_scale' 102 | 103 | latents = [] 104 | masks = [] 105 | with torch.no_grad(): 106 | for i in range(len(pixel_values)): 107 | pixel_value = rearrange(pixel_values[i:i+1], "f h w c -> f c h w") 108 | test_mask = rearrange(test_masks[i:i+1], "f h w c -> f c h w") 109 | 110 | masked_image = (1-test_mask)*pixel_value 111 | latent = vae.encode(masked_image).latent_dist.sample() 112 | test_mask = torch.nn.functional.interpolate(test_mask, size=latent.shape[-2:]).cuda() 113 | 114 | latent = rearrange(latent, "f c h w -> c f h w") 115 | test_mask = rearrange(test_mask, "f c h w -> c f h w") 116 | 117 | latent = latent * 0.18215 118 | latents.append(latent) 119 | masks.append(test_mask) 120 | latents = torch.cat(latents,dim=1) 121 | test_masks = torch.cat(masks,dim=1) 122 | 123 | latents = latents[None,...] 124 | masks = test_masks[None,...] 125 | 126 | generator = torch.Generator(device=latents.device) 127 | generator.manual_seed(int(time.time())) 128 | 129 | with torch.no_grad(): 130 | 131 | videos, masked_videos, recon_videos = validation_pipeline( 132 | prompt, 133 | image = latents, 134 | masked_image = latents, 135 | masked_latents = None, 136 | masks = masks, 137 | generator = generator, 138 | video_length = len(images), 139 | negative_prompt = negative_prompt, 140 | height = height, 141 | width = width, 142 | num_inference_steps = 50, 143 | guidance_scale = guidance_scale, 144 | unet_device=unet_device 145 | ) 146 | 147 | videos = videos.permute(0,2,1,3,4).contiguous()/0.18215 148 | 149 | images = [] 150 | for i in range(len(videos[0])): 151 | image = vae.decode(videos[0][i:i+1].half().to(text_device)).sample 152 | images.append(image) 153 | video = torch.cat(images,dim=0) 154 | video = video/2 + 0.5 155 | video = torch.clamp(video, 0, 1) 156 | video = video.permute(0,2,3,1) 157 | 158 | images = [] 159 | video = 255.0*video.cpu().detach().numpy() 160 | for i in range(len(video)): 161 | image = video[i] 162 | image = np.uint8(image) 163 | #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 164 | #cv2.imwrite(output_dir+'/'+prefix+'_image_'+str(i)+'.png',image) 165 | images.append(image) 166 | return images 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--config", type=str, required=True) 171 | parser.add_argument("--prompt", type=str, default="") 172 | parser.add_argument("--negative_prompt", type=str, default="") 173 | parser.add_argument("--model_path", type=str, default="../") 174 | parser.add_argument("--pretrain_model_path", type=str, default="../") 175 | parser.add_argument("--sub_folder", type=str, default="unet") 176 | parser.add_argument("--guidance_scale", type=float, default=20) 177 | parser.add_argument("--video_path", type=str, default="") 178 | args = parser.parse_args() 179 | 180 | config = OmegaConf.load(args.config) 181 | 182 | validation_pipeline = load_model( \ 183 | model_path=args.model_path, \ 184 | sub_folder=args.sub_folder, \ 185 | pretrained_model_path=args.pretrain_model_path, \ 186 | **config 187 | ) 188 | 189 | video_path = args.video_path+'/images.npy' 190 | mask_path = args.video_path+'/masks.npy' 191 | images = 2*(np.load(video_path)/255.0 - 0.5) 192 | masks = np.load(mask_path)/255.0 193 | 194 | generate_frames(\ 195 | images=images, \ 196 | masks=masks, \ 197 | output_dir = './outputs', \ 198 | validation_pipeline=validation_pipeline, \ 199 | vae = validation_pipeline.vae, \ 200 | prompt=args.prompt, \ 201 | negative_prompt=args.negative_prompt, \ 202 | guidance_scale=args.guidance_scale) 203 | 204 | 205 | -------------------------------------------------------------------------------- /utils_with_T2I_LoRA.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import wandb 5 | import random 6 | import logging 7 | import inspect 8 | import argparse 9 | import datetime 10 | import subprocess 11 | 12 | import cv2 13 | from pathlib import Path 14 | import numpy as np 15 | from tqdm.auto import tqdm 16 | from einops import rearrange 17 | from omegaconf import OmegaConf 18 | from safetensors import safe_open 19 | from typing import Dict, Optional, Tuple 20 | 21 | import torch 22 | import torchvision 23 | import torch.nn.functional as F 24 | import torch.distributed as dist 25 | from torch.optim.swa_utils import AveragedModel 26 | from torch.utils.data.distributed import DistributedSampler 27 | from torch.nn.parallel import DistributedDataParallel as DDP 28 | 29 | import diffusers 30 | from diffusers import AutoencoderKL, DDIMScheduler 31 | from diffusers.models import UNet2DConditionModel 32 | from diffusers.pipelines import StableDiffusionInpaintPipeline 33 | from diffusers.optimization import get_scheduler 34 | from diffusers.utils import check_min_version 35 | from diffusers.utils.import_utils import is_xformers_available 36 | 37 | import transformers 38 | from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer 39 | 40 | from cococo.models.unet import UNet3DConditionModel 41 | from cococo.pipelines.pipeline_animation_inpainting_cross_attention_vae import AnimationInpaintPipeline 42 | from cococo.utils.util import save_videos_grid, zero_rank_print 43 | 44 | def load_model( 45 | model_path: str, 46 | 47 | pretrained_model_path: str, 48 | sub_folder: str = "unet", 49 | 50 | text_device: str = "cuda:0", 51 | unet_device: str = "cuda:1", 52 | 53 | unet_checkpoint_path: str = "", 54 | unet_additional_kwargs: Dict = {}, 55 | noise_scheduler_kwargs = None, 56 | 57 | text_model_path = "", 58 | vae_model_path = "", 59 | unet_model_path = "", 60 | 61 | text_lora_path = "", 62 | vae_lora_path = "", 63 | unet_lora_path = "", 64 | beta_text = 0.0, 65 | beta_vae = 0.0, 66 | beta_unet = 0.0, 67 | 68 | global_seed: int = 42 69 | ): 70 | 71 | # Load scheduler, tokenizer and models. 72 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 73 | 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 75 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 76 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 77 | 78 | unet = UNet3DConditionModel.from_pretrained_2d( 79 | pretrained_model_path, subfolder=sub_folder, 80 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) 81 | ) 82 | state_dict = {} 83 | for i in range(4): 84 | state_dict2 = torch.load(f'{model_path}/model_{i}.pth', map_location='cpu') 85 | state_dict = {**state_dict, **state_dict2} 86 | 87 | state_dict2 = {} 88 | for key in state_dict: 89 | if 'pe' in key: 90 | continue 91 | state_dict2[key.split('module.')[1]] = state_dict[key] 92 | 93 | m, u = unet.load_state_dict(state_dict2, strict=False) 94 | 95 | if text_model_path != '': 96 | text_state_dict = torch.load(text_model_path, map_location='cpu') 97 | text_encoder.load_state_dict(text_state_dict) 98 | 99 | if vae_model_path != '': 100 | vae_state_dict = torch.load(vae_model_path, map_location='cpu') 101 | vae.load_state_dict(vae_state_dict) 102 | 103 | if unet_model_path != '': 104 | unet_state_dict = torch.load(unet_model_path, map_location='cpu') 105 | u,m = unet.load_state_dict(unet_state_dict, strict=False) 106 | 107 | if text_lora_path != '': 108 | text_state_dict = text_encoder.state_dict() 109 | text_lora_state_dict = torch.load(text_lora_path, map_location='cpu') 110 | for key in text_lora_state_dict: 111 | text_state_dict[key] += beta_text*text_lora_state_dict[key] 112 | text_encoder.load_state_dict(text_state_dict) 113 | 114 | if vae_lora_path != '': 115 | vae_state_dict = vae.state_dict() 116 | vae_lora_state_dict = torch.load(vae_lora_path, map_location='cpu') 117 | for key in vae_lora_state_dict: 118 | vae_state_dict[key] += beta_vae*vae_lora_state_dict[key] 119 | vae.load_state_dict(vae_state_dict) 120 | 121 | if unet_lora_path != '': 122 | unet_state_dict = unet.state_dict() 123 | unet_lora_state_dict = torch.load(unet_lora_path, map_location='cpu') 124 | for key in unet_lora_state_dict: 125 | if unet_state_dict[key].shape != unet_lora_state_dict[key].shape: 126 | unet_state_dict[key] += beta_unet*unet_lora_state_dict[key].view(unet_lora_state_dict[key].shape[0], unet_lora_state_dict[key].shape[1], 1, 1) 127 | else: 128 | unet_state_dict[key] += beta_unet*unet_lora_state_dict[key] 129 | unet.load_state_dict(unet_state_dict) 130 | 131 | vae = vae.to(text_device).half().eval() 132 | text_encoder = text_encoder.to(text_device).half().eval() 133 | unet = unet.to(unet_device).half().eval() 134 | 135 | validation_pipeline = AnimationInpaintPipeline( 136 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, 137 | ) 138 | validation_pipeline.enable_vae_slicing() 139 | return validation_pipeline 140 | 141 | 142 | def generate_frames(images, masks, output_dir, validation_pipeline, vae, prompt, negative_prompt, guidance_scale, text_device="cuda:0", unet_device="cuda:1"): 143 | pixel_values = torch.tensor(images).to(device=vae.device, dtype=torch.float16) 144 | test_masks = torch.tensor(masks).to(device=vae.device, dtype=torch.float16) 145 | 146 | height,width = images.shape[-3:-1] 147 | 148 | prefix = 'test_release_'+str(guidance_scale)+'_guidance_scale' 149 | 150 | latents = [] 151 | masks = [] 152 | with torch.no_grad(): 153 | for i in range(len(pixel_values)): 154 | pixel_value = rearrange(pixel_values[i:i+1], "f h w c -> f c h w") 155 | test_mask = rearrange(test_masks[i:i+1], "f h w c -> f c h w") 156 | 157 | masked_image = (1-test_mask)*pixel_value 158 | latent = vae.encode(masked_image).latent_dist.sample() 159 | test_mask = torch.nn.functional.interpolate(test_mask, size=latent.shape[-2:]).cuda() 160 | 161 | latent = rearrange(latent, "f c h w -> c f h w") 162 | test_mask = rearrange(test_mask, "f c h w -> c f h w") 163 | 164 | latent = latent * 0.18215 165 | latents.append(latent) 166 | masks.append(test_mask) 167 | latents = torch.cat(latents,dim=1) 168 | test_masks = torch.cat(masks,dim=1) 169 | 170 | latents = latents[None,...] 171 | masks = test_masks[None,...] 172 | 173 | generator = torch.Generator(device=latents.device) 174 | generator.manual_seed(int(time.time())) 175 | 176 | with torch.no_grad(): 177 | 178 | videos, masked_videos, recon_videos = validation_pipeline( 179 | prompt, 180 | image = latents, 181 | masked_image = latents, 182 | masked_latents = None, 183 | masks = masks, 184 | generator = generator, 185 | video_length = len(images), 186 | negative_prompt = negative_prompt, 187 | height = height, 188 | width = width, 189 | num_inference_steps = 50, 190 | guidance_scale = guidance_scale, 191 | unet_device=unet_device 192 | ) 193 | 194 | videos = videos.permute(0,2,1,3,4).contiguous()/0.18215 195 | 196 | images = [] 197 | for i in range(len(videos[0])): 198 | image = vae.decode(videos[0][i:i+1].half().to(text_device)).sample 199 | images.append(image) 200 | video = torch.cat(images,dim=0) 201 | video = video/2 + 0.5 202 | video = torch.clamp(video, 0, 1) 203 | video = video.permute(0,2,3,1) 204 | 205 | images = [] 206 | video = 255.0*video.cpu().detach().numpy() 207 | for i in range(len(video)): 208 | image = video[i] 209 | image = np.uint8(image) 210 | #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 211 | #cv2.imwrite(output_dir+'/'+prefix+'_image_'+str(i)+'.png',image) 212 | images.append(image) 213 | return images 214 | 215 | if __name__ == "__main__": 216 | parser = argparse.ArgumentParser() 217 | parser.add_argument("--config", type=str, required=True) 218 | parser.add_argument("--prompt", type=str, default="") 219 | parser.add_argument("--negative_prompt", type=str, default="") 220 | parser.add_argument("--model_path", type=str, default="../") 221 | parser.add_argument("--pretrain_model_path", type=str, default="../") 222 | parser.add_argument("--sub_folder", type=str, default="unet") 223 | parser.add_argument("--guidance_scale", type=float, default=20) 224 | parser.add_argument("--video_path", type=str, default="") 225 | args = parser.parse_args() 226 | 227 | config = OmegaConf.load(args.config) 228 | 229 | validation_pipeline = load_model( \ 230 | model_path=args.model_path, \ 231 | sub_folder=args.sub_folder, \ 232 | pretrained_model_path=args.pretrain_model_path, \ 233 | **config 234 | ) 235 | 236 | video_path = args.video_path+'/images.npy' 237 | mask_path = args.video_path+'/masks.npy' 238 | images = 2*(np.load(video_path)/255.0 - 0.5) 239 | masks = np.load(mask_path)/255.0 240 | 241 | generate_frames(\ 242 | images=images, \ 243 | masks=masks, \ 244 | output_dir = './outputs', \ 245 | validation_pipeline=validation_pipeline, \ 246 | vae = validation_pipeline.vae, \ 247 | prompt=args.prompt, \ 248 | negative_prompt=args.negative_prompt, \ 249 | guidance_scale=args.guidance_scale) 250 | 251 | 252 | -------------------------------------------------------------------------------- /valid_code_release.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import wandb 4 | import random 5 | import logging 6 | import inspect 7 | import argparse 8 | import datetime 9 | import subprocess 10 | 11 | import cv2 12 | from pathlib import Path 13 | import numpy as np 14 | from tqdm.auto import tqdm 15 | from einops import rearrange 16 | from omegaconf import OmegaConf 17 | from safetensors import safe_open 18 | from typing import Dict, Optional, Tuple 19 | 20 | import torch 21 | import torchvision 22 | import torch.nn.functional as F 23 | import torch.distributed as dist 24 | from torch.optim.swa_utils import AveragedModel 25 | from torch.utils.data.distributed import DistributedSampler 26 | from torch.nn.parallel import DistributedDataParallel as DDP 27 | 28 | import diffusers 29 | from diffusers import AutoencoderKL, DDIMScheduler 30 | from diffusers.models import UNet2DConditionModel 31 | from diffusers.pipelines import StableDiffusionInpaintPipeline 32 | from diffusers.optimization import get_scheduler 33 | from diffusers.utils import check_min_version 34 | from diffusers.utils.import_utils import is_xformers_available 35 | 36 | import transformers 37 | from transformers import CLIPTextModel, CLIPProcessor, CLIPVisionModel, CLIPTokenizer 38 | 39 | from cococo.models.unet import UNet3DConditionModel 40 | from cococo.pipelines.pipeline_animation_inpainting_cross_attention_vae import AnimationInpaintPipeline 41 | from cococo.utils.util import save_videos_grid, zero_rank_print 42 | 43 | def main( 44 | name: str, 45 | use_wandb: bool, 46 | launcher: str, 47 | 48 | model_path: str, 49 | 50 | prompt: str, 51 | negative_prompt: str, 52 | guidance_scale: float, 53 | 54 | output_dir: str, 55 | pretrained_model_path: str, 56 | sub_folder: str = "unet", 57 | 58 | unet_checkpoint_path: str = "", 59 | unet_additional_kwargs: Dict = {}, 60 | noise_scheduler_kwargs = None, 61 | 62 | num_workers: int = 32, 63 | 64 | enable_xformers_memory_efficient_attention: bool = True, 65 | 66 | image_path: str = '', 67 | mask_path: str = '', 68 | global_seed: int = 42, 69 | is_debug: bool = False, 70 | ): 71 | 72 | seed = global_seed 73 | torch.manual_seed(seed) 74 | 75 | folder_name = name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S") 76 | output_dir = os.path.join(output_dir, folder_name) 77 | 78 | *_, config = inspect.getargvalues(inspect.currentframe()) 79 | 80 | os.makedirs(output_dir, exist_ok=True) 81 | OmegaConf.save(config, os.path.join(output_dir, 'config.yaml')) 82 | 83 | # Load scheduler, tokenizer and models. 84 | noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs)) 85 | 86 | vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae") 87 | tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer") 88 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder") 89 | 90 | unet = UNet3DConditionModel.from_pretrained_2d( 91 | pretrained_model_path, subfolder=sub_folder, 92 | unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs) 93 | ) 94 | state_dict = {} 95 | for i in range(4): 96 | state_dict2 = torch.load(f'{model_path}/model_{i}.pth', map_location='cpu') 97 | state_dict = {**state_dict, **state_dict2} 98 | 99 | state_dict2 = {} 100 | for key in state_dict: 101 | if 'pe' in key: 102 | continue 103 | state_dict2[key.split('module.')[1]] = state_dict[key] 104 | 105 | m, u = unet.load_state_dict(state_dict2, strict=False) 106 | 107 | vae = vae.cuda().half().eval() 108 | text_encoder = text_encoder.cuda().half().eval() 109 | unet = unet.cuda().half().eval() 110 | 111 | validation_pipeline = AnimationInpaintPipeline( 112 | unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, 113 | ) 114 | validation_pipeline.enable_vae_slicing() 115 | 116 | video_path = image_path 117 | mask_path = mask_path 118 | images = 2*(np.load(video_path)/255.0 - 0.5) 119 | masks = np.load(mask_path)/255.0 120 | pixel_values = torch.tensor(images).to(device=vae.device, dtype=torch.float16) 121 | test_masks = torch.tensor(masks).to(device=vae.device, dtype=torch.float16) 122 | 123 | height,width = images.shape[-3:-1] 124 | 125 | prefix = 'test_release_'+str(guidance_scale)+'_guidance_scale' 126 | 127 | latents = [] 128 | masks = [] 129 | with torch.no_grad(): 130 | for i in range(len(pixel_values)): 131 | pixel_value = rearrange(pixel_values[i:i+1], "f h w c -> f c h w") 132 | test_mask = rearrange(test_masks[i:i+1], "f h w c -> f c h w") 133 | 134 | masked_image = (1-test_mask)*pixel_value 135 | latent = vae.encode(masked_image).latent_dist.sample() 136 | test_mask = torch.nn.functional.interpolate(test_mask, size=latent.shape[-2:]).cuda() 137 | 138 | latent = rearrange(latent, "f c h w -> c f h w") 139 | test_mask = rearrange(test_mask, "f c h w -> c f h w") 140 | 141 | latent = latent * 0.18215 142 | latents.append(latent) 143 | masks.append(test_mask) 144 | latents = torch.cat(latents,dim=1) 145 | test_masks = torch.cat(masks,dim=1) 146 | 147 | latents = latents[None,...] 148 | masks = test_masks[None,...] 149 | 150 | generator = torch.Generator(device=latents.device) 151 | generator.manual_seed(0) 152 | 153 | for step in range(10): 154 | 155 | with torch.no_grad(): 156 | 157 | videos, masked_videos, recon_videos = validation_pipeline( 158 | prompt, 159 | image = latents, 160 | masked_image = latents, 161 | masked_latents = None, 162 | masks = masks, 163 | generator = generator, 164 | video_length = len(images), 165 | negative_prompt = negative_prompt, 166 | height = height, 167 | width = width, 168 | num_inference_steps = 50, 169 | guidance_scale = guidance_scale 170 | ) 171 | 172 | videos = videos.permute(0,2,1,3,4).contiguous()/0.18215 173 | 174 | with torch.no_grad(): 175 | images = [] 176 | for i in range(len(videos[0])): 177 | image = vae.decode(videos[0][i:i+1].half()).sample 178 | images.append(image) 179 | video = torch.cat(images,dim=0) 180 | video = video/2 + 0.5 181 | video = torch.clamp(video, 0, 1) 182 | video = video.permute(0,2,3,1) 183 | 184 | video = 255.0*video.cpu().detach().numpy() 185 | for i in range(len(video)): 186 | image = video[i] 187 | image = np.uint8(image) 188 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 189 | cv2.imwrite(output_dir+'/'+prefix+'_'+str(step)+'_image_'+str(i)+'.png',image) 190 | 191 | if __name__ == "__main__": 192 | parser = argparse.ArgumentParser() 193 | parser.add_argument("--config", type=str, required=True) 194 | parser.add_argument("--prompt", type=str, default="") 195 | parser.add_argument("--negative_prompt", type=str, default="") 196 | parser.add_argument("--model_path", type=str, default="../") 197 | parser.add_argument("--pretrain_model_path", type=str, default="../") 198 | parser.add_argument("--sub_folder", type=str, default="unet") 199 | parser.add_argument("--guidance_scale", type=float, default=20) 200 | parser.add_argument("--video_path", type=str, default="") 201 | args = parser.parse_args() 202 | 203 | name = Path(args.config).stem 204 | config = OmegaConf.load(args.config) 205 | 206 | main(name=name, \ 207 | launcher=None, \ 208 | use_wandb=False, \ 209 | prompt=args.prompt, \ 210 | model_path=args.model_path, \ 211 | sub_folder=args.sub_folder, \ 212 | pretrained_model_path=args.pretrain_model_path, \ 213 | negative_prompt=args.negative_prompt, \ 214 | guidance_scale=args.guidance_scale, \ 215 | image_path=args.video_path+'/images.npy', \ 216 | mask_path=args.video_path+'/masks.npy', \ 217 | **config 218 | ) 219 | --------------------------------------------------------------------------------