├── README.md ├── dift_util.py ├── extract_semantic_point.py └── image.jpg /README.md: -------------------------------------------------------------------------------- 1 |

ICLR25: Incorporating Visual Correspondence into Diffusion Model for Virtual Try-On

2 |
3 | Siqi Wan1, 4 | Jingwen Chen2, 5 | Yingwei Pan2, 6 | Ting Yao2, 7 | Tao Mei2 8 |
9 |
10 | 1University of Science and Technology of China; 2HiDream.ai Inc 11 |
12 |
13 | 14 | This is the official repository for the 15 | [Paper](*) 16 | "Incorporating Visual Correspondence into Diffusion Model for Visual Try-On" 17 | 18 | ## Overview 19 | ![](image.jpg "Overview of our approach") 20 | We novelly propose to explicitly capitalize 21 | on visual correspondence as the prior to tame diffusion process instead of simply 22 | feeding the whole garment into UNet as the appearance reference. 23 | 24 | ## Installation 25 | Create a conda environment & Install requirments 26 | ``` 27 | conda create -n SPM-Diff python==3.9.0 28 | conda activate SPM-Diff 29 | cd SPM-Diff-main 30 | pip install -r requirements.txt 31 | ``` 32 | ## Semantic Point Matching 33 | In SPM, a set of semantic points on the garment are first sampled and matched to the 34 | corresponding points on the target person via local flow warping. Then, these 2D cues are augmented 35 | into 3D-aware cues with depth/normal map, which act as semantic point matching to supervise 36 | diffusion model. 37 | 38 | You can directly download the [Semantic Point Feature](*) or follow the instructions in [preprocessing.md](*) to extract the Semantic Point Feature yourself. 39 | 40 | ## Dataset 41 | You can download the VITON-HD dataset from [here](https://github.com/sangyun884/HR-VITON))
42 | For inference, the following dataset structure is required:
43 | ``` 44 | test 45 | |-- image 46 | |-- masked_vton_img 47 | |-- warp-cloth 48 | |-- cloth 49 | |-- cloth_mask 50 | |-- point 51 | ``` 52 | ## Inference 53 | Please download the pre-trained model from [Link](https://huggingface.co/HiDream-ai/SPM-Diff). 54 | ``` 55 | sh inference.sh 56 | ``` 57 | ## Acknowledgement 58 | Thanks the contribution of [LaDI-VTON](https://github.com/miccunifi/ladi-vton) and [GP-VTON](https://github.com/xiezhy6/GP-VTON). 59 | ## Citation 60 | ``` 61 | @inproceedings{ 62 | wan2025incorporating, 63 | title={Incorporating Visual Correspondence into Diffusion Model for Virtual Try-On}, 64 | author={Siqi Wan and Jingwen Chen and Yingwei Pan and Ting Yao and Tao Mei}, 65 | booktitle={ICLR}, 66 | year={2025}, 67 | } 68 | ``` 69 | 70 | 71 | -------------------------------------------------------------------------------- /dift_util.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from typing import Any, Dict, Optional, Union 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | from diffusers import DDIMScheduler, StableDiffusionPipeline 9 | from diffusers.models.unet_2d_condition import UNet2DConditionModel 10 | from PIL import Image, ImageDraw 11 | from typing import Any, Callable, Dict, List, Optional, Union 12 | 13 | import numpy as np 14 | import PIL 15 | import torch 16 | from diffusers import DDIMInverseScheduler, StableDiffusionPipeline 17 | 18 | from diffusers.models import AutoencoderKL, T2IAdapter 19 | 20 | from diffusers.schedulers import KarrasDiffusionSchedulers 21 | from diffusers.utils import deprecate 22 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin 23 | from transformers import AutoProcessor, CLIPVisionModelWithProjection 24 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 25 | import os 26 | import time 27 | MODEL_PATH = "stable-diffusion-v1-5/stable-diffusion-v1-5" 28 | tokenizer = CLIPTokenizer.from_pretrained( 29 | MODEL_PATH, 30 | subfolder="tokenizer", 31 | ) 32 | def tokenize_captions( captions, max_length): 33 | inputs = tokenizer( 34 | captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" 35 | ) 36 | return inputs.input_ids 37 | VIT_PATH = "openai/clip-vit-large-patch14" 38 | 39 | 40 | from diffusers.utils import ( 41 | USE_PEFT_BACKEND, 42 | deprecate, 43 | logging, 44 | replace_example_docstring, 45 | scale_lora_layers, 46 | unscale_lora_layers, 47 | ) 48 | 49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 50 | 51 | 52 | class MyUNet2DConditionModel(UNet2DConditionModel): 53 | def forward( 54 | self, 55 | sample: torch.FloatTensor, 56 | timestep: Union[torch.Tensor, float, int], 57 | up_ft_indices, 58 | encoder_hidden_states: torch.Tensor, 59 | class_labels: Optional[torch.Tensor] = None, 60 | timestep_cond: Optional[torch.Tensor] = None, 61 | attention_mask: Optional[torch.Tensor] = None, 62 | cross_attention_kwargs: Optional[Dict[str, Any]] = None 63 | ): 64 | r""" 65 | Args: 66 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 67 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 68 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 69 | cross_attention_kwargs (`dict`, *optional*): 70 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 71 | `self.processor` in 72 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 73 | """ 74 | # By default samples have to be AT least a multiple of the overall upsampling factor. 75 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 76 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 77 | # on the fly if necessary. 78 | default_overall_up_factor = 2**self.num_upsamplers 79 | 80 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 81 | forward_upsample_size = False 82 | upsample_size = None 83 | 84 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 85 | # logger.info("Forward upsample size to force interpolation output size.") 86 | forward_upsample_size = True 87 | 88 | # prepare attention_mask 89 | if attention_mask is not None: 90 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 91 | attention_mask = attention_mask.unsqueeze(1) 92 | 93 | # 0. center input if necessary 94 | if self.config.center_input_sample: 95 | sample = 2 * sample - 1.0 96 | 97 | # 1. time 98 | timesteps = timestep 99 | if not torch.is_tensor(timesteps): 100 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 101 | # This would be a good case for the `match` statement (Python 3.10+) 102 | is_mps = sample.device.type == 'mps' 103 | if isinstance(timestep, float): 104 | dtype = torch.float32 if is_mps else torch.float64 105 | else: 106 | dtype = torch.int32 if is_mps else torch.int64 107 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 108 | elif len(timesteps.shape) == 0: 109 | timesteps = timesteps[None].to(sample.device) 110 | 111 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 112 | timesteps = timesteps.expand(sample.shape[0]) 113 | 114 | t_emb = self.time_proj(timesteps) 115 | 116 | # timesteps does not contain any weights and will always return f32 tensors 117 | # but time_embedding might actually be running in fp16. so we need to cast here. 118 | # there might be better ways to encapsulate this. 119 | t_emb = t_emb.to(dtype=self.dtype) 120 | 121 | emb = self.time_embedding(t_emb, timestep_cond) 122 | 123 | if self.class_embedding is not None: 124 | if class_labels is None: 125 | raise ValueError('class_labels should be provided when num_class_embeds > 0') 126 | 127 | if self.config.class_embed_type == 'timestep': 128 | class_labels = self.time_proj(class_labels) 129 | 130 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 131 | emb = emb + class_emb 132 | 133 | # 2. pre-process 134 | sample = self.conv_in(sample) 135 | 136 | # 3. down 137 | down_block_res_samples = (sample,) 138 | for downsample_block in self.down_blocks: 139 | if hasattr(downsample_block, 'has_cross_attention') and downsample_block.has_cross_attention: 140 | sample, res_samples = downsample_block( 141 | hidden_states=sample, 142 | temb=emb, 143 | encoder_hidden_states=encoder_hidden_states, 144 | attention_mask=attention_mask, 145 | cross_attention_kwargs=cross_attention_kwargs, 146 | ) 147 | else: 148 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 149 | 150 | down_block_res_samples += res_samples 151 | 152 | # 4. mid 153 | if self.mid_block is not None: 154 | sample = self.mid_block( 155 | sample, 156 | emb, 157 | encoder_hidden_states=encoder_hidden_states, 158 | attention_mask=attention_mask, 159 | cross_attention_kwargs=cross_attention_kwargs, 160 | ) 161 | 162 | # 5. up 163 | up_ft = {} 164 | 165 | for i, upsample_block in enumerate(self.up_blocks): 166 | 167 | if i > np.max(up_ft_indices): 168 | break 169 | 170 | is_final_block = i == len(self.up_blocks) - 1 171 | 172 | res_samples = down_block_res_samples[-len(upsample_block.resnets):] 173 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 174 | 175 | # if we have not reached the final block and need to forward the 176 | # upsample size, we do it here 177 | if not is_final_block and forward_upsample_size: 178 | upsample_size = down_block_res_samples[-1].shape[2:] 179 | 180 | if hasattr(upsample_block, 'has_cross_attention') and upsample_block.has_cross_attention: 181 | sample = upsample_block( 182 | hidden_states=sample, 183 | temb=emb, 184 | res_hidden_states_tuple=res_samples, 185 | encoder_hidden_states=encoder_hidden_states, 186 | cross_attention_kwargs=cross_attention_kwargs, 187 | upsample_size=upsample_size, 188 | attention_mask=attention_mask, 189 | ) 190 | else: 191 | sample = upsample_block( 192 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 193 | ) 194 | 195 | if i in up_ft_indices: 196 | up_ft[i] = sample.detach() 197 | 198 | output = {} 199 | output['up_ft'] = up_ft 200 | 201 | return output 202 | 203 | 204 | class OneStepSDPipeline(StableDiffusionPipeline): 205 | @torch.no_grad() 206 | def __call__( 207 | self, 208 | img_tensor, 209 | t, 210 | up_ft_indices, 211 | prompt_embeds: Optional[torch.FloatTensor] = None, 212 | cross_attention_kwargs: Optional[Dict[str, Any]] = None 213 | ): 214 | 215 | device = self._execution_device 216 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor 217 | t = torch.tensor(t, dtype=torch.long, device=device) 218 | noise = torch.randn_like(latents).to(device) 219 | latents_noisy = self.scheduler.add_noise(latents, noise, t) 220 | unet_output = self.unet(latents, t, up_ft_indices, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs) 221 | return unet_output 222 | 223 | def _encode_prompt( 224 | self, 225 | prompt, 226 | device, 227 | num_images_per_prompt, 228 | do_classifier_free_guidance, 229 | negative_prompt=None, 230 | prompt_embeds: Optional[torch.FloatTensor] = None, 231 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 232 | lora_scale: Optional[float] = None, 233 | **kwargs, 234 | ): 235 | deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." 236 | deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) 237 | 238 | prompt_embeds_tuple = self.encode_prompt( 239 | prompt=prompt, 240 | device=device, 241 | num_images_per_prompt=num_images_per_prompt, 242 | do_classifier_free_guidance=do_classifier_free_guidance, 243 | negative_prompt=negative_prompt, 244 | prompt_embeds=prompt_embeds, 245 | negative_prompt_embeds=negative_prompt_embeds, 246 | lora_scale=lora_scale, 247 | **kwargs, 248 | ) 249 | 250 | # concatenate for backwards comp 251 | # prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) 252 | prompt_embeds = prompt_embeds_tuple[0] 253 | 254 | return prompt_embeds 255 | 256 | def encode_prompt( 257 | self, 258 | prompt, 259 | device, 260 | num_images_per_prompt, 261 | do_classifier_free_guidance, 262 | negative_prompt=None, 263 | prompt_embeds: Optional[torch.FloatTensor] = None, 264 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 265 | lora_scale: Optional[float] = None, 266 | clip_skip: Optional[int] = None, 267 | ): 268 | r""" 269 | Encodes the prompt into text encoder hidden states. 270 | 271 | Args: 272 | prompt (`str` or `List[str]`, *optional*): 273 | prompt to be encoded 274 | device: (`torch.device`): 275 | torch device 276 | num_images_per_prompt (`int`): 277 | number of images that should be generated per prompt 278 | do_classifier_free_guidance (`bool`): 279 | whether to use classifier free guidance or not 280 | negative_prompt (`str` or `List[str]`, *optional*): 281 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 282 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 283 | less than `1`). 284 | prompt_embeds (`torch.FloatTensor`, *optional*): 285 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 286 | provided, text embeddings will be generated from `prompt` input argument. 287 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 288 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 289 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 290 | argument. 291 | lora_scale (`float`, *optional*): 292 | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 293 | clip_skip (`int`, *optional*): 294 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 295 | the output of the pre-final layer will be used for computing the prompt embeddings. 296 | """ 297 | # set lora scale so that monkey patched LoRA 298 | # function of text encoder can correctly access it 299 | if lora_scale is not None and isinstance(self, LoraLoaderMixin): 300 | self._lora_scale = lora_scale 301 | 302 | # dynamically adjust the LoRA scale 303 | if not USE_PEFT_BACKEND: 304 | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) 305 | else: 306 | scale_lora_layers(self.text_encoder, lora_scale) 307 | 308 | if prompt is not None and isinstance(prompt, str): 309 | batch_size = 1 310 | elif prompt is not None and isinstance(prompt, list): 311 | batch_size = len(prompt) 312 | else: 313 | batch_size = prompt_embeds.shape[0] 314 | 315 | if prompt_embeds is None: 316 | # textual inversion: procecss multi-vector tokens if necessary 317 | if isinstance(self, TextualInversionLoaderMixin): 318 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 319 | 320 | text_inputs = self.tokenizer( 321 | prompt, 322 | padding="max_length", 323 | max_length=self.tokenizer.model_max_length, 324 | truncation=True, 325 | return_tensors="pt", 326 | ) 327 | text_input_ids = text_inputs.input_ids 328 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 329 | 330 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 331 | text_input_ids, untruncated_ids 332 | ): 333 | removed_text = self.tokenizer.batch_decode( 334 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 335 | ) 336 | logger.warning( 337 | "The following part of your input was truncated because CLIP can only handle sequences up to" 338 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 339 | ) 340 | 341 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 342 | attention_mask = text_inputs.attention_mask.to(device) 343 | else: 344 | attention_mask = None 345 | 346 | if clip_skip is None: 347 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) 348 | prompt_embeds = prompt_embeds[0] 349 | else: 350 | prompt_embeds = self.text_encoder( 351 | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True 352 | ) 353 | # Access the `hidden_states` first, that contains a tuple of 354 | # all the hidden states from the encoder layers. Then index into 355 | # the tuple to access the hidden states from the desired layer. 356 | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] 357 | # We also need to apply the final LayerNorm here to not mess with the 358 | # representations. The `last_hidden_states` that we typically use for 359 | # obtaining the final prompt representations passes through the LayerNorm 360 | # layer. 361 | prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) 362 | 363 | if self.text_encoder is not None: 364 | prompt_embeds_dtype = self.text_encoder.dtype 365 | elif self.unet is not None: 366 | prompt_embeds_dtype = self.unet.dtype 367 | else: 368 | prompt_embeds_dtype = prompt_embeds.dtype 369 | 370 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 371 | 372 | bs_embed, seq_len, _ = prompt_embeds.shape 373 | # duplicate text embeddings for each generation per prompt, using mps friendly method 374 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 375 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 376 | 377 | # get unconditional embeddings for classifier free guidance 378 | if do_classifier_free_guidance and negative_prompt_embeds is None: 379 | uncond_tokens: List[str] 380 | if negative_prompt is None: 381 | uncond_tokens = [""] * batch_size 382 | elif prompt is not None and type(prompt) is not type(negative_prompt): 383 | raise TypeError( 384 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 385 | f" {type(prompt)}." 386 | ) 387 | elif isinstance(negative_prompt, str): 388 | uncond_tokens = [negative_prompt] 389 | elif batch_size != len(negative_prompt): 390 | raise ValueError( 391 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 392 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 393 | " the batch size of `prompt`." 394 | ) 395 | else: 396 | uncond_tokens = negative_prompt 397 | 398 | # textual inversion: procecss multi-vector tokens if necessary 399 | if isinstance(self, TextualInversionLoaderMixin): 400 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 401 | 402 | max_length = prompt_embeds.shape[1] 403 | uncond_input = self.tokenizer( 404 | uncond_tokens, 405 | padding="max_length", 406 | max_length=max_length, 407 | truncation=True, 408 | return_tensors="pt", 409 | ) 410 | 411 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 412 | attention_mask = uncond_input.attention_mask.to(device) 413 | else: 414 | attention_mask = None 415 | 416 | negative_prompt_embeds = self.text_encoder( 417 | uncond_input.input_ids.to(device), 418 | attention_mask=attention_mask, 419 | ) 420 | negative_prompt_embeds = negative_prompt_embeds[0] 421 | 422 | if do_classifier_free_guidance: 423 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 424 | seq_len = negative_prompt_embeds.shape[1] 425 | 426 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 427 | 428 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 429 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 430 | 431 | if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: 432 | # Retrieve the original scale by scaling back the LoRA layers 433 | unscale_lora_layers(self.text_encoder, lora_scale) 434 | 435 | return prompt_embeds, negative_prompt_embeds 436 | 437 | 438 | class SDFeaturizer: 439 | def __init__(self, sd_id='stable-diffusion-v1-5/stable-diffusion-v1-5'): 440 | unet = MyUNet2DConditionModel.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5/unet') 441 | # onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None) 442 | onestep_pipe = OneStepSDPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', unet=unet, safety_checker=None) 443 | onestep_pipe.vae.decoder = None 444 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', subfolder='scheduler') 445 | gc.collect() 446 | onestep_pipe = onestep_pipe.to('cuda') 447 | onestep_pipe.enable_attention_slicing() 448 | self.pipe = onestep_pipe 449 | self.image_encode = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH) 450 | self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH) 451 | 452 | 453 | self.text_encoder = CLIPTextModel.from_pretrained( 454 | MODEL_PATH, 455 | subfolder="text_encoder", 456 | ).cuda() 457 | 458 | @torch.no_grad() 459 | def forward(self, 460 | img_tensor, 461 | prompt, 462 | out_dir, 463 | t=0, 464 | up_ft_index=1, 465 | ensemble_size=1, 466 | ): 467 | ''' 468 | Args: 469 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W] 470 | prompt: the prompt to use, a string 471 | t: the time step to use, should be an int in the range of [0, 1000] 472 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3] 473 | ensemble_size: the number of repeated images used in the batch to extract features 474 | Return: 475 | unet_ft: a torch tensor in the shape of [1, c, h, w] 476 | ''' 477 | img_tensor = img_tensor.repeat(1, 1, 1, 1).cuda() # ensem, c, h, w 478 | 479 | 480 | prompt_image = self.auto_processor(images=(img_tensor+1)/2, return_tensors="pt").to(img_tensor.device) 481 | prompt_image = self.image_encode.to(img_tensor.device)(prompt_image.data['pixel_values']).image_embeds.to(img_tensor.device) 482 | prompt_image = prompt_image.unsqueeze(1) 483 | 484 | prompt_embeds = self.text_encoder(tokenize_captions([''], 2).to(img_tensor.device))[0] 485 | prompt_embeds[:, 1:] = prompt_image[:] 486 | 487 | 488 | unet_ft_all = self.pipe( 489 | img_tensor=img_tensor, 490 | t=t, 491 | up_ft_indices=[up_ft_index], 492 | prompt_embeds=prompt_embeds) 493 | 494 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w 495 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w 496 | return unet_ft 497 | 498 | 499 | -------------------------------------------------------------------------------- /extract_semantic_point.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as osp 5 | import sys 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from PIL import Image 12 | from torchvision.transforms import PILToTensor 13 | from tqdm import tqdm 14 | from PIL import Image, ImageDraw 15 | from dift_util import SDFeaturizer 16 | import time 17 | 18 | def read_frames_from_folder(frames_folder): 19 | frames = [] 20 | for filename in sorted(os.listdir(frames_folder)): 21 | if filename.endswith('.jpg') or filename.endswith('.png'): 22 | frame_path = os.path.join(frames_folder, filename) 23 | frame = Image.open(frame_path).convert('RGB') 24 | frames.append(np.array(frame)) 25 | return np.stack(frames) 26 | 27 | 28 | def progagate_human_keypoint(file_path): 29 | 30 | with open(file_path, 'r') as file: 31 | 32 | pred_tracks = torch.zeros((len(file.readlines()), 2)) 33 | 34 | file.seek(0) 35 | idx = 0 36 | for line in file: 37 | x, y = map(float, line.strip().split(',')) 38 | pred_tracks[idx, 0] = int(x) 39 | pred_tracks[idx, 1] = int(y) 40 | idx =idx+1 41 | 42 | tap_dict = {'pred_tracks': pred_tracks.cpu(), } 43 | return tap_dict 44 | 45 | 46 | 47 | def extract_dift_feature(image, dift_model,out_dir): 48 | if isinstance(image, Image.Image): 49 | image = image 50 | else: 51 | image = Image.open(image).convert('RGB') 52 | prompt = f'a photo of an upper body garment' 53 | 54 | img_tensor = (PILToTensor()(image) / 255.0 - 0.5) * 2 55 | dift_feature = dift_model.forward(img_tensor, prompt=prompt,out_dir = out_dir, ensemble_size=8) 56 | 57 | return dift_feature 58 | 59 | 60 | def extract_point_embedding(tap_dict, image_path, model_id,out_dir): 61 | 62 | num_point = tap_dict['pred_tracks'].shape[0] 63 | init_embedding = torch.zeros((num_point, 1280)) 64 | init_count = torch.zeros(num_point) 65 | 66 | dift_model = SDFeaturizer(sd_id=model_id) 67 | 68 | img = Image.open(image_path).convert('RGB') 69 | x, y = img.size 70 | img_dift = extract_dift_feature(img,dift_model=dift_model,out_dir = out_dir) 71 | img_dift = nn.Upsample(size=(y, x), mode='bilinear')(img_dift) 72 | 73 | point_all = tap_dict['pred_tracks'] 74 | 75 | for point_idx, point in enumerate(point_all): 76 | point_x, point_y = int(np.round(point[0])), int(np.round(point[1])) 77 | 78 | if point_x >= 0 and point_y >= 0: 79 | point_embedding = img_dift[0, :, point_y//2, point_x//2] 80 | init_embedding[point_idx] += point_embedding.cpu() 81 | init_count[point_idx] += 1 82 | 83 | for point_idx in range(init_count.shape[0]): 84 | if init_count[point_idx] != 0: 85 | init_embedding[point_idx] /= init_count[point_idx] 86 | 87 | tap_dict['point_embedding'] = init_embedding 88 | 89 | return tap_dict 90 | 91 | 92 | if __name__ == '__main__': 93 | parser = argparse.ArgumentParser() 94 | 95 | parser.add_argument('--save_dir', type=str, default='point_embeds_feature_cloth') 96 | parser.add_argument('--image_type', type=str, choices=["cloth", "model"]) 97 | parser.add_argument('--dataroot', type=str, default='datasets/vitonhd/') 98 | parser.add_argument('--phase', type=str, default='test') 99 | 100 | args = parser.parse_args() 101 | 102 | folder_path = os.path.join(args.dataroot,args.phase) 103 | for filename in sorted(os.listdir(os.path.join(folder_path,args.image_type))): 104 | 105 | image_path = os.path.join(folder_path,args.image_type,filename) 106 | file_path = os.path.join(folder_path, 'point',args.image_type,filename.replace('.jpg','.txt')) 107 | tap_dict = progagate_human_keypoint(file_path) 108 | 109 | 110 | tap_dict = extract_point_embedding( 111 | tap_dict, image_path, model_id='sd1.5',out_dir=args.save_dir) 112 | 113 | torch.save(tap_dict, os.path.join(folder_path, args.save_dir,filename.replace('.jpg','.pth'))) 114 | 115 | 116 | -------------------------------------------------------------------------------- /image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HiDream-ai/SPM-Diff/b70a288e707c9116c342c384b5f1ae88bc298016/image.jpg --------------------------------------------------------------------------------