├── .gitignore ├── LICENSE ├── README.md ├── assets ├── figure4.png └── result.png ├── requirements.txt ├── scedit_pytorch ├── __init__.py ├── diffusers_modules │ ├── __init__.py │ ├── unet_2d_blocks.py │ └── unet_2d_condition.py ├── scedit.py └── utils.py ├── scripts ├── app.py └── scedit_pytorch.ipynb └── train_dreambooth_scedit_sdxl.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | tests 163 | test.py -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mkshing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SCEdit-pytorch 2 | Open In Colab 3 | 4 | 5 | This is an implementation of [SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing](https://scedit.github.io/) by [mkshing](https://twitter.com/mk1stats). 6 | 7 | ![result](assets/result.png) 8 | 9 | * Beyond the paper, this implementation can use SDXL as the pre-trained model. 10 | * Enabled to set the weight scale by `scale`. 11 | * As the paper says, the architecture of SCEdit is very flexible. `SCTunerLinearLayer` I implemented seems too small compared to what the paper mentioned. So, please let me know if you find better ones. 12 | 13 | ![image](assets/figure4.png) 14 | 15 | ## Installation 16 | 17 | ```bash 18 | git clone https://github.com/mkshing/scedit-pytorch.git 19 | cd scedit-pytorch 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## SC-Tuner 24 | 25 | ### Training 26 | The training script is pretty much same as the [lora's script from diffuers](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_sdxl.md). 27 | 28 | 29 | ```bash 30 | MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" 31 | INSTANCE_DIR="path-to-dataset" 32 | OUTPUT_DIR="scedit-trained-xl" 33 | 34 | accelerate launch train_dreambooth_scedit_sdxl.py \ 35 | --pretrained_model_name_or_path=$MODEL_NAME \ 36 | --instance_data_dir=$INSTANCE_DIR \ 37 | --output_dir=$OUTPUT_DIR \ 38 | --mixed_precision="fp16" \ 39 | --instance_prompt="a photo of sbu dog" \ 40 | --resolution=1024 \ 41 | --train_batch_size=1 \ 42 | --gradient_accumulation_steps=8 \ 43 | --learning_rate=5e-5 \ 44 | --lr_scheduler="constant" \ 45 | --lr_warmup_steps=0 \ 46 | --max_train_steps=1000 \ 47 | --checkpointing_steps=200 \ 48 | --validation_prompt="A photo of sbu dog in a bucket" \ 49 | --validation_epochs=100 \ 50 | --use_8bit_adam \ 51 | --report_to="wandb" \ 52 | --seed="0" \ 53 | --push_to_hub 54 | 55 | ``` 56 | 57 | ### Inference 58 | 59 | #### Python example: 60 | ```python 61 | from diffusers import DiffusionPipeline 62 | import torch 63 | from scedit_pytorch import UNet2DConditionModel, load_scedit_into_unet 64 | 65 | 66 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" 67 | scedit_model_id = "path-to-scedit" 68 | 69 | # load unet with sctuner 70 | unet = UNet2DConditionModel.from_pretrained(base_model_id, subfolder="unet") 71 | unet.set_sctuner(scale=1.0) 72 | unet = load_scedit_into_unet(scedit_model_id, unet) 73 | # load pipeline 74 | pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet) 75 | pipe = pipe.to(device="cuda", dtype=torch.float16) 76 | ``` 77 | 78 | #### Gradio Demo: 79 | ```bash 80 | MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" 81 | SCEDIT_NAME="mkshing/scedit-trained-xl" 82 | 83 | python scripts/gradio.py \ 84 | --pretrained_model_name_or_path $MODEL_NAME \ 85 | --scedit_name_or_path $SCEDIT_NAME 86 | ``` 87 | 88 | ## TODO 89 | - [x] SC-Tuner 90 | - [ ] CSC-Tuner -------------------------------------------------------------------------------- /assets/figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/scedit-pytorch/41b219515161fd78872254696c1212573867f372/assets/figure4.png -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/scedit-pytorch/41b219515161fd78872254696c1212573867f372/assets/result.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers>=0.24.0 2 | accelerate>=0.16.0 3 | torchvision 4 | transformers>=4.25.1 5 | ftfy 6 | tensorboard 7 | Jinja2 8 | gradio 9 | -------------------------------------------------------------------------------- /scedit_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusers_modules import UNet2DConditionModel 2 | from .utils import save_scedit, load_scedit_into_unet 3 | -------------------------------------------------------------------------------- /scedit_pytorch/diffusers_modules/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | All files in this folder were modified for SCEdit based on https://github.com/huggingface/diffusers/tree/v0.24.0/src/diffusers/models. 3 | Especially, SC-Tuners are inserted into up blocks (CrossAttnUpBlock2D, UpBlock2D) 4 | """ 5 | from .unet_2d_condition import UNet2DConditionModel 6 | -------------------------------------------------------------------------------- /scedit_pytorch/diffusers_modules/unet_2d_blocks.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional, Tuple, Union 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from diffusers.utils import is_torch_version, logging 7 | from diffusers.utils.torch_utils import apply_freeu 8 | from diffusers.models.dual_transformer_2d import DualTransformer2DModel 9 | from diffusers.models.resnet import ResnetBlock2D, Upsample2D 10 | from diffusers.models.transformer_2d import Transformer2DModel 11 | from ..scedit import SCTuner 12 | 13 | 14 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 15 | 16 | 17 | def get_up_block( 18 | up_block_type: str, 19 | num_layers: int, 20 | in_channels: int, 21 | out_channels: int, 22 | prev_output_channel: int, 23 | temb_channels: int, 24 | add_upsample: bool, 25 | resnet_eps: float, 26 | resnet_act_fn: str, 27 | resolution_idx: Optional[int] = None, 28 | transformer_layers_per_block: int = 1, 29 | num_attention_heads: Optional[int] = None, 30 | resnet_groups: Optional[int] = None, 31 | cross_attention_dim: Optional[int] = None, 32 | dual_cross_attention: bool = False, 33 | use_linear_projection: bool = False, 34 | only_cross_attention: bool = False, 35 | upcast_attention: bool = False, 36 | resnet_time_scale_shift: str = "default", 37 | attention_type: str = "default", 38 | resnet_skip_time_act: bool = False, 39 | resnet_out_scale_factor: float = 1.0, 40 | cross_attention_norm: Optional[str] = None, 41 | attention_head_dim: Optional[int] = None, 42 | upsample_type: Optional[str] = None, 43 | dropout: float = 0.0, 44 | ) -> nn.Module: 45 | # If attn head dim is not defined, we default it to the number of heads 46 | if attention_head_dim is None: 47 | logger.warn( 48 | f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." 49 | ) 50 | attention_head_dim = num_attention_heads 51 | 52 | up_block_type = ( 53 | up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type 54 | ) 55 | if up_block_type == "UpBlock2D": 56 | return UpBlock2D( 57 | num_layers=num_layers, 58 | in_channels=in_channels, 59 | out_channels=out_channels, 60 | prev_output_channel=prev_output_channel, 61 | temb_channels=temb_channels, 62 | resolution_idx=resolution_idx, 63 | dropout=dropout, 64 | add_upsample=add_upsample, 65 | resnet_eps=resnet_eps, 66 | resnet_act_fn=resnet_act_fn, 67 | resnet_groups=resnet_groups, 68 | resnet_time_scale_shift=resnet_time_scale_shift, 69 | ) 70 | elif up_block_type == "CrossAttnUpBlock2D": 71 | if cross_attention_dim is None: 72 | raise ValueError( 73 | "cross_attention_dim must be specified for CrossAttnUpBlock2D" 74 | ) 75 | return CrossAttnUpBlock2D( 76 | num_layers=num_layers, 77 | transformer_layers_per_block=transformer_layers_per_block, 78 | in_channels=in_channels, 79 | out_channels=out_channels, 80 | prev_output_channel=prev_output_channel, 81 | temb_channels=temb_channels, 82 | resolution_idx=resolution_idx, 83 | dropout=dropout, 84 | add_upsample=add_upsample, 85 | resnet_eps=resnet_eps, 86 | resnet_act_fn=resnet_act_fn, 87 | resnet_groups=resnet_groups, 88 | cross_attention_dim=cross_attention_dim, 89 | num_attention_heads=num_attention_heads, 90 | dual_cross_attention=dual_cross_attention, 91 | use_linear_projection=use_linear_projection, 92 | only_cross_attention=only_cross_attention, 93 | upcast_attention=upcast_attention, 94 | resnet_time_scale_shift=resnet_time_scale_shift, 95 | attention_type=attention_type, 96 | ) 97 | 98 | # raise ValueError(f"{up_block_type} does not exist.") 99 | raise NotImplementedError 100 | 101 | 102 | class CrossAttnUpBlock2D(SCTuner): 103 | def __init__( 104 | self, 105 | in_channels: int, 106 | out_channels: int, 107 | prev_output_channel: int, 108 | temb_channels: int, 109 | resolution_idx: Optional[int] = None, 110 | dropout: float = 0.0, 111 | num_layers: int = 1, 112 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 113 | resnet_eps: float = 1e-6, 114 | resnet_time_scale_shift: str = "default", 115 | resnet_act_fn: str = "swish", 116 | resnet_groups: int = 32, 117 | resnet_pre_norm: bool = True, 118 | num_attention_heads: int = 1, 119 | cross_attention_dim: int = 1280, 120 | output_scale_factor: float = 1.0, 121 | add_upsample: bool = True, 122 | dual_cross_attention: bool = False, 123 | use_linear_projection: bool = False, 124 | only_cross_attention: bool = False, 125 | upcast_attention: bool = False, 126 | attention_type: str = "default", 127 | ): 128 | super().__init__() 129 | resnets = [] 130 | attentions = [] 131 | self.res_skip_channels = [] 132 | 133 | self.has_cross_attention = True 134 | self.num_attention_heads = num_attention_heads 135 | 136 | if isinstance(transformer_layers_per_block, int): 137 | transformer_layers_per_block = [transformer_layers_per_block] * num_layers 138 | 139 | for i in range(num_layers): 140 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 141 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 142 | self.res_skip_channels.append(res_skip_channels) 143 | resnets.append( 144 | ResnetBlock2D( 145 | in_channels=resnet_in_channels + res_skip_channels, 146 | out_channels=out_channels, 147 | temb_channels=temb_channels, 148 | eps=resnet_eps, 149 | groups=resnet_groups, 150 | dropout=dropout, 151 | time_embedding_norm=resnet_time_scale_shift, 152 | non_linearity=resnet_act_fn, 153 | output_scale_factor=output_scale_factor, 154 | pre_norm=resnet_pre_norm, 155 | ) 156 | ) 157 | if not dual_cross_attention: 158 | attentions.append( 159 | Transformer2DModel( 160 | num_attention_heads, 161 | out_channels // num_attention_heads, 162 | in_channels=out_channels, 163 | num_layers=transformer_layers_per_block[i], 164 | cross_attention_dim=cross_attention_dim, 165 | norm_num_groups=resnet_groups, 166 | use_linear_projection=use_linear_projection, 167 | only_cross_attention=only_cross_attention, 168 | upcast_attention=upcast_attention, 169 | attention_type=attention_type, 170 | ) 171 | ) 172 | else: 173 | attentions.append( 174 | DualTransformer2DModel( 175 | num_attention_heads, 176 | out_channels // num_attention_heads, 177 | in_channels=out_channels, 178 | num_layers=1, 179 | cross_attention_dim=cross_attention_dim, 180 | norm_num_groups=resnet_groups, 181 | ) 182 | ) 183 | self.attentions = nn.ModuleList(attentions) 184 | self.resnets = nn.ModuleList(resnets) 185 | 186 | if add_upsample: 187 | self.upsamplers = nn.ModuleList( 188 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 189 | ) 190 | else: 191 | self.upsamplers = None 192 | 193 | self.gradient_checkpointing = False 194 | self.resolution_idx = resolution_idx 195 | 196 | def forward( 197 | self, 198 | hidden_states: torch.FloatTensor, 199 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], 200 | temb: Optional[torch.FloatTensor] = None, 201 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 202 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 203 | upsample_size: Optional[int] = None, 204 | attention_mask: Optional[torch.FloatTensor] = None, 205 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 206 | ) -> torch.FloatTensor: 207 | lora_scale = ( 208 | cross_attention_kwargs.get("scale", 1.0) 209 | if cross_attention_kwargs is not None 210 | else 1.0 211 | ) 212 | is_freeu_enabled = ( 213 | getattr(self, "s1", None) 214 | and getattr(self, "s2", None) 215 | and getattr(self, "b1", None) 216 | and getattr(self, "b2", None) 217 | ) 218 | 219 | for i, (resnet, attn) in enumerate(zip(self.resnets, self.attentions)): 220 | # pop res hidden states 221 | res_hidden_states = res_hidden_states_tuple[-1] 222 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 223 | 224 | # FreeU: Only operate on the first two stages 225 | if is_freeu_enabled: 226 | hidden_states, res_hidden_states = apply_freeu( 227 | self.resolution_idx, 228 | hidden_states, 229 | res_hidden_states, 230 | s1=self.s1, 231 | s2=self.s2, 232 | b1=self.b1, 233 | b2=self.b2, 234 | ) 235 | if self.sc_tuners is not None: 236 | res_hidden_states = self.sc_tuners[i](res_hidden_states) 237 | 238 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 239 | 240 | if self.training and self.gradient_checkpointing: 241 | 242 | def create_custom_forward(module, return_dict=None): 243 | def custom_forward(*inputs): 244 | if return_dict is not None: 245 | return module(*inputs, return_dict=return_dict) 246 | else: 247 | return module(*inputs) 248 | 249 | return custom_forward 250 | 251 | ckpt_kwargs: Dict[str, Any] = ( 252 | {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 253 | ) 254 | hidden_states = torch.utils.checkpoint.checkpoint( 255 | create_custom_forward(resnet), 256 | hidden_states, 257 | temb, 258 | **ckpt_kwargs, 259 | ) 260 | hidden_states = attn( 261 | hidden_states, 262 | encoder_hidden_states=encoder_hidden_states, 263 | cross_attention_kwargs=cross_attention_kwargs, 264 | attention_mask=attention_mask, 265 | encoder_attention_mask=encoder_attention_mask, 266 | return_dict=False, 267 | )[0] 268 | else: 269 | hidden_states = resnet(hidden_states, temb, scale=lora_scale) 270 | hidden_states = attn( 271 | hidden_states, 272 | encoder_hidden_states=encoder_hidden_states, 273 | cross_attention_kwargs=cross_attention_kwargs, 274 | attention_mask=attention_mask, 275 | encoder_attention_mask=encoder_attention_mask, 276 | return_dict=False, 277 | )[0] 278 | 279 | if self.upsamplers is not None: 280 | for upsampler in self.upsamplers: 281 | hidden_states = upsampler( 282 | hidden_states, upsample_size, scale=lora_scale 283 | ) 284 | 285 | return hidden_states 286 | 287 | 288 | class UpBlock2D(SCTuner): 289 | def __init__( 290 | self, 291 | in_channels: int, 292 | prev_output_channel: int, 293 | out_channels: int, 294 | temb_channels: int, 295 | resolution_idx: Optional[int] = None, 296 | dropout: float = 0.0, 297 | num_layers: int = 1, 298 | resnet_eps: float = 1e-6, 299 | resnet_time_scale_shift: str = "default", 300 | resnet_act_fn: str = "swish", 301 | resnet_groups: int = 32, 302 | resnet_pre_norm: bool = True, 303 | output_scale_factor: float = 1.0, 304 | add_upsample: bool = True, 305 | ): 306 | super().__init__() 307 | resnets = [] 308 | self.res_skip_channels = [] 309 | 310 | for i in range(num_layers): 311 | res_skip_channels = in_channels if (i == num_layers - 1) else out_channels 312 | resnet_in_channels = prev_output_channel if i == 0 else out_channels 313 | self.res_skip_channels.append(res_skip_channels) 314 | resnets.append( 315 | ResnetBlock2D( 316 | in_channels=resnet_in_channels + res_skip_channels, 317 | out_channels=out_channels, 318 | temb_channels=temb_channels, 319 | eps=resnet_eps, 320 | groups=resnet_groups, 321 | dropout=dropout, 322 | time_embedding_norm=resnet_time_scale_shift, 323 | non_linearity=resnet_act_fn, 324 | output_scale_factor=output_scale_factor, 325 | pre_norm=resnet_pre_norm, 326 | ) 327 | ) 328 | 329 | self.resnets = nn.ModuleList(resnets) 330 | 331 | if add_upsample: 332 | self.upsamplers = nn.ModuleList( 333 | [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)] 334 | ) 335 | else: 336 | self.upsamplers = None 337 | 338 | self.gradient_checkpointing = False 339 | self.resolution_idx = resolution_idx 340 | 341 | def forward( 342 | self, 343 | hidden_states: torch.FloatTensor, 344 | res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], 345 | temb: Optional[torch.FloatTensor] = None, 346 | upsample_size: Optional[int] = None, 347 | scale: float = 1.0, 348 | ) -> torch.FloatTensor: 349 | is_freeu_enabled = ( 350 | getattr(self, "s1", None) 351 | and getattr(self, "s2", None) 352 | and getattr(self, "b1", None) 353 | and getattr(self, "b2", None) 354 | ) 355 | 356 | for i, resnet in enumerate(self.resnets): 357 | # pop res hidden states 358 | res_hidden_states = res_hidden_states_tuple[-1] 359 | res_hidden_states_tuple = res_hidden_states_tuple[:-1] 360 | 361 | # FreeU: Only operate on the first two stages 362 | if is_freeu_enabled: 363 | hidden_states, res_hidden_states = apply_freeu( 364 | self.resolution_idx, 365 | hidden_states, 366 | res_hidden_states, 367 | s1=self.s1, 368 | s2=self.s2, 369 | b1=self.b1, 370 | b2=self.b2, 371 | ) 372 | if self.sc_tuners is not None: 373 | res_hidden_states = self.sc_tuners[i](res_hidden_states) 374 | 375 | hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) 376 | 377 | if self.training and self.gradient_checkpointing: 378 | 379 | def create_custom_forward(module): 380 | def custom_forward(*inputs): 381 | return module(*inputs) 382 | 383 | return custom_forward 384 | 385 | if is_torch_version(">=", "1.11.0"): 386 | hidden_states = torch.utils.checkpoint.checkpoint( 387 | create_custom_forward(resnet), 388 | hidden_states, 389 | temb, 390 | use_reentrant=False, 391 | ) 392 | else: 393 | hidden_states = torch.utils.checkpoint.checkpoint( 394 | create_custom_forward(resnet), hidden_states, temb 395 | ) 396 | else: 397 | hidden_states = resnet(hidden_states, temb, scale=scale) 398 | 399 | if self.upsamplers is not None: 400 | for upsampler in self.upsamplers: 401 | hidden_states = upsampler(hidden_states, upsample_size, scale=scale) 402 | 403 | return hidden_states 404 | -------------------------------------------------------------------------------- /scedit_pytorch/diffusers_modules/unet_2d_condition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.checkpoint 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import UNet2DConditionLoadersMixin 23 | from diffusers.utils import ( 24 | USE_PEFT_BACKEND, 25 | BaseOutput, 26 | deprecate, 27 | logging, 28 | scale_lora_layers, 29 | unscale_lora_layers, 30 | ) 31 | from diffusers.models.activations import get_activation 32 | from diffusers.models.attention_processor import ( 33 | ADDED_KV_ATTENTION_PROCESSORS, 34 | CROSS_ATTENTION_PROCESSORS, 35 | AttentionProcessor, 36 | AttnAddedKVProcessor, 37 | AttnProcessor, 38 | ) 39 | from diffusers.models.embeddings import ( 40 | GaussianFourierProjection, 41 | ImageHintTimeEmbedding, 42 | ImageProjection, 43 | ImageTimeEmbedding, 44 | PositionNet, 45 | TextImageProjection, 46 | TextImageTimeEmbedding, 47 | TextTimeEmbedding, 48 | TimestepEmbedding, 49 | Timesteps, 50 | ) 51 | from diffusers.models.modeling_utils import ModelMixin 52 | from diffusers.models.unet_2d_blocks import ( 53 | UNetMidBlock2D, 54 | UNetMidBlock2DCrossAttn, 55 | UNetMidBlock2DSimpleCrossAttn, 56 | get_down_block, 57 | ) 58 | from .unet_2d_blocks import get_up_block 59 | from ..scedit import SCTunerLinearLayer 60 | 61 | 62 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 63 | 64 | 65 | @dataclass 66 | class UNet2DConditionOutput(BaseOutput): 67 | """ 68 | The output of [`UNet2DConditionModel`]. 69 | 70 | Args: 71 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 72 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 73 | """ 74 | 75 | sample: torch.FloatTensor = None 76 | 77 | 78 | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 79 | r""" 80 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample 81 | shaped output. 82 | 83 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 84 | for all models (such as downloading or saving). 85 | 86 | Parameters: 87 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 88 | Height and width of input/output sample. 89 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. 90 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 91 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 92 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 93 | Whether to flip the sin to cos in the time embedding. 94 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 95 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 96 | The tuple of downsample blocks to use. 97 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): 98 | Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or 99 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. 100 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): 101 | The tuple of upsample blocks to use. 102 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): 103 | Whether to include self-attention in the basic transformer blocks, see 104 | [`~models.attention.BasicTransformerBlock`]. 105 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 106 | The tuple of output channels for each block. 107 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 108 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 109 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 110 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 111 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 112 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 113 | If `None`, normalization and activation layers is skipped in post-processing. 114 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 115 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 116 | The dimension of the cross attention features. 117 | transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): 118 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 119 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 120 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 121 | reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): 122 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling 123 | blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for 124 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 125 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 126 | encoder_hid_dim (`int`, *optional*, defaults to None): 127 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 128 | dimension to `cross_attention_dim`. 129 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 130 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 131 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 132 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 133 | num_attention_heads (`int`, *optional*): 134 | The number of attention heads. If not defined, defaults to `attention_head_dim` 135 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 136 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. 137 | class_embed_type (`str`, *optional*, defaults to `None`): 138 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, 139 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 140 | addition_embed_type (`str`, *optional*, defaults to `None`): 141 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 142 | "text". "text" will use the `TextTimeEmbedding` layer. 143 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`): 144 | Dimension for the timestep embeddings. 145 | num_class_embeds (`int`, *optional*, defaults to `None`): 146 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 147 | class conditioning with `class_embed_type` equal to `None`. 148 | time_embedding_type (`str`, *optional*, defaults to `positional`): 149 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. 150 | time_embedding_dim (`int`, *optional*, defaults to `None`): 151 | An optional override for the dimension of the projected time embedding. 152 | time_embedding_act_fn (`str`, *optional*, defaults to `None`): 153 | Optional activation function to use only once on the time embeddings before they are passed to the rest of 154 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. 155 | timestep_post_act (`str`, *optional*, defaults to `None`): 156 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. 157 | time_cond_proj_dim (`int`, *optional*, defaults to `None`): 158 | The dimension of `cond_proj` layer in the timestep embedding. 159 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, 160 | *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, 161 | *optional*): The dimension of the `class_labels` input when 162 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`. 163 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time 164 | embeddings with the class embeddings. 165 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): 166 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If 167 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the 168 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` 169 | otherwise. 170 | """ 171 | 172 | _supports_gradient_checkpointing = True 173 | 174 | @register_to_config 175 | def __init__( 176 | self, 177 | sample_size: Optional[int] = None, 178 | in_channels: int = 4, 179 | out_channels: int = 4, 180 | center_input_sample: bool = False, 181 | flip_sin_to_cos: bool = True, 182 | freq_shift: int = 0, 183 | down_block_types: Tuple[str] = ( 184 | "CrossAttnDownBlock2D", 185 | "CrossAttnDownBlock2D", 186 | "CrossAttnDownBlock2D", 187 | "DownBlock2D", 188 | ), 189 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 190 | up_block_types: Tuple[str] = ( 191 | "UpBlock2D", 192 | "CrossAttnUpBlock2D", 193 | "CrossAttnUpBlock2D", 194 | "CrossAttnUpBlock2D", 195 | ), 196 | only_cross_attention: Union[bool, Tuple[bool]] = False, 197 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 198 | layers_per_block: Union[int, Tuple[int]] = 2, 199 | downsample_padding: int = 1, 200 | mid_block_scale_factor: float = 1, 201 | dropout: float = 0.0, 202 | act_fn: str = "silu", 203 | norm_num_groups: Optional[int] = 32, 204 | norm_eps: float = 1e-5, 205 | cross_attention_dim: Union[int, Tuple[int]] = 1280, 206 | transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, 207 | reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, 208 | encoder_hid_dim: Optional[int] = None, 209 | encoder_hid_dim_type: Optional[str] = None, 210 | attention_head_dim: Union[int, Tuple[int]] = 8, 211 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 212 | dual_cross_attention: bool = False, 213 | use_linear_projection: bool = False, 214 | class_embed_type: Optional[str] = None, 215 | addition_embed_type: Optional[str] = None, 216 | addition_time_embed_dim: Optional[int] = None, 217 | num_class_embeds: Optional[int] = None, 218 | upcast_attention: bool = False, 219 | resnet_time_scale_shift: str = "default", 220 | resnet_skip_time_act: bool = False, 221 | resnet_out_scale_factor: int = 1.0, 222 | time_embedding_type: str = "positional", 223 | time_embedding_dim: Optional[int] = None, 224 | time_embedding_act_fn: Optional[str] = None, 225 | timestep_post_act: Optional[str] = None, 226 | time_cond_proj_dim: Optional[int] = None, 227 | conv_in_kernel: int = 3, 228 | conv_out_kernel: int = 3, 229 | projection_class_embeddings_input_dim: Optional[int] = None, 230 | attention_type: str = "default", 231 | class_embeddings_concat: bool = False, 232 | mid_block_only_cross_attention: Optional[bool] = None, 233 | cross_attention_norm: Optional[str] = None, 234 | addition_embed_type_num_heads=64, 235 | ): 236 | super().__init__() 237 | 238 | self.sample_size = sample_size 239 | 240 | if num_attention_heads is not None: 241 | raise ValueError( 242 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." 243 | ) 244 | 245 | # If `num_attention_heads` is not defined (which is the case for most models) 246 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 247 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 248 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 249 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 250 | # which is why we correct for the naming here. 251 | num_attention_heads = num_attention_heads or attention_head_dim 252 | 253 | # Check inputs 254 | if len(down_block_types) != len(up_block_types): 255 | raise ValueError( 256 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 257 | ) 258 | 259 | if len(block_out_channels) != len(down_block_types): 260 | raise ValueError( 261 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 262 | ) 263 | 264 | if not isinstance(only_cross_attention, bool) and len( 265 | only_cross_attention 266 | ) != len(down_block_types): 267 | raise ValueError( 268 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 269 | ) 270 | 271 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len( 272 | down_block_types 273 | ): 274 | raise ValueError( 275 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 276 | ) 277 | 278 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len( 279 | down_block_types 280 | ): 281 | raise ValueError( 282 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 283 | ) 284 | 285 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len( 286 | down_block_types 287 | ): 288 | raise ValueError( 289 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 290 | ) 291 | 292 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len( 293 | down_block_types 294 | ): 295 | raise ValueError( 296 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 297 | ) 298 | if ( 299 | isinstance(transformer_layers_per_block, list) 300 | and reverse_transformer_layers_per_block is None 301 | ): 302 | for layer_number_per_block in transformer_layers_per_block: 303 | if isinstance(layer_number_per_block, list): 304 | raise ValueError( 305 | "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet." 306 | ) 307 | 308 | # input 309 | conv_in_padding = (conv_in_kernel - 1) // 2 310 | self.conv_in = nn.Conv2d( 311 | in_channels, 312 | block_out_channels[0], 313 | kernel_size=conv_in_kernel, 314 | padding=conv_in_padding, 315 | ) 316 | 317 | # time 318 | if time_embedding_type == "fourier": 319 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 320 | if time_embed_dim % 2 != 0: 321 | raise ValueError( 322 | f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}." 323 | ) 324 | self.time_proj = GaussianFourierProjection( 325 | time_embed_dim // 2, 326 | set_W_to_weight=False, 327 | log=False, 328 | flip_sin_to_cos=flip_sin_to_cos, 329 | ) 330 | timestep_input_dim = time_embed_dim 331 | elif time_embedding_type == "positional": 332 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 333 | 334 | self.time_proj = Timesteps( 335 | block_out_channels[0], flip_sin_to_cos, freq_shift 336 | ) 337 | timestep_input_dim = block_out_channels[0] 338 | else: 339 | raise ValueError( 340 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." 341 | ) 342 | 343 | self.time_embedding = TimestepEmbedding( 344 | timestep_input_dim, 345 | time_embed_dim, 346 | act_fn=act_fn, 347 | post_act_fn=timestep_post_act, 348 | cond_proj_dim=time_cond_proj_dim, 349 | ) 350 | 351 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 352 | encoder_hid_dim_type = "text_proj" 353 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 354 | logger.info( 355 | "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined." 356 | ) 357 | 358 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 359 | raise ValueError( 360 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 361 | ) 362 | 363 | if encoder_hid_dim_type == "text_proj": 364 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 365 | elif encoder_hid_dim_type == "text_image_proj": 366 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 367 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 368 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` 369 | self.encoder_hid_proj = TextImageProjection( 370 | text_embed_dim=encoder_hid_dim, 371 | image_embed_dim=cross_attention_dim, 372 | cross_attention_dim=cross_attention_dim, 373 | ) 374 | elif encoder_hid_dim_type == "image_proj": 375 | # Kandinsky 2.2 376 | self.encoder_hid_proj = ImageProjection( 377 | image_embed_dim=encoder_hid_dim, 378 | cross_attention_dim=cross_attention_dim, 379 | ) 380 | elif encoder_hid_dim_type is not None: 381 | raise ValueError( 382 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." 383 | ) 384 | else: 385 | self.encoder_hid_proj = None 386 | 387 | # class embedding 388 | if class_embed_type is None and num_class_embeds is not None: 389 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 390 | elif class_embed_type == "timestep": 391 | self.class_embedding = TimestepEmbedding( 392 | timestep_input_dim, time_embed_dim, act_fn=act_fn 393 | ) 394 | elif class_embed_type == "identity": 395 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 396 | elif class_embed_type == "projection": 397 | if projection_class_embeddings_input_dim is None: 398 | raise ValueError( 399 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 400 | ) 401 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 402 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 403 | # 2. it projects from an arbitrary input dimension. 404 | # 405 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 406 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 407 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 408 | self.class_embedding = TimestepEmbedding( 409 | projection_class_embeddings_input_dim, time_embed_dim 410 | ) 411 | elif class_embed_type == "simple_projection": 412 | if projection_class_embeddings_input_dim is None: 413 | raise ValueError( 414 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" 415 | ) 416 | self.class_embedding = nn.Linear( 417 | projection_class_embeddings_input_dim, time_embed_dim 418 | ) 419 | else: 420 | self.class_embedding = None 421 | 422 | if addition_embed_type == "text": 423 | if encoder_hid_dim is not None: 424 | text_time_embedding_from_dim = encoder_hid_dim 425 | else: 426 | text_time_embedding_from_dim = cross_attention_dim 427 | 428 | self.add_embedding = TextTimeEmbedding( 429 | text_time_embedding_from_dim, 430 | time_embed_dim, 431 | num_heads=addition_embed_type_num_heads, 432 | ) 433 | elif addition_embed_type == "text_image": 434 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 435 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 436 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` 437 | self.add_embedding = TextImageTimeEmbedding( 438 | text_embed_dim=cross_attention_dim, 439 | image_embed_dim=cross_attention_dim, 440 | time_embed_dim=time_embed_dim, 441 | ) 442 | elif addition_embed_type == "text_time": 443 | self.add_time_proj = Timesteps( 444 | addition_time_embed_dim, flip_sin_to_cos, freq_shift 445 | ) 446 | self.add_embedding = TimestepEmbedding( 447 | projection_class_embeddings_input_dim, time_embed_dim 448 | ) 449 | elif addition_embed_type == "image": 450 | # Kandinsky 2.2 451 | self.add_embedding = ImageTimeEmbedding( 452 | image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim 453 | ) 454 | elif addition_embed_type == "image_hint": 455 | # Kandinsky 2.2 ControlNet 456 | self.add_embedding = ImageHintTimeEmbedding( 457 | image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim 458 | ) 459 | elif addition_embed_type is not None: 460 | raise ValueError( 461 | f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'." 462 | ) 463 | 464 | if time_embedding_act_fn is None: 465 | self.time_embed_act = None 466 | else: 467 | self.time_embed_act = get_activation(time_embedding_act_fn) 468 | 469 | self.down_blocks = nn.ModuleList([]) 470 | self.up_blocks = nn.ModuleList([]) 471 | 472 | if isinstance(only_cross_attention, bool): 473 | if mid_block_only_cross_attention is None: 474 | mid_block_only_cross_attention = only_cross_attention 475 | 476 | only_cross_attention = [only_cross_attention] * len(down_block_types) 477 | 478 | if mid_block_only_cross_attention is None: 479 | mid_block_only_cross_attention = False 480 | 481 | if isinstance(num_attention_heads, int): 482 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 483 | 484 | if isinstance(attention_head_dim, int): 485 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 486 | 487 | if isinstance(cross_attention_dim, int): 488 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 489 | 490 | if isinstance(layers_per_block, int): 491 | layers_per_block = [layers_per_block] * len(down_block_types) 492 | 493 | if isinstance(transformer_layers_per_block, int): 494 | transformer_layers_per_block = [transformer_layers_per_block] * len( 495 | down_block_types 496 | ) 497 | 498 | if class_embeddings_concat: 499 | # The time embeddings are concatenated with the class embeddings. The dimension of the 500 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the 501 | # regular time embeddings 502 | blocks_time_embed_dim = time_embed_dim * 2 503 | else: 504 | blocks_time_embed_dim = time_embed_dim 505 | 506 | # down 507 | output_channel = block_out_channels[0] 508 | for i, down_block_type in enumerate(down_block_types): 509 | input_channel = output_channel 510 | output_channel = block_out_channels[i] 511 | is_final_block = i == len(block_out_channels) - 1 512 | 513 | down_block = get_down_block( 514 | down_block_type, 515 | num_layers=layers_per_block[i], 516 | transformer_layers_per_block=transformer_layers_per_block[i], 517 | in_channels=input_channel, 518 | out_channels=output_channel, 519 | temb_channels=blocks_time_embed_dim, 520 | add_downsample=not is_final_block, 521 | resnet_eps=norm_eps, 522 | resnet_act_fn=act_fn, 523 | resnet_groups=norm_num_groups, 524 | cross_attention_dim=cross_attention_dim[i], 525 | num_attention_heads=num_attention_heads[i], 526 | downsample_padding=downsample_padding, 527 | dual_cross_attention=dual_cross_attention, 528 | use_linear_projection=use_linear_projection, 529 | only_cross_attention=only_cross_attention[i], 530 | upcast_attention=upcast_attention, 531 | resnet_time_scale_shift=resnet_time_scale_shift, 532 | attention_type=attention_type, 533 | resnet_skip_time_act=resnet_skip_time_act, 534 | resnet_out_scale_factor=resnet_out_scale_factor, 535 | cross_attention_norm=cross_attention_norm, 536 | attention_head_dim=attention_head_dim[i] 537 | if attention_head_dim[i] is not None 538 | else output_channel, 539 | dropout=dropout, 540 | ) 541 | self.down_blocks.append(down_block) 542 | 543 | # mid 544 | if mid_block_type == "UNetMidBlock2DCrossAttn": 545 | self.mid_block = UNetMidBlock2DCrossAttn( 546 | transformer_layers_per_block=transformer_layers_per_block[-1], 547 | in_channels=block_out_channels[-1], 548 | temb_channels=blocks_time_embed_dim, 549 | dropout=dropout, 550 | resnet_eps=norm_eps, 551 | resnet_act_fn=act_fn, 552 | output_scale_factor=mid_block_scale_factor, 553 | resnet_time_scale_shift=resnet_time_scale_shift, 554 | cross_attention_dim=cross_attention_dim[-1], 555 | num_attention_heads=num_attention_heads[-1], 556 | resnet_groups=norm_num_groups, 557 | dual_cross_attention=dual_cross_attention, 558 | use_linear_projection=use_linear_projection, 559 | upcast_attention=upcast_attention, 560 | attention_type=attention_type, 561 | ) 562 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": 563 | self.mid_block = UNetMidBlock2DSimpleCrossAttn( 564 | in_channels=block_out_channels[-1], 565 | temb_channels=blocks_time_embed_dim, 566 | dropout=dropout, 567 | resnet_eps=norm_eps, 568 | resnet_act_fn=act_fn, 569 | output_scale_factor=mid_block_scale_factor, 570 | cross_attention_dim=cross_attention_dim[-1], 571 | attention_head_dim=attention_head_dim[-1], 572 | resnet_groups=norm_num_groups, 573 | resnet_time_scale_shift=resnet_time_scale_shift, 574 | skip_time_act=resnet_skip_time_act, 575 | only_cross_attention=mid_block_only_cross_attention, 576 | cross_attention_norm=cross_attention_norm, 577 | ) 578 | elif mid_block_type == "UNetMidBlock2D": 579 | self.mid_block = UNetMidBlock2D( 580 | in_channels=block_out_channels[-1], 581 | temb_channels=blocks_time_embed_dim, 582 | dropout=dropout, 583 | num_layers=0, 584 | resnet_eps=norm_eps, 585 | resnet_act_fn=act_fn, 586 | output_scale_factor=mid_block_scale_factor, 587 | resnet_groups=norm_num_groups, 588 | resnet_time_scale_shift=resnet_time_scale_shift, 589 | add_attention=False, 590 | ) 591 | elif mid_block_type is None: 592 | self.mid_block = None 593 | else: 594 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 595 | 596 | # count how many layers upsample the images 597 | self.num_upsamplers = 0 598 | 599 | # up 600 | reversed_block_out_channels = list(reversed(block_out_channels)) 601 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 602 | reversed_layers_per_block = list(reversed(layers_per_block)) 603 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 604 | reversed_transformer_layers_per_block = ( 605 | list(reversed(transformer_layers_per_block)) 606 | if reverse_transformer_layers_per_block is None 607 | else reverse_transformer_layers_per_block 608 | ) 609 | only_cross_attention = list(reversed(only_cross_attention)) 610 | 611 | output_channel = reversed_block_out_channels[0] 612 | for i, up_block_type in enumerate(up_block_types): 613 | is_final_block = i == len(block_out_channels) - 1 614 | 615 | prev_output_channel = output_channel 616 | output_channel = reversed_block_out_channels[i] 617 | input_channel = reversed_block_out_channels[ 618 | min(i + 1, len(block_out_channels) - 1) 619 | ] 620 | 621 | # add upsample block for all BUT final layer 622 | if not is_final_block: 623 | add_upsample = True 624 | self.num_upsamplers += 1 625 | else: 626 | add_upsample = False 627 | 628 | up_block = get_up_block( 629 | up_block_type, 630 | num_layers=reversed_layers_per_block[i] + 1, 631 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 632 | in_channels=input_channel, 633 | out_channels=output_channel, 634 | prev_output_channel=prev_output_channel, 635 | temb_channels=blocks_time_embed_dim, 636 | add_upsample=add_upsample, 637 | resnet_eps=norm_eps, 638 | resnet_act_fn=act_fn, 639 | resolution_idx=i, 640 | resnet_groups=norm_num_groups, 641 | cross_attention_dim=reversed_cross_attention_dim[i], 642 | num_attention_heads=reversed_num_attention_heads[i], 643 | dual_cross_attention=dual_cross_attention, 644 | use_linear_projection=use_linear_projection, 645 | only_cross_attention=only_cross_attention[i], 646 | upcast_attention=upcast_attention, 647 | resnet_time_scale_shift=resnet_time_scale_shift, 648 | attention_type=attention_type, 649 | resnet_skip_time_act=resnet_skip_time_act, 650 | resnet_out_scale_factor=resnet_out_scale_factor, 651 | cross_attention_norm=cross_attention_norm, 652 | attention_head_dim=attention_head_dim[i] 653 | if attention_head_dim[i] is not None 654 | else output_channel, 655 | dropout=dropout, 656 | ) 657 | self.up_blocks.append(up_block) 658 | prev_output_channel = output_channel 659 | 660 | # out 661 | if norm_num_groups is not None: 662 | self.conv_norm_out = nn.GroupNorm( 663 | num_channels=block_out_channels[0], 664 | num_groups=norm_num_groups, 665 | eps=norm_eps, 666 | ) 667 | 668 | self.conv_act = get_activation(act_fn) 669 | 670 | else: 671 | self.conv_norm_out = None 672 | self.conv_act = None 673 | 674 | conv_out_padding = (conv_out_kernel - 1) // 2 675 | self.conv_out = nn.Conv2d( 676 | block_out_channels[0], 677 | out_channels, 678 | kernel_size=conv_out_kernel, 679 | padding=conv_out_padding, 680 | ) 681 | 682 | if attention_type in ["gated", "gated-text-image"]: 683 | positive_len = 768 684 | if isinstance(cross_attention_dim, int): 685 | positive_len = cross_attention_dim 686 | elif isinstance(cross_attention_dim, tuple) or isinstance( 687 | cross_attention_dim, list 688 | ): 689 | positive_len = cross_attention_dim[0] 690 | 691 | feature_type = "text-only" if attention_type == "gated" else "text-image" 692 | self.position_net = PositionNet( 693 | positive_len=positive_len, 694 | out_dim=cross_attention_dim, 695 | feature_type=feature_type, 696 | ) 697 | self.has_sctuner = False 698 | 699 | @property 700 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 701 | r""" 702 | Returns: 703 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 704 | indexed by its weight name. 705 | """ 706 | # set recursively 707 | processors = {} 708 | 709 | def fn_recursive_add_processors( 710 | name: str, 711 | module: torch.nn.Module, 712 | processors: Dict[str, AttentionProcessor], 713 | ): 714 | if hasattr(module, "get_processor"): 715 | processors[f"{name}.processor"] = module.get_processor( 716 | return_deprecated_lora=True 717 | ) 718 | 719 | for sub_name, child in module.named_children(): 720 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 721 | 722 | return processors 723 | 724 | for name, module in self.named_children(): 725 | fn_recursive_add_processors(name, module, processors) 726 | 727 | return processors 728 | 729 | def set_attn_processor( 730 | self, 731 | processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], 732 | _remove_lora=False, 733 | ): 734 | r""" 735 | Sets the attention processor to use to compute attention. 736 | 737 | Parameters: 738 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 739 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 740 | for **all** `Attention` layers. 741 | 742 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 743 | processor. This is strongly recommended when setting trainable attention processors. 744 | 745 | """ 746 | count = len(self.attn_processors.keys()) 747 | 748 | if isinstance(processor, dict) and len(processor) != count: 749 | raise ValueError( 750 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 751 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 752 | ) 753 | 754 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 755 | if hasattr(module, "set_processor"): 756 | if not isinstance(processor, dict): 757 | module.set_processor(processor, _remove_lora=_remove_lora) 758 | else: 759 | module.set_processor( 760 | processor.pop(f"{name}.processor"), _remove_lora=_remove_lora 761 | ) 762 | 763 | for sub_name, child in module.named_children(): 764 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 765 | 766 | for name, module in self.named_children(): 767 | fn_recursive_attn_processor(name, module, processor) 768 | 769 | def set_default_attn_processor(self): 770 | """ 771 | Disables custom attention processors and sets the default attention implementation. 772 | """ 773 | if all( 774 | proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS 775 | for proc in self.attn_processors.values() 776 | ): 777 | processor = AttnAddedKVProcessor() 778 | elif all( 779 | proc.__class__ in CROSS_ATTENTION_PROCESSORS 780 | for proc in self.attn_processors.values() 781 | ): 782 | processor = AttnProcessor() 783 | else: 784 | raise ValueError( 785 | f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" 786 | ) 787 | 788 | self.set_attn_processor(processor, _remove_lora=True) 789 | 790 | def set_attention_slice(self, slice_size): 791 | r""" 792 | Enable sliced attention computation. 793 | 794 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 795 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 796 | 797 | Args: 798 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 799 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 800 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 801 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 802 | must be a multiple of `slice_size`. 803 | """ 804 | sliceable_head_dims = [] 805 | 806 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 807 | if hasattr(module, "set_attention_slice"): 808 | sliceable_head_dims.append(module.sliceable_head_dim) 809 | 810 | for child in module.children(): 811 | fn_recursive_retrieve_sliceable_dims(child) 812 | 813 | # retrieve number of attention layers 814 | for module in self.children(): 815 | fn_recursive_retrieve_sliceable_dims(module) 816 | 817 | num_sliceable_layers = len(sliceable_head_dims) 818 | 819 | if slice_size == "auto": 820 | # half the attention head size is usually a good trade-off between 821 | # speed and memory 822 | slice_size = [dim // 2 for dim in sliceable_head_dims] 823 | elif slice_size == "max": 824 | # make smallest slice possible 825 | slice_size = num_sliceable_layers * [1] 826 | 827 | slice_size = ( 828 | num_sliceable_layers * [slice_size] 829 | if not isinstance(slice_size, list) 830 | else slice_size 831 | ) 832 | 833 | if len(slice_size) != len(sliceable_head_dims): 834 | raise ValueError( 835 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 836 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 837 | ) 838 | 839 | for i in range(len(slice_size)): 840 | size = slice_size[i] 841 | dim = sliceable_head_dims[i] 842 | if size is not None and size > dim: 843 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 844 | 845 | # Recursively walk through all the children. 846 | # Any children which exposes the set_attention_slice method 847 | # gets the message 848 | def fn_recursive_set_attention_slice( 849 | module: torch.nn.Module, slice_size: List[int] 850 | ): 851 | if hasattr(module, "set_attention_slice"): 852 | module.set_attention_slice(slice_size.pop()) 853 | 854 | for child in module.children(): 855 | fn_recursive_set_attention_slice(child, slice_size) 856 | 857 | reversed_slice_size = list(reversed(slice_size)) 858 | for module in self.children(): 859 | fn_recursive_set_attention_slice(module, reversed_slice_size) 860 | 861 | def _set_gradient_checkpointing(self, module, value=False): 862 | if hasattr(module, "gradient_checkpointing"): 863 | module.gradient_checkpointing = value 864 | 865 | def enable_freeu(self, s1, s2, b1, b2): 866 | r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. 867 | 868 | The suffixes after the scaling factors represent the stage blocks where they are being applied. 869 | 870 | Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that 871 | are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. 872 | 873 | Args: 874 | s1 (`float`): 875 | Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to 876 | mitigate the "oversmoothing effect" in the enhanced denoising process. 877 | s2 (`float`): 878 | Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to 879 | mitigate the "oversmoothing effect" in the enhanced denoising process. 880 | b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. 881 | b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. 882 | """ 883 | for i, upsample_block in enumerate(self.up_blocks): 884 | setattr(upsample_block, "s1", s1) 885 | setattr(upsample_block, "s2", s2) 886 | setattr(upsample_block, "b1", b1) 887 | setattr(upsample_block, "b2", b2) 888 | 889 | def disable_freeu(self): 890 | """Disables the FreeU mechanism.""" 891 | freeu_keys = {"s1", "s2", "b1", "b2"} 892 | for i, upsample_block in enumerate(self.up_blocks): 893 | for k in freeu_keys: 894 | if ( 895 | hasattr(upsample_block, k) 896 | or getattr(upsample_block, k, None) is not None 897 | ): 898 | setattr(upsample_block, k, None) 899 | 900 | def set_sctuner(self, sctuner_module=SCTunerLinearLayer, **kwargs): 901 | for upsample_block in self.up_blocks: 902 | upsample_block.set_sctuner(sctuner_module=sctuner_module, **kwargs) 903 | self.has_sctuner = True 904 | 905 | def forward( 906 | self, 907 | sample: torch.FloatTensor, 908 | timestep: Union[torch.Tensor, float, int], 909 | encoder_hidden_states: torch.Tensor, 910 | class_labels: Optional[torch.Tensor] = None, 911 | timestep_cond: Optional[torch.Tensor] = None, 912 | attention_mask: Optional[torch.Tensor] = None, 913 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 914 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 915 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 916 | mid_block_additional_residual: Optional[torch.Tensor] = None, 917 | down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 918 | encoder_attention_mask: Optional[torch.Tensor] = None, 919 | return_dict: bool = True, 920 | ) -> Union[UNet2DConditionOutput, Tuple]: 921 | r""" 922 | The [`UNet2DConditionModel`] forward method. 923 | 924 | Args: 925 | sample (`torch.FloatTensor`): 926 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 927 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 928 | encoder_hidden_states (`torch.FloatTensor`): 929 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 930 | class_labels (`torch.Tensor`, *optional*, defaults to `None`): 931 | Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. 932 | timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): 933 | Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed 934 | through the `self.time_embedding` layer to obtain the timestep embeddings. 935 | attention_mask (`torch.Tensor`, *optional*, defaults to `None`): 936 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 937 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 938 | negative values to the attention scores corresponding to "discard" tokens. 939 | cross_attention_kwargs (`dict`, *optional*): 940 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 941 | `self.processor` in 942 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 943 | added_cond_kwargs: (`dict`, *optional*): 944 | A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that 945 | are passed along to the UNet blocks. 946 | down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): 947 | A tuple of tensors that if specified are added to the residuals of down unet blocks. 948 | mid_block_additional_residual: (`torch.Tensor`, *optional*): 949 | A tensor that if specified is added to the residual of the middle unet block. 950 | encoder_attention_mask (`torch.Tensor`): 951 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If 952 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, 953 | which adds large negative values to the attention scores corresponding to "discard" tokens. 954 | return_dict (`bool`, *optional*, defaults to `True`): 955 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 956 | tuple. 957 | cross_attention_kwargs (`dict`, *optional*): 958 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. 959 | added_cond_kwargs: (`dict`, *optional*): 960 | A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that 961 | are passed along to the UNet blocks. 962 | down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): 963 | additional residuals to be added to UNet long skip connections from down blocks to up blocks for 964 | example from ControlNet side model(s) 965 | mid_block_additional_residual (`torch.Tensor`, *optional*): 966 | additional residual to be added to UNet mid block output, for example from ControlNet side model 967 | down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): 968 | additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) 969 | 970 | Returns: 971 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 972 | If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise 973 | a `tuple` is returned where the first element is the sample tensor. 974 | """ 975 | # By default samples have to be AT least a multiple of the overall upsampling factor. 976 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 977 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 978 | # on the fly if necessary. 979 | default_overall_up_factor = 2**self.num_upsamplers 980 | 981 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 982 | forward_upsample_size = False 983 | upsample_size = None 984 | 985 | for dim in sample.shape[-2:]: 986 | if dim % default_overall_up_factor != 0: 987 | # Forward upsample size to force interpolation output size. 988 | forward_upsample_size = True 989 | break 990 | 991 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 992 | # expects mask of shape: 993 | # [batch, key_tokens] 994 | # adds singleton query_tokens dimension: 995 | # [batch, 1, key_tokens] 996 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 997 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 998 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 999 | if attention_mask is not None: 1000 | # assume that mask is expressed as: 1001 | # (1 = keep, 0 = discard) 1002 | # convert mask into a bias that can be added to attention scores: 1003 | # (keep = +0, discard = -10000.0) 1004 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 1005 | attention_mask = attention_mask.unsqueeze(1) 1006 | 1007 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 1008 | if encoder_attention_mask is not None: 1009 | encoder_attention_mask = ( 1010 | 1 - encoder_attention_mask.to(sample.dtype) 1011 | ) * -10000.0 1012 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 1013 | 1014 | # 0. center input if necessary 1015 | if self.config.center_input_sample: 1016 | sample = 2 * sample - 1.0 1017 | 1018 | # 1. time 1019 | timesteps = timestep 1020 | if not torch.is_tensor(timesteps): 1021 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 1022 | # This would be a good case for the `match` statement (Python 3.10+) 1023 | is_mps = sample.device.type == "mps" 1024 | if isinstance(timestep, float): 1025 | dtype = torch.float32 if is_mps else torch.float64 1026 | else: 1027 | dtype = torch.int32 if is_mps else torch.int64 1028 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 1029 | elif len(timesteps.shape) == 0: 1030 | timesteps = timesteps[None].to(sample.device) 1031 | 1032 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 1033 | timesteps = timesteps.expand(sample.shape[0]) 1034 | 1035 | t_emb = self.time_proj(timesteps) 1036 | 1037 | # `Timesteps` does not contain any weights and will always return f32 tensors 1038 | # but time_embedding might actually be running in fp16. so we need to cast here. 1039 | # there might be better ways to encapsulate this. 1040 | t_emb = t_emb.to(dtype=sample.dtype) 1041 | 1042 | emb = self.time_embedding(t_emb, timestep_cond) 1043 | aug_emb = None 1044 | 1045 | if self.class_embedding is not None: 1046 | if class_labels is None: 1047 | raise ValueError( 1048 | "class_labels should be provided when num_class_embeds > 0" 1049 | ) 1050 | 1051 | if self.config.class_embed_type == "timestep": 1052 | class_labels = self.time_proj(class_labels) 1053 | 1054 | # `Timesteps` does not contain any weights and will always return f32 tensors 1055 | # there might be better ways to encapsulate this. 1056 | class_labels = class_labels.to(dtype=sample.dtype) 1057 | 1058 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) 1059 | 1060 | if self.config.class_embeddings_concat: 1061 | emb = torch.cat([emb, class_emb], dim=-1) 1062 | else: 1063 | emb = emb + class_emb 1064 | 1065 | if self.config.addition_embed_type == "text": 1066 | aug_emb = self.add_embedding(encoder_hidden_states) 1067 | elif self.config.addition_embed_type == "text_image": 1068 | # Kandinsky 2.1 - style 1069 | if "image_embeds" not in added_cond_kwargs: 1070 | raise ValueError( 1071 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 1072 | ) 1073 | 1074 | image_embs = added_cond_kwargs.get("image_embeds") 1075 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) 1076 | aug_emb = self.add_embedding(text_embs, image_embs) 1077 | elif self.config.addition_embed_type == "text_time": 1078 | # SDXL - style 1079 | if "text_embeds" not in added_cond_kwargs: 1080 | raise ValueError( 1081 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 1082 | ) 1083 | text_embeds = added_cond_kwargs.get("text_embeds") 1084 | if "time_ids" not in added_cond_kwargs: 1085 | raise ValueError( 1086 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 1087 | ) 1088 | time_ids = added_cond_kwargs.get("time_ids") 1089 | time_embeds = self.add_time_proj(time_ids.flatten()) 1090 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 1091 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 1092 | add_embeds = add_embeds.to(emb.dtype) 1093 | aug_emb = self.add_embedding(add_embeds) 1094 | elif self.config.addition_embed_type == "image": 1095 | # Kandinsky 2.2 - style 1096 | if "image_embeds" not in added_cond_kwargs: 1097 | raise ValueError( 1098 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 1099 | ) 1100 | image_embs = added_cond_kwargs.get("image_embeds") 1101 | aug_emb = self.add_embedding(image_embs) 1102 | elif self.config.addition_embed_type == "image_hint": 1103 | # Kandinsky 2.2 - style 1104 | if ( 1105 | "image_embeds" not in added_cond_kwargs 1106 | or "hint" not in added_cond_kwargs 1107 | ): 1108 | raise ValueError( 1109 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" 1110 | ) 1111 | image_embs = added_cond_kwargs.get("image_embeds") 1112 | hint = added_cond_kwargs.get("hint") 1113 | aug_emb, hint = self.add_embedding(image_embs, hint) 1114 | sample = torch.cat([sample, hint], dim=1) 1115 | 1116 | emb = emb + aug_emb if aug_emb is not None else emb 1117 | 1118 | if self.time_embed_act is not None: 1119 | emb = self.time_embed_act(emb) 1120 | 1121 | if ( 1122 | self.encoder_hid_proj is not None 1123 | and self.config.encoder_hid_dim_type == "text_proj" 1124 | ): 1125 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 1126 | elif ( 1127 | self.encoder_hid_proj is not None 1128 | and self.config.encoder_hid_dim_type == "text_image_proj" 1129 | ): 1130 | # Kadinsky 2.1 - style 1131 | if "image_embeds" not in added_cond_kwargs: 1132 | raise ValueError( 1133 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 1134 | ) 1135 | 1136 | image_embeds = added_cond_kwargs.get("image_embeds") 1137 | encoder_hidden_states = self.encoder_hid_proj( 1138 | encoder_hidden_states, image_embeds 1139 | ) 1140 | elif ( 1141 | self.encoder_hid_proj is not None 1142 | and self.config.encoder_hid_dim_type == "image_proj" 1143 | ): 1144 | # Kandinsky 2.2 - style 1145 | if "image_embeds" not in added_cond_kwargs: 1146 | raise ValueError( 1147 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 1148 | ) 1149 | image_embeds = added_cond_kwargs.get("image_embeds") 1150 | encoder_hidden_states = self.encoder_hid_proj(image_embeds) 1151 | elif ( 1152 | self.encoder_hid_proj is not None 1153 | and self.config.encoder_hid_dim_type == "ip_image_proj" 1154 | ): 1155 | if "image_embeds" not in added_cond_kwargs: 1156 | raise ValueError( 1157 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 1158 | ) 1159 | image_embeds = added_cond_kwargs.get("image_embeds") 1160 | image_embeds = self.encoder_hid_proj(image_embeds).to( 1161 | encoder_hidden_states.dtype 1162 | ) 1163 | encoder_hidden_states = torch.cat( 1164 | [encoder_hidden_states, image_embeds], dim=1 1165 | ) 1166 | 1167 | # 2. pre-process 1168 | sample = self.conv_in(sample) 1169 | 1170 | # 2.5 GLIGEN position net 1171 | if ( 1172 | cross_attention_kwargs is not None 1173 | and cross_attention_kwargs.get("gligen", None) is not None 1174 | ): 1175 | cross_attention_kwargs = cross_attention_kwargs.copy() 1176 | gligen_args = cross_attention_kwargs.pop("gligen") 1177 | cross_attention_kwargs["gligen"] = { 1178 | "objs": self.position_net(**gligen_args) 1179 | } 1180 | 1181 | # 3. down 1182 | lora_scale = ( 1183 | cross_attention_kwargs.get("scale", 1.0) 1184 | if cross_attention_kwargs is not None 1185 | else 1.0 1186 | ) 1187 | if USE_PEFT_BACKEND: 1188 | # weight the lora layers by setting `lora_scale` for each PEFT layer 1189 | scale_lora_layers(self, lora_scale) 1190 | 1191 | is_controlnet = ( 1192 | mid_block_additional_residual is not None 1193 | and down_block_additional_residuals is not None 1194 | ) 1195 | # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets 1196 | is_adapter = down_intrablock_additional_residuals is not None 1197 | # maintain backward compatibility for legacy usage, where 1198 | # T2I-Adapter and ControlNet both use down_block_additional_residuals arg 1199 | # but can only use one or the other 1200 | if ( 1201 | not is_adapter 1202 | and mid_block_additional_residual is None 1203 | and down_block_additional_residuals is not None 1204 | ): 1205 | deprecate( 1206 | "T2I should not use down_block_additional_residuals", 1207 | "1.3.0", 1208 | "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ 1209 | and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ 1210 | for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", 1211 | standard_warn=False, 1212 | ) 1213 | down_intrablock_additional_residuals = down_block_additional_residuals 1214 | is_adapter = True 1215 | 1216 | down_block_res_samples = (sample,) 1217 | for downsample_block in self.down_blocks: 1218 | if ( 1219 | hasattr(downsample_block, "has_cross_attention") 1220 | and downsample_block.has_cross_attention 1221 | ): 1222 | # For t2i-adapter CrossAttnDownBlock2D 1223 | additional_residuals = {} 1224 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 1225 | additional_residuals[ 1226 | "additional_residuals" 1227 | ] = down_intrablock_additional_residuals.pop(0) 1228 | 1229 | sample, res_samples = downsample_block( 1230 | hidden_states=sample, 1231 | temb=emb, 1232 | encoder_hidden_states=encoder_hidden_states, 1233 | attention_mask=attention_mask, 1234 | cross_attention_kwargs=cross_attention_kwargs, 1235 | encoder_attention_mask=encoder_attention_mask, 1236 | **additional_residuals, 1237 | ) 1238 | else: 1239 | sample, res_samples = downsample_block( 1240 | hidden_states=sample, temb=emb, scale=lora_scale 1241 | ) 1242 | if is_adapter and len(down_intrablock_additional_residuals) > 0: 1243 | sample += down_intrablock_additional_residuals.pop(0) 1244 | 1245 | down_block_res_samples += res_samples 1246 | 1247 | if is_controlnet: 1248 | new_down_block_res_samples = () 1249 | 1250 | for down_block_res_sample, down_block_additional_residual in zip( 1251 | down_block_res_samples, down_block_additional_residuals 1252 | ): 1253 | down_block_res_sample = ( 1254 | down_block_res_sample + down_block_additional_residual 1255 | ) 1256 | new_down_block_res_samples = new_down_block_res_samples + ( 1257 | down_block_res_sample, 1258 | ) 1259 | 1260 | down_block_res_samples = new_down_block_res_samples 1261 | 1262 | # 4. mid 1263 | if self.mid_block is not None: 1264 | if ( 1265 | hasattr(self.mid_block, "has_cross_attention") 1266 | and self.mid_block.has_cross_attention 1267 | ): 1268 | sample = self.mid_block( 1269 | sample, 1270 | emb, 1271 | encoder_hidden_states=encoder_hidden_states, 1272 | attention_mask=attention_mask, 1273 | cross_attention_kwargs=cross_attention_kwargs, 1274 | encoder_attention_mask=encoder_attention_mask, 1275 | ) 1276 | else: 1277 | sample = self.mid_block(sample, emb) 1278 | 1279 | # To support T2I-Adapter-XL 1280 | if ( 1281 | is_adapter 1282 | and len(down_intrablock_additional_residuals) > 0 1283 | and sample.shape == down_intrablock_additional_residuals[0].shape 1284 | ): 1285 | sample += down_intrablock_additional_residuals.pop(0) 1286 | 1287 | if is_controlnet: 1288 | sample = sample + mid_block_additional_residual 1289 | 1290 | # 5. up 1291 | for i, upsample_block in enumerate(self.up_blocks): 1292 | is_final_block = i == len(self.up_blocks) - 1 1293 | 1294 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 1295 | down_block_res_samples = down_block_res_samples[ 1296 | : -len(upsample_block.resnets) 1297 | ] 1298 | 1299 | # if we have not reached the final block and need to forward the 1300 | # upsample size, we do it here 1301 | if not is_final_block and forward_upsample_size: 1302 | upsample_size = down_block_res_samples[-1].shape[2:] 1303 | 1304 | if ( 1305 | hasattr(upsample_block, "has_cross_attention") 1306 | and upsample_block.has_cross_attention 1307 | ): 1308 | sample = upsample_block( 1309 | hidden_states=sample, 1310 | temb=emb, 1311 | res_hidden_states_tuple=res_samples, 1312 | encoder_hidden_states=encoder_hidden_states, 1313 | cross_attention_kwargs=cross_attention_kwargs, 1314 | upsample_size=upsample_size, 1315 | attention_mask=attention_mask, 1316 | encoder_attention_mask=encoder_attention_mask, 1317 | ) 1318 | else: 1319 | sample = upsample_block( 1320 | hidden_states=sample, 1321 | temb=emb, 1322 | res_hidden_states_tuple=res_samples, 1323 | upsample_size=upsample_size, 1324 | scale=lora_scale, 1325 | ) 1326 | 1327 | # 6. post-process 1328 | if self.conv_norm_out: 1329 | sample = self.conv_norm_out(sample) 1330 | sample = self.conv_act(sample) 1331 | sample = self.conv_out(sample) 1332 | 1333 | if USE_PEFT_BACKEND: 1334 | # remove `lora_scale` from each PEFT layer 1335 | unscale_lora_layers(self, lora_scale) 1336 | 1337 | if not return_dict: 1338 | return (sample,) 1339 | 1340 | return UNet2DConditionOutput(sample=sample) 1341 | -------------------------------------------------------------------------------- /scedit_pytorch/scedit.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | import torch 3 | from torch import nn 4 | from diffusers.models.modeling_utils import get_parameter_device, get_parameter_dtype 5 | 6 | 7 | class AbstractSCTunerLayer(nn.Module): 8 | def __init__( 9 | self, 10 | dim: int, 11 | device: Optional[Union[torch.device, str]] = None, 12 | dtype: Optional[torch.dtype] = None, 13 | ): 14 | super().__init__() 15 | self.dim = dim 16 | 17 | 18 | class SCTunerLinearLayer(AbstractSCTunerLayer): 19 | r""" 20 | A linear layer that is used with SCEdit. 21 | 22 | Parameters: 23 | dim (`int`): 24 | Number of dim. 25 | out_features (`int`): 26 | Number of output features. 27 | rank (`int`, `optional`, defaults to 4): 28 | The rank of the LoRA layer. 29 | device (`torch.device`, `optional`, defaults to `None`): 30 | The device to use for the layer's weights. 31 | dtype (`torch.dtype`, `optional`, defaults to `None`): 32 | The dtype to use for the layer's weights. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | dim: int, 38 | rank: Optional[int] = None, 39 | scale: float = 1.0, 40 | device: Optional[Union[torch.device, str]] = None, 41 | dtype: Optional[torch.dtype] = None, 42 | ): 43 | super().__init__(dim=dim) 44 | if rank is None: 45 | rank = dim 46 | self.down = nn.Linear(dim, rank, device=device, dtype=dtype) 47 | self.up = nn.Linear(rank, dim, device=device, dtype=dtype) 48 | self.act = nn.GELU() 49 | self.rank = rank 50 | self.scale = scale 51 | 52 | nn.init.normal_(self.down.weight, std=1 / rank) 53 | nn.init.zeros_(self.up.weight) 54 | 55 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 56 | orig_dtype = hidden_states.dtype 57 | dtype = self.down.weight.dtype 58 | 59 | hidden_states_input = hidden_states.permute(0, 2, 3, 1) 60 | down_hidden_states = self.down(hidden_states_input.to(dtype)) 61 | up_hidden_states = self.up(self.act(down_hidden_states)) 62 | up_hidden_states = up_hidden_states.to(orig_dtype).permute(0, 3, 1, 2) 63 | return self.scale * up_hidden_states + hidden_states 64 | 65 | 66 | class SCTunerLinearLayer2(AbstractSCTunerLayer): 67 | def __init__( 68 | self, 69 | dim: int, 70 | rank: Optional[int] = None, 71 | scale: float = 1.0, 72 | device: Optional[Union[torch.device, str]] = None, 73 | dtype: Optional[torch.dtype] = None, 74 | ): 75 | super().__init__(dim=dim) 76 | if rank is None: 77 | rank = dim 78 | self.model = nn.Sequential( 79 | nn.Linear(dim, rank, device=device, dtype=dtype), 80 | nn.GELU(), 81 | nn.Linear(rank, rank, device=device, dtype=dtype), 82 | nn.GELU(), 83 | nn.Linear(rank, dim, device=device, dtype=dtype), 84 | ) 85 | self.rank = rank 86 | self.scale = scale 87 | 88 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 89 | orig_dtype = hidden_states.dtype 90 | dtype = self.model[0].weight.dtype 91 | 92 | hidden_states_input = hidden_states.permute(0, 2, 3, 1) 93 | hidden_states_output = self.model(hidden_states_input.to(dtype)) 94 | hidden_states_output = hidden_states_output.to(orig_dtype).permute(0, 3, 1, 2) 95 | return self.scale * hidden_states_output + hidden_states 96 | 97 | 98 | class SCTuner(nn.Module): 99 | def __init__(self): 100 | super().__init__() 101 | self.sc_tuners = None 102 | self.res_skip_channels = None 103 | 104 | @property 105 | def device(self) -> torch.device: 106 | """ 107 | `torch.device`: The device on which the module is (assuming that all the module parameters are on the same 108 | device). 109 | """ 110 | return get_parameter_device(self) 111 | 112 | @property 113 | def dtype(self) -> torch.dtype: 114 | """ 115 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 116 | """ 117 | return get_parameter_dtype(self) 118 | 119 | def set_sctuner(self, sctuner_module=SCTunerLinearLayer, **kwargs): 120 | assert isinstance(self.res_skip_channels, list) 121 | sc_tuners = [] 122 | for c in self.res_skip_channels: 123 | sc_tuners.append( 124 | sctuner_module(dim=c, device=self.device, dtype=self.dtype, **kwargs) 125 | ) 126 | self.sc_tuners = nn.ModuleList(sc_tuners) 127 | -------------------------------------------------------------------------------- /scedit_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | from typing import Dict, Union 4 | 5 | import torch 6 | import safetensors 7 | from diffusers.utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, _get_model_file 8 | from . import UNet2DConditionModel 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | SCEDIT_WEIGHT_NAME_SAFE = "pytorch_scedit_weights.safetensors" 13 | 14 | 15 | def save_function(weights, filename): 16 | return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"}) 17 | 18 | 19 | def save_scedit(state_dict: Dict[str, torch.Tensor], save_directory: str): 20 | if os.path.isfile(save_directory): 21 | logger.error( 22 | f"Provided path ({save_directory}) should be a directory, not a file" 23 | ) 24 | return 25 | os.makedirs(save_directory, exist_ok=True) 26 | save_function(state_dict, os.path.join(save_directory, SCEDIT_WEIGHT_NAME_SAFE)) 27 | logger.info( 28 | f"Model weights saved in {os.path.join(save_directory, SCEDIT_WEIGHT_NAME_SAFE)}" 29 | ) 30 | 31 | 32 | def scedit_state_dict( 33 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs 34 | ) -> Dict[str, torch.Tensor]: 35 | cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) 36 | force_download = kwargs.pop("force_download", False) 37 | resume_download = kwargs.pop("resume_download", False) 38 | proxies = kwargs.pop("proxies", None) 39 | local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) 40 | use_auth_token = kwargs.pop("use_auth_token", None) 41 | revision = kwargs.pop("revision", None) 42 | subfolder = kwargs.pop("subfolder", None) 43 | weight_name = kwargs.pop("weight_name", None) 44 | user_agent = { 45 | "file_type": "attn_procs_weights", 46 | "framework": "pytorch", 47 | } 48 | 49 | model_file = None 50 | if not isinstance(pretrained_model_name_or_path_or_dict, dict): 51 | # Here we're relaxing the loading check to enable more Inference API 52 | # friendliness where sometimes, it's not at all possible to automatically 53 | # determine `weight_name`. 54 | if weight_name is None: 55 | weight_name = SCEDIT_WEIGHT_NAME_SAFE 56 | model_file = _get_model_file( 57 | pretrained_model_name_or_path_or_dict, 58 | weights_name=weight_name, 59 | cache_dir=cache_dir, 60 | force_download=force_download, 61 | resume_download=resume_download, 62 | proxies=proxies, 63 | local_files_only=local_files_only, 64 | use_auth_token=use_auth_token, 65 | revision=revision, 66 | subfolder=subfolder, 67 | user_agent=user_agent, 68 | ) 69 | state_dict = safetensors.torch.load_file(model_file, device="cpu") 70 | else: 71 | state_dict = pretrained_model_name_or_path_or_dict 72 | 73 | return state_dict 74 | 75 | 76 | def load_scedit_into_unet( 77 | state_dict: Union[str, Dict[str, torch.Tensor]], 78 | unet: UNet2DConditionModel, 79 | **kwargs, 80 | ) -> UNet2DConditionModel: 81 | if isinstance(state_dict, str): 82 | state_dict = scedit_state_dict(state_dict, **kwargs) 83 | assert unet.has_sctuner, "Make sure to call `set_sctuner` before!" 84 | unet.load_state_dict(state_dict, strict=False) 85 | return unet 86 | -------------------------------------------------------------------------------- /scripts/app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append(".") 4 | import argparse 5 | import gradio as gr 6 | import torch 7 | from huggingface_hub.repocard import RepoCard 8 | from diffusers import DiffusionPipeline 9 | from scedit_pytorch import UNet2DConditionModel, load_scedit_into_unet 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--pretrained_model_name_or_path", 16 | type=str, 17 | default="stabilityai/stable-diffusion-xl-base-1.0", 18 | help="pretrained model path", 19 | ) 20 | parser.add_argument( 21 | "--scedit_name_or_path", type=str, required=True, help="ziplora path" 22 | ) 23 | parser.add_argument("--scale", type=float, default=1.0, help="weight scale") 24 | return parser.parse_args() 25 | 26 | 27 | args = parse_args() 28 | device = "cuda" if torch.cuda.is_available() else "cpu" 29 | # load unet with sctuner 30 | unet = UNet2DConditionModel.from_pretrained( 31 | args.pretrained_model_name_or_path, subfolder="unet" 32 | ) 33 | unet.set_sctuner(scale=args.scale) 34 | unet = load_scedit_into_unet(args.scedit_name_or_path, unet) 35 | # load pipeline 36 | pipeline = DiffusionPipeline.from_pretrained( 37 | args.pretrained_model_name_or_path, unet=unet 38 | ) 39 | pipeline = pipeline.to(device=device, dtype=torch.float16) 40 | 41 | 42 | def run(prompt: str): 43 | # generator = torch.Generator(device="cuda").manual_seed(42) 44 | generator = None 45 | image = pipeline(prompt=prompt, generator=generator).images[0] 46 | return image 47 | 48 | 49 | with gr.Blocks() as demo: 50 | with gr.Row(): 51 | with gr.Column(): 52 | prompt = gr.Text(label="prompt", value="A picture of a sbu dog in a bucket") 53 | bttn = gr.Button(value="Run") 54 | with gr.Column(): 55 | out = gr.Image(label="out") 56 | prompt.submit(fn=run, inputs=[prompt], outputs=[out]) 57 | bttn.click(fn=run, inputs=[prompt], outputs=[out]) 58 | 59 | demo.launch(share=True, debug=True, show_error=True) 60 | -------------------------------------------------------------------------------- /train_dreambooth_scedit_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import itertools 4 | import logging 5 | import math 6 | import os 7 | import shutil 8 | import warnings 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | import torch.nn.functional as F 14 | import torch.utils.checkpoint 15 | import transformers 16 | from accelerate import Accelerator 17 | from accelerate.logging import get_logger 18 | from accelerate.utils import ( 19 | DistributedDataParallelKwargs, 20 | ProjectConfiguration, 21 | set_seed, 22 | ) 23 | from huggingface_hub import create_repo, upload_folder 24 | from huggingface_hub.utils import insecure_hashlib 25 | from packaging import version 26 | from PIL import Image 27 | from PIL.ImageOps import exif_transpose 28 | from torch.utils.data import Dataset 29 | from torchvision import transforms 30 | from tqdm.auto import tqdm 31 | from transformers import AutoTokenizer, PretrainedConfig 32 | 33 | import diffusers 34 | from diffusers import ( 35 | AutoencoderKL, 36 | DDPMScheduler, 37 | DPMSolverMultistepScheduler, 38 | StableDiffusionXLPipeline, 39 | # UNet2DConditionModel, 40 | ) 41 | from diffusers.optimization import get_scheduler 42 | from diffusers.training_utils import compute_snr 43 | from diffusers.utils import check_min_version, is_wandb_available 44 | from diffusers.utils.import_utils import is_xformers_available 45 | from scedit_pytorch import UNet2DConditionModel, save_scedit, load_scedit_into_unet 46 | 47 | 48 | logger = get_logger(__name__) 49 | 50 | 51 | def save_model_card( 52 | repo_id: str, 53 | images=None, 54 | base_model=str, 55 | instance_prompt=str, 56 | validation_prompt=str, 57 | repo_folder=None, 58 | vae_path=None, 59 | ): 60 | img_str = "widget:\n" if images else "" 61 | for i, image in enumerate(images): 62 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 63 | img_str += f""" 64 | - text: '{validation_prompt if validation_prompt else ' ' }' 65 | output: 66 | url: 67 | "image_{i}.png" 68 | """ 69 | 70 | yaml = f""" 71 | --- 72 | tags: 73 | - stable-diffusion-xl 74 | - stable-diffusion-xl-diffusers 75 | - text-to-image 76 | - diffusers 77 | - scedit 78 | - template:sd-lora 79 | {img_str} 80 | base_model: {base_model} 81 | instance_prompt: {instance_prompt} 82 | license: openrail++ 83 | --- 84 | """ 85 | 86 | model_card = f""" 87 | # SDXL SCEdit DreamBooth - {repo_id} 88 | 89 | 90 | 91 | ## Model description 92 | 93 | These are {repo_id} SC-Tuner adaption weights for {base_model}. 94 | 95 | The weights were trained using [DreamBooth](https://dreambooth.github.io/). 96 | 97 | Special VAE used for training: {vae_path}. 98 | 99 | ## Trigger words 100 | 101 | You should use {instance_prompt} to trigger the image generation. 102 | 103 | ## Download model 104 | 105 | Weights for this model are available in Safetensors format. 106 | 107 | [Download]({repo_id}/tree/main) them in the Files & versions tab. 108 | 109 | """ 110 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 111 | f.write(yaml + model_card) 112 | 113 | 114 | def import_model_class_from_model_name_or_path( 115 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 116 | ): 117 | text_encoder_config = PretrainedConfig.from_pretrained( 118 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 119 | ) 120 | model_class = text_encoder_config.architectures[0] 121 | 122 | if model_class == "CLIPTextModel": 123 | from transformers import CLIPTextModel 124 | 125 | return CLIPTextModel 126 | elif model_class == "CLIPTextModelWithProjection": 127 | from transformers import CLIPTextModelWithProjection 128 | 129 | return CLIPTextModelWithProjection 130 | else: 131 | raise ValueError(f"{model_class} is not supported.") 132 | 133 | 134 | def parse_args(input_args=None): 135 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 136 | parser.add_argument( 137 | "--pretrained_model_name_or_path", 138 | type=str, 139 | default=None, 140 | required=True, 141 | help="Path to pretrained model or model identifier from huggingface.co/models.", 142 | ) 143 | parser.add_argument( 144 | "--pretrained_vae_model_name_or_path", 145 | type=str, 146 | default=None, 147 | help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", 148 | ) 149 | parser.add_argument( 150 | "--revision", 151 | type=str, 152 | default=None, 153 | required=False, 154 | help="Revision of pretrained model identifier from huggingface.co/models.", 155 | ) 156 | parser.add_argument( 157 | "--variant", 158 | type=str, 159 | default=None, 160 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 161 | ) 162 | parser.add_argument( 163 | "--dataset_name", 164 | type=str, 165 | default=None, 166 | help=( 167 | "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," 168 | " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," 169 | " or to a folder containing files that 🤗 Datasets can understand." 170 | ), 171 | ) 172 | parser.add_argument( 173 | "--dataset_config_name", 174 | type=str, 175 | default=None, 176 | help="The config of the Dataset, leave as None if there's only one config.", 177 | ) 178 | parser.add_argument( 179 | "--instance_data_dir", 180 | type=str, 181 | default=None, 182 | help=("A folder containing the training data. "), 183 | ) 184 | 185 | parser.add_argument( 186 | "--cache_dir", 187 | type=str, 188 | default=None, 189 | help="The directory where the downloaded models and datasets will be stored.", 190 | ) 191 | 192 | parser.add_argument( 193 | "--image_column", 194 | type=str, 195 | default="image", 196 | help="The column of the dataset containing the target image. By " 197 | "default, the standard Image Dataset maps out 'file_name' " 198 | "to 'image'.", 199 | ) 200 | parser.add_argument( 201 | "--caption_column", 202 | type=str, 203 | default=None, 204 | help="The column of the dataset containing the instance prompt for each image", 205 | ) 206 | 207 | parser.add_argument( 208 | "--repeats", 209 | type=int, 210 | default=1, 211 | help="How many times to repeat the training data.", 212 | ) 213 | 214 | parser.add_argument( 215 | "--class_data_dir", 216 | type=str, 217 | default=None, 218 | required=False, 219 | help="A folder containing the training data of class images.", 220 | ) 221 | parser.add_argument( 222 | "--instance_prompt", 223 | type=str, 224 | default=None, 225 | required=True, 226 | help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", 227 | ) 228 | parser.add_argument( 229 | "--class_prompt", 230 | type=str, 231 | default=None, 232 | help="The prompt to specify images in the same class as provided instance images.", 233 | ) 234 | parser.add_argument( 235 | "--validation_prompt", 236 | type=str, 237 | default=None, 238 | help="A prompt that is used during validation to verify that the model is learning.", 239 | ) 240 | parser.add_argument( 241 | "--num_validation_images", 242 | type=int, 243 | default=4, 244 | help="Number of images that should be generated during validation with `validation_prompt`.", 245 | ) 246 | parser.add_argument( 247 | "--validation_epochs", 248 | type=int, 249 | default=50, 250 | help=( 251 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" 252 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 253 | ), 254 | ) 255 | parser.add_argument( 256 | "--with_prior_preservation", 257 | default=False, 258 | action="store_true", 259 | help="Flag to add prior preservation loss.", 260 | ) 261 | parser.add_argument( 262 | "--prior_loss_weight", 263 | type=float, 264 | default=1.0, 265 | help="The weight of prior preservation loss.", 266 | ) 267 | parser.add_argument( 268 | "--num_class_images", 269 | type=int, 270 | default=100, 271 | help=( 272 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 273 | " class_data_dir, additional images will be sampled with class_prompt." 274 | ), 275 | ) 276 | parser.add_argument( 277 | "--output_dir", 278 | type=str, 279 | default="scedit-dreambooth-model", 280 | help="The output directory where the model predictions and checkpoints will be written.", 281 | ) 282 | parser.add_argument( 283 | "--seed", type=int, default=None, help="A seed for reproducible training." 284 | ) 285 | parser.add_argument( 286 | "--resolution", 287 | type=int, 288 | default=1024, 289 | help=( 290 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 291 | " resolution" 292 | ), 293 | ) 294 | parser.add_argument( 295 | "--crops_coords_top_left_h", 296 | type=int, 297 | default=0, 298 | help=( 299 | "Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet." 300 | ), 301 | ) 302 | parser.add_argument( 303 | "--crops_coords_top_left_w", 304 | type=int, 305 | default=0, 306 | help=( 307 | "Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet." 308 | ), 309 | ) 310 | parser.add_argument( 311 | "--center_crop", 312 | default=False, 313 | action="store_true", 314 | help=( 315 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 316 | " cropped. The images will be resized to the resolution first before cropping." 317 | ), 318 | ) 319 | parser.add_argument( 320 | "--train_batch_size", 321 | type=int, 322 | default=4, 323 | help="Batch size (per device) for the training dataloader.", 324 | ) 325 | parser.add_argument( 326 | "--sample_batch_size", 327 | type=int, 328 | default=4, 329 | help="Batch size (per device) for sampling images.", 330 | ) 331 | parser.add_argument("--num_train_epochs", type=int, default=1) 332 | parser.add_argument( 333 | "--max_train_steps", 334 | type=int, 335 | default=None, 336 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 337 | ) 338 | parser.add_argument( 339 | "--checkpointing_steps", 340 | type=int, 341 | default=500, 342 | help=( 343 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 344 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 345 | " training using `--resume_from_checkpoint`." 346 | ), 347 | ) 348 | parser.add_argument( 349 | "--checkpoints_total_limit", 350 | type=int, 351 | default=None, 352 | help=("Max number of checkpoints to store."), 353 | ) 354 | parser.add_argument( 355 | "--resume_from_checkpoint", 356 | type=str, 357 | default=None, 358 | help=( 359 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 360 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 361 | ), 362 | ) 363 | parser.add_argument( 364 | "--gradient_accumulation_steps", 365 | type=int, 366 | default=1, 367 | help="Number of updates steps to accumulate before performing a backward/update pass.", 368 | ) 369 | parser.add_argument( 370 | "--gradient_checkpointing", 371 | action="store_true", 372 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 373 | ) 374 | parser.add_argument( 375 | "--learning_rate", 376 | type=float, 377 | default=1e-4, 378 | help="Initial learning rate (after the potential warmup period) to use.", 379 | ) 380 | 381 | parser.add_argument( 382 | "--text_encoder_lr", 383 | type=float, 384 | default=5e-6, 385 | help="Text encoder learning rate to use.", 386 | ) 387 | parser.add_argument( 388 | "--scale_lr", 389 | action="store_true", 390 | default=False, 391 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 392 | ) 393 | parser.add_argument( 394 | "--lr_scheduler", 395 | type=str, 396 | default="constant", 397 | help=( 398 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 399 | ' "constant", "constant_with_warmup"]' 400 | ), 401 | ) 402 | 403 | parser.add_argument( 404 | "--snr_gamma", 405 | type=float, 406 | default=None, 407 | help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " 408 | "More details here: https://arxiv.org/abs/2303.09556.", 409 | ) 410 | parser.add_argument( 411 | "--lr_warmup_steps", 412 | type=int, 413 | default=500, 414 | help="Number of steps for the warmup in the lr scheduler.", 415 | ) 416 | parser.add_argument( 417 | "--lr_num_cycles", 418 | type=int, 419 | default=1, 420 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 421 | ) 422 | parser.add_argument( 423 | "--lr_power", 424 | type=float, 425 | default=1.0, 426 | help="Power factor of the polynomial scheduler.", 427 | ) 428 | parser.add_argument( 429 | "--dataloader_num_workers", 430 | type=int, 431 | default=0, 432 | help=( 433 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 434 | ), 435 | ) 436 | 437 | parser.add_argument( 438 | "--optimizer", 439 | type=str, 440 | default="AdamW", 441 | help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), 442 | ) 443 | 444 | parser.add_argument( 445 | "--use_8bit_adam", 446 | action="store_true", 447 | help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", 448 | ) 449 | 450 | parser.add_argument( 451 | "--adam_beta1", 452 | type=float, 453 | default=0.9, 454 | help="The beta1 parameter for the Adam and Prodigy optimizers.", 455 | ) 456 | parser.add_argument( 457 | "--adam_beta2", 458 | type=float, 459 | default=0.999, 460 | help="The beta2 parameter for the Adam and Prodigy optimizers.", 461 | ) 462 | parser.add_argument( 463 | "--prodigy_beta3", 464 | type=float, 465 | default=None, 466 | help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " 467 | "uses the value of square root of beta2. Ignored if optimizer is adamW", 468 | ) 469 | parser.add_argument( 470 | "--prodigy_decouple", 471 | type=bool, 472 | default=True, 473 | help="Use AdamW style decoupled weight decay", 474 | ) 475 | parser.add_argument( 476 | "--adam_weight_decay", 477 | type=float, 478 | default=1e-04, 479 | help="Weight decay to use for unet params", 480 | ) 481 | parser.add_argument( 482 | "--adam_weight_decay_text_encoder", 483 | type=float, 484 | default=1e-03, 485 | help="Weight decay to use for text_encoder", 486 | ) 487 | 488 | parser.add_argument( 489 | "--adam_epsilon", 490 | type=float, 491 | default=1e-08, 492 | help="Epsilon value for the Adam optimizer and Prodigy optimizers.", 493 | ) 494 | 495 | parser.add_argument( 496 | "--prodigy_use_bias_correction", 497 | type=bool, 498 | default=True, 499 | help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", 500 | ) 501 | parser.add_argument( 502 | "--prodigy_safeguard_warmup", 503 | type=bool, 504 | default=True, 505 | help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " 506 | "Ignored if optimizer is adamW", 507 | ) 508 | parser.add_argument( 509 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 510 | ) 511 | parser.add_argument( 512 | "--push_to_hub", 513 | action="store_true", 514 | help="Whether or not to push the model to the Hub.", 515 | ) 516 | parser.add_argument( 517 | "--hub_token", 518 | type=str, 519 | default=None, 520 | help="The token to use to push to the Model Hub.", 521 | ) 522 | parser.add_argument( 523 | "--hub_model_id", 524 | type=str, 525 | default=None, 526 | help="The name of the repository to keep in sync with the local `output_dir`.", 527 | ) 528 | parser.add_argument( 529 | "--logging_dir", 530 | type=str, 531 | default="logs", 532 | help=( 533 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 534 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 535 | ), 536 | ) 537 | parser.add_argument( 538 | "--allow_tf32", 539 | action="store_true", 540 | help=( 541 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 542 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 543 | ), 544 | ) 545 | parser.add_argument( 546 | "--report_to", 547 | type=str, 548 | default="tensorboard", 549 | help=( 550 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 551 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 552 | ), 553 | ) 554 | parser.add_argument( 555 | "--mixed_precision", 556 | type=str, 557 | default=None, 558 | choices=["no", "fp16", "bf16"], 559 | help=( 560 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 561 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 562 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 563 | ), 564 | ) 565 | parser.add_argument( 566 | "--prior_generation_precision", 567 | type=str, 568 | default=None, 569 | choices=["no", "fp32", "fp16", "bf16"], 570 | help=( 571 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 572 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 573 | ), 574 | ) 575 | parser.add_argument( 576 | "--local_rank", 577 | type=int, 578 | default=-1, 579 | help="For distributed training: local_rank", 580 | ) 581 | parser.add_argument( 582 | "--enable_xformers_memory_efficient_attention", 583 | action="store_true", 584 | help="Whether or not to use xformers.", 585 | ) 586 | 587 | if input_args is not None: 588 | args = parser.parse_args(input_args) 589 | else: 590 | args = parser.parse_args() 591 | 592 | if args.dataset_name is None and args.instance_data_dir is None: 593 | raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") 594 | 595 | if args.dataset_name is not None and args.instance_data_dir is not None: 596 | raise ValueError( 597 | "Specify only one of `--dataset_name` or `--instance_data_dir`" 598 | ) 599 | 600 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 601 | if env_local_rank != -1 and env_local_rank != args.local_rank: 602 | args.local_rank = env_local_rank 603 | 604 | if args.with_prior_preservation: 605 | if args.class_data_dir is None: 606 | raise ValueError("You must specify a data directory for class images.") 607 | if args.class_prompt is None: 608 | raise ValueError("You must specify prompt for class images.") 609 | else: 610 | # logger is not available yet 611 | if args.class_data_dir is not None: 612 | warnings.warn( 613 | "You need not use --class_data_dir without --with_prior_preservation." 614 | ) 615 | if args.class_prompt is not None: 616 | warnings.warn( 617 | "You need not use --class_prompt without --with_prior_preservation." 618 | ) 619 | 620 | return args 621 | 622 | 623 | class DreamBoothDataset(Dataset): 624 | """ 625 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 626 | It pre-processes the images. 627 | """ 628 | 629 | def __init__( 630 | self, 631 | instance_data_root, 632 | instance_prompt, 633 | class_prompt, 634 | class_data_root=None, 635 | class_num=None, 636 | size=1024, 637 | repeats=1, 638 | center_crop=False, 639 | ): 640 | self.size = size 641 | self.center_crop = center_crop 642 | 643 | self.instance_prompt = instance_prompt 644 | self.custom_instance_prompts = None 645 | self.class_prompt = class_prompt 646 | 647 | # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, 648 | # we load the training data using load_dataset 649 | if args.dataset_name is not None: 650 | try: 651 | from datasets import load_dataset 652 | except ImportError: 653 | raise ImportError( 654 | "You are trying to load your data using the datasets library. If you wish to train using custom " 655 | "captions please install the datasets library: `pip install datasets`. If you wish to load a " 656 | "local folder containing images only, specify --instance_data_dir instead." 657 | ) 658 | # Downloading and loading a dataset from the hub. 659 | # See more about loading custom images at 660 | # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script 661 | dataset = load_dataset( 662 | args.dataset_name, 663 | args.dataset_config_name, 664 | cache_dir=args.cache_dir, 665 | ) 666 | # Preprocessing the datasets. 667 | column_names = dataset["train"].column_names 668 | 669 | # 6. Get the column names for input/target. 670 | if args.image_column is None: 671 | image_column = column_names[0] 672 | logger.info(f"image column defaulting to {image_column}") 673 | else: 674 | image_column = args.image_column 675 | if image_column not in column_names: 676 | raise ValueError( 677 | f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 678 | ) 679 | instance_images = dataset["train"][image_column] 680 | 681 | if args.caption_column is None: 682 | logger.info( 683 | "No caption column provided, defaulting to instance_prompt for all images. If your dataset " 684 | "contains captions/prompts for the images, make sure to specify the " 685 | "column as --caption_column" 686 | ) 687 | self.custom_instance_prompts = None 688 | else: 689 | if args.caption_column not in column_names: 690 | raise ValueError( 691 | f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" 692 | ) 693 | custom_instance_prompts = dataset["train"][args.caption_column] 694 | # create final list of captions according to --repeats 695 | self.custom_instance_prompts = [] 696 | for caption in custom_instance_prompts: 697 | self.custom_instance_prompts.extend( 698 | itertools.repeat(caption, repeats) 699 | ) 700 | else: 701 | self.instance_data_root = Path(instance_data_root) 702 | if not self.instance_data_root.exists(): 703 | raise ValueError("Instance images root doesn't exists.") 704 | 705 | instance_images = [ 706 | Image.open(path) for path in list(Path(instance_data_root).iterdir()) 707 | ] 708 | self.custom_instance_prompts = None 709 | 710 | self.instance_images = [] 711 | for img in instance_images: 712 | self.instance_images.extend(itertools.repeat(img, repeats)) 713 | self.num_instance_images = len(self.instance_images) 714 | self._length = self.num_instance_images 715 | 716 | if class_data_root is not None: 717 | self.class_data_root = Path(class_data_root) 718 | self.class_data_root.mkdir(parents=True, exist_ok=True) 719 | self.class_images_path = list(self.class_data_root.iterdir()) 720 | if class_num is not None: 721 | self.num_class_images = min(len(self.class_images_path), class_num) 722 | else: 723 | self.num_class_images = len(self.class_images_path) 724 | self._length = max(self.num_class_images, self.num_instance_images) 725 | else: 726 | self.class_data_root = None 727 | 728 | self.image_transforms = transforms.Compose( 729 | [ 730 | transforms.Resize( 731 | size, interpolation=transforms.InterpolationMode.BILINEAR 732 | ), 733 | transforms.CenterCrop(size) 734 | if center_crop 735 | else transforms.RandomCrop(size), 736 | transforms.ToTensor(), 737 | transforms.Normalize([0.5], [0.5]), 738 | ] 739 | ) 740 | 741 | def __len__(self): 742 | return self._length 743 | 744 | def __getitem__(self, index): 745 | example = {} 746 | instance_image = self.instance_images[index % self.num_instance_images] 747 | instance_image = exif_transpose(instance_image) 748 | 749 | if not instance_image.mode == "RGB": 750 | instance_image = instance_image.convert("RGB") 751 | example["instance_images"] = self.image_transforms(instance_image) 752 | 753 | if self.custom_instance_prompts: 754 | caption = self.custom_instance_prompts[index % self.num_instance_images] 755 | if caption: 756 | example["instance_prompt"] = caption 757 | else: 758 | example["instance_prompt"] = self.instance_prompt 759 | 760 | else: # costum prompts were provided, but length does not match size of image dataset 761 | example["instance_prompt"] = self.instance_prompt 762 | 763 | if self.class_data_root: 764 | class_image = Image.open( 765 | self.class_images_path[index % self.num_class_images] 766 | ) 767 | class_image = exif_transpose(class_image) 768 | 769 | if not class_image.mode == "RGB": 770 | class_image = class_image.convert("RGB") 771 | example["class_images"] = self.image_transforms(class_image) 772 | example["class_prompt"] = self.class_prompt 773 | 774 | return example 775 | 776 | 777 | def collate_fn(examples, with_prior_preservation=False): 778 | pixel_values = [example["instance_images"] for example in examples] 779 | prompts = [example["instance_prompt"] for example in examples] 780 | 781 | # Concat class and instance examples for prior preservation. 782 | # We do this to avoid doing two forward passes. 783 | if with_prior_preservation: 784 | pixel_values += [example["class_images"] for example in examples] 785 | prompts += [example["class_prompt"] for example in examples] 786 | 787 | pixel_values = torch.stack(pixel_values) 788 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 789 | 790 | batch = {"pixel_values": pixel_values, "prompts": prompts} 791 | return batch 792 | 793 | 794 | class PromptDataset(Dataset): 795 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 796 | 797 | def __init__(self, prompt, num_samples): 798 | self.prompt = prompt 799 | self.num_samples = num_samples 800 | 801 | def __len__(self): 802 | return self.num_samples 803 | 804 | def __getitem__(self, index): 805 | example = {} 806 | example["prompt"] = self.prompt 807 | example["index"] = index 808 | return example 809 | 810 | 811 | def tokenize_prompt(tokenizer, prompt): 812 | text_inputs = tokenizer( 813 | prompt, 814 | padding="max_length", 815 | max_length=tokenizer.model_max_length, 816 | truncation=True, 817 | return_tensors="pt", 818 | ) 819 | text_input_ids = text_inputs.input_ids 820 | return text_input_ids 821 | 822 | 823 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 824 | def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): 825 | prompt_embeds_list = [] 826 | 827 | for i, text_encoder in enumerate(text_encoders): 828 | if tokenizers is not None: 829 | tokenizer = tokenizers[i] 830 | text_input_ids = tokenize_prompt(tokenizer, prompt) 831 | else: 832 | assert text_input_ids_list is not None 833 | text_input_ids = text_input_ids_list[i] 834 | 835 | prompt_embeds = text_encoder( 836 | text_input_ids.to(text_encoder.device), 837 | output_hidden_states=True, 838 | ) 839 | 840 | # We are only ALWAYS interested in the pooled output of the final text encoder 841 | pooled_prompt_embeds = prompt_embeds[0] 842 | prompt_embeds = prompt_embeds.hidden_states[-2] 843 | bs_embed, seq_len, _ = prompt_embeds.shape 844 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 845 | prompt_embeds_list.append(prompt_embeds) 846 | 847 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 848 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 849 | return prompt_embeds, pooled_prompt_embeds 850 | 851 | 852 | def main(args): 853 | logging_dir = Path(args.output_dir, args.logging_dir) 854 | 855 | accelerator_project_config = ProjectConfiguration( 856 | project_dir=args.output_dir, logging_dir=logging_dir 857 | ) 858 | kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 859 | accelerator = Accelerator( 860 | gradient_accumulation_steps=args.gradient_accumulation_steps, 861 | mixed_precision=args.mixed_precision, 862 | log_with=args.report_to, 863 | project_config=accelerator_project_config, 864 | kwargs_handlers=[kwargs], 865 | ) 866 | 867 | if args.report_to == "wandb": 868 | if not is_wandb_available(): 869 | raise ImportError( 870 | "Make sure to install wandb if you want to use it for logging during training." 871 | ) 872 | import wandb 873 | 874 | # Make one log on every process with the configuration for debugging. 875 | logging.basicConfig( 876 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 877 | datefmt="%m/%d/%Y %H:%M:%S", 878 | level=logging.INFO, 879 | ) 880 | logger.info(accelerator.state, main_process_only=False) 881 | if accelerator.is_local_main_process: 882 | transformers.utils.logging.set_verbosity_warning() 883 | diffusers.utils.logging.set_verbosity_info() 884 | else: 885 | transformers.utils.logging.set_verbosity_error() 886 | diffusers.utils.logging.set_verbosity_error() 887 | 888 | # If passed along, set the training seed now. 889 | if args.seed is not None: 890 | set_seed(args.seed) 891 | 892 | # Generate class images if prior preservation is enabled. 893 | if args.with_prior_preservation: 894 | class_images_dir = Path(args.class_data_dir) 895 | if not class_images_dir.exists(): 896 | class_images_dir.mkdir(parents=True) 897 | cur_class_images = len(list(class_images_dir.iterdir())) 898 | 899 | if cur_class_images < args.num_class_images: 900 | torch_dtype = ( 901 | torch.float16 if accelerator.device.type == "cuda" else torch.float32 902 | ) 903 | if args.prior_generation_precision == "fp32": 904 | torch_dtype = torch.float32 905 | elif args.prior_generation_precision == "fp16": 906 | torch_dtype = torch.float16 907 | elif args.prior_generation_precision == "bf16": 908 | torch_dtype = torch.bfloat16 909 | pipeline = StableDiffusionXLPipeline.from_pretrained( 910 | args.pretrained_model_name_or_path, 911 | torch_dtype=torch_dtype, 912 | revision=args.revision, 913 | variant=args.variant, 914 | ) 915 | pipeline.set_progress_bar_config(disable=True) 916 | 917 | num_new_images = args.num_class_images - cur_class_images 918 | logger.info(f"Number of class images to sample: {num_new_images}.") 919 | 920 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) 921 | sample_dataloader = torch.utils.data.DataLoader( 922 | sample_dataset, batch_size=args.sample_batch_size 923 | ) 924 | 925 | sample_dataloader = accelerator.prepare(sample_dataloader) 926 | pipeline.to(accelerator.device) 927 | 928 | for example in tqdm( 929 | sample_dataloader, 930 | desc="Generating class images", 931 | disable=not accelerator.is_local_main_process, 932 | ): 933 | images = pipeline(example["prompt"]).images 934 | 935 | for i, image in enumerate(images): 936 | hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() 937 | image_filename = ( 938 | class_images_dir 939 | / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 940 | ) 941 | image.save(image_filename) 942 | 943 | del pipeline 944 | if torch.cuda.is_available(): 945 | torch.cuda.empty_cache() 946 | 947 | # Handle the repository creation 948 | if accelerator.is_main_process: 949 | if args.output_dir is not None: 950 | os.makedirs(args.output_dir, exist_ok=True) 951 | 952 | if args.push_to_hub: 953 | repo_id = create_repo( 954 | repo_id=args.hub_model_id or Path(args.output_dir).name, 955 | exist_ok=True, 956 | token=args.hub_token, 957 | ).repo_id 958 | 959 | # Load the tokenizers 960 | tokenizer_one = AutoTokenizer.from_pretrained( 961 | args.pretrained_model_name_or_path, 962 | subfolder="tokenizer", 963 | revision=args.revision, 964 | use_fast=False, 965 | ) 966 | tokenizer_two = AutoTokenizer.from_pretrained( 967 | args.pretrained_model_name_or_path, 968 | subfolder="tokenizer_2", 969 | revision=args.revision, 970 | use_fast=False, 971 | ) 972 | 973 | # import correct text encoder classes 974 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 975 | args.pretrained_model_name_or_path, args.revision 976 | ) 977 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 978 | args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" 979 | ) 980 | 981 | # Load scheduler and models 982 | noise_scheduler = DDPMScheduler.from_pretrained( 983 | args.pretrained_model_name_or_path, subfolder="scheduler" 984 | ) 985 | text_encoder_one = text_encoder_cls_one.from_pretrained( 986 | args.pretrained_model_name_or_path, 987 | subfolder="text_encoder", 988 | revision=args.revision, 989 | variant=args.variant, 990 | ) 991 | text_encoder_two = text_encoder_cls_two.from_pretrained( 992 | args.pretrained_model_name_or_path, 993 | subfolder="text_encoder_2", 994 | revision=args.revision, 995 | variant=args.variant, 996 | ) 997 | vae_path = ( 998 | args.pretrained_model_name_or_path 999 | if args.pretrained_vae_model_name_or_path is None 1000 | else args.pretrained_vae_model_name_or_path 1001 | ) 1002 | vae = AutoencoderKL.from_pretrained( 1003 | vae_path, 1004 | subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, 1005 | revision=args.revision, 1006 | variant=args.variant, 1007 | ) 1008 | unet = UNet2DConditionModel.from_pretrained( 1009 | args.pretrained_model_name_or_path, 1010 | subfolder="unet", 1011 | revision=args.revision, 1012 | variant=args.variant, 1013 | ) 1014 | 1015 | # We only train the additional adapter layers 1016 | vae.requires_grad_(False) 1017 | text_encoder_one.requires_grad_(False) 1018 | text_encoder_two.requires_grad_(False) 1019 | unet.requires_grad_(False) 1020 | 1021 | # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision 1022 | # as these weights are only used for inference, keeping weights in full precision is not required. 1023 | weight_dtype = torch.float32 1024 | if accelerator.mixed_precision == "fp16": 1025 | weight_dtype = torch.float16 1026 | elif accelerator.mixed_precision == "bf16": 1027 | weight_dtype = torch.bfloat16 1028 | 1029 | # Move unet, vae and text_encoder to device and cast to weight_dtype 1030 | unet.to(accelerator.device, dtype=weight_dtype) 1031 | 1032 | # The VAE is always in float32 to avoid NaN losses. 1033 | vae.to(accelerator.device, dtype=torch.float32) 1034 | 1035 | text_encoder_one.to(accelerator.device, dtype=weight_dtype) 1036 | text_encoder_two.to(accelerator.device, dtype=weight_dtype) 1037 | 1038 | if args.enable_xformers_memory_efficient_attention: 1039 | if is_xformers_available(): 1040 | import xformers 1041 | 1042 | xformers_version = version.parse(xformers.__version__) 1043 | if xformers_version == version.parse("0.0.16"): 1044 | logger.warn( 1045 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " 1046 | "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 1047 | ) 1048 | unet.enable_xformers_memory_efficient_attention() 1049 | else: 1050 | raise ValueError( 1051 | "xformers is not available. Make sure it is installed correctly" 1052 | ) 1053 | 1054 | if args.gradient_checkpointing: 1055 | unet.enable_gradient_checkpointing() 1056 | 1057 | # now we will add SC-Tuner 1058 | unet.set_sctuner() 1059 | 1060 | # Make sure the trainable params are in float32. 1061 | if args.mixed_precision == "fp16": 1062 | models = [unet] 1063 | for model in models: 1064 | for param in model.parameters(): 1065 | # only upcast trainable parameters (LoRA) into fp32 1066 | if param.requires_grad: 1067 | param.data = param.to(torch.float32) 1068 | 1069 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 1070 | def save_model_hook(models, weights, output_dir): 1071 | if accelerator.is_main_process: 1072 | unet_sctuner_layers_to_save = None 1073 | for model in models: 1074 | if isinstance(model, type(accelerator.unwrap_model(unet))): 1075 | unet_sctuner_layers_to_save = { 1076 | k: v for k, v in model.state_dict().items() if "sc_tuners" in k 1077 | } 1078 | else: 1079 | raise ValueError(f"unexpected save model: {model.__class__}") 1080 | 1081 | # make sure to pop weight so that corresponding model is not saved again 1082 | weights.pop() 1083 | save_scedit( 1084 | save_directory=output_dir, state_dict=unet_sctuner_layers_to_save 1085 | ) 1086 | 1087 | def load_model_hook(models, input_dir): 1088 | unet_ = None 1089 | 1090 | while len(models) > 0: 1091 | model = models.pop() 1092 | 1093 | if isinstance(model, type(accelerator.unwrap_model(unet))): 1094 | unet_ = model 1095 | else: 1096 | raise ValueError(f"unexpected save model: {model.__class__}") 1097 | load_scedit_into_unet(state_dict=input_dir, unet=unet_) 1098 | 1099 | accelerator.register_save_state_pre_hook(save_model_hook) 1100 | accelerator.register_load_state_pre_hook(load_model_hook) 1101 | 1102 | # Enable TF32 for faster training on Ampere GPUs, 1103 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 1104 | if args.allow_tf32: 1105 | torch.backends.cuda.matmul.allow_tf32 = True 1106 | 1107 | if args.scale_lr: 1108 | args.learning_rate = ( 1109 | args.learning_rate 1110 | * args.gradient_accumulation_steps 1111 | * args.train_batch_size 1112 | * accelerator.num_processes 1113 | ) 1114 | 1115 | unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) 1116 | total_params = sum(p.numel() for p in unet_lora_parameters) 1117 | 1118 | # Optimization parameters 1119 | unet_lora_parameters_with_lr = { 1120 | "params": unet_lora_parameters, 1121 | "lr": args.learning_rate, 1122 | } 1123 | params_to_optimize = [unet_lora_parameters_with_lr] 1124 | 1125 | # Optimizer creation 1126 | if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): 1127 | logger.warn( 1128 | f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." 1129 | "Defaulting to adamW" 1130 | ) 1131 | args.optimizer = "adamw" 1132 | 1133 | if args.use_8bit_adam and not args.optimizer.lower() == "adamw": 1134 | logger.warn( 1135 | f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " 1136 | f"set to {args.optimizer.lower()}" 1137 | ) 1138 | 1139 | if args.optimizer.lower() == "adamw": 1140 | if args.use_8bit_adam: 1141 | try: 1142 | import bitsandbytes as bnb 1143 | except ImportError: 1144 | raise ImportError( 1145 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 1146 | ) 1147 | 1148 | optimizer_class = bnb.optim.AdamW8bit 1149 | else: 1150 | optimizer_class = torch.optim.AdamW 1151 | 1152 | optimizer = optimizer_class( 1153 | params_to_optimize, 1154 | betas=(args.adam_beta1, args.adam_beta2), 1155 | weight_decay=args.adam_weight_decay, 1156 | eps=args.adam_epsilon, 1157 | ) 1158 | 1159 | if args.optimizer.lower() == "prodigy": 1160 | try: 1161 | import prodigyopt 1162 | except ImportError: 1163 | raise ImportError( 1164 | "To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`" 1165 | ) 1166 | 1167 | optimizer_class = prodigyopt.Prodigy 1168 | 1169 | if args.learning_rate <= 0.1: 1170 | logger.warn( 1171 | "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" 1172 | ) 1173 | 1174 | optimizer = optimizer_class( 1175 | params_to_optimize, 1176 | lr=args.learning_rate, 1177 | betas=(args.adam_beta1, args.adam_beta2), 1178 | beta3=args.prodigy_beta3, 1179 | weight_decay=args.adam_weight_decay, 1180 | eps=args.adam_epsilon, 1181 | decouple=args.prodigy_decouple, 1182 | use_bias_correction=args.prodigy_use_bias_correction, 1183 | safeguard_warmup=args.prodigy_safeguard_warmup, 1184 | ) 1185 | 1186 | # Dataset and DataLoaders creation: 1187 | train_dataset = DreamBoothDataset( 1188 | instance_data_root=args.instance_data_dir, 1189 | instance_prompt=args.instance_prompt, 1190 | class_prompt=args.class_prompt, 1191 | class_data_root=args.class_data_dir if args.with_prior_preservation else None, 1192 | class_num=args.num_class_images, 1193 | size=args.resolution, 1194 | repeats=args.repeats, 1195 | center_crop=args.center_crop, 1196 | ) 1197 | 1198 | train_dataloader = torch.utils.data.DataLoader( 1199 | train_dataset, 1200 | batch_size=args.train_batch_size, 1201 | shuffle=True, 1202 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 1203 | num_workers=args.dataloader_num_workers, 1204 | ) 1205 | 1206 | # Computes additional embeddings/ids required by the SDXL UNet. 1207 | # regular text embeddings (when `train_text_encoder` is not True) 1208 | # pooled text embeddings 1209 | # time ids 1210 | 1211 | def compute_time_ids(): 1212 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids 1213 | original_size = (args.resolution, args.resolution) 1214 | target_size = (args.resolution, args.resolution) 1215 | crops_coords_top_left = ( 1216 | args.crops_coords_top_left_h, 1217 | args.crops_coords_top_left_w, 1218 | ) 1219 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 1220 | add_time_ids = torch.tensor([add_time_ids]) 1221 | add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) 1222 | return add_time_ids 1223 | 1224 | tokenizers = [tokenizer_one, tokenizer_two] 1225 | text_encoders = [text_encoder_one, text_encoder_two] 1226 | 1227 | def compute_text_embeddings(prompt, text_encoders, tokenizers): 1228 | with torch.no_grad(): 1229 | prompt_embeds, pooled_prompt_embeds = encode_prompt( 1230 | text_encoders, tokenizers, prompt 1231 | ) 1232 | prompt_embeds = prompt_embeds.to(accelerator.device) 1233 | pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) 1234 | return prompt_embeds, pooled_prompt_embeds 1235 | 1236 | # Handle instance prompt. 1237 | instance_time_ids = compute_time_ids() 1238 | 1239 | # If no type of tuning is done on the text_encoder and custom instance prompts are NOT 1240 | # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid 1241 | # the redundant encoding. 1242 | if not train_dataset.custom_instance_prompts: 1243 | ( 1244 | instance_prompt_hidden_states, 1245 | instance_pooled_prompt_embeds, 1246 | ) = compute_text_embeddings(args.instance_prompt, text_encoders, tokenizers) 1247 | 1248 | # Handle class prompt for prior-preservation. 1249 | if args.with_prior_preservation: 1250 | class_time_ids = compute_time_ids() 1251 | ( 1252 | class_prompt_hidden_states, 1253 | class_pooled_prompt_embeds, 1254 | ) = compute_text_embeddings(args.class_prompt, text_encoders, tokenizers) 1255 | 1256 | # Clear the memory here 1257 | if not train_dataset.custom_instance_prompts: 1258 | del tokenizers, text_encoders 1259 | gc.collect() 1260 | torch.cuda.empty_cache() 1261 | 1262 | # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), 1263 | # pack the statically computed variables appropriately here. This is so that we don't 1264 | # have to pass them to the dataloader. 1265 | add_time_ids = instance_time_ids 1266 | if args.with_prior_preservation: 1267 | add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) 1268 | 1269 | if not train_dataset.custom_instance_prompts: 1270 | prompt_embeds = instance_prompt_hidden_states 1271 | unet_add_text_embeds = instance_pooled_prompt_embeds 1272 | if args.with_prior_preservation: 1273 | prompt_embeds = torch.cat( 1274 | [prompt_embeds, class_prompt_hidden_states], dim=0 1275 | ) 1276 | unet_add_text_embeds = torch.cat( 1277 | [unet_add_text_embeds, class_pooled_prompt_embeds], dim=0 1278 | ) 1279 | 1280 | # Scheduler and math around the number of training steps. 1281 | overrode_max_train_steps = False 1282 | num_update_steps_per_epoch = math.ceil( 1283 | len(train_dataloader) / args.gradient_accumulation_steps 1284 | ) 1285 | if args.max_train_steps is None: 1286 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1287 | overrode_max_train_steps = True 1288 | 1289 | lr_scheduler = get_scheduler( 1290 | args.lr_scheduler, 1291 | optimizer=optimizer, 1292 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 1293 | num_training_steps=args.max_train_steps * accelerator.num_processes, 1294 | num_cycles=args.lr_num_cycles, 1295 | power=args.lr_power, 1296 | ) 1297 | 1298 | # Prepare everything with our `accelerator`. 1299 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 1300 | unet, optimizer, train_dataloader, lr_scheduler 1301 | ) 1302 | 1303 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1304 | num_update_steps_per_epoch = math.ceil( 1305 | len(train_dataloader) / args.gradient_accumulation_steps 1306 | ) 1307 | if overrode_max_train_steps: 1308 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1309 | # Afterwards we recalculate our number of training epochs 1310 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 1311 | 1312 | # We need to initialize the trackers we use, and also store our configuration. 1313 | # The trackers initializes automatically on the main process. 1314 | if accelerator.is_main_process: 1315 | accelerator.init_trackers("dreambooth-scedit-sd-xl", config=vars(args)) 1316 | 1317 | # Train! 1318 | total_batch_size = ( 1319 | args.train_batch_size 1320 | * accelerator.num_processes 1321 | * args.gradient_accumulation_steps 1322 | ) 1323 | 1324 | logger.info("***** Running training *****") 1325 | logger.info(f" Num examples = {len(train_dataset)}") 1326 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 1327 | logger.info(f" Num Epochs = {args.num_train_epochs}") 1328 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1329 | logger.info( 1330 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 1331 | ) 1332 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1333 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1334 | logger.info(f" Number of Trainable Parameters = {total_params * 1.e-6:.2f} M") 1335 | 1336 | global_step = 0 1337 | first_epoch = 0 1338 | 1339 | # Potentially load in the weights and states from a previous save 1340 | if args.resume_from_checkpoint: 1341 | if args.resume_from_checkpoint != "latest": 1342 | path = os.path.basename(args.resume_from_checkpoint) 1343 | else: 1344 | # Get the mos recent checkpoint 1345 | dirs = os.listdir(args.output_dir) 1346 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1347 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1348 | path = dirs[-1] if len(dirs) > 0 else None 1349 | 1350 | if path is None: 1351 | accelerator.print( 1352 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1353 | ) 1354 | args.resume_from_checkpoint = None 1355 | initial_global_step = 0 1356 | else: 1357 | accelerator.print(f"Resuming from checkpoint {path}") 1358 | accelerator.load_state(os.path.join(args.output_dir, path)) 1359 | global_step = int(path.split("-")[1]) 1360 | 1361 | initial_global_step = global_step 1362 | first_epoch = global_step // num_update_steps_per_epoch 1363 | 1364 | else: 1365 | initial_global_step = 0 1366 | 1367 | progress_bar = tqdm( 1368 | range(0, args.max_train_steps), 1369 | initial=initial_global_step, 1370 | desc="Steps", 1371 | # Only show the progress bar once on each machine. 1372 | disable=not accelerator.is_local_main_process, 1373 | ) 1374 | 1375 | for epoch in range(first_epoch, args.num_train_epochs): 1376 | unet.train() 1377 | for step, batch in enumerate(train_dataloader): 1378 | with accelerator.accumulate(unet): 1379 | pixel_values = batch["pixel_values"].to(dtype=vae.dtype) 1380 | prompts = batch["prompts"] 1381 | 1382 | # encode batch prompts when custom prompts are provided for each image - 1383 | if train_dataset.custom_instance_prompts: 1384 | prompt_embeds, unet_add_text_embeds = compute_text_embeddings( 1385 | prompts, text_encoders, tokenizers 1386 | ) 1387 | 1388 | # Convert images to latent space 1389 | model_input = vae.encode(pixel_values).latent_dist.sample() 1390 | model_input = model_input * vae.config.scaling_factor 1391 | if args.pretrained_vae_model_name_or_path is None: 1392 | model_input = model_input.to(weight_dtype) 1393 | 1394 | # Sample noise that we'll add to the latents 1395 | noise = torch.randn_like(model_input) 1396 | bsz = model_input.shape[0] 1397 | # Sample a random timestep for each image 1398 | timesteps = torch.randint( 1399 | 0, 1400 | noise_scheduler.config.num_train_timesteps, 1401 | (bsz,), 1402 | device=model_input.device, 1403 | ) 1404 | timesteps = timesteps.long() 1405 | 1406 | # Add noise to the model input according to the noise magnitude at each timestep 1407 | # (this is the forward diffusion process) 1408 | noisy_model_input = noise_scheduler.add_noise( 1409 | model_input, noise, timesteps 1410 | ) 1411 | 1412 | # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. 1413 | if not train_dataset.custom_instance_prompts: 1414 | elems_to_repeat_text_embeds = ( 1415 | bsz // 2 if args.with_prior_preservation else bsz 1416 | ) 1417 | elems_to_repeat_time_ids = ( 1418 | bsz // 2 if args.with_prior_preservation else bsz 1419 | ) 1420 | else: 1421 | elems_to_repeat_text_embeds = 1 1422 | elems_to_repeat_time_ids = ( 1423 | bsz // 2 if args.with_prior_preservation else bsz 1424 | ) 1425 | 1426 | # Predict the noise residual 1427 | unet_added_conditions = { 1428 | "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), 1429 | "text_embeds": unet_add_text_embeds.repeat( 1430 | elems_to_repeat_text_embeds, 1 1431 | ), 1432 | } 1433 | prompt_embeds_input = prompt_embeds.repeat( 1434 | elems_to_repeat_text_embeds, 1, 1 1435 | ) 1436 | model_pred = unet( 1437 | noisy_model_input, 1438 | timesteps, 1439 | prompt_embeds_input, 1440 | added_cond_kwargs=unet_added_conditions, 1441 | ).sample 1442 | 1443 | # Get the target for loss depending on the prediction type 1444 | if noise_scheduler.config.prediction_type == "epsilon": 1445 | target = noise 1446 | elif noise_scheduler.config.prediction_type == "v_prediction": 1447 | target = noise_scheduler.get_velocity(model_input, noise, timesteps) 1448 | else: 1449 | raise ValueError( 1450 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" 1451 | ) 1452 | 1453 | if args.with_prior_preservation: 1454 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 1455 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 1456 | target, target_prior = torch.chunk(target, 2, dim=0) 1457 | 1458 | # Compute prior loss 1459 | prior_loss = F.mse_loss( 1460 | model_pred_prior.float(), target_prior.float(), reduction="mean" 1461 | ) 1462 | 1463 | if args.snr_gamma is None: 1464 | loss = F.mse_loss( 1465 | model_pred.float(), target.float(), reduction="mean" 1466 | ) 1467 | else: 1468 | # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. 1469 | # Since we predict the noise instead of x_0, the original formulation is slightly changed. 1470 | # This is discussed in Section 4.2 of the same paper. 1471 | snr = compute_snr(noise_scheduler, timesteps) 1472 | base_weight = ( 1473 | torch.stack( 1474 | [snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1 1475 | ).min(dim=1)[0] 1476 | / snr 1477 | ) 1478 | 1479 | if noise_scheduler.config.prediction_type == "v_prediction": 1480 | # Velocity objective needs to be floored to an SNR weight of one. 1481 | mse_loss_weights = base_weight + 1 1482 | else: 1483 | # Epsilon and sample both use the same loss weights. 1484 | mse_loss_weights = base_weight 1485 | 1486 | loss = F.mse_loss( 1487 | model_pred.float(), target.float(), reduction="none" 1488 | ) 1489 | loss = ( 1490 | loss.mean(dim=list(range(1, len(loss.shape)))) 1491 | * mse_loss_weights 1492 | ) 1493 | loss = loss.mean() 1494 | 1495 | if args.with_prior_preservation: 1496 | # Add the prior loss to the instance loss. 1497 | loss = loss + args.prior_loss_weight * prior_loss 1498 | 1499 | accelerator.backward(loss) 1500 | if accelerator.sync_gradients: 1501 | params_to_clip = unet_lora_parameters 1502 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1503 | optimizer.step() 1504 | lr_scheduler.step() 1505 | optimizer.zero_grad() 1506 | 1507 | # Checks if the accelerator has performed an optimization step behind the scenes 1508 | if accelerator.sync_gradients: 1509 | progress_bar.update(1) 1510 | global_step += 1 1511 | 1512 | if accelerator.is_main_process: 1513 | if global_step % args.checkpointing_steps == 0: 1514 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1515 | if args.checkpoints_total_limit is not None: 1516 | checkpoints = os.listdir(args.output_dir) 1517 | checkpoints = [ 1518 | d for d in checkpoints if d.startswith("checkpoint") 1519 | ] 1520 | checkpoints = sorted( 1521 | checkpoints, key=lambda x: int(x.split("-")[1]) 1522 | ) 1523 | 1524 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1525 | if len(checkpoints) >= args.checkpoints_total_limit: 1526 | num_to_remove = ( 1527 | len(checkpoints) - args.checkpoints_total_limit + 1 1528 | ) 1529 | removing_checkpoints = checkpoints[0:num_to_remove] 1530 | 1531 | logger.info( 1532 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1533 | ) 1534 | logger.info( 1535 | f"removing checkpoints: {', '.join(removing_checkpoints)}" 1536 | ) 1537 | 1538 | for removing_checkpoint in removing_checkpoints: 1539 | removing_checkpoint = os.path.join( 1540 | args.output_dir, removing_checkpoint 1541 | ) 1542 | shutil.rmtree(removing_checkpoint) 1543 | 1544 | save_path = os.path.join( 1545 | args.output_dir, f"checkpoint-{global_step}" 1546 | ) 1547 | accelerator.save_state(save_path) 1548 | logger.info(f"Saved state to {save_path}") 1549 | 1550 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 1551 | progress_bar.set_postfix(**logs) 1552 | accelerator.log(logs, step=global_step) 1553 | 1554 | if global_step >= args.max_train_steps: 1555 | break 1556 | 1557 | if accelerator.is_main_process: 1558 | if ( 1559 | args.validation_prompt is not None 1560 | and epoch % args.validation_epochs == 0 1561 | ): 1562 | logger.info( 1563 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 1564 | f" {args.validation_prompt}." 1565 | ) 1566 | # create pipeline 1567 | text_encoder_one = text_encoder_cls_one.from_pretrained( 1568 | args.pretrained_model_name_or_path, 1569 | subfolder="text_encoder", 1570 | revision=args.revision, 1571 | variant=args.variant, 1572 | ) 1573 | text_encoder_two = text_encoder_cls_two.from_pretrained( 1574 | args.pretrained_model_name_or_path, 1575 | subfolder="text_encoder_2", 1576 | revision=args.revision, 1577 | variant=args.variant, 1578 | ) 1579 | pipeline = StableDiffusionXLPipeline.from_pretrained( 1580 | args.pretrained_model_name_or_path, 1581 | vae=vae, 1582 | text_encoder=accelerator.unwrap_model(text_encoder_one), 1583 | text_encoder_2=accelerator.unwrap_model(text_encoder_two), 1584 | unet=accelerator.unwrap_model(unet), 1585 | revision=args.revision, 1586 | variant=args.variant, 1587 | torch_dtype=weight_dtype, 1588 | ) 1589 | 1590 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it 1591 | scheduler_args = {} 1592 | 1593 | if "variance_type" in pipeline.scheduler.config: 1594 | variance_type = pipeline.scheduler.config.variance_type 1595 | 1596 | if variance_type in ["learned", "learned_range"]: 1597 | variance_type = "fixed_small" 1598 | 1599 | scheduler_args["variance_type"] = variance_type 1600 | 1601 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config( 1602 | pipeline.scheduler.config, **scheduler_args 1603 | ) 1604 | 1605 | pipeline = pipeline.to(accelerator.device) 1606 | pipeline.set_progress_bar_config(disable=True) 1607 | 1608 | # run inference 1609 | generator = ( 1610 | torch.Generator(device=accelerator.device).manual_seed(args.seed) 1611 | if args.seed 1612 | else None 1613 | ) 1614 | pipeline_args = {"prompt": args.validation_prompt} 1615 | 1616 | images = [ 1617 | pipeline(**pipeline_args, generator=generator).images[0] 1618 | for _ in range(args.num_validation_images) 1619 | ] 1620 | 1621 | for tracker in accelerator.trackers: 1622 | if tracker.name == "tensorboard": 1623 | np_images = np.stack([np.asarray(img) for img in images]) 1624 | tracker.writer.add_images( 1625 | "validation", np_images, epoch, dataformats="NHWC" 1626 | ) 1627 | if tracker.name == "wandb": 1628 | tracker.log( 1629 | { 1630 | "validation": [ 1631 | wandb.Image( 1632 | image, caption=f"{i}: {args.validation_prompt}" 1633 | ) 1634 | for i, image in enumerate(images) 1635 | ] 1636 | } 1637 | ) 1638 | 1639 | del pipeline 1640 | torch.cuda.empty_cache() 1641 | 1642 | # Save the lora layers 1643 | accelerator.wait_for_everyone() 1644 | if accelerator.is_main_process: 1645 | del optimizer 1646 | torch.cuda.empty_cache() 1647 | unet = accelerator.unwrap_model(unet) 1648 | unet = unet.to(torch.float32) 1649 | unet_sctuner_layers_to_save = { 1650 | k: v for k, v in model.state_dict().items() if "sc_tuners" in k 1651 | } 1652 | save_scedit( 1653 | save_directory=args.output_dir, state_dict=unet_sctuner_layers_to_save 1654 | ) 1655 | 1656 | # Final inference 1657 | pipeline = StableDiffusionXLPipeline.from_pretrained( 1658 | args.pretrained_model_name_or_path, 1659 | unet=unet, 1660 | vae=vae, 1661 | revision=args.revision, 1662 | variant=args.variant, 1663 | torch_dtype=weight_dtype, 1664 | ) 1665 | 1666 | # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it 1667 | scheduler_args = {} 1668 | 1669 | if "variance_type" in pipeline.scheduler.config: 1670 | variance_type = pipeline.scheduler.config.variance_type 1671 | 1672 | if variance_type in ["learned", "learned_range"]: 1673 | variance_type = "fixed_small" 1674 | 1675 | scheduler_args["variance_type"] = variance_type 1676 | 1677 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config( 1678 | pipeline.scheduler.config, **scheduler_args 1679 | ) 1680 | 1681 | # run inference 1682 | images = [] 1683 | if args.validation_prompt and args.num_validation_images > 0: 1684 | pipeline = pipeline.to(accelerator.device, dtype=weight_dtype) 1685 | generator = ( 1686 | torch.Generator(device=accelerator.device).manual_seed(args.seed) 1687 | if args.seed 1688 | else None 1689 | ) 1690 | images = [ 1691 | pipeline( 1692 | args.validation_prompt, num_inference_steps=25, generator=generator 1693 | ).images[0] 1694 | for _ in range(args.num_validation_images) 1695 | ] 1696 | 1697 | for tracker in accelerator.trackers: 1698 | if tracker.name == "tensorboard": 1699 | np_images = np.stack([np.asarray(img) for img in images]) 1700 | tracker.writer.add_images( 1701 | "test", np_images, epoch, dataformats="NHWC" 1702 | ) 1703 | if tracker.name == "wandb": 1704 | tracker.log( 1705 | { 1706 | "test": [ 1707 | wandb.Image( 1708 | image, caption=f"{i}: {args.validation_prompt}" 1709 | ) 1710 | for i, image in enumerate(images) 1711 | ] 1712 | } 1713 | ) 1714 | 1715 | if args.push_to_hub: 1716 | save_model_card( 1717 | repo_id, 1718 | images=images, 1719 | base_model=args.pretrained_model_name_or_path, 1720 | instance_prompt=args.instance_prompt, 1721 | validation_prompt=args.validation_prompt, 1722 | repo_folder=args.output_dir, 1723 | vae_path=args.pretrained_vae_model_name_or_path, 1724 | ) 1725 | upload_folder( 1726 | repo_id=repo_id, 1727 | folder_path=args.output_dir, 1728 | commit_message="End of training", 1729 | ignore_patterns=["step_*", "epoch_*"], 1730 | ) 1731 | 1732 | accelerator.end_training() 1733 | 1734 | 1735 | if __name__ == "__main__": 1736 | args = parse_args() 1737 | main(args) 1738 | --------------------------------------------------------------------------------