├── .gitignore ├── README.md ├── assert ├── .DS_Store └── img │ ├── .DS_Store │ ├── multiple.png │ ├── singleconcept.png │ ├── story.jpg │ └── video.jpg ├── blip2t.py ├── config └── concept.json ├── data ├── barn │ ├── cora-leach-UlKxoeGP4Rc-unsplash.jpg │ ├── emmy-sobieski-pPoR4vyOomY-unsplash.jpg │ ├── isaac-martin-ZCAiJPiO_98-unsplash.jpg │ ├── jonah-brown-eYAsBnNISa4-unsplash.jpg │ ├── melissa-cronin-P8Bfpi1HLIE-unsplash.jpg │ ├── pedro-lastra-hXunh-ivkPc-unsplash.jpg │ └── sarat-karumuri-vs1voV4jhpU-unsplash.jpg ├── bear │ ├── marina-shatskih-6MDi8o6VYHg-unsplash.jpg │ ├── marina-shatskih-B3Jnug3rKrs-unsplash.jpg │ ├── marina-shatskih-BYZdCQDSNTY-unsplash.jpg │ ├── marina-shatskih-F_imBoYwuTk-unsplash.jpg │ ├── marina-shatskih-kBo2MFJz2QU-unsplash.jpg │ ├── marina-shatskih-nKoIXtEfkZ8-unsplash.jpg │ └── marina-shatskih-sIzj2poobss-unsplash.jpg └── dog │ ├── 00.jpg │ ├── 01.jpg │ ├── 02.jpg │ ├── 03.jpg │ └── 04.jpg ├── requirements.txt ├── scripts ├── train_multi.sh └── train_single.sh ├── train_class_diffusion.py ├── utils ├── blip2t.py ├── load_attn_weight.py ├── semantic_preservation_loss.py └── similarity.py └── videogen.py /.gitignore: -------------------------------------------------------------------------------- 1 | ckpt/ 2 | log/ 3 | wandb/ 4 | __pycache__/ 5 | *.gif 6 | result/ 7 | 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ICLR 2025] ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance 2 | 3 | Official imple. of ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance 4 | 5 | > Recent text-to-image customization works have been proven successful in generating images of given concepts by fine-tuning the diffusion models on a few examples. However, these methods tend to overfit the concepts, resulting in failure to create the concept under multiple conditions (_e.g._, headphone is missing when generating “a <sks> dog wearing a headphone”). Interestingly, we notice that the base model before fine-tuning exhibits the capability to compose the base concept with other elements (_e.g._, “a dog wearing a headphone”), implying that the compositional ability only disappears after personalization tuning. Inspired by this observation, we present ClassDiffusion, a simple technique that leverages a semantic preservation loss to explicitly regulate the concept space when learning the new concept. Despite its simplicity, this helps avoid semantic drift when fine-tuning on the target concepts. Extensive qualitative and quantitative experiments demonstrate that the use of semantic preservation loss effectively improves the compositional abilities of the fine-tune models. In response to the ineffective evaluation of CLIP-T metrics, we introduce BLIP2-T metric, a more equitable and effective evaluation metric for this particular domain. We also provide in-depth empirical study and theoretical analysis to better understand the role of the proposed loss. Lastly, we also extend our ClassDiffusion to personalized video generation, demonstrating its flexibility. 6 | 7 | --- 8 | 9 | **[ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance](https://arxiv.org/pdf/2405.17532)** 10 | 11 | [Jiannan Huang](https://rbrq03.github.io/), [Jun Hao Liew](https://scholar.google.com.sg/citations?user=8gm-CYYAAAAJ), [Hanshu Yan](https://hanshuyan.github.io), [Yuyang Yin](https://yuyangyin.github.io), [Yao Zhao](http://mepro.bjtu.edu.cn/zhaoyao/index.htm), [Humphrey Shi](https://www.humphreyshi.com/), [Yunchao Wei](https://weiyc.github.io/index.html) 12 | 13 | [![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://classdiffusion.github.io/) 14 | [![arXiv](https://img.shields.io/badge/arXiv-2405.17532-b31b1b.svg)](https://arxiv.org/pdf/2405.17532) 15 | 16 |

17 | 18 |
19 | Our method can generate more aligned personalized images with explicit class guidance 20 |

21 | 22 | ## News 23 | 24 | - [23 Jan, 2025] 🎉 Our paper is accepted to ICLR2025! 25 | - [8 Jun. 2024] Code for BLIP2-T and Video Generation Realeased! 26 | - [3 Jun. 2024] Code Released! 27 | - [29 May. 2024] Paper Released! 28 | 29 | ## Code Usage 30 | 31 | **Set Up** 32 | 33 | ``` 34 | git clone https://github.com/Rbrq03/ClassDiffusion.git 35 | cd ClassDiffusion 36 | pip install -r requirements.txt 37 | ``` 38 | 39 | _Warning: Currently, ClassDiffusion don't support PEFT, please ensure PEFT is uninstalled in your environment, or check [PR](https://github.com/huggingface/diffusers/pull/7272). We will move forward with this PR merge soon._ 40 | 41 | ### Training 42 | 43 | **Single Concept** 44 | 45 | ``` 46 | bash scripts/train_single.sh 47 | ``` 48 | 49 | **Multiple Concepts** 50 | 51 | ``` 52 | bash scripts/train_multi.sh 53 | ``` 54 | 55 | ### Inference 56 | 57 | **single concept** 58 | 59 | ``` 60 | import torch 61 | from diffusers import DiffusionPipeline 62 | 63 | pipeline = DiffusionPipeline.from_pretrained( 64 | "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, 65 | ).to("cuda") 66 | pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") 67 | pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin") 68 | 69 | image = pipeline( 70 | " dog swimming in the pool", 71 | num_inference_steps=100, 72 | guidance_scale=6.0, 73 | eta=1.0, 74 | ).images[0] 75 | image.save("dog.png") 76 | ``` 77 | 78 | **Multiple Concepts** 79 | 80 | ``` 81 | import torch 82 | from diffusers import DiffusionPipeline 83 | 84 | pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16).to("cuda") 85 | pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin") 86 | pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin") 87 | pipeline.load_textual_inversion("path-to-save-model", weight_name=".bin") 88 | 89 | image = pipeline( 90 | "a teddy bear sitting in front of a barn", 91 | num_inference_steps=100, 92 | guidance_scale=6.0, 93 | eta=1.0, 94 | ).images[0] 95 | image.save("multi-subject.png") 96 | ``` 97 | 98 | **BLIP2-T** 99 | 100 | You can use following code: 101 | 102 | ``` 103 | from PIL import Image 104 | from utils.blip2t import BLIP2T 105 | 106 | blip2t = BLIP2T("Salesforce/blip-itm-large-coco", "cuda") 107 | 108 | prompt = "photo of a dog" 109 | image = Image.open("data/dog/00.jpg") 110 | 111 | score = blip2t.text_similarity(prompt, image)[0] 112 | score 113 | ``` 114 | 115 | or 116 | 117 | ``` 118 | python blip2t.py 119 | ``` 120 | 121 | **Video Generation** 122 | 123 | ``` 124 | python videogen.py 125 | ``` 126 | 127 | **Ckpt for quick test** 128 | 129 | |Concept(s)|Weight| 130 | |--|--| 131 | |dog|[weight](https://drive.google.com/drive/folders/12KhBmFCUb2opotOQeAH0-dvW9XArwukt?usp=drive_link)| 132 | |bear+barn|[weight](https://drive.google.com/drive/folders/1VQTvszl2FqKhc-9YaKsRTRV3qlhmOqK4?usp=drive_link)| 133 | 134 | ## Results 135 | 136 | **Single Concept Results** 137 | 138 |

139 | 140 |
141 |

142 | 143 | **Multiple Concepts Results** 144 | 145 |

146 | 147 |
148 | 149 |

150 | 151 | **Video Generation Results** 152 | 153 |

154 | 155 |
156 | 157 |

158 | 159 | ## TODO 160 | 161 | - [x] Training Code for ClassDiffusion 162 | - [x] Inference Code for ClassDiffusion 163 | - [x] Pipeline for BLIP2-T Score 164 | - [x] Inference Code for Video Generation with ClassDiffusion 165 | 166 | ## Citation 167 | 168 | If you make use of our work, please cite our paper. 169 | 170 | ```bibtex 171 | @article{huang2024classdiffusion, 172 | title={ClassDiffusion: More Aligned Personalization Tuning with Explicit Class Guidance}, 173 | author={Huang, Jiannan and Liew, Jun Hao and Yan, Hanshu and Yin, Yuyang and Zhao, Yao and Wei, Yunchao}, 174 | journal={arXiv preprint arXiv:2405.17532}, 175 | year={2024} 176 | } 177 | ``` 178 | 179 | ## Acknowledgement 180 | 181 | We thanks to the following repo for their excellent and well-documented code based: 182 | 183 | - Diffusers: [https://github.com/huggingface/diffusers](https://github.com/huggingface/diffusers) 184 | - Custom Diffusion: [https://github.com/adobe-research/custom-diffusion](https://github.com/adobe-research/custom-diffusion) 185 | - Transformers: [https://github.com/huggingface/transformers](https://github.com/huggingface/transformers) 186 | - DreamBooth: [https://github.com/google/dreambooth](https://github.com/google/dreambooth) 187 | - AnimateDiff: [https://github.com/guoyww/AnimateDiff](https://github.com/guoyww/AnimateDiff) 188 | -------------------------------------------------------------------------------- /assert/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/.DS_Store -------------------------------------------------------------------------------- /assert/img/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/.DS_Store -------------------------------------------------------------------------------- /assert/img/multiple.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/multiple.png -------------------------------------------------------------------------------- /assert/img/singleconcept.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/singleconcept.png -------------------------------------------------------------------------------- /assert/img/story.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/story.jpg -------------------------------------------------------------------------------- /assert/img/video.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/assert/img/video.jpg -------------------------------------------------------------------------------- /blip2t.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from utils.blip2t import BLIP2T 3 | 4 | blip2t = BLIP2T("Salesforce/blip-itm-large-coco", "cpu") 5 | 6 | prompt = "photo of a dog" 7 | image = Image.open("data/dog/00.jpg") 8 | 9 | score = blip2t.text_similarity(prompt, image)[0] 10 | print(score) 11 | -------------------------------------------------------------------------------- /config/concept.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instance_prompt": "a photo of a teddy bear", 4 | "class_prompt": "bear", 5 | "instance_data_dir": "data/teddybear" 6 | }, 7 | { 8 | "instance_prompt": "a photo of a barn", 9 | "class_prompt": "barn", 10 | "instance_data_dir": "data/barn" 11 | } 12 | ] 13 | -------------------------------------------------------------------------------- /data/barn/cora-leach-UlKxoeGP4Rc-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/cora-leach-UlKxoeGP4Rc-unsplash.jpg -------------------------------------------------------------------------------- /data/barn/emmy-sobieski-pPoR4vyOomY-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/emmy-sobieski-pPoR4vyOomY-unsplash.jpg -------------------------------------------------------------------------------- /data/barn/isaac-martin-ZCAiJPiO_98-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/isaac-martin-ZCAiJPiO_98-unsplash.jpg -------------------------------------------------------------------------------- /data/barn/jonah-brown-eYAsBnNISa4-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/jonah-brown-eYAsBnNISa4-unsplash.jpg -------------------------------------------------------------------------------- /data/barn/melissa-cronin-P8Bfpi1HLIE-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/melissa-cronin-P8Bfpi1HLIE-unsplash.jpg -------------------------------------------------------------------------------- /data/barn/pedro-lastra-hXunh-ivkPc-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/pedro-lastra-hXunh-ivkPc-unsplash.jpg -------------------------------------------------------------------------------- /data/barn/sarat-karumuri-vs1voV4jhpU-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/barn/sarat-karumuri-vs1voV4jhpU-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-6MDi8o6VYHg-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-6MDi8o6VYHg-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-B3Jnug3rKrs-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-B3Jnug3rKrs-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-BYZdCQDSNTY-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-BYZdCQDSNTY-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-F_imBoYwuTk-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-F_imBoYwuTk-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-kBo2MFJz2QU-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-kBo2MFJz2QU-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-nKoIXtEfkZ8-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-nKoIXtEfkZ8-unsplash.jpg -------------------------------------------------------------------------------- /data/bear/marina-shatskih-sIzj2poobss-unsplash.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/bear/marina-shatskih-sIzj2poobss-unsplash.jpg -------------------------------------------------------------------------------- /data/dog/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/00.jpg -------------------------------------------------------------------------------- /data/dog/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/01.jpg -------------------------------------------------------------------------------- /data/dog/02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/02.jpg -------------------------------------------------------------------------------- /data/dog/03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/03.jpg -------------------------------------------------------------------------------- /data/dog/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Rbrq03/ClassDiffusion/06a1d2ec201a78c461209e95c28da8932c2b8663/data/dog/04.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.24.1 2 | torchvision 3 | transformers>=4.25.1 4 | ftfy 5 | tensorboard 6 | Jinja2 7 | git+https://github.com/huggingface/diffusers -------------------------------------------------------------------------------- /scripts/train_multi.sh: -------------------------------------------------------------------------------- 1 | export CLS_TOKEN1="teddybear" 2 | export CLS_TOKEN2="barn" 3 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 4 | export OUTPUT_DIR="ckpt/${CLS_TOKEN1}_${CLS_TOKEN2}" 5 | 6 | accelerate launch train_class_diffusion.py \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --output_dir=$OUTPUT_DIR \ 9 | --concepts_list=config/concept.json \ 10 | --resolution=512 \ 11 | --train_batch_size=2 \ 12 | --learning_rate=1e-5 \ 13 | --lr_warmup_steps=0 \ 14 | --max_train_steps=800 \ 15 | --scale_lr \ 16 | --hflip \ 17 | --modifier_token "+" \ 18 | --use_spl \ 19 | --spl_weight=1.0 \ 20 | --cls_token "${CLS_TOKEN1}+${CLS_TOKEN2}" \ 21 | --no_safe_serialization \ 22 | --report_to "wandb" \ 23 | --validation_steps=25 \ 24 | --validation_prompt="a ${CLS_TOKEN1} sitting in front of a ${CLS_TOKEN2}" \ -------------------------------------------------------------------------------- /scripts/train_single.sh: -------------------------------------------------------------------------------- 1 | export CLS_TOKEN="dog" 2 | export MODEL_NAME="runwayml/stable-diffusion-v1-5" 3 | export OUTPUT_DIR="./ckpt/${CLS_TOKEN}_cls" 4 | export INSTANCE_DIR="./data/${CLS_TOKEN}" 5 | 6 | accelerate launch train_class_diffusion.py \ 7 | --pretrained_model_name_or_path=$MODEL_NAME \ 8 | --instance_data_dir=$INSTANCE_DIR \ 9 | --output_dir=$OUTPUT_DIR \ 10 | --instance_prompt="a photo of a ${CLS_TOKEN}" \ 11 | --resolution=512 \ 12 | --train_batch_size=2 \ 13 | --learning_rate=1e-5 \ 14 | --lr_warmup_steps=0 \ 15 | --max_train_steps=500 \ 16 | --scale_lr \ 17 | --hflip \ 18 | --modifier_token "" \ 19 | --no_safe_serialization \ 20 | --use_spl \ 21 | --spl_weight=1 \ 22 | --cls_token "${CLS_TOKEN}" \ 23 | --report_to "wandb" \ 24 | --validation_steps=50 \ 25 | --validation_prompt="a ${CLS_TOKEN} is swimming" \ 26 | --tracker_name "custom-diffusion-multi" \ 27 | -------------------------------------------------------------------------------- /train_class_diffusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 Class Diffusion authors from BJTU and the HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import itertools 18 | import json 19 | import logging 20 | import math 21 | import os 22 | import time 23 | import random 24 | import shutil 25 | import warnings 26 | from pathlib import Path 27 | import types 28 | import numpy as np 29 | import safetensors 30 | import torch 31 | import torch.nn.functional as F 32 | import torch.utils.checkpoint 33 | import transformers 34 | from accelerate import Accelerator 35 | from accelerate.logging import get_logger 36 | from accelerate.utils import ProjectConfiguration, set_seed 37 | from huggingface_hub import HfApi, create_repo 38 | from huggingface_hub.utils import insecure_hashlib 39 | from packaging import version 40 | from PIL import Image 41 | from torch.utils.data import Dataset 42 | from torchvision import transforms 43 | from tqdm.auto import tqdm 44 | from transformers import AutoTokenizer, PretrainedConfig 45 | import torch.nn.functional as F 46 | 47 | import subprocess 48 | import diffusers 49 | from diffusers import ( 50 | AutoencoderKL, 51 | DDPMScheduler, 52 | DiffusionPipeline, 53 | DPMSolverMultistepScheduler, 54 | UNet2DConditionModel, 55 | ) 56 | from diffusers.loaders import AttnProcsLayers 57 | from diffusers.models.attention_processor import ( 58 | CustomDiffusionAttnProcessor, 59 | CustomDiffusionAttnProcessor2_0, 60 | CustomDiffusionXFormersAttnProcessor, 61 | ) 62 | from diffusers.optimization import get_scheduler 63 | from diffusers.utils import check_min_version, is_wandb_available 64 | from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card 65 | from diffusers.utils.import_utils import is_xformers_available 66 | from utils.semantic_preservation_loss import forward_with_custom_embeddings, SPL 67 | 68 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 69 | check_min_version("0.27.0.dev0") 70 | 71 | logger = get_logger(__name__) 72 | 73 | 74 | def freeze_params(params): 75 | for param in params: 76 | param.requires_grad = False 77 | 78 | 79 | def save_model_card( 80 | repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None 81 | ): 82 | img_str = "" 83 | for i, image in enumerate(images): 84 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 85 | img_str += f"![img_{i}](./image_{i}.png)\n" 86 | 87 | model_description = f""" 88 | # Class Diffusion - {repo_id} 89 | 90 | These are Class Diffusion adaption weights for {base_model}. The weights were trained on {prompt} using [Class Diffusion](https://www.cs.cmu.edu/~custom-diffusion). You can find some example images in the following. \n 91 | {img_str} 92 | 93 | \nFor more details on the training, please follow [this link](https://github.com/huggingface/diffusers/blob/main/examples/custom_diffusion). 94 | """ 95 | model_card = load_or_create_model_card( 96 | repo_id_or_path=repo_id, 97 | from_training=True, 98 | license="creativeml-openrail-m", 99 | base_model=base_model, 100 | prompt=prompt, 101 | model_description=model_description, 102 | inference=True, 103 | ) 104 | 105 | tags = [ 106 | "text-to-image", 107 | "diffusers", 108 | "stable-diffusion", 109 | "stable-diffusion-diffusers", 110 | "custom-diffusion", 111 | "diffusers-training", 112 | ] 113 | model_card = populate_model_card(model_card, tags=tags) 114 | 115 | model_card.save(os.path.join(repo_folder, "README.md")) 116 | 117 | 118 | def import_model_class_from_model_name_or_path( 119 | pretrained_model_name_or_path: str, revision: str 120 | ): 121 | text_encoder_config = PretrainedConfig.from_pretrained( 122 | pretrained_model_name_or_path, 123 | subfolder="text_encoder", 124 | revision=revision, 125 | ) 126 | model_class = text_encoder_config.architectures[0] 127 | 128 | if model_class == "CLIPTextModel": 129 | from transformers import CLIPTextModel 130 | 131 | return CLIPTextModel 132 | elif model_class == "RobertaSeriesModelWithTransformation": 133 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import ( 134 | RobertaSeriesModelWithTransformation, 135 | ) 136 | 137 | return RobertaSeriesModelWithTransformation 138 | else: 139 | raise ValueError(f"{model_class} is not supported.") 140 | 141 | 142 | def collate_fn(examples, with_prior_preservation): 143 | input_ids = [example["instance_prompt_ids"] for example in examples] 144 | pixel_values = [example["instance_images"] for example in examples] 145 | mask = [example["mask"] for example in examples] 146 | # Concat class and instance examples for prior preservation. 147 | # We do this to avoid doing two forward passes. 148 | if with_prior_preservation: 149 | input_ids += [example["class_prompt_ids"] for example in examples] 150 | pixel_values += [example["class_images"] for example in examples] 151 | mask += [example["class_mask"] for example in examples] 152 | 153 | input_ids = torch.cat(input_ids, dim=0) 154 | pixel_values = torch.stack(pixel_values) 155 | mask = torch.stack(mask) 156 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 157 | mask = mask.to(memory_format=torch.contiguous_format).float() 158 | 159 | batch = { 160 | "input_ids": input_ids, 161 | "pixel_values": pixel_values, 162 | "mask": mask.unsqueeze(1), 163 | } 164 | return batch 165 | 166 | 167 | class PromptDataset(Dataset): 168 | "A simple dataset to prepare the prompts to generate class images on multiple GPUs." 169 | 170 | def __init__(self, prompt, num_samples): 171 | self.prompt = prompt 172 | self.num_samples = num_samples 173 | 174 | def __len__(self): 175 | return self.num_samples 176 | 177 | def __getitem__(self, index): 178 | example = {} 179 | example["prompt"] = self.prompt 180 | example["index"] = index 181 | return example 182 | 183 | 184 | class ClassDiffusionDataset(Dataset): 185 | """ 186 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 187 | It pre-processes the images and the tokenizes prompts. 188 | """ 189 | 190 | def __init__( 191 | self, 192 | concepts_list, 193 | tokenizer, 194 | size=512, 195 | mask_size=64, 196 | center_crop=False, 197 | with_prior_preservation=False, 198 | num_class_images=200, 199 | hflip=False, 200 | aug=True, 201 | ): 202 | self.size = size 203 | self.mask_size = mask_size 204 | self.center_crop = center_crop 205 | self.tokenizer = tokenizer 206 | self.interpolation = Image.BILINEAR 207 | self.aug = aug 208 | 209 | self.instance_images_path = [] 210 | self.class_images_path = [] 211 | self.with_prior_preservation = with_prior_preservation 212 | for concept in concepts_list: 213 | inst_img_path = [ 214 | (x, concept["instance_prompt"]) 215 | for x in Path(concept["instance_data_dir"]).iterdir() 216 | if x.is_file() 217 | ] 218 | self.instance_images_path.extend(inst_img_path) 219 | 220 | if with_prior_preservation: 221 | class_data_root = Path(concept["class_data_dir"]) 222 | if os.path.isdir(class_data_root): 223 | class_images_path = list(class_data_root.iterdir()) 224 | class_prompt = [ 225 | concept["class_prompt"] for _ in range(len(class_images_path)) 226 | ] 227 | else: 228 | with open(class_data_root, "r") as f: 229 | class_images_path = f.read().splitlines() 230 | with open(concept["class_prompt"], "r") as f: 231 | class_prompt = f.read().splitlines() 232 | 233 | class_img_path = list(zip(class_images_path, class_prompt)) 234 | self.class_images_path.extend(class_img_path[:num_class_images]) 235 | 236 | random.shuffle(self.instance_images_path) 237 | self.num_instance_images = len(self.instance_images_path) 238 | self.num_class_images = len(self.class_images_path) 239 | self._length = max(self.num_class_images, self.num_instance_images) 240 | self.flip = transforms.RandomHorizontalFlip(0.5 * hflip) 241 | 242 | self.image_transforms = transforms.Compose( 243 | [ 244 | self.flip, 245 | transforms.Resize( 246 | size, interpolation=transforms.InterpolationMode.BILINEAR 247 | ), 248 | ( 249 | transforms.CenterCrop(size) 250 | if center_crop 251 | else transforms.RandomCrop(size) 252 | ), 253 | transforms.ToTensor(), 254 | transforms.Normalize([0.5], [0.5]), 255 | ] 256 | ) 257 | 258 | def __len__(self): 259 | return self._length 260 | 261 | def preprocess(self, image, scale, resample): 262 | outer, inner = self.size, scale 263 | factor = self.size // self.mask_size 264 | if scale > self.size: 265 | outer, inner = scale, self.size 266 | top, left = np.random.randint(0, outer - inner + 1), np.random.randint( 267 | 0, outer - inner + 1 268 | ) 269 | image = image.resize((scale, scale), resample=resample) 270 | image = np.array(image).astype(np.uint8) 271 | image = (image / 127.5 - 1.0).astype(np.float32) 272 | instance_image = np.zeros((self.size, self.size, 3), dtype=np.float32) 273 | mask = np.zeros((self.size // factor, self.size // factor)) 274 | if scale > self.size: 275 | instance_image = image[top : top + inner, left : left + inner, :] 276 | mask = np.ones((self.size // factor, self.size // factor)) 277 | else: 278 | instance_image[top : top + inner, left : left + inner, :] = image 279 | mask[ 280 | top // factor + 1 : (top + scale) // factor - 1, 281 | left // factor + 1 : (left + scale) // factor - 1, 282 | ] = 1.0 283 | return instance_image, mask 284 | 285 | def __getitem__(self, index): 286 | example = {} 287 | instance_image, instance_prompt = self.instance_images_path[ 288 | index % self.num_instance_images 289 | ] 290 | instance_image = Image.open(instance_image) 291 | if not instance_image.mode == "RGB": 292 | instance_image = instance_image.convert("RGB") 293 | instance_image = self.flip(instance_image) 294 | 295 | # apply resize augmentation and create a valid image region mask 296 | random_scale = self.size 297 | if self.aug: 298 | random_scale = ( 299 | np.random.randint(self.size // 3, self.size + 1) 300 | if np.random.uniform() < 0.66 301 | else np.random.randint(int(1.2 * self.size), int(1.4 * self.size)) 302 | ) 303 | instance_image, mask = self.preprocess( 304 | instance_image, random_scale, self.interpolation 305 | ) 306 | 307 | if random_scale < 0.6 * self.size: 308 | instance_prompt = ( 309 | np.random.choice(["a far away ", "very small "]) + instance_prompt 310 | ) 311 | elif random_scale > self.size: 312 | instance_prompt = ( 313 | np.random.choice(["zoomed in ", "close up "]) + instance_prompt 314 | ) 315 | 316 | example["instance_images"] = torch.from_numpy(instance_image).permute(2, 0, 1) 317 | example["mask"] = torch.from_numpy(mask) 318 | example["instance_prompt_ids"] = self.tokenizer( 319 | instance_prompt, 320 | truncation=True, 321 | padding="max_length", 322 | max_length=self.tokenizer.model_max_length, 323 | return_tensors="pt", 324 | ).input_ids 325 | 326 | if self.with_prior_preservation: 327 | class_image, class_prompt = self.class_images_path[ 328 | index % self.num_class_images 329 | ] 330 | class_image = Image.open(class_image) 331 | if not class_image.mode == "RGB": 332 | class_image = class_image.convert("RGB") 333 | example["class_images"] = self.image_transforms(class_image) 334 | example["class_mask"] = torch.ones_like(example["mask"]) 335 | example["class_prompt_ids"] = self.tokenizer( 336 | class_prompt, 337 | truncation=True, 338 | padding="max_length", 339 | max_length=self.tokenizer.model_max_length, 340 | return_tensors="pt", 341 | ).input_ids 342 | 343 | return example 344 | 345 | 346 | def save_new_embed( 347 | text_encoder, 348 | modifier_token_id, 349 | accelerator, 350 | args, 351 | output_dir, 352 | safe_serialization=True, 353 | ): 354 | """Saves the new token embeddings from the text encoder.""" 355 | logger.info("Saving embeddings") 356 | learned_embeds = ( 357 | accelerator.unwrap_model(text_encoder).get_input_embeddings().weight 358 | ) 359 | for x, y in zip(modifier_token_id, args.modifier_token): 360 | learned_embeds_dict = {} 361 | learned_embeds_dict[y] = learned_embeds[x] 362 | filename = f"{output_dir}/{y}.bin" 363 | 364 | if safe_serialization: 365 | safetensors.torch.save_file( 366 | learned_embeds_dict, filename, metadata={"format": "pt"} 367 | ) 368 | else: 369 | torch.save(learned_embeds_dict, filename) 370 | 371 | 372 | def parse_args(input_args=None): 373 | parser = argparse.ArgumentParser(description="Class Diffusion training script.") 374 | parser.add_argument( 375 | "--pretrained_model_name_or_path", 376 | type=str, 377 | default=None, 378 | required=True, 379 | help="Path to pretrained model or model identifier from huggingface.co/models.", 380 | ) 381 | parser.add_argument( 382 | "--revision", 383 | type=str, 384 | default=None, 385 | required=False, 386 | help="Revision of pretrained model identifier from huggingface.co/models.", 387 | ) 388 | parser.add_argument( 389 | "--variant", 390 | type=str, 391 | default=None, 392 | help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", 393 | ) 394 | parser.add_argument( 395 | "--tokenizer_name", 396 | type=str, 397 | default=None, 398 | help="Pretrained tokenizer name or path if not the same as model_name", 399 | ) 400 | parser.add_argument( 401 | "--instance_data_dir", 402 | type=str, 403 | default=None, 404 | help="A folder containing the training data of instance images.", 405 | ) 406 | parser.add_argument( 407 | "--class_data_dir", 408 | type=str, 409 | default=None, 410 | help="A folder containing the training data of class images.", 411 | ) 412 | parser.add_argument( 413 | "--instance_prompt", 414 | type=str, 415 | default=None, 416 | help="The prompt with identifier specifying the instance", 417 | ) 418 | parser.add_argument( 419 | "--class_prompt", 420 | type=str, 421 | default=None, 422 | help="The prompt to specify images in the same class as provided instance images.", 423 | ) 424 | parser.add_argument( 425 | "--validation_prompt", 426 | type=str, 427 | default=None, 428 | help="A prompt that is used during validation to verify that the model is learning.", 429 | ) 430 | parser.add_argument( 431 | "--num_validation_images", 432 | type=int, 433 | default=2, 434 | help="Number of images that should be generated during validation with `validation_prompt`.", 435 | ) 436 | parser.add_argument( 437 | "--validation_steps", 438 | type=int, 439 | default=50, 440 | help=( 441 | "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" 442 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 443 | ), 444 | ) 445 | parser.add_argument( 446 | "--with_prior_preservation", 447 | default=False, 448 | action="store_true", 449 | help="Flag to add prior preservation loss.", 450 | ) 451 | parser.add_argument( 452 | "--real_prior", 453 | default=False, 454 | action="store_true", 455 | help="real images as prior.", 456 | ) 457 | parser.add_argument( 458 | "--prior_loss_weight", 459 | type=float, 460 | default=1.0, 461 | help="The weight of prior preservation loss.", 462 | ) 463 | parser.add_argument( 464 | "--num_class_images", 465 | type=int, 466 | default=200, 467 | help=( 468 | "Minimal class images for prior preservation loss. If there are not enough images already present in" 469 | " class_data_dir, additional images will be sampled with class_prompt." 470 | ), 471 | ) 472 | parser.add_argument( 473 | "--output_dir", 474 | type=str, 475 | default="custom-diffusion-model", 476 | help="The output directory where the model predictions and checkpoints will be written.", 477 | ) 478 | parser.add_argument( 479 | "--seed", type=int, default=42, help="A seed for reproducible training." 480 | ) 481 | parser.add_argument( 482 | "--resolution", 483 | type=int, 484 | default=512, 485 | help=( 486 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 487 | " resolution" 488 | ), 489 | ) 490 | parser.add_argument( 491 | "--center_crop", 492 | default=False, 493 | action="store_true", 494 | help=( 495 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 496 | " cropped. The images will be resized to the resolution first before cropping." 497 | ), 498 | ) 499 | parser.add_argument( 500 | "--train_batch_size", 501 | type=int, 502 | default=4, 503 | help="Batch size (per device) for the training dataloader.", 504 | ) 505 | parser.add_argument( 506 | "--sample_batch_size", 507 | type=int, 508 | default=4, 509 | help="Batch size (per device) for sampling images.", 510 | ) 511 | parser.add_argument("--num_train_epochs", type=int, default=1) 512 | parser.add_argument( 513 | "--max_train_steps", 514 | type=int, 515 | default=None, 516 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 517 | ) 518 | parser.add_argument( 519 | "--checkpointing_steps", 520 | type=int, 521 | default=250, 522 | help=( 523 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 524 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 525 | " training using `--resume_from_checkpoint`." 526 | ), 527 | ) 528 | parser.add_argument( 529 | "--checkpoints_total_limit", 530 | type=int, 531 | default=None, 532 | help=("Max number of checkpoints to store."), 533 | ) 534 | parser.add_argument( 535 | "--resume_from_checkpoint", 536 | type=str, 537 | default=None, 538 | help=( 539 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 540 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 541 | ), 542 | ) 543 | parser.add_argument( 544 | "--gradient_accumulation_steps", 545 | type=int, 546 | default=1, 547 | help="Number of updates steps to accumulate before performing a backward/update pass.", 548 | ) 549 | parser.add_argument( 550 | "--gradient_checkpointing", 551 | action="store_true", 552 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 553 | ) 554 | parser.add_argument( 555 | "--learning_rate", 556 | type=float, 557 | default=1e-5, 558 | help="Initial learning rate (after the potential warmup period) to use.", 559 | ) 560 | parser.add_argument( 561 | "--scale_lr", 562 | action="store_true", 563 | default=False, 564 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 565 | ) 566 | parser.add_argument( 567 | "--dataloader_num_workers", 568 | type=int, 569 | default=2, 570 | help=( 571 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 572 | ), 573 | ) 574 | parser.add_argument( 575 | "--freeze_model", 576 | type=str, 577 | default="crossattn_kv", 578 | choices=["crossattn_kv", "crossattn"], 579 | help="crossattn to enable fine-tuning of all params in the cross attention", 580 | ) 581 | parser.add_argument( 582 | "--lr_scheduler", 583 | type=str, 584 | default="constant", 585 | help=( 586 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 587 | ' "constant", "constant_with_warmup"]' 588 | ), 589 | ) 590 | parser.add_argument( 591 | "--lr_warmup_steps", 592 | type=int, 593 | default=500, 594 | help="Number of steps for the warmup in the lr scheduler.", 595 | ) 596 | parser.add_argument( 597 | "--use_8bit_adam", 598 | action="store_true", 599 | help="Whether or not to use 8-bit Adam from bitsandbytes.", 600 | ) 601 | parser.add_argument( 602 | "--adam_beta1", 603 | type=float, 604 | default=0.9, 605 | help="The beta1 parameter for the Adam optimizer.", 606 | ) 607 | parser.add_argument( 608 | "--adam_beta2", 609 | type=float, 610 | default=0.999, 611 | help="The beta2 parameter for the Adam optimizer.", 612 | ) 613 | parser.add_argument( 614 | "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use." 615 | ) 616 | parser.add_argument( 617 | "--adam_epsilon", 618 | type=float, 619 | default=1e-08, 620 | help="Epsilon value for the Adam optimizer", 621 | ) 622 | parser.add_argument( 623 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." 624 | ) 625 | parser.add_argument( 626 | "--push_to_hub", 627 | action="store_true", 628 | help="Whether or not to push the model to the Hub.", 629 | ) 630 | parser.add_argument( 631 | "--hub_token", 632 | type=str, 633 | default=None, 634 | help="The token to use to push to the Model Hub.", 635 | ) 636 | parser.add_argument( 637 | "--hub_model_id", 638 | type=str, 639 | default=None, 640 | help="The name of the repository to keep in sync with the local `output_dir`.", 641 | ) 642 | parser.add_argument( 643 | "--logging_dir", 644 | type=str, 645 | default="logs", 646 | help=( 647 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 648 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 649 | ), 650 | ) 651 | parser.add_argument( 652 | "--allow_tf32", 653 | action="store_true", 654 | help=( 655 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 656 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 657 | ), 658 | ) 659 | parser.add_argument( 660 | "--report_to", 661 | type=str, 662 | default="tensorboard", 663 | help=( 664 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 665 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 666 | ), 667 | ) 668 | parser.add_argument( 669 | "--mixed_precision", 670 | type=str, 671 | default=None, 672 | choices=["no", "fp16", "bf16"], 673 | help=( 674 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 675 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 676 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 677 | ), 678 | ) 679 | parser.add_argument( 680 | "--prior_generation_precision", 681 | type=str, 682 | default=None, 683 | choices=["no", "fp32", "fp16", "bf16"], 684 | help=( 685 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 686 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 687 | ), 688 | ) 689 | parser.add_argument( 690 | "--concepts_list", 691 | type=str, 692 | default=None, 693 | help="Path to json containing multiple concepts, will overwrite parameters like instance_prompt, class_prompt, etc.", 694 | ) 695 | parser.add_argument( 696 | "--local_rank", 697 | type=int, 698 | default=-1, 699 | help="For distributed training: local_rank", 700 | ) 701 | parser.add_argument( 702 | "--enable_xformers_memory_efficient_attention", 703 | action="store_true", 704 | help="Whether or not to use xformers.", 705 | ) 706 | parser.add_argument( 707 | "--set_grads_to_none", 708 | action="store_true", 709 | help=( 710 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 711 | " behaviors, so disable this argument if it causes any problems. More info:" 712 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 713 | ), 714 | ) 715 | parser.add_argument( 716 | "--modifier_token", 717 | type=str, 718 | default=None, 719 | help="A token to use as a modifier for the concept.", 720 | ) 721 | parser.add_argument( 722 | "--initializer_token", 723 | type=str, 724 | default="ktn+pll+ucd", 725 | help="A token to use as initializer word.", 726 | ) 727 | parser.add_argument( 728 | "--hflip", action="store_true", help="Apply horizontal flip data augmentation." 729 | ) 730 | parser.add_argument( 731 | "--noaug", 732 | action="store_true", 733 | help="Dont apply augmentation during data augmentation when this flag is enabled.", 734 | ) 735 | parser.add_argument( 736 | "--no_safe_serialization", 737 | action="store_true", 738 | help="If specified save the checkpoint not in `safetensors` format, but in original PyTorch format instead.", 739 | ) 740 | parser.add_argument("--use_spl", action="store_true") 741 | parser.add_argument("--spl_weight", type=float, default=1.0) 742 | parser.add_argument("--cls_token", type=str) 743 | parser.add_argument("--tracker_name", type=str, default="custom-diffusion") 744 | 745 | if input_args is not None: 746 | args = parser.parse_args(input_args) 747 | else: 748 | args = parser.parse_args() 749 | 750 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 751 | if env_local_rank != -1 and env_local_rank != args.local_rank: 752 | args.local_rank = env_local_rank 753 | 754 | if args.with_prior_preservation: 755 | if args.concepts_list is None: 756 | if args.class_data_dir is None: 757 | raise ValueError("You must specify a data directory for class images.") 758 | if args.class_prompt is None: 759 | raise ValueError("You must specify prompt for class images.") 760 | else: 761 | # logger is not available yet 762 | if args.class_data_dir is not None: 763 | warnings.warn( 764 | "You need not use --class_data_dir without --with_prior_preservation." 765 | ) 766 | if args.class_prompt is not None: 767 | warnings.warn( 768 | "You need not use --class_prompt without --with_prior_preservation." 769 | ) 770 | 771 | return args 772 | 773 | 774 | def main(args): 775 | if args.report_to == "wandb" and args.hub_token is not None: 776 | raise ValueError( 777 | "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." 778 | " Please use `huggingface-cli login` to authenticate with the Hub." 779 | ) 780 | 781 | logging_dir = Path(args.output_dir, args.logging_dir) 782 | 783 | accelerator_project_config = ProjectConfiguration( 784 | project_dir=args.output_dir, logging_dir=logging_dir 785 | ) 786 | 787 | accelerator = Accelerator( 788 | gradient_accumulation_steps=args.gradient_accumulation_steps, 789 | mixed_precision=args.mixed_precision, 790 | log_with=args.report_to, 791 | project_config=accelerator_project_config, 792 | ) 793 | 794 | if args.report_to == "wandb": 795 | if not is_wandb_available(): 796 | raise ImportError( 797 | "Make sure to install wandb if you want to use it for logging during training." 798 | ) 799 | import wandb 800 | 801 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 802 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 803 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 804 | # Make one log on every process with the configuration for debugging. 805 | logging.basicConfig( 806 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 807 | datefmt="%m/%d/%Y %H:%M:%S", 808 | level=logging.INFO, 809 | ) 810 | logger.info(accelerator.state, main_process_only=False) 811 | if accelerator.is_local_main_process: 812 | transformers.utils.logging.set_verbosity_warning() 813 | diffusers.utils.logging.set_verbosity_info() 814 | else: 815 | transformers.utils.logging.set_verbosity_error() 816 | diffusers.utils.logging.set_verbosity_error() 817 | 818 | # We need to initialize the trackers we use, and also store our configuration. 819 | # The trackers initializes automatically on the main process. 820 | if accelerator.is_main_process: 821 | accelerator.init_trackers(args.tracker_name, config=vars(args)) 822 | 823 | # If passed along, set the training seed now. 824 | if args.seed is not None: 825 | set_seed(args.seed) 826 | if args.concepts_list is None: 827 | args.concepts_list = [ 828 | { 829 | "instance_prompt": args.instance_prompt, 830 | "class_prompt": args.class_prompt, 831 | "instance_data_dir": args.instance_data_dir, 832 | "class_data_dir": args.class_data_dir, 833 | } 834 | ] 835 | else: 836 | with open(args.concepts_list, "r") as f: 837 | args.concepts_list = json.load(f) 838 | 839 | # Generate class images if prior preservation is enabled. 840 | if args.with_prior_preservation: 841 | for i, concept in enumerate(args.concepts_list): 842 | class_images_dir = Path(concept["class_data_dir"]) 843 | if not class_images_dir.exists(): 844 | class_images_dir.mkdir(parents=True, exist_ok=True) 845 | if args.real_prior: 846 | assert ( 847 | class_images_dir / "images" 848 | ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" 849 | assert ( 850 | len(list((class_images_dir / "images").iterdir())) 851 | == args.num_class_images 852 | ), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" 853 | assert ( 854 | class_images_dir / "caption.txt" 855 | ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" 856 | assert ( 857 | class_images_dir / "images.txt" 858 | ).exists(), f"Please run: python retrieve.py --class_prompt \"{concept['class_prompt']}\" --class_data_dir {class_images_dir} --num_class_images {args.num_class_images}" 859 | concept["class_prompt"] = os.path.join(class_images_dir, "caption.txt") 860 | concept["class_data_dir"] = os.path.join(class_images_dir, "images.txt") 861 | args.concepts_list[i] = concept 862 | accelerator.wait_for_everyone() 863 | else: 864 | cur_class_images = len(list(class_images_dir.iterdir())) 865 | 866 | if cur_class_images < args.num_class_images: 867 | torch_dtype = ( 868 | torch.float16 869 | if accelerator.device.type == "cuda" 870 | else torch.float32 871 | ) 872 | if args.prior_generation_precision == "fp32": 873 | torch_dtype = torch.float32 874 | elif args.prior_generation_precision == "fp16": 875 | torch_dtype = torch.float16 876 | elif args.prior_generation_precision == "bf16": 877 | torch_dtype = torch.bfloat16 878 | pipeline = DiffusionPipeline.from_pretrained( 879 | args.pretrained_model_name_or_path, 880 | torch_dtype=torch_dtype, 881 | safety_checker=None, 882 | revision=args.revision, 883 | variant=args.variant, 884 | ) 885 | pipeline.set_progress_bar_config(disable=True) 886 | 887 | num_new_images = args.num_class_images - cur_class_images 888 | logger.info(f"Number of class images to sample: {num_new_images}.") 889 | 890 | sample_dataset = PromptDataset( 891 | concept["class_prompt"], num_new_images 892 | ) 893 | sample_dataloader = torch.utils.data.DataLoader( 894 | sample_dataset, batch_size=args.sample_batch_size 895 | ) 896 | 897 | sample_dataloader = accelerator.prepare(sample_dataloader) 898 | pipeline.to(accelerator.device) 899 | 900 | for example in tqdm( 901 | sample_dataloader, 902 | desc="Generating class images", 903 | disable=not accelerator.is_local_main_process, 904 | ): 905 | images = pipeline(example["prompt"]).images 906 | 907 | for i, image in enumerate(images): 908 | hash_image = insecure_hashlib.sha1( 909 | image.tobytes() 910 | ).hexdigest() 911 | image_filename = ( 912 | class_images_dir 913 | / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" 914 | ) 915 | image.save(image_filename) 916 | 917 | del pipeline 918 | if torch.cuda.is_available(): 919 | torch.cuda.empty_cache() 920 | 921 | # Handle the repository creation 922 | if accelerator.is_main_process: 923 | if args.output_dir is not None: 924 | os.makedirs(args.output_dir, exist_ok=True) 925 | 926 | if args.push_to_hub: 927 | repo_id = create_repo( 928 | repo_id=args.hub_model_id or Path(args.output_dir).name, 929 | exist_ok=True, 930 | token=args.hub_token, 931 | ).repo_id 932 | 933 | # Load the tokenizer 934 | if args.tokenizer_name: 935 | tokenizer = AutoTokenizer.from_pretrained( 936 | args.tokenizer_name, 937 | revision=args.revision, 938 | use_fast=False, 939 | ) 940 | elif args.pretrained_model_name_or_path: 941 | tokenizer = AutoTokenizer.from_pretrained( 942 | args.pretrained_model_name_or_path, 943 | subfolder="tokenizer", 944 | revision=args.revision, 945 | use_fast=False, 946 | ) 947 | 948 | # import correct text encoder class 949 | text_encoder_cls = import_model_class_from_model_name_or_path( 950 | args.pretrained_model_name_or_path, args.revision 951 | ) 952 | 953 | # Load scheduler and models 954 | noise_scheduler = DDPMScheduler.from_pretrained( 955 | args.pretrained_model_name_or_path, subfolder="scheduler" 956 | ) 957 | text_encoder = text_encoder_cls.from_pretrained( 958 | args.pretrained_model_name_or_path, 959 | subfolder="text_encoder", 960 | revision=args.revision, 961 | variant=args.variant, 962 | ) 963 | vae = AutoencoderKL.from_pretrained( 964 | args.pretrained_model_name_or_path, 965 | subfolder="vae", 966 | revision=args.revision, 967 | variant=args.variant, 968 | ) 969 | unet = UNet2DConditionModel.from_pretrained( 970 | args.pretrained_model_name_or_path, 971 | subfolder="unet", 972 | revision=args.revision, 973 | variant=args.variant, 974 | ) 975 | 976 | # Add forward_with_custom_embeddings to TextModel 977 | text_encoder.text_model.forward_with_custom_embeddings = types.MethodType( 978 | forward_with_custom_embeddings, text_encoder.text_model 979 | ) 980 | 981 | # Adding a modifier token which is optimized #### 982 | # Code taken from https://github.com/huggingface/diffusers/blob/main/examples/textual_inversion/textual_inversion.py 983 | modifier_token_id = [] 984 | initializer_token_id = [] 985 | if args.modifier_token is not None: 986 | args.modifier_token = args.modifier_token.split("+") 987 | args.initializer_token = args.initializer_token.split("+") 988 | if len(args.modifier_token) > len(args.initializer_token): 989 | raise ValueError( 990 | "You must specify + separated initializer token for each modifier token." 991 | ) 992 | for modifier_token, initializer_token in zip( 993 | args.modifier_token, args.initializer_token[: len(args.modifier_token)] 994 | ): 995 | # Add the placeholder token in tokenizer 996 | num_added_tokens = tokenizer.add_tokens(modifier_token) 997 | if num_added_tokens == 0: 998 | raise ValueError( 999 | f"The tokenizer already contains the token {modifier_token}. Please pass a different" 1000 | " `modifier_token` that is not already in the tokenizer." 1001 | ) 1002 | 1003 | # Convert the initializer_token, placeholder_token to ids 1004 | token_ids = tokenizer.encode([initializer_token], add_special_tokens=False) 1005 | # Check if initializer_token is a single token or a sequence of tokens 1006 | if len(token_ids) > 1: 1007 | raise ValueError("The initializer token must be a single token.") 1008 | 1009 | initializer_token_id.append(token_ids[0]) 1010 | modifier_token_id.append(tokenizer.convert_tokens_to_ids(modifier_token)) 1011 | 1012 | # Resize the token embeddings as we are adding new special tokens to the tokenizer 1013 | text_encoder.resize_token_embeddings(len(tokenizer)) 1014 | 1015 | # Initialise the newly added placeholder token with the embeddings of the initializer token 1016 | token_embeds = text_encoder.get_input_embeddings().weight.data 1017 | for x, y in zip(modifier_token_id, initializer_token_id): 1018 | token_embeds[x] = token_embeds[y] 1019 | 1020 | print(modifier_token_id) 1021 | 1022 | if args.cls_token is not None: 1023 | 1024 | args.cls_token = args.cls_token.split("+") 1025 | 1026 | if len(args.modifier_token) != len(args.cls_token): 1027 | raise ValueError( 1028 | "num of Modifier token shoule be the same as the class token" 1029 | ) 1030 | 1031 | modifier_cls_text_inputs = [] 1032 | cls_text_inputs = [] 1033 | 1034 | for modifier_token, cls_token in zip(args.modifier_token, args.cls_token): 1035 | print(modifier_token, cls_token) 1036 | modifier_cls_text_input = tokenizer( 1037 | "a photo of a " + modifier_token + " " + cls_token, 1038 | padding="longest", 1039 | return_tensors="pt", 1040 | ) 1041 | 1042 | cls_text_input = tokenizer( 1043 | "a photo of a " + cls_token, 1044 | padding="longest", 1045 | return_tensors="pt", 1046 | ) 1047 | 1048 | modifier_cls_text_inputs.append(modifier_cls_text_input) 1049 | cls_text_inputs.append(cls_text_input) 1050 | 1051 | use_attention_mask = ( 1052 | hasattr(text_encoder.config, "use_attention_mask") 1053 | and text_encoder.config.use_attention_mask 1054 | ) 1055 | 1056 | spl_loss = SPL( 1057 | text_encoder=text_encoder, use_attention_mask=use_attention_mask 1058 | ) 1059 | 1060 | # Freeze all parameters except for the token embeddings in text encoder 1061 | params_to_freeze = itertools.chain( 1062 | text_encoder.text_model.encoder.parameters(), 1063 | text_encoder.text_model.final_layer_norm.parameters(), 1064 | text_encoder.text_model.embeddings.position_embedding.parameters(), 1065 | ) 1066 | freeze_params(params_to_freeze) 1067 | 1068 | ######################################################## 1069 | ######################################################## 1070 | 1071 | vae.requires_grad_(False) 1072 | if args.modifier_token is None: 1073 | text_encoder.requires_grad_(False) 1074 | unet.requires_grad_(False) 1075 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 1076 | # as these models are only used for inference, keeping weights in full precision is not required. 1077 | weight_dtype = torch.float32 1078 | if accelerator.mixed_precision == "fp16": 1079 | weight_dtype = torch.float16 1080 | elif accelerator.mixed_precision == "bf16": 1081 | weight_dtype = torch.bfloat16 1082 | 1083 | # Move unet, vae and text_encoder to device and cast to weight_dtype 1084 | if accelerator.mixed_precision != "fp16" and args.modifier_token is not None: 1085 | text_encoder.to(accelerator.device, dtype=weight_dtype) 1086 | unet.to(accelerator.device, dtype=weight_dtype) 1087 | vae.to(accelerator.device, dtype=weight_dtype) 1088 | 1089 | attention_class = ( 1090 | CustomDiffusionAttnProcessor2_0 1091 | if hasattr(F, "scaled_dot_product_attention") 1092 | else CustomDiffusionAttnProcessor 1093 | ) 1094 | if args.enable_xformers_memory_efficient_attention: 1095 | if is_xformers_available(): 1096 | import xformers 1097 | 1098 | xformers_version = version.parse(xformers.__version__) 1099 | if xformers_version == version.parse("0.0.16"): 1100 | logger.warn( 1101 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 1102 | ) 1103 | attention_class = CustomDiffusionXFormersAttnProcessor 1104 | else: 1105 | raise ValueError( 1106 | "xformers is not available. Make sure it is installed correctly" 1107 | ) 1108 | 1109 | # now we will add new Custom Diffusion weights to the attention layers 1110 | # It's important to realize here how many attention weights will be added and of which sizes 1111 | # The sizes of the attention layers consist only of two different variables: 1112 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. 1113 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. 1114 | 1115 | # Let's first see how many attention processors we will have to set. 1116 | # For Stable Diffusion, it should be equal to: 1117 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 1118 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 1119 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 1120 | # => 32 layers 1121 | 1122 | # Only train key, value projection layers if freeze_model = 'crossattn_kv' else train all params in the cross attention layer 1123 | train_kv = True 1124 | train_q_out = False if args.freeze_model == "crossattn_kv" else True 1125 | custom_diffusion_attn_procs = {} 1126 | 1127 | st = unet.state_dict() 1128 | for name, _ in unet.attn_processors.items(): 1129 | cross_attention_dim = ( 1130 | None 1131 | if name.endswith("attn1.processor") 1132 | else unet.config.cross_attention_dim 1133 | ) 1134 | if name.startswith("mid_block"): 1135 | hidden_size = unet.config.block_out_channels[-1] 1136 | elif name.startswith("up_blocks"): 1137 | block_id = int(name[len("up_blocks.")]) 1138 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 1139 | elif name.startswith("down_blocks"): 1140 | block_id = int(name[len("down_blocks.")]) 1141 | hidden_size = unet.config.block_out_channels[block_id] 1142 | layer_name = name.split(".processor")[0] 1143 | weights = { 1144 | "to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"], 1145 | "to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"], 1146 | } 1147 | if train_q_out: 1148 | weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"] 1149 | weights["to_out_custom_diffusion.0.weight"] = st[ 1150 | layer_name + ".to_out.0.weight" 1151 | ] 1152 | weights["to_out_custom_diffusion.0.bias"] = st[ 1153 | layer_name + ".to_out.0.bias" 1154 | ] 1155 | if cross_attention_dim is not None: 1156 | custom_diffusion_attn_procs[name] = attention_class( 1157 | train_kv=train_kv, 1158 | train_q_out=train_q_out, 1159 | hidden_size=hidden_size, 1160 | cross_attention_dim=cross_attention_dim, 1161 | ).to(unet.device) 1162 | custom_diffusion_attn_procs[name].load_state_dict(weights) 1163 | else: 1164 | custom_diffusion_attn_procs[name] = attention_class( 1165 | train_kv=False, 1166 | train_q_out=False, 1167 | hidden_size=hidden_size, 1168 | cross_attention_dim=cross_attention_dim, 1169 | ) 1170 | del st 1171 | unet.set_attn_processor(custom_diffusion_attn_procs) 1172 | custom_diffusion_layers = AttnProcsLayers(unet.attn_processors) 1173 | 1174 | accelerator.register_for_checkpointing(custom_diffusion_layers) 1175 | 1176 | if args.gradient_checkpointing: 1177 | unet.enable_gradient_checkpointing() 1178 | if args.modifier_token is not None: 1179 | text_encoder.gradient_checkpointing_enable() 1180 | # Enable TF32 for faster training on Ampere GPUs, 1181 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 1182 | if args.allow_tf32: 1183 | torch.backends.cuda.matmul.allow_tf32 = True 1184 | 1185 | if args.scale_lr: 1186 | args.learning_rate = ( 1187 | args.learning_rate 1188 | * args.gradient_accumulation_steps 1189 | * args.train_batch_size 1190 | * accelerator.num_processes 1191 | ) 1192 | if args.with_prior_preservation: 1193 | args.learning_rate = args.learning_rate * 2.0 1194 | 1195 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 1196 | if args.use_8bit_adam: 1197 | try: 1198 | import bitsandbytes as bnb 1199 | except ImportError: 1200 | raise ImportError( 1201 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 1202 | ) 1203 | 1204 | optimizer_class = bnb.optim.AdamW8bit 1205 | else: 1206 | optimizer_class = torch.optim.AdamW 1207 | 1208 | # Optimizer creation 1209 | optimizer = optimizer_class( 1210 | ( 1211 | itertools.chain( 1212 | text_encoder.get_input_embeddings().parameters(), 1213 | custom_diffusion_layers.parameters(), 1214 | ) 1215 | if args.modifier_token is not None 1216 | else custom_diffusion_layers.parameters() 1217 | ), 1218 | lr=args.learning_rate, 1219 | betas=(args.adam_beta1, args.adam_beta2), 1220 | weight_decay=args.adam_weight_decay, 1221 | eps=args.adam_epsilon, 1222 | ) 1223 | 1224 | # Dataset and DataLoaders creation: 1225 | train_dataset = ClassDiffusionDataset( 1226 | concepts_list=args.concepts_list, 1227 | tokenizer=tokenizer, 1228 | with_prior_preservation=args.with_prior_preservation, 1229 | size=args.resolution, 1230 | mask_size=vae.encode( 1231 | torch.randn(1, 3, args.resolution, args.resolution) 1232 | .to(dtype=weight_dtype) 1233 | .to(accelerator.device) 1234 | ) 1235 | .latent_dist.sample() 1236 | .size()[-1], 1237 | center_crop=args.center_crop, 1238 | num_class_images=args.num_class_images, 1239 | hflip=args.hflip, 1240 | aug=not args.noaug, 1241 | ) 1242 | 1243 | train_dataloader = torch.utils.data.DataLoader( 1244 | train_dataset, 1245 | batch_size=args.train_batch_size, 1246 | shuffle=True, 1247 | collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), 1248 | num_workers=args.dataloader_num_workers, 1249 | ) 1250 | 1251 | # Scheduler and math around the number of training steps. 1252 | overrode_max_train_steps = False 1253 | num_update_steps_per_epoch = math.ceil( 1254 | len(train_dataloader) / args.gradient_accumulation_steps 1255 | ) 1256 | if args.max_train_steps is None: 1257 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1258 | overrode_max_train_steps = True 1259 | 1260 | lr_scheduler = get_scheduler( 1261 | args.lr_scheduler, 1262 | optimizer=optimizer, 1263 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 1264 | num_training_steps=args.max_train_steps * accelerator.num_processes, 1265 | ) 1266 | 1267 | # Prepare everything with our `accelerator`. 1268 | if args.modifier_token is not None: 1269 | ( 1270 | custom_diffusion_layers, 1271 | text_encoder, 1272 | optimizer, 1273 | train_dataloader, 1274 | lr_scheduler, 1275 | ) = accelerator.prepare( 1276 | custom_diffusion_layers, 1277 | text_encoder, 1278 | optimizer, 1279 | train_dataloader, 1280 | lr_scheduler, 1281 | ) 1282 | else: 1283 | custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler = ( 1284 | accelerator.prepare( 1285 | custom_diffusion_layers, optimizer, train_dataloader, lr_scheduler 1286 | ) 1287 | ) 1288 | 1289 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 1290 | num_update_steps_per_epoch = math.ceil( 1291 | len(train_dataloader) / args.gradient_accumulation_steps 1292 | ) 1293 | if overrode_max_train_steps: 1294 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 1295 | # Afterwards we recalculate our number of training epochs 1296 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 1297 | 1298 | # Train! 1299 | total_batch_size = ( 1300 | args.train_batch_size 1301 | * accelerator.num_processes 1302 | * args.gradient_accumulation_steps 1303 | ) 1304 | 1305 | logger.info("***** Running training *****") 1306 | logger.info(f" Num examples = {len(train_dataset)}") 1307 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 1308 | logger.info(f" Num Epochs = {args.num_train_epochs}") 1309 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 1310 | logger.info( 1311 | f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" 1312 | ) 1313 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 1314 | logger.info(f" Total optimization steps = {args.max_train_steps}") 1315 | global_step = 0 1316 | first_epoch = 0 1317 | 1318 | # Potentially load in the weights and states from a previous save 1319 | if args.resume_from_checkpoint: 1320 | if args.resume_from_checkpoint != "latest": 1321 | path = os.path.basename(args.resume_from_checkpoint) 1322 | else: 1323 | # Get the most recent checkpoint 1324 | dirs = os.listdir(args.output_dir) 1325 | dirs = [d for d in dirs if d.startswith("checkpoint")] 1326 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 1327 | path = dirs[-1] if len(dirs) > 0 else None 1328 | 1329 | if path is None: 1330 | accelerator.print( 1331 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 1332 | ) 1333 | args.resume_from_checkpoint = None 1334 | initial_global_step = 0 1335 | else: 1336 | accelerator.print(f"Resuming from checkpoint {path}") 1337 | accelerator.load_state(os.path.join(args.output_dir, path)) 1338 | global_step = int(path.split("-")[1]) 1339 | 1340 | initial_global_step = global_step 1341 | first_epoch = global_step // num_update_steps_per_epoch 1342 | else: 1343 | initial_global_step = 0 1344 | 1345 | progress_bar = tqdm( 1346 | range(0, args.max_train_steps), 1347 | initial=initial_global_step, 1348 | desc="Steps", 1349 | # Only show the progress bar once on each machine. 1350 | disable=not accelerator.is_local_main_process, 1351 | ) 1352 | 1353 | for epoch in range(first_epoch, args.num_train_epochs): 1354 | unet.train() 1355 | if args.modifier_token is not None: 1356 | text_encoder.train() 1357 | for step, batch in enumerate(train_dataloader): 1358 | with accelerator.accumulate(unet), accelerator.accumulate(text_encoder): 1359 | # Convert images to latent space 1360 | latents = vae.encode( 1361 | batch["pixel_values"].to(dtype=weight_dtype) 1362 | ).latent_dist.sample() 1363 | latents = latents * vae.config.scaling_factor 1364 | 1365 | # Sample noise that we'll add to the latents 1366 | noise = torch.randn_like(latents) 1367 | bsz = latents.shape[0] 1368 | # Sample a random timestep for each image 1369 | timesteps = torch.randint( 1370 | 0, 1371 | noise_scheduler.config.num_train_timesteps, 1372 | (bsz,), 1373 | device=latents.device, 1374 | ) 1375 | timesteps = timesteps.long() 1376 | 1377 | # Add noise to the latents according to the noise magnitude at each timestep 1378 | # (this is the forward diffusion process) 1379 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 1380 | 1381 | # Get the text embedding for conditioning 1382 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 1383 | 1384 | # Predict the noise residual 1385 | model_pred = unet( 1386 | noisy_latents, timesteps, encoder_hidden_states 1387 | ).sample 1388 | 1389 | # Get the target for loss depending on the prediction type 1390 | if noise_scheduler.config.prediction_type == "epsilon": 1391 | target = noise 1392 | elif noise_scheduler.config.prediction_type == "v_prediction": 1393 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 1394 | else: 1395 | raise ValueError( 1396 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" 1397 | ) 1398 | 1399 | if args.with_prior_preservation: 1400 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 1401 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 1402 | target, target_prior = torch.chunk(target, 2, dim=0) 1403 | mask = torch.chunk(batch["mask"], 2, dim=0)[0] 1404 | # Compute instance loss 1405 | loss = F.mse_loss( 1406 | model_pred.float(), target.float(), reduction="none" 1407 | ) 1408 | loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() 1409 | 1410 | # Compute prior loss 1411 | prior_loss = F.mse_loss( 1412 | model_pred_prior.float(), target_prior.float(), reduction="mean" 1413 | ) 1414 | 1415 | # Add the prior loss to the instance loss. 1416 | loss = loss + args.prior_loss_weight * prior_loss 1417 | elif args.use_spl: 1418 | 1419 | mask = batch["mask"] 1420 | loss = F.mse_loss( 1421 | model_pred.float(), target.float(), reduction="none" 1422 | ) 1423 | loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() 1424 | 1425 | if accelerator.num_processes > 1: 1426 | text_embeddings = ( 1427 | text_encoder.module.get_input_embeddings().weight 1428 | ) 1429 | else: 1430 | text_embeddings = text_encoder.get_input_embeddings().weight 1431 | 1432 | spl_l = 0.0 1433 | 1434 | for i in range(len(modifier_token_id)): 1435 | spl_l = ( 1436 | spl_l 1437 | + spl_loss( 1438 | text_embeddings, 1439 | modifier_token_id[i], 1440 | modifier_cls_text_inputs[i], 1441 | cls_text_inputs[i], 1442 | ).mean() 1443 | ) 1444 | 1445 | spl_l = spl_l / len(modifier_token_id) 1446 | 1447 | spl_l = args.spl_weight * spl_l 1448 | 1449 | loss = loss + spl_l 1450 | 1451 | else: 1452 | mask = batch["mask"] 1453 | loss = F.mse_loss( 1454 | model_pred.float(), target.float(), reduction="none" 1455 | ) 1456 | loss = ((loss * mask).sum([1, 2, 3]) / mask.sum([1, 2, 3])).mean() 1457 | accelerator.backward(loss) 1458 | # Zero out the gradients for all token embeddings except the newly added 1459 | # embeddings for the concept, as we only want to optimize the concept embeddings 1460 | if args.modifier_token is not None: 1461 | if accelerator.num_processes > 1: 1462 | grads_text_encoder = ( 1463 | text_encoder.module.get_input_embeddings().weight.grad 1464 | ) 1465 | else: 1466 | grads_text_encoder = ( 1467 | text_encoder.get_input_embeddings().weight.grad 1468 | ) 1469 | # Get the index for tokens that we want to zero the grads for 1470 | index_grads_to_zero = ( 1471 | torch.arange(len(tokenizer)) != modifier_token_id[0] 1472 | ) 1473 | for i in range(len(modifier_token_id[1:])): 1474 | index_grads_to_zero = index_grads_to_zero & ( 1475 | torch.arange(len(tokenizer)) != modifier_token_id[i + 1] 1476 | ) 1477 | 1478 | grads_text_encoder.data[index_grads_to_zero, :] = ( 1479 | grads_text_encoder.data[index_grads_to_zero, :].fill_(0) 1480 | ) 1481 | 1482 | if accelerator.sync_gradients: 1483 | params_to_clip = ( 1484 | itertools.chain( 1485 | text_encoder.parameters(), 1486 | custom_diffusion_layers.parameters(), 1487 | ) 1488 | if args.modifier_token is not None 1489 | else custom_diffusion_layers.parameters() 1490 | ) 1491 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 1492 | 1493 | optimizer.step() 1494 | lr_scheduler.step() 1495 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 1496 | 1497 | # Checks if the accelerator has performed an optimization step behind the scenes 1498 | if accelerator.sync_gradients: 1499 | progress_bar.update(1) 1500 | global_step += 1 1501 | 1502 | if global_step % args.checkpointing_steps == 0: 1503 | if accelerator.is_main_process: 1504 | # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` 1505 | if args.checkpoints_total_limit is not None: 1506 | checkpoints = os.listdir(args.output_dir) 1507 | checkpoints = [ 1508 | d for d in checkpoints if d.startswith("checkpoint") 1509 | ] 1510 | checkpoints = sorted( 1511 | checkpoints, key=lambda x: int(x.split("-")[1]) 1512 | ) 1513 | 1514 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 1515 | if len(checkpoints) >= args.checkpoints_total_limit: 1516 | num_to_remove = ( 1517 | len(checkpoints) - args.checkpoints_total_limit + 1 1518 | ) 1519 | removing_checkpoints = checkpoints[0:num_to_remove] 1520 | 1521 | logger.info( 1522 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" 1523 | ) 1524 | logger.info( 1525 | f"removing checkpoints: {', '.join(removing_checkpoints)}" 1526 | ) 1527 | 1528 | for removing_checkpoint in removing_checkpoints: 1529 | removing_checkpoint = os.path.join( 1530 | args.output_dir, removing_checkpoint 1531 | ) 1532 | shutil.rmtree(removing_checkpoint) 1533 | 1534 | save_path = os.path.join( 1535 | args.output_dir, f"checkpoint-{global_step}" 1536 | ) 1537 | accelerator.save_state(save_path) 1538 | logger.info(f"Saved state to {save_path}") 1539 | 1540 | logs = { 1541 | "loss": loss.clone().detach().item(), 1542 | "lr": lr_scheduler.get_last_lr()[0], 1543 | } 1544 | if args.use_spl: 1545 | logs["spl"] = spl_l.clone().detach().item() / args.spl_weight 1546 | logs["recon_loss"] = logs["loss"] - logs["spl"] * args.spl_weight 1547 | 1548 | progress_bar.set_postfix(**logs) 1549 | accelerator.log(logs, step=global_step) 1550 | 1551 | if global_step >= args.max_train_steps: 1552 | break 1553 | 1554 | if accelerator.is_main_process: 1555 | images = [] 1556 | 1557 | if ( 1558 | args.validation_prompt is not None 1559 | and global_step % args.validation_steps == 0 1560 | ): 1561 | logger.info( 1562 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 1563 | f" {args.validation_prompt}." 1564 | ) 1565 | # create pipeline 1566 | pipeline = DiffusionPipeline.from_pretrained( 1567 | args.pretrained_model_name_or_path, 1568 | unet=accelerator.unwrap_model(unet), 1569 | text_encoder=accelerator.unwrap_model(text_encoder), 1570 | tokenizer=tokenizer, 1571 | revision=args.revision, 1572 | variant=args.variant, 1573 | torch_dtype=weight_dtype, 1574 | ) 1575 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config( 1576 | pipeline.scheduler.config 1577 | ) 1578 | pipeline = pipeline.to(accelerator.device) 1579 | pipeline.set_progress_bar_config(disable=True) 1580 | 1581 | # run inference 1582 | generator = torch.Generator(device=accelerator.device).manual_seed( 1583 | args.seed 1584 | ) 1585 | images = [ 1586 | pipeline( 1587 | args.validation_prompt, 1588 | num_inference_steps=25, 1589 | generator=generator, 1590 | eta=1.0, 1591 | ).images[0] 1592 | for _ in range(args.num_validation_images) 1593 | ] 1594 | 1595 | for tracker in accelerator.trackers: 1596 | if tracker.name == "tensorboard": 1597 | np_images = np.stack([np.asarray(img) for img in images]) 1598 | tracker.writer.add_images( 1599 | "validation", np_images, global_step, dataformats="NHWC" 1600 | ) 1601 | if tracker.name == "wandb": 1602 | tracker.log( 1603 | { 1604 | "validation": [ 1605 | wandb.Image( 1606 | image, 1607 | caption=f"{i}: {args.validation_prompt}", 1608 | ) 1609 | for i, image in enumerate(images) 1610 | ] 1611 | } 1612 | ) 1613 | 1614 | del pipeline 1615 | torch.cuda.empty_cache() 1616 | 1617 | accelerator.wait_for_everyone() 1618 | if accelerator.is_main_process: 1619 | unet = unet.to(torch.float32) 1620 | unet.save_attn_procs( 1621 | args.output_dir, safe_serialization=not args.no_safe_serialization 1622 | ) 1623 | save_new_embed( 1624 | text_encoder, 1625 | modifier_token_id, 1626 | accelerator, 1627 | args, 1628 | args.output_dir, 1629 | safe_serialization=not args.no_safe_serialization, 1630 | ) 1631 | 1632 | # Final inference 1633 | # Load previous pipeline 1634 | pipeline = DiffusionPipeline.from_pretrained( 1635 | args.pretrained_model_name_or_path, 1636 | revision=args.revision, 1637 | variant=args.variant, 1638 | torch_dtype=weight_dtype, 1639 | ) 1640 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config( 1641 | pipeline.scheduler.config 1642 | ) 1643 | pipeline = pipeline.to(accelerator.device) 1644 | 1645 | # load attention processors 1646 | weight_name = ( 1647 | "pytorch_custom_diffusion_weights.safetensors" 1648 | if not args.no_safe_serialization 1649 | else "pytorch_custom_diffusion_weights.bin" 1650 | ) 1651 | pipeline.unet.load_attn_procs(args.output_dir, weight_name=weight_name) 1652 | for token in args.modifier_token: 1653 | token_weight_name = ( 1654 | f"{token}.safetensors" 1655 | if not args.no_safe_serialization 1656 | else f"{token}.bin" 1657 | ) 1658 | pipeline.load_textual_inversion( 1659 | args.output_dir, weight_name=token_weight_name 1660 | ) 1661 | 1662 | # run inference 1663 | if args.validation_prompt and args.num_validation_images > 0: 1664 | generator = ( 1665 | torch.Generator(device=accelerator.device).manual_seed(args.seed) 1666 | if args.seed 1667 | else None 1668 | ) 1669 | images = [ 1670 | pipeline( 1671 | args.validation_prompt, 1672 | num_inference_steps=25, 1673 | generator=generator, 1674 | eta=1.0, 1675 | ).images[0] 1676 | for _ in range(args.num_validation_images) 1677 | ] 1678 | 1679 | for tracker in accelerator.trackers: 1680 | if tracker.name == "tensorboard": 1681 | np_images = np.stack([np.asarray(img) for img in images]) 1682 | tracker.writer.add_images( 1683 | "test", np_images, epoch, dataformats="NHWC" 1684 | ) 1685 | if tracker.name == "wandb": 1686 | tracker.log( 1687 | { 1688 | "test": [ 1689 | wandb.Image( 1690 | image, caption=f"{i}: {args.validation_prompt}" 1691 | ) 1692 | for i, image in enumerate(images) 1693 | ] 1694 | } 1695 | ) 1696 | 1697 | if args.push_to_hub: 1698 | save_model_card( 1699 | repo_id, 1700 | images=images, 1701 | base_model=args.pretrained_model_name_or_path, 1702 | prompt=args.instance_prompt, 1703 | repo_folder=args.output_dir, 1704 | ) 1705 | api = HfApi(token=args.hub_token) 1706 | api.upload_folder( 1707 | repo_id=repo_id, 1708 | folder_path=args.output_dir, 1709 | commit_message="End of training", 1710 | ignore_patterns=["step_*", "epoch_*"], 1711 | ) 1712 | 1713 | accelerator.end_training() 1714 | 1715 | 1716 | if __name__ == "__main__": 1717 | args = parse_args() 1718 | main(args) 1719 | -------------------------------------------------------------------------------- /utils/blip2t.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import BlipProcessor, BlipForImageTextRetrieval 3 | 4 | 5 | class BLIP2T: 6 | 7 | def __init__(self, model_name, device): 8 | self.device = device 9 | self.processor = BlipProcessor.from_pretrained(model_name) 10 | self.model = BlipForImageTextRetrieval.from_pretrained( 11 | model_name, 12 | torch_dtype=torch.float16, 13 | ).to(device) 14 | 15 | @torch.no_grad() 16 | def text_similarity(self, prompt, image): 17 | """ 18 | Calculate text similarity between prompt and image 19 | 20 | Args: 21 | prompt: str 22 | image: PIL.Image 23 | 24 | Return 25 | score: float 26 | """ 27 | inputs = self.processor(image, prompt, return_tensors="pt").to( 28 | self.device, torch.float16 29 | ) 30 | scores = self.model(**inputs, use_itm_head=False)[0] 31 | 32 | if self.device == "cpu": 33 | scores = scores.detach().numpy()[0] 34 | else: 35 | scores = scores.detach().cpu().numpy()[0] 36 | 37 | return scores 38 | -------------------------------------------------------------------------------- /utils/load_attn_weight.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import safetensors 4 | from typing import Callable, Dict, List, Optional, Union 5 | 6 | from diffusers.utils import ( 7 | USE_PEFT_BACKEND, 8 | _get_model_file, 9 | ) 10 | 11 | TEXT_ENCODER_NAME = "text_encoder" 12 | UNET_NAME = "unet" 13 | 14 | LORA_WEIGHT_NAME = "pytorch_lora_weights.bin" 15 | LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors" 16 | 17 | CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin" 18 | CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors" 19 | 20 | 21 | def load_custom_attn_param( 22 | self, 23 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 24 | **kwargs 25 | ): 26 | r""" 27 | Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be 28 | defined in 29 | [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py) 30 | and be a `torch.nn.Module` class. 31 | 32 | Parameters: 33 | pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 34 | Can be either: 35 | 36 | - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on 37 | the Hub. 38 | - A path to a directory (for example `./my_model_directory`) containing the model weights saved 39 | with [`ModelMixin.save_pretrained`]. 40 | - A [torch state 41 | dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). 42 | 43 | cache_dir (`Union[str, os.PathLike]`, *optional*): 44 | Path to a directory where a downloaded pretrained model configuration is cached if the standard cache 45 | is not used. 46 | force_download (`bool`, *optional*, defaults to `False`): 47 | Whether or not to force the (re-)download of the model weights and configuration files, overriding the 48 | cached versions if they exist. 49 | resume_download (`bool`, *optional*, defaults to `False`): 50 | Whether or not to resume downloading the model weights and configuration files. If set to `False`, any 51 | incompletely downloaded files are deleted. 52 | proxies (`Dict[str, str]`, *optional*): 53 | A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 54 | 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. 55 | local_files_only (`bool`, *optional*, defaults to `False`): 56 | Whether to only load local model weights and configuration files or not. If set to `True`, the model 57 | won't be downloaded from the Hub. 58 | token (`str` or *bool*, *optional*): 59 | The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from 60 | `diffusers-cli login` (stored in `~/.huggingface`) is used. 61 | low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`): 62 | Speed up model loading only loading the pretrained weights and not initializing the weights. This also 63 | tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model. 64 | Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this 65 | argument to `True` will raise an error. 66 | revision (`str`, *optional*, defaults to `"main"`): 67 | The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier 68 | allowed by Git. 69 | subfolder (`str`, *optional*, defaults to `""`): 70 | The subfolder location of a model file within a larger model repository on the Hub or locally. 71 | mirror (`str`, *optional*): 72 | Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not 73 | guarantee the timeliness or safety of the source, and you should refer to the mirror site for more 74 | information. 75 | 76 | Example: 77 | 78 | ```py 79 | from diffusers import AutoPipelineForText2Image 80 | import torch 81 | 82 | pipeline = AutoPipelineForText2Image.from_pretrained( 83 | "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 84 | ).to("cuda") 85 | pipeline.unet.load_attn_procs( 86 | "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic" 87 | ) 88 | ``` 89 | """ 90 | from diffusers.models.attention_processor import CustomDiffusionAttnProcessor 91 | from diffusers.models.lora import ( 92 | LoRACompatibleConv, 93 | LoRACompatibleLinear, 94 | LoRAConv2dLayer, 95 | LoRALinearLayer, 96 | ) 97 | 98 | cache_dir = kwargs.pop("cache_dir", None) 99 | force_download = kwargs.pop("force_download", False) 100 | resume_download = kwargs.pop("resume_download", False) 101 | proxies = kwargs.pop("proxies", None) 102 | local_files_only = kwargs.pop("local_files_only", None) 103 | token = kwargs.pop("token", None) 104 | revision = kwargs.pop("revision", None) 105 | subfolder = kwargs.pop("subfolder", None) 106 | weight_name = kwargs.pop("weight_name", None) 107 | use_safetensors = kwargs.pop("use_safetensors", None) 108 | # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. 109 | # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning 110 | network_alphas = kwargs.pop("network_alphas", None) 111 | 112 | _pipeline = kwargs.pop("_pipeline", None) 113 | use_custom_diffusion = kwargs.pop("custom_diffusion", None) 114 | 115 | is_network_alphas_none = network_alphas is None 116 | 117 | allow_pickle = False 118 | 119 | if use_safetensors is None: 120 | use_safetensors = True 121 | allow_pickle = True 122 | 123 | user_agent = { 124 | "file_type": "attn_procs_weights", 125 | "framework": "pytorch", 126 | } 127 | 128 | model_file = None 129 | if not isinstance(pretrained_model_name_or_path_or_dict, dict): 130 | # Let's first try to load .safetensors weights 131 | if (use_safetensors and weight_name is None) or ( 132 | weight_name is not None and weight_name.endswith(".safetensors") 133 | ): 134 | try: 135 | model_file = _get_model_file( 136 | pretrained_model_name_or_path_or_dict, 137 | weights_name=weight_name or LORA_WEIGHT_NAME_SAFE, 138 | cache_dir=cache_dir, 139 | force_download=force_download, 140 | resume_download=resume_download, 141 | proxies=proxies, 142 | local_files_only=local_files_only, 143 | token=token, 144 | revision=revision, 145 | subfolder=subfolder, 146 | user_agent=user_agent, 147 | ) 148 | state_dict = safetensors.torch.load_file(model_file, device="cpu") 149 | except IOError as e: 150 | if not allow_pickle: 151 | raise e 152 | # try loading non-safetensors weights 153 | pass 154 | if model_file is None: 155 | model_file = _get_model_file( 156 | pretrained_model_name_or_path_or_dict, 157 | weights_name=weight_name or LORA_WEIGHT_NAME, 158 | cache_dir=cache_dir, 159 | force_download=force_download, 160 | resume_download=resume_download, 161 | proxies=proxies, 162 | local_files_only=local_files_only, 163 | token=token, 164 | revision=revision, 165 | subfolder=subfolder, 166 | user_agent=user_agent, 167 | ) 168 | state_dict = torch.load(model_file, map_location="cpu") 169 | else: 170 | state_dict = pretrained_model_name_or_path_or_dict 171 | 172 | # is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys()) 173 | is_custom_diffusion = ( 174 | any("custom_diffusion" in k for k in state_dict.keys()) or use_custom_diffusion 175 | ) 176 | 177 | assert is_custom_diffusion == True 178 | 179 | unet_state_dict = self.state_dict() 180 | 181 | for name, weight in state_dict.items(): 182 | if name.endswith("weight"): 183 | name = name.split(".") 184 | module_unet = ( 185 | ".".join(name[:-3]) + ".to_" + name[-2].split("_")[1] + ".weight" 186 | ) 187 | unet_state_dict[module_unet] = weight 188 | 189 | self.load_state_dict(unet_state_dict) 190 | 191 | warn_messages = "you are loading custom diffusion weights into a U-Net without the official attention processor, which may cause potential issues. Please ensure you understand the implications of your actions." 192 | warnings.warn(warn_messages) 193 | -------------------------------------------------------------------------------- /utils/semantic_preservation_loss.py: -------------------------------------------------------------------------------- 1 | from ast import Dict 2 | from dataclasses import dataclass 3 | from typing import Any, Optional, Tuple, Union 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from torch import nn 8 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 9 | import torch.nn.functional as F 10 | 11 | from transformers.modeling_outputs import ( 12 | BaseModelOutput, 13 | BaseModelOutputWithPooling, 14 | ImageClassifierOutput, 15 | ) 16 | 17 | from transformers.modeling_attn_mask_utils import ( 18 | _create_4d_causal_attention_mask, 19 | _prepare_4d_attention_mask, 20 | ) 21 | 22 | from transformers import CLIPTextModel 23 | 24 | 25 | def forward_with_custom_embeddings( 26 | self, 27 | input_ids: Optional[torch.Tensor] = None, 28 | attention_mask: Optional[torch.Tensor] = None, 29 | position_ids: Optional[torch.Tensor] = None, 30 | output_attentions: Optional[bool] = None, 31 | output_hidden_states: Optional[bool] = None, 32 | return_dict: Optional[bool] = None, 33 | input_modifier_embeddings: Optional[torch.Tensor] = None, 34 | modifier_token_id: Optional[torch.Tensor] = None, 35 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 36 | r""" 37 | Returns: 38 | 39 | """ 40 | output_attentions = ( 41 | output_attentions 42 | if output_attentions is not None 43 | else self.config.output_attentions 44 | ) 45 | output_hidden_states = ( 46 | output_hidden_states 47 | if output_hidden_states is not None 48 | else self.config.output_hidden_states 49 | ) 50 | return_dict = ( 51 | return_dict if return_dict is not None else self.config.use_return_dict 52 | ) 53 | 54 | use_custom_embeddings = ( 55 | input_modifier_embeddings is not None and modifier_token_id is not None 56 | ) 57 | 58 | if input_ids is None: 59 | raise ValueError("You have to specify input_ids") 60 | 61 | input_shape = input_ids.size() 62 | input_ids = input_ids.view(-1, input_shape[-1]) 63 | 64 | hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) 65 | 66 | if use_custom_embeddings: 67 | modifier_index = torch.where(input_ids.squeeze(0) == modifier_token_id) 68 | hidden_states[0, modifier_index, :] = input_modifier_embeddings 69 | 70 | # CLIP's text model uses causal mask, prepare it here. 71 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 72 | causal_attention_mask = _create_4d_causal_attention_mask( 73 | input_shape, hidden_states.dtype, device=hidden_states.device 74 | ) 75 | # expand attention_mask 76 | if attention_mask is not None: 77 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 78 | attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) 79 | 80 | encoder_outputs = self.encoder( 81 | inputs_embeds=hidden_states, 82 | attention_mask=attention_mask, 83 | causal_attention_mask=causal_attention_mask, 84 | output_attentions=output_attentions, 85 | output_hidden_states=output_hidden_states, 86 | return_dict=return_dict, 87 | ) 88 | 89 | last_hidden_state = encoder_outputs[0] 90 | last_hidden_state = self.final_layer_norm(last_hidden_state) 91 | 92 | if self.eos_token_id == 2: 93 | # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. 94 | # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added 95 | # ------------------------------------------------------------ 96 | # text_embeds.shape = [batch_size, sequence_length, transformer.width] 97 | # take features from the eot embedding (eot_token is the highest number in each sequence) 98 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 99 | pooled_output = last_hidden_state[ 100 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 101 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax( 102 | dim=-1 103 | ), 104 | ] 105 | else: 106 | # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) 107 | pooled_output = last_hidden_state[ 108 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 109 | # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) 110 | ( 111 | input_ids.to(dtype=torch.int, device=last_hidden_state.device) 112 | == self.eos_token_id 113 | ) 114 | .int() 115 | .argmax(dim=-1), 116 | ] 117 | 118 | if not return_dict: 119 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 120 | 121 | return BaseModelOutputWithPooling( 122 | last_hidden_state=last_hidden_state, 123 | pooler_output=pooled_output, 124 | hidden_states=encoder_outputs.hidden_states, 125 | attentions=encoder_outputs.attentions, 126 | ) 127 | 128 | 129 | class SPL: 130 | 131 | def __init__( 132 | self, 133 | text_encoder: CLIPTextModel, 134 | use_attention_mask: Optional[bool] = None, 135 | ): 136 | self.text_encoder = text_encoder 137 | self.use_attention_mask = use_attention_mask 138 | 139 | def _encode( 140 | self, 141 | text_input: Optional[Dict] = None, 142 | input_modifier_embeddings: Optional[torch.Tensor] = None, 143 | modifier_token_id: Optional[torch.Tensor] = None, 144 | ): 145 | 146 | text_input_ids = text_input.input_ids 147 | device = self.text_encoder.device 148 | 149 | if ( 150 | hasattr(self.text_encoder.config, "use_attention_mask") 151 | and self.text_encoder.config.use_attention_mask 152 | ): 153 | attention_mask = text_input.attention_mask.to(device) 154 | else: 155 | attention_mask = None 156 | 157 | prompt_embeds = self.text_encoder.text_model.forward_with_custom_embeddings( 158 | text_input_ids.to(device), 159 | attention_mask=attention_mask, 160 | input_modifier_embeddings=input_modifier_embeddings, 161 | modifier_token_id=modifier_token_id, 162 | ) 163 | 164 | prompt_embeds = prompt_embeds.pooler_output 165 | 166 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) 167 | 168 | return prompt_embeds 169 | 170 | def __call__( 171 | self, 172 | text_embeddings: Optional[torch.Tensor], 173 | modifier_token_id: Optional[torch.Tensor], 174 | modifier_cls_text_input: Optional[Dict], 175 | cls_text_input: Optional[Dict], 176 | ) -> torch.Tensor: 177 | 178 | modifier_embedding = text_embeddings[modifier_token_id].clone() 179 | 180 | modifier_cls_embedding = self._encode( 181 | text_input=modifier_cls_text_input, 182 | input_modifier_embeddings=modifier_embedding, 183 | modifier_token_id=modifier_token_id, 184 | ) 185 | 186 | cls_embedding = self._encode(text_input=cls_text_input) 187 | 188 | dis = F.cosine_similarity(modifier_cls_embedding, cls_embedding) 189 | 190 | dis = 1 - dis 191 | 192 | return dis 193 | -------------------------------------------------------------------------------- /utils/similarity.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | from PIL import Image 4 | 5 | 6 | class Similarity: 7 | 8 | def __init__(self, model_name, device): 9 | self.device = device 10 | self.model, self.preprocess = clip.load(model_name, device) 11 | 12 | @torch.no_grad() 13 | def text_similarity(self, prompt, image): 14 | """ 15 | Calculate text similarity between prompt and image 16 | 17 | Args: 18 | prompt: str 19 | image: PIL.Image 20 | 21 | Return 22 | score: float 23 | """ 24 | image = self.preprocess(image).unsqueeze(0).to(self.device) 25 | text = clip.tokenize([prompt]).to(self.device) 26 | 27 | image_features = self.model.encode_image(image) 28 | text_features = self.model.encode_text(text) 29 | 30 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 31 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 32 | 33 | similarity = torch.matmul(text_features, image_features.T).squeeze() 34 | 35 | score = similarity.detach().cpu().numpy() 36 | 37 | return score 38 | 39 | def image_similarity(self, source, generate): 40 | """ 41 | Calculate image similarity between source image and generate image 42 | 43 | Args: 44 | prompt: PIL.Image 45 | image: PIL.Image 46 | 47 | Return 48 | score: float 49 | """ 50 | source = self.preprocess(source).unsqueeze(0).to(self.device) 51 | generate = self.preprocess(generate).unsqueeze(0).to(self.device) 52 | 53 | image_features_source = self.model.encode_image(source) 54 | image_features_generate = self.model.encode_image(generate) 55 | 56 | image_features_source = image_features_source / image_features_source.norm( 57 | dim=-1, keepdim=True 58 | ) 59 | image_features_generate = ( 60 | image_features_generate / image_features_generate.norm(dim=-1, keepdim=True) 61 | ) 62 | 63 | similarity = torch.matmul( 64 | image_features_source, image_features_generate.T 65 | ).squeeze() 66 | 67 | score = similarity.detach().cpu().numpy() 68 | 69 | return score 70 | -------------------------------------------------------------------------------- /videogen.py: -------------------------------------------------------------------------------- 1 | # import torch 2 | 3 | import types 4 | import torch 5 | from utils.load_attn_weight import load_custom_attn_param 6 | 7 | from diffusers import ( 8 | AnimateDiffPipeline, 9 | DDIMScheduler, 10 | MotionAdapter, 11 | DiffusionPipeline, 12 | ) 13 | from diffusers.utils import export_to_gif 14 | 15 | model_id = "runwayml/stable-diffusion-v1-5" 16 | ckpt_path = "your_ckpt_path" # TODO: modify here 17 | 18 | pipeline = DiffusionPipeline.from_pretrained( 19 | model_id, 20 | torch_dtype=torch.float16, 21 | ).to("cuda") 22 | 23 | pipeline.unet.load_custom_attn_param = types.MethodType( 24 | load_custom_attn_param, pipeline.unet 25 | ) 26 | pipeline.unet.load_custom_attn_param( 27 | ckpt_path, 28 | weight_name="pytorch_custom_diffusion_weights.bin", 29 | ) 30 | pipeline.load_textual_inversion(ckpt_path, weight_name=".bin") 31 | 32 | adapter = MotionAdapter.from_pretrained( 33 | "guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16 34 | ) 35 | pipe = AnimateDiffPipeline.from_pretrained( 36 | model_id, 37 | motion_adapter=adapter, 38 | unet=pipeline.unet, 39 | text_encoder=pipeline.text_encoder, 40 | tokenizer=pipeline.tokenizer, 41 | torch_dtype=torch.float16, 42 | ) 43 | scheduler = DDIMScheduler.from_pretrained( 44 | model_id, 45 | subfolder="scheduler", 46 | clip_sample=False, 47 | timestep_spacing="linspace", 48 | beta_schedule="linear", 49 | steps_offset=1, 50 | ) 51 | pipe.scheduler = scheduler 52 | pipe.enable_vae_slicing() 53 | pipe.enable_model_cpu_offload() 54 | 55 | output = pipe( 56 | prompt=(" dog running on the street"), 57 | negative_prompt="bad quality, worse quality", 58 | num_frames=16, 59 | guidance_scale=7.5, 60 | num_inference_steps=25, 61 | ) 62 | frames = output.frames[0] 63 | export_to_gif(frames, "test.gif", fps=4) 64 | --------------------------------------------------------------------------------