├── LICENSE ├── README.md ├── attention_control.py ├── guidance_config_example.json ├── inference.py ├── layout_guidance.py ├── pics ├── 2_subjects.jpg ├── 3_subjects.jpg ├── 4_subjects.jpg ├── arch.jpg ├── architecture.png ├── layout_example.png ├── more.png ├── pipeline.gif └── teaser.png ├── requirements.txt ├── train_cones2.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 DAMO Vision Intelligence Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cones 2 2 | 3 | Official repo for [Cones 2: Customizable Image Synthesis with Multiple Subjects](https://arxiv.org/abs/2305.19327) | [Project Page](https://cones-page.github.io/). 4 |
5 |
6 |

7 | 8 |

9 |
10 | 11 | [Cones 2](https://cones-page.github.io/) allows you to represent a specific subject as a **residual embedding** by 12 | fine-tuning text encoder in a pre-trained text-to-image diffusion model, such as 13 | [Stable Diffusion](https://github.com/CompVis/stable-diffusion). After tuning, we only need to save the residual between tuned text-encoder and frozen 14 | one. Thus, the storage space required for each additional subject is only **5 KB**. This step only takes about 20~30 minutes on a single 80G A100 15 | GPU for each subject. 16 | 17 | When sampling, our **layout guidance sampling** method further allows you to employ a easy-to-obtain layout as guidance for multiple subjects 18 | arrangement as shown in the following figure. 19 | 20 |

21 | 22 | 23 |

24 | 25 | ## Results 26 | 27 | All results are synthesized by pre-trained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) 28 | models and our customized residual embeddings. We show diverse results on various categories of images, including 29 | scene, pet, personal toy, human _etc._. For more results, please refer to our [paper](https://arxiv.org/abs/2305.19327) or [website](https://cones-page.github.io/). 30 | 31 | ### Two-Subject Results 32 | ![two subject cases](pics/2_subjects.jpg "two subject cases") 33 | 34 | ### Three-Subject Results 35 | ![three subject cases](pics/3_subjects.jpg "three subject cases") 36 | 37 | ### Four-Subject Results 38 | ![four subject cases](pics/4_subjects.jpg "four subject cases") 39 | 40 | ### More Results 41 | ![more challenging cases](pics/more.png "more challenging cases") 42 | 43 | ## Method 44 | 45 | ![method](pics/arch.jpg "method") 46 | 47 | (a) Given few-shot images of the customized subject, we fine-tune the text encoder to learn a residual embedding on top 48 | of the base embedding of raw subject. (b) Based on the residual embeddings, we then propose to employ layout as the 49 | spatial guidance for subject arrangement into the attention maps. After that, we could strengthen the signal of target 50 | subjects and weaken the signal of irrelevant subjects. For more details, please refer to our [paper](https://arxiv.org/abs/2305.19327). 51 | 52 | ## Getting Started 53 | 54 | ### Installing the dependencies 55 | 56 | The implementation of Cones 2 is based entirely on the [diffuser](https://github.com/huggingface/diffusers/tree/main). 57 | Before running out code, make sure to install the library's training dependencies. To do this, execute the following 58 | steps in a new virtual environment: 59 | 60 | ```bash 61 | git clone https://github.com/damo-vilab/Cones-V2.git 62 | cd Cones-V2 63 | pip install -r requirements.txt 64 | ``` 65 | 66 | Then initialize an [🤗—Accelerate](https://github.com/huggingface/accelerate/) environment with: 67 | 68 | ```bash 69 | accelerate config 70 | ``` 71 | 72 | Or for a default accelerate configuration without answering questions about your environment 73 | 74 | ```bash 75 | accelerate config default 76 | ``` 77 | 78 | ## Training (Flower example) 79 | 80 | Firstly, let's download dataset from 81 | [here](https://modelscope.cn/api/v1/datasets/zyf619/cones2_residual/repo?Revision=master&FilePath=data.zip) 82 | and unzip it to `./data`. Now a few images of Flower (download to `./data/flower`) are used to learn its 83 | customized residual embedding. 84 | 85 | ```bash 86 | export MODEL_NAME='path-to-stable-diffusion-v2-1' 87 | export INSTANCE_DIR="./data/flower" 88 | export OUTPUT_DIR="path-to-save-model" 89 | accelerate launch train_cones2.py \ 90 | --pretrained_model_name_or_path=$MODEL_NAME \ 91 | --instance_data_dir=$INSTANCE_DIR \ 92 | --instance_prompt="flower" \ 93 | --token_num=1 \ 94 | --output_dir=$OUTPUT_DIR \ 95 | --resolution=768 \ 96 | --train_batch_size=1 \ 97 | --gradient_accumulation_steps=1 \ 98 | --learning_rate=5e-6 \ 99 | --lr_scheduler="constant" \ 100 | --lr_warmup_steps=0 \ 101 | --max_train_steps=4000 \ 102 | --loss_rate_first=1e-2 \ 103 | --loss_rate_second=1e-3 104 | ``` 105 | 106 | ## Inference 107 | 108 | Once you have trained several residual embeddings of different subjects using the above command, you can run our 109 | **layout guidance sampling method** simply using [inference.py](inference.py). We provide several pre-trained 110 | models for quick validation. 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 |
residual embeddingraw tokendownload
barnbarnbarn.pt
white dogdogdog.pt
flowerflowerflower.pt
lakelakelake.pt
mugmugmug.pt
sunglassessunglassessunglasses.pt
bagbagbag.pt
robotrobot toyrobot_toy.pt
159 | 160 | 161 | Remember to provide a pre-defined layout like [layout_example.png](pics/layout_example.png) and a 162 | [json](guidance_config_example.json) file with the info about the details of the inference settings. The 163 | [json](guidance_config_example.json) file should include the following information: 164 | 165 | - "prompt": the text prompt you want to generate. 166 | - "residual_dict": the paths to all the required residual embeddings. 167 | - "color_context": the color information of different regions in the layout and their corresponding subjects, along with 168 | the weight for strengthening the signal of target subject. (default: 2.5). 169 | - "guidance_steps": the number of steps of the layout guidance. 170 | - "guidance_weight": the strength of the layout guidance (default: 0.08, we recommond 0.05 ~ 0.10). 171 | - "weight_negative": the weight for weakening the signal of irrelevant subject. 172 | - "layout": the path to user-defined layout image. 173 | - "subject_list": the list containing all the subjects to be customized and their corresponding positions in the prompt. 174 | 175 | Then you can simply run inference script with: 176 | 177 | ```bash 178 | python inference.py --pretrained_model_name_or_path /path/to/stable-diffusion-2-1 --inference_config guidance_config_example.json 179 | ``` 180 | 181 | ## References 182 | 183 | ```BibTeX 184 | @article{liu2023cones, 185 | title={Cones 2: Customizable Image Synthesis with Multiple Subjects}, 186 | author={Liu, Zhiheng and Zhang, Yifei and Shen, Yujun and Zheng, Kecheng and Zhu, Kai and Feng, Ruili and Liu, Yu and Zhao, Deli and Zhou, Jingren and Cao, Yang}, 187 | journal={arXiv preprint arXiv:2305.19327}, 188 | year={2023} 189 | } 190 | ``` 191 | 192 | ### Acknowledgements 193 | We thank [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) and [diffuser](https://github.com/huggingface/diffusers/tree/main) for providing pre-trained model and an open-source codebase. 194 | -------------------------------------------------------------------------------- /attention_control.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from diffusers.models.cross_attention import CrossAttention 7 | 8 | 9 | class Cones2AttnProcessor: 10 | def __init__(self): 11 | super().__init__() 12 | 13 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): 14 | batch_size, sequence_length, _ = hidden_states.shape 15 | query = attn.to_q(hidden_states) 16 | is_dict_format = True 17 | if encoder_hidden_states is not None: 18 | try: 19 | encoder_hidden = encoder_hidden_states["CONDITION_TENSOR"] 20 | except: 21 | encoder_hidden = encoder_hidden_states 22 | is_dict_format = False 23 | if attn.cross_attention_norm: 24 | encoder_hidden = attn.norm_cross(encoder_hidden) 25 | else: 26 | encoder_hidden = hidden_states 27 | 28 | key = attn.to_k(encoder_hidden) 29 | value = attn.to_v(encoder_hidden) 30 | 31 | query = attn.head_to_batch_dim(query) 32 | key = attn.head_to_batch_dim(key) 33 | value = attn.head_to_batch_dim(value) 34 | 35 | attention_scores = torch.matmul(query, key.transpose(-1, -2)) 36 | attention_size_of_img = attention_scores.size()[-2] 37 | 38 | if attention_scores.size()[2] == 77: 39 | if is_dict_format: 40 | f = encoder_hidden_states["function"] 41 | try: 42 | w = encoder_hidden_states[f"CA_WEIGHT_{attention_size_of_img}"] 43 | except KeyError: 44 | w = encoder_hidden_states[f"CA_WEIGHT_ORIG"] 45 | if not isinstance(w, int): 46 | img_h, img_w, nc = w.shape 47 | ratio = math.sqrt(img_h * img_w / attention_size_of_img) 48 | w = F.interpolate(w.permute(2, 0, 1).unsqueeze(0), scale_factor=1 / ratio, mode="bilinear", 49 | align_corners=True) 50 | w = F.interpolate(w.reshape(1, nc, -1), size=(attention_size_of_img,), mode='nearest').permute( 51 | 2, 1, 0).squeeze() 52 | else: 53 | w = 0 54 | if type(w) is int and w == 0: 55 | sigma = encoder_hidden_states["SIGMA"] 56 | cross_attention_weight = f(w, sigma, attention_scores) 57 | else: 58 | bias = torch.zeros_like(w) 59 | bias[torch.where(w > 0)] = attention_scores.std() * 0 60 | sigma = encoder_hidden_states["SIGMA"] 61 | cross_attention_weight = f(w, sigma, attention_scores) 62 | cross_attention_weight = cross_attention_weight + bias 63 | else: 64 | cross_attention_weight = 0.0 65 | else: 66 | cross_attention_weight = 0.0 67 | 68 | attention_scores = (attention_scores + cross_attention_weight) * attn.scale 69 | attention_probs = attention_scores.softmax(dim=-1) 70 | 71 | hidden_states = torch.matmul(attention_probs, value) 72 | hidden_states = attn.batch_to_head_dim(hidden_states) 73 | 74 | # linear proj 75 | hidden_states = attn.to_out[0](hidden_states) 76 | # dropout 77 | hidden_states = attn.to_out[1](hidden_states) 78 | 79 | return hidden_states 80 | 81 | 82 | def register_attention_control(unet): 83 | attn_procs = {} 84 | for name in unet.attn_processors.keys(): 85 | attn_procs[name] = Cones2AttnProcessor() 86 | 87 | unet.set_attn_processor(attn_procs) 88 | -------------------------------------------------------------------------------- /guidance_config_example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "prompt":"a mug and a dog on the beach", 4 | "residual_dict": { 5 | "dog":"residuals/dog.pt", 6 | "mug":"residuals/mug.pt", 7 | "flower":"residuals/flower.pt", 8 | "sunglasses":"residuals/sunglasses.pt", 9 | "lake":"residuals/lake.pt", 10 | "barn":"residuals/barn.pt" 11 | }, 12 | "color_context":{ 13 | "255,192,0":["mug",2.5], 14 | "255,0,0":["dog",2.5] 15 | }, 16 | "guidance_steps":50, 17 | "guidance_weight":0.08, 18 | "weight_negative":-1e8, 19 | "layout":"layouts/layout_example.png", 20 | "subject_list":[["mug",2],["dog",5]] 21 | } 22 | ] -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | 5 | from PIL import Image 6 | 7 | from diffusers import StableDiffusionPipeline 8 | 9 | from layout_guidance import layout_guidance_sampling 10 | from utils import image_grid 11 | 12 | 13 | def get_args(): 14 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 15 | parser.add_argument( 16 | "--pretrained_model_name_or_path", 17 | type=str, 18 | default=None, 19 | required=True, 20 | help="Path to pretrained model or model identifier from huggingface.co/models.", 21 | ) 22 | parser.add_argument( 23 | "--inference_config", 24 | type=str, 25 | default=None, 26 | required=True, 27 | help='Path to a json file containing settings for inference, containing "residual_path", "prompt", ' 28 | '"color_context", "edit_tokens", "layout", "subject_list".', 29 | ) 30 | parser.add_argument( 31 | "--output_dir", 32 | type=str, 33 | default="inference_results", 34 | help="The output directory where the model predictions and checkpoints will be written.", 35 | ) 36 | return parser.parse_args() 37 | 38 | 39 | def main(args): 40 | # Initialize pre-trained Stable Diffusion pipeline. 41 | pipeline = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path).to("cuda") 42 | 43 | # Load the settings required for inference from the configuration file. 44 | with open(args.inference_config, "r") as f: 45 | inference_cfg = json.load(f) 46 | 47 | prompt = inference_cfg[0]["prompt"] 48 | residual_dict = inference_cfg[0]["residual_dict"] 49 | subject_list = inference_cfg[0]["subject_list"] 50 | guidance_steps = inference_cfg[0]["guidance_steps"] 51 | guidance_weight = inference_cfg[0]["guidance_weight"] 52 | weight_negative = inference_cfg[0]["weight_negative"] 53 | layout = Image.open(inference_cfg[0]["layout"]).resize((768, 768)).convert("RGB") 54 | color_context = inference_cfg[0]["color_context"] 55 | subject_color_dict = {tuple(map(int, key.split(','))): value for key, value in color_context.items()} 56 | 57 | if args.output_dir is not None: 58 | os.makedirs(args.output_dir, exist_ok=True) 59 | subject_info = '_'.join([s[0] for s in sorted(subject_list)]) 60 | prompt_info = '_'.join(prompt.split()) 61 | save_dir = os.path.join(args.output_dir, subject_info, prompt_info) 62 | os.makedirs(save_dir, exist_ok=True) 63 | 64 | images = [] 65 | 66 | for i in range(4): 67 | image = layout_guidance_sampling( 68 | seed=i, 69 | device="cuda:0", 70 | resolution=768, 71 | pipeline=pipeline, 72 | prompt=prompt, 73 | residual_dict=residual_dict, 74 | subject_list=subject_list, 75 | subject_color_dict=subject_color_dict, 76 | layout=layout, 77 | cfg_scale=7.5, 78 | inference_steps=50, 79 | guidance_steps=guidance_steps, 80 | guidance_weight=guidance_weight, 81 | weight_negative=weight_negative, 82 | ) 83 | 84 | image.save(os.path.join(save_dir, f"{i}.png")) 85 | images.append(image) 86 | 87 | all_image = image_grid(images=images, rows=2, cols=2) 88 | all_image.save(os.path.join(save_dir, f"all_images.png")) 89 | 90 | 91 | if __name__ == "__main__": 92 | args = get_args() 93 | main(args) 94 | -------------------------------------------------------------------------------- /layout_guidance.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | 5 | from tqdm.auto import tqdm 6 | 7 | from diffusers import LMSDiscreteScheduler 8 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 9 | 10 | from attention_control import register_attention_control 11 | from utils import latents_to_images, downsampling 12 | 13 | 14 | @torch.no_grad() 15 | def layout_guidance_sampling(seed, 16 | device, 17 | resolution, 18 | pipeline, 19 | prompt="", 20 | residual_dict=None, 21 | subject_list=None, 22 | subject_color_dict=None, 23 | layout=None, 24 | cfg_scale=7.5, 25 | inference_steps=50, 26 | guidance_steps=50, 27 | guidance_weight=0.05, 28 | weight_negative=-1e8): 29 | vae = pipeline.vae 30 | unet = pipeline.unet 31 | text_encoder = pipeline.text_encoder 32 | tokenizer = pipeline.tokenizer 33 | unconditional_input_prompt = "" 34 | scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) 35 | scheduler.set_timesteps(inference_steps, device=device) 36 | if guidance_steps > 0: 37 | guidance_steps = min(guidance_steps, inference_steps) 38 | scheduler_guidance = LMSDiscreteScheduler( 39 | beta_start=0.00085, 40 | beta_end=0.012, 41 | beta_schedule="scaled_linear", 42 | num_train_timesteps=1000, 43 | ) 44 | scheduler_guidance.set_timesteps(guidance_steps, device=device) 45 | 46 | # Process input prompt text 47 | text_input = tokenizer( 48 | [prompt], 49 | padding="max_length", 50 | max_length=tokenizer.model_max_length, 51 | truncation=True, 52 | return_tensors="pt", 53 | ) 54 | 55 | # Edit text embedding conditions with residual token embeddings. 56 | cond_embeddings = text_encoder(text_input.input_ids.to(device))[0] 57 | for name, token in subject_list: 58 | residual_token_embedding = torch.load(residual_dict[name]) 59 | cond_embeddings[0][token] += residual_token_embedding.reshape(1024) 60 | 61 | # Process unconditional input "" for classifier-free guidance. 62 | max_length = text_input.input_ids.shape[-1] 63 | uncond_input = tokenizer( 64 | [unconditional_input_prompt], 65 | padding="max_length", 66 | max_length=max_length, 67 | return_tensors="pt", 68 | ) 69 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0] 70 | 71 | register_attention_control(unet) 72 | 73 | # Calculate the hidden features for each cross attention layer. 74 | hidden_states, uncond_hidden_states = _extract_cross_attention(tokenizer, 75 | device, 76 | layout, 77 | subject_color_dict, 78 | text_input, 79 | weight_negative) 80 | hidden_states["CONDITION_TENSOR"] = cond_embeddings 81 | uncond_hidden_states["CONDITION_TENSOR"] = uncond_embeddings 82 | hidden_states["function"] = lambda w, x, qk: (guidance_weight * w * math.log(1 + x ** 2)) * qk.std() 83 | uncond_hidden_states["function"] = lambda w, x, qk: 0.0 84 | 85 | # Sampling the initial latents. 86 | latent_size = (1, unet.in_channels, resolution // 8, resolution // 8) 87 | latents = torch.randn(latent_size, generator=torch.manual_seed(seed)) 88 | latents = latents.to(device) 89 | latents = latents * scheduler.init_noise_sigma 90 | 91 | for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)): 92 | # Improve the harmony of generated images by self-recurrence. 93 | if i < guidance_steps: 94 | loop = 2 95 | else: 96 | loop = 1 97 | for k in range(loop): 98 | if i < guidance_steps: 99 | sigma = scheduler_guidance.sigmas[i] 100 | latent_model_input = scheduler.scale_model_input(latents, t) 101 | 102 | hidden_states.update({ 103 | "SIGMA": sigma, 104 | }) 105 | 106 | noise_pred_text = unet( 107 | latent_model_input, 108 | t, 109 | encoder_hidden_states=hidden_states, 110 | ).sample 111 | 112 | uncond_hidden_states.update({ 113 | "SIGMA": sigma, 114 | }) 115 | 116 | noise_pred_uncond = unet( 117 | latent_model_input, 118 | t, 119 | encoder_hidden_states=uncond_hidden_states, 120 | ).sample 121 | 122 | noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) 123 | latents = scheduler.step(noise_pred, t, latents, 1).prev_sample 124 | 125 | # Self-recurrence. 126 | if k < 1 and loop > 1: 127 | noise_recurrent = torch.randn(latents.shape).to(device) 128 | noise_scale = ((scheduler.sigmas[i] ** 2 - scheduler.sigmas[i + 1] ** 2) ** 0.5) 129 | latents = latents + noise_scale * noise_recurrent 130 | else: 131 | latent_model_input = scheduler.scale_model_input(latents, t) 132 | noise_pred_text = unet( 133 | latent_model_input, 134 | t, 135 | encoder_hidden_states=cond_embeddings, 136 | ).sample 137 | 138 | latent_model_input = scheduler.scale_model_input(latents, t) 139 | 140 | noise_pred_uncond = unet( 141 | latent_model_input, 142 | t, 143 | encoder_hidden_states=uncond_embeddings, 144 | ).sample 145 | 146 | noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond) 147 | latents = scheduler.step(noise_pred, t, latents, 1).prev_sample 148 | 149 | edited_images = latents_to_images(vae, latents) 150 | 151 | return StableDiffusionPipelineOutput(images=edited_images, nsfw_content_detected=None).images[0] 152 | 153 | 154 | def _tokens_img_attention_weight(img_context_seperated, tokenized_texts, ratio: int = 8, original_shape=False): 155 | token_lis = tokenized_texts["input_ids"][0].tolist() 156 | w, h = img_context_seperated[0][1].shape 157 | 158 | w_r, h_r = round(w / ratio), round(h / ratio) 159 | ret_tensor = torch.zeros((w_r * h_r, len(token_lis)), dtype=torch.float32) 160 | for v_as_tokens, img_where_color in img_context_seperated: 161 | is_in = 0 162 | for idx, tok in enumerate(token_lis): 163 | if token_lis[idx: idx + len(v_as_tokens)] == v_as_tokens: 164 | is_in = 1 165 | 166 | ret_tensor[:, idx: idx + len(v_as_tokens)] += ( 167 | downsampling(img_where_color, w_r, h_r) 168 | .reshape(-1, 1) 169 | .repeat(1, len(v_as_tokens)) 170 | ) 171 | 172 | if not is_in == 1: 173 | print(f"Warning ratio {ratio} : tokens {v_as_tokens} not found in text") 174 | 175 | if original_shape: 176 | ret_tensor = ret_tensor.reshape((w_r, h_r, len(token_lis))) 177 | 178 | return ret_tensor 179 | 180 | 181 | def _image_context_seperator(img, color_context: dict, _tokenizer, neg: float): 182 | ret_lists = [] 183 | 184 | if img is not None: 185 | w, h = img.size 186 | matrix = np.zeros((h, w)) 187 | for color, v in color_context.items(): 188 | color = tuple(color) 189 | if len(color) > 3: 190 | color = color[:3] 191 | if isinstance(color, str): 192 | r, g, b = color[1:3], color[3:5], color[5:7] 193 | color = (int(r, 16), int(g, 16), int(b, 16)) 194 | img_where_color = (np.array(img) == color).all(axis=-1) 195 | matrix[img_where_color] = 1 196 | 197 | for color, (subject, weight_active) in color_context.items(): 198 | if len(color) > 3: 199 | color = color[:3] 200 | v_input = _tokenizer( 201 | subject, 202 | max_length=_tokenizer.model_max_length, 203 | truncation=True, 204 | ) 205 | 206 | v_as_tokens = v_input["input_ids"][1:-1] 207 | if isinstance(color, str): 208 | r, g, b = color[1:3], color[3:5], color[5:7] 209 | color = (int(r, 16), int(g, 16), int(b, 16)) 210 | img_where_color = (np.array(img) == color).all(axis=-1) 211 | matrix[img_where_color] = 1 212 | if not img_where_color.sum() > 0: 213 | print(f"Warning : not a single color {color} not found in image") 214 | 215 | img_where_color_init = torch.where(torch.tensor(img_where_color, dtype=torch.bool), weight_active, neg) 216 | 217 | img_where_color = torch.where(torch.from_numpy(matrix == 1) & (img_where_color_init == 0.0), 218 | torch.tensor(neg), img_where_color_init) 219 | 220 | # Add the image location corresponding to the token. 221 | ret_lists.append((v_as_tokens, img_where_color)) 222 | else: 223 | w, h = 768, 768 224 | 225 | if len(ret_lists) == 0: 226 | ret_lists.append(([-1], torch.zeros((w, h), dtype=torch.float32))) 227 | return ret_lists, w, h 228 | 229 | 230 | def _extract_cross_attention(tokenizer, device, color_map_image, color_context, text_input, neg): 231 | # Process color map image and context 232 | seperated_word_contexts, width, height = _image_context_seperator( 233 | color_map_image, color_context, tokenizer, neg 234 | ) 235 | 236 | # Compute cross-attention weights 237 | cross_attention_weight_1 = _tokens_img_attention_weight( 238 | seperated_word_contexts, text_input, ratio=1, original_shape=True 239 | ).to(device) 240 | cross_attention_weight_8 = _tokens_img_attention_weight( 241 | seperated_word_contexts, text_input, ratio=8 242 | ).to(device) 243 | cross_attention_weight_16 = _tokens_img_attention_weight( 244 | seperated_word_contexts, text_input, ratio=16 245 | ).to(device) 246 | cross_attention_weight_32 = _tokens_img_attention_weight( 247 | seperated_word_contexts, text_input, ratio=32 248 | ).to(device) 249 | cross_attention_weight_64 = _tokens_img_attention_weight( 250 | seperated_word_contexts, text_input, ratio=64 251 | ).to(device) 252 | 253 | hidden_states = { 254 | "CA_WEIGHT_ORIG": cross_attention_weight_1, # 768 x 768 255 | "CA_WEIGHT_9216": cross_attention_weight_8, # 96 x 96 256 | "CA_WEIGHT_2304": cross_attention_weight_16, # 48 x 48 257 | "CA_WEIGHT_576": cross_attention_weight_32, # 24 x 24 258 | "CA_WEIGHT_144": cross_attention_weight_64, # 12 x 12 259 | } 260 | 261 | uncond_hidden_states = { 262 | "CA_WEIGHT_ORIG": 0, 263 | "CA_WEIGHT_9216": 0, 264 | "CA_WEIGHT_2304": 0, 265 | "CA_WEIGHT_576": 0, 266 | "CA_WEIGHT_144": 0, 267 | } 268 | 269 | return hidden_states, uncond_hidden_states 270 | -------------------------------------------------------------------------------- /pics/2_subjects.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/2_subjects.jpg -------------------------------------------------------------------------------- /pics/3_subjects.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/3_subjects.jpg -------------------------------------------------------------------------------- /pics/4_subjects.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/4_subjects.jpg -------------------------------------------------------------------------------- /pics/arch.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/arch.jpg -------------------------------------------------------------------------------- /pics/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/architecture.png -------------------------------------------------------------------------------- /pics/layout_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/layout_example.png -------------------------------------------------------------------------------- /pics/more.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/more.png -------------------------------------------------------------------------------- /pics/pipeline.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/pipeline.gif -------------------------------------------------------------------------------- /pics/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1 2 | torchvision==0.14.1 3 | omegaconf==2.2.3 4 | opencv-python 5 | imageio==2.9.0 6 | transformers==4.26.1 7 | diffusers==0.13.1 8 | accelerate==0.20.0 9 | scipy==1.9.1 10 | hydra-core==1.2.0 11 | tqdm 12 | gradio==3.23.0 13 | pillow 14 | packaging 15 | numpy -------------------------------------------------------------------------------- /train_cones2.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import argparse 17 | import itertools 18 | import logging 19 | import math 20 | import os 21 | from pathlib import Path 22 | 23 | import accelerate 24 | import torch 25 | import torch.nn.functional as F 26 | import torch.utils.checkpoint 27 | import transformers 28 | from accelerate import Accelerator 29 | from accelerate.logging import get_logger 30 | from accelerate.utils import ProjectConfiguration, set_seed 31 | from packaging import version 32 | from PIL import Image 33 | from torch.utils.data import Dataset 34 | from torchvision import transforms 35 | from tqdm.auto import tqdm 36 | from transformers import AutoTokenizer, PretrainedConfig 37 | 38 | import diffusers 39 | from diffusers import ( 40 | AutoencoderKL, 41 | DDPMScheduler, 42 | DiffusionPipeline, 43 | UNet2DConditionModel, 44 | StableDiffusionPipeline, 45 | ) 46 | from diffusers.optimization import get_scheduler 47 | from diffusers.utils.import_utils import is_xformers_available 48 | 49 | 50 | logger = get_logger(__name__) 51 | 52 | 53 | PROMPT_TEMPLETE = [ 54 | "a photo of a {}", 55 | "a rendering of a {}", 56 | "a cropped photo of the {}", 57 | "the photo of a {}", 58 | "a photo of a clean {}", 59 | "a photo of a dirty {}", 60 | "a dark photo of the {}", 61 | "a photo of my {}", 62 | "a photo of the cool {}", 63 | "a close-up photo of a {}", 64 | "a bright photo of the {}", 65 | "a cropped photo of a {}", 66 | "a photo of the {}", 67 | "a good photo of the {}", 68 | "a photo of one {}", 69 | "a close-up photo of the {}", 70 | "a rendition of the {}", 71 | "a photo of the clean {}", 72 | "a rendition of a {}", 73 | "a photo of a nice {}", 74 | "a good photo of a {}", 75 | "a photo of the nice {}", 76 | "a photo of the small {}", 77 | "a photo of the weird {}", 78 | "a photo of the large {}", 79 | "a photo of a cool {}", 80 | "a photo of a small {}", 81 | ] 82 | 83 | 84 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 85 | text_encoder_config = PretrainedConfig.from_pretrained( 86 | pretrained_model_name_or_path, 87 | subfolder="text_encoder", 88 | revision=revision, 89 | ) 90 | model_class = text_encoder_config.architectures[0] 91 | 92 | if model_class == "CLIPTextModel": 93 | from transformers import CLIPTextModel 94 | 95 | return CLIPTextModel 96 | elif model_class == "RobertaSeriesModelWithTransformation": 97 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 98 | 99 | return RobertaSeriesModelWithTransformation 100 | elif model_class == "T5EncoderModel": 101 | from transformers import T5EncoderModel 102 | 103 | return T5EncoderModel 104 | else: 105 | raise ValueError(f"{model_class} is not supported.") 106 | 107 | 108 | def parse_args(input_args=None): 109 | parser = argparse.ArgumentParser(description="Simple example of a script for training Cones 2.") 110 | parser.add_argument( 111 | "--pretrained_model_name_or_path", 112 | type=str, 113 | default=None, 114 | required=True, 115 | help="Path to pretrained model or model identifier from huggingface.co/models.", 116 | ) 117 | parser.add_argument( 118 | "--revision", 119 | type=str, 120 | default=None, 121 | required=False, 122 | help=( 123 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" 124 | " float32 precision." 125 | ), 126 | ) 127 | parser.add_argument( 128 | "--tokenizer_name", 129 | type=str, 130 | default=None, 131 | help="Pretrained tokenizer name or path if not the same as model_name", 132 | ) 133 | parser.add_argument( 134 | "--instance_data_dir", 135 | type=str, 136 | default=None, 137 | required=True, 138 | help="A folder containing the training data of instance images.", 139 | ) 140 | parser.add_argument( 141 | "--instance_prompt", 142 | type=str, 143 | default=None, 144 | required=True, 145 | help="The prompt with identifier specifying the instance", 146 | ) 147 | parser.add_argument( 148 | "--token_num", 149 | type=int, 150 | default=1, 151 | help="Number of updates steps to accumulate before performing a backward/update pass.", 152 | ) 153 | parser.add_argument( 154 | "--output_dir", 155 | type=str, 156 | default="cones2-model", 157 | help="The output directory where the model predictions and checkpoints will be written.", 158 | ) 159 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 160 | parser.add_argument( 161 | "--resolution", 162 | type=int, 163 | default=768, 164 | help=( 165 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 166 | " resolution" 167 | ), 168 | ) 169 | parser.add_argument( 170 | "--center_crop", 171 | default=False, 172 | action="store_true", 173 | help=( 174 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 175 | " cropped. The images will be resized to the resolution first before cropping." 176 | ), 177 | ) 178 | parser.add_argument( 179 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." 180 | ) 181 | parser.add_argument( 182 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 183 | ) 184 | parser.add_argument("--num_train_epochs", type=int, default=1) 185 | parser.add_argument( 186 | "--max_train_steps", 187 | type=int, 188 | default=None, 189 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 190 | ) 191 | parser.add_argument( 192 | "--checkpointing_steps", 193 | type=int, 194 | default=400, 195 | help=( 196 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via" 197 | " `--resume_from_checkpoint`. In the case that the checkpoint is better than the final trained model, the" 198 | " checkpoint can also be used for inference. Using a checkpoint for inference requires separate loading of" 199 | " the original pipeline and the individual checkpointed model components." 200 | ), 201 | ) 202 | parser.add_argument( 203 | "--checkpoints_total_limit", 204 | type=int, 205 | default=None, 206 | help=( 207 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 208 | ), 209 | ) 210 | parser.add_argument( 211 | "--resume_from_checkpoint", 212 | type=str, 213 | default=None, 214 | help=( 215 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 216 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 217 | ), 218 | ) 219 | parser.add_argument( 220 | "--gradient_accumulation_steps", 221 | type=int, 222 | default=1, 223 | help="Number of updates steps to accumulate before performing a backward/update pass.", 224 | ) 225 | parser.add_argument( 226 | "--gradient_checkpointing", 227 | action="store_true", 228 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 229 | ) 230 | parser.add_argument( 231 | "--learning_rate", 232 | type=float, 233 | default=5e-6, 234 | help="Initial learning rate (after the potential warmup period) to use.", 235 | ) 236 | parser.add_argument( 237 | "--loss_rate_first", 238 | type=float, 239 | default=5e-3, 240 | help="loss_rate", 241 | ) 242 | parser.add_argument( 243 | "--loss_rate_second", 244 | type=float, 245 | default=5e-4, 246 | help="loss_rate_second", 247 | ) 248 | parser.add_argument( 249 | "--scale_lr", 250 | action="store_true", 251 | default=False, 252 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 253 | ) 254 | parser.add_argument( 255 | "--lr_scheduler", 256 | type=str, 257 | default="constant", 258 | help=( 259 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 260 | ' "constant", "constant_with_warmup"]' 261 | ), 262 | ) 263 | parser.add_argument( 264 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 265 | ) 266 | parser.add_argument( 267 | "--lr_num_cycles", 268 | type=int, 269 | default=1, 270 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 271 | ) 272 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 273 | parser.add_argument( 274 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 275 | ) 276 | parser.add_argument( 277 | "--dataloader_num_workers", 278 | type=int, 279 | default=0, 280 | help=( 281 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 282 | ), 283 | ) 284 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 285 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 286 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 287 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 288 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 289 | parser.add_argument( 290 | "--logging_dir", 291 | type=str, 292 | default="logs", 293 | help=( 294 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 295 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 296 | ), 297 | ) 298 | parser.add_argument( 299 | "--allow_tf32", 300 | action="store_true", 301 | help=( 302 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 303 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 304 | ), 305 | ) 306 | parser.add_argument( 307 | "--report_to", 308 | type=str, 309 | default="tensorboard", 310 | help=( 311 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 312 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 313 | ), 314 | ) 315 | parser.add_argument( 316 | "--mixed_precision", 317 | type=str, 318 | default=None, 319 | choices=["no", "fp16", "bf16"], 320 | help=( 321 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 322 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 323 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 324 | ), 325 | ) 326 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 327 | parser.add_argument( 328 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 329 | ) 330 | parser.add_argument( 331 | "--set_grads_to_none", 332 | action="store_true", 333 | help=( 334 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" 335 | " behaviors, so disable this argument if it causes any problems. More info:" 336 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" 337 | ), 338 | ) 339 | 340 | if input_args is not None: 341 | args = parser.parse_args(input_args) 342 | else: 343 | args = parser.parse_args() 344 | 345 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 346 | if env_local_rank != -1 and env_local_rank != args.local_rank: 347 | args.local_rank = env_local_rank 348 | 349 | return args 350 | 351 | 352 | class ImagePromptDataset(Dataset): 353 | """ 354 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model. 355 | It pre-processes the images and the tokenizes prompts. 356 | """ 357 | 358 | def __init__( 359 | self, 360 | instance_data_root, 361 | instance_prompt, 362 | tokenizer, 363 | size=768, 364 | center_crop=False, 365 | ): 366 | self.size = size 367 | self.center_crop = center_crop 368 | self.tokenizer = tokenizer 369 | 370 | self.instance_data_root = Path(instance_data_root) 371 | if not self.instance_data_root.exists(): 372 | raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.") 373 | 374 | self.instance_images_path = list(Path(instance_data_root).iterdir()) 375 | self.num_instance_images = len(self.instance_images_path) 376 | self.instance_prompt = instance_prompt 377 | self._length = self.num_instance_images 378 | 379 | self.image_transforms = transforms.Compose( 380 | [ 381 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 382 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 383 | transforms.ToTensor(), 384 | transforms.Normalize([0.5], [0.5]), 385 | ] 386 | ) 387 | 388 | def __len__(self): 389 | return self._length 390 | 391 | def __getitem__(self, index): 392 | example = {} 393 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) 394 | if not instance_image.mode == "RGB": 395 | instance_image = instance_image.convert("RGB") 396 | example["instance_images"] = self.image_transforms(instance_image) 397 | example["instance_prompt_ids"] = self.tokenizer( 398 | self.instance_prompt, 399 | truncation=True, 400 | padding="max_length", 401 | max_length=self.tokenizer.model_max_length, 402 | return_tensors="pt", 403 | ).input_ids 404 | 405 | return example 406 | 407 | 408 | def collate_fn(examples): 409 | input_ids = [example["instance_prompt_ids"] for example in examples] 410 | pixel_values = [example["instance_images"] for example in examples] 411 | 412 | pixel_values = torch.stack(pixel_values) 413 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 414 | 415 | input_ids = torch.cat(input_ids, dim=0) 416 | 417 | batch = {"input_ids": input_ids, "pixel_values": pixel_values} 418 | return batch 419 | 420 | 421 | class PromptDataset(Dataset): 422 | """A simple dataset to prepare the prompts to generate class images on multiple GPUs.""" 423 | 424 | def __init__(self, prompt, num_samples): 425 | self.prompt = prompt 426 | self.num_samples = num_samples 427 | 428 | def __len__(self): 429 | return self.num_samples 430 | 431 | def __getitem__(self, index): 432 | example = {"prompt": self.prompt, "index": index} 433 | return example 434 | 435 | 436 | def main(args): 437 | logging_dir = Path(args.output_dir, args.logging_dir) 438 | 439 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 440 | accelerator = Accelerator( 441 | gradient_accumulation_steps=args.gradient_accumulation_steps, 442 | mixed_precision=args.mixed_precision, 443 | log_with=args.report_to, 444 | project_config=accelerator_project_config, 445 | ) 446 | 447 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 448 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 449 | if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 450 | raise ValueError( 451 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 452 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 453 | ) 454 | 455 | # Make one log on every process with the configuration for debugging. 456 | logging.basicConfig( 457 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 458 | datefmt="%m/%d/%Y %H:%M:%S", 459 | level=logging.INFO, 460 | ) 461 | logger.info(accelerator.state, main_process_only=False) 462 | if accelerator.is_local_main_process: 463 | transformers.utils.logging.set_verbosity_warning() 464 | diffusers.utils.logging.set_verbosity_info() 465 | else: 466 | transformers.utils.logging.set_verbosity_error() 467 | diffusers.utils.logging.set_verbosity_error() 468 | 469 | # If passed along, set the training seed now. 470 | if args.seed is not None: 471 | set_seed(args.seed) 472 | 473 | # Handle the repository creation 474 | if accelerator.is_main_process: 475 | if args.output_dir is not None: 476 | os.makedirs(args.output_dir, exist_ok=True) 477 | 478 | # Load the tokenizer 479 | if args.tokenizer_name: 480 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 481 | elif args.pretrained_model_name_or_path: 482 | tokenizer = AutoTokenizer.from_pretrained( 483 | args.pretrained_model_name_or_path, 484 | subfolder="tokenizer", 485 | revision=args.revision, 486 | use_fast=False, 487 | ) 488 | 489 | # import correct text encoder class 490 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 491 | 492 | # Load scheduler and models 493 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 494 | text_encoder = text_encoder_cls.from_pretrained( 495 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 496 | ) 497 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 498 | unet = UNet2DConditionModel.from_pretrained( 499 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 500 | ) 501 | 502 | # `accelerate` 0.16.0 will have better support for customized saving 503 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"): 504 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format 505 | def save_model_hook(models, weights, output_dir): 506 | for model in models: 507 | sub_dir = "unet" if type(model) == type(unet) else "text_encoder" 508 | model.save_pretrained(os.path.join(output_dir, sub_dir)) 509 | 510 | # make sure to pop weight so that corresponding model is not saved again 511 | weights.pop() 512 | 513 | def load_model_hook(models, input_dir): 514 | while len(models) > 0: 515 | # pop models so that they are not loaded again 516 | model = models.pop() 517 | 518 | if type(model) == type(text_encoder): 519 | # load transformers style into model 520 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") 521 | model.config = load_model.config 522 | else: 523 | # load diffusers style into model 524 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") 525 | model.register_to_config(**load_model.config) 526 | 527 | model.load_state_dict(load_model.state_dict()) 528 | del load_model 529 | 530 | accelerator.register_save_state_pre_hook(save_model_hook) 531 | accelerator.register_load_state_pre_hook(load_model_hook) 532 | 533 | vae.requires_grad_(False) 534 | unet.requires_grad_(False) 535 | 536 | if args.enable_xformers_memory_efficient_attention: 537 | if is_xformers_available(): 538 | import xformers 539 | xformers_version = version.parse(xformers.__version__) 540 | if xformers_version == version.parse("0.0.16"): 541 | logger.warn( 542 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training," 543 | " please update xFormers to at least 0.0.17." 544 | " See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 545 | ) 546 | unet.enable_xformers_memory_efficient_attention() 547 | else: 548 | raise ValueError("xformers is not available. Make sure it is installed correctly") 549 | 550 | if args.gradient_checkpointing: 551 | unet.enable_gradient_checkpointing() 552 | # if args.train_text_encoder: 553 | # text_encoder.gradient_checkpointing_enable() 554 | text_encoder.gradient_checkpointing_enable() 555 | 556 | # Check that all trainable models are in full precision 557 | low_precision_error_string = ( 558 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 559 | " doing mixed precision training. copy of the weights should still be float32." 560 | ) 561 | 562 | if accelerator.unwrap_model(text_encoder).dtype != torch.float32: 563 | raise ValueError( 564 | f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." 565 | f" {low_precision_error_string}" 566 | ) 567 | 568 | # Enable TF32 for faster training on Ampere GPUs, 569 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 570 | if args.allow_tf32: 571 | torch.backends.cuda.matmul.allow_tf32 = True 572 | 573 | if args.scale_lr: 574 | args.learning_rate = ( 575 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 576 | ) 577 | 578 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 579 | if args.use_8bit_adam: 580 | try: 581 | import bitsandbytes as bnb 582 | except ImportError: 583 | raise ImportError( 584 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 585 | ) 586 | 587 | optimizer_class = bnb.optim.AdamW8bit 588 | else: 589 | optimizer_class = torch.optim.AdamW 590 | 591 | # Optimizer creation 592 | params_to_optimize = (itertools.chain(text_encoder.parameters())) 593 | optimizer = optimizer_class( 594 | params_to_optimize, 595 | lr=args.learning_rate, 596 | betas=(args.adam_beta1, args.adam_beta2), 597 | weight_decay=args.adam_weight_decay, 598 | eps=args.adam_epsilon, 599 | ) 600 | 601 | # Dataset and DataLoaders creation: 602 | train_dataset = ImagePromptDataset( 603 | instance_data_root=args.instance_data_dir, 604 | instance_prompt=args.instance_prompt, 605 | tokenizer=tokenizer, 606 | size=args.resolution, 607 | center_crop=args.center_crop, 608 | ) 609 | 610 | train_dataloader = torch.utils.data.DataLoader( 611 | train_dataset, 612 | batch_size=args.train_batch_size, 613 | shuffle=True, 614 | collate_fn=lambda examples: collate_fn(examples), 615 | num_workers=args.dataloader_num_workers, 616 | ) 617 | 618 | # Scheduler and math around the number of training steps. 619 | overrode_max_train_steps = False 620 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 621 | if args.max_train_steps is None: 622 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 623 | overrode_max_train_steps = True 624 | 625 | lr_scheduler = get_scheduler( 626 | args.lr_scheduler, 627 | optimizer=optimizer, 628 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 629 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 630 | num_cycles=args.lr_num_cycles, 631 | power=args.lr_power, 632 | ) 633 | 634 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 635 | text_encoder, optimizer, train_dataloader, lr_scheduler 636 | ) 637 | 638 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 639 | # as these models are only used for inference, keeping weights in full precision is not required. 640 | weight_dtype = torch.float32 641 | if accelerator.mixed_precision == "fp16": 642 | weight_dtype = torch.float16 643 | elif accelerator.mixed_precision == "bf16": 644 | weight_dtype = torch.bfloat16 645 | 646 | # Move vae and unet to device and cast to weight_dtype 647 | vae.to(accelerator.device, dtype=weight_dtype) 648 | unet.to(accelerator.device, dtype=weight_dtype) 649 | 650 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 651 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 652 | if overrode_max_train_steps: 653 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 654 | 655 | # Afterwards we recalculate our number of training epochs 656 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 657 | 658 | # We need to initialize the trackers we use, and also store our configuration. 659 | # The trackers initializes automatically on the main process. 660 | if accelerator.is_main_process: 661 | accelerator.init_trackers("cones2", config=vars(args)) 662 | 663 | # Train! 664 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 665 | 666 | logger.info("***** Running training *****") 667 | logger.info(f" Num examples = {len(train_dataset)}") 668 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 669 | logger.info(f" Num Epochs = {args.num_train_epochs}") 670 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 671 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 672 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 673 | logger.info(f" Total optimization steps = {args.max_train_steps}") 674 | global_step = 0 675 | first_epoch = 0 676 | 677 | # Potentially load in the weights and states from a previous save 678 | if args.resume_from_checkpoint: 679 | if args.resume_from_checkpoint != "latest": 680 | path = os.path.basename(args.resume_from_checkpoint) 681 | else: 682 | # Get the mos recent checkpoint 683 | dirs = os.listdir(args.output_dir) 684 | dirs = [d for d in dirs if d.startswith("checkpoint")] 685 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 686 | path = dirs[-1] if len(dirs) > 0 else None 687 | 688 | if path is None: 689 | accelerator.print( 690 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 691 | ) 692 | args.resume_from_checkpoint = None 693 | else: 694 | accelerator.print(f"Resuming from checkpoint {path}") 695 | accelerator.load_state(os.path.join(args.output_dir, path)) 696 | global_step = int(path.split("-")[1]) 697 | 698 | resume_global_step = global_step * args.gradient_accumulation_steps 699 | first_epoch = global_step // num_update_steps_per_epoch 700 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 701 | 702 | # Only show the progress bar once on each machine. 703 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 704 | progress_bar.set_description("Steps") 705 | 706 | num_iter = 0 707 | 708 | for epoch in range(first_epoch, args.num_train_epochs): 709 | text_encoder.train() 710 | for step, batch in enumerate(train_dataloader): 711 | # Skip steps until we reach the resumed step 712 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 713 | if step % args.gradient_accumulation_steps == 0: 714 | progress_bar.update(1) 715 | continue 716 | 717 | with accelerator.accumulate(text_encoder): 718 | # Convert images to latent space 719 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 720 | latents = latents * vae.config.scaling_factor 721 | 722 | # Sample noise that we'll add to the latents 723 | noise = torch.randn_like(latents) 724 | bsz = latents.shape[0] 725 | 726 | # Sample a random timestep for each image 727 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 728 | timesteps = timesteps.long() 729 | 730 | # Add noise to the latents according to the noise magnitude at each timestep 731 | # (this is the forward diffusion process) 732 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 733 | 734 | # Get the text embedding for conditioning 735 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 736 | if num_iter == 0: 737 | target_embed = encoder_hidden_states.detach() 738 | 739 | l_first_part = torch.norm(torch.squeeze(target_embed)[:1] - torch.squeeze(encoder_hidden_states)[:1], 2) 740 | l_second_part = torch.norm(torch.squeeze(target_embed)[1 + args.token_num:] 741 | - torch.squeeze(encoder_hidden_states)[1 + args.token_num:], 2) 742 | loss_embedding = args.loss_rate_first * l_first_part + args.loss_rate_second * l_second_part 743 | 744 | # Predict the noise residual 745 | model_pred = unet(noisy_latents.float(), timesteps, encoder_hidden_states.float()).sample 746 | 747 | # Get the target for loss depending on the prediction type 748 | if noise_scheduler.config.prediction_type == "epsilon": 749 | target = noise 750 | elif noise_scheduler.config.prediction_type == "v_prediction": 751 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 752 | else: 753 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 754 | 755 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 756 | loss = loss + loss_embedding 757 | 758 | accelerator.backward(loss) 759 | if accelerator.sync_gradients: 760 | params_to_clip = (itertools.chain(text_encoder.parameters())) 761 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 762 | optimizer.step() 763 | lr_scheduler.step() 764 | 765 | optimizer.zero_grad(set_to_none=args.set_grads_to_none) 766 | num_iter += 1 767 | 768 | # Checks if the accelerator has performed an optimization step behind the scenes 769 | if accelerator.sync_gradients: 770 | progress_bar.update(1) 771 | global_step += 1 772 | 773 | if global_step % args.checkpointing_steps == 0: 774 | if accelerator.is_main_process: 775 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 776 | accelerator.save_state(save_path) 777 | logger.info(f"Saved state to {save_path}") 778 | 779 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 780 | progress_bar.set_postfix(**logs) 781 | accelerator.log(logs, step=global_step) 782 | 783 | if global_step >= args.max_train_steps: 784 | break 785 | 786 | # Create the pipeline using the trained modules and save it. 787 | accelerator.wait_for_everyone() 788 | if accelerator.is_main_process: 789 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path).to("cuda") 790 | 791 | text_inputs_origin = pipe.tokenizer( 792 | args.instance_prompt, 793 | padding="max_length", 794 | max_length=tokenizer.model_max_length, 795 | truncation=True, 796 | return_tensors="pt", 797 | ) 798 | text_inputs_origin_ids = text_inputs_origin.input_ids 799 | index = text_inputs_origin_ids[0][1] 800 | prompt_embeds_new = 0 801 | prompt_embeds_origin = 0 802 | for template in PROMPT_TEMPLETE: 803 | text_inputs = pipe.tokenizer( 804 | template.format(args.instance_prompt), 805 | padding="max_length", 806 | max_length=tokenizer.model_max_length, 807 | truncation=True, 808 | return_tensors="pt", 809 | ) 810 | text_input_ids = text_inputs.input_ids 811 | index_template = int(torch.where(text_input_ids[0] == index)[0][0]) 812 | prompt_embeds_now = text_encoder(text_input_ids.to("cuda"), attention_mask=None) 813 | prompt_embeds_now = prompt_embeds_now[0][0][index_template: index_template + args.token_num] 814 | prompt_embeds = pipe.text_encoder(text_input_ids.to("cuda"), attention_mask=None) 815 | prompt_embeds = prompt_embeds[0][0][index_template: index_template + args.token_num] 816 | prompt_embeds_new += prompt_embeds_now 817 | prompt_embeds_origin += prompt_embeds 818 | 819 | residual_save_path = args.output_dir + '/residual.pt' 820 | torch.save((prompt_embeds_new - prompt_embeds_origin) / len(PROMPT_TEMPLETE), residual_save_path) 821 | 822 | pipeline = DiffusionPipeline.from_pretrained( 823 | args.pretrained_model_name_or_path, 824 | unet=accelerator.unwrap_model(unet), 825 | text_encoder=accelerator.unwrap_model(text_encoder), 826 | revision=args.revision, 827 | ) 828 | pipeline.save_pretrained(args.output_dir) 829 | accelerator.print(f"The customized residual embedding saved in {residual_save_path}") 830 | 831 | accelerator.end_training() 832 | 833 | 834 | if __name__ == "__main__": 835 | args = parse_args() 836 | main(args) 837 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor: 8 | return F.interpolate( 9 | img.unsqueeze(0).unsqueeze(1), 10 | size=(w, h), 11 | mode="bilinear", 12 | align_corners=True, 13 | ).squeeze() 14 | 15 | 16 | def image_grid(images, rows=2, cols=2): 17 | w, h = images[0].size 18 | grid = Image.new('RGB', size=(cols * w, rows * h)) 19 | 20 | for i, img in enumerate(images): 21 | grid.paste(img, box=(i % cols * w, i // cols * h)) 22 | return grid 23 | 24 | 25 | def latents_to_images(vae, latents, scale_factor=0.18215): 26 | """ 27 | Decode latents to PIL images. 28 | """ 29 | scaled_latents = 1.0 / scale_factor * latents.clone() 30 | images = vae.decode(scaled_latents).sample 31 | images = (images / 2 + 0.5).clamp(0, 1) 32 | images = images.detach().cpu().permute(0, 2, 3, 1).numpy() 33 | 34 | if images.ndim == 3: 35 | images = images[None, ...] 36 | images = (images * 255).round().astype("uint8") 37 | pil_images = [Image.fromarray(image) for image in images] 38 | 39 | return pil_images 40 | --------------------------------------------------------------------------------