├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── .gitmodules
├── LICENSE
├── README.md
├── assets
├── car-result.png
├── chair-result.png
├── dog.png
├── kasumi.png
├── kumamon.png
├── lora.png
├── scale-result.png
├── scale.png
└── svdiff.png
├── inference.py
├── paper.png
├── requirements.txt
├── scripts
├── svdiff_pytorch.ipynb
├── train_dreambooth.py
└── train_dreambooth_lora.py
├── setup.py
├── svdiff_pytorch
├── __init__.py
├── diffusers_models
│ ├── __init__.py
│ ├── attention.py
│ ├── cross_attention.py
│ ├── dual_transformer_2d.py
│ ├── embeddings.py
│ ├── resnet.py
│ ├── transformer_2d.py
│ ├── unet_2d_blocks.py
│ └── unet_2d_condition.py
├── layers.py
├── pipeline_stable_diffusion_ddim_inversion.py
├── transformers_models_clip
│ ├── __init__.py
│ └── modeling_clip.py
└── utils.py
└── train_svdiff.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 | # This workflow will upload a Python Package using Twine when a release is created
2 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3 |
4 | # This workflow uses actions that are not certified by GitHub.
5 | # They are provided by a third-party and are governed by
6 | # separate terms of service, privacy policy, and support
7 | # documentation.
8 |
9 | name: Upload Python Package
10 |
11 | on:
12 | release:
13 | types: [published]
14 |
15 | jobs:
16 | deploy:
17 |
18 | runs-on: ubuntu-latest
19 |
20 | steps:
21 | - uses: actions/checkout@v2
22 | - name: Set up Python
23 | uses: actions/setup-python@v2
24 | with:
25 | python-version: '3.x'
26 | - name: Install dependencies
27 | run: |
28 | python -m pip install --upgrade pip
29 | pip install build
30 | - name: Build package
31 | run: python -m build
32 | - name: Publish package
33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34 | with:
35 | user: __token__
36 | password: ${{ secrets.PYPI_API_TOKEN }}
37 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | private
132 | test.ipynb
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "scripts/gradio"]
2 | path = scripts/gradio
3 | url = https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 mkshing
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SVDiff-pytorch
2 |
3 | [](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI)
4 |
5 |
6 | An implementation of [SVDiff: Compact Parameter Space for Diffusion Fine-Tuning](https://arxiv.org/abs/2303.11305) by using d🧨ffusers.
7 |
8 | My summary tweet is found [here](https://twitter.com/mk1stats/status/1642865505106272257).
9 |
10 |
11 | 
12 | left: LoRA, right: SVDiff
13 |
14 |
15 | Compared with LoRA, the number of trainable parameters is 0.5 M less parameters and the file size is only 1.2MB (LoRA: 3.1MB)!!
16 |
17 | 
18 |
19 | ## Updates
20 | ### 2023.4.11
21 | - Released v0.2.0 (please see [here](https://github.com/mkshing/svdiff-pytorch/releases/tag/v0.2.0) for the details). By this change, you get better results with less training steps than the first release v0.1.1!!
22 | - Add [Single Image Editing](#single-image-editing)
23 |
24 | 
25 |
"photo of a ~~pink~~ **blue** chair with black legs" (without DDIM Inversion)
26 |
27 |
28 | ## Installation
29 | ```
30 | $ pip install svdiff-pytorch
31 | ```
32 | Or, manually
33 | ```bash
34 | $ git clone https://github.com/mkshing/svdiff-pytorch
35 | $ pip install -r requirements.txt
36 | ```
37 |
38 | ## Single-Subject Generation
39 | "Single-Subject Generation" is a domain-tuning on a single object or concept (using 3-5 images). (See Section 4.1)
40 |
41 | ### Training
42 | According to the paper, the learning rate for SVDiff needs to be 1000 times larger than the lr used for fine-tuning.
43 |
44 | ```bash
45 | export MODEL_NAME="runwayml/stable-diffusion-v1-5"
46 | export INSTANCE_DIR="path-to-instance-images"
47 | export CLASS_DIR="path-to-class-images"
48 | export OUTPUT_DIR="path-to-save-model"
49 |
50 | accelerate launch train_svdiff.py \
51 | --pretrained_model_name_or_path=$MODEL_NAME \
52 | --instance_data_dir=$INSTANCE_DIR \
53 | --class_data_dir=$CLASS_DIR \
54 | --output_dir=$OUTPUT_DIR \
55 | --with_prior_preservation --prior_loss_weight=1.0 \
56 | --instance_prompt="photo of a sks dog" \
57 | --class_prompt="photo of a dog" \
58 | --resolution=512 \
59 | --train_batch_size=1 \
60 | --gradient_accumulation_steps=1 \
61 | --learning_rate=1e-3 \
62 | --learning_rate_1d=1e-6 \
63 | --train_text_encoder \
64 | --lr_scheduler="constant" \
65 | --lr_warmup_steps=0 \
66 | --num_class_images=200 \
67 | --max_train_steps=500
68 | ```
69 |
70 | ### Inference
71 |
72 | ```python
73 | from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
74 | import torch
75 |
76 | from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff
77 |
78 | pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
79 | spectral_shifts_ckpt_dir = "ckpt-dir-path"
80 | unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
81 | text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
82 | # load pipe
83 | pipe = StableDiffusionPipeline.from_pretrained(
84 | pretrained_model_name_or_path,
85 | unet=unet,
86 | text_encoder=text_encoder,
87 | )
88 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
89 | pipe.to("cuda")
90 | image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
91 | ```
92 |
93 | You can use the following CLI too! Once it's done, you will see `grid.png` for the result.
94 |
95 | ```bash
96 | python inference.py \
97 | --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
98 | --spectral_shifts_ckpt="ckpt-dir-path" \
99 | --prompt="A picture of a sks dog in a bucket" \
100 | --scheduler_type="dpm_solver++" \
101 | --num_inference_steps=25 \
102 | --num_images_per_prompt=2
103 | ```
104 |
105 | ## Single Image Editing
106 | ### Training
107 | In Single Image Editing, your instance prompt should be just the description of your input image **without the identifier**.
108 |
109 | ```bash
110 | export MODEL_NAME="runwayml/stable-diffusion-v1-5"
111 | export INSTANCE_DIR="dir-path-to-input-image"
112 | export CLASS_DIR="path-to-class-images"
113 | export OUTPUT_DIR="path-to-save-model"
114 |
115 | accelerate launch train_svdiff.py \
116 | --pretrained_model_name_or_path=$MODEL_NAME \
117 | --instance_data_dir=$INSTANCE_DIR \
118 | --class_data_dir=$CLASS_DIR \
119 | --output_dir=$OUTPUT_DIR \
120 | --instance_prompt="photo of a pink chair with black legs" \
121 | --resolution=512 \
122 | --train_batch_size=1 \
123 | --gradient_accumulation_steps=1 \
124 | --learning_rate=1e-3 \
125 | --learning_rate_1d=1e-6 \
126 | --train_text_encoder \
127 | --lr_scheduler="constant" \
128 | --lr_warmup_steps=0 \
129 | --max_train_steps=500
130 | ```
131 |
132 | ### Inference
133 |
134 | ```python
135 | import torch
136 | from PIL import Image
137 | from diffusers import DDIMScheduler
138 | from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, StableDiffusionPipelineWithDDIMInversion
139 |
140 | pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"
141 | spectral_shifts_ckpt_dir = "ckpt-dir-path"
142 | image = "path-to-image"
143 | source_prompt = "prompt-for-image"
144 | target_prompt = "prompt-you-want-to-generate"
145 |
146 | unet = load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="unet")
147 | text_encoder = load_text_encoder_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=spectral_shifts_ckpt_dir, subfolder="text_encoder")
148 | # load pipe
149 | pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
150 | pretrained_model_name_or_path,
151 | unet=unet,
152 | text_encoder=text_encoder,
153 | )
154 | pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
155 | pipe.to("cuda")
156 |
157 | # (optional) ddim inversion
158 | # if you don't do it, inv_latents = None
159 | image = Image.open(image).convert("RGB").resize((512, 512))
160 | # in SVDiff, they use guidance scale=1 in ddim inversion
161 | # They use target_prompt in DDIM inversion for better results. See below for comparison between source_prompt and target_prompt.
162 | inv_latents = pipe.invert(target_prompt, image=image, guidance_scale=1.0).latents
163 |
164 | # They use a small cfg scale in Single Image Editing
165 | image = pipe(target_prompt, latents=inv_latents, guidance_scale=3, eta=0.5).images[0]
166 | ```
167 |
168 | DDIM inversion with target prompt (left) v.s. source prompt (right):
169 |
170 | 
171 |
"photo of a grey ~~Beetle~~ **Mustang** car" (original image: https://unsplash.com/photos/YEPDV3T8Vi8)
172 |
173 | To use slerp to add more stochasticity,
174 | ```python
175 | from svdiff_pytorch.utils import slerp_tensor
176 |
177 | # prev steps omitted
178 | inv_latents = pipe.invert(target_prompt, image=image, guidance_scale=1.0).latents
179 | noise_latents = pipe.prepare_latents(inv_latents.shape[0], inv_latents.shape[1], 512, 512, dtype=inv_latents.dtype, device=pipe.device, generator=torch.Generator("cuda").manual_seed(0))
180 | inv_latents = slerp_tensor(0.5, inv_latents, noise_latents)
181 | image = pipe(target_prompt, latents=inv_latents).images[0]
182 | ```
183 |
184 |
185 | ## Gradio
186 | You can also try SVDiff-pytorch in a UI with [gradio](https://gradio.app/). This demo supports both training and inference!
187 |
188 | [](https://huggingface.co/spaces/svdiff-library/SVDiff-Training-UI)
189 |
190 | If you want to run it locally, run the following commands step by step.
191 | ```bash
192 | $ git clone --recursive https://github.com/mkshing/svdiff-pytorch.git
193 | $ cd scripts/gradio
194 | $ pip install -r requirements.txt
195 | $ export HF_TOKEN="YOUR_HF_TOKEN_HERE"
196 | $ python app.py
197 | ```
198 |
199 | ## Additional Features
200 |
201 | ### Spectral Shift Scaling
202 |
203 | 
204 |
205 | You can adjust the strength of the weights by `--spectral_shifts_scale`
206 |
207 | Here's a result for 0.8, 1.0, 1.2 (1.0 is the default).
208 | 
209 |
210 |
211 | ### Fast prior generation by using ToMe
212 | By using [ToMe for SD](https://github.com/dbolya/tomesd), the prior generation can be faster!
213 | ```
214 | $ pip install tomesd
215 | ```
216 | And, add `--enable_tome_merging` to your training arguments!
217 |
218 | ## Citation
219 |
220 | ```bibtex
221 | @misc{https://doi.org/10.48550/arXiv.2303.11305,
222 | title = {SVDiff: Compact Parameter Space for Diffusion Fine-Tuning},
223 | author = {Ligong Han and Yinxiao Li and Han Zhang and Peyman Milanfar and Dimitris Metaxas and Feng Yang},
224 | year = {2023},
225 | eprint = {2303.11305},
226 | archivePrefix = {arXiv},
227 | primaryClass = {cs.CV},
228 | url = {https://arxiv.org/abs/2303.11305}
229 | }
230 | ```
231 |
232 | ```bibtex
233 | @misc{hu2021lora,
234 | title = {LoRA: Low-Rank Adaptation of Large Language Models},
235 | author = {Hu, Edward and Shen, Yelong and Wallis, Phil and Allen-Zhu, Zeyuan and Li, Yuanzhi and Wang, Lu and Chen, Weizhu},
236 | year = {2021},
237 | eprint = {2106.09685},
238 | archivePrefix = {arXiv},
239 | primaryClass = {cs.CL}
240 | }
241 | ```
242 |
243 | ```bibtex
244 | @article{bolya2023tomesd,
245 | title = {Token Merging for Fast Stable Diffusion},
246 | author = {Bolya, Daniel and Hoffman, Judy},
247 | journal = {arXiv},
248 | url = {https://arxiv.org/abs/2303.17604},
249 | year = {2023}
250 | }
251 | ```
252 |
253 | ## Reference
254 | - [DreamBooth in diffusers](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth)
255 | - [DreamBooth in ShivamShrirao](https://github.com/ShivamShrirao/diffusers/tree/main/examples/dreambooth)
256 | - [Data from custom-diffusion](https://github.com/adobe-research/custom-diffusion#getting-started)
257 |
258 | ## TODO
259 | - [x] Training
260 | - [x] Inference
261 | - [x] Scaling spectral shifts
262 | - [x] Support Single Image Editing
263 | - [ ] Support multiple spectral shifts (Section 3.2)
264 | - [ ] Cut-Mix-Unmix (Section 3.3)
265 | - [ ] SVDiff + LoRA
266 |
--------------------------------------------------------------------------------
/assets/car-result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/car-result.png
--------------------------------------------------------------------------------
/assets/chair-result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/chair-result.png
--------------------------------------------------------------------------------
/assets/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/dog.png
--------------------------------------------------------------------------------
/assets/kasumi.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/kasumi.png
--------------------------------------------------------------------------------
/assets/kumamon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/kumamon.png
--------------------------------------------------------------------------------
/assets/lora.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/lora.png
--------------------------------------------------------------------------------
/assets/scale-result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/scale-result.png
--------------------------------------------------------------------------------
/assets/scale.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/scale.png
--------------------------------------------------------------------------------
/assets/svdiff.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/assets/svdiff.png
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from tqdm import tqdm
4 | import random
5 | import torch
6 | import huggingface_hub
7 | from transformers import CLIPTextModel
8 | from diffusers import StableDiffusionPipeline
9 | from diffusers.utils import is_xformers_available
10 | from svdiff_pytorch import load_unet_for_svdiff, load_text_encoder_for_svdiff, SCHEDULER_MAPPING, image_grid
11 |
12 |
13 | def parse_args():
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument("--pretrained_model_name_or_path", type=str, help="pretrained model name or path")
16 | parser.add_argument("--spectral_shifts_ckpt", type=str, help="path to spectral_shifts.safetensors")
17 | # diffusers config
18 | parser.add_argument("--prompt", type=str, nargs="?", default="a photo of *s", help="the prompt to render")
19 | parser.add_argument("--num_inference_steps", type=int, default=50, help="number of sampling steps")
20 | parser.add_argument("--guidance_scale", type=float, default=7.5, help="unconditional guidance scale")
21 | parser.add_argument("--num_images_per_prompt", type=int, default=1, help="number of images per prompt")
22 | parser.add_argument("--height", type=int, default=512, help="image height, in pixel space",)
23 | parser.add_argument("--width", type=int, default=512, help="image width, in pixel space",)
24 | parser.add_argument("--seed", type=str, default="random_seed", help="the seed (for reproducible sampling)")
25 | parser.add_argument("--scheduler_type", type=str, choices=["ddim", "plms", "lms", "euler", "euler_ancestral", "dpm_solver++"], default="ddim", help="diffusion scheduler type")
26 | parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
27 | parser.add_argument("--spectral_shifts_scale", type=float, default=1.0, help="scaling spectral shifts")
28 | parser.add_argument("--fp16", action="store_true", help="fp16 inference")
29 | args = parser.parse_args()
30 | return args
31 |
32 |
33 | def load_text_encoder(pretrained_model_name_or_path, spectral_shifts_ckpt, device, fp16=False):
34 | if os.path.isdir(spectral_shifts_ckpt):
35 | spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors")
36 | elif not os.path.exists(spectral_shifts_ckpt):
37 | # download from hub
38 | hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
39 | try:
40 | spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs)
41 | except huggingface_hub.utils.EntryNotFoundError:
42 | return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
43 | if not os.path.exists(spectral_shifts_ckpt):
44 | return CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch.float16 if fp16 else None).to(device)
45 | text_encoder = load_text_encoder_for_svdiff(
46 | pretrained_model_name_or_path=pretrained_model_name_or_path,
47 | spectral_shifts_ckpt=spectral_shifts_ckpt,
48 | subfolder="text_encoder",
49 | )
50 | # first perform svd and cache
51 | for module in text_encoder.modules():
52 | if hasattr(module, "perform_svd"):
53 | module.perform_svd()
54 | if fp16:
55 | text_encoder = text_encoder.to(device, dtype=torch.float16)
56 | return text_encoder
57 |
58 |
59 |
60 | def main():
61 | args = parse_args()
62 | device = "cuda" if torch.cuda.is_available() else "cpu"
63 | print(f"device: {device}")
64 | # load unet
65 | unet = load_unet_for_svdiff(args.pretrained_model_name_or_path, spectral_shifts_ckpt=args.spectral_shifts_ckpt, subfolder="unet")
66 | unet = unet.to(device)
67 | # first perform svd and cache
68 | for module in unet.modules():
69 | if hasattr(module, "perform_svd"):
70 | module.perform_svd()
71 | if args.fp16:
72 | unet = unet.to(device, dtype=torch.float16)
73 | text_encoder = load_text_encoder(
74 | pretrained_model_name_or_path=args.pretrained_model_name_or_path,
75 | spectral_shifts_ckpt=args.spectral_shifts_ckpt,
76 | fp16=args.fp16,
77 | device=device
78 | )
79 |
80 | # load pipe
81 | pipe = StableDiffusionPipeline.from_pretrained(
82 | args.pretrained_model_name_or_path,
83 | unet=unet,
84 | text_encoder=text_encoder,
85 | requires_safety_checker=False,
86 | safety_checker=None,
87 | feature_extractor=None,
88 | scheduler=SCHEDULER_MAPPING[args.scheduler_type].from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler"),
89 | torch_dtype=torch.float16 if args.fp16 else None,
90 | )
91 | if args.enable_xformers_memory_efficient_attention:
92 | assert is_xformers_available()
93 | pipe.enable_xformers_memory_efficient_attention()
94 | print("Using xformers!")
95 | try:
96 | import tomesd
97 | tomesd.apply_patch(pipe, ratio=0.5)
98 | print("Using tomesd!")
99 | except:
100 | pass
101 | pipe = pipe.to(device)
102 | print("loaded pipeline")
103 | # run!
104 | if pipe.unet.conv_out.scale != args.spectral_shifts_scale:
105 | for module in pipe.unet.modules():
106 | if hasattr(module, "set_scale"):
107 | module.set_scale(scale=args.spectral_shifts_scale)
108 | if not isinstance(pipe.text_encoder, CLIPTextModel):
109 | for module in pipe.text_encoder.modules():
110 | if hasattr(module, "set_scale"):
111 | module.set_scale(scale=args.spectral_shifts_scale)
112 |
113 | print(f"Set spectral_shifts_scale to {args.spectral_shifts_scale}!")
114 |
115 | if args.seed == "random_seed":
116 | random.seed()
117 | seed = random.randint(0, 2**32)
118 | else:
119 | seed = int(args.seed)
120 | generator = torch.Generator(device=device).manual_seed(seed)
121 | print(f"seed: {seed}")
122 | prompts = args.prompt.split("::")
123 | all_images = []
124 | for prompt in tqdm(prompts):
125 | with torch.autocast(device), torch.inference_mode():
126 | images = pipe(
127 | prompt,
128 | num_inference_steps=args.num_inference_steps,
129 | guidance_scale=args.guidance_scale,
130 | generator=generator,
131 | num_images_per_prompt=args.num_images_per_prompt,
132 | height=args.height,
133 | width=args.width,
134 | ).images
135 | all_images.extend(images)
136 | grid_image = image_grid(all_images, len(prompts), args.num_images_per_prompt)
137 | grid_image.save("grid.png")
138 | print("DONE! See `grid.png` for the results!")
139 |
140 |
141 | if __name__ == '__main__':
142 | main()
143 |
144 |
--------------------------------------------------------------------------------
/paper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mkshing/svdiff-pytorch/a78f69e14410c1963318806050a566d262eca9f8/paper.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | diffusers==0.14.0
2 | accelerate
3 | torchvision
4 | safetensors
5 | transformers>=4.25.1, <=4.27.3
6 | ftfy
7 | tensorboard
8 | Jinja2
9 | einops
10 | wandb
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 |
4 | setup(
5 | name="svdiff-pytorch",
6 | version="0.2.0",
7 | author="Makoto Shing",
8 | url="https://github.com/mkshing/svdiff-pytorch",
9 | description="Implementation of 'SVDiff: Compact Parameter Space for Diffusion Fine-Tuning'",
10 | install_requires=[
11 | "diffusers==0.14.0",
12 | "accelerate",
13 | "torchvision",
14 | "safetensors",
15 | "transformers>=4.25.1",
16 | "ftfy",
17 | "tensorboard",
18 | "Jinja2",
19 | "einops",
20 | "wandb"
21 | ],
22 | packages=find_packages(exclude=("examples", "build")),
23 | license = 'MIT',
24 | long_description=open("README.md", "r", encoding="utf-8").read(),
25 | long_description_content_type="text/markdown",
26 | )
--------------------------------------------------------------------------------
/svdiff_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from svdiff_pytorch.diffusers_models.unet_2d_condition import UNet2DConditionModel as UNet2DConditionModelForSVDiff
2 | from svdiff_pytorch.transformers_models_clip.modeling_clip import CLIPTextModel as CLIPTextModelForSVDiff
3 | from svdiff_pytorch.utils import load_unet_for_svdiff, load_text_encoder_for_svdiff, image_grid, SCHEDULER_MAPPING
4 | from svdiff_pytorch.pipeline_stable_diffusion_ddim_inversion import StableDiffusionPipelineWithDDIMInversion
5 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/__init__.py:
--------------------------------------------------------------------------------
1 | # all files in this folder were taken from https://github.com/huggingface/diffusers/tree/main/src/diffusers/models
2 | # so, these files follow the LICENSE of diffusers
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 | from typing import Callable, Optional
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | from torch import nn
20 |
21 | from diffusers.utils.import_utils import is_xformers_available
22 | from svdiff_pytorch.diffusers_models.cross_attention import CrossAttention
23 | from diffusers.models.embeddings import CombinedTimestepLabelEmbeddings
24 | from svdiff_pytorch.layers import SVDLinear, SVDGroupNorm, SVDLayerNorm
25 |
26 |
27 | if is_xformers_available():
28 | import xformers
29 | import xformers.ops
30 | else:
31 | xformers = None
32 |
33 |
34 | class AttentionBlock(nn.Module):
35 | """
36 | An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
37 | to the N-d case.
38 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
39 | Uses three q, k, v linear layers to compute attention.
40 |
41 | Parameters:
42 | channels (`int`): The number of channels in the input and output.
43 | num_head_channels (`int`, *optional*):
44 | The number of channels in each head. If None, then `num_heads` = 1.
45 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
46 | rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
47 | eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
48 | """
49 |
50 | # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
51 |
52 | def __init__(
53 | self,
54 | channels: int,
55 | num_head_channels: Optional[int] = None,
56 | norm_num_groups: int = 32,
57 | rescale_output_factor: float = 1.0,
58 | eps: float = 1e-5,
59 | ):
60 | super().__init__()
61 | self.channels = channels
62 |
63 | self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
64 | self.num_head_size = num_head_channels
65 | self.group_norm = SVDGroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
66 |
67 | # define q,k,v as linear layers
68 | self.query = SVDLinear(channels, channels)
69 | self.key = SVDLinear(channels, channels)
70 | self.value = SVDLinear(channels, channels)
71 |
72 | self.rescale_output_factor = rescale_output_factor
73 | self.proj_attn = SVDLinear(channels, channels, 1)
74 |
75 | self._use_memory_efficient_attention_xformers = False
76 | self._attention_op = None
77 |
78 | def reshape_heads_to_batch_dim(self, tensor):
79 | batch_size, seq_len, dim = tensor.shape
80 | head_size = self.num_heads
81 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
82 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
83 | return tensor
84 |
85 | def reshape_batch_dim_to_heads(self, tensor):
86 | batch_size, seq_len, dim = tensor.shape
87 | head_size = self.num_heads
88 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
89 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
90 | return tensor
91 |
92 | def set_use_memory_efficient_attention_xformers(
93 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
94 | ):
95 | if use_memory_efficient_attention_xformers:
96 | if not is_xformers_available():
97 | raise ModuleNotFoundError(
98 | (
99 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
100 | " xformers"
101 | ),
102 | name="xformers",
103 | )
104 | elif not torch.cuda.is_available():
105 | raise ValueError(
106 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
107 | " only available for GPU "
108 | )
109 | else:
110 | try:
111 | # Make sure we can run the memory efficient attention
112 | _ = xformers.ops.memory_efficient_attention(
113 | torch.randn((1, 2, 40), device="cuda"),
114 | torch.randn((1, 2, 40), device="cuda"),
115 | torch.randn((1, 2, 40), device="cuda"),
116 | )
117 | except Exception as e:
118 | raise e
119 | self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
120 | self._attention_op = attention_op
121 |
122 | def forward(self, hidden_states):
123 | residual = hidden_states
124 | batch, channel, height, width = hidden_states.shape
125 |
126 | # norm
127 | hidden_states = self.group_norm(hidden_states)
128 |
129 | hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
130 |
131 | # proj to q, k, v
132 | query_proj = self.query(hidden_states)
133 | key_proj = self.key(hidden_states)
134 | value_proj = self.value(hidden_states)
135 |
136 | scale = 1 / math.sqrt(self.channels / self.num_heads)
137 |
138 | query_proj = self.reshape_heads_to_batch_dim(query_proj)
139 | key_proj = self.reshape_heads_to_batch_dim(key_proj)
140 | value_proj = self.reshape_heads_to_batch_dim(value_proj)
141 |
142 | if self._use_memory_efficient_attention_xformers:
143 | # Memory efficient attention
144 | hidden_states = xformers.ops.memory_efficient_attention(
145 | query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
146 | )
147 | hidden_states = hidden_states.to(query_proj.dtype)
148 | else:
149 | attention_scores = torch.baddbmm(
150 | torch.empty(
151 | query_proj.shape[0],
152 | query_proj.shape[1],
153 | key_proj.shape[1],
154 | dtype=query_proj.dtype,
155 | device=query_proj.device,
156 | ),
157 | query_proj,
158 | key_proj.transpose(-1, -2),
159 | beta=0,
160 | alpha=scale,
161 | )
162 | attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
163 | hidden_states = torch.bmm(attention_probs, value_proj)
164 |
165 | # reshape hidden_states
166 | hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
167 |
168 | # compute next hidden_states
169 | hidden_states = self.proj_attn(hidden_states)
170 |
171 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
172 |
173 | # res connect and rescale
174 | hidden_states = (hidden_states + residual) / self.rescale_output_factor
175 | return hidden_states
176 |
177 |
178 | class BasicTransformerBlock(nn.Module):
179 | r"""
180 | A basic Transformer block.
181 |
182 | Parameters:
183 | dim (`int`): The number of channels in the input and output.
184 | num_attention_heads (`int`): The number of heads to use for multi-head attention.
185 | attention_head_dim (`int`): The number of channels in each head.
186 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
187 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
188 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
189 | num_embeds_ada_norm (:
190 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
191 | attention_bias (:
192 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
193 | """
194 |
195 | def __init__(
196 | self,
197 | dim: int,
198 | num_attention_heads: int,
199 | attention_head_dim: int,
200 | dropout=0.0,
201 | cross_attention_dim: Optional[int] = None,
202 | activation_fn: str = "geglu",
203 | num_embeds_ada_norm: Optional[int] = None,
204 | attention_bias: bool = False,
205 | only_cross_attention: bool = False,
206 | upcast_attention: bool = False,
207 | norm_elementwise_affine: bool = True,
208 | norm_type: str = "layer_norm",
209 | final_dropout: bool = False,
210 | ):
211 | super().__init__()
212 | self.only_cross_attention = only_cross_attention
213 |
214 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
215 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
216 |
217 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
218 | raise ValueError(
219 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
220 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
221 | )
222 |
223 | # 1. Self-Attn
224 | self.attn1 = CrossAttention(
225 | query_dim=dim,
226 | heads=num_attention_heads,
227 | dim_head=attention_head_dim,
228 | dropout=dropout,
229 | bias=attention_bias,
230 | cross_attention_dim=cross_attention_dim if only_cross_attention else None,
231 | upcast_attention=upcast_attention,
232 | )
233 |
234 | self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
235 |
236 | # 2. Cross-Attn
237 | if cross_attention_dim is not None:
238 | self.attn2 = CrossAttention(
239 | query_dim=dim,
240 | cross_attention_dim=cross_attention_dim,
241 | heads=num_attention_heads,
242 | dim_head=attention_head_dim,
243 | dropout=dropout,
244 | bias=attention_bias,
245 | upcast_attention=upcast_attention,
246 | ) # is self-attn if encoder_hidden_states is none
247 | else:
248 | self.attn2 = None
249 |
250 | if self.use_ada_layer_norm:
251 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
252 | elif self.use_ada_layer_norm_zero:
253 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
254 | else:
255 | self.norm1 = SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine)
256 |
257 | if cross_attention_dim is not None:
258 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
259 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
260 | # the second cross attention block.
261 | self.norm2 = (
262 | AdaLayerNorm(dim, num_embeds_ada_norm)
263 | if self.use_ada_layer_norm
264 | else SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine)
265 | )
266 | else:
267 | self.norm2 = None
268 |
269 | # 3. Feed-forward
270 | self.norm3 = SVDLayerNorm(dim, elementwise_affine=norm_elementwise_affine)
271 |
272 | def forward(
273 | self,
274 | hidden_states,
275 | encoder_hidden_states=None,
276 | timestep=None,
277 | attention_mask=None,
278 | cross_attention_kwargs=None,
279 | class_labels=None,
280 | ):
281 | if self.use_ada_layer_norm:
282 | norm_hidden_states = self.norm1(hidden_states, timestep)
283 | elif self.use_ada_layer_norm_zero:
284 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
285 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
286 | )
287 | else:
288 | norm_hidden_states = self.norm1(hidden_states)
289 |
290 | # 1. Self-Attention
291 | cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
292 | attn_output = self.attn1(
293 | norm_hidden_states,
294 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
295 | attention_mask=attention_mask,
296 | **cross_attention_kwargs,
297 | )
298 | if self.use_ada_layer_norm_zero:
299 | attn_output = gate_msa.unsqueeze(1) * attn_output
300 | hidden_states = attn_output + hidden_states
301 |
302 | if self.attn2 is not None:
303 | norm_hidden_states = (
304 | self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
305 | )
306 |
307 | # 2. Cross-Attention
308 | attn_output = self.attn2(
309 | norm_hidden_states,
310 | encoder_hidden_states=encoder_hidden_states,
311 | attention_mask=attention_mask,
312 | **cross_attention_kwargs,
313 | )
314 | hidden_states = attn_output + hidden_states
315 |
316 | # 3. Feed-forward
317 | norm_hidden_states = self.norm3(hidden_states)
318 |
319 | if self.use_ada_layer_norm_zero:
320 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
321 |
322 | ff_output = self.ff(norm_hidden_states)
323 |
324 | if self.use_ada_layer_norm_zero:
325 | ff_output = gate_mlp.unsqueeze(1) * ff_output
326 |
327 | hidden_states = ff_output + hidden_states
328 |
329 | return hidden_states
330 |
331 |
332 | class FeedForward(nn.Module):
333 | r"""
334 | A feed-forward layer.
335 |
336 | Parameters:
337 | dim (`int`): The number of channels in the input.
338 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
339 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
340 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
341 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
342 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
343 | """
344 |
345 | def __init__(
346 | self,
347 | dim: int,
348 | dim_out: Optional[int] = None,
349 | mult: int = 4,
350 | dropout: float = 0.0,
351 | activation_fn: str = "geglu",
352 | final_dropout: bool = False,
353 | ):
354 | super().__init__()
355 | inner_dim = int(dim * mult)
356 | dim_out = dim_out if dim_out is not None else dim
357 |
358 | if activation_fn == "gelu":
359 | act_fn = GELU(dim, inner_dim)
360 | if activation_fn == "gelu-approximate":
361 | act_fn = GELU(dim, inner_dim, approximate="tanh")
362 | elif activation_fn == "geglu":
363 | act_fn = GEGLU(dim, inner_dim)
364 | elif activation_fn == "geglu-approximate":
365 | act_fn = ApproximateGELU(dim, inner_dim)
366 |
367 | self.net = nn.ModuleList([])
368 | # project in
369 | self.net.append(act_fn)
370 | # project dropout
371 | self.net.append(nn.Dropout(dropout))
372 | # project out
373 | self.net.append(SVDLinear(inner_dim, dim_out))
374 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
375 | if final_dropout:
376 | self.net.append(nn.Dropout(dropout))
377 |
378 | def forward(self, hidden_states):
379 | for module in self.net:
380 | hidden_states = module(hidden_states)
381 | return hidden_states
382 |
383 |
384 | class GELU(nn.Module):
385 | r"""
386 | GELU activation function with tanh approximation support with `approximate="tanh"`.
387 | """
388 |
389 | def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
390 | super().__init__()
391 | self.proj = SVDLinear(dim_in, dim_out)
392 | self.approximate = approximate
393 |
394 | def gelu(self, gate):
395 | if gate.device.type != "mps":
396 | return F.gelu(gate, approximate=self.approximate)
397 | # mps: gelu is not implemented for float16
398 | return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
399 |
400 | def forward(self, hidden_states):
401 | hidden_states = self.proj(hidden_states)
402 | hidden_states = self.gelu(hidden_states)
403 | return hidden_states
404 |
405 |
406 | class GEGLU(nn.Module):
407 | r"""
408 | A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
409 |
410 | Parameters:
411 | dim_in (`int`): The number of channels in the input.
412 | dim_out (`int`): The number of channels in the output.
413 | """
414 |
415 | def __init__(self, dim_in: int, dim_out: int):
416 | super().__init__()
417 | self.proj = SVDLinear(dim_in, dim_out * 2)
418 |
419 | def gelu(self, gate):
420 | if gate.device.type != "mps":
421 | return F.gelu(gate)
422 | # mps: gelu is not implemented for float16
423 | return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
424 |
425 | def forward(self, hidden_states):
426 | hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
427 | return hidden_states * self.gelu(gate)
428 |
429 |
430 | class ApproximateGELU(nn.Module):
431 | """
432 | The approximate form of Gaussian Error Linear Unit (GELU)
433 |
434 | For more details, see section 2: https://arxiv.org/abs/1606.08415
435 | """
436 |
437 | def __init__(self, dim_in: int, dim_out: int):
438 | super().__init__()
439 | self.proj = SVDLinear(dim_in, dim_out)
440 |
441 | def forward(self, x):
442 | x = self.proj(x)
443 | return x * torch.sigmoid(1.702 * x)
444 |
445 |
446 | class AdaLayerNorm(nn.Module):
447 | """
448 | Norm layer modified to incorporate timestep embeddings.
449 | """
450 |
451 | def __init__(self, embedding_dim, num_embeddings):
452 | super().__init__()
453 | self.emb = nn.Embedding(num_embeddings, embedding_dim)
454 | self.silu = nn.SiLU()
455 | self.linear = SVDLinear(embedding_dim, embedding_dim * 2)
456 | self.norm = SVDLayerNorm(embedding_dim, elementwise_affine=False)
457 |
458 | def forward(self, x, timestep):
459 | emb = self.linear(self.silu(self.emb(timestep)))
460 | scale, shift = torch.chunk(emb, 2)
461 | x = self.norm(x) * (1 + scale) + shift
462 | return x
463 |
464 |
465 | class AdaLayerNormZero(nn.Module):
466 | """
467 | Norm layer adaptive layer norm zero (adaLN-Zero).
468 | """
469 |
470 | def __init__(self, embedding_dim, num_embeddings):
471 | super().__init__()
472 |
473 | self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
474 |
475 | self.silu = nn.SiLU()
476 | self.linear = SVDLinear(embedding_dim, 6 * embedding_dim, bias=True)
477 | self.norm = SVDLayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
478 |
479 | def forward(self, x, timestep, class_labels, hidden_dtype=None):
480 | emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
481 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
482 | x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
483 | return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
484 |
485 |
486 | class AdaGroupNorm(nn.Module):
487 | """
488 | GroupNorm layer modified to incorporate timestep embeddings.
489 | """
490 |
491 | def __init__(
492 | self, embedding_dim: int, out_dim: int, num_groups: int, act_fn: Optional[str] = None, eps: float = 1e-5
493 | ):
494 | super().__init__()
495 | self.num_groups = num_groups
496 | self.eps = eps
497 | self.act = None
498 | if act_fn == "swish":
499 | self.act = lambda x: F.silu(x)
500 | elif act_fn == "mish":
501 | self.act = nn.Mish()
502 | elif act_fn == "silu":
503 | self.act = nn.SiLU()
504 | elif act_fn == "gelu":
505 | self.act = nn.GELU()
506 |
507 | self.linear = SVDLinear(embedding_dim, out_dim * 2)
508 |
509 | def forward(self, x, emb):
510 | if self.act:
511 | emb = self.act(emb)
512 | emb = self.linear(emb)
513 | emb = emb[:, :, None, None]
514 | scale, shift = emb.chunk(2, dim=1)
515 |
516 | x = F.group_norm(x, self.num_groups, eps=self.eps)
517 | x = x * (1 + scale) + shift
518 | return x
519 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/cross_attention.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Callable, Optional, Union
15 |
16 | import torch
17 | import torch.nn.functional as F
18 | from torch import nn
19 |
20 | from diffusers.utils import deprecate, logging
21 | from diffusers.utils.import_utils import is_xformers_available
22 | from svdiff_pytorch.layers import SVDLinear, SVDGroupNorm, SVDLayerNorm
23 |
24 |
25 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
26 |
27 |
28 | if is_xformers_available():
29 | import xformers
30 | import xformers.ops
31 | else:
32 | xformers = None
33 |
34 |
35 | class CrossAttention(nn.Module):
36 | r"""
37 | A cross attention layer.
38 |
39 | Parameters:
40 | query_dim (`int`): The number of channels in the query.
41 | cross_attention_dim (`int`, *optional*):
42 | The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
43 | heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
44 | dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
45 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
46 | bias (`bool`, *optional*, defaults to False):
47 | Set to `True` for the query, key, and value linear layers to contain a bias parameter.
48 | """
49 |
50 | def __init__(
51 | self,
52 | query_dim: int,
53 | cross_attention_dim: Optional[int] = None,
54 | heads: int = 8,
55 | dim_head: int = 64,
56 | dropout: float = 0.0,
57 | bias=False,
58 | upcast_attention: bool = False,
59 | upcast_softmax: bool = False,
60 | cross_attention_norm: bool = False,
61 | added_kv_proj_dim: Optional[int] = None,
62 | norm_num_groups: Optional[int] = None,
63 | processor: Optional["AttnProcessor"] = None,
64 | ):
65 | super().__init__()
66 | inner_dim = dim_head * heads
67 | cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
68 | self.upcast_attention = upcast_attention
69 | self.upcast_softmax = upcast_softmax
70 | self.cross_attention_norm = cross_attention_norm
71 |
72 | self.scale = dim_head**-0.5
73 |
74 | self.heads = heads
75 | # for slice_size > 0 the attention score computation
76 | # is split across the batch axis to save memory
77 | # You can set slice_size with `set_attention_slice`
78 | self.sliceable_head_dim = heads
79 |
80 | self.added_kv_proj_dim = added_kv_proj_dim
81 |
82 | if norm_num_groups is not None:
83 | self.group_norm = SVDGroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
84 | else:
85 | self.group_norm = None
86 |
87 | if cross_attention_norm:
88 | self.norm_cross = SVDLayerNorm(cross_attention_dim)
89 |
90 | self.to_q = SVDLinear(query_dim, inner_dim, bias=bias)
91 | self.to_k = SVDLinear(cross_attention_dim, inner_dim, bias=bias)
92 | self.to_v = SVDLinear(cross_attention_dim, inner_dim, bias=bias)
93 |
94 | if self.added_kv_proj_dim is not None:
95 | self.add_k_proj = SVDLinear(added_kv_proj_dim, cross_attention_dim)
96 | self.add_v_proj = SVDLinear(added_kv_proj_dim, cross_attention_dim)
97 |
98 | self.to_out = nn.ModuleList([])
99 | self.to_out.append(SVDLinear(inner_dim, query_dim))
100 | self.to_out.append(nn.Dropout(dropout))
101 |
102 | # set attention processor
103 | # We use the AttnProcessor2_0 by default when torch2.x is used which uses
104 | # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
105 | if processor is None:
106 | processor = AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else CrossAttnProcessor()
107 | self.set_processor(processor)
108 |
109 | def set_use_memory_efficient_attention_xformers(
110 | self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
111 | ):
112 | is_lora = hasattr(self, "processor") and isinstance(
113 | self.processor, (LoRACrossAttnProcessor, LoRAXFormersCrossAttnProcessor)
114 | )
115 |
116 | if use_memory_efficient_attention_xformers:
117 | if self.added_kv_proj_dim is not None:
118 | # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
119 | # which uses this type of cross attention ONLY because the attention mask of format
120 | # [0, ..., -10.000, ..., 0, ...,] is not supported
121 | raise NotImplementedError(
122 | "Memory efficient attention with `xformers` is currently not supported when"
123 | " `self.added_kv_proj_dim` is defined."
124 | )
125 | elif not is_xformers_available():
126 | raise ModuleNotFoundError(
127 | (
128 | "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
129 | " xformers"
130 | ),
131 | name="xformers",
132 | )
133 | elif not torch.cuda.is_available():
134 | raise ValueError(
135 | "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
136 | " only available for GPU "
137 | )
138 | else:
139 | try:
140 | # Make sure we can run the memory efficient attention
141 | _ = xformers.ops.memory_efficient_attention(
142 | torch.randn((1, 2, 40), device="cuda"),
143 | torch.randn((1, 2, 40), device="cuda"),
144 | torch.randn((1, 2, 40), device="cuda"),
145 | )
146 | except Exception as e:
147 | raise e
148 |
149 | if is_lora:
150 | processor = LoRAXFormersCrossAttnProcessor(
151 | hidden_size=self.processor.hidden_size,
152 | cross_attention_dim=self.processor.cross_attention_dim,
153 | rank=self.processor.rank,
154 | attention_op=attention_op,
155 | )
156 | processor.load_state_dict(self.processor.state_dict())
157 | processor.to(self.processor.to_q_lora.up.weight.device)
158 | else:
159 | processor = XFormersCrossAttnProcessor(attention_op=attention_op)
160 | else:
161 | if is_lora:
162 | processor = LoRACrossAttnProcessor(
163 | hidden_size=self.processor.hidden_size,
164 | cross_attention_dim=self.processor.cross_attention_dim,
165 | rank=self.processor.rank,
166 | )
167 | processor.load_state_dict(self.processor.state_dict())
168 | processor.to(self.processor.to_q_lora.up.weight.device)
169 | else:
170 | processor = CrossAttnProcessor()
171 |
172 | self.set_processor(processor)
173 |
174 | def set_attention_slice(self, slice_size):
175 | if slice_size is not None and slice_size > self.sliceable_head_dim:
176 | raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
177 |
178 | if slice_size is not None and self.added_kv_proj_dim is not None:
179 | processor = SlicedAttnAddedKVProcessor(slice_size)
180 | elif slice_size is not None:
181 | processor = SlicedAttnProcessor(slice_size)
182 | elif self.added_kv_proj_dim is not None:
183 | processor = CrossAttnAddedKVProcessor()
184 | else:
185 | processor = CrossAttnProcessor()
186 |
187 | self.set_processor(processor)
188 |
189 | def set_processor(self, processor: "AttnProcessor"):
190 | # if current processor is in `self._modules` and if passed `processor` is not, we need to
191 | # pop `processor` from `self._modules`
192 | if (
193 | hasattr(self, "processor")
194 | and isinstance(self.processor, torch.nn.Module)
195 | and not isinstance(processor, torch.nn.Module)
196 | ):
197 | logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
198 | self._modules.pop("processor")
199 |
200 | self.processor = processor
201 |
202 | def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
203 | # The `CrossAttention` class can call different attention processors / attention functions
204 | # here we simply pass along all tensors to the selected processor class
205 | # For standard processors that are defined here, `**cross_attention_kwargs` is empty
206 | return self.processor(
207 | self,
208 | hidden_states,
209 | encoder_hidden_states=encoder_hidden_states,
210 | attention_mask=attention_mask,
211 | **cross_attention_kwargs,
212 | )
213 |
214 | def batch_to_head_dim(self, tensor):
215 | head_size = self.heads
216 | batch_size, seq_len, dim = tensor.shape
217 | tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
218 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
219 | return tensor
220 |
221 | def head_to_batch_dim(self, tensor):
222 | head_size = self.heads
223 | batch_size, seq_len, dim = tensor.shape
224 | tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
225 | tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
226 | return tensor
227 |
228 | def get_attention_scores(self, query, key, attention_mask=None):
229 | dtype = query.dtype
230 | if self.upcast_attention:
231 | query = query.float()
232 | key = key.float()
233 |
234 | if attention_mask is None:
235 | baddbmm_input = torch.empty(
236 | query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device
237 | )
238 | beta = 0
239 | else:
240 | baddbmm_input = attention_mask
241 | beta = 1
242 |
243 | attention_scores = torch.baddbmm(
244 | baddbmm_input,
245 | query,
246 | key.transpose(-1, -2),
247 | beta=beta,
248 | alpha=self.scale,
249 | )
250 |
251 | if self.upcast_softmax:
252 | attention_scores = attention_scores.float()
253 |
254 | attention_probs = attention_scores.softmax(dim=-1)
255 | attention_probs = attention_probs.to(dtype)
256 |
257 | return attention_probs
258 |
259 | def prepare_attention_mask(self, attention_mask, target_length, batch_size=None):
260 | if batch_size is None:
261 | deprecate(
262 | "batch_size=None",
263 | "0.0.15",
264 | message=(
265 | "Not passing the `batch_size` parameter to `prepare_attention_mask` can lead to incorrect"
266 | " attention mask preparation and is deprecated behavior. Please make sure to pass `batch_size` to"
267 | " `prepare_attention_mask` when preparing the attention_mask."
268 | ),
269 | )
270 | batch_size = 1
271 |
272 | head_size = self.heads
273 | if attention_mask is None:
274 | return attention_mask
275 |
276 | if attention_mask.shape[-1] != target_length:
277 | if attention_mask.device.type == "mps":
278 | # HACK: MPS: Does not support padding by greater than dimension of input tensor.
279 | # Instead, we can manually construct the padding tensor.
280 | padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
281 | padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
282 | attention_mask = torch.cat([attention_mask, padding], dim=2)
283 | else:
284 | attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
285 |
286 | if attention_mask.shape[0] < batch_size * head_size:
287 | attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
288 | return attention_mask
289 |
290 |
291 | class CrossAttnProcessor:
292 | def __call__(
293 | self,
294 | attn: CrossAttention,
295 | hidden_states,
296 | encoder_hidden_states=None,
297 | attention_mask=None,
298 | ):
299 | batch_size, sequence_length, _ = hidden_states.shape
300 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
301 | query = attn.to_q(hidden_states)
302 |
303 | if encoder_hidden_states is None:
304 | encoder_hidden_states = hidden_states
305 | elif attn.cross_attention_norm:
306 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
307 |
308 | key = attn.to_k(encoder_hidden_states)
309 | value = attn.to_v(encoder_hidden_states)
310 |
311 | query = attn.head_to_batch_dim(query)
312 | key = attn.head_to_batch_dim(key)
313 | value = attn.head_to_batch_dim(value)
314 |
315 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
316 | hidden_states = torch.bmm(attention_probs, value)
317 | hidden_states = attn.batch_to_head_dim(hidden_states)
318 |
319 | # linear proj
320 | hidden_states = attn.to_out[0](hidden_states)
321 | # dropout
322 | hidden_states = attn.to_out[1](hidden_states)
323 |
324 | return hidden_states
325 |
326 |
327 | class LoRALinearLayer(nn.Module):
328 | def __init__(self, in_features, out_features, rank=4):
329 | super().__init__()
330 |
331 | if rank > min(in_features, out_features):
332 | raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
333 |
334 | self.down = SVDLinear(in_features, rank, bias=False)
335 | self.up = SVDLinear(rank, out_features, bias=False)
336 |
337 | nn.init.normal_(self.down.weight, std=1 / rank)
338 | nn.init.zeros_(self.up.weight)
339 |
340 | def forward(self, hidden_states):
341 | orig_dtype = hidden_states.dtype
342 | dtype = self.down.weight.dtype
343 |
344 | down_hidden_states = self.down(hidden_states.to(dtype))
345 | up_hidden_states = self.up(down_hidden_states)
346 |
347 | return up_hidden_states.to(orig_dtype)
348 |
349 |
350 | class LoRACrossAttnProcessor(nn.Module):
351 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
352 | super().__init__()
353 |
354 | self.hidden_size = hidden_size
355 | self.cross_attention_dim = cross_attention_dim
356 | self.rank = rank
357 |
358 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
359 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
360 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
361 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
362 |
363 | def __call__(
364 | self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
365 | ):
366 | batch_size, sequence_length, _ = hidden_states.shape
367 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
368 |
369 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
370 | query = attn.head_to_batch_dim(query)
371 |
372 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
373 |
374 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
375 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
376 |
377 | key = attn.head_to_batch_dim(key)
378 | value = attn.head_to_batch_dim(value)
379 |
380 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
381 | hidden_states = torch.bmm(attention_probs, value)
382 | hidden_states = attn.batch_to_head_dim(hidden_states)
383 |
384 | # linear proj
385 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
386 | # dropout
387 | hidden_states = attn.to_out[1](hidden_states)
388 |
389 | return hidden_states
390 |
391 |
392 | class CrossAttnAddedKVProcessor:
393 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
394 | residual = hidden_states
395 | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
396 | batch_size, sequence_length, _ = hidden_states.shape
397 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
398 |
399 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
400 |
401 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
402 |
403 | query = attn.to_q(hidden_states)
404 | query = attn.head_to_batch_dim(query)
405 |
406 | key = attn.to_k(hidden_states)
407 | value = attn.to_v(hidden_states)
408 | key = attn.head_to_batch_dim(key)
409 | value = attn.head_to_batch_dim(value)
410 |
411 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
412 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
413 | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
414 | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
415 |
416 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
417 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
418 |
419 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
420 | hidden_states = torch.bmm(attention_probs, value)
421 | hidden_states = attn.batch_to_head_dim(hidden_states)
422 |
423 | # linear proj
424 | hidden_states = attn.to_out[0](hidden_states)
425 | # dropout
426 | hidden_states = attn.to_out[1](hidden_states)
427 |
428 | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
429 | hidden_states = hidden_states + residual
430 |
431 | return hidden_states
432 |
433 |
434 | class XFormersCrossAttnProcessor:
435 | def __init__(self, attention_op: Optional[Callable] = None):
436 | self.attention_op = attention_op
437 |
438 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
439 | batch_size, sequence_length, _ = hidden_states.shape
440 |
441 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
442 |
443 | query = attn.to_q(hidden_states)
444 |
445 | if encoder_hidden_states is None:
446 | encoder_hidden_states = hidden_states
447 | elif attn.cross_attention_norm:
448 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
449 |
450 | key = attn.to_k(encoder_hidden_states)
451 | value = attn.to_v(encoder_hidden_states)
452 |
453 | query = attn.head_to_batch_dim(query).contiguous()
454 | key = attn.head_to_batch_dim(key).contiguous()
455 | value = attn.head_to_batch_dim(value).contiguous()
456 |
457 | hidden_states = xformers.ops.memory_efficient_attention(
458 | query, key, value, attn_bias=attention_mask, op=self.attention_op
459 | )
460 | hidden_states = hidden_states.to(query.dtype)
461 | hidden_states = attn.batch_to_head_dim(hidden_states)
462 |
463 | # linear proj
464 | hidden_states = attn.to_out[0](hidden_states)
465 | # dropout
466 | hidden_states = attn.to_out[1](hidden_states)
467 | return hidden_states
468 |
469 |
470 | class AttnProcessor2_0:
471 | def __init__(self):
472 | if not hasattr(F, "scaled_dot_product_attention"):
473 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
474 |
475 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
476 | batch_size, sequence_length, inner_dim = hidden_states.shape
477 |
478 | if attention_mask is not None:
479 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
480 | # scaled_dot_product_attention expects attention_mask shape to be
481 | # (batch, heads, source_length, target_length)
482 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
483 |
484 | query = attn.to_q(hidden_states)
485 |
486 | if encoder_hidden_states is None:
487 | encoder_hidden_states = hidden_states
488 | elif attn.cross_attention_norm:
489 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
490 |
491 | key = attn.to_k(encoder_hidden_states)
492 | value = attn.to_v(encoder_hidden_states)
493 |
494 | head_dim = inner_dim // attn.heads
495 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
496 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
497 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
498 |
499 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
500 | hidden_states = F.scaled_dot_product_attention(
501 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
502 | )
503 |
504 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
505 | hidden_states = hidden_states.to(query.dtype)
506 |
507 | # linear proj
508 | hidden_states = attn.to_out[0](hidden_states)
509 | # dropout
510 | hidden_states = attn.to_out[1](hidden_states)
511 | return hidden_states
512 |
513 |
514 | class LoRAXFormersCrossAttnProcessor(nn.Module):
515 | def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
516 | super().__init__()
517 |
518 | self.hidden_size = hidden_size
519 | self.cross_attention_dim = cross_attention_dim
520 | self.rank = rank
521 | self.attention_op = attention_op
522 |
523 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
524 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
525 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
526 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
527 |
528 | def __call__(
529 | self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
530 | ):
531 | batch_size, sequence_length, _ = hidden_states.shape
532 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
533 |
534 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
535 | query = attn.head_to_batch_dim(query).contiguous()
536 |
537 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
538 |
539 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
540 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
541 |
542 | key = attn.head_to_batch_dim(key).contiguous()
543 | value = attn.head_to_batch_dim(value).contiguous()
544 |
545 | hidden_states = xformers.ops.memory_efficient_attention(
546 | query, key, value, attn_bias=attention_mask, op=self.attention_op
547 | )
548 | hidden_states = attn.batch_to_head_dim(hidden_states)
549 |
550 | # linear proj
551 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
552 | # dropout
553 | hidden_states = attn.to_out[1](hidden_states)
554 |
555 | return hidden_states
556 |
557 |
558 | class SlicedAttnProcessor:
559 | def __init__(self, slice_size):
560 | self.slice_size = slice_size
561 |
562 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
563 | batch_size, sequence_length, _ = hidden_states.shape
564 |
565 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
566 |
567 | query = attn.to_q(hidden_states)
568 | dim = query.shape[-1]
569 | query = attn.head_to_batch_dim(query)
570 |
571 | if encoder_hidden_states is None:
572 | encoder_hidden_states = hidden_states
573 | elif attn.cross_attention_norm:
574 | encoder_hidden_states = attn.norm_cross(encoder_hidden_states)
575 |
576 | key = attn.to_k(encoder_hidden_states)
577 | value = attn.to_v(encoder_hidden_states)
578 | key = attn.head_to_batch_dim(key)
579 | value = attn.head_to_batch_dim(value)
580 |
581 | batch_size_attention = query.shape[0]
582 | hidden_states = torch.zeros(
583 | (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
584 | )
585 |
586 | for i in range(hidden_states.shape[0] // self.slice_size):
587 | start_idx = i * self.slice_size
588 | end_idx = (i + 1) * self.slice_size
589 |
590 | query_slice = query[start_idx:end_idx]
591 | key_slice = key[start_idx:end_idx]
592 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
593 |
594 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
595 |
596 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
597 |
598 | hidden_states[start_idx:end_idx] = attn_slice
599 |
600 | hidden_states = attn.batch_to_head_dim(hidden_states)
601 |
602 | # linear proj
603 | hidden_states = attn.to_out[0](hidden_states)
604 | # dropout
605 | hidden_states = attn.to_out[1](hidden_states)
606 |
607 | return hidden_states
608 |
609 |
610 | class SlicedAttnAddedKVProcessor:
611 | def __init__(self, slice_size):
612 | self.slice_size = slice_size
613 |
614 | def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
615 | residual = hidden_states
616 | hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
617 | encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
618 |
619 | batch_size, sequence_length, _ = hidden_states.shape
620 |
621 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
622 |
623 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
624 |
625 | query = attn.to_q(hidden_states)
626 | dim = query.shape[-1]
627 | query = attn.head_to_batch_dim(query)
628 |
629 | key = attn.to_k(hidden_states)
630 | value = attn.to_v(hidden_states)
631 | encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
632 | encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
633 |
634 | key = attn.head_to_batch_dim(key)
635 | value = attn.head_to_batch_dim(value)
636 | encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
637 | encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
638 |
639 | key = torch.cat([encoder_hidden_states_key_proj, key], dim=1)
640 | value = torch.cat([encoder_hidden_states_value_proj, value], dim=1)
641 |
642 | batch_size_attention = query.shape[0]
643 | hidden_states = torch.zeros(
644 | (batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
645 | )
646 |
647 | for i in range(hidden_states.shape[0] // self.slice_size):
648 | start_idx = i * self.slice_size
649 | end_idx = (i + 1) * self.slice_size
650 |
651 | query_slice = query[start_idx:end_idx]
652 | key_slice = key[start_idx:end_idx]
653 | attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
654 |
655 | attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
656 |
657 | attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
658 |
659 | hidden_states[start_idx:end_idx] = attn_slice
660 |
661 | hidden_states = attn.batch_to_head_dim(hidden_states)
662 |
663 | # linear proj
664 | hidden_states = attn.to_out[0](hidden_states)
665 | # dropout
666 | hidden_states = attn.to_out[1](hidden_states)
667 |
668 | hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
669 | hidden_states = hidden_states + residual
670 |
671 | return hidden_states
672 |
673 |
674 | AttnProcessor = Union[
675 | CrossAttnProcessor,
676 | XFormersCrossAttnProcessor,
677 | SlicedAttnProcessor,
678 | CrossAttnAddedKVProcessor,
679 | SlicedAttnAddedKVProcessor,
680 | LoRACrossAttnProcessor,
681 | LoRAXFormersCrossAttnProcessor,
682 | ]
683 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/dual_transformer_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from typing import Optional
15 |
16 | from torch import nn
17 |
18 | from svdiff_pytorch.diffusers_models.transformer_2d import Transformer2DModel, Transformer2DModelOutput
19 |
20 |
21 | class DualTransformer2DModel(nn.Module):
22 | """
23 | Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
24 |
25 | Parameters:
26 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
27 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
28 | in_channels (`int`, *optional*):
29 | Pass if the input is continuous. The number of channels in the input and output.
30 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
31 | dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
32 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
33 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
34 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See
35 | `ImagePositionalEmbeddings`.
36 | num_vector_embeds (`int`, *optional*):
37 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
38 | Includes the class for the masked latent pixel.
39 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
40 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
41 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used
42 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
43 | up to but not more than steps than `num_embeds_ada_norm`.
44 | attention_bias (`bool`, *optional*):
45 | Configure if the TransformerBlocks' attention should contain a bias parameter.
46 | """
47 |
48 | def __init__(
49 | self,
50 | num_attention_heads: int = 16,
51 | attention_head_dim: int = 88,
52 | in_channels: Optional[int] = None,
53 | num_layers: int = 1,
54 | dropout: float = 0.0,
55 | norm_num_groups: int = 32,
56 | cross_attention_dim: Optional[int] = None,
57 | attention_bias: bool = False,
58 | sample_size: Optional[int] = None,
59 | num_vector_embeds: Optional[int] = None,
60 | activation_fn: str = "geglu",
61 | num_embeds_ada_norm: Optional[int] = None,
62 | ):
63 | super().__init__()
64 | self.transformers = nn.ModuleList(
65 | [
66 | Transformer2DModel(
67 | num_attention_heads=num_attention_heads,
68 | attention_head_dim=attention_head_dim,
69 | in_channels=in_channels,
70 | num_layers=num_layers,
71 | dropout=dropout,
72 | norm_num_groups=norm_num_groups,
73 | cross_attention_dim=cross_attention_dim,
74 | attention_bias=attention_bias,
75 | sample_size=sample_size,
76 | num_vector_embeds=num_vector_embeds,
77 | activation_fn=activation_fn,
78 | num_embeds_ada_norm=num_embeds_ada_norm,
79 | )
80 | for _ in range(2)
81 | ]
82 | )
83 |
84 | # Variables that can be set by a pipeline:
85 |
86 | # The ratio of transformer1 to transformer2's output states to be combined during inference
87 | self.mix_ratio = 0.5
88 |
89 | # The shape of `encoder_hidden_states` is expected to be
90 | # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
91 | self.condition_lengths = [77, 257]
92 |
93 | # Which transformer to use to encode which condition.
94 | # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
95 | self.transformer_index_for_condition = [1, 0]
96 |
97 | def forward(
98 | self,
99 | hidden_states,
100 | encoder_hidden_states,
101 | timestep=None,
102 | attention_mask=None,
103 | cross_attention_kwargs=None,
104 | return_dict: bool = True,
105 | ):
106 | """
107 | Args:
108 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109 | When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110 | hidden_states
111 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113 | self-attention.
114 | timestep ( `torch.long`, *optional*):
115 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116 | attention_mask (`torch.FloatTensor`, *optional*):
117 | Optional attention mask to be applied in CrossAttention
118 | return_dict (`bool`, *optional*, defaults to `True`):
119 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
120 |
121 | Returns:
122 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
123 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
124 | returning a tuple, the first element is the sample tensor.
125 | """
126 | input_states = hidden_states
127 |
128 | encoded_states = []
129 | tokens_start = 0
130 | # attention_mask is not used yet
131 | for i in range(2):
132 | # for each of the two transformers, pass the corresponding condition tokens
133 | condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
134 | transformer_index = self.transformer_index_for_condition[i]
135 | encoded_state = self.transformers[transformer_index](
136 | input_states,
137 | encoder_hidden_states=condition_state,
138 | timestep=timestep,
139 | cross_attention_kwargs=cross_attention_kwargs,
140 | return_dict=False,
141 | )[0]
142 | encoded_states.append(encoded_state - input_states)
143 | tokens_start += self.condition_lengths[i]
144 |
145 | output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
146 | output_states = output_states + input_states
147 |
148 | if not return_dict:
149 | return (output_states,)
150 |
151 | return Transformer2DModelOutput(sample=output_states)
152 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import math
15 | from typing import Optional
16 |
17 | import numpy as np
18 | import torch
19 | from torch import nn
20 | from svdiff_pytorch.layers import SVDLinear
21 |
22 |
23 | def get_timestep_embedding(
24 | timesteps: torch.Tensor,
25 | embedding_dim: int,
26 | flip_sin_to_cos: bool = False,
27 | downscale_freq_shift: float = 1,
28 | scale: float = 1,
29 | max_period: int = 10000,
30 | ):
31 | """
32 | This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
33 |
34 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
35 | These may be fractional.
36 | :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
37 | embeddings. :return: an [N x dim] Tensor of positional embeddings.
38 | """
39 | assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
40 |
41 | half_dim = embedding_dim // 2
42 | exponent = -math.log(max_period) * torch.arange(
43 | start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
44 | )
45 | exponent = exponent / (half_dim - downscale_freq_shift)
46 |
47 | emb = torch.exp(exponent)
48 | emb = timesteps[:, None].float() * emb[None, :]
49 |
50 | # scale embeddings
51 | emb = scale * emb
52 |
53 | # concat sine and cosine embeddings
54 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
55 |
56 | # flip sine and cosine embeddings
57 | if flip_sin_to_cos:
58 | emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
59 |
60 | # zero pad
61 | if embedding_dim % 2 == 1:
62 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
63 | return emb
64 |
65 |
66 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
67 | """
68 | grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
69 | [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
70 | """
71 | grid_h = np.arange(grid_size, dtype=np.float32)
72 | grid_w = np.arange(grid_size, dtype=np.float32)
73 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
74 | grid = np.stack(grid, axis=0)
75 |
76 | grid = grid.reshape([2, 1, grid_size, grid_size])
77 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
78 | if cls_token and extra_tokens > 0:
79 | pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
80 | return pos_embed
81 |
82 |
83 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
84 | if embed_dim % 2 != 0:
85 | raise ValueError("embed_dim must be divisible by 2")
86 |
87 | # use half of dimensions to encode grid_h
88 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
89 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
90 |
91 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
92 | return emb
93 |
94 |
95 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
96 | """
97 | embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
98 | """
99 | if embed_dim % 2 != 0:
100 | raise ValueError("embed_dim must be divisible by 2")
101 |
102 | omega = np.arange(embed_dim // 2, dtype=np.float64)
103 | omega /= embed_dim / 2.0
104 | omega = 1.0 / 10000**omega # (D/2,)
105 |
106 | pos = pos.reshape(-1) # (M,)
107 | out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
108 |
109 | emb_sin = np.sin(out) # (M, D/2)
110 | emb_cos = np.cos(out) # (M, D/2)
111 |
112 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
113 | return emb
114 |
115 |
116 | class PatchEmbed(nn.Module):
117 | """2D Image to Patch Embedding"""
118 |
119 | def __init__(
120 | self,
121 | height=224,
122 | width=224,
123 | patch_size=16,
124 | in_channels=3,
125 | embed_dim=768,
126 | layer_norm=False,
127 | flatten=True,
128 | bias=True,
129 | ):
130 | super().__init__()
131 |
132 | num_patches = (height // patch_size) * (width // patch_size)
133 | self.flatten = flatten
134 | self.layer_norm = layer_norm
135 |
136 | self.proj = nn.Conv2d(
137 | in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
138 | )
139 | if layer_norm:
140 | self.norm = SVDLayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
141 | else:
142 | self.norm = None
143 |
144 | pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
145 | self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
146 |
147 | def forward(self, latent):
148 | latent = self.proj(latent)
149 | if self.flatten:
150 | latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
151 | if self.layer_norm:
152 | latent = self.norm(latent)
153 | return latent + self.pos_embed
154 |
155 |
156 | class TimestepEmbedding(nn.Module):
157 | def __init__(
158 | self,
159 | in_channels: int,
160 | time_embed_dim: int,
161 | act_fn: str = "silu",
162 | out_dim: int = None,
163 | post_act_fn: Optional[str] = None,
164 | cond_proj_dim=None,
165 | ):
166 | super().__init__()
167 |
168 | self.linear_1 = SVDLinear(in_channels, time_embed_dim)
169 |
170 | if cond_proj_dim is not None:
171 | self.cond_proj = SVDLinear(cond_proj_dim, in_channels, bias=False)
172 | else:
173 | self.cond_proj = None
174 |
175 | if act_fn == "silu":
176 | self.act = nn.SiLU()
177 | elif act_fn == "mish":
178 | self.act = nn.Mish()
179 | elif act_fn == "gelu":
180 | self.act = nn.GELU()
181 | else:
182 | raise ValueError(f"{act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
183 |
184 | if out_dim is not None:
185 | time_embed_dim_out = out_dim
186 | else:
187 | time_embed_dim_out = time_embed_dim
188 | self.linear_2 = SVDLinear(time_embed_dim, time_embed_dim_out)
189 |
190 | if post_act_fn is None:
191 | self.post_act = None
192 | elif post_act_fn == "silu":
193 | self.post_act = nn.SiLU()
194 | elif post_act_fn == "mish":
195 | self.post_act = nn.Mish()
196 | elif post_act_fn == "gelu":
197 | self.post_act = nn.GELU()
198 | else:
199 | raise ValueError(f"{post_act_fn} does not exist. Make sure to define one of 'silu', 'mish', or 'gelu'")
200 |
201 | def forward(self, sample, condition=None):
202 | if condition is not None:
203 | sample = sample + self.cond_proj(condition)
204 | sample = self.linear_1(sample)
205 |
206 | if self.act is not None:
207 | sample = self.act(sample)
208 |
209 | sample = self.linear_2(sample)
210 |
211 | if self.post_act is not None:
212 | sample = self.post_act(sample)
213 | return sample
214 |
215 |
216 | class Timesteps(nn.Module):
217 | def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
218 | super().__init__()
219 | self.num_channels = num_channels
220 | self.flip_sin_to_cos = flip_sin_to_cos
221 | self.downscale_freq_shift = downscale_freq_shift
222 |
223 | def forward(self, timesteps):
224 | t_emb = get_timestep_embedding(
225 | timesteps,
226 | self.num_channels,
227 | flip_sin_to_cos=self.flip_sin_to_cos,
228 | downscale_freq_shift=self.downscale_freq_shift,
229 | )
230 | return t_emb
231 |
232 |
233 | class GaussianFourierProjection(nn.Module):
234 | """Gaussian Fourier embeddings for noise levels."""
235 |
236 | def __init__(
237 | self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False
238 | ):
239 | super().__init__()
240 | self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
241 | self.log = log
242 | self.flip_sin_to_cos = flip_sin_to_cos
243 |
244 | if set_W_to_weight:
245 | # to delete later
246 | self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
247 |
248 | self.weight = self.W
249 |
250 | def forward(self, x):
251 | if self.log:
252 | x = torch.log(x)
253 |
254 | x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
255 |
256 | if self.flip_sin_to_cos:
257 | out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
258 | else:
259 | out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
260 | return out
261 |
262 |
263 | class ImagePositionalEmbeddings(nn.Module):
264 | """
265 | Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
266 | height and width of the latent space.
267 |
268 | For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
269 |
270 | For VQ-diffusion:
271 |
272 | Output vector embeddings are used as input for the transformer.
273 |
274 | Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
275 |
276 | Args:
277 | num_embed (`int`):
278 | Number of embeddings for the latent pixels embeddings.
279 | height (`int`):
280 | Height of the latent image i.e. the number of height embeddings.
281 | width (`int`):
282 | Width of the latent image i.e. the number of width embeddings.
283 | embed_dim (`int`):
284 | Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
285 | """
286 |
287 | def __init__(
288 | self,
289 | num_embed: int,
290 | height: int,
291 | width: int,
292 | embed_dim: int,
293 | ):
294 | super().__init__()
295 |
296 | self.height = height
297 | self.width = width
298 | self.num_embed = num_embed
299 | self.embed_dim = embed_dim
300 |
301 | self.emb = nn.Embedding(self.num_embed, embed_dim)
302 | self.height_emb = nn.Embedding(self.height, embed_dim)
303 | self.width_emb = nn.Embedding(self.width, embed_dim)
304 |
305 | def forward(self, index):
306 | emb = self.emb(index)
307 |
308 | height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
309 |
310 | # 1 x H x D -> 1 x H x 1 x D
311 | height_emb = height_emb.unsqueeze(2)
312 |
313 | width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
314 |
315 | # 1 x W x D -> 1 x 1 x W x D
316 | width_emb = width_emb.unsqueeze(1)
317 |
318 | pos_emb = height_emb + width_emb
319 |
320 | # 1 x H x W x D -> 1 x L xD
321 | pos_emb = pos_emb.view(1, self.height * self.width, -1)
322 |
323 | emb = emb + pos_emb[:, : emb.shape[1], :]
324 |
325 | return emb
326 |
327 |
328 | class LabelEmbedding(nn.Module):
329 | """
330 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
331 |
332 | Args:
333 | num_classes (`int`): The number of classes.
334 | hidden_size (`int`): The size of the vector embeddings.
335 | dropout_prob (`float`): The probability of dropping a label.
336 | """
337 |
338 | def __init__(self, num_classes, hidden_size, dropout_prob):
339 | super().__init__()
340 | use_cfg_embedding = dropout_prob > 0
341 | self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
342 | self.num_classes = num_classes
343 | self.dropout_prob = dropout_prob
344 |
345 | def token_drop(self, labels, force_drop_ids=None):
346 | """
347 | Drops labels to enable classifier-free guidance.
348 | """
349 | if force_drop_ids is None:
350 | drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
351 | else:
352 | drop_ids = torch.tensor(force_drop_ids == 1)
353 | labels = torch.where(drop_ids, self.num_classes, labels)
354 | return labels
355 |
356 | def forward(self, labels, force_drop_ids=None):
357 | use_dropout = self.dropout_prob > 0
358 | if (self.training and use_dropout) or (force_drop_ids is not None):
359 | labels = self.token_drop(labels, force_drop_ids)
360 | embeddings = self.embedding_table(labels)
361 | return embeddings
362 |
363 |
364 | class CombinedTimestepLabelEmbeddings(nn.Module):
365 | def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
366 | super().__init__()
367 |
368 | self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
369 | self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
370 | self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
371 |
372 | def forward(self, timestep, class_labels, hidden_dtype=None):
373 | timesteps_proj = self.time_proj(timestep)
374 | timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
375 |
376 | class_labels = self.class_embedder(class_labels) # (N, D)
377 |
378 | conditioning = timesteps_emb + class_labels # (N, D)
379 |
380 | return conditioning
381 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/resnet.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Optional
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from svdiff_pytorch.diffusers_models.attention import AdaGroupNorm
9 | from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm
10 |
11 |
12 | class Upsample1D(nn.Module):
13 | """
14 | An upsampling layer with an optional convolution.
15 |
16 | Parameters:
17 | channels: channels in the inputs and outputs.
18 | use_conv: a bool determining if a convolution is applied.
19 | use_conv_transpose:
20 | out_channels:
21 | """
22 |
23 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
24 | super().__init__()
25 | self.channels = channels
26 | self.out_channels = out_channels or channels
27 | self.use_conv = use_conv
28 | self.use_conv_transpose = use_conv_transpose
29 | self.name = name
30 |
31 | self.conv = None
32 | if use_conv_transpose:
33 | self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
34 | elif use_conv:
35 | self.conv = SVDConv1d(self.channels, self.out_channels, 3, padding=1)
36 |
37 | def forward(self, x):
38 | assert x.shape[1] == self.channels
39 | if self.use_conv_transpose:
40 | return self.conv(x)
41 |
42 | x = F.interpolate(x, scale_factor=2.0, mode="nearest")
43 |
44 | if self.use_conv:
45 | x = self.conv(x)
46 |
47 | return x
48 |
49 |
50 | class Downsample1D(nn.Module):
51 | """
52 | A downsampling layer with an optional convolution.
53 |
54 | Parameters:
55 | channels: channels in the inputs and outputs.
56 | use_conv: a bool determining if a convolution is applied.
57 | out_channels:
58 | padding:
59 | """
60 |
61 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
62 | super().__init__()
63 | self.channels = channels
64 | self.out_channels = out_channels or channels
65 | self.use_conv = use_conv
66 | self.padding = padding
67 | stride = 2
68 | self.name = name
69 |
70 | if use_conv:
71 | self.conv = SVDConv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
72 | else:
73 | assert self.channels == self.out_channels
74 | self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
75 |
76 | def forward(self, x):
77 | assert x.shape[1] == self.channels
78 | return self.conv(x)
79 |
80 |
81 | class Upsample2D(nn.Module):
82 | """
83 | An upsampling layer with an optional convolution.
84 |
85 | Parameters:
86 | channels: channels in the inputs and outputs.
87 | use_conv: a bool determining if a convolution is applied.
88 | use_conv_transpose:
89 | out_channels:
90 | """
91 |
92 | def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
93 | super().__init__()
94 | self.channels = channels
95 | self.out_channels = out_channels or channels
96 | self.use_conv = use_conv
97 | self.use_conv_transpose = use_conv_transpose
98 | self.name = name
99 |
100 | conv = None
101 | if use_conv_transpose:
102 | conv = nn.ConvTranspose2d(channels, self.out_channels, 4, 2, 1)
103 | elif use_conv:
104 | conv = SVDConv2d(self.channels, self.out_channels, 3, padding=1)
105 |
106 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
107 | if name == "conv":
108 | self.conv = conv
109 | else:
110 | self.Conv2d_0 = conv
111 |
112 | def forward(self, hidden_states, output_size=None):
113 | assert hidden_states.shape[1] == self.channels
114 |
115 | if self.use_conv_transpose:
116 | return self.conv(hidden_states)
117 |
118 | # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
119 | # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch
120 | # https://github.com/pytorch/pytorch/issues/86679
121 | dtype = hidden_states.dtype
122 | if dtype == torch.bfloat16:
123 | hidden_states = hidden_states.to(torch.float32)
124 |
125 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
126 | if hidden_states.shape[0] >= 64:
127 | hidden_states = hidden_states.contiguous()
128 |
129 | # if `output_size` is passed we force the interpolation output
130 | # size and do not make use of `scale_factor=2`
131 | if output_size is None:
132 | hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
133 | else:
134 | hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
135 |
136 | # If the input is bfloat16, we cast back to bfloat16
137 | if dtype == torch.bfloat16:
138 | hidden_states = hidden_states.to(dtype)
139 |
140 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
141 | if self.use_conv:
142 | if self.name == "conv":
143 | hidden_states = self.conv(hidden_states)
144 | else:
145 | hidden_states = self.Conv2d_0(hidden_states)
146 |
147 | return hidden_states
148 |
149 |
150 | class Downsample2D(nn.Module):
151 | """
152 | A downsampling layer with an optional convolution.
153 |
154 | Parameters:
155 | channels: channels in the inputs and outputs.
156 | use_conv: a bool determining if a convolution is applied.
157 | out_channels:
158 | padding:
159 | """
160 |
161 | def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
162 | super().__init__()
163 | self.channels = channels
164 | self.out_channels = out_channels or channels
165 | self.use_conv = use_conv
166 | self.padding = padding
167 | stride = 2
168 | self.name = name
169 |
170 | if use_conv:
171 | conv = SVDConv2d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
172 | else:
173 | assert self.channels == self.out_channels
174 | conv = nn.AvgPool2d(kernel_size=stride, stride=stride)
175 |
176 | # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
177 | if name == "conv":
178 | self.Conv2d_0 = conv
179 | self.conv = conv
180 | elif name == "Conv2d_0":
181 | self.conv = conv
182 | else:
183 | self.conv = conv
184 |
185 | def forward(self, hidden_states):
186 | assert hidden_states.shape[1] == self.channels
187 | if self.use_conv and self.padding == 0:
188 | pad = (0, 1, 0, 1)
189 | hidden_states = F.pad(hidden_states, pad, mode="constant", value=0)
190 |
191 | assert hidden_states.shape[1] == self.channels
192 | hidden_states = self.conv(hidden_states)
193 |
194 | return hidden_states
195 |
196 |
197 | class FirUpsample2D(nn.Module):
198 | def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
199 | super().__init__()
200 | out_channels = out_channels if out_channels else channels
201 | if use_conv:
202 | self.Conv2d_0 = SVDConv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
203 | self.use_conv = use_conv
204 | self.fir_kernel = fir_kernel
205 | self.out_channels = out_channels
206 |
207 | def _upsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
208 | """Fused `upsample_2d()` followed by `Conv2d()`.
209 |
210 | Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
211 | efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
212 | arbitrary order.
213 |
214 | Args:
215 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
216 | weight: Weight tensor of the shape `[filterH, filterW, inChannels,
217 | outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
218 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
219 | (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
220 | factor: Integer upsampling factor (default: 2).
221 | gain: Scaling factor for signal magnitude (default: 1.0).
222 |
223 | Returns:
224 | output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
225 | datatype as `hidden_states`.
226 | """
227 |
228 | assert isinstance(factor, int) and factor >= 1
229 |
230 | # Setup filter kernel.
231 | if kernel is None:
232 | kernel = [1] * factor
233 |
234 | # setup kernel
235 | kernel = torch.tensor(kernel, dtype=torch.float32)
236 | if kernel.ndim == 1:
237 | kernel = torch.outer(kernel, kernel)
238 | kernel /= torch.sum(kernel)
239 |
240 | kernel = kernel * (gain * (factor**2))
241 |
242 | if self.use_conv:
243 | convH = weight.shape[2]
244 | convW = weight.shape[3]
245 | inC = weight.shape[1]
246 |
247 | pad_value = (kernel.shape[0] - factor) - (convW - 1)
248 |
249 | stride = (factor, factor)
250 | # Determine data dimensions.
251 | output_shape = (
252 | (hidden_states.shape[2] - 1) * factor + convH,
253 | (hidden_states.shape[3] - 1) * factor + convW,
254 | )
255 | output_padding = (
256 | output_shape[0] - (hidden_states.shape[2] - 1) * stride[0] - convH,
257 | output_shape[1] - (hidden_states.shape[3] - 1) * stride[1] - convW,
258 | )
259 | assert output_padding[0] >= 0 and output_padding[1] >= 0
260 | num_groups = hidden_states.shape[1] // inC
261 |
262 | # Transpose weights.
263 | weight = torch.reshape(weight, (num_groups, -1, inC, convH, convW))
264 | weight = torch.flip(weight, dims=[3, 4]).permute(0, 2, 1, 3, 4)
265 | weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
266 |
267 | inverse_conv = F.conv_transpose2d(
268 | hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
269 | )
270 |
271 | output = upfirdn2d_native(
272 | inverse_conv,
273 | torch.tensor(kernel, device=inverse_conv.device),
274 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
275 | )
276 | else:
277 | pad_value = kernel.shape[0] - factor
278 | output = upfirdn2d_native(
279 | hidden_states,
280 | torch.tensor(kernel, device=hidden_states.device),
281 | up=factor,
282 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
283 | )
284 |
285 | return output
286 |
287 | def forward(self, hidden_states):
288 | if self.use_conv:
289 | height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
290 | height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
291 | else:
292 | height = self._upsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
293 |
294 | return height
295 |
296 |
297 | class FirDownsample2D(nn.Module):
298 | def __init__(self, channels=None, out_channels=None, use_conv=False, fir_kernel=(1, 3, 3, 1)):
299 | super().__init__()
300 | out_channels = out_channels if out_channels else channels
301 | if use_conv:
302 | self.Conv2d_0 = SVDConv2d(channels, out_channels, kernel_size=3, stride=1, padding=1)
303 | self.fir_kernel = fir_kernel
304 | self.use_conv = use_conv
305 | self.out_channels = out_channels
306 |
307 | def _downsample_2d(self, hidden_states, weight=None, kernel=None, factor=2, gain=1):
308 | """Fused `Conv2d()` followed by `downsample_2d()`.
309 | Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
310 | efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
311 | arbitrary order.
312 |
313 | Args:
314 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
315 | weight:
316 | Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
317 | performed by `inChannels = x.shape[0] // numGroups`.
318 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
319 | factor`, which corresponds to average pooling.
320 | factor: Integer downsampling factor (default: 2).
321 | gain: Scaling factor for signal magnitude (default: 1.0).
322 |
323 | Returns:
324 | output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
325 | same datatype as `x`.
326 | """
327 |
328 | assert isinstance(factor, int) and factor >= 1
329 | if kernel is None:
330 | kernel = [1] * factor
331 |
332 | # setup kernel
333 | kernel = torch.tensor(kernel, dtype=torch.float32)
334 | if kernel.ndim == 1:
335 | kernel = torch.outer(kernel, kernel)
336 | kernel /= torch.sum(kernel)
337 |
338 | kernel = kernel * gain
339 |
340 | if self.use_conv:
341 | _, _, convH, convW = weight.shape
342 | pad_value = (kernel.shape[0] - factor) + (convW - 1)
343 | stride_value = [factor, factor]
344 | upfirdn_input = upfirdn2d_native(
345 | hidden_states,
346 | torch.tensor(kernel, device=hidden_states.device),
347 | pad=((pad_value + 1) // 2, pad_value // 2),
348 | )
349 | output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
350 | else:
351 | pad_value = kernel.shape[0] - factor
352 | output = upfirdn2d_native(
353 | hidden_states,
354 | torch.tensor(kernel, device=hidden_states.device),
355 | down=factor,
356 | pad=((pad_value + 1) // 2, pad_value // 2),
357 | )
358 |
359 | return output
360 |
361 | def forward(self, hidden_states):
362 | if self.use_conv:
363 | downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
364 | hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
365 | else:
366 | hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2)
367 |
368 | return hidden_states
369 |
370 |
371 | # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead
372 | class KDownsample2D(nn.Module):
373 | def __init__(self, pad_mode="reflect"):
374 | super().__init__()
375 | self.pad_mode = pad_mode
376 | kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]])
377 | self.pad = kernel_1d.shape[1] // 2 - 1
378 | self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
379 |
380 | def forward(self, x):
381 | x = F.pad(x, (self.pad,) * 4, self.pad_mode)
382 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
383 | indices = torch.arange(x.shape[1], device=x.device)
384 | weight[indices, indices] = self.kernel.to(weight)
385 | return F.conv2d(x, weight, stride=2)
386 |
387 |
388 | class KUpsample2D(nn.Module):
389 | def __init__(self, pad_mode="reflect"):
390 | super().__init__()
391 | self.pad_mode = pad_mode
392 | kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) * 2
393 | self.pad = kernel_1d.shape[1] // 2 - 1
394 | self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False)
395 |
396 | def forward(self, x):
397 | x = F.pad(x, ((self.pad + 1) // 2,) * 4, self.pad_mode)
398 | weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
399 | indices = torch.arange(x.shape[1], device=x.device)
400 | weight[indices, indices] = self.kernel.to(weight)
401 | return F.conv_transpose2d(x, weight, stride=2, padding=self.pad * 2 + 1)
402 |
403 |
404 | class ResnetBlock2D(nn.Module):
405 | r"""
406 | A Resnet block.
407 |
408 | Parameters:
409 | in_channels (`int`): The number of channels in the input.
410 | out_channels (`int`, *optional*, default to be `None`):
411 | The number of output channels for the first conv2d layer. If None, same as `in_channels`.
412 | dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
413 | temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
414 | groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
415 | groups_out (`int`, *optional*, default to None):
416 | The number of groups to use for the second normalization layer. if set to None, same as `groups`.
417 | eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
418 | non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
419 | time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
420 | By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
421 | "ada_group" for a stronger conditioning with scale and shift.
422 | kernal (`torch.FloatTensor`, optional, default to None): FIR filter, see
423 | [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
424 | output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
425 | use_in_shortcut (`bool`, *optional*, default to `True`):
426 | If `True`, add a 1x1 nn.conv2d layer for skip-connection.
427 | up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
428 | down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
429 | conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
430 | `conv_shortcut` output.
431 | conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
432 | If None, same as `out_channels`.
433 | """
434 |
435 | def __init__(
436 | self,
437 | *,
438 | in_channels,
439 | out_channels=None,
440 | conv_shortcut=False,
441 | dropout=0.0,
442 | temb_channels=512,
443 | groups=32,
444 | groups_out=None,
445 | pre_norm=True,
446 | eps=1e-6,
447 | non_linearity="swish",
448 | time_embedding_norm="default", # default, scale_shift, ada_group
449 | kernel=None,
450 | output_scale_factor=1.0,
451 | use_in_shortcut=None,
452 | up=False,
453 | down=False,
454 | conv_shortcut_bias: bool = True,
455 | conv_2d_out_channels: Optional[int] = None,
456 | ):
457 | super().__init__()
458 | self.pre_norm = pre_norm
459 | self.pre_norm = True
460 | self.in_channels = in_channels
461 | out_channels = in_channels if out_channels is None else out_channels
462 | self.out_channels = out_channels
463 | self.use_conv_shortcut = conv_shortcut
464 | self.up = up
465 | self.down = down
466 | self.output_scale_factor = output_scale_factor
467 | self.time_embedding_norm = time_embedding_norm
468 |
469 | if groups_out is None:
470 | groups_out = groups
471 |
472 | if self.time_embedding_norm == "ada_group":
473 | self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
474 | else:
475 | self.norm1 = SVDGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
476 |
477 | self.conv1 = SVDConv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
478 |
479 | if temb_channels is not None:
480 | if self.time_embedding_norm == "default":
481 | self.time_emb_proj = SVDLinear(temb_channels, out_channels)
482 | elif self.time_embedding_norm == "scale_shift":
483 | self.time_emb_proj = SVDLinear(temb_channels, 2 * out_channels)
484 | elif self.time_embedding_norm == "ada_group":
485 | self.time_emb_proj = None
486 | else:
487 | raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
488 | else:
489 | self.time_emb_proj = None
490 |
491 | if self.time_embedding_norm == "ada_group":
492 | self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
493 | else:
494 | self.norm2 = SVDGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
495 |
496 | self.dropout = torch.nn.Dropout(dropout)
497 | conv_2d_out_channels = conv_2d_out_channels or out_channels
498 | self.conv2 = SVDConv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
499 |
500 | if non_linearity == "swish":
501 | self.nonlinearity = lambda x: F.silu(x)
502 | elif non_linearity == "mish":
503 | self.nonlinearity = nn.Mish()
504 | elif non_linearity == "silu":
505 | self.nonlinearity = nn.SiLU()
506 | elif non_linearity == "gelu":
507 | self.nonlinearity = nn.GELU()
508 |
509 | self.upsample = self.downsample = None
510 | if self.up:
511 | if kernel == "fir":
512 | fir_kernel = (1, 3, 3, 1)
513 | self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
514 | elif kernel == "sde_vp":
515 | self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
516 | else:
517 | self.upsample = Upsample2D(in_channels, use_conv=False)
518 | elif self.down:
519 | if kernel == "fir":
520 | fir_kernel = (1, 3, 3, 1)
521 | self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
522 | elif kernel == "sde_vp":
523 | self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
524 | else:
525 | self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
526 |
527 | self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
528 |
529 | self.conv_shortcut = None
530 | if self.use_in_shortcut:
531 | self.conv_shortcut = SVDConv2d(
532 | in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
533 | )
534 |
535 | def forward(self, input_tensor, temb):
536 | hidden_states = input_tensor
537 |
538 | if self.time_embedding_norm == "ada_group":
539 | hidden_states = self.norm1(hidden_states, temb)
540 | else:
541 | hidden_states = self.norm1(hidden_states)
542 |
543 | hidden_states = self.nonlinearity(hidden_states)
544 |
545 | if self.upsample is not None:
546 | # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
547 | if hidden_states.shape[0] >= 64:
548 | input_tensor = input_tensor.contiguous()
549 | hidden_states = hidden_states.contiguous()
550 | input_tensor = self.upsample(input_tensor)
551 | hidden_states = self.upsample(hidden_states)
552 | elif self.downsample is not None:
553 | input_tensor = self.downsample(input_tensor)
554 | hidden_states = self.downsample(hidden_states)
555 |
556 | hidden_states = self.conv1(hidden_states)
557 |
558 | if self.time_emb_proj is not None:
559 | temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
560 |
561 | if temb is not None and self.time_embedding_norm == "default":
562 | hidden_states = hidden_states + temb
563 |
564 | if self.time_embedding_norm == "ada_group":
565 | hidden_states = self.norm2(hidden_states, temb)
566 | else:
567 | hidden_states = self.norm2(hidden_states)
568 |
569 | if temb is not None and self.time_embedding_norm == "scale_shift":
570 | scale, shift = torch.chunk(temb, 2, dim=1)
571 | hidden_states = hidden_states * (1 + scale) + shift
572 |
573 | hidden_states = self.nonlinearity(hidden_states)
574 |
575 | hidden_states = self.dropout(hidden_states)
576 | hidden_states = self.conv2(hidden_states)
577 |
578 | if self.conv_shortcut is not None:
579 | input_tensor = self.conv_shortcut(input_tensor)
580 |
581 | output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
582 |
583 | return output_tensor
584 |
585 |
586 | class Mish(torch.nn.Module):
587 | def forward(self, hidden_states):
588 | return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
589 |
590 |
591 | # unet_rl.py
592 | def rearrange_dims(tensor):
593 | if len(tensor.shape) == 2:
594 | return tensor[:, :, None]
595 | if len(tensor.shape) == 3:
596 | return tensor[:, :, None, :]
597 | elif len(tensor.shape) == 4:
598 | return tensor[:, :, 0, :]
599 | else:
600 | raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
601 |
602 |
603 | class Conv1dBlock(nn.Module):
604 | """
605 | Conv1d --> GroupNorm --> Mish
606 | """
607 |
608 | def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
609 | super().__init__()
610 |
611 | self.conv1d = SVDConv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
612 | self.group_norm = SVDGroupNorm(n_groups, out_channels)
613 | self.mish = nn.Mish()
614 |
615 | def forward(self, x):
616 | x = self.conv1d(x)
617 | x = rearrange_dims(x)
618 | x = self.group_norm(x)
619 | x = rearrange_dims(x)
620 | x = self.mish(x)
621 | return x
622 |
623 |
624 | # unet_rl.py
625 | class ResidualTemporalBlock1D(nn.Module):
626 | def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
627 | super().__init__()
628 | self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
629 | self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
630 |
631 | self.time_emb_act = nn.Mish()
632 | self.time_emb = SVDLinear(embed_dim, out_channels)
633 |
634 | self.residual_conv = (
635 | SVDConv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
636 | )
637 |
638 | def forward(self, x, t):
639 | """
640 | Args:
641 | x : [ batch_size x inp_channels x horizon ]
642 | t : [ batch_size x embed_dim ]
643 |
644 | returns:
645 | out : [ batch_size x out_channels x horizon ]
646 | """
647 | t = self.time_emb_act(t)
648 | t = self.time_emb(t)
649 | out = self.conv_in(x) + rearrange_dims(t)
650 | out = self.conv_out(out)
651 | return out + self.residual_conv(x)
652 |
653 |
654 | def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
655 | r"""Upsample2D a batch of 2D images with the given filter.
656 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
657 | filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
658 | `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is
659 | a: multiple of the upsampling factor.
660 |
661 | Args:
662 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
663 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
664 | (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
665 | factor: Integer upsampling factor (default: 2).
666 | gain: Scaling factor for signal magnitude (default: 1.0).
667 |
668 | Returns:
669 | output: Tensor of the shape `[N, C, H * factor, W * factor]`
670 | """
671 | assert isinstance(factor, int) and factor >= 1
672 | if kernel is None:
673 | kernel = [1] * factor
674 |
675 | kernel = torch.tensor(kernel, dtype=torch.float32)
676 | if kernel.ndim == 1:
677 | kernel = torch.outer(kernel, kernel)
678 | kernel /= torch.sum(kernel)
679 |
680 | kernel = kernel * (gain * (factor**2))
681 | pad_value = kernel.shape[0] - factor
682 | output = upfirdn2d_native(
683 | hidden_states,
684 | kernel.to(device=hidden_states.device),
685 | up=factor,
686 | pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
687 | )
688 | return output
689 |
690 |
691 | def downsample_2d(hidden_states, kernel=None, factor=2, gain=1):
692 | r"""Downsample2D a batch of 2D images with the given filter.
693 | Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
694 | given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
695 | specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its
696 | shape is a multiple of the downsampling factor.
697 |
698 | Args:
699 | hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
700 | kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
701 | (separable). The default is `[1] * factor`, which corresponds to average pooling.
702 | factor: Integer downsampling factor (default: 2).
703 | gain: Scaling factor for signal magnitude (default: 1.0).
704 |
705 | Returns:
706 | output: Tensor of the shape `[N, C, H // factor, W // factor]`
707 | """
708 |
709 | assert isinstance(factor, int) and factor >= 1
710 | if kernel is None:
711 | kernel = [1] * factor
712 |
713 | kernel = torch.tensor(kernel, dtype=torch.float32)
714 | if kernel.ndim == 1:
715 | kernel = torch.outer(kernel, kernel)
716 | kernel /= torch.sum(kernel)
717 |
718 | kernel = kernel * gain
719 | pad_value = kernel.shape[0] - factor
720 | output = upfirdn2d_native(
721 | hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
722 | )
723 | return output
724 |
725 |
726 | def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
727 | up_x = up_y = up
728 | down_x = down_y = down
729 | pad_x0 = pad_y0 = pad[0]
730 | pad_x1 = pad_y1 = pad[1]
731 |
732 | _, channel, in_h, in_w = tensor.shape
733 | tensor = tensor.reshape(-1, in_h, in_w, 1)
734 |
735 | _, in_h, in_w, minor = tensor.shape
736 | kernel_h, kernel_w = kernel.shape
737 |
738 | out = tensor.view(-1, in_h, 1, in_w, 1, minor)
739 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
740 | out = out.view(-1, in_h * up_y, in_w * up_x, minor)
741 |
742 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
743 | out = out.to(tensor.device) # Move back to mps if necessary
744 | out = out[
745 | :,
746 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
747 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
748 | :,
749 | ]
750 |
751 | out = out.permute(0, 3, 1, 2)
752 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
753 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
754 | out = F.conv2d(out, w)
755 | out = out.reshape(
756 | -1,
757 | minor,
758 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
759 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
760 | )
761 | out = out.permute(0, 2, 3, 1)
762 | out = out[:, ::down_y, ::down_x, :]
763 |
764 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
765 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
766 |
767 | return out.view(-1, channel, out_h, out_w)
768 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/transformer_2d.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Optional
16 |
17 | import torch
18 | import torch.nn.functional as F
19 | from torch import nn
20 |
21 | from diffusers.configuration_utils import ConfigMixin, register_to_config
22 | from diffusers.models.embeddings import ImagePositionalEmbeddings
23 | from diffusers.utils import BaseOutput, deprecate
24 | from svdiff_pytorch.diffusers_models.attention import BasicTransformerBlock
25 | from diffusers.models.embeddings import PatchEmbed
26 | from diffusers.models.modeling_utils import ModelMixin
27 | from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm
28 |
29 |
30 | @dataclass
31 | class Transformer2DModelOutput(BaseOutput):
32 | """
33 | Args:
34 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
35 | Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
36 | for the unnoised latent pixels.
37 | """
38 |
39 | sample: torch.FloatTensor
40 |
41 |
42 | class Transformer2DModel(ModelMixin, ConfigMixin):
43 | """
44 | Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
45 | embeddings) inputs.
46 |
47 | When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
48 | transformer action. Finally, reshape to image.
49 |
50 | When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
51 | embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
52 | classes of unnoised image.
53 |
54 | Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
55 | image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
56 |
57 | Parameters:
58 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
59 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
60 | in_channels (`int`, *optional*):
61 | Pass if the input is continuous. The number of channels in the input and output.
62 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
63 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
64 | cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
65 | sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
66 | Note that this is fixed at training time as it is used for learning a number of position embeddings. See
67 | `ImagePositionalEmbeddings`.
68 | num_vector_embeds (`int`, *optional*):
69 | Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
70 | Includes the class for the masked latent pixel.
71 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
72 | num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
73 | The number of diffusion steps used during training. Note that this is fixed at training time as it is used
74 | to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
75 | up to but not more than steps than `num_embeds_ada_norm`.
76 | attention_bias (`bool`, *optional*):
77 | Configure if the TransformerBlocks' attention should contain a bias parameter.
78 | """
79 |
80 | @register_to_config
81 | def __init__(
82 | self,
83 | num_attention_heads: int = 16,
84 | attention_head_dim: int = 88,
85 | in_channels: Optional[int] = None,
86 | out_channels: Optional[int] = None,
87 | num_layers: int = 1,
88 | dropout: float = 0.0,
89 | norm_num_groups: int = 32,
90 | cross_attention_dim: Optional[int] = None,
91 | attention_bias: bool = False,
92 | sample_size: Optional[int] = None,
93 | num_vector_embeds: Optional[int] = None,
94 | patch_size: Optional[int] = None,
95 | activation_fn: str = "geglu",
96 | num_embeds_ada_norm: Optional[int] = None,
97 | use_linear_projection: bool = False,
98 | only_cross_attention: bool = False,
99 | upcast_attention: bool = False,
100 | norm_type: str = "layer_norm",
101 | norm_elementwise_affine: bool = True,
102 | ):
103 | super().__init__()
104 | self.use_linear_projection = use_linear_projection
105 | self.num_attention_heads = num_attention_heads
106 | self.attention_head_dim = attention_head_dim
107 | inner_dim = num_attention_heads * attention_head_dim
108 |
109 | # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
110 | # Define whether input is continuous or discrete depending on configuration
111 | self.is_input_continuous = (in_channels is not None) and (patch_size is None)
112 | self.is_input_vectorized = num_vector_embeds is not None
113 | self.is_input_patches = in_channels is not None and patch_size is not None
114 |
115 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
116 | deprecation_message = (
117 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
118 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
119 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
120 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
121 | " would be very nice if you could open a Pull request for the `transformer/config.json` file"
122 | )
123 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
124 | norm_type = "ada_norm"
125 |
126 | if self.is_input_continuous and self.is_input_vectorized:
127 | raise ValueError(
128 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
129 | " sure that either `in_channels` or `num_vector_embeds` is None."
130 | )
131 | elif self.is_input_vectorized and self.is_input_patches:
132 | raise ValueError(
133 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
134 | " sure that either `num_vector_embeds` or `num_patches` is None."
135 | )
136 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
137 | raise ValueError(
138 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
139 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
140 | )
141 |
142 | # 2. Define input layers
143 | if self.is_input_continuous:
144 | self.in_channels = in_channels
145 |
146 | self.norm = SVDGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
147 | if use_linear_projection:
148 | self.proj_in = SVDLinear(in_channels, inner_dim)
149 | else:
150 | self.proj_in = SVDConv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
151 | elif self.is_input_vectorized:
152 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
153 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
154 |
155 | self.height = sample_size
156 | self.width = sample_size
157 | self.num_vector_embeds = num_vector_embeds
158 | self.num_latent_pixels = self.height * self.width
159 |
160 | self.latent_image_embedding = ImagePositionalEmbeddings(
161 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
162 | )
163 | elif self.is_input_patches:
164 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
165 |
166 | self.height = sample_size
167 | self.width = sample_size
168 |
169 | self.patch_size = patch_size
170 | self.pos_embed = PatchEmbed(
171 | height=sample_size,
172 | width=sample_size,
173 | patch_size=patch_size,
174 | in_channels=in_channels,
175 | embed_dim=inner_dim,
176 | )
177 |
178 | # 3. Define transformers blocks
179 | self.transformer_blocks = nn.ModuleList(
180 | [
181 | BasicTransformerBlock(
182 | inner_dim,
183 | num_attention_heads,
184 | attention_head_dim,
185 | dropout=dropout,
186 | cross_attention_dim=cross_attention_dim,
187 | activation_fn=activation_fn,
188 | num_embeds_ada_norm=num_embeds_ada_norm,
189 | attention_bias=attention_bias,
190 | only_cross_attention=only_cross_attention,
191 | upcast_attention=upcast_attention,
192 | norm_type=norm_type,
193 | norm_elementwise_affine=norm_elementwise_affine,
194 | )
195 | for d in range(num_layers)
196 | ]
197 | )
198 |
199 | # 4. Define output layers
200 | self.out_channels = in_channels if out_channels is None else out_channels
201 | if self.is_input_continuous:
202 | # TODO: should use out_channels for continous projections
203 | if use_linear_projection:
204 | self.proj_out = SVDLinear(inner_dim, in_channels)
205 | else:
206 | self.proj_out = SVDConv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
207 | elif self.is_input_vectorized:
208 | self.norm_out = SVDLayerNorm(inner_dim)
209 | self.out = SVDLinear(inner_dim, self.num_vector_embeds - 1)
210 | elif self.is_input_patches:
211 | self.norm_out = SVDLayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
212 | self.proj_out_1 = SVDLinear(inner_dim, 2 * inner_dim)
213 | self.proj_out_2 = SVDLinear(inner_dim, patch_size * patch_size * self.out_channels)
214 |
215 | def forward(
216 | self,
217 | hidden_states,
218 | encoder_hidden_states=None,
219 | timestep=None,
220 | class_labels=None,
221 | cross_attention_kwargs=None,
222 | return_dict: bool = True,
223 | ):
224 | """
225 | Args:
226 | hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
227 | When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
228 | hidden_states
229 | encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
230 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
231 | self-attention.
232 | timestep ( `torch.long`, *optional*):
233 | Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
234 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
235 | Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
236 | conditioning.
237 | return_dict (`bool`, *optional*, defaults to `True`):
238 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
239 |
240 | Returns:
241 | [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
242 | [`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
243 | returning a tuple, the first element is the sample tensor.
244 | """
245 | # 1. Input
246 | if self.is_input_continuous:
247 | batch, _, height, width = hidden_states.shape
248 | residual = hidden_states
249 |
250 | hidden_states = self.norm(hidden_states)
251 | if not self.use_linear_projection:
252 | hidden_states = self.proj_in(hidden_states)
253 | inner_dim = hidden_states.shape[1]
254 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
255 | else:
256 | inner_dim = hidden_states.shape[1]
257 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
258 | hidden_states = self.proj_in(hidden_states)
259 | elif self.is_input_vectorized:
260 | hidden_states = self.latent_image_embedding(hidden_states)
261 | elif self.is_input_patches:
262 | hidden_states = self.pos_embed(hidden_states)
263 |
264 | # 2. Blocks
265 | for block in self.transformer_blocks:
266 | hidden_states = block(
267 | hidden_states,
268 | encoder_hidden_states=encoder_hidden_states,
269 | timestep=timestep,
270 | cross_attention_kwargs=cross_attention_kwargs,
271 | class_labels=class_labels,
272 | )
273 |
274 | # 3. Output
275 | if self.is_input_continuous:
276 | if not self.use_linear_projection:
277 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
278 | hidden_states = self.proj_out(hidden_states)
279 | else:
280 | hidden_states = self.proj_out(hidden_states)
281 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
282 |
283 | output = hidden_states + residual
284 | elif self.is_input_vectorized:
285 | hidden_states = self.norm_out(hidden_states)
286 | logits = self.out(hidden_states)
287 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
288 | logits = logits.permute(0, 2, 1)
289 |
290 | # log(p(x_0))
291 | output = F.log_softmax(logits.double(), dim=1).float()
292 | elif self.is_input_patches:
293 | # TODO: cleanup!
294 | conditioning = self.transformer_blocks[0].norm1.emb(
295 | timestep, class_labels, hidden_dtype=hidden_states.dtype
296 | )
297 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
298 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
299 | hidden_states = self.proj_out_2(hidden_states)
300 |
301 | # unpatchify
302 | height = width = int(hidden_states.shape[1] ** 0.5)
303 | hidden_states = hidden_states.reshape(
304 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
305 | )
306 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
307 | output = hidden_states.reshape(
308 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
309 | )
310 |
311 | if not return_dict:
312 | return (output,)
313 |
314 | return Transformer2DModelOutput(sample=output)
315 |
--------------------------------------------------------------------------------
/svdiff_pytorch/diffusers_models/unet_2d_condition.py:
--------------------------------------------------------------------------------
1 | # Copyright 2023 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | from dataclasses import dataclass
15 | from typing import Any, Dict, List, Optional, Tuple, Union
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.utils.checkpoint
20 |
21 | from diffusers.configuration_utils import ConfigMixin, register_to_config
22 | from diffusers.loaders import UNet2DConditionLoadersMixin
23 | from diffusers.utils import BaseOutput, logging
24 | from diffusers.models.cross_attention import AttnProcessor
25 | from svdiff_pytorch.diffusers_models.embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
26 | from diffusers.models.modeling_utils import ModelMixin
27 | from svdiff_pytorch.diffusers_models.unet_2d_blocks import (
28 | CrossAttnDownBlock2D,
29 | CrossAttnUpBlock2D,
30 | DownBlock2D,
31 | UNetMidBlock2DCrossAttn,
32 | UNetMidBlock2DSimpleCrossAttn,
33 | UpBlock2D,
34 | get_down_block,
35 | get_up_block,
36 | )
37 | from svdiff_pytorch.layers import SVDConv1d, SVDConv2d, SVDLinear, SVDGroupNorm, SVDLayerNorm
38 |
39 |
40 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41 |
42 |
43 | @dataclass
44 | class UNet2DConditionOutput(BaseOutput):
45 | """
46 | Args:
47 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
48 | Hidden states conditioned on `encoder_hidden_states` input. Output of last layer of model.
49 | """
50 |
51 | sample: torch.FloatTensor
52 |
53 |
54 | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
55 | r"""
56 | UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
57 | and returns sample shaped output.
58 |
59 | This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
60 | implements for all the models (such as downloading or saving, etc.)
61 |
62 | Parameters:
63 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
64 | Height and width of input/output sample.
65 | in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
66 | out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
67 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
68 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
69 | Whether to flip the sin to cos in the time embedding.
70 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
71 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
72 | The tuple of downsample blocks to use.
73 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
74 | The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`, will skip the
75 | mid block layer if `None`.
76 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
77 | The tuple of upsample blocks to use.
78 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
79 | Whether to include self-attention in the basic transformer blocks, see
80 | [`~models.attention.BasicTransformerBlock`].
81 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
82 | The tuple of output channels for each block.
83 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
84 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
85 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
86 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
87 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
88 | If `None`, it will skip the normalization and activation layers in post-processing
89 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
90 | cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
91 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
92 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
93 | for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
94 | class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
95 | summed with the time embeddings. Choose from `None`, `"timestep"`, `"identity"`, or `"projection"`.
96 | num_class_embeds (`int`, *optional*, defaults to None):
97 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
98 | class conditioning with `class_embed_type` equal to `None`.
99 | time_embedding_type (`str`, *optional*, default to `positional`):
100 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
101 | timestep_post_act (`str, *optional*, default to `None`):
102 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
103 | time_cond_proj_dim (`int`, *optional*, default to `None`):
104 | The dimension of `cond_proj` layer in timestep embedding.
105 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer.
106 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer.
107 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when
108 | using the "projection" `class_embed_type`. Required when using the "projection" `class_embed_type`.
109 | """
110 |
111 | _supports_gradient_checkpointing = True
112 |
113 | @register_to_config
114 | def __init__(
115 | self,
116 | sample_size: Optional[int] = None,
117 | in_channels: int = 4,
118 | out_channels: int = 4,
119 | center_input_sample: bool = False,
120 | flip_sin_to_cos: bool = True,
121 | freq_shift: int = 0,
122 | down_block_types: Tuple[str] = (
123 | "CrossAttnDownBlock2D",
124 | "CrossAttnDownBlock2D",
125 | "CrossAttnDownBlock2D",
126 | "DownBlock2D",
127 | ),
128 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
129 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
130 | only_cross_attention: Union[bool, Tuple[bool]] = False,
131 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
132 | layers_per_block: int = 2,
133 | downsample_padding: int = 1,
134 | mid_block_scale_factor: float = 1,
135 | act_fn: str = "silu",
136 | norm_num_groups: Optional[int] = 32,
137 | norm_eps: float = 1e-5,
138 | cross_attention_dim: int = 1280,
139 | attention_head_dim: Union[int, Tuple[int]] = 8,
140 | dual_cross_attention: bool = False,
141 | use_linear_projection: bool = False,
142 | class_embed_type: Optional[str] = None,
143 | num_class_embeds: Optional[int] = None,
144 | upcast_attention: bool = False,
145 | resnet_time_scale_shift: str = "default",
146 | time_embedding_type: str = "positional",
147 | timestep_post_act: Optional[str] = None,
148 | time_cond_proj_dim: Optional[int] = None,
149 | conv_in_kernel: int = 3,
150 | conv_out_kernel: int = 3,
151 | projection_class_embeddings_input_dim: Optional[int] = None,
152 | ):
153 | super().__init__()
154 |
155 | self.sample_size = sample_size
156 |
157 | # Check inputs
158 | if len(down_block_types) != len(up_block_types):
159 | raise ValueError(
160 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
161 | )
162 |
163 | if len(block_out_channels) != len(down_block_types):
164 | raise ValueError(
165 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
166 | )
167 |
168 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
169 | raise ValueError(
170 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
171 | )
172 |
173 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
174 | raise ValueError(
175 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
176 | )
177 |
178 | # input
179 | conv_in_padding = (conv_in_kernel - 1) // 2
180 | self.conv_in = SVDConv2d(
181 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
182 | )
183 |
184 | # time
185 | if time_embedding_type == "fourier":
186 | time_embed_dim = block_out_channels[0] * 2
187 | if time_embed_dim % 2 != 0:
188 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
189 | self.time_proj = GaussianFourierProjection(
190 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
191 | )
192 | timestep_input_dim = time_embed_dim
193 | elif time_embedding_type == "positional":
194 | time_embed_dim = block_out_channels[0] * 4
195 |
196 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
197 | timestep_input_dim = block_out_channels[0]
198 | else:
199 | raise ValueError(
200 | f"{time_embedding_type} does not exist. Pleaes make sure to use one of `fourier` or `positional`."
201 | )
202 |
203 | self.time_embedding = TimestepEmbedding(
204 | timestep_input_dim,
205 | time_embed_dim,
206 | act_fn=act_fn,
207 | post_act_fn=timestep_post_act,
208 | cond_proj_dim=time_cond_proj_dim,
209 | )
210 |
211 | # class embedding
212 | if class_embed_type is None and num_class_embeds is not None:
213 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
214 | elif class_embed_type == "timestep":
215 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
216 | elif class_embed_type == "identity":
217 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
218 | elif class_embed_type == "projection":
219 | if projection_class_embeddings_input_dim is None:
220 | raise ValueError(
221 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
222 | )
223 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
224 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
225 | # 2. it projects from an arbitrary input dimension.
226 | #
227 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
228 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
229 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
230 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
231 | else:
232 | self.class_embedding = None
233 |
234 | self.down_blocks = nn.ModuleList([])
235 | self.up_blocks = nn.ModuleList([])
236 |
237 | if isinstance(only_cross_attention, bool):
238 | only_cross_attention = [only_cross_attention] * len(down_block_types)
239 |
240 | if isinstance(attention_head_dim, int):
241 | attention_head_dim = (attention_head_dim,) * len(down_block_types)
242 |
243 | # down
244 | output_channel = block_out_channels[0]
245 | for i, down_block_type in enumerate(down_block_types):
246 | input_channel = output_channel
247 | output_channel = block_out_channels[i]
248 | is_final_block = i == len(block_out_channels) - 1
249 |
250 | down_block = get_down_block(
251 | down_block_type,
252 | num_layers=layers_per_block,
253 | in_channels=input_channel,
254 | out_channels=output_channel,
255 | temb_channels=time_embed_dim,
256 | add_downsample=not is_final_block,
257 | resnet_eps=norm_eps,
258 | resnet_act_fn=act_fn,
259 | resnet_groups=norm_num_groups,
260 | cross_attention_dim=cross_attention_dim,
261 | attn_num_head_channels=attention_head_dim[i],
262 | downsample_padding=downsample_padding,
263 | dual_cross_attention=dual_cross_attention,
264 | use_linear_projection=use_linear_projection,
265 | only_cross_attention=only_cross_attention[i],
266 | upcast_attention=upcast_attention,
267 | resnet_time_scale_shift=resnet_time_scale_shift,
268 | )
269 | self.down_blocks.append(down_block)
270 |
271 | # mid
272 | if mid_block_type == "UNetMidBlock2DCrossAttn":
273 | self.mid_block = UNetMidBlock2DCrossAttn(
274 | in_channels=block_out_channels[-1],
275 | temb_channels=time_embed_dim,
276 | resnet_eps=norm_eps,
277 | resnet_act_fn=act_fn,
278 | output_scale_factor=mid_block_scale_factor,
279 | resnet_time_scale_shift=resnet_time_scale_shift,
280 | cross_attention_dim=cross_attention_dim,
281 | attn_num_head_channels=attention_head_dim[-1],
282 | resnet_groups=norm_num_groups,
283 | dual_cross_attention=dual_cross_attention,
284 | use_linear_projection=use_linear_projection,
285 | upcast_attention=upcast_attention,
286 | )
287 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
288 | self.mid_block = UNetMidBlock2DSimpleCrossAttn(
289 | in_channels=block_out_channels[-1],
290 | temb_channels=time_embed_dim,
291 | resnet_eps=norm_eps,
292 | resnet_act_fn=act_fn,
293 | output_scale_factor=mid_block_scale_factor,
294 | cross_attention_dim=cross_attention_dim,
295 | attn_num_head_channels=attention_head_dim[-1],
296 | resnet_groups=norm_num_groups,
297 | resnet_time_scale_shift=resnet_time_scale_shift,
298 | )
299 | elif mid_block_type is None:
300 | self.mid_block = None
301 | else:
302 | raise ValueError(f"unknown mid_block_type : {mid_block_type}")
303 |
304 | # count how many layers upsample the images
305 | self.num_upsamplers = 0
306 |
307 | # up
308 | reversed_block_out_channels = list(reversed(block_out_channels))
309 | reversed_attention_head_dim = list(reversed(attention_head_dim))
310 | only_cross_attention = list(reversed(only_cross_attention))
311 |
312 | output_channel = reversed_block_out_channels[0]
313 | for i, up_block_type in enumerate(up_block_types):
314 | is_final_block = i == len(block_out_channels) - 1
315 |
316 | prev_output_channel = output_channel
317 | output_channel = reversed_block_out_channels[i]
318 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
319 |
320 | # add upsample block for all BUT final layer
321 | if not is_final_block:
322 | add_upsample = True
323 | self.num_upsamplers += 1
324 | else:
325 | add_upsample = False
326 |
327 | up_block = get_up_block(
328 | up_block_type,
329 | num_layers=layers_per_block + 1,
330 | in_channels=input_channel,
331 | out_channels=output_channel,
332 | prev_output_channel=prev_output_channel,
333 | temb_channels=time_embed_dim,
334 | add_upsample=add_upsample,
335 | resnet_eps=norm_eps,
336 | resnet_act_fn=act_fn,
337 | resnet_groups=norm_num_groups,
338 | cross_attention_dim=cross_attention_dim,
339 | attn_num_head_channels=reversed_attention_head_dim[i],
340 | dual_cross_attention=dual_cross_attention,
341 | use_linear_projection=use_linear_projection,
342 | only_cross_attention=only_cross_attention[i],
343 | upcast_attention=upcast_attention,
344 | resnet_time_scale_shift=resnet_time_scale_shift,
345 | )
346 | self.up_blocks.append(up_block)
347 | prev_output_channel = output_channel
348 |
349 | # out
350 | if norm_num_groups is not None:
351 | self.conv_norm_out = SVDGroupNorm(
352 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
353 | )
354 | self.conv_act = nn.SiLU()
355 | else:
356 | self.conv_norm_out = None
357 | self.conv_act = None
358 |
359 | conv_out_padding = (conv_out_kernel - 1) // 2
360 | self.conv_out = SVDConv2d(
361 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
362 | )
363 |
364 | @property
365 | def attn_processors(self) -> Dict[str, AttnProcessor]:
366 | r"""
367 | Returns:
368 | `dict` of attention processors: A dictionary containing all attention processors used in the model with
369 | indexed by its weight name.
370 | """
371 | # set recursively
372 | processors = {}
373 |
374 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
375 | if hasattr(module, "set_processor"):
376 | processors[f"{name}.processor"] = module.processor
377 |
378 | for sub_name, child in module.named_children():
379 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
380 |
381 | return processors
382 |
383 | for name, module in self.named_children():
384 | fn_recursive_add_processors(name, module, processors)
385 |
386 | return processors
387 |
388 | def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
389 | r"""
390 | Parameters:
391 | `processor (`dict` of `AttnProcessor` or `AttnProcessor`):
392 | The instantiated processor class or a dictionary of processor classes that will be set as the processor
393 | of **all** `CrossAttention` layers.
394 | In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
395 |
396 | """
397 | count = len(self.attn_processors.keys())
398 |
399 | if isinstance(processor, dict) and len(processor) != count:
400 | raise ValueError(
401 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
402 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
403 | )
404 |
405 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
406 | if hasattr(module, "set_processor"):
407 | if not isinstance(processor, dict):
408 | module.set_processor(processor)
409 | else:
410 | module.set_processor(processor.pop(f"{name}.processor"))
411 |
412 | for sub_name, child in module.named_children():
413 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
414 |
415 | for name, module in self.named_children():
416 | fn_recursive_attn_processor(name, module, processor)
417 |
418 | def set_attention_slice(self, slice_size):
419 | r"""
420 | Enable sliced attention computation.
421 |
422 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention
423 | in several steps. This is useful to save some memory in exchange for a small speed decrease.
424 |
425 | Args:
426 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
427 | When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
428 | `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
429 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
430 | must be a multiple of `slice_size`.
431 | """
432 | sliceable_head_dims = []
433 |
434 | def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
435 | if hasattr(module, "set_attention_slice"):
436 | sliceable_head_dims.append(module.sliceable_head_dim)
437 |
438 | for child in module.children():
439 | fn_recursive_retrieve_slicable_dims(child)
440 |
441 | # retrieve number of attention layers
442 | for module in self.children():
443 | fn_recursive_retrieve_slicable_dims(module)
444 |
445 | num_slicable_layers = len(sliceable_head_dims)
446 |
447 | if slice_size == "auto":
448 | # half the attention head size is usually a good trade-off between
449 | # speed and memory
450 | slice_size = [dim // 2 for dim in sliceable_head_dims]
451 | elif slice_size == "max":
452 | # make smallest slice possible
453 | slice_size = num_slicable_layers * [1]
454 |
455 | slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
456 |
457 | if len(slice_size) != len(sliceable_head_dims):
458 | raise ValueError(
459 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
460 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
461 | )
462 |
463 | for i in range(len(slice_size)):
464 | size = slice_size[i]
465 | dim = sliceable_head_dims[i]
466 | if size is not None and size > dim:
467 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
468 |
469 | # Recursively walk through all the children.
470 | # Any children which exposes the set_attention_slice method
471 | # gets the message
472 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
473 | if hasattr(module, "set_attention_slice"):
474 | module.set_attention_slice(slice_size.pop())
475 |
476 | for child in module.children():
477 | fn_recursive_set_attention_slice(child, slice_size)
478 |
479 | reversed_slice_size = list(reversed(slice_size))
480 | for module in self.children():
481 | fn_recursive_set_attention_slice(module, reversed_slice_size)
482 |
483 | def _set_gradient_checkpointing(self, module, value=False):
484 | if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
485 | module.gradient_checkpointing = value
486 |
487 | def forward(
488 | self,
489 | sample: torch.FloatTensor,
490 | timestep: Union[torch.Tensor, float, int],
491 | encoder_hidden_states: torch.Tensor,
492 | class_labels: Optional[torch.Tensor] = None,
493 | timestep_cond: Optional[torch.Tensor] = None,
494 | attention_mask: Optional[torch.Tensor] = None,
495 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
496 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
497 | mid_block_additional_residual: Optional[torch.Tensor] = None,
498 | return_dict: bool = True,
499 | ) -> Union[UNet2DConditionOutput, Tuple]:
500 | r"""
501 | Args:
502 | sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
503 | timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
504 | encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
505 | return_dict (`bool`, *optional*, defaults to `True`):
506 | Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
507 | cross_attention_kwargs (`dict`, *optional*):
508 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
509 | `self.processor` in
510 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
511 |
512 | Returns:
513 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
514 | [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
515 | returning a tuple, the first element is the sample tensor.
516 | """
517 | # By default samples have to be AT least a multiple of the overall upsampling factor.
518 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
519 | # However, the upsampling interpolation output size can be forced to fit any upsampling size
520 | # on the fly if necessary.
521 | default_overall_up_factor = 2**self.num_upsamplers
522 |
523 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
524 | forward_upsample_size = False
525 | upsample_size = None
526 |
527 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
528 | logger.info("Forward upsample size to force interpolation output size.")
529 | forward_upsample_size = True
530 |
531 | # prepare attention_mask
532 | if attention_mask is not None:
533 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
534 | attention_mask = attention_mask.unsqueeze(1)
535 |
536 | # 0. center input if necessary
537 | if self.config.center_input_sample:
538 | sample = 2 * sample - 1.0
539 |
540 | # 1. time
541 | timesteps = timestep
542 | if not torch.is_tensor(timesteps):
543 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
544 | # This would be a good case for the `match` statement (Python 3.10+)
545 | is_mps = sample.device.type == "mps"
546 | if isinstance(timestep, float):
547 | dtype = torch.float32 if is_mps else torch.float64
548 | else:
549 | dtype = torch.int32 if is_mps else torch.int64
550 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
551 | elif len(timesteps.shape) == 0:
552 | timesteps = timesteps[None].to(sample.device)
553 |
554 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
555 | timesteps = timesteps.expand(sample.shape[0])
556 |
557 | t_emb = self.time_proj(timesteps)
558 |
559 | # timesteps does not contain any weights and will always return f32 tensors
560 | # but time_embedding might actually be running in fp16. so we need to cast here.
561 | # there might be better ways to encapsulate this.
562 | t_emb = t_emb.to(dtype=self.dtype)
563 |
564 | emb = self.time_embedding(t_emb, timestep_cond)
565 |
566 | if self.class_embedding is not None:
567 | if class_labels is None:
568 | raise ValueError("class_labels should be provided when num_class_embeds > 0")
569 |
570 | if self.config.class_embed_type == "timestep":
571 | class_labels = self.time_proj(class_labels)
572 |
573 | class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
574 | emb = emb + class_emb
575 |
576 | # 2. pre-process
577 | sample = self.conv_in(sample)
578 |
579 | # 3. down
580 | down_block_res_samples = (sample,)
581 | for downsample_block in self.down_blocks:
582 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
583 | sample, res_samples = downsample_block(
584 | hidden_states=sample,
585 | temb=emb,
586 | encoder_hidden_states=encoder_hidden_states,
587 | attention_mask=attention_mask,
588 | cross_attention_kwargs=cross_attention_kwargs,
589 | )
590 | else:
591 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
592 |
593 | down_block_res_samples += res_samples
594 |
595 | if down_block_additional_residuals is not None:
596 | new_down_block_res_samples = ()
597 |
598 | for down_block_res_sample, down_block_additional_residual in zip(
599 | down_block_res_samples, down_block_additional_residuals
600 | ):
601 | down_block_res_sample += down_block_additional_residual
602 | new_down_block_res_samples += (down_block_res_sample,)
603 |
604 | down_block_res_samples = new_down_block_res_samples
605 |
606 | # 4. mid
607 | if self.mid_block is not None:
608 | sample = self.mid_block(
609 | sample,
610 | emb,
611 | encoder_hidden_states=encoder_hidden_states,
612 | attention_mask=attention_mask,
613 | cross_attention_kwargs=cross_attention_kwargs,
614 | )
615 |
616 | if mid_block_additional_residual is not None:
617 | sample += mid_block_additional_residual
618 |
619 | # 5. up
620 | for i, upsample_block in enumerate(self.up_blocks):
621 | is_final_block = i == len(self.up_blocks) - 1
622 |
623 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
624 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
625 |
626 | # if we have not reached the final block and need to forward the
627 | # upsample size, we do it here
628 | if not is_final_block and forward_upsample_size:
629 | upsample_size = down_block_res_samples[-1].shape[2:]
630 |
631 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
632 | sample = upsample_block(
633 | hidden_states=sample,
634 | temb=emb,
635 | res_hidden_states_tuple=res_samples,
636 | encoder_hidden_states=encoder_hidden_states,
637 | cross_attention_kwargs=cross_attention_kwargs,
638 | upsample_size=upsample_size,
639 | attention_mask=attention_mask,
640 | )
641 | else:
642 | sample = upsample_block(
643 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
644 | )
645 |
646 | # 6. post-process
647 | if self.conv_norm_out:
648 | sample = self.conv_norm_out(sample)
649 | sample = self.conv_act(sample)
650 | sample = self.conv_out(sample)
651 |
652 | if not return_dict:
653 | return (sample,)
654 |
655 | return UNet2DConditionOutput(sample=sample)
656 |
--------------------------------------------------------------------------------
/svdiff_pytorch/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | from einops import rearrange
5 |
6 |
7 |
8 | class SVDConv2d(nn.Conv2d):
9 | def __init__(
10 | self,
11 | in_channels: int,
12 | out_channels: int,
13 | kernel_size: int,
14 | scale: float = 1.0,
15 | **kwargs
16 | ):
17 | nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
18 | assert type(kernel_size) is int
19 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)')
20 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
21 | # initialize to 0 for smooth tuning
22 | self.delta = nn.Parameter(torch.zeros_like(self.S))
23 | self.weight.requires_grad = False
24 | self.done_svd = False
25 | self.scale = scale
26 | self.reset_parameters()
27 |
28 | def set_scale(self, scale: float):
29 | self.scale = scale
30 |
31 | def perform_svd(self):
32 | # shape
33 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)')
34 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
35 | self.done_svd = True
36 |
37 | def reset_parameters(self):
38 | nn.Conv2d.reset_parameters(self)
39 | if hasattr(self, 'delta'):
40 | nn.init.zeros_(self.delta)
41 |
42 | def forward(self, x: torch.Tensor):
43 | if not self.done_svd:
44 | # this happens after loading the state dict
45 | self.perform_svd()
46 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
47 | weight_updated = rearrange(weight_updated, 'co (cin h w) -> co cin h w', cin=self.weight.size(1), h=self.weight.size(2), w=self.weight.size(3))
48 | return F.conv2d(x, weight_updated, self.bias, self.stride, self.padding, self.dilation, self.groups)
49 |
50 |
51 | class SVDConv1d(nn.Conv1d):
52 | def __init__(
53 | self,
54 | in_channels: int,
55 | out_channels: int,
56 | kernel_size: int,
57 | scale: float = 1.0,
58 | **kwargs
59 | ):
60 | nn.Conv1d.__init__(self, in_channels, out_channels, kernel_size, **kwargs)
61 | assert type(kernel_size) is int
62 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)')
63 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
64 | # initialize to 0 for smooth tuning
65 | self.delta = nn.Parameter(torch.zeros_like(self.S))
66 | self.weight.requires_grad = False
67 | self.done_svd = False
68 | self.scale = scale
69 | self.reset_parameters()
70 |
71 | def set_scale(self, scale: float):
72 | self.scale = scale
73 |
74 | def perform_svd(self):
75 | # shape
76 | weight_reshaped = rearrange(self.weight, 'co cin h w -> co (cin h w)')
77 | self.U, self.S, self.Vh = torch.linalg.svd(weight_reshaped, full_matrices=False)
78 | self.done_svd = True
79 |
80 | def reset_parameters(self):
81 | nn.Conv1d.reset_parameters(self)
82 | if hasattr(self, 'delta'):
83 | nn.init.zeros_(self.delta)
84 |
85 | def forward(self, x: torch.Tensor):
86 | if not self.done_svd:
87 | # this happens after loading the state dict
88 | self.perform_svd()
89 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
90 | weight_updated = rearrange(weight_updated, 'co (cin h w) -> co cin h w', cin=self.weight.size(1), h=self.weight.size(2), w=self.weight.size(3))
91 | return F.conv1d(x, weight_updated, self.bias, self.stride, self.padding, self.dilation, self.groups)
92 |
93 |
94 |
95 | class SVDLinear(nn.Linear):
96 | def __init__(
97 | self,
98 | in_features: int,
99 | out_features: int,
100 | scale: float = 1.0,
101 | **kwargs
102 | ):
103 | nn.Linear.__init__(self, in_features, out_features, **kwargs)
104 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
105 | # initialize to 0 for smooth tuning
106 | self.delta = nn.Parameter(torch.zeros_like(self.S))
107 | self.weight.requires_grad = False
108 | self.done_svd = False
109 | self.scale = scale
110 | self.reset_parameters()
111 |
112 | def set_scale(self, scale: float):
113 | self.scale = scale
114 |
115 | def perform_svd(self):
116 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
117 | self.done_svd = True
118 |
119 | def reset_parameters(self):
120 | nn.Linear.reset_parameters(self)
121 | if hasattr(self, 'delta'):
122 | nn.init.zeros_(self.delta)
123 |
124 | def forward(self, x: torch.Tensor):
125 | if not self.done_svd:
126 | # this happens after loading the state dict
127 | self.perform_svd()
128 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
129 | return F.linear(x, weight_updated, bias=self.bias)
130 |
131 |
132 | class SVDEmbedding(nn.Embedding):
133 | # LoRA implemented in a dense layer
134 | def __init__(
135 | self,
136 | num_embeddings: int,
137 | embedding_dim: int,
138 | scale: float = 1.0,
139 | **kwargs
140 | ):
141 | nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
142 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
143 | # initialize to 0 for smooth tuning
144 | self.delta = nn.Parameter(torch.zeros_like(self.S))
145 | self.weight.requires_grad = False
146 | self.done_svd = False
147 | self.scale = scale
148 | self.reset_parameters()
149 |
150 | def set_scale(self, scale: float):
151 | self.scale = scale
152 |
153 | def perform_svd(self):
154 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight, full_matrices=False)
155 | self.done_svd = True
156 |
157 | def reset_parameters(self):
158 | nn.Embedding.reset_parameters(self)
159 | if hasattr(self, 'delta'):
160 | nn.init.zeros_(self.delta)
161 |
162 | def forward(self, x: torch.Tensor):
163 | if not self.done_svd:
164 | # this happens after loading the state dict
165 | self.perform_svd()
166 | weight_updated = self.U.to(x.device) @ torch.diag(F.relu(self.S.to(x.device)+self.scale * self.delta)) @ self.Vh.to(x.device)
167 | return F.embedding(x, weight_updated, padding_idx=self.padding_idx, max_norm=self.max_norm, norm_type=self.norm_type, scale_grad_by_freq=self.scale_grad_by_freq, sparse=self.sparse)
168 |
169 |
170 | # 1-D
171 | class SVDLayerNorm(nn.LayerNorm):
172 | def __init__(
173 | self,
174 | normalized_shape: int,
175 | scale: float = 1.0,
176 | **kwargs
177 | ):
178 | nn.LayerNorm.__init__(self, normalized_shape=normalized_shape, **kwargs)
179 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
180 | # initialize to 0 for smooth tuning
181 | self.delta = nn.Parameter(torch.zeros_like(self.S))
182 | self.weight.requires_grad = False
183 | self.done_svd = False
184 | self.scale = scale
185 | self.reset_parameters()
186 |
187 | def set_scale(self, scale: float):
188 | self.scale = scale
189 |
190 | def perform_svd(self):
191 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
192 | self.done_svd = True
193 |
194 | def reset_parameters(self):
195 | nn.LayerNorm.reset_parameters(self)
196 | if hasattr(self, 'delta'):
197 | nn.init.zeros_(self.delta)
198 |
199 | def forward(self, x: torch.Tensor):
200 | if not self.done_svd:
201 | # this happens after loading the state dict
202 | self.perform_svd()
203 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
204 | weight_updated = weight_updated.squeeze(0)
205 | return F.layer_norm(x, normalized_shape=self.normalized_shape, weight=weight_updated, bias=self.bias, eps=self.eps)
206 |
207 |
208 | class SVDGroupNorm(nn.GroupNorm):
209 | def __init__(
210 | self,
211 | num_groups: int,
212 | num_channels: int,
213 | scale: float = 1.0,
214 | **kwargs
215 | ):
216 | nn.GroupNorm.__init__(self, num_groups, num_channels, **kwargs)
217 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
218 | # initialize to 0 for smooth tuning
219 | self.delta = nn.Parameter(torch.zeros_like(self.S))
220 | self.weight.requires_grad = False
221 | self.done_svd = False
222 | self.scale = scale
223 | self.reset_parameters()
224 |
225 | def set_scale(self, scale: float):
226 | self.scale = scale
227 |
228 | def perform_svd(self):
229 | self.U, self.S, self.Vh = torch.linalg.svd(self.weight.unsqueeze(0), full_matrices=False)
230 | self.done_svd = True
231 |
232 | def reset_parameters(self):
233 | nn.GroupNorm.reset_parameters(self)
234 | if hasattr(self, 'delta'):
235 | nn.init.zeros_(self.delta)
236 |
237 | def forward(self, x: torch.Tensor):
238 | if not self.done_svd:
239 | # this happens after loading the state dict
240 | self.perform_svd()
241 | weight_updated = self.U.to(x.device, dtype=x.dtype) @ torch.diag(F.relu(self.S.to(x.device, dtype=x.dtype)+self.scale * self.delta)) @ self.Vh.to(x.device, dtype=x.dtype)
242 | weight_updated = weight_updated.squeeze(0)
243 | return F.group_norm(x, num_groups=self.num_groups, weight=weight_updated, bias=self.bias, eps=self.eps)
244 |
245 |
--------------------------------------------------------------------------------
/svdiff_pytorch/pipeline_stable_diffusion_ddim_inversion.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Callable, Dict, List, Optional, Union
2 | import PIL
3 | import torch
4 | from diffusers import StableDiffusionPipeline, DDIMInverseScheduler
5 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import preprocess
6 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_pix2pix_zero import Pix2PixInversionPipelineOutput
7 |
8 |
9 | class StableDiffusionPipelineWithDDIMInversion(StableDiffusionPipeline):
10 | def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker: bool = True):
11 | super().__init__(vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker)
12 | self.inverse_scheduler = DDIMInverseScheduler.from_config(self.scheduler.config)
13 | # self.register_modules(inverse_scheduler=DDIMInverseScheduler.from_config(self.scheduler.config))
14 |
15 |
16 | def prepare_image_latents(self, image, batch_size, dtype, device, generator=None):
17 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
18 | raise ValueError(
19 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
20 | )
21 |
22 | image = image.to(device=device, dtype=dtype)
23 |
24 | if isinstance(generator, list) and len(generator) != batch_size:
25 | raise ValueError(
26 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
27 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
28 | )
29 |
30 | if isinstance(generator, list):
31 | init_latents = [
32 | self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
33 | ]
34 | init_latents = torch.cat(init_latents, dim=0)
35 | else:
36 | init_latents = self.vae.encode(image).latent_dist.sample(generator)
37 |
38 | init_latents = self.vae.config.scaling_factor * init_latents
39 |
40 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
41 | raise ValueError(
42 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
43 | )
44 | else:
45 | init_latents = torch.cat([init_latents], dim=0)
46 |
47 | latents = init_latents
48 |
49 | return latents
50 |
51 | def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep: int):
52 | pred_type = self.inverse_scheduler.config.prediction_type
53 | alpha_prod_t = self.inverse_scheduler.alphas_cumprod[timestep]
54 |
55 | beta_prod_t = 1 - alpha_prod_t
56 |
57 | if pred_type == "epsilon":
58 | return model_output
59 | elif pred_type == "sample":
60 | return (sample - alpha_prod_t ** (0.5) * model_output) / beta_prod_t ** (0.5)
61 | elif pred_type == "v_prediction":
62 | return (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
63 | else:
64 | raise ValueError(
65 | f"prediction_type given as {pred_type} must be one of `epsilon`, `sample`, or `v_prediction`"
66 | )
67 |
68 | def auto_corr_loss(self, hidden_states, generator=None):
69 | batch_size, channel, height, width = hidden_states.shape
70 | if batch_size > 1:
71 | raise ValueError("Only batch_size 1 is supported for now")
72 |
73 | hidden_states = hidden_states.squeeze(0)
74 | # hidden_states must be shape [C,H,W] now
75 | reg_loss = 0.0
76 | for i in range(hidden_states.shape[0]):
77 | noise = hidden_states[i][None, None, :, :]
78 | while True:
79 | roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item()
80 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2
81 | reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2
82 |
83 | if noise.shape[2] <= 8:
84 | break
85 | noise = F.avg_pool2d(noise, kernel_size=2)
86 | return reg_loss
87 |
88 | def kl_divergence(self, hidden_states):
89 | mean = hidden_states.mean()
90 | var = hidden_states.var()
91 | return var + mean**2 - 1 - torch.log(var + 1e-7)
92 |
93 |
94 | # based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py#L1063
95 | @torch.no_grad()
96 | def invert(
97 | self,
98 | prompt: Optional[str] = None,
99 | image: Union[torch.FloatTensor, PIL.Image.Image] = None,
100 | num_inference_steps: int = 50,
101 | guidance_scale: float = 1,
102 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
103 | latents: Optional[torch.FloatTensor] = None,
104 | prompt_embeds: Optional[torch.FloatTensor] = None,
105 | output_type: Optional[str] = "pil",
106 | return_dict: bool = True,
107 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
108 | callback_steps: Optional[int] = 1,
109 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
110 | lambda_auto_corr: float = 20.0,
111 | lambda_kl: float = 20.0,
112 | num_reg_steps: int = 0, # disabled
113 | num_auto_corr_rolls: int = 5,
114 | ):
115 | # 1. Define call parameters
116 | if prompt is not None and isinstance(prompt, str):
117 | batch_size = 1
118 | elif prompt is not None and isinstance(prompt, list):
119 | batch_size = len(prompt)
120 | else:
121 | batch_size = prompt_embeds.shape[0]
122 | if cross_attention_kwargs is None:
123 | cross_attention_kwargs = {}
124 |
125 | device = self._execution_device
126 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
127 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
128 | # corresponds to doing no classifier free guidance.
129 | do_classifier_free_guidance = guidance_scale > 1.0
130 |
131 | # 3. Preprocess image
132 | image = preprocess(image)
133 |
134 | # 4. Prepare latent variables
135 | latents = self.prepare_image_latents(image, batch_size, self.vae.dtype, device, generator)
136 |
137 | # 5. Encode input prompt
138 | num_images_per_prompt = 1
139 | prompt_embeds = self._encode_prompt(
140 | prompt,
141 | device,
142 | num_images_per_prompt,
143 | do_classifier_free_guidance,
144 | prompt_embeds=prompt_embeds,
145 | )
146 |
147 | # 4. Prepare timesteps
148 | self.inverse_scheduler.set_timesteps(num_inference_steps, device=device)
149 | timesteps = self.inverse_scheduler.timesteps
150 |
151 | # 7. Denoising loop where we obtain the cross-attention maps.
152 | num_warmup_steps = len(timesteps) - num_inference_steps * self.inverse_scheduler.order
153 | with self.progress_bar(total=num_inference_steps - 1) as progress_bar:
154 | for i, t in enumerate(timesteps[:-1]):
155 | # expand the latents if we are doing classifier free guidance
156 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
157 | latent_model_input = self.inverse_scheduler.scale_model_input(latent_model_input, t)
158 |
159 | # predict the noise residual
160 | noise_pred = self.unet(
161 | latent_model_input,
162 | t,
163 | encoder_hidden_states=prompt_embeds,
164 | cross_attention_kwargs=cross_attention_kwargs,
165 | ).sample
166 |
167 | # perform guidance
168 | if do_classifier_free_guidance:
169 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
170 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
171 |
172 | # regularization of the noise prediction
173 | with torch.enable_grad():
174 | for _ in range(num_reg_steps):
175 | if lambda_auto_corr > 0:
176 | for _ in range(num_auto_corr_rolls):
177 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
178 |
179 | # Derive epsilon from model output before regularizing to IID standard normal
180 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
181 |
182 | l_ac = self.auto_corr_loss(var_epsilon, generator=generator)
183 | l_ac.backward()
184 |
185 | grad = var.grad.detach() / num_auto_corr_rolls
186 | noise_pred = noise_pred - lambda_auto_corr * grad
187 |
188 | if lambda_kl > 0:
189 | var = torch.autograd.Variable(noise_pred.detach().clone(), requires_grad=True)
190 |
191 | # Derive epsilon from model output before regularizing to IID standard normal
192 | var_epsilon = self.get_epsilon(var, latent_model_input.detach(), t)
193 |
194 | l_kld = self.kl_divergence(var_epsilon)
195 | l_kld.backward()
196 |
197 | grad = var.grad.detach()
198 | noise_pred = noise_pred - lambda_kl * grad
199 |
200 | noise_pred = noise_pred.detach()
201 |
202 | # compute the previous noisy sample x_t -> x_t-1
203 | latents = self.inverse_scheduler.step(noise_pred, t, latents).prev_sample
204 |
205 | # call the callback, if provided
206 | if i == len(timesteps) - 1 or (
207 | (i + 1) > num_warmup_steps and (i + 1) % self.inverse_scheduler.order == 0
208 | ):
209 | progress_bar.update()
210 | if callback is not None and i % callback_steps == 0:
211 | callback(i, t, latents)
212 |
213 | inverted_latents = latents.detach().clone()
214 |
215 | # 8. Post-processing
216 | image = self.decode_latents(latents.detach())
217 |
218 | # Offload last model to CPU
219 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
220 | self.final_offload_hook.offload()
221 |
222 | # 9. Convert to PIL.
223 | if output_type == "pil":
224 | image = self.numpy_to_pil(image)
225 |
226 | if not return_dict:
227 | return (inverted_latents, image)
228 |
229 | return Pix2PixInversionPipelineOutput(latents=inverted_latents, images=image)
230 |
231 |
232 |
233 | if __name__ == '__main__':
234 | from PIL import Image
235 | from diffusers import DDIMScheduler
236 | model_id = "CompVis/stable-diffusion-v1-4"
237 | input_prompt = "A photo of Barack Obama"
238 | prompt = "A photo of Barack Obama smiling with a big grin"
239 | url = "obama.png" # https://github.com/cccntu/efficient-prompt-to-prompt/blob/main/ddim-inversion.ipynb
240 |
241 | pipe = StableDiffusionPipelineWithDDIMInversion.from_pretrained(
242 | model_id,
243 | # make sure to load ddim here
244 | scheduler=DDIMScheduler.from_pretrained(model_id, subfolder="scheduler"),
245 | )
246 | image = Image.open(url).convert("RGB").resize((512, 512))
247 | # in SVDiff, they use guidance scale=1 in ddim inversion
248 | inv_latents = pipe.invert(input_prompt, image=image, guidance_scale=1.0).latents
249 | image = pipe(prompt, latents=inv_latents).images[0]
250 | image.save("out.png")
--------------------------------------------------------------------------------
/svdiff_pytorch/transformers_models_clip/__init__.py:
--------------------------------------------------------------------------------
1 | # all files in this folder were taken from https://github.com/huggingface/transformers/blob/v4.27.3/src/transformers/models/clip/modeling_clip.py
2 | # so, these files follow the LICENSE of transformers
--------------------------------------------------------------------------------
/svdiff_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import inspect
3 | from PIL import Image
4 |
5 | import torch
6 | import accelerate
7 | from accelerate.utils import set_module_tensor_to_device
8 | from diffusers import (
9 | LMSDiscreteScheduler,
10 | DDIMScheduler,
11 | PNDMScheduler,
12 | DPMSolverMultistepScheduler,
13 | EulerDiscreteScheduler,
14 | EulerAncestralDiscreteScheduler,
15 | )
16 | from transformers import CLIPTextModel, CLIPTextConfig
17 | from diffusers import UNet2DConditionModel
18 | from safetensors.torch import safe_open
19 | import huggingface_hub
20 | from svdiff_pytorch import UNet2DConditionModelForSVDiff, CLIPTextModelForSVDiff
21 |
22 |
23 |
24 | def load_unet_for_svdiff(pretrained_model_name_or_path, spectral_shifts_ckpt=None, hf_hub_kwargs=None, **kwargs):
25 | """
26 | https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/models/modeling_utils.py#L541
27 | """
28 | config = UNet2DConditionModel.load_config(pretrained_model_name_or_path, **kwargs)
29 | original_model = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
30 | state_dict = original_model.state_dict()
31 | with accelerate.init_empty_weights():
32 | model = UNet2DConditionModelForSVDiff.from_config(config)
33 | # load pre-trained weights
34 | param_device = "cpu"
35 | torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
36 | spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
37 | state_dict.update(spectral_shifts_weights)
38 | # move the params from meta device to cpu
39 | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
40 | if len(missing_keys) > 0:
41 | raise ValueError(
42 | f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
43 | f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
44 | " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
45 | " those weights or else make sure your checkpoint file is correct."
46 | )
47 |
48 | for param_name, param in state_dict.items():
49 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
50 | if accepts_dtype:
51 | set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
52 | else:
53 | set_module_tensor_to_device(model, param_name, param_device, value=param)
54 |
55 | if spectral_shifts_ckpt:
56 | if os.path.isdir(spectral_shifts_ckpt):
57 | spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts.safetensors")
58 | elif not os.path.exists(spectral_shifts_ckpt):
59 | # download from hub
60 | hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
61 | spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts.safetensors", **hf_hub_kwargs)
62 | assert os.path.exists(spectral_shifts_ckpt)
63 |
64 | with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
65 | for key in f.keys():
66 | # spectral_shifts_weights[key] = f.get_tensor(key)
67 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
68 | if accepts_dtype:
69 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
70 | else:
71 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
72 | print(f"Resumed from {spectral_shifts_ckpt}")
73 | if "torch_dtype"in kwargs:
74 | model = model.to(kwargs["torch_dtype"])
75 | model.register_to_config(_name_or_path=pretrained_model_name_or_path)
76 | # Set model in evaluation mode to deactivate DropOut modules by default
77 | model.eval()
78 | del original_model
79 | torch.cuda.empty_cache()
80 | return model
81 |
82 |
83 |
84 | def load_text_encoder_for_svdiff(
85 | pretrained_model_name_or_path,
86 | spectral_shifts_ckpt=None,
87 | hf_hub_kwargs=None,
88 | **kwargs
89 | ):
90 | """
91 | https://github.com/huggingface/diffusers/blob/v0.14.0/src/diffusers/models/modeling_utils.py#L541
92 | """
93 | config = CLIPTextConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
94 | original_model = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
95 | state_dict = original_model.state_dict()
96 | with accelerate.init_empty_weights():
97 | model = CLIPTextModelForSVDiff(config)
98 | # load pre-trained weights
99 | param_device = "cpu"
100 | torch_dtype = kwargs["torch_dtype"] if "torch_dtype" in kwargs else None
101 | spectral_shifts_weights = {n: torch.zeros(p.shape) for n, p in model.named_parameters() if "delta" in n}
102 | state_dict.update(spectral_shifts_weights)
103 | # move the params from meta device to cpu
104 | missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
105 | if len(missing_keys) > 0:
106 | raise ValueError(
107 | f"Cannot load {model.__class__.__name__} from {pretrained_model_name_or_path} because the following keys are"
108 | f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
109 | " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomely initialize"
110 | " those weights or else make sure your checkpoint file is correct."
111 | )
112 |
113 | for param_name, param in state_dict.items():
114 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
115 | if accepts_dtype:
116 | set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
117 | else:
118 | set_module_tensor_to_device(model, param_name, param_device, value=param)
119 |
120 | if spectral_shifts_ckpt:
121 | if os.path.isdir(spectral_shifts_ckpt):
122 | spectral_shifts_ckpt = os.path.join(spectral_shifts_ckpt, "spectral_shifts_te.safetensors")
123 | elif not os.path.exists(spectral_shifts_ckpt):
124 | # download from hub
125 | hf_hub_kwargs = {} if hf_hub_kwargs is None else hf_hub_kwargs
126 | try:
127 | spectral_shifts_ckpt = huggingface_hub.hf_hub_download(spectral_shifts_ckpt, filename="spectral_shifts_te.safetensors", **hf_hub_kwargs)
128 | except huggingface_hub.utils.EntryNotFoundError:
129 | spectral_shifts_ckpt = None
130 | # load state dict only if `spectral_shifts_te.safetensors` exists
131 | if os.path.exists(spectral_shifts_ckpt):
132 | with safe_open(spectral_shifts_ckpt, framework="pt", device="cpu") as f:
133 | for key in f.keys():
134 | # spectral_shifts_weights[key] = f.get_tensor(key)
135 | accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
136 | if accepts_dtype:
137 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key), dtype=torch_dtype)
138 | else:
139 | set_module_tensor_to_device(model, key, param_device, value=f.get_tensor(key))
140 | print(f"Resumed from {spectral_shifts_ckpt}")
141 |
142 | if "torch_dtype"in kwargs:
143 | model = model.to(kwargs["torch_dtype"])
144 | # model.register_to_config(_name_or_path=pretrained_model_name_or_path)
145 | # Set model in evaluation mode to deactivate DropOut modules by default
146 | model.eval()
147 | del original_model
148 | torch.cuda.empty_cache()
149 | return model
150 |
151 |
152 |
153 | def image_grid(imgs, rows, cols):
154 | assert len(imgs) == rows * cols
155 | w, h = imgs[0].size
156 | grid = Image.new('RGB', size=(cols * w, rows * h))
157 | for i, img in enumerate(imgs):
158 | grid.paste(img, box=(i % cols * w, i // cols * h))
159 | return grid
160 |
161 |
162 | def slerp(val, low, high):
163 | """ taken from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/4
164 | """
165 | low_norm = low/torch.norm(low, dim=1, keepdim=True)
166 | high_norm = high/torch.norm(high, dim=1, keepdim=True)
167 | omega = torch.acos((low_norm*high_norm).sum(1))
168 | so = torch.sin(omega)
169 | res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
170 | return res
171 |
172 |
173 | def slerp_tensor(val, low, high):
174 | shape = low.shape
175 | res = slerp(val, low.flatten(1), high.flatten(1))
176 | return res.reshape(shape)
177 |
178 |
179 | SCHEDULER_MAPPING = {
180 | "ddim": DDIMScheduler,
181 | "plms": PNDMScheduler,
182 | "lms": LMSDiscreteScheduler,
183 | "euler": EulerDiscreteScheduler,
184 | "euler_ancestral": EulerAncestralDiscreteScheduler,
185 | "dpm_solver++": DPMSolverMultistepScheduler,
186 | }
187 |
188 |
--------------------------------------------------------------------------------