├── README.md
├── assets
└── teaser.png
├── gen2seg_mae_pipeline.py
├── gen2seg_sd_pipeline.py
├── inference_mae.py
├── inference_sd.py
├── prompting.py
└── sam.py
/README.md:
--------------------------------------------------------------------------------
1 | # gen2seg: Generative Models Enable Generalizable Instance Segmentation
2 | [](https://huggingface.co/spaces/reachomk/gen2seg)
3 |
4 | ### [Project Page](https://reachomk.github.io/gen2seg) | [Paper](https://arxiv.org/abs/2505.15263)
5 |
6 | [**gen2seg: Generative Models Enable Generalizable Instance Segmentation**](https://reachomk.github.io/gen2seg)
7 | [Om Khangaonkar](https://reachomk.github.io),
8 | [Hamed Pirsiavash](https://web.cs.ucdavis.edu/~hpirsiav/)
9 | UC Davis
10 |
11 |
12 | ## Pretrained Models
13 | Stable Diffusion 2 (SD): https://huggingface.co/reachomk/gen2seg-sd
14 |
15 | ImageNet-1K-pretrained Masked Autoencoder-Huge (MAE-H): https://huggingface.co/reachomk/gen2seg-mae-h
16 |
17 | If you want any of our other models, send me an email. If there is sufficient demand, I will also release them publicly.
18 |
19 | ## Inference
20 | Currently, we have released inference code for our SD and MAE models. You can run them by editing the `image_path` variable (for your input image) in each file, and then simply running it with `python inference_{mae or sd}.py`.
21 |
22 | You will need to have `transformers` and `diffusers` installed, along with standard machine learning packages such as `pytorch` and `numpy`. More details on our specific environment will be released with the training code.
23 |
24 | We have also released code for prompting. Please run `pip install opencv-contrib-python` prior to running this file.
25 |
26 | Here is how you run it:
27 | ```
28 | python prompting.py \
29 | --feature_image /path/to/your/feature_image.png \
30 | --prompt_x 150 \
31 | --prompt_y 200 \
32 | ```
33 | The feature image is the one generated by our model, NOT the original image.
34 |
35 |
36 | We also have the additional optional arguments:
37 | ```
38 | --output_mask /path/to/save/output_mask.png \
39 | --sigma 0.02 \
40 | --threshold 10
41 | ```
42 |
43 | Threshold and sigma allow you to control the mask threshold (out of 255) and the amount of averaging for the query vector. By default they are 0.01 and 3. See our paper for more details.
44 |
45 | We have also provided our inference script for SAM, to enable qualitative comparison. Please make sure you download the checkpoint and input the path in the script. You should also edit the `image_path` variable (for your input image).
46 |
47 | ## Training
48 | I will release all training code by June 7 (likely earlier). Make sure to star and watch our repository so you're notified when we update it! Send me an email if you need it before that.
49 |
50 | ## Citation
51 | Please cite our paper if it was helpful or you liked it.
52 | ```
53 | @article{khangaonkar2025gen2seg,
54 | title={gen2seg: Generative Models Enable Generalizable Instance Segmentation},
55 | author={Om Khangaonkar and Hamed Pirsiavash},
56 | year={2025},
57 | journal={arXiv preprint arXiv:2505.15263}
58 | }
59 | ```
60 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/gen2seg/72e01e93feba569f580e1299edf21a7e08303ea7/assets/teaser.png
--------------------------------------------------------------------------------
/gen2seg_mae_pipeline.py:
--------------------------------------------------------------------------------
1 | # gen2seg official inference pipeline code for Stable Diffusion model
2 | #
3 | # Please see our project website at https://reachomk.github.io/gen2seg
4 | #
5 | # Additionally, if you use our code please cite our paper, along with the two works above.
6 |
7 | from dataclasses import dataclass
8 | from typing import Union, List, Optional
9 |
10 | import torch
11 | import numpy as np
12 | from PIL import Image
13 | from einops import rearrange
14 |
15 | from diffusers import DiffusionPipeline
16 | from diffusers.utils import BaseOutput, logging
17 | from transformers import AutoImageProcessor
18 |
19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20 |
21 |
22 | @dataclass
23 | class gen2segMAEInstanceOutput(BaseOutput):
24 | """
25 | Output class for the ViTMAE Instance Segmentation Pipeline.
26 |
27 | Args:
28 | prediction (`np.ndarray` or `torch.Tensor`):
29 | Predicted instance segmentation maps. The output has shape
30 | `(batch_size, 3, height, width)` with pixel values scaled to [0, 255].
31 | """
32 | prediction: Union[np.ndarray, torch.Tensor]
33 |
34 |
35 | class gen2segMAEInstancePipeline(DiffusionPipeline):
36 | r"""
37 | Pipeline for Instance Segmentation using a fine-tuned ViTMAEForPreTraining model.
38 |
39 | This pipeline takes one or more input images and returns an instance segmentation
40 | prediction for each image. The model is assumed to have been fine-tuned using an instance
41 | segmentation loss, and the reconstruction is performed by rearranging the model’s
42 | patch logits into an image.
43 |
44 | Args:
45 | model (`ViTMAEForPreTraining`):
46 | The fine-tuned ViTMAE model.
47 | image_processor (`AutoImageProcessor`):
48 | The image processor responsible for preprocessing input images.
49 | """
50 | def __init__(self, model, image_processor):
51 | super().__init__()
52 | self.register_modules(model=model, image_processor=image_processor)
53 | self.model = model
54 | self.image_processor = image_processor
55 |
56 | def check_inputs(
57 | self,
58 | image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]]
59 | ) -> List:
60 | if not isinstance(image, list):
61 | image = [image]
62 | # Additional input validations can be added here if desired.
63 | return image
64 |
65 | @torch.no_grad()
66 | def __call__(
67 | self,
68 | image: Union[Image.Image, np.ndarray, torch.Tensor, List[Union[Image.Image, np.ndarray, torch.Tensor]]],
69 | output_type: str = "np",
70 | **kwargs
71 | ) -> gen2segMAEInstanceOutput:
72 | r"""
73 | The call method of the pipeline.
74 |
75 | Args:
76 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, or a list of these):
77 | The input image(s) for instance segmentation. For arrays/tensors, expected values are in [0, 1].
78 | output_type (`str`, optional, defaults to `"np"`):
79 | The format of the output prediction. Choose `"np"` for a NumPy array or `"pt"` for a PyTorch tensor.
80 | **kwargs:
81 | Additional keyword arguments passed to the image processor.
82 |
83 | Returns:
84 | [`gen2segMAEInstanceOutput`]:
85 | An output object containing the predicted instance segmentation maps.
86 | """
87 | # 1. Check and prepare input images.
88 | images = self.check_inputs(image)
89 | inputs = self.image_processor(images=images, return_tensors="pt", **kwargs)
90 | pixel_values = inputs["pixel_values"].to(self.device)
91 |
92 | # 2. Forward pass through the model.
93 | outputs = self.model(pixel_values=pixel_values)
94 | logits = outputs.logits # Expected shape: (B, num_patches, patch_dim)
95 |
96 | # 3. Retrieve patch size and image size from the model configuration.
97 | patch_size = self.model.config.patch_size # e.g., 16
98 | image_size = self.model.config.image_size # e.g., 224
99 | grid_size = image_size // patch_size
100 |
101 | # 4. Rearrange logits into the reconstructed image.
102 | # The logits are reshaped from (B, num_patches, patch_dim) to (B, 3, H, W).
103 | reconstructed = rearrange(
104 | logits,
105 | "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
106 | h=grid_size,
107 | p1=patch_size,
108 | p2=patch_size,
109 | c=3,
110 | )
111 |
112 | # 5. Post-process the reconstructed output.
113 | # For each sample, shift and scale the prediction to [0, 255].
114 | predictions = []
115 | for i in range(reconstructed.shape[0]):
116 | sample = reconstructed[i]
117 | min_val = torch.abs(sample.min())
118 | max_val = torch.abs(sample.max())
119 | sample = (sample + min_val) / (max_val + min_val + 1e-5)
120 | # sometimes the image is very dark so we perform gamma correction to "brighten" it
121 | # in practice we can set this value to whatever we want or disable it entirely.
122 | sample = sample**0.7
123 | sample = sample * 255.0
124 | predictions.append(sample)
125 | prediction_tensor = torch.stack(predictions, dim=0).permute(0, 2, 3, 1)
126 |
127 | # 6. Format the output.
128 | if output_type == "np":
129 | prediction = prediction_tensor.cpu().numpy()
130 | else:
131 | prediction = prediction_tensor
132 | return gen2segMAEInstanceOutput(prediction=prediction)
133 |
--------------------------------------------------------------------------------
/gen2seg_sd_pipeline.py:
--------------------------------------------------------------------------------
1 | # gen2seg official inference pipeline code for Stable Diffusion model
2 | #
3 | # This code was adapted from Marigold and Diffusion E2E Finetuning.
4 | #
5 | # Please see our project website at https://reachomk.github.io/gen2seg
6 | #
7 | # Additionally, if you use our code please cite our paper, along with the two works above.
8 |
9 | from dataclasses import dataclass
10 | from typing import List, Optional, Tuple, Union
11 |
12 | import numpy as np
13 | import torch
14 | from PIL import Image
15 | from tqdm.auto import tqdm
16 | from transformers import CLIPTextModel, CLIPTokenizer
17 |
18 | from diffusers.image_processor import PipelineImageInput
19 | from diffusers.models import (
20 | AutoencoderKL,
21 | UNet2DConditionModel,
22 | )
23 | from diffusers.schedulers import (
24 | DDIMScheduler,
25 | )
26 | from diffusers.utils import (
27 | BaseOutput,
28 | logging,
29 | )
30 | from diffusers import DiffusionPipeline
31 | from diffusers.pipelines.marigold.marigold_image_processing import MarigoldImageProcessor
32 |
33 | # add
34 | def zeros_tensor(
35 | shape: Union[Tuple, List],
36 | device: Optional["torch.device"] = None,
37 | dtype: Optional["torch.dtype"] = None,
38 | layout: Optional["torch.layout"] = None,
39 | ):
40 | """
41 | A helper function to create tensors of zeros on the desired `device`.
42 | Mirrors randn_tensor from diffusers.utils.torch_utils.
43 | """
44 | layout = layout or torch.strided
45 | device = device or torch.device("cpu")
46 | latents = torch.zeros(list(shape), dtype=dtype, layout=layout).to(device)
47 | return latents
48 |
49 |
50 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51 |
52 | @dataclass
53 | class gen2segSDSegOutput(BaseOutput):
54 | """
55 | Output class for gen2seg Instance Segmentation prediction pipeline.
56 |
57 | Args:
58 | prediction (`np.ndarray`, `torch.Tensor`):
59 | Predicted instance segmentation with values in the range [0, 255]. The shape is always $numimages \times 1 \times height
60 | \times width$, regardless of whether the images were passed as a 4D array or a list.
61 | latent (`None`, `torch.Tensor`):
62 | Latent features corresponding to the predictions, compatible with the `latents` argument of the pipeline.
63 | The shape is $numimages * numensemble \times 4 \times latentheight \times latentwidth$.
64 | """
65 |
66 | prediction: Union[np.ndarray, torch.Tensor]
67 | latent: Union[None, torch.Tensor]
68 |
69 |
70 | class gen2segSDPipeline(DiffusionPipeline):
71 | """
72 | # add
73 | Pipeline for Instance Segmentation prediction using our Stable Diffusion model.
74 | Implementation is built upon Marigold: https://marigoldmonodepth.github.io and E2E FThttps://gonzalomartingarcia.github.io/diffusion-e2e-ft/
75 |
76 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
77 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
78 |
79 | Args:
80 | unet (`UNet2DConditionModel`):
81 | Conditional U-Net to denoise the segmentation latent, synthesized from image latent.
82 | vae (`AutoencoderKL`):
83 | Variational Auto-Encoder (VAE) Model to encode and decode images and predictions to and from latent
84 | representations.
85 | scheduler (`DDIMScheduler`):
86 | A scheduler to be used in combination with `unet` to denoise the encoded image latent.
87 | text_encoder (`CLIPTextModel`):
88 | Text-encoder, for empty text embedding.
89 | tokenizer (`CLIPTokenizer`):
90 | CLIP tokenizer.
91 | default_processing_resolution (`int`, *optional*):
92 | The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in
93 | the model config. When the pipeline is called without explicitly setting `processing_resolution`, the
94 | default value is used. This is required to ensure reasonable results with various model flavors trained
95 | with varying optimal processing resolution values.
96 | """
97 |
98 | model_cpu_offload_seq = "text_encoder->unet->vae"
99 |
100 | def __init__(
101 | self,
102 | unet: UNet2DConditionModel,
103 | vae: AutoencoderKL,
104 | scheduler: Union[DDIMScheduler],
105 | text_encoder: CLIPTextModel,
106 | tokenizer: CLIPTokenizer,
107 | default_processing_resolution: Optional[int] = 768, # add
108 | ):
109 | super().__init__()
110 |
111 | self.register_modules(
112 | unet=unet,
113 | vae=vae,
114 | scheduler=scheduler,
115 | text_encoder=text_encoder,
116 | tokenizer=tokenizer,
117 | )
118 | self.register_to_config(
119 | default_processing_resolution=default_processing_resolution,
120 | )
121 |
122 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
123 | self.default_processing_resolution = default_processing_resolution
124 | self.empty_text_embedding = None
125 |
126 | self.image_processor = MarigoldImageProcessor(vae_scale_factor=self.vae_scale_factor)
127 |
128 | def check_inputs(
129 | self,
130 | image: PipelineImageInput,
131 | processing_resolution: int,
132 | resample_method_input: str,
133 | resample_method_output: str,
134 | batch_size: int,
135 | output_type: str,
136 | ) -> int:
137 | if processing_resolution is None:
138 | raise ValueError(
139 | "`processing_resolution` is not specified and could not be resolved from the model config."
140 | )
141 | if processing_resolution < 0:
142 | raise ValueError(
143 | "`processing_resolution` must be non-negative: 0 for native resolution, or any positive value for "
144 | "downsampled processing."
145 | )
146 | if processing_resolution % self.vae_scale_factor != 0:
147 | raise ValueError(f"`processing_resolution` must be a multiple of {self.vae_scale_factor}.")
148 | if resample_method_input not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
149 | raise ValueError(
150 | "`resample_method_input` takes string values compatible with PIL library: "
151 | "nearest, nearest-exact, bilinear, bicubic, area."
152 | )
153 | if resample_method_output not in ("nearest", "nearest-exact", "bilinear", "bicubic", "area"):
154 | raise ValueError(
155 | "`resample_method_output` takes string values compatible with PIL library: "
156 | "nearest, nearest-exact, bilinear, bicubic, area."
157 | )
158 | if batch_size < 1:
159 | raise ValueError("`batch_size` must be positive.")
160 | if output_type not in ["pt", "np"]:
161 | raise ValueError("`output_type` must be one of `pt` or `np`.")
162 |
163 | # image checks
164 | num_images = 0
165 | W, H = None, None
166 | if not isinstance(image, list):
167 | image = [image]
168 | for i, img in enumerate(image):
169 | if isinstance(img, np.ndarray) or torch.is_tensor(img):
170 | if img.ndim not in (2, 3, 4):
171 | raise ValueError(f"`image[{i}]` has unsupported dimensions or shape: {img.shape}.")
172 | H_i, W_i = img.shape[-2:]
173 | N_i = 1
174 | if img.ndim == 4:
175 | N_i = img.shape[0]
176 | elif isinstance(img, Image.Image):
177 | W_i, H_i = img.size
178 | N_i = 1
179 | else:
180 | raise ValueError(f"Unsupported `image[{i}]` type: {type(img)}.")
181 | if W is None:
182 | W, H = W_i, H_i
183 | elif (W, H) != (W_i, H_i):
184 | raise ValueError(
185 | f"Input `image[{i}]` has incompatible dimensions {(W_i, H_i)} with the previous images {(W, H)}"
186 | )
187 | num_images += N_i
188 |
189 | return num_images
190 |
191 | def progress_bar(self, iterable=None, total=None, desc=None, leave=True):
192 | if not hasattr(self, "_progress_bar_config"):
193 | self._progress_bar_config = {}
194 | elif not isinstance(self._progress_bar_config, dict):
195 | raise ValueError(
196 | f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
197 | )
198 |
199 | progress_bar_config = dict(**self._progress_bar_config)
200 | progress_bar_config["desc"] = progress_bar_config.get("desc", desc)
201 | progress_bar_config["leave"] = progress_bar_config.get("leave", leave)
202 | if iterable is not None:
203 | return tqdm(iterable, **progress_bar_config)
204 | elif total is not None:
205 | return tqdm(total=total, **progress_bar_config)
206 | else:
207 | raise ValueError("Either `total` or `iterable` has to be defined.")
208 |
209 | @torch.no_grad()
210 | def __call__(
211 | self,
212 | image: PipelineImageInput,
213 | processing_resolution: Optional[int] = None,
214 | match_input_resolution: bool = False,
215 | resample_method_input: str = "bilinear",
216 | resample_method_output: str = "bilinear",
217 | batch_size: int = 1,
218 | output_type: str = "np",
219 | output_latent: bool = False,
220 | return_dict: bool = True,
221 | ):
222 | """
223 | Function invoked when calling the pipeline.
224 |
225 | Args:
226 | image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`),
227 | `List[torch.Tensor]`: An input image or images used as an input for the instance segmentation task. For
228 | arrays and tensors, the expected value range is between `[0, 1]`. Passing a batch of images is possible
229 | by providing a four-dimensional array or a tensor. Additionally, a list of images of two- or
230 | three-dimensional arrays or tensors can be passed. In the latter case, all list elements must have the
231 | same width and height.
232 | processing_resolution (`int`, *optional*, defaults to `None`):
233 | Effective processing resolution. When set to `0`, matches the larger input image dimension. This
234 | produces crisper predictions, but may also lead to the overall loss of global context. The default
235 | value `None` resolves to the optimal value from the model config.
236 | match_input_resolution (`bool`, *optional*, defaults to `True`):
237 | When enabled, the output prediction is resized to match the input dimensions. When disabled, the longer
238 | side of the output will equal to `processing_resolution`.
239 | resample_method_input (`str`, *optional*, defaults to `"bilinear"`):
240 | Resampling method used to resize input images to `processing_resolution`. The accepted values are:
241 | `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
242 | resample_method_output (`str`, *optional*, defaults to `"bilinear"`):
243 | Resampling method used to resize output predictions to match the input resolution. The accepted values
244 | are `"nearest"`, `"nearest-exact"`, `"bilinear"`, `"bicubic"`, or `"area"`.
245 | batch_size (`int`, *optional*, defaults to `1`):
246 | Batch size; only matters passing a tensor of images.
247 | output_type (`str`, *optional*, defaults to `"np"`):
248 | Preferred format of the output's `prediction`. The accepted ßvalues are: `"np"` (numpy array) or `"pt"` (torch tensor).
249 | output_latent (`bool`, *optional*, defaults to `False`):
250 | When enabled, the output's `latent` field contains the latent codes corresponding to the predictions
251 | within the ensemble. These codes can be saved, modified, and used for subsequent calls with the
252 | `latents` argument.
253 | return_dict (`bool`, *optional*, defaults to `True`):
254 | Whether or not to return a [`gen2segSDSegOutput`] instead of a plain tuple.
255 |
256 | # add
257 | E2E FT models are deterministic single step models involving no ensembling, i.e. E=1.
258 | """
259 |
260 | # 0. Resolving variables.
261 | device = self._execution_device
262 | dtype = self.dtype
263 |
264 | # Model-specific optimal default values leading to fast and reasonable results.
265 | if processing_resolution is None:
266 | processing_resolution = self.default_processing_resolution
267 |
268 | #print(image[0].size)
269 | #processing_resolution = 8 * round(max(image[0].size) / 8)
270 |
271 | # 1. Check inputs.
272 | num_images = self.check_inputs(
273 | image,
274 | processing_resolution,
275 | resample_method_input,
276 | resample_method_output,
277 | batch_size,
278 | output_type,
279 | )
280 |
281 | # 2. Prepare empty text conditioning.
282 | # Model invocation: self.tokenizer, self.text_encoder.
283 | prompt = ""
284 | text_inputs = self.tokenizer(
285 | prompt,
286 | padding="do_not_pad",
287 | max_length=self.tokenizer.model_max_length,
288 | truncation=True,
289 | return_tensors="pt",
290 | )
291 | text_input_ids = text_inputs.input_ids.to(device)
292 | self.empty_text_embedding = self.text_encoder(text_input_ids)[0] # [1,2,1024]
293 |
294 | # 3. Preprocess input images. This function loads input image or images of compatible dimensions `(H, W)`,
295 | # optionally downsamples them to the `processing_resolution` `(PH, PW)`, where
296 | # `max(PH, PW) == processing_resolution`, and pads the dimensions to `(PPH, PPW)` such that these values are
297 | # divisible by the latent space downscaling factor (typically 8 in Stable Diffusion). The default value `None`
298 | # of `processing_resolution` resolves to the optimal value from the model config. It is a recommended mode of
299 | # operation and leads to the most reasonable results. Using the native image resolution or any other processing
300 | # resolution can lead to loss of either fine details or global context in the output predictions.
301 | image, padding, original_resolution = self.image_processor.preprocess(
302 | image, processing_resolution, resample_method_input, device, dtype
303 | ) # [N,3,PPH,PPW]
304 | # image =(image+torch.abs(image.min()))
305 | # image = image/(torch.abs(image.max())+torch.abs(image.min()))
306 | # # prediction = prediction**0.5
307 | # #prediction = torch.clip(prediction, min=-1, max=1)+1
308 | # image = (image) * 2
309 | # image = image - 1
310 | # 4. Encode input image into latent space. At this step, each of the `N` input images is represented with `E`
311 | # ensemble members. Each ensemble member is an independent diffused prediction, just initialized independently.
312 | # Latents of each such predictions across all input images and all ensemble members are represented in the
313 | # `pred_latent` variable. The variable `image_latent` is of the same shape: it contains each input image encoded
314 | # into latent space and replicated `E` times. Encoding into latent space happens in batches of size `batch_size`.
315 | # Model invocation: self.vae.encoder.
316 | image_latent, pred_latent = self.prepare_latents(
317 | image, batch_size
318 | ) # [N*E,4,h,w], [N*E,4,h,w]
319 |
320 | del image
321 |
322 | batch_empty_text_embedding = self.empty_text_embedding.to(device=device, dtype=dtype).repeat(
323 | batch_size, 1, 1
324 | ) # [B,1024,2]
325 |
326 | # 5. Process the denoising loop. All `N * E` latents are processed sequentially in batches of size `batch_size`.
327 | # The unet model takes concatenated latent spaces of the input image and the predicted modality as an input, and
328 | # outputs noise for the predicted modality's latent space.
329 | # Model invocation: self.unet.
330 | pred_latents = []
331 |
332 | for i in range(0, num_images, batch_size):
333 | batch_image_latent = image_latent[i : i + batch_size] # [B,4,h,w]
334 | batch_pred_latent = batch_image_latent[i : i + batch_size] # [B,4,h,w]
335 | effective_batch_size = batch_image_latent.shape[0]
336 | text = batch_empty_text_embedding[:effective_batch_size] # [B,2,1024]
337 |
338 | # add
339 | # Single step inference for E2E FT models
340 | self.scheduler.set_timesteps(1, device=device)
341 | for t in self.scheduler.timesteps:
342 | batch_latent = batch_image_latent # torch.cat([batch_image_latent, batch_pred_latent], dim=1) # [B,8,h,w]
343 | noise = self.unet(batch_latent, t, encoder_hidden_states=text, return_dict=False)[0] # [B,4,h,w]
344 | batch_pred_latent = self.scheduler.step(
345 | noise, t, batch_image_latent
346 | ).pred_original_sample # [B,4,h,w], # add
347 | # directly take pred_original_sample rather than prev_sample
348 |
349 | pred_latents.append(batch_pred_latent)
350 |
351 | pred_latent = torch.cat(pred_latents, dim=0) # [N*E,4,h,w]
352 |
353 | del (
354 | pred_latents,
355 | image_latent,
356 | batch_empty_text_embedding,
357 | batch_image_latent,
358 | # batch_pred_latent,
359 | text,
360 | batch_latent,
361 | noise,
362 | )
363 |
364 | # 6. Decode predictions from latent into pixel space. The resulting `N * E` predictions have shape `(PPH, PPW)`,
365 | # which requires slight postprocessing. Decoding into pixel space happens in batches of size `batch_size`.
366 | # Model invocation: self.vae.decoder.
367 | prediction = torch.cat(
368 | [
369 | self.decode_prediction(pred_latent[i : i + batch_size])
370 | for i in range(0, pred_latent.shape[0], batch_size)
371 | ],
372 | dim=0,
373 | ) # [N*E,1,PPH,PPW]
374 |
375 | if not output_latent:
376 | pred_latent = None
377 |
378 | # 7. Remove padding. The output shape is (PH, PW).
379 | prediction = self.image_processor.unpad_image(prediction, padding) # [N*E,1,PH,PW]
380 |
381 | # 9. If `match_input_resolution` is set, the output prediction are upsampled to match the
382 | # input resolution `(H, W)`. This step may introduce upsampling artifacts, and therefore can be disabled.
383 | # Depending on the downstream use-case, upsampling can be also chosen based on the tolerated artifacts by
384 | # setting the `resample_method_output` parameter (e.g., to `"nearest"`).
385 | if match_input_resolution:
386 | prediction = self.image_processor.resize_antialias(
387 | prediction, original_resolution, resample_method_output, is_aa=False
388 | ) # [N,1,H,W]
389 |
390 | # 10. Prepare the final outputs.
391 | if output_type == "np":
392 | prediction = self.image_processor.pt_to_numpy(prediction) # [N,H,W,1]
393 |
394 | # 11. Offload all models
395 | self.maybe_free_model_hooks()
396 |
397 | if not return_dict:
398 | return (prediction, pred_latent)
399 |
400 | return gen2segSDSegOutput(
401 | prediction=prediction,
402 | latent=pred_latent,
403 | )
404 |
405 | def prepare_latents(
406 | self,
407 | image: torch.Tensor,
408 | batch_size: int,
409 | ) -> Tuple[torch.Tensor, torch.Tensor]:
410 | def retrieve_latents(encoder_output):
411 | if hasattr(encoder_output, "latent_dist"):
412 | return encoder_output.latent_dist.mode()
413 | elif hasattr(encoder_output, "latents"):
414 | return encoder_output.latents
415 | else:
416 | raise AttributeError("Could not access latents of provided encoder_output")
417 |
418 | image_latent = torch.cat(
419 | [
420 | retrieve_latents(self.vae.encode(image[i : i + batch_size]))
421 | for i in range(0, image.shape[0], batch_size)
422 | ],
423 | dim=0,
424 | ) # [N,4,h,w]
425 | image_latent = image_latent * self.vae.config.scaling_factor # [N*E,4,h,w]
426 |
427 | # add
428 | # provide zeros as noised latent
429 | pred_latent = zeros_tensor(
430 | image_latent.shape,
431 | device=image_latent.device,
432 | dtype=image_latent.dtype,
433 | ) # [N*E,4,h,w]
434 |
435 | return image_latent, pred_latent
436 |
437 | def decode_prediction(self, pred_latent: torch.Tensor) -> torch.Tensor:
438 | if pred_latent.dim() != 4 or pred_latent.shape[1] != self.vae.config.latent_channels:
439 | raise ValueError(
440 | f"Expecting 4D tensor of shape [B,{self.vae.config.latent_channels},H,W]; got {pred_latent.shape}."
441 | )
442 |
443 | prediction = self.vae.decode(pred_latent / self.vae.config.scaling_factor, return_dict=False)[0] # [B,3,H,W]
444 | #print(prediction.max())
445 | #print(prediction.min())
446 |
447 | prediction =(prediction+torch.abs(prediction.min()))
448 | prediction = prediction/(torch.abs(prediction.max())+torch.abs(prediction.min()))
449 | #prediction = prediction**0.5
450 | #prediction = torch.clip(prediction, min=-1, max=1)+1
451 | prediction = (prediction) * 255.0
452 | #print(prediction.max())
453 | #print(prediction.min())
454 | return prediction # [B,1,H,W]
--------------------------------------------------------------------------------
/inference_mae.py:
--------------------------------------------------------------------------------
1 | # gen2seg official inference pipeline code for Stable Diffusion model
2 | #
3 | # Please see our project website at https://reachomk.github.io/gen2seg
4 | #
5 | # Additionally, if you use our code please cite our paper, along with the two works above.
6 |
7 |
8 | import os
9 | import time
10 | import torch
11 | from gen2seg_mae_pipeline import gen2segMAEInstancePipeline # Custom pipeline for MAE
12 | from transformers import AutoImageProcessor
13 | from PIL import Image
14 | import numpy as np
15 |
16 |
17 | # Example usage: Update these paths as needed.
18 | image_path = "FILL THIS OUT" # Path to the input image.
19 | output_path = "seg_mae.png" # Path to save the output image.
20 | device = "cuda:0" # Change to "cpu" if no GPU is available.
21 |
22 | print(f"Loading MAE pipeline on {device} for single image inference...")
23 |
24 | # Load the image processor (using a pretrained processor from facebook/vit-mae-huge).
25 | image_processor = AutoImageProcessor.from_pretrained("facebook/vit-mae-huge")
26 |
27 | # Instantiate the pipeline and move it to the desired device.
28 | pipe = gen2segMAEInstancePipeline(model="reachomk/gen2seg-mae-h", image_processor=image_processor).to(device)
29 |
30 | # Load the image, storing the original size, then resize for inference.
31 | orig_image = Image.open(image_path).convert("RGB")
32 | orig_size = orig_image.size # (width, height)
33 | image = orig_image.resize((224, 224))
34 |
35 | # Run inference.
36 | start_time = time.time()
37 | with torch.no_grad():
38 | pipe_output = pipe([image])
39 | end_time = time.time()
40 | print(f"Inference completed in {end_time - start_time:.2f} seconds.")
41 | prediction = pipe_output.prediction[0]
42 |
43 | # Convert the prediction to an image.
44 | seg = np.array(prediction.squeeze()).astype(np.uint8)
45 | seg_img = Image.fromarray(seg)
46 |
47 | # Resize the segmentation output back to the original image size.
48 | seg_img = seg_img.resize(orig_size, Image.LANCZOS)
49 |
50 | # Save the output image.
51 | seg_img.save(output_path)
52 | print(f"Saved output image to {output_path}")
--------------------------------------------------------------------------------
/inference_sd.py:
--------------------------------------------------------------------------------
1 | # gen2seg official inference pipeline code for Stable Diffusion model
2 | #
3 | # Please see our project website at https://reachomk.github.io/gen2seg
4 | #
5 | # Additionally, if you use our code please cite our paper.
6 |
7 |
8 | import torch
9 | from gen2seg_sd_pipeline import gen2segSDPipeline # Import your custom pipeline
10 | from PIL import Image
11 | import numpy as np
12 | import time
13 |
14 | # Load the image
15 | image_path = 'FILL THIS IN'
16 | image = Image.open(image_path).convert("RGB")
17 | orig_res = image.size
18 | output_path = "seg.png"
19 |
20 | pipe = gen2segSDPipeline.from_pretrained(
21 | "reachomk/gen2seg-sd",
22 | use_safetensors=True, # Use safetensors if available
23 | ).to("cuda") # Ensure the pipeline is moved to CUDA
24 |
25 | # Load the pipeline and generate the segmentation map
26 | with torch.no_grad():
27 |
28 | start_time = time.time()
29 | # Generate segmentation map
30 | seg = pipe(image).prediction.squeeze()
31 | end_time = time.time()
32 | print(f"Inference completed in {end_time - start_time:.2f} seconds.")
33 |
34 | seg = np.array(seg).astype(np.uint8)
35 | Image.fromarray(seg).resize(orig_res, Image.LANCZOS).save(output_path)
36 | print(f"Saved output image to {output_path}")
--------------------------------------------------------------------------------
/prompting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import cv2
5 | import torch
6 | from PIL import Image
7 |
8 | #############################################
9 | # Gaussian Heatmap Generator
10 | #############################################
11 | def create_gaussian_heatmap(H, W, point, sigma, device='cpu'):
12 | """
13 | Creates a Gaussian heatmap of size (H, W) centered at the given normalized point.
14 | The heatmap is scaled from [0,1] to [-1,1] and returned with an extra channel.
15 | """
16 | ys = torch.linspace(0, 1, H, device=device)
17 | xs = torch.linspace(0, 1, W, device=device)
18 | y_grid, x_grid = torch.meshgrid(ys, xs, indexing='ij')
19 | # point is (normalized_y, normalized_x)
20 | dist_sq = (x_grid - point[1])**2 + (y_grid - point[0])**2
21 | heatmap = torch.exp(-dist_sq / (2 * sigma**2))
22 | heatmap = heatmap * 2 - 1 # scale from [0,1] -> [-1,1]
23 | return heatmap.unsqueeze(0) # shape: (1, H, W)
24 |
25 | #############################################
26 | # Bilateral Solver (Refinement) Function
27 | #############################################
28 | def refine_with_bilateral_solver(sim_map, guidance_img, d=9, sigmaColor=75, sigmaSpace=75):
29 | """
30 | Refines a similarity map using cv2.ximgproc.jointBilateralFilter with the guidance image.
31 | """
32 | if not (hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'jointBilateralFilter')):
33 | print("WARNING: cv2.ximgproc.jointBilateralFilter is not available. Install opencv-contrib-python.")
34 | print("Skipping refinement.")
35 | return sim_map # Return unrefined map if filter is not available
36 |
37 | sim_map_8u = np.clip(sim_map * 255, 0, 255).astype(np.uint8)
38 | refined = cv2.ximgproc.jointBilateralFilter(guidance_img, sim_map_8u, d, sigmaColor, sigmaSpace)
39 | refined_float = refined.astype(np.float32) / 255.0
40 | return refined_float
41 |
42 | ##################################
43 | # Process a Single Prompt Point
44 | ##################################
45 | def generate_mask_for_single_prompt(feature_image_path, prompt_point_xy, output_mask_path,
46 | gaussian_sigma=0.01, manual_threshold=3, epsilon=1e-6):
47 | """
48 | Generates and saves a binary mask for a single prompt point on a feature image.
49 | The image is processed at its original dimensions.
50 | """
51 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
52 | print(f"Using device: {device}")
53 |
54 | # 1. Load the feature image
55 | try:
56 | img = Image.open(feature_image_path).convert("RGB")
57 | except FileNotFoundError:
58 | print(f"Error: Feature image not found at {feature_image_path}")
59 | return
60 | except Exception as e:
61 | print(f"Error opening image {feature_image_path}: {e}")
62 | return
63 |
64 | W_orig, H_orig = img.size # Get original dimensions
65 |
66 | img_np = np.array(img).astype(np.float32)
67 | features_tensor = torch.from_numpy(img_np).permute(2, 0, 1).to(device) # (3, H, W)
68 | _, H, W = features_tensor.shape # H, W will be H_orig, W_orig
69 |
70 | # Create guidance image for bilateral filter (original RGB image)
71 | guidance_img_np = np.array(img).astype(np.uint8) # cv2 expects uint8 BGR or grayscale
72 | if guidance_img_np.shape[2] == 3: # RGB
73 | guidance_img_bgr = cv2.cvtColor(guidance_img_np, cv2.COLOR_RGB2BGR)
74 | else: # Grayscale
75 | guidance_img_bgr = guidance_img_np
76 |
77 |
78 | print(f"Image dimensions: H={H}, W={W}")
79 | print(f"Prompt (x,y): {prompt_point_xy}")
80 |
81 | # 2. Normalize prompt point and create Gaussian heatmap
82 | # prompt_point_xy is (x, y), create_gaussian_heatmap expects (norm_y, norm_x)
83 | # Ensure prompt point is within image bounds
84 | if not (0 <= prompt_point_xy[0] < W and 0 <= prompt_point_xy[1] < H):
85 | print(f"Error: Prompt point ({prompt_point_xy[0]},{prompt_point_xy[1]}) is outside image bounds ({W-1},{H-1}).")
86 | return
87 |
88 | norm_y = prompt_point_xy[1] / (H - 1) if H > 1 else 0.5
89 | norm_x = prompt_point_xy[0] / (W - 1) if W > 1 else 0.5
90 | norm_point = (norm_y, norm_x)
91 |
92 | prompt_heatmap = create_gaussian_heatmap(H, W, norm_point, sigma=gaussian_sigma, device=device)
93 | prompt_weights = (prompt_heatmap + 1) / 2 # Convert from [-1,1] to [0,1].
94 |
95 | # 3. Compute weighted color and similarity map
96 | weighted_sum_rgb = torch.sum(features_tensor * prompt_weights, dim=(1, 2))
97 | total_weight = torch.sum(prompt_weights)
98 | query_color_rgb = weighted_sum_rgb / (total_weight + epsilon)
99 |
100 | diff_rgb = features_tensor - query_color_rgb[:, None, None]
101 | distance_map_rgb = torch.norm(diff_rgb, dim=0)
102 | similarity_map_rgb = 1.0 / (distance_map_rgb + epsilon)
103 |
104 | min_val_rgb = similarity_map_rgb.min()
105 | max_val_rgb = similarity_map_rgb.max()
106 | # Add epsilon to prevent division by zero if max_val_rgb == min_val_rgb
107 | normalized_similarity_rgb = (similarity_map_rgb - min_val_rgb) / (max_val_rgb - min_val_rgb + epsilon)
108 | normalized_similarity_rgb = normalized_similarity_rgb.view(H, W)
109 |
110 | # 4. Refine similarity map
111 | # Use guidance_img_bgr which is the original color image in BGR uint8 format
112 | refined_sim_map_np = refine_with_bilateral_solver(
113 | normalized_similarity_rgb.cpu().numpy().astype(np.float32), guidance_img_bgr
114 | )
115 |
116 | # 5. Threshold to produce binary mask
117 | binary_mask = ((refined_sim_map_np * 255) > manual_threshold).astype(np.uint8) * 255
118 |
119 | # 6. Save the binary mask
120 | try:
121 | cv2.imwrite(output_mask_path, binary_mask)
122 | print(f"Successfully saved binary mask to {output_mask_path}")
123 | except Exception as e:
124 | print(f"Error saving mask to {output_mask_path}: {e}")
125 |
126 | #############################################
127 | # Main Execution
128 | #############################################
129 | if __name__ == "__main__":
130 | parser = argparse.ArgumentParser(description="Generate a binary mask for a single prompt point.")
131 | parser.add_argument("--feature_image", type=str, required=True,
132 | help="Path to the input feature PNG file.")
133 | parser.add_argument("--prompt_x", type=int, required=True,
134 | help="X-coordinate of the prompt point.")
135 | parser.add_argument("--prompt_y", type=int, required=True,
136 | help="Y-coordinate of the prompt point.")
137 | parser.add_argument("--output_mask", type=str, default="mask.png",
138 | help="Path to save the output binary mask PNG file.")
139 | parser.add_argument("--sigma", type=float, default=0.01, help="Gaussian sigma for the heatmap.")
140 | parser.add_argument("--threshold", type=int, default=3,
141 | help="Manual threshold for binarization (applied to 0-255 scale map).")
142 |
143 | args = parser.parse_args()
144 |
145 | # Check if cv2.ximgproc is available early
146 | if not (hasattr(cv2, 'ximgproc') and hasattr(cv2.ximgproc, 'jointBilateralFilter')):
147 | print("Warning: opencv-contrib-python might not be installed or cv2.ximgproc is not found.")
148 | print("The 'refine_with_bilateral_solver' function will skip refinement if the filter is unavailable.")
149 | print("To install: pip install opencv-contrib-python")
150 |
151 |
152 | generate_mask_for_single_prompt(
153 | feature_image_path=args.feature_image,
154 | prompt_point_xy=(args.prompt_x, args.prompt_y),
155 | output_mask_path=args.output_mask,
156 | gaussian_sigma=args.sigma,
157 | manual_threshold=args.threshold
158 | )
--------------------------------------------------------------------------------
/sam.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Single Image AutomaskGen Script for SAM (Segment Anything Model)
4 | with masks overlaid on a black background.
5 |
6 | This script:
7 | - Loads a pre-trained SAM model from a checkpoint.
8 | - Creates an automatic mask generator using SAM.
9 | - Processes a single input image to generate segmentation masks.
10 | - Overlays each mask (with a unique random color) on a black background.
11 | - Saves the resulting image.
12 |
13 | User Parameters:
14 | - INPUT_IMAGE: Path to the input image.
15 | - OUTPUT_IMAGE: Path where the output image will be saved.
16 | - MODEL_TYPE: Type of SAM model to use (e.g., "vit_h", "vit_l", "vit_b").
17 | - CHECKPOINT_PATH: Path to the SAM model checkpoint.
18 | - DEVICE: Computation device ("cuda" or "cpu").
19 | """
20 |
21 | import os
22 | import cv2
23 | import numpy as np
24 | import random
25 |
26 | # Import SAM model components. Adjust the import based on your SAM repository installation.
27 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
28 |
29 | #########################################
30 | # USER PARAMETERS (edit these variables)
31 | #########################################
32 | INPUT_IMAGE = "FILL THIS IN" # Path to the input image.
33 | OUTPUT_IMAGE = "sam.png" # Path where the output image will be saved.
34 | MODEL_TYPE = "vit_h" # Options typically include "vit_h", "vit_l", or "vit_b".
35 | CHECKPOINT_PATH = "PATH TO sam_vit_h_4b8939.pth" # Path to the SAM model checkpoint.
36 | DEVICE = "cuda" # Device for inference ("cuda" or "cpu").
37 |
38 | print("Loading SAM model...")
39 | sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH)
40 | sam.to(DEVICE)
41 |
42 | # Instantiate the automatic mask generator with default parameters.
43 | mask_generator = SamAutomaticMaskGenerator(sam)
44 |
45 | # Load the input image using OpenCV.
46 | image_bgr = cv2.imread(INPUT_IMAGE)
47 | if image_bgr is None:
48 | print(f"Error: Unable to load image at {INPUT_IMAGE}")
49 | exit()
50 | # Convert the image to RGB since SAM expects RGB images.
51 | image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
52 |
53 | # Create a black background with the same size as the input image.
54 | black_background = np.zeros_like(image_rgb)
55 |
56 | print("Generating masks...")
57 | # Generate masks automatically.
58 | masks = mask_generator.generate(image_rgb)
59 | num_masks = len(masks)
60 | print(f"Generated {num_masks} masks.")
61 |
62 | # Pre-generate a unique random color for each mask
63 | unique_colors = []
64 | for _ in range(num_masks):
65 | # generate a random RGB tuple in [5,250]
66 | color = tuple(random.randint(5, 250) for _ in range(3))
67 | # ensure uniqueness
68 | while color in unique_colors:
69 | color = tuple(random.randint(5, 250) for _ in range(3))
70 | unique_colors.append(color)
71 |
72 | # Overlay each mask onto the black background with its unique color.
73 | for mask, color in zip(masks, unique_colors):
74 | segmentation = mask["segmentation"] # boolean mask
75 | color_arr = np.array(color, dtype=np.uint8)
76 | black_background[segmentation] = color_arr
77 |
78 | # Convert the resulting image from RGB to BGR for saving via OpenCV.
79 | output_bgr = cv2.cvtColor(black_background, cv2.COLOR_RGB2BGR)
80 |
81 | cv2.imwrite(OUTPUT_IMAGE, output_bgr)
82 | print(f"Output image saved to: {OUTPUT_IMAGE}")
83 |
84 |
--------------------------------------------------------------------------------