├── .gitignore
├── LICENSE
├── README.md
├── assets
├── architecture.jpg
└── banner.jpg
├── flux_nag_demo.ipynb
├── nag
├── __init__.py
├── attention_flux_nag.py
├── attention_joint_nag.py
├── attention_nag.py
├── attention_wan_nag.py
├── normalization.py
├── pipeline_flux_kontext_nag.py
├── pipeline_flux_nag.py
├── pipeline_sd3_nag.py
├── pipeline_sdxl_nag.py
├── pipeline_wan_nag.py
├── transformer_flux.py
└── transformer_wan_nag.py
├── requirements.txt
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | .idea/
169 |
170 | # Abstra
171 | # Abstra is an AI-powered process automation framework.
172 | # Ignore directories containing user credentials, local state, and settings.
173 | # Learn more at https://abstra.io/docs
174 | .abstra/
175 |
176 | # Visual Studio Code
177 | # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
178 | # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
179 | # and can be added to the global gitignore or merged into this file. However, if you prefer,
180 | # you could uncomment the following to ignore the enitre vscode folder
181 | # .vscode/
182 |
183 | # Ruff stuff:
184 | .ruff_cache/
185 |
186 | # PyPI configuration file
187 | .pypirc
188 |
189 | # Cursor
190 | # Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
191 | # exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
192 | # refer to https://docs.cursor.com/context/ignore-files
193 | .cursorignore
194 | .cursorindexingignore
195 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Dar-Yen Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Normalized Attention Guidance: Universal Negative Guidance for Diffusion Models
2 |
3 | [](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-Kontext-Dev)
4 | [](https://chendaryen.github.io/NAG.github.io/)
5 | [](https://arxiv.org/abs/2505.21179)
6 | [](https://badges.toozhao.com/stats/01JWNDV5JQ2XT69RCZ5KQBCY0E "Get your own page views count badge on badges.toozhao.com")
7 |
8 |
9 | 
10 | Negative prompting on 4-step Flux-Schnell:
11 | CFG fails in few-step models. NAG restores effective negative prompting, enabling direct suppression of visual, semantic, and stylistic attributes, such as ``glasses``, ``tiger``, ``realistic``, or ``blurry``. This enhances controllability and expands creative freedom across composition, style, and quality—including prompt-based debiasing.
12 |
13 |
14 | ## News
15 |
16 |
17 | **2025-06-30:** 🤗 Code and [demo](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-Kontext-Dev) for `Flux Kontext` is now available!
18 |
19 | **2025-06-28:** 🎉 Our [ComfyUI implementation](https://github.com/ChenDarYen/ComfyUI-NAG) now supports `Flux Kontext`, `Wan2.1`, and `Hunyuan Video`!
20 |
21 | **2025-06-24:** 🎉 A [ComfyUI node](https://github.com/kijai/ComfyUI-KJNodes/blob/f7eb33abc80a2aded1b46dff0dd14d07856a7d50/nodes/model_optimization_nodes.py#L1568) for Wan is now available! Big thanks to [Kijai](https://github.com/kijai)!
22 |
23 | **2025-06-24:** 🤗 Demo for [LTX Video Fast](https://huggingface.co/spaces/ChenDY/NAG_ltx-video-distilled) is now available!
24 |
25 | **2025-06-22:** 🚀 SD3.5 pipeline is released!
26 |
27 | **2025-06-22:** 🎉 Play with the [ComfyUI implementation](https://github.com/ChenDarYen/ComfyUI-NAG) now!
28 |
29 | **2025-06-19:** 🚀 Wan2.1 and the SDXL pipeline are released!
30 |
31 | **2025-06-09:** 🤗 Demo for [4-step Wan2.1 with CausVid](https://huggingface.co/spaces/ChenDY/NAG_wan2-1-fast) video generation is now available!
32 |
33 | **2025-06-01:** 🤗 Demo for [Flux-Schnell](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-schnell) and [Flux-Dev](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev) are now available!
34 |
35 |
36 | ## Approach
37 |
38 | The prevailing approach to diffusion model control, Classifier-Free Guidance (CFG), enables negative guidance by extrapolating between positive and negative conditional outputs at each denoising step. However, in few-step regimes, CFG's assumption of consistent structure between diffusion branches breaks down, as these branches diverge dramatically at early steps. This divergence causes severe artifacts rather than controlled guidance.
39 |
40 | Normalized Attention Guidance (NAG) operates in attention space by extrapolating positive and negative features Z+ and Z-, followed by L1-based normalization and α-blending. This constrains feature deviation, suppresses out-of-manifold drift, and achieves stable, controllable guidance.
41 |
42 | 
43 |
44 | ## Installation
45 |
46 | Install directly from GitHub:
47 |
48 | ```bash
49 | pip install git+https://github.com/ChenDarYen/Normalized-Attention-Guidance.git
50 | ```
51 |
52 | ## Usage
53 |
54 | ### Flux
55 | You can try NAG in `flux_nag_demo.ipynb`, or 🤗 Hugging Face Demo for [Flux-Schell](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-schnell) and [Flux-Dev](https://huggingface.co/spaces/ChenDY/NAG_FLUX.1-dev)!
56 |
57 | Loading Custom Pipeline:
58 |
59 | ```python
60 | import torch
61 | from nag import NAGFluxPipeline
62 | from nag import NAGFluxTransformer2DModel
63 |
64 | transformer = NAGFluxTransformer2DModel.from_pretrained(
65 | "black-forest-labs/FLUX.1-schnell",
66 | subfolder="transformer",
67 | torch_dtype=torch.bfloat16,
68 | token="hf_token",
69 | )
70 | pipe = NAGFluxPipeline.from_pretrained(
71 | "black-forest-labs/FLUX.1-schnell",
72 | transformer=transformer,
73 | torch_dtype=torch.bfloat16,
74 | token="hf_token",
75 | )
76 | pipe.to("cuda")
77 | ```
78 |
79 | Sampling with NAG:
80 |
81 | ```python
82 | prompt = "Portrait of AI researcher."
83 | nag_negative_prompt = "Glasses."
84 | # prompt = "A baby phoenix made of fire and flames is born from the smoking ashes."
85 | # nag_negative_prompt = "Low resolution, blurry, lack of details, illustration, cartoon, painting."
86 |
87 | image = pipe(
88 | prompt,
89 | nag_negative_prompt=nag_negative_prompt,
90 | guidance_scale=0.0,
91 | nag_scale=5.0,
92 | num_inference_steps=4,
93 | max_sequence_length=256,
94 | ).images[0]
95 | ```
96 |
97 | ### Flux Kontext
98 |
99 | ```python
100 | import torch
101 | from diffusers.utils import load_image
102 | from nag import NAGFluxKontextPipeline
103 | from nag import NAGFluxTransformer2DModel
104 |
105 | transformer = NAGFluxTransformer2DModel.from_pretrained(
106 | "black-forest-labs/FLUX.1-schnell",
107 | subfolder="transformer",
108 | torch_dtype=torch.bfloat16,
109 | token="hf_token",
110 | )
111 | pipe = NAGFluxKontextPipeline.from_pretrained(
112 | "black-forest-labs/FLUX.1-schnell",
113 | transformer=transformer,
114 | torch_dtype=torch.bfloat16,
115 | token="hf_token",
116 | )
117 | pipe.to("cuda")
118 |
119 | input_image = load_image(
120 | "https://raw.githubusercontent.com/Comfy-Org/example_workflows/main/flux/kontext/dev/rabbit.jpg")
121 | prompt = "Using this elegant style, create a portrait of a cute Godzilla wearing a pearl tiara and lace collar, maintaining the same refined quality and soft color tones."
122 | nag_negative_prompt = "Low resolution, blurry, lack of details"
123 |
124 | image = pipe(
125 | prompt=prompt,
126 | image=input_image,
127 | nag_negative_prompt=nag_negative_prompt,
128 | guidance_scale=2.5,
129 | nag_scale=5.0,
130 | num_inference_steps=25,
131 | width=input_image.size[0],
132 | height=input_image.size[1],
133 | ).images[0]
134 | ```
135 |
136 | ### Wan2.1
137 |
138 | ```python
139 | import torch
140 | from diffusers import AutoencoderKLWan, UniPCMultistepScheduler
141 | from nag import NagWanTransformer3DModel
142 | from nag import NAGWanPipeline
143 |
144 | model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
145 | vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
146 | transformer = NagWanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
147 | pipe = NAGWanPipeline.from_pretrained(
148 | model_id,
149 | vae=vae,
150 | transformer=transformer,
151 | torch_dtype=torch.bfloat16,
152 | )
153 | pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=5.0)
154 | pipe.to("cuda")
155 |
156 | prompt = "An origami fox running in the forest. The fox is made of polygons. speed and passion. realistic."
157 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
158 | nag_negative_prompt = "static, low resolution, blurry"
159 |
160 | output = pipe(
161 | prompt=prompt,
162 | negative_prompt=negative_prompt,
163 | nag_negative_prompt=nag_negative_prompt,
164 | guidance_scale=5.0,
165 | nag_scale=9,
166 | height=480,
167 | width=832,
168 | num_inference_steps=25,
169 | num_frames=81,
170 | ).frames[0]
171 | ```
172 |
173 | For 4-step inference with CausVid, please refer to the [demo](https://huggingface.co/spaces/ChenDY/NAG_wan2-1-fast/blob/main/app.py).
174 |
175 | ### SD3.5
176 |
177 | ```python
178 | import torch
179 | from nag import NAGStableDiffusion3Pipeline
180 |
181 | model_id = "stabilityai/stable-diffusion-3.5-large-turbo"
182 | pipe = NAGStableDiffusion3Pipeline.from_pretrained(
183 | model_id,
184 | torch_dtype=torch.bfloat16,
185 | token="hf_token",
186 | )
187 | pipe.to("cuda")
188 |
189 | prompt = "A beautiful cyborg"
190 | nag_negative_prompt = "robot"
191 |
192 | image = pipe(
193 | prompt,
194 | nag_negative_prompt=nag_negative_prompt,
195 | guidance_scale=0.,
196 | nag_scale=5,
197 | num_inference_steps=8,
198 | ).images[0]
199 | ```
200 |
201 | ### SDXL
202 |
203 | ```python
204 | import torch
205 | from diffusers import UNet2DConditionModel, LCMScheduler
206 | from huggingface_hub import hf_hub_download
207 | from nag import NAGStableDiffusionXLPipeline
208 |
209 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
210 | repo_name = "tianweiy/DMD2"
211 | ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin"
212 |
213 | unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.bfloat16)
214 | unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda"))
215 | pipe = NAGStableDiffusionXLPipeline.from_pretrained(
216 | base_model_id,
217 | unet=unet,
218 | torch_dtype=torch.bfloat16,
219 | variant="fp16",
220 | ).to("cuda")
221 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, original_inference_steps=4)
222 |
223 | prompt = "A beautiful cyborg"
224 | nag_negative_prompt = "robot"
225 |
226 | image = pipe(
227 | prompt,
228 | nag_negative_prompt=nag_negative_prompt,
229 | guidance_scale=0,
230 | nag_scale=3,
231 | num_inference_steps=4,
232 | ).images[0]
233 | ```
234 |
235 | ## Citation
236 |
237 | If you find NAG is useful or relevant to your research, please kindly cite our work:
238 |
239 | ```bib
240 | @article{chen2025normalizedattentionguidanceuniversal,
241 | title={Normalized Attention Guidance: Universal Negative Guidance for Diffusion Model},
242 | author={Dar-Yen Chen and Hmrishav Bandyopadhyay and Kai Zou and Yi-Zhe Song},
243 | journal={arXiv preprint arxiv:2505.21179},
244 | year={2025}
245 | }
246 | ```
247 |
248 |
--------------------------------------------------------------------------------
/assets/architecture.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenDarYen/Normalized-Attention-Guidance/8bb34a16e517692743fab99bc60838c20f975ded/assets/architecture.jpg
--------------------------------------------------------------------------------
/assets/banner.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ChenDarYen/Normalized-Attention-Guidance/8bb34a16e517692743fab99bc60838c20f975ded/assets/banner.jpg
--------------------------------------------------------------------------------
/nag/__init__.py:
--------------------------------------------------------------------------------
1 | from .pipeline_flux_nag import NAGFluxPipeline
2 | from .pipeline_flux_kontext_nag import NAGFluxKontextPipeline
3 | from .pipeline_wan_nag import NAGWanPipeline
4 | from .pipeline_sd3_nag import NAGStableDiffusion3Pipeline
5 | from .pipeline_sdxl_nag import NAGStableDiffusionXLPipeline
6 | from .transformer_flux import NAGFluxTransformer2DModel
7 | from .transformer_wan_nag import NagWanTransformer3DModel
8 |
--------------------------------------------------------------------------------
/nag/attention_flux_nag.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | from diffusers.models.attention_processor import Attention
8 | from diffusers.models.embeddings import apply_rotary_emb
9 |
10 |
11 | class NAGFluxAttnProcessor2_0:
12 | """Attention processor used typically in processing the SD3-like self-attention projections."""
13 |
14 | def __init__(
15 | self,
16 | nag_scale: float = 1.0,
17 | nag_tau=2.5,
18 | nag_alpha=0.25,
19 | encoder_hidden_states_length: int = None,
20 | ):
21 | if not hasattr(F, "scaled_dot_product_attention"):
22 | raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
23 | self.nag_scale = nag_scale
24 | self.nag_tau = nag_tau
25 | self.nag_alpha = nag_alpha
26 | self.encoder_hidden_states_length = encoder_hidden_states_length
27 |
28 | def __call__(
29 | self,
30 | attn: Attention,
31 | hidden_states: torch.FloatTensor,
32 | encoder_hidden_states: torch.FloatTensor = None,
33 | attention_mask: Optional[torch.FloatTensor] = None,
34 | image_rotary_emb: Optional[torch.Tensor] = None,
35 | ) -> torch.FloatTensor:
36 | batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37 |
38 | if self.nag_scale > 1.:
39 | if encoder_hidden_states is not None:
40 | assert len(hidden_states) == batch_size * 0.5
41 | apply_guidance = True
42 | else:
43 | apply_guidance = False
44 |
45 | # `sample` projections.
46 | query = attn.to_q(hidden_states)
47 | key = attn.to_k(hidden_states)
48 | value = attn.to_v(hidden_states)
49 |
50 | # attention
51 | if apply_guidance and encoder_hidden_states is not None:
52 | query = query.tile(2, 1, 1)
53 | key = key.tile(2, 1, 1)
54 | value = value.tile(2, 1, 1)
55 |
56 | inner_dim = key.shape[-1]
57 | head_dim = inner_dim // attn.heads
58 |
59 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
60 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
61 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
62 |
63 | if attn.norm_q is not None:
64 | query = attn.norm_q(query)
65 | if attn.norm_k is not None:
66 | key = attn.norm_k(key)
67 |
68 | # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
69 | if encoder_hidden_states is not None:
70 | # `context` projections.
71 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
72 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
73 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
74 |
75 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
76 | batch_size, -1, attn.heads, head_dim
77 | ).transpose(1, 2)
78 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
79 | batch_size, -1, attn.heads, head_dim
80 | ).transpose(1, 2)
81 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
82 | batch_size, -1, attn.heads, head_dim
83 | ).transpose(1, 2)
84 |
85 | if attn.norm_added_q is not None:
86 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
87 | if attn.norm_added_k is not None:
88 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
89 |
90 | query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
91 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
92 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
93 |
94 | encoder_hidden_states_length = encoder_hidden_states.shape[1]
95 |
96 | else:
97 | assert self.encoder_hidden_states_length is not None
98 | encoder_hidden_states_length = self.encoder_hidden_states_length
99 |
100 | if image_rotary_emb is not None:
101 | query = apply_rotary_emb(query, image_rotary_emb)
102 | key = apply_rotary_emb(key, image_rotary_emb)
103 |
104 | if not apply_guidance:
105 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
106 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
107 | hidden_states = hidden_states.to(query.dtype)
108 |
109 | else:
110 | origin_batch_size = batch_size // 2
111 | query, query_negative = torch.chunk(query, 2, dim=0)
112 | key, key_negative = torch.chunk(key, 2, dim=0)
113 | value, value_negative = torch.chunk(value, 2, dim=0)
114 |
115 | hidden_states_negative = F.scaled_dot_product_attention(query_negative, key_negative, value_negative, dropout_p=0.0, is_causal=False)
116 | hidden_states_negative = hidden_states_negative.transpose(1, 2).reshape(origin_batch_size, -1, attn.heads * head_dim)
117 | hidden_states_negative = hidden_states_negative.to(query.dtype)
118 |
119 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
120 | hidden_states = hidden_states.transpose(1, 2).reshape(origin_batch_size, -1, attn.heads * head_dim)
121 | hidden_states = hidden_states.to(query.dtype)
122 |
123 | if encoder_hidden_states is not None:
124 | encoder_hidden_states, hidden_states = (
125 | hidden_states[:, : encoder_hidden_states.shape[1]],
126 | hidden_states[:, encoder_hidden_states.shape[1] :],
127 | )
128 |
129 | if apply_guidance:
130 | encoder_hidden_states_negative, hidden_states_negative = (
131 | hidden_states_negative[:, : encoder_hidden_states.shape[1]],
132 | hidden_states_negative[:, encoder_hidden_states.shape[1]:],
133 | )
134 | hidden_states_positive = hidden_states
135 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1)
136 | norm_positive = torch.norm(hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
137 | norm_guidance = torch.norm(hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
138 |
139 | scale = norm_guidance / norm_positive
140 | hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale
141 |
142 | hidden_states = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha)
143 |
144 | encoder_hidden_states = torch.cat((encoder_hidden_states, encoder_hidden_states_negative), dim=0)
145 |
146 | # linear proj
147 | hidden_states = attn.to_out[0](hidden_states)
148 | # dropout
149 | hidden_states = attn.to_out[1](hidden_states)
150 |
151 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
152 |
153 | return hidden_states, encoder_hidden_states
154 |
155 | else:
156 | if apply_guidance:
157 | image_hidden_states_negative = hidden_states_negative[:, encoder_hidden_states_length:]
158 | image_hidden_states = hidden_states[:, encoder_hidden_states_length:]
159 |
160 | image_hidden_states_positive = image_hidden_states
161 | image_hidden_states_guidance = image_hidden_states_positive * self.nag_scale - image_hidden_states_negative * (self.nag_scale - 1)
162 | norm_positive = torch.norm(image_hidden_states_positive, p=2, dim=-1, keepdim=True).expand(*image_hidden_states_positive.shape)
163 | norm_guidance = torch.norm(image_hidden_states_guidance, p=2, dim=-1, keepdim=True).expand(*image_hidden_states_positive.shape)
164 |
165 | scale = norm_guidance / norm_positive
166 | image_hidden_states_guidance = image_hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale
167 | # scale = torch.nan_to_num(scale, 10)
168 | # image_hidden_states_guidance[scale > self.nag_tau] = image_hidden_states_guidance[scale > self.nag_tau] / (norm_guidance[scale > self.nag_tau] + 1e-7) * norm_positive[scale > self.nag_tau] * self.nag_tau
169 |
170 | image_hidden_states = image_hidden_states_guidance * self.nag_alpha + image_hidden_states_positive * (1 - self.nag_alpha)
171 |
172 | hidden_states_negative[:, encoder_hidden_states_length:] = image_hidden_states
173 | hidden_states[:, encoder_hidden_states_length:] = image_hidden_states
174 | hidden_states = torch.cat((hidden_states, hidden_states_negative), dim=0)
175 |
176 | return hidden_states
177 |
--------------------------------------------------------------------------------
/nag/attention_joint_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from diffusers.models.attention_processor import Attention
7 |
8 |
9 | class NAGJointAttnProcessor2_0:
10 | def __init__(self, nag_scale: float = 1.0, nag_tau: float = 2.5, nag_alpha:float = 0.125):
11 | if not hasattr(F, "scaled_dot_product_attention"):
12 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
13 | self.nag_scale = nag_scale
14 | self.nag_tau = nag_tau
15 | self.nag_alpha = nag_alpha
16 |
17 | def __call__(
18 | self,
19 | attn: Attention,
20 | hidden_states: torch.FloatTensor,
21 | encoder_hidden_states: torch.FloatTensor = None,
22 | attention_mask: Optional[torch.FloatTensor] = None,
23 | *args,
24 | **kwargs,
25 | ) -> torch.FloatTensor:
26 | residual = hidden_states
27 |
28 | batch_size = hidden_states.shape[0]
29 |
30 | apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None
31 | if apply_guidance:
32 | origin_batch_size = len(encoder_hidden_states) - batch_size
33 | assert len(encoder_hidden_states) / origin_batch_size in [2, 3, 4]
34 |
35 | # `sample` projections.
36 | query = attn.to_q(hidden_states)
37 | key = attn.to_k(hidden_states)
38 | value = attn.to_v(hidden_states)
39 |
40 | inner_dim = key.shape[-1]
41 | head_dim = inner_dim // attn.heads
42 |
43 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
44 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
45 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
46 |
47 | if attn.norm_q is not None:
48 | query = attn.norm_q(query)
49 | if attn.norm_k is not None:
50 | key = attn.norm_k(key)
51 |
52 | if apply_guidance:
53 | batch_size += origin_batch_size
54 | if batch_size == 2 * origin_batch_size:
55 | query = query.tile(2, 1, 1, 1)
56 | key = key.tile(2, 1, 1, 1)
57 | value = value.tile(2, 1, 1, 1)
58 | else:
59 | query = torch.cat([query, query[origin_batch_size:2 * origin_batch_size]], dim=0)
60 | key = torch.cat([key, key[origin_batch_size:2 * origin_batch_size]], dim=0)
61 | value = torch.cat([value, value[origin_batch_size:2 * origin_batch_size]], dim=0)
62 |
63 | # `context` projections.
64 | if encoder_hidden_states is not None:
65 | encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
66 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
67 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
68 |
69 | encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
70 | batch_size, -1, attn.heads, head_dim
71 | ).transpose(1, 2)
72 | encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
73 | batch_size, -1, attn.heads, head_dim
74 | ).transpose(1, 2)
75 | encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
76 | batch_size, -1, attn.heads, head_dim
77 | ).transpose(1, 2)
78 |
79 | if attn.norm_added_q is not None:
80 | encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
81 | if attn.norm_added_k is not None:
82 | encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
83 |
84 | query = torch.cat([query, encoder_hidden_states_query_proj], dim=2)
85 | key = torch.cat([key, encoder_hidden_states_key_proj], dim=2)
86 | value = torch.cat([value, encoder_hidden_states_value_proj], dim=2)
87 |
88 | hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
89 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
90 | hidden_states = hidden_states.to(query.dtype)
91 |
92 | if encoder_hidden_states is not None:
93 | # Split the attention outputs.
94 | hidden_states, encoder_hidden_states = (
95 | hidden_states[:, : residual.shape[1]],
96 | hidden_states[:, residual.shape[1] :],
97 | )
98 | if not attn.context_pre_only:
99 | encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
100 |
101 | if apply_guidance:
102 | hidden_states_negative = hidden_states[-origin_batch_size:]
103 | if batch_size == 2 * origin_batch_size:
104 | hidden_states_positive = hidden_states[:origin_batch_size]
105 | else:
106 | hidden_states_positive = hidden_states[origin_batch_size:2 * origin_batch_size]
107 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1)
108 | norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
109 | norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape)
110 |
111 | scale = norm_guidance / (norm_positive + 1e-7)
112 | hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / (scale + 1e-7)
113 |
114 | hidden_states_guidance = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha)
115 |
116 | if batch_size == 2 * origin_batch_size:
117 | hidden_states = hidden_states_guidance
118 | elif batch_size == 3 * origin_batch_size:
119 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance), dim=0)
120 | elif batch_size == 4 * origin_batch_size:
121 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance, hidden_states[2 * origin_batch_size:3 * origin_batch_size]), dim=0)
122 |
123 | # linear proj
124 | hidden_states = attn.to_out[0](hidden_states)
125 | # dropout
126 | hidden_states = attn.to_out[1](hidden_states)
127 |
128 | if encoder_hidden_states is not None:
129 | return hidden_states, encoder_hidden_states
130 | else:
131 | return hidden_states
132 |
133 |
134 | class NAGPAGCFGJointAttnProcessor2_0:
135 | """Attention processor used typically in processing the SD3-like self-attention projections."""
136 | def __init__(self, nag_scale: float = 1.0, nag_tau: float = 2.5, nag_alpha:float = 0.125):
137 | if not hasattr(F, "scaled_dot_product_attention"):
138 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
139 | self.nag_scale = nag_scale
140 | self.nag_tau = nag_tau
141 | self.nag_alpha = nag_alpha
142 |
143 | def __call__(
144 | self,
145 | attn: Attention,
146 | hidden_states: torch.FloatTensor,
147 | encoder_hidden_states: torch.FloatTensor = None,
148 | attention_mask: Optional[torch.FloatTensor] = None,
149 | *args,
150 | **kwargs,
151 | ) -> torch.FloatTensor:
152 | residual = hidden_states
153 |
154 | input_ndim = hidden_states.ndim
155 | if input_ndim == 4:
156 | batch_size, channel, height, width = hidden_states.shape
157 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
158 | context_input_ndim = encoder_hidden_states.ndim
159 | if context_input_ndim == 4:
160 | batch_size, channel, height, width = encoder_hidden_states.shape
161 | encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
162 |
163 | identity_block_size = hidden_states.shape[
164 | 1
165 | ] # patch embeddings width * height (correspond to self-attention map width or height)
166 |
167 | # chunk
168 | hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
169 | hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
170 |
171 | (
172 | encoder_hidden_states_uncond,
173 | encoder_hidden_states_org,
174 | encoder_hidden_states_ptb,
175 | encoder_hidden_states_nag,
176 | ) = encoder_hidden_states.chunk(4)
177 | encoder_hidden_states_org = torch.cat([encoder_hidden_states_uncond, encoder_hidden_states_org, encoder_hidden_states_nag])
178 |
179 | ################## original path ##################
180 | batch_size = encoder_hidden_states_org.shape[0]
181 | origin_batch_size = batch_size // 3
182 |
183 | # `sample` projections.
184 | query_org = attn.to_q(hidden_states_org)
185 | key_org = attn.to_k(hidden_states_org)
186 | value_org = attn.to_v(hidden_states_org)
187 |
188 | query_org = torch.cat([query_org, query_org[-origin_batch_size:]], dim=0)
189 | key_org = torch.cat([key_org, key_org[-origin_batch_size:]], dim=0)
190 | value_org = torch.cat([value_org, value_org[-origin_batch_size:]], dim=0)
191 |
192 | # `context` projections.
193 | encoder_hidden_states_org_query_proj = attn.add_q_proj(encoder_hidden_states_org)
194 | encoder_hidden_states_org_key_proj = attn.add_k_proj(encoder_hidden_states_org)
195 | encoder_hidden_states_org_value_proj = attn.add_v_proj(encoder_hidden_states_org)
196 |
197 | # attention
198 | query_org = torch.cat([query_org, encoder_hidden_states_org_query_proj], dim=1)
199 | key_org = torch.cat([key_org, encoder_hidden_states_org_key_proj], dim=1)
200 | value_org = torch.cat([value_org, encoder_hidden_states_org_value_proj], dim=1)
201 |
202 | inner_dim = key_org.shape[-1]
203 | head_dim = inner_dim // attn.heads
204 | query_org = query_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
205 | key_org = key_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
206 | value_org = value_org.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
207 |
208 | hidden_states_org = F.scaled_dot_product_attention(
209 | query_org, key_org, value_org, dropout_p=0.0, is_causal=False
210 | )
211 | hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
212 | hidden_states_org = hidden_states_org.to(query_org.dtype)
213 |
214 | # Split the attention outputs.
215 | hidden_states_org, encoder_hidden_states_org = (
216 | hidden_states_org[:, : residual.shape[1]],
217 | hidden_states_org[:, residual.shape[1] :],
218 | )
219 |
220 | hidden_states_org_negative = hidden_states_org[-origin_batch_size:]
221 | hidden_states_org_positive = hidden_states_org[-2 * origin_batch_size:-origin_batch_size]
222 | hidden_states_org_guidance = hidden_states_org_positive * self.nag_scale - hidden_states_org_negative * (self.nag_scale - 1)
223 | norm_positive = torch.norm(hidden_states_org_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_org_positive.shape)
224 | norm_guidance = torch.norm(hidden_states_org_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_org_guidance.shape)
225 |
226 | scale = norm_guidance / (norm_positive + 1e-7)
227 | hidden_states_org_guidance = hidden_states_org_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / (scale + 1e-7)
228 |
229 | hidden_states_org_guidance = hidden_states_org_guidance * self.nag_alpha + hidden_states_org_positive * (1 - self.nag_alpha)
230 |
231 | hidden_states_org = torch.cat((hidden_states_org[:origin_batch_size], hidden_states_org_guidance), dim=0)
232 |
233 | # linear proj
234 | hidden_states_org = attn.to_out[0](hidden_states_org)
235 | # dropout
236 | hidden_states_org = attn.to_out[1](hidden_states_org)
237 | if not attn.context_pre_only:
238 | encoder_hidden_states_org = attn.to_add_out(encoder_hidden_states_org)
239 |
240 | if input_ndim == 4:
241 | hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
242 | if context_input_ndim == 4:
243 | encoder_hidden_states_org = encoder_hidden_states_org.transpose(-1, -2).reshape(
244 | batch_size, channel, height, width
245 | )
246 |
247 | ################## perturbed path ##################
248 |
249 | batch_size = encoder_hidden_states_ptb.shape[0]
250 |
251 | # `sample` projections.
252 | query_ptb = attn.to_q(hidden_states_ptb)
253 | key_ptb = attn.to_k(hidden_states_ptb)
254 | value_ptb = attn.to_v(hidden_states_ptb)
255 |
256 | # `context` projections.
257 | encoder_hidden_states_ptb_query_proj = attn.add_q_proj(encoder_hidden_states_ptb)
258 | encoder_hidden_states_ptb_key_proj = attn.add_k_proj(encoder_hidden_states_ptb)
259 | encoder_hidden_states_ptb_value_proj = attn.add_v_proj(encoder_hidden_states_ptb)
260 |
261 | # attention
262 | query_ptb = torch.cat([query_ptb, encoder_hidden_states_ptb_query_proj], dim=1)
263 | key_ptb = torch.cat([key_ptb, encoder_hidden_states_ptb_key_proj], dim=1)
264 | value_ptb = torch.cat([value_ptb, encoder_hidden_states_ptb_value_proj], dim=1)
265 |
266 | inner_dim = key_ptb.shape[-1]
267 | head_dim = inner_dim // attn.heads
268 | query_ptb = query_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269 | key_ptb = key_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
270 | value_ptb = value_ptb.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
271 |
272 | # create a full mask with all entries set to 0
273 | seq_len = query_ptb.size(2)
274 | full_mask = torch.zeros((seq_len, seq_len), device=query_ptb.device, dtype=query_ptb.dtype)
275 |
276 | # set the attention value between image patches to -inf
277 | full_mask[:identity_block_size, :identity_block_size] = float("-inf")
278 |
279 | # set the diagonal of the attention value between image patches to 0
280 | full_mask[:identity_block_size, :identity_block_size].fill_diagonal_(0)
281 |
282 | # expand the mask to match the attention weights shape
283 | full_mask = full_mask.unsqueeze(0).unsqueeze(0) # Add batch and num_heads dimensions
284 |
285 | hidden_states_ptb = F.scaled_dot_product_attention(
286 | query_ptb, key_ptb, value_ptb, attn_mask=full_mask, dropout_p=0.0, is_causal=False
287 | )
288 | hidden_states_ptb = hidden_states_ptb.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
289 | hidden_states_ptb = hidden_states_ptb.to(query_ptb.dtype)
290 |
291 | # split the attention outputs.
292 | hidden_states_ptb, encoder_hidden_states_ptb = (
293 | hidden_states_ptb[:, : residual.shape[1]],
294 | hidden_states_ptb[:, residual.shape[1] :],
295 | )
296 |
297 | # linear proj
298 | hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
299 | # dropout
300 | hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
301 | if not attn.context_pre_only:
302 | encoder_hidden_states_ptb = attn.to_add_out(encoder_hidden_states_ptb)
303 |
304 | if input_ndim == 4:
305 | hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
306 | if context_input_ndim == 4:
307 | encoder_hidden_states_ptb = encoder_hidden_states_ptb.transpose(-1, -2).reshape(
308 | batch_size, channel, height, width
309 | )
310 |
311 | ################ concat ###############
312 | hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
313 | encoder_hidden_states = torch.cat([encoder_hidden_states_org[:2 * origin_batch_size], encoder_hidden_states_ptb, encoder_hidden_states_org[-origin_batch_size:]])
314 |
315 | return hidden_states, encoder_hidden_states
--------------------------------------------------------------------------------
/nag/attention_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from diffusers.utils import deprecate
7 | from diffusers.models.attention_processor import Attention
8 |
9 |
10 | class NAGAttnProcessor2_0:
11 | def __init__(self, nag_scale: float = 1.0, nag_tau: float = 2.5, nag_alpha:float = 0.5):
12 | if not hasattr(F, "scaled_dot_product_attention"):
13 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
14 | self.nag_scale = nag_scale
15 | self.nag_tau = nag_tau
16 | self.nag_alpha = nag_alpha
17 |
18 | def __call__(
19 | self,
20 | attn: Attention,
21 | hidden_states: torch.Tensor,
22 | encoder_hidden_states: Optional[torch.Tensor] = None,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | temb: Optional[torch.Tensor] = None,
25 | *args,
26 | **kwargs,
27 | ) -> torch.Tensor:
28 | if len(args) > 0 or kwargs.get("scale", None) is not None:
29 | deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
30 | deprecate("scale", "1.0.0", deprecation_message)
31 |
32 | residual = hidden_states
33 | if attn.spatial_norm is not None:
34 | hidden_states = attn.spatial_norm(hidden_states, temb)
35 |
36 | input_ndim = hidden_states.ndim
37 |
38 | if input_ndim == 4:
39 | batch_size, channel, height, width = hidden_states.shape
40 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
41 |
42 | batch_size, sequence_length, _ = (
43 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
44 | )
45 |
46 | apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None
47 | if apply_guidance:
48 | origin_batch_size = batch_size - len(hidden_states)
49 | assert batch_size / origin_batch_size in [2, 3, 4]
50 |
51 | if attention_mask is not None:
52 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
53 | # scaled_dot_product_attention expects attention_mask shape to be
54 | # (batch, heads, source_length, target_length)
55 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
56 |
57 | if attn.group_norm is not None:
58 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
59 |
60 | query = attn.to_q(hidden_states)
61 |
62 | if encoder_hidden_states is None:
63 | encoder_hidden_states = hidden_states
64 | elif attn.norm_cross:
65 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
66 |
67 | key = attn.to_k(encoder_hidden_states)
68 | value = attn.to_v(encoder_hidden_states)
69 |
70 | inner_dim = key.shape[-1]
71 | head_dim = inner_dim // attn.heads
72 |
73 | if apply_guidance:
74 | if batch_size == 2 * origin_batch_size:
75 | query = query.tile(2, 1, 1)
76 | else:
77 | query = torch.cat((query, query[origin_batch_size:2 * origin_batch_size]), dim=0)
78 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
79 |
80 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
81 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82 |
83 | if attn.norm_q is not None:
84 | query = attn.norm_q(query)
85 | if attn.norm_k is not None:
86 | key = attn.norm_k(key)
87 |
88 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
89 | # TODO: add support for attn.scale when we move to Torch 2.1
90 | hidden_states = F.scaled_dot_product_attention(
91 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
92 | )
93 |
94 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
95 | hidden_states = hidden_states.to(query.dtype)
96 |
97 | if apply_guidance:
98 | hidden_states_negative = hidden_states[-origin_batch_size:]
99 | if batch_size == 2 * origin_batch_size:
100 | hidden_states_positive = hidden_states[:origin_batch_size]
101 | else:
102 | hidden_states_positive = hidden_states[origin_batch_size:2 * origin_batch_size]
103 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1)
104 | norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
105 | norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape)
106 |
107 | scale = norm_guidance / norm_positive
108 | hidden_states_guidance = hidden_states_guidance * torch.minimum(scale, scale.new_ones(1) * self.nag_tau) / scale
109 |
110 | hidden_states_guidance = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha)
111 |
112 | if batch_size == 2 * origin_batch_size:
113 | hidden_states = hidden_states_guidance
114 | elif batch_size == 3 * origin_batch_size:
115 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance), dim=0)
116 | elif batch_size == 4 * origin_batch_size:
117 | hidden_states = torch.cat((hidden_states[:origin_batch_size], hidden_states_guidance, hidden_states[2 * origin_batch_size:3 * origin_batch_size]), dim=0)
118 |
119 | # linear proj
120 | hidden_states = attn.to_out[0](hidden_states)
121 | # dropout
122 | hidden_states = attn.to_out[1](hidden_states)
123 |
124 | if input_ndim == 4:
125 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
126 |
127 | if attn.residual_connection:
128 | hidden_states = hidden_states + residual
129 |
130 | hidden_states = hidden_states / attn.rescale_output_factor
131 |
132 | return hidden_states
133 |
--------------------------------------------------------------------------------
/nag/attention_wan_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from diffusers.models.attention_processor import Attention
7 | from ftfy import apply_plan
8 |
9 |
10 | class NAGWanAttnProcessor2_0:
11 | def __init__(self, nag_scale=1.0, nag_tau=2.5, nag_alpha=0.25):
12 | if not hasattr(F, "scaled_dot_product_attention"):
13 | raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
14 | self.nag_scale = nag_scale
15 | self.nag_tau = nag_tau
16 | self.nag_alpha = nag_alpha
17 |
18 | def __call__(
19 | self,
20 | attn: Attention,
21 | hidden_states: torch.Tensor,
22 | encoder_hidden_states: Optional[torch.Tensor] = None,
23 | attention_mask: Optional[torch.Tensor] = None,
24 | rotary_emb: Optional[torch.Tensor] = None,
25 | ) -> torch.Tensor:
26 | apply_guidance = self.nag_scale > 1 and encoder_hidden_states is not None
27 | if apply_guidance:
28 | if len(encoder_hidden_states) == 2 * len(hidden_states):
29 | batch_size = len(hidden_states)
30 | else:
31 | apply_guidance = False
32 |
33 | encoder_hidden_states_img = None
34 | if attn.add_k_proj is not None:
35 | encoder_hidden_states_img = encoder_hidden_states[:, :257]
36 | encoder_hidden_states = encoder_hidden_states[:, 257:]
37 | if apply_guidance:
38 | encoder_hidden_states_img = encoder_hidden_states_img[:batch_size]
39 | if encoder_hidden_states is None:
40 | encoder_hidden_states = hidden_states
41 |
42 | query = attn.to_q(hidden_states)
43 | key = attn.to_k(encoder_hidden_states)
44 | value = attn.to_v(encoder_hidden_states)
45 |
46 | if attn.norm_q is not None:
47 | query = attn.norm_q(query)
48 | if attn.norm_k is not None:
49 | key = attn.norm_k(key)
50 |
51 | query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
52 | key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
53 | value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
54 |
55 | if rotary_emb is not None:
56 |
57 | def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
58 | x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
59 | x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
60 | return x_out.type_as(hidden_states)
61 |
62 | query = apply_rotary_emb(query, rotary_emb)
63 | key = apply_rotary_emb(key, rotary_emb)
64 |
65 | # I2V task
66 | hidden_states_img = None
67 | if encoder_hidden_states_img is not None:
68 | key_img = attn.add_k_proj(encoder_hidden_states_img)
69 | key_img = attn.norm_added_k(key_img)
70 | value_img = attn.add_v_proj(encoder_hidden_states_img)
71 |
72 | key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
73 | value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
74 |
75 | hidden_states_img = F.scaled_dot_product_attention(
76 | query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
77 | )
78 | hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
79 | hidden_states_img = hidden_states_img.type_as(query)
80 |
81 | if apply_guidance:
82 | key, key_negative = torch.chunk(key, 2, dim=0)
83 | value, value_negative = torch.chunk(value, 2, dim=0)
84 | hidden_states = F.scaled_dot_product_attention(
85 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
86 | )
87 | hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
88 | hidden_states = hidden_states.type_as(query)
89 | if apply_guidance:
90 | hidden_states_negative = F.scaled_dot_product_attention(
91 | query, key_negative, value_negative, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
92 | )
93 | hidden_states_negative = hidden_states_negative.transpose(1, 2).flatten(2, 3)
94 | hidden_states_negative = hidden_states_negative.type_as(query)
95 |
96 | hidden_states_positive = hidden_states
97 |
98 | hidden_states_guidance = hidden_states_positive * self.nag_scale - hidden_states_negative * (self.nag_scale - 1)
99 | norm_positive = torch.norm(hidden_states_positive, p=1, dim=-1, keepdim=True).expand(*hidden_states_positive.shape)
100 | norm_guidance = torch.norm(hidden_states_guidance, p=1, dim=-1, keepdim=True).expand(*hidden_states_guidance.shape)
101 |
102 | scale = norm_guidance / norm_positive
103 | scale = torch.nan_to_num(scale, 10)
104 | hidden_states_guidance[scale > self.nag_tau] = \
105 | hidden_states_guidance[scale > self.nag_tau] / (norm_guidance[scale > self.nag_tau] + 1e-7) * norm_positive[scale > self.nag_tau] * self.nag_tau
106 |
107 | hidden_states = hidden_states_guidance * self.nag_alpha + hidden_states_positive * (1 - self.nag_alpha)
108 |
109 | if hidden_states_img is not None:
110 | hidden_states = hidden_states + hidden_states_img
111 |
112 | hidden_states = attn.to_out[0](hidden_states)
113 | hidden_states = attn.to_out[1](hidden_states)
114 | return hidden_states
115 |
--------------------------------------------------------------------------------
/nag/normalization.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Tuple
2 |
3 | import torch
4 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX
5 |
6 |
7 | class TruncAdaLayerNorm(AdaLayerNorm):
8 | def forward(
9 | self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
10 | ) -> torch.Tensor:
11 | batch_size = x.shape[0]
12 | return self.forward_old(
13 | x,
14 | temb[:batch_size] if temb is not None else None,
15 | )
16 |
17 |
18 | class TruncAdaLayerNormContinuous(AdaLayerNormContinuous):
19 | def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
20 | batch_size = x.shape[0]
21 | return self.forward_old(x, conditioning_embedding[:batch_size])
22 |
23 |
24 | class TruncAdaLayerNormZero(AdaLayerNormZero):
25 | def forward(
26 | self,
27 | x: torch.Tensor,
28 | timestep: Optional[torch.Tensor] = None,
29 | class_labels: Optional[torch.LongTensor] = None,
30 | hidden_dtype: Optional[torch.dtype] = None,
31 | emb: Optional[torch.Tensor] = None,
32 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
33 | batch_size = x.shape[0]
34 | return self.forward_old(
35 | x,
36 | timestep[:batch_size] if timestep is not None else None,
37 | class_labels[:batch_size] if class_labels is not None else None,
38 | hidden_dtype,
39 | emb[:batch_size] if emb is not None else None,
40 | )
41 |
42 |
43 | class TruncSD35AdaLayerNormZeroX(SD35AdaLayerNormZeroX):
44 | def forward(
45 | self,
46 | hidden_states: torch.Tensor,
47 | emb: Optional[torch.Tensor] = None,
48 | ) -> Tuple[torch.Tensor, ...]:
49 | batch_size = hidden_states.shape[0]
50 | return self.forward_old(
51 | hidden_states,
52 | emb[:batch_size] if emb is not None else None,
53 | )
54 |
--------------------------------------------------------------------------------
/nag/pipeline_flux_kontext_nag.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union, List, Optional, Dict, Any, Callable
3 | import types
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from diffusers import FluxKontextPipeline
9 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
10 | from diffusers.image_processor import PipelineImageInput
11 | from diffusers.utils import is_torch_xla_available, logging
12 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
13 | from diffusers.pipelines.flux.pipeline_flux_kontext import PREFERRED_KONTEXT_RESOLUTIONS
14 | from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormContinuous
15 |
16 | from nag.attention_flux_nag import NAGFluxAttnProcessor2_0
17 | from nag.normalization import TruncAdaLayerNormZero, TruncAdaLayerNormContinuous
18 |
19 | if is_torch_xla_available():
20 | import torch_xla.core.xla_model as xm
21 |
22 | XLA_AVAILABLE = True
23 | else:
24 | XLA_AVAILABLE = False
25 |
26 |
27 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28 |
29 |
30 | class NAGFluxKontextPipeline(FluxKontextPipeline):
31 | @property
32 | def do_normalized_attention_guidance(self):
33 | return self._nag_scale > 1
34 |
35 | def _set_nag_attn_processor(
36 | self,
37 | nag_scale,
38 | encoder_hidden_states_length,
39 | nag_tau=2.5,
40 | nag_alpha=0.25,
41 | ):
42 | attn_procs = {}
43 | for name in self.transformer.attn_processors.keys():
44 | attn_procs[name] = NAGFluxAttnProcessor2_0(
45 | nag_scale=nag_scale,
46 | nag_tau=nag_tau,
47 | nag_alpha=nag_alpha,
48 | encoder_hidden_states_length=encoder_hidden_states_length,
49 | )
50 | self.transformer.set_attn_processor(attn_procs)
51 |
52 | @torch.no_grad()
53 | def __call__(
54 | self,
55 | image: Optional[PipelineImageInput] = None,
56 | prompt: Union[str, List[str]] = None,
57 | prompt_2: Optional[Union[str, List[str]]] = None,
58 | negative_prompt: Union[str, List[str]] = None,
59 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
60 | true_cfg_scale: float = 1.0,
61 | height: Optional[int] = None,
62 | width: Optional[int] = None,
63 | num_inference_steps: int = 28,
64 | sigmas: Optional[List[float]] = None,
65 | guidance_scale: float = 3.5,
66 | num_images_per_prompt: Optional[int] = 1,
67 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
68 | latents: Optional[torch.FloatTensor] = None,
69 | prompt_embeds: Optional[torch.FloatTensor] = None,
70 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
71 | ip_adapter_image: Optional[PipelineImageInput] = None,
72 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
73 | negative_ip_adapter_image: Optional[PipelineImageInput] = None,
74 | negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
75 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
76 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
77 | output_type: Optional[str] = "pil",
78 | return_dict: bool = True,
79 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
81 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
82 | max_sequence_length: int = 512,
83 | max_area: int = 1024 ** 2,
84 | _auto_resize: bool = True,
85 |
86 | nag_scale: float = 1.0,
87 | nag_tau: float = 2.5,
88 | nag_alpha: float = 0.25,
89 | nag_end: float = 0.25,
90 | nag_negative_prompt: str = None,
91 | nag_negative_prompt_2: str = None,
92 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
93 | nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
94 | ):
95 | r"""
96 | Function invoked when calling the pipeline for generation.
97 |
98 | Args:
99 | image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
100 | `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
101 | numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
102 | or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
103 | list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
104 | latents as `image`, but if passing latents directly it is not encoded again.
105 | prompt (`str` or `List[str]`, *optional*):
106 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
107 | instead.
108 | prompt_2 (`str` or `List[str]`, *optional*):
109 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
110 | will be used instead.
111 | negative_prompt (`str` or `List[str]`, *optional*):
112 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
113 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
114 | not greater than `1`).
115 | negative_prompt_2 (`str` or `List[str]`, *optional*):
116 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
117 | `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
118 | true_cfg_scale (`float`, *optional*, defaults to 1.0):
119 | When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
120 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
121 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
122 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
123 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
124 | num_inference_steps (`int`, *optional*, defaults to 50):
125 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
126 | expense of slower inference.
127 | sigmas (`List[float]`, *optional*):
128 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
129 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
130 | will be used.
131 | guidance_scale (`float`, *optional*, defaults to 3.5):
132 | Guidance scale as defined in [Classifier-Free Diffusion
133 | Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
134 | of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
135 | `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
136 | the text `prompt`, usually at the expense of lower image quality.
137 | num_images_per_prompt (`int`, *optional*, defaults to 1):
138 | The number of images to generate per prompt.
139 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
140 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
141 | to make generation deterministic.
142 | latents (`torch.FloatTensor`, *optional*):
143 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
144 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
145 | tensor will ge generated by sampling using the supplied random `generator`.
146 | prompt_embeds (`torch.FloatTensor`, *optional*):
147 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
148 | provided, text embeddings will be generated from `prompt` input argument.
149 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
150 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
151 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
152 | ip_adapter_image: (`PipelineImageInput`, *optional*):
153 | Optional image input to work with IP Adapters.
154 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
155 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
156 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
157 | provided, embeddings are computed from the `ip_adapter_image` input argument.
158 | negative_ip_adapter_image:
159 | (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
160 | negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
161 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
162 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
163 | provided, embeddings are computed from the `ip_adapter_image` input argument.
164 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
165 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
166 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
167 | argument.
168 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
169 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
170 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
171 | input argument.
172 | output_type (`str`, *optional*, defaults to `"pil"`):
173 | The output format of the generate image. Choose between
174 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
175 | return_dict (`bool`, *optional*, defaults to `True`):
176 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
177 | joint_attention_kwargs (`dict`, *optional*):
178 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
179 | `self.processor` in
180 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
181 | callback_on_step_end (`Callable`, *optional*):
182 | A function that calls at the end of each denoising steps during the inference. The function is called
183 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
184 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
185 | `callback_on_step_end_tensor_inputs`.
186 | callback_on_step_end_tensor_inputs (`List`, *optional*):
187 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
188 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
189 | `._callback_tensor_inputs` attribute of your pipeline class.
190 | max_sequence_length (`int` defaults to 512):
191 | Maximum sequence length to use with the `prompt`.
192 | max_area (`int`, defaults to `1024 ** 2`):
193 | The maximum area of the generated image in pixels. The height and width will be adjusted to fit this
194 | area while maintaining the aspect ratio.
195 |
196 | Examples:
197 |
198 | Returns:
199 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
200 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
201 | images.
202 | """
203 |
204 | height = height or self.default_sample_size * self.vae_scale_factor
205 | width = width or self.default_sample_size * self.vae_scale_factor
206 |
207 | original_height, original_width = height, width
208 | aspect_ratio = width / height
209 | width = round((max_area * aspect_ratio) ** 0.5)
210 | height = round((max_area / aspect_ratio) ** 0.5)
211 |
212 | multiple_of = self.vae_scale_factor * 2
213 | width = width // multiple_of * multiple_of
214 | height = height // multiple_of * multiple_of
215 |
216 | if height != original_height or width != original_width:
217 | logger.warning(
218 | f"Generation `height` and `width` have been adjusted to {height} and {width} to fit the model requirements."
219 | )
220 |
221 | # 1. Check inputs. Raise error if not correct
222 | self.check_inputs(
223 | prompt,
224 | prompt_2,
225 | height,
226 | width,
227 | negative_prompt=negative_prompt,
228 | negative_prompt_2=negative_prompt_2,
229 | prompt_embeds=prompt_embeds,
230 | negative_prompt_embeds=negative_prompt_embeds,
231 | pooled_prompt_embeds=pooled_prompt_embeds,
232 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
233 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
234 | max_sequence_length=max_sequence_length,
235 | )
236 |
237 | self._guidance_scale = guidance_scale
238 | self._joint_attention_kwargs = joint_attention_kwargs
239 | self._current_timestep = None
240 | self._interrupt = False
241 | self._nag_scale = nag_scale
242 |
243 | # 2. Define call parameters
244 | if prompt is not None and isinstance(prompt, str):
245 | batch_size = 1
246 | elif prompt is not None and isinstance(prompt, list):
247 | batch_size = len(prompt)
248 | else:
249 | batch_size = prompt_embeds.shape[0]
250 |
251 | device = self._execution_device
252 |
253 | lora_scale = (
254 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
255 | )
256 | has_neg_prompt = negative_prompt is not None or (
257 | negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
258 | )
259 | do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
260 | (
261 | prompt_embeds,
262 | pooled_prompt_embeds,
263 | text_ids,
264 | ) = self.encode_prompt(
265 | prompt=prompt,
266 | prompt_2=prompt_2,
267 | prompt_embeds=prompt_embeds,
268 | pooled_prompt_embeds=pooled_prompt_embeds,
269 | device=device,
270 | num_images_per_prompt=num_images_per_prompt,
271 | max_sequence_length=max_sequence_length,
272 | lora_scale=lora_scale,
273 | )
274 | if do_true_cfg:
275 | (
276 | negative_prompt_embeds,
277 | negative_pooled_prompt_embeds,
278 | negative_text_ids,
279 | ) = self.encode_prompt(
280 | prompt=negative_prompt,
281 | prompt_2=negative_prompt_2,
282 | prompt_embeds=negative_prompt_embeds,
283 | pooled_prompt_embeds=negative_pooled_prompt_embeds,
284 | device=device,
285 | num_images_per_prompt=num_images_per_prompt,
286 | max_sequence_length=max_sequence_length,
287 | lora_scale=lora_scale,
288 | )
289 |
290 | if self.do_normalized_attention_guidance:
291 | if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None:
292 | if nag_negative_prompt is None:
293 | if negative_prompt is not None:
294 | if do_true_cfg:
295 | nag_negative_prompt_embeds = negative_prompt_embeds
296 | nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds
297 | else:
298 | nag_negative_prompt = negative_prompt
299 | nag_negative_prompt_2 = negative_prompt_2
300 | else:
301 | nag_negative_prompt = ""
302 |
303 | if nag_negative_prompt is not None:
304 | nag_negative_prompt_embeds, nag_negative_pooled_prompt_embeds = self.encode_prompt(
305 | prompt=nag_negative_prompt,
306 | prompt_2=nag_negative_prompt_2,
307 | device=device,
308 | num_images_per_prompt=num_images_per_prompt,
309 | max_sequence_length=max_sequence_length,
310 | lora_scale=lora_scale,
311 | )[:2]
312 |
313 | if self.do_normalized_attention_guidance:
314 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, nag_negative_pooled_prompt_embeds], dim=0)
315 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
316 |
317 | # 3. Preprocess image
318 | if image is not None and not (isinstance(image, torch.Tensor) and image.size(1) == self.latent_channels):
319 | img = image[0] if isinstance(image, list) else image
320 | image_height, image_width = self.image_processor.get_default_height_width(img)
321 | aspect_ratio = image_width / image_height
322 | if _auto_resize:
323 | # Kontext is trained on specific resolutions, using one of them is recommended
324 | _, image_width, image_height = min(
325 | (abs(aspect_ratio - w / h), w, h) for w, h in PREFERRED_KONTEXT_RESOLUTIONS
326 | )
327 | image_width = image_width // multiple_of * multiple_of
328 | image_height = image_height // multiple_of * multiple_of
329 | image = self.image_processor.resize(image, image_height, image_width)
330 | image = self.image_processor.preprocess(image, image_height, image_width)
331 |
332 | # 4. Prepare latent variables
333 | num_channels_latents = self.transformer.config.in_channels // 4
334 | latents, image_latents, latent_ids, image_ids = self.prepare_latents(
335 | image,
336 | batch_size * num_images_per_prompt,
337 | num_channels_latents,
338 | height,
339 | width,
340 | prompt_embeds.dtype,
341 | device,
342 | generator,
343 | latents,
344 | )
345 | if image_ids is not None:
346 | latent_ids = torch.cat([latent_ids, image_ids], dim=0) # dim 0 is sequence dimension
347 |
348 | # 5. Prepare timesteps
349 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
350 | image_seq_len = latents.shape[1]
351 | mu = calculate_shift(
352 | image_seq_len,
353 | self.scheduler.config.get("base_image_seq_len", 256),
354 | self.scheduler.config.get("max_image_seq_len", 4096),
355 | self.scheduler.config.get("base_shift", 0.5),
356 | self.scheduler.config.get("max_shift", 1.15),
357 | )
358 | timesteps, num_inference_steps = retrieve_timesteps(
359 | self.scheduler,
360 | num_inference_steps,
361 | device,
362 | sigmas=sigmas,
363 | mu=mu,
364 | )
365 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
366 | self._num_timesteps = len(timesteps)
367 |
368 | # handle guidance
369 | if self.transformer.config.guidance_embeds:
370 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
371 | guidance = guidance.expand(prompt_embeds.shape[0])
372 | else:
373 | guidance = None
374 |
375 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
376 | negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
377 | ):
378 | negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
379 | negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
380 |
381 | elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
382 | negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
383 | ):
384 | ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
385 | ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters
386 |
387 | if self.joint_attention_kwargs is None:
388 | self._joint_attention_kwargs = {}
389 |
390 | image_embeds = None
391 | negative_image_embeds = None
392 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
393 | image_embeds = self.prepare_ip_adapter_image_embeds(
394 | ip_adapter_image,
395 | ip_adapter_image_embeds,
396 | device,
397 | batch_size * num_images_per_prompt,
398 | )
399 | if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
400 | negative_image_embeds = self.prepare_ip_adapter_image_embeds(
401 | negative_ip_adapter_image,
402 | negative_ip_adapter_image_embeds,
403 | device,
404 | batch_size * num_images_per_prompt,
405 | )
406 |
407 | origin_attn_procs = self.transformer.attn_processors
408 | if self.do_normalized_attention_guidance:
409 | self._set_nag_attn_processor(nag_scale, prompt_embeds.shape[1], nag_tau, nag_alpha)
410 | attn_procs_recovered = False
411 |
412 | for sub_mod in self.transformer.modules():
413 | if not hasattr(sub_mod, "forward_old") :
414 | sub_mod.forward_old = sub_mod.forward
415 | if isinstance(sub_mod, AdaLayerNormZero):
416 | sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod)
417 | elif isinstance(sub_mod, AdaLayerNormContinuous):
418 | sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod)
419 |
420 |
421 | # 6. Denoising loop
422 | # We set the index here to remove DtoH sync, helpful especially during compilation.
423 | # Check out more details here: https://github.com/huggingface/diffusers/pull/11696
424 | self.scheduler.set_begin_index(0)
425 | with self.progress_bar(total=num_inference_steps) as progress_bar:
426 | for i, t in enumerate(timesteps):
427 | if self.interrupt:
428 | continue
429 |
430 | if t < (1 - nag_end) * 1000 and self.do_normalized_attention_guidance and not attn_procs_recovered:
431 | self.transformer.set_attn_processor(origin_attn_procs)
432 | if guidance is not None:
433 | guidance = guidance[:len(latents)]
434 | pooled_prompt_embeds = pooled_prompt_embeds[:len(latents)]
435 | prompt_embeds = prompt_embeds[:len(latents)]
436 | attn_procs_recovered = True
437 |
438 | self._current_timestep = t
439 | if image_embeds is not None:
440 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
441 |
442 | latent_model_input = latents
443 | if image_latents is not None:
444 | latent_model_input = torch.cat([latents, image_latents], dim=1)
445 | timestep = t.expand(prompt_embeds.shape[0]).to(latents.dtype)
446 |
447 | noise_pred = self.transformer(
448 | hidden_states=latent_model_input,
449 | timestep=timestep / 1000,
450 | guidance=guidance,
451 | pooled_projections=pooled_prompt_embeds,
452 | encoder_hidden_states=prompt_embeds,
453 | txt_ids=text_ids,
454 | img_ids=latent_ids,
455 | joint_attention_kwargs=self.joint_attention_kwargs,
456 | return_dict=False,
457 | )[0]
458 | noise_pred = noise_pred[:, : latents.size(1)]
459 |
460 | if do_true_cfg:
461 | if negative_image_embeds is not None:
462 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
463 | neg_noise_pred = self.transformer(
464 | hidden_states=latent_model_input,
465 | timestep=timestep / 1000,
466 | guidance=guidance,
467 | pooled_projections=negative_pooled_prompt_embeds,
468 | encoder_hidden_states=negative_prompt_embeds,
469 | txt_ids=negative_text_ids,
470 | img_ids=latent_ids,
471 | joint_attention_kwargs=self.joint_attention_kwargs,
472 | return_dict=False,
473 | )[0]
474 | neg_noise_pred = neg_noise_pred[:, : latents.size(1)]
475 | noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
476 |
477 | # compute the previous noisy sample x_t -> x_t-1
478 | latents_dtype = latents.dtype
479 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
480 |
481 | if latents.dtype != latents_dtype:
482 | if torch.backends.mps.is_available():
483 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
484 | latents = latents.to(latents_dtype)
485 |
486 | if callback_on_step_end is not None:
487 | callback_kwargs = {}
488 | for k in callback_on_step_end_tensor_inputs:
489 | callback_kwargs[k] = locals()[k]
490 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
491 |
492 | latents = callback_outputs.pop("latents", latents)
493 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
494 |
495 | # call the callback, if provided
496 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
497 | progress_bar.update()
498 |
499 | if XLA_AVAILABLE:
500 | xm.mark_step()
501 |
502 | self._current_timestep = None
503 |
504 | if output_type == "latent":
505 | image = latents
506 | else:
507 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
508 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
509 | image = self.vae.decode(latents, return_dict=False)[0]
510 | image = self.image_processor.postprocess(image, output_type=output_type)
511 |
512 | if self.do_normalized_attention_guidance and not attn_procs_recovered:
513 | self.transformer.set_attn_processor(origin_attn_procs)
514 |
515 | # Offload all models
516 | self.maybe_free_model_hooks()
517 |
518 | if not return_dict:
519 | return (image,)
520 |
521 | return FluxPipelineOutput(images=image)
522 |
--------------------------------------------------------------------------------
/nag/pipeline_flux_nag.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Union, List, Optional, Dict, Any, Callable
3 | import types
4 |
5 | import numpy as np
6 | import torch
7 |
8 | from diffusers import FluxPipeline
9 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
10 | from diffusers.image_processor import PipelineImageInput
11 | from diffusers.utils import is_torch_xla_available
12 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
13 | from diffusers.models.normalization import AdaLayerNormZero, AdaLayerNormContinuous
14 |
15 | from nag.attention_flux_nag import NAGFluxAttnProcessor2_0
16 | from nag.normalization import TruncAdaLayerNormZero, TruncAdaLayerNormContinuous
17 |
18 | if is_torch_xla_available():
19 | import torch_xla.core.xla_model as xm
20 |
21 | XLA_AVAILABLE = True
22 | else:
23 | XLA_AVAILABLE = False
24 |
25 |
26 | class NAGFluxPipeline(FluxPipeline):
27 | @property
28 | def do_normalized_attention_guidance(self):
29 | return self._nag_scale > 1
30 |
31 | def _set_nag_attn_processor(
32 | self,
33 | nag_scale,
34 | encoder_hidden_states_length,
35 | nag_tau=2.5,
36 | nag_alpha=0.25,
37 | ):
38 | attn_procs = {}
39 | for name in self.transformer.attn_processors.keys():
40 | attn_procs[name] = NAGFluxAttnProcessor2_0(
41 | nag_scale=nag_scale,
42 | nag_tau=nag_tau,
43 | nag_alpha=nag_alpha,
44 | encoder_hidden_states_length=encoder_hidden_states_length,
45 | )
46 | self.transformer.set_attn_processor(attn_procs)
47 |
48 | @torch.no_grad()
49 | def __call__(
50 | self,
51 | prompt: Union[str, List[str]] = None,
52 | prompt_2: Optional[Union[str, List[str]]] = None,
53 | negative_prompt: Union[str, List[str]] = None,
54 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
55 | true_cfg_scale: float = 1.0,
56 | height: Optional[int] = None,
57 | width: Optional[int] = None,
58 | num_inference_steps: int = 28,
59 | sigmas: Optional[List[float]] = None,
60 | guidance_scale: float = 3.5,
61 | num_images_per_prompt: Optional[int] = 1,
62 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
63 | latents: Optional[torch.FloatTensor] = None,
64 | prompt_embeds: Optional[torch.FloatTensor] = None,
65 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
66 | ip_adapter_image: Optional[PipelineImageInput] = None,
67 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
68 | negative_ip_adapter_image: Optional[PipelineImageInput] = None,
69 | negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
70 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
71 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
72 | output_type: Optional[str] = "pil",
73 | return_dict: bool = True,
74 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
75 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
76 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
77 | max_sequence_length: int = 512,
78 |
79 | nag_scale: float = 1.0,
80 | nag_tau: float = 2.5,
81 | nag_alpha: float = 0.25,
82 | nag_end: float = 1.0,
83 | nag_negative_prompt: str = None,
84 | nag_negative_prompt_2: str = None,
85 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
86 | nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
87 | ):
88 | r"""
89 | Function invoked when calling the pipeline for generation.
90 |
91 | Args:
92 | prompt (`str` or `List[str]`, *optional*):
93 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
94 | instead.
95 | prompt_2 (`str` or `List[str]`, *optional*):
96 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
97 | will be used instead.
98 | negative_prompt (`str` or `List[str]`, *optional*):
99 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
100 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
101 | not greater than `1`).
102 | negative_prompt_2 (`str` or `List[str]`, *optional*):
103 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
104 | `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
105 | true_cfg_scale (`float`, *optional*, defaults to 1.0):
106 | When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
107 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
108 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
109 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
110 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
111 | num_inference_steps (`int`, *optional*, defaults to 50):
112 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
113 | expense of slower inference.
114 | sigmas (`List[float]`, *optional*):
115 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
116 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
117 | will be used.
118 | guidance_scale (`float`, *optional*, defaults to 7.0):
119 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
120 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
121 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
122 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
123 | usually at the expense of lower image quality.
124 | num_images_per_prompt (`int`, *optional*, defaults to 1):
125 | The number of images to generate per prompt.
126 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
127 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
128 | to make generation deterministic.
129 | latents (`torch.FloatTensor`, *optional*):
130 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
131 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
132 | tensor will ge generated by sampling using the supplied random `generator`.
133 | prompt_embeds (`torch.FloatTensor`, *optional*):
134 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
135 | provided, text embeddings will be generated from `prompt` input argument.
136 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
137 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
138 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
139 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
140 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
141 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
142 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
143 | provided, embeddings are computed from the `ip_adapter_image` input argument.
144 | negative_ip_adapter_image:
145 | (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
146 | negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
147 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
148 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
149 | provided, embeddings are computed from the `ip_adapter_image` input argument.
150 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
151 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
152 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
153 | argument.
154 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
155 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
156 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
157 | input argument.
158 | output_type (`str`, *optional*, defaults to `"pil"`):
159 | The output format of the generate image. Choose between
160 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
161 | return_dict (`bool`, *optional*, defaults to `True`):
162 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
163 | joint_attention_kwargs (`dict`, *optional*):
164 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
165 | `self.processor` in
166 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
167 | callback_on_step_end (`Callable`, *optional*):
168 | A function that calls at the end of each denoising steps during the inference. The function is called
169 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
170 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
171 | `callback_on_step_end_tensor_inputs`.
172 | callback_on_step_end_tensor_inputs (`List`, *optional*):
173 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
174 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
175 | `._callback_tensor_inputs` attribute of your pipeline class.
176 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
177 |
178 | Examples:
179 |
180 | Returns:
181 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
182 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
183 | images.
184 | """
185 |
186 | height = height or self.default_sample_size * self.vae_scale_factor
187 | width = width or self.default_sample_size * self.vae_scale_factor
188 |
189 | # 1. Check inputs. Raise error if not correct
190 | self.check_inputs(
191 | prompt,
192 | prompt_2,
193 | height,
194 | width,
195 | negative_prompt=negative_prompt,
196 | negative_prompt_2=negative_prompt_2,
197 | prompt_embeds=prompt_embeds,
198 | negative_prompt_embeds=negative_prompt_embeds,
199 | pooled_prompt_embeds=pooled_prompt_embeds,
200 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
201 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
202 | max_sequence_length=max_sequence_length,
203 | )
204 |
205 | self._guidance_scale = guidance_scale
206 | self._joint_attention_kwargs = joint_attention_kwargs
207 | self._current_timestep = None
208 | self._interrupt = False
209 | self._nag_scale = nag_scale
210 |
211 | # 2. Define call parameters
212 | if prompt is not None and isinstance(prompt, str):
213 | batch_size = 1
214 | elif prompt is not None and isinstance(prompt, list):
215 | batch_size = len(prompt)
216 | else:
217 | batch_size = prompt_embeds.shape[0]
218 |
219 | device = self._execution_device
220 |
221 | lora_scale = (
222 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
223 | )
224 | has_neg_prompt = negative_prompt is not None or (
225 | negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None
226 | )
227 | do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
228 | # do_true_cfg = do_true_cfg or self.do_normalized_attention_guidance
229 |
230 | (
231 | prompt_embeds,
232 | pooled_prompt_embeds,
233 | text_ids,
234 | ) = self.encode_prompt(
235 | prompt=prompt,
236 | prompt_2=prompt_2,
237 | prompt_embeds=prompt_embeds,
238 | pooled_prompt_embeds=pooled_prompt_embeds,
239 | device=device,
240 | num_images_per_prompt=num_images_per_prompt,
241 | max_sequence_length=max_sequence_length,
242 | lora_scale=lora_scale,
243 | )
244 | if do_true_cfg:
245 | (
246 | negative_prompt_embeds,
247 | negative_pooled_prompt_embeds,
248 | _,
249 | ) = self.encode_prompt(
250 | prompt=negative_prompt,
251 | prompt_2=negative_prompt_2,
252 | prompt_embeds=negative_prompt_embeds,
253 | pooled_prompt_embeds=negative_pooled_prompt_embeds,
254 | device=device,
255 | num_images_per_prompt=num_images_per_prompt,
256 | max_sequence_length=max_sequence_length,
257 | lora_scale=lora_scale,
258 | )
259 |
260 | if self.do_normalized_attention_guidance:
261 | if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None:
262 | if nag_negative_prompt is None:
263 | if negative_prompt is not None:
264 | if do_true_cfg:
265 | nag_negative_prompt_embeds = negative_prompt_embeds
266 | nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds
267 | else:
268 | nag_negative_prompt = negative_prompt
269 | nag_negative_prompt_2 = negative_prompt_2
270 | else:
271 | nag_negative_prompt = ""
272 |
273 | if nag_negative_prompt is not None:
274 | nag_negative_prompt_embeds, nag_negative_pooled_prompt_embeds = self.encode_prompt(
275 | prompt=nag_negative_prompt,
276 | prompt_2=nag_negative_prompt_2,
277 | device=device,
278 | num_images_per_prompt=num_images_per_prompt,
279 | max_sequence_length=max_sequence_length,
280 | lora_scale=lora_scale,
281 | )[:2]
282 |
283 | if self.do_normalized_attention_guidance:
284 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, nag_negative_pooled_prompt_embeds], dim=0)
285 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
286 |
287 | # 4. Prepare latent variables
288 | num_channels_latents = self.transformer.config.in_channels // 4
289 | latents, latent_image_ids = self.prepare_latents(
290 | batch_size * num_images_per_prompt,
291 | num_channels_latents,
292 | height,
293 | width,
294 | prompt_embeds.dtype,
295 | device,
296 | generator,
297 | latents,
298 | )
299 |
300 | # 5. Prepare timesteps
301 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
302 | image_seq_len = latents.shape[1]
303 | mu = calculate_shift(
304 | image_seq_len,
305 | self.scheduler.config.get("base_image_seq_len", 256),
306 | self.scheduler.config.get("max_image_seq_len", 4096),
307 | self.scheduler.config.get("base_shift", 0.5),
308 | self.scheduler.config.get("max_shift", 1.16),
309 | )
310 | timesteps, num_inference_steps = retrieve_timesteps(
311 | self.scheduler,
312 | num_inference_steps,
313 | device,
314 | sigmas=sigmas,
315 | mu=mu,
316 | )
317 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
318 | self._num_timesteps = len(timesteps)
319 |
320 | # handle guidance
321 | if self.transformer.config.guidance_embeds:
322 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
323 | guidance = guidance.expand(prompt_embeds.shape[0])
324 | else:
325 | guidance = None
326 |
327 | if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
328 | negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
329 | ):
330 | negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
331 | elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
332 | negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
333 | ):
334 | ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
335 |
336 | if self.joint_attention_kwargs is None:
337 | self._joint_attention_kwargs = {}
338 |
339 | image_embeds = None
340 | negative_image_embeds = None
341 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
342 | image_embeds = self.prepare_ip_adapter_image_embeds(
343 | ip_adapter_image,
344 | ip_adapter_image_embeds,
345 | device,
346 | batch_size * num_images_per_prompt,
347 | )
348 | if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
349 | negative_image_embeds = self.prepare_ip_adapter_image_embeds(
350 | negative_ip_adapter_image,
351 | negative_ip_adapter_image_embeds,
352 | device,
353 | batch_size * num_images_per_prompt,
354 | )
355 |
356 | origin_attn_procs = self.transformer.attn_processors
357 | if self.do_normalized_attention_guidance:
358 | self._set_nag_attn_processor(nag_scale, prompt_embeds.shape[1], nag_tau, nag_alpha)
359 | attn_procs_recovered = False
360 |
361 | for sub_mod in self.transformer.modules():
362 | if not hasattr(sub_mod, "forward_old") :
363 | sub_mod.forward_old = sub_mod.forward
364 | if isinstance(sub_mod, AdaLayerNormZero):
365 | sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod)
366 | elif isinstance(sub_mod, AdaLayerNormContinuous):
367 | sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod)
368 |
369 | # 6. Denoising loop
370 | with self.progress_bar(total=num_inference_steps) as progress_bar:
371 | for i, t in enumerate(timesteps):
372 | if self.interrupt:
373 | continue
374 |
375 | if t < (1 - nag_end) * 1000 and self.do_normalized_attention_guidance and not attn_procs_recovered:
376 | self.transformer.set_attn_processor(origin_attn_procs)
377 | if guidance is not None:
378 | guidance = guidance[:len(latents)]
379 | pooled_prompt_embeds = pooled_prompt_embeds[:len(latents)]
380 | prompt_embeds = prompt_embeds[:len(latents)]
381 | attn_procs_recovered = True
382 |
383 | self._current_timestep = t
384 | if image_embeds is not None:
385 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
386 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
387 | timestep = t.expand(prompt_embeds.shape[0]).to(latents.dtype)
388 |
389 | noise_pred = self.transformer(
390 | hidden_states=latents,
391 | timestep=timestep / 1000,
392 | guidance=guidance,
393 | pooled_projections=pooled_prompt_embeds,
394 | encoder_hidden_states=prompt_embeds,
395 | txt_ids=text_ids,
396 | img_ids=latent_image_ids,
397 | joint_attention_kwargs=self.joint_attention_kwargs,
398 | return_dict=False,
399 | )[0]
400 |
401 | if do_true_cfg:
402 | if negative_image_embeds is not None:
403 | self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
404 | neg_noise_pred = self.transformer(
405 | hidden_states=latents,
406 | timestep=timestep / 1000,
407 | guidance=guidance,
408 | pooled_projections=negative_pooled_prompt_embeds,
409 | encoder_hidden_states=negative_prompt_embeds,
410 | txt_ids=text_ids,
411 | img_ids=latent_image_ids,
412 | joint_attention_kwargs=self.joint_attention_kwargs,
413 | return_dict=False,
414 | )[0]
415 | noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
416 |
417 | # compute the previous noisy sample x_t -> x_t-1
418 | latents_dtype = latents.dtype
419 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
420 |
421 | if latents.dtype != latents_dtype:
422 | if torch.backends.mps.is_available():
423 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
424 | latents = latents.to(latents_dtype)
425 |
426 | if callback_on_step_end is not None:
427 | callback_kwargs = {}
428 | for k in callback_on_step_end_tensor_inputs:
429 | callback_kwargs[k] = locals()[k]
430 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
431 |
432 | latents = callback_outputs.pop("latents", latents)
433 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
434 |
435 | # call the callback, if provided
436 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
437 | progress_bar.update()
438 |
439 | if XLA_AVAILABLE:
440 | xm.mark_step()
441 |
442 | self._current_timestep = None
443 |
444 | if output_type == "latent":
445 | image = latents
446 | else:
447 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
448 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
449 | image = self.vae.decode(latents, return_dict=False)[0]
450 | image = self.image_processor.postprocess(image, output_type=output_type)
451 |
452 | if self.do_normalized_attention_guidance and not attn_procs_recovered:
453 | self.transformer.set_attn_processor(origin_attn_procs)
454 |
455 | # Offload all models
456 | self.maybe_free_model_hooks()
457 |
458 | if not return_dict:
459 | return (image,)
460 |
461 | return FluxPipelineOutput(images=image)
--------------------------------------------------------------------------------
/nag/pipeline_sd3_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional, Union
2 | import types
3 |
4 | import torch
5 |
6 | from diffusers.image_processor import PipelineImageInput
7 | from diffusers.utils import (
8 | is_torch_xla_available,
9 | logging,
10 | )
11 | from diffusers.pipelines.stable_diffusion_3.pipeline_output import StableDiffusion3PipelineOutput
12 | from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import (
13 | StableDiffusion3Pipeline,
14 | retrieve_timesteps,
15 | calculate_shift,
16 | )
17 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, SD35AdaLayerNormZeroX
18 |
19 | from nag.attention_joint_nag import NAGJointAttnProcessor2_0
20 | from nag.normalization import TruncAdaLayerNorm, TruncAdaLayerNormZero, TruncAdaLayerNormContinuous, TruncSD35AdaLayerNormZeroX
21 |
22 |
23 | if is_torch_xla_available():
24 | import torch_xla.core.xla_model as xm
25 |
26 | XLA_AVAILABLE = True
27 | else:
28 | XLA_AVAILABLE = False
29 |
30 |
31 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32 |
33 |
34 | class NAGStableDiffusion3Pipeline(StableDiffusion3Pipeline):
35 | @property
36 | def do_normalized_attention_guidance(self):
37 | return self._nag_scale > 1
38 |
39 | def _set_nag_attn_processor(self, nag_scale, nag_tau, nag_alpha):
40 | attn_procs = {}
41 | for name in self.transformer.attn_processors.keys():
42 | attn_procs[name] = NAGJointAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha)
43 | self.transformer.set_attn_processor(attn_procs)
44 |
45 | @torch.no_grad()
46 | def __call__(
47 | self,
48 | prompt: Union[str, List[str]] = None,
49 | prompt_2: Optional[Union[str, List[str]]] = None,
50 | prompt_3: Optional[Union[str, List[str]]] = None,
51 | height: Optional[int] = None,
52 | width: Optional[int] = None,
53 | num_inference_steps: int = 28,
54 | sigmas: Optional[List[float]] = None,
55 | guidance_scale: float = 7.0,
56 | negative_prompt: Optional[Union[str, List[str]]] = None,
57 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
58 | negative_prompt_3: Optional[Union[str, List[str]]] = None,
59 | num_images_per_prompt: Optional[int] = 1,
60 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61 | latents: Optional[torch.FloatTensor] = None,
62 | prompt_embeds: Optional[torch.FloatTensor] = None,
63 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
64 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
65 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
66 | ip_adapter_image: Optional[PipelineImageInput] = None,
67 | ip_adapter_image_embeds: Optional[torch.Tensor] = None,
68 | output_type: Optional[str] = "pil",
69 | return_dict: bool = True,
70 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
71 | clip_skip: Optional[int] = None,
72 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
73 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
74 | max_sequence_length: int = 256,
75 | skip_guidance_layers: List[int] = None,
76 | skip_layer_guidance_scale: float = 2.8,
77 | skip_layer_guidance_stop: float = 0.2,
78 | skip_layer_guidance_start: float = 0.01,
79 | mu: Optional[float] = None,
80 |
81 | nag_scale: float = 1.0,
82 | nag_tau: float = 2.5,
83 | nag_alpha: float = 0.125,
84 | nag_negative_prompt: str = None,
85 | nag_negative_prompt_2: str = None,
86 | nag_negative_prompt_3: str = None,
87 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
88 | nag_negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
89 | ):
90 | r"""
91 | Function invoked when calling the pipeline for generation.
92 |
93 | Args:
94 | prompt (`str` or `List[str]`, *optional*):
95 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
96 | instead.
97 | prompt_2 (`str` or `List[str]`, *optional*):
98 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
99 | will be used instead
100 | prompt_3 (`str` or `List[str]`, *optional*):
101 | The prompt or prompts to be sent to `tokenizer_3` and `text_encoder_3`. If not defined, `prompt` is
102 | will be used instead
103 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
104 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
105 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
106 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
107 | num_inference_steps (`int`, *optional*, defaults to 50):
108 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
109 | expense of slower inference.
110 | sigmas (`List[float]`, *optional*):
111 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
112 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
113 | will be used.
114 | guidance_scale (`float`, *optional*, defaults to 7.0):
115 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
116 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
117 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
118 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
119 | usually at the expense of lower image quality.
120 | negative_prompt (`str` or `List[str]`, *optional*):
121 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
122 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
123 | less than `1`).
124 | negative_prompt_2 (`str` or `List[str]`, *optional*):
125 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
126 | `text_encoder_2`. If not defined, `negative_prompt` is used instead
127 | negative_prompt_3 (`str` or `List[str]`, *optional*):
128 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_3` and
129 | `text_encoder_3`. If not defined, `negative_prompt` is used instead
130 | num_images_per_prompt (`int`, *optional*, defaults to 1):
131 | The number of images to generate per prompt.
132 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
133 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
134 | to make generation deterministic.
135 | latents (`torch.FloatTensor`, *optional*):
136 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
137 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
138 | tensor will ge generated by sampling using the supplied random `generator`.
139 | prompt_embeds (`torch.FloatTensor`, *optional*):
140 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
141 | provided, text embeddings will be generated from `prompt` input argument.
142 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
143 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
144 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
145 | argument.
146 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
147 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
148 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
149 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
150 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
151 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
152 | input argument.
153 | ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
154 | ip_adapter_image_embeds (`torch.Tensor`, *optional*):
155 | Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
156 | emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
157 | `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
158 | output_type (`str`, *optional*, defaults to `"pil"`):
159 | The output format of the generate image. Choose between
160 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
161 | return_dict (`bool`, *optional*, defaults to `True`):
162 | Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
163 | a plain tuple.
164 | joint_attention_kwargs (`dict`, *optional*):
165 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
166 | `self.processor` in
167 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
168 | callback_on_step_end (`Callable`, *optional*):
169 | A function that calls at the end of each denoising steps during the inference. The function is called
170 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
171 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
172 | `callback_on_step_end_tensor_inputs`.
173 | callback_on_step_end_tensor_inputs (`List`, *optional*):
174 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
175 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
176 | `._callback_tensor_inputs` attribute of your pipeline class.
177 | max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
178 | skip_guidance_layers (`List[int]`, *optional*):
179 | A list of integers that specify layers to skip during guidance. If not provided, all layers will be
180 | used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
181 | Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
182 | skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
183 | `skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
184 | with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
185 | with a scale of `1`.
186 | skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
187 | `skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
188 | `skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
189 | StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
190 | skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
191 | `skip_guidance_layers` will start. The guidance will be applied to the layers specified in
192 | `skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
193 | StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
194 | mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
195 |
196 | Examples:
197 |
198 | Returns:
199 | [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] or `tuple`:
200 | [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] if `return_dict` is True, otherwise a
201 | `tuple`. When returning a tuple, the first element is a list with the generated images.
202 | """
203 |
204 | height = height or self.default_sample_size * self.vae_scale_factor
205 | width = width or self.default_sample_size * self.vae_scale_factor
206 |
207 | # 1. Check inputs. Raise error if not correct
208 | self.check_inputs(
209 | prompt,
210 | prompt_2,
211 | prompt_3,
212 | height,
213 | width,
214 | negative_prompt=negative_prompt,
215 | negative_prompt_2=negative_prompt_2,
216 | negative_prompt_3=negative_prompt_3,
217 | prompt_embeds=prompt_embeds,
218 | negative_prompt_embeds=negative_prompt_embeds,
219 | pooled_prompt_embeds=pooled_prompt_embeds,
220 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
221 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
222 | max_sequence_length=max_sequence_length,
223 | )
224 |
225 | self._guidance_scale = guidance_scale
226 | self._skip_layer_guidance_scale = skip_layer_guidance_scale
227 | self._clip_skip = clip_skip
228 | self._joint_attention_kwargs = joint_attention_kwargs
229 | self._interrupt = False
230 | self._nag_scale = nag_scale
231 |
232 | # 2. Define call parameters
233 | if prompt is not None and isinstance(prompt, str):
234 | batch_size = 1
235 | elif prompt is not None and isinstance(prompt, list):
236 | batch_size = len(prompt)
237 | else:
238 | batch_size = prompt_embeds.shape[0]
239 |
240 | device = self._execution_device
241 |
242 | lora_scale = (
243 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
244 | )
245 | (
246 | prompt_embeds,
247 | negative_prompt_embeds,
248 | pooled_prompt_embeds,
249 | negative_pooled_prompt_embeds,
250 | ) = self.encode_prompt(
251 | prompt=prompt,
252 | prompt_2=prompt_2,
253 | prompt_3=prompt_3,
254 | negative_prompt=negative_prompt,
255 | negative_prompt_2=negative_prompt_2,
256 | negative_prompt_3=negative_prompt_3,
257 | do_classifier_free_guidance=self.do_classifier_free_guidance,
258 | prompt_embeds=prompt_embeds,
259 | negative_prompt_embeds=negative_prompt_embeds,
260 | pooled_prompt_embeds=pooled_prompt_embeds,
261 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
262 | device=device,
263 | clip_skip=self.clip_skip,
264 | num_images_per_prompt=num_images_per_prompt,
265 | max_sequence_length=max_sequence_length,
266 | lora_scale=lora_scale,
267 | )
268 | if self.do_normalized_attention_guidance:
269 | if nag_negative_prompt_embeds is None or nag_negative_pooled_prompt_embeds is None:
270 | if nag_negative_prompt is None:
271 | if negative_prompt is not None:
272 | if self.do_classifier_free_guidance:
273 | nag_negative_prompt_embeds = negative_prompt_embeds
274 | nag_negative_pooled_prompt_embeds = negative_pooled_prompt_embeds
275 | else:
276 | nag_negative_prompt = negative_prompt
277 | nag_negative_prompt_2 = negative_prompt_2
278 | nag_negative_prompt_3 = negative_prompt_3
279 | else:
280 | nag_negative_prompt = ""
281 |
282 | if nag_negative_prompt is not None:
283 | nag_negative_prompt_embeds, _, nag_negative_pooled_prompt_embeds, _ = self.encode_prompt(
284 | prompt=nag_negative_prompt,
285 | prompt_2=nag_negative_prompt_2,
286 | prompt_3=nag_negative_prompt_3,
287 | do_classifier_free_guidance=False,
288 | device=device,
289 | clip_skip=self.clip_skip,
290 | num_images_per_prompt=num_images_per_prompt,
291 | max_sequence_length=max_sequence_length,
292 | lora_scale=lora_scale,
293 | )
294 |
295 | if self.do_classifier_free_guidance:
296 | original_prompt_embeds = prompt_embeds
297 | original_pooled_prompt_embeds = pooled_prompt_embeds
298 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
299 | pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
300 |
301 | if self.do_normalized_attention_guidance:
302 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
303 | if self.do_classifier_free_guidance:
304 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, original_pooled_prompt_embeds], dim=0)
305 | else:
306 | pooled_prompt_embeds = torch.cat([pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
307 |
308 | # 4. Prepare latent variables
309 | num_channels_latents = self.transformer.config.in_channels
310 | latents = self.prepare_latents(
311 | batch_size * num_images_per_prompt,
312 | num_channels_latents,
313 | height,
314 | width,
315 | prompt_embeds.dtype,
316 | device,
317 | generator,
318 | latents,
319 | )
320 |
321 | # 5. Prepare timesteps
322 | scheduler_kwargs = {}
323 | if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
324 | _, _, height, width = latents.shape
325 | image_seq_len = (height // self.transformer.config.patch_size) * (
326 | width // self.transformer.config.patch_size
327 | )
328 | mu = calculate_shift(
329 | image_seq_len,
330 | self.scheduler.config.base_image_seq_len,
331 | self.scheduler.config.max_image_seq_len,
332 | self.scheduler.config.base_shift,
333 | self.scheduler.config.max_shift,
334 | )
335 | scheduler_kwargs["mu"] = mu
336 | elif mu is not None:
337 | scheduler_kwargs["mu"] = mu
338 | timesteps, num_inference_steps = retrieve_timesteps(
339 | self.scheduler,
340 | num_inference_steps,
341 | device,
342 | sigmas=sigmas,
343 | **scheduler_kwargs,
344 | )
345 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
346 | self._num_timesteps = len(timesteps)
347 |
348 | # 6. Prepare image embeddings
349 | if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
350 | ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
351 | ip_adapter_image,
352 | ip_adapter_image_embeds,
353 | device,
354 | batch_size * num_images_per_prompt,
355 | self.do_classifier_free_guidance,
356 | )
357 |
358 | if self.joint_attention_kwargs is None:
359 | self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
360 | else:
361 | self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
362 |
363 | origin_attn_procs = self.transformer.attn_processors
364 | if self.do_normalized_attention_guidance:
365 | self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha)
366 |
367 | for sub_mod in self.transformer.modules():
368 | if not hasattr(sub_mod, "forward_old") :
369 | sub_mod.forward_old = sub_mod.forward
370 | if isinstance(sub_mod, AdaLayerNorm):
371 | sub_mod.forward = types.MethodType(TruncAdaLayerNorm.forward, sub_mod)
372 | elif isinstance(sub_mod, AdaLayerNormContinuous):
373 | sub_mod.forward = types.MethodType(TruncAdaLayerNormContinuous.forward, sub_mod)
374 | elif isinstance(sub_mod, AdaLayerNormZero):
375 | sub_mod.forward = types.MethodType(TruncAdaLayerNormZero.forward, sub_mod)
376 | elif isinstance(sub_mod, SD35AdaLayerNormZeroX):
377 | sub_mod.forward = types.MethodType(TruncSD35AdaLayerNormZeroX.forward, sub_mod)
378 |
379 | # 7. Denoising loop
380 | with self.progress_bar(total=num_inference_steps) as progress_bar:
381 | for i, t in enumerate(timesteps):
382 | if self.interrupt:
383 | continue
384 |
385 | # expand the latents if we are doing classifier free guidance
386 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
387 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
388 | timestep = t.expand(prompt_embeds.shape[0])
389 |
390 | noise_pred = self.transformer(
391 | hidden_states=latent_model_input,
392 | timestep=timestep,
393 | encoder_hidden_states=prompt_embeds,
394 | pooled_projections=pooled_prompt_embeds,
395 | joint_attention_kwargs=self.joint_attention_kwargs,
396 | return_dict=False,
397 | )[0]
398 |
399 | # perform guidance
400 | if self.do_classifier_free_guidance:
401 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
402 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
403 | should_skip_layers = (
404 | True
405 | if i > num_inference_steps * skip_layer_guidance_start
406 | and i < num_inference_steps * skip_layer_guidance_stop
407 | else False
408 | )
409 | if skip_guidance_layers is not None and should_skip_layers:
410 | timestep = t.expand(latents.shape[0])
411 | latent_model_input = latents
412 | noise_pred_skip_layers = self.transformer(
413 | hidden_states=latent_model_input,
414 | timestep=timestep,
415 | encoder_hidden_states=original_prompt_embeds,
416 | pooled_projections=original_pooled_prompt_embeds,
417 | joint_attention_kwargs=self.joint_attention_kwargs,
418 | return_dict=False,
419 | skip_layers=skip_guidance_layers,
420 | )[0]
421 | noise_pred = (
422 | noise_pred + (
423 | noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
424 | )
425 |
426 | # compute the previous noisy sample x_t -> x_t-1
427 | latents_dtype = latents.dtype
428 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
429 |
430 | if latents.dtype != latents_dtype:
431 | if torch.backends.mps.is_available():
432 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
433 | latents = latents.to(latents_dtype)
434 |
435 | if callback_on_step_end is not None:
436 | callback_kwargs = {}
437 | for k in callback_on_step_end_tensor_inputs:
438 | callback_kwargs[k] = locals()[k]
439 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
440 |
441 | latents = callback_outputs.pop("latents", latents)
442 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
443 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
444 | negative_pooled_prompt_embeds = callback_outputs.pop(
445 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
446 | )
447 |
448 | # call the callback, if provided
449 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
450 | progress_bar.update()
451 |
452 | if XLA_AVAILABLE:
453 | xm.mark_step()
454 |
455 | if output_type == "latent":
456 | image = latents
457 |
458 | else:
459 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
460 |
461 | image = self.vae.decode(latents, return_dict=False)[0]
462 | image = self.image_processor.postprocess(image, output_type=output_type)
463 |
464 | if self.do_normalized_attention_guidance:
465 | self.transformer.set_attn_processor(origin_attn_procs)
466 |
467 | # Offload all models
468 | self.maybe_free_model_hooks()
469 |
470 | if not return_dict:
471 | return (image,)
472 |
473 | return StableDiffusion3PipelineOutput(images=image)
474 |
--------------------------------------------------------------------------------
/nag/pipeline_sdxl_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2 | import math
3 |
4 | import torch
5 |
6 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
7 | from diffusers.image_processor import PipelineImageInput
8 | from diffusers.utils import (
9 | deprecate,
10 | is_torch_xla_available,
11 | )
12 | from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
13 |
14 | if is_torch_xla_available():
15 | import torch_xla.core.xla_model as xm
16 |
17 | XLA_AVAILABLE = True
18 | else:
19 | XLA_AVAILABLE = False
20 | from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
21 | StableDiffusionXLPipeline,
22 | retrieve_timesteps,
23 | rescale_noise_cfg,
24 | )
25 |
26 | from nag.attention_nag import NAGAttnProcessor2_0
27 |
28 |
29 | class NAGStableDiffusionXLPipeline(StableDiffusionXLPipeline):
30 | @property
31 | def do_normalized_attention_guidance(self):
32 | return self._nag_scale > 1
33 |
34 | def _set_nag_attn_processor(self, nag_scale, nag_tau=2.5, nag_alpha=0.5):
35 | if self.do_normalized_attention_guidance:
36 | attn_procs = {}
37 | for name, origin_attn_processor in self.unet.attn_processors.items():
38 | if "attn2" in name:
39 | attn_procs[name] = NAGAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha)
40 | else:
41 | attn_procs[name] = origin_attn_processor
42 | self.unet.set_attn_processor(attn_procs)
43 |
44 | @torch.no_grad()
45 | def __call__(
46 | self,
47 | prompt: Union[str, List[str]] = None,
48 | prompt_2: Optional[Union[str, List[str]]] = None,
49 | height: Optional[int] = None,
50 | width: Optional[int] = None,
51 | num_inference_steps: int = 50,
52 | timesteps: List[int] = None,
53 | sigmas: List[float] = None,
54 | denoising_end: Optional[float] = None,
55 | guidance_scale: float = 5.0,
56 | negative_prompt: Optional[Union[str, List[str]]] = None,
57 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
58 | num_images_per_prompt: Optional[int] = 1,
59 | eta: float = 0.0,
60 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
61 | latents: Optional[torch.Tensor] = None,
62 | prompt_embeds: Optional[torch.Tensor] = None,
63 | negative_prompt_embeds: Optional[torch.Tensor] = None,
64 | pooled_prompt_embeds: Optional[torch.Tensor] = None,
65 | negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
66 | ip_adapter_image: Optional[PipelineImageInput] = None,
67 | ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
68 | output_type: Optional[str] = "pil",
69 | return_dict: bool = True,
70 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
71 | guidance_rescale: float = 0.0,
72 | original_size: Optional[Tuple[int, int]] = None,
73 | crops_coords_top_left: Tuple[int, int] = (0, 0),
74 | target_size: Optional[Tuple[int, int]] = None,
75 | negative_original_size: Optional[Tuple[int, int]] = None,
76 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
77 | negative_target_size: Optional[Tuple[int, int]] = None,
78 | clip_skip: Optional[int] = None,
79 | callback_on_step_end: Optional[
80 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
81 | ] = None,
82 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
83 |
84 | nag_scale: float = 1.0,
85 | nag_tau: float = 2.5,
86 | nag_alpha: float = 0.5,
87 | nag_negative_prompt: str = None,
88 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
89 | nag_end: float = 1.0,
90 |
91 | **kwargs,
92 | ):
93 | r"""
94 | Function invoked when calling the pipeline for generation.
95 |
96 | Args:
97 | prompt (`str` or `List[str]`, *optional*):
98 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
99 | instead.
100 | prompt_2 (`str` or `List[str]`, *optional*):
101 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
102 | used in both text-encoders
103 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
104 | The height in pixels of the generated image. This is set to 1024 by default for the best results.
105 | Anything below 512 pixels won't work well for
106 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
107 | and checkpoints that are not specifically fine-tuned on low resolutions.
108 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
109 | The width in pixels of the generated image. This is set to 1024 by default for the best results.
110 | Anything below 512 pixels won't work well for
111 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
112 | and checkpoints that are not specifically fine-tuned on low resolutions.
113 | num_inference_steps (`int`, *optional*, defaults to 50):
114 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
115 | expense of slower inference.
116 | timesteps (`List[int]`, *optional*):
117 | Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
118 | in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
119 | passed will be used. Must be in descending order.
120 | sigmas (`List[float]`, *optional*):
121 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
122 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
123 | will be used.
124 | denoising_end (`float`, *optional*):
125 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
126 | completed before it is intentionally prematurely terminated. As a result, the returned sample will
127 | still retain a substantial amount of noise as determined by the discrete timesteps selected by the
128 | scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
129 | "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
130 | Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
131 | guidance_scale (`float`, *optional*, defaults to 5.0):
132 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
133 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
134 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
135 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
136 | usually at the expense of lower image quality.
137 | negative_prompt (`str` or `List[str]`, *optional*):
138 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
139 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
140 | less than `1`).
141 | negative_prompt_2 (`str` or `List[str]`, *optional*):
142 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
143 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
144 | num_images_per_prompt (`int`, *optional*, defaults to 1):
145 | The number of images to generate per prompt.
146 | eta (`float`, *optional*, defaults to 0.0):
147 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
148 | [`schedulers.DDIMScheduler`], will be ignored for others.
149 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
150 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
151 | to make generation deterministic.
152 | latents (`torch.Tensor`, *optional*):
153 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
154 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
155 | tensor will ge generated by sampling using the supplied random `generator`.
156 | prompt_embeds (`torch.Tensor`, *optional*):
157 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
158 | provided, text embeddings will be generated from `prompt` input argument.
159 | negative_prompt_embeds (`torch.Tensor`, *optional*):
160 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
161 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
162 | argument.
163 | pooled_prompt_embeds (`torch.Tensor`, *optional*):
164 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
165 | If not provided, pooled text embeddings will be generated from `prompt` input argument.
166 | negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
167 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
168 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
169 | input argument.
170 | ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
171 | ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
172 | Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
173 | IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
174 | contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
175 | provided, embeddings are computed from the `ip_adapter_image` input argument.
176 | output_type (`str`, *optional*, defaults to `"pil"`):
177 | The output format of the generate image. Choose between
178 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
179 | return_dict (`bool`, *optional*, defaults to `True`):
180 | Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
181 | of a plain tuple.
182 | cross_attention_kwargs (`dict`, *optional*):
183 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
184 | `self.processor` in
185 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
186 | guidance_rescale (`float`, *optional*, defaults to 0.0):
187 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
188 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
189 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
190 | Guidance rescale factor should fix overexposure when using zero terminal SNR.
191 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
192 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
193 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
194 | explained in section 2.2 of
195 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
196 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
197 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
198 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
199 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
200 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
201 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
202 | For most cases, `target_size` should be set to the desired height and width of the generated image. If
203 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
204 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
205 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
206 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's
207 | micro-conditioning as explained in section 2.2 of
208 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
209 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
210 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
211 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
212 | micro-conditioning as explained in section 2.2 of
213 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
214 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
215 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
216 | To negatively condition the generation process based on a target image resolution. It should be as same
217 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
218 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
219 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
220 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
221 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
222 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
223 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
224 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
225 | callback_on_step_end_tensor_inputs (`List`, *optional*):
226 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
227 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
228 | `._callback_tensor_inputs` attribute of your pipeline class.
229 |
230 | Examples:
231 |
232 | Returns:
233 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
234 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
235 | `tuple`. When returning a tuple, the first element is a list with the generated images.
236 | """
237 |
238 | callback = kwargs.pop("callback", None)
239 | callback_steps = kwargs.pop("callback_steps", None)
240 |
241 | if callback is not None:
242 | deprecate(
243 | "callback",
244 | "1.0.0",
245 | "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
246 | )
247 | if callback_steps is not None:
248 | deprecate(
249 | "callback_steps",
250 | "1.0.0",
251 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
252 | )
253 |
254 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
255 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
256 |
257 | # 0. Default height and width to unet
258 | height = height or self.default_sample_size * self.vae_scale_factor
259 | width = width or self.default_sample_size * self.vae_scale_factor
260 |
261 | original_size = original_size or (height, width)
262 | target_size = target_size or (height, width)
263 |
264 | # 1. Check inputs. Raise error if not correct
265 | self.check_inputs(
266 | prompt,
267 | prompt_2,
268 | height,
269 | width,
270 | callback_steps,
271 | negative_prompt,
272 | negative_prompt_2,
273 | prompt_embeds,
274 | negative_prompt_embeds,
275 | pooled_prompt_embeds,
276 | negative_pooled_prompt_embeds,
277 | ip_adapter_image,
278 | ip_adapter_image_embeds,
279 | callback_on_step_end_tensor_inputs,
280 | )
281 |
282 | self._guidance_scale = guidance_scale
283 | self._guidance_rescale = guidance_rescale
284 | self._clip_skip = clip_skip
285 | self._cross_attention_kwargs = cross_attention_kwargs
286 | self._denoising_end = denoising_end
287 | self._interrupt = False
288 | self._nag_scale = nag_scale
289 |
290 | # 2. Define call parameters
291 | if prompt is not None and isinstance(prompt, str):
292 | batch_size = 1
293 | elif prompt is not None and isinstance(prompt, list):
294 | batch_size = len(prompt)
295 | else:
296 | batch_size = prompt_embeds.shape[0]
297 |
298 | device = self._execution_device
299 |
300 | # 3. Encode input prompt
301 | lora_scale = (
302 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
303 | )
304 |
305 | (
306 | prompt_embeds,
307 | negative_prompt_embeds,
308 | pooled_prompt_embeds,
309 | negative_pooled_prompt_embeds,
310 | ) = self.encode_prompt(
311 | prompt=prompt,
312 | prompt_2=prompt_2,
313 | device=device,
314 | num_images_per_prompt=num_images_per_prompt,
315 | do_classifier_free_guidance=self.do_classifier_free_guidance or self.do_normalized_attention_guidance,
316 | negative_prompt=negative_prompt,
317 | negative_prompt_2=negative_prompt_2,
318 | prompt_embeds=prompt_embeds,
319 | negative_prompt_embeds=negative_prompt_embeds,
320 | pooled_prompt_embeds=pooled_prompt_embeds,
321 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
322 | lora_scale=lora_scale,
323 | clip_skip=self.clip_skip,
324 | )
325 | if self.do_normalized_attention_guidance:
326 | if nag_negative_prompt_embeds is None:
327 | if nag_negative_prompt is None:
328 | if negative_prompt is not None:
329 | if self.do_classifier_free_guidance:
330 | nag_negative_prompt_embeds = negative_prompt_embeds
331 | else:
332 | negative_prompt = negative_prompt
333 | else:
334 | nag_negative_prompt = ""
335 |
336 | if nag_negative_prompt is not None:
337 | nag_negative_prompt_embeds = self.encode_prompt(
338 | prompt=nag_negative_prompt,
339 | device=device,
340 | num_images_per_prompt=num_images_per_prompt,
341 | do_classifier_free_guidance=False,
342 | lora_scale=lora_scale,
343 | clip_skip=self.clip_skip,
344 | )[0]
345 |
346 | # 4. Prepare timesteps
347 | timesteps, num_inference_steps = retrieve_timesteps(
348 | self.scheduler, num_inference_steps, device, timesteps, sigmas
349 | )
350 |
351 | # 5. Prepare latent variables
352 | num_channels_latents = self.unet.config.in_channels
353 | latents = self.prepare_latents(
354 | batch_size * num_images_per_prompt,
355 | num_channels_latents,
356 | height,
357 | width,
358 | prompt_embeds.dtype,
359 | device,
360 | generator,
361 | latents,
362 | )
363 |
364 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
365 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
366 |
367 | # 7. Prepare added time ids & embeddings
368 | add_text_embeds = pooled_prompt_embeds
369 | if self.text_encoder_2 is None:
370 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
371 | else:
372 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
373 |
374 | add_time_ids = self._get_add_time_ids(
375 | original_size,
376 | crops_coords_top_left,
377 | target_size,
378 | dtype=prompt_embeds.dtype,
379 | text_encoder_projection_dim=text_encoder_projection_dim,
380 | )
381 | if negative_original_size is not None and negative_target_size is not None:
382 | negative_add_time_ids = self._get_add_time_ids(
383 | negative_original_size,
384 | negative_crops_coords_top_left,
385 | negative_target_size,
386 | dtype=prompt_embeds.dtype,
387 | text_encoder_projection_dim=text_encoder_projection_dim,
388 | )
389 | else:
390 | negative_add_time_ids = add_time_ids
391 |
392 | if self.do_classifier_free_guidance:
393 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
394 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
395 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
396 |
397 | if self.do_normalized_attention_guidance:
398 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
399 |
400 | prompt_embeds = prompt_embeds.to(device)
401 | add_text_embeds = add_text_embeds.to(device)
402 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
403 |
404 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
405 | image_embeds = self.prepare_ip_adapter_image_embeds(
406 | ip_adapter_image,
407 | ip_adapter_image_embeds,
408 | device,
409 | batch_size * num_images_per_prompt,
410 | self.do_classifier_free_guidance,
411 | )
412 |
413 | # 8. Denoising loop
414 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
415 |
416 | # 8.1 Apply denoising_end
417 | if (
418 | self.denoising_end is not None
419 | and isinstance(self.denoising_end, float)
420 | and self.denoising_end > 0
421 | and self.denoising_end < 1
422 | ):
423 | discrete_timestep_cutoff = int(
424 | round(
425 | self.scheduler.config.num_train_timesteps
426 | - (self.denoising_end * self.scheduler.config.num_train_timesteps)
427 | )
428 | )
429 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
430 | timesteps = timesteps[:num_inference_steps]
431 |
432 | # 9. Optionally get Guidance Scale Embedding
433 | timestep_cond = None
434 | if self.unet.config.time_cond_proj_dim is not None:
435 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
436 | timestep_cond = self.get_guidance_scale_embedding(
437 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
438 | ).to(device=device, dtype=latents.dtype)
439 |
440 | if self.do_normalized_attention_guidance:
441 | origin_attn_procs = self.unet.attn_processors
442 | self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha)
443 | attn_procs_recovered = False
444 |
445 | self._num_timesteps = len(timesteps)
446 | with self.progress_bar(total=num_inference_steps) as progress_bar:
447 | for i, t in enumerate(timesteps):
448 | if self.interrupt:
449 | continue
450 |
451 | # expand the latents if we are doing classifier free guidance
452 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
453 |
454 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
455 |
456 | # predict the noise residual
457 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
458 | if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
459 | added_cond_kwargs["image_embeds"] = image_embeds
460 |
461 | if t < math.floor((1 - nag_end) * 999) and self.do_normalized_attention_guidance and not attn_procs_recovered:
462 | self.unet.set_attn_processor(origin_attn_procs)
463 | prompt_embeds = prompt_embeds[:len(latent_model_input)]
464 | attn_procs_recovered = True
465 |
466 | noise_pred = self.unet(
467 | latent_model_input,
468 | t,
469 | encoder_hidden_states=prompt_embeds,
470 | timestep_cond=timestep_cond,
471 | cross_attention_kwargs=self.cross_attention_kwargs,
472 | added_cond_kwargs=added_cond_kwargs,
473 | return_dict=False,
474 | )[0]
475 |
476 | # perform guidance
477 | if self.do_classifier_free_guidance:
478 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
479 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
480 |
481 | if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
482 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
483 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
484 |
485 | # compute the previous noisy sample x_t -> x_t-1
486 | latents_dtype = latents.dtype
487 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
488 | if latents.dtype != latents_dtype:
489 | if torch.backends.mps.is_available():
490 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
491 | latents = latents.to(latents_dtype)
492 |
493 | if callback_on_step_end is not None:
494 | callback_kwargs = {}
495 | for k in callback_on_step_end_tensor_inputs:
496 | callback_kwargs[k] = locals()[k]
497 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
498 |
499 | latents = callback_outputs.pop("latents", latents)
500 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
501 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
502 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
503 | negative_pooled_prompt_embeds = callback_outputs.pop(
504 | "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
505 | )
506 | add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
507 | negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
508 |
509 | # call the callback, if provided
510 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
511 | progress_bar.update()
512 | if callback is not None and i % callback_steps == 0:
513 | step_idx = i // getattr(self.scheduler, "order", 1)
514 | callback(step_idx, t, latents)
515 |
516 | if XLA_AVAILABLE:
517 | xm.mark_step()
518 |
519 | if not output_type == "latent":
520 | # make sure the VAE is in float32 mode, as it overflows in float16
521 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
522 |
523 | if needs_upcasting:
524 | self.upcast_vae()
525 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
526 | elif latents.dtype != self.vae.dtype:
527 | if torch.backends.mps.is_available():
528 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
529 | self.vae = self.vae.to(latents.dtype)
530 |
531 | # unscale/denormalize the latents
532 | # denormalize with the mean and std if available and not None
533 | has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
534 | has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
535 | if has_latents_mean and has_latents_std:
536 | latents_mean = (
537 | torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
538 | )
539 | latents_std = (
540 | torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
541 | )
542 | latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
543 | else:
544 | latents = latents / self.vae.config.scaling_factor
545 |
546 | image = self.vae.decode(latents, return_dict=False)[0]
547 |
548 | # cast back to fp16 if needed
549 | if needs_upcasting:
550 | self.vae.to(dtype=torch.float16)
551 | else:
552 | image = latents
553 |
554 | if not output_type == "latent":
555 | # apply watermark if available
556 | if self.watermark is not None:
557 | image = self.watermark.apply_watermark(image)
558 |
559 | image = self.image_processor.postprocess(image, output_type=output_type)
560 |
561 | if self.do_normalized_attention_guidance and not attn_procs_recovered:
562 | self.unet.set_attn_processor(origin_attn_procs)
563 |
564 | # Offload all models
565 | self.maybe_free_model_hooks()
566 |
567 | if not return_dict:
568 | return (image,)
569 |
570 | return StableDiffusionXLPipelineOutput(images=image)
571 |
--------------------------------------------------------------------------------
/nag/pipeline_wan_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional, Union
2 |
3 | import torch
4 |
5 | from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
6 | from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
7 | from diffusers.pipelines.wan.pipeline_output import WanPipelineOutput
8 | from diffusers.pipelines.wan.pipeline_wan import WanPipeline
9 |
10 | from nag.attention_wan_nag import NAGWanAttnProcessor2_0
11 |
12 | if is_torch_xla_available():
13 | import torch_xla.core.xla_model as xm
14 |
15 | XLA_AVAILABLE = True
16 | else:
17 | XLA_AVAILABLE = False
18 |
19 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20 |
21 |
22 | class NAGWanPipeline(WanPipeline):
23 | @property
24 | def do_normalized_attention_guidance(self):
25 | return self._nag_scale > 1
26 |
27 | def _set_nag_attn_processor(self, nag_scale, nag_tau, nag_alpha):
28 | attn_procs = {}
29 | for name, origin_attn_proc in self.transformer.attn_processors.items():
30 | if "attn2" in name:
31 | attn_procs[name] = NAGWanAttnProcessor2_0(nag_scale=nag_scale, nag_tau=nag_tau, nag_alpha=nag_alpha)
32 | else:
33 | attn_procs[name] = origin_attn_proc
34 | self.transformer.set_attn_processor(attn_procs)
35 |
36 | @torch.no_grad()
37 | def __call__(
38 | self,
39 | prompt: Union[str, List[str]] = None,
40 | negative_prompt: Union[str, List[str]] = None,
41 | height: int = 480,
42 | width: int = 832,
43 | num_frames: int = 81,
44 | num_inference_steps: int = 50,
45 | guidance_scale: float = 5.0,
46 | num_videos_per_prompt: Optional[int] = 1,
47 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
48 | latents: Optional[torch.Tensor] = None,
49 | prompt_embeds: Optional[torch.Tensor] = None,
50 | negative_prompt_embeds: Optional[torch.Tensor] = None,
51 | output_type: Optional[str] = "np",
52 | return_dict: bool = True,
53 | attention_kwargs: Optional[Dict[str, Any]] = None,
54 | callback_on_step_end: Optional[
55 | Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
56 | ] = None,
57 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
58 | max_sequence_length: int = 512,
59 |
60 | nag_scale: float = 1.0,
61 | nag_tau: float = 2.5,
62 | nag_alpha: float = 0.25,
63 | nag_negative_prompt: str = None,
64 | nag_negative_prompt_embeds: Optional[torch.Tensor] = None,
65 | ):
66 | r"""
67 | The call function to the pipeline for generation.
68 |
69 | Args:
70 | prompt (`str` or `List[str]`, *optional*):
71 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
72 | instead.
73 | height (`int`, defaults to `480`):
74 | The height in pixels of the generated image.
75 | width (`int`, defaults to `832`):
76 | The width in pixels of the generated image.
77 | num_frames (`int`, defaults to `81`):
78 | The number of frames in the generated video.
79 | num_inference_steps (`int`, defaults to `50`):
80 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
81 | expense of slower inference.
82 | guidance_scale (`float`, defaults to `5.0`):
83 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
84 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
85 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
86 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
87 | usually at the expense of lower image quality.
88 | num_videos_per_prompt (`int`, *optional*, defaults to 1):
89 | The number of images to generate per prompt.
90 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
91 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
92 | generation deterministic.
93 | latents (`torch.Tensor`, *optional*):
94 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
95 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
96 | tensor is generated by sampling using the supplied random `generator`.
97 | prompt_embeds (`torch.Tensor`, *optional*):
98 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
99 | provided, text embeddings are generated from the `prompt` input argument.
100 | output_type (`str`, *optional*, defaults to `"pil"`):
101 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
102 | return_dict (`bool`, *optional*, defaults to `True`):
103 | Whether or not to return a [`WanPipelineOutput`] instead of a plain tuple.
104 | attention_kwargs (`dict`, *optional*):
105 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
106 | `self.processor` in
107 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
108 | callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
109 | A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
110 | each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
111 | DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
112 | list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
113 | callback_on_step_end_tensor_inputs (`List`, *optional*):
114 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
115 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
116 | `._callback_tensor_inputs` attribute of your pipeline class.
117 | autocast_dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
118 | The dtype to use for the torch.amp.autocast.
119 |
120 | Examples:
121 |
122 | Returns:
123 | [`~WanPipelineOutput`] or `tuple`:
124 | If `return_dict` is `True`, [`WanPipelineOutput`] is returned, otherwise a `tuple` is returned where
125 | the first element is a list with the generated images and the second element is a list of `bool`s
126 | indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
127 | """
128 |
129 | if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
130 | callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
131 |
132 | # 1. Check inputs. Raise error if not correct
133 | self.check_inputs(
134 | prompt,
135 | negative_prompt,
136 | height,
137 | width,
138 | prompt_embeds,
139 | negative_prompt_embeds,
140 | callback_on_step_end_tensor_inputs,
141 | )
142 |
143 | self._guidance_scale = guidance_scale
144 | self._attention_kwargs = attention_kwargs
145 | self._current_timestep = None
146 | self._interrupt = False
147 | self._nag_scale = nag_scale
148 |
149 | device = self._execution_device
150 |
151 | # 2. Define call parameters
152 | if prompt is not None and isinstance(prompt, str):
153 | batch_size = 1
154 | elif prompt is not None and isinstance(prompt, list):
155 | batch_size = len(prompt)
156 | else:
157 | batch_size = prompt_embeds.shape[0]
158 |
159 | # 3. Encode input prompt
160 | prompt_embeds, negative_prompt_embeds = self.encode_prompt(
161 | prompt=prompt,
162 | negative_prompt=negative_prompt,
163 | do_classifier_free_guidance=self.do_classifier_free_guidance,
164 | num_videos_per_prompt=num_videos_per_prompt,
165 | prompt_embeds=prompt_embeds,
166 | negative_prompt_embeds=negative_prompt_embeds,
167 | max_sequence_length=max_sequence_length,
168 | device=device,
169 | )
170 | if self.do_normalized_attention_guidance:
171 | if nag_negative_prompt_embeds is None:
172 | if nag_negative_prompt is None:
173 | if self.do_classifier_free_guidance:
174 | nag_negative_prompt_embeds = negative_prompt_embeds
175 | else:
176 | nag_negative_prompt = negative_prompt or ""
177 |
178 | if nag_negative_prompt is not None:
179 | nag_negative_prompt_embeds = self.encode_prompt(
180 | prompt=nag_negative_prompt,
181 | do_classifier_free_guidance=False,
182 | num_videos_per_prompt=num_videos_per_prompt,
183 | max_sequence_length=max_sequence_length,
184 | device=device,
185 | )[0]
186 |
187 | if self.do_normalized_attention_guidance:
188 | prompt_embeds = torch.cat([prompt_embeds, nag_negative_prompt_embeds], dim=0)
189 |
190 | transformer_dtype = self.transformer.dtype
191 | prompt_embeds = prompt_embeds.to(transformer_dtype)
192 | if negative_prompt_embeds is not None:
193 | negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
194 |
195 | # 4. Prepare timesteps
196 | self.scheduler.set_timesteps(num_inference_steps, device=device)
197 | timesteps = self.scheduler.timesteps
198 |
199 | # 5. Prepare latent variables
200 | num_channels_latents = self.transformer.config.in_channels
201 | latents = self.prepare_latents(
202 | batch_size * num_videos_per_prompt,
203 | num_channels_latents,
204 | height,
205 | width,
206 | num_frames,
207 | torch.float32,
208 | device,
209 | generator,
210 | latents,
211 | )
212 |
213 | # 6. Denoising loop
214 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
215 | self._num_timesteps = len(timesteps)
216 |
217 | if self.do_normalized_attention_guidance:
218 | origin_attn_procs = self.transformer.attn_processors
219 | self._set_nag_attn_processor(nag_scale, nag_tau, nag_alpha)
220 |
221 | with self.progress_bar(total=num_inference_steps) as progress_bar:
222 | for i, t in enumerate(timesteps):
223 | if self.interrupt:
224 | continue
225 |
226 | self._current_timestep = t
227 | latent_model_input = latents.to(transformer_dtype)
228 | timestep = t.expand(latents.shape[0])
229 |
230 | noise_pred = self.transformer(
231 | hidden_states=latent_model_input,
232 | timestep=timestep,
233 | encoder_hidden_states=prompt_embeds,
234 | attention_kwargs=attention_kwargs,
235 | return_dict=False,
236 | )[0]
237 |
238 | if self.do_classifier_free_guidance:
239 | noise_uncond = self.transformer(
240 | hidden_states=latent_model_input,
241 | timestep=timestep,
242 | encoder_hidden_states=negative_prompt_embeds,
243 | attention_kwargs=attention_kwargs,
244 | return_dict=False,
245 | )[0]
246 | noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
247 |
248 | # compute the previous noisy sample x_t -> x_t-1
249 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
250 |
251 | if callback_on_step_end is not None:
252 | callback_kwargs = {}
253 | for k in callback_on_step_end_tensor_inputs:
254 | callback_kwargs[k] = locals()[k]
255 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
256 |
257 | latents = callback_outputs.pop("latents", latents)
258 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
259 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
260 |
261 | # call the callback, if provided
262 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
263 | progress_bar.update()
264 |
265 | if XLA_AVAILABLE:
266 | xm.mark_step()
267 |
268 | self._current_timestep = None
269 |
270 | if not output_type == "latent":
271 | latents = latents.to(self.vae.dtype)
272 | latents_mean = (
273 | torch.tensor(self.vae.config.latents_mean)
274 | .view(1, self.vae.config.z_dim, 1, 1, 1)
275 | .to(latents.device, latents.dtype)
276 | )
277 | latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
278 | latents.device, latents.dtype
279 | )
280 | latents = latents / latents_std + latents_mean
281 | video = self.vae.decode(latents, return_dict=False)[0]
282 | video = self.video_processor.postprocess_video(video, output_type=output_type)
283 | else:
284 | video = latents
285 |
286 | if self.do_normalized_attention_guidance:
287 | self.transformer.set_attn_processor(origin_attn_procs)
288 |
289 | # Offload all models
290 | self.maybe_free_model_hooks()
291 |
292 | if not return_dict:
293 | return (video,)
294 |
295 | return WanPipelineOutput(frames=video)
296 |
--------------------------------------------------------------------------------
/nag/transformer_flux.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import numpy as np
4 | import torch
5 |
6 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
7 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
8 | from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel
9 |
10 |
11 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
12 |
13 |
14 | class NAGFluxTransformer2DModel(FluxTransformer2DModel):
15 | def forward(
16 | self,
17 | hidden_states: torch.Tensor,
18 | encoder_hidden_states: torch.Tensor = None,
19 | pooled_projections: torch.Tensor = None,
20 | timestep: torch.LongTensor = None,
21 | img_ids: torch.Tensor = None,
22 | txt_ids: torch.Tensor = None,
23 | guidance: torch.Tensor = None,
24 | joint_attention_kwargs: Optional[Dict[str, Any]] = None,
25 | controlnet_block_samples=None,
26 | controlnet_single_block_samples=None,
27 | return_dict: bool = True,
28 | controlnet_blocks_repeat: bool = False,
29 | ) -> Union[torch.Tensor, Transformer2DModelOutput]:
30 | """
31 | The [`FluxTransformer2DModel`] forward method.
32 |
33 | Args:
34 | hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
35 | Input `hidden_states`.
36 | encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
37 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
38 | pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
39 | from the embeddings of input conditions.
40 | timestep ( `torch.LongTensor`):
41 | Used to indicate denoising step.
42 | block_controlnet_hidden_states: (`list` of `torch.Tensor`):
43 | A list of tensors that if specified are added to the residuals of transformer blocks.
44 | joint_attention_kwargs (`dict`, *optional*):
45 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
46 | `self.processor` in
47 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
48 | return_dict (`bool`, *optional*, defaults to `True`):
49 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
50 | tuple.
51 |
52 | Returns:
53 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
54 | `tuple` where the first element is the sample tensor.
55 | """
56 | if joint_attention_kwargs is not None:
57 | joint_attention_kwargs = joint_attention_kwargs.copy()
58 | lora_scale = joint_attention_kwargs.pop("scale", 1.0)
59 | else:
60 | lora_scale = 1.0
61 |
62 | if USE_PEFT_BACKEND:
63 | # weight the lora layers by setting `lora_scale` for each PEFT layer
64 | scale_lora_layers(self, lora_scale)
65 | else:
66 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
67 | logger.warning(
68 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
69 | )
70 |
71 | do_nag = hidden_states.shape[0] != encoder_hidden_states.shape[0]
72 |
73 | hidden_states = self.x_embedder(hidden_states)
74 |
75 | timestep = timestep.to(hidden_states.dtype) * 1000
76 | if guidance is not None:
77 | guidance = guidance.to(hidden_states.dtype) * 1000
78 | else:
79 | guidance = None
80 |
81 | temb = (
82 | self.time_text_embed(timestep, pooled_projections)
83 | if guidance is None
84 | else self.time_text_embed(timestep, guidance, pooled_projections)
85 | )
86 | encoder_hidden_states = self.context_embedder(encoder_hidden_states)
87 |
88 | if txt_ids.ndim == 3:
89 | logger.warning(
90 | "Passing `txt_ids` 3d torch.Tensor is deprecated."
91 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
92 | )
93 | txt_ids = txt_ids[0]
94 | if img_ids.ndim == 3:
95 | logger.warning(
96 | "Passing `img_ids` 3d torch.Tensor is deprecated."
97 | "Please remove the batch dimension and pass it as a 2d torch Tensor"
98 | )
99 | img_ids = img_ids[0]
100 |
101 | ids = torch.cat((txt_ids, img_ids), dim=0)
102 | image_rotary_emb = self.pos_embed(ids)
103 |
104 | if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
105 | ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
106 | ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
107 | joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
108 |
109 | for index_block, block in enumerate(self.transformer_blocks):
110 | if torch.is_grad_enabled() and self.gradient_checkpointing:
111 | encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
112 | block,
113 | hidden_states,
114 | encoder_hidden_states,
115 | temb,
116 | image_rotary_emb,
117 | )
118 |
119 | else:
120 | encoder_hidden_states, hidden_states = block(
121 | hidden_states=hidden_states,
122 | encoder_hidden_states=encoder_hidden_states,
123 | temb=temb,
124 | image_rotary_emb=image_rotary_emb,
125 | joint_attention_kwargs=joint_attention_kwargs,
126 | )
127 |
128 | # controlnet residual
129 | if controlnet_block_samples is not None:
130 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
131 | interval_control = int(np.ceil(interval_control))
132 | # For Xlabs ControlNet.
133 | if controlnet_blocks_repeat:
134 | hidden_states = (
135 | hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
136 | )
137 | else:
138 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
139 |
140 | if do_nag:
141 | hidden_states = hidden_states.tile(2, 1, 1)
142 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
143 |
144 | for index_block, block in enumerate(self.single_transformer_blocks):
145 | if torch.is_grad_enabled() and self.gradient_checkpointing:
146 | hidden_states = self._gradient_checkpointing_func(
147 | block,
148 | hidden_states,
149 | temb,
150 | image_rotary_emb,
151 | )
152 |
153 | else:
154 | hidden_states = block(
155 | hidden_states=hidden_states,
156 | temb=temb,
157 | image_rotary_emb=image_rotary_emb,
158 | joint_attention_kwargs=joint_attention_kwargs,
159 | )
160 |
161 | # controlnet residual
162 | if controlnet_single_block_samples is not None:
163 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
164 | interval_control = int(np.ceil(interval_control))
165 | controlnet_single_block_sample = controlnet_single_block_samples[index_block // interval_control]
166 | if do_nag:
167 | controlnet_single_block_sample = controlnet_single_block_sample.tile(2, 1, 1)
168 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
169 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_sample
170 | )
171 |
172 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
173 |
174 | if do_nag:
175 | hidden_states = torch.chunk(hidden_states, 2, dim=0)[0]
176 |
177 | hidden_states = self.norm_out(hidden_states, temb)
178 | output = self.proj_out(hidden_states)
179 |
180 | if USE_PEFT_BACKEND:
181 | # remove `lora_scale` from each PEFT layer
182 | unscale_lora_layers(self, lora_scale)
183 |
184 | if not return_dict:
185 | return (output,)
186 |
187 | return Transformer2DModelOutput(sample=output)
188 |
--------------------------------------------------------------------------------
/nag/transformer_wan_nag.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import torch
4 |
5 | from diffusers.utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
6 | from diffusers.models.modeling_outputs import Transformer2DModelOutput
7 | from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
8 | from diffusers.models.attention_processor import AttentionProcessor
9 |
10 |
11 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
12 |
13 |
14 | class NagWanTransformer3DModel(WanTransformer3DModel):
15 | @property
16 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
17 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
18 | r"""
19 | Returns:
20 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
21 | indexed by its weight name.
22 | """
23 | # set recursively
24 | processors = {}
25 |
26 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
27 | if hasattr(module, "get_processor"):
28 | processors[f"{name}.processor"] = module.get_processor()
29 |
30 | for sub_name, child in module.named_children():
31 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
32 |
33 | return processors
34 |
35 | for name, module in self.named_children():
36 | fn_recursive_add_processors(name, module, processors)
37 |
38 | return processors
39 |
40 | # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
41 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
42 | r"""
43 | Sets the attention processor to use to compute attention.
44 |
45 | Parameters:
46 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
47 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
48 | for **all** `Attention` layers.
49 |
50 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
51 | processor. This is strongly recommended when setting trainable attention processors.
52 |
53 | """
54 | count = len(self.attn_processors.keys())
55 |
56 | if isinstance(processor, dict) and len(processor) != count:
57 | raise ValueError(
58 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
59 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
60 | )
61 |
62 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
63 | if hasattr(module, "set_processor"):
64 | if not isinstance(processor, dict):
65 | module.set_processor(processor)
66 | else:
67 | module.set_processor(processor.pop(f"{name}.processor"))
68 |
69 | for sub_name, child in module.named_children():
70 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
71 |
72 | for name, module in self.named_children():
73 | fn_recursive_attn_processor(name, module, processor)
74 |
75 | def forward(
76 | self,
77 | hidden_states: torch.Tensor,
78 | timestep: torch.LongTensor,
79 | encoder_hidden_states: torch.Tensor,
80 | encoder_hidden_states_image: Optional[torch.Tensor] = None,
81 | return_dict: bool = True,
82 | attention_kwargs: Optional[Dict[str, Any]] = None,
83 | ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
84 | if attention_kwargs is not None:
85 | attention_kwargs = attention_kwargs.copy()
86 | lora_scale = attention_kwargs.pop("scale", 1.0)
87 | else:
88 | lora_scale = 1.0
89 |
90 | if USE_PEFT_BACKEND:
91 | # weight the lora layers by setting `lora_scale` for each PEFT layer
92 | scale_lora_layers(self, lora_scale)
93 | else:
94 | if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
95 | logger.warning(
96 | "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
97 | )
98 |
99 | batch_size, num_channels, num_frames, height, width = hidden_states.shape
100 | p_t, p_h, p_w = self.config.patch_size
101 | post_patch_num_frames = num_frames // p_t
102 | post_patch_height = height // p_h
103 | post_patch_width = width // p_w
104 |
105 | rotary_emb = self.rope(hidden_states)
106 |
107 | hidden_states = self.patch_embedding(hidden_states)
108 | hidden_states = hidden_states.flatten(2).transpose(1, 2)
109 |
110 | temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
111 | timestep, encoder_hidden_states, encoder_hidden_states_image
112 | )
113 | timestep_proj = timestep_proj.unflatten(1, (6, -1))
114 |
115 | if encoder_hidden_states_image is not None:
116 | bs_encoder_hidden_states = len(encoder_hidden_states)
117 | bs_encoder_hidden_states_image = len(encoder_hidden_states_image)
118 | bs_scale = bs_encoder_hidden_states / bs_encoder_hidden_states_image
119 | assert bs_scale in [1, 2, 3]
120 | if bs_scale != 1:
121 | encoder_hidden_states_image = encoder_hidden_states_image.tile(int(bs_scale), 1, 1)
122 | encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
123 |
124 | # 4. Transformer blocks
125 | if torch.is_grad_enabled() and self.gradient_checkpointing:
126 | for block in self.blocks:
127 | hidden_states = self._gradient_checkpointing_func(
128 | block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
129 | )
130 | else:
131 | for block in self.blocks:
132 | hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
133 |
134 | # 5. Output norm, projection & unpatchify
135 | shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
136 |
137 | # Move the shift and scale tensors to the same device as hidden_states.
138 | # When using multi-GPU inference via accelerate these will be on the
139 | # first device rather than the last device, which hidden_states ends up
140 | # on.
141 | shift = shift.to(hidden_states.device)
142 | scale = scale.to(hidden_states.device)
143 |
144 | hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
145 | hidden_states = self.proj_out(hidden_states)
146 |
147 | hidden_states = hidden_states.reshape(
148 | batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
149 | )
150 | hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
151 | output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
152 |
153 | if USE_PEFT_BACKEND:
154 | # remove `lora_scale` from each PEFT layer
155 | unscale_lora_layers(self, lora_scale)
156 |
157 | if not return_dict:
158 | return (output,)
159 |
160 | return Transformer2DModelOutput(sample=output)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | diffusers
3 | torch
4 | transformers
5 | sentencepiece
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | # Read requirements.txt
4 | with open("requirements.txt", "r") as f:
5 | requirements = f.read().splitlines()
6 |
7 | setup(
8 | name="nag",
9 | version="0.0.0",
10 | description="Normalized Attention Guidance for Diffusion Models",
11 | author="ChenDarYen",
12 | packages=find_packages(include=["nag", "nag.*"]),
13 | install_requires=requirements,
14 | python_requires=">=3.10",
15 | )
--------------------------------------------------------------------------------