├── .gitignore
├── README.md
├── assert
├── .DS_Store
└── img
│ ├── .DS_Store
│ ├── multiple.png
│ ├── singleconcept.png
│ ├── story.jpg
│ └── video.jpg
├── blip2t.py
├── config
└── concept.json
├── data
├── barn
│ ├── cora-leach-UlKxoeGP4Rc-unsplash.jpg
│ ├── emmy-sobieski-pPoR4vyOomY-unsplash.jpg
│ ├── isaac-martin-ZCAiJPiO_98-unsplash.jpg
│ ├── jonah-brown-eYAsBnNISa4-unsplash.jpg
│ ├── melissa-cronin-P8Bfpi1HLIE-unsplash.jpg
│ ├── pedro-lastra-hXunh-ivkPc-unsplash.jpg
│ └── sarat-karumuri-vs1voV4jhpU-unsplash.jpg
├── bear
│ ├── marina-shatskih-6MDi8o6VYHg-unsplash.jpg
│ ├── marina-shatskih-B3Jnug3rKrs-unsplash.jpg
│ ├── marina-shatskih-BYZdCQDSNTY-unsplash.jpg
│ ├── marina-shatskih-F_imBoYwuTk-unsplash.jpg
│ ├── marina-shatskih-kBo2MFJz2QU-unsplash.jpg
│ ├── marina-shatskih-nKoIXtEfkZ8-unsplash.jpg
│ └── marina-shatskih-sIzj2poobss-unsplash.jpg
└── dog
│ ├── 00.jpg
│ ├── 01.jpg
│ ├── 02.jpg
│ ├── 03.jpg
│ └── 04.jpg
├── requirements.txt
├── scripts
├── train_multi.sh
└── train_single.sh
├── train_class_diffusion.py
├── utils
├── blip2t.py
├── load_attn_weight.py
├── semantic_preservation_loss.py
└── similarity.py
└── videogen.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ckpt/
2 | log/
3 | wandb/
4 | __pycache__/
5 | *.gif
6 | result/
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [ICLR 2025] ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance
2 |
3 | Official imple. of ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance
4 |
5 | > Recent text-to-image customization works have been proven successful in generating images of given concepts by fine-tuning the diffusion models on a few examples. However, these methods tend to overfit the concepts, resulting in failure to create the concept under multiple conditions (_e.g._, headphone is missing when generating “a <sks> dog wearing a headphone”). Interestingly, we notice that the base model before fine-tuning exhibits the capability to compose the base concept with other elements (_e.g._, “a dog wearing a headphone”), implying that the compositional ability only disappears after personalization tuning. Inspired by this observation, we present ClassDiffusion, a simple technique that leverages a semantic preservation loss to explicitly regulate the concept space when learning the new concept. Despite its simplicity, this helps avoid semantic drift when fine-tuning on the target concepts. Extensive qualitative and quantitative experiments demonstrate that the use of semantic preservation loss effectively improves the compositional abilities of the fine-tune models. In response to the ineffective evaluation of CLIP-T metrics, we introduce BLIP2-T metric, a more equitable and effective evaluation metric for this particular domain. We also provide in-depth empirical study and theoretical analysis to better understand the role of the proposed loss. Lastly, we also extend our ClassDiffusion to personalized video generation, demonstrating its flexibility.
6 |
7 | ---
8 |
9 | **[ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance](https://arxiv.org/pdf/2405.17532)**
10 |
11 | [Jiannan Huang](https://rbrq03.github.io/), [Jun Hao Liew](https://scholar.google.com.sg/citations?user=8gm-CYYAAAAJ), [Hanshu Yan](https://hanshuyan.github.io), [Yuyang Yin](https://yuyangyin.github.io), [Yao Zhao](http://mepro.bjtu.edu.cn/zhaoyao/index.htm), [Humphrey Shi](https://www.humphreyshi.com/), [Yunchao Wei](https://weiyc.github.io/index.html)
12 |
13 | [](https://classdiffusion.github.io/)
14 | [](https://arxiv.org/pdf/2405.17532)
15 |
16 |
17 |
18 |
19 | Our method can generate more aligned personalized images with explicit class guidance
20 |
21 |
22 | ## News
23 |
24 | - [23 Jan, 2025] 🎉 Our paper is accepted to ICLR2025!
25 | - [8 Jun. 2024] Code for BLIP2-T and Video Generation Realeased!
26 | - [3 Jun. 2024] Code Released!
27 | - [29 May. 2024] Paper Released!
28 |
29 | ## Code Usage
30 |
31 | **Set Up**
32 |
33 | ```
34 | git clone https://github.com/Rbrq03/ClassDiffusion.git
35 | cd ClassDiffusion
36 | pip install -r requirements.txt
37 | ```
38 |
39 | _Warning: Currently, ClassDiffusion don't support PEFT, please ensure PEFT is uninstalled in your environment, or check [PR](https://github.com/huggingface/diffusers/pull/7272). We will move forward with this PR merge soon._
40 |
41 | ### Training
42 |
43 | **Single Concept**
44 |
45 | ```
46 | bash scripts/train_single.sh
47 | ```
48 |
49 | **Multiple Concepts**
50 |
51 | ```
52 | bash scripts/train_multi.sh
53 | ```
54 |
55 | ### Inference
56 |
57 | **single concept**
58 |
59 | ```
60 | import torch
61 | from diffusers import DiffusionPipeline
62 |
63 | pipeline = DiffusionPipeline.from_pretrained(
64 | "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16,
65 | ).to("cuda")
66 | pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
67 | pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin")
68 |
69 | image = pipeline(
70 | " dog swimming in the pool",
71 | num_inference_steps=100,
72 | guidance_scale=6.0,
73 | eta=1.0,
74 | ).images[0]
75 | image.save("dog.png")
76 | ```
77 |
78 | **Multiple Concepts**
79 |
80 | ```
81 | import torch
82 | from diffusers import DiffusionPipeline
83 |
84 | pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda")
85 | pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
86 | pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin")
87 | pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin")
88 |
89 | image = pipeline(
90 | "a teddy bear sitting in front of a barn",
91 | num_inference_steps=100,
92 | guidance_scale=6.0,
93 | eta=1.0,
94 | ).images[0]
95 | image.save("multi-subject.png")
96 | ```
97 |
98 | **BLIP2-T**
99 |
100 | You can use following code:
101 |
102 | ```
103 | from PIL import Image
104 | from utils.blip2t import BLIP2T
105 |
106 | blip2t = BLIP2T("Salesforce/blip-itm-large-coco", "cuda")
107 |
108 | prompt = "photo of a dog"
109 | image = Image.open("data/dog/00.jpg")
110 |
111 | score = blip2t.text_similarity(prompt, image)[0]
112 | score
113 | ```
114 |
115 | or
116 |
117 | ```
118 | python blip2t.py
119 | ```
120 |
121 | **Video Generation**
122 |
123 | ```
124 | python videogen.py
125 | ```
126 |
127 | **Ckpt for quick test**
128 |
129 | |Concept(s)|Weight|
130 | |--|--|
131 | |dog|[weight](https://drive.google.com/drive/folders/12KhBmFCUb2opotOQeAH0-dvW9XArwukt?usp=drive_link)|
132 | |bear+barn|[weight](https://drive.google.com/drive/folders/1VQTvszl2FqKhc-9YaKsRTRV3qlhmOqK4?usp=drive_link)|
133 |
134 | ## Results
135 |
136 | **Single Concept Results**
137 |
138 |
139 |
140 |
141 |
142 |
143 | **Multiple Concepts Results**
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 | **Video Generation Results**
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 | ## TODO
160 |
161 | - [x] Training Code for ClassDiffusion
162 | - [x] Inference Code for ClassDiffusion
163 | - [x] Pipeline for BLIP2-T Score
164 | - [x] Inference Code for Video Generation with ClassDiffusion
165 |
166 | ## Citation
167 |
168 | If you make use of our work, please cite our paper.
169 |
170 | ```bibtex
171 | @article{huang2024classdiffusion,
172 | title={ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance},
173 | author={Huang, Jiannan and Liew, Jun Hao and Yan, Hanshu and Yin, Yuyang and Zhao, Yao and Wei, Yunchao},
174 | journal={arXiv preprint arXiv:2405.17532},
175 | year={2024}
176 | }
177 | ```
178 |
179 | ## Acknowledgement
180 |
181 | We thanks to the following repo for their excellent and well-documented code based:
182 |
183 | - Diffusers: [https://github.com/huggingface/diffusers](https://github.com/huggingface/diffusers)
184 | - Custom Diffusion: [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion)
185 | - Transformers: [https://github.com/huggingface/transformers](https://github.com/huggingface/transformers)
186 | - DreamBooth: [https://github.com/google/dreambooth](https://github.com/google/dreambooth)
187 | - AnimateDiff: [https://github.com/guoyww/AnimateDiff](https://github.com/guoyww/AnimateDiff)
188 |
--------------------------------------------------------------------------------
/assert/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/.DS_Store
--------------------------------------------------------------------------------
/assert/img/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/.DS_Store
--------------------------------------------------------------------------------
/assert/img/multiple.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/multiple.png
--------------------------------------------------------------------------------
/assert/img/singleconcept.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/singleconcept.png
--------------------------------------------------------------------------------
/assert/img/story.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/story.jpg
--------------------------------------------------------------------------------
/assert/img/video.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/video.jpg
--------------------------------------------------------------------------------
/blip2t.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | from utils.blip2t import BLIP2T
3 |
4 | blip2t = BLIP2T("Salesforce/blip-itm-large-coco", "cpu")
5 |
6 | prompt = "photo of a dog"
7 | image = Image.open("data/dog/00.jpg")
8 |
9 | score = blip2t.text_similarity(prompt, image)[0]
10 | print(score)
11 |
--------------------------------------------------------------------------------
/config/concept.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "instance_prompt": "a photo of a teddy bear",
4 | "class_prompt": "bear",
5 | "instance_data_dir": "data/teddybear"
6 | },
7 | {
8 | "instance_prompt": "a photo of a barn",
9 | "class_prompt": "barn",
10 | "instance_data_dir": "data/barn"
11 | }
12 | ]
13 |
--------------------------------------------------------------------------------
/data/barn/cora-leach-UlKxoeGP4Rc-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/cora-leach-UlKxoeGP4Rc-unsplash.jpg
--------------------------------------------------------------------------------
/data/barn/emmy-sobieski-pPoR4vyOomY-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/emmy-sobieski-pPoR4vyOomY-unsplash.jpg
--------------------------------------------------------------------------------
/data/barn/isaac-martin-ZCAiJPiO_98-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/isaac-martin-ZCAiJPiO_98-unsplash.jpg
--------------------------------------------------------------------------------
/data/barn/jonah-brown-eYAsBnNISa4-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/jonah-brown-eYAsBnNISa4-unsplash.jpg
--------------------------------------------------------------------------------
/data/barn/melissa-cronin-P8Bfpi1HLIE-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/melissa-cronin-P8Bfpi1HLIE-unsplash.jpg
--------------------------------------------------------------------------------
/data/barn/pedro-lastra-hXunh-ivkPc-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/pedro-lastra-hXunh-ivkPc-unsplash.jpg
--------------------------------------------------------------------------------
/data/barn/sarat-karumuri-vs1voV4jhpU-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/sarat-karumuri-vs1voV4jhpU-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-6MDi8o6VYHg-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-6MDi8o6VYHg-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-B3Jnug3rKrs-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-B3Jnug3rKrs-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-BYZdCQDSNTY-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-BYZdCQDSNTY-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-F_imBoYwuTk-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-F_imBoYwuTk-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-kBo2MFJz2QU-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-kBo2MFJz2QU-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-nKoIXtEfkZ8-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-nKoIXtEfkZ8-unsplash.jpg
--------------------------------------------------------------------------------
/data/bear/marina-shatskih-sIzj2poobss-unsplash.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-sIzj2poobss-unsplash.jpg
--------------------------------------------------------------------------------
/data/dog/00.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/00.jpg
--------------------------------------------------------------------------------
/data/dog/01.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/01.jpg
--------------------------------------------------------------------------------
/data/dog/02.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/02.jpg
--------------------------------------------------------------------------------
/data/dog/03.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/03.jpg
--------------------------------------------------------------------------------
/data/dog/04.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/04.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate==0.24.1
2 | torchvision
3 | transformers>=4.25.1
4 | ftfy
5 | tensorboard
6 | Jinja2
7 | git+https://github.com/huggingface/diffusers
--------------------------------------------------------------------------------
/scripts/train_multi.sh:
--------------------------------------------------------------------------------
1 | export CLS_TOKEN1="teddybear"
2 | export CLS_TOKEN2="barn"
3 | export MODEL_NAME="runwayml/stable-diffusion-v1-5"
4 | export OUTPUT_DIR="ckpt/${CLS_TOKEN1}_${CLS_TOKEN2}"
5 |
6 | accelerate launch train_class_diffusion.py \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --output_dir=$OUTPUT_DIR \
9 | --concepts_list=config/concept.json \
10 | --resolution=512 \
11 | --train_batch_size=2 \
12 | --learning_rate=1e-5 \
13 | --lr_warmup_steps=0 \
14 | --max_train_steps=800 \
15 | --scale_lr \
16 | --hflip \
17 | --modifier_token "+" \
18 | --use_spl \
19 | --spl_weight=1.0 \
20 | --cls_token "${CLS_TOKEN1}+${CLS_TOKEN2}" \
21 | --no_safe_serialization \
22 | --report_to "wandb" \
23 | --validation_steps=25 \
24 | --validation_prompt="a ${CLS_TOKEN1} sitting in front of a ${CLS_TOKEN2}" \
--------------------------------------------------------------------------------
/scripts/train_single.sh:
--------------------------------------------------------------------------------
1 | export CLS_TOKEN="dog"
2 | export MODEL_NAME="runwayml/stable-diffusion-v1-5"
3 | export OUTPUT_DIR="./ckpt/${CLS_TOKEN}_cls"
4 | export INSTANCE_DIR="./data/${CLS_TOKEN}"
5 |
6 | accelerate launch train_class_diffusion.py \
7 | --pretrained_model_name_or_path=$MODEL_NAME \
8 | --instance_data_dir=$INSTANCE_DIR \
9 | --output_dir=$OUTPUT_DIR \
10 | --instance_prompt="a photo of a ${CLS_TOKEN}" \
11 | --resolution=512 \
12 | --train_batch_size=2 \
13 | --learning_rate=1e-5 \
14 | --lr_warmup_steps=0 \
15 | --max_train_steps=500 \
16 | --scale_lr \
17 | --hflip \
18 | --modifier_token "" \
19 | --no_safe_serialization \
20 | --use_spl \
21 | --spl_weight=1 \
22 | --cls_token "${CLS_TOKEN}" \
23 | --report_to "wandb" \
24 | --validation_steps=50 \
25 | --validation_prompt="a ${CLS_TOKEN} is swimming" \
26 | --tracker_name "custom-diffusion-multi" \
27 |
--------------------------------------------------------------------------------
/train_class_diffusion.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2024 Class Diffusion authors from BJTU and the HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import itertools
18 | import json
19 | import logging
20 | import math
21 | import os
22 | import time
23 | import random
24 | import shutil
25 | import warnings
26 | from pathlib import Path
27 | import types
28 | import numpy as np
29 | import safetensors
30 | import torch
31 | import torch.nn.functional as F
32 | import torch.utils.checkpoint
33 | import transformers
34 | from accelerate import Accelerator
35 | from accelerate.logging import get_logger
36 | from accelerate.utils import ProjectConfiguration, set_seed
37 | from huggingface_hub import HfApi, create_repo
38 | from huggingface_hub.utils import insecure_hashlib
39 | from packaging import version
40 | from PIL import Image
41 | from torch.utils.data import Dataset
42 | from torchvision import transforms
43 | from tqdm.auto import tqdm
44 | from transformers import AutoTokenizer, PretrainedConfig
45 | import torch.nn.functional as F
46 |
47 | import subprocess
48 | import diffusers
49 | from diffusers import (
50 | AutoencoderKL,
51 | DDPMScheduler,
52 | DiffusionPipeline,
53 | DPMSolverMultistepScheduler,
54 | UNet2DConditionModel,
55 | )
56 | from diffusers.loaders import AttnProcsLayers
57 | from diffusers.models.attention_processor import (
58 | CustomDiffusionAttnProcessor,
59 | CustomDiffusionAttnProcessor2_0,
60 | CustomDiffusionXFormersAttnProcessor,
61 | )
62 | from diffusers.optimization import get_scheduler
63 | from diffusers.utils import check_min_version, is_wandb_available
64 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
65 | from diffusers.utils.import_utils import is_xformers_available
66 | from utils.semantic_preservation_loss import forward_with_custom_embeddings, SPL
67 |
68 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
69 | check_min_version("0.27.0.dev0")
70 |
71 | logger = get_logger(__name__)
72 |
73 |
74 | def freeze_params(params):
75 | for param in params:
76 | param.requires_grad = False
77 |
78 |
79 | def save_model_card(
80 | repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None
81 | ):
82 | img_str = ""
83 | for i, image in enumerate(images):
84 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
85 | img_str += f"\n"
86 |
87 | model_description = f"""
88 | # Class Diffusion - {repo_id}
89 |
90 | These are Class Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Class Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n
91 | {img_str}
92 |
93 | \nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion).
94 | """
95 | model_card = load_or_create_model_card(
96 | repo_id_or_path=repo_id,
97 | from_training=True,
98 | license="creativeml-openrail-m",
99 | base_model=base_model,
100 | prompt=prompt,
101 | model_description=model_description,
102 | inference=True,
103 | )
104 |
105 | tags = [
106 | "text-to-image",
107 | "diffusers",
108 | "stable-diffusion",
109 | "stable-diffusion-diffusers",
110 | "custom-diffusion",
111 | "diffusers-training",
112 | ]
113 | model_card = populate_model_card(model_card, tags=tags)
114 |
115 | model_card.save(os.path.join(repo_folder, "README.md"))
116 |
117 |
118 | def import_model_class_from_model_name_or_path(
119 | pretrained_model_name_or_path: str, revision: str
120 | ):
121 | text_encoder_config = PretrainedConfig.from_pretrained(
122 | pretrained_model_name_or_path,
123 | subfolder="text_encoder",
124 | revision=revision,
125 | )
126 | model_class = text_encoder_config.architectures[0]
127 |
128 | if model_class == "CLIPTextModel":
129 | from transformers import CLIPTextModel
130 |
131 | return CLIPTextModel
132 | elif model_class == "RobertaSeriesModelWithTransformation":
133 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import (
134 | RobertaSeriesModelWithTransformation,
135 | )
136 |
137 | return RobertaSeriesModelWithTransformation
138 | else:
139 | raise ValueError(f"{model_class} is not supported.")
140 |
141 |
142 | def collate_fn(examples, with_prior_preservation):
143 | input_ids = [example["instance_prompt_ids"] for example in examples]
144 | pixel_values = [example["instance_images"] for example in examples]
145 | mask = [example["mask"] for example in examples]
146 | # Concat class and instance examples for prior preservation.
147 | # We do this to avoid doing two forward passes.
148 | if with_prior_preservation:
149 | input_ids += [example["class_prompt_ids"] for example in examples]
150 | pixel_values += [example["class_images"] for example in examples]
151 | mask += [example["class_mask"] for example in examples]
152 |
153 | input_ids = torch.cat(input_ids, dim=0)
154 | pixel_values = torch.stack(pixel_values)
155 | mask = torch.stack(mask)
156 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
157 | mask = mask.to(memory_format=torch.contiguous_format).float()
158 |
159 | batch = {
160 | "input_ids": input_ids,
161 | "pixel_values": pixel_values,
162 | "mask": mask.unsqueeze(1),
163 | }
164 | return batch
165 |
166 |
167 | class PromptDataset(Dataset):
168 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
169 |
170 | def __init__(self, prompt, num_samples):
171 | self.prompt = prompt
172 | self.num_samples = num_samples
173 |
174 | def __len__(self):
175 | return self.num_samples
176 |
177 | def __getitem__(self, index):
178 | example = {}
179 | example["prompt"] = self.prompt
180 | example["index"] = index
181 | return example
182 |
183 |
184 | class ClassDiffusionDataset(Dataset):
185 | """
186 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
187 | It pre-processes the images and the tokenizes prompts.
188 | """
189 |
190 | def __init__(
191 | self,
192 | concepts_list,
193 | tokenizer,
194 | size=512,
195 | mask_size=64,
196 | center_crop=False,
197 | with_prior_preservation=False,
198 | num_class_images=200,
199 | hflip=False,
200 | aug=True,
201 | ):
202 | self.size = size
203 | self.mask_size = mask_size
204 | self.center_crop = center_crop
205 | self.tokenizer = tokenizer
206 | self.interpolation = Image.BILINEAR
207 | self.aug = aug
208 |
209 | self.instance_images_path = []
210 | self.class_images_path = []
211 | self.with_prior_preservation = with_prior_preservation
212 | for concept in concepts_list:
213 | inst_img_path = [
214 | (x, concept["instance_prompt"])
215 | for x in Path(concept["instance_data_dir"]).iterdir()
216 | if x.is_file()
217 | ]
218 | self.instance_images_path.extend(inst_img_path)
219 |
220 | if with_prior_preservation:
221 | class_data_root = Path(concept["class_data_dir"])
222 | if os.path.isdir(class_data_root):
223 | class_images_path = list(class_data_root.iterdir())
224 | class_prompt = [
225 | concept["class_prompt"] for _ in range(len(class_images_path))
226 | ]
227 | else:
228 | with open(class_data_root, "r") as f:
229 | class_images_path = f.read().splitlines()
230 | with open(concept["class_prompt"], "r") as f:
231 | class_prompt = f.read().splitlines()
232 |
233 | class_img_path = list(zip(class_images_path, class_prompt))
234 | self.class_images_path.extend(class_img_path[:num_class_images])
235 |
236 | random.shuffle(self.instance_images_path)
237 | self.num_instance_images = len(self.instance_images_path)
238 | self.num_class_images = len(self.class_images_path)
239 | self._length = max(self.num_class_images, self.num_instance_images)
240 | self.flip = transforms.RandomHorizontalFlip(0.5 * hflip)
241 |
242 | self.image_transforms = transforms.Compose(
243 | [
244 | self.flip,
245 | transforms.Resize(
246 | size, interpolation=transforms.InterpolationMode.BILINEAR
247 | ),
248 | (
249 | transforms.CenterCrop(size)
250 | if center_crop
251 | else transforms.RandomCrop(size)
252 | ),
253 | transforms.ToTensor(),
254 | transforms.Normalize([0.5], [0.5]),
255 | ]
256 | )
257 |
258 | def __len__(self):
259 | return self._length
260 |
261 | def preprocess(self, image, scale, resample):
262 | outer, inner = self.size, scale
263 | factor = self.size // self.mask_size
264 | if scale > self.size:
265 | outer, inner = scale, self.size
266 | top, left = np.random.randint(0, outer - inner + 1), np.random.randint(
267 | 0, outer - inner + 1
268 | )
269 | image = image.resize((scale, scale), resample=resample)
270 | image = np.array(image).astype(np.uint8)
271 | image = (image / 127.5 - 1.0).astype(np.float32)
272 | instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32)
273 | mask = np.zeros((self.size // factor, self.size // factor))
274 | if scale > self.size:
275 | instance_image = image[top : top + inner, left : left + inner, :]
276 | mask = np.ones((self.size // factor, self.size // factor))
277 | else:
278 | instance_image[top : top + inner, left : left + inner, :] = image
279 | mask[
280 | top // factor + 1 : (top + scale) // factor - 1,
281 | left // factor + 1 : (left + scale) // factor - 1,
282 | ] = 1.0
283 | return instance_image, mask
284 |
285 | def __getitem__(self, index):
286 | example = {}
287 | instance_image, instance_prompt = self.instance_images_path[
288 | index % self.num_instance_images
289 | ]
290 | instance_image = Image.open(instance_image)
291 | if not instance_image.mode == "RGB":
292 | instance_image = instance_image.convert("RGB")
293 | instance_image = self.flip(instance_image)
294 |
295 | # apply resize augmentation and create a valid image region mask
296 | random_scale = self.size
297 | if self.aug:
298 | random_scale = (
299 | np.random.randint(self.size // 3, self.size + 1)
300 | if np.random.uniform() < 0.66
301 | else np.random.randint(int(1.2 * self.size), int(1.4 * self.size))
302 | )
303 | instance_image, mask = self.preprocess(
304 | instance_image, random_scale, self.interpolation
305 | )
306 |
307 | if random_scale < 0.6 * self.size:
308 | instance_prompt = (
309 | np.random.choice(["a far away ", "very small "]) + instance_prompt
310 | )
311 | elif random_scale > self.size:
312 | instance_prompt = (
313 | np.random.choice(["zoomed in ", "close up "]) + instance_prompt
314 | )
315 |
316 | example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1)
317 | example["mask"] = torch.from_numpy(mask)
318 | example["instance_prompt_ids"] = self.tokenizer(
319 | instance_prompt,
320 | truncation=True,
321 | padding="max_length",
322 | max_length=self.tokenizer.model_max_length,
323 | return_tensors="pt",
324 | ).input_ids
325 |
326 | if self.with_prior_preservation:
327 | class_image, class_prompt = self.class_images_path[
328 | index % self.num_class_images
329 | ]
330 | class_image = Image.open(class_image)
331 | if not class_image.mode == "RGB":
332 | class_image = class_image.convert("RGB")
333 | example["class_images"] = self.image_transforms(class_image)
334 | example["class_mask"] = torch.ones_like(example["mask"])
335 | example["class_prompt_ids"] = self.tokenizer(
336 | class_prompt,
337 | truncation=True,
338 | padding="max_length",
339 | max_length=self.tokenizer.model_max_length,
340 | return_tensors="pt",
341 | ).input_ids
342 |
343 | return example
344 |
345 |
346 | def save_new_embed(
347 | text_encoder,
348 | modifier_token_id,
349 | accelerator,
350 | args,
351 | output_dir,
352 | safe_serialization=True,
353 | ):
354 | """Saves the new token embeddings from the text encoder."""
355 | logger.info("Saving embeddings")
356 | learned_embeds = (
357 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight
358 | )
359 | for x, y in zip(modifier_token_id, args.modifier_token):
360 | learned_embeds_dict = {}
361 | learned_embeds_dict[y] = learned_embeds[x]
362 | filename = f"{output_dir}/{y}.bin"
363 |
364 | if safe_serialization:
365 | safetensors.torch.save_file(
366 | learned_embeds_dict, filename, metadata={"format": "pt"}
367 | )
368 | else:
369 | torch.save(learned_embeds_dict, filename)
370 |
371 |
372 | def parse_args(input_args=None):
373 | parser = argparse.ArgumentParser(description="Class Diffusion training script.")
374 | parser.add_argument(
375 | "--pretrained_model_name_or_path",
376 | type=str,
377 | default=None,
378 | required=True,
379 | help="Path to pretrained model or model identifier from huggingface.co/models.",
380 | )
381 | parser.add_argument(
382 | "--revision",
383 | type=str,
384 | default=None,
385 | required=False,
386 | help="Revision of pretrained model identifier from huggingface.co/models.",
387 | )
388 | parser.add_argument(
389 | "--variant",
390 | type=str,
391 | default=None,
392 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
393 | )
394 | parser.add_argument(
395 | "--tokenizer_name",
396 | type=str,
397 | default=None,
398 | help="Pretrained tokenizer name or path if not the same as model_name",
399 | )
400 | parser.add_argument(
401 | "--instance_data_dir",
402 | type=str,
403 | default=None,
404 | help="A folder containing the training data of instance images.",
405 | )
406 | parser.add_argument(
407 | "--class_data_dir",
408 | type=str,
409 | default=None,
410 | help="A folder containing the training data of class images.",
411 | )
412 | parser.add_argument(
413 | "--instance_prompt",
414 | type=str,
415 | default=None,
416 | help="The prompt with identifier specifying the instance",
417 | )
418 | parser.add_argument(
419 | "--class_prompt",
420 | type=str,
421 | default=None,
422 | help="The prompt to specify images in the same class as provided instance images.",
423 | )
424 | parser.add_argument(
425 | "--validation_prompt",
426 | type=str,
427 | default=None,
428 | help="A prompt that is used during validation to verify that the model is learning.",
429 | )
430 | parser.add_argument(
431 | "--num_validation_images",
432 | type=int,
433 | default=2,
434 | help="Number of images that should be generated during validation with `validation_prompt`.",
435 | )
436 | parser.add_argument(
437 | "--validation_steps",
438 | type=int,
439 | default=50,
440 | help=(
441 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
442 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
443 | ),
444 | )
445 | parser.add_argument(
446 | "--with_prior_preservation",
447 | default=False,
448 | action="store_true",
449 | help="Flag to add prior preservation loss.",
450 | )
451 | parser.add_argument(
452 | "--real_prior",
453 | default=False,
454 | action="store_true",
455 | help="real images as prior.",
456 | )
457 | parser.add_argument(
458 | "--prior_loss_weight",
459 | type=float,
460 | default=1.0,
461 | help="The weight of prior preservation loss.",
462 | )
463 | parser.add_argument(
464 | "--num_class_images",
465 | type=int,
466 | default=200,
467 | help=(
468 | "Minimal class images for prior preservation loss. If there are not enough images already present in"
469 | " class_data_dir, additional images will be sampled with class_prompt."
470 | ),
471 | )
472 | parser.add_argument(
473 | "--output_dir",
474 | type=str,
475 | default="custom-diffusion-model",
476 | help="The output directory where the model predictions and checkpoints will be written.",
477 | )
478 | parser.add_argument(
479 | "--seed", type=int, default=42, help="A seed for reproducible training."
480 | )
481 | parser.add_argument(
482 | "--resolution",
483 | type=int,
484 | default=512,
485 | help=(
486 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
487 | " resolution"
488 | ),
489 | )
490 | parser.add_argument(
491 | "--center_crop",
492 | default=False,
493 | action="store_true",
494 | help=(
495 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
496 | " cropped. The images will be resized to the resolution first before cropping."
497 | ),
498 | )
499 | parser.add_argument(
500 | "--train_batch_size",
501 | type=int,
502 | default=4,
503 | help="Batch size (per device) for the training dataloader.",
504 | )
505 | parser.add_argument(
506 | "--sample_batch_size",
507 | type=int,
508 | default=4,
509 | help="Batch size (per device) for sampling images.",
510 | )
511 | parser.add_argument("--num_train_epochs", type=int, default=1)
512 | parser.add_argument(
513 | "--max_train_steps",
514 | type=int,
515 | default=None,
516 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
517 | )
518 | parser.add_argument(
519 | "--checkpointing_steps",
520 | type=int,
521 | default=250,
522 | help=(
523 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
524 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
525 | " training using `--resume_from_checkpoint`."
526 | ),
527 | )
528 | parser.add_argument(
529 | "--checkpoints_total_limit",
530 | type=int,
531 | default=None,
532 | help=("Max number of checkpoints to store."),
533 | )
534 | parser.add_argument(
535 | "--resume_from_checkpoint",
536 | type=str,
537 | default=None,
538 | help=(
539 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
540 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
541 | ),
542 | )
543 | parser.add_argument(
544 | "--gradient_accumulation_steps",
545 | type=int,
546 | default=1,
547 | help="Number of updates steps to accumulate before performing a backward/update pass.",
548 | )
549 | parser.add_argument(
550 | "--gradient_checkpointing",
551 | action="store_true",
552 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
553 | )
554 | parser.add_argument(
555 | "--learning_rate",
556 | type=float,
557 | default=1e-5,
558 | help="Initial learning rate (after the potential warmup period) to use.",
559 | )
560 | parser.add_argument(
561 | "--scale_lr",
562 | action="store_true",
563 | default=False,
564 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
565 | )
566 | parser.add_argument(
567 | "--dataloader_num_workers",
568 | type=int,
569 | default=2,
570 | help=(
571 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
572 | ),
573 | )
574 | parser.add_argument(
575 | "--freeze_model",
576 | type=str,
577 | default="crossattn_kv",
578 | choices=["crossattn_kv", "crossattn"],
579 | help="crossattn to enable fine-tuning of all params in the cross attention",
580 | )
581 | parser.add_argument(
582 | "--lr_scheduler",
583 | type=str,
584 | default="constant",
585 | help=(
586 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
587 | ' "constant", "constant_with_warmup"]'
588 | ),
589 | )
590 | parser.add_argument(
591 | "--lr_warmup_steps",
592 | type=int,
593 | default=500,
594 | help="Number of steps for the warmup in the lr scheduler.",
595 | )
596 | parser.add_argument(
597 | "--use_8bit_adam",
598 | action="store_true",
599 | help="Whether or not to use 8-bit Adam from bitsandbytes.",
600 | )
601 | parser.add_argument(
602 | "--adam_beta1",
603 | type=float,
604 | default=0.9,
605 | help="The beta1 parameter for the Adam optimizer.",
606 | )
607 | parser.add_argument(
608 | "--adam_beta2",
609 | type=float,
610 | default=0.999,
611 | help="The beta2 parameter for the Adam optimizer.",
612 | )
613 | parser.add_argument(
614 | "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
615 | )
616 | parser.add_argument(
617 | "--adam_epsilon",
618 | type=float,
619 | default=1e-08,
620 | help="Epsilon value for the Adam optimizer",
621 | )
622 | parser.add_argument(
623 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
624 | )
625 | parser.add_argument(
626 | "--push_to_hub",
627 | action="store_true",
628 | help="Whether or not to push the model to the Hub.",
629 | )
630 | parser.add_argument(
631 | "--hub_token",
632 | type=str,
633 | default=None,
634 | help="The token to use to push to the Model Hub.",
635 | )
636 | parser.add_argument(
637 | "--hub_model_id",
638 | type=str,
639 | default=None,
640 | help="The name of the repository to keep in sync with the local `output_dir`.",
641 | )
642 | parser.add_argument(
643 | "--logging_dir",
644 | type=str,
645 | default="logs",
646 | help=(
647 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
648 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
649 | ),
650 | )
651 | parser.add_argument(
652 | "--allow_tf32",
653 | action="store_true",
654 | help=(
655 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
656 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
657 | ),
658 | )
659 | parser.add_argument(
660 | "--report_to",
661 | type=str,
662 | default="tensorboard",
663 | help=(
664 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
665 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
666 | ),
667 | )
668 | parser.add_argument(
669 | "--mixed_precision",
670 | type=str,
671 | default=None,
672 | choices=["no", "fp16", "bf16"],
673 | help=(
674 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
675 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
676 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
677 | ),
678 | )
679 | parser.add_argument(
680 | "--prior_generation_precision",
681 | type=str,
682 | default=None,
683 | choices=["no", "fp32", "fp16", "bf16"],
684 | help=(
685 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
686 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
687 | ),
688 | )
689 | parser.add_argument(
690 | "--concepts_list",
691 | type=str,
692 | default=None,
693 | help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.",
694 | )
695 | parser.add_argument(
696 | "--local_rank",
697 | type=int,
698 | default=-1,
699 | help="For distributed training: local_rank",
700 | )
701 | parser.add_argument(
702 | "--enable_xformers_memory_efficient_attention",
703 | action="store_true",
704 | help="Whether or not to use xformers.",
705 | )
706 | parser.add_argument(
707 | "--set_grads_to_none",
708 | action="store_true",
709 | help=(
710 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
711 | " behaviors, so disable this argument if it causes any problems. More info:"
712 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
713 | ),
714 | )
715 | parser.add_argument(
716 | "--modifier_token",
717 | type=str,
718 | default=None,
719 | help="A token to use as a modifier for the concept.",
720 | )
721 | parser.add_argument(
722 | "--initializer_token",
723 | type=str,
724 | default="ktn+pll+ucd",
725 | help="A token to use as initializer word.",
726 | )
727 | parser.add_argument(
728 | "--hflip", action="store_true", help="Apply horizontal flip data augmentation."
729 | )
730 | parser.add_argument(
731 | "--noaug",
732 | action="store_true",
733 | help="Dont apply augmentation during data augmentation when this flag is enabled.",
734 | )
735 | parser.add_argument(
736 | "--no_safe_serialization",
737 | action="store_true",
738 | help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.",
739 | )
740 | parser.add_argument("--use_spl", action="store_true")
741 | parser.add_argument("--spl_weight", type=float, default=1.0)
742 | parser.add_argument("--cls_token", type=str)
743 | parser.add_argument("--tracker_name", type=str, default="custom-diffusion")
744 |
745 | if input_args is not None:
746 | args = parser.parse_args(input_args)
747 | else:
748 | args = parser.parse_args()
749 |
750 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
751 | if env_local_rank != -1 and env_local_rank != args.local_rank:
752 | args.local_rank = env_local_rank
753 |
754 | if args.with_prior_preservation:
755 | if args.concepts_list is None:
756 | if args.class_data_dir is None:
757 | raise ValueError("You must specify a data directory for class images.")
758 | if args.class_prompt is None:
759 | raise ValueError("You must specify prompt for class images.")
760 | else:
761 | # logger is not available yet
762 | if args.class_data_dir is not None:
763 | warnings.warn(
764 | "You need not use --class_data_dir without --with_prior_preservation."
765 | )
766 | if args.class_prompt is not None:
767 | warnings.warn(
768 | "You need not use --class_prompt without --with_prior_preservation."
769 | )
770 |
771 | return args
772 |
773 |
774 | def main(args):
775 | if args.report_to == "wandb" and args.hub_token is not None:
776 | raise ValueError(
777 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
778 | " Please use `huggingface-cli login` to authenticate with the Hub."
779 | )
780 |
781 | logging_dir = Path(args.output_dir, args.logging_dir)
782 |
783 | accelerator_project_config = ProjectConfiguration(
784 | project_dir=args.output_dir, logging_dir=logging_dir
785 | )
786 |
787 | accelerator = Accelerator(
788 | gradient_accumulation_steps=args.gradient_accumulation_steps,
789 | mixed_precision=args.mixed_precision,
790 | log_with=args.report_to,
791 | project_config=accelerator_project_config,
792 | )
793 |
794 | if args.report_to == "wandb":
795 | if not is_wandb_available():
796 | raise ImportError(
797 | "Make sure to install wandb if you want to use it for logging during training."
798 | )
799 | import wandb
800 |
801 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
802 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
803 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
804 | # Make one log on every process with the configuration for debugging.
805 | logging.basicConfig(
806 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
807 | datefmt="%m/%d/%Y %H:%M:%S",
808 | level=logging.INFO,
809 | )
810 | logger.info(accelerator.state, main_process_only=False)
811 | if accelerator.is_local_main_process:
812 | transformers.utils.logging.set_verbosity_warning()
813 | diffusers.utils.logging.set_verbosity_info()
814 | else:
815 | transformers.utils.logging.set_verbosity_error()
816 | diffusers.utils.logging.set_verbosity_error()
817 |
818 | # We need to initialize the trackers we use, and also store our configuration.
819 | # The trackers initializes automatically on the main process.
820 | if accelerator.is_main_process:
821 | accelerator.init_trackers(args.tracker_name, config=vars(args))
822 |
823 | # If passed along, set the training seed now.
824 | if args.seed is not None:
825 | set_seed(args.seed)
826 | if args.concepts_list is None:
827 | args.concepts_list = [
828 | {
829 | "instance_prompt": args.instance_prompt,
830 | "class_prompt": args.class_prompt,
831 | "instance_data_dir": args.instance_data_dir,
832 | "class_data_dir": args.class_data_dir,
833 | }
834 | ]
835 | else:
836 | with open(args.concepts_list, "r") as f:
837 | args.concepts_list = json.load(f)
838 |
839 | # Generate class images if prior preservation is enabled.
840 | if args.with_prior_preservation:
841 | for i, concept in enumerate(args.concepts_list):
842 | class_images_dir = Path(concept["class_data_dir"])
843 | if not class_images_dir.exists():
844 | class_images_dir.mkdir(parents=True, exist_ok=True)
845 | if args.real_prior:
846 | assert (
847 | class_images_dir / "images"
848 | ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
849 | assert (
850 | len(list((class_images_dir / "images").iterdir()))
851 | == args.num_class_images
852 | ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
853 | assert (
854 | class_images_dir / "caption.txt"
855 | ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
856 | assert (
857 | class_images_dir / "images.txt"
858 | ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}"
859 | concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt")
860 | concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt")
861 | args.concepts_list[i] = concept
862 | accelerator.wait_for_everyone()
863 | else:
864 | cur_class_images = len(list(class_images_dir.iterdir()))
865 |
866 | if cur_class_images < args.num_class_images:
867 | torch_dtype = (
868 | torch.float16
869 | if accelerator.device.type == "cuda"
870 | else torch.float32
871 | )
872 | if args.prior_generation_precision == "fp32":
873 | torch_dtype = torch.float32
874 | elif args.prior_generation_precision == "fp16":
875 | torch_dtype = torch.float16
876 | elif args.prior_generation_precision == "bf16":
877 | torch_dtype = torch.bfloat16
878 | pipeline = DiffusionPipeline.from_pretrained(
879 | args.pretrained_model_name_or_path,
880 | torch_dtype=torch_dtype,
881 | safety_checker=None,
882 | revision=args.revision,
883 | variant=args.variant,
884 | )
885 | pipeline.set_progress_bar_config(disable=True)
886 |
887 | num_new_images = args.num_class_images - cur_class_images
888 | logger.info(f"Number of class images to sample: {num_new_images}.")
889 |
890 | sample_dataset = PromptDataset(
891 | concept["class_prompt"], num_new_images
892 | )
893 | sample_dataloader = torch.utils.data.DataLoader(
894 | sample_dataset, batch_size=args.sample_batch_size
895 | )
896 |
897 | sample_dataloader = accelerator.prepare(sample_dataloader)
898 | pipeline.to(accelerator.device)
899 |
900 | for example in tqdm(
901 | sample_dataloader,
902 | desc="Generating class images",
903 | disable=not accelerator.is_local_main_process,
904 | ):
905 | images = pipeline(example["prompt"]).images
906 |
907 | for i, image in enumerate(images):
908 | hash_image = insecure_hashlib.sha1(
909 | image.tobytes()
910 | ).hexdigest()
911 | image_filename = (
912 | class_images_dir
913 | / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
914 | )
915 | image.save(image_filename)
916 |
917 | del pipeline
918 | if torch.cuda.is_available():
919 | torch.cuda.empty_cache()
920 |
921 | # Handle the repository creation
922 | if accelerator.is_main_process:
923 | if args.output_dir is not None:
924 | os.makedirs(args.output_dir, exist_ok=True)
925 |
926 | if args.push_to_hub:
927 | repo_id = create_repo(
928 | repo_id=args.hub_model_id or Path(args.output_dir).name,
929 | exist_ok=True,
930 | token=args.hub_token,
931 | ).repo_id
932 |
933 | # Load the tokenizer
934 | if args.tokenizer_name:
935 | tokenizer = AutoTokenizer.from_pretrained(
936 | args.tokenizer_name,
937 | revision=args.revision,
938 | use_fast=False,
939 | )
940 | elif args.pretrained_model_name_or_path:
941 | tokenizer = AutoTokenizer.from_pretrained(
942 | args.pretrained_model_name_or_path,
943 | subfolder="tokenizer",
944 | revision=args.revision,
945 | use_fast=False,
946 | )
947 |
948 | # import correct text encoder class
949 | text_encoder_cls = import_model_class_from_model_name_or_path(
950 | args.pretrained_model_name_or_path, args.revision
951 | )
952 |
953 | # Load scheduler and models
954 | noise_scheduler = DDPMScheduler.from_pretrained(
955 | args.pretrained_model_name_or_path, subfolder="scheduler"
956 | )
957 | text_encoder = text_encoder_cls.from_pretrained(
958 | args.pretrained_model_name_or_path,
959 | subfolder="text_encoder",
960 | revision=args.revision,
961 | variant=args.variant,
962 | )
963 | vae = AutoencoderKL.from_pretrained(
964 | args.pretrained_model_name_or_path,
965 | subfolder="vae",
966 | revision=args.revision,
967 | variant=args.variant,
968 | )
969 | unet = UNet2DConditionModel.from_pretrained(
970 | args.pretrained_model_name_or_path,
971 | subfolder="unet",
972 | revision=args.revision,
973 | variant=args.variant,
974 | )
975 |
976 | # Add forward_with_custom_embeddings to TextModel
977 | text_encoder.text_model.forward_with_custom_embeddings = types.MethodType(
978 | forward_with_custom_embeddings, text_encoder.text_model
979 | )
980 |
981 | # Adding a modifier token which is optimized ####
982 | # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py
983 | modifier_token_id = []
984 | initializer_token_id = []
985 | if args.modifier_token is not None:
986 | args.modifier_token = args.modifier_token.split("+")
987 | args.initializer_token = args.initializer_token.split("+")
988 | if len(args.modifier_token) > len(args.initializer_token):
989 | raise ValueError(
990 | "You must specify + separated initializer token for each modifier token."
991 | )
992 | for modifier_token, initializer_token in zip(
993 | args.modifier_token, args.initializer_token[: len(args.modifier_token)]
994 | ):
995 | # Add the placeholder token in tokenizer
996 | num_added_tokens = tokenizer.add_tokens(modifier_token)
997 | if num_added_tokens == 0:
998 | raise ValueError(
999 | f"The tokenizer already contains the token {modifier_token}. Please pass a different"
1000 | " `modifier_token` that is not already in the tokenizer."
1001 | )
1002 |
1003 | # Convert the initializer_token, placeholder_token to ids
1004 | token_ids = tokenizer.encode([initializer_token], add_special_tokens=False)
1005 | # Check if initializer_token is a single token or a sequence of tokens
1006 | if len(token_ids) > 1:
1007 | raise ValueError("The initializer token must be a single token.")
1008 |
1009 | initializer_token_id.append(token_ids[0])
1010 | modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token))
1011 |
1012 | # Resize the token embeddings as we are adding new special tokens to the tokenizer
1013 | text_encoder.resize_token_embeddings(len(tokenizer))
1014 |
1015 | # Initialise the newly added placeholder token with the embeddings of the initializer token
1016 | token_embeds = text_encoder.get_input_embeddings().weight.data
1017 | for x, y in zip(modifier_token_id, initializer_token_id):
1018 | token_embeds[x] = token_embeds[y]
1019 |
1020 | print(modifier_token_id)
1021 |
1022 | if args.cls_token is not None:
1023 |
1024 | args.cls_token = args.cls_token.split("+")
1025 |
1026 | if len(args.modifier_token) != len(args.cls_token):
1027 | raise ValueError(
1028 | "num of Modifier token shoule be the same as the class token"
1029 | )
1030 |
1031 | modifier_cls_text_inputs = []
1032 | cls_text_inputs = []
1033 |
1034 | for modifier_token, cls_token in zip(args.modifier_token, args.cls_token):
1035 | print(modifier_token, cls_token)
1036 | modifier_cls_text_input = tokenizer(
1037 | "a photo of a " + modifier_token + " " + cls_token,
1038 | padding="longest",
1039 | return_tensors="pt",
1040 | )
1041 |
1042 | cls_text_input = tokenizer(
1043 | "a photo of a " + cls_token,
1044 | padding="longest",
1045 | return_tensors="pt",
1046 | )
1047 |
1048 | modifier_cls_text_inputs.append(modifier_cls_text_input)
1049 | cls_text_inputs.append(cls_text_input)
1050 |
1051 | use_attention_mask = (
1052 | hasattr(text_encoder.config, "use_attention_mask")
1053 | and text_encoder.config.use_attention_mask
1054 | )
1055 |
1056 | spl_loss = SPL(
1057 | text_encoder=text_encoder, use_attention_mask=use_attention_mask
1058 | )
1059 |
1060 | # Freeze all parameters except for the token embeddings in text encoder
1061 | params_to_freeze = itertools.chain(
1062 | text_encoder.text_model.encoder.parameters(),
1063 | text_encoder.text_model.final_layer_norm.parameters(),
1064 | text_encoder.text_model.embeddings.position_embedding.parameters(),
1065 | )
1066 | freeze_params(params_to_freeze)
1067 |
1068 | ########################################################
1069 | ########################################################
1070 |
1071 | vae.requires_grad_(False)
1072 | if args.modifier_token is None:
1073 | text_encoder.requires_grad_(False)
1074 | unet.requires_grad_(False)
1075 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
1076 | # as these models are only used for inference, keeping weights in full precision is not required.
1077 | weight_dtype = torch.float32
1078 | if accelerator.mixed_precision == "fp16":
1079 | weight_dtype = torch.float16
1080 | elif accelerator.mixed_precision == "bf16":
1081 | weight_dtype = torch.bfloat16
1082 |
1083 | # Move unet, vae and text_encoder to device and cast to weight_dtype
1084 | if accelerator.mixed_precision != "fp16" and args.modifier_token is not None:
1085 | text_encoder.to(accelerator.device, dtype=weight_dtype)
1086 | unet.to(accelerator.device, dtype=weight_dtype)
1087 | vae.to(accelerator.device, dtype=weight_dtype)
1088 |
1089 | attention_class = (
1090 | CustomDiffusionAttnProcessor2_0
1091 | if hasattr(F, "scaled_dot_product_attention")
1092 | else CustomDiffusionAttnProcessor
1093 | )
1094 | if args.enable_xformers_memory_efficient_attention:
1095 | if is_xformers_available():
1096 | import xformers
1097 |
1098 | xformers_version = version.parse(xformers.__version__)
1099 | if xformers_version == version.parse("0.0.16"):
1100 | logger.warn(
1101 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
1102 | )
1103 | attention_class = CustomDiffusionXFormersAttnProcessor
1104 | else:
1105 | raise ValueError(
1106 | "xformers is not available. Make sure it is installed correctly"
1107 | )
1108 |
1109 | # now we will add new Custom Diffusion weights to the attention layers
1110 | # It's important to realize here how many attention weights will be added and of which sizes
1111 | # The sizes of the attention layers consist only of two different variables:
1112 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
1113 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
1114 |
1115 | # Let's first see how many attention processors we will have to set.
1116 | # For Stable Diffusion, it should be equal to:
1117 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
1118 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
1119 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
1120 | # => 32 layers
1121 |
1122 | # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer
1123 | train_kv = True
1124 | train_q_out = False if args.freeze_model == "crossattn_kv" else True
1125 | custom_diffusion_attn_procs = {}
1126 |
1127 | st = unet.state_dict()
1128 | for name, _ in unet.attn_processors.items():
1129 | cross_attention_dim = (
1130 | None
1131 | if name.endswith("attn1.processor")
1132 | else unet.config.cross_attention_dim
1133 | )
1134 | if name.startswith("mid_block"):
1135 | hidden_size = unet.config.block_out_channels[-1]
1136 | elif name.startswith("up_blocks"):
1137 | block_id = int(name[len("up_blocks.")])
1138 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1139 | elif name.startswith("down_blocks"):
1140 | block_id = int(name[len("down_blocks.")])
1141 | hidden_size = unet.config.block_out_channels[block_id]
1142 | layer_name = name.split(".processor")[0]
1143 | weights = {
1144 | "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
1145 | "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
1146 | }
1147 | if train_q_out:
1148 | weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
1149 | weights["to_out_custom_diffusion.0.weight"] = st[
1150 | layer_name + ".to_out.0.weight"
1151 | ]
1152 | weights["to_out_custom_diffusion.0.bias"] = st[
1153 | layer_name + ".to_out.0.bias"
1154 | ]
1155 | if cross_attention_dim is not None:
1156 | custom_diffusion_attn_procs[name] = attention_class(
1157 | train_kv=train_kv,
1158 | train_q_out=train_q_out,
1159 | hidden_size=hidden_size,
1160 | cross_attention_dim=cross_attention_dim,
1161 | ).to(unet.device)
1162 | custom_diffusion_attn_procs[name].load_state_dict(weights)
1163 | else:
1164 | custom_diffusion_attn_procs[name] = attention_class(
1165 | train_kv=False,
1166 | train_q_out=False,
1167 | hidden_size=hidden_size,
1168 | cross_attention_dim=cross_attention_dim,
1169 | )
1170 | del st
1171 | unet.set_attn_processor(custom_diffusion_attn_procs)
1172 | custom_diffusion_layers = AttnProcsLayers(unet.attn_processors)
1173 |
1174 | accelerator.register_for_checkpointing(custom_diffusion_layers)
1175 |
1176 | if args.gradient_checkpointing:
1177 | unet.enable_gradient_checkpointing()
1178 | if args.modifier_token is not None:
1179 | text_encoder.gradient_checkpointing_enable()
1180 | # Enable TF32 for faster training on Ampere GPUs,
1181 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1182 | if args.allow_tf32:
1183 | torch.backends.cuda.matmul.allow_tf32 = True
1184 |
1185 | if args.scale_lr:
1186 | args.learning_rate = (
1187 | args.learning_rate
1188 | * args.gradient_accumulation_steps
1189 | * args.train_batch_size
1190 | * accelerator.num_processes
1191 | )
1192 | if args.with_prior_preservation:
1193 | args.learning_rate = args.learning_rate * 2.0
1194 |
1195 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1196 | if args.use_8bit_adam:
1197 | try:
1198 | import bitsandbytes as bnb
1199 | except ImportError:
1200 | raise ImportError(
1201 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1202 | )
1203 |
1204 | optimizer_class = bnb.optim.AdamW8bit
1205 | else:
1206 | optimizer_class = torch.optim.AdamW
1207 |
1208 | # Optimizer creation
1209 | optimizer = optimizer_class(
1210 | (
1211 | itertools.chain(
1212 | text_encoder.get_input_embeddings().parameters(),
1213 | custom_diffusion_layers.parameters(),
1214 | )
1215 | if args.modifier_token is not None
1216 | else custom_diffusion_layers.parameters()
1217 | ),
1218 | lr=args.learning_rate,
1219 | betas=(args.adam_beta1, args.adam_beta2),
1220 | weight_decay=args.adam_weight_decay,
1221 | eps=args.adam_epsilon,
1222 | )
1223 |
1224 | # Dataset and DataLoaders creation:
1225 | train_dataset = ClassDiffusionDataset(
1226 | concepts_list=args.concepts_list,
1227 | tokenizer=tokenizer,
1228 | with_prior_preservation=args.with_prior_preservation,
1229 | size=args.resolution,
1230 | mask_size=vae.encode(
1231 | torch.randn(1, 3, args.resolution, args.resolution)
1232 | .to(dtype=weight_dtype)
1233 | .to(accelerator.device)
1234 | )
1235 | .latent_dist.sample()
1236 | .size()[-1],
1237 | center_crop=args.center_crop,
1238 | num_class_images=args.num_class_images,
1239 | hflip=args.hflip,
1240 | aug=not args.noaug,
1241 | )
1242 |
1243 | train_dataloader = torch.utils.data.DataLoader(
1244 | train_dataset,
1245 | batch_size=args.train_batch_size,
1246 | shuffle=True,
1247 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
1248 | num_workers=args.dataloader_num_workers,
1249 | )
1250 |
1251 | # Scheduler and math around the number of training steps.
1252 | overrode_max_train_steps = False
1253 | num_update_steps_per_epoch = math.ceil(
1254 | len(train_dataloader) / args.gradient_accumulation_steps
1255 | )
1256 | if args.max_train_steps is None:
1257 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1258 | overrode_max_train_steps = True
1259 |
1260 | lr_scheduler = get_scheduler(
1261 | args.lr_scheduler,
1262 | optimizer=optimizer,
1263 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1264 | num_training_steps=args.max_train_steps * accelerator.num_processes,
1265 | )
1266 |
1267 | # Prepare everything with our `accelerator`.
1268 | if args.modifier_token is not None:
1269 | (
1270 | custom_diffusion_layers,
1271 | text_encoder,
1272 | optimizer,
1273 | train_dataloader,
1274 | lr_scheduler,
1275 | ) = accelerator.prepare(
1276 | custom_diffusion_layers,
1277 | text_encoder,
1278 | optimizer,
1279 | train_dataloader,
1280 | lr_scheduler,
1281 | )
1282 | else:
1283 | custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler = (
1284 | accelerator.prepare(
1285 | custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler
1286 | )
1287 | )
1288 |
1289 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1290 | num_update_steps_per_epoch = math.ceil(
1291 | len(train_dataloader) / args.gradient_accumulation_steps
1292 | )
1293 | if overrode_max_train_steps:
1294 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1295 | # Afterwards we recalculate our number of training epochs
1296 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1297 |
1298 | # Train!
1299 | total_batch_size = (
1300 | args.train_batch_size
1301 | * accelerator.num_processes
1302 | * args.gradient_accumulation_steps
1303 | )
1304 |
1305 | logger.info("***** Running training *****")
1306 | logger.info(f" Num examples = {len(train_dataset)}")
1307 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1308 | logger.info(f" Num Epochs = {args.num_train_epochs}")
1309 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1310 | logger.info(
1311 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
1312 | )
1313 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1314 | logger.info(f" Total optimization steps = {args.max_train_steps}")
1315 | global_step = 0
1316 | first_epoch = 0
1317 |
1318 | # Potentially load in the weights and states from a previous save
1319 | if args.resume_from_checkpoint:
1320 | if args.resume_from_checkpoint != "latest":
1321 | path = os.path.basename(args.resume_from_checkpoint)
1322 | else:
1323 | # Get the most recent checkpoint
1324 | dirs = os.listdir(args.output_dir)
1325 | dirs = [d for d in dirs if d.startswith("checkpoint")]
1326 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1327 | path = dirs[-1] if len(dirs) > 0 else None
1328 |
1329 | if path is None:
1330 | accelerator.print(
1331 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1332 | )
1333 | args.resume_from_checkpoint = None
1334 | initial_global_step = 0
1335 | else:
1336 | accelerator.print(f"Resuming from checkpoint {path}")
1337 | accelerator.load_state(os.path.join(args.output_dir, path))
1338 | global_step = int(path.split("-")[1])
1339 |
1340 | initial_global_step = global_step
1341 | first_epoch = global_step // num_update_steps_per_epoch
1342 | else:
1343 | initial_global_step = 0
1344 |
1345 | progress_bar = tqdm(
1346 | range(0, args.max_train_steps),
1347 | initial=initial_global_step,
1348 | desc="Steps",
1349 | # Only show the progress bar once on each machine.
1350 | disable=not accelerator.is_local_main_process,
1351 | )
1352 |
1353 | for epoch in range(first_epoch, args.num_train_epochs):
1354 | unet.train()
1355 | if args.modifier_token is not None:
1356 | text_encoder.train()
1357 | for step, batch in enumerate(train_dataloader):
1358 | with accelerator.accumulate(unet), accelerator.accumulate(text_encoder):
1359 | # Convert images to latent space
1360 | latents = vae.encode(
1361 | batch["pixel_values"].to(dtype=weight_dtype)
1362 | ).latent_dist.sample()
1363 | latents = latents * vae.config.scaling_factor
1364 |
1365 | # Sample noise that we'll add to the latents
1366 | noise = torch.randn_like(latents)
1367 | bsz = latents.shape[0]
1368 | # Sample a random timestep for each image
1369 | timesteps = torch.randint(
1370 | 0,
1371 | noise_scheduler.config.num_train_timesteps,
1372 | (bsz,),
1373 | device=latents.device,
1374 | )
1375 | timesteps = timesteps.long()
1376 |
1377 | # Add noise to the latents according to the noise magnitude at each timestep
1378 | # (this is the forward diffusion process)
1379 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
1380 |
1381 | # Get the text embedding for conditioning
1382 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1383 |
1384 | # Predict the noise residual
1385 | model_pred = unet(
1386 | noisy_latents, timesteps, encoder_hidden_states
1387 | ).sample
1388 |
1389 | # Get the target for loss depending on the prediction type
1390 | if noise_scheduler.config.prediction_type == "epsilon":
1391 | target = noise
1392 | elif noise_scheduler.config.prediction_type == "v_prediction":
1393 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
1394 | else:
1395 | raise ValueError(
1396 | f"Unknown prediction type {noise_scheduler.config.prediction_type}"
1397 | )
1398 |
1399 | if args.with_prior_preservation:
1400 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
1401 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
1402 | target, target_prior = torch.chunk(target, 2, dim=0)
1403 | mask = torch.chunk(batch["mask"], 2, dim=0)[0]
1404 | # Compute instance loss
1405 | loss = F.mse_loss(
1406 | model_pred.float(), target.float(), reduction="none"
1407 | )
1408 | loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
1409 |
1410 | # Compute prior loss
1411 | prior_loss = F.mse_loss(
1412 | model_pred_prior.float(), target_prior.float(), reduction="mean"
1413 | )
1414 |
1415 | # Add the prior loss to the instance loss.
1416 | loss = loss + args.prior_loss_weight * prior_loss
1417 | elif args.use_spl:
1418 |
1419 | mask = batch["mask"]
1420 | loss = F.mse_loss(
1421 | model_pred.float(), target.float(), reduction="none"
1422 | )
1423 | loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
1424 |
1425 | if accelerator.num_processes > 1:
1426 | text_embeddings = (
1427 | text_encoder.module.get_input_embeddings().weight
1428 | )
1429 | else:
1430 | text_embeddings = text_encoder.get_input_embeddings().weight
1431 |
1432 | spl_l = 0.0
1433 |
1434 | for i in range(len(modifier_token_id)):
1435 | spl_l = (
1436 | spl_l
1437 | + spl_loss(
1438 | text_embeddings,
1439 | modifier_token_id[i],
1440 | modifier_cls_text_inputs[i],
1441 | cls_text_inputs[i],
1442 | ).mean()
1443 | )
1444 |
1445 | spl_l = spl_l / len(modifier_token_id)
1446 |
1447 | spl_l = args.spl_weight * spl_l
1448 |
1449 | loss = loss + spl_l
1450 |
1451 | else:
1452 | mask = batch["mask"]
1453 | loss = F.mse_loss(
1454 | model_pred.float(), target.float(), reduction="none"
1455 | )
1456 | loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean()
1457 | accelerator.backward(loss)
1458 | # Zero out the gradients for all token embeddings except the newly added
1459 | # embeddings for the concept, as we only want to optimize the concept embeddings
1460 | if args.modifier_token is not None:
1461 | if accelerator.num_processes > 1:
1462 | grads_text_encoder = (
1463 | text_encoder.module.get_input_embeddings().weight.grad
1464 | )
1465 | else:
1466 | grads_text_encoder = (
1467 | text_encoder.get_input_embeddings().weight.grad
1468 | )
1469 | # Get the index for tokens that we want to zero the grads for
1470 | index_grads_to_zero = (
1471 | torch.arange(len(tokenizer)) != modifier_token_id[0]
1472 | )
1473 | for i in range(len(modifier_token_id[1:])):
1474 | index_grads_to_zero = index_grads_to_zero & (
1475 | torch.arange(len(tokenizer)) != modifier_token_id[i + 1]
1476 | )
1477 |
1478 | grads_text_encoder.data[index_grads_to_zero, :] = (
1479 | grads_text_encoder.data[index_grads_to_zero, :].fill_(0)
1480 | )
1481 |
1482 | if accelerator.sync_gradients:
1483 | params_to_clip = (
1484 | itertools.chain(
1485 | text_encoder.parameters(),
1486 | custom_diffusion_layers.parameters(),
1487 | )
1488 | if args.modifier_token is not None
1489 | else custom_diffusion_layers.parameters()
1490 | )
1491 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1492 |
1493 | optimizer.step()
1494 | lr_scheduler.step()
1495 | optimizer.zero_grad(set_to_none=args.set_grads_to_none)
1496 |
1497 | # Checks if the accelerator has performed an optimization step behind the scenes
1498 | if accelerator.sync_gradients:
1499 | progress_bar.update(1)
1500 | global_step += 1
1501 |
1502 | if global_step % args.checkpointing_steps == 0:
1503 | if accelerator.is_main_process:
1504 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1505 | if args.checkpoints_total_limit is not None:
1506 | checkpoints = os.listdir(args.output_dir)
1507 | checkpoints = [
1508 | d for d in checkpoints if d.startswith("checkpoint")
1509 | ]
1510 | checkpoints = sorted(
1511 | checkpoints, key=lambda x: int(x.split("-")[1])
1512 | )
1513 |
1514 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1515 | if len(checkpoints) >= args.checkpoints_total_limit:
1516 | num_to_remove = (
1517 | len(checkpoints) - args.checkpoints_total_limit + 1
1518 | )
1519 | removing_checkpoints = checkpoints[0:num_to_remove]
1520 |
1521 | logger.info(
1522 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1523 | )
1524 | logger.info(
1525 | f"removing checkpoints: {', '.join(removing_checkpoints)}"
1526 | )
1527 |
1528 | for removing_checkpoint in removing_checkpoints:
1529 | removing_checkpoint = os.path.join(
1530 | args.output_dir, removing_checkpoint
1531 | )
1532 | shutil.rmtree(removing_checkpoint)
1533 |
1534 | save_path = os.path.join(
1535 | args.output_dir, f"checkpoint-{global_step}"
1536 | )
1537 | accelerator.save_state(save_path)
1538 | logger.info(f"Saved state to {save_path}")
1539 |
1540 | logs = {
1541 | "loss": loss.clone().detach().item(),
1542 | "lr": lr_scheduler.get_last_lr()[0],
1543 | }
1544 | if args.use_spl:
1545 | logs["spl"] = spl_l.clone().detach().item() / args.spl_weight
1546 | logs["recon_loss"] = logs["loss"] - logs["spl"] * args.spl_weight
1547 |
1548 | progress_bar.set_postfix(**logs)
1549 | accelerator.log(logs, step=global_step)
1550 |
1551 | if global_step >= args.max_train_steps:
1552 | break
1553 |
1554 | if accelerator.is_main_process:
1555 | images = []
1556 |
1557 | if (
1558 | args.validation_prompt is not None
1559 | and global_step % args.validation_steps == 0
1560 | ):
1561 | logger.info(
1562 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1563 | f" {args.validation_prompt}."
1564 | )
1565 | # create pipeline
1566 | pipeline = DiffusionPipeline.from_pretrained(
1567 | args.pretrained_model_name_or_path,
1568 | unet=accelerator.unwrap_model(unet),
1569 | text_encoder=accelerator.unwrap_model(text_encoder),
1570 | tokenizer=tokenizer,
1571 | revision=args.revision,
1572 | variant=args.variant,
1573 | torch_dtype=weight_dtype,
1574 | )
1575 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1576 | pipeline.scheduler.config
1577 | )
1578 | pipeline = pipeline.to(accelerator.device)
1579 | pipeline.set_progress_bar_config(disable=True)
1580 |
1581 | # run inference
1582 | generator = torch.Generator(device=accelerator.device).manual_seed(
1583 | args.seed
1584 | )
1585 | images = [
1586 | pipeline(
1587 | args.validation_prompt,
1588 | num_inference_steps=25,
1589 | generator=generator,
1590 | eta=1.0,
1591 | ).images[0]
1592 | for _ in range(args.num_validation_images)
1593 | ]
1594 |
1595 | for tracker in accelerator.trackers:
1596 | if tracker.name == "tensorboard":
1597 | np_images = np.stack([np.asarray(img) for img in images])
1598 | tracker.writer.add_images(
1599 | "validation", np_images, global_step, dataformats="NHWC"
1600 | )
1601 | if tracker.name == "wandb":
1602 | tracker.log(
1603 | {
1604 | "validation": [
1605 | wandb.Image(
1606 | image,
1607 | caption=f"{i}: {args.validation_prompt}",
1608 | )
1609 | for i, image in enumerate(images)
1610 | ]
1611 | }
1612 | )
1613 |
1614 | del pipeline
1615 | torch.cuda.empty_cache()
1616 |
1617 | accelerator.wait_for_everyone()
1618 | if accelerator.is_main_process:
1619 | unet = unet.to(torch.float32)
1620 | unet.save_attn_procs(
1621 | args.output_dir, safe_serialization=not args.no_safe_serialization
1622 | )
1623 | save_new_embed(
1624 | text_encoder,
1625 | modifier_token_id,
1626 | accelerator,
1627 | args,
1628 | args.output_dir,
1629 | safe_serialization=not args.no_safe_serialization,
1630 | )
1631 |
1632 | # Final inference
1633 | # Load previous pipeline
1634 | pipeline = DiffusionPipeline.from_pretrained(
1635 | args.pretrained_model_name_or_path,
1636 | revision=args.revision,
1637 | variant=args.variant,
1638 | torch_dtype=weight_dtype,
1639 | )
1640 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
1641 | pipeline.scheduler.config
1642 | )
1643 | pipeline = pipeline.to(accelerator.device)
1644 |
1645 | # load attention processors
1646 | weight_name = (
1647 | "pytorch_custom_diffusion_weights.safetensors"
1648 | if not args.no_safe_serialization
1649 | else "pytorch_custom_diffusion_weights.bin"
1650 | )
1651 | pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name)
1652 | for token in args.modifier_token:
1653 | token_weight_name = (
1654 | f"{token}.safetensors"
1655 | if not args.no_safe_serialization
1656 | else f"{token}.bin"
1657 | )
1658 | pipeline.load_textual_inversion(
1659 | args.output_dir, weight_name=token_weight_name
1660 | )
1661 |
1662 | # run inference
1663 | if args.validation_prompt and args.num_validation_images > 0:
1664 | generator = (
1665 | torch.Generator(device=accelerator.device).manual_seed(args.seed)
1666 | if args.seed
1667 | else None
1668 | )
1669 | images = [
1670 | pipeline(
1671 | args.validation_prompt,
1672 | num_inference_steps=25,
1673 | generator=generator,
1674 | eta=1.0,
1675 | ).images[0]
1676 | for _ in range(args.num_validation_images)
1677 | ]
1678 |
1679 | for tracker in accelerator.trackers:
1680 | if tracker.name == "tensorboard":
1681 | np_images = np.stack([np.asarray(img) for img in images])
1682 | tracker.writer.add_images(
1683 | "test", np_images, epoch, dataformats="NHWC"
1684 | )
1685 | if tracker.name == "wandb":
1686 | tracker.log(
1687 | {
1688 | "test": [
1689 | wandb.Image(
1690 | image, caption=f"{i}: {args.validation_prompt}"
1691 | )
1692 | for i, image in enumerate(images)
1693 | ]
1694 | }
1695 | )
1696 |
1697 | if args.push_to_hub:
1698 | save_model_card(
1699 | repo_id,
1700 | images=images,
1701 | base_model=args.pretrained_model_name_or_path,
1702 | prompt=args.instance_prompt,
1703 | repo_folder=args.output_dir,
1704 | )
1705 | api = HfApi(token=args.hub_token)
1706 | api.upload_folder(
1707 | repo_id=repo_id,
1708 | folder_path=args.output_dir,
1709 | commit_message="End of training",
1710 | ignore_patterns=["step_*", "epoch_*"],
1711 | )
1712 |
1713 | accelerator.end_training()
1714 |
1715 |
1716 | if __name__ == "__main__":
1717 | args = parse_args()
1718 | main(args)
1719 |
--------------------------------------------------------------------------------
/utils/blip2t.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import BlipProcessor, BlipForImageTextRetrieval
3 |
4 |
5 | class BLIP2T:
6 |
7 | def __init__(self, model_name, device):
8 | self.device = device
9 | self.processor = BlipProcessor.from_pretrained(model_name)
10 | self.model = BlipForImageTextRetrieval.from_pretrained(
11 | model_name,
12 | torch_dtype=torch.float16,
13 | ).to(device)
14 |
15 | @torch.no_grad()
16 | def text_similarity(self, prompt, image):
17 | """
18 | Calculate text similarity between prompt and image
19 |
20 | Args:
21 | prompt: str
22 | image: PIL.Image
23 |
24 | Return
25 | score: float
26 | """
27 | inputs = self.processor(image, prompt, return_tensors="pt").to(
28 | self.device, torch.float16
29 | )
30 | scores = self.model(**inputs, use_itm_head=False)[0]
31 |
32 | if self.device == "cpu":
33 | scores = scores.detach().numpy()[0]
34 | else:
35 | scores = scores.detach().cpu().numpy()[0]
36 |
37 | return scores
38 |
--------------------------------------------------------------------------------
/utils/load_attn_weight.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import warnings
3 | import safetensors
4 | from typing import Callable, Dict, List, Optional, Union
5 |
6 | from diffusers.utils import (
7 | USE_PEFT_BACKEND,
8 | _get_model_file,
9 | )
10 |
11 | TEXT_ENCODER_NAME = "text_encoder"
12 | UNET_NAME = "unet"
13 |
14 | LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
15 | LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
16 |
17 | CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
18 | CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
19 |
20 |
21 | def load_custom_attn_param(
22 | self,
23 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
24 | **kwargs
25 | ):
26 | r"""
27 | Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
28 | defined in
29 | [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
30 | and be a `torch.nn.Module` class.
31 |
32 | Parameters:
33 | pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
34 | Can be either:
35 |
36 | - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
37 | the Hub.
38 | - A path to a directory (for example `./my_model_directory`) containing the model weights saved
39 | with [`ModelMixin.save_pretrained`].
40 | - A [torch state
41 | dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
42 |
43 | cache_dir (`Union[str, os.PathLike]`, *optional*):
44 | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
45 | is not used.
46 | force_download (`bool`, *optional*, defaults to `False`):
47 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the
48 | cached versions if they exist.
49 | resume_download (`bool`, *optional*, defaults to `False`):
50 | Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
51 | incompletely downloaded files are deleted.
52 | proxies (`Dict[str, str]`, *optional*):
53 | A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
54 | 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
55 | local_files_only (`bool`, *optional*, defaults to `False`):
56 | Whether to only load local model weights and configuration files or not. If set to `True`, the model
57 | won't be downloaded from the Hub.
58 | token (`str` or *bool*, *optional*):
59 | The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
60 | `diffusers-cli login` (stored in `~/.huggingface`) is used.
61 | low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
62 | Speed up model loading only loading the pretrained weights and not initializing the weights. This also
63 | tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
64 | Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
65 | argument to `True` will raise an error.
66 | revision (`str`, *optional*, defaults to `"main"`):
67 | The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
68 | allowed by Git.
69 | subfolder (`str`, *optional*, defaults to `""`):
70 | The subfolder location of a model file within a larger model repository on the Hub or locally.
71 | mirror (`str`, *optional*):
72 | Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
73 | guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
74 | information.
75 |
76 | Example:
77 |
78 | ```py
79 | from diffusers import AutoPipelineForText2Image
80 | import torch
81 |
82 | pipeline = AutoPipelineForText2Image.from_pretrained(
83 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
84 | ).to("cuda")
85 | pipeline.unet.load_attn_procs(
86 | "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
87 | )
88 | ```
89 | """
90 | from diffusers.models.attention_processor import CustomDiffusionAttnProcessor
91 | from diffusers.models.lora import (
92 | LoRACompatibleConv,
93 | LoRACompatibleLinear,
94 | LoRAConv2dLayer,
95 | LoRALinearLayer,
96 | )
97 |
98 | cache_dir = kwargs.pop("cache_dir", None)
99 | force_download = kwargs.pop("force_download", False)
100 | resume_download = kwargs.pop("resume_download", False)
101 | proxies = kwargs.pop("proxies", None)
102 | local_files_only = kwargs.pop("local_files_only", None)
103 | token = kwargs.pop("token", None)
104 | revision = kwargs.pop("revision", None)
105 | subfolder = kwargs.pop("subfolder", None)
106 | weight_name = kwargs.pop("weight_name", None)
107 | use_safetensors = kwargs.pop("use_safetensors", None)
108 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
109 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
110 | network_alphas = kwargs.pop("network_alphas", None)
111 |
112 | _pipeline = kwargs.pop("_pipeline", None)
113 | use_custom_diffusion = kwargs.pop("custom_diffusion", None)
114 |
115 | is_network_alphas_none = network_alphas is None
116 |
117 | allow_pickle = False
118 |
119 | if use_safetensors is None:
120 | use_safetensors = True
121 | allow_pickle = True
122 |
123 | user_agent = {
124 | "file_type": "attn_procs_weights",
125 | "framework": "pytorch",
126 | }
127 |
128 | model_file = None
129 | if not isinstance(pretrained_model_name_or_path_or_dict, dict):
130 | # Let's first try to load .safetensors weights
131 | if (use_safetensors and weight_name is None) or (
132 | weight_name is not None and weight_name.endswith(".safetensors")
133 | ):
134 | try:
135 | model_file = _get_model_file(
136 | pretrained_model_name_or_path_or_dict,
137 | weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
138 | cache_dir=cache_dir,
139 | force_download=force_download,
140 | resume_download=resume_download,
141 | proxies=proxies,
142 | local_files_only=local_files_only,
143 | token=token,
144 | revision=revision,
145 | subfolder=subfolder,
146 | user_agent=user_agent,
147 | )
148 | state_dict = safetensors.torch.load_file(model_file, device="cpu")
149 | except IOError as e:
150 | if not allow_pickle:
151 | raise e
152 | # try loading non-safetensors weights
153 | pass
154 | if model_file is None:
155 | model_file = _get_model_file(
156 | pretrained_model_name_or_path_or_dict,
157 | weights_name=weight_name or LORA_WEIGHT_NAME,
158 | cache_dir=cache_dir,
159 | force_download=force_download,
160 | resume_download=resume_download,
161 | proxies=proxies,
162 | local_files_only=local_files_only,
163 | token=token,
164 | revision=revision,
165 | subfolder=subfolder,
166 | user_agent=user_agent,
167 | )
168 | state_dict = torch.load(model_file, map_location="cpu")
169 | else:
170 | state_dict = pretrained_model_name_or_path_or_dict
171 |
172 | # is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
173 | is_custom_diffusion = (
174 | any("custom_diffusion" in k for k in state_dict.keys()) or use_custom_diffusion
175 | )
176 |
177 | assert is_custom_diffusion == True
178 |
179 | unet_state_dict = self.state_dict()
180 |
181 | for name, weight in state_dict.items():
182 | if name.endswith("weight"):
183 | name = name.split(".")
184 | module_unet = (
185 | ".".join(name[:-3]) + ".to_" + name[-2].split("_")[1] + ".weight"
186 | )
187 | unet_state_dict[module_unet] = weight
188 |
189 | self.load_state_dict(unet_state_dict)
190 |
191 | warn_messages = "you are loading custom diffusion weights into a U-Net without the official attention processor, which may cause potential issues. Please ensure you understand the implications of your actions."
192 | warnings.warn(warn_messages)
193 |
--------------------------------------------------------------------------------
/utils/semantic_preservation_loss.py:
--------------------------------------------------------------------------------
1 | from ast import Dict
2 | from dataclasses import dataclass
3 | from typing import Any, Optional, Tuple, Union
4 |
5 | import torch
6 | import torch.utils.checkpoint
7 | from torch import nn
8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
9 | import torch.nn.functional as F
10 |
11 | from transformers.modeling_outputs import (
12 | BaseModelOutput,
13 | BaseModelOutputWithPooling,
14 | ImageClassifierOutput,
15 | )
16 |
17 | from transformers.modeling_attn_mask_utils import (
18 | _create_4d_causal_attention_mask,
19 | _prepare_4d_attention_mask,
20 | )
21 |
22 | from transformers import CLIPTextModel
23 |
24 |
25 | def forward_with_custom_embeddings(
26 | self,
27 | input_ids: Optional[torch.Tensor] = None,
28 | attention_mask: Optional[torch.Tensor] = None,
29 | position_ids: Optional[torch.Tensor] = None,
30 | output_attentions: Optional[bool] = None,
31 | output_hidden_states: Optional[bool] = None,
32 | return_dict: Optional[bool] = None,
33 | input_modifier_embeddings: Optional[torch.Tensor] = None,
34 | modifier_token_id: Optional[torch.Tensor] = None,
35 | ) -> Union[Tuple, BaseModelOutputWithPooling]:
36 | r"""
37 | Returns:
38 |
39 | """
40 | output_attentions = (
41 | output_attentions
42 | if output_attentions is not None
43 | else self.config.output_attentions
44 | )
45 | output_hidden_states = (
46 | output_hidden_states
47 | if output_hidden_states is not None
48 | else self.config.output_hidden_states
49 | )
50 | return_dict = (
51 | return_dict if return_dict is not None else self.config.use_return_dict
52 | )
53 |
54 | use_custom_embeddings = (
55 | input_modifier_embeddings is not None and modifier_token_id is not None
56 | )
57 |
58 | if input_ids is None:
59 | raise ValueError("You have to specify input_ids")
60 |
61 | input_shape = input_ids.size()
62 | input_ids = input_ids.view(-1, input_shape[-1])
63 |
64 | hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
65 |
66 | if use_custom_embeddings:
67 | modifier_index = torch.where(input_ids.squeeze(0) == modifier_token_id)
68 | hidden_states[0, modifier_index, :] = input_modifier_embeddings
69 |
70 | # CLIP's text model uses causal mask, prepare it here.
71 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
72 | causal_attention_mask = _create_4d_causal_attention_mask(
73 | input_shape, hidden_states.dtype, device=hidden_states.device
74 | )
75 | # expand attention_mask
76 | if attention_mask is not None:
77 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
78 | attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
79 |
80 | encoder_outputs = self.encoder(
81 | inputs_embeds=hidden_states,
82 | attention_mask=attention_mask,
83 | causal_attention_mask=causal_attention_mask,
84 | output_attentions=output_attentions,
85 | output_hidden_states=output_hidden_states,
86 | return_dict=return_dict,
87 | )
88 |
89 | last_hidden_state = encoder_outputs[0]
90 | last_hidden_state = self.final_layer_norm(last_hidden_state)
91 |
92 | if self.eos_token_id == 2:
93 | # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
94 | # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
95 | # ------------------------------------------------------------
96 | # text_embeds.shape = [batch_size, sequence_length, transformer.width]
97 | # take features from the eot embedding (eot_token is the highest number in each sequence)
98 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
99 | pooled_output = last_hidden_state[
100 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
101 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(
102 | dim=-1
103 | ),
104 | ]
105 | else:
106 | # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
107 | pooled_output = last_hidden_state[
108 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
109 | # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
110 | (
111 | input_ids.to(dtype=torch.int, device=last_hidden_state.device)
112 | == self.eos_token_id
113 | )
114 | .int()
115 | .argmax(dim=-1),
116 | ]
117 |
118 | if not return_dict:
119 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
120 |
121 | return BaseModelOutputWithPooling(
122 | last_hidden_state=last_hidden_state,
123 | pooler_output=pooled_output,
124 | hidden_states=encoder_outputs.hidden_states,
125 | attentions=encoder_outputs.attentions,
126 | )
127 |
128 |
129 | class SPL:
130 |
131 | def __init__(
132 | self,
133 | text_encoder: CLIPTextModel,
134 | use_attention_mask: Optional[bool] = None,
135 | ):
136 | self.text_encoder = text_encoder
137 | self.use_attention_mask = use_attention_mask
138 |
139 | def _encode(
140 | self,
141 | text_input: Optional[Dict] = None,
142 | input_modifier_embeddings: Optional[torch.Tensor] = None,
143 | modifier_token_id: Optional[torch.Tensor] = None,
144 | ):
145 |
146 | text_input_ids = text_input.input_ids
147 | device = self.text_encoder.device
148 |
149 | if (
150 | hasattr(self.text_encoder.config, "use_attention_mask")
151 | and self.text_encoder.config.use_attention_mask
152 | ):
153 | attention_mask = text_input.attention_mask.to(device)
154 | else:
155 | attention_mask = None
156 |
157 | prompt_embeds = self.text_encoder.text_model.forward_with_custom_embeddings(
158 | text_input_ids.to(device),
159 | attention_mask=attention_mask,
160 | input_modifier_embeddings=input_modifier_embeddings,
161 | modifier_token_id=modifier_token_id,
162 | )
163 |
164 | prompt_embeds = prompt_embeds.pooler_output
165 |
166 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
167 |
168 | return prompt_embeds
169 |
170 | def __call__(
171 | self,
172 | text_embeddings: Optional[torch.Tensor],
173 | modifier_token_id: Optional[torch.Tensor],
174 | modifier_cls_text_input: Optional[Dict],
175 | cls_text_input: Optional[Dict],
176 | ) -> torch.Tensor:
177 |
178 | modifier_embedding = text_embeddings[modifier_token_id].clone()
179 |
180 | modifier_cls_embedding = self._encode(
181 | text_input=modifier_cls_text_input,
182 | input_modifier_embeddings=modifier_embedding,
183 | modifier_token_id=modifier_token_id,
184 | )
185 |
186 | cls_embedding = self._encode(text_input=cls_text_input)
187 |
188 | dis = F.cosine_similarity(modifier_cls_embedding, cls_embedding)
189 |
190 | dis = 1 - dis
191 |
192 | return dis
193 |
--------------------------------------------------------------------------------
/utils/similarity.py:
--------------------------------------------------------------------------------
1 | import clip
2 | import torch
3 | from PIL import Image
4 |
5 |
6 | class Similarity:
7 |
8 | def __init__(self, model_name, device):
9 | self.device = device
10 | self.model, self.preprocess = clip.load(model_name, device)
11 |
12 | @torch.no_grad()
13 | def text_similarity(self, prompt, image):
14 | """
15 | Calculate text similarity between prompt and image
16 |
17 | Args:
18 | prompt: str
19 | image: PIL.Image
20 |
21 | Return
22 | score: float
23 | """
24 | image = self.preprocess(image).unsqueeze(0).to(self.device)
25 | text = clip.tokenize([prompt]).to(self.device)
26 |
27 | image_features = self.model.encode_image(image)
28 | text_features = self.model.encode_text(text)
29 |
30 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
31 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
32 |
33 | similarity = torch.matmul(text_features, image_features.T).squeeze()
34 |
35 | score = similarity.detach().cpu().numpy()
36 |
37 | return score
38 |
39 | def image_similarity(self, source, generate):
40 | """
41 | Calculate image similarity between source image and generate image
42 |
43 | Args:
44 | prompt: PIL.Image
45 | image: PIL.Image
46 |
47 | Return
48 | score: float
49 | """
50 | source = self.preprocess(source).unsqueeze(0).to(self.device)
51 | generate = self.preprocess(generate).unsqueeze(0).to(self.device)
52 |
53 | image_features_source = self.model.encode_image(source)
54 | image_features_generate = self.model.encode_image(generate)
55 |
56 | image_features_source = image_features_source / image_features_source.norm(
57 | dim=-1, keepdim=True
58 | )
59 | image_features_generate = (
60 | image_features_generate / image_features_generate.norm(dim=-1, keepdim=True)
61 | )
62 |
63 | similarity = torch.matmul(
64 | image_features_source, image_features_generate.T
65 | ).squeeze()
66 |
67 | score = similarity.detach().cpu().numpy()
68 |
69 | return score
70 |
--------------------------------------------------------------------------------
/videogen.py:
--------------------------------------------------------------------------------
1 | # import torch
2 |
3 | import types
4 | import torch
5 | from utils.load_attn_weight import load_custom_attn_param
6 |
7 | from diffusers import (
8 | AnimateDiffPipeline,
9 | DDIMScheduler,
10 | MotionAdapter,
11 | DiffusionPipeline,
12 | )
13 | from diffusers.utils import export_to_gif
14 |
15 | model_id = "runwayml/stable-diffusion-v1-5"
16 | ckpt_path = "your_ckpt_path" # TODO: modify here
17 |
18 | pipeline = DiffusionPipeline.from_pretrained(
19 | model_id,
20 | torch_dtype=torch.float16,
21 | ).to("cuda")
22 |
23 | pipeline.unet.load_custom_attn_param = types.MethodType(
24 | load_custom_attn_param, pipeline.unet
25 | )
26 | pipeline.unet.load_custom_attn_param(
27 | ckpt_path,
28 | weight_name="pytorch_custom_diffusion_weights.bin",
29 | )
30 | pipeline.load_textual_inversion(ckpt_path, weight_name=".bin")
31 |
32 | adapter = MotionAdapter.from_pretrained(
33 | "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16
34 | )
35 | pipe = AnimateDiffPipeline.from_pretrained(
36 | model_id,
37 | motion_adapter=adapter,
38 | unet=pipeline.unet,
39 | text_encoder=pipeline.text_encoder,
40 | tokenizer=pipeline.tokenizer,
41 | torch_dtype=torch.float16,
42 | )
43 | scheduler = DDIMScheduler.from_pretrained(
44 | model_id,
45 | subfolder="scheduler",
46 | clip_sample=False,
47 | timestep_spacing="linspace",
48 | beta_schedule="linear",
49 | steps_offset=1,
50 | )
51 | pipe.scheduler = scheduler
52 | pipe.enable_vae_slicing()
53 | pipe.enable_model_cpu_offload()
54 |
55 | output = pipe(
56 | prompt=(" dog running on the street"),
57 | negative_prompt="bad quality, worse quality",
58 | num_frames=16,
59 | guidance_scale=7.5,
60 | num_inference_steps=25,
61 | )
62 | frames = output.frames[0]
63 | export_to_gif(frames, "test.gif", fps=4)
64 |
--------------------------------------------------------------------------------