├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── assets ├── car-result.png ├── chair-result.png ├── dog.png ├── kasumi.png ├── kumamon.png ├── lora.png ├── scale-result.png ├── scale.png └── svdiff.png ├── inference.py ├── paper.png ├── requirements.txt ├── scripts ├── svdiff_pytorch.ipynb ├── train_dreambooth.py └── train_dreambooth_lora.py ├── setup.py ├── svdiff_pytorch ├── __init__.py ├── diffusers_models │ ├── __init__.py │ ├── attention.py │ ├── cross_attention.py │ ├── dual_transformer_2d.py │ ├── embeddings.py │ ├── resnet.py │ ├── transformer_2d.py │ ├── unet_2d_blocks.py │ └── unet_2d_condition.py ├── layers.py ├── pipeline_stable_diffusion_ddim_inversion.py ├── transformers_models_clip │ ├── __init__.py │ └── modeling_clip.py └── utils.py └── train_svdiff.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | # This workflow will upload a Python Package using Twine when a release is created 2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 3 | 4 | # This workflow uses actions that are not certified by GitHub. 5 | # They are provided by a third-party and are governed by 6 | # separate terms of service, privacy policy, and support 7 | # documentation. 8 | 9 | name: Upload Python Package 10 | 11 | on: 12 | release: 13 | types: [published] 14 | 15 | jobs: 16 | deploy: 17 | 18 | runs-on: ubuntu-latest 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: '3.x' 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install build 30 | - name: Build package 31 | run: python -m build 32 | - name: Publish package 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | private 132 | test.ipynb -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "scripts/gradio"] 2 | path = scripts/gradio 3 | url = https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mkshing 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SVDiff-pytorch 2 | Open In Colab 3 | [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI) 4 | 5 | 6 | An implementation of [SVDiff: Compact Parameter Space for Diffusion Fine-Tuning](https://arxiv.org/abs/2303.11305) by using d🧨ffusers. 7 | 8 | My summary tweet is found [here](https://twitter.com/mk1stats/status/1642865505106272257). 9 | 10 | 11 | ![result](assets/dog.png) 12 | left: LoRA, right: SVDiff 13 | 14 | 15 | Compared with LoRA, the number of trainable parameters is 0.5 M less parameters and the file size is only 1.2MB (LoRA: 3.1MB)!! 16 | 17 | ![kumamon](assets/kumamon.png) 18 | 19 | ## Updates 20 | ### 2023.4.11 21 | - Released v0.2.0 (please see [here](https://github.com/mkshing/svdiff-pytorch/releases/tag/v0.2.0) for the details). By this change, you get better results with less training steps than the first release v0.1.1!! 22 | - Add [Single Image Editing](#single-image-editing) 23 |
24 | ![chair-result](assets/chair-result.png) 25 |
"photo of a ~~pink~~ **blue** chair with black legs" (without DDIM Inversion) 26 | 27 | 28 | ## Installation 29 | ``` 30 | $ pip install svdiff-pytorch 31 | ``` 32 | Or, manually 33 | ```bash 34 | $ git clone https://github.com/mkshing/svdiff-pytorch 35 | $ pip install -r requirements.txt 36 | ``` 37 | 38 | ## Single-Subject Generation 39 | "Single-Subject Generation" is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1) 40 | 41 | ### Training 42 | According to the paper, the learning rate for SVDiff needs to be 1000 times larger than the lr used for fine-tuning. 43 | 44 | ```bash 45 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 46 | export INSTANCE_DIR="path-to-instance-images" 47 | export CLASS_DIR="path-to-class-images" 48 | export OUTPUT_DIR="path-to-save-model" 49 | 50 | accelerate launch train_svdiff.py \ 51 | --pretrained_model_name_or_path=$MODEL_NAME \ 52 | --instance_data_dir=$INSTANCE_DIR \ 53 | --class_data_dir=$CLASS_DIR \ 54 | --output_dir=$OUTPUT_DIR \ 55 | --with_prior_preservation --prior_loss_weight=1.0 \ 56 | --instance_prompt="photo of a sks dog" \ 57 | --class_prompt="photo of a dog" \ 58 | --resolution=512 \ 59 | --train_batch_size=1 \ 60 | --gradient_accumulation_steps=1 \ 61 | --learning_rate=1e-3 \ 62 | --learning_rate_1d=1e-6 \ 63 | --train_text_encoder \ 64 | --lr_scheduler="constant" \ 65 | --lr_warmup_steps=0 \ 66 | --num_class_images=200 \ 67 | --max_train_steps=500 68 | ``` 69 | 70 | ### Inference 71 | 72 | ```python 73 | from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler 74 | import torch 75 | 76 | from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff 77 | 78 | pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" 79 | spectral_shifts_ckpt_dir = "ckpt-dir-path" 80 | unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet") 81 | text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder") 82 | # load pipe 83 | pipe = StableDiffusionPipeline.from_pretrained( 84 | pretrained_model_name_or_path, 85 | unet=unet, 86 | text_encoder=text_encoder, 87 | ) 88 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 89 | pipe.to("cuda") 90 | image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0] 91 | ``` 92 | 93 | You can use the following CLI too! Once it's done, you will see `grid.png` for the result. 94 | 95 | ```bash 96 | python inference.py \ 97 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ 98 | --spectral_shifts_ckpt="ckpt-dir-path" \ 99 | --prompt="A picture of a sks dog in a bucket" \ 100 | --scheduler_type="dpm_solver++" \ 101 | --num_inference_steps=25 \ 102 | --num_images_per_prompt=2 103 | ``` 104 | 105 | ## Single Image Editing 106 | ### Training 107 | In Single Image Editing, your instance prompt should be just the description of your input image **without the identifier**. 108 | 109 | ```bash 110 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 111 | export INSTANCE_DIR="dir-path-to-input-image" 112 | export CLASS_DIR="path-to-class-images" 113 | export OUTPUT_DIR="path-to-save-model" 114 | 115 | accelerate launch train_svdiff.py \ 116 | --pretrained_model_name_or_path=$MODEL_NAME \ 117 | --instance_data_dir=$INSTANCE_DIR \ 118 | --class_data_dir=$CLASS_DIR \ 119 | --output_dir=$OUTPUT_DIR \ 120 | --instance_prompt="photo of a pink chair with black legs" \ 121 | --resolution=512 \ 122 | --train_batch_size=1 \ 123 | --gradient_accumulation_steps=1 \ 124 | --learning_rate=1e-3 \ 125 | --learning_rate_1d=1e-6 \ 126 | --train_text_encoder \ 127 | --lr_scheduler="constant" \ 128 | --lr_warmup_steps=0 \ 129 | --max_train_steps=500 130 | ``` 131 | 132 | ### Inference 133 | 134 | ```python 135 | import torch 136 | from PIL import Image 137 | from diffusers import DDIMScheduler 138 | from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, StableDiffusionPipelineWithDDIMInversion 139 | 140 | pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5" 141 | spectral_shifts_ckpt_dir = "ckpt-dir-path" 142 | image = "path-to-image" 143 | source_prompt = "prompt-for-image" 144 | target_prompt = "prompt-you-want-to-generate" 145 | 146 | unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet") 147 | text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder") 148 | # load pipe 149 | pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained( 150 | pretrained_model_name_or_path, 151 | unet=unet, 152 | text_encoder=text_encoder, 153 | ) 154 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) 155 | pipe.to("cuda") 156 | 157 | # (optional) ddim inversion 158 | # if you don't do it, inv_latents = None 159 | image = Image.open(image).convert("RGB").resize((512, 512)) 160 | # in SVDiff, they use guidance scale=1 in ddim inversion 161 | # They use target_prompt in DDIM inversion for better results. See below for comparison between source_prompt and target_prompt. 162 | inv_latents = pipe.invert(target_prompt, image=image, guidance_scale=1.0).latents 163 | 164 | # They use a small cfg scale in Single Image Editing 165 | image = pipe(target_prompt, latents=inv_latents, guidance_scale=3, eta=0.5).images[0] 166 | ``` 167 | 168 | DDIM inversion with target prompt (left) v.s. source prompt (right): 169 |
170 | ![car-result](assets/car-result.png) 171 |
"photo of a grey ~~Beetle~~ **Mustang** car" (original image: https://unsplash.com/photos/YEPDV3T8Vi8) 172 | 173 | To use slerp to add more stochasticity, 174 | ```python 175 | from svdiff_pytorch.utils import slerp_tensor 176 | 177 | # prev steps omitted 178 | inv_latents = pipe.invert(target_prompt, image=image, guidance_scale=1.0).latents 179 | noise_latents = pipe.prepare_latents(inv_latents.shape[0], inv_latents.shape[1], 512, 512, dtype=inv_latents.dtype, device=pipe.device, generator=torch.Generator("cuda").manual_seed(0)) 180 | inv_latents = slerp_tensor(0.5, inv_latents, noise_latents) 181 | image = pipe(target_prompt, latents=inv_latents).images[0] 182 | ``` 183 | 184 | 185 | ## Gradio 186 | You can also try SVDiff-pytorch in a UI with [gradio](https://gradio.app/). This demo supports both training and inference! 187 | 188 | [![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI) 189 | 190 | If you want to run it locally, run the following commands step by step. 191 | ```bash 192 | $ git clone --recursive https://github.com/mkshing/svdiff-pytorch.git 193 | $ cd scripts/gradio 194 | $ pip install -r requirements.txt 195 | $ export HF_TOKEN="YOUR_HF_TOKEN_HERE" 196 | $ python app.py 197 | ``` 198 | 199 | ## Additional Features 200 | 201 | ### Spectral Shift Scaling 202 | 203 | ![scale](assets/scale.png) 204 | 205 | You can adjust the strength of the weights by `--spectral_shifts_scale` 206 | 207 | Here's a result for 0.8, 1.0, 1.2 (1.0 is the default). 208 | ![scale-result](assets/scale-result.png) 209 | 210 | 211 | ### Fast prior generation by using ToMe 212 | By using [ToMe for SD](https://github.com/dbolya/tomesd), the prior generation can be faster! 213 | ``` 214 | $ pip install tomesd 215 | ``` 216 | And, add `--enable_tome_merging` to your training arguments! 217 | 218 | ## Citation 219 | 220 | ```bibtex 221 | @misc{https://doi.org/10.48550/arXiv.2303.11305, 222 | title = {SVDiff: Compact Parameter Space for Diffusion Fine-Tuning}, 223 | author = {Ligong Han and Yinxiao Li and Han Zhang and Peyman Milanfar and Dimitris Metaxas and Feng Yang}, 224 | year = {2023}, 225 | eprint = {2303.11305}, 226 | archivePrefix = {arXiv}, 227 | primaryClass = {cs.CV}, 228 | url = {https://arxiv.org/abs/2303.11305} 229 | } 230 | ``` 231 | 232 | ```bibtex 233 | @misc{hu2021lora, 234 | title = {LoRA: Low-Rank Adaptation of Large Language Models}, 235 | author = {Hu, Edward and Shen, Yelong and Wallis, Phil and Allen-Zhu, Zeyuan and Li, Yuanzhi and Wang, Lu and Chen, Weizhu}, 236 | year = {2021}, 237 | eprint = {2106.09685}, 238 | archivePrefix = {arXiv}, 239 | primaryClass = {cs.CL} 240 | } 241 | ``` 242 | 243 | ```bibtex 244 | @article{bolya2023tomesd, 245 | title = {Token Merging for Fast Stable Diffusion}, 246 | author = {Bolya, Daniel and Hoffman, Judy}, 247 | journal = {arXiv}, 248 | url = {https://arxiv.org/abs/2303.17604}, 249 | year = {2023} 250 | } 251 | ``` 252 | 253 | ## Reference 254 | - [DreamBooth in diffusers](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) 255 | - [DreamBooth in ShivamShrirao](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth) 256 | - [Data from custom-diffusion](https://github.com/adobe-research/custom-diffusion#getting-started) 257 | 258 | ## TODO 259 | - [x] Training 260 | - [x] Inference 261 | - [x] Scaling spectral shifts 262 | - [x] Support Single Image Editing 263 | - [ ] Support multiple spectral shifts (Section 3.2) 264 | - [ ] Cut-Mix-Unmix (Section 3.3) 265 | - [ ] SVDiff + LoRA 266 | -------------------------------------------------------------------------------- /assets/car-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/car-result.png -------------------------------------------------------------------------------- /assets/chair-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/chair-result.png -------------------------------------------------------------------------------- /assets/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/dog.png -------------------------------------------------------------------------------- /assets/kasumi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/kasumi.png -------------------------------------------------------------------------------- /assets/kumamon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/kumamon.png -------------------------------------------------------------------------------- /assets/lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/lora.png -------------------------------------------------------------------------------- /assets/scale-result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/scale-result.png -------------------------------------------------------------------------------- /assets/scale.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/scale.png -------------------------------------------------------------------------------- /assets/svdiff.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/svdiff.png -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from tqdm import tqdm 4 | import random 5 | import torch 6 | import huggingface_hub 7 | from transformers import CLIPTextModel 8 | from diffusers import StableDiffusionPipeline 9 | from diffusers.utils import is_xformers_available 10 | from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid 11 | 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--pretrained_model_name_or_path", type=str, help="pretrained model name or path") 16 | parser.add_argument("--spectral_shifts_ckpt", type=str, help="path to spectral_shifts.safetensors") 17 | # diffusers config 18 | parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render") 19 | parser.add_argument("--num_inference_steps", type=int, default=50, help="number of sampling steps") 20 | parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale") 21 | parser.add_argument("--num_images_per_prompt", type=int, default=1, help="number of images per prompt") 22 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",) 23 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",) 24 | parser.add_argument("--seed", type=str, default="random_seed", help="the seed (for reproducible sampling)") 25 | parser.add_argument("--scheduler_type", type=str, choices=["ddim", "plms", "lms", "euler", "euler_ancestral", "dpm_solver++"], default="ddim", help="diffusion scheduler type") 26 | parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") 27 | parser.add_argument("--spectral_shifts_scale", type=float, default=1.0, help="scaling spectral shifts") 28 | parser.add_argument("--fp16", action="store_true", help="fp16 inference") 29 | args = parser.parse_args() 30 | return args 31 | 32 | 33 | def load_text_encoder(pretrained_model_name_or_path, spectral_shifts_ckpt, device, fp16=False): 34 | if os.path.isdir(spectral_shifts_ckpt): 35 | spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors") 36 | elif not os.path.exists(spectral_shifts_ckpt): 37 | # download from hub 38 | hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs 39 | try: 40 | spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs) 41 | except huggingface_hub.utils.EntryNotFoundError: 42 | return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device) 43 | if not os.path.exists(spectral_shifts_ckpt): 44 | return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device) 45 | text_encoder = load_text_encoder_for_svdiff( 46 | pretrained_model_name_or_path=pretrained_model_name_or_path, 47 | spectral_shifts_ckpt=spectral_shifts_ckpt, 48 | subfolder="text_encoder", 49 | ) 50 | # first perform svd and cache 51 | for module in text_encoder.modules(): 52 | if hasattr(module, "perform_svd"): 53 | module.perform_svd() 54 | if fp16: 55 | text_encoder = text_encoder.to(device, dtype=torch.float16) 56 | return text_encoder 57 | 58 | 59 | 60 | def main(): 61 | args = parse_args() 62 | device = "cuda" if torch.cuda.is_available() else "cpu" 63 | print(f"device: {device}") 64 | # load unet 65 | unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, spectral_shifts_ckpt=args.spectral_shifts_ckpt, subfolder="unet") 66 | unet = unet.to(device) 67 | # first perform svd and cache 68 | for module in unet.modules(): 69 | if hasattr(module, "perform_svd"): 70 | module.perform_svd() 71 | if args.fp16: 72 | unet = unet.to(device, dtype=torch.float16) 73 | text_encoder = load_text_encoder( 74 | pretrained_model_name_or_path=args.pretrained_model_name_or_path, 75 | spectral_shifts_ckpt=args.spectral_shifts_ckpt, 76 | fp16=args.fp16, 77 | device=device 78 | ) 79 | 80 | # load pipe 81 | pipe = StableDiffusionPipeline.from_pretrained( 82 | args.pretrained_model_name_or_path, 83 | unet=unet, 84 | text_encoder=text_encoder, 85 | requires_safety_checker=False, 86 | safety_checker=None, 87 | feature_extractor=None, 88 | scheduler=SCHEDULER_MAPPING[args.scheduler_type].from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"), 89 | torch_dtype=torch.float16 if args.fp16 else None, 90 | ) 91 | if args.enable_xformers_memory_efficient_attention: 92 | assert is_xformers_available() 93 | pipe.enable_xformers_memory_efficient_attention() 94 | print("Using xformers!") 95 | try: 96 | import tomesd 97 | tomesd.apply_patch(pipe, ratio=0.5) 98 | print("Using tomesd!") 99 | except: 100 | pass 101 | pipe = pipe.to(device) 102 | print("loaded pipeline") 103 | # run! 104 | if pipe.unet.conv_out.scale != args.spectral_shifts_scale: 105 | for module in pipe.unet.modules(): 106 | if hasattr(module, "set_scale"): 107 | module.set_scale(scale=args.spectral_shifts_scale) 108 | if not isinstance(pipe.text_encoder, CLIPTextModel): 109 | for module in pipe.text_encoder.modules(): 110 | if hasattr(module, "set_scale"): 111 | module.set_scale(scale=args.spectral_shifts_scale) 112 | 113 | print(f"Set spectral_shifts_scale to {args.spectral_shifts_scale}!") 114 | 115 | if args.seed == "random_seed": 116 | random.seed() 117 | seed = random.randint(0, 2**32) 118 | else: 119 | seed = int(args.seed) 120 | generator = torch.Generator(device=device).manual_seed(seed) 121 | print(f"seed: {seed}") 122 | prompts = args.prompt.split("::") 123 | all_images = [] 124 | for prompt in tqdm(prompts): 125 | with torch.autocast(device), torch.inference_mode(): 126 | images = pipe( 127 | prompt, 128 | num_inference_steps=args.num_inference_steps, 129 | guidance_scale=args.guidance_scale, 130 | generator=generator, 131 | num_images_per_prompt=args.num_images_per_prompt, 132 | height=args.height, 133 | width=args.width, 134 | ).images 135 | all_images.extend(images) 136 | grid_image = image_grid(all_images, len(prompts), args.num_images_per_prompt) 137 | grid_image.save("grid.png") 138 | print("DONE! See `grid.png` for the results!") 139 | 140 | 141 | if __name__ == '__main__': 142 | main() 143 | 144 | -------------------------------------------------------------------------------- /paper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/paper.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.14.0 2 | accelerate 3 | torchvision 4 | safetensors 5 | transformers>=4.25.1, <=4.27.3 6 | ftfy 7 | tensorboard 8 | Jinja2 9 | einops 10 | wandb -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | 4 | setup( 5 | name="svdiff-pytorch", 6 | version="0.2.0", 7 | author="Makoto Shing", 8 | url="https://github.com/mkshing/svdiff-pytorch", 9 | description="Implementation of 'SVDiff: Compact Parameter Space for Diffusion Fine-Tuning'", 10 | install_requires=[ 11 | "diffusers==0.14.0", 12 | "accelerate", 13 | "torchvision", 14 | "safetensors", 15 | "transformers>=4.25.1", 16 | "ftfy", 17 | "tensorboard", 18 | "Jinja2", 19 | "einops", 20 | "wandb" 21 | ], 22 | packages=find_packages(exclude=("examples", "build")), 23 | license = 'MIT', 24 | long_description=open("README.md", "r", encoding="utf-8").read(), 25 | long_description_content_type="text/markdown", 26 | ) -------------------------------------------------------------------------------- /svdiff_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from svdiff_pytorch.diffusers_models.unet_2d_condition import UNet2DConditionModel as UNet2DConditionModelForSVDiff 2 | from svdiff_pytorch.transformers_models_clip.modeling_clip import CLIPTextModel as CLIPTextModelForSVDiff 3 | from svdiff_pytorch.utils import load_unet_for_svdiff, load_text_encoder_for_svdiff, image_grid, SCHEDULER_MAPPING 4 | from svdiff_pytorch.pipeline_stable_diffusion_ddim_inversion import StableDiffusionPipelineWithDDIMInversion 5 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/__init__.py: -------------------------------------------------------------------------------- 1 | # all files in this folder were taken from https://github.com/huggingface/diffusers/tree/main/src/diffusers/models 2 | # so, these files follow the LICENSE of diffusers -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | from typing import Callable, Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from diffusers.utils.import_utils import is_xformers_available 22 | from svdiff_pytorch.diffusers_models.cross_attention import CrossAttention 23 | from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings 24 | from svdiff_pytorch.layers import SVDLinear, SVDGroupNorm, SVDLayerNorm 25 | 26 | 27 | if is_xformers_available(): 28 | import xformers 29 | import xformers.ops 30 | else: 31 | xformers = None 32 | 33 | 34 | class AttentionBlock(nn.Module): 35 | """ 36 | An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted 37 | to the N-d case. 38 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 39 | Uses three q, k, v linear layers to compute attention. 40 | 41 | Parameters: 42 | channels (`int`): The number of channels in the input and output. 43 | num_head_channels (`int`, *optional*): 44 | The number of channels in each head. If None, then `num_heads` = 1. 45 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. 46 | rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. 47 | eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. 48 | """ 49 | 50 | # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore 51 | 52 | def __init__( 53 | self, 54 | channels: int, 55 | num_head_channels: Optional[int] = None, 56 | norm_num_groups: int = 32, 57 | rescale_output_factor: float = 1.0, 58 | eps: float = 1e-5, 59 | ): 60 | super().__init__() 61 | self.channels = channels 62 | 63 | self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 64 | self.num_head_size = num_head_channels 65 | self.group_norm = SVDGroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) 66 | 67 | # define q,k,v as linear layers 68 | self.query = SVDLinear(channels, channels) 69 | self.key = SVDLinear(channels, channels) 70 | self.value = SVDLinear(channels, channels) 71 | 72 | self.rescale_output_factor = rescale_output_factor 73 | self.proj_attn = SVDLinear(channels, channels, 1) 74 | 75 | self._use_memory_efficient_attention_xformers = False 76 | self._attention_op = None 77 | 78 | def reshape_heads_to_batch_dim(self, tensor): 79 | batch_size, seq_len, dim = tensor.shape 80 | head_size = self.num_heads 81 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 82 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 83 | return tensor 84 | 85 | def reshape_batch_dim_to_heads(self, tensor): 86 | batch_size, seq_len, dim = tensor.shape 87 | head_size = self.num_heads 88 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 89 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 90 | return tensor 91 | 92 | def set_use_memory_efficient_attention_xformers( 93 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None 94 | ): 95 | if use_memory_efficient_attention_xformers: 96 | if not is_xformers_available(): 97 | raise ModuleNotFoundError( 98 | ( 99 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 100 | " xformers" 101 | ), 102 | name="xformers", 103 | ) 104 | elif not torch.cuda.is_available(): 105 | raise ValueError( 106 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 107 | " only available for GPU " 108 | ) 109 | else: 110 | try: 111 | # Make sure we can run the memory efficient attention 112 | _ = xformers.ops.memory_efficient_attention( 113 | torch.randn((1, 2, 40), device="cuda"), 114 | torch.randn((1, 2, 40), device="cuda"), 115 | torch.randn((1, 2, 40), device="cuda"), 116 | ) 117 | except Exception as e: 118 | raise e 119 | self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers 120 | self._attention_op = attention_op 121 | 122 | def forward(self, hidden_states): 123 | residual = hidden_states 124 | batch, channel, height, width = hidden_states.shape 125 | 126 | # norm 127 | hidden_states = self.group_norm(hidden_states) 128 | 129 | hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) 130 | 131 | # proj to q, k, v 132 | query_proj = self.query(hidden_states) 133 | key_proj = self.key(hidden_states) 134 | value_proj = self.value(hidden_states) 135 | 136 | scale = 1 / math.sqrt(self.channels / self.num_heads) 137 | 138 | query_proj = self.reshape_heads_to_batch_dim(query_proj) 139 | key_proj = self.reshape_heads_to_batch_dim(key_proj) 140 | value_proj = self.reshape_heads_to_batch_dim(value_proj) 141 | 142 | if self._use_memory_efficient_attention_xformers: 143 | # Memory efficient attention 144 | hidden_states = xformers.ops.memory_efficient_attention( 145 | query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op 146 | ) 147 | hidden_states = hidden_states.to(query_proj.dtype) 148 | else: 149 | attention_scores = torch.baddbmm( 150 | torch.empty( 151 | query_proj.shape[0], 152 | query_proj.shape[1], 153 | key_proj.shape[1], 154 | dtype=query_proj.dtype, 155 | device=query_proj.device, 156 | ), 157 | query_proj, 158 | key_proj.transpose(-1, -2), 159 | beta=0, 160 | alpha=scale, 161 | ) 162 | attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) 163 | hidden_states = torch.bmm(attention_probs, value_proj) 164 | 165 | # reshape hidden_states 166 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 167 | 168 | # compute next hidden_states 169 | hidden_states = self.proj_attn(hidden_states) 170 | 171 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) 172 | 173 | # res connect and rescale 174 | hidden_states = (hidden_states + residual) / self.rescale_output_factor 175 | return hidden_states 176 | 177 | 178 | class BasicTransformerBlock(nn.Module): 179 | r""" 180 | A basic Transformer block. 181 | 182 | Parameters: 183 | dim (`int`): The number of channels in the input and output. 184 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 185 | attention_head_dim (`int`): The number of channels in each head. 186 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 187 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 188 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 189 | num_embeds_ada_norm (: 190 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 191 | attention_bias (: 192 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 193 | """ 194 | 195 | def __init__( 196 | self, 197 | dim: int, 198 | num_attention_heads: int, 199 | attention_head_dim: int, 200 | dropout=0.0, 201 | cross_attention_dim: Optional[int] = None, 202 | activation_fn: str = "geglu", 203 | num_embeds_ada_norm: Optional[int] = None, 204 | attention_bias: bool = False, 205 | only_cross_attention: bool = False, 206 | upcast_attention: bool = False, 207 | norm_elementwise_affine: bool = True, 208 | norm_type: str = "layer_norm", 209 | final_dropout: bool = False, 210 | ): 211 | super().__init__() 212 | self.only_cross_attention = only_cross_attention 213 | 214 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 215 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 216 | 217 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 218 | raise ValueError( 219 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 220 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 221 | ) 222 | 223 | # 1. Self-Attn 224 | self.attn1 = CrossAttention( 225 | query_dim=dim, 226 | heads=num_attention_heads, 227 | dim_head=attention_head_dim, 228 | dropout=dropout, 229 | bias=attention_bias, 230 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 231 | upcast_attention=upcast_attention, 232 | ) 233 | 234 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout) 235 | 236 | # 2. Cross-Attn 237 | if cross_attention_dim is not None: 238 | self.attn2 = CrossAttention( 239 | query_dim=dim, 240 | cross_attention_dim=cross_attention_dim, 241 | heads=num_attention_heads, 242 | dim_head=attention_head_dim, 243 | dropout=dropout, 244 | bias=attention_bias, 245 | upcast_attention=upcast_attention, 246 | ) # is self-attn if encoder_hidden_states is none 247 | else: 248 | self.attn2 = None 249 | 250 | if self.use_ada_layer_norm: 251 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 252 | elif self.use_ada_layer_norm_zero: 253 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 254 | else: 255 | self.norm1 = SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine) 256 | 257 | if cross_attention_dim is not None: 258 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 259 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 260 | # the second cross attention block. 261 | self.norm2 = ( 262 | AdaLayerNorm(dim, num_embeds_ada_norm) 263 | if self.use_ada_layer_norm 264 | else SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine) 265 | ) 266 | else: 267 | self.norm2 = None 268 | 269 | # 3. Feed-forward 270 | self.norm3 = SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine) 271 | 272 | def forward( 273 | self, 274 | hidden_states, 275 | encoder_hidden_states=None, 276 | timestep=None, 277 | attention_mask=None, 278 | cross_attention_kwargs=None, 279 | class_labels=None, 280 | ): 281 | if self.use_ada_layer_norm: 282 | norm_hidden_states = self.norm1(hidden_states, timestep) 283 | elif self.use_ada_layer_norm_zero: 284 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 285 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 286 | ) 287 | else: 288 | norm_hidden_states = self.norm1(hidden_states) 289 | 290 | # 1. Self-Attention 291 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} 292 | attn_output = self.attn1( 293 | norm_hidden_states, 294 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 295 | attention_mask=attention_mask, 296 | **cross_attention_kwargs, 297 | ) 298 | if self.use_ada_layer_norm_zero: 299 | attn_output = gate_msa.unsqueeze(1) * attn_output 300 | hidden_states = attn_output + hidden_states 301 | 302 | if self.attn2 is not None: 303 | norm_hidden_states = ( 304 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) 305 | ) 306 | 307 | # 2. Cross-Attention 308 | attn_output = self.attn2( 309 | norm_hidden_states, 310 | encoder_hidden_states=encoder_hidden_states, 311 | attention_mask=attention_mask, 312 | **cross_attention_kwargs, 313 | ) 314 | hidden_states = attn_output + hidden_states 315 | 316 | # 3. Feed-forward 317 | norm_hidden_states = self.norm3(hidden_states) 318 | 319 | if self.use_ada_layer_norm_zero: 320 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 321 | 322 | ff_output = self.ff(norm_hidden_states) 323 | 324 | if self.use_ada_layer_norm_zero: 325 | ff_output = gate_mlp.unsqueeze(1) * ff_output 326 | 327 | hidden_states = ff_output + hidden_states 328 | 329 | return hidden_states 330 | 331 | 332 | class FeedForward(nn.Module): 333 | r""" 334 | A feed-forward layer. 335 | 336 | Parameters: 337 | dim (`int`): The number of channels in the input. 338 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 339 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 340 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 341 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 342 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 343 | """ 344 | 345 | def __init__( 346 | self, 347 | dim: int, 348 | dim_out: Optional[int] = None, 349 | mult: int = 4, 350 | dropout: float = 0.0, 351 | activation_fn: str = "geglu", 352 | final_dropout: bool = False, 353 | ): 354 | super().__init__() 355 | inner_dim = int(dim * mult) 356 | dim_out = dim_out if dim_out is not None else dim 357 | 358 | if activation_fn == "gelu": 359 | act_fn = GELU(dim, inner_dim) 360 | if activation_fn == "gelu-approximate": 361 | act_fn = GELU(dim, inner_dim, approximate="tanh") 362 | elif activation_fn == "geglu": 363 | act_fn = GEGLU(dim, inner_dim) 364 | elif activation_fn == "geglu-approximate": 365 | act_fn = ApproximateGELU(dim, inner_dim) 366 | 367 | self.net = nn.ModuleList([]) 368 | # project in 369 | self.net.append(act_fn) 370 | # project dropout 371 | self.net.append(nn.Dropout(dropout)) 372 | # project out 373 | self.net.append(SVDLinear(inner_dim, dim_out)) 374 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 375 | if final_dropout: 376 | self.net.append(nn.Dropout(dropout)) 377 | 378 | def forward(self, hidden_states): 379 | for module in self.net: 380 | hidden_states = module(hidden_states) 381 | return hidden_states 382 | 383 | 384 | class GELU(nn.Module): 385 | r""" 386 | GELU activation function with tanh approximation support with `approximate="tanh"`. 387 | """ 388 | 389 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"): 390 | super().__init__() 391 | self.proj = SVDLinear(dim_in, dim_out) 392 | self.approximate = approximate 393 | 394 | def gelu(self, gate): 395 | if gate.device.type != "mps": 396 | return F.gelu(gate, approximate=self.approximate) 397 | # mps: gelu is not implemented for float16 398 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) 399 | 400 | def forward(self, hidden_states): 401 | hidden_states = self.proj(hidden_states) 402 | hidden_states = self.gelu(hidden_states) 403 | return hidden_states 404 | 405 | 406 | class GEGLU(nn.Module): 407 | r""" 408 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. 409 | 410 | Parameters: 411 | dim_in (`int`): The number of channels in the input. 412 | dim_out (`int`): The number of channels in the output. 413 | """ 414 | 415 | def __init__(self, dim_in: int, dim_out: int): 416 | super().__init__() 417 | self.proj = SVDLinear(dim_in, dim_out * 2) 418 | 419 | def gelu(self, gate): 420 | if gate.device.type != "mps": 421 | return F.gelu(gate) 422 | # mps: gelu is not implemented for float16 423 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) 424 | 425 | def forward(self, hidden_states): 426 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) 427 | return hidden_states * self.gelu(gate) 428 | 429 | 430 | class ApproximateGELU(nn.Module): 431 | """ 432 | The approximate form of Gaussian Error Linear Unit (GELU) 433 | 434 | For more details, see section 2: https://arxiv.org/abs/1606.08415 435 | """ 436 | 437 | def __init__(self, dim_in: int, dim_out: int): 438 | super().__init__() 439 | self.proj = SVDLinear(dim_in, dim_out) 440 | 441 | def forward(self, x): 442 | x = self.proj(x) 443 | return x * torch.sigmoid(1.702 * x) 444 | 445 | 446 | class AdaLayerNorm(nn.Module): 447 | """ 448 | Norm layer modified to incorporate timestep embeddings. 449 | """ 450 | 451 | def __init__(self, embedding_dim, num_embeddings): 452 | super().__init__() 453 | self.emb = nn.Embedding(num_embeddings, embedding_dim) 454 | self.silu = nn.SiLU() 455 | self.linear = SVDLinear(embedding_dim, embedding_dim * 2) 456 | self.norm = SVDLayerNorm(embedding_dim, elementwise_affine=False) 457 | 458 | def forward(self, x, timestep): 459 | emb = self.linear(self.silu(self.emb(timestep))) 460 | scale, shift = torch.chunk(emb, 2) 461 | x = self.norm(x) * (1 + scale) + shift 462 | return x 463 | 464 | 465 | class AdaLayerNormZero(nn.Module): 466 | """ 467 | Norm layer adaptive layer norm zero (adaLN-Zero). 468 | """ 469 | 470 | def __init__(self, embedding_dim, num_embeddings): 471 | super().__init__() 472 | 473 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim) 474 | 475 | self.silu = nn.SiLU() 476 | self.linear = SVDLinear(embedding_dim, 6 * embedding_dim, bias=True) 477 | self.norm = SVDLayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) 478 | 479 | def forward(self, x, timestep, class_labels, hidden_dtype=None): 480 | emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype))) 481 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1) 482 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None] 483 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp 484 | 485 | 486 | class AdaGroupNorm(nn.Module): 487 | """ 488 | GroupNorm layer modified to incorporate timestep embeddings. 489 | """ 490 | 491 | def __init__( 492 | self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5 493 | ): 494 | super().__init__() 495 | self.num_groups = num_groups 496 | self.eps = eps 497 | self.act = None 498 | if act_fn == "swish": 499 | self.act = lambda x: F.silu(x) 500 | elif act_fn == "mish": 501 | self.act = nn.Mish() 502 | elif act_fn == "silu": 503 | self.act = nn.SiLU() 504 | elif act_fn == "gelu": 505 | self.act = nn.GELU() 506 | 507 | self.linear = SVDLinear(embedding_dim, out_dim * 2) 508 | 509 | def forward(self, x, emb): 510 | if self.act: 511 | emb = self.act(emb) 512 | emb = self.linear(emb) 513 | emb = emb[:, :, None, None] 514 | scale, shift = emb.chunk(2, dim=1) 515 | 516 | x = F.group_norm(x, self.num_groups, eps=self.eps) 517 | x = x * (1 + scale) + shift 518 | return x 519 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/cross_attention.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Callable, Optional, Union 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import deprecate, logging 21 | from diffusers.utils.import_utils import is_xformers_available 22 | from svdiff_pytorch.layers import SVDLinear, SVDGroupNorm, SVDLayerNorm 23 | 24 | 25 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 26 | 27 | 28 | if is_xformers_available(): 29 | import xformers 30 | import xformers.ops 31 | else: 32 | xformers = None 33 | 34 | 35 | class CrossAttention(nn.Module): 36 | r""" 37 | A cross attention layer. 38 | 39 | Parameters: 40 | query_dim (`int`): The number of channels in the query. 41 | cross_attention_dim (`int`, *optional*): 42 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. 43 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. 44 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. 45 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 46 | bias (`bool`, *optional*, defaults to False): 47 | Set to `True` for the query, key, and value linear layers to contain a bias parameter. 48 | """ 49 | 50 | def __init__( 51 | self, 52 | query_dim: int, 53 | cross_attention_dim: Optional[int] = None, 54 | heads: int = 8, 55 | dim_head: int = 64, 56 | dropout: float = 0.0, 57 | bias=False, 58 | upcast_attention: bool = False, 59 | upcast_softmax: bool = False, 60 | cross_attention_norm: bool = False, 61 | added_kv_proj_dim: Optional[int] = None, 62 | norm_num_groups: Optional[int] = None, 63 | processor: Optional["AttnProcessor"] = None, 64 | ): 65 | super().__init__() 66 | inner_dim = dim_head * heads 67 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim 68 | self.upcast_attention = upcast_attention 69 | self.upcast_softmax = upcast_softmax 70 | self.cross_attention_norm = cross_attention_norm 71 | 72 | self.scale = dim_head**-0.5 73 | 74 | self.heads = heads 75 | # for slice_size > 0 the attention score computation 76 | # is split across the batch axis to save memory 77 | # You can set slice_size with `set_attention_slice` 78 | self.sliceable_head_dim = heads 79 | 80 | self.added_kv_proj_dim = added_kv_proj_dim 81 | 82 | if norm_num_groups is not None: 83 | self.group_norm = SVDGroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) 84 | else: 85 | self.group_norm = None 86 | 87 | if cross_attention_norm: 88 | self.norm_cross = SVDLayerNorm(cross_attention_dim) 89 | 90 | self.to_q = SVDLinear(query_dim, inner_dim, bias=bias) 91 | self.to_k = SVDLinear(cross_attention_dim, inner_dim, bias=bias) 92 | self.to_v = SVDLinear(cross_attention_dim, inner_dim, bias=bias) 93 | 94 | if self.added_kv_proj_dim is not None: 95 | self.add_k_proj = SVDLinear(added_kv_proj_dim, cross_attention_dim) 96 | self.add_v_proj = SVDLinear(added_kv_proj_dim, cross_attention_dim) 97 | 98 | self.to_out = nn.ModuleList([]) 99 | self.to_out.append(SVDLinear(inner_dim, query_dim)) 100 | self.to_out.append(nn.Dropout(dropout)) 101 | 102 | # set attention processor 103 | # We use the AttnProcessor2_0 by default when torch2.x is used which uses 104 | # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention 105 | if processor is None: 106 | processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor() 107 | self.set_processor(processor) 108 | 109 | def set_use_memory_efficient_attention_xformers( 110 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None 111 | ): 112 | is_lora = hasattr(self, "processor") and isinstance( 113 | self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor) 114 | ) 115 | 116 | if use_memory_efficient_attention_xformers: 117 | if self.added_kv_proj_dim is not None: 118 | # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP 119 | # which uses this type of cross attention ONLY because the attention mask of format 120 | # [0, ..., -10.000, ..., 0, ...,] is not supported 121 | raise NotImplementedError( 122 | "Memory efficient attention with `xformers` is currently not supported when" 123 | " `self.added_kv_proj_dim` is defined." 124 | ) 125 | elif not is_xformers_available(): 126 | raise ModuleNotFoundError( 127 | ( 128 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install" 129 | " xformers" 130 | ), 131 | name="xformers", 132 | ) 133 | elif not torch.cuda.is_available(): 134 | raise ValueError( 135 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is" 136 | " only available for GPU " 137 | ) 138 | else: 139 | try: 140 | # Make sure we can run the memory efficient attention 141 | _ = xformers.ops.memory_efficient_attention( 142 | torch.randn((1, 2, 40), device="cuda"), 143 | torch.randn((1, 2, 40), device="cuda"), 144 | torch.randn((1, 2, 40), device="cuda"), 145 | ) 146 | except Exception as e: 147 | raise e 148 | 149 | if is_lora: 150 | processor = LoRAXFormersCrossAttnProcessor( 151 | hidden_size=self.processor.hidden_size, 152 | cross_attention_dim=self.processor.cross_attention_dim, 153 | rank=self.processor.rank, 154 | attention_op=attention_op, 155 | ) 156 | processor.load_state_dict(self.processor.state_dict()) 157 | processor.to(self.processor.to_q_lora.up.weight.device) 158 | else: 159 | processor = XFormersCrossAttnProcessor(attention_op=attention_op) 160 | else: 161 | if is_lora: 162 | processor = LoRACrossAttnProcessor( 163 | hidden_size=self.processor.hidden_size, 164 | cross_attention_dim=self.processor.cross_attention_dim, 165 | rank=self.processor.rank, 166 | ) 167 | processor.load_state_dict(self.processor.state_dict()) 168 | processor.to(self.processor.to_q_lora.up.weight.device) 169 | else: 170 | processor = CrossAttnProcessor() 171 | 172 | self.set_processor(processor) 173 | 174 | def set_attention_slice(self, slice_size): 175 | if slice_size is not None and slice_size > self.sliceable_head_dim: 176 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") 177 | 178 | if slice_size is not None and self.added_kv_proj_dim is not None: 179 | processor = SlicedAttnAddedKVProcessor(slice_size) 180 | elif slice_size is not None: 181 | processor = SlicedAttnProcessor(slice_size) 182 | elif self.added_kv_proj_dim is not None: 183 | processor = CrossAttnAddedKVProcessor() 184 | else: 185 | processor = CrossAttnProcessor() 186 | 187 | self.set_processor(processor) 188 | 189 | def set_processor(self, processor: "AttnProcessor"): 190 | # if current processor is in `self._modules` and if passed `processor` is not, we need to 191 | # pop `processor` from `self._modules` 192 | if ( 193 | hasattr(self, "processor") 194 | and isinstance(self.processor, torch.nn.Module) 195 | and not isinstance(processor, torch.nn.Module) 196 | ): 197 | logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}") 198 | self._modules.pop("processor") 199 | 200 | self.processor = processor 201 | 202 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): 203 | # The `CrossAttention` class can call different attention processors / attention functions 204 | # here we simply pass along all tensors to the selected processor class 205 | # For standard processors that are defined here, `**cross_attention_kwargs` is empty 206 | return self.processor( 207 | self, 208 | hidden_states, 209 | encoder_hidden_states=encoder_hidden_states, 210 | attention_mask=attention_mask, 211 | **cross_attention_kwargs, 212 | ) 213 | 214 | def batch_to_head_dim(self, tensor): 215 | head_size = self.heads 216 | batch_size, seq_len, dim = tensor.shape 217 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) 218 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) 219 | return tensor 220 | 221 | def head_to_batch_dim(self, tensor): 222 | head_size = self.heads 223 | batch_size, seq_len, dim = tensor.shape 224 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) 225 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) 226 | return tensor 227 | 228 | def get_attention_scores(self, query, key, attention_mask=None): 229 | dtype = query.dtype 230 | if self.upcast_attention: 231 | query = query.float() 232 | key = key.float() 233 | 234 | if attention_mask is None: 235 | baddbmm_input = torch.empty( 236 | query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device 237 | ) 238 | beta = 0 239 | else: 240 | baddbmm_input = attention_mask 241 | beta = 1 242 | 243 | attention_scores = torch.baddbmm( 244 | baddbmm_input, 245 | query, 246 | key.transpose(-1, -2), 247 | beta=beta, 248 | alpha=self.scale, 249 | ) 250 | 251 | if self.upcast_softmax: 252 | attention_scores = attention_scores.float() 253 | 254 | attention_probs = attention_scores.softmax(dim=-1) 255 | attention_probs = attention_probs.to(dtype) 256 | 257 | return attention_probs 258 | 259 | def prepare_attention_mask(self, attention_mask, target_length, batch_size=None): 260 | if batch_size is None: 261 | deprecate( 262 | "batch_size=None", 263 | "0.0.15", 264 | message=( 265 | "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect" 266 | " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to" 267 | " `prepare_attention_mask` when preparing the attention_mask." 268 | ), 269 | ) 270 | batch_size = 1 271 | 272 | head_size = self.heads 273 | if attention_mask is None: 274 | return attention_mask 275 | 276 | if attention_mask.shape[-1] != target_length: 277 | if attention_mask.device.type == "mps": 278 | # HACK: MPS: Does not support padding by greater than dimension of input tensor. 279 | # Instead, we can manually construct the padding tensor. 280 | padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length) 281 | padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device) 282 | attention_mask = torch.cat([attention_mask, padding], dim=2) 283 | else: 284 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) 285 | 286 | if attention_mask.shape[0] < batch_size * head_size: 287 | attention_mask = attention_mask.repeat_interleave(head_size, dim=0) 288 | return attention_mask 289 | 290 | 291 | class CrossAttnProcessor: 292 | def __call__( 293 | self, 294 | attn: CrossAttention, 295 | hidden_states, 296 | encoder_hidden_states=None, 297 | attention_mask=None, 298 | ): 299 | batch_size, sequence_length, _ = hidden_states.shape 300 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 301 | query = attn.to_q(hidden_states) 302 | 303 | if encoder_hidden_states is None: 304 | encoder_hidden_states = hidden_states 305 | elif attn.cross_attention_norm: 306 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 307 | 308 | key = attn.to_k(encoder_hidden_states) 309 | value = attn.to_v(encoder_hidden_states) 310 | 311 | query = attn.head_to_batch_dim(query) 312 | key = attn.head_to_batch_dim(key) 313 | value = attn.head_to_batch_dim(value) 314 | 315 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 316 | hidden_states = torch.bmm(attention_probs, value) 317 | hidden_states = attn.batch_to_head_dim(hidden_states) 318 | 319 | # linear proj 320 | hidden_states = attn.to_out[0](hidden_states) 321 | # dropout 322 | hidden_states = attn.to_out[1](hidden_states) 323 | 324 | return hidden_states 325 | 326 | 327 | class LoRALinearLayer(nn.Module): 328 | def __init__(self, in_features, out_features, rank=4): 329 | super().__init__() 330 | 331 | if rank > min(in_features, out_features): 332 | raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") 333 | 334 | self.down = SVDLinear(in_features, rank, bias=False) 335 | self.up = SVDLinear(rank, out_features, bias=False) 336 | 337 | nn.init.normal_(self.down.weight, std=1 / rank) 338 | nn.init.zeros_(self.up.weight) 339 | 340 | def forward(self, hidden_states): 341 | orig_dtype = hidden_states.dtype 342 | dtype = self.down.weight.dtype 343 | 344 | down_hidden_states = self.down(hidden_states.to(dtype)) 345 | up_hidden_states = self.up(down_hidden_states) 346 | 347 | return up_hidden_states.to(orig_dtype) 348 | 349 | 350 | class LoRACrossAttnProcessor(nn.Module): 351 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): 352 | super().__init__() 353 | 354 | self.hidden_size = hidden_size 355 | self.cross_attention_dim = cross_attention_dim 356 | self.rank = rank 357 | 358 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) 359 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) 360 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) 361 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) 362 | 363 | def __call__( 364 | self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 365 | ): 366 | batch_size, sequence_length, _ = hidden_states.shape 367 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 368 | 369 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) 370 | query = attn.head_to_batch_dim(query) 371 | 372 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 373 | 374 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) 375 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) 376 | 377 | key = attn.head_to_batch_dim(key) 378 | value = attn.head_to_batch_dim(value) 379 | 380 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 381 | hidden_states = torch.bmm(attention_probs, value) 382 | hidden_states = attn.batch_to_head_dim(hidden_states) 383 | 384 | # linear proj 385 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) 386 | # dropout 387 | hidden_states = attn.to_out[1](hidden_states) 388 | 389 | return hidden_states 390 | 391 | 392 | class CrossAttnAddedKVProcessor: 393 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 394 | residual = hidden_states 395 | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) 396 | batch_size, sequence_length, _ = hidden_states.shape 397 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2) 398 | 399 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 400 | 401 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 402 | 403 | query = attn.to_q(hidden_states) 404 | query = attn.head_to_batch_dim(query) 405 | 406 | key = attn.to_k(hidden_states) 407 | value = attn.to_v(hidden_states) 408 | key = attn.head_to_batch_dim(key) 409 | value = attn.head_to_batch_dim(value) 410 | 411 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 412 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 413 | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) 414 | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) 415 | 416 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) 417 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) 418 | 419 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 420 | hidden_states = torch.bmm(attention_probs, value) 421 | hidden_states = attn.batch_to_head_dim(hidden_states) 422 | 423 | # linear proj 424 | hidden_states = attn.to_out[0](hidden_states) 425 | # dropout 426 | hidden_states = attn.to_out[1](hidden_states) 427 | 428 | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) 429 | hidden_states = hidden_states + residual 430 | 431 | return hidden_states 432 | 433 | 434 | class XFormersCrossAttnProcessor: 435 | def __init__(self, attention_op: Optional[Callable] = None): 436 | self.attention_op = attention_op 437 | 438 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 439 | batch_size, sequence_length, _ = hidden_states.shape 440 | 441 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 442 | 443 | query = attn.to_q(hidden_states) 444 | 445 | if encoder_hidden_states is None: 446 | encoder_hidden_states = hidden_states 447 | elif attn.cross_attention_norm: 448 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 449 | 450 | key = attn.to_k(encoder_hidden_states) 451 | value = attn.to_v(encoder_hidden_states) 452 | 453 | query = attn.head_to_batch_dim(query).contiguous() 454 | key = attn.head_to_batch_dim(key).contiguous() 455 | value = attn.head_to_batch_dim(value).contiguous() 456 | 457 | hidden_states = xformers.ops.memory_efficient_attention( 458 | query, key, value, attn_bias=attention_mask, op=self.attention_op 459 | ) 460 | hidden_states = hidden_states.to(query.dtype) 461 | hidden_states = attn.batch_to_head_dim(hidden_states) 462 | 463 | # linear proj 464 | hidden_states = attn.to_out[0](hidden_states) 465 | # dropout 466 | hidden_states = attn.to_out[1](hidden_states) 467 | return hidden_states 468 | 469 | 470 | class AttnProcessor2_0: 471 | def __init__(self): 472 | if not hasattr(F, "scaled_dot_product_attention"): 473 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 474 | 475 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 476 | batch_size, sequence_length, inner_dim = hidden_states.shape 477 | 478 | if attention_mask is not None: 479 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 480 | # scaled_dot_product_attention expects attention_mask shape to be 481 | # (batch, heads, source_length, target_length) 482 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 483 | 484 | query = attn.to_q(hidden_states) 485 | 486 | if encoder_hidden_states is None: 487 | encoder_hidden_states = hidden_states 488 | elif attn.cross_attention_norm: 489 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 490 | 491 | key = attn.to_k(encoder_hidden_states) 492 | value = attn.to_v(encoder_hidden_states) 493 | 494 | head_dim = inner_dim // attn.heads 495 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 496 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 497 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 498 | 499 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 500 | hidden_states = F.scaled_dot_product_attention( 501 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 502 | ) 503 | 504 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 505 | hidden_states = hidden_states.to(query.dtype) 506 | 507 | # linear proj 508 | hidden_states = attn.to_out[0](hidden_states) 509 | # dropout 510 | hidden_states = attn.to_out[1](hidden_states) 511 | return hidden_states 512 | 513 | 514 | class LoRAXFormersCrossAttnProcessor(nn.Module): 515 | def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): 516 | super().__init__() 517 | 518 | self.hidden_size = hidden_size 519 | self.cross_attention_dim = cross_attention_dim 520 | self.rank = rank 521 | self.attention_op = attention_op 522 | 523 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) 524 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) 525 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) 526 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) 527 | 528 | def __call__( 529 | self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0 530 | ): 531 | batch_size, sequence_length, _ = hidden_states.shape 532 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 533 | 534 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) 535 | query = attn.head_to_batch_dim(query).contiguous() 536 | 537 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 538 | 539 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) 540 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) 541 | 542 | key = attn.head_to_batch_dim(key).contiguous() 543 | value = attn.head_to_batch_dim(value).contiguous() 544 | 545 | hidden_states = xformers.ops.memory_efficient_attention( 546 | query, key, value, attn_bias=attention_mask, op=self.attention_op 547 | ) 548 | hidden_states = attn.batch_to_head_dim(hidden_states) 549 | 550 | # linear proj 551 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) 552 | # dropout 553 | hidden_states = attn.to_out[1](hidden_states) 554 | 555 | return hidden_states 556 | 557 | 558 | class SlicedAttnProcessor: 559 | def __init__(self, slice_size): 560 | self.slice_size = slice_size 561 | 562 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 563 | batch_size, sequence_length, _ = hidden_states.shape 564 | 565 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 566 | 567 | query = attn.to_q(hidden_states) 568 | dim = query.shape[-1] 569 | query = attn.head_to_batch_dim(query) 570 | 571 | if encoder_hidden_states is None: 572 | encoder_hidden_states = hidden_states 573 | elif attn.cross_attention_norm: 574 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states) 575 | 576 | key = attn.to_k(encoder_hidden_states) 577 | value = attn.to_v(encoder_hidden_states) 578 | key = attn.head_to_batch_dim(key) 579 | value = attn.head_to_batch_dim(value) 580 | 581 | batch_size_attention = query.shape[0] 582 | hidden_states = torch.zeros( 583 | (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype 584 | ) 585 | 586 | for i in range(hidden_states.shape[0] // self.slice_size): 587 | start_idx = i * self.slice_size 588 | end_idx = (i + 1) * self.slice_size 589 | 590 | query_slice = query[start_idx:end_idx] 591 | key_slice = key[start_idx:end_idx] 592 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None 593 | 594 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 595 | 596 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 597 | 598 | hidden_states[start_idx:end_idx] = attn_slice 599 | 600 | hidden_states = attn.batch_to_head_dim(hidden_states) 601 | 602 | # linear proj 603 | hidden_states = attn.to_out[0](hidden_states) 604 | # dropout 605 | hidden_states = attn.to_out[1](hidden_states) 606 | 607 | return hidden_states 608 | 609 | 610 | class SlicedAttnAddedKVProcessor: 611 | def __init__(self, slice_size): 612 | self.slice_size = slice_size 613 | 614 | def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None): 615 | residual = hidden_states 616 | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2) 617 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2) 618 | 619 | batch_size, sequence_length, _ = hidden_states.shape 620 | 621 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 622 | 623 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 624 | 625 | query = attn.to_q(hidden_states) 626 | dim = query.shape[-1] 627 | query = attn.head_to_batch_dim(query) 628 | 629 | key = attn.to_k(hidden_states) 630 | value = attn.to_v(hidden_states) 631 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) 632 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) 633 | 634 | key = attn.head_to_batch_dim(key) 635 | value = attn.head_to_batch_dim(value) 636 | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj) 637 | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj) 638 | 639 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1) 640 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1) 641 | 642 | batch_size_attention = query.shape[0] 643 | hidden_states = torch.zeros( 644 | (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype 645 | ) 646 | 647 | for i in range(hidden_states.shape[0] // self.slice_size): 648 | start_idx = i * self.slice_size 649 | end_idx = (i + 1) * self.slice_size 650 | 651 | query_slice = query[start_idx:end_idx] 652 | key_slice = key[start_idx:end_idx] 653 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None 654 | 655 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice) 656 | 657 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) 658 | 659 | hidden_states[start_idx:end_idx] = attn_slice 660 | 661 | hidden_states = attn.batch_to_head_dim(hidden_states) 662 | 663 | # linear proj 664 | hidden_states = attn.to_out[0](hidden_states) 665 | # dropout 666 | hidden_states = attn.to_out[1](hidden_states) 667 | 668 | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape) 669 | hidden_states = hidden_states + residual 670 | 671 | return hidden_states 672 | 673 | 674 | AttnProcessor = Union[ 675 | CrossAttnProcessor, 676 | XFormersCrossAttnProcessor, 677 | SlicedAttnProcessor, 678 | CrossAttnAddedKVProcessor, 679 | SlicedAttnAddedKVProcessor, 680 | LoRACrossAttnProcessor, 681 | LoRAXFormersCrossAttnProcessor, 682 | ] 683 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/dual_transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Optional 15 | 16 | from torch import nn 17 | 18 | from svdiff_pytorch.diffusers_models.transformer_2d import Transformer2DModel, Transformer2DModelOutput 19 | 20 | 21 | class DualTransformer2DModel(nn.Module): 22 | """ 23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. 24 | 25 | Parameters: 26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 28 | in_channels (`int`, *optional*): 29 | Pass if the input is continuous. The number of channels in the input and output. 30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. 32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 35 | `ImagePositionalEmbeddings`. 36 | num_vector_embeds (`int`, *optional*): 37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 38 | Includes the class for the masked latent pixel. 39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 43 | up to but not more than steps than `num_embeds_ada_norm`. 44 | attention_bias (`bool`, *optional*): 45 | Configure if the TransformerBlocks' attention should contain a bias parameter. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | num_attention_heads: int = 16, 51 | attention_head_dim: int = 88, 52 | in_channels: Optional[int] = None, 53 | num_layers: int = 1, 54 | dropout: float = 0.0, 55 | norm_num_groups: int = 32, 56 | cross_attention_dim: Optional[int] = None, 57 | attention_bias: bool = False, 58 | sample_size: Optional[int] = None, 59 | num_vector_embeds: Optional[int] = None, 60 | activation_fn: str = "geglu", 61 | num_embeds_ada_norm: Optional[int] = None, 62 | ): 63 | super().__init__() 64 | self.transformers = nn.ModuleList( 65 | [ 66 | Transformer2DModel( 67 | num_attention_heads=num_attention_heads, 68 | attention_head_dim=attention_head_dim, 69 | in_channels=in_channels, 70 | num_layers=num_layers, 71 | dropout=dropout, 72 | norm_num_groups=norm_num_groups, 73 | cross_attention_dim=cross_attention_dim, 74 | attention_bias=attention_bias, 75 | sample_size=sample_size, 76 | num_vector_embeds=num_vector_embeds, 77 | activation_fn=activation_fn, 78 | num_embeds_ada_norm=num_embeds_ada_norm, 79 | ) 80 | for _ in range(2) 81 | ] 82 | ) 83 | 84 | # Variables that can be set by a pipeline: 85 | 86 | # The ratio of transformer1 to transformer2's output states to be combined during inference 87 | self.mix_ratio = 0.5 88 | 89 | # The shape of `encoder_hidden_states` is expected to be 90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` 91 | self.condition_lengths = [77, 257] 92 | 93 | # Which transformer to use to encode which condition. 94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` 95 | self.transformer_index_for_condition = [1, 0] 96 | 97 | def forward( 98 | self, 99 | hidden_states, 100 | encoder_hidden_states, 101 | timestep=None, 102 | attention_mask=None, 103 | cross_attention_kwargs=None, 104 | return_dict: bool = True, 105 | ): 106 | """ 107 | Args: 108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 110 | hidden_states 111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 113 | self-attention. 114 | timestep ( `torch.long`, *optional*): 115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 116 | attention_mask (`torch.FloatTensor`, *optional*): 117 | Optional attention mask to be applied in CrossAttention 118 | return_dict (`bool`, *optional*, defaults to `True`): 119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 120 | 121 | Returns: 122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: 123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When 124 | returning a tuple, the first element is the sample tensor. 125 | """ 126 | input_states = hidden_states 127 | 128 | encoded_states = [] 129 | tokens_start = 0 130 | # attention_mask is not used yet 131 | for i in range(2): 132 | # for each of the two transformers, pass the corresponding condition tokens 133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] 134 | transformer_index = self.transformer_index_for_condition[i] 135 | encoded_state = self.transformers[transformer_index]( 136 | input_states, 137 | encoder_hidden_states=condition_state, 138 | timestep=timestep, 139 | cross_attention_kwargs=cross_attention_kwargs, 140 | return_dict=False, 141 | )[0] 142 | encoded_states.append(encoded_state - input_states) 143 | tokens_start += self.condition_lengths[i] 144 | 145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) 146 | output_states = output_states + input_states 147 | 148 | if not return_dict: 149 | return (output_states,) 150 | 151 | return Transformer2DModelOutput(sample=output_states) 152 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import math 15 | from typing import Optional 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn 20 | from svdiff_pytorch.layers import SVDLinear 21 | 22 | 23 | def get_timestep_embedding( 24 | timesteps: torch.Tensor, 25 | embedding_dim: int, 26 | flip_sin_to_cos: bool = False, 27 | downscale_freq_shift: float = 1, 28 | scale: float = 1, 29 | max_period: int = 10000, 30 | ): 31 | """ 32 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. 33 | 34 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 35 | These may be fractional. 36 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the 37 | embeddings. :return: an [N x dim] Tensor of positional embeddings. 38 | """ 39 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" 40 | 41 | half_dim = embedding_dim // 2 42 | exponent = -math.log(max_period) * torch.arange( 43 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device 44 | ) 45 | exponent = exponent / (half_dim - downscale_freq_shift) 46 | 47 | emb = torch.exp(exponent) 48 | emb = timesteps[:, None].float() * emb[None, :] 49 | 50 | # scale embeddings 51 | emb = scale * emb 52 | 53 | # concat sine and cosine embeddings 54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 55 | 56 | # flip sine and cosine embeddings 57 | if flip_sin_to_cos: 58 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) 59 | 60 | # zero pad 61 | if embedding_dim % 2 == 1: 62 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 63 | return emb 64 | 65 | 66 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 67 | """ 68 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or 69 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 70 | """ 71 | grid_h = np.arange(grid_size, dtype=np.float32) 72 | grid_w = np.arange(grid_size, dtype=np.float32) 73 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 74 | grid = np.stack(grid, axis=0) 75 | 76 | grid = grid.reshape([2, 1, grid_size, grid_size]) 77 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 78 | if cls_token and extra_tokens > 0: 79 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 80 | return pos_embed 81 | 82 | 83 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 84 | if embed_dim % 2 != 0: 85 | raise ValueError("embed_dim must be divisible by 2") 86 | 87 | # use half of dimensions to encode grid_h 88 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 89 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 90 | 91 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 92 | return emb 93 | 94 | 95 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 96 | """ 97 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) 98 | """ 99 | if embed_dim % 2 != 0: 100 | raise ValueError("embed_dim must be divisible by 2") 101 | 102 | omega = np.arange(embed_dim // 2, dtype=np.float64) 103 | omega /= embed_dim / 2.0 104 | omega = 1.0 / 10000**omega # (D/2,) 105 | 106 | pos = pos.reshape(-1) # (M,) 107 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product 108 | 109 | emb_sin = np.sin(out) # (M, D/2) 110 | emb_cos = np.cos(out) # (M, D/2) 111 | 112 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 113 | return emb 114 | 115 | 116 | class PatchEmbed(nn.Module): 117 | """2D Image to Patch Embedding""" 118 | 119 | def __init__( 120 | self, 121 | height=224, 122 | width=224, 123 | patch_size=16, 124 | in_channels=3, 125 | embed_dim=768, 126 | layer_norm=False, 127 | flatten=True, 128 | bias=True, 129 | ): 130 | super().__init__() 131 | 132 | num_patches = (height // patch_size) * (width // patch_size) 133 | self.flatten = flatten 134 | self.layer_norm = layer_norm 135 | 136 | self.proj = nn.Conv2d( 137 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias 138 | ) 139 | if layer_norm: 140 | self.norm = SVDLayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) 141 | else: 142 | self.norm = None 143 | 144 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5)) 145 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False) 146 | 147 | def forward(self, latent): 148 | latent = self.proj(latent) 149 | if self.flatten: 150 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC 151 | if self.layer_norm: 152 | latent = self.norm(latent) 153 | return latent + self.pos_embed 154 | 155 | 156 | class TimestepEmbedding(nn.Module): 157 | def __init__( 158 | self, 159 | in_channels: int, 160 | time_embed_dim: int, 161 | act_fn: str = "silu", 162 | out_dim: int = None, 163 | post_act_fn: Optional[str] = None, 164 | cond_proj_dim=None, 165 | ): 166 | super().__init__() 167 | 168 | self.linear_1 = SVDLinear(in_channels, time_embed_dim) 169 | 170 | if cond_proj_dim is not None: 171 | self.cond_proj = SVDLinear(cond_proj_dim, in_channels, bias=False) 172 | else: 173 | self.cond_proj = None 174 | 175 | if act_fn == "silu": 176 | self.act = nn.SiLU() 177 | elif act_fn == "mish": 178 | self.act = nn.Mish() 179 | elif act_fn == "gelu": 180 | self.act = nn.GELU() 181 | else: 182 | raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") 183 | 184 | if out_dim is not None: 185 | time_embed_dim_out = out_dim 186 | else: 187 | time_embed_dim_out = time_embed_dim 188 | self.linear_2 = SVDLinear(time_embed_dim, time_embed_dim_out) 189 | 190 | if post_act_fn is None: 191 | self.post_act = None 192 | elif post_act_fn == "silu": 193 | self.post_act = nn.SiLU() 194 | elif post_act_fn == "mish": 195 | self.post_act = nn.Mish() 196 | elif post_act_fn == "gelu": 197 | self.post_act = nn.GELU() 198 | else: 199 | raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'") 200 | 201 | def forward(self, sample, condition=None): 202 | if condition is not None: 203 | sample = sample + self.cond_proj(condition) 204 | sample = self.linear_1(sample) 205 | 206 | if self.act is not None: 207 | sample = self.act(sample) 208 | 209 | sample = self.linear_2(sample) 210 | 211 | if self.post_act is not None: 212 | sample = self.post_act(sample) 213 | return sample 214 | 215 | 216 | class Timesteps(nn.Module): 217 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float): 218 | super().__init__() 219 | self.num_channels = num_channels 220 | self.flip_sin_to_cos = flip_sin_to_cos 221 | self.downscale_freq_shift = downscale_freq_shift 222 | 223 | def forward(self, timesteps): 224 | t_emb = get_timestep_embedding( 225 | timesteps, 226 | self.num_channels, 227 | flip_sin_to_cos=self.flip_sin_to_cos, 228 | downscale_freq_shift=self.downscale_freq_shift, 229 | ) 230 | return t_emb 231 | 232 | 233 | class GaussianFourierProjection(nn.Module): 234 | """Gaussian Fourier embeddings for noise levels.""" 235 | 236 | def __init__( 237 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False 238 | ): 239 | super().__init__() 240 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 241 | self.log = log 242 | self.flip_sin_to_cos = flip_sin_to_cos 243 | 244 | if set_W_to_weight: 245 | # to delete later 246 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) 247 | 248 | self.weight = self.W 249 | 250 | def forward(self, x): 251 | if self.log: 252 | x = torch.log(x) 253 | 254 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi 255 | 256 | if self.flip_sin_to_cos: 257 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) 258 | else: 259 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) 260 | return out 261 | 262 | 263 | class ImagePositionalEmbeddings(nn.Module): 264 | """ 265 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the 266 | height and width of the latent space. 267 | 268 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 269 | 270 | For VQ-diffusion: 271 | 272 | Output vector embeddings are used as input for the transformer. 273 | 274 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. 275 | 276 | Args: 277 | num_embed (`int`): 278 | Number of embeddings for the latent pixels embeddings. 279 | height (`int`): 280 | Height of the latent image i.e. the number of height embeddings. 281 | width (`int`): 282 | Width of the latent image i.e. the number of width embeddings. 283 | embed_dim (`int`): 284 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. 285 | """ 286 | 287 | def __init__( 288 | self, 289 | num_embed: int, 290 | height: int, 291 | width: int, 292 | embed_dim: int, 293 | ): 294 | super().__init__() 295 | 296 | self.height = height 297 | self.width = width 298 | self.num_embed = num_embed 299 | self.embed_dim = embed_dim 300 | 301 | self.emb = nn.Embedding(self.num_embed, embed_dim) 302 | self.height_emb = nn.Embedding(self.height, embed_dim) 303 | self.width_emb = nn.Embedding(self.width, embed_dim) 304 | 305 | def forward(self, index): 306 | emb = self.emb(index) 307 | 308 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) 309 | 310 | # 1 x H x D -> 1 x H x 1 x D 311 | height_emb = height_emb.unsqueeze(2) 312 | 313 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) 314 | 315 | # 1 x W x D -> 1 x 1 x W x D 316 | width_emb = width_emb.unsqueeze(1) 317 | 318 | pos_emb = height_emb + width_emb 319 | 320 | # 1 x H x W x D -> 1 x L xD 321 | pos_emb = pos_emb.view(1, self.height * self.width, -1) 322 | 323 | emb = emb + pos_emb[:, : emb.shape[1], :] 324 | 325 | return emb 326 | 327 | 328 | class LabelEmbedding(nn.Module): 329 | """ 330 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 331 | 332 | Args: 333 | num_classes (`int`): The number of classes. 334 | hidden_size (`int`): The size of the vector embeddings. 335 | dropout_prob (`float`): The probability of dropping a label. 336 | """ 337 | 338 | def __init__(self, num_classes, hidden_size, dropout_prob): 339 | super().__init__() 340 | use_cfg_embedding = dropout_prob > 0 341 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) 342 | self.num_classes = num_classes 343 | self.dropout_prob = dropout_prob 344 | 345 | def token_drop(self, labels, force_drop_ids=None): 346 | """ 347 | Drops labels to enable classifier-free guidance. 348 | """ 349 | if force_drop_ids is None: 350 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob 351 | else: 352 | drop_ids = torch.tensor(force_drop_ids == 1) 353 | labels = torch.where(drop_ids, self.num_classes, labels) 354 | return labels 355 | 356 | def forward(self, labels, force_drop_ids=None): 357 | use_dropout = self.dropout_prob > 0 358 | if (self.training and use_dropout) or (force_drop_ids is not None): 359 | labels = self.token_drop(labels, force_drop_ids) 360 | embeddings = self.embedding_table(labels) 361 | return embeddings 362 | 363 | 364 | class CombinedTimestepLabelEmbeddings(nn.Module): 365 | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): 366 | super().__init__() 367 | 368 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) 369 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) 370 | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) 371 | 372 | def forward(self, timestep, class_labels, hidden_dtype=None): 373 | timesteps_proj = self.time_proj(timestep) 374 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) 375 | 376 | class_labels = self.class_embedder(class_labels) # (N, D) 377 | 378 | conditioning = timesteps_emb + class_labels # (N, D) 379 | 380 | return conditioning 381 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/resnet.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Optional 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from svdiff_pytorch.diffusers_models.attention import AdaGroupNorm 9 | from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm 10 | 11 | 12 | class Upsample1D(nn.Module): 13 | """ 14 | An upsampling layer with an optional convolution. 15 | 16 | Parameters: 17 | channels: channels in the inputs and outputs. 18 | use_conv: a bool determining if a convolution is applied. 19 | use_conv_transpose: 20 | out_channels: 21 | """ 22 | 23 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 24 | super().__init__() 25 | self.channels = channels 26 | self.out_channels = out_channels or channels 27 | self.use_conv = use_conv 28 | self.use_conv_transpose = use_conv_transpose 29 | self.name = name 30 | 31 | self.conv = None 32 | if use_conv_transpose: 33 | self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) 34 | elif use_conv: 35 | self.conv = SVDConv1d(self.channels, self.out_channels, 3, padding=1) 36 | 37 | def forward(self, x): 38 | assert x.shape[1] == self.channels 39 | if self.use_conv_transpose: 40 | return self.conv(x) 41 | 42 | x = F.interpolate(x, scale_factor=2.0, mode="nearest") 43 | 44 | if self.use_conv: 45 | x = self.conv(x) 46 | 47 | return x 48 | 49 | 50 | class Downsample1D(nn.Module): 51 | """ 52 | A downsampling layer with an optional convolution. 53 | 54 | Parameters: 55 | channels: channels in the inputs and outputs. 56 | use_conv: a bool determining if a convolution is applied. 57 | out_channels: 58 | padding: 59 | """ 60 | 61 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 62 | super().__init__() 63 | self.channels = channels 64 | self.out_channels = out_channels or channels 65 | self.use_conv = use_conv 66 | self.padding = padding 67 | stride = 2 68 | self.name = name 69 | 70 | if use_conv: 71 | self.conv = SVDConv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 72 | else: 73 | assert self.channels == self.out_channels 74 | self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) 75 | 76 | def forward(self, x): 77 | assert x.shape[1] == self.channels 78 | return self.conv(x) 79 | 80 | 81 | class Upsample2D(nn.Module): 82 | """ 83 | An upsampling layer with an optional convolution. 84 | 85 | Parameters: 86 | channels: channels in the inputs and outputs. 87 | use_conv: a bool determining if a convolution is applied. 88 | use_conv_transpose: 89 | out_channels: 90 | """ 91 | 92 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): 93 | super().__init__() 94 | self.channels = channels 95 | self.out_channels = out_channels or channels 96 | self.use_conv = use_conv 97 | self.use_conv_transpose = use_conv_transpose 98 | self.name = name 99 | 100 | conv = None 101 | if use_conv_transpose: 102 | conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1) 103 | elif use_conv: 104 | conv = SVDConv2d(self.channels, self.out_channels, 3, padding=1) 105 | 106 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 107 | if name == "conv": 108 | self.conv = conv 109 | else: 110 | self.Conv2d_0 = conv 111 | 112 | def forward(self, hidden_states, output_size=None): 113 | assert hidden_states.shape[1] == self.channels 114 | 115 | if self.use_conv_transpose: 116 | return self.conv(hidden_states) 117 | 118 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 119 | # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch 120 | # https://github.com/pytorch/pytorch/issues/86679 121 | dtype = hidden_states.dtype 122 | if dtype == torch.bfloat16: 123 | hidden_states = hidden_states.to(torch.float32) 124 | 125 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 126 | if hidden_states.shape[0] >= 64: 127 | hidden_states = hidden_states.contiguous() 128 | 129 | # if `output_size` is passed we force the interpolation output 130 | # size and do not make use of `scale_factor=2` 131 | if output_size is None: 132 | hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") 133 | else: 134 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") 135 | 136 | # If the input is bfloat16, we cast back to bfloat16 137 | if dtype == torch.bfloat16: 138 | hidden_states = hidden_states.to(dtype) 139 | 140 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 141 | if self.use_conv: 142 | if self.name == "conv": 143 | hidden_states = self.conv(hidden_states) 144 | else: 145 | hidden_states = self.Conv2d_0(hidden_states) 146 | 147 | return hidden_states 148 | 149 | 150 | class Downsample2D(nn.Module): 151 | """ 152 | A downsampling layer with an optional convolution. 153 | 154 | Parameters: 155 | channels: channels in the inputs and outputs. 156 | use_conv: a bool determining if a convolution is applied. 157 | out_channels: 158 | padding: 159 | """ 160 | 161 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): 162 | super().__init__() 163 | self.channels = channels 164 | self.out_channels = out_channels or channels 165 | self.use_conv = use_conv 166 | self.padding = padding 167 | stride = 2 168 | self.name = name 169 | 170 | if use_conv: 171 | conv = SVDConv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding) 172 | else: 173 | assert self.channels == self.out_channels 174 | conv = nn.AvgPool2d(kernel_size=stride, stride=stride) 175 | 176 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed 177 | if name == "conv": 178 | self.Conv2d_0 = conv 179 | self.conv = conv 180 | elif name == "Conv2d_0": 181 | self.conv = conv 182 | else: 183 | self.conv = conv 184 | 185 | def forward(self, hidden_states): 186 | assert hidden_states.shape[1] == self.channels 187 | if self.use_conv and self.padding == 0: 188 | pad = (0, 1, 0, 1) 189 | hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) 190 | 191 | assert hidden_states.shape[1] == self.channels 192 | hidden_states = self.conv(hidden_states) 193 | 194 | return hidden_states 195 | 196 | 197 | class FirUpsample2D(nn.Module): 198 | def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): 199 | super().__init__() 200 | out_channels = out_channels if out_channels else channels 201 | if use_conv: 202 | self.Conv2d_0 = SVDConv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) 203 | self.use_conv = use_conv 204 | self.fir_kernel = fir_kernel 205 | self.out_channels = out_channels 206 | 207 | def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): 208 | """Fused `upsample_2d()` followed by `Conv2d()`. 209 | 210 | Padding is performed only once at the beginning, not between the operations. The fused op is considerably more 211 | efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of 212 | arbitrary order. 213 | 214 | Args: 215 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 216 | weight: Weight tensor of the shape `[filterH, filterW, inChannels, 217 | outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. 218 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]` 219 | (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. 220 | factor: Integer upsampling factor (default: 2). 221 | gain: Scaling factor for signal magnitude (default: 1.0). 222 | 223 | Returns: 224 | output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same 225 | datatype as `hidden_states`. 226 | """ 227 | 228 | assert isinstance(factor, int) and factor >= 1 229 | 230 | # Setup filter kernel. 231 | if kernel is None: 232 | kernel = [1] * factor 233 | 234 | # setup kernel 235 | kernel = torch.tensor(kernel, dtype=torch.float32) 236 | if kernel.ndim == 1: 237 | kernel = torch.outer(kernel, kernel) 238 | kernel /= torch.sum(kernel) 239 | 240 | kernel = kernel * (gain * (factor**2)) 241 | 242 | if self.use_conv: 243 | convH = weight.shape[2] 244 | convW = weight.shape[3] 245 | inC = weight.shape[1] 246 | 247 | pad_value = (kernel.shape[0] - factor) - (convW - 1) 248 | 249 | stride = (factor, factor) 250 | # Determine data dimensions. 251 | output_shape = ( 252 | (hidden_states.shape[2] - 1) * factor + convH, 253 | (hidden_states.shape[3] - 1) * factor + convW, 254 | ) 255 | output_padding = ( 256 | output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH, 257 | output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW, 258 | ) 259 | assert output_padding[0] >= 0 and output_padding[1] >= 0 260 | num_groups = hidden_states.shape[1] // inC 261 | 262 | # Transpose weights. 263 | weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW)) 264 | weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4) 265 | weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW)) 266 | 267 | inverse_conv = F.conv_transpose2d( 268 | hidden_states, weight, stride=stride, output_padding=output_padding, padding=0 269 | ) 270 | 271 | output = upfirdn2d_native( 272 | inverse_conv, 273 | torch.tensor(kernel, device=inverse_conv.device), 274 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1), 275 | ) 276 | else: 277 | pad_value = kernel.shape[0] - factor 278 | output = upfirdn2d_native( 279 | hidden_states, 280 | torch.tensor(kernel, device=hidden_states.device), 281 | up=factor, 282 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), 283 | ) 284 | 285 | return output 286 | 287 | def forward(self, hidden_states): 288 | if self.use_conv: 289 | height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel) 290 | height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1) 291 | else: 292 | height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) 293 | 294 | return height 295 | 296 | 297 | class FirDownsample2D(nn.Module): 298 | def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)): 299 | super().__init__() 300 | out_channels = out_channels if out_channels else channels 301 | if use_conv: 302 | self.Conv2d_0 = SVDConv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) 303 | self.fir_kernel = fir_kernel 304 | self.use_conv = use_conv 305 | self.out_channels = out_channels 306 | 307 | def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1): 308 | """Fused `Conv2d()` followed by `downsample_2d()`. 309 | Padding is performed only once at the beginning, not between the operations. The fused op is considerably more 310 | efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of 311 | arbitrary order. 312 | 313 | Args: 314 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 315 | weight: 316 | Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be 317 | performed by `inChannels = x.shape[0] // numGroups`. 318 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * 319 | factor`, which corresponds to average pooling. 320 | factor: Integer downsampling factor (default: 2). 321 | gain: Scaling factor for signal magnitude (default: 1.0). 322 | 323 | Returns: 324 | output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and 325 | same datatype as `x`. 326 | """ 327 | 328 | assert isinstance(factor, int) and factor >= 1 329 | if kernel is None: 330 | kernel = [1] * factor 331 | 332 | # setup kernel 333 | kernel = torch.tensor(kernel, dtype=torch.float32) 334 | if kernel.ndim == 1: 335 | kernel = torch.outer(kernel, kernel) 336 | kernel /= torch.sum(kernel) 337 | 338 | kernel = kernel * gain 339 | 340 | if self.use_conv: 341 | _, _, convH, convW = weight.shape 342 | pad_value = (kernel.shape[0] - factor) + (convW - 1) 343 | stride_value = [factor, factor] 344 | upfirdn_input = upfirdn2d_native( 345 | hidden_states, 346 | torch.tensor(kernel, device=hidden_states.device), 347 | pad=((pad_value + 1) // 2, pad_value // 2), 348 | ) 349 | output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) 350 | else: 351 | pad_value = kernel.shape[0] - factor 352 | output = upfirdn2d_native( 353 | hidden_states, 354 | torch.tensor(kernel, device=hidden_states.device), 355 | down=factor, 356 | pad=((pad_value + 1) // 2, pad_value // 2), 357 | ) 358 | 359 | return output 360 | 361 | def forward(self, hidden_states): 362 | if self.use_conv: 363 | downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) 364 | hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) 365 | else: 366 | hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) 367 | 368 | return hidden_states 369 | 370 | 371 | # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead 372 | class KDownsample2D(nn.Module): 373 | def __init__(self, pad_mode="reflect"): 374 | super().__init__() 375 | self.pad_mode = pad_mode 376 | kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) 377 | self.pad = kernel_1d.shape[1] // 2 - 1 378 | self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) 379 | 380 | def forward(self, x): 381 | x = F.pad(x, (self.pad,) * 4, self.pad_mode) 382 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 383 | indices = torch.arange(x.shape[1], device=x.device) 384 | weight[indices, indices] = self.kernel.to(weight) 385 | return F.conv2d(x, weight, stride=2) 386 | 387 | 388 | class KUpsample2D(nn.Module): 389 | def __init__(self, pad_mode="reflect"): 390 | super().__init__() 391 | self.pad_mode = pad_mode 392 | kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2 393 | self.pad = kernel_1d.shape[1] // 2 - 1 394 | self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) 395 | 396 | def forward(self, x): 397 | x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode) 398 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]]) 399 | indices = torch.arange(x.shape[1], device=x.device) 400 | weight[indices, indices] = self.kernel.to(weight) 401 | return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1) 402 | 403 | 404 | class ResnetBlock2D(nn.Module): 405 | r""" 406 | A Resnet block. 407 | 408 | Parameters: 409 | in_channels (`int`): The number of channels in the input. 410 | out_channels (`int`, *optional*, default to be `None`): 411 | The number of output channels for the first conv2d layer. If None, same as `in_channels`. 412 | dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. 413 | temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding. 414 | groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. 415 | groups_out (`int`, *optional*, default to None): 416 | The number of groups to use for the second normalization layer. if set to None, same as `groups`. 417 | eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. 418 | non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use. 419 | time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config. 420 | By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or 421 | "ada_group" for a stronger conditioning with scale and shift. 422 | kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see 423 | [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`]. 424 | output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output. 425 | use_in_shortcut (`bool`, *optional*, default to `True`): 426 | If `True`, add a 1x1 nn.conv2d layer for skip-connection. 427 | up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer. 428 | down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer. 429 | conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the 430 | `conv_shortcut` output. 431 | conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output. 432 | If None, same as `out_channels`. 433 | """ 434 | 435 | def __init__( 436 | self, 437 | *, 438 | in_channels, 439 | out_channels=None, 440 | conv_shortcut=False, 441 | dropout=0.0, 442 | temb_channels=512, 443 | groups=32, 444 | groups_out=None, 445 | pre_norm=True, 446 | eps=1e-6, 447 | non_linearity="swish", 448 | time_embedding_norm="default", # default, scale_shift, ada_group 449 | kernel=None, 450 | output_scale_factor=1.0, 451 | use_in_shortcut=None, 452 | up=False, 453 | down=False, 454 | conv_shortcut_bias: bool = True, 455 | conv_2d_out_channels: Optional[int] = None, 456 | ): 457 | super().__init__() 458 | self.pre_norm = pre_norm 459 | self.pre_norm = True 460 | self.in_channels = in_channels 461 | out_channels = in_channels if out_channels is None else out_channels 462 | self.out_channels = out_channels 463 | self.use_conv_shortcut = conv_shortcut 464 | self.up = up 465 | self.down = down 466 | self.output_scale_factor = output_scale_factor 467 | self.time_embedding_norm = time_embedding_norm 468 | 469 | if groups_out is None: 470 | groups_out = groups 471 | 472 | if self.time_embedding_norm == "ada_group": 473 | self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps) 474 | else: 475 | self.norm1 = SVDGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) 476 | 477 | self.conv1 = SVDConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 478 | 479 | if temb_channels is not None: 480 | if self.time_embedding_norm == "default": 481 | self.time_emb_proj = SVDLinear(temb_channels, out_channels) 482 | elif self.time_embedding_norm == "scale_shift": 483 | self.time_emb_proj = SVDLinear(temb_channels, 2 * out_channels) 484 | elif self.time_embedding_norm == "ada_group": 485 | self.time_emb_proj = None 486 | else: 487 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") 488 | else: 489 | self.time_emb_proj = None 490 | 491 | if self.time_embedding_norm == "ada_group": 492 | self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps) 493 | else: 494 | self.norm2 = SVDGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) 495 | 496 | self.dropout = torch.nn.Dropout(dropout) 497 | conv_2d_out_channels = conv_2d_out_channels or out_channels 498 | self.conv2 = SVDConv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1) 499 | 500 | if non_linearity == "swish": 501 | self.nonlinearity = lambda x: F.silu(x) 502 | elif non_linearity == "mish": 503 | self.nonlinearity = nn.Mish() 504 | elif non_linearity == "silu": 505 | self.nonlinearity = nn.SiLU() 506 | elif non_linearity == "gelu": 507 | self.nonlinearity = nn.GELU() 508 | 509 | self.upsample = self.downsample = None 510 | if self.up: 511 | if kernel == "fir": 512 | fir_kernel = (1, 3, 3, 1) 513 | self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel) 514 | elif kernel == "sde_vp": 515 | self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest") 516 | else: 517 | self.upsample = Upsample2D(in_channels, use_conv=False) 518 | elif self.down: 519 | if kernel == "fir": 520 | fir_kernel = (1, 3, 3, 1) 521 | self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel) 522 | elif kernel == "sde_vp": 523 | self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2) 524 | else: 525 | self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op") 526 | 527 | self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut 528 | 529 | self.conv_shortcut = None 530 | if self.use_in_shortcut: 531 | self.conv_shortcut = SVDConv2d( 532 | in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias 533 | ) 534 | 535 | def forward(self, input_tensor, temb): 536 | hidden_states = input_tensor 537 | 538 | if self.time_embedding_norm == "ada_group": 539 | hidden_states = self.norm1(hidden_states, temb) 540 | else: 541 | hidden_states = self.norm1(hidden_states) 542 | 543 | hidden_states = self.nonlinearity(hidden_states) 544 | 545 | if self.upsample is not None: 546 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 547 | if hidden_states.shape[0] >= 64: 548 | input_tensor = input_tensor.contiguous() 549 | hidden_states = hidden_states.contiguous() 550 | input_tensor = self.upsample(input_tensor) 551 | hidden_states = self.upsample(hidden_states) 552 | elif self.downsample is not None: 553 | input_tensor = self.downsample(input_tensor) 554 | hidden_states = self.downsample(hidden_states) 555 | 556 | hidden_states = self.conv1(hidden_states) 557 | 558 | if self.time_emb_proj is not None: 559 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None] 560 | 561 | if temb is not None and self.time_embedding_norm == "default": 562 | hidden_states = hidden_states + temb 563 | 564 | if self.time_embedding_norm == "ada_group": 565 | hidden_states = self.norm2(hidden_states, temb) 566 | else: 567 | hidden_states = self.norm2(hidden_states) 568 | 569 | if temb is not None and self.time_embedding_norm == "scale_shift": 570 | scale, shift = torch.chunk(temb, 2, dim=1) 571 | hidden_states = hidden_states * (1 + scale) + shift 572 | 573 | hidden_states = self.nonlinearity(hidden_states) 574 | 575 | hidden_states = self.dropout(hidden_states) 576 | hidden_states = self.conv2(hidden_states) 577 | 578 | if self.conv_shortcut is not None: 579 | input_tensor = self.conv_shortcut(input_tensor) 580 | 581 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor 582 | 583 | return output_tensor 584 | 585 | 586 | class Mish(torch.nn.Module): 587 | def forward(self, hidden_states): 588 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) 589 | 590 | 591 | # unet_rl.py 592 | def rearrange_dims(tensor): 593 | if len(tensor.shape) == 2: 594 | return tensor[:, :, None] 595 | if len(tensor.shape) == 3: 596 | return tensor[:, :, None, :] 597 | elif len(tensor.shape) == 4: 598 | return tensor[:, :, 0, :] 599 | else: 600 | raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") 601 | 602 | 603 | class Conv1dBlock(nn.Module): 604 | """ 605 | Conv1d --> GroupNorm --> Mish 606 | """ 607 | 608 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): 609 | super().__init__() 610 | 611 | self.conv1d = SVDConv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) 612 | self.group_norm = SVDGroupNorm(n_groups, out_channels) 613 | self.mish = nn.Mish() 614 | 615 | def forward(self, x): 616 | x = self.conv1d(x) 617 | x = rearrange_dims(x) 618 | x = self.group_norm(x) 619 | x = rearrange_dims(x) 620 | x = self.mish(x) 621 | return x 622 | 623 | 624 | # unet_rl.py 625 | class ResidualTemporalBlock1D(nn.Module): 626 | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): 627 | super().__init__() 628 | self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) 629 | self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) 630 | 631 | self.time_emb_act = nn.Mish() 632 | self.time_emb = SVDLinear(embed_dim, out_channels) 633 | 634 | self.residual_conv = ( 635 | SVDConv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() 636 | ) 637 | 638 | def forward(self, x, t): 639 | """ 640 | Args: 641 | x : [ batch_size x inp_channels x horizon ] 642 | t : [ batch_size x embed_dim ] 643 | 644 | returns: 645 | out : [ batch_size x out_channels x horizon ] 646 | """ 647 | t = self.time_emb_act(t) 648 | t = self.time_emb(t) 649 | out = self.conv_in(x) + rearrange_dims(t) 650 | out = self.conv_out(out) 651 | return out + self.residual_conv(x) 652 | 653 | 654 | def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): 655 | r"""Upsample2D a batch of 2D images with the given filter. 656 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given 657 | filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified 658 | `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is 659 | a: multiple of the upsampling factor. 660 | 661 | Args: 662 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 663 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]` 664 | (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling. 665 | factor: Integer upsampling factor (default: 2). 666 | gain: Scaling factor for signal magnitude (default: 1.0). 667 | 668 | Returns: 669 | output: Tensor of the shape `[N, C, H * factor, W * factor]` 670 | """ 671 | assert isinstance(factor, int) and factor >= 1 672 | if kernel is None: 673 | kernel = [1] * factor 674 | 675 | kernel = torch.tensor(kernel, dtype=torch.float32) 676 | if kernel.ndim == 1: 677 | kernel = torch.outer(kernel, kernel) 678 | kernel /= torch.sum(kernel) 679 | 680 | kernel = kernel * (gain * (factor**2)) 681 | pad_value = kernel.shape[0] - factor 682 | output = upfirdn2d_native( 683 | hidden_states, 684 | kernel.to(device=hidden_states.device), 685 | up=factor, 686 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2), 687 | ) 688 | return output 689 | 690 | 691 | def downsample_2d(hidden_states, kernel=None, factor=2, gain=1): 692 | r"""Downsample2D a batch of 2D images with the given filter. 693 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the 694 | given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the 695 | specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its 696 | shape is a multiple of the downsampling factor. 697 | 698 | Args: 699 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. 700 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]` 701 | (separable). The default is `[1] * factor`, which corresponds to average pooling. 702 | factor: Integer downsampling factor (default: 2). 703 | gain: Scaling factor for signal magnitude (default: 1.0). 704 | 705 | Returns: 706 | output: Tensor of the shape `[N, C, H // factor, W // factor]` 707 | """ 708 | 709 | assert isinstance(factor, int) and factor >= 1 710 | if kernel is None: 711 | kernel = [1] * factor 712 | 713 | kernel = torch.tensor(kernel, dtype=torch.float32) 714 | if kernel.ndim == 1: 715 | kernel = torch.outer(kernel, kernel) 716 | kernel /= torch.sum(kernel) 717 | 718 | kernel = kernel * gain 719 | pad_value = kernel.shape[0] - factor 720 | output = upfirdn2d_native( 721 | hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2) 722 | ) 723 | return output 724 | 725 | 726 | def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): 727 | up_x = up_y = up 728 | down_x = down_y = down 729 | pad_x0 = pad_y0 = pad[0] 730 | pad_x1 = pad_y1 = pad[1] 731 | 732 | _, channel, in_h, in_w = tensor.shape 733 | tensor = tensor.reshape(-1, in_h, in_w, 1) 734 | 735 | _, in_h, in_w, minor = tensor.shape 736 | kernel_h, kernel_w = kernel.shape 737 | 738 | out = tensor.view(-1, in_h, 1, in_w, 1, minor) 739 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 740 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 741 | 742 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 743 | out = out.to(tensor.device) # Move back to mps if necessary 744 | out = out[ 745 | :, 746 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 747 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 748 | :, 749 | ] 750 | 751 | out = out.permute(0, 3, 1, 2) 752 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 753 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 754 | out = F.conv2d(out, w) 755 | out = out.reshape( 756 | -1, 757 | minor, 758 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 759 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 760 | ) 761 | out = out.permute(0, 2, 3, 1) 762 | out = out[:, ::down_y, ::down_x, :] 763 | 764 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 765 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 766 | 767 | return out.view(-1, channel, out_h, out_w) 768 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/transformer_2d.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.models.embeddings import ImagePositionalEmbeddings 23 | from diffusers.utils import BaseOutput, deprecate 24 | from svdiff_pytorch.diffusers_models.attention import BasicTransformerBlock 25 | from diffusers.models.embeddings import PatchEmbed 26 | from diffusers.models.modeling_utils import ModelMixin 27 | from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm 28 | 29 | 30 | @dataclass 31 | class Transformer2DModelOutput(BaseOutput): 32 | """ 33 | Args: 34 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 35 | Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions 36 | for the unnoised latent pixels. 37 | """ 38 | 39 | sample: torch.FloatTensor 40 | 41 | 42 | class Transformer2DModel(ModelMixin, ConfigMixin): 43 | """ 44 | Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual 45 | embeddings) inputs. 46 | 47 | When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard 48 | transformer action. Finally, reshape to image. 49 | 50 | When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional 51 | embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict 52 | classes of unnoised image. 53 | 54 | Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised 55 | image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. 56 | 57 | Parameters: 58 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 59 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 60 | in_channels (`int`, *optional*): 61 | Pass if the input is continuous. The number of channels in the input and output. 62 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 63 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 64 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. 65 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. 66 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See 67 | `ImagePositionalEmbeddings`. 68 | num_vector_embeds (`int`, *optional*): 69 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. 70 | Includes the class for the masked latent pixel. 71 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 72 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. 73 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used 74 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for 75 | up to but not more than steps than `num_embeds_ada_norm`. 76 | attention_bias (`bool`, *optional*): 77 | Configure if the TransformerBlocks' attention should contain a bias parameter. 78 | """ 79 | 80 | @register_to_config 81 | def __init__( 82 | self, 83 | num_attention_heads: int = 16, 84 | attention_head_dim: int = 88, 85 | in_channels: Optional[int] = None, 86 | out_channels: Optional[int] = None, 87 | num_layers: int = 1, 88 | dropout: float = 0.0, 89 | norm_num_groups: int = 32, 90 | cross_attention_dim: Optional[int] = None, 91 | attention_bias: bool = False, 92 | sample_size: Optional[int] = None, 93 | num_vector_embeds: Optional[int] = None, 94 | patch_size: Optional[int] = None, 95 | activation_fn: str = "geglu", 96 | num_embeds_ada_norm: Optional[int] = None, 97 | use_linear_projection: bool = False, 98 | only_cross_attention: bool = False, 99 | upcast_attention: bool = False, 100 | norm_type: str = "layer_norm", 101 | norm_elementwise_affine: bool = True, 102 | ): 103 | super().__init__() 104 | self.use_linear_projection = use_linear_projection 105 | self.num_attention_heads = num_attention_heads 106 | self.attention_head_dim = attention_head_dim 107 | inner_dim = num_attention_heads * attention_head_dim 108 | 109 | # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 110 | # Define whether input is continuous or discrete depending on configuration 111 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 112 | self.is_input_vectorized = num_vector_embeds is not None 113 | self.is_input_patches = in_channels is not None and patch_size is not None 114 | 115 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 116 | deprecation_message = ( 117 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 118 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 119 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 120 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 121 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 122 | ) 123 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 124 | norm_type = "ada_norm" 125 | 126 | if self.is_input_continuous and self.is_input_vectorized: 127 | raise ValueError( 128 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 129 | " sure that either `in_channels` or `num_vector_embeds` is None." 130 | ) 131 | elif self.is_input_vectorized and self.is_input_patches: 132 | raise ValueError( 133 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 134 | " sure that either `num_vector_embeds` or `num_patches` is None." 135 | ) 136 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 137 | raise ValueError( 138 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 139 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 140 | ) 141 | 142 | # 2. Define input layers 143 | if self.is_input_continuous: 144 | self.in_channels = in_channels 145 | 146 | self.norm = SVDGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 147 | if use_linear_projection: 148 | self.proj_in = SVDLinear(in_channels, inner_dim) 149 | else: 150 | self.proj_in = SVDConv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 151 | elif self.is_input_vectorized: 152 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 153 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 154 | 155 | self.height = sample_size 156 | self.width = sample_size 157 | self.num_vector_embeds = num_vector_embeds 158 | self.num_latent_pixels = self.height * self.width 159 | 160 | self.latent_image_embedding = ImagePositionalEmbeddings( 161 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 162 | ) 163 | elif self.is_input_patches: 164 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 165 | 166 | self.height = sample_size 167 | self.width = sample_size 168 | 169 | self.patch_size = patch_size 170 | self.pos_embed = PatchEmbed( 171 | height=sample_size, 172 | width=sample_size, 173 | patch_size=patch_size, 174 | in_channels=in_channels, 175 | embed_dim=inner_dim, 176 | ) 177 | 178 | # 3. Define transformers blocks 179 | self.transformer_blocks = nn.ModuleList( 180 | [ 181 | BasicTransformerBlock( 182 | inner_dim, 183 | num_attention_heads, 184 | attention_head_dim, 185 | dropout=dropout, 186 | cross_attention_dim=cross_attention_dim, 187 | activation_fn=activation_fn, 188 | num_embeds_ada_norm=num_embeds_ada_norm, 189 | attention_bias=attention_bias, 190 | only_cross_attention=only_cross_attention, 191 | upcast_attention=upcast_attention, 192 | norm_type=norm_type, 193 | norm_elementwise_affine=norm_elementwise_affine, 194 | ) 195 | for d in range(num_layers) 196 | ] 197 | ) 198 | 199 | # 4. Define output layers 200 | self.out_channels = in_channels if out_channels is None else out_channels 201 | if self.is_input_continuous: 202 | # TODO: should use out_channels for continous projections 203 | if use_linear_projection: 204 | self.proj_out = SVDLinear(inner_dim, in_channels) 205 | else: 206 | self.proj_out = SVDConv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 207 | elif self.is_input_vectorized: 208 | self.norm_out = SVDLayerNorm(inner_dim) 209 | self.out = SVDLinear(inner_dim, self.num_vector_embeds - 1) 210 | elif self.is_input_patches: 211 | self.norm_out = SVDLayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 212 | self.proj_out_1 = SVDLinear(inner_dim, 2 * inner_dim) 213 | self.proj_out_2 = SVDLinear(inner_dim, patch_size * patch_size * self.out_channels) 214 | 215 | def forward( 216 | self, 217 | hidden_states, 218 | encoder_hidden_states=None, 219 | timestep=None, 220 | class_labels=None, 221 | cross_attention_kwargs=None, 222 | return_dict: bool = True, 223 | ): 224 | """ 225 | Args: 226 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. 227 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input 228 | hidden_states 229 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): 230 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 231 | self-attention. 232 | timestep ( `torch.long`, *optional*): 233 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. 234 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 235 | Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels 236 | conditioning. 237 | return_dict (`bool`, *optional*, defaults to `True`): 238 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 239 | 240 | Returns: 241 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`: 242 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When 243 | returning a tuple, the first element is the sample tensor. 244 | """ 245 | # 1. Input 246 | if self.is_input_continuous: 247 | batch, _, height, width = hidden_states.shape 248 | residual = hidden_states 249 | 250 | hidden_states = self.norm(hidden_states) 251 | if not self.use_linear_projection: 252 | hidden_states = self.proj_in(hidden_states) 253 | inner_dim = hidden_states.shape[1] 254 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 255 | else: 256 | inner_dim = hidden_states.shape[1] 257 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 258 | hidden_states = self.proj_in(hidden_states) 259 | elif self.is_input_vectorized: 260 | hidden_states = self.latent_image_embedding(hidden_states) 261 | elif self.is_input_patches: 262 | hidden_states = self.pos_embed(hidden_states) 263 | 264 | # 2. Blocks 265 | for block in self.transformer_blocks: 266 | hidden_states = block( 267 | hidden_states, 268 | encoder_hidden_states=encoder_hidden_states, 269 | timestep=timestep, 270 | cross_attention_kwargs=cross_attention_kwargs, 271 | class_labels=class_labels, 272 | ) 273 | 274 | # 3. Output 275 | if self.is_input_continuous: 276 | if not self.use_linear_projection: 277 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 278 | hidden_states = self.proj_out(hidden_states) 279 | else: 280 | hidden_states = self.proj_out(hidden_states) 281 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 282 | 283 | output = hidden_states + residual 284 | elif self.is_input_vectorized: 285 | hidden_states = self.norm_out(hidden_states) 286 | logits = self.out(hidden_states) 287 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 288 | logits = logits.permute(0, 2, 1) 289 | 290 | # log(p(x_0)) 291 | output = F.log_softmax(logits.double(), dim=1).float() 292 | elif self.is_input_patches: 293 | # TODO: cleanup! 294 | conditioning = self.transformer_blocks[0].norm1.emb( 295 | timestep, class_labels, hidden_dtype=hidden_states.dtype 296 | ) 297 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 298 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 299 | hidden_states = self.proj_out_2(hidden_states) 300 | 301 | # unpatchify 302 | height = width = int(hidden_states.shape[1] ** 0.5) 303 | hidden_states = hidden_states.reshape( 304 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 305 | ) 306 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 307 | output = hidden_states.reshape( 308 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 309 | ) 310 | 311 | if not return_dict: 312 | return (output,) 313 | 314 | return Transformer2DModelOutput(sample=output) 315 | -------------------------------------------------------------------------------- /svdiff_pytorch/diffusers_models/unet_2d_condition.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.checkpoint 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import UNet2DConditionLoadersMixin 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.models.cross_attention import AttnProcessor 25 | from svdiff_pytorch.diffusers_models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps 26 | from diffusers.models.modeling_utils import ModelMixin 27 | from svdiff_pytorch.diffusers_models.unet_2d_blocks import ( 28 | CrossAttnDownBlock2D, 29 | CrossAttnUpBlock2D, 30 | DownBlock2D, 31 | UNetMidBlock2DCrossAttn, 32 | UNetMidBlock2DSimpleCrossAttn, 33 | UpBlock2D, 34 | get_down_block, 35 | get_up_block, 36 | ) 37 | from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm 38 | 39 | 40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 41 | 42 | 43 | @dataclass 44 | class UNet2DConditionOutput(BaseOutput): 45 | """ 46 | Args: 47 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 48 | Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model. 49 | """ 50 | 51 | sample: torch.FloatTensor 52 | 53 | 54 | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 55 | r""" 56 | UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep 57 | and returns sample shaped output. 58 | 59 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library 60 | implements for all the models (such as downloading or saving, etc.) 61 | 62 | Parameters: 63 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 64 | Height and width of input/output sample. 65 | in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. 66 | out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. 67 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 68 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 69 | Whether to flip the sin to cos in the time embedding. 70 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 71 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 72 | The tuple of downsample blocks to use. 73 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): 74 | The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the 75 | mid block layer if `None`. 76 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`): 77 | The tuple of upsample blocks to use. 78 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): 79 | Whether to include self-attention in the basic transformer blocks, see 80 | [`~models.attention.BasicTransformerBlock`]. 81 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 82 | The tuple of output channels for each block. 83 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 84 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 85 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 86 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 87 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 88 | If `None`, it will skip the normalization and activation layers in post-processing 89 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 90 | cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features. 91 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 92 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 93 | for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`. 94 | class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately 95 | summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`. 96 | num_class_embeds (`int`, *optional*, defaults to None): 97 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 98 | class conditioning with `class_embed_type` equal to `None`. 99 | time_embedding_type (`str`, *optional*, default to `positional`): 100 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. 101 | timestep_post_act (`str, *optional*, default to `None`): 102 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. 103 | time_cond_proj_dim (`int`, *optional*, default to `None`): 104 | The dimension of `cond_proj` layer in timestep embedding. 105 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. 106 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. 107 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when 108 | using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`. 109 | """ 110 | 111 | _supports_gradient_checkpointing = True 112 | 113 | @register_to_config 114 | def __init__( 115 | self, 116 | sample_size: Optional[int] = None, 117 | in_channels: int = 4, 118 | out_channels: int = 4, 119 | center_input_sample: bool = False, 120 | flip_sin_to_cos: bool = True, 121 | freq_shift: int = 0, 122 | down_block_types: Tuple[str] = ( 123 | "CrossAttnDownBlock2D", 124 | "CrossAttnDownBlock2D", 125 | "CrossAttnDownBlock2D", 126 | "DownBlock2D", 127 | ), 128 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 129 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 130 | only_cross_attention: Union[bool, Tuple[bool]] = False, 131 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 132 | layers_per_block: int = 2, 133 | downsample_padding: int = 1, 134 | mid_block_scale_factor: float = 1, 135 | act_fn: str = "silu", 136 | norm_num_groups: Optional[int] = 32, 137 | norm_eps: float = 1e-5, 138 | cross_attention_dim: int = 1280, 139 | attention_head_dim: Union[int, Tuple[int]] = 8, 140 | dual_cross_attention: bool = False, 141 | use_linear_projection: bool = False, 142 | class_embed_type: Optional[str] = None, 143 | num_class_embeds: Optional[int] = None, 144 | upcast_attention: bool = False, 145 | resnet_time_scale_shift: str = "default", 146 | time_embedding_type: str = "positional", 147 | timestep_post_act: Optional[str] = None, 148 | time_cond_proj_dim: Optional[int] = None, 149 | conv_in_kernel: int = 3, 150 | conv_out_kernel: int = 3, 151 | projection_class_embeddings_input_dim: Optional[int] = None, 152 | ): 153 | super().__init__() 154 | 155 | self.sample_size = sample_size 156 | 157 | # Check inputs 158 | if len(down_block_types) != len(up_block_types): 159 | raise ValueError( 160 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 161 | ) 162 | 163 | if len(block_out_channels) != len(down_block_types): 164 | raise ValueError( 165 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 166 | ) 167 | 168 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 169 | raise ValueError( 170 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 171 | ) 172 | 173 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): 174 | raise ValueError( 175 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 176 | ) 177 | 178 | # input 179 | conv_in_padding = (conv_in_kernel - 1) // 2 180 | self.conv_in = SVDConv2d( 181 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 182 | ) 183 | 184 | # time 185 | if time_embedding_type == "fourier": 186 | time_embed_dim = block_out_channels[0] * 2 187 | if time_embed_dim % 2 != 0: 188 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") 189 | self.time_proj = GaussianFourierProjection( 190 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos 191 | ) 192 | timestep_input_dim = time_embed_dim 193 | elif time_embedding_type == "positional": 194 | time_embed_dim = block_out_channels[0] * 4 195 | 196 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 197 | timestep_input_dim = block_out_channels[0] 198 | else: 199 | raise ValueError( 200 | f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`." 201 | ) 202 | 203 | self.time_embedding = TimestepEmbedding( 204 | timestep_input_dim, 205 | time_embed_dim, 206 | act_fn=act_fn, 207 | post_act_fn=timestep_post_act, 208 | cond_proj_dim=time_cond_proj_dim, 209 | ) 210 | 211 | # class embedding 212 | if class_embed_type is None and num_class_embeds is not None: 213 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 214 | elif class_embed_type == "timestep": 215 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 216 | elif class_embed_type == "identity": 217 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 218 | elif class_embed_type == "projection": 219 | if projection_class_embeddings_input_dim is None: 220 | raise ValueError( 221 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 222 | ) 223 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 224 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 225 | # 2. it projects from an arbitrary input dimension. 226 | # 227 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 228 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 229 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 230 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 231 | else: 232 | self.class_embedding = None 233 | 234 | self.down_blocks = nn.ModuleList([]) 235 | self.up_blocks = nn.ModuleList([]) 236 | 237 | if isinstance(only_cross_attention, bool): 238 | only_cross_attention = [only_cross_attention] * len(down_block_types) 239 | 240 | if isinstance(attention_head_dim, int): 241 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 242 | 243 | # down 244 | output_channel = block_out_channels[0] 245 | for i, down_block_type in enumerate(down_block_types): 246 | input_channel = output_channel 247 | output_channel = block_out_channels[i] 248 | is_final_block = i == len(block_out_channels) - 1 249 | 250 | down_block = get_down_block( 251 | down_block_type, 252 | num_layers=layers_per_block, 253 | in_channels=input_channel, 254 | out_channels=output_channel, 255 | temb_channels=time_embed_dim, 256 | add_downsample=not is_final_block, 257 | resnet_eps=norm_eps, 258 | resnet_act_fn=act_fn, 259 | resnet_groups=norm_num_groups, 260 | cross_attention_dim=cross_attention_dim, 261 | attn_num_head_channels=attention_head_dim[i], 262 | downsample_padding=downsample_padding, 263 | dual_cross_attention=dual_cross_attention, 264 | use_linear_projection=use_linear_projection, 265 | only_cross_attention=only_cross_attention[i], 266 | upcast_attention=upcast_attention, 267 | resnet_time_scale_shift=resnet_time_scale_shift, 268 | ) 269 | self.down_blocks.append(down_block) 270 | 271 | # mid 272 | if mid_block_type == "UNetMidBlock2DCrossAttn": 273 | self.mid_block = UNetMidBlock2DCrossAttn( 274 | in_channels=block_out_channels[-1], 275 | temb_channels=time_embed_dim, 276 | resnet_eps=norm_eps, 277 | resnet_act_fn=act_fn, 278 | output_scale_factor=mid_block_scale_factor, 279 | resnet_time_scale_shift=resnet_time_scale_shift, 280 | cross_attention_dim=cross_attention_dim, 281 | attn_num_head_channels=attention_head_dim[-1], 282 | resnet_groups=norm_num_groups, 283 | dual_cross_attention=dual_cross_attention, 284 | use_linear_projection=use_linear_projection, 285 | upcast_attention=upcast_attention, 286 | ) 287 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": 288 | self.mid_block = UNetMidBlock2DSimpleCrossAttn( 289 | in_channels=block_out_channels[-1], 290 | temb_channels=time_embed_dim, 291 | resnet_eps=norm_eps, 292 | resnet_act_fn=act_fn, 293 | output_scale_factor=mid_block_scale_factor, 294 | cross_attention_dim=cross_attention_dim, 295 | attn_num_head_channels=attention_head_dim[-1], 296 | resnet_groups=norm_num_groups, 297 | resnet_time_scale_shift=resnet_time_scale_shift, 298 | ) 299 | elif mid_block_type is None: 300 | self.mid_block = None 301 | else: 302 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 303 | 304 | # count how many layers upsample the images 305 | self.num_upsamplers = 0 306 | 307 | # up 308 | reversed_block_out_channels = list(reversed(block_out_channels)) 309 | reversed_attention_head_dim = list(reversed(attention_head_dim)) 310 | only_cross_attention = list(reversed(only_cross_attention)) 311 | 312 | output_channel = reversed_block_out_channels[0] 313 | for i, up_block_type in enumerate(up_block_types): 314 | is_final_block = i == len(block_out_channels) - 1 315 | 316 | prev_output_channel = output_channel 317 | output_channel = reversed_block_out_channels[i] 318 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 319 | 320 | # add upsample block for all BUT final layer 321 | if not is_final_block: 322 | add_upsample = True 323 | self.num_upsamplers += 1 324 | else: 325 | add_upsample = False 326 | 327 | up_block = get_up_block( 328 | up_block_type, 329 | num_layers=layers_per_block + 1, 330 | in_channels=input_channel, 331 | out_channels=output_channel, 332 | prev_output_channel=prev_output_channel, 333 | temb_channels=time_embed_dim, 334 | add_upsample=add_upsample, 335 | resnet_eps=norm_eps, 336 | resnet_act_fn=act_fn, 337 | resnet_groups=norm_num_groups, 338 | cross_attention_dim=cross_attention_dim, 339 | attn_num_head_channels=reversed_attention_head_dim[i], 340 | dual_cross_attention=dual_cross_attention, 341 | use_linear_projection=use_linear_projection, 342 | only_cross_attention=only_cross_attention[i], 343 | upcast_attention=upcast_attention, 344 | resnet_time_scale_shift=resnet_time_scale_shift, 345 | ) 346 | self.up_blocks.append(up_block) 347 | prev_output_channel = output_channel 348 | 349 | # out 350 | if norm_num_groups is not None: 351 | self.conv_norm_out = SVDGroupNorm( 352 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps 353 | ) 354 | self.conv_act = nn.SiLU() 355 | else: 356 | self.conv_norm_out = None 357 | self.conv_act = None 358 | 359 | conv_out_padding = (conv_out_kernel - 1) // 2 360 | self.conv_out = SVDConv2d( 361 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding 362 | ) 363 | 364 | @property 365 | def attn_processors(self) -> Dict[str, AttnProcessor]: 366 | r""" 367 | Returns: 368 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 369 | indexed by its weight name. 370 | """ 371 | # set recursively 372 | processors = {} 373 | 374 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]): 375 | if hasattr(module, "set_processor"): 376 | processors[f"{name}.processor"] = module.processor 377 | 378 | for sub_name, child in module.named_children(): 379 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 380 | 381 | return processors 382 | 383 | for name, module in self.named_children(): 384 | fn_recursive_add_processors(name, module, processors) 385 | 386 | return processors 387 | 388 | def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]): 389 | r""" 390 | Parameters: 391 | `processor (`dict` of `AttnProcessor` or `AttnProcessor`): 392 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 393 | of **all** `CrossAttention` layers. 394 | In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.: 395 | 396 | """ 397 | count = len(self.attn_processors.keys()) 398 | 399 | if isinstance(processor, dict) and len(processor) != count: 400 | raise ValueError( 401 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 402 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 403 | ) 404 | 405 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 406 | if hasattr(module, "set_processor"): 407 | if not isinstance(processor, dict): 408 | module.set_processor(processor) 409 | else: 410 | module.set_processor(processor.pop(f"{name}.processor")) 411 | 412 | for sub_name, child in module.named_children(): 413 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 414 | 415 | for name, module in self.named_children(): 416 | fn_recursive_attn_processor(name, module, processor) 417 | 418 | def set_attention_slice(self, slice_size): 419 | r""" 420 | Enable sliced attention computation. 421 | 422 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 423 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 424 | 425 | Args: 426 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 427 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If 428 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is 429 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 430 | must be a multiple of `slice_size`. 431 | """ 432 | sliceable_head_dims = [] 433 | 434 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): 435 | if hasattr(module, "set_attention_slice"): 436 | sliceable_head_dims.append(module.sliceable_head_dim) 437 | 438 | for child in module.children(): 439 | fn_recursive_retrieve_slicable_dims(child) 440 | 441 | # retrieve number of attention layers 442 | for module in self.children(): 443 | fn_recursive_retrieve_slicable_dims(module) 444 | 445 | num_slicable_layers = len(sliceable_head_dims) 446 | 447 | if slice_size == "auto": 448 | # half the attention head size is usually a good trade-off between 449 | # speed and memory 450 | slice_size = [dim // 2 for dim in sliceable_head_dims] 451 | elif slice_size == "max": 452 | # make smallest slice possible 453 | slice_size = num_slicable_layers * [1] 454 | 455 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 456 | 457 | if len(slice_size) != len(sliceable_head_dims): 458 | raise ValueError( 459 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 460 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 461 | ) 462 | 463 | for i in range(len(slice_size)): 464 | size = slice_size[i] 465 | dim = sliceable_head_dims[i] 466 | if size is not None and size > dim: 467 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 468 | 469 | # Recursively walk through all the children. 470 | # Any children which exposes the set_attention_slice method 471 | # gets the message 472 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 473 | if hasattr(module, "set_attention_slice"): 474 | module.set_attention_slice(slice_size.pop()) 475 | 476 | for child in module.children(): 477 | fn_recursive_set_attention_slice(child, slice_size) 478 | 479 | reversed_slice_size = list(reversed(slice_size)) 480 | for module in self.children(): 481 | fn_recursive_set_attention_slice(module, reversed_slice_size) 482 | 483 | def _set_gradient_checkpointing(self, module, value=False): 484 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): 485 | module.gradient_checkpointing = value 486 | 487 | def forward( 488 | self, 489 | sample: torch.FloatTensor, 490 | timestep: Union[torch.Tensor, float, int], 491 | encoder_hidden_states: torch.Tensor, 492 | class_labels: Optional[torch.Tensor] = None, 493 | timestep_cond: Optional[torch.Tensor] = None, 494 | attention_mask: Optional[torch.Tensor] = None, 495 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 496 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 497 | mid_block_additional_residual: Optional[torch.Tensor] = None, 498 | return_dict: bool = True, 499 | ) -> Union[UNet2DConditionOutput, Tuple]: 500 | r""" 501 | Args: 502 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor 503 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps 504 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states 505 | return_dict (`bool`, *optional*, defaults to `True`): 506 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. 507 | cross_attention_kwargs (`dict`, *optional*): 508 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under 509 | `self.processor` in 510 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 511 | 512 | Returns: 513 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 514 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When 515 | returning a tuple, the first element is the sample tensor. 516 | """ 517 | # By default samples have to be AT least a multiple of the overall upsampling factor. 518 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). 519 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 520 | # on the fly if necessary. 521 | default_overall_up_factor = 2**self.num_upsamplers 522 | 523 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 524 | forward_upsample_size = False 525 | upsample_size = None 526 | 527 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 528 | logger.info("Forward upsample size to force interpolation output size.") 529 | forward_upsample_size = True 530 | 531 | # prepare attention_mask 532 | if attention_mask is not None: 533 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 534 | attention_mask = attention_mask.unsqueeze(1) 535 | 536 | # 0. center input if necessary 537 | if self.config.center_input_sample: 538 | sample = 2 * sample - 1.0 539 | 540 | # 1. time 541 | timesteps = timestep 542 | if not torch.is_tensor(timesteps): 543 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 544 | # This would be a good case for the `match` statement (Python 3.10+) 545 | is_mps = sample.device.type == "mps" 546 | if isinstance(timestep, float): 547 | dtype = torch.float32 if is_mps else torch.float64 548 | else: 549 | dtype = torch.int32 if is_mps else torch.int64 550 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 551 | elif len(timesteps.shape) == 0: 552 | timesteps = timesteps[None].to(sample.device) 553 | 554 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 555 | timesteps = timesteps.expand(sample.shape[0]) 556 | 557 | t_emb = self.time_proj(timesteps) 558 | 559 | # timesteps does not contain any weights and will always return f32 tensors 560 | # but time_embedding might actually be running in fp16. so we need to cast here. 561 | # there might be better ways to encapsulate this. 562 | t_emb = t_emb.to(dtype=self.dtype) 563 | 564 | emb = self.time_embedding(t_emb, timestep_cond) 565 | 566 | if self.class_embedding is not None: 567 | if class_labels is None: 568 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 569 | 570 | if self.config.class_embed_type == "timestep": 571 | class_labels = self.time_proj(class_labels) 572 | 573 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) 574 | emb = emb + class_emb 575 | 576 | # 2. pre-process 577 | sample = self.conv_in(sample) 578 | 579 | # 3. down 580 | down_block_res_samples = (sample,) 581 | for downsample_block in self.down_blocks: 582 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 583 | sample, res_samples = downsample_block( 584 | hidden_states=sample, 585 | temb=emb, 586 | encoder_hidden_states=encoder_hidden_states, 587 | attention_mask=attention_mask, 588 | cross_attention_kwargs=cross_attention_kwargs, 589 | ) 590 | else: 591 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 592 | 593 | down_block_res_samples += res_samples 594 | 595 | if down_block_additional_residuals is not None: 596 | new_down_block_res_samples = () 597 | 598 | for down_block_res_sample, down_block_additional_residual in zip( 599 | down_block_res_samples, down_block_additional_residuals 600 | ): 601 | down_block_res_sample += down_block_additional_residual 602 | new_down_block_res_samples += (down_block_res_sample,) 603 | 604 | down_block_res_samples = new_down_block_res_samples 605 | 606 | # 4. mid 607 | if self.mid_block is not None: 608 | sample = self.mid_block( 609 | sample, 610 | emb, 611 | encoder_hidden_states=encoder_hidden_states, 612 | attention_mask=attention_mask, 613 | cross_attention_kwargs=cross_attention_kwargs, 614 | ) 615 | 616 | if mid_block_additional_residual is not None: 617 | sample += mid_block_additional_residual 618 | 619 | # 5. up 620 | for i, upsample_block in enumerate(self.up_blocks): 621 | is_final_block = i == len(self.up_blocks) - 1 622 | 623 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 624 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 625 | 626 | # if we have not reached the final block and need to forward the 627 | # upsample size, we do it here 628 | if not is_final_block and forward_upsample_size: 629 | upsample_size = down_block_res_samples[-1].shape[2:] 630 | 631 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 632 | sample = upsample_block( 633 | hidden_states=sample, 634 | temb=emb, 635 | res_hidden_states_tuple=res_samples, 636 | encoder_hidden_states=encoder_hidden_states, 637 | cross_attention_kwargs=cross_attention_kwargs, 638 | upsample_size=upsample_size, 639 | attention_mask=attention_mask, 640 | ) 641 | else: 642 | sample = upsample_block( 643 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 644 | ) 645 | 646 | # 6. post-process 647 | if self.conv_norm_out: 648 | sample = self.conv_norm_out(sample) 649 | sample = self.conv_act(sample) 650 | sample = self.conv_out(sample) 651 | 652 | if not return_dict: 653 | return (sample,) 654 | 655 | return UNet2DConditionOutput(sample=sample) 656 | -------------------------------------------------------------------------------- /svdiff_pytorch/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from einops import rearrange 5 | 6 | 7 | 8 | class SVDConv2d(nn.Conv2d): 9 | def __init__( 10 | self, 11 | in_channels: int, 12 | out_channels: int, 13 | kernel_size: int, 14 | scale: float = 1.0, 15 | **kwargs 16 | ): 17 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) 18 | assert type(kernel_size) is int 19 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)') 20 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False) 21 | # initialize to 0 for smooth tuning 22 | self.delta = nn.Parameter(torch.zeros_like(self.S)) 23 | self.weight.requires_grad = False 24 | self.done_svd = False 25 | self.scale = scale 26 | self.reset_parameters() 27 | 28 | def set_scale(self, scale: float): 29 | self.scale = scale 30 | 31 | def perform_svd(self): 32 | # shape 33 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)') 34 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False) 35 | self.done_svd = True 36 | 37 | def reset_parameters(self): 38 | nn.Conv2d.reset_parameters(self) 39 | if hasattr(self, 'delta'): 40 | nn.init.zeros_(self.delta) 41 | 42 | def forward(self, x: torch.Tensor): 43 | if not self.done_svd: 44 | # this happens after loading the state dict 45 | self.perform_svd() 46 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype) 47 | weight_updated = rearrange(weight_updated, 'co (cin h w) -> co cin h w', cin=self.weight.size(1), h=self.weight.size(2), w=self.weight.size(3)) 48 | return F.conv2d(x, weight_updated, self.bias, self.stride, self.padding, self.dilation, self.groups) 49 | 50 | 51 | class SVDConv1d(nn.Conv1d): 52 | def __init__( 53 | self, 54 | in_channels: int, 55 | out_channels: int, 56 | kernel_size: int, 57 | scale: float = 1.0, 58 | **kwargs 59 | ): 60 | nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) 61 | assert type(kernel_size) is int 62 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)') 63 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False) 64 | # initialize to 0 for smooth tuning 65 | self.delta = nn.Parameter(torch.zeros_like(self.S)) 66 | self.weight.requires_grad = False 67 | self.done_svd = False 68 | self.scale = scale 69 | self.reset_parameters() 70 | 71 | def set_scale(self, scale: float): 72 | self.scale = scale 73 | 74 | def perform_svd(self): 75 | # shape 76 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)') 77 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False) 78 | self.done_svd = True 79 | 80 | def reset_parameters(self): 81 | nn.Conv1d.reset_parameters(self) 82 | if hasattr(self, 'delta'): 83 | nn.init.zeros_(self.delta) 84 | 85 | def forward(self, x: torch.Tensor): 86 | if not self.done_svd: 87 | # this happens after loading the state dict 88 | self.perform_svd() 89 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype) 90 | weight_updated = rearrange(weight_updated, 'co (cin h w) -> co cin h w', cin=self.weight.size(1), h=self.weight.size(2), w=self.weight.size(3)) 91 | return F.conv1d(x, weight_updated, self.bias, self.stride, self.padding, self.dilation, self.groups) 92 | 93 | 94 | 95 | class SVDLinear(nn.Linear): 96 | def __init__( 97 | self, 98 | in_features: int, 99 | out_features: int, 100 | scale: float = 1.0, 101 | **kwargs 102 | ): 103 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 104 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False) 105 | # initialize to 0 for smooth tuning 106 | self.delta = nn.Parameter(torch.zeros_like(self.S)) 107 | self.weight.requires_grad = False 108 | self.done_svd = False 109 | self.scale = scale 110 | self.reset_parameters() 111 | 112 | def set_scale(self, scale: float): 113 | self.scale = scale 114 | 115 | def perform_svd(self): 116 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False) 117 | self.done_svd = True 118 | 119 | def reset_parameters(self): 120 | nn.Linear.reset_parameters(self) 121 | if hasattr(self, 'delta'): 122 | nn.init.zeros_(self.delta) 123 | 124 | def forward(self, x: torch.Tensor): 125 | if not self.done_svd: 126 | # this happens after loading the state dict 127 | self.perform_svd() 128 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype) 129 | return F.linear(x, weight_updated, bias=self.bias) 130 | 131 | 132 | class SVDEmbedding(nn.Embedding): 133 | # LoRA implemented in a dense layer 134 | def __init__( 135 | self, 136 | num_embeddings: int, 137 | embedding_dim: int, 138 | scale: float = 1.0, 139 | **kwargs 140 | ): 141 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) 142 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False) 143 | # initialize to 0 for smooth tuning 144 | self.delta = nn.Parameter(torch.zeros_like(self.S)) 145 | self.weight.requires_grad = False 146 | self.done_svd = False 147 | self.scale = scale 148 | self.reset_parameters() 149 | 150 | def set_scale(self, scale: float): 151 | self.scale = scale 152 | 153 | def perform_svd(self): 154 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False) 155 | self.done_svd = True 156 | 157 | def reset_parameters(self): 158 | nn.Embedding.reset_parameters(self) 159 | if hasattr(self, 'delta'): 160 | nn.init.zeros_(self.delta) 161 | 162 | def forward(self, x: torch.Tensor): 163 | if not self.done_svd: 164 | # this happens after loading the state dict 165 | self.perform_svd() 166 | weight_updated = self.U.to(x.device) @ torch.diag(F.relu(self.S.to(x.device)+self.scale * self.delta)) @ self.Vh.to(x.device) 167 | return F.embedding(x, weight_updated, padding_idx=self.padding_idx, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse) 168 | 169 | 170 | # 1-D 171 | class SVDLayerNorm(nn.LayerNorm): 172 | def __init__( 173 | self, 174 | normalized_shape: int, 175 | scale: float = 1.0, 176 | **kwargs 177 | ): 178 | nn.LayerNorm.__init__(self, normalized_shape=normalized_shape, **kwargs) 179 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False) 180 | # initialize to 0 for smooth tuning 181 | self.delta = nn.Parameter(torch.zeros_like(self.S)) 182 | self.weight.requires_grad = False 183 | self.done_svd = False 184 | self.scale = scale 185 | self.reset_parameters() 186 | 187 | def set_scale(self, scale: float): 188 | self.scale = scale 189 | 190 | def perform_svd(self): 191 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False) 192 | self.done_svd = True 193 | 194 | def reset_parameters(self): 195 | nn.LayerNorm.reset_parameters(self) 196 | if hasattr(self, 'delta'): 197 | nn.init.zeros_(self.delta) 198 | 199 | def forward(self, x: torch.Tensor): 200 | if not self.done_svd: 201 | # this happens after loading the state dict 202 | self.perform_svd() 203 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype) 204 | weight_updated = weight_updated.squeeze(0) 205 | return F.layer_norm(x, normalized_shape=self.normalized_shape, weight=weight_updated, bias=self.bias, eps=self.eps) 206 | 207 | 208 | class SVDGroupNorm(nn.GroupNorm): 209 | def __init__( 210 | self, 211 | num_groups: int, 212 | num_channels: int, 213 | scale: float = 1.0, 214 | **kwargs 215 | ): 216 | nn.GroupNorm.__init__(self, num_groups, num_channels, **kwargs) 217 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False) 218 | # initialize to 0 for smooth tuning 219 | self.delta = nn.Parameter(torch.zeros_like(self.S)) 220 | self.weight.requires_grad = False 221 | self.done_svd = False 222 | self.scale = scale 223 | self.reset_parameters() 224 | 225 | def set_scale(self, scale: float): 226 | self.scale = scale 227 | 228 | def perform_svd(self): 229 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False) 230 | self.done_svd = True 231 | 232 | def reset_parameters(self): 233 | nn.GroupNorm.reset_parameters(self) 234 | if hasattr(self, 'delta'): 235 | nn.init.zeros_(self.delta) 236 | 237 | def forward(self, x: torch.Tensor): 238 | if not self.done_svd: 239 | # this happens after loading the state dict 240 | self.perform_svd() 241 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype) 242 | weight_updated = weight_updated.squeeze(0) 243 | return F.group_norm(x, num_groups=self.num_groups, weight=weight_updated, bias=self.bias, eps=self.eps) 244 | 245 | -------------------------------------------------------------------------------- /svdiff_pytorch/pipeline_stable_diffusion_ddim_inversion.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union 2 | import PIL 3 | import torch 4 | from diffusers import StableDiffusionPipeline, DDIMInverseScheduler 5 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess 6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero import Pix2PixInversionPipelineOutput 7 | 8 | 9 | class StableDiffusionPipelineWithDDIMInversion(StableDiffusionPipeline): 10 | def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker: bool = True): 11 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker) 12 | self.inverse_scheduler = DDIMInverseScheduler.from_config(self.scheduler.config) 13 | # self.register_modules(inverse_scheduler=DDIMInverseScheduler.from_config(self.scheduler.config)) 14 | 15 | 16 | def prepare_image_latents(self, image, batch_size, dtype, device, generator=None): 17 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 18 | raise ValueError( 19 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 20 | ) 21 | 22 | image = image.to(device=device, dtype=dtype) 23 | 24 | if isinstance(generator, list) and len(generator) != batch_size: 25 | raise ValueError( 26 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 27 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 28 | ) 29 | 30 | if isinstance(generator, list): 31 | init_latents = [ 32 | self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) 33 | ] 34 | init_latents = torch.cat(init_latents, dim=0) 35 | else: 36 | init_latents = self.vae.encode(image).latent_dist.sample(generator) 37 | 38 | init_latents = self.vae.config.scaling_factor * init_latents 39 | 40 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: 41 | raise ValueError( 42 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 43 | ) 44 | else: 45 | init_latents = torch.cat([init_latents], dim=0) 46 | 47 | latents = init_latents 48 | 49 | return latents 50 | 51 | def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int): 52 | pred_type = self.inverse_scheduler.config.prediction_type 53 | alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep] 54 | 55 | beta_prod_t = 1 - alpha_prod_t 56 | 57 | if pred_type == "epsilon": 58 | return model_output 59 | elif pred_type == "sample": 60 | return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5) 61 | elif pred_type == "v_prediction": 62 | return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 63 | else: 64 | raise ValueError( 65 | f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`" 66 | ) 67 | 68 | def auto_corr_loss(self, hidden_states, generator=None): 69 | batch_size, channel, height, width = hidden_states.shape 70 | if batch_size > 1: 71 | raise ValueError("Only batch_size 1 is supported for now") 72 | 73 | hidden_states = hidden_states.squeeze(0) 74 | # hidden_states must be shape [C,H,W] now 75 | reg_loss = 0.0 76 | for i in range(hidden_states.shape[0]): 77 | noise = hidden_states[i][None, None, :, :] 78 | while True: 79 | roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() 80 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 81 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 82 | 83 | if noise.shape[2] <= 8: 84 | break 85 | noise = F.avg_pool2d(noise, kernel_size=2) 86 | return reg_loss 87 | 88 | def kl_divergence(self, hidden_states): 89 | mean = hidden_states.mean() 90 | var = hidden_states.var() 91 | return var + mean**2 - 1 - torch.log(var + 1e-7) 92 | 93 | 94 | # based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py#L1063 95 | @torch.no_grad() 96 | def invert( 97 | self, 98 | prompt: Optional[str] = None, 99 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 100 | num_inference_steps: int = 50, 101 | guidance_scale: float = 1, 102 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 103 | latents: Optional[torch.FloatTensor] = None, 104 | prompt_embeds: Optional[torch.FloatTensor] = None, 105 | output_type: Optional[str] = "pil", 106 | return_dict: bool = True, 107 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 108 | callback_steps: Optional[int] = 1, 109 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 110 | lambda_auto_corr: float = 20.0, 111 | lambda_kl: float = 20.0, 112 | num_reg_steps: int = 0, # disabled 113 | num_auto_corr_rolls: int = 5, 114 | ): 115 | # 1. Define call parameters 116 | if prompt is not None and isinstance(prompt, str): 117 | batch_size = 1 118 | elif prompt is not None and isinstance(prompt, list): 119 | batch_size = len(prompt) 120 | else: 121 | batch_size = prompt_embeds.shape[0] 122 | if cross_attention_kwargs is None: 123 | cross_attention_kwargs = {} 124 | 125 | device = self._execution_device 126 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 127 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 128 | # corresponds to doing no classifier free guidance. 129 | do_classifier_free_guidance = guidance_scale > 1.0 130 | 131 | # 3. Preprocess image 132 | image = preprocess(image) 133 | 134 | # 4. Prepare latent variables 135 | latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator) 136 | 137 | # 5. Encode input prompt 138 | num_images_per_prompt = 1 139 | prompt_embeds = self._encode_prompt( 140 | prompt, 141 | device, 142 | num_images_per_prompt, 143 | do_classifier_free_guidance, 144 | prompt_embeds=prompt_embeds, 145 | ) 146 | 147 | # 4. Prepare timesteps 148 | self.inverse_scheduler.set_timesteps(num_inference_steps, device=device) 149 | timesteps = self.inverse_scheduler.timesteps 150 | 151 | # 7. Denoising loop where we obtain the cross-attention maps. 152 | num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order 153 | with self.progress_bar(total=num_inference_steps - 1) as progress_bar: 154 | for i, t in enumerate(timesteps[:-1]): 155 | # expand the latents if we are doing classifier free guidance 156 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 157 | latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t) 158 | 159 | # predict the noise residual 160 | noise_pred = self.unet( 161 | latent_model_input, 162 | t, 163 | encoder_hidden_states=prompt_embeds, 164 | cross_attention_kwargs=cross_attention_kwargs, 165 | ).sample 166 | 167 | # perform guidance 168 | if do_classifier_free_guidance: 169 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 170 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 171 | 172 | # regularization of the noise prediction 173 | with torch.enable_grad(): 174 | for _ in range(num_reg_steps): 175 | if lambda_auto_corr > 0: 176 | for _ in range(num_auto_corr_rolls): 177 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) 178 | 179 | # Derive epsilon from model output before regularizing to IID standard normal 180 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) 181 | 182 | l_ac = self.auto_corr_loss(var_epsilon, generator=generator) 183 | l_ac.backward() 184 | 185 | grad = var.grad.detach() / num_auto_corr_rolls 186 | noise_pred = noise_pred - lambda_auto_corr * grad 187 | 188 | if lambda_kl > 0: 189 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True) 190 | 191 | # Derive epsilon from model output before regularizing to IID standard normal 192 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t) 193 | 194 | l_kld = self.kl_divergence(var_epsilon) 195 | l_kld.backward() 196 | 197 | grad = var.grad.detach() 198 | noise_pred = noise_pred - lambda_kl * grad 199 | 200 | noise_pred = noise_pred.detach() 201 | 202 | # compute the previous noisy sample x_t -> x_t-1 203 | latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample 204 | 205 | # call the callback, if provided 206 | if i == len(timesteps) - 1 or ( 207 | (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0 208 | ): 209 | progress_bar.update() 210 | if callback is not None and i % callback_steps == 0: 211 | callback(i, t, latents) 212 | 213 | inverted_latents = latents.detach().clone() 214 | 215 | # 8. Post-processing 216 | image = self.decode_latents(latents.detach()) 217 | 218 | # Offload last model to CPU 219 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 220 | self.final_offload_hook.offload() 221 | 222 | # 9. Convert to PIL. 223 | if output_type == "pil": 224 | image = self.numpy_to_pil(image) 225 | 226 | if not return_dict: 227 | return (inverted_latents, image) 228 | 229 | return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image) 230 | 231 | 232 | 233 | if __name__ == '__main__': 234 | from PIL import Image 235 | from diffusers import DDIMScheduler 236 | model_id = "CompVis/stable-diffusion-v1-4" 237 | input_prompt = "A photo of Barack Obama" 238 | prompt = "A photo of Barack Obama smiling with a big grin" 239 | url = "obama.png" # https://github.com/cccntu/efficient-prompt-to-prompt/blob/main/ddim-inversion.ipynb 240 | 241 | pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained( 242 | model_id, 243 | # make sure to load ddim here 244 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"), 245 | ) 246 | image = Image.open(url).convert("RGB").resize((512, 512)) 247 | # in SVDiff, they use guidance scale=1 in ddim inversion 248 | inv_latents = pipe.invert(input_prompt, image=image, guidance_scale=1.0).latents 249 | image = pipe(prompt, latents=inv_latents).images[0] 250 | image.save("out.png") -------------------------------------------------------------------------------- /svdiff_pytorch/transformers_models_clip/__init__.py: -------------------------------------------------------------------------------- 1 | # all files in this folder were taken from https://github.com/huggingface/transformers/blob/v4.27.3/src/transformers/models/clip/modeling_clip.py 2 | # so, these files follow the LICENSE of transformers -------------------------------------------------------------------------------- /svdiff_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import inspect 3 | from PIL import Image 4 | 5 | import torch 6 | import accelerate 7 | from accelerate.utils import set_module_tensor_to_device 8 | from diffusers import ( 9 | LMSDiscreteScheduler, 10 | DDIMScheduler, 11 | PNDMScheduler, 12 | DPMSolverMultistepScheduler, 13 | EulerDiscreteScheduler, 14 | EulerAncestralDiscreteScheduler, 15 | ) 16 | from transformers import CLIPTextModel, CLIPTextConfig 17 | from diffusers import UNet2DConditionModel 18 | from safetensors.torch import safe_open 19 | import huggingface_hub 20 | from svdiff_pytorch import UNet2DConditionModelForSVDiff, CLIPTextModelForSVDiff 21 | 22 | 23 | 24 | def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs): 25 | """ 26 | https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/models/modeling_utils.py#L541 27 | """ 28 | config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs) 29 | original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs) 30 | state_dict = original_model.state_dict() 31 | with accelerate.init_empty_weights(): 32 | model = UNet2DConditionModelForSVDiff.from_config(config) 33 | # load pre-trained weights 34 | param_device = "cpu" 35 | torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None 36 | spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n} 37 | state_dict.update(spectral_shifts_weights) 38 | # move the params from meta device to cpu 39 | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) 40 | if len(missing_keys) > 0: 41 | raise ValueError( 42 | f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are" 43 | f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" 44 | " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" 45 | " those weights or else make sure your checkpoint file is correct." 46 | ) 47 | 48 | for param_name, param in state_dict.items(): 49 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) 50 | if accepts_dtype: 51 | set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) 52 | else: 53 | set_module_tensor_to_device(model, param_name, param_device, value=param) 54 | 55 | if spectral_shifts_ckpt: 56 | if os.path.isdir(spectral_shifts_ckpt): 57 | spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors") 58 | elif not os.path.exists(spectral_shifts_ckpt): 59 | # download from hub 60 | hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs 61 | spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs) 62 | assert os.path.exists(spectral_shifts_ckpt) 63 | 64 | with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f: 65 | for key in f.keys(): 66 | # spectral_shifts_weights[key] = f.get_tensor(key) 67 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) 68 | if accepts_dtype: 69 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype) 70 | else: 71 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key)) 72 | print(f"Resumed from {spectral_shifts_ckpt}") 73 | if "torch_dtype"in kwargs: 74 | model = model.to(kwargs["torch_dtype"]) 75 | model.register_to_config(_name_or_path=pretrained_model_name_or_path) 76 | # Set model in evaluation mode to deactivate DropOut modules by default 77 | model.eval() 78 | del original_model 79 | torch.cuda.empty_cache() 80 | return model 81 | 82 | 83 | 84 | def load_text_encoder_for_svdiff( 85 | pretrained_model_name_or_path, 86 | spectral_shifts_ckpt=None, 87 | hf_hub_kwargs=None, 88 | **kwargs 89 | ): 90 | """ 91 | https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/models/modeling_utils.py#L541 92 | """ 93 | config = CLIPTextConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) 94 | original_model = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, **kwargs) 95 | state_dict = original_model.state_dict() 96 | with accelerate.init_empty_weights(): 97 | model = CLIPTextModelForSVDiff(config) 98 | # load pre-trained weights 99 | param_device = "cpu" 100 | torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None 101 | spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n} 102 | state_dict.update(spectral_shifts_weights) 103 | # move the params from meta device to cpu 104 | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) 105 | if len(missing_keys) > 0: 106 | raise ValueError( 107 | f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are" 108 | f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass" 109 | " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize" 110 | " those weights or else make sure your checkpoint file is correct." 111 | ) 112 | 113 | for param_name, param in state_dict.items(): 114 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) 115 | if accepts_dtype: 116 | set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype) 117 | else: 118 | set_module_tensor_to_device(model, param_name, param_device, value=param) 119 | 120 | if spectral_shifts_ckpt: 121 | if os.path.isdir(spectral_shifts_ckpt): 122 | spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors") 123 | elif not os.path.exists(spectral_shifts_ckpt): 124 | # download from hub 125 | hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs 126 | try: 127 | spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs) 128 | except huggingface_hub.utils.EntryNotFoundError: 129 | spectral_shifts_ckpt = None 130 | # load state dict only if `spectral_shifts_te.safetensors` exists 131 | if os.path.exists(spectral_shifts_ckpt): 132 | with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f: 133 | for key in f.keys(): 134 | # spectral_shifts_weights[key] = f.get_tensor(key) 135 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys()) 136 | if accepts_dtype: 137 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype) 138 | else: 139 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key)) 140 | print(f"Resumed from {spectral_shifts_ckpt}") 141 | 142 | if "torch_dtype"in kwargs: 143 | model = model.to(kwargs["torch_dtype"]) 144 | # model.register_to_config(_name_or_path=pretrained_model_name_or_path) 145 | # Set model in evaluation mode to deactivate DropOut modules by default 146 | model.eval() 147 | del original_model 148 | torch.cuda.empty_cache() 149 | return model 150 | 151 | 152 | 153 | def image_grid(imgs, rows, cols): 154 | assert len(imgs) == rows * cols 155 | w, h = imgs[0].size 156 | grid = Image.new('RGB', size=(cols * w, rows * h)) 157 | for i, img in enumerate(imgs): 158 | grid.paste(img, box=(i % cols * w, i // cols * h)) 159 | return grid 160 | 161 | 162 | def slerp(val, low, high): 163 | """ taken from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4 164 | """ 165 | low_norm = low/torch.norm(low, dim=1, keepdim=True) 166 | high_norm = high/torch.norm(high, dim=1, keepdim=True) 167 | omega = torch.acos((low_norm*high_norm).sum(1)) 168 | so = torch.sin(omega) 169 | res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high 170 | return res 171 | 172 | 173 | def slerp_tensor(val, low, high): 174 | shape = low.shape 175 | res = slerp(val, low.flatten(1), high.flatten(1)) 176 | return res.reshape(shape) 177 | 178 | 179 | SCHEDULER_MAPPING = { 180 | "ddim": DDIMScheduler, 181 | "plms": PNDMScheduler, 182 | "lms": LMSDiscreteScheduler, 183 | "euler": EulerDiscreteScheduler, 184 | "euler_ancestral": EulerAncestralDiscreteScheduler, 185 | "dpm_solver++": DPMSolverMultistepScheduler, 186 | } 187 | 188 | --------------------------------------------------------------------------------