├── README.md
├── dift_util.py
├── extract_semantic_point.py
└── image.jpg
/README.md:
--------------------------------------------------------------------------------
1 |
ICLR25: Incorporating Visual Correspondence into Diffusion Model for Virtual Try-On
2 |
9 |
10 | 1University of Science and Technology of China; 2HiDream.ai Inc
11 |
12 |
13 |
14 | This is the official repository for the
15 | [Paper](*)
16 | "Incorporating Visual Correspondence into Diffusion Model for Visual Try-On"
17 |
18 | ## Overview
19 | 
20 | We novelly propose to explicitly capitalize
21 | on visual correspondence as the prior to tame diffusion process instead of simply
22 | feeding the whole garment into UNet as the appearance reference.
23 |
24 | ## Installation
25 | Create a conda environment & Install requirments
26 | ```
27 | conda create -n SPM-Diff python==3.9.0
28 | conda activate SPM-Diff
29 | cd SPM-Diff-main
30 | pip install -r requirements.txt
31 | ```
32 | ## Semantic Point Matching
33 | In SPM, a set of semantic points on the garment are first sampled and matched to the
34 | corresponding points on the target person via local flow warping. Then, these 2D cues are augmented
35 | into 3D-aware cues with depth/normal map, which act as semantic point matching to supervise
36 | diffusion model.
37 |
38 | You can directly download the [Semantic Point Feature](*) or follow the instructions in [preprocessing.md](*) to extract the Semantic Point Feature yourself.
39 |
40 | ## Dataset
41 | You can download the VITON-HD dataset from [here](https://github.com/sangyun884/HR-VITON))
42 | For inference, the following dataset structure is required:
43 | ```
44 | test
45 | |-- image
46 | |-- masked_vton_img
47 | |-- warp-cloth
48 | |-- cloth
49 | |-- cloth_mask
50 | |-- point
51 | ```
52 | ## Inference
53 | Please download the pre-trained model from [Link](https://huggingface.co/HiDream-ai/SPM-Diff).
54 | ```
55 | sh inference.sh
56 | ```
57 | ## Acknowledgement
58 | Thanks the contribution of [LaDI-VTON](https://github.com/miccunifi/ladi-vton) and [GP-VTON](https://github.com/xiezhy6/GP-VTON).
59 | ## Citation
60 | ```
61 | @inproceedings{
62 | wan2025incorporating,
63 | title={Incorporating Visual Correspondence into Diffusion Model for Virtual Try-On},
64 | author={Siqi Wan and Jingwen Chen and Yingwei Pan and Ting Yao and Tao Mei},
65 | booktitle={ICLR},
66 | year={2025},
67 | }
68 | ```
69 |
70 |
71 |
--------------------------------------------------------------------------------
/dift_util.py:
--------------------------------------------------------------------------------
1 | import gc
2 | from typing import Any, Dict, Optional, Union
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | from diffusers import DDIMScheduler, StableDiffusionPipeline
9 | from diffusers.models.unet_2d_condition import UNet2DConditionModel
10 | from PIL import Image, ImageDraw
11 | from typing import Any, Callable, Dict, List, Optional, Union
12 |
13 | import numpy as np
14 | import PIL
15 | import torch
16 | from diffusers import DDIMInverseScheduler, StableDiffusionPipeline
17 |
18 | from diffusers.models import AutoencoderKL, T2IAdapter
19 |
20 | from diffusers.schedulers import KarrasDiffusionSchedulers
21 | from diffusers.utils import deprecate
22 | from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
23 | from transformers import AutoProcessor, CLIPVisionModelWithProjection
24 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
25 | import os
26 | import time
27 | MODEL_PATH = "stable-diffusion-v1-5/stable-diffusion-v1-5"
28 | tokenizer = CLIPTokenizer.from_pretrained(
29 | MODEL_PATH,
30 | subfolder="tokenizer",
31 | )
32 | def tokenize_captions( captions, max_length):
33 | inputs = tokenizer(
34 | captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
35 | )
36 | return inputs.input_ids
37 | VIT_PATH = "openai/clip-vit-large-patch14"
38 |
39 |
40 | from diffusers.utils import (
41 | USE_PEFT_BACKEND,
42 | deprecate,
43 | logging,
44 | replace_example_docstring,
45 | scale_lora_layers,
46 | unscale_lora_layers,
47 | )
48 |
49 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
50 |
51 |
52 | class MyUNet2DConditionModel(UNet2DConditionModel):
53 | def forward(
54 | self,
55 | sample: torch.FloatTensor,
56 | timestep: Union[torch.Tensor, float, int],
57 | up_ft_indices,
58 | encoder_hidden_states: torch.Tensor,
59 | class_labels: Optional[torch.Tensor] = None,
60 | timestep_cond: Optional[torch.Tensor] = None,
61 | attention_mask: Optional[torch.Tensor] = None,
62 | cross_attention_kwargs: Optional[Dict[str, Any]] = None
63 | ):
64 | r"""
65 | Args:
66 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
67 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
68 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
69 | cross_attention_kwargs (`dict`, *optional*):
70 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
71 | `self.processor` in
72 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
73 | """
74 | # By default samples have to be AT least a multiple of the overall upsampling factor.
75 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
76 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
77 | # on the fly if necessary.
78 | default_overall_up_factor = 2**self.num_upsamplers
79 |
80 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
81 | forward_upsample_size = False
82 | upsample_size = None
83 |
84 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
85 | # logger.info("Forward upsample size to force interpolation output size.")
86 | forward_upsample_size = True
87 |
88 | # prepare attention_mask
89 | if attention_mask is not None:
90 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
91 | attention_mask = attention_mask.unsqueeze(1)
92 |
93 | # 0. center input if necessary
94 | if self.config.center_input_sample:
95 | sample = 2 * sample - 1.0
96 |
97 | # 1. time
98 | timesteps = timestep
99 | if not torch.is_tensor(timesteps):
100 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
101 | # This would be a good case for the `match` statement (Python 3.10+)
102 | is_mps = sample.device.type == 'mps'
103 | if isinstance(timestep, float):
104 | dtype = torch.float32 if is_mps else torch.float64
105 | else:
106 | dtype = torch.int32 if is_mps else torch.int64
107 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
108 | elif len(timesteps.shape) == 0:
109 | timesteps = timesteps[None].to(sample.device)
110 |
111 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
112 | timesteps = timesteps.expand(sample.shape[0])
113 |
114 | t_emb = self.time_proj(timesteps)
115 |
116 | # timesteps does not contain any weights and will always return f32 tensors
117 | # but time_embedding might actually be running in fp16. so we need to cast here.
118 | # there might be better ways to encapsulate this.
119 | t_emb = t_emb.to(dtype=self.dtype)
120 |
121 | emb = self.time_embedding(t_emb, timestep_cond)
122 |
123 | if self.class_embedding is not None:
124 | if class_labels is None:
125 | raise ValueError('class_labels should be provided when num_class_embeds > 0')
126 |
127 | if self.config.class_embed_type == 'timestep':
128 | class_labels = self.time_proj(class_labels)
129 |
130 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
131 | emb = emb + class_emb
132 |
133 | # 2. pre-process
134 | sample = self.conv_in(sample)
135 |
136 | # 3. down
137 | down_block_res_samples = (sample,)
138 | for downsample_block in self.down_blocks:
139 | if hasattr(downsample_block, 'has_cross_attention') and downsample_block.has_cross_attention:
140 | sample, res_samples = downsample_block(
141 | hidden_states=sample,
142 | temb=emb,
143 | encoder_hidden_states=encoder_hidden_states,
144 | attention_mask=attention_mask,
145 | cross_attention_kwargs=cross_attention_kwargs,
146 | )
147 | else:
148 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
149 |
150 | down_block_res_samples += res_samples
151 |
152 | # 4. mid
153 | if self.mid_block is not None:
154 | sample = self.mid_block(
155 | sample,
156 | emb,
157 | encoder_hidden_states=encoder_hidden_states,
158 | attention_mask=attention_mask,
159 | cross_attention_kwargs=cross_attention_kwargs,
160 | )
161 |
162 | # 5. up
163 | up_ft = {}
164 |
165 | for i, upsample_block in enumerate(self.up_blocks):
166 |
167 | if i > np.max(up_ft_indices):
168 | break
169 |
170 | is_final_block = i == len(self.up_blocks) - 1
171 |
172 | res_samples = down_block_res_samples[-len(upsample_block.resnets):]
173 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
174 |
175 | # if we have not reached the final block and need to forward the
176 | # upsample size, we do it here
177 | if not is_final_block and forward_upsample_size:
178 | upsample_size = down_block_res_samples[-1].shape[2:]
179 |
180 | if hasattr(upsample_block, 'has_cross_attention') and upsample_block.has_cross_attention:
181 | sample = upsample_block(
182 | hidden_states=sample,
183 | temb=emb,
184 | res_hidden_states_tuple=res_samples,
185 | encoder_hidden_states=encoder_hidden_states,
186 | cross_attention_kwargs=cross_attention_kwargs,
187 | upsample_size=upsample_size,
188 | attention_mask=attention_mask,
189 | )
190 | else:
191 | sample = upsample_block(
192 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
193 | )
194 |
195 | if i in up_ft_indices:
196 | up_ft[i] = sample.detach()
197 |
198 | output = {}
199 | output['up_ft'] = up_ft
200 |
201 | return output
202 |
203 |
204 | class OneStepSDPipeline(StableDiffusionPipeline):
205 | @torch.no_grad()
206 | def __call__(
207 | self,
208 | img_tensor,
209 | t,
210 | up_ft_indices,
211 | prompt_embeds: Optional[torch.FloatTensor] = None,
212 | cross_attention_kwargs: Optional[Dict[str, Any]] = None
213 | ):
214 |
215 | device = self._execution_device
216 | latents = self.vae.encode(img_tensor).latent_dist.sample() * self.vae.config.scaling_factor
217 | t = torch.tensor(t, dtype=torch.long, device=device)
218 | noise = torch.randn_like(latents).to(device)
219 | latents_noisy = self.scheduler.add_noise(latents, noise, t)
220 | unet_output = self.unet(latents, t, up_ft_indices, encoder_hidden_states=prompt_embeds, cross_attention_kwargs=cross_attention_kwargs)
221 | return unet_output
222 |
223 | def _encode_prompt(
224 | self,
225 | prompt,
226 | device,
227 | num_images_per_prompt,
228 | do_classifier_free_guidance,
229 | negative_prompt=None,
230 | prompt_embeds: Optional[torch.FloatTensor] = None,
231 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
232 | lora_scale: Optional[float] = None,
233 | **kwargs,
234 | ):
235 | deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
236 | deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
237 |
238 | prompt_embeds_tuple = self.encode_prompt(
239 | prompt=prompt,
240 | device=device,
241 | num_images_per_prompt=num_images_per_prompt,
242 | do_classifier_free_guidance=do_classifier_free_guidance,
243 | negative_prompt=negative_prompt,
244 | prompt_embeds=prompt_embeds,
245 | negative_prompt_embeds=negative_prompt_embeds,
246 | lora_scale=lora_scale,
247 | **kwargs,
248 | )
249 |
250 | # concatenate for backwards comp
251 | # prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
252 | prompt_embeds = prompt_embeds_tuple[0]
253 |
254 | return prompt_embeds
255 |
256 | def encode_prompt(
257 | self,
258 | prompt,
259 | device,
260 | num_images_per_prompt,
261 | do_classifier_free_guidance,
262 | negative_prompt=None,
263 | prompt_embeds: Optional[torch.FloatTensor] = None,
264 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
265 | lora_scale: Optional[float] = None,
266 | clip_skip: Optional[int] = None,
267 | ):
268 | r"""
269 | Encodes the prompt into text encoder hidden states.
270 |
271 | Args:
272 | prompt (`str` or `List[str]`, *optional*):
273 | prompt to be encoded
274 | device: (`torch.device`):
275 | torch device
276 | num_images_per_prompt (`int`):
277 | number of images that should be generated per prompt
278 | do_classifier_free_guidance (`bool`):
279 | whether to use classifier free guidance or not
280 | negative_prompt (`str` or `List[str]`, *optional*):
281 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
282 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
283 | less than `1`).
284 | prompt_embeds (`torch.FloatTensor`, *optional*):
285 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
286 | provided, text embeddings will be generated from `prompt` input argument.
287 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
288 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
289 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
290 | argument.
291 | lora_scale (`float`, *optional*):
292 | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
293 | clip_skip (`int`, *optional*):
294 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
295 | the output of the pre-final layer will be used for computing the prompt embeddings.
296 | """
297 | # set lora scale so that monkey patched LoRA
298 | # function of text encoder can correctly access it
299 | if lora_scale is not None and isinstance(self, LoraLoaderMixin):
300 | self._lora_scale = lora_scale
301 |
302 | # dynamically adjust the LoRA scale
303 | if not USE_PEFT_BACKEND:
304 | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
305 | else:
306 | scale_lora_layers(self.text_encoder, lora_scale)
307 |
308 | if prompt is not None and isinstance(prompt, str):
309 | batch_size = 1
310 | elif prompt is not None and isinstance(prompt, list):
311 | batch_size = len(prompt)
312 | else:
313 | batch_size = prompt_embeds.shape[0]
314 |
315 | if prompt_embeds is None:
316 | # textual inversion: procecss multi-vector tokens if necessary
317 | if isinstance(self, TextualInversionLoaderMixin):
318 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
319 |
320 | text_inputs = self.tokenizer(
321 | prompt,
322 | padding="max_length",
323 | max_length=self.tokenizer.model_max_length,
324 | truncation=True,
325 | return_tensors="pt",
326 | )
327 | text_input_ids = text_inputs.input_ids
328 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
329 |
330 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
331 | text_input_ids, untruncated_ids
332 | ):
333 | removed_text = self.tokenizer.batch_decode(
334 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
335 | )
336 | logger.warning(
337 | "The following part of your input was truncated because CLIP can only handle sequences up to"
338 | f" {self.tokenizer.model_max_length} tokens: {removed_text}"
339 | )
340 |
341 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
342 | attention_mask = text_inputs.attention_mask.to(device)
343 | else:
344 | attention_mask = None
345 |
346 | if clip_skip is None:
347 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
348 | prompt_embeds = prompt_embeds[0]
349 | else:
350 | prompt_embeds = self.text_encoder(
351 | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
352 | )
353 | # Access the `hidden_states` first, that contains a tuple of
354 | # all the hidden states from the encoder layers. Then index into
355 | # the tuple to access the hidden states from the desired layer.
356 | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
357 | # We also need to apply the final LayerNorm here to not mess with the
358 | # representations. The `last_hidden_states` that we typically use for
359 | # obtaining the final prompt representations passes through the LayerNorm
360 | # layer.
361 | prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
362 |
363 | if self.text_encoder is not None:
364 | prompt_embeds_dtype = self.text_encoder.dtype
365 | elif self.unet is not None:
366 | prompt_embeds_dtype = self.unet.dtype
367 | else:
368 | prompt_embeds_dtype = prompt_embeds.dtype
369 |
370 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
371 |
372 | bs_embed, seq_len, _ = prompt_embeds.shape
373 | # duplicate text embeddings for each generation per prompt, using mps friendly method
374 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
375 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
376 |
377 | # get unconditional embeddings for classifier free guidance
378 | if do_classifier_free_guidance and negative_prompt_embeds is None:
379 | uncond_tokens: List[str]
380 | if negative_prompt is None:
381 | uncond_tokens = [""] * batch_size
382 | elif prompt is not None and type(prompt) is not type(negative_prompt):
383 | raise TypeError(
384 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
385 | f" {type(prompt)}."
386 | )
387 | elif isinstance(negative_prompt, str):
388 | uncond_tokens = [negative_prompt]
389 | elif batch_size != len(negative_prompt):
390 | raise ValueError(
391 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
392 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
393 | " the batch size of `prompt`."
394 | )
395 | else:
396 | uncond_tokens = negative_prompt
397 |
398 | # textual inversion: procecss multi-vector tokens if necessary
399 | if isinstance(self, TextualInversionLoaderMixin):
400 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
401 |
402 | max_length = prompt_embeds.shape[1]
403 | uncond_input = self.tokenizer(
404 | uncond_tokens,
405 | padding="max_length",
406 | max_length=max_length,
407 | truncation=True,
408 | return_tensors="pt",
409 | )
410 |
411 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
412 | attention_mask = uncond_input.attention_mask.to(device)
413 | else:
414 | attention_mask = None
415 |
416 | negative_prompt_embeds = self.text_encoder(
417 | uncond_input.input_ids.to(device),
418 | attention_mask=attention_mask,
419 | )
420 | negative_prompt_embeds = negative_prompt_embeds[0]
421 |
422 | if do_classifier_free_guidance:
423 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
424 | seq_len = negative_prompt_embeds.shape[1]
425 |
426 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
427 |
428 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
429 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
430 |
431 | if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
432 | # Retrieve the original scale by scaling back the LoRA layers
433 | unscale_lora_layers(self.text_encoder, lora_scale)
434 |
435 | return prompt_embeds, negative_prompt_embeds
436 |
437 |
438 | class SDFeaturizer:
439 | def __init__(self, sd_id='stable-diffusion-v1-5/stable-diffusion-v1-5'):
440 | unet = MyUNet2DConditionModel.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5/unet')
441 | # onestep_pipe = OneStepSDPipeline.from_pretrained(sd_id, unet=unet, safety_checker=None)
442 | onestep_pipe = OneStepSDPipeline.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', unet=unet, safety_checker=None)
443 | onestep_pipe.vae.decoder = None
444 | onestep_pipe.scheduler = DDIMScheduler.from_pretrained('stable-diffusion-v1-5/stable-diffusion-v1-5', subfolder='scheduler')
445 | gc.collect()
446 | onestep_pipe = onestep_pipe.to('cuda')
447 | onestep_pipe.enable_attention_slicing()
448 | self.pipe = onestep_pipe
449 | self.image_encode = CLIPVisionModelWithProjection.from_pretrained(VIT_PATH)
450 | self.auto_processor = AutoProcessor.from_pretrained(VIT_PATH)
451 |
452 |
453 | self.text_encoder = CLIPTextModel.from_pretrained(
454 | MODEL_PATH,
455 | subfolder="text_encoder",
456 | ).cuda()
457 |
458 | @torch.no_grad()
459 | def forward(self,
460 | img_tensor,
461 | prompt,
462 | out_dir,
463 | t=0,
464 | up_ft_index=1,
465 | ensemble_size=1,
466 | ):
467 | '''
468 | Args:
469 | img_tensor: should be a single torch tensor in the shape of [1, C, H, W] or [C, H, W]
470 | prompt: the prompt to use, a string
471 | t: the time step to use, should be an int in the range of [0, 1000]
472 | up_ft_index: which upsampling block of the U-Net to extract feature, you can choose [0, 1, 2, 3]
473 | ensemble_size: the number of repeated images used in the batch to extract features
474 | Return:
475 | unet_ft: a torch tensor in the shape of [1, c, h, w]
476 | '''
477 | img_tensor = img_tensor.repeat(1, 1, 1, 1).cuda() # ensem, c, h, w
478 |
479 |
480 | prompt_image = self.auto_processor(images=(img_tensor+1)/2, return_tensors="pt").to(img_tensor.device)
481 | prompt_image = self.image_encode.to(img_tensor.device)(prompt_image.data['pixel_values']).image_embeds.to(img_tensor.device)
482 | prompt_image = prompt_image.unsqueeze(1)
483 |
484 | prompt_embeds = self.text_encoder(tokenize_captions([''], 2).to(img_tensor.device))[0]
485 | prompt_embeds[:, 1:] = prompt_image[:]
486 |
487 |
488 | unet_ft_all = self.pipe(
489 | img_tensor=img_tensor,
490 | t=t,
491 | up_ft_indices=[up_ft_index],
492 | prompt_embeds=prompt_embeds)
493 |
494 | unet_ft = unet_ft_all['up_ft'][up_ft_index] # ensem, c, h, w
495 | unet_ft = unet_ft.mean(0, keepdim=True) # 1,c,h,w
496 | return unet_ft
497 |
498 |
499 |
--------------------------------------------------------------------------------
/extract_semantic_point.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import os.path as osp
5 | import sys
6 | from pathlib import Path
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn as nn
11 | from PIL import Image
12 | from torchvision.transforms import PILToTensor
13 | from tqdm import tqdm
14 | from PIL import Image, ImageDraw
15 | from dift_util import SDFeaturizer
16 | import time
17 |
18 | def read_frames_from_folder(frames_folder):
19 | frames = []
20 | for filename in sorted(os.listdir(frames_folder)):
21 | if filename.endswith('.jpg') or filename.endswith('.png'):
22 | frame_path = os.path.join(frames_folder, filename)
23 | frame = Image.open(frame_path).convert('RGB')
24 | frames.append(np.array(frame))
25 | return np.stack(frames)
26 |
27 |
28 | def progagate_human_keypoint(file_path):
29 |
30 | with open(file_path, 'r') as file:
31 |
32 | pred_tracks = torch.zeros((len(file.readlines()), 2))
33 |
34 | file.seek(0)
35 | idx = 0
36 | for line in file:
37 | x, y = map(float, line.strip().split(','))
38 | pred_tracks[idx, 0] = int(x)
39 | pred_tracks[idx, 1] = int(y)
40 | idx =idx+1
41 |
42 | tap_dict = {'pred_tracks': pred_tracks.cpu(), }
43 | return tap_dict
44 |
45 |
46 |
47 | def extract_dift_feature(image, dift_model,out_dir):
48 | if isinstance(image, Image.Image):
49 | image = image
50 | else:
51 | image = Image.open(image).convert('RGB')
52 | prompt = f'a photo of an upper body garment'
53 |
54 | img_tensor = (PILToTensor()(image) / 255.0 - 0.5) * 2
55 | dift_feature = dift_model.forward(img_tensor, prompt=prompt,out_dir = out_dir, ensemble_size=8)
56 |
57 | return dift_feature
58 |
59 |
60 | def extract_point_embedding(tap_dict, image_path, model_id,out_dir):
61 |
62 | num_point = tap_dict['pred_tracks'].shape[0]
63 | init_embedding = torch.zeros((num_point, 1280))
64 | init_count = torch.zeros(num_point)
65 |
66 | dift_model = SDFeaturizer(sd_id=model_id)
67 |
68 | img = Image.open(image_path).convert('RGB')
69 | x, y = img.size
70 | img_dift = extract_dift_feature(img,dift_model=dift_model,out_dir = out_dir)
71 | img_dift = nn.Upsample(size=(y, x), mode='bilinear')(img_dift)
72 |
73 | point_all = tap_dict['pred_tracks']
74 |
75 | for point_idx, point in enumerate(point_all):
76 | point_x, point_y = int(np.round(point[0])), int(np.round(point[1]))
77 |
78 | if point_x >= 0 and point_y >= 0:
79 | point_embedding = img_dift[0, :, point_y//2, point_x//2]
80 | init_embedding[point_idx] += point_embedding.cpu()
81 | init_count[point_idx] += 1
82 |
83 | for point_idx in range(init_count.shape[0]):
84 | if init_count[point_idx] != 0:
85 | init_embedding[point_idx] /= init_count[point_idx]
86 |
87 | tap_dict['point_embedding'] = init_embedding
88 |
89 | return tap_dict
90 |
91 |
92 | if __name__ == '__main__':
93 | parser = argparse.ArgumentParser()
94 |
95 | parser.add_argument('--save_dir', type=str, default='point_embeds_feature_cloth')
96 | parser.add_argument('--image_type', type=str, choices=["cloth", "model"])
97 | parser.add_argument('--dataroot', type=str, default='datasets/vitonhd/')
98 | parser.add_argument('--phase', type=str, default='test')
99 |
100 | args = parser.parse_args()
101 |
102 | folder_path = os.path.join(args.dataroot,args.phase)
103 | for filename in sorted(os.listdir(os.path.join(folder_path,args.image_type))):
104 |
105 | image_path = os.path.join(folder_path,args.image_type,filename)
106 | file_path = os.path.join(folder_path, 'point',args.image_type,filename.replace('.jpg','.txt'))
107 | tap_dict = progagate_human_keypoint(file_path)
108 |
109 |
110 | tap_dict = extract_point_embedding(
111 | tap_dict, image_path, model_id='sd1.5',out_dir=args.save_dir)
112 |
113 | torch.save(tap_dict, os.path.join(folder_path, args.save_dir,filename.replace('.jpg','.pth')))
114 |
115 |
116 |
--------------------------------------------------------------------------------
/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HiDream-ai/SPM-Diff/b70a288e707c9116c342c384b5f1ae88bc298016/image.jpg
--------------------------------------------------------------------------------