├── README.md ├── assets └── main.png ├── generation_with_nulltext_model.py ├── nulltext_attention.py ├── nulltext_unet.py └── prepare_nulltext_weight.py /README.md: -------------------------------------------------------------------------------- 1 | ## In-domain Generation with Diffusion Models 2 | The official code of "Image is All You Need to Empower Large-scale Diffusion Models for In-Domain Generation". 3 | 4 | ![img.pdf](assets/main.png) 5 | 6 | ### 1. Code for Null-text UNet 7 | 8 | #### 1.1 Prepare null-text checkpoint 9 | 10 | ```bash 11 | python prepare_nulltext_checkpoint.py 12 | ``` 13 | 14 | #### 1.2 Construct null-text UNet 15 | 16 | ```bash 17 | python nulltext_unet.py 18 | ``` 19 | 20 | ### 2. In-domain Generation with Multi-Guidance 21 | 22 | See `generation_with_nulltext_model.py` for details. 23 | 24 | The core code of GCFG are as follows: 25 | 26 | ```python 27 | noise_pred_text = self.unet( 28 | latent_model_input[1:], 29 | t, 30 | encoder_hidden_states=prompt_embeds[2:3], 31 | cross_attention_kwargs=cross_attention_kwargs, 32 | return_dict=False, 33 | )[0] 34 | 35 | noise_pred_text_ori = self.unet1( 36 | latent_model_input[1:], 37 | t, 38 | encoder_hidden_states=prompt_embeds[3:4], 39 | cross_attention_kwargs=cross_attention_kwargs, 40 | return_dict=False, 41 | )[0] 42 | 43 | noise_pred_uncond = self.unet0( 44 | latent_model_input[:1], 45 | t, 46 | encoder_hidden_states=prompt_embeds[:1], 47 | cross_attention_kwargs=cross_attention_kwargs, 48 | return_dict=False, 49 | )[0] 50 | 51 | # perform guidance 52 | if do_classifier_free_guidance: 53 | noise_pred = noise_pred_uncond + \ 54 | guidance_scale * (noise_pred_text - noise_pred_uncond) + \ 55 | guidance_scale_ori * (noise_pred_text_ori - noise_pred_uncond) 56 | ``` 57 | 58 | where ```self.unet0``` is SD1.5 for unconditional guidance, ```self.unet``` is in-domain diffusion model for domain guidance, and ```self.unet1``` is SD1.5 or customized SD for control guidance. 59 | 60 | ### TODO 61 | - [ ] Updating training codes in [UniDiffusion](https://github.com/PRIV-Creation/UniDiffusion). 62 | - [ ] Results on SDXL and SD3. 63 | -------------------------------------------------------------------------------- /assets/main.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PRIV-Creation/In-domain-Generation-Diffusion/b0c6e9d55167fc764742976cc9dcd87ab1c56a73/assets/main.png -------------------------------------------------------------------------------- /generation_with_nulltext_model.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import * 2 | import torch 3 | from diffusers import DPMSolverMultistepScheduler 4 | import os 5 | import cv2 6 | from nulltext_unet import UNet2DConditionModel_Nulltext 7 | 8 | 9 | class StableDiffusionPipeline_(StableDiffusionPipeline): 10 | @torch.no_grad() 11 | @replace_example_docstring(EXAMPLE_DOC_STRING) 12 | def gcfg_call( 13 | self, 14 | prompt: Union[str, List[str]] = None, 15 | guidance_scale_ori = None, 16 | height: Optional[int] = None, 17 | width: Optional[int] = None, 18 | num_inference_steps: int = 50, 19 | guidance_scale: float = 7.5, 20 | negative_prompt: Optional[Union[str, List[str]]] = None, 21 | num_images_per_prompt: Optional[int] = 1, 22 | eta: float = 0.0, 23 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 24 | latents: Optional[torch.FloatTensor] = None, 25 | prompt_embeds: Optional[torch.FloatTensor] = None, 26 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 27 | output_type: Optional[str] = "pil", 28 | return_dict: bool = True, 29 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 30 | callback_steps: int = 1, 31 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 32 | guidance_rescale: float = 0.0, 33 | ): 34 | r""" 35 | Function invoked when calling the pipeline for generation. 36 | 37 | Args: 38 | prompt (`str` or `List[str]`, *optional*): 39 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 40 | instead. 41 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 42 | The height in pixels of the generated image. 43 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 44 | The width in pixels of the generated image. 45 | num_inference_steps (`int`, *optional*, defaults to 50): 46 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 47 | expense of slower inference. 48 | guidance_scale (`float`, *optional*, defaults to 7.5): 49 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 50 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 51 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 52 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 53 | usually at the expense of lower image quality. 54 | negative_prompt (`str` or `List[str]`, *optional*): 55 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 56 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 57 | less than `1`). 58 | num_images_per_prompt (`int`, *optional*, defaults to 1): 59 | The number of images to generate per prompt. 60 | eta (`float`, *optional*, defaults to 0.0): 61 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 62 | [`schedulers.DDIMScheduler`], will be ignored for others. 63 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 64 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 65 | to make generation deterministic. 66 | latents (`torch.FloatTensor`, *optional*): 67 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 68 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 69 | tensor will ge generated by sampling using the supplied random `generator`. 70 | prompt_embeds (`torch.FloatTensor`, *optional*): 71 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 72 | provided, text embeddings will be generated from `prompt` input argument. 73 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 74 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 75 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 76 | argument. 77 | output_type (`str`, *optional*, defaults to `"pil"`): 78 | The output format of the generate image. Choose between 79 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 80 | return_dict (`bool`, *optional*, defaults to `True`): 81 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 82 | plain tuple. 83 | callback (`Callable`, *optional*): 84 | A function that will be called every `callback_steps` steps during inference. The function will be 85 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 86 | callback_steps (`int`, *optional*, defaults to 1): 87 | The frequency at which the `callback` function will be called. If not specified, the callback will be 88 | called at every step. 89 | cross_attention_kwargs (`dict`, *optional*): 90 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 91 | `self.processor` in 92 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 93 | guidance_rescale (`float`, *optional*, defaults to 0.7): 94 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are 95 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of 96 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 97 | Guidance rescale factor should fix overexposure when using zero terminal SNR. 98 | 99 | Examples: 100 | 101 | Returns: 102 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 103 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 104 | When returning a tuple, the first element is a list with the generated images, and the second element is a 105 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 106 | (nsfw) content, according to the `safety_checker`. 107 | """ 108 | # 0. Default height and width to unet 109 | height = height or self.unet.config.sample_size * self.vae_scale_factor 110 | width = width or self.unet.config.sample_size * self.vae_scale_factor 111 | 112 | # 1. Check inputs. Raise error if not correct 113 | self.check_inputs( 114 | prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 115 | ) 116 | 117 | # 2. Define call parameters 118 | if prompt is not None and isinstance(prompt, str): 119 | batch_size = 1 120 | elif prompt is not None and isinstance(prompt, list): 121 | batch_size = len(prompt) 122 | else: 123 | batch_size = prompt_embeds.shape[0] 124 | 125 | device = self._execution_device 126 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 127 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 128 | # corresponds to doing no classifier free guidance. 129 | do_classifier_free_guidance = True 130 | 131 | # 3. Encode input prompt 132 | text_encoder_lora_scale = ( 133 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 134 | ) 135 | prompt_embeds = self._encode_prompt( 136 | prompt, 137 | device, 138 | num_images_per_prompt, 139 | do_classifier_free_guidance, 140 | negative_prompt, 141 | prompt_embeds=prompt_embeds, 142 | negative_prompt_embeds=negative_prompt_embeds, 143 | lora_scale=text_encoder_lora_scale, 144 | ) 145 | 146 | # 4. Prepare timesteps 147 | self.scheduler.set_timesteps(num_inference_steps, device=device) 148 | timesteps = self.scheduler.timesteps 149 | 150 | # 5. Prepare latent variables 151 | num_channels_latents = self.unet.config.in_channels 152 | latents = self.prepare_latents( 153 | batch_size * num_images_per_prompt, 154 | num_channels_latents, 155 | height, 156 | width, 157 | prompt_embeds.dtype, 158 | device, 159 | generator, 160 | latents, 161 | ) 162 | latents = latents[:1] 163 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 164 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 165 | 166 | # 7. Denoising loop 167 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 168 | with self.progress_bar(total=num_inference_steps) as progress_bar: 169 | for i, t in enumerate(timesteps): 170 | # expand the latents if we are doing classifier free guidance 171 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 172 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 173 | 174 | # predict the noise residual 175 | noise_pred_text = self.unet( 176 | latent_model_input[1:], 177 | t, 178 | encoder_hidden_states=prompt_embeds[2:3], 179 | cross_attention_kwargs=cross_attention_kwargs, 180 | return_dict=False, 181 | )[0] 182 | 183 | noise_pred_text_ori = self.unet1( 184 | latent_model_input[1:], 185 | t, 186 | encoder_hidden_states=prompt_embeds[3:4], 187 | cross_attention_kwargs=cross_attention_kwargs, 188 | return_dict=False, 189 | )[0] 190 | 191 | noise_pred_uncond = self.unet0( 192 | latent_model_input[:1], 193 | t, 194 | encoder_hidden_states=prompt_embeds[:1], 195 | cross_attention_kwargs=cross_attention_kwargs, 196 | return_dict=False, 197 | )[0] 198 | 199 | # perform guidance 200 | if do_classifier_free_guidance: 201 | noise_pred = noise_pred_uncond + \ 202 | guidance_scale * (noise_pred_text - noise_pred_uncond) + \ 203 | guidance_scale_ori * (noise_pred_text_ori - noise_pred_uncond) 204 | 205 | if False and do_classifier_free_guidance and guidance_rescale > 0.0: 206 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 207 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 208 | 209 | # compute the previous noisy sample x_t -> x_t-1 210 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 211 | 212 | # call the callback, if provided 213 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 214 | progress_bar.update() 215 | if callback is not None and i % callback_steps == 0: 216 | callback(i, t, latents) 217 | 218 | if not output_type == "latent": 219 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 220 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 221 | else: 222 | image = latents 223 | has_nsfw_concept = None 224 | 225 | if has_nsfw_concept is None: 226 | do_denormalize = [True] * image.shape[0] 227 | else: 228 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 229 | 230 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 231 | 232 | # Offload last model to CPU 233 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 234 | self.final_offload_hook.offload() 235 | 236 | if not return_dict: 237 | return (image, has_nsfw_concept) 238 | 239 | return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) 240 | 241 | if __name__ == '__main__': 242 | model_id = "runwayml/stable-diffusion-v1-5" 243 | weight = torch.load("ffhq/unet/diffusion_pytorch_model.bin") # concept-centric diffusion model weight 244 | 245 | pipe = StableDiffusionPipeline_.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None) 246 | pipe.unet0 = pipe.unet.to(device="cuda", dtype=torch.float16) 247 | pipe.unet1 = pipe.unet.to(device="cuda", dtype=torch.float16) 248 | 249 | pipe.unet = UNet2DConditionModel_Nulltext.from_pretrained( 250 | pretrained_model_name_or_path=model_id, 251 | subfolder="unet", 252 | null_text_weight=weight 253 | ).to(device="cuda", dtype=torch.float16) 254 | 255 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 256 | pipe.to("cuda") 257 | 258 | prompt_ = [ 259 | "a photo of ", # not used in null-text mode 260 | "wearing glasses", # control prompt 261 | ] 262 | 263 | generator = torch.Generator(device="cuda").manual_seed(0) 264 | image1 = pipe.gcfg_call( 265 | prompt_, 266 | guidance_scale_ori=2.5, 267 | generator=generator, 268 | guidance_scale=2.5, 269 | num_inference_steps=25, 270 | ).images[0] 271 | 272 | image1.save(f"{prompt_[1]}.png") 273 | -------------------------------------------------------------------------------- /nulltext_attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | class NullTextAttention(torch.nn.Module): 5 | def __init__(self, attention, initial=False): 6 | super().__init__() 7 | self.attention = attention 8 | self.initial = initial 9 | self.apply_to() 10 | 11 | def apply_to(self): 12 | self.attention.forward = self.forward 13 | self.attention.set_use_memory_efficient_attention_xformers = lambda *args, **kwargs: None 14 | 15 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): 16 | if self.initial: 17 | return self.attention.null_text_feature.repeat([hidden_states.shape[0], hidden_states.shape[1], 1]) 18 | else: 19 | attn = self.attention 20 | residual = hidden_states 21 | 22 | if attn.spatial_norm is not None: 23 | hidden_states = attn.spatial_norm(hidden_states, None) 24 | 25 | input_ndim = hidden_states.ndim 26 | 27 | if input_ndim == 4: 28 | batch_size, channel, height, width = hidden_states.shape 29 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 30 | 31 | batch_size, sequence_length, _ = ( 32 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 33 | ) 34 | inner_dim = hidden_states.shape[-1] 35 | 36 | if attention_mask is not None: 37 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 38 | # scaled_dot_product_attention expects attention_mask shape to be 39 | # (batch, heads, source_length, target_length) 40 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 41 | 42 | if attn.group_norm is not None: 43 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 44 | 45 | query = attn.to_q(hidden_states) 46 | 47 | if encoder_hidden_states is None: 48 | encoder_hidden_states = hidden_states 49 | elif attn.norm_cross: 50 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 51 | 52 | key = attn.to_k(encoder_hidden_states) 53 | value = attn.to_v(encoder_hidden_states) 54 | 55 | head_dim = inner_dim // attn.heads 56 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 57 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 58 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 59 | 60 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 61 | # TODO: add support for attn.scale when we move to Torch 2.1 62 | hidden_states = F.scaled_dot_product_attention( 63 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 64 | ) 65 | 66 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 67 | hidden_states = hidden_states.to(query.dtype) 68 | 69 | # linear proj 70 | hidden_states = attn.to_out[0](hidden_states) 71 | # dropout 72 | hidden_states = attn.to_out[1](hidden_states) 73 | 74 | if input_ndim == 4: 75 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 76 | 77 | if attn.residual_connection: 78 | hidden_states = hidden_states + residual 79 | 80 | hidden_states = hidden_states / attn.rescale_output_factor 81 | 82 | attn.register_buffer('null_text_feature', hidden_states.mean(dim=[0,1], keepdim=True)) 83 | 84 | del attn.to_out 85 | del attn.to_q 86 | del attn.to_k 87 | del attn.to_v 88 | 89 | self.initial = True 90 | return hidden_states 91 | -------------------------------------------------------------------------------- /nulltext_unet.py: -------------------------------------------------------------------------------- 1 | from diffusers import UNet2DConditionModel 2 | from diffusers.models.attention import BasicTransformerBlock 3 | from nulltext_attention import NullTextAttention 4 | import torch 5 | 6 | 7 | class UNet2DConditionModel_Nulltext(UNet2DConditionModel): 8 | @classmethod 9 | def from_pretrained(cls, null_text_weight, *args, **kwargs): 10 | unet = super().from_pretrained(*args, **kwargs) 11 | for name, module in unet.named_modules(): 12 | if name.endswith("attn2"): 13 | _ = NullTextAttention(module, initial=True) 14 | channel = module.to_q.in_features 15 | del module.to_q 16 | del module.to_k 17 | del module.to_v 18 | del module.to_out 19 | module.register_buffer("null_text_feature", torch.zeros([1, 1, channel])) 20 | if isinstance(module, BasicTransformerBlock): 21 | module.norm2 = torch.nn.Identity() 22 | unet.load_state_dict(null_text_weight, strict=True) 23 | return unet 24 | 25 | 26 | if __name__ == '__main__': 27 | unet = UNet2DConditionModel_Nulltext.from_pretrained( 28 | pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5", 29 | subfolder='unet', 30 | null_text_weight="stable-diffusion-v1-5_null-text.pt" 31 | ) -------------------------------------------------------------------------------- /prepare_nulltext_weight.py: -------------------------------------------------------------------------------- 1 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import * 2 | import torch 3 | import torch.nn.functional as F 4 | from nulltext_attention import NullTextAttention 5 | 6 | 7 | model_id = "runwayml/stable-diffusion-v1-5" 8 | pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None) 9 | pipe.to("cuda") 10 | 11 | for name, module in pipe.unet.named_modules(): 12 | if name.endswith("attn2"): 13 | _ = NullTextAttention(module) 14 | 15 | torch.save(pipe.unet.state_dict(), "stable-diffusion-v1-5_null-text.pt") 16 | --------------------------------------------------------------------------------