├── README.md ├── assets └── teaser.png ├── gen2seg_mae_pipeline.py ├── gen2seg_sd_pipeline.py ├── inference_mae.py ├── inference_sd.py ├── prompting.py └── sam.py /README.md: -------------------------------------------------------------------------------- 1 | # gen2seg: Generative Models Enable Generalizable Instance Segmentation 2 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/reachomk/gen2seg) 3 | 4 | ### [Project Page](https://reachomk.github.io/gen2seg) | [Paper](https://arxiv.org/abs/2505.15263) 5 | 6 | [**gen2seg: Generative Models Enable Generalizable Instance Segmentation**](https://reachomk.github.io/gen2seg) 7 | [Om Khangaonkar](https://reachomk.github.io), 8 | [Hamed Pirsiavash](https://web.cs.ucdavis.edu/~hpirsiav/)
9 | UC Davis
10 | 11 | 12 | ## Pretrained Models 13 | Stable Diffusion 2 (SD): https://huggingface.co/reachomk/gen2seg-sd 14 | 15 | ImageNet-1K-pretrained Masked Autoencoder-Huge (MAE-H): https://huggingface.co/reachomk/gen2seg-mae-h 16 | 17 | If you want any of our other models, send me an email. If there is sufficient demand, I will also release them publicly. 18 | 19 | ## Inference 20 | Currently, we have released inference code for our SD and MAE models. You can run them by editing the `image_path` variable (for your input image) in each file, and then simply running it with `python inference_{mae or sd}.py`. 21 | 22 | You will need to have `transformers` and `diffusers` installed, along with standard machine learning packages such as `pytorch` and `numpy`. More details on our specific environment will be released with the training code. 23 | 24 | We have also released code for prompting. Please run `pip install opencv-contrib-python` prior to running this file. 25 | 26 | Here is how you run it: 27 | ``` 28 | python prompting.py \ 29 | --feature_image /path/to/your/feature_image.png \ 30 | --prompt_x 150 \ 31 | --prompt_y 200 \ 32 | ``` 33 | The feature image is the one generated by our model, NOT the original image. 34 | 35 | 36 | We also have the additional optional arguments: 37 | ``` 38 | --output_mask /path/to/save/output_mask.png \ 39 | --sigma 0.02 \ 40 | --threshold 10 41 | ``` 42 | 43 | Threshold and sigma allow you to control the mask threshold (out of 255) and the amount of averaging for the query vector. By default they are 0.01 and 3. See our paper for more details. 44 | 45 | We have also provided our inference script for SAM, to enable qualitative comparison. Please make sure you download the checkpoint and input the path in the script. You should also edit the `image_path` variable (for your input image). 46 | 47 | ## Training 48 | I will release all training code by June 7 (likely earlier). Make sure to star and watch our repository so you're notified when we update it! Send me an email if you need it before that. 49 | 50 | ## Citation 51 | Please cite our paper if it was helpful or you liked it. 52 | ``` 53 | @article{khangaonkar2025gen2seg, 54 | title={gen2seg: Generative Models Enable Generalizable Instance Segmentation}, 55 | author={Om Khangaonkar and Hamed Pirsiavash}, 56 | year={2025}, 57 | journal={arXiv preprint arXiv:2505.15263} 58 | } 59 | ``` 60 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/gen2seg/72e01e93feba569f580e1299edf21a7e08303ea7/assets/teaser.png -------------------------------------------------------------------------------- /gen2seg_mae_pipeline.py: -------------------------------------------------------------------------------- 1 | # gen2seg official inference pipeline code for Stable Diffusion model 2 | # 3 | # Please see our project website at https://reachomk.github.io/gen2seg 4 | # 5 | # Additionally, if you use our code please cite our paper, along with the two works above. 6 | 7 | from dataclasses import dataclass 8 | from typing import Union, List, Optional 9 | 10 | import torch 11 | import numpy as np 12 | from PIL import Image 13 | from einops import rearrange 14 | 15 | from diffusers import DiffusionPipeline 16 | from diffusers.utils import BaseOutput, logging 17 | from transformers import AutoImageProcessor 18 | 19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | @dataclass 23 | class gen2segMAEInstanceOutput(BaseOutput): 24 | """ 25 | Output class for the ViTMAE Instance Segmentation Pipeline. 26 | 27 | Args: 28 | prediction (`np.ndarray` or `torch.Tensor`): 29 | Predicted instance segmentation maps. The output has shape 30 | `(batch_size, 3, height, width)` with pixel values scaled to [0, 255]. 31 | """ 32 | prediction: Union[np.ndarray, torch.Tensor] 33 | 34 | 35 | class gen2segMAEInstancePipeline(DiffusionPipeline): 36 | r""" 37 | Pipeline for Instance Segmentation using a fine-tuned ViTMAEForPreTraining model. 38 | 39 | This pipeline takes one or more input images and returns an instance segmentation 40 | prediction for each image. The model is assumed to have been fine-tuned using an instance 41 | segmentation loss, and the reconstruction is performed by rearranging the model’s 42 | patch logits into an image. 43 | 44 | Args: 45 | model (`ViTMAEForPreTraining`): 46 | The fine-tuned ViTMAE model. 47 | image_processor (`AutoImageProcessor`): 48 | The image processor responsible for preprocessing input images. 49 | """ 50 | def __init__(self, model, image_processor): 51 | super().__init__() 52 | self.register_modules(model=model, image_processor=image_processor) 53 | self.model = model 54 | self.image_processor = image_processor 55 | 56 | def check_inputs( 57 | self, 58 | image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]] 59 | ) -> List: 60 | if not isinstance(image, list): 61 | image = [image] 62 | # Additional input validations can be added here if desired. 63 | return image 64 | 65 | @torch.no_grad() 66 | def __call__( 67 | self, 68 | image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]], 69 | output_type: str = "np", 70 | **kwargs 71 | ) -> gen2segMAEInstanceOutput: 72 | r""" 73 | The call method of the pipeline. 74 | 75 | Args: 76 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, or a list of these): 77 | The input image(s) for instance segmentation. For arrays/tensors, expected values are in [0, 1]. 78 | output_type (`str`, optional, defaults to `"np"`): 79 | The format of the output prediction. Choose `"np"` for a NumPy array or `"pt"` for a PyTorch tensor. 80 | **kwargs: 81 | Additional keyword arguments passed to the image processor. 82 | 83 | Returns: 84 | [`gen2segMAEInstanceOutput`]: 85 | An output object containing the predicted instance segmentation maps. 86 | """ 87 | # 1. Check and prepare input images. 88 | images = self.check_inputs(image) 89 | inputs = self.image_processor(images=images, return_tensors="pt", **kwargs) 90 | pixel_values = inputs["pixel_values"].to(self.device) 91 | 92 | # 2. Forward pass through the model. 93 | outputs = self.model(pixel_values=pixel_values) 94 | logits = outputs.logits # Expected shape: (B, num_patches, patch_dim) 95 | 96 | # 3. Retrieve patch size and image size from the model configuration. 97 | patch_size = self.model.config.patch_size # e.g., 16 98 | image_size = self.model.config.image_size # e.g., 224 99 | grid_size = image_size // patch_size 100 | 101 | # 4. Rearrange logits into the reconstructed image. 102 | # The logits are reshaped from (B, num_patches, patch_dim) to (B, 3, H, W). 103 | reconstructed = rearrange( 104 | logits, 105 | "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", 106 | h=grid_size, 107 | p1=patch_size, 108 | p2=patch_size, 109 | c=3, 110 | ) 111 | 112 | # 5. Post-process the reconstructed output. 113 | # For each sample, shift and scale the prediction to [0, 255]. 114 | predictions = [] 115 | for i in range(reconstructed.shape[0]): 116 | sample = reconstructed[i] 117 | min_val = torch.abs(sample.min()) 118 | max_val = torch.abs(sample.max()) 119 | sample = (sample + min_val) / (max_val + min_val + 1e-5) 120 | # sometimes the image is very dark so we perform gamma correction to "brighten" it 121 | # in practice we can set this value to whatever we want or disable it entirely. 122 | sample = sample**0.7 123 | sample = sample * 255.0 124 | predictions.append(sample) 125 | prediction_tensor = torch.stack(predictions, dim=0).permute(0, 2, 3, 1) 126 | 127 | # 6. Format the output. 128 | if output_type == "np": 129 | prediction = prediction_tensor.cpu().numpy() 130 | else: 131 | prediction = prediction_tensor 132 | return gen2segMAEInstanceOutput(prediction=prediction) 133 | -------------------------------------------------------------------------------- /gen2seg_sd_pipeline.py: -------------------------------------------------------------------------------- 1 | # gen2seg official inference pipeline code for Stable Diffusion model 2 | # 3 | # This code was adapted from Marigold and Diffusion E2E Finetuning. 4 | # 5 | # Please see our project website at https://reachomk.github.io/gen2seg 6 | # 7 | # Additionally, if you use our code please cite our paper, along with the two works above. 8 | 9 | from dataclasses import dataclass 10 | from typing import List, Optional, Tuple, Union 11 | 12 | import numpy as np 13 | import torch 14 | from PIL import Image 15 | from tqdm.auto import tqdm 16 | from transformers import CLIPTextModel, CLIPTokenizer 17 | 18 | from diffusers.image_processor import PipelineImageInput 19 | from diffusers.models import ( 20 | AutoencoderKL, 21 | UNet2DConditionModel, 22 | ) 23 | from diffusers.schedulers import ( 24 | DDIMScheduler, 25 | ) 26 | from diffusers.utils import ( 27 | BaseOutput, 28 | logging, 29 | ) 30 | from diffusers import DiffusionPipeline 31 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor 32 | 33 | # add 34 | def zeros_tensor( 35 | shape: Union[Tuple, List], 36 | device: Optional["torch.device"] = None, 37 | dtype: Optional["torch.dtype"] = None, 38 | layout: Optional["torch.layout"] = None, 39 | ): 40 | """ 41 | A helper function to create tensors of zeros on the desired `device`. 42 | Mirrors randn_tensor from diffusers.utils.torch_utils. 43 | """ 44 | layout = layout or torch.strided 45 | device = device or torch.device("cpu") 46 | latents = torch.zeros(list(shape), dtype=dtype, layout=layout).to(device) 47 | return latents 48 | 49 | 50 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 51 | 52 | @dataclass 53 | class gen2segSDSegOutput(BaseOutput): 54 | """ 55 | Output class for gen2seg Instance Segmentation prediction pipeline. 56 | 57 | Args: 58 | prediction (`np.ndarray`, `torch.Tensor`): 59 | Predicted instance segmentation with values in the range [0, 255]. The shape is always $numimages \times 1 \times height 60 | \times width$, regardless of whether the images were passed as a 4D array or a list. 61 | latent (`None`, `torch.Tensor`): 62 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline. 63 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$. 64 | """ 65 | 66 | prediction: Union[np.ndarray, torch.Tensor] 67 | latent: Union[None, torch.Tensor] 68 | 69 | 70 | class gen2segSDPipeline(DiffusionPipeline): 71 | """ 72 | # add 73 | Pipeline for Instance Segmentation prediction using our Stable Diffusion model. 74 | Implementation is built upon Marigold: https://marigoldmonodepth.github.io and E2E FThttps://gonzalomartingarcia.github.io/diffusion-e2e-ft/ 75 | 76 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 77 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 78 | 79 | Args: 80 | unet (`UNet2DConditionModel`): 81 | Conditional U-Net to denoise the segmentation latent, synthesized from image latent. 82 | vae (`AutoencoderKL`): 83 | Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent 84 | representations. 85 | scheduler (`DDIMScheduler`): 86 | A scheduler to be used in combination with `unet` to denoise the encoded image latent. 87 | text_encoder (`CLIPTextModel`): 88 | Text-encoder, for empty text embedding. 89 | tokenizer (`CLIPTokenizer`): 90 | CLIP tokenizer. 91 | default_processing_resolution (`int`, *optional*): 92 | The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in 93 | the model config. When the pipeline is called without explicitly setting `processing_resolution`, the 94 | default value is used. This is required to ensure reasonable results with various model flavors trained 95 | with varying optimal processing resolution values. 96 | """ 97 | 98 | model_cpu_offload_seq = "text_encoder->unet->vae" 99 | 100 | def __init__( 101 | self, 102 | unet: UNet2DConditionModel, 103 | vae: AutoencoderKL, 104 | scheduler: Union[DDIMScheduler], 105 | text_encoder: CLIPTextModel, 106 | tokenizer: CLIPTokenizer, 107 | default_processing_resolution: Optional[int] = 768, # add 108 | ): 109 | super().__init__() 110 | 111 | self.register_modules( 112 | unet=unet, 113 | vae=vae, 114 | scheduler=scheduler, 115 | text_encoder=text_encoder, 116 | tokenizer=tokenizer, 117 | ) 118 | self.register_to_config( 119 | default_processing_resolution=default_processing_resolution, 120 | ) 121 | 122 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 123 | self.default_processing_resolution = default_processing_resolution 124 | self.empty_text_embedding = None 125 | 126 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor) 127 | 128 | def check_inputs( 129 | self, 130 | image: PipelineImageInput, 131 | processing_resolution: int, 132 | resample_method_input: str, 133 | resample_method_output: str, 134 | batch_size: int, 135 | output_type: str, 136 | ) -> int: 137 | if processing_resolution is None: 138 | raise ValueError( 139 | "`processing_resolution` is not specified and could not be resolved from the model config." 140 | ) 141 | if processing_resolution < 0: 142 | raise ValueError( 143 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for " 144 | "downsampled processing." 145 | ) 146 | if processing_resolution % self.vae_scale_factor != 0: 147 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.") 148 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 149 | raise ValueError( 150 | "`resample_method_input` takes string values compatible with PIL library: " 151 | "nearest, nearest-exact, bilinear, bicubic, area." 152 | ) 153 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"): 154 | raise ValueError( 155 | "`resample_method_output` takes string values compatible with PIL library: " 156 | "nearest, nearest-exact, bilinear, bicubic, area." 157 | ) 158 | if batch_size < 1: 159 | raise ValueError("`batch_size` must be positive.") 160 | if output_type not in ["pt", "np"]: 161 | raise ValueError("`output_type` must be one of `pt` or `np`.") 162 | 163 | # image checks 164 | num_images = 0 165 | W, H = None, None 166 | if not isinstance(image, list): 167 | image = [image] 168 | for i, img in enumerate(image): 169 | if isinstance(img, np.ndarray) or torch.is_tensor(img): 170 | if img.ndim not in (2, 3, 4): 171 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.") 172 | H_i, W_i = img.shape[-2:] 173 | N_i = 1 174 | if img.ndim == 4: 175 | N_i = img.shape[0] 176 | elif isinstance(img, Image.Image): 177 | W_i, H_i = img.size 178 | N_i = 1 179 | else: 180 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.") 181 | if W is None: 182 | W, H = W_i, H_i 183 | elif (W, H) != (W_i, H_i): 184 | raise ValueError( 185 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}" 186 | ) 187 | num_images += N_i 188 | 189 | return num_images 190 | 191 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True): 192 | if not hasattr(self, "_progress_bar_config"): 193 | self._progress_bar_config = {} 194 | elif not isinstance(self._progress_bar_config, dict): 195 | raise ValueError( 196 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}." 197 | ) 198 | 199 | progress_bar_config = dict(**self._progress_bar_config) 200 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc) 201 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave) 202 | if iterable is not None: 203 | return tqdm(iterable, **progress_bar_config) 204 | elif total is not None: 205 | return tqdm(total=total, **progress_bar_config) 206 | else: 207 | raise ValueError("Either `total` or `iterable` has to be defined.") 208 | 209 | @torch.no_grad() 210 | def __call__( 211 | self, 212 | image: PipelineImageInput, 213 | processing_resolution: Optional[int] = None, 214 | match_input_resolution: bool = False, 215 | resample_method_input: str = "bilinear", 216 | resample_method_output: str = "bilinear", 217 | batch_size: int = 1, 218 | output_type: str = "np", 219 | output_latent: bool = False, 220 | return_dict: bool = True, 221 | ): 222 | """ 223 | Function invoked when calling the pipeline. 224 | 225 | Args: 226 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`), 227 | `List[torch.Tensor]`: An input image or images used as an input for the instance segmentation task. For 228 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible 229 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or 230 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the 231 | same width and height. 232 | processing_resolution (`int`, *optional*, defaults to `None`): 233 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This 234 | produces crisper predictions, but may also lead to the overall loss of global context. The default 235 | value `None` resolves to the optimal value from the model config. 236 | match_input_resolution (`bool`, *optional*, defaults to `True`): 237 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer 238 | side of the output will equal to `processing_resolution`. 239 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`): 240 | Resampling method used to resize input images to `processing_resolution`. The accepted values are: 241 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 242 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`): 243 | Resampling method used to resize output predictions to match the input resolution. The accepted values 244 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`. 245 | batch_size (`int`, *optional*, defaults to `1`): 246 | Batch size; only matters passing a tensor of images. 247 | output_type (`str`, *optional*, defaults to `"np"`): 248 | Preferred format of the output's `prediction`. The accepted ßvalues are: `"np"` (numpy array) or `"pt"` (torch tensor). 249 | output_latent (`bool`, *optional*, defaults to `False`): 250 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions 251 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the 252 | `latents` argument. 253 | return_dict (`bool`, *optional*, defaults to `True`): 254 | Whether or not to return a [`gen2segSDSegOutput`] instead of a plain tuple. 255 | 256 | # add 257 | E2E FT models are deterministic single step models involving no ensembling, i.e. E=1. 258 | """ 259 | 260 | # 0. Resolving variables. 261 | device = self._execution_device 262 | dtype = self.dtype 263 | 264 | # Model-specific optimal default values leading to fast and reasonable results. 265 | if processing_resolution is None: 266 | processing_resolution = self.default_processing_resolution 267 | 268 | #print(image[0].size) 269 | #processing_resolution = 8 * round(max(image[0].size) / 8) 270 | 271 | # 1. Check inputs. 272 | num_images = self.check_inputs( 273 | image, 274 | processing_resolution, 275 | resample_method_input, 276 | resample_method_output, 277 | batch_size, 278 | output_type, 279 | ) 280 | 281 | # 2. Prepare empty text conditioning. 282 | # Model invocation: self.tokenizer, self.text_encoder. 283 | prompt = "" 284 | text_inputs = self.tokenizer( 285 | prompt, 286 | padding="do_not_pad", 287 | max_length=self.tokenizer.model_max_length, 288 | truncation=True, 289 | return_tensors="pt", 290 | ) 291 | text_input_ids = text_inputs.input_ids.to(device) 292 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024] 293 | 294 | # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`, 295 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where 296 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are 297 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None` 298 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of 299 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing 300 | # resolution can lead to loss of either fine details or global context in the output predictions. 301 | image, padding, original_resolution = self.image_processor.preprocess( 302 | image, processing_resolution, resample_method_input, device, dtype 303 | ) # [N,3,PPH,PPW] 304 | # image =(image+torch.abs(image.min())) 305 | # image = image/(torch.abs(image.max())+torch.abs(image.min())) 306 | # # prediction = prediction**0.5 307 | # #prediction = torch.clip(prediction, min=-1, max=1)+1 308 | # image = (image) * 2 309 | # image = image - 1 310 | # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E` 311 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently. 312 | # Latents of each such predictions across all input images and all ensemble members are represented in the 313 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded 314 | # into latent space and replicated `E` times. Encoding into latent space happens in batches of size `batch_size`. 315 | # Model invocation: self.vae.encoder. 316 | image_latent, pred_latent = self.prepare_latents( 317 | image, batch_size 318 | ) # [N*E,4,h,w], [N*E,4,h,w] 319 | 320 | del image 321 | 322 | batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat( 323 | batch_size, 1, 1 324 | ) # [B,1024,2] 325 | 326 | # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`. 327 | # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and 328 | # outputs noise for the predicted modality's latent space. 329 | # Model invocation: self.unet. 330 | pred_latents = [] 331 | 332 | for i in range(0, num_images, batch_size): 333 | batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w] 334 | batch_pred_latent = batch_image_latent[i : i + batch_size] # [B,4,h,w] 335 | effective_batch_size = batch_image_latent.shape[0] 336 | text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024] 337 | 338 | # add 339 | # Single step inference for E2E FT models 340 | self.scheduler.set_timesteps(1, device=device) 341 | for t in self.scheduler.timesteps: 342 | batch_latent = batch_image_latent # torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w] 343 | noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w] 344 | batch_pred_latent = self.scheduler.step( 345 | noise, t, batch_image_latent 346 | ).pred_original_sample # [B,4,h,w], # add 347 | # directly take pred_original_sample rather than prev_sample 348 | 349 | pred_latents.append(batch_pred_latent) 350 | 351 | pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w] 352 | 353 | del ( 354 | pred_latents, 355 | image_latent, 356 | batch_empty_text_embedding, 357 | batch_image_latent, 358 | # batch_pred_latent, 359 | text, 360 | batch_latent, 361 | noise, 362 | ) 363 | 364 | # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`, 365 | # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`. 366 | # Model invocation: self.vae.decoder. 367 | prediction = torch.cat( 368 | [ 369 | self.decode_prediction(pred_latent[i : i + batch_size]) 370 | for i in range(0, pred_latent.shape[0], batch_size) 371 | ], 372 | dim=0, 373 | ) # [N*E,1,PPH,PPW] 374 | 375 | if not output_latent: 376 | pred_latent = None 377 | 378 | # 7. Remove padding. The output shape is (PH, PW). 379 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW] 380 | 381 | # 9. If `match_input_resolution` is set, the output prediction are upsampled to match the 382 | # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled. 383 | # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by 384 | # setting the `resample_method_output` parameter (e.g., to `"nearest"`). 385 | if match_input_resolution: 386 | prediction = self.image_processor.resize_antialias( 387 | prediction, original_resolution, resample_method_output, is_aa=False 388 | ) # [N,1,H,W] 389 | 390 | # 10. Prepare the final outputs. 391 | if output_type == "np": 392 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1] 393 | 394 | # 11. Offload all models 395 | self.maybe_free_model_hooks() 396 | 397 | if not return_dict: 398 | return (prediction, pred_latent) 399 | 400 | return gen2segSDSegOutput( 401 | prediction=prediction, 402 | latent=pred_latent, 403 | ) 404 | 405 | def prepare_latents( 406 | self, 407 | image: torch.Tensor, 408 | batch_size: int, 409 | ) -> Tuple[torch.Tensor, torch.Tensor]: 410 | def retrieve_latents(encoder_output): 411 | if hasattr(encoder_output, "latent_dist"): 412 | return encoder_output.latent_dist.mode() 413 | elif hasattr(encoder_output, "latents"): 414 | return encoder_output.latents 415 | else: 416 | raise AttributeError("Could not access latents of provided encoder_output") 417 | 418 | image_latent = torch.cat( 419 | [ 420 | retrieve_latents(self.vae.encode(image[i : i + batch_size])) 421 | for i in range(0, image.shape[0], batch_size) 422 | ], 423 | dim=0, 424 | ) # [N,4,h,w] 425 | image_latent = image_latent * self.vae.config.scaling_factor # [N*E,4,h,w] 426 | 427 | # add 428 | # provide zeros as noised latent 429 | pred_latent = zeros_tensor( 430 | image_latent.shape, 431 | device=image_latent.device, 432 | dtype=image_latent.dtype, 433 | ) # [N*E,4,h,w] 434 | 435 | return image_latent, pred_latent 436 | 437 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor: 438 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels: 439 | raise ValueError( 440 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}." 441 | ) 442 | 443 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W] 444 | #print(prediction.max()) 445 | #print(prediction.min()) 446 | 447 | prediction =(prediction+torch.abs(prediction.min())) 448 | prediction = prediction/(torch.abs(prediction.max())+torch.abs(prediction.min())) 449 | #prediction = prediction**0.5 450 | #prediction = torch.clip(prediction, min=-1, max=1)+1 451 | prediction = (prediction) * 255.0 452 | #print(prediction.max()) 453 | #print(prediction.min()) 454 | return prediction # [B,1,H,W] -------------------------------------------------------------------------------- /inference_mae.py: -------------------------------------------------------------------------------- 1 | # gen2seg official inference pipeline code for Stable Diffusion model 2 | # 3 | # Please see our project website at https://reachomk.github.io/gen2seg 4 | # 5 | # Additionally, if you use our code please cite our paper, along with the two works above. 6 | 7 | 8 | import os 9 | import time 10 | import torch 11 | from gen2seg_mae_pipeline import gen2segMAEInstancePipeline # Custom pipeline for MAE 12 | from transformers import AutoImageProcessor 13 | from PIL import Image 14 | import numpy as np 15 | 16 | 17 | # Example usage: Update these paths as needed. 18 | image_path = "FILL THIS OUT" # Path to the input image. 19 | output_path = "seg_mae.png" # Path to save the output image. 20 | device = "cuda:0" # Change to "cpu" if no GPU is available. 21 | 22 | print(f"Loading MAE pipeline on {device} for single image inference...") 23 | 24 | # Load the image processor (using a pretrained processor from facebook/vit-mae-huge). 25 | image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge") 26 | 27 | # Instantiate the pipeline and move it to the desired device. 28 | pipe = gen2segMAEInstancePipeline(model="reachomk/gen2seg-mae-h", image_processor=image_processor).to(device) 29 | 30 | # Load the image, storing the original size, then resize for inference. 31 | orig_image = Image.open(image_path).convert("RGB") 32 | orig_size = orig_image.size # (width, height) 33 | image = orig_image.resize((224, 224)) 34 | 35 | # Run inference. 36 | start_time = time.time() 37 | with torch.no_grad(): 38 | pipe_output = pipe([image]) 39 | end_time = time.time() 40 | print(f"Inference completed in {end_time - start_time:.2f} seconds.") 41 | prediction = pipe_output.prediction[0] 42 | 43 | # Convert the prediction to an image. 44 | seg = np.array(prediction.squeeze()).astype(np.uint8) 45 | seg_img = Image.fromarray(seg) 46 | 47 | # Resize the segmentation output back to the original image size. 48 | seg_img = seg_img.resize(orig_size, Image.LANCZOS) 49 | 50 | # Save the output image. 51 | seg_img.save(output_path) 52 | print(f"Saved output image to {output_path}") -------------------------------------------------------------------------------- /inference_sd.py: -------------------------------------------------------------------------------- 1 | # gen2seg official inference pipeline code for Stable Diffusion model 2 | # 3 | # Please see our project website at https://reachomk.github.io/gen2seg 4 | # 5 | # Additionally, if you use our code please cite our paper. 6 | 7 | 8 | import torch 9 | from gen2seg_sd_pipeline import gen2segSDPipeline # Import your custom pipeline 10 | from PIL import Image 11 | import numpy as np 12 | import time 13 | 14 | # Load the image 15 | image_path = 'FILL THIS IN' 16 | image = Image.open(image_path).convert("RGB") 17 | orig_res = image.size 18 | output_path = "seg.png" 19 | 20 | pipe = gen2segSDPipeline.from_pretrained( 21 | "reachomk/gen2seg-sd", 22 | use_safetensors=True, # Use safetensors if available 23 | ).to("cuda") # Ensure the pipeline is moved to CUDA 24 | 25 | # Load the pipeline and generate the segmentation map 26 | with torch.no_grad(): 27 | 28 | start_time = time.time() 29 | # Generate segmentation map 30 | seg = pipe(image).prediction.squeeze() 31 | end_time = time.time() 32 | print(f"Inference completed in {end_time - start_time:.2f} seconds.") 33 | 34 | seg = np.array(seg).astype(np.uint8) 35 | Image.fromarray(seg).resize(orig_res, Image.LANCZOS).save(output_path) 36 | print(f"Saved output image to {output_path}") -------------------------------------------------------------------------------- /prompting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import numpy as np 4 | import cv2 5 | import torch 6 | from PIL import Image 7 | 8 | ############################################# 9 | # Gaussian Heatmap Generator 10 | ############################################# 11 | def create_gaussian_heatmap(H, W, point, sigma, device='cpu'): 12 | """ 13 | Creates a Gaussian heatmap of size (H, W) centered at the given normalized point. 14 | The heatmap is scaled from [0,1] to [-1,1] and returned with an extra channel. 15 | """ 16 | ys = torch.linspace(0, 1, H, device=device) 17 | xs = torch.linspace(0, 1, W, device=device) 18 | y_grid, x_grid = torch.meshgrid(ys, xs, indexing='ij') 19 | # point is (normalized_y, normalized_x) 20 | dist_sq = (x_grid - point[1])**2 + (y_grid - point[0])**2 21 | heatmap = torch.exp(-dist_sq / (2 * sigma**2)) 22 | heatmap = heatmap * 2 - 1 # scale from [0,1] -> [-1,1] 23 | return heatmap.unsqueeze(0) # shape: (1, H, W) 24 | 25 | ############################################# 26 | # Bilateral Solver (Refinement) Function 27 | ############################################# 28 | def refine_with_bilateral_solver(sim_map, guidance_img, d=9, sigmaColor=75, sigmaSpace=75): 29 | """ 30 | Refines a similarity map using cv2.ximgproc.jointBilateralFilter with the guidance image. 31 | """ 32 | if not (hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'jointBilateralFilter')): 33 | print("WARNING: cv2.ximgproc.jointBilateralFilter is not available. Install opencv-contrib-python.") 34 | print("Skipping refinement.") 35 | return sim_map # Return unrefined map if filter is not available 36 | 37 | sim_map_8u = np.clip(sim_map * 255, 0, 255).astype(np.uint8) 38 | refined = cv2.ximgproc.jointBilateralFilter(guidance_img, sim_map_8u, d, sigmaColor, sigmaSpace) 39 | refined_float = refined.astype(np.float32) / 255.0 40 | return refined_float 41 | 42 | ################################## 43 | # Process a Single Prompt Point 44 | ################################## 45 | def generate_mask_for_single_prompt(feature_image_path, prompt_point_xy, output_mask_path, 46 | gaussian_sigma=0.01, manual_threshold=3, epsilon=1e-6): 47 | """ 48 | Generates and saves a binary mask for a single prompt point on a feature image. 49 | The image is processed at its original dimensions. 50 | """ 51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 52 | print(f"Using device: {device}") 53 | 54 | # 1. Load the feature image 55 | try: 56 | img = Image.open(feature_image_path).convert("RGB") 57 | except FileNotFoundError: 58 | print(f"Error: Feature image not found at {feature_image_path}") 59 | return 60 | except Exception as e: 61 | print(f"Error opening image {feature_image_path}: {e}") 62 | return 63 | 64 | W_orig, H_orig = img.size # Get original dimensions 65 | 66 | img_np = np.array(img).astype(np.float32) 67 | features_tensor = torch.from_numpy(img_np).permute(2, 0, 1).to(device) # (3, H, W) 68 | _, H, W = features_tensor.shape # H, W will be H_orig, W_orig 69 | 70 | # Create guidance image for bilateral filter (original RGB image) 71 | guidance_img_np = np.array(img).astype(np.uint8) # cv2 expects uint8 BGR or grayscale 72 | if guidance_img_np.shape[2] == 3: # RGB 73 | guidance_img_bgr = cv2.cvtColor(guidance_img_np, cv2.COLOR_RGB2BGR) 74 | else: # Grayscale 75 | guidance_img_bgr = guidance_img_np 76 | 77 | 78 | print(f"Image dimensions: H={H}, W={W}") 79 | print(f"Prompt (x,y): {prompt_point_xy}") 80 | 81 | # 2. Normalize prompt point and create Gaussian heatmap 82 | # prompt_point_xy is (x, y), create_gaussian_heatmap expects (norm_y, norm_x) 83 | # Ensure prompt point is within image bounds 84 | if not (0 <= prompt_point_xy[0] < W and 0 <= prompt_point_xy[1] < H): 85 | print(f"Error: Prompt point ({prompt_point_xy[0]},{prompt_point_xy[1]}) is outside image bounds ({W-1},{H-1}).") 86 | return 87 | 88 | norm_y = prompt_point_xy[1] / (H - 1) if H > 1 else 0.5 89 | norm_x = prompt_point_xy[0] / (W - 1) if W > 1 else 0.5 90 | norm_point = (norm_y, norm_x) 91 | 92 | prompt_heatmap = create_gaussian_heatmap(H, W, norm_point, sigma=gaussian_sigma, device=device) 93 | prompt_weights = (prompt_heatmap + 1) / 2 # Convert from [-1,1] to [0,1]. 94 | 95 | # 3. Compute weighted color and similarity map 96 | weighted_sum_rgb = torch.sum(features_tensor * prompt_weights, dim=(1, 2)) 97 | total_weight = torch.sum(prompt_weights) 98 | query_color_rgb = weighted_sum_rgb / (total_weight + epsilon) 99 | 100 | diff_rgb = features_tensor - query_color_rgb[:, None, None] 101 | distance_map_rgb = torch.norm(diff_rgb, dim=0) 102 | similarity_map_rgb = 1.0 / (distance_map_rgb + epsilon) 103 | 104 | min_val_rgb = similarity_map_rgb.min() 105 | max_val_rgb = similarity_map_rgb.max() 106 | # Add epsilon to prevent division by zero if max_val_rgb == min_val_rgb 107 | normalized_similarity_rgb = (similarity_map_rgb - min_val_rgb) / (max_val_rgb - min_val_rgb + epsilon) 108 | normalized_similarity_rgb = normalized_similarity_rgb.view(H, W) 109 | 110 | # 4. Refine similarity map 111 | # Use guidance_img_bgr which is the original color image in BGR uint8 format 112 | refined_sim_map_np = refine_with_bilateral_solver( 113 | normalized_similarity_rgb.cpu().numpy().astype(np.float32), guidance_img_bgr 114 | ) 115 | 116 | # 5. Threshold to produce binary mask 117 | binary_mask = ((refined_sim_map_np * 255) > manual_threshold).astype(np.uint8) * 255 118 | 119 | # 6. Save the binary mask 120 | try: 121 | cv2.imwrite(output_mask_path, binary_mask) 122 | print(f"Successfully saved binary mask to {output_mask_path}") 123 | except Exception as e: 124 | print(f"Error saving mask to {output_mask_path}: {e}") 125 | 126 | ############################################# 127 | # Main Execution 128 | ############################################# 129 | if __name__ == "__main__": 130 | parser = argparse.ArgumentParser(description="Generate a binary mask for a single prompt point.") 131 | parser.add_argument("--feature_image", type=str, required=True, 132 | help="Path to the input feature PNG file.") 133 | parser.add_argument("--prompt_x", type=int, required=True, 134 | help="X-coordinate of the prompt point.") 135 | parser.add_argument("--prompt_y", type=int, required=True, 136 | help="Y-coordinate of the prompt point.") 137 | parser.add_argument("--output_mask", type=str, default="mask.png", 138 | help="Path to save the output binary mask PNG file.") 139 | parser.add_argument("--sigma", type=float, default=0.01, help="Gaussian sigma for the heatmap.") 140 | parser.add_argument("--threshold", type=int, default=3, 141 | help="Manual threshold for binarization (applied to 0-255 scale map).") 142 | 143 | args = parser.parse_args() 144 | 145 | # Check if cv2.ximgproc is available early 146 | if not (hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'jointBilateralFilter')): 147 | print("Warning: opencv-contrib-python might not be installed or cv2.ximgproc is not found.") 148 | print("The 'refine_with_bilateral_solver' function will skip refinement if the filter is unavailable.") 149 | print("To install: pip install opencv-contrib-python") 150 | 151 | 152 | generate_mask_for_single_prompt( 153 | feature_image_path=args.feature_image, 154 | prompt_point_xy=(args.prompt_x, args.prompt_y), 155 | output_mask_path=args.output_mask, 156 | gaussian_sigma=args.sigma, 157 | manual_threshold=args.threshold 158 | ) -------------------------------------------------------------------------------- /sam.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Single Image AutomaskGen Script for SAM (Segment Anything Model) 4 | with masks overlaid on a black background. 5 | 6 | This script: 7 | - Loads a pre-trained SAM model from a checkpoint. 8 | - Creates an automatic mask generator using SAM. 9 | - Processes a single input image to generate segmentation masks. 10 | - Overlays each mask (with a unique random color) on a black background. 11 | - Saves the resulting image. 12 | 13 | User Parameters: 14 | - INPUT_IMAGE: Path to the input image. 15 | - OUTPUT_IMAGE: Path where the output image will be saved. 16 | - MODEL_TYPE: Type of SAM model to use (e.g., "vit_h", "vit_l", "vit_b"). 17 | - CHECKPOINT_PATH: Path to the SAM model checkpoint. 18 | - DEVICE: Computation device ("cuda" or "cpu"). 19 | """ 20 | 21 | import os 22 | import cv2 23 | import numpy as np 24 | import random 25 | 26 | # Import SAM model components. Adjust the import based on your SAM repository installation. 27 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator 28 | 29 | ######################################### 30 | # USER PARAMETERS (edit these variables) 31 | ######################################### 32 | INPUT_IMAGE = "FILL THIS IN" # Path to the input image. 33 | OUTPUT_IMAGE = "sam.png" # Path where the output image will be saved. 34 | MODEL_TYPE = "vit_h" # Options typically include "vit_h", "vit_l", or "vit_b". 35 | CHECKPOINT_PATH = "PATH TO sam_vit_h_4b8939.pth" # Path to the SAM model checkpoint. 36 | DEVICE = "cuda" # Device for inference ("cuda" or "cpu"). 37 | 38 | print("Loading SAM model...") 39 | sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH) 40 | sam.to(DEVICE) 41 | 42 | # Instantiate the automatic mask generator with default parameters. 43 | mask_generator = SamAutomaticMaskGenerator(sam) 44 | 45 | # Load the input image using OpenCV. 46 | image_bgr = cv2.imread(INPUT_IMAGE) 47 | if image_bgr is None: 48 | print(f"Error: Unable to load image at {INPUT_IMAGE}") 49 | exit() 50 | # Convert the image to RGB since SAM expects RGB images. 51 | image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) 52 | 53 | # Create a black background with the same size as the input image. 54 | black_background = np.zeros_like(image_rgb) 55 | 56 | print("Generating masks...") 57 | # Generate masks automatically. 58 | masks = mask_generator.generate(image_rgb) 59 | num_masks = len(masks) 60 | print(f"Generated {num_masks} masks.") 61 | 62 | # Pre-generate a unique random color for each mask 63 | unique_colors = [] 64 | for _ in range(num_masks): 65 | # generate a random RGB tuple in [5,250] 66 | color = tuple(random.randint(5, 250) for _ in range(3)) 67 | # ensure uniqueness 68 | while color in unique_colors: 69 | color = tuple(random.randint(5, 250) for _ in range(3)) 70 | unique_colors.append(color) 71 | 72 | # Overlay each mask onto the black background with its unique color. 73 | for mask, color in zip(masks, unique_colors): 74 | segmentation = mask["segmentation"] # boolean mask 75 | color_arr = np.array(color, dtype=np.uint8) 76 | black_background[segmentation] = color_arr 77 | 78 | # Convert the resulting image from RGB to BGR for saving via OpenCV. 79 | output_bgr = cv2.cvtColor(black_background, cv2.COLOR_RGB2BGR) 80 | 81 | cv2.imwrite(OUTPUT_IMAGE, output_bgr) 82 | print(f"Output image saved to: {OUTPUT_IMAGE}") 83 | 84 | --------------------------------------------------------------------------------