├── .gitignore ├── AttrConcenTrainableSDPipeline.py ├── AttrConcenTrainableSDXLPipeline.py ├── README.md ├── TrainableSDPipeline.py ├── attn_utils ├── tc_attn_utils.py ├── tc_loss_utils.py └── tc_sdxl_attn_utils.py ├── attr_concen_utils ├── gsam_interface.py └── load_segmodel.py ├── attribute_concen_utils.py ├── collected_data ├── abc5k.txt ├── hrs_collected_10k.txt └── train_all.txt ├── concept_mat_utils ├── caption_blip.py ├── load_captionmodel.py └── processing_blip.py ├── fig └── demo.png ├── merged_data └── abc5k_hrs10k_t2icompall_20k.txt ├── node8.yaml ├── requirements.txt ├── scripts ├── sd15.sh └── sdxl.sh ├── tools └── gan_gt_generate.py ├── training_script.py ├── training_utils ├── arguments.py ├── dataset.py ├── gan_dataset.py ├── gan_sd_model.py ├── gan_sdxl.py ├── logging.py └── pipeline.py └── valid.txt /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | -------------------------------------------------------------------------------- /AttrConcenTrainableSDPipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | import torch 3 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg 4 | 5 | from typing import Any, Callable, Dict, List, Optional, Union 6 | 7 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 8 | 9 | from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin 10 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 11 | from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker 12 | from diffusers.models.lora import adjust_lora_scale_text_encoder 13 | from diffusers.schedulers import KarrasDiffusionSchedulers 14 | from diffusers.utils import ( 15 | USE_PEFT_BACKEND, 16 | scale_lora_layers, 17 | unscale_lora_layers, 18 | ) 19 | from torch.nn.parallel import DistributedDataParallel as DDP 20 | from TrainableSDPipeline import TrainableSDPipeline 21 | 22 | import spacy 23 | from attribute_concen_utils import * 24 | 25 | from attn_utils.tc_attn_utils import get_cross_attn_map_from_unet 26 | 27 | 28 | class AttrConcenTrainableSDPipeline(TrainableSDPipeline): 29 | def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPImageProcessor, requires_safety_checker: bool = True): 30 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) 31 | self.parser = spacy.load("en_core_web_trf") 32 | self.subtrees_indices = None 33 | self.doc = None 34 | self.contrast_loss = False 35 | self.attn_dict = {} 36 | 37 | # this is modified from the __call__ method of the StableDiffusionPipeline class 38 | def forward( 39 | self, 40 | prompt: Union[str, List[str]] = None, 41 | height: int = 512, 42 | width: int = 512, 43 | training_timesteps: Optional[List[int]] = [], # new training parameter 44 | early_exit: bool = False, # new training parameter 45 | detach_gradient: bool = True, # new training parameter 46 | train_text_encoder: bool = False, # new training parameter 47 | double_laststep: bool = False, # new training parameter 48 | bp_on_trained: bool = False, # new training parameter 49 | fast_training: bool = False, # new training parameter 50 | num_inference_steps: int = 50, 51 | guidance_scale: float = 7.5, 52 | negative_prompt: Optional[Union[str, List[str]]] = None, 53 | num_images_per_prompt: Optional[int] = 1, 54 | eta: float = 0.0, 55 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 56 | latents: Optional[torch.FloatTensor] = None, 57 | prompt_embeds: Optional[torch.FloatTensor] = None, 58 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 59 | output_type: Optional[str] = "image", 60 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 61 | callback_steps: int = 1, 62 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 63 | guidance_rescale: float = 0.0, 64 | return_latents: bool = False, # new training parameter 65 | batch=None, 66 | attrcon_train_steps=None, 67 | ): 68 | # SYNGEN: NEW - use parsed_prompt instead of prompt 69 | self.doc = {} 70 | for p in prompt: 71 | self.doc[p] = self.parser(p) 72 | syn_gen_loss = negative_prompt_embeds.new_zeros(()) 73 | self.subtrees_indices = dict() 74 | 75 | # 2. Define call parameters 76 | if prompt is not None and isinstance(prompt, str): 77 | batch_size = 1 78 | elif prompt is not None and isinstance(prompt, list): 79 | batch_size = len(prompt) 80 | else: 81 | batch_size = prompt_embeds.shape[0] 82 | 83 | device = self._execution_device 84 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 85 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 86 | # corresponds to doing no classifier free guidance. 87 | do_classifier_free_guidance = guidance_scale > 1.0 88 | 89 | torch.set_grad_enabled(train_text_encoder) 90 | # 3. Encode input prompt 91 | text_encoder_lora_scale = ( 92 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 93 | ) 94 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 95 | prompt, 96 | device, 97 | num_images_per_prompt, 98 | do_classifier_free_guidance, 99 | negative_prompt, 100 | prompt_embeds=prompt_embeds, 101 | negative_prompt_embeds=negative_prompt_embeds, 102 | lora_scale=text_encoder_lora_scale, 103 | ) 104 | # For classifier free guidance, we need to do two forward passes. 105 | # Here we concatenate the unconditional and text embeddings into a single batch 106 | # to avoid doing two forward passes 107 | if do_classifier_free_guidance: 108 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 109 | prompt_embeds = prompt_embeds 110 | 111 | # 4. Prepare timesteps 112 | self.scheduler.set_timesteps(num_inference_steps, device=device) 113 | if fast_training: 114 | self.scheduler.timesteps = [self.scheduler.timesteps[i] for i in training_timesteps] 115 | training_timesteps = range(len(self.scheduler.timesteps)) 116 | timesteps = self.scheduler.timesteps 117 | 118 | # 5. Prepare latent variables 119 | latents = self.prepare_latents( 120 | batch_size * num_images_per_prompt, 121 | 4, 122 | height, 123 | width, 124 | prompt_embeds.dtype, 125 | device, 126 | generator, 127 | latents, 128 | ) 129 | 130 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 131 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 132 | 133 | # Added 134 | text_embeddings = ( 135 | prompt_embeds[batch_size * num_images_per_prompt:] if do_classifier_free_guidance else prompt_embeds 136 | ) 137 | 138 | torch.set_grad_enabled(True) 139 | 140 | # training_steps 141 | attrcon_train_steps = attrcon_train_steps if attrcon_train_steps is not None else \ 142 | training_timesteps 143 | 144 | # 7. Denoising loop 145 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 146 | for i, t in enumerate(timesteps): 147 | torch.set_grad_enabled((len(training_timesteps) == 0 or i > min(training_timesteps)) and not double_laststep) 148 | # expand the latents if we are doing classifier free guidance 149 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 150 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 151 | 152 | torch.set_grad_enabled(i in training_timesteps and not double_laststep) 153 | 154 | do_detach = detach_gradient 155 | if i in training_timesteps and bp_on_trained: 156 | do_detach = False 157 | 158 | # predict the noise residual 159 | if i in training_timesteps and bp_on_trained and i in attrcon_train_steps: 160 | noise_pred = self._attrcon_forward( 161 | latent_model_input.detach() if do_detach else latent_model_input, 162 | t, 163 | prompt_embeds=prompt_embeds, 164 | cross_attention_kwargs=cross_attention_kwargs, 165 | prompt=prompt, 166 | text_embeddings=text_embeddings 167 | ) 168 | else: 169 | noise_pred = self.unet( 170 | latent_model_input.detach() if do_detach else latent_model_input, 171 | t, 172 | encoder_hidden_states=prompt_embeds, 173 | cross_attention_kwargs=cross_attention_kwargs, 174 | return_dict=False, 175 | )[0] 176 | 177 | noise_pred = noise_pred.to(prompt_embeds.dtype) 178 | 179 | # perform guidance 180 | if do_classifier_free_guidance: 181 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 182 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 183 | 184 | if do_classifier_free_guidance and guidance_rescale > 0.0: 185 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 186 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 187 | 188 | torch.set_grad_enabled((len(training_timesteps) == 0 or i >= min(training_timesteps)) and not double_laststep) 189 | 190 | # compute the previous noisy sample x_t -> x_t-1 191 | out = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True) 192 | latents = out.prev_sample 193 | intermediate_result = out.pred_original_sample if hasattr(out, "pred_original_sample") else None 194 | 195 | # call the callback, if provided 196 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 197 | if callback is not None and i % callback_steps == 0: 198 | callback(i, t, latents) 199 | 200 | if len(training_timesteps) > 0 and i == max(training_timesteps) and early_exit: 201 | latents = intermediate_result 202 | break 203 | 204 | if double_laststep: 205 | torch.set_grad_enabled(double_laststep) 206 | t = timesteps[training_timesteps[0]] 207 | noise = torch.randn_like(latents) 208 | noisy_latents = self.scheduler.add_noise(latents, noise, t) 209 | 210 | noisy_model_input = torch.cat([noisy_latents] * 2) if do_classifier_free_guidance else noisy_latents 211 | noisy_model_input = self.scheduler.scale_model_input(noisy_model_input, t) 212 | noise_pred = self.unet( 213 | noisy_model_input.detach() if do_detach else latent_model_input, 214 | t, 215 | encoder_hidden_states=prompt_embeds, 216 | cross_attention_kwargs=cross_attention_kwargs, 217 | return_dict=False, 218 | )[0] 219 | # perform guidance 220 | if do_classifier_free_guidance: 221 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 222 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 223 | 224 | if do_classifier_free_guidance and guidance_rescale > 0.0: 225 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 226 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 227 | 228 | out = self.scheduler.step(noise_pred, t, noisy_latents, **extra_step_kwargs, return_dict=True) 229 | latents = out.prev_sample 230 | 231 | if output_type == "image": 232 | image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0] 233 | if return_latents: 234 | return (image / 2 + 0.5), latents 235 | return (image / 2 + 0.5) # .clamp(0, 1) 236 | elif output_type == "latent": 237 | return latents 238 | 239 | def _attrcon_forward( 240 | self, 241 | latents, # this should not detach 242 | t, 243 | prompt_embeds, 244 | cross_attention_kwargs=None, 245 | prompt=None, 246 | text_embeddings=None, 247 | ): 248 | 249 | # forward with the whole 250 | self.controller.reset() 251 | noise_cond = self.unet( 252 | latents[latents.shape[0]//2:], 253 | t, 254 | encoder_hidden_states=prompt_embeds[latents.shape[0]//2:], 255 | cross_attention_kwargs=cross_attention_kwargs, 256 | return_dict=False, 257 | )[0] 258 | 259 | # Get attention maps 260 | attn_dict = get_cross_attn_map_from_unet( 261 | attention_store=self.controller, 262 | is_training_sd21=False 263 | ) 264 | self.attn_dict[str(t.cpu().item())] = attn_dict 265 | self.controller.reset() 266 | 267 | # the first half(uncond) can be directly forward 268 | noise_uncond = self.unet( 269 | latents[:latents.shape[0]//2], 270 | t, 271 | encoder_hidden_states=prompt_embeds[:latents.shape[0]//2], 272 | cross_attention_kwargs=cross_attention_kwargs, 273 | return_dict=False, 274 | )[0] 275 | self.controller.reset() 276 | 277 | noise_pred = torch.cat([noise_uncond, noise_cond], dim=0) 278 | 279 | return noise_pred 280 | 281 | def _extract_attribution_indices(self, prompt): 282 | # extract standard attribution indices 283 | 284 | pairs = extract_attribution_indices(self.doc[prompt]) or [] 285 | 286 | # extract attribution indices with verbs in between 287 | pairs_2 = extract_attribution_indices_with_verb_root(self.doc[prompt]) or [] 288 | pairs_3 = extract_attribution_indices_with_verbs(self.doc[prompt]) or [] 289 | # make sure there are no duplicates 290 | pairs = unify_lists(pairs, pairs_2, pairs_3) 291 | 292 | # MOD: filter too long pairs 293 | pairs = [p for p in pairs if len(p)<4] 294 | 295 | paired_indices = self._align_indices(prompt, pairs) 296 | return paired_indices 297 | 298 | def _align_indices(self, prompt, spacy_pairs): 299 | wordpieces2indices = get_indices(self.tokenizer, prompt) 300 | paired_indices = [] 301 | collected_spacy_indices = ( 302 | set() 303 | ) # helps track recurring nouns across different relations (i.e., cases where there is more than one instance of the same word) 304 | 305 | for pair in spacy_pairs: 306 | curr_collected_wp_indices = ( 307 | [] 308 | ) # helps track which nouns and amods were added to the current pair (this is useful in sentences with repeating amod on the same relation (e.g., "a red red red bear")) 309 | for member in pair: 310 | for idx, wp in wordpieces2indices.items(): 311 | if wp in [start_token, end_token]: 312 | continue 313 | 314 | wp = wp.replace("", "") 315 | if member.text == wp: 316 | if idx not in curr_collected_wp_indices and idx not in collected_spacy_indices: 317 | curr_collected_wp_indices.append(idx) 318 | break 319 | # take care of wordpieces that are split up 320 | elif member.text.startswith(wp) and wp != member.text: # can maybe be while loop 321 | wp_indices = align_wordpieces_indices( 322 | wordpieces2indices, idx, member.text 323 | ) 324 | # check if all wp_indices are not already in collected_spacy_indices 325 | if wp_indices and (wp_indices not in curr_collected_wp_indices) and all([wp_idx not in collected_spacy_indices for wp_idx in wp_indices]): 326 | curr_collected_wp_indices.append(wp_indices) 327 | break 328 | 329 | for collected_idx in curr_collected_wp_indices: 330 | if isinstance(collected_idx, list): 331 | for idx in collected_idx: 332 | collected_spacy_indices.add(idx) 333 | else: 334 | collected_spacy_indices.add(collected_idx) 335 | 336 | paired_indices.append(curr_collected_wp_indices) 337 | 338 | return paired_indices 339 | 340 | 341 | def encode_prompt( 342 | self, 343 | prompt, 344 | device, 345 | num_images_per_prompt, 346 | do_classifier_free_guidance, 347 | negative_prompt=None, 348 | prompt_embeds: Optional[torch.FloatTensor] = None, 349 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 350 | lora_scale: Optional[float] = None, 351 | clip_skip: Optional[int] = None, 352 | ): 353 | r""" 354 | Encodes the prompt into text encoder hidden states. 355 | 356 | Args: 357 | prompt (`str` or `List[str]`, *optional*): 358 | prompt to be encoded 359 | device: (`torch.device`): 360 | torch device 361 | num_images_per_prompt (`int`): 362 | number of images that should be generated per prompt 363 | do_classifier_free_guidance (`bool`): 364 | whether to use classifier free guidance or not 365 | negative_prompt (`str` or `List[str]`, *optional*): 366 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 367 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 368 | less than `1`). 369 | prompt_embeds (`torch.FloatTensor`, *optional*): 370 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 371 | provided, text embeddings will be generated from `prompt` input argument. 372 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 373 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 374 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 375 | argument. 376 | lora_scale (`float`, *optional*): 377 | A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 378 | clip_skip (`int`, *optional*): 379 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 380 | the output of the pre-final layer will be used for computing the prompt embeddings. 381 | """ 382 | # set lora scale so that monkey patched LoRA 383 | # function of text encoder can correctly access it 384 | if lora_scale is not None and isinstance(self, LoraLoaderMixin): 385 | self._lora_scale = lora_scale 386 | 387 | # dynamically adjust the LoRA scale 388 | if not USE_PEFT_BACKEND: 389 | adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) 390 | else: 391 | scale_lora_layers(self.text_encoder, lora_scale) 392 | 393 | if prompt is not None and isinstance(prompt, str): 394 | batch_size = 1 395 | elif prompt is not None and isinstance(prompt, list): 396 | batch_size = len(prompt) 397 | else: 398 | batch_size = prompt_embeds.shape[0] 399 | 400 | if prompt_embeds is None: 401 | # textual inversion: procecss multi-vector tokens if necessary 402 | if isinstance(self, TextualInversionLoaderMixin): 403 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer) 404 | 405 | text_inputs = self.tokenizer( 406 | prompt, 407 | padding="max_length", 408 | max_length=self.tokenizer.model_max_length, 409 | truncation=True, 410 | return_tensors="pt", 411 | ) 412 | text_input_ids = text_inputs.input_ids 413 | untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 414 | 415 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 416 | text_input_ids, untruncated_ids 417 | ): 418 | removed_text = self.tokenizer.batch_decode( 419 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 420 | ) 421 | logger.warning( 422 | "The following part of your input was truncated because CLIP can only handle sequences up to" 423 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 424 | ) 425 | 426 | # MOD: accomodate for training text encoder 427 | if isinstance(self.text_encoder, DDP): 428 | if hasattr(self.text_encoder.module.config, "use_attention_mask") and self.text_encoder.module.config.use_attention_mask: 429 | attention_mask = text_inputs.attention_mask.to(device) 430 | else: 431 | attention_mask = None 432 | else: 433 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 434 | attention_mask = text_inputs.attention_mask.to(device) 435 | else: 436 | attention_mask = None 437 | 438 | if clip_skip is None: 439 | prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) 440 | prompt_embeds = prompt_embeds[0] 441 | else: 442 | prompt_embeds = self.text_encoder( 443 | text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True 444 | ) 445 | # Access the `hidden_states` first, that contains a tuple of 446 | # all the hidden states from the encoder layers. Then index into 447 | # the tuple to access the hidden states from the desired layer. 448 | prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] 449 | # We also need to apply the final LayerNorm here to not mess with the 450 | # representations. The `last_hidden_states` that we typically use for 451 | # obtaining the final prompt representations passes through the LayerNorm 452 | # layer. 453 | prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) 454 | 455 | if self.text_encoder is not None and isinstance(self.text_encoder, DDP): 456 | prompt_embeds_dtype = self.text_encoder.module.dtype 457 | elif self.text_encoder is not None: 458 | prompt_embeds_dtype = self.text_encoder.dtype 459 | elif self.unet is not None: 460 | prompt_embeds_dtype = self.unet.dtype 461 | else: 462 | prompt_embeds_dtype = prompt_embeds.dtype 463 | 464 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 465 | 466 | bs_embed, seq_len, _ = prompt_embeds.shape 467 | # duplicate text embeddings for each generation per prompt, using mps friendly method 468 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 469 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 470 | 471 | # get unconditional embeddings for classifier free guidance 472 | if do_classifier_free_guidance and negative_prompt_embeds is None: 473 | uncond_tokens: List[str] 474 | if negative_prompt is None: 475 | uncond_tokens = [""] * batch_size 476 | elif prompt is not None and type(prompt) is not type(negative_prompt): 477 | raise TypeError( 478 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 479 | f" {type(prompt)}." 480 | ) 481 | elif isinstance(negative_prompt, str): 482 | uncond_tokens = [negative_prompt] 483 | elif batch_size != len(negative_prompt): 484 | raise ValueError( 485 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 486 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 487 | " the batch size of `prompt`." 488 | ) 489 | else: 490 | uncond_tokens = negative_prompt 491 | 492 | # textual inversion: procecss multi-vector tokens if necessary 493 | if isinstance(self, TextualInversionLoaderMixin): 494 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) 495 | 496 | max_length = prompt_embeds.shape[1] 497 | uncond_input = self.tokenizer( 498 | uncond_tokens, 499 | padding="max_length", 500 | max_length=max_length, 501 | truncation=True, 502 | return_tensors="pt", 503 | ) 504 | 505 | # MOD: accomodate for training text encoder 506 | if isinstance(self.text_encoder, DDP): 507 | if hasattr(self.text_encoder.module.config, "use_attention_mask") and self.text_encoder.module.config.use_attention_mask: 508 | attention_mask = text_inputs.attention_mask.to(device) 509 | else: 510 | attention_mask = None 511 | else: 512 | if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: 513 | attention_mask = text_inputs.attention_mask.to(device) 514 | else: 515 | attention_mask = None 516 | 517 | 518 | negative_prompt_embeds = self.text_encoder( 519 | uncond_input.input_ids.to(device), 520 | attention_mask=attention_mask, 521 | ) 522 | negative_prompt_embeds = negative_prompt_embeds[0] 523 | 524 | if do_classifier_free_guidance: 525 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 526 | seq_len = negative_prompt_embeds.shape[1] 527 | 528 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 529 | 530 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 531 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 532 | 533 | if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: 534 | # Retrieve the original scale by scaling back the LoRA layers 535 | unscale_lora_layers(self.text_encoder, lora_scale) 536 | 537 | return prompt_embeds, negative_prompt_embeds 538 | 539 | def is_sublist(sub, main): 540 | # This function checks if 'sub' is a sublist of 'main' 541 | return len(sub) < len(main) and all(item in main for item in sub) 542 | 543 | def unify_lists(lists_1, lists_2, lists_3): 544 | unified_list = lists_1 + lists_2 + lists_3 545 | sorted_list = sorted(unified_list, key=len) 546 | seen = set() 547 | 548 | result = [] 549 | 550 | for i in range(len(sorted_list)): 551 | if tuple(sorted_list[i]) in seen: # Skip if already added 552 | continue 553 | 554 | sublist_to_add = True 555 | for j in range(i + 1, len(sorted_list)): 556 | if is_sublist(sorted_list[i], sorted_list[j]): 557 | sublist_to_add = False 558 | break 559 | 560 | if sublist_to_add: 561 | result.append(sorted_list[i]) 562 | seen.add(tuple(sorted_list[i])) 563 | 564 | return result -------------------------------------------------------------------------------- /AttrConcenTrainableSDXLPipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 2 | import torch 3 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import rescale_noise_cfg 4 | 5 | from typing import Any, Callable, Dict, List, Optional, Union 6 | 7 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection 8 | 9 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 10 | from diffusers.schedulers import KarrasDiffusionSchedulers 11 | 12 | from TrainableSDPipeline import TrainableSDXLPipeline 13 | 14 | import spacy 15 | from attribute_concen_utils import * 16 | 17 | from attn_utils.tc_attn_utils import get_cross_attn_map_from_unet 18 | 19 | 20 | class AttrConcenTrainableSDXLPipeline(TrainableSDXLPipeline): 21 | def __init__(self, vae: AutoencoderKL, text_encoder: CLIPTextModel, text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None): 22 | super().__init__(vae, text_encoder, text_encoder_2, tokenizer, tokenizer_2, unet, scheduler, force_zeros_for_empty_prompt, add_watermarker) 23 | self.parser = spacy.load("en_core_web_trf") 24 | self.doc = None 25 | 26 | self.attn_dict = {} 27 | 28 | @torch.no_grad() 29 | def __call__( 30 | self, 31 | prompt: Union[str, List[str]] = None, 32 | prompt_2: Optional[Union[str, List[str]]] = None, 33 | height: Optional[int] = None, 34 | width: Optional[int] = None, 35 | num_inference_steps: int = 50, 36 | denoising_end: Optional[float] = None, 37 | guidance_scale: float = 5.0, 38 | negative_prompt: Optional[Union[str, List[str]]] = None, 39 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 40 | num_images_per_prompt: Optional[int] = 1, 41 | eta: float = 0.0, 42 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 43 | latents: Optional[torch.FloatTensor] = None, 44 | prompt_embeds: Optional[torch.FloatTensor] = None, 45 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 46 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 47 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 48 | output_type: Optional[str] = "pil", 49 | return_dict: bool = True, 50 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 51 | callback_steps: int = 1, 52 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 53 | guidance_rescale: float = 0.0, 54 | original_size: Optional[Tuple[int, int]] = None, 55 | crops_coords_top_left: Tuple[int, int] = (0, 0), 56 | target_size: Optional[Tuple[int, int]] = None, 57 | negative_original_size: Optional[Tuple[int, int]] = None, 58 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 59 | negative_target_size: Optional[Tuple[int, int]] = None, 60 | ): 61 | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput 62 | 63 | # 0. Default height and width to unet 64 | height = height or self.default_sample_size * self.vae_scale_factor 65 | width = width or self.default_sample_size * self.vae_scale_factor 66 | 67 | original_size = original_size or (height, width) 68 | target_size = target_size or (height, width) 69 | 70 | # 2. Define call parameters 71 | if prompt is not None and isinstance(prompt, str): 72 | batch_size = 1 73 | elif prompt is not None and isinstance(prompt, list): 74 | batch_size = len(prompt) 75 | else: 76 | batch_size = prompt_embeds.shape[0] 77 | 78 | device = self._execution_device 79 | 80 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 81 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 82 | # corresponds to doing no classifier free guidance. 83 | do_classifier_free_guidance = guidance_scale > 1.0 84 | 85 | # 3. Encode input prompt 86 | text_encoder_lora_scale = ( 87 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 88 | ) 89 | ( 90 | prompt_embeds, 91 | negative_prompt_embeds, 92 | pooled_prompt_embeds, 93 | negative_pooled_prompt_embeds, 94 | ) = self.encode_prompt( 95 | prompt=prompt, 96 | prompt_2=prompt_2, 97 | device=device, 98 | num_images_per_prompt=num_images_per_prompt, 99 | do_classifier_free_guidance=do_classifier_free_guidance, 100 | negative_prompt=negative_prompt, 101 | negative_prompt_2=negative_prompt_2, 102 | prompt_embeds=prompt_embeds, 103 | negative_prompt_embeds=negative_prompt_embeds, 104 | pooled_prompt_embeds=pooled_prompt_embeds, 105 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 106 | lora_scale=text_encoder_lora_scale, 107 | ) 108 | 109 | # 4. Prepare timesteps 110 | self.scheduler.set_timesteps(num_inference_steps, device=device) 111 | 112 | timesteps = self.scheduler.timesteps 113 | 114 | # 5. Prepare latent variables 115 | num_channels_latents = self.unet.config.in_channels 116 | latents = self.prepare_latents( 117 | batch_size * num_images_per_prompt, 118 | num_channels_latents, 119 | height, 120 | width, 121 | prompt_embeds.dtype, 122 | device, 123 | generator, 124 | latents, 125 | ) 126 | 127 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 128 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 129 | 130 | # 7. Prepare added time ids & embeddings 131 | add_text_embeds = pooled_prompt_embeds 132 | add_time_ids = self._get_add_time_ids( 133 | original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype 134 | ) 135 | if negative_original_size is not None and negative_target_size is not None: 136 | negative_add_time_ids = self._get_add_time_ids( 137 | negative_original_size, 138 | negative_crops_coords_top_left, 139 | negative_target_size, 140 | dtype=prompt_embeds.dtype, 141 | ) 142 | else: 143 | negative_add_time_ids = add_time_ids 144 | 145 | if do_classifier_free_guidance: 146 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 147 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 148 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 149 | 150 | prompt_embeds = prompt_embeds.to(device) 151 | add_text_embeds = add_text_embeds.to(device) 152 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 153 | 154 | # 8. Denoising loop 155 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 156 | 157 | # 7.1 Apply denoising_end 158 | if denoising_end is not None and isinstance(denoising_end, float) and denoising_end > 0 and denoising_end < 1: 159 | discrete_timestep_cutoff = int( 160 | round( 161 | self.scheduler.config.num_train_timesteps 162 | - (denoising_end * self.scheduler.config.num_train_timesteps) 163 | ) 164 | ) 165 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 166 | timesteps = timesteps[:num_inference_steps] 167 | 168 | 169 | for i, t in enumerate(timesteps): 170 | # expand the latents if we are doing classifier free guidance 171 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 172 | 173 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 174 | 175 | # predict the noise residual 176 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 177 | noise_pred = self.unet( 178 | latent_model_input, 179 | t, 180 | encoder_hidden_states=prompt_embeds, 181 | cross_attention_kwargs=cross_attention_kwargs, 182 | added_cond_kwargs=added_cond_kwargs, 183 | return_dict=False, 184 | )[0] 185 | 186 | # perform guidance 187 | if do_classifier_free_guidance: 188 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 189 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 190 | 191 | if do_classifier_free_guidance and guidance_rescale > 0.0: 192 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 193 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 194 | 195 | # compute the previous noisy sample x_t -> x_t-1 196 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 197 | 198 | # call the callback, if provided 199 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 200 | if callback is not None and i % callback_steps == 0: 201 | callback(i, t, latents) 202 | 203 | if not output_type == "latent": 204 | # make sure the VAE is in float32 mode, as it overflows in float16 205 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 206 | 207 | if needs_upcasting: 208 | self.upcast_vae() 209 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 210 | 211 | image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0] 212 | 213 | # cast back to fp16 if needed 214 | if needs_upcasting: 215 | self.vae.to(dtype=torch.float16) 216 | else: 217 | image = latents 218 | 219 | if not output_type == "latent": 220 | # apply watermark if available 221 | # if self.watermark is not None: 222 | # image = self.watermark.apply_watermark(image) 223 | 224 | image = self.image_processor.postprocess(image, output_type=output_type) 225 | 226 | # Offload all models 227 | self.maybe_free_model_hooks() 228 | 229 | if not return_dict: 230 | return (image,) 231 | 232 | return StableDiffusionXLPipelineOutput(images=image) 233 | 234 | def forward( 235 | self, 236 | prompt: Union[str, List[str]] = None, 237 | prompt_2: Optional[Union[str, List[str]]] = None, 238 | height: int = 512, 239 | width: int = 512, 240 | training_timesteps: Optional[List[int]] = [], 241 | early_exit: bool = False, 242 | detach_gradient: bool = True, 243 | train_text_encoder: bool = False, 244 | num_inference_steps: int = 50, 245 | denoising_end: Optional[float] = None, 246 | guidance_scale: float = 5.0, 247 | negative_prompt: Optional[Union[str, List[str]]] = None, 248 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 249 | num_images_per_prompt: Optional[int] = 1, 250 | eta: float = 0.0, 251 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 252 | latents: Optional[torch.FloatTensor] = None, 253 | prompt_embeds: Optional[torch.FloatTensor] = None, 254 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 255 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 256 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 257 | output_type: Optional[str] = "image", 258 | return_dict: bool = True, 259 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 260 | callback_steps: int = 1, 261 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 262 | guidance_rescale: float = 0.0, 263 | original_size: Optional[Tuple[int, int]] = None, 264 | crops_coords_top_left: Tuple[int, int] = (0, 0), 265 | target_size: Optional[Tuple[int, int]] = None, 266 | return_latents: bool = False, # new training parameter 267 | batch=None, 268 | attrcon_train_steps=None 269 | ): 270 | 271 | self.doc = {} 272 | for p in prompt: 273 | self.doc[p] = self.parser(p) 274 | 275 | original_size = original_size or (height, width) 276 | target_size = target_size or (height, width) 277 | 278 | # 1. Check inputs. Raise error if not correct 279 | # self.check_inputs( 280 | # prompt, 281 | # prompt_2, 282 | # height, 283 | # width, 284 | # callback_steps, 285 | # negative_prompt, 286 | # negative_prompt_2, 287 | # prompt_embeds, 288 | # negative_prompt_embeds, 289 | # pooled_prompt_embeds, 290 | # negative_pooled_prompt_embeds, 291 | # ) 292 | 293 | # 2. Define call parameters 294 | if prompt is not None and isinstance(prompt, str): 295 | batch_size = 1 296 | elif prompt is not None and isinstance(prompt, list): 297 | batch_size = len(prompt) 298 | else: 299 | batch_size = prompt_embeds.shape[0] 300 | 301 | device = self._execution_device 302 | 303 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 304 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 305 | # corresponds to doing no classifier free guidance. 306 | do_classifier_free_guidance = guidance_scale > 1.0 307 | torch.set_grad_enabled(train_text_encoder) 308 | # 3. Encode input prompt 309 | text_encoder_lora_scale = ( 310 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 311 | ) 312 | ( 313 | prompt_embeds, 314 | negative_prompt_embeds, 315 | pooled_prompt_embeds, 316 | negative_pooled_prompt_embeds, 317 | ) = self.encode_prompt( 318 | prompt=prompt, 319 | prompt_2=prompt_2, 320 | device=device, 321 | num_images_per_prompt=num_images_per_prompt, 322 | do_classifier_free_guidance=do_classifier_free_guidance, 323 | negative_prompt=negative_prompt, 324 | negative_prompt_2=negative_prompt_2, 325 | prompt_embeds=prompt_embeds, 326 | negative_prompt_embeds=negative_prompt_embeds, 327 | pooled_prompt_embeds=pooled_prompt_embeds, 328 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 329 | lora_scale=text_encoder_lora_scale, 330 | ) 331 | # 4. Prepare timesteps 332 | self.scheduler.set_timesteps(num_inference_steps, device=device) 333 | timesteps = self.scheduler.timesteps 334 | 335 | # 5. Prepare latent variables 336 | latents = self.prepare_latents( 337 | batch_size * num_images_per_prompt, 338 | 4, 339 | height, 340 | width, 341 | prompt_embeds.dtype, 342 | device, 343 | generator, 344 | latents, 345 | ) 346 | 347 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 348 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 349 | torch.set_grad_enabled(True) 350 | # 7. Prepare added time ids & embeddings 351 | add_text_embeds = pooled_prompt_embeds 352 | add_time_ids = self._get_add_time_ids( 353 | original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype 354 | ) 355 | 356 | if do_classifier_free_guidance: 357 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 358 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 359 | add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) 360 | 361 | # mask training_steps: 362 | attrcon_train_steps = attrcon_train_steps if attrcon_train_steps is not None else \ 363 | training_timesteps 364 | 365 | prompt_embeds = prompt_embeds.to(device) 366 | add_text_embeds = add_text_embeds.to(device) 367 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 368 | # 8. Denoising loop 369 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 370 | 371 | # 7.1 Apply denoising_end 372 | if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: 373 | discrete_timestep_cutoff = int( 374 | round( 375 | self.scheduler.config.num_train_timesteps 376 | - (denoising_end * self.scheduler.config.num_train_timesteps) 377 | ) 378 | ) 379 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 380 | timesteps = timesteps[:num_inference_steps] 381 | 382 | for i, t in enumerate(timesteps): 383 | # expand the latents if we are doing classifier free guidance 384 | torch.set_grad_enabled(i > min(training_timesteps)) 385 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 386 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 387 | 388 | torch.set_grad_enabled(i in training_timesteps) 389 | # predict the noise residual 390 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 391 | 392 | # add mask loss 393 | if i in training_timesteps and i in attrcon_train_steps: 394 | noise_pred = self._attrcon_forward( 395 | latent_model_input.detach() if detach_gradient else latent_model_input, # this should not detach 396 | t, 397 | prompt_embeds=prompt_embeds, 398 | cross_attention_kwargs=cross_attention_kwargs, 399 | added_cond_kwargs=added_cond_kwargs, 400 | return_dict=False, 401 | ) 402 | else: 403 | noise_pred = self.unet( 404 | latent_model_input.detach() if detach_gradient else latent_model_input, 405 | t, 406 | encoder_hidden_states=prompt_embeds, 407 | cross_attention_kwargs=cross_attention_kwargs, 408 | added_cond_kwargs=added_cond_kwargs, 409 | return_dict=False, 410 | )[0] 411 | 412 | 413 | # perform guidance 414 | if do_classifier_free_guidance: 415 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 416 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 417 | 418 | if do_classifier_free_guidance and guidance_rescale > 0.0: 419 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 420 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 421 | 422 | torch.set_grad_enabled(i >= min(training_timesteps)) 423 | # compute the previous noisy sample x_t -> x_t-1 424 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 425 | 426 | # call the callback, if provided 427 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 428 | if callback is not None and i % callback_steps == 0: 429 | callback(i, t, latents) 430 | 431 | if i == max(training_timesteps) and early_exit: 432 | break 433 | 434 | # make sure the VAE is in float32 mode, as it overflows in float16 435 | #if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: 436 | # self.upcast_vae() 437 | # latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 438 | 439 | if return_latents: 440 | image = self.vae.decode(latents.half() / self.vae.config.scaling_factor, return_dict=False)[0] 441 | return image, latents.half() 442 | 443 | if output_type == "image": 444 | image = self.vae.decode(latents.half() / self.vae.config.scaling_factor, return_dict=False)[0] 445 | return (image / 2 + 0.5) # .clamp(0, 1) 446 | elif output_type == "latent": 447 | return latents 448 | 449 | def _attrcon_forward( 450 | self, 451 | latents, # this should not detach 452 | t, 453 | prompt_embeds, 454 | cross_attention_kwargs, 455 | added_cond_kwargs, 456 | return_dict=False, 457 | prompts=None 458 | ): 459 | added_cond_kwargs_front, added_cond_kwargs_back = {}, {} 460 | for k, v in added_cond_kwargs.items(): 461 | v_front, v_back = v.chunk(2) 462 | added_cond_kwargs_front[k] = v_front 463 | added_cond_kwargs_back[k] = v_back 464 | # forward with the whole 465 | self.controller.reset() 466 | noise_cond = self.unet( 467 | latents[latents.shape[0]//2:], 468 | t, 469 | encoder_hidden_states=prompt_embeds[latents.shape[0]//2:], 470 | cross_attention_kwargs=cross_attention_kwargs, 471 | added_cond_kwargs=added_cond_kwargs_back, 472 | return_dict=False, 473 | )[0] 474 | 475 | # Get attention maps 476 | attn_dict = get_cross_attn_map_from_unet( 477 | attention_store=self.controller, 478 | is_training_sd21=False 479 | ) 480 | self.attn_dict[str(t.cpu().item())] = attn_dict 481 | self.controller.reset() 482 | 483 | # the first half(uncond) can be directly forward 484 | noise_uncond = self.unet( 485 | latents[:latents.shape[0]//2], 486 | t, 487 | encoder_hidden_states=prompt_embeds[:latents.shape[0]//2], 488 | cross_attention_kwargs=cross_attention_kwargs, 489 | added_cond_kwargs=added_cond_kwargs_front, 490 | return_dict=False, 491 | )[0] 492 | self.controller.reset() 493 | 494 | noise_pred = torch.cat([noise_uncond, noise_cond], dim=0) 495 | 496 | return noise_pred 497 | 498 | def _extract_attribution_indices(self, prompt): 499 | # extract standard attribution indices 500 | 501 | pairs = extract_attribution_indices(self.doc[prompt]) or [] 502 | 503 | # extract attribution indices with verbs in between 504 | pairs_2 = extract_attribution_indices_with_verb_root(self.doc[prompt]) or [] 505 | pairs_3 = extract_attribution_indices_with_verbs(self.doc[prompt]) or [] 506 | # make sure there are no duplicates 507 | pairs = unify_lists(pairs, pairs_2, pairs_3) 508 | 509 | # MOD: filter too long pairs 510 | pairs = [p for p in pairs if len(p)<4] 511 | 512 | paired_indices = self._align_indices(prompt, pairs) 513 | return paired_indices 514 | 515 | 516 | def _align_indices(self, prompt, spacy_pairs): 517 | wordpieces2indices = get_indices(self.tokenizer, prompt) 518 | paired_indices = [] 519 | collected_spacy_indices = ( 520 | set() 521 | ) # helps track recurring nouns across different relations (i.e., cases where there is more than one instance of the same word) 522 | 523 | for pair in spacy_pairs: 524 | curr_collected_wp_indices = ( 525 | [] 526 | ) # helps track which nouns and amods were added to the current pair (this is useful in sentences with repeating amod on the same relation (e.g., "a red red red bear")) 527 | for member in pair: 528 | for idx, wp in wordpieces2indices.items(): 529 | if wp in [start_token, end_token]: 530 | continue 531 | 532 | wp = wp.replace("", "") 533 | if member.text == wp: 534 | if idx not in curr_collected_wp_indices and idx not in collected_spacy_indices: 535 | curr_collected_wp_indices.append(idx) 536 | break 537 | # take care of wordpieces that are split up 538 | elif member.text.startswith(wp) and wp != member.text: # can maybe be while loop 539 | wp_indices = align_wordpieces_indices( 540 | wordpieces2indices, idx, member.text 541 | ) 542 | # check if all wp_indices are not already in collected_spacy_indices 543 | if wp_indices and (wp_indices not in curr_collected_wp_indices) and all([wp_idx not in collected_spacy_indices for wp_idx in wp_indices]): 544 | curr_collected_wp_indices.append(wp_indices) 545 | break 546 | 547 | for collected_idx in curr_collected_wp_indices: 548 | if isinstance(collected_idx, list): 549 | for idx in collected_idx: 550 | collected_spacy_indices.add(idx) 551 | else: 552 | collected_spacy_indices.add(collected_idx) 553 | 554 | paired_indices.append(curr_collected_wp_indices) 555 | 556 | return paired_indices 557 | 558 | # mask loss 559 | 560 | def is_sublist(sub, main): 561 | # This function checks if 'sub' is a sublist of 'main' 562 | return len(sub) < len(main) and all(item in main for item in sub) 563 | 564 | 565 | def unify_lists(lists_1, lists_2, lists_3): 566 | unified_list = lists_1 + lists_2 + lists_3 567 | sorted_list = sorted(unified_list, key=len) 568 | seen = set() 569 | 570 | result = [] 571 | 572 | for i in range(len(sorted_list)): 573 | if tuple(sorted_list[i]) in seen: # Skip if already added 574 | continue 575 | 576 | sublist_to_add = True 577 | for j in range(i + 1, len(sorted_list)): 578 | if is_sublist(sorted_list[i], sorted_list[j]): 579 | sublist_to_add = False 580 | break 581 | 582 | if sublist_to_add: 583 | result.append(sorted_list[i]) 584 | seen.add(tuple(sorted_list[i])) 585 | 586 | return result -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 💫CoMat: Aligning Text-to-Image Diffusion Model with Image-to-Text Concept Matching 2 | 3 | Official repository for the paper "[CoMat: Aligning Text-to-Image Diffusion Model with Image-to-Text Concept Matching](https://arxiv.org/pdf/2404.03653.pdf)". 4 | 5 | 🌟 For more details, please refer to the project page: [https://caraj7.github.io/comat/](https://caraj7.github.io/comat/). 6 | 7 | [[🌐 Webpage](https://caraj7.github.io/comat/)] [[📖 Paper](https://arxiv.org/pdf/2404.03653.pdf)] 8 | 9 | 10 | 11 | 12 | ## 💥 News 13 | 14 | - **[2024.09.26]** 🎉 CoMat is accepted by Neurips 2024! 15 | 16 | - **[2024.04.30]** 🔥 We release the training code of CoMat. 17 | 18 | - **[2024.04.05]** 🚀 We release our paper on [arXiv](https://arxiv.org/pdf/2404.03653.pdf). 19 | 20 | 21 | 22 | ## 👀 About CoMat 23 | 24 | We propose 💫CoMat, an end-to-end diffusion model fine-tuning strategy with an image-to-text concept matching mechanism. We leverage an image captioning model to measure image-to-text alignment and guide the diffusion model to revisit ignored tokens. 25 | 26 | ![demo](fig/demo.png) 27 | 28 | ## 🔨Usage 29 | 30 | ### Install 31 | 32 | Install the requirements first. We verify the environment setup in the current file but we expect the newer versions should also work. 33 | 34 | ```bash 35 | pip install -r requirements.txt 36 | ``` 37 | 38 | The Attribute Concentration module requires [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) to find the mask of the entities. Please run the following command to install Grounded-SAM. 39 | 40 | ```bash 41 | mkdir seg_model 42 | cd seg_model 43 | git clone https://github.com/IDEA-Research/Grounded-Segment-Anything.git 44 | mv Grounded-Segment-Anything gsam 45 | cd gsam/GroundingDINO 46 | pip install -e . 47 | ``` 48 | 49 | 50 | 51 | ### Training 52 | 53 | We currently support [SD1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0). Other Version 1 of Stable Diffusion should also be supported, e.g, [SD1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4), as they share the same architecture with SD1.5. 54 | 55 | #### SD1.5 56 | 57 | SD1.5 can be directly used to train. 58 | 59 | First, we need to generate the latents used in the Fidelity Preservation module. 60 | 61 | ```bash 62 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 63 | python tools/gan_gt_generate.py \ 64 | --prompt-path merged_data/abc5k_hrs10k_t2icompall_20k.txt \ 65 | --save-prompt-path train_data_sd15/gan_gt_data.jsonl \ 66 | --model-path runwayml/stable-diffusion-v1-5 \ 67 | --model-type sd_1_5 68 | ``` 69 | 70 | Then we start training. 71 | 72 | ```bash 73 | bash scripts/sd15.sh 74 | ``` 75 | 76 | #### SDXL 77 | 78 | We recommend to first fine-tune the Unet of SDXL on the resolution of 512\*512 to enable fast convergence since the original SDXL generate images of poor quality on 512\*512. 79 | 80 | For the fine-tuning data, we directly use SDXL to generate 1024\*1024 images given the training prompt. Then we resize the generated images to 512\*512 and use these images to fine-tune SDXL for only 100 steps. We use the [script](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_sdxl.py) from diffusers. We will later release the fine-tuned unet. 81 | 82 | Then we first generate the latents with the fine-tuned UNet. 83 | 84 | ```bash 85 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \ 86 | python tools/gan_gt_generate.py \ 87 | --prompt-path merged_data/abc5k_hrs10k_t2icompall_20k.txt \ 88 | --save-prompt-path train_data_sdxl/gan_gt_data.jsonl \ 89 | --unet-path FINETUNED_UNET_PATH \ 90 | --model-path stabilityai/stable-diffusion-xl-base-1.0 \ 91 | --model-type sdxl_unet 92 | ``` 93 | 94 | Finally we start training: 95 | 96 | ```bash 97 | bash scripts/sdxl.sh 98 | ``` 99 | 100 | 101 | 102 | ## 📌 TODO 103 | 104 | - [ ] Release the checkpoints. 105 | 106 | - [x] Release training code in April. 107 | 108 | 109 | 110 | ## :white_check_mark: Citation 111 | 112 | If you find **CoMat** useful for your research and applications, please kindly cite using this BibTeX: 113 | 114 | ```latex 115 | @article{jiang2024comat, 116 | title={CoMat: Aligning Text-to-Image Diffusion Model with Image-to-Text Concept Matching}, 117 | author={Jiang, Dongzhi and Song, Guanglu and Wu, Xiaoshi and Zhang, Renrui and Shen, Dazhong and Zong, Zhuofan and Liu, Yu and Li, Hongsheng}, 118 | journal={arXiv preprint arXiv:2404.03653}, 119 | year={2024} 120 | } 121 | ``` 122 | 123 | 124 | 125 | ## 👍Thanks 126 | 127 | We would like to thank [Grounded-SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything), [TokenCompose](https://github.com/mlpc-ucsd/TokenCompose), and [Diffusers](https://github.com/huggingface/diffusers). 128 | -------------------------------------------------------------------------------- /attn_utils/tc_attn_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | class AttentionControl and class AttentionStore are modified from 3 | https://github.com/mlpc-ucsd/TokenCompose/blob/main/train/src/attn_utils.py 4 | https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb 5 | https://github.com/google/prompt-to-prompt/blob/main/ptp_utils.py 6 | ''' 7 | 8 | 9 | import abc 10 | import torch 11 | import pdb 12 | 13 | 14 | LOW_RESOURCE = False 15 | SD14_TO_SD21_RATIO = 1.5 16 | 17 | class AttentionControl(abc.ABC): 18 | 19 | def step_callback(self, x_t): 20 | return x_t 21 | 22 | def between_steps(self): 23 | return 24 | 25 | @property 26 | def num_uncond_att_layers(self): 27 | return self.num_att_layers if LOW_RESOURCE else 0 28 | 29 | @abc.abstractmethod 30 | def forward (self, attn, is_cross: bool, place_in_unet: str): 31 | raise NotImplementedError 32 | 33 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 34 | attn = self.forward(attn, is_cross, place_in_unet) 35 | # attn: 8 * res * res 36 | 37 | self.cur_att_layer += 1 38 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 39 | self.cur_att_layer = 0 40 | self.cur_step += 1 41 | self.between_steps() 42 | return attn 43 | 44 | def reset(self): 45 | self.cur_step = 0 46 | self.cur_att_layer = 0 47 | 48 | def __init__(self): 49 | self.cur_step = 0 50 | self.num_att_layers = -1 51 | self.cur_att_layer = 0 52 | 53 | class AttentionStore(AttentionControl): 54 | 55 | @staticmethod 56 | def get_empty_store(): 57 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 58 | "down_self": [], "mid_self": [], "up_self": []} 59 | 60 | def forward(self, attn, is_cross: bool, place_in_unet: str): 61 | # if is_cross: 62 | 63 | # only store cross attn 64 | if is_cross and place_in_unet in self.train_layer_place: 65 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 66 | self.step_store[key].append(attn.clone()) 67 | 68 | return attn 69 | 70 | def between_steps(self): 71 | 72 | self.cur_step = 0 73 | self.attention_store = self.step_store 74 | self.step_store = self.get_empty_store() 75 | 76 | def get_average_attention(self): 77 | # debug: make sure the attn map will not be changed without clone 78 | average_attention = {key: [item for item in self.attention_store[key]] for key in self.attention_store} 79 | return average_attention 80 | 81 | def reset(self): 82 | super(AttentionStore, self).reset() 83 | self.step_store = self.get_empty_store() 84 | self.attention_store = {} 85 | 86 | def __init__(self, train_layer_ls): 87 | super(AttentionStore, self).__init__() 88 | self.step_store = self.get_empty_store() 89 | self.attention_store = {} 90 | # only store training layer 91 | train_layer_place = [] 92 | for inst in train_layer_ls: 93 | train_layer_place.append(inst.split('_')[0]) 94 | self.train_layer_place = list(set(train_layer_place)) 95 | 96 | def register_attention_control(unet_model, controller): 97 | def ca_forward(self, place_in_unet): 98 | to_out = self.to_out 99 | if type(to_out) is torch.nn.modules.container.ModuleList: 100 | to_out = self.to_out[0] 101 | else: 102 | to_out = self.to_out 103 | 104 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,): 105 | is_cross = encoder_hidden_states is not None 106 | 107 | residual = hidden_states 108 | 109 | if self.spatial_norm is not None: 110 | hidden_states = self.spatial_norm(hidden_states, temb) 111 | 112 | input_ndim = hidden_states.ndim 113 | 114 | if input_ndim == 4: 115 | batch_size, channel, height, width = hidden_states.shape 116 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 117 | 118 | batch_size, sequence_length, _ = ( 119 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 120 | ) 121 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) 122 | 123 | if self.group_norm is not None: 124 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 125 | 126 | query = self.to_q(hidden_states) 127 | 128 | if encoder_hidden_states is None: 129 | encoder_hidden_states = hidden_states 130 | elif self.norm_cross: 131 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) 132 | 133 | key = self.to_k(encoder_hidden_states) 134 | value = self.to_v(encoder_hidden_states) 135 | 136 | query = self.head_to_batch_dim(query) 137 | key = self.head_to_batch_dim(key) 138 | value = self.head_to_batch_dim(value) 139 | 140 | attention_probs = self.get_attention_scores(query, key, attention_mask) 141 | 142 | if attention_probs.requires_grad: 143 | attention_probs = controller(attention_probs, is_cross, place_in_unet) 144 | 145 | hidden_states = torch.bmm(attention_probs, value) 146 | hidden_states = self.batch_to_head_dim(hidden_states) 147 | 148 | # linear proj 149 | hidden_states = to_out(hidden_states) 150 | # all drop out in diffusers are 0.0 151 | # so we here ignore dropout 152 | 153 | if input_ndim == 4: 154 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 155 | 156 | if self.residual_connection: 157 | hidden_states = hidden_states + residual 158 | 159 | hidden_states = hidden_states / self.rescale_output_factor 160 | 161 | return hidden_states 162 | return forward 163 | 164 | assert controller is not None, "controller must be specified" 165 | 166 | def register_recr(net_, count, place_in_unet): 167 | if net_.__class__.__name__ == 'Attention': 168 | net_.forward = ca_forward(net_, place_in_unet) 169 | return count + 1 170 | elif hasattr(net_, 'children'): 171 | for net__ in net_.children(): 172 | count = register_recr(net__, count, place_in_unet) 173 | return count 174 | 175 | down_count = 0 176 | up_count = 0 177 | mid_count = 0 178 | 179 | cross_att_count = 0 180 | sub_nets = unet_model.named_children() 181 | for net in sub_nets: 182 | 183 | if "down" in net[0]: 184 | down_temp = register_recr(net[1], 0, "down") 185 | cross_att_count += down_temp 186 | down_count += down_temp 187 | elif "up" in net[0]: 188 | up_temp = register_recr(net[1], 0, "up") 189 | cross_att_count += up_temp 190 | up_count += up_temp 191 | elif "mid" in net[0]: 192 | mid_temp = register_recr(net[1], 0, "mid") 193 | cross_att_count += mid_temp 194 | mid_count += mid_temp 195 | 196 | controller.num_att_layers = cross_att_count 197 | 198 | def get_cross_attn_map_from_unet(attention_store: AttentionStore, is_training_sd21, 199 | reses=[64, 32, 16, 8], poses=["down", "mid", "up"]): 200 | attention_maps = attention_store.get_average_attention() 201 | 202 | attn_dict = {} 203 | 204 | if is_training_sd21: 205 | reses = [int(SD14_TO_SD21_RATIO * item) for item in reses] 206 | 207 | for pos in poses: 208 | for res in reses: 209 | temp_list = [] 210 | for item in attention_maps[f"{pos}_cross"]: 211 | if item.shape[1] == res ** 2: 212 | cross_maps = item.reshape(-1, res, res, item.shape[-1]) 213 | temp_list.append(cross_maps) 214 | # if such resolution exists 215 | if len(temp_list) > 0: 216 | attn_dict[f"{pos}_{res}"] = temp_list 217 | return attn_dict -------------------------------------------------------------------------------- /attn_utils/tc_loss_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Modified from 3 | https://github.com/mlpc-ucsd/TokenCompose/blob/main/train/src/loss_utils.py 4 | 5 | ''' 6 | import torch 7 | import torch.nn as nn 8 | from torchvision import transforms 9 | from copy import deepcopy 10 | import pdb 11 | 12 | SD14_TO_SD21_RATIO = 1.5 13 | 14 | # get token index in text 15 | def get_word_idx(text: str, tgt_word, tokenizer): 16 | 17 | tgt_word = tgt_word.lower() 18 | 19 | # ignore the first and last token 20 | encoded_text = tokenizer.encode(text)[1:-1] 21 | encoded_tgt_word = tokenizer.encode(tgt_word)[1:-1] 22 | 23 | # find the idx of target word in text 24 | first_token_idx = -1 25 | for i in range(len(encoded_text)): 26 | if encoded_text[i] == encoded_tgt_word[0]: 27 | 28 | if len(encoded_text) > 0: 29 | # check the following 30 | following_match = True 31 | for j in range(1, len(encoded_tgt_word)): 32 | if encoded_text[i + j] != encoded_tgt_word[j]: 33 | following_match = False 34 | if not following_match: 35 | continue 36 | # for a single encoded idx, just take it 37 | first_token_idx = i 38 | 39 | break 40 | 41 | assert first_token_idx != -1, "word not in text" 42 | 43 | # add 1 for sot token 44 | tgt_word_tokens_idx_ls = [i + 1 + first_token_idx for i in range(len(encoded_tgt_word))] 45 | 46 | # sanity check 47 | encoded_text = tokenizer.encode(text) 48 | 49 | decoded_token_ls = [] 50 | 51 | for word_idx in tgt_word_tokens_idx_ls: 52 | text_decode = tokenizer.decode([encoded_text[word_idx]]).strip("#") 53 | decoded_token_ls.append(text_decode) 54 | 55 | decoded_tgt_word = "".join(decoded_token_ls) 56 | 57 | tgt_word_ls = tgt_word.split(" ") 58 | striped_tgt_word = "".join(tgt_word_ls).strip("#") 59 | 60 | assert decoded_tgt_word == striped_tgt_word, "decode_text != striped_tar_wd" 61 | 62 | return tgt_word_tokens_idx_ls 63 | 64 | # get attn loss by resolution 65 | 66 | def get_grounding_loss_by_layer(_gt_seg_list, word_token_idx_ls, res, 67 | input_attn_map_ls, is_training_sd21): 68 | """ 69 | _gt_seg_list (List[Tensor]): (1, 1, res, res) 70 | input_attn_map_ls (List[Tensor]): (b*head, res, res, 77), sum = b*head*res*res 71 | 72 | """ 73 | if is_training_sd21: 74 | # training with sd21, using resolution 768 = 512 * 1.5 75 | res = int(SD14_TO_SD21_RATIO * res) 76 | 77 | if len(word_token_idx_ls) == 0: 78 | return { 79 | "token_loss" : 0, 80 | "pixel_loss": 0, 81 | } 82 | 83 | gt_seg_list = deepcopy(_gt_seg_list) 84 | 85 | # reszie gt seg map to the same size with attn map 86 | # assert gt_seg_list[0].min() == 0 and gt_seg_list[0].max() == 1 87 | 88 | resize_transform = transforms.Resize((res, res), antialias=True) 89 | 90 | valid_gt_seg = 0 91 | for i in range(len(gt_seg_list)): 92 | gt_seg_list[i] = resize_transform(gt_seg_list[i]) 93 | gt_seg_list[i] = gt_seg_list[i].squeeze(0) # 1, 1, res, res => 1, res, res 94 | # add binary 95 | gt_seg_list[i] = (gt_seg_list[i] > 0.0).float() 96 | # if there is truly a mask 97 | if gt_seg_list[i].sum() > 0.1: 98 | valid_gt_seg += 1 99 | 100 | 101 | ################### token loss start ################### 102 | # Following code is adapted from 103 | # https://github.com/silent-chen/layout-guidance/blob/08b687470f911c7f57937012bdf55194836d693e/utils.py#L27 104 | token_loss = 0.0 105 | for attn_map in input_attn_map_ls: 106 | b, H, W, j = attn_map.shape 107 | for i in range(len(word_token_idx_ls)): # [[word1 token_idx1, word1 token_idx2, ...], [word2 token_idx1, word2 token_idx2, ...]] 108 | obj_loss = 0.0 109 | single_word_idx_ls = word_token_idx_ls[i] #[token_idx1, token_idx2, ...] 110 | mask = gt_seg_list[i] 111 | 112 | for obj_position in single_word_idx_ls: 113 | # ca map obj shape 8 * 16 * 16 114 | try: 115 | ca_map_obj = attn_map[:, :, :, obj_position].reshape(b, H, W) 116 | except: 117 | print(_gt_seg_list, word_token_idx_ls, res, input_attn_map_ls, attn_map.shape) 118 | 119 | # calculate pixel value inside mask 120 | activation_value = (ca_map_obj * mask).reshape(b, -1).sum(dim=-1)/ca_map_obj.reshape(b, -1).sum(dim=-1) 121 | obj_loss += (1.0 - torch.mean(activation_value)) ** 2 122 | 123 | token_loss += (obj_loss/len(single_word_idx_ls)) 124 | 125 | token_loss = token_loss / len(word_token_idx_ls) 126 | 127 | ################## token loss end ########################## 128 | 129 | ################## pixel loss start ###################### 130 | # average cross attention map on different layers 131 | avg_attn_map_ls = [] 132 | for i in range(len(input_attn_map_ls)): 133 | avg_attn_map_ls.append( 134 | input_attn_map_ls[i].reshape(-1, res, res, input_attn_map_ls[i].shape[-1]).mean(0) 135 | ) # avg on each head 136 | avg_attn_map = torch.stack(avg_attn_map_ls, dim=0) 137 | avg_attn_map = avg_attn_map.sum(0) / avg_attn_map.shape[0] 138 | avg_attn_map = avg_attn_map.unsqueeze(0) # (1, res, res, 77) 139 | 140 | bce_loss_func = nn.BCELoss() 141 | pixel_loss = 0.0 142 | 143 | for i in range(len(word_token_idx_ls)): 144 | 145 | word_cross_attn_ls = [] 146 | for token_idx in word_token_idx_ls[i]: 147 | word_cross_attn_ls.append( 148 | avg_attn_map[..., token_idx] 149 | ) 150 | 151 | word_cross_attn_ls = torch.stack(word_cross_attn_ls, dim=0).sum(dim=0) # this should be sum, tokens of one object should belong to one meta class 152 | try: 153 | pixel_loss += bce_loss_func( 154 | word_cross_attn_ls, 155 | gt_seg_list[i] 156 | ) 157 | except: 158 | print(word_cross_attn_ls.shape, word_cross_attn_ls.device) 159 | print(gt_seg_list[i].shape, gt_seg_list[i].device) 160 | pdb.set_trace() 161 | print(gt_seg_list) 162 | # contrastive loss 163 | 164 | 165 | # average with len word_token_idx_ls 166 | # pixel_loss = pixel_loss / valid_gt_seg 167 | pixel_loss = pixel_loss / len(word_token_idx_ls) 168 | ################## pixel loss end ######################### 169 | 170 | return { 171 | "token_loss" : token_loss, 172 | "pixel_loss": pixel_loss, 173 | } 174 | -------------------------------------------------------------------------------- /attn_utils/tc_sdxl_attn_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | class AttentionControl and class AttentionStore are modified from 3 | https://github.com/mlpc-ucsd/TokenCompose/blob/main/train/src/attn_utils.py 4 | https://github.com/google/prompt-to-prompt/blob/main/prompt-to-prompt_stable.ipynb 5 | https://github.com/google/prompt-to-prompt/blob/main/ptp_utils.py 6 | ''' 7 | 8 | 9 | import abc 10 | import torch 11 | import pdb 12 | 13 | 14 | LOW_RESOURCE = False 15 | SD14_TO_SD21_RATIO = 1.5 16 | 17 | class AttentionControl(abc.ABC): 18 | 19 | def step_callback(self, x_t): 20 | return x_t 21 | 22 | def between_steps(self): 23 | return 24 | 25 | @property 26 | def num_uncond_att_layers(self): 27 | return self.num_att_layers if LOW_RESOURCE else 0 28 | 29 | @abc.abstractmethod 30 | def forward (self, attn, is_cross: bool, place_in_unet: str): 31 | raise NotImplementedError 32 | 33 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 34 | attn = self.forward(attn, is_cross, place_in_unet) 35 | # attn: 8 * res * res 36 | 37 | self.cur_att_layer += 1 38 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 39 | self.cur_att_layer = 0 40 | self.cur_step += 1 41 | self.between_steps() 42 | return attn 43 | 44 | def reset(self): 45 | self.cur_step = 0 46 | self.cur_att_layer = 0 47 | 48 | def __init__(self): 49 | self.cur_step = 0 50 | self.num_att_layers = -1 51 | self.cur_att_layer = 0 52 | 53 | class AttentionStore(AttentionControl): 54 | 55 | @staticmethod 56 | def get_empty_store(): 57 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 58 | "down_self": [], "mid_self": [], "up_self": []} 59 | 60 | def forward(self, attn, is_cross: bool, place_in_unet: str): 61 | # only store cross attn 62 | if is_cross and place_in_unet in self.train_layer_place: 63 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 64 | self.step_store[key].append(attn.clone()) 65 | 66 | return attn 67 | 68 | def between_steps(self): 69 | self.cur_step = 0 70 | self.attention_store = self.step_store 71 | self.step_store = self.get_empty_store() 72 | 73 | def get_average_attention(self): 74 | # debug: make sure the attn map will not be changed without clone 75 | average_attention = {key: [item for item in self.attention_store[key]] for key in self.attention_store} 76 | return average_attention 77 | 78 | def reset(self): 79 | super(AttentionStore, self).reset() 80 | self.step_store = self.get_empty_store() 81 | self.attention_store = {} 82 | 83 | def __init__(self, train_layer_ls): 84 | super(AttentionStore, self).__init__() 85 | self.step_store = self.get_empty_store() 86 | self.attention_store = {} 87 | train_layer_place = [] 88 | for inst in train_layer_ls: 89 | train_layer_place.append(inst.split('_')[0]) 90 | self.train_layer_place = list(set(train_layer_place)) 91 | 92 | def register_attention_control(unet_model, controller): 93 | def ca_forward(self, place_in_unet): 94 | to_out = self.to_out 95 | if type(to_out) is torch.nn.modules.container.ModuleList: 96 | to_out = self.to_out[0] 97 | else: 98 | to_out = self.to_out 99 | 100 | def forward(hidden_states, encoder_hidden_states=None, attention_mask=None,temb=None,): 101 | is_cross = encoder_hidden_states is not None 102 | 103 | residual = hidden_states 104 | 105 | if self.spatial_norm is not None: 106 | hidden_states = self.spatial_norm(hidden_states, temb) 107 | 108 | input_ndim = hidden_states.ndim 109 | 110 | if input_ndim == 4: 111 | batch_size, channel, height, width = hidden_states.shape 112 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 113 | 114 | batch_size, sequence_length, _ = ( 115 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 116 | ) 117 | attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) 118 | 119 | if self.group_norm is not None: 120 | hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 121 | 122 | query = self.to_q(hidden_states) 123 | 124 | if encoder_hidden_states is None: 125 | encoder_hidden_states = hidden_states 126 | elif self.norm_cross: 127 | encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) 128 | 129 | key = self.to_k(encoder_hidden_states) 130 | value = self.to_v(encoder_hidden_states) 131 | 132 | query = self.head_to_batch_dim(query) 133 | key = self.head_to_batch_dim(key) 134 | value = self.head_to_batch_dim(value) 135 | 136 | attention_probs = self.get_attention_scores(query, key, attention_mask) 137 | 138 | if attention_probs.requires_grad: 139 | attention_probs = controller(attention_probs, is_cross, place_in_unet) 140 | 141 | hidden_states = torch.bmm(attention_probs, value) 142 | hidden_states = self.batch_to_head_dim(hidden_states) 143 | 144 | # linear proj 145 | hidden_states = to_out(hidden_states) 146 | # all drop out in diffusers are 0.0 147 | # so we here ignore dropout 148 | 149 | if input_ndim == 4: 150 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 151 | 152 | if self.residual_connection: 153 | hidden_states = hidden_states + residual 154 | 155 | hidden_states = hidden_states / self.rescale_output_factor 156 | 157 | return hidden_states 158 | return forward 159 | 160 | assert controller is not None, "controller must be specified" 161 | 162 | def register_recr(net_, count, place_in_unet): 163 | if net_.__class__.__name__ == 'Attention': 164 | net_.forward = ca_forward(net_, place_in_unet) 165 | return count + 1 166 | elif hasattr(net_, 'children'): 167 | for net__ in net_.children(): 168 | count = register_recr(net__, count, place_in_unet) 169 | return count 170 | 171 | down_count = 0 172 | up_count = 0 173 | mid_count = 0 174 | 175 | cross_att_count = 0 176 | sub_nets = unet_model.named_children() 177 | for net in sub_nets: 178 | 179 | if "down" in net[0]: 180 | down_temp = register_recr(net[1], 0, "down") 181 | cross_att_count += down_temp 182 | down_count += down_temp 183 | elif "up" in net[0]: 184 | up_temp = register_recr(net[1], 0, "up") 185 | cross_att_count += up_temp 186 | up_count += up_temp 187 | elif "mid" in net[0]: 188 | mid_temp = register_recr(net[1], 0, "mid") 189 | cross_att_count += mid_temp 190 | mid_count += mid_temp 191 | 192 | controller.num_att_layers = cross_att_count 193 | 194 | def get_cross_attn_map_from_unet(attention_store: AttentionStore, is_training_sd21, 195 | reses=[64, 32, 16, 8], poses=["down", "mid", "up"]): 196 | attention_maps = attention_store.get_average_attention() 197 | 198 | attn_dict = {} 199 | 200 | if is_training_sd21: 201 | reses = [int(SD14_TO_SD21_RATIO * item) for item in reses] 202 | 203 | for pos in poses: 204 | for res in reses: 205 | temp_list = [] 206 | for item in attention_maps[f"{pos}_cross"]: 207 | if item.shape[1] == res ** 2: 208 | cross_maps = item.reshape(-1, res, res, item.shape[-1]) 209 | temp_list.append(cross_maps) 210 | # if such resolution exists 211 | if len(temp_list) > 0: 212 | attn_dict[f"{pos}_{res}"] = temp_list 213 | return attn_dict -------------------------------------------------------------------------------- /attr_concen_utils/gsam_interface.py: -------------------------------------------------------------------------------- 1 | from ultralytics import YOLO 2 | from collections import defaultdict 3 | 4 | from PIL import Image 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn as nn 9 | from seg_model.gsam.EfficientSAM.FastSAM.tools import * 10 | from seg_model.gsam.GroundingDINO.groundingdino.util.inference import load_model, predict 11 | from torchvision.ops import box_convert 12 | from attn_utils.tc_loss_utils import get_grounding_loss_by_layer 13 | import groundingdino.datasets.transforms as T 14 | 15 | class GsamSegModel(nn.Module): 16 | def __init__(self, args, device, train_layer_ls=['mid_8', 'up_16', 'up_32', 'up_64']) -> None: 17 | super().__init__() 18 | 19 | self.device = device 20 | 21 | self.train_layer_ls = train_layer_ls 22 | self.args = args 23 | # load sam model 24 | sam_model_path = 'pretrained_model/FastSAM-x.pt' 25 | self.sam = YOLO(sam_model_path) 26 | self.sam.model.to(device) 27 | for p in self.sam.model.parameters(): 28 | p.requires_grad = False 29 | 30 | # load gdino 31 | groundingdino_config = "seg_model/gsam/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py" 32 | groundingdino_ckpt_path = "pretrained_model/groundingdino_swint_ogc.pth" 33 | 34 | self.gdino = load_model(groundingdino_config, groundingdino_ckpt_path, device='cpu') 35 | self.gdino.to(device) 36 | for n, p in self.gdino.named_parameters(): 37 | p.requires_grad = False 38 | 39 | self.gdino_transform_np = T.Compose( 40 | [ 41 | T.RandomResize([800], max_size=1333), 42 | T.ToTensor(), 43 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 44 | ] 45 | ) 46 | 47 | self.gdino_transform_tensor = T.Compose( 48 | [ 49 | T.RandomResize([800], max_size=1333), 50 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 51 | ] 52 | ) 53 | 54 | @torch.no_grad() 55 | def get_mask(self, image, nouns): 56 | width = image.shape[1] 57 | height = image.shape[2] 58 | # use np_image 59 | image=image.mul(255).clamp(0,255).to(torch.uint8) 60 | np_image = image.cpu().permute(1,2,0).numpy() 61 | # use tensor image 62 | # tensor_image = image.unsqueeze(0) 63 | 64 | seg_results = self.sam( 65 | np_image, 66 | # tensor_image, 67 | imgsz=(width, height), 68 | device=image.device, 69 | retina_masks=True, # draw high-resolution segmentation masks 70 | iou=0.9, # iou threshold for filtering the annotations 71 | conf=0.4, # object confidence threshold 72 | max_det=100, 73 | verbose=False 74 | ) 75 | 76 | # sometimes it will have no mask 77 | if seg_results[0].masks is None: 78 | print("No mask is detected") 79 | return None 80 | 81 | 82 | # use np_image 83 | pil_image = Image.fromarray(np_image.astype(np.uint8)).convert("RGB") 84 | image_transformed, _ = self.gdino_transform_np(pil_image, None) 85 | image_transformed = image_transformed.to(image.device) 86 | 87 | # use tensor image 88 | # image_transformed, _ = self.gdino_transform_tensor(tensor_image, None) 89 | 90 | caption = ' . '.join(nouns) 91 | 92 | boxes, logits, phrases = predict( 93 | model=self.gdino, 94 | image=image_transformed.squeeze(), 95 | caption=caption, 96 | box_threshold=0.3, 97 | text_threshold=0.25, 98 | # device=image.device, 99 | device=image.device 100 | ) 101 | 102 | ori_h = width # due to square image, this is fine 103 | ori_w = height 104 | 105 | # Save each frame due to the post process from FastSAM 106 | boxes = boxes * torch.Tensor([ori_w, ori_h, ori_w, ori_h]) 107 | 108 | boxes = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").cpu().numpy().tolist() 109 | 110 | noun_box_dict = defaultdict(list) 111 | for box_idx in range(len(boxes)): 112 | try: 113 | noun_idx = nouns.index(phrases[box_idx].strip()) 114 | except: # the box does not belong to single noun, just pass 115 | # print(f"not detected: {phrases[box_idx]}; nouns: {nouns}") 116 | continue 117 | 118 | mask, _ = box_prompt( 119 | seg_results[0].masks.data, 120 | boxes[box_idx], 121 | ori_h, 122 | ori_w, 123 | to_numpy=False, 124 | ) 125 | noun_box_dict[str(noun_idx)].append(mask.squeeze()) 126 | 127 | mask_list = [] 128 | for i in range(len(nouns)): 129 | if str(i) in noun_box_dict: # if the object is detected 130 | noun_mask = torch.sum(torch.stack(noun_box_dict[str(i)]), dim=0).squeeze() 131 | noun_mask = noun_mask > 0 # make this to be a [0,1] mask 132 | else: # no object is detected 133 | noun_mask = image.new_zeros((width, height)) > 0 134 | 135 | mask_list.append(noun_mask.unsqueeze(0).unsqueeze(0)) 136 | 137 | return mask_list 138 | 139 | 140 | def get_mask_loss(self, images, prompt, all_subtree_indices, attn_map_idx_to_wp_all, attn_map): 141 | images = images.detach().requires_grad_(False) 142 | 143 | token_loss = images.new_zeros(()) # could not directly be 0, int has no detach 144 | pixel_loss = images.new_zeros(()) 145 | grounding_loss_dict = defaultdict(int) 146 | 147 | # re-organize the attn map, split each batch 148 | bs = images.shape[0] 149 | attn_map_per_sample = [] 150 | for i in range(bs): 151 | attn_map_per_timestep = {} 152 | 153 | for timestep in attn_map: 154 | attn_map_per_timestep[timestep] = {} 155 | for place in attn_map[timestep]: 156 | attn_map_per_timestep[timestep][place] = [] 157 | for inst in attn_map[timestep][place]: 158 | inst = inst.reshape(bs, inst.shape[0]//bs, inst.shape[1], inst.shape[2], inst.shape[3]) 159 | attn_map_per_timestep[timestep][place].append(inst[i]) 160 | 161 | attn_map_per_sample.append(attn_map_per_timestep) 162 | 163 | for idx, subtree_indices in enumerate(all_subtree_indices): 164 | attn_map_idx_to_wp = attn_map_idx_to_wp_all[idx] 165 | image = images[idx] 166 | 167 | nouns = [] 168 | attributes = [] 169 | for subtree in subtree_indices: 170 | if len(subtree) < 1: # should also collect single object 171 | continue 172 | noun_indices = subtree[-1] if isinstance(subtree[-1], list) else [subtree[-1]] 173 | noun_char = [attn_map_idx_to_wp[i] for i in noun_indices] 174 | noun = ''.join(noun_char) 175 | nouns.append(noun) 176 | # attribute: [[17, 18], 19] -> [17, 18, 19] 177 | attribute = [] 178 | for attribute_char in subtree[:-1]: 179 | if isinstance(attribute_char, list): 180 | attribute.extend(attribute_char) 181 | else: 182 | attribute.append(attribute_char) 183 | # also add nouns into mask loss 184 | attribute.extend(noun_indices) 185 | 186 | attributes.append(attribute) 187 | 188 | if len(nouns) == 0: 189 | continue 190 | 191 | # nouns: [noun1, noun2, noun3] 192 | # attributes: [[1,2,3], [4,5], [18,19]] 193 | nouns, attributes = self.update_nouns_attributes(nouns, attributes) 194 | 195 | if len(nouns) == 0: 196 | continue 197 | 198 | with torch.autocast(device_type='cuda'): 199 | mask = self.get_mask(image, nouns) 200 | 201 | if mask == None: 202 | continue 203 | 204 | for timestep in attn_map_per_sample[idx]: 205 | for train_layer in self.train_layer_ls: 206 | layer_res = int(train_layer.split("_")[1]) 207 | attn_loss_dict = \ 208 | get_grounding_loss_by_layer( 209 | _gt_seg_list=mask, 210 | word_token_idx_ls=attributes, 211 | res=layer_res, 212 | input_attn_map_ls=attn_map_per_sample[idx][timestep][train_layer], 213 | is_training_sd21=False, 214 | ) 215 | 216 | layer_token_loss = attn_loss_dict["token_loss"] 217 | layer_pixel_loss = attn_loss_dict["pixel_loss"] 218 | 219 | grounding_loss_dict[f"token/{timestep}/{train_layer}"] += layer_token_loss 220 | grounding_loss_dict[f"pixel/{timestep}/{train_layer}"] += layer_pixel_loss 221 | 222 | token_loss += layer_token_loss 223 | pixel_loss += layer_pixel_loss 224 | 225 | token_loss = token_loss / len(all_subtree_indices) 226 | pixel_loss = pixel_loss / len(all_subtree_indices) 227 | 228 | return token_loss, pixel_loss, grounding_loss_dict 229 | 230 | 231 | # rm not qualified nouns 232 | def update_nouns_attributes(self, nouns, attributes): 233 | new_nouns, new_attributes = [], [] 234 | # rm duplicate nouns, do not calculate loss if there is duplicate nouns 235 | nouns2idx = defaultdict(list) 236 | for idx, n in enumerate(nouns): 237 | nouns2idx[n].append(idx) 238 | for n in nouns2idx: 239 | if len(nouns2idx[n]) > 1: 240 | continue 241 | else: 242 | new_nouns.append(n) 243 | new_attributes.append(attributes[nouns2idx[n][0]]) 244 | 245 | # rm invalid nouns 246 | filtered_nouns, filtered_attributes = [], [] 247 | invalid_nouns = set(['scene', 'surface', 'area', 'atmosphere', 'noise', 'place', 'kitchen', 'dream', 'interior', 'exterior', 248 | 'meal', 'background', 'bathroom', 'room', 'scent', 'street', 'hillside', 'mountain', 'sky', 'sea', 'ocean', 'lost', 249 | 'language', 'skill', 'one', 'night', 'day', 'morning', 'space', 'environment', 'conditions', 'field', 'shore', 'restroom', 250 | 'party', 'grass', 'snow', 'meadow', 'water', 'shadow', 'waves', 'song', 'cycle', 'sunlight', 'mysteries', 'wall', 'salon', 251 | 'range', 'cry', 'speech', 'tone', 'thing', 'about', 'activity', 'air', 'advertisement', 'airport', 'also']) 252 | 253 | for idx, n in enumerate(new_nouns): 254 | if n in invalid_nouns or n[:-1] in invalid_nouns: 255 | continue 256 | else: 257 | filtered_nouns.append(n) 258 | filtered_attributes.append(new_attributes[idx]) 259 | 260 | 261 | return filtered_nouns, filtered_attributes 262 | -------------------------------------------------------------------------------- /attr_concen_utils/load_segmodel.py: -------------------------------------------------------------------------------- 1 | from attr_concen_utils.gsam_interface import GsamSegModel 2 | 3 | def load_seg_model(args, device, train_layer_ls): 4 | if 'gsam' in args.seg_model: 5 | seg_model = GsamSegModel(args, device, train_layer_ls) 6 | else: 7 | raise NotImplementedError 8 | 9 | return seg_model 10 | -------------------------------------------------------------------------------- /attribute_concen_utils.py: -------------------------------------------------------------------------------- 1 | import torch.distributions as dist 2 | from typing import List, Dict 3 | import itertools 4 | import pdb 5 | import torch 6 | 7 | start_token = "<|startoftext|>" 8 | end_token = "<|endoftext|>" 9 | 10 | 11 | def align_wordpieces_indices( 12 | wordpieces2indices, start_idx, target_word 13 | ): 14 | """ 15 | Aligns a `target_word` that contains more than one wordpiece (the first wordpiece is `start_idx`) 16 | """ 17 | 18 | wp_indices = [start_idx] 19 | wp = wordpieces2indices[start_idx].replace("", "") 20 | 21 | # Run over the next wordpieces in the sequence (which is why we use +1) 22 | for wp_idx in range(start_idx + 1, len(wordpieces2indices)): 23 | if wp == target_word: 24 | break 25 | 26 | wp2 = wordpieces2indices[wp_idx].replace("", "") 27 | if target_word.startswith(wp + wp2) and wp2 != target_word: 28 | wp += wordpieces2indices[wp_idx].replace("", "") 29 | wp_indices.append(wp_idx) 30 | else: 31 | wp_indices = ( 32 | [] 33 | ) # if there's no match, you want to clear the list and finish 34 | break 35 | 36 | return wp_indices 37 | 38 | 39 | def extract_attribution_indices(doc): 40 | # doc = parser(prompt) 41 | subtrees = [] 42 | modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"] 43 | 44 | for w in doc: 45 | if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers: 46 | continue 47 | subtree = [] 48 | stack = [] 49 | for child in w.children: 50 | if child.dep_ in modifiers: 51 | subtree.append(child) 52 | stack.extend(child.children) 53 | 54 | while stack: 55 | node = stack.pop() 56 | if node.dep_ in modifiers or node.dep_ == "conj": 57 | subtree.append(node) 58 | stack.extend(node.children) 59 | if subtree: 60 | subtree.append(w) 61 | subtrees.append(subtree) 62 | return subtrees 63 | 64 | def extract_attribution_indices_with_verbs(doc): 65 | '''This function specifically addresses cases where a verb is between 66 | a noun and its modifier. For instance: "a dog that is red" 67 | here, the aux is between 'dog' and 'red'. ''' 68 | 69 | subtrees = [] 70 | modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp", 71 | 'relcl'] 72 | for w in doc: 73 | if w.pos_ not in ["NOUN", "PROPN"] or w.dep_ in modifiers: 74 | continue 75 | subtree = [] 76 | stack = [] 77 | for child in w.children: 78 | if child.dep_ in modifiers: 79 | if child.pos_ not in ['AUX', 'VERB']: 80 | subtree.append(child) 81 | stack.extend(child.children) 82 | 83 | while stack: 84 | node = stack.pop() 85 | if node.dep_ in modifiers or node.dep_ == "conj": 86 | # we don't want to add 'is' or other verbs to the loss, we want their children 87 | if node.pos_ not in ['AUX', 'VERB']: 88 | subtree.append(node) 89 | stack.extend(node.children) 90 | if subtree: 91 | subtree.append(w) 92 | subtrees.append(subtree) 93 | return subtrees 94 | 95 | def extract_attribution_indices_with_verb_root(doc): 96 | '''This function specifically addresses cases where a verb is between 97 | a noun and its modifier. For instance: "a dog that is red" 98 | here, the aux is between 'dog' and 'red'. ''' 99 | 100 | subtrees = [] 101 | modifiers = ["amod", "nmod", "compound", "npadvmod", "advmod", "acomp"] 102 | for w in doc: 103 | subtree = [] 104 | stack = [] 105 | 106 | # if w is a verb/aux and has a noun child and a modifier child, add them to the stack 107 | if w.pos_ != 'AUX' or w.dep_ in modifiers: 108 | continue 109 | 110 | for child in w.children: 111 | if child.dep_ in modifiers or child.pos_ in ['NOUN', 'PROPN']: 112 | if child.pos_ not in ['AUX', 'VERB']: 113 | subtree.append(child) 114 | stack.extend(child.children) 115 | # did not find a pair of noun and modifier 116 | if len(subtree) < 2: 117 | continue 118 | 119 | while stack: 120 | node = stack.pop() 121 | if node.dep_ in modifiers or node.dep_ == "conj": 122 | # we don't want to add 'is' or other verbs to the loss, we want their children 123 | if node.pos_ not in ['AUX']: 124 | subtree.append(node) 125 | stack.extend(node.children) 126 | 127 | if subtree: 128 | if w.pos_ not in ['AUX']: 129 | subtree.append(w) 130 | subtrees.append(subtree) 131 | return subtrees 132 | 133 | 134 | def get_indices(tokenizer, prompt: str) -> Dict[str, int]: 135 | """Utility function to list the indices of the tokens you wish to alter""" 136 | ids = tokenizer(prompt).input_ids 137 | indices = { 138 | i: tok 139 | for tok, i in zip( 140 | tokenizer.convert_ids_to_tokens(ids), range(len(ids)) 141 | ) 142 | } 143 | return indices 144 | 145 | def get_attention_map_index_to_wordpiece(tokenizer, prompt): 146 | attn_map_idx_to_wp = {} 147 | 148 | wordpieces2indices = get_indices(tokenizer, prompt) 149 | 150 | # Ignore `start_token` and `end_token` 151 | for i in list(wordpieces2indices.keys())[1:-1]: 152 | wordpiece = wordpieces2indices[i] 153 | wordpiece = wordpiece.replace("", "") 154 | attn_map_idx_to_wp[i] = wordpiece 155 | 156 | return attn_map_idx_to_wp -------------------------------------------------------------------------------- /concept_mat_utils/caption_blip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from dataclasses import dataclass, field 4 | from PIL import Image 5 | 6 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, \ 7 | InterpolationMode, ToTensor, Resize, CenterCrop 8 | 9 | from transformers import BlipForConditionalGeneration 10 | from concept_mat_utils.processing_blip import BlipProcessor 11 | 12 | IGNORE_INDEX = -100 13 | 14 | class Blip(torch.nn.Module): 15 | def __init__(self, model_path, device, args=None, score_qa=False) -> None: 16 | super(Blip, self).__init__() 17 | self.processor = BlipProcessor.from_pretrained(model_path) 18 | self.model = BlipForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16).to(device) 19 | 20 | for n, p in self.model.named_parameters(): 21 | p.requires_grad = False 22 | 23 | self.mean = [ 24 | 0.48145466, 25 | 0.4578275, 26 | 0.40821073 27 | ] 28 | self.std = [ 29 | 0.26862954, 30 | 0.26130258, 31 | 0.27577711 32 | ] 33 | self.transforms = Compose([ 34 | Resize(size=(384,384), interpolation=InterpolationMode.BICUBIC, antialias=True), 35 | Normalize(mean=self.mean, std=self.std), 36 | ]) 37 | 38 | self.prompt = 'a photography of' 39 | self.prompt_length = len(self.processor.tokenizer(self.prompt).input_ids) - 1 40 | 41 | self.args = args 42 | 43 | def score(self, images, prompts, **kwargs): 44 | 45 | images = torch.stack([self.transforms(image) for image in images]) 46 | 47 | text = [self.prompt + ' ' + prompt.lower() for prompt in prompts] 48 | inputs = self.processor(images=images, text=text, return_tensors="pt", padding='longest') 49 | device = images.device 50 | inputs = {key: inputs[key].to(device) for key in inputs.keys()} 51 | inputs['labels'] = inputs['input_ids'].masked_fill( 52 | inputs['input_ids'] == self.processor.tokenizer.pad_token_id, -100 53 | ) 54 | inputs['labels'][:, :self.prompt_length] = -100 55 | 56 | with torch.autocast(device_type="cuda"): 57 | outputs = self.model(**inputs) 58 | reward = -outputs.loss 59 | return reward 60 | -------------------------------------------------------------------------------- /concept_mat_utils/load_captionmodel.py: -------------------------------------------------------------------------------- 1 | from concept_mat_utils.caption_blip import Blip 2 | 3 | def load_model(reward_wrapper, caption_model, load_device, args): 4 | if 'Blip' in caption_model: 5 | model_path = 'Salesforce/blip-image-captioning-large' 6 | reward_wrapper.blip_model = Blip(model_path, load_device, args) 7 | reward_wrapper.caption_model_dict[str(caption_model.index('Blip'))] = reward_wrapper.blip_model 8 | else: 9 | raise NotImplementedError -------------------------------------------------------------------------------- /concept_mat_utils/processing_blip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | Processor class for Blip. 17 | """ 18 | 19 | from typing import List, Optional, Union 20 | 21 | from transformers.image_utils import ImageFeatureExtractionMixin 22 | 23 | from transformers.image_utils import ImageInput 24 | from transformers.image_processing_utils import BatchFeature 25 | from transformers.processing_utils import ProcessorMixin 26 | from transformers.tokenization_utils_base import BatchEncoding, PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy 27 | from transformers.utils import TensorType 28 | 29 | 30 | class BlipProcessor(ProcessorMixin): 31 | r""" 32 | Constructs a BLIP processor which wraps a BERT tokenizer and BLIP image processor into a single processor. 33 | 34 | [`BlipProcessor`] offers all the functionalities of [`BlipImageProcessor`] and [`BertTokenizerFast`]. See the 35 | docstring of [`~BlipProcessor.__call__`] and [`~BlipProcessor.decode`] for more information. 36 | 37 | Args: 38 | image_processor (`BlipImageProcessor`): 39 | An instance of [`BlipImageProcessor`]. The image processor is a required input. 40 | tokenizer (`BertTokenizerFast`): 41 | An instance of ['BertTokenizerFast`]. The tokenizer is a required input. 42 | """ 43 | attributes = ["image_processor", "tokenizer"] 44 | image_processor_class = "BlipImageProcessor" 45 | tokenizer_class = ("BertTokenizer", "BertTokenizerFast") 46 | 47 | def __init__(self, image_processor, tokenizer): 48 | tokenizer.return_token_type_ids = False 49 | super().__init__(image_processor, tokenizer) 50 | self.current_processor = self.image_processor 51 | 52 | def __call__( 53 | self, 54 | images: ImageInput = None, 55 | text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, 56 | add_special_tokens: bool = True, 57 | padding: Union[bool, str, PaddingStrategy] = False, 58 | truncation: Union[bool, str, TruncationStrategy] = None, 59 | max_length: Optional[int] = None, 60 | stride: int = 0, 61 | pad_to_multiple_of: Optional[int] = None, 62 | return_attention_mask: Optional[bool] = None, 63 | return_overflowing_tokens: bool = False, 64 | return_special_tokens_mask: bool = False, 65 | return_offsets_mapping: bool = False, 66 | return_token_type_ids: bool = False, 67 | return_length: bool = False, 68 | verbose: bool = True, 69 | return_tensors: Optional[Union[str, TensorType]] = None, 70 | **kwargs, 71 | ) -> BatchEncoding: 72 | """ 73 | This method uses [`BlipImageProcessor.__call__`] method to prepare image(s) for the model, and 74 | [`BertTokenizerFast.__call__`] to prepare text for the model. 75 | 76 | Please refer to the docstring of the above two methods for more information. 77 | """ 78 | if images is None and text is None: 79 | raise ValueError("You have to specify either images or text.") 80 | 81 | # Get only text 82 | if images is None: 83 | self.current_processor = self.tokenizer 84 | text_encoding = self.tokenizer( 85 | text=text, 86 | add_special_tokens=add_special_tokens, 87 | padding=padding, 88 | truncation=truncation, 89 | max_length=max_length, 90 | stride=stride, 91 | pad_to_multiple_of=pad_to_multiple_of, 92 | return_attention_mask=return_attention_mask, 93 | return_overflowing_tokens=return_overflowing_tokens, 94 | return_special_tokens_mask=return_special_tokens_mask, 95 | return_offsets_mapping=return_offsets_mapping, 96 | return_token_type_ids=return_token_type_ids, 97 | return_length=return_length, 98 | verbose=verbose, 99 | return_tensors=return_tensors, 100 | **kwargs, 101 | ) 102 | return text_encoding 103 | 104 | # add pixel_values 105 | # encoding_image_processor = self.image_processor(images, return_tensors=return_tensors) 106 | encoding_image_processor = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors) 107 | 108 | if text is not None: 109 | text_encoding = self.tokenizer( 110 | text=text, 111 | add_special_tokens=add_special_tokens, 112 | padding=padding, 113 | truncation=truncation, 114 | max_length=max_length, 115 | stride=stride, 116 | pad_to_multiple_of=pad_to_multiple_of, 117 | return_attention_mask=return_attention_mask, 118 | return_overflowing_tokens=return_overflowing_tokens, 119 | return_special_tokens_mask=return_special_tokens_mask, 120 | return_offsets_mapping=return_offsets_mapping, 121 | return_token_type_ids=return_token_type_ids, 122 | return_length=return_length, 123 | verbose=verbose, 124 | return_tensors=return_tensors, 125 | **kwargs, 126 | ) 127 | else: 128 | text_encoding = None 129 | 130 | if text_encoding is not None: 131 | encoding_image_processor.update(text_encoding) 132 | 133 | return encoding_image_processor 134 | 135 | def batch_decode(self, *args, **kwargs): 136 | """ 137 | This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please 138 | refer to the docstring of this method for more information. 139 | """ 140 | return self.tokenizer.batch_decode(*args, **kwargs) 141 | 142 | def decode(self, *args, **kwargs): 143 | """ 144 | This method forwards all its arguments to BertTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to 145 | the docstring of this method for more information. 146 | """ 147 | return self.tokenizer.decode(*args, **kwargs) 148 | 149 | @property 150 | def model_input_names(self): 151 | tokenizer_input_names = self.tokenizer.model_input_names 152 | image_processor_input_names = self.image_processor.model_input_names 153 | return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) 154 | -------------------------------------------------------------------------------- /fig/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CaraJ7/CoMat/595cf4ad11fd241c04dd51ea6ba41a16981d79cb/fig/demo.png -------------------------------------------------------------------------------- /node8.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: fp16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | main_process_port: 18905 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.31.0 2 | ultralytics 3 | timm==0.6.12 4 | torch==2.0.0 5 | torchvision==0.15.1 6 | diffusers>=0.22.1 7 | xformers==0.0.24 8 | tensorboard 9 | spacy 10 | accelerate 11 | shortuuid 12 | xformers==0.0.24 -------------------------------------------------------------------------------- /scripts/sd15.sh: -------------------------------------------------------------------------------- 1 | python -u -m accelerate.commands.launch --config_file node8.yaml --main_process_port 12213 \ 2 | training_script.py \ 3 | --pretrain_model runwayml/stable-diffusion-v1-5 --resolution 512 \ 4 | --train_batch_size 4 --gradient_accumulation_steps 1 --max_train_steps 2000 \ 5 | --learning_rate 5e-5 --max_grad_norm 0.1 --lr_scheduler constant --lr_warmup_steps 0 \ 6 | --output_dir output/sd15 \ 7 | --caption_model "Blip" --gradient_checkpointing \ 8 | --mixed_precision=fp16 --validation_prompts "A man walking on street" \ 9 | --seed 42 --K 5 --lora_rank 128 --training_prompts train_data/gan_abc5k_t2icomp_hrs_20k_sd15.jsonl \ 10 | --total_step 50 --scheduler DDPM \ 11 | --validation_prompts_file valid_15k.txt \ 12 | --gan_loss --gan_loss_weight 1 --learning_rate_D 2e-5 --adam_beta1_D 0 --max_grad_norm_D 1 \ 13 | --validation_steps 200 --pretrain_model_name sd_1_5_attrcon \ 14 | --mask_token_loss_weight 1e-3 --mask_pixel_loss_weight 5e-5 --attrcon_train_steps 2 \ 15 | --gan_model_arch gansd_1_5 --seg_model gsam -------------------------------------------------------------------------------- /scripts/sdxl.sh: -------------------------------------------------------------------------------- 1 | python -u -m accelerate.commands.launch --config_file node8.yaml --main_process_port 12213 \ 2 | training_script.py \ 3 | --pretrain_model stabilityai/stable-diffusion-xl-base-1.0 --resolution 512 \ 4 | --train_batch_size 6 --gradient_accumulation_steps 1 --max_train_steps 2000 \ 5 | --learning_rate 2e-5 --max_grad_norm 0.1 --lr_scheduler constant --lr_warmup_steps 0 \ 6 | --output_dir output/sdxl \ 7 | --caption_model "Blip" --gradient_checkpointing \ 8 | --mixed_precision=fp16 --validation_prompts "A man walking on street" \ 9 | --seed 42 --K 5 --lora_rank 128 \ 10 | --training_prompts train_data/gan_abc5k_t2icomp_hrs_20k_sdxl_unet.jsonl \ 11 | --total_step 50 --scheduler DDPM \ 12 | --validation_prompts_file valid_15k.txt \ 13 | --gan_loss --gan_loss_weight 5e-1 --learning_rate_D 5e-5 --adam_beta1_D 0 --max_grad_norm_D 1 \ 14 | --validation_steps 200 --pretrain_model_name sdxl_attrcon_unet \ 15 | --mask_token_loss_weight 1e-3 --mask_pixel_loss_weight 5e-5 --attrcon_train_steps 2 \ 16 | --sdxl_unet_path FINETUNED_UNET_PATH \ 17 | --gan_model_arch gansd_1_5 --seg_model gsam --num_validation_images 0 -------------------------------------------------------------------------------- /tools/gan_gt_generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:18000" 4 | import random 5 | import json 6 | 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | import torch.utils.checkpoint 11 | from packaging import version 12 | from tqdm.auto import tqdm 13 | 14 | from PIL import Image 15 | 16 | from torchvision.transforms import Compose, Resize, CenterCrop, Normalize 17 | from torchvision.transforms.functional import pil_to_tensor 18 | try: 19 | from torchvision.transforms import InterpolationMode 20 | BICUBIC = InterpolationMode.BICUBIC 21 | except ImportError: 22 | BICUBIC = Image.BICUBIC 23 | 24 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler 25 | from diffusers.utils.import_utils import is_xformers_available 26 | 27 | from TrainableSDPipeline import TrainableSDPipeline, TrainableSDXLPipeline 28 | from torchvision import transforms as T 29 | 30 | import shortuuid 31 | import time 32 | 33 | import threading 34 | 35 | # build a lock 36 | lock = threading.Lock() 37 | 38 | def write_data(file_name, data): 39 | # obtain the lock 40 | with lock: 41 | with open(file_name, 'a') as f: 42 | f.write(data + '\n') 43 | 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("--start", type=int, default="0") 46 | parser.add_argument("--end", type=int, default="10000") 47 | parser.add_argument("--unet-path", type=str) 48 | parser.add_argument("--use-cache", action='store_true') 49 | parser.add_argument("--prompt-path", type=str, default='merged_data/abc5k_hrs10k_t2icompall_20k.txt') 50 | parser.add_argument("--save-prompt-path", type=str, default='train_data/gan_train_data.jsonl') 51 | parser.add_argument("--model-path", type=str, default="runwayml/stable-diffusion-v1-5") 52 | parser.add_argument("--model-type", type=str, default='sd_1_5') 53 | parser.add_argument("--batch-size", type=int, default=8) 54 | args = parser.parse_args() 55 | 56 | def setup_seed(seed): 57 | torch.manual_seed(seed) 58 | torch.cuda.manual_seed_all(seed) 59 | np.random.seed(seed) 60 | random.seed(seed) 61 | torch.backends.cudnn.deterministic = True 62 | # random seed 63 | timestamp = int(time.time()) 64 | setup_seed(timestamp) 65 | 66 | 67 | def load_diffusion_pipeline(model_path, revision, weight_dtype, args): 68 | if args.model_type == 'sd_1_5': 69 | pipeline = TrainableSDPipeline.from_pretrained(model_path, revision=revision, torch_type=weight_dtype) 70 | elif args.model_type == 'sdxl': 71 | vae_path = "madebyollin/sdxl-vae-fp16-fix" 72 | vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16) 73 | pipeline = TrainableSDXLPipeline.from_pretrained(model_path, revision=revision, vae=vae, torch_type=weight_dtype) 74 | elif args.model_type == 'sdxl_unet': # for a UNet fine-tuned on 512 75 | vae_path = "madebyollin/sdxl-vae-fp16-fix" 76 | vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16) 77 | unet_path = args.unet_path 78 | unet = UNet2DConditionModel.from_pretrained(unet_path) 79 | pipeline = TrainableSDXLPipeline.from_pretrained(model_path, revision=revision, vae=vae, unet=unet, torch_type=weight_dtype) 80 | else: 81 | raise NotImplementedError("This model is not supported yet") 82 | return pipeline 83 | 84 | def read_jsonl(save_path): 85 | ret_list = [] 86 | if os.path.exists(save_path): 87 | with open(save_path, 'r') as f: 88 | for line in f: 89 | ret_list.append(json.loads(line)) 90 | return ret_list 91 | 92 | class Prompt_dataset(torch.utils.data.Dataset): 93 | def __init__(self, data_path, save_path, start, end) -> None: 94 | super().__init__() 95 | if 'txt' in data_path: 96 | self.ann = list() 97 | with open(data_path, 'r') as f: 98 | for line in f: 99 | self.ann.append(line.strip()) 100 | elif 'json' in data_path: 101 | self.ann = json.load(open(data_path, 'r')) 102 | self.ann = self.ann[start:end] 103 | 104 | if args.use_cache: 105 | ann_done = read_jsonl(save_path) 106 | prompts_done = [inst['prompt'] for inst in ann_done] 107 | self.ann = list(set(self.ann)-set(prompts_done)) 108 | 109 | def __getitem__(self, idx): 110 | return self.ann[idx] 111 | 112 | def __len__(self): 113 | return len(self.ann) 114 | 115 | # load dataset 116 | prompt_path = args.prompt_path 117 | save_prompt_path = args.save_prompt_path 118 | 119 | save_dir = os.path.dirname(save_prompt_path) 120 | os.makedirs(save_dir, exist_ok=True) 121 | os.makedirs(os.path.join(save_dir, 'latents'), exist_ok=True) 122 | 123 | prompt_dataset = Prompt_dataset(prompt_path, save_prompt_path, args.start, args.end) 124 | batch_size = args.batch_size 125 | prompt_loader = torch.utils.data.DataLoader( 126 | prompt_dataset, 127 | shuffle=False, 128 | batch_size=batch_size, 129 | num_workers=4 130 | ) 131 | 132 | # load model 133 | model_path = args.model_path 134 | pipeline = load_diffusion_pipeline(model_path, None, torch.float16, args) 135 | pipeline.to('cuda') 136 | pipeline.vae.requires_grad_(False) 137 | pipeline.text_encoder.requires_grad_(False) 138 | pipeline.unet.requires_grad_(False) 139 | pipeline.text_encoder.to(torch.float16) 140 | pipeline.unet.to(torch.float16) 141 | if isinstance(pipeline, TrainableSDXLPipeline): 142 | pipeline.text_encoder_2.requires_grad_(False) 143 | pipeline.text_encoder_2.to(torch.float16) 144 | 145 | 146 | if is_xformers_available(): 147 | import xformers 148 | 149 | xformers_version = version.parse(xformers.__version__) 150 | if xformers_version == version.parse("0.0.16"): 151 | print( 152 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 153 | ) 154 | pipeline.enable_xformers_memory_efficient_attention() 155 | else: 156 | raise ValueError("xformers is not available. Make sure it is installed correctly") 157 | 158 | 159 | scheduler_args = {} 160 | if "variance_type" in pipeline.scheduler.config: 161 | variance_type = pipeline.scheduler.config.variance_type 162 | 163 | if variance_type in ["learned", "learned_range"]: 164 | variance_type = "fixed_small" 165 | 166 | scheduler_args["variance_type"] = variance_type 167 | 168 | pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config, **scheduler_args) 169 | generator = torch.Generator(device='cuda').manual_seed(int(time.time())) 170 | 171 | with torch.no_grad(): 172 | for idx, prompts in tqdm(enumerate(prompt_loader)): 173 | latents = pipeline( 174 | prompts, 175 | height=512, 176 | width=512, 177 | num_inference_steps=50, 178 | generator=generator, 179 | guidance_scale=7.5, 180 | guidance_rescale=0.0, 181 | output_type='latent').images 182 | 183 | save_list = [] 184 | for idx in range(latents.shape[0]): 185 | uid = shortuuid.uuid() 186 | torch.save(latents[idx].cpu().float(), os.path.join(save_dir, 'latents', f"{uid}.pt")) 187 | save_dict = json.dumps({ 188 | 'prompt': prompts[idx], 189 | 'file_path': os.path.join(save_dir, 'latents', f"{uid}.pt") 190 | }) 191 | save_list.append(save_dict) 192 | 193 | write_data(save_prompt_path, '\n'.join(save_list)) 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /training_script.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import json 5 | 6 | import numpy as np 7 | import torch 8 | import torch.utils.checkpoint 9 | from accelerate import Accelerator 10 | from accelerate.logging import get_logger 11 | from accelerate.utils import ProjectConfiguration, set_seed 12 | from packaging import version 13 | from tqdm.auto import tqdm 14 | 15 | from PIL import Image 16 | 17 | try: 18 | from torchvision.transforms import InterpolationMode 19 | BICUBIC = InterpolationMode.BICUBIC 20 | except ImportError: 21 | BICUBIC = Image.BICUBIC 22 | 23 | from diffusers import DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 24 | from diffusers.optimization import get_scheduler 25 | from diffusers.utils import check_min_version 26 | from diffusers.utils.import_utils import is_xformers_available 27 | from diffusers.loaders import ( 28 | LoraLoaderMixin, 29 | text_encoder_lora_state_dict, 30 | ) 31 | 32 | from TrainableSDPipeline import TrainableSDPipeline, TrainableSDXLPipeline 33 | 34 | # import training_utils 35 | from training_utils.arguments import parse_args 36 | from training_utils.logging import set_logger 37 | from training_utils.pipeline import * 38 | from training_utils.gan_sd_model import load_discriminator 39 | from training_utils.dataset import get_dataset_dataloader 40 | 41 | from attribute_concen_utils import get_attention_map_index_to_wordpiece 42 | from concept_mat_utils.load_captionmodel import load_model 43 | 44 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 45 | check_min_version("0.16.0.dev0") 46 | 47 | logger = get_logger(__name__, log_level="INFO") 48 | 49 | 50 | def unet_lora_state_dict(unet: UNet2DConditionModel): 51 | r""" 52 | Returns: 53 | A state dict containing just the LoRA parameters. 54 | """ 55 | lora_state_dict = {} 56 | 57 | for name, module in unet.named_modules(): 58 | if hasattr(module, "set_lora_layer"): 59 | lora_layer = getattr(module, "lora_layer") 60 | if lora_layer is not None: 61 | current_lora_layer_sd = lora_layer.state_dict() 62 | for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items(): 63 | # The matrix name can either be "down" or "up". 64 | lora_state_dict[f"unet.{name}.lora.{lora_layer_matrix_name}"] = lora_param 65 | 66 | return lora_state_dict 67 | 68 | 69 | class CaptionModelWrapper(torch.nn.Module): 70 | def __init__(self, caption_model, weights, device, args, dtype): 71 | super().__init__() 72 | 73 | self.caption_model = caption_model 74 | self.model_name = caption_model 75 | 76 | self.device = device 77 | self.dtype = dtype 78 | self.caption_model_dict = {} 79 | load_device = device 80 | 81 | self.weights = {} 82 | self.args = args 83 | for model, weight in zip(caption_model, weights): 84 | self.weights[model] = weight 85 | 86 | 87 | load_model(self, caption_model, load_device, args) 88 | 89 | def forward(self, images, prompts, text_encoder=None, return_feature=False, step=-1, batch=None): 90 | caption_rewards = {} 91 | 92 | if 'Blip' in self.model_name: 93 | caption_reward = self.blip_model.score(images, prompts, **batch) 94 | caption_rewards['Blip'] = caption_reward * self.weights['Blip'] 95 | 96 | caption_rewards["total"] = sum([caption_rewards[k] for k in self.model_name]) 97 | return caption_rewards 98 | 99 | class Trainer(object): 100 | 101 | def __init__(self, pretrained_model_name_or_path, args): 102 | 103 | self.pretrained_model_name_or_path = pretrained_model_name_or_path 104 | 105 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 106 | 107 | accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit, logging_dir=logging_dir) 108 | 109 | self.accelerator = Accelerator( 110 | gradient_accumulation_steps=args.gradient_accumulation_steps, 111 | mixed_precision=args.mixed_precision, 112 | log_with=args.report_to, 113 | project_config=accelerator_project_config, 114 | ) 115 | 116 | # If passed along, set the training seed now. 117 | if args.seed is not None: 118 | set_seed(args.seed) 119 | 120 | set_logger(args, self.accelerator, logger) 121 | 122 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 123 | # as these models are only used for inference, keeping weights in full precision is not required. 124 | self.weight_dtype = torch.float32 125 | if self.accelerator.mixed_precision == "fp16": 126 | self.weight_dtype = torch.float16 127 | elif self.accelerator.mixed_precision == "bf16": 128 | self.weight_dtype = torch.bfloat16 129 | 130 | self.model_name = args.pretrain_model_name 131 | self.pipeline = load_pipeline(args, self.model_name, self.weight_dtype) 132 | 133 | self.caption_model = CaptionModelWrapper(args.caption_model, weights=args.reward_weights, device=self.accelerator.device, args=args, dtype=self.weight_dtype) 134 | 135 | if args.enable_xformers_memory_efficient_attention: 136 | if is_xformers_available(): 137 | import xformers 138 | 139 | xformers_version = version.parse(xformers.__version__) 140 | if xformers_version == version.parse("0.0.16"): 141 | logger.warn( 142 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 143 | ) 144 | self.pipeline.enable_xformers_memory_efficient_attention() 145 | else: 146 | raise ValueError("xformers is not available. Make sure it is installed correctly") 147 | 148 | if args.gan_loss: 149 | self.D = load_discriminator(args, self.weight_dtype, device=self.accelerator.device) 150 | 151 | self.global_step = 0 152 | self.first_epoch = 0 153 | self.resume_step = 0 154 | resume_global_step = 0 155 | # Potentially load in the weights and states from a previous save 156 | if args.resume_from_checkpoint: 157 | path = None 158 | if args.resume_from_checkpoint != "latest": 159 | # assert False, "not implemented" 160 | path = os.path.basename(args.resume_from_checkpoint) 161 | load_dir = os.path.dirname(args.resume_from_checkpoint) 162 | else: 163 | # Get the most recent checkpoint 164 | dirs = os.listdir(args.output_dir) 165 | dirs = [d for d in dirs if d.startswith("checkpoint")] 166 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 167 | path = dirs[-1] if len(dirs) > 0 else None 168 | load_dir = args.output_dir 169 | 170 | if path is None: 171 | self.accelerator.print( 172 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 173 | ) 174 | args.resume_from_checkpoint = None 175 | else: 176 | self.accelerator.print(f"Resuming from checkpoint {path}") 177 | if args.full_finetuning: 178 | self.pipeline.unet.load_state_dict(torch.load(os.path.join(load_dir, path, "unet.pt"), map_location="cpu")) 179 | else: 180 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(os.path.join(load_dir, path, "pytorch_lora_weights.safetensors")) 181 | LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=self.accelerator.unwrap_model(self.pipeline.unet)) 182 | if args.train_text_encoder_lora: 183 | LoraLoaderMixin.load_lora_into_text_encoder( 184 | lora_state_dict, network_alphas=network_alphas, text_encoder=self.accelerator.unwrap_model(self.pipeline.text_encoder) 185 | ) 186 | 187 | if args.tune_vae: 188 | self.pipeline.vae.load_state_dict(torch.load(os.path.join(load_dir, path, "vae.pt"), map_location="cpu")) 189 | if args.tune_text_encoder: 190 | self.pipeline.text_encoder.load_state_dict(torch.load(os.path.join(load_dir, path, "text_encoder.pt"), map_location="cpu")) 191 | if args.gan_loss and args.resume_from_checkpoint == 'latest': 192 | print("Loading D_SD Lora") 193 | lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(os.path.join(load_dir, path, 'D_sd', "pytorch_lora_weights.safetensors")) 194 | LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=self.accelerator.unwrap_model(self.D.unet)) 195 | print("Loading D_SD MLP") 196 | self.D.mlp.load_state_dict(torch.load(os.path.join(load_dir, path, 'D_sd', "mlp.pt"), map_location="cpu")) 197 | for p in self.D.mlp.parameters(): 198 | p.data = p.data.float() 199 | if args.gan_unet_lastlayer_cls: 200 | self.D.unet.conv_out = self.mlp 201 | 202 | 203 | self.global_step = int(path.split("-")[1]) 204 | 205 | resume_global_step = self.global_step * args.gradient_accumulation_steps 206 | 207 | # load_trainable parameters, should be done after resume 208 | G_parameters, text_lora_parameters = get_trainable_parameters(args, self.pipeline, is_D=False) 209 | 210 | # Enable TF32 for faster training on Ampere GPUs, 211 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 212 | if args.allow_tf32: 213 | torch.backends.cuda.matmul.allow_tf32 = True 214 | 215 | # Initialize the optimizer 216 | if args.use_8bit_adam: 217 | try: 218 | import bitsandbytes as bnb 219 | except ImportError: 220 | raise ImportError( 221 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 222 | ) 223 | optimizer_cls = bnb.optim.AdamW8bit 224 | elif args.optimizer_class == 'AdamW': 225 | optimizer_cls = torch.optim.AdamW 226 | 227 | if args.train_text_encoder_lora: 228 | if args.textenc_lora_lr is None: # share lr with text lora 229 | G_parameters.extend(text_lora_parameters) 230 | self.G_parameters = G_parameters 231 | self.optimizer = optimizer_cls( 232 | self.G_parameters, 233 | lr=args.learning_rate, 234 | betas=(args.adam_beta1, args.adam_beta2), 235 | weight_decay=args.adam_weight_decay, 236 | eps=args.adam_epsilon, 237 | ) 238 | else: 239 | self.optimizer = optimizer_cls([ 240 | dict( 241 | params=G_parameters, 242 | lr=args.learning_rate, 243 | betas=(args.adam_beta1, args.adam_beta2), 244 | weight_decay=args.adam_weight_decay, 245 | eps=args.adam_epsilon), 246 | dict( 247 | params=text_lora_parameters, 248 | lr=args.textenc_lora_lr, 249 | betas=(args.adam_beta1, args.adam_beta2), 250 | weight_decay=args.adam_weight_decay, 251 | eps=args.adam_epsilon) 252 | ]) 253 | self.G_parameters = G_parameters 254 | self.G_parameters.extend(text_lora_parameters) 255 | 256 | else: 257 | self.G_parameters = G_parameters 258 | self.optimizer = optimizer_cls( 259 | self.G_parameters, 260 | lr=args.learning_rate, 261 | betas=(args.adam_beta1, args.adam_beta2), 262 | weight_decay=args.adam_weight_decay, 263 | eps=args.adam_epsilon, 264 | ) 265 | 266 | self.D_optimizer = None 267 | if args.gan_loss: 268 | self.D_parameters = self.D.get_trainable_parameters() 269 | self.D_optimizer = optimizer_cls( 270 | self.D_parameters, 271 | lr=args.learning_rate_D, 272 | betas=(args.adam_beta1_D, args.adam_beta2_D), 273 | weight_decay=args.adam_weight_decay, 274 | eps=args.adam_epsilon, 275 | ) 276 | 277 | # get train_dataset and train_dataloader 278 | self.train_dataset, self.train_dataloader = get_dataset_dataloader(args, self.accelerator) 279 | 280 | # Scheduler and math around the number of training steps. 281 | overrode_max_train_steps = False 282 | self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps) 283 | if args.max_train_steps is None: 284 | args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch 285 | overrode_max_train_steps = True 286 | 287 | self.first_epoch = self.global_step // self.num_update_steps_per_epoch 288 | self.resume_step = resume_global_step % (self.num_update_steps_per_epoch * args.gradient_accumulation_steps) 289 | 290 | self.lr_scheduler = get_scheduler( 291 | args.lr_scheduler, 292 | optimizer=self.optimizer, 293 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 294 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 295 | ) 296 | 297 | self.pipeline.to(torch_device=self.accelerator.device) 298 | 299 | self.caption_model.to(self.accelerator.device, dtype=self.weight_dtype) 300 | 301 | if args.tune_vae: 302 | self.pipeline.vae.to(dtype=torch.float) 303 | 304 | if args.tune_text_encoder: 305 | self.pipeline.text_encoder.to(dtype=torch.float) 306 | 307 | 308 | if 'attrcon' in args.pretrain_model_name : 309 | from attr_concen_utils.load_segmodel import load_seg_model 310 | if 'sdxl' in args.pretrain_model_name: 311 | from attn_utils.tc_sdxl_attn_utils import AttentionStore, register_attention_control 312 | train_layer_ls = ['mid_16', 'up_16', 'up_32'] 313 | else: 314 | from attn_utils.tc_attn_utils import AttentionStore, register_attention_control 315 | train_layer_ls = ['mid_8', 'up_16', 'up_32', 'up_64'] 316 | 317 | 318 | self.seg_model = load_seg_model(args, self.accelerator.device, train_layer_ls) 319 | self.pipeline.controller = AttentionStore(train_layer_ls) 320 | register_attention_control(self.pipeline.unet, self.pipeline.controller) 321 | 322 | # Prepare everything with our `self.accelerator`. 323 | if not args.gan_loss: 324 | self.pipeline.unet, self.optimizer, self.train_dataloader, self.lr_scheduler, self.D_optimizer, self.pipeline.text_encoder = self.accelerator.prepare( 325 | self.pipeline.unet, self.optimizer, self.train_dataloader, self.lr_scheduler, self.D_optimizer, self.pipeline.text_encoder 326 | ) 327 | else: # add D_sd_pipeline 328 | self.pipeline.unet, self.optimizer, self.train_dataloader, self.lr_scheduler, self.D_optimizer, self.pipeline.text_encoder, self.D.unet, self.D.mlp = self.accelerator.prepare( 329 | self.pipeline.unet, self.optimizer, self.train_dataloader, self.lr_scheduler, self.D_optimizer, self.pipeline.text_encoder, self.D.unet, self.D.mlp 330 | ) 331 | 332 | 333 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 334 | self.num_update_steps_per_epoch = math.ceil(len(self.train_dataloader) / args.gradient_accumulation_steps) 335 | if overrode_max_train_steps: 336 | args.max_train_steps = args.num_train_epochs * self.num_update_steps_per_epoch 337 | # Afterwards we recalculate our number of training epochs 338 | args.num_train_epochs = math.ceil(args.max_train_steps / self.num_update_steps_per_epoch) 339 | 340 | # We need to initialize the trackers we use, and also store our configuration. 341 | # The trackers initializes automatically on the main process. 342 | if self.accelerator.is_main_process: 343 | tracker_config = dict(vars(args)) 344 | tracker_config.pop("validation_prompts") 345 | tracker_config.pop("caption_model") 346 | tracker_config.pop("reward_weights") 347 | tracker_config.pop("seg_model") 348 | none_keys = [] 349 | for k, v in tracker_config.items(): 350 | if v is None: 351 | none_keys.append(k) 352 | # print(f"{k}: {type(v)}") 353 | 354 | for k in none_keys: 355 | tracker_config.pop(k) 356 | for k, v in tracker_config.items(): 357 | print(f"{k}: {type(v)}") 358 | 359 | self.accelerator.init_trackers(args.tracker_project_name, tracker_config) 360 | 361 | def train(self, args): 362 | 363 | # Train! 364 | total_batch_size = args.train_batch_size * self.accelerator.num_processes * args.gradient_accumulation_steps 365 | 366 | logger.info("***** Running training *****") 367 | logger.info(f" Num examples = {len(self.train_dataset)}") 368 | logger.info(f" Num Epochs = {args.num_train_epochs}") 369 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 370 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 371 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 372 | logger.info(f" Total optimization steps = {args.max_train_steps}") 373 | 374 | global_step = self.global_step 375 | first_epoch = self.first_epoch 376 | resume_step = self.resume_step 377 | 378 | # Only show the progress bar once on each machine. 379 | progress_bar = tqdm(range(args.max_train_steps), disable=not self.accelerator.is_local_main_process) 380 | progress_bar.set_description("Steps") 381 | 382 | def save_and_evaluate(output_dir, n_iter, save=True): 383 | unet = self.accelerator.unwrap_model(self.pipeline.unet) 384 | unet_lora_layers = unet_lora_state_dict(unet) 385 | 386 | text_encoder_lora_layers = None 387 | if args.train_text_encoder_lora: 388 | text_encoder = self.accelerator.unwrap_model(self.pipeline.text_encoder) 389 | text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder) 390 | 391 | if save: 392 | if args.full_finetuning: 393 | if not os.path.exists(output_dir): 394 | os.makedirs(output_dir) 395 | torch.save(unet.state_dict(), os.path.join(output_dir, "unet.pt")) 396 | else: 397 | LoraLoaderMixin.save_lora_weights( 398 | save_directory=output_dir, 399 | unet_lora_layers=unet_lora_layers, 400 | text_encoder_lora_layers=text_encoder_lora_layers, 401 | ) 402 | if args.tune_vae: 403 | torch.save(self.accelerator.unwrap_model(self.pipeline.vae).state_dict(), os.path.join(output_dir, "vae.pt")) 404 | 405 | if args.tune_text_encoder: 406 | torch.save(self.accelerator.unwrap_model(self.pipeline.text_encoder).state_dict(), os.path.join(output_dir, "text_encoder.pt")) 407 | 408 | 409 | if args.gan_loss: 410 | os.makedirs(os.path.join(output_dir, 'D_sd'), exist_ok=True) 411 | D_sd_unet = self.accelerator.unwrap_model(self.D.unet) 412 | D_sd_unet_lora_layers = unet_lora_state_dict(D_sd_unet) 413 | 414 | D_sd_text_encoder_lora_layers = None 415 | if args.full_finetuning: 416 | if not os.path.exists(output_dir): 417 | os.makedirs(output_dir) 418 | torch.save(D_sd_unet.state_dict(), os.path.join(output_dir, 'D_sd', "unet.pt")) 419 | else: 420 | LoraLoaderMixin.save_lora_weights( 421 | save_directory=os.path.join(output_dir, 'D_sd'), 422 | unet_lora_layers=D_sd_unet_lora_layers, 423 | text_encoder_lora_layers=D_sd_text_encoder_lora_layers, 424 | ) 425 | # save mlp 426 | torch.save(self.accelerator.unwrap_model(self.D.mlp).state_dict(), os.path.join(output_dir, 'D_sd', "mlp.pt")) 427 | 428 | def dummy_checker(image, device, dtype): 429 | return image, None 430 | 431 | # Load previous pipeline 432 | self.pipeline.run_safety_checker = dummy_checker 433 | ori_scheduler = self.pipeline.scheduler 434 | ori_unet = self.pipeline.unet 435 | self.pipeline.unet = unet 436 | if args.train_text_encoder_lora: 437 | ori_text_encoder = self.pipeline.text_encoder 438 | self.pipeline.text_encoder = self.accelerator.unwrap_model(self.pipeline.text_encoder) 439 | 440 | if args.scheduler == "DPM++": 441 | self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) 442 | elif args.scheduler == "DDPM": 443 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it 444 | scheduler_args = {} 445 | 446 | if "variance_type" in self.pipeline.scheduler.config: 447 | variance_type = self.pipeline.scheduler.config.variance_type 448 | 449 | if variance_type in ["learned", "learned_range"]: 450 | variance_type = "fixed_small" 451 | 452 | scheduler_args["variance_type"] = variance_type 453 | 454 | self.pipeline.scheduler = DDPMScheduler.from_config(self.pipeline.scheduler.config, **scheduler_args) 455 | 456 | images = [] 457 | if args.validation_prompts and args.num_validation_images > 0: 458 | if args.validation_prompts_file is not None: 459 | with open(args.validation_prompts_file, 'r') as f: 460 | val_prompts_from_file = f.readlines() 461 | validation_prompts = args.validation_prompts + val_prompts_from_file 462 | validation_prompts = [p.strip() for p in validation_prompts] 463 | else: 464 | validation_prompts = args.validation_prompts 465 | generator = torch.Generator(device=self.accelerator.device).manual_seed(args.seed) if args.seed else None 466 | # avoid oom by shrinking bs 467 | all_images = [[] for _ in range(args.num_validation_images)] 468 | for start in range(0, len(validation_prompts), 1): 469 | prompts = validation_prompts[start: start+1] 470 | with torch.autocast(device_type='cuda'): 471 | images = [ 472 | self.pipeline(prompts, num_inference_steps=args.total_step, generator=generator, guidance_scale=args.cfg_scale, guidance_rescale=args.cfg_rescale).images 473 | for _ in range(args.num_validation_images) 474 | ] 475 | for i, img in enumerate(images): 476 | all_images[i].extend(img) 477 | 478 | images = all_images 479 | 480 | new_images = [[] for _ in validation_prompts] 481 | for image in images: 482 | for i, img in enumerate(image): 483 | new_images[i].append(img) 484 | 485 | for tracker in self.accelerator.trackers: 486 | if tracker.name == "tensorboard": 487 | for i, image in enumerate(new_images): 488 | np_images = np.stack([np.asarray(img) for img in image]) 489 | tracker.writer.add_images(f"test_{i}", np_images, n_iter, dataformats="NHWC") 490 | 491 | self.pipeline.scheduler = ori_scheduler 492 | self.pipeline.unet = ori_unet 493 | if args.train_text_encoder_lora: 494 | self.pipeline.text_encoder = ori_text_encoder 495 | 496 | # evaluate before training 497 | if self.accelerator.is_main_process and not self.accelerator.is_last_process and global_step == 0 and resume_step == 0: 498 | with torch.no_grad(): 499 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 500 | save_and_evaluate(save_path, global_step) 501 | torch.cuda.empty_cache() 502 | self.accelerator.wait_for_everyone() 503 | 504 | # evaluate after resume 505 | if self.accelerator.is_main_process and not self.accelerator.is_last_process and global_step%100 == 0: 506 | with torch.no_grad(): 507 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 508 | save_and_evaluate(save_path, global_step, save=False) 509 | torch.cuda.empty_cache() 510 | self.accelerator.wait_for_everyone() 511 | 512 | 513 | if args.do_classifier_free_guidance: 514 | if args.gan_loss: 515 | with torch.no_grad(): 516 | gan_null_embed, gan_pooled_null_embed = self.D.encode_prompt("", self.accelerator.device, args.train_batch_size, do_classifier_free_guidance=False) 517 | 518 | # embed for pipeline 519 | with torch.no_grad(): 520 | if isinstance(self.pipeline, TrainableSDPipeline): 521 | null_embed = self.pipeline.encode_prompt("", self.accelerator.device, args.train_batch_size, do_classifier_free_guidance=False)[0] 522 | elif isinstance(self.pipeline, TrainableSDXLPipeline): 523 | null_embed, _, pooled_null_embed, _ = self.pipeline.encode_prompt("", device=self.accelerator.device, num_images_per_prompt=args.train_batch_size, do_classifier_free_guidance=False) 524 | else: 525 | raise NotImplementedError("This model is not supported yet") 526 | 527 | # remove unnecessary D pipeline vae and text encoders 528 | if args.gan_loss: 529 | del self.D.D_sd_pipeline.vae 530 | del self.D.D_sd_pipeline.text_encoder 531 | if isinstance(self.D.D_sd_pipeline, TrainableSDXLPipeline): 532 | del self.D.D_sd_pipeline.text_encoder_2 533 | 534 | torch.cuda.empty_cache() 535 | 536 | step_count = 0 537 | for epoch in range(first_epoch, args.num_train_epochs): 538 | self.pipeline.unet.train() 539 | if args.tune_text_encoder or args.train_text_encoder_lora: 540 | self.pipeline.text_encoder.train() 541 | self.pipeline.text_encoder_2.train() if hasattr(self.pipeline, "text_encoder_2") else None 542 | train_loss = 0.0 543 | for step, batch in enumerate(self.train_dataloader): 544 | # Skip steps until we reach the resumed step 545 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 546 | if step % args.gradient_accumulation_steps == 0: 547 | progress_bar.update(1) 548 | continue 549 | 550 | if args.batch_repeat > 1: 551 | batch['text'] = batch['text'] * args.batch_repeat 552 | 553 | total_step = args.total_step 554 | 555 | # train diffusion model 556 | with self.accelerator.accumulate(self.pipeline.unet): 557 | # setting of backward 558 | bp_on_trained = True 559 | early_exit = False 560 | double_laststep = False 561 | fast_training = False 562 | 563 | interval = total_step // args.K 564 | max_start = total_step - interval * (args.K - 1) - 1 565 | start = random.randint(0, max_start) 566 | training_steps = list(range(start, total_step, interval)) 567 | detach_gradient = True 568 | 569 | if args.tune_text_encoder or args.train_text_encoder_lora: 570 | if isinstance(self.pipeline, TrainableSDPipeline): 571 | null_embed = self.pipeline.encode_prompt("", self.accelerator.device, args.train_batch_size, do_classifier_free_guidance=False)[0] 572 | elif isinstance(self.pipeline, TrainableSDXLPipeline): 573 | null_embed, _, pooled_null_embed, _ = self.pipeline.encode_prompt("", device=self.accelerator.device, num_images_per_prompt=args.train_batch_size, do_classifier_free_guidance=False) 574 | 575 | kwargs = dict( 576 | prompt=batch["text"], 577 | height=args.resolution, 578 | width=args.resolution, 579 | training_timesteps=training_steps, 580 | detach_gradient=detach_gradient, 581 | train_text_encoder=args.tune_text_encoder or args.train_text_encoder_lora, 582 | num_inference_steps=total_step, 583 | guidance_scale=args.cfg_scale, 584 | guidance_rescale=args.cfg_rescale, 585 | negative_prompt_embeds=null_embed if args.do_classifier_free_guidance else None, 586 | early_exit=early_exit, 587 | return_latents=True if args.gan_loss else False, 588 | ) 589 | if 'attrcon' in args.pretrain_model_name: 590 | kwargs['attrcon_train_steps'] = random.choices(training_steps, k=min(args.attrcon_train_steps, len(training_steps))) 591 | 592 | if isinstance(self.pipeline, TrainableSDPipeline): 593 | if args.gan_loss: 594 | image, training_latents = self.pipeline.forward(bp_on_trained=bp_on_trained, double_laststep=double_laststep, fast_training=fast_training, **kwargs) 595 | else: 596 | image = self.pipeline.forward(bp_on_trained=bp_on_trained, double_laststep=double_laststep, fast_training=fast_training, **kwargs) 597 | elif isinstance(self.pipeline, TrainableSDXLPipeline): 598 | if args.gan_loss: 599 | image, training_latents = self.pipeline.forward(negative_pooled_prompt_embeds=pooled_null_embed if args.do_classifier_free_guidance else None, **kwargs) 600 | else: 601 | image = self.pipeline.forward(negative_pooled_prompt_embeds=pooled_null_embed if args.do_classifier_free_guidance else None, **kwargs) 602 | else: 603 | raise NotImplementedError("This model is not supported yet") 604 | 605 | # reward 606 | offset_range = args.resolution // 224 607 | random_offset_x = random.randint(0, offset_range) 608 | random_offset_y = random.randint(0, offset_range) 609 | size = args.resolution - offset_range 610 | caption_rewards = self.caption_model( 611 | image[:,:,random_offset_x:random_offset_x + size, random_offset_y:random_offset_y + size].to(self.weight_dtype), 612 | batch['text'], 613 | step=step_count, 614 | text_encoder=self.pipeline.text_encoder, 615 | batch=batch) 616 | step_count += 1 617 | 618 | loss = - caption_rewards["total"].mean() 619 | 620 | if args.gan_loss: 621 | kwargs['negative_prompt_embeds'] = gan_null_embed # used in D_sd 622 | kwargs['negative_pooled_prompt_embeds'] = gan_pooled_null_embed 623 | 624 | G_loss = self.D.D_sd_pipeline_forward(training_latents, side='G' ,**kwargs) 625 | loss += args.gan_loss_weight * G_loss 626 | 627 | if 'attrcon' in args.pretrain_model_name: 628 | 629 | all_subtree_indices = [self.pipeline._extract_attribution_indices(p) for p in batch['text']] 630 | attn_map_idx_to_wp_all = [get_attention_map_index_to_wordpiece(self.pipeline.tokenizer, p) for p in batch['text'] ] 631 | attn_map = self.pipeline.attn_dict 632 | 633 | token_loss, pixel_loss, grounding_loss_dict = self.seg_model.get_mask_loss( 634 | image.clamp(0, 1), # this does not require grad, so clamp(0,1) is better 635 | batch['text'], 636 | all_subtree_indices, 637 | attn_map_idx_to_wp_all, 638 | attn_map) 639 | loss += args.mask_token_loss_weight * token_loss 640 | loss += args.mask_pixel_loss_weight * pixel_loss 641 | 642 | self.pipeline.attn_dict = {} # clear the attn_dict after usage, or it will cause error in next iter 643 | 644 | norm = {} 645 | def record_grad(grad): 646 | norm['reward_norm'] = grad.norm(2).item() 647 | if args.norm_grad: 648 | grad = grad / (norm['reward_norm'] / 1e4) # 1e4 for numerical stability 649 | return grad 650 | 651 | image.register_hook(record_grad) 652 | 653 | # Gather the losses across all processes for logging (if we use distributed training). 654 | avg_loss = self.accelerator.gather(loss.repeat(args.train_batch_size)).mean() 655 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 656 | 657 | # Backpropagate 658 | self.optimizer.zero_grad() 659 | self.accelerator.backward(loss) 660 | 661 | if self.accelerator.sync_gradients: 662 | self.accelerator.clip_grad_norm_(self.G_parameters, args.max_grad_norm) 663 | self.optimizer.step() 664 | self.lr_scheduler.step() 665 | 666 | 667 | logs = {"step_loss": loss.detach().item(), "lr": self.lr_scheduler.get_last_lr()[0]} 668 | logs.update({k: self.accelerator.gather(v.detach()).mean().item() for k, v in caption_rewards.items()}) 669 | 670 | if args.gan_loss: 671 | logs.update({'G_loss': self.accelerator.gather(G_loss.detach()).mean().item()}) 672 | 673 | if 'attrcon' in args.pretrain_model_name: 674 | logs.update({'token_loss': self.accelerator.gather(token_loss.detach()).mean().item()}) 675 | logs.update({'pixel_loss': self.accelerator.gather(pixel_loss.detach()).mean().item()}) 676 | 677 | logs.update(norm) 678 | 679 | if args.gan_loss: 680 | with self.accelerator.accumulate(self.D): 681 | kwargs['batch'] = batch 682 | 683 | D_loss = self.D.D_sd_pipeline_forward(training_latents.detach(), side='D', **kwargs) 684 | avg_D_loss = self.accelerator.gather(D_loss.detach()).mean().item() 685 | logs.update(dict( 686 | D_loss=avg_D_loss, 687 | )) 688 | 689 | self.D_optimizer.zero_grad() 690 | self.accelerator.backward(D_loss) 691 | 692 | if self.accelerator.sync_gradients: 693 | self.accelerator.clip_grad_norm_(self.D_parameters, args.max_grad_norm_D) 694 | self.D_optimizer.step() 695 | 696 | 697 | # Checks if the self.accelerator has performed an optimization step behind the scenes 698 | if self.accelerator.sync_gradients: 699 | 700 | progress_bar.update(1) 701 | global_step += 1 702 | self.accelerator.log({"train_loss": train_loss}, step=global_step) 703 | self.accelerator.log(logs, step=global_step) 704 | train_loss = 0.0 705 | 706 | logger.info(f"{global_step}: {json.dumps(logs, sort_keys=False, indent=4)}") 707 | 708 | progress_bar.set_postfix(**logs) 709 | 710 | if global_step % args.validation_steps == 0 and self.accelerator.sync_gradients: 711 | if self.accelerator.is_main_process: 712 | with torch.no_grad(): 713 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 714 | save_and_evaluate(save_path, global_step) 715 | torch.cuda.empty_cache() 716 | self.accelerator.wait_for_everyone() 717 | 718 | if global_step >= args.max_train_steps: 719 | break 720 | 721 | 722 | # Save the lora layers 723 | self.accelerator.wait_for_everyone() 724 | self.accelerator.end_training() 725 | 726 | 727 | if __name__ == "__main__": 728 | args = parse_args() 729 | trainer = Trainer(args.pretrain_model, args=args) 730 | trainer.train(args=args) -------------------------------------------------------------------------------- /training_utils/arguments.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def parse_args(): 4 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 5 | parser.add_argument( 6 | "--pretrain_model", type=str, default="runwayml/stable-diffusion-v1-5", help="The pretrained model to use." 7 | ) 8 | parser.add_argument( 9 | "--pretrain_model_name", type=str, default="sd_1_5", help="the name of pretrained model", 10 | choices=["sd_1_5", "sdxl", "sd_1_5_attrcon", "sdxl_unet", "sdxl_attrcon_unet", "sdxl_attrcon"] 11 | ) 12 | parser.add_argument( 13 | "--caption_model", 14 | type=str, 15 | choices=['Blip'], 16 | default=["Blip"], 17 | nargs="+", 18 | help="The reward model to use.", 19 | ) 20 | parser.add_argument( 21 | "--reward_weights", 22 | type=float, 23 | default=None, 24 | nargs="+", 25 | help="The weight of each reward model.", 26 | ) 27 | parser.add_argument( 28 | "--revision", 29 | type=str, 30 | default=None, 31 | required=False, 32 | help="Revision of pretrained model identifier from huggingface.co/models.", 33 | ) 34 | parser.add_argument( 35 | "--max_train_samples", 36 | type=int, 37 | default=None, 38 | help=( 39 | "For debugging purposes or quicker training, truncate the number of training examples to this " 40 | "value if set." 41 | ), 42 | ) 43 | parser.add_argument( 44 | "--validation_prompts", 45 | type=str, 46 | default=None, 47 | nargs="+", 48 | help=("A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."), 49 | ) 50 | parser.add_argument( 51 | "--validation_prompts_file", 52 | type=str, 53 | default=None, 54 | help=("A set of prompts evaluated every `--validation_steps` and logged to `--report_to`."), 55 | ) 56 | parser.add_argument( 57 | "--output_dir", 58 | type=str, 59 | default="checkpoint/refl", 60 | help="The output directory where the model predictions and checkpoints will be written.", 61 | ) 62 | parser.add_argument( 63 | "--cache_dir", 64 | type=str, 65 | default=None, 66 | help="The directory where the downloaded models and datasets will be stored.", 67 | ) 68 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 69 | parser.add_argument( 70 | "--resolution", 71 | type=int, 72 | default=512, 73 | help=( 74 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 75 | " resolution" 76 | ), 77 | ) 78 | parser.add_argument( 79 | "--center_crop", 80 | default=False, 81 | action="store_true", 82 | help=( 83 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 84 | " cropped. The images will be resized to the resolution first before cropping." 85 | ), 86 | ) 87 | 88 | parser.add_argument( 89 | "--train_batch_size", type=int, default=2, help="Batch size (per device) for the training dataloader." 90 | ) 91 | parser.add_argument( 92 | "--batch_repeat", type=int, default=1, help="Repeat the batch." 93 | ) 94 | parser.add_argument("--num_train_epochs", type=int, default=100) 95 | parser.add_argument( 96 | "--max_train_steps", 97 | type=int, 98 | default=100, 99 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 100 | ) 101 | parser.add_argument( 102 | "--gradient_accumulation_steps", 103 | type=int, 104 | default=4, 105 | help="Number of updates steps to accumulate before performing a backward/update pass.", 106 | ) 107 | parser.add_argument( 108 | "--gradient_checkpointing", 109 | action="store_true", 110 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 111 | ) 112 | parser.add_argument( 113 | "--learning_rate", 114 | type=float, 115 | default=1e-5, 116 | help="Initial learning rate (after the potential warmup period) to use.", 117 | ) 118 | parser.add_argument( 119 | "--learning_rate_D", 120 | type=float, 121 | default=1e-5, 122 | help="Initial learning rate (after the potential warmup period) to use.", 123 | ) 124 | parser.add_argument( 125 | "--lr_scheduler", 126 | type=str, 127 | default="constant", 128 | help=( 129 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 130 | ' "constant", "constant_with_warmup"]' 131 | ), 132 | ) 133 | parser.add_argument( 134 | "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler." 135 | ) 136 | parser.add_argument( 137 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 138 | ) 139 | parser.add_argument( 140 | "--allow_tf32", 141 | action="store_true", 142 | help=( 143 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 144 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 145 | ), 146 | ) 147 | parser.add_argument( 148 | "--dataloader_num_workers", 149 | type=int, 150 | default=4, 151 | help=( 152 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 153 | ), 154 | ) 155 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 156 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 157 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 158 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 159 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 160 | parser.add_argument( 161 | "--logging_dir", 162 | type=str, 163 | default="logs", 164 | help=( 165 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 166 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 167 | ), 168 | ) 169 | parser.add_argument( 170 | "--mixed_precision", 171 | type=str, 172 | default=None, 173 | choices=["no", "fp16", "bf16"], 174 | help=( 175 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 176 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 177 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 178 | ), 179 | ) 180 | parser.add_argument( 181 | "--report_to", 182 | type=str, 183 | default="tensorboard", 184 | help=( 185 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 186 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 187 | ), 188 | ) 189 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 190 | parser.add_argument( 191 | "--checkpoints_total_limit", 192 | type=int, 193 | default=None, 194 | help=( 195 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 196 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 197 | " for more docs" 198 | ), 199 | ) 200 | parser.add_argument( 201 | "--resume_from_checkpoint", 202 | type=str, 203 | default="latest", 204 | help=( 205 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 206 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 207 | ), 208 | ) 209 | parser.add_argument( 210 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 211 | ) 212 | parser.add_argument( 213 | "--enable_torch2_product", action="store_true", help="Whether or not to use torch2 product. \ 214 | see https://huggingface.co/docs/diffusers/optimization/torch2.0" 215 | ) 216 | 217 | parser.add_argument( 218 | "--validation_steps", 219 | type=int, 220 | default=5, 221 | help="Run validation every X epochs.", 222 | ) 223 | parser.add_argument( 224 | "--tracker_project_name", 225 | type=str, 226 | default="comat", 227 | help=( 228 | "The `project_name` argument passed to Accelerator.init_trackers for" 229 | " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" 230 | ), 231 | ) 232 | 233 | parser.add_argument( 234 | "--num_validation_images", 235 | type=int, 236 | default=4, 237 | ) 238 | 239 | parser.add_argument( 240 | "--lora_rank", 241 | type=int, 242 | default=32, 243 | help=("The dimension of the LoRA update matrices."), 244 | ) 245 | 246 | parser.add_argument( 247 | "--K", 248 | type=int, 249 | default=1, 250 | ) 251 | 252 | parser.add_argument( 253 | "--cfg_scale", 254 | type=float, 255 | default=7.5, 256 | help=("classifier free guidance scale for training and inference."), 257 | ) 258 | 259 | parser.add_argument( 260 | "--cfg_rescale", 261 | type=float, 262 | default=0.0, 263 | help=("cfg rescale."), 264 | ) 265 | 266 | parser.add_argument( 267 | "--training_prompts", 268 | type=str, 269 | default="refl_data.txt", 270 | ) 271 | 272 | parser.add_argument( 273 | "--image_folder", 274 | type=str 275 | ) 276 | 277 | parser.add_argument( 278 | "--total_step", 279 | type=int, 280 | default=40, 281 | ) 282 | 283 | parser.add_argument( 284 | "--scheduler", 285 | type=str, 286 | default='DDPM', 287 | choices=["DDPM"] 288 | ) 289 | 290 | parser.add_argument( 291 | "--bp_on_trained", 292 | action="store_true" 293 | ) 294 | 295 | parser.add_argument( 296 | "--full_finetuning", 297 | action="store_true" 298 | ) 299 | 300 | parser.add_argument( 301 | "--tune_vae", 302 | action="store_true" 303 | ) 304 | 305 | parser.add_argument( 306 | "--tune_text_encoder", 307 | action="store_true" 308 | ) 309 | parser.add_argument( 310 | "--train_text_encoder_lora", 311 | action="store_true" 312 | ) 313 | parser.add_argument( 314 | "--textenc_lora_lr", 315 | type=float, 316 | default=None 317 | ) 318 | parser.add_argument( 319 | "--norm_grad", 320 | action="store_true" 321 | ) 322 | parser.add_argument( 323 | "--prediction_type", 324 | type=str, 325 | default=None, 326 | ) 327 | 328 | parser.add_argument( 329 | "--gan_model_arch", 330 | type=str, 331 | default='gan_sd_1_5', 332 | ) 333 | parser.add_argument( 334 | "--gan_loss", 335 | action='store_true' 336 | ) 337 | parser.add_argument( 338 | "--condition_discriminator", 339 | action='store_true' 340 | ) 341 | parser.add_argument( 342 | "--gan_loss_weight", 343 | type=float, 344 | default=1.0, 345 | ) 346 | parser.add_argument("--adam_beta1_D", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 347 | parser.add_argument("--adam_beta2_D", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 348 | parser.add_argument( 349 | "--max_grad_norm_D", 350 | type=float, 351 | default=1.0, 352 | ) 353 | parser.add_argument( 354 | "--gan_unet_lastlayer_cls", 355 | action='store_true' 356 | ) 357 | parser.add_argument( 358 | "--mask_token_loss_weight", 359 | type=float, 360 | default=1e-3 361 | ) 362 | parser.add_argument( 363 | "--mask_pixel_loss_weight", 364 | type=float, 365 | default=5e-5 366 | ) 367 | parser.add_argument( 368 | "--attrcon_train_steps", 369 | type=int, 370 | default=5 371 | ) 372 | parser.add_argument( 373 | "--sdxl_unet_path", 374 | type=str, 375 | default=None 376 | ) 377 | parser.add_argument( 378 | "--seg_model", 379 | type=str, 380 | choices=['gsam'], 381 | default=["gsam"], 382 | nargs="+", 383 | help="The reward model to use.", 384 | ) 385 | parser.add_argument( 386 | "--optimizer_class", 387 | type=str, 388 | default='AdamW' 389 | ) 390 | 391 | args = parser.parse_args() 392 | 393 | args.do_classifier_free_guidance = args.cfg_scale > 1.0 394 | 395 | if args.reward_weights is None: 396 | args.reward_weights = [1.0] * len(args.caption_model) 397 | 398 | return args -------------------------------------------------------------------------------- /training_utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision import transforms 4 | from datasets import load_dataset 5 | from PIL import Image 6 | 7 | from training_utils.gan_dataset import Gan_Dataset 8 | 9 | 10 | def get_dataset_dataloader(args, accelerator): 11 | if args.gan_loss: 12 | dataset = Gan_Dataset(args) 13 | elif args.training_prompts.endswith("txt"): 14 | dataset = load_dataset("text", data_files=dict(train=args.training_prompts)) 15 | elif args.training_prompts.endswith("json"): 16 | dataset = load_dataset("json", data_files=dict(train=args.training_prompts)) 17 | 18 | image_transforms = transforms.Compose( 19 | [ 20 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 21 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 22 | transforms.ToTensor() 23 | ] 24 | ) 25 | 26 | def preprocess_train(instance): 27 | if 'file_name' in instance: 28 | filenames = instance.pop('file_name') 29 | images = [Image.open(os.path.join(args.image_folder, filename)).convert("RGB") for filename in filenames] 30 | images = [image_transforms(image) for image in images] 31 | instance['image'] = torch.stack(images) 32 | return instance 33 | 34 | 35 | with accelerator.main_process_first(): 36 | if args.gan_loss: 37 | train_dataset = dataset 38 | else: 39 | dataset["train"] = dataset["train"].shuffle(seed=args.seed + accelerator.process_index) 40 | if args.max_train_samples is not None: 41 | dataset["train"] = dataset["train"].select(range(args.max_train_samples)) 42 | # Set the training transforms 43 | train_dataset = dataset["train"].with_transform(preprocess_train) 44 | 45 | 46 | # DataLoaders creation: 47 | train_dataloader = torch.utils.data.DataLoader( 48 | train_dataset, 49 | shuffle=True, 50 | batch_size=args.train_batch_size, 51 | num_workers=args.dataloader_num_workers, 52 | ) 53 | 54 | return train_dataset, train_dataloader -------------------------------------------------------------------------------- /training_utils/gan_dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy 3 | import json 4 | import matplotlib.pyplot as plt 5 | from PIL import Image 6 | import os 7 | # from petrel_client.client import Client 8 | # from aoss_client.client import Client 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | from pathlib import Path 13 | from torchvision import transforms 14 | from PIL import Image 15 | from transformers import CLIPTextModel, CLIPTokenizer 16 | import os 17 | import io 18 | import json 19 | import random 20 | 21 | def read_jsonl(save_path): 22 | ret_list = [] 23 | with open(save_path, 'r') as f: 24 | for line in f: 25 | ret_list.append(json.loads(line)) 26 | return ret_list 27 | 28 | 29 | class Gan_Dataset(Dataset): 30 | """ 31 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 32 | It pre-processes the images and the tokenizes prompts. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | args, 38 | ): 39 | self.args = args 40 | 41 | if 'txt' in args.training_prompts: 42 | self.ann = list() 43 | with open(args.training_prompts, 'r') as f: 44 | for line in f: 45 | self.ann.append(line.strip()) 46 | elif 'jsonl' in args.training_prompts: 47 | self.ann = read_jsonl(args.training_prompts) 48 | elif 'json' in args.training_prompts: 49 | self.ann = json.load(open(args.training_prompts, 'r')) 50 | 51 | # self.client = Client('~/aoss.conf') 52 | 53 | def __len__(self): 54 | return len(self.ann) 55 | 56 | def __getitem__(self, index): 57 | example = {} 58 | example['text'] = self.ann[index]['prompt'] 59 | latent_path = self.ann[index]['file_path'] if not isinstance(self.ann[index]['file_path'], list) else random.choice(self.ann[index]['file_path']) 60 | # use ceph 61 | # with io.BytesIO(self.client.get(latent_path)) as f: 62 | # latents = torch.load(f) 63 | # use disk 64 | latents = torch.load(latent_path) 65 | 66 | example['latents'] = latents 67 | 68 | # add another potential key in example 69 | eliminate_keys = ['prompt', 'file_path', 'image'] 70 | for k in self.ann[index].keys(): 71 | if k not in eliminate_keys: 72 | example[k] = self.ann[index][k] 73 | 74 | return example 75 | 76 | 77 | -------------------------------------------------------------------------------- /training_utils/gan_sd_model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from training_utils.pipeline import * 3 | from training_utils.pipeline import get_trainable_parameters as _get_trainable_parameters 4 | from training_utils.gan_sdxl import D_sd, D_sdxl 5 | import copy 6 | import pdb 7 | 8 | def load_discriminator(args, weight_dtype, device): 9 | args.gan_model_arch = args.gan_model_arch.replace('gan', '') 10 | 11 | if args.gan_model_arch == 'sd_1_5': 12 | return D_sd(args, weight_dtype, device) 13 | elif 'sdxl' in args.gan_model_arch : 14 | return D_sdxl(args, weight_dtype, device) 15 | 16 | -------------------------------------------------------------------------------- /training_utils/gan_sdxl.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from training_utils.pipeline import * 3 | from training_utils.pipeline import get_trainable_parameters as _get_trainable_parameters 4 | import copy 5 | 6 | class D_sd(nn.Module): 7 | def __init__(self, args, weight_dtype, device=None) -> None: 8 | super().__init__() 9 | self.D_args = copy.deepcopy(args) 10 | self.D_args.train_text_encoder_lora = False 11 | self.D_args.tune_text_encoder = False 12 | self.D_args.pretrain_model = "runwayml/stable-diffusion-v1-5" 13 | self.D_sd_pipeline = load_pipeline(self.D_args, args.gan_model_arch, weight_dtype, is_D=True).to(device) 14 | print("D_sd pipeline", self.D_sd_pipeline) 15 | 16 | self.weight_dtype = weight_dtype 17 | 18 | self.unet = self.D_sd_pipeline.unet 19 | 20 | if args.train_text_encoder_lora or args.tune_text_encoder: 21 | self.D_sd_pipeline.text_encoder.to(device) 22 | self.text_encoder = self.D_sd_pipeline.text_encoder 23 | 24 | self.ori_scheduler = copy.deepcopy(self.D_sd_pipeline.scheduler) 25 | 26 | # for classification 27 | if args.gan_unet_lastlayer_cls: 28 | ori_last_conv = self.D_sd_pipeline.unet.conv_out 29 | self.mlp = nn.Conv2d(ori_last_conv.in_channels, 1, ori_last_conv.kernel_size, ori_last_conv.padding, ori_last_conv.stride) 30 | self.unet.conv_out = self.mlp # to hack in accelerator.prepare 31 | else: # MLP 32 | self.mlp = nn.Sequential( 33 | nn.Linear(4, 1) 34 | ) 35 | self.cls_loss_fn = nn.BCEWithLogitsLoss() 36 | 37 | def get_trainable_parameters(self): 38 | self.D_parameters = _get_trainable_parameters(self.D_args, self.D_sd_pipeline, is_D=True) 39 | self.D_parameters.extend([p for p in self.mlp.parameters()]) 40 | return self.D_parameters 41 | 42 | def set_D_sd_pipeline_lora(self, requires_grad=True): 43 | for p in self.D_parameters: 44 | p.requires_grad = requires_grad 45 | 46 | def get_D_gt_noise(self, device, **kwargs): 47 | ori_latents = kwargs['batch']['latents'].to(device, dtype=self.weight_dtype) 48 | return ori_latents 49 | 50 | def D_sd_pipeline_forward(self, training_latents, side='G',**kwargs): 51 | device = training_latents.device 52 | if side == 'G': 53 | 54 | # set D_sd_pipeline no grad 55 | self.unet.eval() 56 | self.set_D_sd_pipeline_lora(requires_grad=False) 57 | 58 | # get discriminator condition 59 | if self.D_args.condition_discriminator: 60 | D_cond = self.pipeline.encode_prompt( 61 | prompt=kwargs['prompt'], 62 | device=device, 63 | num_images_per_prompt=1, 64 | do_classifier_free_guidance=False)[0] 65 | else: 66 | D_cond = kwargs['negative_prompt_embeds'] 67 | 68 | self.ori_scheduler.set_timesteps(kwargs['num_inference_steps'], device=device) 69 | timesteps = self.ori_scheduler.timesteps 70 | training_latents = self.ori_scheduler.scale_model_input(training_latents, timesteps[-1]) # identity in ddpm 71 | 72 | noise_pred = self.unet( 73 | training_latents, 74 | timesteps[-1], # marking its own step, could be out of domain 75 | encoder_hidden_states=D_cond, 76 | cross_attention_kwargs=None, 77 | return_dict=False, 78 | )[0] 79 | 80 | noise_pred = noise_pred.permute(0,2,3,1) # -> (bs, h, w, 4) 81 | if self.D_args.gan_unet_lastlayer_cls: 82 | pred = noise_pred # noise_pred -> (bs, h, w, 1) 83 | else: 84 | pred = self.mlp(noise_pred) # -> (bs, h, w, 2) 85 | target = torch.ones_like(pred) # -> (bs, h, w, 2) 86 | 87 | with torch.autocast('cuda'): 88 | gan_loss = self.cls_loss_fn(pred, target) 89 | return gan_loss 90 | 91 | 92 | elif side == 'D': 93 | 94 | self.unet.train() 95 | self.set_D_sd_pipeline_lora(requires_grad=True) 96 | training_latents.requires_grad_(False) 97 | 98 | D_cond = torch.cat( 99 | [kwargs['negative_prompt_embeds'], kwargs['negative_prompt_embeds']] 100 | ) 101 | 102 | with torch.no_grad(): 103 | ori_latents = self.get_D_gt_noise(device, **kwargs) 104 | 105 | 106 | self.ori_scheduler.set_timesteps(kwargs['num_inference_steps'], device=device) 107 | timesteps = self.ori_scheduler.timesteps 108 | 109 | input_latents = torch.cat([training_latents, ori_latents]) 110 | input_latents = self.ori_scheduler.scale_model_input(input_latents, timesteps[-1]) # identity in ddpm 111 | 112 | noise_pred = self.unet( 113 | input_latents, 114 | timesteps[-1], # marking its own step, could be out of domain 115 | encoder_hidden_states=D_cond, 116 | cross_attention_kwargs=None, 117 | return_dict=False, 118 | )[0] # (2*bs, 4, h, w) 119 | 120 | 121 | noise_pred = noise_pred.permute(0,2,3,1) # -> (2*bs, h, w, 4) 122 | if self.D_args.gan_unet_lastlayer_cls: 123 | pred = noise_pred # noise_pred -> (bs, h, w, 1) 124 | else: 125 | pred = self.mlp(noise_pred) # -> (bs, h, w, 2) 126 | target = torch.ones_like(pred) # -> (2*bs, h, w, 1) 127 | 128 | target[:target.shape[0]//2] = 0 # label = 0 for generated_image 129 | 130 | with torch.autocast('cuda'): 131 | gan_loss = self.cls_loss_fn(pred, target) 132 | return gan_loss 133 | 134 | @torch.no_grad() 135 | def encode_prompt(self, prompt, device, batch_size, do_classifier_free_guidance=False): 136 | if isinstance(self.D_sd_pipeline, TrainableSDPipeline) or isinstance(self.D_sd_pipeline, SynTrainableSDPipeline): 137 | null_embed = self.D_sd_pipeline.encode_prompt( 138 | prompt, 139 | device, 140 | batch_size, 141 | do_classifier_free_guidance=do_classifier_free_guidance 142 | )[0] 143 | pooled_null_embed = None 144 | elif isinstance(self.D_sd_pipeline, TrainableSDXLPipeline): 145 | null_embed, _, pooled_null_embed, _ = self.D_sd_pipeline.encode_prompt( 146 | prompt, 147 | device=device, 148 | num_images_per_prompt=batch_size, 149 | do_classifier_free_guidance=do_classifier_free_guidance 150 | ) 151 | # assume this function will be only called once 152 | self.D_sd_pipeline.text_encoder.to('cpu') 153 | self.D_sd_pipeline.vae.to('cpu') 154 | 155 | return null_embed, pooled_null_embed 156 | 157 | 158 | class D_sdxl(D_sd): 159 | def __init__(self, args, weight_dtype, device=None) -> None: 160 | super().__init__() 161 | self.D_args = copy.deepcopy(args) 162 | self.D_args.train_text_encoder_lora = False 163 | self.D_args.tune_text_encoder = False 164 | self.D_args.pretrain_model = 'stabilityai/stable-diffusion-xl-base-1.0' 165 | self.D_sd_pipeline = load_pipeline(self.D_args, args.gan_model_arch, weight_dtype, is_D=True).to(device) 166 | 167 | self.weight_dtype = weight_dtype 168 | 169 | self.unet = self.D_sd_pipeline.unet 170 | 171 | if args.train_text_encoder_lora or args.tune_text_encoder: 172 | self.D_sd_pipeline.text_encoder.to(device) 173 | self.text_encoder = self.D_sd_pipeline.text_encoder 174 | 175 | self.ori_scheduler = copy.deepcopy(self.D_sd_pipeline.scheduler) 176 | 177 | # for classification 178 | if args.gan_unet_lastlayer_cls: 179 | ori_last_conv = self.D_sd_pipeline.unet.conv_out 180 | self.mlp = nn.Conv2d(ori_last_conv.in_channels, 1, ori_last_conv.kernel_size, ori_last_conv.padding, ori_last_conv.stride) 181 | self.unet.conv_out = self.mlp # to hack in accelerator.prepare 182 | else: # MLP 183 | self.mlp = nn.Sequential( 184 | nn.Linear(4, 1) 185 | ) 186 | self.cls_loss_fn = nn.BCEWithLogitsLoss() 187 | 188 | if args.train_text_encoder_lora or args.tune_text_encoder: 189 | self.D_sd_pipeline.text_encoder_2.to(device) 190 | self.text_encoder_2 = self.D_sd_pipeline.text_encoder_2 191 | 192 | # used in add_time_ids 193 | height = args.resolution or self.D_sd_pipeline.default_sample_size * self.D_sd_pipeline.vae_scale_factor 194 | width = args.resolution or self.D_sd_pipeline.default_sample_size * self.D_sd_pipeline.vae_scale_factor 195 | 196 | self.original_size = (height, width) 197 | self.target_size = (height, width) 198 | self.crops_coords_top_left = (0, 0) 199 | self.add_time_ids = self._get_add_time_ids( 200 | self.original_size, 201 | self.crops_coords_top_left, 202 | self.target_size, 203 | weight_dtype 204 | ).to(device) 205 | 206 | 207 | def D_sd_pipeline_forward(self, training_latents, side='G',**kwargs): 208 | device = training_latents.device 209 | bs = training_latents.shape[0] 210 | 211 | if side == 'G': 212 | 213 | # set D_sd_pipeline no grad 214 | self.unet.eval() 215 | self.set_D_sd_pipeline_lora(requires_grad=False) 216 | 217 | D_cond = kwargs['negative_prompt_embeds'] 218 | 219 | add_time_ids = self.add_time_ids.repeat(bs, 1) 220 | # Predict the noise residual 221 | unet_added_conditions = {"time_ids": add_time_ids} 222 | unet_added_conditions.update({"text_embeds": kwargs['negative_pooled_prompt_embeds']}) 223 | 224 | self.ori_scheduler.set_timesteps(kwargs['num_inference_steps'], device=device) 225 | timesteps = self.ori_scheduler.timesteps 226 | training_latents = self.ori_scheduler.scale_model_input(training_latents, timesteps[-1]) # identity in ddpm 227 | 228 | noise_pred = self.unet( 229 | training_latents, 230 | timesteps[-1], # marking its own step, could be out of domain 231 | encoder_hidden_states=D_cond, 232 | added_cond_kwargs=unet_added_conditions, 233 | return_dict=False, 234 | )[0] 235 | 236 | # noise_pred = noise_pred.to(kwargs['negative_prompt_embeds'].dtype).flatten(-2).permute(0,2,1) # -> (2*bs,h*w,4) 237 | noise_pred = noise_pred.permute(0,2,3,1) # -> (bs, h, w, 4) 238 | if self.D_args.gan_unet_lastlayer_cls: 239 | pred = noise_pred # noise_pred -> (bs, h, w, 1) 240 | else: 241 | pred = self.mlp(noise_pred) # -> (bs, h, w, 2) 242 | target = torch.ones_like(pred) # -> (bs, h, w, 2) 243 | 244 | with torch.autocast('cuda'): 245 | gan_loss = self.cls_loss_fn(pred, target) 246 | return gan_loss 247 | 248 | 249 | elif side == 'D': 250 | 251 | self.unet.train() 252 | self.set_D_sd_pipeline_lora(requires_grad=True) 253 | training_latents.requires_grad_(False) 254 | 255 | D_cond = torch.cat( 256 | [kwargs['negative_prompt_embeds'], kwargs['negative_prompt_embeds']] 257 | ) 258 | 259 | with torch.no_grad(): 260 | ori_latents = self.get_D_gt_noise(device, **kwargs) 261 | 262 | add_time_ids = self.add_time_ids.repeat(2*bs, 1) 263 | # Predict the noise residual 264 | unet_added_conditions = {"time_ids": add_time_ids} 265 | unet_added_conditions.update( 266 | {"text_embeds": torch.cat([ 267 | kwargs['negative_pooled_prompt_embeds'], kwargs['negative_pooled_prompt_embeds'] 268 | ], dim=0)} 269 | ) 270 | 271 | self.ori_scheduler.set_timesteps(kwargs['num_inference_steps'], device=device) 272 | timesteps = self.ori_scheduler.timesteps 273 | 274 | input_latents = torch.cat([training_latents, ori_latents]) 275 | input_latents = self.ori_scheduler.scale_model_input(input_latents, timesteps[-1]) # identity in ddpm 276 | 277 | noise_pred = self.unet( 278 | input_latents, 279 | timesteps[-1], # marking its own step, could be out of domain 280 | encoder_hidden_states=D_cond, 281 | added_cond_kwargs=unet_added_conditions, 282 | return_dict=False, 283 | )[0] # (2*bs, 4, h, w) 284 | 285 | noise_pred = noise_pred.permute(0,2,3,1) # -> (2*bs, h, w, 4) 286 | if self.D_args.gan_unet_lastlayer_cls: 287 | pred = noise_pred # noise_pred -> (bs, h, w, 1) 288 | else: 289 | pred = self.mlp(noise_pred) # -> (bs, h, w, 2) 290 | target = torch.ones_like(pred) # -> (2*bs, h, w, 1) 291 | target[:target.shape[0]//2] = 0 # label = 0 for generated_image 292 | 293 | with torch.autocast('cuda'): 294 | gan_loss = self.cls_loss_fn(pred, target) 295 | return gan_loss 296 | 297 | @torch.no_grad() 298 | def encode_prompt(self, prompt, device, batch_size, do_classifier_free_guidance=False): 299 | null_embed, pooled_null_embed = super().encode_prompt(prompt, device, batch_size, do_classifier_free_guidance) 300 | self.D_sd_pipeline.text_encoder_2.to('cpu') 301 | 302 | return null_embed, pooled_null_embed 303 | 304 | def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): 305 | from torch.nn.parallel import DistributedDataParallel 306 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 307 | 308 | if isinstance(self.unet, DistributedDataParallel): 309 | passed_add_embed_dim = ( 310 | self.unet.module.config.addition_time_embed_dim * len(add_time_ids) + self.D_sd_pipeline.text_encoder_2.config.projection_dim 311 | ) 312 | expected_add_embed_dim = self.unet.module.add_embedding.linear_1.in_features 313 | else: 314 | passed_add_embed_dim = ( 315 | self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.D_sd_pipeline.text_encoder_2.config.projection_dim 316 | ) 317 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 318 | 319 | if expected_add_embed_dim != passed_add_embed_dim: 320 | raise ValueError( 321 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 322 | ) 323 | 324 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 325 | return add_time_ids -------------------------------------------------------------------------------- /training_utils/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | import transformers 5 | import diffusers 6 | 7 | def set_logger(args, accelerator, logger): 8 | # Make one log on every process with the configuration for debugging. 9 | logging.basicConfig( 10 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 11 | datefmt="%m/%d/%Y %H:%M:%S", 12 | level=logging.INFO, 13 | ) 14 | logger.info(accelerator.state, main_process_only=False) 15 | if accelerator.is_local_main_process: 16 | transformers.utils.logging.set_verbosity_warning() 17 | diffusers.utils.logging.set_verbosity_info() 18 | else: 19 | transformers.utils.logging.set_verbosity_error() 20 | diffusers.utils.logging.set_verbosity_error() 21 | 22 | # Handle the repository creation 23 | if accelerator.is_main_process: 24 | if args.output_dir is not None: 25 | os.makedirs(args.output_dir, exist_ok=True) 26 | 27 | # Add log file 28 | formatter = logging.Formatter('%(asctime)s %(filename)s %(levelname)s %(message)s') 29 | fh = logging.FileHandler(os.path.join(args.output_dir, 'log.txt'), mode='a', encoding='utf8') 30 | fh.setFormatter(formatter) 31 | logger.logger.addHandler(fh) 32 | logger.info(args) -------------------------------------------------------------------------------- /training_utils/pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from TrainableSDPipeline import TrainableSDPipeline, TrainableSDXLPipeline 4 | from AttrConcenTrainableSDPipeline import AttrConcenTrainableSDPipeline 5 | from AttrConcenTrainableSDXLPipeline import AttrConcenTrainableSDXLPipeline 6 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, DPMSolverMultistepScheduler 7 | from diffusers.models.lora import LoRALinearLayer 8 | from diffusers.loaders import ( 9 | LoraLoaderMixin, 10 | text_encoder_lora_state_dict, 11 | ) 12 | from diffusers.models.attention_processor import ( 13 | AttnAddedKVProcessor, 14 | AttnAddedKVProcessor2_0, 15 | SlicedAttnAddedKVProcessor, 16 | ) 17 | 18 | 19 | def _load_diffusion_pipeline(model_path, model_name, revision, weight_dtype, args=None): 20 | if model_name == 'sd_1_5': 21 | pipeline = TrainableSDPipeline.from_pretrained(model_path, revision=revision, torch_type=weight_dtype) 22 | elif model_name == 'sd_1_5_attrcon': 23 | pipeline = AttrConcenTrainableSDPipeline.from_pretrained(model_path, revision=revision, torch_type=weight_dtype) 24 | elif 'sdxl' in model_name: 25 | vae_path = "madebyollin/sdxl-vae-fp16-fix" 26 | vae = AutoencoderKL.from_pretrained(vae_path, torch_dtype=torch.float16) 27 | if 'unet' in model_name: 28 | unet = UNet2DConditionModel.from_pretrained(args.sdxl_unet_path, revision=revision) 29 | else: 30 | unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet", revision=revision) 31 | if 'attrcon' in model_name: 32 | PIPELINE_NAME = AttrConcenTrainableSDXLPipeline 33 | else: 34 | PIPELINE_NAME = TrainableSDXLPipeline 35 | pipeline = PIPELINE_NAME.from_pretrained(model_path, revision=revision, vae=vae, unet=unet, torch_type=weight_dtype) 36 | 37 | else: 38 | raise NotImplementedError("This model is not supported yet") 39 | return pipeline 40 | 41 | 42 | def load_pipeline(args, model_name, weight_dtype, is_D=False): 43 | # Load pipeline 44 | pipeline = _load_diffusion_pipeline(args.pretrain_model, model_name, args.revision, weight_dtype, args) 45 | 46 | scheduler_args = {} 47 | 48 | if args.scheduler == "DPM++": 49 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 50 | elif args.scheduler == "DDPM": 51 | if "variance_type" in pipeline.scheduler.config: 52 | variance_type = pipeline.scheduler.config.variance_type 53 | 54 | if variance_type in ["learned", "learned_range"]: 55 | variance_type = "fixed_small" 56 | 57 | scheduler_args["variance_type"] = variance_type 58 | 59 | pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config, **scheduler_args) 60 | if args.full_finetuning: 61 | pipeline.unet.to(dtype=torch.float) 62 | else: 63 | pipeline.unet.to(dtype=weight_dtype) 64 | pipeline.vae.to(dtype=weight_dtype) 65 | pipeline.text_encoder.to(dtype=weight_dtype) 66 | # set grad 67 | # Freeze vae and text_encoder 68 | pipeline.vae.requires_grad_(args.tune_vae) 69 | pipeline.text_encoder.requires_grad_(args.tune_text_encoder) 70 | pipeline.unet.requires_grad_(False) 71 | 72 | # gradient checkpoint 73 | if args.gradient_checkpointing: 74 | pipeline.unet.enable_gradient_checkpointing() 75 | if args.tune_text_encoder or args.train_text_encoder_lora: 76 | pipeline.text_encoder.gradient_checkpointing_enable() 77 | pipeline.text_encoder_2.gradient_checkpointing_enable() if hasattr(pipeline, "text_encoder_2") else None 78 | 79 | # set trainable lora 80 | pipeline = set_pipeline_trainable_module(args, pipeline, is_D=is_D) 81 | 82 | return pipeline 83 | 84 | def set_pipeline_trainable_module(args, pipeline, is_D=False): 85 | if not args.full_finetuning: 86 | # Set correct lora layers 87 | for attn_processor_name, attn_processor in pipeline.unet.attn_processors.items(): 88 | # Parse the attention module. 89 | attn_module = pipeline.unet 90 | for n in attn_processor_name.split(".")[:-1]: 91 | attn_module = getattr(attn_module, n) 92 | 93 | # Set the `lora_layer` attribute of the attention-related matrices. 94 | attn_module.to_q.set_lora_layer( 95 | LoRALinearLayer( 96 | in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.lora_rank 97 | ) 98 | ) 99 | attn_module.to_k.set_lora_layer( 100 | LoRALinearLayer( 101 | in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.lora_rank 102 | ) 103 | ) 104 | attn_module.to_v.set_lora_layer( 105 | LoRALinearLayer( 106 | in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.lora_rank 107 | ) 108 | ) 109 | attn_module.to_out[0].set_lora_layer( 110 | LoRALinearLayer( 111 | in_features=attn_module.to_out[0].in_features, 112 | out_features=attn_module.to_out[0].out_features, 113 | rank=args.lora_rank, 114 | ) 115 | ) 116 | 117 | if args.train_text_encoder_lora and not is_D: 118 | # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16 119 | text_lora_parameters = LoraLoaderMixin._modify_text_encoder(pipeline.text_encoder, dtype=torch.float32, rank=args.lora_rank) 120 | 121 | return pipeline 122 | 123 | def get_trainable_parameters(args, pipeline, is_D=False): 124 | # load unet parameters 125 | if args.full_finetuning: 126 | G_parameters = list(pipeline.unet.parameters()) 127 | else: 128 | G_parameters = [] 129 | for attn_processor_name, attn_processor in pipeline.unet.attn_processors.items(): 130 | 131 | attn_module = pipeline.unet 132 | for n in attn_processor_name.split(".")[:-1]: 133 | attn_module = getattr(attn_module, n) 134 | 135 | attn_module.to_q.lora_layer.to(torch.float) 136 | attn_module.to_k.lora_layer.to(torch.float) 137 | attn_module.to_v.lora_layer.to(torch.float) 138 | attn_module.to_out[0].lora_layer.to(torch.float) 139 | 140 | # Accumulate the LoRA params to optimize. 141 | G_parameters.extend(attn_module.to_q.lora_layer.parameters()) 142 | G_parameters.extend(attn_module.to_k.lora_layer.parameters()) 143 | G_parameters.extend(attn_module.to_v.lora_layer.parameters()) 144 | G_parameters.extend(attn_module.to_out[0].lora_layer.parameters()) 145 | 146 | if isinstance(attn_processor, (AttnAddedKVProcessor, SlicedAttnAddedKVProcessor, AttnAddedKVProcessor2_0)): 147 | import pdb 148 | pdb.set_trace() # should not enter here, otherwise the resume process is useless 149 | attn_module.add_k_proj.set_lora_layer( 150 | LoRALinearLayer( 151 | in_features=attn_module.add_k_proj.in_features, 152 | out_features=attn_module.add_k_proj.out_features, 153 | rank=args.lora_rank, 154 | ) 155 | ) 156 | attn_module.add_v_proj.set_lora_layer( 157 | LoRALinearLayer( 158 | in_features=attn_module.add_v_proj.in_features, 159 | out_features=attn_module.add_v_proj.out_features, 160 | rank=args.lora_rank, 161 | ) 162 | ) 163 | G_parameters.extend(attn_module.add_k_proj.lora_layer.parameters()) 164 | G_parameters.extend(attn_module.add_v_proj.lora_layer.parameters()) 165 | 166 | 167 | # load other parameters, not for D 168 | if not is_D: 169 | text_lora_parameters = [] 170 | if args.tune_vae: 171 | G_parameters.extend(pipeline.vae.parameters()) 172 | 173 | if args.tune_text_encoder: 174 | G_parameters.extend(pipeline.text_encoder.parameters()) 175 | 176 | if args.train_text_encoder_lora: 177 | for n, p in pipeline.text_encoder.named_parameters(): 178 | if 'lora' in n: 179 | if not p.requires_grad: 180 | import pdb 181 | pdb.set_trace() 182 | if p.dtype != torch.float: 183 | p.data = p.data.to(torch.float) 184 | text_lora_parameters.append(p) 185 | 186 | return G_parameters, text_lora_parameters 187 | else: 188 | return G_parameters -------------------------------------------------------------------------------- /valid.txt: -------------------------------------------------------------------------------- 1 | a red dog and a pink cat 2 | a traffic light next to a bed 3 | The black apple was on right of the red plate. 4 | a wooden cup and a plastic fork 5 | A black truck has a red dog in the drivers chair. --------------------------------------------------------------------------------