├── .gitignore
├── LICENSE
├── README.md
├── assets
├── figure4.png
└── result.png
├── requirements.txt
├── scedit_pytorch
├── __init__.py
├── diffusers_modules
│ ├── __init__.py
│ ├── unet_2d_blocks.py
│ └── unet_2d_condition.py
├── scedit.py
└── utils.py
├── scripts
├── app.py
└── scedit_pytorch.ipynb
└── train_dreambooth_scedit_sdxl.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | tests
163 | test.py
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mkshing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SCEdit-pytorch
2 |
3 |
4 |
5 | This is an implementation of [SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing](https://scedit.github.io/) by [mkshing](https://twitter.com/mk1stats).
6 |
7 | 
8 |
9 | * Beyond the paper, this implementation can use SDXL as the pre-trained model.
10 | * Enabled to set the weight scale by `scale`.
11 | * As the paper says, the architecture of SCEdit is very flexible. `SCTunerLinearLayer` I implemented seems too small compared to what the paper mentioned. So, please let me know if you find better ones.
12 |
13 | 
14 |
15 | ## Installation
16 |
17 | ```bash
18 | git clone https://github.com/mkshing/scedit-pytorch.git
19 | cd scedit-pytorch
20 | pip install -r requirements.txt
21 | ```
22 |
23 | ## SC-Tuner
24 |
25 | ### Training
26 | The training script is pretty much same as the [lora's script from diffuers](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md).
27 |
28 |
29 | ```bash
30 | MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
31 | INSTANCE_DIR="path-to-dataset"
32 | OUTPUT_DIR="scedit-trained-xl"
33 |
34 | accelerate launch train_dreambooth_scedit_sdxl.py \
35 | --pretrained_model_name_or_path=$MODEL_NAME \
36 | --instance_data_dir=$INSTANCE_DIR \
37 | --output_dir=$OUTPUT_DIR \
38 | --mixed_precision="fp16" \
39 | --instance_prompt="a photo of sbu dog" \
40 | --resolution=1024 \
41 | --train_batch_size=1 \
42 | --gradient_accumulation_steps=8 \
43 | --learning_rate=5e-5 \
44 | --lr_scheduler="constant" \
45 | --lr_warmup_steps=0 \
46 | --max_train_steps=1000 \
47 | --checkpointing_steps=200 \
48 | --validation_prompt="A photo of sbu dog in a bucket" \
49 | --validation_epochs=100 \
50 | --use_8bit_adam \
51 | --report_to="wandb" \
52 | --seed="0" \
53 | --push_to_hub
54 |
55 | ```
56 |
57 | ### Inference
58 |
59 | #### Python example:
60 | ```python
61 | from diffusers import DiffusionPipeline
62 | import torch
63 | from scedit_pytorch import UNet2DConditionModel, load_scedit_into_unet
64 |
65 |
66 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0"
67 | scedit_model_id = "path-to-scedit"
68 |
69 | # load unet with sctuner
70 | unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet")
71 | unet.set_sctuner(scale=1.0)
72 | unet = load_scedit_into_unet(scedit_model_id, unet)
73 | # load pipeline
74 | pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet)
75 | pipe = pipe.to(device="cuda", dtype=torch.float16)
76 | ```
77 |
78 | #### Gradio Demo:
79 | ```bash
80 | MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
81 | SCEDIT_NAME="mkshing/scedit-trained-xl"
82 |
83 | python scripts/gradio.py \
84 | --pretrained_model_name_or_path $MODEL_NAME \
85 | --scedit_name_or_path $SCEDIT_NAME
86 | ```
87 |
88 | ## TODO
89 | - [x] SC-Tuner
90 | - [ ] CSC-Tuner
--------------------------------------------------------------------------------
/assets/figure4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/scedit-pytorch/41b219515161fd78872254696c1212573867f372/assets/figure4.png
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/scedit-pytorch/41b219515161fd78872254696c1212573867f372/assets/result.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers>=0.24.0
2 | accelerate>=0.16.0
3 | torchvision
4 | transformers>=4.25.1
5 | ftfy
6 | tensorboard
7 | Jinja2
8 | gradio
9 |
--------------------------------------------------------------------------------
/scedit_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffusers_modules import UNet2DConditionModel
2 | from .utils import save_scedit, load_scedit_into_unet
3 |
--------------------------------------------------------------------------------
/scedit_pytorch/diffusers_modules/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | All files in this folder were modified for SCEdit based on https://github.com/huggingface/diffusers/tree/v0.24.0/src/diffusers/models.
3 | Especially, SC-Tuners are inserted into up blocks (CrossAttnUpBlock2D, UpBlock2D)
4 | """
5 | from .unet_2d_condition import UNet2DConditionModel
6 |
--------------------------------------------------------------------------------
/scedit_pytorch/diffusers_modules/unet_2d_blocks.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional, Tuple, Union
2 |
3 | import torch
4 | from torch import nn
5 |
6 | from diffusers.utils import is_torch_version, logging
7 | from diffusers.utils.torch_utils import apply_freeu
8 | from diffusers.models.dual_transformer_2d import DualTransformer2DModel
9 | from diffusers.models.resnet import ResnetBlock2D, Upsample2D
10 | from diffusers.models.transformer_2d import Transformer2DModel
11 | from ..scedit import SCTuner
12 |
13 |
14 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
15 |
16 |
17 | def get_up_block(
18 | up_block_type: str,
19 | num_layers: int,
20 | in_channels: int,
21 | out_channels: int,
22 | prev_output_channel: int,
23 | temb_channels: int,
24 | add_upsample: bool,
25 | resnet_eps: float,
26 | resnet_act_fn: str,
27 | resolution_idx: Optional[int] = None,
28 | transformer_layers_per_block: int = 1,
29 | num_attention_heads: Optional[int] = None,
30 | resnet_groups: Optional[int] = None,
31 | cross_attention_dim: Optional[int] = None,
32 | dual_cross_attention: bool = False,
33 | use_linear_projection: bool = False,
34 | only_cross_attention: bool = False,
35 | upcast_attention: bool = False,
36 | resnet_time_scale_shift: str = "default",
37 | attention_type: str = "default",
38 | resnet_skip_time_act: bool = False,
39 | resnet_out_scale_factor: float = 1.0,
40 | cross_attention_norm: Optional[str] = None,
41 | attention_head_dim: Optional[int] = None,
42 | upsample_type: Optional[str] = None,
43 | dropout: float = 0.0,
44 | ) -> nn.Module:
45 | # If attn head dim is not defined, we default it to the number of heads
46 | if attention_head_dim is None:
47 | logger.warn(
48 | f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
49 | )
50 | attention_head_dim = num_attention_heads
51 |
52 | up_block_type = (
53 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
54 | )
55 | if up_block_type == "UpBlock2D":
56 | return UpBlock2D(
57 | num_layers=num_layers,
58 | in_channels=in_channels,
59 | out_channels=out_channels,
60 | prev_output_channel=prev_output_channel,
61 | temb_channels=temb_channels,
62 | resolution_idx=resolution_idx,
63 | dropout=dropout,
64 | add_upsample=add_upsample,
65 | resnet_eps=resnet_eps,
66 | resnet_act_fn=resnet_act_fn,
67 | resnet_groups=resnet_groups,
68 | resnet_time_scale_shift=resnet_time_scale_shift,
69 | )
70 | elif up_block_type == "CrossAttnUpBlock2D":
71 | if cross_attention_dim is None:
72 | raise ValueError(
73 | "cross_attention_dim must be specified for CrossAttnUpBlock2D"
74 | )
75 | return CrossAttnUpBlock2D(
76 | num_layers=num_layers,
77 | transformer_layers_per_block=transformer_layers_per_block,
78 | in_channels=in_channels,
79 | out_channels=out_channels,
80 | prev_output_channel=prev_output_channel,
81 | temb_channels=temb_channels,
82 | resolution_idx=resolution_idx,
83 | dropout=dropout,
84 | add_upsample=add_upsample,
85 | resnet_eps=resnet_eps,
86 | resnet_act_fn=resnet_act_fn,
87 | resnet_groups=resnet_groups,
88 | cross_attention_dim=cross_attention_dim,
89 | num_attention_heads=num_attention_heads,
90 | dual_cross_attention=dual_cross_attention,
91 | use_linear_projection=use_linear_projection,
92 | only_cross_attention=only_cross_attention,
93 | upcast_attention=upcast_attention,
94 | resnet_time_scale_shift=resnet_time_scale_shift,
95 | attention_type=attention_type,
96 | )
97 |
98 | # raise ValueError(f"{up_block_type} does not exist.")
99 | raise NotImplementedError
100 |
101 |
102 | class CrossAttnUpBlock2D(SCTuner):
103 | def __init__(
104 | self,
105 | in_channels: int,
106 | out_channels: int,
107 | prev_output_channel: int,
108 | temb_channels: int,
109 | resolution_idx: Optional[int] = None,
110 | dropout: float = 0.0,
111 | num_layers: int = 1,
112 | transformer_layers_per_block: Union[int, Tuple[int]] = 1,
113 | resnet_eps: float = 1e-6,
114 | resnet_time_scale_shift: str = "default",
115 | resnet_act_fn: str = "swish",
116 | resnet_groups: int = 32,
117 | resnet_pre_norm: bool = True,
118 | num_attention_heads: int = 1,
119 | cross_attention_dim: int = 1280,
120 | output_scale_factor: float = 1.0,
121 | add_upsample: bool = True,
122 | dual_cross_attention: bool = False,
123 | use_linear_projection: bool = False,
124 | only_cross_attention: bool = False,
125 | upcast_attention: bool = False,
126 | attention_type: str = "default",
127 | ):
128 | super().__init__()
129 | resnets = []
130 | attentions = []
131 | self.res_skip_channels = []
132 |
133 | self.has_cross_attention = True
134 | self.num_attention_heads = num_attention_heads
135 |
136 | if isinstance(transformer_layers_per_block, int):
137 | transformer_layers_per_block = [transformer_layers_per_block] * num_layers
138 |
139 | for i in range(num_layers):
140 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
141 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
142 | self.res_skip_channels.append(res_skip_channels)
143 | resnets.append(
144 | ResnetBlock2D(
145 | in_channels=resnet_in_channels + res_skip_channels,
146 | out_channels=out_channels,
147 | temb_channels=temb_channels,
148 | eps=resnet_eps,
149 | groups=resnet_groups,
150 | dropout=dropout,
151 | time_embedding_norm=resnet_time_scale_shift,
152 | non_linearity=resnet_act_fn,
153 | output_scale_factor=output_scale_factor,
154 | pre_norm=resnet_pre_norm,
155 | )
156 | )
157 | if not dual_cross_attention:
158 | attentions.append(
159 | Transformer2DModel(
160 | num_attention_heads,
161 | out_channels // num_attention_heads,
162 | in_channels=out_channels,
163 | num_layers=transformer_layers_per_block[i],
164 | cross_attention_dim=cross_attention_dim,
165 | norm_num_groups=resnet_groups,
166 | use_linear_projection=use_linear_projection,
167 | only_cross_attention=only_cross_attention,
168 | upcast_attention=upcast_attention,
169 | attention_type=attention_type,
170 | )
171 | )
172 | else:
173 | attentions.append(
174 | DualTransformer2DModel(
175 | num_attention_heads,
176 | out_channels // num_attention_heads,
177 | in_channels=out_channels,
178 | num_layers=1,
179 | cross_attention_dim=cross_attention_dim,
180 | norm_num_groups=resnet_groups,
181 | )
182 | )
183 | self.attentions = nn.ModuleList(attentions)
184 | self.resnets = nn.ModuleList(resnets)
185 |
186 | if add_upsample:
187 | self.upsamplers = nn.ModuleList(
188 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
189 | )
190 | else:
191 | self.upsamplers = None
192 |
193 | self.gradient_checkpointing = False
194 | self.resolution_idx = resolution_idx
195 |
196 | def forward(
197 | self,
198 | hidden_states: torch.FloatTensor,
199 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
200 | temb: Optional[torch.FloatTensor] = None,
201 | encoder_hidden_states: Optional[torch.FloatTensor] = None,
202 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
203 | upsample_size: Optional[int] = None,
204 | attention_mask: Optional[torch.FloatTensor] = None,
205 | encoder_attention_mask: Optional[torch.FloatTensor] = None,
206 | ) -> torch.FloatTensor:
207 | lora_scale = (
208 | cross_attention_kwargs.get("scale", 1.0)
209 | if cross_attention_kwargs is not None
210 | else 1.0
211 | )
212 | is_freeu_enabled = (
213 | getattr(self, "s1", None)
214 | and getattr(self, "s2", None)
215 | and getattr(self, "b1", None)
216 | and getattr(self, "b2", None)
217 | )
218 |
219 | for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)):
220 | # pop res hidden states
221 | res_hidden_states = res_hidden_states_tuple[-1]
222 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
223 |
224 | # FreeU: Only operate on the first two stages
225 | if is_freeu_enabled:
226 | hidden_states, res_hidden_states = apply_freeu(
227 | self.resolution_idx,
228 | hidden_states,
229 | res_hidden_states,
230 | s1=self.s1,
231 | s2=self.s2,
232 | b1=self.b1,
233 | b2=self.b2,
234 | )
235 | if self.sc_tuners is not None:
236 | res_hidden_states = self.sc_tuners[i](res_hidden_states)
237 |
238 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
239 |
240 | if self.training and self.gradient_checkpointing:
241 |
242 | def create_custom_forward(module, return_dict=None):
243 | def custom_forward(*inputs):
244 | if return_dict is not None:
245 | return module(*inputs, return_dict=return_dict)
246 | else:
247 | return module(*inputs)
248 |
249 | return custom_forward
250 |
251 | ckpt_kwargs: Dict[str, Any] = (
252 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
253 | )
254 | hidden_states = torch.utils.checkpoint.checkpoint(
255 | create_custom_forward(resnet),
256 | hidden_states,
257 | temb,
258 | **ckpt_kwargs,
259 | )
260 | hidden_states = attn(
261 | hidden_states,
262 | encoder_hidden_states=encoder_hidden_states,
263 | cross_attention_kwargs=cross_attention_kwargs,
264 | attention_mask=attention_mask,
265 | encoder_attention_mask=encoder_attention_mask,
266 | return_dict=False,
267 | )[0]
268 | else:
269 | hidden_states = resnet(hidden_states, temb, scale=lora_scale)
270 | hidden_states = attn(
271 | hidden_states,
272 | encoder_hidden_states=encoder_hidden_states,
273 | cross_attention_kwargs=cross_attention_kwargs,
274 | attention_mask=attention_mask,
275 | encoder_attention_mask=encoder_attention_mask,
276 | return_dict=False,
277 | )[0]
278 |
279 | if self.upsamplers is not None:
280 | for upsampler in self.upsamplers:
281 | hidden_states = upsampler(
282 | hidden_states, upsample_size, scale=lora_scale
283 | )
284 |
285 | return hidden_states
286 |
287 |
288 | class UpBlock2D(SCTuner):
289 | def __init__(
290 | self,
291 | in_channels: int,
292 | prev_output_channel: int,
293 | out_channels: int,
294 | temb_channels: int,
295 | resolution_idx: Optional[int] = None,
296 | dropout: float = 0.0,
297 | num_layers: int = 1,
298 | resnet_eps: float = 1e-6,
299 | resnet_time_scale_shift: str = "default",
300 | resnet_act_fn: str = "swish",
301 | resnet_groups: int = 32,
302 | resnet_pre_norm: bool = True,
303 | output_scale_factor: float = 1.0,
304 | add_upsample: bool = True,
305 | ):
306 | super().__init__()
307 | resnets = []
308 | self.res_skip_channels = []
309 |
310 | for i in range(num_layers):
311 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
312 | resnet_in_channels = prev_output_channel if i == 0 else out_channels
313 | self.res_skip_channels.append(res_skip_channels)
314 | resnets.append(
315 | ResnetBlock2D(
316 | in_channels=resnet_in_channels + res_skip_channels,
317 | out_channels=out_channels,
318 | temb_channels=temb_channels,
319 | eps=resnet_eps,
320 | groups=resnet_groups,
321 | dropout=dropout,
322 | time_embedding_norm=resnet_time_scale_shift,
323 | non_linearity=resnet_act_fn,
324 | output_scale_factor=output_scale_factor,
325 | pre_norm=resnet_pre_norm,
326 | )
327 | )
328 |
329 | self.resnets = nn.ModuleList(resnets)
330 |
331 | if add_upsample:
332 | self.upsamplers = nn.ModuleList(
333 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
334 | )
335 | else:
336 | self.upsamplers = None
337 |
338 | self.gradient_checkpointing = False
339 | self.resolution_idx = resolution_idx
340 |
341 | def forward(
342 | self,
343 | hidden_states: torch.FloatTensor,
344 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
345 | temb: Optional[torch.FloatTensor] = None,
346 | upsample_size: Optional[int] = None,
347 | scale: float = 1.0,
348 | ) -> torch.FloatTensor:
349 | is_freeu_enabled = (
350 | getattr(self, "s1", None)
351 | and getattr(self, "s2", None)
352 | and getattr(self, "b1", None)
353 | and getattr(self, "b2", None)
354 | )
355 |
356 | for i, resnet in enumerate(self.resnets):
357 | # pop res hidden states
358 | res_hidden_states = res_hidden_states_tuple[-1]
359 | res_hidden_states_tuple = res_hidden_states_tuple[:-1]
360 |
361 | # FreeU: Only operate on the first two stages
362 | if is_freeu_enabled:
363 | hidden_states, res_hidden_states = apply_freeu(
364 | self.resolution_idx,
365 | hidden_states,
366 | res_hidden_states,
367 | s1=self.s1,
368 | s2=self.s2,
369 | b1=self.b1,
370 | b2=self.b2,
371 | )
372 | if self.sc_tuners is not None:
373 | res_hidden_states = self.sc_tuners[i](res_hidden_states)
374 |
375 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
376 |
377 | if self.training and self.gradient_checkpointing:
378 |
379 | def create_custom_forward(module):
380 | def custom_forward(*inputs):
381 | return module(*inputs)
382 |
383 | return custom_forward
384 |
385 | if is_torch_version(">=", "1.11.0"):
386 | hidden_states = torch.utils.checkpoint.checkpoint(
387 | create_custom_forward(resnet),
388 | hidden_states,
389 | temb,
390 | use_reentrant=False,
391 | )
392 | else:
393 | hidden_states = torch.utils.checkpoint.checkpoint(
394 | create_custom_forward(resnet), hidden_states, temb
395 | )
396 | else:
397 | hidden_states = resnet(hidden_states, temb, scale=scale)
398 |
399 | if self.upsamplers is not None:
400 | for upsampler in self.upsamplers:
401 | hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
402 |
403 | return hidden_states
404 |
--------------------------------------------------------------------------------
/scedit_pytorch/diffusers_modules/unet_2d_condition.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.utils.checkpoint
20 |
21 | from diffusers.configuration_utils import ConfigMixin, register_to_config
22 | from diffusers.loaders import UNet2DConditionLoadersMixin
23 | from diffusers.utils import (
24 | USE_PEFT_BACKEND,
25 | BaseOutput,
26 | deprecate,
27 | logging,
28 | scale_lora_layers,
29 | unscale_lora_layers,
30 | )
31 | from diffusers.models.activations import get_activation
32 | from diffusers.models.attention_processor import (
33 | ADDED_KV_ATTENTION_PROCESSORS,
34 | CROSS_ATTENTION_PROCESSORS,
35 | AttentionProcessor,
36 | AttnAddedKVProcessor,
37 | AttnProcessor,
38 | )
39 | from diffusers.models.embeddings import (
40 | GaussianFourierProjection,
41 | ImageHintTimeEmbedding,
42 | ImageProjection,
43 | ImageTimeEmbedding,
44 | PositionNet,
45 | TextImageProjection,
46 | TextImageTimeEmbedding,
47 | TextTimeEmbedding,
48 | TimestepEmbedding,
49 | Timesteps,
50 | )
51 | from diffusers.models.modeling_utils import ModelMixin
52 | from diffusers.models.unet_2d_blocks import (
53 | UNetMidBlock2D,
54 | UNetMidBlock2DCrossAttn,
55 | UNetMidBlock2DSimpleCrossAttn,
56 | get_down_block,
57 | )
58 | from .unet_2d_blocks import get_up_block
59 | from ..scedit import SCTunerLinearLayer
60 |
61 |
62 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
63 |
64 |
65 | @dataclass
66 | class UNet2DConditionOutput(BaseOutput):
67 | """
68 | The output of [`UNet2DConditionModel`].
69 |
70 | Args:
71 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
72 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
73 | """
74 |
75 | sample: torch.FloatTensor = None
76 |
77 |
78 | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
79 | r"""
80 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
81 | shaped output.
82 |
83 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
84 | for all models (such as downloading or saving).
85 |
86 | Parameters:
87 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
88 | Height and width of input/output sample.
89 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
90 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
91 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
92 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
93 | Whether to flip the sin to cos in the time embedding.
94 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
95 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
96 | The tuple of downsample blocks to use.
97 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
98 | Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
99 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
100 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
101 | The tuple of upsample blocks to use.
102 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
103 | Whether to include self-attention in the basic transformer blocks, see
104 | [`~models.attention.BasicTransformerBlock`].
105 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
106 | The tuple of output channels for each block.
107 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
108 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
109 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
110 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
111 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
112 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
113 | If `None`, normalization and activation layers is skipped in post-processing.
114 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
115 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
116 | The dimension of the cross attention features.
117 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
118 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
119 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
120 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
121 | reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
122 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
123 | blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
124 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
125 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
126 | encoder_hid_dim (`int`, *optional*, defaults to None):
127 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
128 | dimension to `cross_attention_dim`.
129 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
130 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
131 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
132 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
133 | num_attention_heads (`int`, *optional*):
134 | The number of attention heads. If not defined, defaults to `attention_head_dim`
135 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
136 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
137 | class_embed_type (`str`, *optional*, defaults to `None`):
138 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
139 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
140 | addition_embed_type (`str`, *optional*, defaults to `None`):
141 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
142 | "text". "text" will use the `TextTimeEmbedding` layer.
143 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
144 | Dimension for the timestep embeddings.
145 | num_class_embeds (`int`, *optional*, defaults to `None`):
146 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
147 | class conditioning with `class_embed_type` equal to `None`.
148 | time_embedding_type (`str`, *optional*, defaults to `positional`):
149 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
150 | time_embedding_dim (`int`, *optional*, defaults to `None`):
151 | An optional override for the dimension of the projected time embedding.
152 | time_embedding_act_fn (`str`, *optional*, defaults to `None`):
153 | Optional activation function to use only once on the time embeddings before they are passed to the rest of
154 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
155 | timestep_post_act (`str`, *optional*, defaults to `None`):
156 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
157 | time_cond_proj_dim (`int`, *optional*, defaults to `None`):
158 | The dimension of `cond_proj` layer in the timestep embedding.
159 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
160 | *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
161 | *optional*): The dimension of the `class_labels` input when
162 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
163 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
164 | embeddings with the class embeddings.
165 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
166 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
167 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
168 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
169 | otherwise.
170 | """
171 |
172 | _supports_gradient_checkpointing = True
173 |
174 | @register_to_config
175 | def __init__(
176 | self,
177 | sample_size: Optional[int] = None,
178 | in_channels: int = 4,
179 | out_channels: int = 4,
180 | center_input_sample: bool = False,
181 | flip_sin_to_cos: bool = True,
182 | freq_shift: int = 0,
183 | down_block_types: Tuple[str] = (
184 | "CrossAttnDownBlock2D",
185 | "CrossAttnDownBlock2D",
186 | "CrossAttnDownBlock2D",
187 | "DownBlock2D",
188 | ),
189 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
190 | up_block_types: Tuple[str] = (
191 | "UpBlock2D",
192 | "CrossAttnUpBlock2D",
193 | "CrossAttnUpBlock2D",
194 | "CrossAttnUpBlock2D",
195 | ),
196 | only_cross_attention: Union[bool, Tuple[bool]] = False,
197 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
198 | layers_per_block: Union[int, Tuple[int]] = 2,
199 | downsample_padding: int = 1,
200 | mid_block_scale_factor: float = 1,
201 | dropout: float = 0.0,
202 | act_fn: str = "silu",
203 | norm_num_groups: Optional[int] = 32,
204 | norm_eps: float = 1e-5,
205 | cross_attention_dim: Union[int, Tuple[int]] = 1280,
206 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
207 | reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
208 | encoder_hid_dim: Optional[int] = None,
209 | encoder_hid_dim_type: Optional[str] = None,
210 | attention_head_dim: Union[int, Tuple[int]] = 8,
211 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
212 | dual_cross_attention: bool = False,
213 | use_linear_projection: bool = False,
214 | class_embed_type: Optional[str] = None,
215 | addition_embed_type: Optional[str] = None,
216 | addition_time_embed_dim: Optional[int] = None,
217 | num_class_embeds: Optional[int] = None,
218 | upcast_attention: bool = False,
219 | resnet_time_scale_shift: str = "default",
220 | resnet_skip_time_act: bool = False,
221 | resnet_out_scale_factor: int = 1.0,
222 | time_embedding_type: str = "positional",
223 | time_embedding_dim: Optional[int] = None,
224 | time_embedding_act_fn: Optional[str] = None,
225 | timestep_post_act: Optional[str] = None,
226 | time_cond_proj_dim: Optional[int] = None,
227 | conv_in_kernel: int = 3,
228 | conv_out_kernel: int = 3,
229 | projection_class_embeddings_input_dim: Optional[int] = None,
230 | attention_type: str = "default",
231 | class_embeddings_concat: bool = False,
232 | mid_block_only_cross_attention: Optional[bool] = None,
233 | cross_attention_norm: Optional[str] = None,
234 | addition_embed_type_num_heads=64,
235 | ):
236 | super().__init__()
237 |
238 | self.sample_size = sample_size
239 |
240 | if num_attention_heads is not None:
241 | raise ValueError(
242 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
243 | )
244 |
245 | # If `num_attention_heads` is not defined (which is the case for most models)
246 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
247 | # The reason for this behavior is to correct for incorrectly named variables that were introduced
248 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
249 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
250 | # which is why we correct for the naming here.
251 | num_attention_heads = num_attention_heads or attention_head_dim
252 |
253 | # Check inputs
254 | if len(down_block_types) != len(up_block_types):
255 | raise ValueError(
256 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
257 | )
258 |
259 | if len(block_out_channels) != len(down_block_types):
260 | raise ValueError(
261 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
262 | )
263 |
264 | if not isinstance(only_cross_attention, bool) and len(
265 | only_cross_attention
266 | ) != len(down_block_types):
267 | raise ValueError(
268 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
269 | )
270 |
271 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
272 | down_block_types
273 | ):
274 | raise ValueError(
275 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
276 | )
277 |
278 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
279 | down_block_types
280 | ):
281 | raise ValueError(
282 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
283 | )
284 |
285 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
286 | down_block_types
287 | ):
288 | raise ValueError(
289 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
290 | )
291 |
292 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
293 | down_block_types
294 | ):
295 | raise ValueError(
296 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
297 | )
298 | if (
299 | isinstance(transformer_layers_per_block, list)
300 | and reverse_transformer_layers_per_block is None
301 | ):
302 | for layer_number_per_block in transformer_layers_per_block:
303 | if isinstance(layer_number_per_block, list):
304 | raise ValueError(
305 | "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
306 | )
307 |
308 | # input
309 | conv_in_padding = (conv_in_kernel - 1) // 2
310 | self.conv_in = nn.Conv2d(
311 | in_channels,
312 | block_out_channels[0],
313 | kernel_size=conv_in_kernel,
314 | padding=conv_in_padding,
315 | )
316 |
317 | # time
318 | if time_embedding_type == "fourier":
319 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
320 | if time_embed_dim % 2 != 0:
321 | raise ValueError(
322 | f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
323 | )
324 | self.time_proj = GaussianFourierProjection(
325 | time_embed_dim // 2,
326 | set_W_to_weight=False,
327 | log=False,
328 | flip_sin_to_cos=flip_sin_to_cos,
329 | )
330 | timestep_input_dim = time_embed_dim
331 | elif time_embedding_type == "positional":
332 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
333 |
334 | self.time_proj = Timesteps(
335 | block_out_channels[0], flip_sin_to_cos, freq_shift
336 | )
337 | timestep_input_dim = block_out_channels[0]
338 | else:
339 | raise ValueError(
340 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
341 | )
342 |
343 | self.time_embedding = TimestepEmbedding(
344 | timestep_input_dim,
345 | time_embed_dim,
346 | act_fn=act_fn,
347 | post_act_fn=timestep_post_act,
348 | cond_proj_dim=time_cond_proj_dim,
349 | )
350 |
351 | if encoder_hid_dim_type is None and encoder_hid_dim is not None:
352 | encoder_hid_dim_type = "text_proj"
353 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
354 | logger.info(
355 | "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
356 | )
357 |
358 | if encoder_hid_dim is None and encoder_hid_dim_type is not None:
359 | raise ValueError(
360 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
361 | )
362 |
363 | if encoder_hid_dim_type == "text_proj":
364 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
365 | elif encoder_hid_dim_type == "text_image_proj":
366 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
367 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
368 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
369 | self.encoder_hid_proj = TextImageProjection(
370 | text_embed_dim=encoder_hid_dim,
371 | image_embed_dim=cross_attention_dim,
372 | cross_attention_dim=cross_attention_dim,
373 | )
374 | elif encoder_hid_dim_type == "image_proj":
375 | # Kandinsky 2.2
376 | self.encoder_hid_proj = ImageProjection(
377 | image_embed_dim=encoder_hid_dim,
378 | cross_attention_dim=cross_attention_dim,
379 | )
380 | elif encoder_hid_dim_type is not None:
381 | raise ValueError(
382 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
383 | )
384 | else:
385 | self.encoder_hid_proj = None
386 |
387 | # class embedding
388 | if class_embed_type is None and num_class_embeds is not None:
389 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
390 | elif class_embed_type == "timestep":
391 | self.class_embedding = TimestepEmbedding(
392 | timestep_input_dim, time_embed_dim, act_fn=act_fn
393 | )
394 | elif class_embed_type == "identity":
395 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
396 | elif class_embed_type == "projection":
397 | if projection_class_embeddings_input_dim is None:
398 | raise ValueError(
399 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
400 | )
401 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
402 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
403 | # 2. it projects from an arbitrary input dimension.
404 | #
405 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
406 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
407 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
408 | self.class_embedding = TimestepEmbedding(
409 | projection_class_embeddings_input_dim, time_embed_dim
410 | )
411 | elif class_embed_type == "simple_projection":
412 | if projection_class_embeddings_input_dim is None:
413 | raise ValueError(
414 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
415 | )
416 | self.class_embedding = nn.Linear(
417 | projection_class_embeddings_input_dim, time_embed_dim
418 | )
419 | else:
420 | self.class_embedding = None
421 |
422 | if addition_embed_type == "text":
423 | if encoder_hid_dim is not None:
424 | text_time_embedding_from_dim = encoder_hid_dim
425 | else:
426 | text_time_embedding_from_dim = cross_attention_dim
427 |
428 | self.add_embedding = TextTimeEmbedding(
429 | text_time_embedding_from_dim,
430 | time_embed_dim,
431 | num_heads=addition_embed_type_num_heads,
432 | )
433 | elif addition_embed_type == "text_image":
434 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
435 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
436 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
437 | self.add_embedding = TextImageTimeEmbedding(
438 | text_embed_dim=cross_attention_dim,
439 | image_embed_dim=cross_attention_dim,
440 | time_embed_dim=time_embed_dim,
441 | )
442 | elif addition_embed_type == "text_time":
443 | self.add_time_proj = Timesteps(
444 | addition_time_embed_dim, flip_sin_to_cos, freq_shift
445 | )
446 | self.add_embedding = TimestepEmbedding(
447 | projection_class_embeddings_input_dim, time_embed_dim
448 | )
449 | elif addition_embed_type == "image":
450 | # Kandinsky 2.2
451 | self.add_embedding = ImageTimeEmbedding(
452 | image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
453 | )
454 | elif addition_embed_type == "image_hint":
455 | # Kandinsky 2.2 ControlNet
456 | self.add_embedding = ImageHintTimeEmbedding(
457 | image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
458 | )
459 | elif addition_embed_type is not None:
460 | raise ValueError(
461 | f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
462 | )
463 |
464 | if time_embedding_act_fn is None:
465 | self.time_embed_act = None
466 | else:
467 | self.time_embed_act = get_activation(time_embedding_act_fn)
468 |
469 | self.down_blocks = nn.ModuleList([])
470 | self.up_blocks = nn.ModuleList([])
471 |
472 | if isinstance(only_cross_attention, bool):
473 | if mid_block_only_cross_attention is None:
474 | mid_block_only_cross_attention = only_cross_attention
475 |
476 | only_cross_attention = [only_cross_attention] * len(down_block_types)
477 |
478 | if mid_block_only_cross_attention is None:
479 | mid_block_only_cross_attention = False
480 |
481 | if isinstance(num_attention_heads, int):
482 | num_attention_heads = (num_attention_heads,) * len(down_block_types)
483 |
484 | if isinstance(attention_head_dim, int):
485 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
486 |
487 | if isinstance(cross_attention_dim, int):
488 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
489 |
490 | if isinstance(layers_per_block, int):
491 | layers_per_block = [layers_per_block] * len(down_block_types)
492 |
493 | if isinstance(transformer_layers_per_block, int):
494 | transformer_layers_per_block = [transformer_layers_per_block] * len(
495 | down_block_types
496 | )
497 |
498 | if class_embeddings_concat:
499 | # The time embeddings are concatenated with the class embeddings. The dimension of the
500 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
501 | # regular time embeddings
502 | blocks_time_embed_dim = time_embed_dim * 2
503 | else:
504 | blocks_time_embed_dim = time_embed_dim
505 |
506 | # down
507 | output_channel = block_out_channels[0]
508 | for i, down_block_type in enumerate(down_block_types):
509 | input_channel = output_channel
510 | output_channel = block_out_channels[i]
511 | is_final_block = i == len(block_out_channels) - 1
512 |
513 | down_block = get_down_block(
514 | down_block_type,
515 | num_layers=layers_per_block[i],
516 | transformer_layers_per_block=transformer_layers_per_block[i],
517 | in_channels=input_channel,
518 | out_channels=output_channel,
519 | temb_channels=blocks_time_embed_dim,
520 | add_downsample=not is_final_block,
521 | resnet_eps=norm_eps,
522 | resnet_act_fn=act_fn,
523 | resnet_groups=norm_num_groups,
524 | cross_attention_dim=cross_attention_dim[i],
525 | num_attention_heads=num_attention_heads[i],
526 | downsample_padding=downsample_padding,
527 | dual_cross_attention=dual_cross_attention,
528 | use_linear_projection=use_linear_projection,
529 | only_cross_attention=only_cross_attention[i],
530 | upcast_attention=upcast_attention,
531 | resnet_time_scale_shift=resnet_time_scale_shift,
532 | attention_type=attention_type,
533 | resnet_skip_time_act=resnet_skip_time_act,
534 | resnet_out_scale_factor=resnet_out_scale_factor,
535 | cross_attention_norm=cross_attention_norm,
536 | attention_head_dim=attention_head_dim[i]
537 | if attention_head_dim[i] is not None
538 | else output_channel,
539 | dropout=dropout,
540 | )
541 | self.down_blocks.append(down_block)
542 |
543 | # mid
544 | if mid_block_type == "UNetMidBlock2DCrossAttn":
545 | self.mid_block = UNetMidBlock2DCrossAttn(
546 | transformer_layers_per_block=transformer_layers_per_block[-1],
547 | in_channels=block_out_channels[-1],
548 | temb_channels=blocks_time_embed_dim,
549 | dropout=dropout,
550 | resnet_eps=norm_eps,
551 | resnet_act_fn=act_fn,
552 | output_scale_factor=mid_block_scale_factor,
553 | resnet_time_scale_shift=resnet_time_scale_shift,
554 | cross_attention_dim=cross_attention_dim[-1],
555 | num_attention_heads=num_attention_heads[-1],
556 | resnet_groups=norm_num_groups,
557 | dual_cross_attention=dual_cross_attention,
558 | use_linear_projection=use_linear_projection,
559 | upcast_attention=upcast_attention,
560 | attention_type=attention_type,
561 | )
562 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
563 | self.mid_block = UNetMidBlock2DSimpleCrossAttn(
564 | in_channels=block_out_channels[-1],
565 | temb_channels=blocks_time_embed_dim,
566 | dropout=dropout,
567 | resnet_eps=norm_eps,
568 | resnet_act_fn=act_fn,
569 | output_scale_factor=mid_block_scale_factor,
570 | cross_attention_dim=cross_attention_dim[-1],
571 | attention_head_dim=attention_head_dim[-1],
572 | resnet_groups=norm_num_groups,
573 | resnet_time_scale_shift=resnet_time_scale_shift,
574 | skip_time_act=resnet_skip_time_act,
575 | only_cross_attention=mid_block_only_cross_attention,
576 | cross_attention_norm=cross_attention_norm,
577 | )
578 | elif mid_block_type == "UNetMidBlock2D":
579 | self.mid_block = UNetMidBlock2D(
580 | in_channels=block_out_channels[-1],
581 | temb_channels=blocks_time_embed_dim,
582 | dropout=dropout,
583 | num_layers=0,
584 | resnet_eps=norm_eps,
585 | resnet_act_fn=act_fn,
586 | output_scale_factor=mid_block_scale_factor,
587 | resnet_groups=norm_num_groups,
588 | resnet_time_scale_shift=resnet_time_scale_shift,
589 | add_attention=False,
590 | )
591 | elif mid_block_type is None:
592 | self.mid_block = None
593 | else:
594 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
595 |
596 | # count how many layers upsample the images
597 | self.num_upsamplers = 0
598 |
599 | # up
600 | reversed_block_out_channels = list(reversed(block_out_channels))
601 | reversed_num_attention_heads = list(reversed(num_attention_heads))
602 | reversed_layers_per_block = list(reversed(layers_per_block))
603 | reversed_cross_attention_dim = list(reversed(cross_attention_dim))
604 | reversed_transformer_layers_per_block = (
605 | list(reversed(transformer_layers_per_block))
606 | if reverse_transformer_layers_per_block is None
607 | else reverse_transformer_layers_per_block
608 | )
609 | only_cross_attention = list(reversed(only_cross_attention))
610 |
611 | output_channel = reversed_block_out_channels[0]
612 | for i, up_block_type in enumerate(up_block_types):
613 | is_final_block = i == len(block_out_channels) - 1
614 |
615 | prev_output_channel = output_channel
616 | output_channel = reversed_block_out_channels[i]
617 | input_channel = reversed_block_out_channels[
618 | min(i + 1, len(block_out_channels) - 1)
619 | ]
620 |
621 | # add upsample block for all BUT final layer
622 | if not is_final_block:
623 | add_upsample = True
624 | self.num_upsamplers += 1
625 | else:
626 | add_upsample = False
627 |
628 | up_block = get_up_block(
629 | up_block_type,
630 | num_layers=reversed_layers_per_block[i] + 1,
631 | transformer_layers_per_block=reversed_transformer_layers_per_block[i],
632 | in_channels=input_channel,
633 | out_channels=output_channel,
634 | prev_output_channel=prev_output_channel,
635 | temb_channels=blocks_time_embed_dim,
636 | add_upsample=add_upsample,
637 | resnet_eps=norm_eps,
638 | resnet_act_fn=act_fn,
639 | resolution_idx=i,
640 | resnet_groups=norm_num_groups,
641 | cross_attention_dim=reversed_cross_attention_dim[i],
642 | num_attention_heads=reversed_num_attention_heads[i],
643 | dual_cross_attention=dual_cross_attention,
644 | use_linear_projection=use_linear_projection,
645 | only_cross_attention=only_cross_attention[i],
646 | upcast_attention=upcast_attention,
647 | resnet_time_scale_shift=resnet_time_scale_shift,
648 | attention_type=attention_type,
649 | resnet_skip_time_act=resnet_skip_time_act,
650 | resnet_out_scale_factor=resnet_out_scale_factor,
651 | cross_attention_norm=cross_attention_norm,
652 | attention_head_dim=attention_head_dim[i]
653 | if attention_head_dim[i] is not None
654 | else output_channel,
655 | dropout=dropout,
656 | )
657 | self.up_blocks.append(up_block)
658 | prev_output_channel = output_channel
659 |
660 | # out
661 | if norm_num_groups is not None:
662 | self.conv_norm_out = nn.GroupNorm(
663 | num_channels=block_out_channels[0],
664 | num_groups=norm_num_groups,
665 | eps=norm_eps,
666 | )
667 |
668 | self.conv_act = get_activation(act_fn)
669 |
670 | else:
671 | self.conv_norm_out = None
672 | self.conv_act = None
673 |
674 | conv_out_padding = (conv_out_kernel - 1) // 2
675 | self.conv_out = nn.Conv2d(
676 | block_out_channels[0],
677 | out_channels,
678 | kernel_size=conv_out_kernel,
679 | padding=conv_out_padding,
680 | )
681 |
682 | if attention_type in ["gated", "gated-text-image"]:
683 | positive_len = 768
684 | if isinstance(cross_attention_dim, int):
685 | positive_len = cross_attention_dim
686 | elif isinstance(cross_attention_dim, tuple) or isinstance(
687 | cross_attention_dim, list
688 | ):
689 | positive_len = cross_attention_dim[0]
690 |
691 | feature_type = "text-only" if attention_type == "gated" else "text-image"
692 | self.position_net = PositionNet(
693 | positive_len=positive_len,
694 | out_dim=cross_attention_dim,
695 | feature_type=feature_type,
696 | )
697 | self.has_sctuner = False
698 |
699 | @property
700 | def attn_processors(self) -> Dict[str, AttentionProcessor]:
701 | r"""
702 | Returns:
703 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
704 | indexed by its weight name.
705 | """
706 | # set recursively
707 | processors = {}
708 |
709 | def fn_recursive_add_processors(
710 | name: str,
711 | module: torch.nn.Module,
712 | processors: Dict[str, AttentionProcessor],
713 | ):
714 | if hasattr(module, "get_processor"):
715 | processors[f"{name}.processor"] = module.get_processor(
716 | return_deprecated_lora=True
717 | )
718 |
719 | for sub_name, child in module.named_children():
720 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
721 |
722 | return processors
723 |
724 | for name, module in self.named_children():
725 | fn_recursive_add_processors(name, module, processors)
726 |
727 | return processors
728 |
729 | def set_attn_processor(
730 | self,
731 | processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
732 | _remove_lora=False,
733 | ):
734 | r"""
735 | Sets the attention processor to use to compute attention.
736 |
737 | Parameters:
738 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
739 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
740 | for **all** `Attention` layers.
741 |
742 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention
743 | processor. This is strongly recommended when setting trainable attention processors.
744 |
745 | """
746 | count = len(self.attn_processors.keys())
747 |
748 | if isinstance(processor, dict) and len(processor) != count:
749 | raise ValueError(
750 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
751 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
752 | )
753 |
754 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
755 | if hasattr(module, "set_processor"):
756 | if not isinstance(processor, dict):
757 | module.set_processor(processor, _remove_lora=_remove_lora)
758 | else:
759 | module.set_processor(
760 | processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
761 | )
762 |
763 | for sub_name, child in module.named_children():
764 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
765 |
766 | for name, module in self.named_children():
767 | fn_recursive_attn_processor(name, module, processor)
768 |
769 | def set_default_attn_processor(self):
770 | """
771 | Disables custom attention processors and sets the default attention implementation.
772 | """
773 | if all(
774 | proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
775 | for proc in self.attn_processors.values()
776 | ):
777 | processor = AttnAddedKVProcessor()
778 | elif all(
779 | proc.__class__ in CROSS_ATTENTION_PROCESSORS
780 | for proc in self.attn_processors.values()
781 | ):
782 | processor = AttnProcessor()
783 | else:
784 | raise ValueError(
785 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
786 | )
787 |
788 | self.set_attn_processor(processor, _remove_lora=True)
789 |
790 | def set_attention_slice(self, slice_size):
791 | r"""
792 | Enable sliced attention computation.
793 |
794 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in
795 | several steps. This is useful for saving some memory in exchange for a small decrease in speed.
796 |
797 | Args:
798 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
799 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
800 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
801 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
802 | must be a multiple of `slice_size`.
803 | """
804 | sliceable_head_dims = []
805 |
806 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
807 | if hasattr(module, "set_attention_slice"):
808 | sliceable_head_dims.append(module.sliceable_head_dim)
809 |
810 | for child in module.children():
811 | fn_recursive_retrieve_sliceable_dims(child)
812 |
813 | # retrieve number of attention layers
814 | for module in self.children():
815 | fn_recursive_retrieve_sliceable_dims(module)
816 |
817 | num_sliceable_layers = len(sliceable_head_dims)
818 |
819 | if slice_size == "auto":
820 | # half the attention head size is usually a good trade-off between
821 | # speed and memory
822 | slice_size = [dim // 2 for dim in sliceable_head_dims]
823 | elif slice_size == "max":
824 | # make smallest slice possible
825 | slice_size = num_sliceable_layers * [1]
826 |
827 | slice_size = (
828 | num_sliceable_layers * [slice_size]
829 | if not isinstance(slice_size, list)
830 | else slice_size
831 | )
832 |
833 | if len(slice_size) != len(sliceable_head_dims):
834 | raise ValueError(
835 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
836 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
837 | )
838 |
839 | for i in range(len(slice_size)):
840 | size = slice_size[i]
841 | dim = sliceable_head_dims[i]
842 | if size is not None and size > dim:
843 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
844 |
845 | # Recursively walk through all the children.
846 | # Any children which exposes the set_attention_slice method
847 | # gets the message
848 | def fn_recursive_set_attention_slice(
849 | module: torch.nn.Module, slice_size: List[int]
850 | ):
851 | if hasattr(module, "set_attention_slice"):
852 | module.set_attention_slice(slice_size.pop())
853 |
854 | for child in module.children():
855 | fn_recursive_set_attention_slice(child, slice_size)
856 |
857 | reversed_slice_size = list(reversed(slice_size))
858 | for module in self.children():
859 | fn_recursive_set_attention_slice(module, reversed_slice_size)
860 |
861 | def _set_gradient_checkpointing(self, module, value=False):
862 | if hasattr(module, "gradient_checkpointing"):
863 | module.gradient_checkpointing = value
864 |
865 | def enable_freeu(self, s1, s2, b1, b2):
866 | r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
867 |
868 | The suffixes after the scaling factors represent the stage blocks where they are being applied.
869 |
870 | Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
871 | are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
872 |
873 | Args:
874 | s1 (`float`):
875 | Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
876 | mitigate the "oversmoothing effect" in the enhanced denoising process.
877 | s2 (`float`):
878 | Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
879 | mitigate the "oversmoothing effect" in the enhanced denoising process.
880 | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
881 | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
882 | """
883 | for i, upsample_block in enumerate(self.up_blocks):
884 | setattr(upsample_block, "s1", s1)
885 | setattr(upsample_block, "s2", s2)
886 | setattr(upsample_block, "b1", b1)
887 | setattr(upsample_block, "b2", b2)
888 |
889 | def disable_freeu(self):
890 | """Disables the FreeU mechanism."""
891 | freeu_keys = {"s1", "s2", "b1", "b2"}
892 | for i, upsample_block in enumerate(self.up_blocks):
893 | for k in freeu_keys:
894 | if (
895 | hasattr(upsample_block, k)
896 | or getattr(upsample_block, k, None) is not None
897 | ):
898 | setattr(upsample_block, k, None)
899 |
900 | def set_sctuner(self, sctuner_module=SCTunerLinearLayer, **kwargs):
901 | for upsample_block in self.up_blocks:
902 | upsample_block.set_sctuner(sctuner_module=sctuner_module, **kwargs)
903 | self.has_sctuner = True
904 |
905 | def forward(
906 | self,
907 | sample: torch.FloatTensor,
908 | timestep: Union[torch.Tensor, float, int],
909 | encoder_hidden_states: torch.Tensor,
910 | class_labels: Optional[torch.Tensor] = None,
911 | timestep_cond: Optional[torch.Tensor] = None,
912 | attention_mask: Optional[torch.Tensor] = None,
913 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
914 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
915 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
916 | mid_block_additional_residual: Optional[torch.Tensor] = None,
917 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
918 | encoder_attention_mask: Optional[torch.Tensor] = None,
919 | return_dict: bool = True,
920 | ) -> Union[UNet2DConditionOutput, Tuple]:
921 | r"""
922 | The [`UNet2DConditionModel`] forward method.
923 |
924 | Args:
925 | sample (`torch.FloatTensor`):
926 | The noisy input tensor with the following shape `(batch, channel, height, width)`.
927 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
928 | encoder_hidden_states (`torch.FloatTensor`):
929 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
930 | class_labels (`torch.Tensor`, *optional*, defaults to `None`):
931 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
932 | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
933 | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
934 | through the `self.time_embedding` layer to obtain the timestep embeddings.
935 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
936 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
937 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
938 | negative values to the attention scores corresponding to "discard" tokens.
939 | cross_attention_kwargs (`dict`, *optional*):
940 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
941 | `self.processor` in
942 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
943 | added_cond_kwargs: (`dict`, *optional*):
944 | A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
945 | are passed along to the UNet blocks.
946 | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
947 | A tuple of tensors that if specified are added to the residuals of down unet blocks.
948 | mid_block_additional_residual: (`torch.Tensor`, *optional*):
949 | A tensor that if specified is added to the residual of the middle unet block.
950 | encoder_attention_mask (`torch.Tensor`):
951 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
952 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
953 | which adds large negative values to the attention scores corresponding to "discard" tokens.
954 | return_dict (`bool`, *optional*, defaults to `True`):
955 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
956 | tuple.
957 | cross_attention_kwargs (`dict`, *optional*):
958 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
959 | added_cond_kwargs: (`dict`, *optional*):
960 | A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
961 | are passed along to the UNet blocks.
962 | down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
963 | additional residuals to be added to UNet long skip connections from down blocks to up blocks for
964 | example from ControlNet side model(s)
965 | mid_block_additional_residual (`torch.Tensor`, *optional*):
966 | additional residual to be added to UNet mid block output, for example from ControlNet side model
967 | down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
968 | additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
969 |
970 | Returns:
971 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
972 | If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
973 | a `tuple` is returned where the first element is the sample tensor.
974 | """
975 | # By default samples have to be AT least a multiple of the overall upsampling factor.
976 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
977 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
978 | # on the fly if necessary.
979 | default_overall_up_factor = 2**self.num_upsamplers
980 |
981 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
982 | forward_upsample_size = False
983 | upsample_size = None
984 |
985 | for dim in sample.shape[-2:]:
986 | if dim % default_overall_up_factor != 0:
987 | # Forward upsample size to force interpolation output size.
988 | forward_upsample_size = True
989 | break
990 |
991 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
992 | # expects mask of shape:
993 | # [batch, key_tokens]
994 | # adds singleton query_tokens dimension:
995 | # [batch, 1, key_tokens]
996 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
997 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
998 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
999 | if attention_mask is not None:
1000 | # assume that mask is expressed as:
1001 | # (1 = keep, 0 = discard)
1002 | # convert mask into a bias that can be added to attention scores:
1003 | # (keep = +0, discard = -10000.0)
1004 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1005 | attention_mask = attention_mask.unsqueeze(1)
1006 |
1007 | # convert encoder_attention_mask to a bias the same way we do for attention_mask
1008 | if encoder_attention_mask is not None:
1009 | encoder_attention_mask = (
1010 | 1 - encoder_attention_mask.to(sample.dtype)
1011 | ) * -10000.0
1012 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
1013 |
1014 | # 0. center input if necessary
1015 | if self.config.center_input_sample:
1016 | sample = 2 * sample - 1.0
1017 |
1018 | # 1. time
1019 | timesteps = timestep
1020 | if not torch.is_tensor(timesteps):
1021 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
1022 | # This would be a good case for the `match` statement (Python 3.10+)
1023 | is_mps = sample.device.type == "mps"
1024 | if isinstance(timestep, float):
1025 | dtype = torch.float32 if is_mps else torch.float64
1026 | else:
1027 | dtype = torch.int32 if is_mps else torch.int64
1028 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1029 | elif len(timesteps.shape) == 0:
1030 | timesteps = timesteps[None].to(sample.device)
1031 |
1032 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1033 | timesteps = timesteps.expand(sample.shape[0])
1034 |
1035 | t_emb = self.time_proj(timesteps)
1036 |
1037 | # `Timesteps` does not contain any weights and will always return f32 tensors
1038 | # but time_embedding might actually be running in fp16. so we need to cast here.
1039 | # there might be better ways to encapsulate this.
1040 | t_emb = t_emb.to(dtype=sample.dtype)
1041 |
1042 | emb = self.time_embedding(t_emb, timestep_cond)
1043 | aug_emb = None
1044 |
1045 | if self.class_embedding is not None:
1046 | if class_labels is None:
1047 | raise ValueError(
1048 | "class_labels should be provided when num_class_embeds > 0"
1049 | )
1050 |
1051 | if self.config.class_embed_type == "timestep":
1052 | class_labels = self.time_proj(class_labels)
1053 |
1054 | # `Timesteps` does not contain any weights and will always return f32 tensors
1055 | # there might be better ways to encapsulate this.
1056 | class_labels = class_labels.to(dtype=sample.dtype)
1057 |
1058 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1059 |
1060 | if self.config.class_embeddings_concat:
1061 | emb = torch.cat([emb, class_emb], dim=-1)
1062 | else:
1063 | emb = emb + class_emb
1064 |
1065 | if self.config.addition_embed_type == "text":
1066 | aug_emb = self.add_embedding(encoder_hidden_states)
1067 | elif self.config.addition_embed_type == "text_image":
1068 | # Kandinsky 2.1 - style
1069 | if "image_embeds" not in added_cond_kwargs:
1070 | raise ValueError(
1071 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1072 | )
1073 |
1074 | image_embs = added_cond_kwargs.get("image_embeds")
1075 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1076 | aug_emb = self.add_embedding(text_embs, image_embs)
1077 | elif self.config.addition_embed_type == "text_time":
1078 | # SDXL - style
1079 | if "text_embeds" not in added_cond_kwargs:
1080 | raise ValueError(
1081 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1082 | )
1083 | text_embeds = added_cond_kwargs.get("text_embeds")
1084 | if "time_ids" not in added_cond_kwargs:
1085 | raise ValueError(
1086 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1087 | )
1088 | time_ids = added_cond_kwargs.get("time_ids")
1089 | time_embeds = self.add_time_proj(time_ids.flatten())
1090 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1091 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1092 | add_embeds = add_embeds.to(emb.dtype)
1093 | aug_emb = self.add_embedding(add_embeds)
1094 | elif self.config.addition_embed_type == "image":
1095 | # Kandinsky 2.2 - style
1096 | if "image_embeds" not in added_cond_kwargs:
1097 | raise ValueError(
1098 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1099 | )
1100 | image_embs = added_cond_kwargs.get("image_embeds")
1101 | aug_emb = self.add_embedding(image_embs)
1102 | elif self.config.addition_embed_type == "image_hint":
1103 | # Kandinsky 2.2 - style
1104 | if (
1105 | "image_embeds" not in added_cond_kwargs
1106 | or "hint" not in added_cond_kwargs
1107 | ):
1108 | raise ValueError(
1109 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1110 | )
1111 | image_embs = added_cond_kwargs.get("image_embeds")
1112 | hint = added_cond_kwargs.get("hint")
1113 | aug_emb, hint = self.add_embedding(image_embs, hint)
1114 | sample = torch.cat([sample, hint], dim=1)
1115 |
1116 | emb = emb + aug_emb if aug_emb is not None else emb
1117 |
1118 | if self.time_embed_act is not None:
1119 | emb = self.time_embed_act(emb)
1120 |
1121 | if (
1122 | self.encoder_hid_proj is not None
1123 | and self.config.encoder_hid_dim_type == "text_proj"
1124 | ):
1125 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1126 | elif (
1127 | self.encoder_hid_proj is not None
1128 | and self.config.encoder_hid_dim_type == "text_image_proj"
1129 | ):
1130 | # Kadinsky 2.1 - style
1131 | if "image_embeds" not in added_cond_kwargs:
1132 | raise ValueError(
1133 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1134 | )
1135 |
1136 | image_embeds = added_cond_kwargs.get("image_embeds")
1137 | encoder_hidden_states = self.encoder_hid_proj(
1138 | encoder_hidden_states, image_embeds
1139 | )
1140 | elif (
1141 | self.encoder_hid_proj is not None
1142 | and self.config.encoder_hid_dim_type == "image_proj"
1143 | ):
1144 | # Kandinsky 2.2 - style
1145 | if "image_embeds" not in added_cond_kwargs:
1146 | raise ValueError(
1147 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1148 | )
1149 | image_embeds = added_cond_kwargs.get("image_embeds")
1150 | encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1151 | elif (
1152 | self.encoder_hid_proj is not None
1153 | and self.config.encoder_hid_dim_type == "ip_image_proj"
1154 | ):
1155 | if "image_embeds" not in added_cond_kwargs:
1156 | raise ValueError(
1157 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1158 | )
1159 | image_embeds = added_cond_kwargs.get("image_embeds")
1160 | image_embeds = self.encoder_hid_proj(image_embeds).to(
1161 | encoder_hidden_states.dtype
1162 | )
1163 | encoder_hidden_states = torch.cat(
1164 | [encoder_hidden_states, image_embeds], dim=1
1165 | )
1166 |
1167 | # 2. pre-process
1168 | sample = self.conv_in(sample)
1169 |
1170 | # 2.5 GLIGEN position net
1171 | if (
1172 | cross_attention_kwargs is not None
1173 | and cross_attention_kwargs.get("gligen", None) is not None
1174 | ):
1175 | cross_attention_kwargs = cross_attention_kwargs.copy()
1176 | gligen_args = cross_attention_kwargs.pop("gligen")
1177 | cross_attention_kwargs["gligen"] = {
1178 | "objs": self.position_net(**gligen_args)
1179 | }
1180 |
1181 | # 3. down
1182 | lora_scale = (
1183 | cross_attention_kwargs.get("scale", 1.0)
1184 | if cross_attention_kwargs is not None
1185 | else 1.0
1186 | )
1187 | if USE_PEFT_BACKEND:
1188 | # weight the lora layers by setting `lora_scale` for each PEFT layer
1189 | scale_lora_layers(self, lora_scale)
1190 |
1191 | is_controlnet = (
1192 | mid_block_additional_residual is not None
1193 | and down_block_additional_residuals is not None
1194 | )
1195 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1196 | is_adapter = down_intrablock_additional_residuals is not None
1197 | # maintain backward compatibility for legacy usage, where
1198 | # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1199 | # but can only use one or the other
1200 | if (
1201 | not is_adapter
1202 | and mid_block_additional_residual is None
1203 | and down_block_additional_residuals is not None
1204 | ):
1205 | deprecate(
1206 | "T2I should not use down_block_additional_residuals",
1207 | "1.3.0",
1208 | "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1209 | and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1210 | for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1211 | standard_warn=False,
1212 | )
1213 | down_intrablock_additional_residuals = down_block_additional_residuals
1214 | is_adapter = True
1215 |
1216 | down_block_res_samples = (sample,)
1217 | for downsample_block in self.down_blocks:
1218 | if (
1219 | hasattr(downsample_block, "has_cross_attention")
1220 | and downsample_block.has_cross_attention
1221 | ):
1222 | # For t2i-adapter CrossAttnDownBlock2D
1223 | additional_residuals = {}
1224 | if is_adapter and len(down_intrablock_additional_residuals) > 0:
1225 | additional_residuals[
1226 | "additional_residuals"
1227 | ] = down_intrablock_additional_residuals.pop(0)
1228 |
1229 | sample, res_samples = downsample_block(
1230 | hidden_states=sample,
1231 | temb=emb,
1232 | encoder_hidden_states=encoder_hidden_states,
1233 | attention_mask=attention_mask,
1234 | cross_attention_kwargs=cross_attention_kwargs,
1235 | encoder_attention_mask=encoder_attention_mask,
1236 | **additional_residuals,
1237 | )
1238 | else:
1239 | sample, res_samples = downsample_block(
1240 | hidden_states=sample, temb=emb, scale=lora_scale
1241 | )
1242 | if is_adapter and len(down_intrablock_additional_residuals) > 0:
1243 | sample += down_intrablock_additional_residuals.pop(0)
1244 |
1245 | down_block_res_samples += res_samples
1246 |
1247 | if is_controlnet:
1248 | new_down_block_res_samples = ()
1249 |
1250 | for down_block_res_sample, down_block_additional_residual in zip(
1251 | down_block_res_samples, down_block_additional_residuals
1252 | ):
1253 | down_block_res_sample = (
1254 | down_block_res_sample + down_block_additional_residual
1255 | )
1256 | new_down_block_res_samples = new_down_block_res_samples + (
1257 | down_block_res_sample,
1258 | )
1259 |
1260 | down_block_res_samples = new_down_block_res_samples
1261 |
1262 | # 4. mid
1263 | if self.mid_block is not None:
1264 | if (
1265 | hasattr(self.mid_block, "has_cross_attention")
1266 | and self.mid_block.has_cross_attention
1267 | ):
1268 | sample = self.mid_block(
1269 | sample,
1270 | emb,
1271 | encoder_hidden_states=encoder_hidden_states,
1272 | attention_mask=attention_mask,
1273 | cross_attention_kwargs=cross_attention_kwargs,
1274 | encoder_attention_mask=encoder_attention_mask,
1275 | )
1276 | else:
1277 | sample = self.mid_block(sample, emb)
1278 |
1279 | # To support T2I-Adapter-XL
1280 | if (
1281 | is_adapter
1282 | and len(down_intrablock_additional_residuals) > 0
1283 | and sample.shape == down_intrablock_additional_residuals[0].shape
1284 | ):
1285 | sample += down_intrablock_additional_residuals.pop(0)
1286 |
1287 | if is_controlnet:
1288 | sample = sample + mid_block_additional_residual
1289 |
1290 | # 5. up
1291 | for i, upsample_block in enumerate(self.up_blocks):
1292 | is_final_block = i == len(self.up_blocks) - 1
1293 |
1294 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1295 | down_block_res_samples = down_block_res_samples[
1296 | : -len(upsample_block.resnets)
1297 | ]
1298 |
1299 | # if we have not reached the final block and need to forward the
1300 | # upsample size, we do it here
1301 | if not is_final_block and forward_upsample_size:
1302 | upsample_size = down_block_res_samples[-1].shape[2:]
1303 |
1304 | if (
1305 | hasattr(upsample_block, "has_cross_attention")
1306 | and upsample_block.has_cross_attention
1307 | ):
1308 | sample = upsample_block(
1309 | hidden_states=sample,
1310 | temb=emb,
1311 | res_hidden_states_tuple=res_samples,
1312 | encoder_hidden_states=encoder_hidden_states,
1313 | cross_attention_kwargs=cross_attention_kwargs,
1314 | upsample_size=upsample_size,
1315 | attention_mask=attention_mask,
1316 | encoder_attention_mask=encoder_attention_mask,
1317 | )
1318 | else:
1319 | sample = upsample_block(
1320 | hidden_states=sample,
1321 | temb=emb,
1322 | res_hidden_states_tuple=res_samples,
1323 | upsample_size=upsample_size,
1324 | scale=lora_scale,
1325 | )
1326 |
1327 | # 6. post-process
1328 | if self.conv_norm_out:
1329 | sample = self.conv_norm_out(sample)
1330 | sample = self.conv_act(sample)
1331 | sample = self.conv_out(sample)
1332 |
1333 | if USE_PEFT_BACKEND:
1334 | # remove `lora_scale` from each PEFT layer
1335 | unscale_lora_layers(self, lora_scale)
1336 |
1337 | if not return_dict:
1338 | return (sample,)
1339 |
1340 | return UNet2DConditionOutput(sample=sample)
1341 |
--------------------------------------------------------------------------------
/scedit_pytorch/scedit.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 | import torch
3 | from torch import nn
4 | from diffusers.models.modeling_utils import get_parameter_device, get_parameter_dtype
5 |
6 |
7 | class AbstractSCTunerLayer(nn.Module):
8 | def __init__(
9 | self,
10 | dim: int,
11 | device: Optional[Union[torch.device, str]] = None,
12 | dtype: Optional[torch.dtype] = None,
13 | ):
14 | super().__init__()
15 | self.dim = dim
16 |
17 |
18 | class SCTunerLinearLayer(AbstractSCTunerLayer):
19 | r"""
20 | A linear layer that is used with SCEdit.
21 |
22 | Parameters:
23 | dim (`int`):
24 | Number of dim.
25 | out_features (`int`):
26 | Number of output features.
27 | rank (`int`, `optional`, defaults to 4):
28 | The rank of the LoRA layer.
29 | device (`torch.device`, `optional`, defaults to `None`):
30 | The device to use for the layer's weights.
31 | dtype (`torch.dtype`, `optional`, defaults to `None`):
32 | The dtype to use for the layer's weights.
33 | """
34 |
35 | def __init__(
36 | self,
37 | dim: int,
38 | rank: Optional[int] = None,
39 | scale: float = 1.0,
40 | device: Optional[Union[torch.device, str]] = None,
41 | dtype: Optional[torch.dtype] = None,
42 | ):
43 | super().__init__(dim=dim)
44 | if rank is None:
45 | rank = dim
46 | self.down = nn.Linear(dim, rank, device=device, dtype=dtype)
47 | self.up = nn.Linear(rank, dim, device=device, dtype=dtype)
48 | self.act = nn.GELU()
49 | self.rank = rank
50 | self.scale = scale
51 |
52 | nn.init.normal_(self.down.weight, std=1 / rank)
53 | nn.init.zeros_(self.up.weight)
54 |
55 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
56 | orig_dtype = hidden_states.dtype
57 | dtype = self.down.weight.dtype
58 |
59 | hidden_states_input = hidden_states.permute(0, 2, 3, 1)
60 | down_hidden_states = self.down(hidden_states_input.to(dtype))
61 | up_hidden_states = self.up(self.act(down_hidden_states))
62 | up_hidden_states = up_hidden_states.to(orig_dtype).permute(0, 3, 1, 2)
63 | return self.scale * up_hidden_states + hidden_states
64 |
65 |
66 | class SCTunerLinearLayer2(AbstractSCTunerLayer):
67 | def __init__(
68 | self,
69 | dim: int,
70 | rank: Optional[int] = None,
71 | scale: float = 1.0,
72 | device: Optional[Union[torch.device, str]] = None,
73 | dtype: Optional[torch.dtype] = None,
74 | ):
75 | super().__init__(dim=dim)
76 | if rank is None:
77 | rank = dim
78 | self.model = nn.Sequential(
79 | nn.Linear(dim, rank, device=device, dtype=dtype),
80 | nn.GELU(),
81 | nn.Linear(rank, rank, device=device, dtype=dtype),
82 | nn.GELU(),
83 | nn.Linear(rank, dim, device=device, dtype=dtype),
84 | )
85 | self.rank = rank
86 | self.scale = scale
87 |
88 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
89 | orig_dtype = hidden_states.dtype
90 | dtype = self.model[0].weight.dtype
91 |
92 | hidden_states_input = hidden_states.permute(0, 2, 3, 1)
93 | hidden_states_output = self.model(hidden_states_input.to(dtype))
94 | hidden_states_output = hidden_states_output.to(orig_dtype).permute(0, 3, 1, 2)
95 | return self.scale * hidden_states_output + hidden_states
96 |
97 |
98 | class SCTuner(nn.Module):
99 | def __init__(self):
100 | super().__init__()
101 | self.sc_tuners = None
102 | self.res_skip_channels = None
103 |
104 | @property
105 | def device(self) -> torch.device:
106 | """
107 | `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
108 | device).
109 | """
110 | return get_parameter_device(self)
111 |
112 | @property
113 | def dtype(self) -> torch.dtype:
114 | """
115 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
116 | """
117 | return get_parameter_dtype(self)
118 |
119 | def set_sctuner(self, sctuner_module=SCTunerLinearLayer, **kwargs):
120 | assert isinstance(self.res_skip_channels, list)
121 | sc_tuners = []
122 | for c in self.res_skip_channels:
123 | sc_tuners.append(
124 | sctuner_module(dim=c, device=self.device, dtype=self.dtype, **kwargs)
125 | )
126 | self.sc_tuners = nn.ModuleList(sc_tuners)
127 |
--------------------------------------------------------------------------------
/scedit_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | from typing import Dict, Union
4 |
5 | import torch
6 | import safetensors
7 | from diffusers.utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file
8 | from . import UNet2DConditionModel
9 |
10 | logger = logging.getLogger(__name__)
11 |
12 | SCEDIT_WEIGHT_NAME_SAFE = "pytorch_scedit_weights.safetensors"
13 |
14 |
15 | def save_function(weights, filename):
16 | return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
17 |
18 |
19 | def save_scedit(state_dict: Dict[str, torch.Tensor], save_directory: str):
20 | if os.path.isfile(save_directory):
21 | logger.error(
22 | f"Provided path ({save_directory}) should be a directory, not a file"
23 | )
24 | return
25 | os.makedirs(save_directory, exist_ok=True)
26 | save_function(state_dict, os.path.join(save_directory, SCEDIT_WEIGHT_NAME_SAFE))
27 | logger.info(
28 | f"Model weights saved in {os.path.join(save_directory, SCEDIT_WEIGHT_NAME_SAFE)}"
29 | )
30 |
31 |
32 | def scedit_state_dict(
33 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs
34 | ) -> Dict[str, torch.Tensor]:
35 | cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
36 | force_download = kwargs.pop("force_download", False)
37 | resume_download = kwargs.pop("resume_download", False)
38 | proxies = kwargs.pop("proxies", None)
39 | local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
40 | use_auth_token = kwargs.pop("use_auth_token", None)
41 | revision = kwargs.pop("revision", None)
42 | subfolder = kwargs.pop("subfolder", None)
43 | weight_name = kwargs.pop("weight_name", None)
44 | user_agent = {
45 | "file_type": "attn_procs_weights",
46 | "framework": "pytorch",
47 | }
48 |
49 | model_file = None
50 | if not isinstance(pretrained_model_name_or_path_or_dict, dict):
51 | # Here we're relaxing the loading check to enable more Inference API
52 | # friendliness where sometimes, it's not at all possible to automatically
53 | # determine `weight_name`.
54 | if weight_name is None:
55 | weight_name = SCEDIT_WEIGHT_NAME_SAFE
56 | model_file = _get_model_file(
57 | pretrained_model_name_or_path_or_dict,
58 | weights_name=weight_name,
59 | cache_dir=cache_dir,
60 | force_download=force_download,
61 | resume_download=resume_download,
62 | proxies=proxies,
63 | local_files_only=local_files_only,
64 | use_auth_token=use_auth_token,
65 | revision=revision,
66 | subfolder=subfolder,
67 | user_agent=user_agent,
68 | )
69 | state_dict = safetensors.torch.load_file(model_file, device="cpu")
70 | else:
71 | state_dict = pretrained_model_name_or_path_or_dict
72 |
73 | return state_dict
74 |
75 |
76 | def load_scedit_into_unet(
77 | state_dict: Union[str, Dict[str, torch.Tensor]],
78 | unet: UNet2DConditionModel,
79 | **kwargs,
80 | ) -> UNet2DConditionModel:
81 | if isinstance(state_dict, str):
82 | state_dict = scedit_state_dict(state_dict, **kwargs)
83 | assert unet.has_sctuner, "Make sure to call `set_sctuner` before!"
84 | unet.load_state_dict(state_dict, strict=False)
85 | return unet
86 |
--------------------------------------------------------------------------------
/scripts/app.py:
--------------------------------------------------------------------------------
1 | import sys
2 |
3 | sys.path.append(".")
4 | import argparse
5 | import gradio as gr
6 | import torch
7 | from huggingface_hub.repocard import RepoCard
8 | from diffusers import DiffusionPipeline
9 | from scedit_pytorch import UNet2DConditionModel, load_scedit_into_unet
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument(
15 | "--pretrained_model_name_or_path",
16 | type=str,
17 | default="stabilityai/stable-diffusion-xl-base-1.0",
18 | help="pretrained model path",
19 | )
20 | parser.add_argument(
21 | "--scedit_name_or_path", type=str, required=True, help="ziplora path"
22 | )
23 | parser.add_argument("--scale", type=float, default=1.0, help="weight scale")
24 | return parser.parse_args()
25 |
26 |
27 | args = parse_args()
28 | device = "cuda" if torch.cuda.is_available() else "cpu"
29 | # load unet with sctuner
30 | unet = UNet2DConditionModel.from_pretrained(
31 | args.pretrained_model_name_or_path, subfolder="unet"
32 | )
33 | unet.set_sctuner(scale=args.scale)
34 | unet = load_scedit_into_unet(args.scedit_name_or_path, unet)
35 | # load pipeline
36 | pipeline = DiffusionPipeline.from_pretrained(
37 | args.pretrained_model_name_or_path, unet=unet
38 | )
39 | pipeline = pipeline.to(device=device, dtype=torch.float16)
40 |
41 |
42 | def run(prompt: str):
43 | # generator = torch.Generator(device="cuda").manual_seed(42)
44 | generator = None
45 | image = pipeline(prompt=prompt, generator=generator).images[0]
46 | return image
47 |
48 |
49 | with gr.Blocks() as demo:
50 | with gr.Row():
51 | with gr.Column():
52 | prompt = gr.Text(label="prompt", value="A picture of a sbu dog in a bucket")
53 | bttn = gr.Button(value="Run")
54 | with gr.Column():
55 | out = gr.Image(label="out")
56 | prompt.submit(fn=run, inputs=[prompt], outputs=[out])
57 | bttn.click(fn=run, inputs=[prompt], outputs=[out])
58 |
59 | demo.launch(share=True, debug=True, show_error=True)
60 |
--------------------------------------------------------------------------------
/train_dreambooth_scedit_sdxl.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import gc
3 | import itertools
4 | import logging
5 | import math
6 | import os
7 | import shutil
8 | import warnings
9 | from pathlib import Path
10 |
11 | import numpy as np
12 | import torch
13 | import torch.nn.functional as F
14 | import torch.utils.checkpoint
15 | import transformers
16 | from accelerate import Accelerator
17 | from accelerate.logging import get_logger
18 | from accelerate.utils import (
19 | DistributedDataParallelKwargs,
20 | ProjectConfiguration,
21 | set_seed,
22 | )
23 | from huggingface_hub import create_repo, upload_folder
24 | from huggingface_hub.utils import insecure_hashlib
25 | from packaging import version
26 | from PIL import Image
27 | from PIL.ImageOps import exif_transpose
28 | from torch.utils.data import Dataset
29 | from torchvision import transforms
30 | from tqdm.auto import tqdm
31 | from transformers import AutoTokenizer, PretrainedConfig
32 |
33 | import diffusers
34 | from diffusers import (
35 | AutoencoderKL,
36 | DDPMScheduler,
37 | DPMSolverMultistepScheduler,
38 | StableDiffusionXLPipeline,
39 | # UNet2DConditionModel,
40 | )
41 | from diffusers.optimization import get_scheduler
42 | from diffusers.training_utils import compute_snr
43 | from diffusers.utils import check_min_version, is_wandb_available
44 | from diffusers.utils.import_utils import is_xformers_available
45 | from scedit_pytorch import UNet2DConditionModel, save_scedit, load_scedit_into_unet
46 |
47 |
48 | logger = get_logger(__name__)
49 |
50 |
51 | def save_model_card(
52 | repo_id: str,
53 | images=None,
54 | base_model=str,
55 | instance_prompt=str,
56 | validation_prompt=str,
57 | repo_folder=None,
58 | vae_path=None,
59 | ):
60 | img_str = "widget:\n" if images else ""
61 | for i, image in enumerate(images):
62 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
63 | img_str += f"""
64 | - text: '{validation_prompt if validation_prompt else ' ' }'
65 | output:
66 | url:
67 | "image_{i}.png"
68 | """
69 |
70 | yaml = f"""
71 | ---
72 | tags:
73 | - stable-diffusion-xl
74 | - stable-diffusion-xl-diffusers
75 | - text-to-image
76 | - diffusers
77 | - scedit
78 | - template:sd-lora
79 | {img_str}
80 | base_model: {base_model}
81 | instance_prompt: {instance_prompt}
82 | license: openrail++
83 | ---
84 | """
85 |
86 | model_card = f"""
87 | # SDXL SCEdit DreamBooth - {repo_id}
88 |
89 |
90 |
91 | ## Model description
92 |
93 | These are {repo_id} SC-Tuner adaption weights for {base_model}.
94 |
95 | The weights were trained using [DreamBooth](https://dreambooth.github.io/).
96 |
97 | Special VAE used for training: {vae_path}.
98 |
99 | ## Trigger words
100 |
101 | You should use {instance_prompt} to trigger the image generation.
102 |
103 | ## Download model
104 |
105 | Weights for this model are available in Safetensors format.
106 |
107 | [Download]({repo_id}/tree/main) them in the Files & versions tab.
108 |
109 | """
110 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
111 | f.write(yaml + model_card)
112 |
113 |
114 | def import_model_class_from_model_name_or_path(
115 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
116 | ):
117 | text_encoder_config = PretrainedConfig.from_pretrained(
118 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision
119 | )
120 | model_class = text_encoder_config.architectures[0]
121 |
122 | if model_class == "CLIPTextModel":
123 | from transformers import CLIPTextModel
124 |
125 | return CLIPTextModel
126 | elif model_class == "CLIPTextModelWithProjection":
127 | from transformers import CLIPTextModelWithProjection
128 |
129 | return CLIPTextModelWithProjection
130 | else:
131 | raise ValueError(f"{model_class} is not supported.")
132 |
133 |
134 | def parse_args(input_args=None):
135 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
136 | parser.add_argument(
137 | "--pretrained_model_name_or_path",
138 | type=str,
139 | default=None,
140 | required=True,
141 | help="Path to pretrained model or model identifier from huggingface.co/models.",
142 | )
143 | parser.add_argument(
144 | "--pretrained_vae_model_name_or_path",
145 | type=str,
146 | default=None,
147 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
148 | )
149 | parser.add_argument(
150 | "--revision",
151 | type=str,
152 | default=None,
153 | required=False,
154 | help="Revision of pretrained model identifier from huggingface.co/models.",
155 | )
156 | parser.add_argument(
157 | "--variant",
158 | type=str,
159 | default=None,
160 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
161 | )
162 | parser.add_argument(
163 | "--dataset_name",
164 | type=str,
165 | default=None,
166 | help=(
167 | "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
168 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
169 | " or to a folder containing files that 🤗 Datasets can understand."
170 | ),
171 | )
172 | parser.add_argument(
173 | "--dataset_config_name",
174 | type=str,
175 | default=None,
176 | help="The config of the Dataset, leave as None if there's only one config.",
177 | )
178 | parser.add_argument(
179 | "--instance_data_dir",
180 | type=str,
181 | default=None,
182 | help=("A folder containing the training data. "),
183 | )
184 |
185 | parser.add_argument(
186 | "--cache_dir",
187 | type=str,
188 | default=None,
189 | help="The directory where the downloaded models and datasets will be stored.",
190 | )
191 |
192 | parser.add_argument(
193 | "--image_column",
194 | type=str,
195 | default="image",
196 | help="The column of the dataset containing the target image. By "
197 | "default, the standard Image Dataset maps out 'file_name' "
198 | "to 'image'.",
199 | )
200 | parser.add_argument(
201 | "--caption_column",
202 | type=str,
203 | default=None,
204 | help="The column of the dataset containing the instance prompt for each image",
205 | )
206 |
207 | parser.add_argument(
208 | "--repeats",
209 | type=int,
210 | default=1,
211 | help="How many times to repeat the training data.",
212 | )
213 |
214 | parser.add_argument(
215 | "--class_data_dir",
216 | type=str,
217 | default=None,
218 | required=False,
219 | help="A folder containing the training data of class images.",
220 | )
221 | parser.add_argument(
222 | "--instance_prompt",
223 | type=str,
224 | default=None,
225 | required=True,
226 | help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
227 | )
228 | parser.add_argument(
229 | "--class_prompt",
230 | type=str,
231 | default=None,
232 | help="The prompt to specify images in the same class as provided instance images.",
233 | )
234 | parser.add_argument(
235 | "--validation_prompt",
236 | type=str,
237 | default=None,
238 | help="A prompt that is used during validation to verify that the model is learning.",
239 | )
240 | parser.add_argument(
241 | "--num_validation_images",
242 | type=int,
243 | default=4,
244 | help="Number of images that should be generated during validation with `validation_prompt`.",
245 | )
246 | parser.add_argument(
247 | "--validation_epochs",
248 | type=int,
249 | default=50,
250 | help=(
251 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
252 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
253 | ),
254 | )
255 | parser.add_argument(
256 | "--with_prior_preservation",
257 | default=False,
258 | action="store_true",
259 | help="Flag to add prior preservation loss.",
260 | )
261 | parser.add_argument(
262 | "--prior_loss_weight",
263 | type=float,
264 | default=1.0,
265 | help="The weight of prior preservation loss.",
266 | )
267 | parser.add_argument(
268 | "--num_class_images",
269 | type=int,
270 | default=100,
271 | help=(
272 | "Minimal class images for prior preservation loss. If there are not enough images already present in"
273 | " class_data_dir, additional images will be sampled with class_prompt."
274 | ),
275 | )
276 | parser.add_argument(
277 | "--output_dir",
278 | type=str,
279 | default="scedit-dreambooth-model",
280 | help="The output directory where the model predictions and checkpoints will be written.",
281 | )
282 | parser.add_argument(
283 | "--seed", type=int, default=None, help="A seed for reproducible training."
284 | )
285 | parser.add_argument(
286 | "--resolution",
287 | type=int,
288 | default=1024,
289 | help=(
290 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
291 | " resolution"
292 | ),
293 | )
294 | parser.add_argument(
295 | "--crops_coords_top_left_h",
296 | type=int,
297 | default=0,
298 | help=(
299 | "Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."
300 | ),
301 | )
302 | parser.add_argument(
303 | "--crops_coords_top_left_w",
304 | type=int,
305 | default=0,
306 | help=(
307 | "Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."
308 | ),
309 | )
310 | parser.add_argument(
311 | "--center_crop",
312 | default=False,
313 | action="store_true",
314 | help=(
315 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
316 | " cropped. The images will be resized to the resolution first before cropping."
317 | ),
318 | )
319 | parser.add_argument(
320 | "--train_batch_size",
321 | type=int,
322 | default=4,
323 | help="Batch size (per device) for the training dataloader.",
324 | )
325 | parser.add_argument(
326 | "--sample_batch_size",
327 | type=int,
328 | default=4,
329 | help="Batch size (per device) for sampling images.",
330 | )
331 | parser.add_argument("--num_train_epochs", type=int, default=1)
332 | parser.add_argument(
333 | "--max_train_steps",
334 | type=int,
335 | default=None,
336 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
337 | )
338 | parser.add_argument(
339 | "--checkpointing_steps",
340 | type=int,
341 | default=500,
342 | help=(
343 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
344 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
345 | " training using `--resume_from_checkpoint`."
346 | ),
347 | )
348 | parser.add_argument(
349 | "--checkpoints_total_limit",
350 | type=int,
351 | default=None,
352 | help=("Max number of checkpoints to store."),
353 | )
354 | parser.add_argument(
355 | "--resume_from_checkpoint",
356 | type=str,
357 | default=None,
358 | help=(
359 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
360 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
361 | ),
362 | )
363 | parser.add_argument(
364 | "--gradient_accumulation_steps",
365 | type=int,
366 | default=1,
367 | help="Number of updates steps to accumulate before performing a backward/update pass.",
368 | )
369 | parser.add_argument(
370 | "--gradient_checkpointing",
371 | action="store_true",
372 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
373 | )
374 | parser.add_argument(
375 | "--learning_rate",
376 | type=float,
377 | default=1e-4,
378 | help="Initial learning rate (after the potential warmup period) to use.",
379 | )
380 |
381 | parser.add_argument(
382 | "--text_encoder_lr",
383 | type=float,
384 | default=5e-6,
385 | help="Text encoder learning rate to use.",
386 | )
387 | parser.add_argument(
388 | "--scale_lr",
389 | action="store_true",
390 | default=False,
391 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
392 | )
393 | parser.add_argument(
394 | "--lr_scheduler",
395 | type=str,
396 | default="constant",
397 | help=(
398 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
399 | ' "constant", "constant_with_warmup"]'
400 | ),
401 | )
402 |
403 | parser.add_argument(
404 | "--snr_gamma",
405 | type=float,
406 | default=None,
407 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
408 | "More details here: https://arxiv.org/abs/2303.09556.",
409 | )
410 | parser.add_argument(
411 | "--lr_warmup_steps",
412 | type=int,
413 | default=500,
414 | help="Number of steps for the warmup in the lr scheduler.",
415 | )
416 | parser.add_argument(
417 | "--lr_num_cycles",
418 | type=int,
419 | default=1,
420 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
421 | )
422 | parser.add_argument(
423 | "--lr_power",
424 | type=float,
425 | default=1.0,
426 | help="Power factor of the polynomial scheduler.",
427 | )
428 | parser.add_argument(
429 | "--dataloader_num_workers",
430 | type=int,
431 | default=0,
432 | help=(
433 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
434 | ),
435 | )
436 |
437 | parser.add_argument(
438 | "--optimizer",
439 | type=str,
440 | default="AdamW",
441 | help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
442 | )
443 |
444 | parser.add_argument(
445 | "--use_8bit_adam",
446 | action="store_true",
447 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
448 | )
449 |
450 | parser.add_argument(
451 | "--adam_beta1",
452 | type=float,
453 | default=0.9,
454 | help="The beta1 parameter for the Adam and Prodigy optimizers.",
455 | )
456 | parser.add_argument(
457 | "--adam_beta2",
458 | type=float,
459 | default=0.999,
460 | help="The beta2 parameter for the Adam and Prodigy optimizers.",
461 | )
462 | parser.add_argument(
463 | "--prodigy_beta3",
464 | type=float,
465 | default=None,
466 | help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
467 | "uses the value of square root of beta2. Ignored if optimizer is adamW",
468 | )
469 | parser.add_argument(
470 | "--prodigy_decouple",
471 | type=bool,
472 | default=True,
473 | help="Use AdamW style decoupled weight decay",
474 | )
475 | parser.add_argument(
476 | "--adam_weight_decay",
477 | type=float,
478 | default=1e-04,
479 | help="Weight decay to use for unet params",
480 | )
481 | parser.add_argument(
482 | "--adam_weight_decay_text_encoder",
483 | type=float,
484 | default=1e-03,
485 | help="Weight decay to use for text_encoder",
486 | )
487 |
488 | parser.add_argument(
489 | "--adam_epsilon",
490 | type=float,
491 | default=1e-08,
492 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
493 | )
494 |
495 | parser.add_argument(
496 | "--prodigy_use_bias_correction",
497 | type=bool,
498 | default=True,
499 | help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
500 | )
501 | parser.add_argument(
502 | "--prodigy_safeguard_warmup",
503 | type=bool,
504 | default=True,
505 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
506 | "Ignored if optimizer is adamW",
507 | )
508 | parser.add_argument(
509 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
510 | )
511 | parser.add_argument(
512 | "--push_to_hub",
513 | action="store_true",
514 | help="Whether or not to push the model to the Hub.",
515 | )
516 | parser.add_argument(
517 | "--hub_token",
518 | type=str,
519 | default=None,
520 | help="The token to use to push to the Model Hub.",
521 | )
522 | parser.add_argument(
523 | "--hub_model_id",
524 | type=str,
525 | default=None,
526 | help="The name of the repository to keep in sync with the local `output_dir`.",
527 | )
528 | parser.add_argument(
529 | "--logging_dir",
530 | type=str,
531 | default="logs",
532 | help=(
533 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
534 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
535 | ),
536 | )
537 | parser.add_argument(
538 | "--allow_tf32",
539 | action="store_true",
540 | help=(
541 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
542 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
543 | ),
544 | )
545 | parser.add_argument(
546 | "--report_to",
547 | type=str,
548 | default="tensorboard",
549 | help=(
550 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
551 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
552 | ),
553 | )
554 | parser.add_argument(
555 | "--mixed_precision",
556 | type=str,
557 | default=None,
558 | choices=["no", "fp16", "bf16"],
559 | help=(
560 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
561 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
562 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
563 | ),
564 | )
565 | parser.add_argument(
566 | "--prior_generation_precision",
567 | type=str,
568 | default=None,
569 | choices=["no", "fp32", "fp16", "bf16"],
570 | help=(
571 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
572 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
573 | ),
574 | )
575 | parser.add_argument(
576 | "--local_rank",
577 | type=int,
578 | default=-1,
579 | help="For distributed training: local_rank",
580 | )
581 | parser.add_argument(
582 | "--enable_xformers_memory_efficient_attention",
583 | action="store_true",
584 | help="Whether or not to use xformers.",
585 | )
586 |
587 | if input_args is not None:
588 | args = parser.parse_args(input_args)
589 | else:
590 | args = parser.parse_args()
591 |
592 | if args.dataset_name is None and args.instance_data_dir is None:
593 | raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
594 |
595 | if args.dataset_name is not None and args.instance_data_dir is not None:
596 | raise ValueError(
597 | "Specify only one of `--dataset_name` or `--instance_data_dir`"
598 | )
599 |
600 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
601 | if env_local_rank != -1 and env_local_rank != args.local_rank:
602 | args.local_rank = env_local_rank
603 |
604 | if args.with_prior_preservation:
605 | if args.class_data_dir is None:
606 | raise ValueError("You must specify a data directory for class images.")
607 | if args.class_prompt is None:
608 | raise ValueError("You must specify prompt for class images.")
609 | else:
610 | # logger is not available yet
611 | if args.class_data_dir is not None:
612 | warnings.warn(
613 | "You need not use --class_data_dir without --with_prior_preservation."
614 | )
615 | if args.class_prompt is not None:
616 | warnings.warn(
617 | "You need not use --class_prompt without --with_prior_preservation."
618 | )
619 |
620 | return args
621 |
622 |
623 | class DreamBoothDataset(Dataset):
624 | """
625 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
626 | It pre-processes the images.
627 | """
628 |
629 | def __init__(
630 | self,
631 | instance_data_root,
632 | instance_prompt,
633 | class_prompt,
634 | class_data_root=None,
635 | class_num=None,
636 | size=1024,
637 | repeats=1,
638 | center_crop=False,
639 | ):
640 | self.size = size
641 | self.center_crop = center_crop
642 |
643 | self.instance_prompt = instance_prompt
644 | self.custom_instance_prompts = None
645 | self.class_prompt = class_prompt
646 |
647 | # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
648 | # we load the training data using load_dataset
649 | if args.dataset_name is not None:
650 | try:
651 | from datasets import load_dataset
652 | except ImportError:
653 | raise ImportError(
654 | "You are trying to load your data using the datasets library. If you wish to train using custom "
655 | "captions please install the datasets library: `pip install datasets`. If you wish to load a "
656 | "local folder containing images only, specify --instance_data_dir instead."
657 | )
658 | # Downloading and loading a dataset from the hub.
659 | # See more about loading custom images at
660 | # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
661 | dataset = load_dataset(
662 | args.dataset_name,
663 | args.dataset_config_name,
664 | cache_dir=args.cache_dir,
665 | )
666 | # Preprocessing the datasets.
667 | column_names = dataset["train"].column_names
668 |
669 | # 6. Get the column names for input/target.
670 | if args.image_column is None:
671 | image_column = column_names[0]
672 | logger.info(f"image column defaulting to {image_column}")
673 | else:
674 | image_column = args.image_column
675 | if image_column not in column_names:
676 | raise ValueError(
677 | f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
678 | )
679 | instance_images = dataset["train"][image_column]
680 |
681 | if args.caption_column is None:
682 | logger.info(
683 | "No caption column provided, defaulting to instance_prompt for all images. If your dataset "
684 | "contains captions/prompts for the images, make sure to specify the "
685 | "column as --caption_column"
686 | )
687 | self.custom_instance_prompts = None
688 | else:
689 | if args.caption_column not in column_names:
690 | raise ValueError(
691 | f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
692 | )
693 | custom_instance_prompts = dataset["train"][args.caption_column]
694 | # create final list of captions according to --repeats
695 | self.custom_instance_prompts = []
696 | for caption in custom_instance_prompts:
697 | self.custom_instance_prompts.extend(
698 | itertools.repeat(caption, repeats)
699 | )
700 | else:
701 | self.instance_data_root = Path(instance_data_root)
702 | if not self.instance_data_root.exists():
703 | raise ValueError("Instance images root doesn't exists.")
704 |
705 | instance_images = [
706 | Image.open(path) for path in list(Path(instance_data_root).iterdir())
707 | ]
708 | self.custom_instance_prompts = None
709 |
710 | self.instance_images = []
711 | for img in instance_images:
712 | self.instance_images.extend(itertools.repeat(img, repeats))
713 | self.num_instance_images = len(self.instance_images)
714 | self._length = self.num_instance_images
715 |
716 | if class_data_root is not None:
717 | self.class_data_root = Path(class_data_root)
718 | self.class_data_root.mkdir(parents=True, exist_ok=True)
719 | self.class_images_path = list(self.class_data_root.iterdir())
720 | if class_num is not None:
721 | self.num_class_images = min(len(self.class_images_path), class_num)
722 | else:
723 | self.num_class_images = len(self.class_images_path)
724 | self._length = max(self.num_class_images, self.num_instance_images)
725 | else:
726 | self.class_data_root = None
727 |
728 | self.image_transforms = transforms.Compose(
729 | [
730 | transforms.Resize(
731 | size, interpolation=transforms.InterpolationMode.BILINEAR
732 | ),
733 | transforms.CenterCrop(size)
734 | if center_crop
735 | else transforms.RandomCrop(size),
736 | transforms.ToTensor(),
737 | transforms.Normalize([0.5], [0.5]),
738 | ]
739 | )
740 |
741 | def __len__(self):
742 | return self._length
743 |
744 | def __getitem__(self, index):
745 | example = {}
746 | instance_image = self.instance_images[index % self.num_instance_images]
747 | instance_image = exif_transpose(instance_image)
748 |
749 | if not instance_image.mode == "RGB":
750 | instance_image = instance_image.convert("RGB")
751 | example["instance_images"] = self.image_transforms(instance_image)
752 |
753 | if self.custom_instance_prompts:
754 | caption = self.custom_instance_prompts[index % self.num_instance_images]
755 | if caption:
756 | example["instance_prompt"] = caption
757 | else:
758 | example["instance_prompt"] = self.instance_prompt
759 |
760 | else: # costum prompts were provided, but length does not match size of image dataset
761 | example["instance_prompt"] = self.instance_prompt
762 |
763 | if self.class_data_root:
764 | class_image = Image.open(
765 | self.class_images_path[index % self.num_class_images]
766 | )
767 | class_image = exif_transpose(class_image)
768 |
769 | if not class_image.mode == "RGB":
770 | class_image = class_image.convert("RGB")
771 | example["class_images"] = self.image_transforms(class_image)
772 | example["class_prompt"] = self.class_prompt
773 |
774 | return example
775 |
776 |
777 | def collate_fn(examples, with_prior_preservation=False):
778 | pixel_values = [example["instance_images"] for example in examples]
779 | prompts = [example["instance_prompt"] for example in examples]
780 |
781 | # Concat class and instance examples for prior preservation.
782 | # We do this to avoid doing two forward passes.
783 | if with_prior_preservation:
784 | pixel_values += [example["class_images"] for example in examples]
785 | prompts += [example["class_prompt"] for example in examples]
786 |
787 | pixel_values = torch.stack(pixel_values)
788 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
789 |
790 | batch = {"pixel_values": pixel_values, "prompts": prompts}
791 | return batch
792 |
793 |
794 | class PromptDataset(Dataset):
795 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
796 |
797 | def __init__(self, prompt, num_samples):
798 | self.prompt = prompt
799 | self.num_samples = num_samples
800 |
801 | def __len__(self):
802 | return self.num_samples
803 |
804 | def __getitem__(self, index):
805 | example = {}
806 | example["prompt"] = self.prompt
807 | example["index"] = index
808 | return example
809 |
810 |
811 | def tokenize_prompt(tokenizer, prompt):
812 | text_inputs = tokenizer(
813 | prompt,
814 | padding="max_length",
815 | max_length=tokenizer.model_max_length,
816 | truncation=True,
817 | return_tensors="pt",
818 | )
819 | text_input_ids = text_inputs.input_ids
820 | return text_input_ids
821 |
822 |
823 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
824 | def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
825 | prompt_embeds_list = []
826 |
827 | for i, text_encoder in enumerate(text_encoders):
828 | if tokenizers is not None:
829 | tokenizer = tokenizers[i]
830 | text_input_ids = tokenize_prompt(tokenizer, prompt)
831 | else:
832 | assert text_input_ids_list is not None
833 | text_input_ids = text_input_ids_list[i]
834 |
835 | prompt_embeds = text_encoder(
836 | text_input_ids.to(text_encoder.device),
837 | output_hidden_states=True,
838 | )
839 |
840 | # We are only ALWAYS interested in the pooled output of the final text encoder
841 | pooled_prompt_embeds = prompt_embeds[0]
842 | prompt_embeds = prompt_embeds.hidden_states[-2]
843 | bs_embed, seq_len, _ = prompt_embeds.shape
844 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
845 | prompt_embeds_list.append(prompt_embeds)
846 |
847 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
848 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
849 | return prompt_embeds, pooled_prompt_embeds
850 |
851 |
852 | def main(args):
853 | logging_dir = Path(args.output_dir, args.logging_dir)
854 |
855 | accelerator_project_config = ProjectConfiguration(
856 | project_dir=args.output_dir, logging_dir=logging_dir
857 | )
858 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
859 | accelerator = Accelerator(
860 | gradient_accumulation_steps=args.gradient_accumulation_steps,
861 | mixed_precision=args.mixed_precision,
862 | log_with=args.report_to,
863 | project_config=accelerator_project_config,
864 | kwargs_handlers=[kwargs],
865 | )
866 |
867 | if args.report_to == "wandb":
868 | if not is_wandb_available():
869 | raise ImportError(
870 | "Make sure to install wandb if you want to use it for logging during training."
871 | )
872 | import wandb
873 |
874 | # Make one log on every process with the configuration for debugging.
875 | logging.basicConfig(
876 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
877 | datefmt="%m/%d/%Y %H:%M:%S",
878 | level=logging.INFO,
879 | )
880 | logger.info(accelerator.state, main_process_only=False)
881 | if accelerator.is_local_main_process:
882 | transformers.utils.logging.set_verbosity_warning()
883 | diffusers.utils.logging.set_verbosity_info()
884 | else:
885 | transformers.utils.logging.set_verbosity_error()
886 | diffusers.utils.logging.set_verbosity_error()
887 |
888 | # If passed along, set the training seed now.
889 | if args.seed is not None:
890 | set_seed(args.seed)
891 |
892 | # Generate class images if prior preservation is enabled.
893 | if args.with_prior_preservation:
894 | class_images_dir = Path(args.class_data_dir)
895 | if not class_images_dir.exists():
896 | class_images_dir.mkdir(parents=True)
897 | cur_class_images = len(list(class_images_dir.iterdir()))
898 |
899 | if cur_class_images < args.num_class_images:
900 | torch_dtype = (
901 | torch.float16 if accelerator.device.type == "cuda" else torch.float32
902 | )
903 | if args.prior_generation_precision == "fp32":
904 | torch_dtype = torch.float32
905 | elif args.prior_generation_precision == "fp16":
906 | torch_dtype = torch.float16
907 | elif args.prior_generation_precision == "bf16":
908 | torch_dtype = torch.bfloat16
909 | pipeline = StableDiffusionXLPipeline.from_pretrained(
910 | args.pretrained_model_name_or_path,
911 | torch_dtype=torch_dtype,
912 | revision=args.revision,
913 | variant=args.variant,
914 | )
915 | pipeline.set_progress_bar_config(disable=True)
916 |
917 | num_new_images = args.num_class_images - cur_class_images
918 | logger.info(f"Number of class images to sample: {num_new_images}.")
919 |
920 | sample_dataset = PromptDataset(args.class_prompt, num_new_images)
921 | sample_dataloader = torch.utils.data.DataLoader(
922 | sample_dataset, batch_size=args.sample_batch_size
923 | )
924 |
925 | sample_dataloader = accelerator.prepare(sample_dataloader)
926 | pipeline.to(accelerator.device)
927 |
928 | for example in tqdm(
929 | sample_dataloader,
930 | desc="Generating class images",
931 | disable=not accelerator.is_local_main_process,
932 | ):
933 | images = pipeline(example["prompt"]).images
934 |
935 | for i, image in enumerate(images):
936 | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
937 | image_filename = (
938 | class_images_dir
939 | / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
940 | )
941 | image.save(image_filename)
942 |
943 | del pipeline
944 | if torch.cuda.is_available():
945 | torch.cuda.empty_cache()
946 |
947 | # Handle the repository creation
948 | if accelerator.is_main_process:
949 | if args.output_dir is not None:
950 | os.makedirs(args.output_dir, exist_ok=True)
951 |
952 | if args.push_to_hub:
953 | repo_id = create_repo(
954 | repo_id=args.hub_model_id or Path(args.output_dir).name,
955 | exist_ok=True,
956 | token=args.hub_token,
957 | ).repo_id
958 |
959 | # Load the tokenizers
960 | tokenizer_one = AutoTokenizer.from_pretrained(
961 | args.pretrained_model_name_or_path,
962 | subfolder="tokenizer",
963 | revision=args.revision,
964 | use_fast=False,
965 | )
966 | tokenizer_two = AutoTokenizer.from_pretrained(
967 | args.pretrained_model_name_or_path,
968 | subfolder="tokenizer_2",
969 | revision=args.revision,
970 | use_fast=False,
971 | )
972 |
973 | # import correct text encoder classes
974 | text_encoder_cls_one = import_model_class_from_model_name_or_path(
975 | args.pretrained_model_name_or_path, args.revision
976 | )
977 | text_encoder_cls_two = import_model_class_from_model_name_or_path(
978 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
979 | )
980 |
981 | # Load scheduler and models
982 | noise_scheduler = DDPMScheduler.from_pretrained(
983 | args.pretrained_model_name_or_path, subfolder="scheduler"
984 | )
985 | text_encoder_one = text_encoder_cls_one.from_pretrained(
986 | args.pretrained_model_name_or_path,
987 | subfolder="text_encoder",
988 | revision=args.revision,
989 | variant=args.variant,
990 | )
991 | text_encoder_two = text_encoder_cls_two.from_pretrained(
992 | args.pretrained_model_name_or_path,
993 | subfolder="text_encoder_2",
994 | revision=args.revision,
995 | variant=args.variant,
996 | )
997 | vae_path = (
998 | args.pretrained_model_name_or_path
999 | if args.pretrained_vae_model_name_or_path is None
1000 | else args.pretrained_vae_model_name_or_path
1001 | )
1002 | vae = AutoencoderKL.from_pretrained(
1003 | vae_path,
1004 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
1005 | revision=args.revision,
1006 | variant=args.variant,
1007 | )
1008 | unet = UNet2DConditionModel.from_pretrained(
1009 | args.pretrained_model_name_or_path,
1010 | subfolder="unet",
1011 | revision=args.revision,
1012 | variant=args.variant,
1013 | )
1014 |
1015 | # We only train the additional adapter layers
1016 | vae.requires_grad_(False)
1017 | text_encoder_one.requires_grad_(False)
1018 | text_encoder_two.requires_grad_(False)
1019 | unet.requires_grad_(False)
1020 |
1021 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
1022 | # as these weights are only used for inference, keeping weights in full precision is not required.
1023 | weight_dtype = torch.float32
1024 | if accelerator.mixed_precision == "fp16":
1025 | weight_dtype = torch.float16
1026 | elif accelerator.mixed_precision == "bf16":
1027 | weight_dtype = torch.bfloat16
1028 |
1029 | # Move unet, vae and text_encoder to device and cast to weight_dtype
1030 | unet.to(accelerator.device, dtype=weight_dtype)
1031 |
1032 | # The VAE is always in float32 to avoid NaN losses.
1033 | vae.to(accelerator.device, dtype=torch.float32)
1034 |
1035 | text_encoder_one.to(accelerator.device, dtype=weight_dtype)
1036 | text_encoder_two.to(accelerator.device, dtype=weight_dtype)
1037 |
1038 | if args.enable_xformers_memory_efficient_attention:
1039 | if is_xformers_available():
1040 | import xformers
1041 |
1042 | xformers_version = version.parse(xformers.__version__)
1043 | if xformers_version == version.parse("0.0.16"):
1044 | logger.warn(
1045 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
1046 | "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1047 | )
1048 | unet.enable_xformers_memory_efficient_attention()
1049 | else:
1050 | raise ValueError(
1051 | "xformers is not available. Make sure it is installed correctly"
1052 | )
1053 |
1054 | if args.gradient_checkpointing:
1055 | unet.enable_gradient_checkpointing()
1056 |
1057 | # now we will add SC-Tuner
1058 | unet.set_sctuner()
1059 |
1060 | # Make sure the trainable params are in float32.
1061 | if args.mixed_precision == "fp16":
1062 | models = [unet]
1063 | for model in models:
1064 | for param in model.parameters():
1065 | # only upcast trainable parameters (LoRA) into fp32
1066 | if param.requires_grad:
1067 | param.data = param.to(torch.float32)
1068 |
1069 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
1070 | def save_model_hook(models, weights, output_dir):
1071 | if accelerator.is_main_process:
1072 | unet_sctuner_layers_to_save = None
1073 | for model in models:
1074 | if isinstance(model, type(accelerator.unwrap_model(unet))):
1075 | unet_sctuner_layers_to_save = {
1076 | k: v for k, v in model.state_dict().items() if "sc_tuners" in k
1077 | }
1078 | else:
1079 | raise ValueError(f"unexpected save model: {model.__class__}")
1080 |
1081 | # make sure to pop weight so that corresponding model is not saved again
1082 | weights.pop()
1083 | save_scedit(
1084 | save_directory=output_dir, state_dict=unet_sctuner_layers_to_save
1085 | )
1086 |
1087 | def load_model_hook(models, input_dir):
1088 | unet_ = None
1089 |
1090 | while len(models) > 0:
1091 | model = models.pop()
1092 |
1093 | if isinstance(model, type(accelerator.unwrap_model(unet))):
1094 | unet_ = model
1095 | else:
1096 | raise ValueError(f"unexpected save model: {model.__class__}")
1097 | load_scedit_into_unet(state_dict=input_dir, unet=unet_)
1098 |
1099 | accelerator.register_save_state_pre_hook(save_model_hook)
1100 | accelerator.register_load_state_pre_hook(load_model_hook)
1101 |
1102 | # Enable TF32 for faster training on Ampere GPUs,
1103 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1104 | if args.allow_tf32:
1105 | torch.backends.cuda.matmul.allow_tf32 = True
1106 |
1107 | if args.scale_lr:
1108 | args.learning_rate = (
1109 | args.learning_rate
1110 | * args.gradient_accumulation_steps
1111 | * args.train_batch_size
1112 | * accelerator.num_processes
1113 | )
1114 |
1115 | unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters()))
1116 | total_params = sum(p.numel() for p in unet_lora_parameters)
1117 |
1118 | # Optimization parameters
1119 | unet_lora_parameters_with_lr = {
1120 | "params": unet_lora_parameters,
1121 | "lr": args.learning_rate,
1122 | }
1123 | params_to_optimize = [unet_lora_parameters_with_lr]
1124 |
1125 | # Optimizer creation
1126 | if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
1127 | logger.warn(
1128 | f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
1129 | "Defaulting to adamW"
1130 | )
1131 | args.optimizer = "adamw"
1132 |
1133 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
1134 | logger.warn(
1135 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
1136 | f"set to {args.optimizer.lower()}"
1137 | )
1138 |
1139 | if args.optimizer.lower() == "adamw":
1140 | if args.use_8bit_adam:
1141 | try:
1142 | import bitsandbytes as bnb
1143 | except ImportError:
1144 | raise ImportError(
1145 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1146 | )
1147 |
1148 | optimizer_class = bnb.optim.AdamW8bit
1149 | else:
1150 | optimizer_class = torch.optim.AdamW
1151 |
1152 | optimizer = optimizer_class(
1153 | params_to_optimize,
1154 | betas=(args.adam_beta1, args.adam_beta2),
1155 | weight_decay=args.adam_weight_decay,
1156 | eps=args.adam_epsilon,
1157 | )
1158 |
1159 | if args.optimizer.lower() == "prodigy":
1160 | try:
1161 | import prodigyopt
1162 | except ImportError:
1163 | raise ImportError(
1164 | "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`"
1165 | )
1166 |
1167 | optimizer_class = prodigyopt.Prodigy
1168 |
1169 | if args.learning_rate <= 0.1:
1170 | logger.warn(
1171 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
1172 | )
1173 |
1174 | optimizer = optimizer_class(
1175 | params_to_optimize,
1176 | lr=args.learning_rate,
1177 | betas=(args.adam_beta1, args.adam_beta2),
1178 | beta3=args.prodigy_beta3,
1179 | weight_decay=args.adam_weight_decay,
1180 | eps=args.adam_epsilon,
1181 | decouple=args.prodigy_decouple,
1182 | use_bias_correction=args.prodigy_use_bias_correction,
1183 | safeguard_warmup=args.prodigy_safeguard_warmup,
1184 | )
1185 |
1186 | # Dataset and DataLoaders creation:
1187 | train_dataset = DreamBoothDataset(
1188 | instance_data_root=args.instance_data_dir,
1189 | instance_prompt=args.instance_prompt,
1190 | class_prompt=args.class_prompt,
1191 | class_data_root=args.class_data_dir if args.with_prior_preservation else None,
1192 | class_num=args.num_class_images,
1193 | size=args.resolution,
1194 | repeats=args.repeats,
1195 | center_crop=args.center_crop,
1196 | )
1197 |
1198 | train_dataloader = torch.utils.data.DataLoader(
1199 | train_dataset,
1200 | batch_size=args.train_batch_size,
1201 | shuffle=True,
1202 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1203 | num_workers=args.dataloader_num_workers,
1204 | )
1205 |
1206 | # Computes additional embeddings/ids required by the SDXL UNet.
1207 | # regular text embeddings (when `train_text_encoder` is not True)
1208 | # pooled text embeddings
1209 | # time ids
1210 |
1211 | def compute_time_ids():
1212 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1213 | original_size = (args.resolution, args.resolution)
1214 | target_size = (args.resolution, args.resolution)
1215 | crops_coords_top_left = (
1216 | args.crops_coords_top_left_h,
1217 | args.crops_coords_top_left_w,
1218 | )
1219 | add_time_ids = list(original_size + crops_coords_top_left + target_size)
1220 | add_time_ids = torch.tensor([add_time_ids])
1221 | add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1222 | return add_time_ids
1223 |
1224 | tokenizers = [tokenizer_one, tokenizer_two]
1225 | text_encoders = [text_encoder_one, text_encoder_two]
1226 |
1227 | def compute_text_embeddings(prompt, text_encoders, tokenizers):
1228 | with torch.no_grad():
1229 | prompt_embeds, pooled_prompt_embeds = encode_prompt(
1230 | text_encoders, tokenizers, prompt
1231 | )
1232 | prompt_embeds = prompt_embeds.to(accelerator.device)
1233 | pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
1234 | return prompt_embeds, pooled_prompt_embeds
1235 |
1236 | # Handle instance prompt.
1237 | instance_time_ids = compute_time_ids()
1238 |
1239 | # If no type of tuning is done on the text_encoder and custom instance prompts are NOT
1240 | # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
1241 | # the redundant encoding.
1242 | if not train_dataset.custom_instance_prompts:
1243 | (
1244 | instance_prompt_hidden_states,
1245 | instance_pooled_prompt_embeds,
1246 | ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers)
1247 |
1248 | # Handle class prompt for prior-preservation.
1249 | if args.with_prior_preservation:
1250 | class_time_ids = compute_time_ids()
1251 | (
1252 | class_prompt_hidden_states,
1253 | class_pooled_prompt_embeds,
1254 | ) = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers)
1255 |
1256 | # Clear the memory here
1257 | if not train_dataset.custom_instance_prompts:
1258 | del tokenizers, text_encoders
1259 | gc.collect()
1260 | torch.cuda.empty_cache()
1261 |
1262 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
1263 | # pack the statically computed variables appropriately here. This is so that we don't
1264 | # have to pass them to the dataloader.
1265 | add_time_ids = instance_time_ids
1266 | if args.with_prior_preservation:
1267 | add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
1268 |
1269 | if not train_dataset.custom_instance_prompts:
1270 | prompt_embeds = instance_prompt_hidden_states
1271 | unet_add_text_embeds = instance_pooled_prompt_embeds
1272 | if args.with_prior_preservation:
1273 | prompt_embeds = torch.cat(
1274 | [prompt_embeds, class_prompt_hidden_states], dim=0
1275 | )
1276 | unet_add_text_embeds = torch.cat(
1277 | [unet_add_text_embeds, class_pooled_prompt_embeds], dim=0
1278 | )
1279 |
1280 | # Scheduler and math around the number of training steps.
1281 | overrode_max_train_steps = False
1282 | num_update_steps_per_epoch = math.ceil(
1283 | len(train_dataloader) / args.gradient_accumulation_steps
1284 | )
1285 | if args.max_train_steps is None:
1286 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1287 | overrode_max_train_steps = True
1288 |
1289 | lr_scheduler = get_scheduler(
1290 | args.lr_scheduler,
1291 | optimizer=optimizer,
1292 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1293 | num_training_steps=args.max_train_steps * accelerator.num_processes,
1294 | num_cycles=args.lr_num_cycles,
1295 | power=args.lr_power,
1296 | )
1297 |
1298 | # Prepare everything with our `accelerator`.
1299 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1300 | unet, optimizer, train_dataloader, lr_scheduler
1301 | )
1302 |
1303 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1304 | num_update_steps_per_epoch = math.ceil(
1305 | len(train_dataloader) / args.gradient_accumulation_steps
1306 | )
1307 | if overrode_max_train_steps:
1308 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1309 | # Afterwards we recalculate our number of training epochs
1310 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1311 |
1312 | # We need to initialize the trackers we use, and also store our configuration.
1313 | # The trackers initializes automatically on the main process.
1314 | if accelerator.is_main_process:
1315 | accelerator.init_trackers("dreambooth-scedit-sd-xl", config=vars(args))
1316 |
1317 | # Train!
1318 | total_batch_size = (
1319 | args.train_batch_size
1320 | * accelerator.num_processes
1321 | * args.gradient_accumulation_steps
1322 | )
1323 |
1324 | logger.info("***** Running training *****")
1325 | logger.info(f" Num examples = {len(train_dataset)}")
1326 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1327 | logger.info(f" Num Epochs = {args.num_train_epochs}")
1328 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1329 | logger.info(
1330 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
1331 | )
1332 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1333 | logger.info(f" Total optimization steps = {args.max_train_steps}")
1334 | logger.info(f" Number of Trainable Parameters = {total_params * 1.e-6:.2f} M")
1335 |
1336 | global_step = 0
1337 | first_epoch = 0
1338 |
1339 | # Potentially load in the weights and states from a previous save
1340 | if args.resume_from_checkpoint:
1341 | if args.resume_from_checkpoint != "latest":
1342 | path = os.path.basename(args.resume_from_checkpoint)
1343 | else:
1344 | # Get the mos recent checkpoint
1345 | dirs = os.listdir(args.output_dir)
1346 | dirs = [d for d in dirs if d.startswith("checkpoint")]
1347 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1348 | path = dirs[-1] if len(dirs) > 0 else None
1349 |
1350 | if path is None:
1351 | accelerator.print(
1352 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1353 | )
1354 | args.resume_from_checkpoint = None
1355 | initial_global_step = 0
1356 | else:
1357 | accelerator.print(f"Resuming from checkpoint {path}")
1358 | accelerator.load_state(os.path.join(args.output_dir, path))
1359 | global_step = int(path.split("-")[1])
1360 |
1361 | initial_global_step = global_step
1362 | first_epoch = global_step // num_update_steps_per_epoch
1363 |
1364 | else:
1365 | initial_global_step = 0
1366 |
1367 | progress_bar = tqdm(
1368 | range(0, args.max_train_steps),
1369 | initial=initial_global_step,
1370 | desc="Steps",
1371 | # Only show the progress bar once on each machine.
1372 | disable=not accelerator.is_local_main_process,
1373 | )
1374 |
1375 | for epoch in range(first_epoch, args.num_train_epochs):
1376 | unet.train()
1377 | for step, batch in enumerate(train_dataloader):
1378 | with accelerator.accumulate(unet):
1379 | pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1380 | prompts = batch["prompts"]
1381 |
1382 | # encode batch prompts when custom prompts are provided for each image -
1383 | if train_dataset.custom_instance_prompts:
1384 | prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
1385 | prompts, text_encoders, tokenizers
1386 | )
1387 |
1388 | # Convert images to latent space
1389 | model_input = vae.encode(pixel_values).latent_dist.sample()
1390 | model_input = model_input * vae.config.scaling_factor
1391 | if args.pretrained_vae_model_name_or_path is None:
1392 | model_input = model_input.to(weight_dtype)
1393 |
1394 | # Sample noise that we'll add to the latents
1395 | noise = torch.randn_like(model_input)
1396 | bsz = model_input.shape[0]
1397 | # Sample a random timestep for each image
1398 | timesteps = torch.randint(
1399 | 0,
1400 | noise_scheduler.config.num_train_timesteps,
1401 | (bsz,),
1402 | device=model_input.device,
1403 | )
1404 | timesteps = timesteps.long()
1405 |
1406 | # Add noise to the model input according to the noise magnitude at each timestep
1407 | # (this is the forward diffusion process)
1408 | noisy_model_input = noise_scheduler.add_noise(
1409 | model_input, noise, timesteps
1410 | )
1411 |
1412 | # Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
1413 | if not train_dataset.custom_instance_prompts:
1414 | elems_to_repeat_text_embeds = (
1415 | bsz // 2 if args.with_prior_preservation else bsz
1416 | )
1417 | elems_to_repeat_time_ids = (
1418 | bsz // 2 if args.with_prior_preservation else bsz
1419 | )
1420 | else:
1421 | elems_to_repeat_text_embeds = 1
1422 | elems_to_repeat_time_ids = (
1423 | bsz // 2 if args.with_prior_preservation else bsz
1424 | )
1425 |
1426 | # Predict the noise residual
1427 | unet_added_conditions = {
1428 | "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
1429 | "text_embeds": unet_add_text_embeds.repeat(
1430 | elems_to_repeat_text_embeds, 1
1431 | ),
1432 | }
1433 | prompt_embeds_input = prompt_embeds.repeat(
1434 | elems_to_repeat_text_embeds, 1, 1
1435 | )
1436 | model_pred = unet(
1437 | noisy_model_input,
1438 | timesteps,
1439 | prompt_embeds_input,
1440 | added_cond_kwargs=unet_added_conditions,
1441 | ).sample
1442 |
1443 | # Get the target for loss depending on the prediction type
1444 | if noise_scheduler.config.prediction_type == "epsilon":
1445 | target = noise
1446 | elif noise_scheduler.config.prediction_type == "v_prediction":
1447 | target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1448 | else:
1449 | raise ValueError(
1450 | f"Unknown prediction type {noise_scheduler.config.prediction_type}"
1451 | )
1452 |
1453 | if args.with_prior_preservation:
1454 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1455 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1456 | target, target_prior = torch.chunk(target, 2, dim=0)
1457 |
1458 | # Compute prior loss
1459 | prior_loss = F.mse_loss(
1460 | model_pred_prior.float(), target_prior.float(), reduction="mean"
1461 | )
1462 |
1463 | if args.snr_gamma is None:
1464 | loss = F.mse_loss(
1465 | model_pred.float(), target.float(), reduction="mean"
1466 | )
1467 | else:
1468 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1469 | # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1470 | # This is discussed in Section 4.2 of the same paper.
1471 | snr = compute_snr(noise_scheduler, timesteps)
1472 | base_weight = (
1473 | torch.stack(
1474 | [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1
1475 | ).min(dim=1)[0]
1476 | / snr
1477 | )
1478 |
1479 | if noise_scheduler.config.prediction_type == "v_prediction":
1480 | # Velocity objective needs to be floored to an SNR weight of one.
1481 | mse_loss_weights = base_weight + 1
1482 | else:
1483 | # Epsilon and sample both use the same loss weights.
1484 | mse_loss_weights = base_weight
1485 |
1486 | loss = F.mse_loss(
1487 | model_pred.float(), target.float(), reduction="none"
1488 | )
1489 | loss = (
1490 | loss.mean(dim=list(range(1, len(loss.shape))))
1491 | * mse_loss_weights
1492 | )
1493 | loss = loss.mean()
1494 |
1495 | if args.with_prior_preservation:
1496 | # Add the prior loss to the instance loss.
1497 | loss = loss + args.prior_loss_weight * prior_loss
1498 |
1499 | accelerator.backward(loss)
1500 | if accelerator.sync_gradients:
1501 | params_to_clip = unet_lora_parameters
1502 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1503 | optimizer.step()
1504 | lr_scheduler.step()
1505 | optimizer.zero_grad()
1506 |
1507 | # Checks if the accelerator has performed an optimization step behind the scenes
1508 | if accelerator.sync_gradients:
1509 | progress_bar.update(1)
1510 | global_step += 1
1511 |
1512 | if accelerator.is_main_process:
1513 | if global_step % args.checkpointing_steps == 0:
1514 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1515 | if args.checkpoints_total_limit is not None:
1516 | checkpoints = os.listdir(args.output_dir)
1517 | checkpoints = [
1518 | d for d in checkpoints if d.startswith("checkpoint")
1519 | ]
1520 | checkpoints = sorted(
1521 | checkpoints, key=lambda x: int(x.split("-")[1])
1522 | )
1523 |
1524 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1525 | if len(checkpoints) >= args.checkpoints_total_limit:
1526 | num_to_remove = (
1527 | len(checkpoints) - args.checkpoints_total_limit + 1
1528 | )
1529 | removing_checkpoints = checkpoints[0:num_to_remove]
1530 |
1531 | logger.info(
1532 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1533 | )
1534 | logger.info(
1535 | f"removing checkpoints: {', '.join(removing_checkpoints)}"
1536 | )
1537 |
1538 | for removing_checkpoint in removing_checkpoints:
1539 | removing_checkpoint = os.path.join(
1540 | args.output_dir, removing_checkpoint
1541 | )
1542 | shutil.rmtree(removing_checkpoint)
1543 |
1544 | save_path = os.path.join(
1545 | args.output_dir, f"checkpoint-{global_step}"
1546 | )
1547 | accelerator.save_state(save_path)
1548 | logger.info(f"Saved state to {save_path}")
1549 |
1550 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1551 | progress_bar.set_postfix(**logs)
1552 | accelerator.log(logs, step=global_step)
1553 |
1554 | if global_step >= args.max_train_steps:
1555 | break
1556 |
1557 | if accelerator.is_main_process:
1558 | if (
1559 | args.validation_prompt is not None
1560 | and epoch % args.validation_epochs == 0
1561 | ):
1562 | logger.info(
1563 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1564 | f" {args.validation_prompt}."
1565 | )
1566 | # create pipeline
1567 | text_encoder_one = text_encoder_cls_one.from_pretrained(
1568 | args.pretrained_model_name_or_path,
1569 | subfolder="text_encoder",
1570 | revision=args.revision,
1571 | variant=args.variant,
1572 | )
1573 | text_encoder_two = text_encoder_cls_two.from_pretrained(
1574 | args.pretrained_model_name_or_path,
1575 | subfolder="text_encoder_2",
1576 | revision=args.revision,
1577 | variant=args.variant,
1578 | )
1579 | pipeline = StableDiffusionXLPipeline.from_pretrained(
1580 | args.pretrained_model_name_or_path,
1581 | vae=vae,
1582 | text_encoder=accelerator.unwrap_model(text_encoder_one),
1583 | text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1584 | unet=accelerator.unwrap_model(unet),
1585 | revision=args.revision,
1586 | variant=args.variant,
1587 | torch_dtype=weight_dtype,
1588 | )
1589 |
1590 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1591 | scheduler_args = {}
1592 |
1593 | if "variance_type" in pipeline.scheduler.config:
1594 | variance_type = pipeline.scheduler.config.variance_type
1595 |
1596 | if variance_type in ["learned", "learned_range"]:
1597 | variance_type = "fixed_small"
1598 |
1599 | scheduler_args["variance_type"] = variance_type
1600 |
1601 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1602 | pipeline.scheduler.config, **scheduler_args
1603 | )
1604 |
1605 | pipeline = pipeline.to(accelerator.device)
1606 | pipeline.set_progress_bar_config(disable=True)
1607 |
1608 | # run inference
1609 | generator = (
1610 | torch.Generator(device=accelerator.device).manual_seed(args.seed)
1611 | if args.seed
1612 | else None
1613 | )
1614 | pipeline_args = {"prompt": args.validation_prompt}
1615 |
1616 | images = [
1617 | pipeline(**pipeline_args, generator=generator).images[0]
1618 | for _ in range(args.num_validation_images)
1619 | ]
1620 |
1621 | for tracker in accelerator.trackers:
1622 | if tracker.name == "tensorboard":
1623 | np_images = np.stack([np.asarray(img) for img in images])
1624 | tracker.writer.add_images(
1625 | "validation", np_images, epoch, dataformats="NHWC"
1626 | )
1627 | if tracker.name == "wandb":
1628 | tracker.log(
1629 | {
1630 | "validation": [
1631 | wandb.Image(
1632 | image, caption=f"{i}: {args.validation_prompt}"
1633 | )
1634 | for i, image in enumerate(images)
1635 | ]
1636 | }
1637 | )
1638 |
1639 | del pipeline
1640 | torch.cuda.empty_cache()
1641 |
1642 | # Save the lora layers
1643 | accelerator.wait_for_everyone()
1644 | if accelerator.is_main_process:
1645 | del optimizer
1646 | torch.cuda.empty_cache()
1647 | unet = accelerator.unwrap_model(unet)
1648 | unet = unet.to(torch.float32)
1649 | unet_sctuner_layers_to_save = {
1650 | k: v for k, v in model.state_dict().items() if "sc_tuners" in k
1651 | }
1652 | save_scedit(
1653 | save_directory=args.output_dir, state_dict=unet_sctuner_layers_to_save
1654 | )
1655 |
1656 | # Final inference
1657 | pipeline = StableDiffusionXLPipeline.from_pretrained(
1658 | args.pretrained_model_name_or_path,
1659 | unet=unet,
1660 | vae=vae,
1661 | revision=args.revision,
1662 | variant=args.variant,
1663 | torch_dtype=weight_dtype,
1664 | )
1665 |
1666 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1667 | scheduler_args = {}
1668 |
1669 | if "variance_type" in pipeline.scheduler.config:
1670 | variance_type = pipeline.scheduler.config.variance_type
1671 |
1672 | if variance_type in ["learned", "learned_range"]:
1673 | variance_type = "fixed_small"
1674 |
1675 | scheduler_args["variance_type"] = variance_type
1676 |
1677 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1678 | pipeline.scheduler.config, **scheduler_args
1679 | )
1680 |
1681 | # run inference
1682 | images = []
1683 | if args.validation_prompt and args.num_validation_images > 0:
1684 | pipeline = pipeline.to(accelerator.device, dtype=weight_dtype)
1685 | generator = (
1686 | torch.Generator(device=accelerator.device).manual_seed(args.seed)
1687 | if args.seed
1688 | else None
1689 | )
1690 | images = [
1691 | pipeline(
1692 | args.validation_prompt, num_inference_steps=25, generator=generator
1693 | ).images[0]
1694 | for _ in range(args.num_validation_images)
1695 | ]
1696 |
1697 | for tracker in accelerator.trackers:
1698 | if tracker.name == "tensorboard":
1699 | np_images = np.stack([np.asarray(img) for img in images])
1700 | tracker.writer.add_images(
1701 | "test", np_images, epoch, dataformats="NHWC"
1702 | )
1703 | if tracker.name == "wandb":
1704 | tracker.log(
1705 | {
1706 | "test": [
1707 | wandb.Image(
1708 | image, caption=f"{i}: {args.validation_prompt}"
1709 | )
1710 | for i, image in enumerate(images)
1711 | ]
1712 | }
1713 | )
1714 |
1715 | if args.push_to_hub:
1716 | save_model_card(
1717 | repo_id,
1718 | images=images,
1719 | base_model=args.pretrained_model_name_or_path,
1720 | instance_prompt=args.instance_prompt,
1721 | validation_prompt=args.validation_prompt,
1722 | repo_folder=args.output_dir,
1723 | vae_path=args.pretrained_vae_model_name_or_path,
1724 | )
1725 | upload_folder(
1726 | repo_id=repo_id,
1727 | folder_path=args.output_dir,
1728 | commit_message="End of training",
1729 | ignore_patterns=["step_*", "epoch_*"],
1730 | )
1731 |
1732 | accelerator.end_training()
1733 |
1734 |
1735 | if __name__ == "__main__":
1736 | args = parse_args()
1737 | main(args)
1738 |
--------------------------------------------------------------------------------