├── .gitignore ├── LICENSE ├── README.md ├── assets ├── architecture.jpg └── banner.jpg ├── flux_nag_demo.ipynb ├── nag ├── __init__.py ├── attention_flux_nag.py ├── attention_joint_nag.py ├── attention_nag.py ├── attention_wan_nag.py ├── normalization.py ├── pipeline_flux_kontext_nag.py ├── pipeline_flux_nag.py ├── pipeline_sd3_nag.py ├── pipeline_sdxl_nag.py ├── pipeline_wan_nag.py ├── transformer_flux.py └── transformer_wan_nag.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | .idea/ 169 | 170 | # Abstra 171 | # Abstra is an AI-powered process automation framework. 172 | # Ignore directories containing user credentials, local state, and settings. 173 | # Learn more at https://abstra.io/docs 174 | .abstra/ 175 | 176 | # Visual Studio Code 177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore 178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore 179 | # and can be added to the global gitignore or merged into this file. However, if you prefer, 180 | # you could uncomment the following to ignore the enitre vscode folder 181 | # .vscode/ 182 | 183 | # Ruff stuff: 184 | .ruff_cache/ 185 | 186 | # PyPI configuration file 187 | .pypirc 188 | 189 | # Cursor 190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to 191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data 192 | # refer to https://docs.cursor.com/context/ignore-files 193 | .cursorignore 194 | .cursorindexingignore 195 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Dar-Yen Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models 2 | 3 | [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-Kontext-Dev) 4 | [![Project Page](https://img.shields.io/badge/Project-Page-green.svg)](https://chendaryen.github.io/NAG.github.io/) 5 | [![arXiv](https://img.shields.io/badge/arXiv-2505.21179-b31b1b.svg)](https://arxiv.org/abs/2505.21179) 6 | [![Page Views Count](https://badges.toozhao.com/badges/01JWNDV5JQ2XT69RCZ5KQBCY0E/blue.svg)](https://badges.toozhao.com/stats/01JWNDV5JQ2XT69RCZ5KQBCY0E "Get your own page views count badge on badges.toozhao.com") 7 | 8 | 9 | ![](./assets/banner.jpg) 10 | Negative prompting on 4-step Flux-Schnell: 11 | CFG fails in few-step models. NAG restores effective negative prompting, enabling direct suppression of visual, semantic, and stylistic attributes, such as ``glasses``, ``tiger``, ``realistic``, or ``blurry``. This enhances controllability and expands creative freedom across composition, style, and quality—including prompt-based debiasing. 12 | 13 | 14 | ## News 15 | 16 | 17 | **2025-06-30:** 🤗 Code and [demo](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-Kontext-Dev) for `Flux Kontext` is now available! 18 | 19 | **2025-06-28:** 🎉 Our [ComfyUI implementation](https://github.com/ChenDarYen/ComfyUI-NAG) now supports `Flux Kontext`, `Wan2.1`, and `Hunyuan Video`! 20 | 21 | **2025-06-24:** 🎉 A [ComfyUI node](https://github.com/kijai/ComfyUI-KJNodes/blob/f7eb33abc80a2aded1b46dff0dd14d07856a7d50/nodes/model_optimization_nodes.py#L1568) for Wan is now available! Big thanks to [Kijai](https://github.com/kijai)! 22 | 23 | **2025-06-24:** 🤗 Demo for [LTX Video Fast](https://huggingface.co/spaces/ChenDY/NAG_ltx-video-distilled) is now available! 24 | 25 | **2025-06-22:** 🚀 SD3.5 pipeline is released! 26 | 27 | **2025-06-22:** 🎉 Play with the [ComfyUI implementation](https://github.com/ChenDarYen/ComfyUI-NAG) now! 28 | 29 | **2025-06-19:** 🚀 Wan2.1 and the SDXL pipeline are released! 30 | 31 | **2025-06-09:** 🤗 Demo for [4-step Wan2.1 with CausVid](https://huggingface.co/spaces/ChenDY/NAG_wan2-1-fast) video generation is now available! 32 | 33 | **2025-06-01:** 🤗 Demo for [Flux-Schnell](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-schnell) and [Flux-Dev](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev) are now available! 34 | 35 | 36 | ## Approach 37 | 38 | The prevailing approach to diffusion model control, Classifier-Free Guidance (CFG), enables negative guidance by extrapolating between positive and negative conditional outputs at each denoising step. However, in few-step regimes, CFG's assumption of consistent structure between diffusion branches breaks down, as these branches diverge dramatically at early steps. This divergence causes severe artifacts rather than controlled guidance. 39 | 40 | Normalized Attention Guidance (NAG) operates in attention space by extrapolating positive and negative features Z+ and Z-, followed by L1-based normalization and α-blending. This constrains feature deviation, suppresses out-of-manifold drift, and achieves stable, controllable guidance. 41 | 42 | ![](./assets/architecture.jpg) 43 | 44 | ## Installation 45 | 46 | Install directly from GitHub: 47 | 48 | ```bash 49 | pip install git+https://github.com/ChenDarYen/Normalized-Attention-Guidance.git 50 | ``` 51 | 52 | ## Usage 53 | 54 | ### Flux 55 | You can try NAG in `flux_nag_demo.ipynb`, or 🤗 Hugging Face Demo for [Flux-Schell](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-schnell) and [Flux-Dev](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev)! 56 | 57 | Loading Custom Pipeline: 58 | 59 | ```python 60 | import torch 61 | from nag import NAGFluxPipeline 62 | from nag import NAGFluxTransformer2DModel 63 | 64 | transformer = NAGFluxTransformer2DModel.from_pretrained( 65 | "black-forest-labs/FLUX.1-schnell", 66 | subfolder="transformer", 67 | torch_dtype=torch.bfloat16, 68 | token="hf_token", 69 | ) 70 | pipe = NAGFluxPipeline.from_pretrained( 71 | "black-forest-labs/FLUX.1-schnell", 72 | transformer=transformer, 73 | torch_dtype=torch.bfloat16, 74 | token="hf_token", 75 | ) 76 | pipe.to("cuda") 77 | ``` 78 | 79 | Sampling with NAG: 80 | 81 | ```python 82 | prompt = "Portrait of AI researcher." 83 | nag_negative_prompt = "Glasses." 84 | # prompt = "A baby phoenix made of fire and flames is born from the smoking ashes." 85 | # nag_negative_prompt = "Low resolution, blurry, lack of details, illustration, cartoon, painting." 86 | 87 | image = pipe( 88 | prompt, 89 | nag_negative_prompt=nag_negative_prompt, 90 | guidance_scale=0.0, 91 | nag_scale=5.0, 92 | num_inference_steps=4, 93 | max_sequence_length=256, 94 | ).images[0] 95 | ``` 96 | 97 | ### Flux Kontext 98 | 99 | ```python 100 | import torch 101 | from diffusers.utils import load_image 102 | from nag import NAGFluxKontextPipeline 103 | from nag import NAGFluxTransformer2DModel 104 | 105 | transformer = NAGFluxTransformer2DModel.from_pretrained( 106 | "black-forest-labs/FLUX.1-schnell", 107 | subfolder="transformer", 108 | torch_dtype=torch.bfloat16, 109 | token="hf_token", 110 | ) 111 | pipe = NAGFluxKontextPipeline.from_pretrained( 112 | "black-forest-labs/FLUX.1-schnell", 113 | transformer=transformer, 114 | torch_dtype=torch.bfloat16, 115 | token="hf_token", 116 | ) 117 | pipe.to("cuda") 118 | 119 | input_image = load_image( 120 | "https://raw.githubusercontent.com/Comfy-Org/example_workflows/main/flux/kontext/dev/rabbit.jpg") 121 | prompt = "Using this elegant style, create a portrait of a cute Godzilla wearing a pearl tiara and lace collar, maintaining the same refined quality and soft color tones." 122 | nag_negative_prompt = "Low resolution, blurry, lack of details" 123 | 124 | image = pipe( 125 | prompt=prompt, 126 | image=input_image, 127 | nag_negative_prompt=nag_negative_prompt, 128 | guidance_scale=2.5, 129 | nag_scale=5.0, 130 | num_inference_steps=25, 131 | width=input_image.size[0], 132 | height=input_image.size[1], 133 | ).images[0] 134 | ``` 135 | 136 | ### Wan2.1 137 | 138 | ```python 139 | import torch 140 | from diffusers import AutoencoderKLWan, UniPCMultistepScheduler 141 | from nag import NagWanTransformer3DModel 142 | from nag import NAGWanPipeline 143 | 144 | model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" 145 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) 146 | transformer = NagWanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16) 147 | pipe = NAGWanPipeline.from_pretrained( 148 | model_id, 149 | vae=vae, 150 | transformer=transformer, 151 | torch_dtype=torch.bfloat16, 152 | ) 153 | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0) 154 | pipe.to("cuda") 155 | 156 | prompt = "An origami fox running in the forest. The fox is made of polygons. speed and passion. realistic." 157 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 158 | nag_negative_prompt = "static, low resolution, blurry" 159 | 160 | output = pipe( 161 | prompt=prompt, 162 | negative_prompt=negative_prompt, 163 | nag_negative_prompt=nag_negative_prompt, 164 | guidance_scale=5.0, 165 | nag_scale=9, 166 | height=480, 167 | width=832, 168 | num_inference_steps=25, 169 | num_frames=81, 170 | ).frames[0] 171 | ``` 172 | 173 | For 4-step inference with CausVid, please refer to the [demo](https://huggingface.co/spaces/ChenDY/NAG_wan2-1-fast/blob/main/app.py). 174 | 175 | ### SD3.5 176 | 177 | ```python 178 | import torch 179 | from nag import NAGStableDiffusion3Pipeline 180 | 181 | model_id = "stabilityai/stable-diffusion-3.5-large-turbo" 182 | pipe = NAGStableDiffusion3Pipeline.from_pretrained( 183 | model_id, 184 | torch_dtype=torch.bfloat16, 185 | token="hf_token", 186 | ) 187 | pipe.to("cuda") 188 | 189 | prompt = "A beautiful cyborg" 190 | nag_negative_prompt = "robot" 191 | 192 | image = pipe( 193 | prompt, 194 | nag_negative_prompt=nag_negative_prompt, 195 | guidance_scale=0., 196 | nag_scale=5, 197 | num_inference_steps=8, 198 | ).images[0] 199 | ``` 200 | 201 | ### SDXL 202 | 203 | ```python 204 | import torch 205 | from diffusers import UNet2DConditionModel, LCMScheduler 206 | from huggingface_hub import hf_hub_download 207 | from nag import NAGStableDiffusionXLPipeline 208 | 209 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" 210 | repo_name = "tianweiy/DMD2" 211 | ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin" 212 | 213 | unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.bfloat16) 214 | unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda")) 215 | pipe = NAGStableDiffusionXLPipeline.from_pretrained( 216 | base_model_id, 217 | unet=unet, 218 | torch_dtype=torch.bfloat16, 219 | variant="fp16", 220 | ).to("cuda") 221 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, original_inference_steps=4) 222 | 223 | prompt = "A beautiful cyborg" 224 | nag_negative_prompt = "robot" 225 | 226 | image = pipe( 227 | prompt, 228 | nag_negative_prompt=nag_negative_prompt, 229 | guidance_scale=0, 230 | nag_scale=3, 231 | num_inference_steps=4, 232 | ).images[0] 233 | ``` 234 | 235 | ## Citation 236 | 237 | If you find NAG is useful or relevant to your research, please kindly cite our work: 238 | 239 | ```bib 240 | @article{chen2025normalizedattentionguidanceuniversal, 241 | title={Normalized Attention Guidance: Universal Negative Guidance for Diffusion Model}, 242 | author={Dar-Yen Chen and Hmrishav Bandyopadhyay and Kai Zou and Yi-Zhe Song}, 243 | journal={arXiv preprint arxiv:2505.21179}, 244 | year={2025} 245 | } 246 | ``` 247 | 248 | -------------------------------------------------------------------------------- /assets/architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenDarYen/Normalized-Attention-Guidance/8bb34a16e517692743fab99bc60838c20f975ded/assets/architecture.jpg -------------------------------------------------------------------------------- /assets/banner.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChenDarYen/Normalized-Attention-Guidance/8bb34a16e517692743fab99bc60838c20f975ded/assets/banner.jpg -------------------------------------------------------------------------------- /nag/__init__.py: -------------------------------------------------------------------------------- 1 | from .pipeline_flux_nag import NAGFluxPipeline 2 | from .pipeline_flux_kontext_nag import NAGFluxKontextPipeline 3 | from .pipeline_wan_nag import NAGWanPipeline 4 | from .pipeline_sd3_nag import NAGStableDiffusion3Pipeline 5 | from .pipeline_sdxl_nag import NAGStableDiffusionXLPipeline 6 | from .transformer_flux import NAGFluxTransformer2DModel 7 | from .transformer_wan_nag import NagWanTransformer3DModel 8 | -------------------------------------------------------------------------------- /nag/attention_flux_nag.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from diffusers.models.attention_processor import Attention 8 | from diffusers.models.embeddings import apply_rotary_emb 9 | 10 | 11 | class NAGFluxAttnProcessor2_0: 12 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 13 | 14 | def __init__( 15 | self, 16 | nag_scale: float = 1.0, 17 | nag_tau=2.5, 18 | nag_alpha=0.25, 19 | encoder_hidden_states_length: int = None, 20 | ): 21 | if not hasattr(F, "scaled_dot_product_attention"): 22 | raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 23 | self.nag_scale = nag_scale 24 | self.nag_tau = nag_tau 25 | self.nag_alpha = nag_alpha 26 | self.encoder_hidden_states_length = encoder_hidden_states_length 27 | 28 | def __call__( 29 | self, 30 | attn: Attention, 31 | hidden_states: torch.FloatTensor, 32 | encoder_hidden_states: torch.FloatTensor = None, 33 | attention_mask: Optional[torch.FloatTensor] = None, 34 | image_rotary_emb: Optional[torch.Tensor] = None, 35 | ) -> torch.FloatTensor: 36 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 37 | 38 | if self.nag_scale > 1.: 39 | if encoder_hidden_states is not None: 40 | assert len(hidden_states) == batch_size * 0.5 41 | apply_guidance = True 42 | else: 43 | apply_guidance = False 44 | 45 | # `sample` projections. 46 | query = attn.to_q(hidden_states) 47 | key = attn.to_k(hidden_states) 48 | value = attn.to_v(hidden_states) 49 | 50 | # attention 51 | if apply_guidance and encoder_hidden_states is not None: 52 | query = query.tile(2, 1, 1) 53 | key = key.tile(2, 1, 1) 54 | value = value.tile(2, 1, 1) 55 | 56 | inner_dim = key.shape[-1] 57 | head_dim = inner_dim // attn.heads 58 | 59 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 60 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 61 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 62 | 63 | if attn.norm_q is not None: 64 | query = attn.norm_q(query) 65 | if attn.norm_k is not None: 66 | key = attn.norm_k(key) 67 | 68 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` 69 | if encoder_hidden_states is not None: 70 | # `context` projections. 71 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 72 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 73 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 74 | 75 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 76 | batch_size, -1, attn.heads, head_dim 77 | ).transpose(1, 2) 78 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 79 | batch_size, -1, attn.heads, head_dim 80 | ).transpose(1, 2) 81 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 82 | batch_size, -1, attn.heads, head_dim 83 | ).transpose(1, 2) 84 | 85 | if attn.norm_added_q is not None: 86 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 87 | if attn.norm_added_k is not None: 88 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 89 | 90 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) 91 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) 92 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) 93 | 94 | encoder_hidden_states_length = encoder_hidden_states.shape[1] 95 | 96 | else: 97 | assert self.encoder_hidden_states_length is not None 98 | encoder_hidden_states_length = self.encoder_hidden_states_length 99 | 100 | if image_rotary_emb is not None: 101 | query = apply_rotary_emb(query, image_rotary_emb) 102 | key = apply_rotary_emb(key, image_rotary_emb) 103 | 104 | if not apply_guidance: 105 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 106 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 107 | hidden_states = hidden_states.to(query.dtype) 108 | 109 | else: 110 | origin_batch_size = batch_size // 2 111 | query, query_negative = torch.chunk(query, 2, dim=0) 112 | key, key_negative = torch.chunk(key, 2, dim=0) 113 | value, value_negative = torch.chunk(value, 2, dim=0) 114 | 115 | hidden_states_negative = F.scaled_dot_product_attention(query_negative, key_negative, value_negative, dropout_p=0.0, is_causal=False) 116 | hidden_states_negative = hidden_states_negative.transpose(1, 2).reshape(origin_batch_size, -1, attn.heads * head_dim) 117 | hidden_states_negative = hidden_states_negative.to(query.dtype) 118 | 119 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 120 | hidden_states = hidden_states.transpose(1, 2).reshape(origin_batch_size, -1, attn.heads * head_dim) 121 | hidden_states = hidden_states.to(query.dtype) 122 | 123 | if encoder_hidden_states is not None: 124 | encoder_hidden_states, hidden_states = ( 125 | hidden_states[:, : encoder_hidden_states.shape[1]], 126 | hidden_states[:, encoder_hidden_states.shape[1] :], 127 | ) 128 | 129 | if apply_guidance: 130 | encoder_hidden_states_negative, hidden_states_negative = ( 131 | hidden_states_negative[:, : encoder_hidden_states.shape[1]], 132 | hidden_states_negative[:, encoder_hidden_states.shape[1]:], 133 | ) 134 | hidden_states_positive = hidden_states 135 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1) 136 | norm_positive = torch.norm(hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) 137 | norm_guidance = torch.norm(hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) 138 | 139 | scale = norm_guidance / norm_positive 140 | hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale 141 | 142 | hidden_states = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha) 143 | 144 | encoder_hidden_states = torch.cat((encoder_hidden_states, encoder_hidden_states_negative), dim=0) 145 | 146 | # linear proj 147 | hidden_states = attn.to_out[0](hidden_states) 148 | # dropout 149 | hidden_states = attn.to_out[1](hidden_states) 150 | 151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 152 | 153 | return hidden_states, encoder_hidden_states 154 | 155 | else: 156 | if apply_guidance: 157 | image_hidden_states_negative = hidden_states_negative[:, encoder_hidden_states_length:] 158 | image_hidden_states = hidden_states[:, encoder_hidden_states_length:] 159 | 160 | image_hidden_states_positive = image_hidden_states 161 | image_hidden_states_guidance = image_hidden_states_positive * self.nag_scale - image_hidden_states_negative * (self.nag_scale - 1) 162 | norm_positive = torch.norm(image_hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*image_hidden_states_positive.shape) 163 | norm_guidance = torch.norm(image_hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*image_hidden_states_positive.shape) 164 | 165 | scale = norm_guidance / norm_positive 166 | image_hidden_states_guidance = image_hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale 167 | # scale = torch.nan_to_num(scale, 10) 168 | # image_hidden_states_guidance[scale > self.nag_tau] = image_hidden_states_guidance[scale > self.nag_tau] / (norm_guidance[scale > self.nag_tau] + 1e-7) * norm_positive[scale > self.nag_tau] * self.nag_tau 169 | 170 | image_hidden_states = image_hidden_states_guidance * self.nag_alpha + image_hidden_states_positive * (1 - self.nag_alpha) 171 | 172 | hidden_states_negative[:, encoder_hidden_states_length:] = image_hidden_states 173 | hidden_states[:, encoder_hidden_states_length:] = image_hidden_states 174 | hidden_states = torch.cat((hidden_states, hidden_states_negative), dim=0) 175 | 176 | return hidden_states 177 | -------------------------------------------------------------------------------- /nag/attention_joint_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from diffusers.models.attention_processor import Attention 7 | 8 | 9 | class NAGJointAttnProcessor2_0: 10 | def __init__(self, nag_scale: float = 1.0, nag_tau: float = 2.5, nag_alpha:float = 0.125): 11 | if not hasattr(F, "scaled_dot_product_attention"): 12 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 13 | self.nag_scale = nag_scale 14 | self.nag_tau = nag_tau 15 | self.nag_alpha = nag_alpha 16 | 17 | def __call__( 18 | self, 19 | attn: Attention, 20 | hidden_states: torch.FloatTensor, 21 | encoder_hidden_states: torch.FloatTensor = None, 22 | attention_mask: Optional[torch.FloatTensor] = None, 23 | *args, 24 | **kwargs, 25 | ) -> torch.FloatTensor: 26 | residual = hidden_states 27 | 28 | batch_size = hidden_states.shape[0] 29 | 30 | apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None 31 | if apply_guidance: 32 | origin_batch_size = len(encoder_hidden_states) - batch_size 33 | assert len(encoder_hidden_states) / origin_batch_size in [2, 3, 4] 34 | 35 | # `sample` projections. 36 | query = attn.to_q(hidden_states) 37 | key = attn.to_k(hidden_states) 38 | value = attn.to_v(hidden_states) 39 | 40 | inner_dim = key.shape[-1] 41 | head_dim = inner_dim // attn.heads 42 | 43 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 44 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 45 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 46 | 47 | if attn.norm_q is not None: 48 | query = attn.norm_q(query) 49 | if attn.norm_k is not None: 50 | key = attn.norm_k(key) 51 | 52 | if apply_guidance: 53 | batch_size += origin_batch_size 54 | if batch_size == 2 * origin_batch_size: 55 | query = query.tile(2, 1, 1, 1) 56 | key = key.tile(2, 1, 1, 1) 57 | value = value.tile(2, 1, 1, 1) 58 | else: 59 | query = torch.cat([query, query[origin_batch_size:2 * origin_batch_size]], dim=0) 60 | key = torch.cat([key, key[origin_batch_size:2 * origin_batch_size]], dim=0) 61 | value = torch.cat([value, value[origin_batch_size:2 * origin_batch_size]], dim=0) 62 | 63 | # `context` projections. 64 | if encoder_hidden_states is not None: 65 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) 66 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 67 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 68 | 69 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( 70 | batch_size, -1, attn.heads, head_dim 71 | ).transpose(1, 2) 72 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( 73 | batch_size, -1, attn.heads, head_dim 74 | ).transpose(1, 2) 75 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( 76 | batch_size, -1, attn.heads, head_dim 77 | ).transpose(1, 2) 78 | 79 | if attn.norm_added_q is not None: 80 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) 81 | if attn.norm_added_k is not None: 82 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) 83 | 84 | query = torch.cat([query, encoder_hidden_states_query_proj], dim=2) 85 | key = torch.cat([key, encoder_hidden_states_key_proj], dim=2) 86 | value = torch.cat([value, encoder_hidden_states_value_proj], dim=2) 87 | 88 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) 89 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 90 | hidden_states = hidden_states.to(query.dtype) 91 | 92 | if encoder_hidden_states is not None: 93 | # Split the attention outputs. 94 | hidden_states, encoder_hidden_states = ( 95 | hidden_states[:, : residual.shape[1]], 96 | hidden_states[:, residual.shape[1] :], 97 | ) 98 | if not attn.context_pre_only: 99 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states) 100 | 101 | if apply_guidance: 102 | hidden_states_negative = hidden_states[-origin_batch_size:] 103 | if batch_size == 2 * origin_batch_size: 104 | hidden_states_positive = hidden_states[:origin_batch_size] 105 | else: 106 | hidden_states_positive = hidden_states[origin_batch_size:2 * origin_batch_size] 107 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1) 108 | norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) 109 | norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape) 110 | 111 | scale = norm_guidance / (norm_positive + 1e-7) 112 | hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / (scale + 1e-7) 113 | 114 | hidden_states_guidance = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha) 115 | 116 | if batch_size == 2 * origin_batch_size: 117 | hidden_states = hidden_states_guidance 118 | elif batch_size == 3 * origin_batch_size: 119 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance), dim=0) 120 | elif batch_size == 4 * origin_batch_size: 121 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance, hidden_states[2 * origin_batch_size:3 * origin_batch_size]), dim=0) 122 | 123 | # linear proj 124 | hidden_states = attn.to_out[0](hidden_states) 125 | # dropout 126 | hidden_states = attn.to_out[1](hidden_states) 127 | 128 | if encoder_hidden_states is not None: 129 | return hidden_states, encoder_hidden_states 130 | else: 131 | return hidden_states 132 | 133 | 134 | class NAGPAGCFGJointAttnProcessor2_0: 135 | """Attention processor used typically in processing the SD3-like self-attention projections.""" 136 | def __init__(self, nag_scale: float = 1.0, nag_tau: float = 2.5, nag_alpha:float = 0.125): 137 | if not hasattr(F, "scaled_dot_product_attention"): 138 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 139 | self.nag_scale = nag_scale 140 | self.nag_tau = nag_tau 141 | self.nag_alpha = nag_alpha 142 | 143 | def __call__( 144 | self, 145 | attn: Attention, 146 | hidden_states: torch.FloatTensor, 147 | encoder_hidden_states: torch.FloatTensor = None, 148 | attention_mask: Optional[torch.FloatTensor] = None, 149 | *args, 150 | **kwargs, 151 | ) -> torch.FloatTensor: 152 | residual = hidden_states 153 | 154 | input_ndim = hidden_states.ndim 155 | if input_ndim == 4: 156 | batch_size, channel, height, width = hidden_states.shape 157 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 158 | context_input_ndim = encoder_hidden_states.ndim 159 | if context_input_ndim == 4: 160 | batch_size, channel, height, width = encoder_hidden_states.shape 161 | encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 162 | 163 | identity_block_size = hidden_states.shape[ 164 | 1 165 | ] # patch embeddings width * height (correspond to self-attention map width or height) 166 | 167 | # chunk 168 | hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) 169 | hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) 170 | 171 | ( 172 | encoder_hidden_states_uncond, 173 | encoder_hidden_states_org, 174 | encoder_hidden_states_ptb, 175 | encoder_hidden_states_nag, 176 | ) = encoder_hidden_states.chunk(4) 177 | encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org, encoder_hidden_states_nag]) 178 | 179 | ################## original path ################## 180 | batch_size = encoder_hidden_states_org.shape[0] 181 | origin_batch_size = batch_size // 3 182 | 183 | # `sample` projections. 184 | query_org = attn.to_q(hidden_states_org) 185 | key_org = attn.to_k(hidden_states_org) 186 | value_org = attn.to_v(hidden_states_org) 187 | 188 | query_org = torch.cat([query_org, query_org[-origin_batch_size:]], dim=0) 189 | key_org = torch.cat([key_org, key_org[-origin_batch_size:]], dim=0) 190 | value_org = torch.cat([value_org, value_org[-origin_batch_size:]], dim=0) 191 | 192 | # `context` projections. 193 | encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org) 194 | encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org) 195 | encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org) 196 | 197 | # attention 198 | query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1) 199 | key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1) 200 | value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1) 201 | 202 | inner_dim = key_org.shape[-1] 203 | head_dim = inner_dim // attn.heads 204 | query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 205 | key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 206 | value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 207 | 208 | hidden_states_org = F.scaled_dot_product_attention( 209 | query_org, key_org, value_org, dropout_p=0.0, is_causal=False 210 | ) 211 | hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 212 | hidden_states_org = hidden_states_org.to(query_org.dtype) 213 | 214 | # Split the attention outputs. 215 | hidden_states_org, encoder_hidden_states_org = ( 216 | hidden_states_org[:, : residual.shape[1]], 217 | hidden_states_org[:, residual.shape[1] :], 218 | ) 219 | 220 | hidden_states_org_negative = hidden_states_org[-origin_batch_size:] 221 | hidden_states_org_positive = hidden_states_org[-2 * origin_batch_size:-origin_batch_size] 222 | hidden_states_org_guidance = hidden_states_org_positive * self.nag_scale - hidden_states_org_negative * (self.nag_scale - 1) 223 | norm_positive = torch.norm(hidden_states_org_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_org_positive.shape) 224 | norm_guidance = torch.norm(hidden_states_org_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_org_guidance.shape) 225 | 226 | scale = norm_guidance / (norm_positive + 1e-7) 227 | hidden_states_org_guidance = hidden_states_org_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / (scale + 1e-7) 228 | 229 | hidden_states_org_guidance = hidden_states_org_guidance * self.nag_alpha + hidden_states_org_positive * (1 - self.nag_alpha) 230 | 231 | hidden_states_org = torch.cat((hidden_states_org[:origin_batch_size], hidden_states_org_guidance), dim=0) 232 | 233 | # linear proj 234 | hidden_states_org = attn.to_out[0](hidden_states_org) 235 | # dropout 236 | hidden_states_org = attn.to_out[1](hidden_states_org) 237 | if not attn.context_pre_only: 238 | encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org) 239 | 240 | if input_ndim == 4: 241 | hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) 242 | if context_input_ndim == 4: 243 | encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape( 244 | batch_size, channel, height, width 245 | ) 246 | 247 | ################## perturbed path ################## 248 | 249 | batch_size = encoder_hidden_states_ptb.shape[0] 250 | 251 | # `sample` projections. 252 | query_ptb = attn.to_q(hidden_states_ptb) 253 | key_ptb = attn.to_k(hidden_states_ptb) 254 | value_ptb = attn.to_v(hidden_states_ptb) 255 | 256 | # `context` projections. 257 | encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb) 258 | encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb) 259 | encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb) 260 | 261 | # attention 262 | query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1) 263 | key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1) 264 | value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1) 265 | 266 | inner_dim = key_ptb.shape[-1] 267 | head_dim = inner_dim // attn.heads 268 | query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 269 | key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 270 | value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 271 | 272 | # create a full mask with all entries set to 0 273 | seq_len = query_ptb.size(2) 274 | full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype) 275 | 276 | # set the attention value between image patches to -inf 277 | full_mask[:identity_block_size, :identity_block_size] = float("-inf") 278 | 279 | # set the diagonal of the attention value between image patches to 0 280 | full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0) 281 | 282 | # expand the mask to match the attention weights shape 283 | full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions 284 | 285 | hidden_states_ptb = F.scaled_dot_product_attention( 286 | query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False 287 | ) 288 | hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 289 | hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype) 290 | 291 | # split the attention outputs. 292 | hidden_states_ptb, encoder_hidden_states_ptb = ( 293 | hidden_states_ptb[:, : residual.shape[1]], 294 | hidden_states_ptb[:, residual.shape[1] :], 295 | ) 296 | 297 | # linear proj 298 | hidden_states_ptb = attn.to_out[0](hidden_states_ptb) 299 | # dropout 300 | hidden_states_ptb = attn.to_out[1](hidden_states_ptb) 301 | if not attn.context_pre_only: 302 | encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb) 303 | 304 | if input_ndim == 4: 305 | hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) 306 | if context_input_ndim == 4: 307 | encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape( 308 | batch_size, channel, height, width 309 | ) 310 | 311 | ################ concat ############### 312 | hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) 313 | encoder_hidden_states = torch.cat([encoder_hidden_states_org[:2 * origin_batch_size], encoder_hidden_states_ptb, encoder_hidden_states_org[-origin_batch_size:]]) 314 | 315 | return hidden_states, encoder_hidden_states -------------------------------------------------------------------------------- /nag/attention_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from diffusers.utils import deprecate 7 | from diffusers.models.attention_processor import Attention 8 | 9 | 10 | class NAGAttnProcessor2_0: 11 | def __init__(self, nag_scale: float = 1.0, nag_tau: float = 2.5, nag_alpha:float = 0.5): 12 | if not hasattr(F, "scaled_dot_product_attention"): 13 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 14 | self.nag_scale = nag_scale 15 | self.nag_tau = nag_tau 16 | self.nag_alpha = nag_alpha 17 | 18 | def __call__( 19 | self, 20 | attn: Attention, 21 | hidden_states: torch.Tensor, 22 | encoder_hidden_states: Optional[torch.Tensor] = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | temb: Optional[torch.Tensor] = None, 25 | *args, 26 | **kwargs, 27 | ) -> torch.Tensor: 28 | if len(args) > 0 or kwargs.get("scale", None) is not None: 29 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." 30 | deprecate("scale", "1.0.0", deprecation_message) 31 | 32 | residual = hidden_states 33 | if attn.spatial_norm is not None: 34 | hidden_states = attn.spatial_norm(hidden_states, temb) 35 | 36 | input_ndim = hidden_states.ndim 37 | 38 | if input_ndim == 4: 39 | batch_size, channel, height, width = hidden_states.shape 40 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 41 | 42 | batch_size, sequence_length, _ = ( 43 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 44 | ) 45 | 46 | apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None 47 | if apply_guidance: 48 | origin_batch_size = batch_size - len(hidden_states) 49 | assert batch_size / origin_batch_size in [2, 3, 4] 50 | 51 | if attention_mask is not None: 52 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 53 | # scaled_dot_product_attention expects attention_mask shape to be 54 | # (batch, heads, source_length, target_length) 55 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 56 | 57 | if attn.group_norm is not None: 58 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 59 | 60 | query = attn.to_q(hidden_states) 61 | 62 | if encoder_hidden_states is None: 63 | encoder_hidden_states = hidden_states 64 | elif attn.norm_cross: 65 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 66 | 67 | key = attn.to_k(encoder_hidden_states) 68 | value = attn.to_v(encoder_hidden_states) 69 | 70 | inner_dim = key.shape[-1] 71 | head_dim = inner_dim // attn.heads 72 | 73 | if apply_guidance: 74 | if batch_size == 2 * origin_batch_size: 75 | query = query.tile(2, 1, 1) 76 | else: 77 | query = torch.cat((query, query[origin_batch_size:2 * origin_batch_size]), dim=0) 78 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 79 | 80 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 81 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 82 | 83 | if attn.norm_q is not None: 84 | query = attn.norm_q(query) 85 | if attn.norm_k is not None: 86 | key = attn.norm_k(key) 87 | 88 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 89 | # TODO: add support for attn.scale when we move to Torch 2.1 90 | hidden_states = F.scaled_dot_product_attention( 91 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 92 | ) 93 | 94 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 95 | hidden_states = hidden_states.to(query.dtype) 96 | 97 | if apply_guidance: 98 | hidden_states_negative = hidden_states[-origin_batch_size:] 99 | if batch_size == 2 * origin_batch_size: 100 | hidden_states_positive = hidden_states[:origin_batch_size] 101 | else: 102 | hidden_states_positive = hidden_states[origin_batch_size:2 * origin_batch_size] 103 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1) 104 | norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) 105 | norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape) 106 | 107 | scale = norm_guidance / norm_positive 108 | hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale 109 | 110 | hidden_states_guidance = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha) 111 | 112 | if batch_size == 2 * origin_batch_size: 113 | hidden_states = hidden_states_guidance 114 | elif batch_size == 3 * origin_batch_size: 115 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance), dim=0) 116 | elif batch_size == 4 * origin_batch_size: 117 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance, hidden_states[2 * origin_batch_size:3 * origin_batch_size]), dim=0) 118 | 119 | # linear proj 120 | hidden_states = attn.to_out[0](hidden_states) 121 | # dropout 122 | hidden_states = attn.to_out[1](hidden_states) 123 | 124 | if input_ndim == 4: 125 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 126 | 127 | if attn.residual_connection: 128 | hidden_states = hidden_states + residual 129 | 130 | hidden_states = hidden_states / attn.rescale_output_factor 131 | 132 | return hidden_states 133 | -------------------------------------------------------------------------------- /nag/attention_wan_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from diffusers.models.attention_processor import Attention 7 | from ftfy import apply_plan 8 | 9 | 10 | class NAGWanAttnProcessor2_0: 11 | def __init__(self, nag_scale=1.0, nag_tau=2.5, nag_alpha=0.25): 12 | if not hasattr(F, "scaled_dot_product_attention"): 13 | raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") 14 | self.nag_scale = nag_scale 15 | self.nag_tau = nag_tau 16 | self.nag_alpha = nag_alpha 17 | 18 | def __call__( 19 | self, 20 | attn: Attention, 21 | hidden_states: torch.Tensor, 22 | encoder_hidden_states: Optional[torch.Tensor] = None, 23 | attention_mask: Optional[torch.Tensor] = None, 24 | rotary_emb: Optional[torch.Tensor] = None, 25 | ) -> torch.Tensor: 26 | apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None 27 | if apply_guidance: 28 | if len(encoder_hidden_states) == 2 * len(hidden_states): 29 | batch_size = len(hidden_states) 30 | else: 31 | apply_guidance = False 32 | 33 | encoder_hidden_states_img = None 34 | if attn.add_k_proj is not None: 35 | encoder_hidden_states_img = encoder_hidden_states[:, :257] 36 | encoder_hidden_states = encoder_hidden_states[:, 257:] 37 | if apply_guidance: 38 | encoder_hidden_states_img = encoder_hidden_states_img[:batch_size] 39 | if encoder_hidden_states is None: 40 | encoder_hidden_states = hidden_states 41 | 42 | query = attn.to_q(hidden_states) 43 | key = attn.to_k(encoder_hidden_states) 44 | value = attn.to_v(encoder_hidden_states) 45 | 46 | if attn.norm_q is not None: 47 | query = attn.norm_q(query) 48 | if attn.norm_k is not None: 49 | key = attn.norm_k(key) 50 | 51 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) 52 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) 53 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) 54 | 55 | if rotary_emb is not None: 56 | 57 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): 58 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) 59 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) 60 | return x_out.type_as(hidden_states) 61 | 62 | query = apply_rotary_emb(query, rotary_emb) 63 | key = apply_rotary_emb(key, rotary_emb) 64 | 65 | # I2V task 66 | hidden_states_img = None 67 | if encoder_hidden_states_img is not None: 68 | key_img = attn.add_k_proj(encoder_hidden_states_img) 69 | key_img = attn.norm_added_k(key_img) 70 | value_img = attn.add_v_proj(encoder_hidden_states_img) 71 | 72 | key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) 73 | value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) 74 | 75 | hidden_states_img = F.scaled_dot_product_attention( 76 | query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False 77 | ) 78 | hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) 79 | hidden_states_img = hidden_states_img.type_as(query) 80 | 81 | if apply_guidance: 82 | key, key_negative = torch.chunk(key, 2, dim=0) 83 | value, value_negative = torch.chunk(value, 2, dim=0) 84 | hidden_states = F.scaled_dot_product_attention( 85 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 86 | ) 87 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) 88 | hidden_states = hidden_states.type_as(query) 89 | if apply_guidance: 90 | hidden_states_negative = F.scaled_dot_product_attention( 91 | query, key_negative, value_negative, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 92 | ) 93 | hidden_states_negative = hidden_states_negative.transpose(1, 2).flatten(2, 3) 94 | hidden_states_negative = hidden_states_negative.type_as(query) 95 | 96 | hidden_states_positive = hidden_states 97 | 98 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1) 99 | norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape) 100 | norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape) 101 | 102 | scale = norm_guidance / norm_positive 103 | scale = torch.nan_to_num(scale, 10) 104 | hidden_states_guidance[scale > self.nag_tau] = \ 105 | hidden_states_guidance[scale > self.nag_tau] / (norm_guidance[scale > self.nag_tau] + 1e-7) * norm_positive[scale > self.nag_tau] * self.nag_tau 106 | 107 | hidden_states = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha) 108 | 109 | if hidden_states_img is not None: 110 | hidden_states = hidden_states + hidden_states_img 111 | 112 | hidden_states = attn.to_out[0](hidden_states) 113 | hidden_states = attn.to_out[1](hidden_states) 114 | return hidden_states 115 | -------------------------------------------------------------------------------- /nag/normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX 5 | 6 | 7 | class TruncAdaLayerNorm(AdaLayerNorm): 8 | def forward( 9 | self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None 10 | ) -> torch.Tensor: 11 | batch_size = x.shape[0] 12 | return self.forward_old( 13 | x, 14 | temb[:batch_size] if temb is not None else None, 15 | ) 16 | 17 | 18 | class TruncAdaLayerNormContinuous(AdaLayerNormContinuous): 19 | def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor: 20 | batch_size = x.shape[0] 21 | return self.forward_old(x, conditioning_embedding[:batch_size]) 22 | 23 | 24 | class TruncAdaLayerNormZero(AdaLayerNormZero): 25 | def forward( 26 | self, 27 | x: torch.Tensor, 28 | timestep: Optional[torch.Tensor] = None, 29 | class_labels: Optional[torch.LongTensor] = None, 30 | hidden_dtype: Optional[torch.dtype] = None, 31 | emb: Optional[torch.Tensor] = None, 32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 33 | batch_size = x.shape[0] 34 | return self.forward_old( 35 | x, 36 | timestep[:batch_size] if timestep is not None else None, 37 | class_labels[:batch_size] if class_labels is not None else None, 38 | hidden_dtype, 39 | emb[:batch_size] if emb is not None else None, 40 | ) 41 | 42 | 43 | class TruncSD35AdaLayerNormZeroX(SD35AdaLayerNormZeroX): 44 | def forward( 45 | self, 46 | hidden_states: torch.Tensor, 47 | emb: Optional[torch.Tensor] = None, 48 | ) -> Tuple[torch.Tensor, ...]: 49 | batch_size = hidden_states.shape[0] 50 | return self.forward_old( 51 | hidden_states, 52 | emb[:batch_size] if emb is not None else None, 53 | ) 54 | -------------------------------------------------------------------------------- /nag/pipeline_flux_kontext_nag.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, List, Optional, Dict, Any, Callable 3 | import types 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from diffusers import FluxKontextPipeline 9 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps 10 | from diffusers.image_processor import PipelineImageInput 11 | from diffusers.utils import is_torch_xla_available, logging 12 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 13 | from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS 14 | from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormContinuous 15 | 16 | from nag.attention_flux_nag import NAGFluxAttnProcessor2_0 17 | from nag.normalization import TruncAdaLayerNormZero, TruncAdaLayerNormContinuous 18 | 19 | if is_torch_xla_available(): 20 | import torch_xla.core.xla_model as xm 21 | 22 | XLA_AVAILABLE = True 23 | else: 24 | XLA_AVAILABLE = False 25 | 26 | 27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 28 | 29 | 30 | class NAGFluxKontextPipeline(FluxKontextPipeline): 31 | @property 32 | def do_normalized_attention_guidance(self): 33 | return self._nag_scale > 1 34 | 35 | def _set_nag_attn_processor( 36 | self, 37 | nag_scale, 38 | encoder_hidden_states_length, 39 | nag_tau=2.5, 40 | nag_alpha=0.25, 41 | ): 42 | attn_procs = {} 43 | for name in self.transformer.attn_processors.keys(): 44 | attn_procs[name] = NAGFluxAttnProcessor2_0( 45 | nag_scale=nag_scale, 46 | nag_tau=nag_tau, 47 | nag_alpha=nag_alpha, 48 | encoder_hidden_states_length=encoder_hidden_states_length, 49 | ) 50 | self.transformer.set_attn_processor(attn_procs) 51 | 52 | @torch.no_grad() 53 | def __call__( 54 | self, 55 | image: Optional[PipelineImageInput] = None, 56 | prompt: Union[str, List[str]] = None, 57 | prompt_2: Optional[Union[str, List[str]]] = None, 58 | negative_prompt: Union[str, List[str]] = None, 59 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 60 | true_cfg_scale: float = 1.0, 61 | height: Optional[int] = None, 62 | width: Optional[int] = None, 63 | num_inference_steps: int = 28, 64 | sigmas: Optional[List[float]] = None, 65 | guidance_scale: float = 3.5, 66 | num_images_per_prompt: Optional[int] = 1, 67 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 68 | latents: Optional[torch.FloatTensor] = None, 69 | prompt_embeds: Optional[torch.FloatTensor] = None, 70 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 71 | ip_adapter_image: Optional[PipelineImageInput] = None, 72 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 73 | negative_ip_adapter_image: Optional[PipelineImageInput] = None, 74 | negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 75 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 76 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 77 | output_type: Optional[str] = "pil", 78 | return_dict: bool = True, 79 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 80 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 81 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 82 | max_sequence_length: int = 512, 83 | max_area: int = 1024 ** 2, 84 | _auto_resize: bool = True, 85 | 86 | nag_scale: float = 1.0, 87 | nag_tau: float = 2.5, 88 | nag_alpha: float = 0.25, 89 | nag_end: float = 0.25, 90 | nag_negative_prompt: str = None, 91 | nag_negative_prompt_2: str = None, 92 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None, 93 | nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, 94 | ): 95 | r""" 96 | Function invoked when calling the pipeline for generation. 97 | 98 | Args: 99 | image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): 100 | `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both 101 | numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list 102 | or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a 103 | list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image 104 | latents as `image`, but if passing latents directly it is not encoded again. 105 | prompt (`str` or `List[str]`, *optional*): 106 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 107 | instead. 108 | prompt_2 (`str` or `List[str]`, *optional*): 109 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 110 | will be used instead. 111 | negative_prompt (`str` or `List[str]`, *optional*): 112 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 113 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is 114 | not greater than `1`). 115 | negative_prompt_2 (`str` or `List[str]`, *optional*): 116 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 117 | `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. 118 | true_cfg_scale (`float`, *optional*, defaults to 1.0): 119 | When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. 120 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 121 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 122 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 123 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 124 | num_inference_steps (`int`, *optional*, defaults to 50): 125 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 126 | expense of slower inference. 127 | sigmas (`List[float]`, *optional*): 128 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 129 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 130 | will be used. 131 | guidance_scale (`float`, *optional*, defaults to 3.5): 132 | Guidance scale as defined in [Classifier-Free Diffusion 133 | Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. 134 | of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting 135 | `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to 136 | the text `prompt`, usually at the expense of lower image quality. 137 | num_images_per_prompt (`int`, *optional*, defaults to 1): 138 | The number of images to generate per prompt. 139 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 140 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 141 | to make generation deterministic. 142 | latents (`torch.FloatTensor`, *optional*): 143 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 144 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 145 | tensor will ge generated by sampling using the supplied random `generator`. 146 | prompt_embeds (`torch.FloatTensor`, *optional*): 147 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 148 | provided, text embeddings will be generated from `prompt` input argument. 149 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 150 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 151 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 152 | ip_adapter_image: (`PipelineImageInput`, *optional*): 153 | Optional image input to work with IP Adapters. 154 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 155 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 156 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not 157 | provided, embeddings are computed from the `ip_adapter_image` input argument. 158 | negative_ip_adapter_image: 159 | (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 160 | negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 161 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 162 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not 163 | provided, embeddings are computed from the `ip_adapter_image` input argument. 164 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 165 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 166 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 167 | argument. 168 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 169 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 170 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 171 | input argument. 172 | output_type (`str`, *optional*, defaults to `"pil"`): 173 | The output format of the generate image. Choose between 174 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 175 | return_dict (`bool`, *optional*, defaults to `True`): 176 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 177 | joint_attention_kwargs (`dict`, *optional*): 178 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 179 | `self.processor` in 180 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 181 | callback_on_step_end (`Callable`, *optional*): 182 | A function that calls at the end of each denoising steps during the inference. The function is called 183 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 184 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 185 | `callback_on_step_end_tensor_inputs`. 186 | callback_on_step_end_tensor_inputs (`List`, *optional*): 187 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 188 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 189 | `._callback_tensor_inputs` attribute of your pipeline class. 190 | max_sequence_length (`int` defaults to 512): 191 | Maximum sequence length to use with the `prompt`. 192 | max_area (`int`, defaults to `1024 ** 2`): 193 | The maximum area of the generated image in pixels. The height and width will be adjusted to fit this 194 | area while maintaining the aspect ratio. 195 | 196 | Examples: 197 | 198 | Returns: 199 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 200 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 201 | images. 202 | """ 203 | 204 | height = height or self.default_sample_size * self.vae_scale_factor 205 | width = width or self.default_sample_size * self.vae_scale_factor 206 | 207 | original_height, original_width = height, width 208 | aspect_ratio = width / height 209 | width = round((max_area * aspect_ratio) ** 0.5) 210 | height = round((max_area / aspect_ratio) ** 0.5) 211 | 212 | multiple_of = self.vae_scale_factor * 2 213 | width = width // multiple_of * multiple_of 214 | height = height // multiple_of * multiple_of 215 | 216 | if height != original_height or width != original_width: 217 | logger.warning( 218 | f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements." 219 | ) 220 | 221 | # 1. Check inputs. Raise error if not correct 222 | self.check_inputs( 223 | prompt, 224 | prompt_2, 225 | height, 226 | width, 227 | negative_prompt=negative_prompt, 228 | negative_prompt_2=negative_prompt_2, 229 | prompt_embeds=prompt_embeds, 230 | negative_prompt_embeds=negative_prompt_embeds, 231 | pooled_prompt_embeds=pooled_prompt_embeds, 232 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 233 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 234 | max_sequence_length=max_sequence_length, 235 | ) 236 | 237 | self._guidance_scale = guidance_scale 238 | self._joint_attention_kwargs = joint_attention_kwargs 239 | self._current_timestep = None 240 | self._interrupt = False 241 | self._nag_scale = nag_scale 242 | 243 | # 2. Define call parameters 244 | if prompt is not None and isinstance(prompt, str): 245 | batch_size = 1 246 | elif prompt is not None and isinstance(prompt, list): 247 | batch_size = len(prompt) 248 | else: 249 | batch_size = prompt_embeds.shape[0] 250 | 251 | device = self._execution_device 252 | 253 | lora_scale = ( 254 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 255 | ) 256 | has_neg_prompt = negative_prompt is not None or ( 257 | negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None 258 | ) 259 | do_true_cfg = true_cfg_scale > 1 and has_neg_prompt 260 | ( 261 | prompt_embeds, 262 | pooled_prompt_embeds, 263 | text_ids, 264 | ) = self.encode_prompt( 265 | prompt=prompt, 266 | prompt_2=prompt_2, 267 | prompt_embeds=prompt_embeds, 268 | pooled_prompt_embeds=pooled_prompt_embeds, 269 | device=device, 270 | num_images_per_prompt=num_images_per_prompt, 271 | max_sequence_length=max_sequence_length, 272 | lora_scale=lora_scale, 273 | ) 274 | if do_true_cfg: 275 | ( 276 | negative_prompt_embeds, 277 | negative_pooled_prompt_embeds, 278 | negative_text_ids, 279 | ) = self.encode_prompt( 280 | prompt=negative_prompt, 281 | prompt_2=negative_prompt_2, 282 | prompt_embeds=negative_prompt_embeds, 283 | pooled_prompt_embeds=negative_pooled_prompt_embeds, 284 | device=device, 285 | num_images_per_prompt=num_images_per_prompt, 286 | max_sequence_length=max_sequence_length, 287 | lora_scale=lora_scale, 288 | ) 289 | 290 | if self.do_normalized_attention_guidance: 291 | if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None: 292 | if nag_negative_prompt is None: 293 | if negative_prompt is not None: 294 | if do_true_cfg: 295 | nag_negative_prompt_embeds = negative_prompt_embeds 296 | nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds 297 | else: 298 | nag_negative_prompt = negative_prompt 299 | nag_negative_prompt_2 = negative_prompt_2 300 | else: 301 | nag_negative_prompt = "" 302 | 303 | if nag_negative_prompt is not None: 304 | nag_negative_prompt_embeds, nag_negative_pooled_prompt_embeds = self.encode_prompt( 305 | prompt=nag_negative_prompt, 306 | prompt_2=nag_negative_prompt_2, 307 | device=device, 308 | num_images_per_prompt=num_images_per_prompt, 309 | max_sequence_length=max_sequence_length, 310 | lora_scale=lora_scale, 311 | )[:2] 312 | 313 | if self.do_normalized_attention_guidance: 314 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, nag_negative_pooled_prompt_embeds], dim=0) 315 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0) 316 | 317 | # 3. Preprocess image 318 | if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels): 319 | img = image[0] if isinstance(image, list) else image 320 | image_height, image_width = self.image_processor.get_default_height_width(img) 321 | aspect_ratio = image_width / image_height 322 | if _auto_resize: 323 | # Kontext is trained on specific resolutions, using one of them is recommended 324 | _, image_width, image_height = min( 325 | (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS 326 | ) 327 | image_width = image_width // multiple_of * multiple_of 328 | image_height = image_height // multiple_of * multiple_of 329 | image = self.image_processor.resize(image, image_height, image_width) 330 | image = self.image_processor.preprocess(image, image_height, image_width) 331 | 332 | # 4. Prepare latent variables 333 | num_channels_latents = self.transformer.config.in_channels // 4 334 | latents, image_latents, latent_ids, image_ids = self.prepare_latents( 335 | image, 336 | batch_size * num_images_per_prompt, 337 | num_channels_latents, 338 | height, 339 | width, 340 | prompt_embeds.dtype, 341 | device, 342 | generator, 343 | latents, 344 | ) 345 | if image_ids is not None: 346 | latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension 347 | 348 | # 5. Prepare timesteps 349 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas 350 | image_seq_len = latents.shape[1] 351 | mu = calculate_shift( 352 | image_seq_len, 353 | self.scheduler.config.get("base_image_seq_len", 256), 354 | self.scheduler.config.get("max_image_seq_len", 4096), 355 | self.scheduler.config.get("base_shift", 0.5), 356 | self.scheduler.config.get("max_shift", 1.15), 357 | ) 358 | timesteps, num_inference_steps = retrieve_timesteps( 359 | self.scheduler, 360 | num_inference_steps, 361 | device, 362 | sigmas=sigmas, 363 | mu=mu, 364 | ) 365 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 366 | self._num_timesteps = len(timesteps) 367 | 368 | # handle guidance 369 | if self.transformer.config.guidance_embeds: 370 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 371 | guidance = guidance.expand(prompt_embeds.shape[0]) 372 | else: 373 | guidance = None 374 | 375 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( 376 | negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None 377 | ): 378 | negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) 379 | negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters 380 | 381 | elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( 382 | negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None 383 | ): 384 | ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) 385 | ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters 386 | 387 | if self.joint_attention_kwargs is None: 388 | self._joint_attention_kwargs = {} 389 | 390 | image_embeds = None 391 | negative_image_embeds = None 392 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 393 | image_embeds = self.prepare_ip_adapter_image_embeds( 394 | ip_adapter_image, 395 | ip_adapter_image_embeds, 396 | device, 397 | batch_size * num_images_per_prompt, 398 | ) 399 | if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: 400 | negative_image_embeds = self.prepare_ip_adapter_image_embeds( 401 | negative_ip_adapter_image, 402 | negative_ip_adapter_image_embeds, 403 | device, 404 | batch_size * num_images_per_prompt, 405 | ) 406 | 407 | origin_attn_procs = self.transformer.attn_processors 408 | if self.do_normalized_attention_guidance: 409 | self._set_nag_attn_processor(nag_scale, prompt_embeds.shape[1], nag_tau, nag_alpha) 410 | attn_procs_recovered = False 411 | 412 | for sub_mod in self.transformer.modules(): 413 | if not hasattr(sub_mod, "forward_old") : 414 | sub_mod.forward_old = sub_mod.forward 415 | if isinstance(sub_mod, AdaLayerNormZero): 416 | sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod) 417 | elif isinstance(sub_mod, AdaLayerNormContinuous): 418 | sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod) 419 | 420 | 421 | # 6. Denoising loop 422 | # We set the index here to remove DtoH sync, helpful especially during compilation. 423 | # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 424 | self.scheduler.set_begin_index(0) 425 | with self.progress_bar(total=num_inference_steps) as progress_bar: 426 | for i, t in enumerate(timesteps): 427 | if self.interrupt: 428 | continue 429 | 430 | if t < (1 - nag_end) * 1000 and self.do_normalized_attention_guidance and not attn_procs_recovered: 431 | self.transformer.set_attn_processor(origin_attn_procs) 432 | if guidance is not None: 433 | guidance = guidance[:len(latents)] 434 | pooled_prompt_embeds = pooled_prompt_embeds[:len(latents)] 435 | prompt_embeds = prompt_embeds[:len(latents)] 436 | attn_procs_recovered = True 437 | 438 | self._current_timestep = t 439 | if image_embeds is not None: 440 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds 441 | 442 | latent_model_input = latents 443 | if image_latents is not None: 444 | latent_model_input = torch.cat([latents, image_latents], dim=1) 445 | timestep = t.expand(prompt_embeds.shape[0]).to(latents.dtype) 446 | 447 | noise_pred = self.transformer( 448 | hidden_states=latent_model_input, 449 | timestep=timestep / 1000, 450 | guidance=guidance, 451 | pooled_projections=pooled_prompt_embeds, 452 | encoder_hidden_states=prompt_embeds, 453 | txt_ids=text_ids, 454 | img_ids=latent_ids, 455 | joint_attention_kwargs=self.joint_attention_kwargs, 456 | return_dict=False, 457 | )[0] 458 | noise_pred = noise_pred[:, : latents.size(1)] 459 | 460 | if do_true_cfg: 461 | if negative_image_embeds is not None: 462 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds 463 | neg_noise_pred = self.transformer( 464 | hidden_states=latent_model_input, 465 | timestep=timestep / 1000, 466 | guidance=guidance, 467 | pooled_projections=negative_pooled_prompt_embeds, 468 | encoder_hidden_states=negative_prompt_embeds, 469 | txt_ids=negative_text_ids, 470 | img_ids=latent_ids, 471 | joint_attention_kwargs=self.joint_attention_kwargs, 472 | return_dict=False, 473 | )[0] 474 | neg_noise_pred = neg_noise_pred[:, : latents.size(1)] 475 | noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) 476 | 477 | # compute the previous noisy sample x_t -> x_t-1 478 | latents_dtype = latents.dtype 479 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 480 | 481 | if latents.dtype != latents_dtype: 482 | if torch.backends.mps.is_available(): 483 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 484 | latents = latents.to(latents_dtype) 485 | 486 | if callback_on_step_end is not None: 487 | callback_kwargs = {} 488 | for k in callback_on_step_end_tensor_inputs: 489 | callback_kwargs[k] = locals()[k] 490 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 491 | 492 | latents = callback_outputs.pop("latents", latents) 493 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 494 | 495 | # call the callback, if provided 496 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 497 | progress_bar.update() 498 | 499 | if XLA_AVAILABLE: 500 | xm.mark_step() 501 | 502 | self._current_timestep = None 503 | 504 | if output_type == "latent": 505 | image = latents 506 | else: 507 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 508 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 509 | image = self.vae.decode(latents, return_dict=False)[0] 510 | image = self.image_processor.postprocess(image, output_type=output_type) 511 | 512 | if self.do_normalized_attention_guidance and not attn_procs_recovered: 513 | self.transformer.set_attn_processor(origin_attn_procs) 514 | 515 | # Offload all models 516 | self.maybe_free_model_hooks() 517 | 518 | if not return_dict: 519 | return (image,) 520 | 521 | return FluxPipelineOutput(images=image) 522 | -------------------------------------------------------------------------------- /nag/pipeline_flux_nag.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Union, List, Optional, Dict, Any, Callable 3 | import types 4 | 5 | import numpy as np 6 | import torch 7 | 8 | from diffusers import FluxPipeline 9 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps 10 | from diffusers.image_processor import PipelineImageInput 11 | from diffusers.utils import is_torch_xla_available 12 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 13 | from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormContinuous 14 | 15 | from nag.attention_flux_nag import NAGFluxAttnProcessor2_0 16 | from nag.normalization import TruncAdaLayerNormZero, TruncAdaLayerNormContinuous 17 | 18 | if is_torch_xla_available(): 19 | import torch_xla.core.xla_model as xm 20 | 21 | XLA_AVAILABLE = True 22 | else: 23 | XLA_AVAILABLE = False 24 | 25 | 26 | class NAGFluxPipeline(FluxPipeline): 27 | @property 28 | def do_normalized_attention_guidance(self): 29 | return self._nag_scale > 1 30 | 31 | def _set_nag_attn_processor( 32 | self, 33 | nag_scale, 34 | encoder_hidden_states_length, 35 | nag_tau=2.5, 36 | nag_alpha=0.25, 37 | ): 38 | attn_procs = {} 39 | for name in self.transformer.attn_processors.keys(): 40 | attn_procs[name] = NAGFluxAttnProcessor2_0( 41 | nag_scale=nag_scale, 42 | nag_tau=nag_tau, 43 | nag_alpha=nag_alpha, 44 | encoder_hidden_states_length=encoder_hidden_states_length, 45 | ) 46 | self.transformer.set_attn_processor(attn_procs) 47 | 48 | @torch.no_grad() 49 | def __call__( 50 | self, 51 | prompt: Union[str, List[str]] = None, 52 | prompt_2: Optional[Union[str, List[str]]] = None, 53 | negative_prompt: Union[str, List[str]] = None, 54 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 55 | true_cfg_scale: float = 1.0, 56 | height: Optional[int] = None, 57 | width: Optional[int] = None, 58 | num_inference_steps: int = 28, 59 | sigmas: Optional[List[float]] = None, 60 | guidance_scale: float = 3.5, 61 | num_images_per_prompt: Optional[int] = 1, 62 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 63 | latents: Optional[torch.FloatTensor] = None, 64 | prompt_embeds: Optional[torch.FloatTensor] = None, 65 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 66 | ip_adapter_image: Optional[PipelineImageInput] = None, 67 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 68 | negative_ip_adapter_image: Optional[PipelineImageInput] = None, 69 | negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 70 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 71 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 72 | output_type: Optional[str] = "pil", 73 | return_dict: bool = True, 74 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 75 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 76 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 77 | max_sequence_length: int = 512, 78 | 79 | nag_scale: float = 1.0, 80 | nag_tau: float = 2.5, 81 | nag_alpha: float = 0.25, 82 | nag_end: float = 1.0, 83 | nag_negative_prompt: str = None, 84 | nag_negative_prompt_2: str = None, 85 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None, 86 | nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, 87 | ): 88 | r""" 89 | Function invoked when calling the pipeline for generation. 90 | 91 | Args: 92 | prompt (`str` or `List[str]`, *optional*): 93 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 94 | instead. 95 | prompt_2 (`str` or `List[str]`, *optional*): 96 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 97 | will be used instead. 98 | negative_prompt (`str` or `List[str]`, *optional*): 99 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 100 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is 101 | not greater than `1`). 102 | negative_prompt_2 (`str` or `List[str]`, *optional*): 103 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 104 | `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. 105 | true_cfg_scale (`float`, *optional*, defaults to 1.0): 106 | When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. 107 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 108 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 109 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 110 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 111 | num_inference_steps (`int`, *optional*, defaults to 50): 112 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 113 | expense of slower inference. 114 | sigmas (`List[float]`, *optional*): 115 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 116 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 117 | will be used. 118 | guidance_scale (`float`, *optional*, defaults to 7.0): 119 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 120 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 121 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 122 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 123 | usually at the expense of lower image quality. 124 | num_images_per_prompt (`int`, *optional*, defaults to 1): 125 | The number of images to generate per prompt. 126 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 127 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 128 | to make generation deterministic. 129 | latents (`torch.FloatTensor`, *optional*): 130 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 131 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 132 | tensor will ge generated by sampling using the supplied random `generator`. 133 | prompt_embeds (`torch.FloatTensor`, *optional*): 134 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 135 | provided, text embeddings will be generated from `prompt` input argument. 136 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 137 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 138 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 139 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 140 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 141 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 142 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not 143 | provided, embeddings are computed from the `ip_adapter_image` input argument. 144 | negative_ip_adapter_image: 145 | (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 146 | negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 147 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 148 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not 149 | provided, embeddings are computed from the `ip_adapter_image` input argument. 150 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 151 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 152 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 153 | argument. 154 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 155 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 156 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 157 | input argument. 158 | output_type (`str`, *optional*, defaults to `"pil"`): 159 | The output format of the generate image. Choose between 160 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 161 | return_dict (`bool`, *optional*, defaults to `True`): 162 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 163 | joint_attention_kwargs (`dict`, *optional*): 164 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 165 | `self.processor` in 166 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 167 | callback_on_step_end (`Callable`, *optional*): 168 | A function that calls at the end of each denoising steps during the inference. The function is called 169 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 170 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 171 | `callback_on_step_end_tensor_inputs`. 172 | callback_on_step_end_tensor_inputs (`List`, *optional*): 173 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 174 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 175 | `._callback_tensor_inputs` attribute of your pipeline class. 176 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 177 | 178 | Examples: 179 | 180 | Returns: 181 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 182 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 183 | images. 184 | """ 185 | 186 | height = height or self.default_sample_size * self.vae_scale_factor 187 | width = width or self.default_sample_size * self.vae_scale_factor 188 | 189 | # 1. Check inputs. Raise error if not correct 190 | self.check_inputs( 191 | prompt, 192 | prompt_2, 193 | height, 194 | width, 195 | negative_prompt=negative_prompt, 196 | negative_prompt_2=negative_prompt_2, 197 | prompt_embeds=prompt_embeds, 198 | negative_prompt_embeds=negative_prompt_embeds, 199 | pooled_prompt_embeds=pooled_prompt_embeds, 200 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 201 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 202 | max_sequence_length=max_sequence_length, 203 | ) 204 | 205 | self._guidance_scale = guidance_scale 206 | self._joint_attention_kwargs = joint_attention_kwargs 207 | self._current_timestep = None 208 | self._interrupt = False 209 | self._nag_scale = nag_scale 210 | 211 | # 2. Define call parameters 212 | if prompt is not None and isinstance(prompt, str): 213 | batch_size = 1 214 | elif prompt is not None and isinstance(prompt, list): 215 | batch_size = len(prompt) 216 | else: 217 | batch_size = prompt_embeds.shape[0] 218 | 219 | device = self._execution_device 220 | 221 | lora_scale = ( 222 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 223 | ) 224 | has_neg_prompt = negative_prompt is not None or ( 225 | negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None 226 | ) 227 | do_true_cfg = true_cfg_scale > 1 and has_neg_prompt 228 | # do_true_cfg = do_true_cfg or self.do_normalized_attention_guidance 229 | 230 | ( 231 | prompt_embeds, 232 | pooled_prompt_embeds, 233 | text_ids, 234 | ) = self.encode_prompt( 235 | prompt=prompt, 236 | prompt_2=prompt_2, 237 | prompt_embeds=prompt_embeds, 238 | pooled_prompt_embeds=pooled_prompt_embeds, 239 | device=device, 240 | num_images_per_prompt=num_images_per_prompt, 241 | max_sequence_length=max_sequence_length, 242 | lora_scale=lora_scale, 243 | ) 244 | if do_true_cfg: 245 | ( 246 | negative_prompt_embeds, 247 | negative_pooled_prompt_embeds, 248 | _, 249 | ) = self.encode_prompt( 250 | prompt=negative_prompt, 251 | prompt_2=negative_prompt_2, 252 | prompt_embeds=negative_prompt_embeds, 253 | pooled_prompt_embeds=negative_pooled_prompt_embeds, 254 | device=device, 255 | num_images_per_prompt=num_images_per_prompt, 256 | max_sequence_length=max_sequence_length, 257 | lora_scale=lora_scale, 258 | ) 259 | 260 | if self.do_normalized_attention_guidance: 261 | if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None: 262 | if nag_negative_prompt is None: 263 | if negative_prompt is not None: 264 | if do_true_cfg: 265 | nag_negative_prompt_embeds = negative_prompt_embeds 266 | nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds 267 | else: 268 | nag_negative_prompt = negative_prompt 269 | nag_negative_prompt_2 = negative_prompt_2 270 | else: 271 | nag_negative_prompt = "" 272 | 273 | if nag_negative_prompt is not None: 274 | nag_negative_prompt_embeds, nag_negative_pooled_prompt_embeds = self.encode_prompt( 275 | prompt=nag_negative_prompt, 276 | prompt_2=nag_negative_prompt_2, 277 | device=device, 278 | num_images_per_prompt=num_images_per_prompt, 279 | max_sequence_length=max_sequence_length, 280 | lora_scale=lora_scale, 281 | )[:2] 282 | 283 | if self.do_normalized_attention_guidance: 284 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, nag_negative_pooled_prompt_embeds], dim=0) 285 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0) 286 | 287 | # 4. Prepare latent variables 288 | num_channels_latents = self.transformer.config.in_channels // 4 289 | latents, latent_image_ids = self.prepare_latents( 290 | batch_size * num_images_per_prompt, 291 | num_channels_latents, 292 | height, 293 | width, 294 | prompt_embeds.dtype, 295 | device, 296 | generator, 297 | latents, 298 | ) 299 | 300 | # 5. Prepare timesteps 301 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas 302 | image_seq_len = latents.shape[1] 303 | mu = calculate_shift( 304 | image_seq_len, 305 | self.scheduler.config.get("base_image_seq_len", 256), 306 | self.scheduler.config.get("max_image_seq_len", 4096), 307 | self.scheduler.config.get("base_shift", 0.5), 308 | self.scheduler.config.get("max_shift", 1.16), 309 | ) 310 | timesteps, num_inference_steps = retrieve_timesteps( 311 | self.scheduler, 312 | num_inference_steps, 313 | device, 314 | sigmas=sigmas, 315 | mu=mu, 316 | ) 317 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 318 | self._num_timesteps = len(timesteps) 319 | 320 | # handle guidance 321 | if self.transformer.config.guidance_embeds: 322 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 323 | guidance = guidance.expand(prompt_embeds.shape[0]) 324 | else: 325 | guidance = None 326 | 327 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and ( 328 | negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None 329 | ): 330 | negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) 331 | elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and ( 332 | negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None 333 | ): 334 | ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8) 335 | 336 | if self.joint_attention_kwargs is None: 337 | self._joint_attention_kwargs = {} 338 | 339 | image_embeds = None 340 | negative_image_embeds = None 341 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 342 | image_embeds = self.prepare_ip_adapter_image_embeds( 343 | ip_adapter_image, 344 | ip_adapter_image_embeds, 345 | device, 346 | batch_size * num_images_per_prompt, 347 | ) 348 | if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None: 349 | negative_image_embeds = self.prepare_ip_adapter_image_embeds( 350 | negative_ip_adapter_image, 351 | negative_ip_adapter_image_embeds, 352 | device, 353 | batch_size * num_images_per_prompt, 354 | ) 355 | 356 | origin_attn_procs = self.transformer.attn_processors 357 | if self.do_normalized_attention_guidance: 358 | self._set_nag_attn_processor(nag_scale, prompt_embeds.shape[1], nag_tau, nag_alpha) 359 | attn_procs_recovered = False 360 | 361 | for sub_mod in self.transformer.modules(): 362 | if not hasattr(sub_mod, "forward_old") : 363 | sub_mod.forward_old = sub_mod.forward 364 | if isinstance(sub_mod, AdaLayerNormZero): 365 | sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod) 366 | elif isinstance(sub_mod, AdaLayerNormContinuous): 367 | sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod) 368 | 369 | # 6. Denoising loop 370 | with self.progress_bar(total=num_inference_steps) as progress_bar: 371 | for i, t in enumerate(timesteps): 372 | if self.interrupt: 373 | continue 374 | 375 | if t < (1 - nag_end) * 1000 and self.do_normalized_attention_guidance and not attn_procs_recovered: 376 | self.transformer.set_attn_processor(origin_attn_procs) 377 | if guidance is not None: 378 | guidance = guidance[:len(latents)] 379 | pooled_prompt_embeds = pooled_prompt_embeds[:len(latents)] 380 | prompt_embeds = prompt_embeds[:len(latents)] 381 | attn_procs_recovered = True 382 | 383 | self._current_timestep = t 384 | if image_embeds is not None: 385 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds 386 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 387 | timestep = t.expand(prompt_embeds.shape[0]).to(latents.dtype) 388 | 389 | noise_pred = self.transformer( 390 | hidden_states=latents, 391 | timestep=timestep / 1000, 392 | guidance=guidance, 393 | pooled_projections=pooled_prompt_embeds, 394 | encoder_hidden_states=prompt_embeds, 395 | txt_ids=text_ids, 396 | img_ids=latent_image_ids, 397 | joint_attention_kwargs=self.joint_attention_kwargs, 398 | return_dict=False, 399 | )[0] 400 | 401 | if do_true_cfg: 402 | if negative_image_embeds is not None: 403 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds 404 | neg_noise_pred = self.transformer( 405 | hidden_states=latents, 406 | timestep=timestep / 1000, 407 | guidance=guidance, 408 | pooled_projections=negative_pooled_prompt_embeds, 409 | encoder_hidden_states=negative_prompt_embeds, 410 | txt_ids=text_ids, 411 | img_ids=latent_image_ids, 412 | joint_attention_kwargs=self.joint_attention_kwargs, 413 | return_dict=False, 414 | )[0] 415 | noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) 416 | 417 | # compute the previous noisy sample x_t -> x_t-1 418 | latents_dtype = latents.dtype 419 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 420 | 421 | if latents.dtype != latents_dtype: 422 | if torch.backends.mps.is_available(): 423 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 424 | latents = latents.to(latents_dtype) 425 | 426 | if callback_on_step_end is not None: 427 | callback_kwargs = {} 428 | for k in callback_on_step_end_tensor_inputs: 429 | callback_kwargs[k] = locals()[k] 430 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 431 | 432 | latents = callback_outputs.pop("latents", latents) 433 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 434 | 435 | # call the callback, if provided 436 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 437 | progress_bar.update() 438 | 439 | if XLA_AVAILABLE: 440 | xm.mark_step() 441 | 442 | self._current_timestep = None 443 | 444 | if output_type == "latent": 445 | image = latents 446 | else: 447 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 448 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 449 | image = self.vae.decode(latents, return_dict=False)[0] 450 | image = self.image_processor.postprocess(image, output_type=output_type) 451 | 452 | if self.do_normalized_attention_guidance and not attn_procs_recovered: 453 | self.transformer.set_attn_processor(origin_attn_procs) 454 | 455 | # Offload all models 456 | self.maybe_free_model_hooks() 457 | 458 | if not return_dict: 459 | return (image,) 460 | 461 | return FluxPipelineOutput(images=image) -------------------------------------------------------------------------------- /nag/pipeline_sd3_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | import types 3 | 4 | import torch 5 | 6 | from diffusers.image_processor import PipelineImageInput 7 | from diffusers.utils import ( 8 | is_torch_xla_available, 9 | logging, 10 | ) 11 | from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput 12 | from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import ( 13 | StableDiffusion3Pipeline, 14 | retrieve_timesteps, 15 | calculate_shift, 16 | ) 17 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX 18 | 19 | from nag.attention_joint_nag import NAGJointAttnProcessor2_0 20 | from nag.normalization import TruncAdaLayerNorm, TruncAdaLayerNormZero, TruncAdaLayerNormContinuous, TruncSD35AdaLayerNormZeroX 21 | 22 | 23 | if is_torch_xla_available(): 24 | import torch_xla.core.xla_model as xm 25 | 26 | XLA_AVAILABLE = True 27 | else: 28 | XLA_AVAILABLE = False 29 | 30 | 31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 32 | 33 | 34 | class NAGStableDiffusion3Pipeline(StableDiffusion3Pipeline): 35 | @property 36 | def do_normalized_attention_guidance(self): 37 | return self._nag_scale > 1 38 | 39 | def _set_nag_attn_processor(self, nag_scale, nag_tau, nag_alpha): 40 | attn_procs = {} 41 | for name in self.transformer.attn_processors.keys(): 42 | attn_procs[name] = NAGJointAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha) 43 | self.transformer.set_attn_processor(attn_procs) 44 | 45 | @torch.no_grad() 46 | def __call__( 47 | self, 48 | prompt: Union[str, List[str]] = None, 49 | prompt_2: Optional[Union[str, List[str]]] = None, 50 | prompt_3: Optional[Union[str, List[str]]] = None, 51 | height: Optional[int] = None, 52 | width: Optional[int] = None, 53 | num_inference_steps: int = 28, 54 | sigmas: Optional[List[float]] = None, 55 | guidance_scale: float = 7.0, 56 | negative_prompt: Optional[Union[str, List[str]]] = None, 57 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 58 | negative_prompt_3: Optional[Union[str, List[str]]] = None, 59 | num_images_per_prompt: Optional[int] = 1, 60 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 61 | latents: Optional[torch.FloatTensor] = None, 62 | prompt_embeds: Optional[torch.FloatTensor] = None, 63 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 64 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 65 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 66 | ip_adapter_image: Optional[PipelineImageInput] = None, 67 | ip_adapter_image_embeds: Optional[torch.Tensor] = None, 68 | output_type: Optional[str] = "pil", 69 | return_dict: bool = True, 70 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 71 | clip_skip: Optional[int] = None, 72 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 73 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 74 | max_sequence_length: int = 256, 75 | skip_guidance_layers: List[int] = None, 76 | skip_layer_guidance_scale: float = 2.8, 77 | skip_layer_guidance_stop: float = 0.2, 78 | skip_layer_guidance_start: float = 0.01, 79 | mu: Optional[float] = None, 80 | 81 | nag_scale: float = 1.0, 82 | nag_tau: float = 2.5, 83 | nag_alpha: float = 0.125, 84 | nag_negative_prompt: str = None, 85 | nag_negative_prompt_2: str = None, 86 | nag_negative_prompt_3: str = None, 87 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None, 88 | nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, 89 | ): 90 | r""" 91 | Function invoked when calling the pipeline for generation. 92 | 93 | Args: 94 | prompt (`str` or `List[str]`, *optional*): 95 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 96 | instead. 97 | prompt_2 (`str` or `List[str]`, *optional*): 98 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 99 | will be used instead 100 | prompt_3 (`str` or `List[str]`, *optional*): 101 | The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is 102 | will be used instead 103 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 104 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 105 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 106 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 107 | num_inference_steps (`int`, *optional*, defaults to 50): 108 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 109 | expense of slower inference. 110 | sigmas (`List[float]`, *optional*): 111 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 112 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 113 | will be used. 114 | guidance_scale (`float`, *optional*, defaults to 7.0): 115 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 116 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 117 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 118 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 119 | usually at the expense of lower image quality. 120 | negative_prompt (`str` or `List[str]`, *optional*): 121 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 122 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 123 | less than `1`). 124 | negative_prompt_2 (`str` or `List[str]`, *optional*): 125 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 126 | `text_encoder_2`. If not defined, `negative_prompt` is used instead 127 | negative_prompt_3 (`str` or `List[str]`, *optional*): 128 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and 129 | `text_encoder_3`. If not defined, `negative_prompt` is used instead 130 | num_images_per_prompt (`int`, *optional*, defaults to 1): 131 | The number of images to generate per prompt. 132 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 133 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 134 | to make generation deterministic. 135 | latents (`torch.FloatTensor`, *optional*): 136 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 137 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 138 | tensor will ge generated by sampling using the supplied random `generator`. 139 | prompt_embeds (`torch.FloatTensor`, *optional*): 140 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 141 | provided, text embeddings will be generated from `prompt` input argument. 142 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 143 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 144 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 145 | argument. 146 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 147 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 148 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 149 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 150 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 151 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 152 | input argument. 153 | ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 154 | ip_adapter_image_embeds (`torch.Tensor`, *optional*): 155 | Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images, 156 | emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to 157 | `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument. 158 | output_type (`str`, *optional*, defaults to `"pil"`): 159 | The output format of the generate image. Choose between 160 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 161 | return_dict (`bool`, *optional*, defaults to `True`): 162 | Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of 163 | a plain tuple. 164 | joint_attention_kwargs (`dict`, *optional*): 165 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 166 | `self.processor` in 167 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 168 | callback_on_step_end (`Callable`, *optional*): 169 | A function that calls at the end of each denoising steps during the inference. The function is called 170 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 171 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 172 | `callback_on_step_end_tensor_inputs`. 173 | callback_on_step_end_tensor_inputs (`List`, *optional*): 174 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 175 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 176 | `._callback_tensor_inputs` attribute of your pipeline class. 177 | max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. 178 | skip_guidance_layers (`List[int]`, *optional*): 179 | A list of integers that specify layers to skip during guidance. If not provided, all layers will be 180 | used for guidance. If provided, the guidance will only be applied to the layers specified in the list. 181 | Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9]. 182 | skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in 183 | `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers` 184 | with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers 185 | with a scale of `1`. 186 | skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in 187 | `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in 188 | `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by 189 | StabiltyAI for Stable Diffusion 3.5 Medium is 0.2. 190 | skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in 191 | `skip_guidance_layers` will start. The guidance will be applied to the layers specified in 192 | `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by 193 | StabiltyAI for Stable Diffusion 3.5 Medium is 0.01. 194 | mu (`float`, *optional*): `mu` value used for `dynamic_shifting`. 195 | 196 | Examples: 197 | 198 | Returns: 199 | [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`: 200 | [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a 201 | `tuple`. When returning a tuple, the first element is a list with the generated images. 202 | """ 203 | 204 | height = height or self.default_sample_size * self.vae_scale_factor 205 | width = width or self.default_sample_size * self.vae_scale_factor 206 | 207 | # 1. Check inputs. Raise error if not correct 208 | self.check_inputs( 209 | prompt, 210 | prompt_2, 211 | prompt_3, 212 | height, 213 | width, 214 | negative_prompt=negative_prompt, 215 | negative_prompt_2=negative_prompt_2, 216 | negative_prompt_3=negative_prompt_3, 217 | prompt_embeds=prompt_embeds, 218 | negative_prompt_embeds=negative_prompt_embeds, 219 | pooled_prompt_embeds=pooled_prompt_embeds, 220 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 221 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 222 | max_sequence_length=max_sequence_length, 223 | ) 224 | 225 | self._guidance_scale = guidance_scale 226 | self._skip_layer_guidance_scale = skip_layer_guidance_scale 227 | self._clip_skip = clip_skip 228 | self._joint_attention_kwargs = joint_attention_kwargs 229 | self._interrupt = False 230 | self._nag_scale = nag_scale 231 | 232 | # 2. Define call parameters 233 | if prompt is not None and isinstance(prompt, str): 234 | batch_size = 1 235 | elif prompt is not None and isinstance(prompt, list): 236 | batch_size = len(prompt) 237 | else: 238 | batch_size = prompt_embeds.shape[0] 239 | 240 | device = self._execution_device 241 | 242 | lora_scale = ( 243 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 244 | ) 245 | ( 246 | prompt_embeds, 247 | negative_prompt_embeds, 248 | pooled_prompt_embeds, 249 | negative_pooled_prompt_embeds, 250 | ) = self.encode_prompt( 251 | prompt=prompt, 252 | prompt_2=prompt_2, 253 | prompt_3=prompt_3, 254 | negative_prompt=negative_prompt, 255 | negative_prompt_2=negative_prompt_2, 256 | negative_prompt_3=negative_prompt_3, 257 | do_classifier_free_guidance=self.do_classifier_free_guidance, 258 | prompt_embeds=prompt_embeds, 259 | negative_prompt_embeds=negative_prompt_embeds, 260 | pooled_prompt_embeds=pooled_prompt_embeds, 261 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 262 | device=device, 263 | clip_skip=self.clip_skip, 264 | num_images_per_prompt=num_images_per_prompt, 265 | max_sequence_length=max_sequence_length, 266 | lora_scale=lora_scale, 267 | ) 268 | if self.do_normalized_attention_guidance: 269 | if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None: 270 | if nag_negative_prompt is None: 271 | if negative_prompt is not None: 272 | if self.do_classifier_free_guidance: 273 | nag_negative_prompt_embeds = negative_prompt_embeds 274 | nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds 275 | else: 276 | nag_negative_prompt = negative_prompt 277 | nag_negative_prompt_2 = negative_prompt_2 278 | nag_negative_prompt_3 = negative_prompt_3 279 | else: 280 | nag_negative_prompt = "" 281 | 282 | if nag_negative_prompt is not None: 283 | nag_negative_prompt_embeds, _, nag_negative_pooled_prompt_embeds, _ = self.encode_prompt( 284 | prompt=nag_negative_prompt, 285 | prompt_2=nag_negative_prompt_2, 286 | prompt_3=nag_negative_prompt_3, 287 | do_classifier_free_guidance=False, 288 | device=device, 289 | clip_skip=self.clip_skip, 290 | num_images_per_prompt=num_images_per_prompt, 291 | max_sequence_length=max_sequence_length, 292 | lora_scale=lora_scale, 293 | ) 294 | 295 | if self.do_classifier_free_guidance: 296 | original_prompt_embeds = prompt_embeds 297 | original_pooled_prompt_embeds = pooled_prompt_embeds 298 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 299 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 300 | 301 | if self.do_normalized_attention_guidance: 302 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0) 303 | if self.do_classifier_free_guidance: 304 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, original_pooled_prompt_embeds], dim=0) 305 | else: 306 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 307 | 308 | # 4. Prepare latent variables 309 | num_channels_latents = self.transformer.config.in_channels 310 | latents = self.prepare_latents( 311 | batch_size * num_images_per_prompt, 312 | num_channels_latents, 313 | height, 314 | width, 315 | prompt_embeds.dtype, 316 | device, 317 | generator, 318 | latents, 319 | ) 320 | 321 | # 5. Prepare timesteps 322 | scheduler_kwargs = {} 323 | if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None: 324 | _, _, height, width = latents.shape 325 | image_seq_len = (height // self.transformer.config.patch_size) * ( 326 | width // self.transformer.config.patch_size 327 | ) 328 | mu = calculate_shift( 329 | image_seq_len, 330 | self.scheduler.config.base_image_seq_len, 331 | self.scheduler.config.max_image_seq_len, 332 | self.scheduler.config.base_shift, 333 | self.scheduler.config.max_shift, 334 | ) 335 | scheduler_kwargs["mu"] = mu 336 | elif mu is not None: 337 | scheduler_kwargs["mu"] = mu 338 | timesteps, num_inference_steps = retrieve_timesteps( 339 | self.scheduler, 340 | num_inference_steps, 341 | device, 342 | sigmas=sigmas, 343 | **scheduler_kwargs, 344 | ) 345 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 346 | self._num_timesteps = len(timesteps) 347 | 348 | # 6. Prepare image embeddings 349 | if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None: 350 | ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds( 351 | ip_adapter_image, 352 | ip_adapter_image_embeds, 353 | device, 354 | batch_size * num_images_per_prompt, 355 | self.do_classifier_free_guidance, 356 | ) 357 | 358 | if self.joint_attention_kwargs is None: 359 | self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds} 360 | else: 361 | self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds) 362 | 363 | origin_attn_procs = self.transformer.attn_processors 364 | if self.do_normalized_attention_guidance: 365 | self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha) 366 | 367 | for sub_mod in self.transformer.modules(): 368 | if not hasattr(sub_mod, "forward_old") : 369 | sub_mod.forward_old = sub_mod.forward 370 | if isinstance(sub_mod, AdaLayerNorm): 371 | sub_mod.forward = types.MethodType(TruncAdaLayerNorm.forward, sub_mod) 372 | elif isinstance(sub_mod, AdaLayerNormContinuous): 373 | sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod) 374 | elif isinstance(sub_mod, AdaLayerNormZero): 375 | sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod) 376 | elif isinstance(sub_mod, SD35AdaLayerNormZeroX): 377 | sub_mod.forward = types.MethodType(TruncSD35AdaLayerNormZeroX.forward, sub_mod) 378 | 379 | # 7. Denoising loop 380 | with self.progress_bar(total=num_inference_steps) as progress_bar: 381 | for i, t in enumerate(timesteps): 382 | if self.interrupt: 383 | continue 384 | 385 | # expand the latents if we are doing classifier free guidance 386 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 387 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 388 | timestep = t.expand(prompt_embeds.shape[0]) 389 | 390 | noise_pred = self.transformer( 391 | hidden_states=latent_model_input, 392 | timestep=timestep, 393 | encoder_hidden_states=prompt_embeds, 394 | pooled_projections=pooled_prompt_embeds, 395 | joint_attention_kwargs=self.joint_attention_kwargs, 396 | return_dict=False, 397 | )[0] 398 | 399 | # perform guidance 400 | if self.do_classifier_free_guidance: 401 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 402 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 403 | should_skip_layers = ( 404 | True 405 | if i > num_inference_steps * skip_layer_guidance_start 406 | and i < num_inference_steps * skip_layer_guidance_stop 407 | else False 408 | ) 409 | if skip_guidance_layers is not None and should_skip_layers: 410 | timestep = t.expand(latents.shape[0]) 411 | latent_model_input = latents 412 | noise_pred_skip_layers = self.transformer( 413 | hidden_states=latent_model_input, 414 | timestep=timestep, 415 | encoder_hidden_states=original_prompt_embeds, 416 | pooled_projections=original_pooled_prompt_embeds, 417 | joint_attention_kwargs=self.joint_attention_kwargs, 418 | return_dict=False, 419 | skip_layers=skip_guidance_layers, 420 | )[0] 421 | noise_pred = ( 422 | noise_pred + ( 423 | noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale 424 | ) 425 | 426 | # compute the previous noisy sample x_t -> x_t-1 427 | latents_dtype = latents.dtype 428 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 429 | 430 | if latents.dtype != latents_dtype: 431 | if torch.backends.mps.is_available(): 432 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 433 | latents = latents.to(latents_dtype) 434 | 435 | if callback_on_step_end is not None: 436 | callback_kwargs = {} 437 | for k in callback_on_step_end_tensor_inputs: 438 | callback_kwargs[k] = locals()[k] 439 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 440 | 441 | latents = callback_outputs.pop("latents", latents) 442 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 443 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 444 | negative_pooled_prompt_embeds = callback_outputs.pop( 445 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds 446 | ) 447 | 448 | # call the callback, if provided 449 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 450 | progress_bar.update() 451 | 452 | if XLA_AVAILABLE: 453 | xm.mark_step() 454 | 455 | if output_type == "latent": 456 | image = latents 457 | 458 | else: 459 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 460 | 461 | image = self.vae.decode(latents, return_dict=False)[0] 462 | image = self.image_processor.postprocess(image, output_type=output_type) 463 | 464 | if self.do_normalized_attention_guidance: 465 | self.transformer.set_attn_processor(origin_attn_procs) 466 | 467 | # Offload all models 468 | self.maybe_free_model_hooks() 469 | 470 | if not return_dict: 471 | return (image,) 472 | 473 | return StableDiffusion3PipelineOutput(images=image) 474 | -------------------------------------------------------------------------------- /nag/pipeline_sdxl_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 2 | import math 3 | 4 | import torch 5 | 6 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 7 | from diffusers.image_processor import PipelineImageInput 8 | from diffusers.utils import ( 9 | deprecate, 10 | is_torch_xla_available, 11 | ) 12 | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput 13 | 14 | if is_torch_xla_available(): 15 | import torch_xla.core.xla_model as xm 16 | 17 | XLA_AVAILABLE = True 18 | else: 19 | XLA_AVAILABLE = False 20 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import ( 21 | StableDiffusionXLPipeline, 22 | retrieve_timesteps, 23 | rescale_noise_cfg, 24 | ) 25 | 26 | from nag.attention_nag import NAGAttnProcessor2_0 27 | 28 | 29 | class NAGStableDiffusionXLPipeline(StableDiffusionXLPipeline): 30 | @property 31 | def do_normalized_attention_guidance(self): 32 | return self._nag_scale > 1 33 | 34 | def _set_nag_attn_processor(self, nag_scale, nag_tau=2.5, nag_alpha=0.5): 35 | if self.do_normalized_attention_guidance: 36 | attn_procs = {} 37 | for name, origin_attn_processor in self.unet.attn_processors.items(): 38 | if "attn2" in name: 39 | attn_procs[name] = NAGAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha) 40 | else: 41 | attn_procs[name] = origin_attn_processor 42 | self.unet.set_attn_processor(attn_procs) 43 | 44 | @torch.no_grad() 45 | def __call__( 46 | self, 47 | prompt: Union[str, List[str]] = None, 48 | prompt_2: Optional[Union[str, List[str]]] = None, 49 | height: Optional[int] = None, 50 | width: Optional[int] = None, 51 | num_inference_steps: int = 50, 52 | timesteps: List[int] = None, 53 | sigmas: List[float] = None, 54 | denoising_end: Optional[float] = None, 55 | guidance_scale: float = 5.0, 56 | negative_prompt: Optional[Union[str, List[str]]] = None, 57 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 58 | num_images_per_prompt: Optional[int] = 1, 59 | eta: float = 0.0, 60 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 61 | latents: Optional[torch.Tensor] = None, 62 | prompt_embeds: Optional[torch.Tensor] = None, 63 | negative_prompt_embeds: Optional[torch.Tensor] = None, 64 | pooled_prompt_embeds: Optional[torch.Tensor] = None, 65 | negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, 66 | ip_adapter_image: Optional[PipelineImageInput] = None, 67 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, 68 | output_type: Optional[str] = "pil", 69 | return_dict: bool = True, 70 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 71 | guidance_rescale: float = 0.0, 72 | original_size: Optional[Tuple[int, int]] = None, 73 | crops_coords_top_left: Tuple[int, int] = (0, 0), 74 | target_size: Optional[Tuple[int, int]] = None, 75 | negative_original_size: Optional[Tuple[int, int]] = None, 76 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 77 | negative_target_size: Optional[Tuple[int, int]] = None, 78 | clip_skip: Optional[int] = None, 79 | callback_on_step_end: Optional[ 80 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 81 | ] = None, 82 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 83 | 84 | nag_scale: float = 1.0, 85 | nag_tau: float = 2.5, 86 | nag_alpha: float = 0.5, 87 | nag_negative_prompt: str = None, 88 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None, 89 | nag_end: float = 1.0, 90 | 91 | **kwargs, 92 | ): 93 | r""" 94 | Function invoked when calling the pipeline for generation. 95 | 96 | Args: 97 | prompt (`str` or `List[str]`, *optional*): 98 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 99 | instead. 100 | prompt_2 (`str` or `List[str]`, *optional*): 101 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 102 | used in both text-encoders 103 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 104 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 105 | Anything below 512 pixels won't work well for 106 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 107 | and checkpoints that are not specifically fine-tuned on low resolutions. 108 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 109 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 110 | Anything below 512 pixels won't work well for 111 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 112 | and checkpoints that are not specifically fine-tuned on low resolutions. 113 | num_inference_steps (`int`, *optional*, defaults to 50): 114 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 115 | expense of slower inference. 116 | timesteps (`List[int]`, *optional*): 117 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument 118 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is 119 | passed will be used. Must be in descending order. 120 | sigmas (`List[float]`, *optional*): 121 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 122 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 123 | will be used. 124 | denoising_end (`float`, *optional*): 125 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be 126 | completed before it is intentionally prematurely terminated. As a result, the returned sample will 127 | still retain a substantial amount of noise as determined by the discrete timesteps selected by the 128 | scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 129 | "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image 130 | Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) 131 | guidance_scale (`float`, *optional*, defaults to 5.0): 132 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 133 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 134 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 135 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 136 | usually at the expense of lower image quality. 137 | negative_prompt (`str` or `List[str]`, *optional*): 138 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 139 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 140 | less than `1`). 141 | negative_prompt_2 (`str` or `List[str]`, *optional*): 142 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 143 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders 144 | num_images_per_prompt (`int`, *optional*, defaults to 1): 145 | The number of images to generate per prompt. 146 | eta (`float`, *optional*, defaults to 0.0): 147 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 148 | [`schedulers.DDIMScheduler`], will be ignored for others. 149 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 150 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 151 | to make generation deterministic. 152 | latents (`torch.Tensor`, *optional*): 153 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 154 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 155 | tensor will ge generated by sampling using the supplied random `generator`. 156 | prompt_embeds (`torch.Tensor`, *optional*): 157 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 158 | provided, text embeddings will be generated from `prompt` input argument. 159 | negative_prompt_embeds (`torch.Tensor`, *optional*): 160 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 161 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 162 | argument. 163 | pooled_prompt_embeds (`torch.Tensor`, *optional*): 164 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 165 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 166 | negative_pooled_prompt_embeds (`torch.Tensor`, *optional*): 167 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 168 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 169 | input argument. 170 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. 171 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): 172 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of 173 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should 174 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not 175 | provided, embeddings are computed from the `ip_adapter_image` input argument. 176 | output_type (`str`, *optional*, defaults to `"pil"`): 177 | The output format of the generate image. Choose between 178 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 179 | return_dict (`bool`, *optional*, defaults to `True`): 180 | Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead 181 | of a plain tuple. 182 | cross_attention_kwargs (`dict`, *optional*): 183 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 184 | `self.processor` in 185 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 186 | guidance_rescale (`float`, *optional*, defaults to 0.0): 187 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are 188 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of 189 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 190 | Guidance rescale factor should fix overexposure when using zero terminal SNR. 191 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 192 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. 193 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as 194 | explained in section 2.2 of 195 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 196 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 197 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position 198 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting 199 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of 200 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 201 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 202 | For most cases, `target_size` should be set to the desired height and width of the generated image. If 203 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in 204 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 205 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 206 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's 207 | micro-conditioning as explained in section 2.2 of 208 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 209 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 210 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 211 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's 212 | micro-conditioning as explained in section 2.2 of 213 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 214 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 215 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 216 | To negatively condition the generation process based on a target image resolution. It should be as same 217 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of 218 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 219 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 220 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 221 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 222 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 223 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 224 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 225 | callback_on_step_end_tensor_inputs (`List`, *optional*): 226 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 227 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 228 | `._callback_tensor_inputs` attribute of your pipeline class. 229 | 230 | Examples: 231 | 232 | Returns: 233 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: 234 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a 235 | `tuple`. When returning a tuple, the first element is a list with the generated images. 236 | """ 237 | 238 | callback = kwargs.pop("callback", None) 239 | callback_steps = kwargs.pop("callback_steps", None) 240 | 241 | if callback is not None: 242 | deprecate( 243 | "callback", 244 | "1.0.0", 245 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 246 | ) 247 | if callback_steps is not None: 248 | deprecate( 249 | "callback_steps", 250 | "1.0.0", 251 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`", 252 | ) 253 | 254 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 255 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 256 | 257 | # 0. Default height and width to unet 258 | height = height or self.default_sample_size * self.vae_scale_factor 259 | width = width or self.default_sample_size * self.vae_scale_factor 260 | 261 | original_size = original_size or (height, width) 262 | target_size = target_size or (height, width) 263 | 264 | # 1. Check inputs. Raise error if not correct 265 | self.check_inputs( 266 | prompt, 267 | prompt_2, 268 | height, 269 | width, 270 | callback_steps, 271 | negative_prompt, 272 | negative_prompt_2, 273 | prompt_embeds, 274 | negative_prompt_embeds, 275 | pooled_prompt_embeds, 276 | negative_pooled_prompt_embeds, 277 | ip_adapter_image, 278 | ip_adapter_image_embeds, 279 | callback_on_step_end_tensor_inputs, 280 | ) 281 | 282 | self._guidance_scale = guidance_scale 283 | self._guidance_rescale = guidance_rescale 284 | self._clip_skip = clip_skip 285 | self._cross_attention_kwargs = cross_attention_kwargs 286 | self._denoising_end = denoising_end 287 | self._interrupt = False 288 | self._nag_scale = nag_scale 289 | 290 | # 2. Define call parameters 291 | if prompt is not None and isinstance(prompt, str): 292 | batch_size = 1 293 | elif prompt is not None and isinstance(prompt, list): 294 | batch_size = len(prompt) 295 | else: 296 | batch_size = prompt_embeds.shape[0] 297 | 298 | device = self._execution_device 299 | 300 | # 3. Encode input prompt 301 | lora_scale = ( 302 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 303 | ) 304 | 305 | ( 306 | prompt_embeds, 307 | negative_prompt_embeds, 308 | pooled_prompt_embeds, 309 | negative_pooled_prompt_embeds, 310 | ) = self.encode_prompt( 311 | prompt=prompt, 312 | prompt_2=prompt_2, 313 | device=device, 314 | num_images_per_prompt=num_images_per_prompt, 315 | do_classifier_free_guidance=self.do_classifier_free_guidance or self.do_normalized_attention_guidance, 316 | negative_prompt=negative_prompt, 317 | negative_prompt_2=negative_prompt_2, 318 | prompt_embeds=prompt_embeds, 319 | negative_prompt_embeds=negative_prompt_embeds, 320 | pooled_prompt_embeds=pooled_prompt_embeds, 321 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 322 | lora_scale=lora_scale, 323 | clip_skip=self.clip_skip, 324 | ) 325 | if self.do_normalized_attention_guidance: 326 | if nag_negative_prompt_embeds is None: 327 | if nag_negative_prompt is None: 328 | if negative_prompt is not None: 329 | if self.do_classifier_free_guidance: 330 | nag_negative_prompt_embeds = negative_prompt_embeds 331 | else: 332 | negative_prompt = negative_prompt 333 | else: 334 | nag_negative_prompt = "" 335 | 336 | if nag_negative_prompt is not None: 337 | nag_negative_prompt_embeds = self.encode_prompt( 338 | prompt=nag_negative_prompt, 339 | device=device, 340 | num_images_per_prompt=num_images_per_prompt, 341 | do_classifier_free_guidance=False, 342 | lora_scale=lora_scale, 343 | clip_skip=self.clip_skip, 344 | )[0] 345 | 346 | # 4. Prepare timesteps 347 | timesteps, num_inference_steps = retrieve_timesteps( 348 | self.scheduler, num_inference_steps, device, timesteps, sigmas 349 | ) 350 | 351 | # 5. Prepare latent variables 352 | num_channels_latents = self.unet.config.in_channels 353 | latents = self.prepare_latents( 354 | batch_size * num_images_per_prompt, 355 | num_channels_latents, 356 | height, 357 | width, 358 | prompt_embeds.dtype, 359 | device, 360 | generator, 361 | latents, 362 | ) 363 | 364 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 365 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 366 | 367 | # 7. Prepare added time ids & embeddings 368 | add_text_embeds = pooled_prompt_embeds 369 | if self.text_encoder_2 is None: 370 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 371 | else: 372 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 373 | 374 | add_time_ids = self._get_add_time_ids( 375 | original_size, 376 | crops_coords_top_left, 377 | target_size, 378 | dtype=prompt_embeds.dtype, 379 | text_encoder_projection_dim=text_encoder_projection_dim, 380 | ) 381 | if negative_original_size is not None and negative_target_size is not None: 382 | negative_add_time_ids = self._get_add_time_ids( 383 | negative_original_size, 384 | negative_crops_coords_top_left, 385 | negative_target_size, 386 | dtype=prompt_embeds.dtype, 387 | text_encoder_projection_dim=text_encoder_projection_dim, 388 | ) 389 | else: 390 | negative_add_time_ids = add_time_ids 391 | 392 | if self.do_classifier_free_guidance: 393 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 394 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 395 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 396 | 397 | if self.do_normalized_attention_guidance: 398 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0) 399 | 400 | prompt_embeds = prompt_embeds.to(device) 401 | add_text_embeds = add_text_embeds.to(device) 402 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 403 | 404 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 405 | image_embeds = self.prepare_ip_adapter_image_embeds( 406 | ip_adapter_image, 407 | ip_adapter_image_embeds, 408 | device, 409 | batch_size * num_images_per_prompt, 410 | self.do_classifier_free_guidance, 411 | ) 412 | 413 | # 8. Denoising loop 414 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 415 | 416 | # 8.1 Apply denoising_end 417 | if ( 418 | self.denoising_end is not None 419 | and isinstance(self.denoising_end, float) 420 | and self.denoising_end > 0 421 | and self.denoising_end < 1 422 | ): 423 | discrete_timestep_cutoff = int( 424 | round( 425 | self.scheduler.config.num_train_timesteps 426 | - (self.denoising_end * self.scheduler.config.num_train_timesteps) 427 | ) 428 | ) 429 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 430 | timesteps = timesteps[:num_inference_steps] 431 | 432 | # 9. Optionally get Guidance Scale Embedding 433 | timestep_cond = None 434 | if self.unet.config.time_cond_proj_dim is not None: 435 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 436 | timestep_cond = self.get_guidance_scale_embedding( 437 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 438 | ).to(device=device, dtype=latents.dtype) 439 | 440 | if self.do_normalized_attention_guidance: 441 | origin_attn_procs = self.unet.attn_processors 442 | self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha) 443 | attn_procs_recovered = False 444 | 445 | self._num_timesteps = len(timesteps) 446 | with self.progress_bar(total=num_inference_steps) as progress_bar: 447 | for i, t in enumerate(timesteps): 448 | if self.interrupt: 449 | continue 450 | 451 | # expand the latents if we are doing classifier free guidance 452 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 453 | 454 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 455 | 456 | # predict the noise residual 457 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 458 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None: 459 | added_cond_kwargs["image_embeds"] = image_embeds 460 | 461 | if t < math.floor((1 - nag_end) * 999) and self.do_normalized_attention_guidance and not attn_procs_recovered: 462 | self.unet.set_attn_processor(origin_attn_procs) 463 | prompt_embeds = prompt_embeds[:len(latent_model_input)] 464 | attn_procs_recovered = True 465 | 466 | noise_pred = self.unet( 467 | latent_model_input, 468 | t, 469 | encoder_hidden_states=prompt_embeds, 470 | timestep_cond=timestep_cond, 471 | cross_attention_kwargs=self.cross_attention_kwargs, 472 | added_cond_kwargs=added_cond_kwargs, 473 | return_dict=False, 474 | )[0] 475 | 476 | # perform guidance 477 | if self.do_classifier_free_guidance: 478 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 479 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) 480 | 481 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: 482 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 483 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) 484 | 485 | # compute the previous noisy sample x_t -> x_t-1 486 | latents_dtype = latents.dtype 487 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 488 | if latents.dtype != latents_dtype: 489 | if torch.backends.mps.is_available(): 490 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 491 | latents = latents.to(latents_dtype) 492 | 493 | if callback_on_step_end is not None: 494 | callback_kwargs = {} 495 | for k in callback_on_step_end_tensor_inputs: 496 | callback_kwargs[k] = locals()[k] 497 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 498 | 499 | latents = callback_outputs.pop("latents", latents) 500 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 501 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 502 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) 503 | negative_pooled_prompt_embeds = callback_outputs.pop( 504 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds 505 | ) 506 | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids) 507 | negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids) 508 | 509 | # call the callback, if provided 510 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 511 | progress_bar.update() 512 | if callback is not None and i % callback_steps == 0: 513 | step_idx = i // getattr(self.scheduler, "order", 1) 514 | callback(step_idx, t, latents) 515 | 516 | if XLA_AVAILABLE: 517 | xm.mark_step() 518 | 519 | if not output_type == "latent": 520 | # make sure the VAE is in float32 mode, as it overflows in float16 521 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 522 | 523 | if needs_upcasting: 524 | self.upcast_vae() 525 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 526 | elif latents.dtype != self.vae.dtype: 527 | if torch.backends.mps.is_available(): 528 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 529 | self.vae = self.vae.to(latents.dtype) 530 | 531 | # unscale/denormalize the latents 532 | # denormalize with the mean and std if available and not None 533 | has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None 534 | has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None 535 | if has_latents_mean and has_latents_std: 536 | latents_mean = ( 537 | torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype) 538 | ) 539 | latents_std = ( 540 | torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype) 541 | ) 542 | latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean 543 | else: 544 | latents = latents / self.vae.config.scaling_factor 545 | 546 | image = self.vae.decode(latents, return_dict=False)[0] 547 | 548 | # cast back to fp16 if needed 549 | if needs_upcasting: 550 | self.vae.to(dtype=torch.float16) 551 | else: 552 | image = latents 553 | 554 | if not output_type == "latent": 555 | # apply watermark if available 556 | if self.watermark is not None: 557 | image = self.watermark.apply_watermark(image) 558 | 559 | image = self.image_processor.postprocess(image, output_type=output_type) 560 | 561 | if self.do_normalized_attention_guidance and not attn_procs_recovered: 562 | self.unet.set_attn_processor(origin_attn_procs) 563 | 564 | # Offload all models 565 | self.maybe_free_model_hooks() 566 | 567 | if not return_dict: 568 | return (image,) 569 | 570 | return StableDiffusionXLPipelineOutput(images=image) 571 | -------------------------------------------------------------------------------- /nag/pipeline_wan_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | 3 | import torch 4 | 5 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback 6 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring 7 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput 8 | from diffusers.pipelines.wan.pipeline_wan import WanPipeline 9 | 10 | from nag.attention_wan_nag import NAGWanAttnProcessor2_0 11 | 12 | if is_torch_xla_available(): 13 | import torch_xla.core.xla_model as xm 14 | 15 | XLA_AVAILABLE = True 16 | else: 17 | XLA_AVAILABLE = False 18 | 19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 20 | 21 | 22 | class NAGWanPipeline(WanPipeline): 23 | @property 24 | def do_normalized_attention_guidance(self): 25 | return self._nag_scale > 1 26 | 27 | def _set_nag_attn_processor(self, nag_scale, nag_tau, nag_alpha): 28 | attn_procs = {} 29 | for name, origin_attn_proc in self.transformer.attn_processors.items(): 30 | if "attn2" in name: 31 | attn_procs[name] = NAGWanAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha) 32 | else: 33 | attn_procs[name] = origin_attn_proc 34 | self.transformer.set_attn_processor(attn_procs) 35 | 36 | @torch.no_grad() 37 | def __call__( 38 | self, 39 | prompt: Union[str, List[str]] = None, 40 | negative_prompt: Union[str, List[str]] = None, 41 | height: int = 480, 42 | width: int = 832, 43 | num_frames: int = 81, 44 | num_inference_steps: int = 50, 45 | guidance_scale: float = 5.0, 46 | num_videos_per_prompt: Optional[int] = 1, 47 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 48 | latents: Optional[torch.Tensor] = None, 49 | prompt_embeds: Optional[torch.Tensor] = None, 50 | negative_prompt_embeds: Optional[torch.Tensor] = None, 51 | output_type: Optional[str] = "np", 52 | return_dict: bool = True, 53 | attention_kwargs: Optional[Dict[str, Any]] = None, 54 | callback_on_step_end: Optional[ 55 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] 56 | ] = None, 57 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 58 | max_sequence_length: int = 512, 59 | 60 | nag_scale: float = 1.0, 61 | nag_tau: float = 2.5, 62 | nag_alpha: float = 0.25, 63 | nag_negative_prompt: str = None, 64 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None, 65 | ): 66 | r""" 67 | The call function to the pipeline for generation. 68 | 69 | Args: 70 | prompt (`str` or `List[str]`, *optional*): 71 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 72 | instead. 73 | height (`int`, defaults to `480`): 74 | The height in pixels of the generated image. 75 | width (`int`, defaults to `832`): 76 | The width in pixels of the generated image. 77 | num_frames (`int`, defaults to `81`): 78 | The number of frames in the generated video. 79 | num_inference_steps (`int`, defaults to `50`): 80 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 81 | expense of slower inference. 82 | guidance_scale (`float`, defaults to `5.0`): 83 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 84 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 85 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 86 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 87 | usually at the expense of lower image quality. 88 | num_videos_per_prompt (`int`, *optional*, defaults to 1): 89 | The number of images to generate per prompt. 90 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 91 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 92 | generation deterministic. 93 | latents (`torch.Tensor`, *optional*): 94 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 95 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 96 | tensor is generated by sampling using the supplied random `generator`. 97 | prompt_embeds (`torch.Tensor`, *optional*): 98 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 99 | provided, text embeddings are generated from the `prompt` input argument. 100 | output_type (`str`, *optional*, defaults to `"pil"`): 101 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 102 | return_dict (`bool`, *optional*, defaults to `True`): 103 | Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple. 104 | attention_kwargs (`dict`, *optional*): 105 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 106 | `self.processor` in 107 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 108 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): 109 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of 110 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self: 111 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a 112 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`. 113 | callback_on_step_end_tensor_inputs (`List`, *optional*): 114 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 115 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 116 | `._callback_tensor_inputs` attribute of your pipeline class. 117 | autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`): 118 | The dtype to use for the torch.amp.autocast. 119 | 120 | Examples: 121 | 122 | Returns: 123 | [`~WanPipelineOutput`] or `tuple`: 124 | If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where 125 | the first element is a list with the generated images and the second element is a list of `bool`s 126 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. 127 | """ 128 | 129 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): 130 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs 131 | 132 | # 1. Check inputs. Raise error if not correct 133 | self.check_inputs( 134 | prompt, 135 | negative_prompt, 136 | height, 137 | width, 138 | prompt_embeds, 139 | negative_prompt_embeds, 140 | callback_on_step_end_tensor_inputs, 141 | ) 142 | 143 | self._guidance_scale = guidance_scale 144 | self._attention_kwargs = attention_kwargs 145 | self._current_timestep = None 146 | self._interrupt = False 147 | self._nag_scale = nag_scale 148 | 149 | device = self._execution_device 150 | 151 | # 2. Define call parameters 152 | if prompt is not None and isinstance(prompt, str): 153 | batch_size = 1 154 | elif prompt is not None and isinstance(prompt, list): 155 | batch_size = len(prompt) 156 | else: 157 | batch_size = prompt_embeds.shape[0] 158 | 159 | # 3. Encode input prompt 160 | prompt_embeds, negative_prompt_embeds = self.encode_prompt( 161 | prompt=prompt, 162 | negative_prompt=negative_prompt, 163 | do_classifier_free_guidance=self.do_classifier_free_guidance, 164 | num_videos_per_prompt=num_videos_per_prompt, 165 | prompt_embeds=prompt_embeds, 166 | negative_prompt_embeds=negative_prompt_embeds, 167 | max_sequence_length=max_sequence_length, 168 | device=device, 169 | ) 170 | if self.do_normalized_attention_guidance: 171 | if nag_negative_prompt_embeds is None: 172 | if nag_negative_prompt is None: 173 | if self.do_classifier_free_guidance: 174 | nag_negative_prompt_embeds = negative_prompt_embeds 175 | else: 176 | nag_negative_prompt = negative_prompt or "" 177 | 178 | if nag_negative_prompt is not None: 179 | nag_negative_prompt_embeds = self.encode_prompt( 180 | prompt=nag_negative_prompt, 181 | do_classifier_free_guidance=False, 182 | num_videos_per_prompt=num_videos_per_prompt, 183 | max_sequence_length=max_sequence_length, 184 | device=device, 185 | )[0] 186 | 187 | if self.do_normalized_attention_guidance: 188 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0) 189 | 190 | transformer_dtype = self.transformer.dtype 191 | prompt_embeds = prompt_embeds.to(transformer_dtype) 192 | if negative_prompt_embeds is not None: 193 | negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype) 194 | 195 | # 4. Prepare timesteps 196 | self.scheduler.set_timesteps(num_inference_steps, device=device) 197 | timesteps = self.scheduler.timesteps 198 | 199 | # 5. Prepare latent variables 200 | num_channels_latents = self.transformer.config.in_channels 201 | latents = self.prepare_latents( 202 | batch_size * num_videos_per_prompt, 203 | num_channels_latents, 204 | height, 205 | width, 206 | num_frames, 207 | torch.float32, 208 | device, 209 | generator, 210 | latents, 211 | ) 212 | 213 | # 6. Denoising loop 214 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 215 | self._num_timesteps = len(timesteps) 216 | 217 | if self.do_normalized_attention_guidance: 218 | origin_attn_procs = self.transformer.attn_processors 219 | self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha) 220 | 221 | with self.progress_bar(total=num_inference_steps) as progress_bar: 222 | for i, t in enumerate(timesteps): 223 | if self.interrupt: 224 | continue 225 | 226 | self._current_timestep = t 227 | latent_model_input = latents.to(transformer_dtype) 228 | timestep = t.expand(latents.shape[0]) 229 | 230 | noise_pred = self.transformer( 231 | hidden_states=latent_model_input, 232 | timestep=timestep, 233 | encoder_hidden_states=prompt_embeds, 234 | attention_kwargs=attention_kwargs, 235 | return_dict=False, 236 | )[0] 237 | 238 | if self.do_classifier_free_guidance: 239 | noise_uncond = self.transformer( 240 | hidden_states=latent_model_input, 241 | timestep=timestep, 242 | encoder_hidden_states=negative_prompt_embeds, 243 | attention_kwargs=attention_kwargs, 244 | return_dict=False, 245 | )[0] 246 | noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) 247 | 248 | # compute the previous noisy sample x_t -> x_t-1 249 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 250 | 251 | if callback_on_step_end is not None: 252 | callback_kwargs = {} 253 | for k in callback_on_step_end_tensor_inputs: 254 | callback_kwargs[k] = locals()[k] 255 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 256 | 257 | latents = callback_outputs.pop("latents", latents) 258 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 259 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 260 | 261 | # call the callback, if provided 262 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 263 | progress_bar.update() 264 | 265 | if XLA_AVAILABLE: 266 | xm.mark_step() 267 | 268 | self._current_timestep = None 269 | 270 | if not output_type == "latent": 271 | latents = latents.to(self.vae.dtype) 272 | latents_mean = ( 273 | torch.tensor(self.vae.config.latents_mean) 274 | .view(1, self.vae.config.z_dim, 1, 1, 1) 275 | .to(latents.device, latents.dtype) 276 | ) 277 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( 278 | latents.device, latents.dtype 279 | ) 280 | latents = latents / latents_std + latents_mean 281 | video = self.vae.decode(latents, return_dict=False)[0] 282 | video = self.video_processor.postprocess_video(video, output_type=output_type) 283 | else: 284 | video = latents 285 | 286 | if self.do_normalized_attention_guidance: 287 | self.transformer.set_attn_processor(origin_attn_procs) 288 | 289 | # Offload all models 290 | self.maybe_free_model_hooks() 291 | 292 | if not return_dict: 293 | return (video,) 294 | 295 | return WanPipelineOutput(frames=video) 296 | -------------------------------------------------------------------------------- /nag/transformer_flux.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers 7 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 8 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel 9 | 10 | 11 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 12 | 13 | 14 | class NAGFluxTransformer2DModel(FluxTransformer2DModel): 15 | def forward( 16 | self, 17 | hidden_states: torch.Tensor, 18 | encoder_hidden_states: torch.Tensor = None, 19 | pooled_projections: torch.Tensor = None, 20 | timestep: torch.LongTensor = None, 21 | img_ids: torch.Tensor = None, 22 | txt_ids: torch.Tensor = None, 23 | guidance: torch.Tensor = None, 24 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 25 | controlnet_block_samples=None, 26 | controlnet_single_block_samples=None, 27 | return_dict: bool = True, 28 | controlnet_blocks_repeat: bool = False, 29 | ) -> Union[torch.Tensor, Transformer2DModelOutput]: 30 | """ 31 | The [`FluxTransformer2DModel`] forward method. 32 | 33 | Args: 34 | hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): 35 | Input `hidden_states`. 36 | encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): 37 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 38 | pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected 39 | from the embeddings of input conditions. 40 | timestep ( `torch.LongTensor`): 41 | Used to indicate denoising step. 42 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 43 | A list of tensors that if specified are added to the residuals of transformer blocks. 44 | joint_attention_kwargs (`dict`, *optional*): 45 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 46 | `self.processor` in 47 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 48 | return_dict (`bool`, *optional*, defaults to `True`): 49 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 50 | tuple. 51 | 52 | Returns: 53 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 54 | `tuple` where the first element is the sample tensor. 55 | """ 56 | if joint_attention_kwargs is not None: 57 | joint_attention_kwargs = joint_attention_kwargs.copy() 58 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 59 | else: 60 | lora_scale = 1.0 61 | 62 | if USE_PEFT_BACKEND: 63 | # weight the lora layers by setting `lora_scale` for each PEFT layer 64 | scale_lora_layers(self, lora_scale) 65 | else: 66 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 67 | logger.warning( 68 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 69 | ) 70 | 71 | do_nag = hidden_states.shape[0] != encoder_hidden_states.shape[0] 72 | 73 | hidden_states = self.x_embedder(hidden_states) 74 | 75 | timestep = timestep.to(hidden_states.dtype) * 1000 76 | if guidance is not None: 77 | guidance = guidance.to(hidden_states.dtype) * 1000 78 | else: 79 | guidance = None 80 | 81 | temb = ( 82 | self.time_text_embed(timestep, pooled_projections) 83 | if guidance is None 84 | else self.time_text_embed(timestep, guidance, pooled_projections) 85 | ) 86 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 87 | 88 | if txt_ids.ndim == 3: 89 | logger.warning( 90 | "Passing `txt_ids` 3d torch.Tensor is deprecated." 91 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 92 | ) 93 | txt_ids = txt_ids[0] 94 | if img_ids.ndim == 3: 95 | logger.warning( 96 | "Passing `img_ids` 3d torch.Tensor is deprecated." 97 | "Please remove the batch dimension and pass it as a 2d torch Tensor" 98 | ) 99 | img_ids = img_ids[0] 100 | 101 | ids = torch.cat((txt_ids, img_ids), dim=0) 102 | image_rotary_emb = self.pos_embed(ids) 103 | 104 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: 105 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") 106 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) 107 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) 108 | 109 | for index_block, block in enumerate(self.transformer_blocks): 110 | if torch.is_grad_enabled() and self.gradient_checkpointing: 111 | encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( 112 | block, 113 | hidden_states, 114 | encoder_hidden_states, 115 | temb, 116 | image_rotary_emb, 117 | ) 118 | 119 | else: 120 | encoder_hidden_states, hidden_states = block( 121 | hidden_states=hidden_states, 122 | encoder_hidden_states=encoder_hidden_states, 123 | temb=temb, 124 | image_rotary_emb=image_rotary_emb, 125 | joint_attention_kwargs=joint_attention_kwargs, 126 | ) 127 | 128 | # controlnet residual 129 | if controlnet_block_samples is not None: 130 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 131 | interval_control = int(np.ceil(interval_control)) 132 | # For Xlabs ControlNet. 133 | if controlnet_blocks_repeat: 134 | hidden_states = ( 135 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 136 | ) 137 | else: 138 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 139 | 140 | if do_nag: 141 | hidden_states = hidden_states.tile(2, 1, 1) 142 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 143 | 144 | for index_block, block in enumerate(self.single_transformer_blocks): 145 | if torch.is_grad_enabled() and self.gradient_checkpointing: 146 | hidden_states = self._gradient_checkpointing_func( 147 | block, 148 | hidden_states, 149 | temb, 150 | image_rotary_emb, 151 | ) 152 | 153 | else: 154 | hidden_states = block( 155 | hidden_states=hidden_states, 156 | temb=temb, 157 | image_rotary_emb=image_rotary_emb, 158 | joint_attention_kwargs=joint_attention_kwargs, 159 | ) 160 | 161 | # controlnet residual 162 | if controlnet_single_block_samples is not None: 163 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 164 | interval_control = int(np.ceil(interval_control)) 165 | controlnet_single_block_sample = controlnet_single_block_samples[index_block // interval_control] 166 | if do_nag: 167 | controlnet_single_block_sample = controlnet_single_block_sample.tile(2, 1, 1) 168 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 169 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_sample 170 | ) 171 | 172 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 173 | 174 | if do_nag: 175 | hidden_states = torch.chunk(hidden_states, 2, dim=0)[0] 176 | 177 | hidden_states = self.norm_out(hidden_states, temb) 178 | output = self.proj_out(hidden_states) 179 | 180 | if USE_PEFT_BACKEND: 181 | # remove `lora_scale` from each PEFT layer 182 | unscale_lora_layers(self, lora_scale) 183 | 184 | if not return_dict: 185 | return (output,) 186 | 187 | return Transformer2DModelOutput(sample=output) 188 | -------------------------------------------------------------------------------- /nag/transformer_wan_nag.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | 5 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers 6 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 7 | from diffusers.models.transformers.transformer_wan import WanTransformer3DModel 8 | from diffusers.models.attention_processor import AttentionProcessor 9 | 10 | 11 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 12 | 13 | 14 | class NagWanTransformer3DModel(WanTransformer3DModel): 15 | @property 16 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors 17 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 18 | r""" 19 | Returns: 20 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 21 | indexed by its weight name. 22 | """ 23 | # set recursively 24 | processors = {} 25 | 26 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 27 | if hasattr(module, "get_processor"): 28 | processors[f"{name}.processor"] = module.get_processor() 29 | 30 | for sub_name, child in module.named_children(): 31 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 32 | 33 | return processors 34 | 35 | for name, module in self.named_children(): 36 | fn_recursive_add_processors(name, module, processors) 37 | 38 | return processors 39 | 40 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor 41 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 42 | r""" 43 | Sets the attention processor to use to compute attention. 44 | 45 | Parameters: 46 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 47 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 48 | for **all** `Attention` layers. 49 | 50 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 51 | processor. This is strongly recommended when setting trainable attention processors. 52 | 53 | """ 54 | count = len(self.attn_processors.keys()) 55 | 56 | if isinstance(processor, dict) and len(processor) != count: 57 | raise ValueError( 58 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 59 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 60 | ) 61 | 62 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 63 | if hasattr(module, "set_processor"): 64 | if not isinstance(processor, dict): 65 | module.set_processor(processor) 66 | else: 67 | module.set_processor(processor.pop(f"{name}.processor")) 68 | 69 | for sub_name, child in module.named_children(): 70 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 71 | 72 | for name, module in self.named_children(): 73 | fn_recursive_attn_processor(name, module, processor) 74 | 75 | def forward( 76 | self, 77 | hidden_states: torch.Tensor, 78 | timestep: torch.LongTensor, 79 | encoder_hidden_states: torch.Tensor, 80 | encoder_hidden_states_image: Optional[torch.Tensor] = None, 81 | return_dict: bool = True, 82 | attention_kwargs: Optional[Dict[str, Any]] = None, 83 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: 84 | if attention_kwargs is not None: 85 | attention_kwargs = attention_kwargs.copy() 86 | lora_scale = attention_kwargs.pop("scale", 1.0) 87 | else: 88 | lora_scale = 1.0 89 | 90 | if USE_PEFT_BACKEND: 91 | # weight the lora layers by setting `lora_scale` for each PEFT layer 92 | scale_lora_layers(self, lora_scale) 93 | else: 94 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: 95 | logger.warning( 96 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." 97 | ) 98 | 99 | batch_size, num_channels, num_frames, height, width = hidden_states.shape 100 | p_t, p_h, p_w = self.config.patch_size 101 | post_patch_num_frames = num_frames // p_t 102 | post_patch_height = height // p_h 103 | post_patch_width = width // p_w 104 | 105 | rotary_emb = self.rope(hidden_states) 106 | 107 | hidden_states = self.patch_embedding(hidden_states) 108 | hidden_states = hidden_states.flatten(2).transpose(1, 2) 109 | 110 | temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( 111 | timestep, encoder_hidden_states, encoder_hidden_states_image 112 | ) 113 | timestep_proj = timestep_proj.unflatten(1, (6, -1)) 114 | 115 | if encoder_hidden_states_image is not None: 116 | bs_encoder_hidden_states = len(encoder_hidden_states) 117 | bs_encoder_hidden_states_image = len(encoder_hidden_states_image) 118 | bs_scale = bs_encoder_hidden_states / bs_encoder_hidden_states_image 119 | assert bs_scale in [1, 2, 3] 120 | if bs_scale != 1: 121 | encoder_hidden_states_image = encoder_hidden_states_image.tile(int(bs_scale), 1, 1) 122 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) 123 | 124 | # 4. Transformer blocks 125 | if torch.is_grad_enabled() and self.gradient_checkpointing: 126 | for block in self.blocks: 127 | hidden_states = self._gradient_checkpointing_func( 128 | block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb 129 | ) 130 | else: 131 | for block in self.blocks: 132 | hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) 133 | 134 | # 5. Output norm, projection & unpatchify 135 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) 136 | 137 | # Move the shift and scale tensors to the same device as hidden_states. 138 | # When using multi-GPU inference via accelerate these will be on the 139 | # first device rather than the last device, which hidden_states ends up 140 | # on. 141 | shift = shift.to(hidden_states.device) 142 | scale = scale.to(hidden_states.device) 143 | 144 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) 145 | hidden_states = self.proj_out(hidden_states) 146 | 147 | hidden_states = hidden_states.reshape( 148 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 149 | ) 150 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) 151 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) 152 | 153 | if USE_PEFT_BACKEND: 154 | # remove `lora_scale` from each PEFT layer 155 | unscale_lora_layers(self, lora_scale) 156 | 157 | if not return_dict: 158 | return (output,) 159 | 160 | return Transformer2DModelOutput(sample=output) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | diffusers 3 | torch 4 | transformers 5 | sentencepiece -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | # Read requirements.txt 4 | with open("requirements.txt", "r") as f: 5 | requirements = f.read().splitlines() 6 | 7 | setup( 8 | name="nag", 9 | version="0.0.0", 10 | description="Normalized Attention Guidance for Diffusion Models", 11 | author="ChenDarYen", 12 | packages=find_packages(include=["nag", "nag.*"]), 13 | install_requires=requirements, 14 | python_requires=">=3.10", 15 | ) --------------------------------------------------------------------------------