├── .gitignore ├── README.md ├── dataset.py ├── encoder_model.py ├── images └── asuka.png ├── requirements.txt ├── sample.py ├── sample.sh ├── sample_chill.sh ├── train.py ├── train_encoder.sh └── train_encoder_chill.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Initially taken from Github's Python gitignore file 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | amlt/ 8 | # C extensions 9 | *.so 10 | 11 | *_train_ffhq/* 12 | temp/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Encoder for SatbleDiffusion Fast Personalization 2 | This is unofficial implementation of [Encoder-based Domain Tuning for Fast Personalization of Text-to-Image Models](https://tuning-encoder.github.io/). The code is based on [Huggingface diffusers](https://github.com/huggingface/diffusers). 3 | 4 | This code is not exactly the same as the original paper, we use LORA instead of Weight Offsets. 5 | 6 | 7 | ## Environment 8 | pip install -r requirements.txt 9 | 10 | ## Model Pretraining (using stablediffusion v1.5 based model achieves better result) 11 | ```bash 12 | accelerate config 13 | accelerate launch train.py --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --images_dir $FFHQ_DIR --lr_scheduler constant_with_warmup \ 14 | --train_batch_size 5 --resolution 512 --scale_lr --output_dir $MODEL_SAVE_DIR --num_train_epochs 10 --save_steps 10000 --learning_rate 1.6e-6 --lr_scheduler cosine_with_restarts --reg_weight 0.01 --lora_rank 64 --placeholder_token face 15 | ``` 16 | 17 | pretrained model is available at https://huggingface.co/yoctta/sd-personalization-encoder-face/tree/main 18 | ## Finetune and sample images 19 | ```bash 20 | accelerate config 21 | accelerate launch --multi_gpu sample.py --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" \ 22 | --model_path "$MODEL_SAVE_DIR/checkpoint-70000" --final_checkpoint \ 23 | --image_path $INPUT_IMAGE_PATH \ 24 | --train_batch_size 2 \ 25 | --finetune_steps 15 --reg_weight 0.1 --resolution 512 \ 26 | --prompt "a photo of face wearing sunglasses." --placeholder_token face \ 27 | --num_samples 2 --learning_rate 1.6e-5 --train_text_encoder --mixed_precision bf16 \ 28 | --output_dir $OUTPUT_IMAGE_PATH 29 | 30 | ``` 31 | ## Sampled Images 32 | ![Image 1](images/asuka.png) -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torch 3 | import numpy as np 4 | import PIL 5 | from PIL import Image 6 | import random 7 | from torchvision import transforms 8 | import os 9 | 10 | 11 | my_templates=["a photo of {}"] 12 | 13 | 14 | class Simpledataset(Dataset): 15 | def __init__( 16 | self, 17 | image_dir, 18 | tokenizer, 19 | preprocess, 20 | size=512, 21 | placeholder_token='person' 22 | ): 23 | self.image_dir = image_dir 24 | self.tokenizer = tokenizer 25 | self.placeholder_token = placeholder_token 26 | self.size=size 27 | self.preprocess=preprocess 28 | self.place_holder_id=self.tokenizer.encode(self.placeholder_token)[1] 29 | self.ids=os.listdir(image_dir) 30 | self.templates = [i.format(self.placeholder_token) for i in my_templates] 31 | 32 | def __len__(self): 33 | return len(self.ids) 34 | 35 | def __getitem__(self, i): 36 | example = {} 37 | img_path=os.path.join(self.image_dir,self.ids[i]) 38 | image = Image.open(img_path) 39 | if not image.mode == "RGB": 40 | image = image.convert("RGB") 41 | example["image"]=self.preprocess(image) 42 | text = random.choice(self.templates) 43 | example["input_ids"] = self.tokenizer( 44 | text, 45 | padding="max_length", 46 | truncation=True, 47 | max_length=self.tokenizer.model_max_length, 48 | return_tensors="pt", 49 | ).input_ids[0] 50 | example["input_placeholder_pos"]=example["input_ids"]==self.place_holder_id 51 | image = image.resize((self.size, self.size)) 52 | image = np.array(image).astype(np.uint8) 53 | image = (image / 127.5 - 1.0).astype(np.float32) 54 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) 55 | return example 56 | -------------------------------------------------------------------------------- /encoder_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import open_clip 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from PIL import Image 6 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel 7 | class feature_extractor: 8 | def __init__(self,clip_model='ViT-H-14/laion2b_s32b_b79k',unet_model="stabilityai/stable-diffusion-2-1"): 9 | unet = UNet2DConditionModel.from_pretrained(unet_model, subfolder="unet",use_auth_token=True).eval() 10 | del unet.mid_block 11 | del unet.up_blocks 12 | del unet.conv_out 13 | unet.eval() 14 | self.unet=unet 15 | clip_model, _, preprocess = open_clip.create_model_and_transforms(clip_model.split('/')[0], pretrained=clip_model.split('/')[1]) 16 | self.preprocess=preprocess 17 | self.image_encoder=clip_model.visual.eval() 18 | self.activation={} 19 | self.device=torch.device('cpu') 20 | self.dtype=torch.float32 21 | def getActivation(name): 22 | def hook(model, input, output): 23 | self.activation[name] = output[:,0,:].detach() 24 | return hook 25 | for i in range(1,len(self.image_encoder.transformer.resblocks),2): 26 | self.image_encoder.transformer.resblocks[i].register_forward_hook(getActivation(i)) 27 | 28 | def set_device(self,device,dtype=torch.float32,only_unet=False): 29 | self.device=torch.device(device) 30 | self.dtype=dtype 31 | self.unet.to(device,dtype=dtype) 32 | if not only_unet: 33 | self.image_encoder.to(device,dtype=dtype) 34 | 35 | def preprocess_images(self,images): 36 | if not type(images)==list: 37 | images=[images] 38 | return torch.stack([self.preprocess(i) for i in images]).to(self.device) 39 | 40 | def encode_image(self,images): 41 | with torch.no_grad(): 42 | _=self.image_encoder(images.to(self.dtype)) 43 | n=sorted(self.activation.keys()) 44 | return torch.stack([torch.cat([self.activation[j][i] for j in n],dim=0) for i in range(len(images))]) 45 | 46 | def encode_unet(self,latent,timestep,encoder_hidden_states): 47 | timesteps=timestep 48 | with torch.no_grad(): 49 | pooled_features=[] 50 | if not torch.is_tensor(timesteps): 51 | if isinstance(timestep, float): 52 | dtype = torch.float64 53 | else: 54 | dtype = torch.int64 55 | timesteps = torch.tensor([timesteps], dtype=dtype, device=latent.device) 56 | elif len(timesteps.shape) == 0: 57 | timesteps = timesteps[None].to(latent.device) 58 | timesteps = timesteps.expand(latent.shape[0]) 59 | t_emb = self.unet.time_proj(timesteps).to(dtype=self.dtype) 60 | emb = self.unet.time_embedding(t_emb) 61 | latent = self.unet.conv_in(latent.to(dtype=self.dtype)) 62 | for downsample_block in self.unet.down_blocks: 63 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 64 | latent, res_samples = downsample_block( 65 | hidden_states=latent, 66 | temb=emb, 67 | encoder_hidden_states=encoder_hidden_states.to(self.dtype)) 68 | else: 69 | latent, res_samples = downsample_block(hidden_states=latent, temb=emb) 70 | pooled_features.append(latent.mean(dim=[2,3])) 71 | return pooled_features 72 | 73 | 74 | 75 | 76 | 77 | class IDEncoder(nn.Module): 78 | def __init__(self,clip_model='ViT-H-14/laion2b_s32b_b79k',unet_model="stabilityai/stable-diffusion-2-1"): 79 | super().__init__() 80 | self.feature_extractor=feature_extractor(clip_model,unet_model) 81 | dim_width=self.feature_extractor.image_encoder.transformer.width 82 | dim_layers=self.feature_extractor.image_encoder.transformer.layers//2 83 | dim_unet = self.feature_extractor.unet.config.block_out_channels 84 | dim_cross_attn=self.feature_extractor.unet.config.cross_attention_dim 85 | self.id_encoder_feature=nn.Conv1d(dim_width*dim_layers,dim_width*dim_layers,1,groups=dim_layers) 86 | self.unet_encoder_feature=nn.ModuleList([nn.Linear(i,dim_width) for i in dim_unet]) 87 | self.last_linear=nn.Linear(dim_width,dim_cross_attn) 88 | self.dim_width=dim_width 89 | self.dim_layers=dim_layers 90 | def forward(self,batch,latent,timestep,encoder_hidden_states): 91 | if 'image_features' in batch: 92 | image_features=batch['image_features'] 93 | elif 'image' in batch: 94 | image_features=self.feature_extractor.encode_image(batch['image']).to(dtype=latent.dtype) 95 | unet_features=self.feature_extractor.encode_unet(latent,timestep,encoder_hidden_states) 96 | unet_features=torch.stack([f(i.to(dtype=latent.dtype)) for f,i in zip(self.unet_encoder_feature,unet_features)],dim=1) 97 | image_features_proj=self.id_encoder_feature(image_features.unsqueeze(-1)).reshape(image_features.shape[0],self.dim_layers,self.dim_width) 98 | features=torch.cat([unet_features,image_features_proj],dim=1) 99 | features=F.leaky_relu(features, negative_slope=0.1) 100 | features=torch.mean(features,dim=1) 101 | return self.last_linear(features) 102 | -------------------------------------------------------------------------------- /images/asuka.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoctta/sd_personalization_encoder/4f0e49a858fb905f28e9180c5eb58d11eb0a260e/images/asuka.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.16.0 2 | diffusers==0.13.1 3 | numpy==1.21.5 4 | open_clip_torch==2.9.3 5 | packaging==21.3 6 | Pillow==9.4.0 7 | torch==1.13.1 8 | torchvision==0.14.1 9 | tqdm==4.63.1 10 | transformers==4.25.1 11 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | from train import finetune, load_model 2 | from PIL import Image 3 | import os 4 | import argparse 5 | if __name__ == "__main__": 6 | parser=argparse.ArgumentParser() 7 | parser.add_argument('--pretrained_model_name_or_path',type=str,default="stabilityai/stable-diffusion-2-1") 8 | parser.add_argument('--model_path',type=str,default=None) 9 | parser.add_argument('--final_checkpoint', action='store_true') 10 | parser.add_argument('--train_text_encoder', action='store_true') 11 | parser.add_argument('--image_path',type=str,default=None) 12 | parser.add_argument("--mixed_precision",type=str,default="no",choices=["no", "fp16", "bf16"]) 13 | parser.add_argument('--learning_rate',type=float,default=1e-6) 14 | parser.add_argument('--reg_weight',type=float,default=0.1) 15 | parser.add_argument('--train_batch_size',type=int,default=1) 16 | parser.add_argument('--finetune_steps',type=int,default=15) 17 | parser.add_argument('--prompt',type=str,nargs='+') 18 | parser.add_argument('--placeholder_token',type=str,default='person') 19 | parser.add_argument('--resolution',type=int,default=768) 20 | parser.add_argument('--output_dir',type=str,default=None) 21 | parser.add_argument('--num_samples',type=int,default=1) 22 | 23 | args=parser.parse_args() 24 | if not args.output_dir: 25 | args.output_dir=os.path.join(args.model_path,'sampled_images') 26 | 27 | id_encoder,pipe=load_model(args.model_path,args.pretrained_model_name_or_path,args.final_checkpoint) 28 | image=Image.open(args.image_path) 29 | if not image.mode == "RGB": 30 | image = image.convert("RGB") 31 | finetune(image,pipe,id_encoder,mixed_precision=args.mixed_precision,learning_rate=args.learning_rate,train_batch_size=args.train_batch_size,\ 32 | train_steps=args.finetune_steps,text='a photo of '+args.placeholder_token,placeholder_token=args.placeholder_token,resize=args.resolution,\ 33 | prompts=args.prompt,output_dir=args.output_dir,num_samples=args.num_samples,train_text_encoder=args.train_text_encoder,reg_weight=args.reg_weight) 34 | -------------------------------------------------------------------------------- /sample.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --multi_gpu sample.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-2-1" \ 2 | --model_path "new_train_ffhq/checkpoint-140000" --final_checkpoint \ 3 | --image_path "images/asuka.png" \ 4 | --train_batch_size 2 \ 5 | --finetune_steps 15 --reg_weight 0.1 --resolution 768 \ 6 | --prompt "a pencil sketch of person." "an oil paint of person." "a photo of person surrounded by sunflowers" "a photo of person in red shirt" "a photo of person" \ 7 | --num_samples 2 --learning_rate 1.6e-5 --train_text_encoder --mixed_precision bf16 \ 8 | --output_dir new_train_ffhq/checkpoint-140000/samples_lr16_reg_0.1_f15_asuka_bf16 9 | 10 | -------------------------------------------------------------------------------- /sample_chill.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --multi_gpu sample.py --pretrained_model_name_or_path "/ssd/zhaohanqing/msws/diffusers/examples/model_lab/chilloutmix" \ 2 | --model_path "chillout_train_ffhq/checkpoint-70000" --final_checkpoint \ 3 | --image_path "temp/images/asuka.png" \ 4 | --train_batch_size 6 \ 5 | --finetune_steps 15 --reg_weight 0.1 --resolution 512 \ 6 | --prompt "a photo of face with sunglasses." "an photo of face on the beach." "a photo of face surrounded by sunflowers" "a photo of face in red shirt" --placeholder_token face \ 7 | --num_samples 6 --learning_rate 1.6e-5 --train_text_encoder --mixed_precision bf16 \ 8 | --output_dir temp/samples_lr16_reg_0.1_f15_asuka_bf16 9 | 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import math 4 | import os 5 | import random 6 | from pathlib import Path 7 | from typing import Optional 8 | 9 | import numpy as np 10 | import torch 11 | import torch.nn.functional as F 12 | import torch.utils.checkpoint 13 | from accelerate import DistributedDataParallelKwargs 14 | import PIL 15 | from accelerate import Accelerator 16 | from accelerate.logging import get_logger 17 | from accelerate.utils import set_seed 18 | from diffusers.optimization import get_scheduler 19 | from packaging import version 20 | from PIL import Image 21 | from torchvision import transforms 22 | from tqdm.auto import tqdm 23 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 24 | from dataset import Simpledataset 25 | from diffusers.loaders import AttnProcsLayers 26 | from diffusers.models.cross_attention import LoRACrossAttnProcessor 27 | from diffusers import ( 28 | AutoencoderKL, 29 | DDPMScheduler, 30 | DiffusionPipeline, 31 | DPMSolverMultistepScheduler, 32 | UNet2DConditionModel, 33 | ) 34 | from copy import deepcopy 35 | import itertools 36 | from encoder_model import IDEncoder 37 | logger = get_logger(__name__) 38 | 39 | def train_step(accelerator,batch,vae,noise_scheduler,id_encoder,text_encoder,unet,optimizer,lr_scheduler=None,reg_weight=0.01): 40 | if 'latents' in batch: 41 | latents=batch['latents'] 42 | else: 43 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 44 | latents = latents * 0.18215 45 | noise = torch.randn(latents.shape,device=latents.device,dtype=latents.dtype) 46 | bsz = latents.shape[0] 47 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long() 48 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 49 | if 'text_encoder_states' in batch: 50 | text_encoder_states=batch['text_encoder_states'] 51 | else: 52 | text_encoder_states=text_encoder(batch["input_ids"])[0] 53 | id_f=id_encoder(batch,noisy_latents,timesteps,text_encoder_states) 54 | loss_reg=torch.mean(torch.norm(id_f,dim=1)**2) 55 | input_placeholder_pos=batch["input_placeholder_pos"].unsqueeze(-1) 56 | input_ids=batch["input_ids"] 57 | input_shape = input_ids.size() 58 | input_ids = input_ids.view(-1, input_shape[-1]) 59 | hidden_states = text_encoder.text_model.embeddings(input_ids=input_ids, position_ids=None) 60 | id_f=id_f.unsqueeze(1) 61 | hidden_states=id_f*input_placeholder_pos*0.1+hidden_states #mul 0.1 62 | #print("id_f",id_f.shape,"input_placeholder_pos",input_placeholder_pos.shape,"hidden_states",hidden_states.shape) 63 | bsz, seq_len = input_shape 64 | causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(hidden_states.device) 65 | encoder_outputs = text_encoder.text_model.encoder( 66 | inputs_embeds=hidden_states, 67 | attention_mask=None, 68 | causal_attention_mask=causal_attention_mask 69 | ) 70 | last_hidden_state = encoder_outputs[0] 71 | encoder_hidden_states = text_encoder.text_model.final_layer_norm(last_hidden_state) 72 | # Predict the noise residual 73 | if noise_scheduler.config.prediction_type == "epsilon": 74 | target = noise 75 | elif noise_scheduler.config.prediction_type == "v_prediction": 76 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 77 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 78 | loss_simple = F.mse_loss(model_pred, target) 79 | loss=loss_simple+reg_weight*loss_reg 80 | accelerator.backward(loss) 81 | optimizer.step() 82 | if lr_scheduler is not None: 83 | lr_scheduler.step() 84 | optimizer.zero_grad() 85 | return dict(loss_simple=loss_simple.detach().item(),loss_reg=loss_reg.detach().item()) 86 | 87 | 88 | def parse_args(): 89 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 90 | parser.add_argument( 91 | "--pretrained_model_name_or_path", 92 | type=str, 93 | default="stabilityai/stable-diffusion-2-1", 94 | help="Path to pretrained model or model identifier from huggingface.co/models.", 95 | ) 96 | parser.add_argument( 97 | "--images_dir", type=str, default=None, required=True, help="A folder containing the training data." 98 | ) 99 | parser.add_argument("--lora_rank",type=int,default=32) 100 | parser.add_argument("--placeholder_token", type=str, default='person') 101 | parser.add_argument("--resolution",type=int,default=768) 102 | parser.add_argument("--max_train_steps",type=int,default=None) 103 | parser.add_argument( 104 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 105 | ) 106 | parser.add_argument("--num_train_epochs", type=int, default=10) 107 | parser.add_argument("--output_dir",type=str,default='runs') 108 | parser.add_argument( 109 | "--learning_rate", 110 | type=float, 111 | default=1e-6, 112 | help="Initial learning rate (after the potential warmup period) to use.", 113 | ) 114 | parser.add_argument("--gradient_accumulation_steps",type=int,default=1) 115 | parser.add_argument( 116 | "--scale_lr", 117 | action="store_true", 118 | default=True, 119 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 120 | ) 121 | parser.add_argument( 122 | "--lr_scheduler", 123 | type=str, 124 | default="constant", 125 | help=( 126 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 127 | ' "constant", "constant_with_warmup"]' 128 | ), 129 | ) 130 | parser.add_argument( 131 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 132 | ) 133 | parser.add_argument( 134 | "--resume_from_checkpoint", 135 | type=str, 136 | default=None, 137 | help=( 138 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 139 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 140 | ), 141 | ) 142 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 143 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 144 | parser.add_argument("--adam_weight_decay", type=float, default=1e-4, help="Weight decay to use.") 145 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 146 | parser.add_argument("--reg_weight", type=float, default=0.01) 147 | parser.add_argument("--save_steps", type=int, default=50000) 148 | parser.add_argument( 149 | "--logging_dir", 150 | type=str, 151 | default="logs", 152 | help=( 153 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 154 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 155 | ), 156 | ) 157 | parser.add_argument( 158 | "--mixed_precision", 159 | type=str, 160 | default="no", 161 | choices=["no", "fp16", "bf16"], 162 | help=( 163 | "Whether to use mixed precision. Choose" 164 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." 165 | "and an Nvidia Ampere GPU." 166 | ), 167 | ) 168 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 169 | 170 | args = parser.parse_args() 171 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 172 | if env_local_rank != -1 and env_local_rank != args.local_rank: 173 | args.local_rank = env_local_rank 174 | 175 | 176 | return args 177 | 178 | 179 | def freeze_params(params): 180 | for param in params: 181 | param.requires_grad = False 182 | 183 | 184 | def main(): 185 | args = parse_args() 186 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 187 | 188 | accelerator = Accelerator( 189 | gradient_accumulation_steps=args.gradient_accumulation_steps, 190 | mixed_precision=args.mixed_precision, 191 | log_with="tensorboard", 192 | logging_dir=logging_dir 193 | #kwargs_handlers=[DistributedDataParallelKwargs(find_unused_parameters=True)] 194 | ) 195 | 196 | 197 | if accelerator.is_main_process: 198 | os.makedirs(args.output_dir, exist_ok=True) 199 | 200 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer",use_auth_token=True) 201 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder",use_auth_token=True) 202 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae",use_auth_token=True) 203 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",use_auth_token=True) 204 | ######### id_encoder 205 | id_encoder=IDEncoder(unet_model=args.pretrained_model_name_or_path) 206 | weight_dtype = torch.float32 207 | if accelerator.mixed_precision == "fp16": 208 | weight_dtype = torch.float16 209 | elif accelerator.mixed_precision == "bf16": 210 | weight_dtype = torch.bfloat16 211 | for i in [unet,vae,text_encoder]: 212 | freeze_params(i.parameters()) 213 | i.eval() 214 | i.to(accelerator.device) 215 | lora_attn_procs = {} 216 | for name in unet.attn_processors.keys(): 217 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 218 | if name.startswith("mid_block"): 219 | hidden_size = unet.config.block_out_channels[-1] 220 | elif name.startswith("up_blocks"): 221 | block_id = int(name[len("up_blocks.")]) 222 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 223 | elif name.startswith("down_blocks"): 224 | block_id = int(name[len("down_blocks.")]) 225 | hidden_size = unet.config.block_out_channels[block_id] 226 | 227 | lora_attn_procs[name] = LoRACrossAttnProcessor( 228 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim,rank=args.lora_rank 229 | ) 230 | 231 | unet.set_attn_processor(lora_attn_procs) 232 | lora_layers = AttnProcsLayers(unet.attn_processors) 233 | accelerator.register_for_checkpointing(lora_layers) 234 | 235 | if args.scale_lr: 236 | args.learning_rate = ( 237 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 238 | ) 239 | 240 | # Initialize the optimizer 241 | optimizer = torch.optim.AdamW( 242 | itertools.chain(id_encoder.parameters(),lora_layers.parameters()), 243 | lr=args.learning_rate, 244 | betas=(args.adam_beta1, args.adam_beta2), 245 | weight_decay=args.adam_weight_decay, 246 | eps=args.adam_epsilon, 247 | ) 248 | 249 | noise_scheduler = DDPMScheduler.from_config(args.pretrained_model_name_or_path, subfolder="scheduler") 250 | 251 | train_dataset = Simpledataset( 252 | args.images_dir, 253 | tokenizer, 254 | id_encoder.feature_extractor.preprocess, 255 | size=args.resolution, 256 | placeholder_token=args.placeholder_token 257 | ) 258 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) 259 | 260 | # Scheduler and math around the number of training steps. 261 | overrode_max_train_steps = False 262 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 263 | if args.max_train_steps is None: 264 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 265 | overrode_max_train_steps = True 266 | 267 | lr_scheduler = get_scheduler( 268 | args.lr_scheduler, 269 | optimizer=optimizer, 270 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 271 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 272 | ) 273 | 274 | id_encoder,lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 275 | id_encoder,lora_layers, optimizer, train_dataloader, lr_scheduler 276 | ) 277 | id_encoder.module.feature_extractor.set_device(accelerator.device, dtype=weight_dtype) 278 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 279 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 280 | if overrode_max_train_steps: 281 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 282 | # Afterwards we recalculate our number of training epochs 283 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 284 | global_step = 0 285 | first_epoch = 0 286 | if args.resume_from_checkpoint: 287 | if args.resume_from_checkpoint != "latest": 288 | path = os.path.basename(args.resume_from_checkpoint) 289 | else: 290 | # Get the mos recent checkpoint 291 | dirs = os.listdir(args.output_dir) 292 | dirs = [d for d in dirs if d.startswith("checkpoint")] 293 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 294 | path = dirs[-1] if len(dirs) > 0 else None 295 | 296 | if path is None: 297 | accelerator.print( 298 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 299 | ) 300 | args.resume_from_checkpoint = None 301 | else: 302 | accelerator.print(f"Resuming from checkpoint {path}") 303 | accelerator.load_state(os.path.join(args.output_dir, path)) 304 | global_step = int(path.split("-")[1]) 305 | resume_global_step = global_step * args.gradient_accumulation_steps 306 | first_epoch = global_step // num_update_steps_per_epoch 307 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 308 | 309 | torch.cuda.empty_cache() 310 | #accelerator.free_memory() 311 | # We need to initialize the trackers we use, and also store our configuration. 312 | # The trackers initializes automatically on the main process. 313 | if accelerator.is_main_process: 314 | accelerator.init_trackers("personalization_encoder", config=vars(args)) 315 | 316 | # Train! 317 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 318 | 319 | logger.info("***** Running training *****") 320 | logger.info(f" Num examples = {len(train_dataset)}") 321 | logger.info(f" Num Epochs = {args.num_train_epochs}") 322 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 323 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 324 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 325 | logger.info(f" Total optimization steps = {args.max_train_steps}") 326 | # Only show the progress bar once on each machine. 327 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 328 | progress_bar.set_description("Steps") 329 | for epoch in range(first_epoch,args.num_train_epochs): 330 | unet.train() 331 | id_encoder.train() 332 | for step, batch in enumerate(train_dataloader): 333 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 334 | if step % args.gradient_accumulation_steps == 0: 335 | progress_bar.update(1) 336 | continue 337 | losses=train_step(accelerator,batch,vae,noise_scheduler,id_encoder,text_encoder,unet,optimizer,lr_scheduler,args.reg_weight) 338 | progress_bar.update(1) 339 | global_step += 1 340 | if global_step % args.save_steps == 0: 341 | if accelerator.is_main_process: 342 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 343 | accelerator.save_state(save_path) 344 | logger.info(f"Saved state to {save_path}") 345 | 346 | logs = {"loss_simple": losses['loss_simple'],"loss_reg": losses['loss_reg'], "lr": lr_scheduler.get_last_lr()[0]} 347 | progress_bar.set_postfix(**logs) 348 | if global_step >= args.max_train_steps: 349 | break 350 | accelerator.wait_for_everyone() 351 | 352 | # Create the pipeline using using the trained modules and save it. 353 | if accelerator.is_main_process: 354 | save_path=os.path.join(args.output_dir,f"checkpoint-{global_step}") 355 | os.makedirs(save_path,exist_ok=True) 356 | accelerator.unwrap_model(unet).save_attn_procs(os.path.join(save_path,"LORA_module")) 357 | state_dict = accelerator.unwrap_model(id_encoder).state_dict() 358 | torch.save(state_dict, os.path.join(save_path, "id_encoder.pth")) 359 | accelerator.end_training() 360 | 361 | def load_model(path,model="stabilityai/stable-diffusion-2-1",final_ckpt=True): 362 | id_encoder=IDEncoder(unet_model=model) 363 | pipe = DiffusionPipeline.from_pretrained(model) 364 | if final_ckpt: 365 | id_encoder.load_state_dict(torch.load(os.path.join(path,"id_encoder.pth"),map_location='cpu')) 366 | pipe.unet.load_attn_procs(os.path.join(path,"LORA_module")) 367 | else: 368 | id_encoder.load_state_dict(torch.load(os.path.join(path,"pytorch_model.bin"),map_location='cpu')) 369 | pipe.unet.load_attn_procs(os.path.join(path,"pytorch_model_1.bin")) 370 | return id_encoder,pipe 371 | 372 | def finetune(image,pipe,id_encoder,mixed_precision='no',learning_rate=1e-6,train_batch_size=1,train_steps=15,\ 373 | text='a photo of person',placeholder_token='person',resize=768,prompts=None,output_dir='',num_samples=2,train_text_encoder=False,reg_weight=0.1): 374 | tokenizer=pipe.tokenizer 375 | text_encoder=pipe.text_encoder 376 | vae=pipe.vae.eval() 377 | unet=pipe.unet 378 | accelerator = Accelerator(mixed_precision=mixed_precision) 379 | raw_image=image 380 | with torch.no_grad(): 381 | image = image.resize((resize, resize)) 382 | image = np.array(image).astype(np.uint8) 383 | image = (image / 127.5 - 1.0).astype(np.float32) 384 | image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0) 385 | latents = vae.encode(image).latent_dist.sample().repeat(train_batch_size,1,1,1) 386 | latents = latents * 0.18215 387 | input_ids=tokenizer( 388 | text, 389 | padding="max_length", 390 | truncation=True, 391 | max_length=tokenizer.model_max_length, 392 | return_tensors="pt", 393 | ).input_ids.repeat(train_batch_size,1) 394 | place_holder_id=tokenizer.encode(placeholder_token)[1] 395 | text_encoder_states=text_encoder(input_ids)[0] 396 | input_placeholder_pos=input_ids==place_holder_id 397 | weight_dtype = torch.float32 398 | if accelerator.mixed_precision == "fp16": 399 | weight_dtype = torch.float16 400 | elif accelerator.mixed_precision == "bf16": 401 | weight_dtype = torch.bfloat16 402 | for module in [unet, text_encoder,id_encoder]: 403 | module.train() 404 | for param in module.parameters(): 405 | param.requires_grad = True 406 | unet.enable_gradient_checkpointing() 407 | #learning_rate = learning_rate *train_batch_size * accelerator.num_processes 408 | optimizer = torch.optim.AdamW(itertools.chain(id_encoder.parameters(),unet.parameters(),text_encoder.text_model.encoder.parameters(),text_encoder.text_model.final_layer_norm.parameters())\ 409 | ,lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-2,eps=1e-8) 410 | noise_scheduler = DDPMScheduler.from_config(pipe.scheduler.config) 411 | image_features=id_encoder.feature_extractor.encode_image(id_encoder.feature_extractor.preprocess_images(raw_image)).repeat(train_batch_size,1) 412 | id_encoder.feature_extractor.set_device(accelerator.device, dtype=weight_dtype,only_unet=True) 413 | if train_text_encoder: 414 | text_encoder.gradient_checkpointing_enable() 415 | id_encoder,unet,encoder_layers,final_layer_norm, optimizer = accelerator.prepare(id_encoder,unet,text_encoder.text_model.encoder,text_encoder.text_model.final_layer_norm,optimizer) 416 | text_encoder.text_model.encoder=encoder_layers 417 | text_encoder.text_model.final_layer_norm=final_layer_norm 418 | pipe.text_encoder.text_model.embeddings.to(device=accelerator.device,dtype=weight_dtype) 419 | else: 420 | id_encoder,unet, optimizer = accelerator.prepare(id_encoder,unet,optimizer) 421 | text_encoder.eval().to(device=accelerator.device) 422 | for param in text_encoder.parameters(): 423 | param.requires_grad = False 424 | batch=dict(latents=latents.to(accelerator.device),text_encoder_states=text_encoder_states.to(accelerator.device),input_ids=input_ids.to(accelerator.device),\ 425 | input_placeholder_pos=input_placeholder_pos.to(accelerator.device),image_features=image_features.to(accelerator.device)) 426 | progress_bar = tqdm(range(train_steps), disable=not accelerator.is_local_main_process) 427 | progress_bar.set_description("Steps") 428 | for step in range(train_steps): 429 | loss=train_step(accelerator,batch,vae,noise_scheduler,id_encoder,text_encoder,unet,optimizer,None,reg_weight=reg_weight) 430 | progress_bar.update(1) 431 | progress_bar.set_postfix(**loss) 432 | accelerator.wait_for_everyone() 433 | if train_text_encoder: 434 | pipe.text_encoder.text_model.encoder=accelerator.unwrap_model(text_encoder.text_model.encoder).eval().to(accelerator.device,dtype=weight_dtype) 435 | pipe.text_encoder.text_model.final_layer_norm=accelerator.unwrap_model(text_encoder.text_model.final_layer_norm).eval().to(accelerator.device,dtype=weight_dtype) 436 | else: 437 | text_encoder.to(dtype=weight_dtype) 438 | pipe.unet=accelerator.unwrap_model(unet).eval().to(accelerator.device,dtype=weight_dtype) 439 | id_encoder=accelerator.unwrap_model(id_encoder).eval().to(accelerator.device,dtype=weight_dtype) 440 | vae.to(device=accelerator.device,dtype=weight_dtype) 441 | if type(prompts)==str: 442 | sample(id_encoder,pipe,(resize,resize),prompts,accelerator,num_samples,output_dir,place_holder_id,batch['image_features'][:1].to(dtype=weight_dtype),batch['latents'][:1].to(dtype=weight_dtype)) 443 | elif type(prompts)==list: 444 | for prompt in prompts: 445 | sample(id_encoder,pipe,(resize,resize),prompt,accelerator,num_samples,output_dir,place_holder_id,batch['image_features'][:1].to(dtype=weight_dtype),batch['latents'][:1].to(dtype=weight_dtype)) 446 | 447 | @torch.no_grad() 448 | def sample(id_encoder,pipe,image_size,prompt,accelerator,num_samples,output_dir,place_holder_id,image_features,source_latents,guidance_scale=7.5,num_inference_steps=100): 449 | #pipe.scheduler=DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 450 | height,width = image_size 451 | if prompt is not None and isinstance(prompt, str): 452 | prompt=[prompt] 453 | if prompt is not None and isinstance(prompt, list): 454 | batch_size = len(prompt) 455 | device = accelerator.device 456 | weight_dtype=pipe.text_encoder.dtype 457 | text_inputs = pipe.tokenizer( 458 | prompt, 459 | padding="max_length", 460 | max_length=pipe.tokenizer.model_max_length, 461 | truncation=True, 462 | return_tensors="pt", 463 | ).input_ids 464 | text_inputs=torch.repeat_interleave(text_inputs,num_samples,dim=0) 465 | do_classifier_free_guidance = guidance_scale > 1.0 466 | if do_classifier_free_guidance: 467 | uncond_tokens = [""] * batch_size 468 | uncond_input = pipe.tokenizer( 469 | uncond_tokens, 470 | padding="max_length", 471 | max_length=pipe.tokenizer.model_max_length, 472 | truncation=True, 473 | return_tensors="pt", 474 | ) 475 | negative_prompt_embeds = torch.repeat_interleave(pipe.text_encoder(uncond_input.input_ids.to(device),None)[0],num_samples,dim=0) 476 | pipe.scheduler.set_timesteps(num_inference_steps, device=device) 477 | timesteps = pipe.scheduler.timesteps 478 | num_warmup_steps = len(timesteps) - num_inference_steps * pipe.scheduler.order 479 | num_channels_latents = pipe.unet.in_channels 480 | latents = pipe.prepare_latents(batch_size * num_samples,num_channels_latents,height,width,weight_dtype,device,None) 481 | source_latents=source_latents.repeat(batch_size * num_samples,1,1,1) 482 | input_placeholder_pos=(text_inputs==place_holder_id).to(device) #BxL 483 | if accelerator.is_main_process: 484 | print(input_placeholder_pos) 485 | os.makedirs(output_dir,exist_ok=True) 486 | prompt_embeds = pipe.text_encoder(text_inputs.to(device),None)[0] 487 | batch_input=dict(image_features=image_features.repeat(num_samples,1)) 488 | extra_step_kwargs = pipe.prepare_extra_step_kwargs(None, 0.0) 489 | noise = torch.randn(latents.shape,device=latents.device,dtype=latents.dtype) 490 | with pipe.progress_bar(total=num_inference_steps) as progress_bar: 491 | for i, t in enumerate(timesteps): 492 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 493 | latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t) 494 | #id_f=id_encoder(batch_input,pipe.scheduler.add_noise(source_latents, noise, t),t,prompt_embeds) #BxD 495 | id_f=id_encoder(batch_input,latents,t,prompt_embeds)*0.1 ##mul 0.1 496 | if do_classifier_free_guidance: 497 | noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=torch.cat([negative_prompt_embeds, prompt_embeds+id_f.unsqueeze(1)*input_placeholder_pos.unsqueeze(-1)])).sample 498 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 499 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 500 | else: 501 | noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds+id_f.unsqueeze(1)*input_placeholder_pos.unsqueeze(-1)).sample 502 | latents = pipe.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 503 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % pipe.scheduler.order == 0): 504 | progress_bar.update() 505 | image = pipe.decode_latents(latents) 506 | image = pipe.numpy_to_pil(image) 507 | for i in range(batch_size): 508 | for j in range(num_samples): 509 | rank=num_samples*accelerator.process_index+j 510 | image[i*num_samples+j].save(os.path.join(output_dir,f'{prompt[i]}-{rank}.png')) 511 | 512 | 513 | if __name__ == "__main__": 514 | main() 515 | -------------------------------------------------------------------------------- /train_encoder.sh: -------------------------------------------------------------------------------- 1 | accelerate launch train.py --pretrained_model_name_or_path "stabilityai/stable-diffusion-2-1" --images_dir /public/zhaohanqing/dataset/ffhq/ --lr_scheduler constant_with_warmup \ 2 | --train_batch_size 1 --resolution 768 --scale_lr --output_dir new_train_ffhq --num_train_epochs 10 --save_steps 5000 --resume_from_checkpoint new_train_ffhq/checkpoint-115000 -------------------------------------------------------------------------------- /train_encoder_chill.sh: -------------------------------------------------------------------------------- 1 | accelerate launch train.py --pretrained_model_name_or_path "/ssd/zhaohanqing/msws/diffusers/examples/model_lab/chilloutmix" --images_dir /public/zhaohanqing/dataset/ffhq/ --lr_scheduler constant_with_warmup \ 2 | --train_batch_size 5 --resolution 512 --scale_lr --output_dir chillout_train_ffhq --num_train_epochs 10 --save_steps 10000 --learning_rate 1.6e-6 --lr_scheduler cosine_with_restarts --reg_weight 0.01 --lora_rank 64 --placeholder_token face --------------------------------------------------------------------------------