├── LICENSE
├── README.md
├── attention_control.py
├── guidance_config_example.json
├── inference.py
├── layout_guidance.py
├── pics
├── 2_subjects.jpg
├── 3_subjects.jpg
├── 4_subjects.jpg
├── arch.jpg
├── architecture.png
├── layout_example.png
├── more.png
├── pipeline.gif
└── teaser.png
├── requirements.txt
├── train_cones2.py
└── utils.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 DAMO Vision Intelligence Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cones 2
2 |
3 | Official repo for [Cones 2: Customizable Image Synthesis with Multiple Subjects](https://arxiv.org/abs/2305.19327) | [Project Page](https://cones-page.github.io/).
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | [Cones 2](https://cones-page.github.io/) allows you to represent a specific subject as a **residual embedding** by
12 | fine-tuning text encoder in a pre-trained text-to-image diffusion model, such as
13 | [Stable Diffusion](https://github.com/CompVis/stable-diffusion). After tuning, we only need to save the residual between tuned text-encoder and frozen
14 | one. Thus, the storage space required for each additional subject is only **5 KB**. This step only takes about 20~30 minutes on a single 80G A100
15 | GPU for each subject.
16 |
17 | When sampling, our **layout guidance sampling** method further allows you to employ a easy-to-obtain layout as guidance for multiple subjects
18 | arrangement as shown in the following figure.
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## Results
26 |
27 | All results are synthesized by pre-trained [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
28 | models and our customized residual embeddings. We show diverse results on various categories of images, including
29 | scene, pet, personal toy, human _etc._. For more results, please refer to our [paper](https://arxiv.org/abs/2305.19327) or [website](https://cones-page.github.io/).
30 |
31 | ### Two-Subject Results
32 | 
33 |
34 | ### Three-Subject Results
35 | 
36 |
37 | ### Four-Subject Results
38 | 
39 |
40 | ### More Results
41 | 
42 |
43 | ## Method
44 |
45 | 
46 |
47 | (a) Given few-shot images of the customized subject, we fine-tune the text encoder to learn a residual embedding on top
48 | of the base embedding of raw subject. (b) Based on the residual embeddings, we then propose to employ layout as the
49 | spatial guidance for subject arrangement into the attention maps. After that, we could strengthen the signal of target
50 | subjects and weaken the signal of irrelevant subjects. For more details, please refer to our [paper](https://arxiv.org/abs/2305.19327).
51 |
52 | ## Getting Started
53 |
54 | ### Installing the dependencies
55 |
56 | The implementation of Cones 2 is based entirely on the [diffuser](https://github.com/huggingface/diffusers/tree/main).
57 | Before running out code, make sure to install the library's training dependencies. To do this, execute the following
58 | steps in a new virtual environment:
59 |
60 | ```bash
61 | git clone https://github.com/damo-vilab/Cones-V2.git
62 | cd Cones-V2
63 | pip install -r requirements.txt
64 | ```
65 |
66 | Then initialize an [🤗—Accelerate](https://github.com/huggingface/accelerate/) environment with:
67 |
68 | ```bash
69 | accelerate config
70 | ```
71 |
72 | Or for a default accelerate configuration without answering questions about your environment
73 |
74 | ```bash
75 | accelerate config default
76 | ```
77 |
78 | ## Training (Flower example)
79 |
80 | Firstly, let's download dataset from
81 | [here](https://modelscope.cn/api/v1/datasets/zyf619/cones2_residual/repo?Revision=master&FilePath=data.zip)
82 | and unzip it to `./data`. Now a few images of Flower (download to `./data/flower`) are used to learn its
83 | customized residual embedding.
84 |
85 | ```bash
86 | export MODEL_NAME='path-to-stable-diffusion-v2-1'
87 | export INSTANCE_DIR="./data/flower"
88 | export OUTPUT_DIR="path-to-save-model"
89 | accelerate launch train_cones2.py \
90 | --pretrained_model_name_or_path=$MODEL_NAME \
91 | --instance_data_dir=$INSTANCE_DIR \
92 | --instance_prompt="flower" \
93 | --token_num=1 \
94 | --output_dir=$OUTPUT_DIR \
95 | --resolution=768 \
96 | --train_batch_size=1 \
97 | --gradient_accumulation_steps=1 \
98 | --learning_rate=5e-6 \
99 | --lr_scheduler="constant" \
100 | --lr_warmup_steps=0 \
101 | --max_train_steps=4000 \
102 | --loss_rate_first=1e-2 \
103 | --loss_rate_second=1e-3
104 | ```
105 |
106 | ## Inference
107 |
108 | Once you have trained several residual embeddings of different subjects using the above command, you can run our
109 | **layout guidance sampling method** simply using [inference.py](inference.py). We provide several pre-trained
110 | models for quick validation.
111 |
112 |
113 |
114 | residual embedding |
115 | raw token |
116 | download |
117 |
118 |
119 | barn |
120 | barn |
121 | barn.pt |
122 |
123 |
124 | white dog |
125 | dog |
126 | dog.pt |
127 |
128 |
129 | flower |
130 | flower |
131 | flower.pt |
132 |
133 |
134 | lake |
135 | lake |
136 | lake.pt |
137 |
138 |
139 | mug |
140 | mug |
141 | mug.pt |
142 |
143 |
144 | sunglasses |
145 | sunglasses |
146 | sunglasses.pt |
147 |
148 |
149 | bag |
150 | bag |
151 | bag.pt |
152 |
153 |
154 | robot |
155 | robot toy |
156 | robot_toy.pt |
157 |
158 |
159 |
160 |
161 | Remember to provide a pre-defined layout like [layout_example.png](pics/layout_example.png) and a
162 | [json](guidance_config_example.json) file with the info about the details of the inference settings. The
163 | [json](guidance_config_example.json) file should include the following information:
164 |
165 | - "prompt": the text prompt you want to generate.
166 | - "residual_dict": the paths to all the required residual embeddings.
167 | - "color_context": the color information of different regions in the layout and their corresponding subjects, along with
168 | the weight for strengthening the signal of target subject. (default: 2.5).
169 | - "guidance_steps": the number of steps of the layout guidance.
170 | - "guidance_weight": the strength of the layout guidance (default: 0.08, we recommond 0.05 ~ 0.10).
171 | - "weight_negative": the weight for weakening the signal of irrelevant subject.
172 | - "layout": the path to user-defined layout image.
173 | - "subject_list": the list containing all the subjects to be customized and their corresponding positions in the prompt.
174 |
175 | Then you can simply run inference script with:
176 |
177 | ```bash
178 | python inference.py --pretrained_model_name_or_path /path/to/stable-diffusion-2-1 --inference_config guidance_config_example.json
179 | ```
180 |
181 | ## References
182 |
183 | ```BibTeX
184 | @article{liu2023cones,
185 | title={Cones 2: Customizable Image Synthesis with Multiple Subjects},
186 | author={Liu, Zhiheng and Zhang, Yifei and Shen, Yujun and Zheng, Kecheng and Zhu, Kai and Feng, Ruili and Liu, Yu and Zhao, Deli and Zhou, Jingren and Cao, Yang},
187 | journal={arXiv preprint arXiv:2305.19327},
188 | year={2023}
189 | }
190 | ```
191 |
192 | ### Acknowledgements
193 | We thank [Stable Diffusion v2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) and [diffuser](https://github.com/huggingface/diffusers/tree/main) for providing pre-trained model and an open-source codebase.
194 |
--------------------------------------------------------------------------------
/attention_control.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from diffusers.models.cross_attention import CrossAttention
7 |
8 |
9 | class Cones2AttnProcessor:
10 | def __init__(self):
11 | super().__init__()
12 |
13 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
14 | batch_size, sequence_length, _ = hidden_states.shape
15 | query = attn.to_q(hidden_states)
16 | is_dict_format = True
17 | if encoder_hidden_states is not None:
18 | try:
19 | encoder_hidden = encoder_hidden_states["CONDITION_TENSOR"]
20 | except:
21 | encoder_hidden = encoder_hidden_states
22 | is_dict_format = False
23 | if attn.cross_attention_norm:
24 | encoder_hidden = attn.norm_cross(encoder_hidden)
25 | else:
26 | encoder_hidden = hidden_states
27 |
28 | key = attn.to_k(encoder_hidden)
29 | value = attn.to_v(encoder_hidden)
30 |
31 | query = attn.head_to_batch_dim(query)
32 | key = attn.head_to_batch_dim(key)
33 | value = attn.head_to_batch_dim(value)
34 |
35 | attention_scores = torch.matmul(query, key.transpose(-1, -2))
36 | attention_size_of_img = attention_scores.size()[-2]
37 |
38 | if attention_scores.size()[2] == 77:
39 | if is_dict_format:
40 | f = encoder_hidden_states["function"]
41 | try:
42 | w = encoder_hidden_states[f"CA_WEIGHT_{attention_size_of_img}"]
43 | except KeyError:
44 | w = encoder_hidden_states[f"CA_WEIGHT_ORIG"]
45 | if not isinstance(w, int):
46 | img_h, img_w, nc = w.shape
47 | ratio = math.sqrt(img_h * img_w / attention_size_of_img)
48 | w = F.interpolate(w.permute(2, 0, 1).unsqueeze(0), scale_factor=1 / ratio, mode="bilinear",
49 | align_corners=True)
50 | w = F.interpolate(w.reshape(1, nc, -1), size=(attention_size_of_img,), mode='nearest').permute(
51 | 2, 1, 0).squeeze()
52 | else:
53 | w = 0
54 | if type(w) is int and w == 0:
55 | sigma = encoder_hidden_states["SIGMA"]
56 | cross_attention_weight = f(w, sigma, attention_scores)
57 | else:
58 | bias = torch.zeros_like(w)
59 | bias[torch.where(w > 0)] = attention_scores.std() * 0
60 | sigma = encoder_hidden_states["SIGMA"]
61 | cross_attention_weight = f(w, sigma, attention_scores)
62 | cross_attention_weight = cross_attention_weight + bias
63 | else:
64 | cross_attention_weight = 0.0
65 | else:
66 | cross_attention_weight = 0.0
67 |
68 | attention_scores = (attention_scores + cross_attention_weight) * attn.scale
69 | attention_probs = attention_scores.softmax(dim=-1)
70 |
71 | hidden_states = torch.matmul(attention_probs, value)
72 | hidden_states = attn.batch_to_head_dim(hidden_states)
73 |
74 | # linear proj
75 | hidden_states = attn.to_out[0](hidden_states)
76 | # dropout
77 | hidden_states = attn.to_out[1](hidden_states)
78 |
79 | return hidden_states
80 |
81 |
82 | def register_attention_control(unet):
83 | attn_procs = {}
84 | for name in unet.attn_processors.keys():
85 | attn_procs[name] = Cones2AttnProcessor()
86 |
87 | unet.set_attn_processor(attn_procs)
88 |
--------------------------------------------------------------------------------
/guidance_config_example.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "prompt":"a mug and a dog on the beach",
4 | "residual_dict": {
5 | "dog":"residuals/dog.pt",
6 | "mug":"residuals/mug.pt",
7 | "flower":"residuals/flower.pt",
8 | "sunglasses":"residuals/sunglasses.pt",
9 | "lake":"residuals/lake.pt",
10 | "barn":"residuals/barn.pt"
11 | },
12 | "color_context":{
13 | "255,192,0":["mug",2.5],
14 | "255,0,0":["dog",2.5]
15 | },
16 | "guidance_steps":50,
17 | "guidance_weight":0.08,
18 | "weight_negative":-1e8,
19 | "layout":"layouts/layout_example.png",
20 | "subject_list":[["mug",2],["dog",5]]
21 | }
22 | ]
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 |
5 | from PIL import Image
6 |
7 | from diffusers import StableDiffusionPipeline
8 |
9 | from layout_guidance import layout_guidance_sampling
10 | from utils import image_grid
11 |
12 |
13 | def get_args():
14 | parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
15 | parser.add_argument(
16 | "--pretrained_model_name_or_path",
17 | type=str,
18 | default=None,
19 | required=True,
20 | help="Path to pretrained model or model identifier from huggingface.co/models.",
21 | )
22 | parser.add_argument(
23 | "--inference_config",
24 | type=str,
25 | default=None,
26 | required=True,
27 | help='Path to a json file containing settings for inference, containing "residual_path", "prompt", '
28 | '"color_context", "edit_tokens", "layout", "subject_list".',
29 | )
30 | parser.add_argument(
31 | "--output_dir",
32 | type=str,
33 | default="inference_results",
34 | help="The output directory where the model predictions and checkpoints will be written.",
35 | )
36 | return parser.parse_args()
37 |
38 |
39 | def main(args):
40 | # Initialize pre-trained Stable Diffusion pipeline.
41 | pipeline = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path).to("cuda")
42 |
43 | # Load the settings required for inference from the configuration file.
44 | with open(args.inference_config, "r") as f:
45 | inference_cfg = json.load(f)
46 |
47 | prompt = inference_cfg[0]["prompt"]
48 | residual_dict = inference_cfg[0]["residual_dict"]
49 | subject_list = inference_cfg[0]["subject_list"]
50 | guidance_steps = inference_cfg[0]["guidance_steps"]
51 | guidance_weight = inference_cfg[0]["guidance_weight"]
52 | weight_negative = inference_cfg[0]["weight_negative"]
53 | layout = Image.open(inference_cfg[0]["layout"]).resize((768, 768)).convert("RGB")
54 | color_context = inference_cfg[0]["color_context"]
55 | subject_color_dict = {tuple(map(int, key.split(','))): value for key, value in color_context.items()}
56 |
57 | if args.output_dir is not None:
58 | os.makedirs(args.output_dir, exist_ok=True)
59 | subject_info = '_'.join([s[0] for s in sorted(subject_list)])
60 | prompt_info = '_'.join(prompt.split())
61 | save_dir = os.path.join(args.output_dir, subject_info, prompt_info)
62 | os.makedirs(save_dir, exist_ok=True)
63 |
64 | images = []
65 |
66 | for i in range(4):
67 | image = layout_guidance_sampling(
68 | seed=i,
69 | device="cuda:0",
70 | resolution=768,
71 | pipeline=pipeline,
72 | prompt=prompt,
73 | residual_dict=residual_dict,
74 | subject_list=subject_list,
75 | subject_color_dict=subject_color_dict,
76 | layout=layout,
77 | cfg_scale=7.5,
78 | inference_steps=50,
79 | guidance_steps=guidance_steps,
80 | guidance_weight=guidance_weight,
81 | weight_negative=weight_negative,
82 | )
83 |
84 | image.save(os.path.join(save_dir, f"{i}.png"))
85 | images.append(image)
86 |
87 | all_image = image_grid(images=images, rows=2, cols=2)
88 | all_image.save(os.path.join(save_dir, f"all_images.png"))
89 |
90 |
91 | if __name__ == "__main__":
92 | args = get_args()
93 | main(args)
94 |
--------------------------------------------------------------------------------
/layout_guidance.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import torch
4 |
5 | from tqdm.auto import tqdm
6 |
7 | from diffusers import LMSDiscreteScheduler
8 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
9 |
10 | from attention_control import register_attention_control
11 | from utils import latents_to_images, downsampling
12 |
13 |
14 | @torch.no_grad()
15 | def layout_guidance_sampling(seed,
16 | device,
17 | resolution,
18 | pipeline,
19 | prompt="",
20 | residual_dict=None,
21 | subject_list=None,
22 | subject_color_dict=None,
23 | layout=None,
24 | cfg_scale=7.5,
25 | inference_steps=50,
26 | guidance_steps=50,
27 | guidance_weight=0.05,
28 | weight_negative=-1e8):
29 | vae = pipeline.vae
30 | unet = pipeline.unet
31 | text_encoder = pipeline.text_encoder
32 | tokenizer = pipeline.tokenizer
33 | unconditional_input_prompt = ""
34 | scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
35 | scheduler.set_timesteps(inference_steps, device=device)
36 | if guidance_steps > 0:
37 | guidance_steps = min(guidance_steps, inference_steps)
38 | scheduler_guidance = LMSDiscreteScheduler(
39 | beta_start=0.00085,
40 | beta_end=0.012,
41 | beta_schedule="scaled_linear",
42 | num_train_timesteps=1000,
43 | )
44 | scheduler_guidance.set_timesteps(guidance_steps, device=device)
45 |
46 | # Process input prompt text
47 | text_input = tokenizer(
48 | [prompt],
49 | padding="max_length",
50 | max_length=tokenizer.model_max_length,
51 | truncation=True,
52 | return_tensors="pt",
53 | )
54 |
55 | # Edit text embedding conditions with residual token embeddings.
56 | cond_embeddings = text_encoder(text_input.input_ids.to(device))[0]
57 | for name, token in subject_list:
58 | residual_token_embedding = torch.load(residual_dict[name])
59 | cond_embeddings[0][token] += residual_token_embedding.reshape(1024)
60 |
61 | # Process unconditional input "" for classifier-free guidance.
62 | max_length = text_input.input_ids.shape[-1]
63 | uncond_input = tokenizer(
64 | [unconditional_input_prompt],
65 | padding="max_length",
66 | max_length=max_length,
67 | return_tensors="pt",
68 | )
69 | uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
70 |
71 | register_attention_control(unet)
72 |
73 | # Calculate the hidden features for each cross attention layer.
74 | hidden_states, uncond_hidden_states = _extract_cross_attention(tokenizer,
75 | device,
76 | layout,
77 | subject_color_dict,
78 | text_input,
79 | weight_negative)
80 | hidden_states["CONDITION_TENSOR"] = cond_embeddings
81 | uncond_hidden_states["CONDITION_TENSOR"] = uncond_embeddings
82 | hidden_states["function"] = lambda w, x, qk: (guidance_weight * w * math.log(1 + x ** 2)) * qk.std()
83 | uncond_hidden_states["function"] = lambda w, x, qk: 0.0
84 |
85 | # Sampling the initial latents.
86 | latent_size = (1, unet.in_channels, resolution // 8, resolution // 8)
87 | latents = torch.randn(latent_size, generator=torch.manual_seed(seed))
88 | latents = latents.to(device)
89 | latents = latents * scheduler.init_noise_sigma
90 |
91 | for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
92 | # Improve the harmony of generated images by self-recurrence.
93 | if i < guidance_steps:
94 | loop = 2
95 | else:
96 | loop = 1
97 | for k in range(loop):
98 | if i < guidance_steps:
99 | sigma = scheduler_guidance.sigmas[i]
100 | latent_model_input = scheduler.scale_model_input(latents, t)
101 |
102 | hidden_states.update({
103 | "SIGMA": sigma,
104 | })
105 |
106 | noise_pred_text = unet(
107 | latent_model_input,
108 | t,
109 | encoder_hidden_states=hidden_states,
110 | ).sample
111 |
112 | uncond_hidden_states.update({
113 | "SIGMA": sigma,
114 | })
115 |
116 | noise_pred_uncond = unet(
117 | latent_model_input,
118 | t,
119 | encoder_hidden_states=uncond_hidden_states,
120 | ).sample
121 |
122 | noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
123 | latents = scheduler.step(noise_pred, t, latents, 1).prev_sample
124 |
125 | # Self-recurrence.
126 | if k < 1 and loop > 1:
127 | noise_recurrent = torch.randn(latents.shape).to(device)
128 | noise_scale = ((scheduler.sigmas[i] ** 2 - scheduler.sigmas[i + 1] ** 2) ** 0.5)
129 | latents = latents + noise_scale * noise_recurrent
130 | else:
131 | latent_model_input = scheduler.scale_model_input(latents, t)
132 | noise_pred_text = unet(
133 | latent_model_input,
134 | t,
135 | encoder_hidden_states=cond_embeddings,
136 | ).sample
137 |
138 | latent_model_input = scheduler.scale_model_input(latents, t)
139 |
140 | noise_pred_uncond = unet(
141 | latent_model_input,
142 | t,
143 | encoder_hidden_states=uncond_embeddings,
144 | ).sample
145 |
146 | noise_pred = noise_pred_uncond + cfg_scale * (noise_pred_text - noise_pred_uncond)
147 | latents = scheduler.step(noise_pred, t, latents, 1).prev_sample
148 |
149 | edited_images = latents_to_images(vae, latents)
150 |
151 | return StableDiffusionPipelineOutput(images=edited_images, nsfw_content_detected=None).images[0]
152 |
153 |
154 | def _tokens_img_attention_weight(img_context_seperated, tokenized_texts, ratio: int = 8, original_shape=False):
155 | token_lis = tokenized_texts["input_ids"][0].tolist()
156 | w, h = img_context_seperated[0][1].shape
157 |
158 | w_r, h_r = round(w / ratio), round(h / ratio)
159 | ret_tensor = torch.zeros((w_r * h_r, len(token_lis)), dtype=torch.float32)
160 | for v_as_tokens, img_where_color in img_context_seperated:
161 | is_in = 0
162 | for idx, tok in enumerate(token_lis):
163 | if token_lis[idx: idx + len(v_as_tokens)] == v_as_tokens:
164 | is_in = 1
165 |
166 | ret_tensor[:, idx: idx + len(v_as_tokens)] += (
167 | downsampling(img_where_color, w_r, h_r)
168 | .reshape(-1, 1)
169 | .repeat(1, len(v_as_tokens))
170 | )
171 |
172 | if not is_in == 1:
173 | print(f"Warning ratio {ratio} : tokens {v_as_tokens} not found in text")
174 |
175 | if original_shape:
176 | ret_tensor = ret_tensor.reshape((w_r, h_r, len(token_lis)))
177 |
178 | return ret_tensor
179 |
180 |
181 | def _image_context_seperator(img, color_context: dict, _tokenizer, neg: float):
182 | ret_lists = []
183 |
184 | if img is not None:
185 | w, h = img.size
186 | matrix = np.zeros((h, w))
187 | for color, v in color_context.items():
188 | color = tuple(color)
189 | if len(color) > 3:
190 | color = color[:3]
191 | if isinstance(color, str):
192 | r, g, b = color[1:3], color[3:5], color[5:7]
193 | color = (int(r, 16), int(g, 16), int(b, 16))
194 | img_where_color = (np.array(img) == color).all(axis=-1)
195 | matrix[img_where_color] = 1
196 |
197 | for color, (subject, weight_active) in color_context.items():
198 | if len(color) > 3:
199 | color = color[:3]
200 | v_input = _tokenizer(
201 | subject,
202 | max_length=_tokenizer.model_max_length,
203 | truncation=True,
204 | )
205 |
206 | v_as_tokens = v_input["input_ids"][1:-1]
207 | if isinstance(color, str):
208 | r, g, b = color[1:3], color[3:5], color[5:7]
209 | color = (int(r, 16), int(g, 16), int(b, 16))
210 | img_where_color = (np.array(img) == color).all(axis=-1)
211 | matrix[img_where_color] = 1
212 | if not img_where_color.sum() > 0:
213 | print(f"Warning : not a single color {color} not found in image")
214 |
215 | img_where_color_init = torch.where(torch.tensor(img_where_color, dtype=torch.bool), weight_active, neg)
216 |
217 | img_where_color = torch.where(torch.from_numpy(matrix == 1) & (img_where_color_init == 0.0),
218 | torch.tensor(neg), img_where_color_init)
219 |
220 | # Add the image location corresponding to the token.
221 | ret_lists.append((v_as_tokens, img_where_color))
222 | else:
223 | w, h = 768, 768
224 |
225 | if len(ret_lists) == 0:
226 | ret_lists.append(([-1], torch.zeros((w, h), dtype=torch.float32)))
227 | return ret_lists, w, h
228 |
229 |
230 | def _extract_cross_attention(tokenizer, device, color_map_image, color_context, text_input, neg):
231 | # Process color map image and context
232 | seperated_word_contexts, width, height = _image_context_seperator(
233 | color_map_image, color_context, tokenizer, neg
234 | )
235 |
236 | # Compute cross-attention weights
237 | cross_attention_weight_1 = _tokens_img_attention_weight(
238 | seperated_word_contexts, text_input, ratio=1, original_shape=True
239 | ).to(device)
240 | cross_attention_weight_8 = _tokens_img_attention_weight(
241 | seperated_word_contexts, text_input, ratio=8
242 | ).to(device)
243 | cross_attention_weight_16 = _tokens_img_attention_weight(
244 | seperated_word_contexts, text_input, ratio=16
245 | ).to(device)
246 | cross_attention_weight_32 = _tokens_img_attention_weight(
247 | seperated_word_contexts, text_input, ratio=32
248 | ).to(device)
249 | cross_attention_weight_64 = _tokens_img_attention_weight(
250 | seperated_word_contexts, text_input, ratio=64
251 | ).to(device)
252 |
253 | hidden_states = {
254 | "CA_WEIGHT_ORIG": cross_attention_weight_1, # 768 x 768
255 | "CA_WEIGHT_9216": cross_attention_weight_8, # 96 x 96
256 | "CA_WEIGHT_2304": cross_attention_weight_16, # 48 x 48
257 | "CA_WEIGHT_576": cross_attention_weight_32, # 24 x 24
258 | "CA_WEIGHT_144": cross_attention_weight_64, # 12 x 12
259 | }
260 |
261 | uncond_hidden_states = {
262 | "CA_WEIGHT_ORIG": 0,
263 | "CA_WEIGHT_9216": 0,
264 | "CA_WEIGHT_2304": 0,
265 | "CA_WEIGHT_576": 0,
266 | "CA_WEIGHT_144": 0,
267 | }
268 |
269 | return hidden_states, uncond_hidden_states
270 |
--------------------------------------------------------------------------------
/pics/2_subjects.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/2_subjects.jpg
--------------------------------------------------------------------------------
/pics/3_subjects.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/3_subjects.jpg
--------------------------------------------------------------------------------
/pics/4_subjects.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/4_subjects.jpg
--------------------------------------------------------------------------------
/pics/arch.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/arch.jpg
--------------------------------------------------------------------------------
/pics/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/architecture.png
--------------------------------------------------------------------------------
/pics/layout_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/layout_example.png
--------------------------------------------------------------------------------
/pics/more.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/more.png
--------------------------------------------------------------------------------
/pics/pipeline.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/pipeline.gif
--------------------------------------------------------------------------------
/pics/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ali-vilab/Cones-V2/447c194e0ac99a2871efb15f378c96cf12238ac2/pics/teaser.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1
2 | torchvision==0.14.1
3 | omegaconf==2.2.3
4 | opencv-python
5 | imageio==2.9.0
6 | transformers==4.26.1
7 | diffusers==0.13.1
8 | accelerate==0.20.0
9 | scipy==1.9.1
10 | hydra-core==1.2.0
11 | tqdm
12 | gradio==3.23.0
13 | pillow
14 | packaging
15 | numpy
--------------------------------------------------------------------------------
/train_cones2.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 |
16 | import argparse
17 | import itertools
18 | import logging
19 | import math
20 | import os
21 | from pathlib import Path
22 |
23 | import accelerate
24 | import torch
25 | import torch.nn.functional as F
26 | import torch.utils.checkpoint
27 | import transformers
28 | from accelerate import Accelerator
29 | from accelerate.logging import get_logger
30 | from accelerate.utils import ProjectConfiguration, set_seed
31 | from packaging import version
32 | from PIL import Image
33 | from torch.utils.data import Dataset
34 | from torchvision import transforms
35 | from tqdm.auto import tqdm
36 | from transformers import AutoTokenizer, PretrainedConfig
37 |
38 | import diffusers
39 | from diffusers import (
40 | AutoencoderKL,
41 | DDPMScheduler,
42 | DiffusionPipeline,
43 | UNet2DConditionModel,
44 | StableDiffusionPipeline,
45 | )
46 | from diffusers.optimization import get_scheduler
47 | from diffusers.utils.import_utils import is_xformers_available
48 |
49 |
50 | logger = get_logger(__name__)
51 |
52 |
53 | PROMPT_TEMPLETE = [
54 | "a photo of a {}",
55 | "a rendering of a {}",
56 | "a cropped photo of the {}",
57 | "the photo of a {}",
58 | "a photo of a clean {}",
59 | "a photo of a dirty {}",
60 | "a dark photo of the {}",
61 | "a photo of my {}",
62 | "a photo of the cool {}",
63 | "a close-up photo of a {}",
64 | "a bright photo of the {}",
65 | "a cropped photo of a {}",
66 | "a photo of the {}",
67 | "a good photo of the {}",
68 | "a photo of one {}",
69 | "a close-up photo of the {}",
70 | "a rendition of the {}",
71 | "a photo of the clean {}",
72 | "a rendition of a {}",
73 | "a photo of a nice {}",
74 | "a good photo of a {}",
75 | "a photo of the nice {}",
76 | "a photo of the small {}",
77 | "a photo of the weird {}",
78 | "a photo of the large {}",
79 | "a photo of a cool {}",
80 | "a photo of a small {}",
81 | ]
82 |
83 |
84 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
85 | text_encoder_config = PretrainedConfig.from_pretrained(
86 | pretrained_model_name_or_path,
87 | subfolder="text_encoder",
88 | revision=revision,
89 | )
90 | model_class = text_encoder_config.architectures[0]
91 |
92 | if model_class == "CLIPTextModel":
93 | from transformers import CLIPTextModel
94 |
95 | return CLIPTextModel
96 | elif model_class == "RobertaSeriesModelWithTransformation":
97 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
98 |
99 | return RobertaSeriesModelWithTransformation
100 | elif model_class == "T5EncoderModel":
101 | from transformers import T5EncoderModel
102 |
103 | return T5EncoderModel
104 | else:
105 | raise ValueError(f"{model_class} is not supported.")
106 |
107 |
108 | def parse_args(input_args=None):
109 | parser = argparse.ArgumentParser(description="Simple example of a script for training Cones 2.")
110 | parser.add_argument(
111 | "--pretrained_model_name_or_path",
112 | type=str,
113 | default=None,
114 | required=True,
115 | help="Path to pretrained model or model identifier from huggingface.co/models.",
116 | )
117 | parser.add_argument(
118 | "--revision",
119 | type=str,
120 | default=None,
121 | required=False,
122 | help=(
123 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
124 | " float32 precision."
125 | ),
126 | )
127 | parser.add_argument(
128 | "--tokenizer_name",
129 | type=str,
130 | default=None,
131 | help="Pretrained tokenizer name or path if not the same as model_name",
132 | )
133 | parser.add_argument(
134 | "--instance_data_dir",
135 | type=str,
136 | default=None,
137 | required=True,
138 | help="A folder containing the training data of instance images.",
139 | )
140 | parser.add_argument(
141 | "--instance_prompt",
142 | type=str,
143 | default=None,
144 | required=True,
145 | help="The prompt with identifier specifying the instance",
146 | )
147 | parser.add_argument(
148 | "--token_num",
149 | type=int,
150 | default=1,
151 | help="Number of updates steps to accumulate before performing a backward/update pass.",
152 | )
153 | parser.add_argument(
154 | "--output_dir",
155 | type=str,
156 | default="cones2-model",
157 | help="The output directory where the model predictions and checkpoints will be written.",
158 | )
159 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
160 | parser.add_argument(
161 | "--resolution",
162 | type=int,
163 | default=768,
164 | help=(
165 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
166 | " resolution"
167 | ),
168 | )
169 | parser.add_argument(
170 | "--center_crop",
171 | default=False,
172 | action="store_true",
173 | help=(
174 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
175 | " cropped. The images will be resized to the resolution first before cropping."
176 | ),
177 | )
178 | parser.add_argument(
179 | "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
180 | )
181 | parser.add_argument(
182 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
183 | )
184 | parser.add_argument("--num_train_epochs", type=int, default=1)
185 | parser.add_argument(
186 | "--max_train_steps",
187 | type=int,
188 | default=None,
189 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
190 | )
191 | parser.add_argument(
192 | "--checkpointing_steps",
193 | type=int,
194 | default=400,
195 | help=(
196 | "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via"
197 | " `--resume_from_checkpoint`. In the case that the checkpoint is better than the final trained model, the"
198 | " checkpoint can also be used for inference. Using a checkpoint for inference requires separate loading of"
199 | " the original pipeline and the individual checkpointed model components."
200 | ),
201 | )
202 | parser.add_argument(
203 | "--checkpoints_total_limit",
204 | type=int,
205 | default=None,
206 | help=(
207 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
208 | ),
209 | )
210 | parser.add_argument(
211 | "--resume_from_checkpoint",
212 | type=str,
213 | default=None,
214 | help=(
215 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
216 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
217 | ),
218 | )
219 | parser.add_argument(
220 | "--gradient_accumulation_steps",
221 | type=int,
222 | default=1,
223 | help="Number of updates steps to accumulate before performing a backward/update pass.",
224 | )
225 | parser.add_argument(
226 | "--gradient_checkpointing",
227 | action="store_true",
228 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
229 | )
230 | parser.add_argument(
231 | "--learning_rate",
232 | type=float,
233 | default=5e-6,
234 | help="Initial learning rate (after the potential warmup period) to use.",
235 | )
236 | parser.add_argument(
237 | "--loss_rate_first",
238 | type=float,
239 | default=5e-3,
240 | help="loss_rate",
241 | )
242 | parser.add_argument(
243 | "--loss_rate_second",
244 | type=float,
245 | default=5e-4,
246 | help="loss_rate_second",
247 | )
248 | parser.add_argument(
249 | "--scale_lr",
250 | action="store_true",
251 | default=False,
252 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
253 | )
254 | parser.add_argument(
255 | "--lr_scheduler",
256 | type=str,
257 | default="constant",
258 | help=(
259 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
260 | ' "constant", "constant_with_warmup"]'
261 | ),
262 | )
263 | parser.add_argument(
264 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
265 | )
266 | parser.add_argument(
267 | "--lr_num_cycles",
268 | type=int,
269 | default=1,
270 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
271 | )
272 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
273 | parser.add_argument(
274 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
275 | )
276 | parser.add_argument(
277 | "--dataloader_num_workers",
278 | type=int,
279 | default=0,
280 | help=(
281 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
282 | ),
283 | )
284 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
285 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
286 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
287 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
288 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
289 | parser.add_argument(
290 | "--logging_dir",
291 | type=str,
292 | default="logs",
293 | help=(
294 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
295 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
296 | ),
297 | )
298 | parser.add_argument(
299 | "--allow_tf32",
300 | action="store_true",
301 | help=(
302 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
303 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
304 | ),
305 | )
306 | parser.add_argument(
307 | "--report_to",
308 | type=str,
309 | default="tensorboard",
310 | help=(
311 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
312 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
313 | ),
314 | )
315 | parser.add_argument(
316 | "--mixed_precision",
317 | type=str,
318 | default=None,
319 | choices=["no", "fp16", "bf16"],
320 | help=(
321 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
322 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
323 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
324 | ),
325 | )
326 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
327 | parser.add_argument(
328 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
329 | )
330 | parser.add_argument(
331 | "--set_grads_to_none",
332 | action="store_true",
333 | help=(
334 | "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
335 | " behaviors, so disable this argument if it causes any problems. More info:"
336 | " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
337 | ),
338 | )
339 |
340 | if input_args is not None:
341 | args = parser.parse_args(input_args)
342 | else:
343 | args = parser.parse_args()
344 |
345 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
346 | if env_local_rank != -1 and env_local_rank != args.local_rank:
347 | args.local_rank = env_local_rank
348 |
349 | return args
350 |
351 |
352 | class ImagePromptDataset(Dataset):
353 | """
354 | A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
355 | It pre-processes the images and the tokenizes prompts.
356 | """
357 |
358 | def __init__(
359 | self,
360 | instance_data_root,
361 | instance_prompt,
362 | tokenizer,
363 | size=768,
364 | center_crop=False,
365 | ):
366 | self.size = size
367 | self.center_crop = center_crop
368 | self.tokenizer = tokenizer
369 |
370 | self.instance_data_root = Path(instance_data_root)
371 | if not self.instance_data_root.exists():
372 | raise ValueError(f"Instance {self.instance_data_root} images root doesn't exists.")
373 |
374 | self.instance_images_path = list(Path(instance_data_root).iterdir())
375 | self.num_instance_images = len(self.instance_images_path)
376 | self.instance_prompt = instance_prompt
377 | self._length = self.num_instance_images
378 |
379 | self.image_transforms = transforms.Compose(
380 | [
381 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
382 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
383 | transforms.ToTensor(),
384 | transforms.Normalize([0.5], [0.5]),
385 | ]
386 | )
387 |
388 | def __len__(self):
389 | return self._length
390 |
391 | def __getitem__(self, index):
392 | example = {}
393 | instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
394 | if not instance_image.mode == "RGB":
395 | instance_image = instance_image.convert("RGB")
396 | example["instance_images"] = self.image_transforms(instance_image)
397 | example["instance_prompt_ids"] = self.tokenizer(
398 | self.instance_prompt,
399 | truncation=True,
400 | padding="max_length",
401 | max_length=self.tokenizer.model_max_length,
402 | return_tensors="pt",
403 | ).input_ids
404 |
405 | return example
406 |
407 |
408 | def collate_fn(examples):
409 | input_ids = [example["instance_prompt_ids"] for example in examples]
410 | pixel_values = [example["instance_images"] for example in examples]
411 |
412 | pixel_values = torch.stack(pixel_values)
413 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
414 |
415 | input_ids = torch.cat(input_ids, dim=0)
416 |
417 | batch = {"input_ids": input_ids, "pixel_values": pixel_values}
418 | return batch
419 |
420 |
421 | class PromptDataset(Dataset):
422 | """A simple dataset to prepare the prompts to generate class images on multiple GPUs."""
423 |
424 | def __init__(self, prompt, num_samples):
425 | self.prompt = prompt
426 | self.num_samples = num_samples
427 |
428 | def __len__(self):
429 | return self.num_samples
430 |
431 | def __getitem__(self, index):
432 | example = {"prompt": self.prompt, "index": index}
433 | return example
434 |
435 |
436 | def main(args):
437 | logging_dir = Path(args.output_dir, args.logging_dir)
438 |
439 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
440 | accelerator = Accelerator(
441 | gradient_accumulation_steps=args.gradient_accumulation_steps,
442 | mixed_precision=args.mixed_precision,
443 | log_with=args.report_to,
444 | project_config=accelerator_project_config,
445 | )
446 |
447 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
448 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
449 | if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
450 | raise ValueError(
451 | "Gradient accumulation is not supported when training the text encoder in distributed training. "
452 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
453 | )
454 |
455 | # Make one log on every process with the configuration for debugging.
456 | logging.basicConfig(
457 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
458 | datefmt="%m/%d/%Y %H:%M:%S",
459 | level=logging.INFO,
460 | )
461 | logger.info(accelerator.state, main_process_only=False)
462 | if accelerator.is_local_main_process:
463 | transformers.utils.logging.set_verbosity_warning()
464 | diffusers.utils.logging.set_verbosity_info()
465 | else:
466 | transformers.utils.logging.set_verbosity_error()
467 | diffusers.utils.logging.set_verbosity_error()
468 |
469 | # If passed along, set the training seed now.
470 | if args.seed is not None:
471 | set_seed(args.seed)
472 |
473 | # Handle the repository creation
474 | if accelerator.is_main_process:
475 | if args.output_dir is not None:
476 | os.makedirs(args.output_dir, exist_ok=True)
477 |
478 | # Load the tokenizer
479 | if args.tokenizer_name:
480 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
481 | elif args.pretrained_model_name_or_path:
482 | tokenizer = AutoTokenizer.from_pretrained(
483 | args.pretrained_model_name_or_path,
484 | subfolder="tokenizer",
485 | revision=args.revision,
486 | use_fast=False,
487 | )
488 |
489 | # import correct text encoder class
490 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
491 |
492 | # Load scheduler and models
493 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
494 | text_encoder = text_encoder_cls.from_pretrained(
495 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
496 | )
497 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
498 | unet = UNet2DConditionModel.from_pretrained(
499 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
500 | )
501 |
502 | # `accelerate` 0.16.0 will have better support for customized saving
503 | if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
504 | # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
505 | def save_model_hook(models, weights, output_dir):
506 | for model in models:
507 | sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
508 | model.save_pretrained(os.path.join(output_dir, sub_dir))
509 |
510 | # make sure to pop weight so that corresponding model is not saved again
511 | weights.pop()
512 |
513 | def load_model_hook(models, input_dir):
514 | while len(models) > 0:
515 | # pop models so that they are not loaded again
516 | model = models.pop()
517 |
518 | if type(model) == type(text_encoder):
519 | # load transformers style into model
520 | load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
521 | model.config = load_model.config
522 | else:
523 | # load diffusers style into model
524 | load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
525 | model.register_to_config(**load_model.config)
526 |
527 | model.load_state_dict(load_model.state_dict())
528 | del load_model
529 |
530 | accelerator.register_save_state_pre_hook(save_model_hook)
531 | accelerator.register_load_state_pre_hook(load_model_hook)
532 |
533 | vae.requires_grad_(False)
534 | unet.requires_grad_(False)
535 |
536 | if args.enable_xformers_memory_efficient_attention:
537 | if is_xformers_available():
538 | import xformers
539 | xformers_version = version.parse(xformers.__version__)
540 | if xformers_version == version.parse("0.0.16"):
541 | logger.warn(
542 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training,"
543 | " please update xFormers to at least 0.0.17."
544 | " See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
545 | )
546 | unet.enable_xformers_memory_efficient_attention()
547 | else:
548 | raise ValueError("xformers is not available. Make sure it is installed correctly")
549 |
550 | if args.gradient_checkpointing:
551 | unet.enable_gradient_checkpointing()
552 | # if args.train_text_encoder:
553 | # text_encoder.gradient_checkpointing_enable()
554 | text_encoder.gradient_checkpointing_enable()
555 |
556 | # Check that all trainable models are in full precision
557 | low_precision_error_string = (
558 | "Please make sure to always have all model weights in full float32 precision when starting training - even if"
559 | " doing mixed precision training. copy of the weights should still be float32."
560 | )
561 |
562 | if accelerator.unwrap_model(text_encoder).dtype != torch.float32:
563 | raise ValueError(
564 | f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
565 | f" {low_precision_error_string}"
566 | )
567 |
568 | # Enable TF32 for faster training on Ampere GPUs,
569 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
570 | if args.allow_tf32:
571 | torch.backends.cuda.matmul.allow_tf32 = True
572 |
573 | if args.scale_lr:
574 | args.learning_rate = (
575 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
576 | )
577 |
578 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
579 | if args.use_8bit_adam:
580 | try:
581 | import bitsandbytes as bnb
582 | except ImportError:
583 | raise ImportError(
584 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
585 | )
586 |
587 | optimizer_class = bnb.optim.AdamW8bit
588 | else:
589 | optimizer_class = torch.optim.AdamW
590 |
591 | # Optimizer creation
592 | params_to_optimize = (itertools.chain(text_encoder.parameters()))
593 | optimizer = optimizer_class(
594 | params_to_optimize,
595 | lr=args.learning_rate,
596 | betas=(args.adam_beta1, args.adam_beta2),
597 | weight_decay=args.adam_weight_decay,
598 | eps=args.adam_epsilon,
599 | )
600 |
601 | # Dataset and DataLoaders creation:
602 | train_dataset = ImagePromptDataset(
603 | instance_data_root=args.instance_data_dir,
604 | instance_prompt=args.instance_prompt,
605 | tokenizer=tokenizer,
606 | size=args.resolution,
607 | center_crop=args.center_crop,
608 | )
609 |
610 | train_dataloader = torch.utils.data.DataLoader(
611 | train_dataset,
612 | batch_size=args.train_batch_size,
613 | shuffle=True,
614 | collate_fn=lambda examples: collate_fn(examples),
615 | num_workers=args.dataloader_num_workers,
616 | )
617 |
618 | # Scheduler and math around the number of training steps.
619 | overrode_max_train_steps = False
620 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
621 | if args.max_train_steps is None:
622 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
623 | overrode_max_train_steps = True
624 |
625 | lr_scheduler = get_scheduler(
626 | args.lr_scheduler,
627 | optimizer=optimizer,
628 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
629 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
630 | num_cycles=args.lr_num_cycles,
631 | power=args.lr_power,
632 | )
633 |
634 | text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
635 | text_encoder, optimizer, train_dataloader, lr_scheduler
636 | )
637 |
638 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
639 | # as these models are only used for inference, keeping weights in full precision is not required.
640 | weight_dtype = torch.float32
641 | if accelerator.mixed_precision == "fp16":
642 | weight_dtype = torch.float16
643 | elif accelerator.mixed_precision == "bf16":
644 | weight_dtype = torch.bfloat16
645 |
646 | # Move vae and unet to device and cast to weight_dtype
647 | vae.to(accelerator.device, dtype=weight_dtype)
648 | unet.to(accelerator.device, dtype=weight_dtype)
649 |
650 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
651 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
652 | if overrode_max_train_steps:
653 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
654 |
655 | # Afterwards we recalculate our number of training epochs
656 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
657 |
658 | # We need to initialize the trackers we use, and also store our configuration.
659 | # The trackers initializes automatically on the main process.
660 | if accelerator.is_main_process:
661 | accelerator.init_trackers("cones2", config=vars(args))
662 |
663 | # Train!
664 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
665 |
666 | logger.info("***** Running training *****")
667 | logger.info(f" Num examples = {len(train_dataset)}")
668 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
669 | logger.info(f" Num Epochs = {args.num_train_epochs}")
670 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
671 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
672 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
673 | logger.info(f" Total optimization steps = {args.max_train_steps}")
674 | global_step = 0
675 | first_epoch = 0
676 |
677 | # Potentially load in the weights and states from a previous save
678 | if args.resume_from_checkpoint:
679 | if args.resume_from_checkpoint != "latest":
680 | path = os.path.basename(args.resume_from_checkpoint)
681 | else:
682 | # Get the mos recent checkpoint
683 | dirs = os.listdir(args.output_dir)
684 | dirs = [d for d in dirs if d.startswith("checkpoint")]
685 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
686 | path = dirs[-1] if len(dirs) > 0 else None
687 |
688 | if path is None:
689 | accelerator.print(
690 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
691 | )
692 | args.resume_from_checkpoint = None
693 | else:
694 | accelerator.print(f"Resuming from checkpoint {path}")
695 | accelerator.load_state(os.path.join(args.output_dir, path))
696 | global_step = int(path.split("-")[1])
697 |
698 | resume_global_step = global_step * args.gradient_accumulation_steps
699 | first_epoch = global_step // num_update_steps_per_epoch
700 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
701 |
702 | # Only show the progress bar once on each machine.
703 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
704 | progress_bar.set_description("Steps")
705 |
706 | num_iter = 0
707 |
708 | for epoch in range(first_epoch, args.num_train_epochs):
709 | text_encoder.train()
710 | for step, batch in enumerate(train_dataloader):
711 | # Skip steps until we reach the resumed step
712 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
713 | if step % args.gradient_accumulation_steps == 0:
714 | progress_bar.update(1)
715 | continue
716 |
717 | with accelerator.accumulate(text_encoder):
718 | # Convert images to latent space
719 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
720 | latents = latents * vae.config.scaling_factor
721 |
722 | # Sample noise that we'll add to the latents
723 | noise = torch.randn_like(latents)
724 | bsz = latents.shape[0]
725 |
726 | # Sample a random timestep for each image
727 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
728 | timesteps = timesteps.long()
729 |
730 | # Add noise to the latents according to the noise magnitude at each timestep
731 | # (this is the forward diffusion process)
732 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
733 |
734 | # Get the text embedding for conditioning
735 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
736 | if num_iter == 0:
737 | target_embed = encoder_hidden_states.detach()
738 |
739 | l_first_part = torch.norm(torch.squeeze(target_embed)[:1] - torch.squeeze(encoder_hidden_states)[:1], 2)
740 | l_second_part = torch.norm(torch.squeeze(target_embed)[1 + args.token_num:]
741 | - torch.squeeze(encoder_hidden_states)[1 + args.token_num:], 2)
742 | loss_embedding = args.loss_rate_first * l_first_part + args.loss_rate_second * l_second_part
743 |
744 | # Predict the noise residual
745 | model_pred = unet(noisy_latents.float(), timesteps, encoder_hidden_states.float()).sample
746 |
747 | # Get the target for loss depending on the prediction type
748 | if noise_scheduler.config.prediction_type == "epsilon":
749 | target = noise
750 | elif noise_scheduler.config.prediction_type == "v_prediction":
751 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
752 | else:
753 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
754 |
755 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
756 | loss = loss + loss_embedding
757 |
758 | accelerator.backward(loss)
759 | if accelerator.sync_gradients:
760 | params_to_clip = (itertools.chain(text_encoder.parameters()))
761 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
762 | optimizer.step()
763 | lr_scheduler.step()
764 |
765 | optimizer.zero_grad(set_to_none=args.set_grads_to_none)
766 | num_iter += 1
767 |
768 | # Checks if the accelerator has performed an optimization step behind the scenes
769 | if accelerator.sync_gradients:
770 | progress_bar.update(1)
771 | global_step += 1
772 |
773 | if global_step % args.checkpointing_steps == 0:
774 | if accelerator.is_main_process:
775 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
776 | accelerator.save_state(save_path)
777 | logger.info(f"Saved state to {save_path}")
778 |
779 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
780 | progress_bar.set_postfix(**logs)
781 | accelerator.log(logs, step=global_step)
782 |
783 | if global_step >= args.max_train_steps:
784 | break
785 |
786 | # Create the pipeline using the trained modules and save it.
787 | accelerator.wait_for_everyone()
788 | if accelerator.is_main_process:
789 | pipe = StableDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path).to("cuda")
790 |
791 | text_inputs_origin = pipe.tokenizer(
792 | args.instance_prompt,
793 | padding="max_length",
794 | max_length=tokenizer.model_max_length,
795 | truncation=True,
796 | return_tensors="pt",
797 | )
798 | text_inputs_origin_ids = text_inputs_origin.input_ids
799 | index = text_inputs_origin_ids[0][1]
800 | prompt_embeds_new = 0
801 | prompt_embeds_origin = 0
802 | for template in PROMPT_TEMPLETE:
803 | text_inputs = pipe.tokenizer(
804 | template.format(args.instance_prompt),
805 | padding="max_length",
806 | max_length=tokenizer.model_max_length,
807 | truncation=True,
808 | return_tensors="pt",
809 | )
810 | text_input_ids = text_inputs.input_ids
811 | index_template = int(torch.where(text_input_ids[0] == index)[0][0])
812 | prompt_embeds_now = text_encoder(text_input_ids.to("cuda"), attention_mask=None)
813 | prompt_embeds_now = prompt_embeds_now[0][0][index_template: index_template + args.token_num]
814 | prompt_embeds = pipe.text_encoder(text_input_ids.to("cuda"), attention_mask=None)
815 | prompt_embeds = prompt_embeds[0][0][index_template: index_template + args.token_num]
816 | prompt_embeds_new += prompt_embeds_now
817 | prompt_embeds_origin += prompt_embeds
818 |
819 | residual_save_path = args.output_dir + '/residual.pt'
820 | torch.save((prompt_embeds_new - prompt_embeds_origin) / len(PROMPT_TEMPLETE), residual_save_path)
821 |
822 | pipeline = DiffusionPipeline.from_pretrained(
823 | args.pretrained_model_name_or_path,
824 | unet=accelerator.unwrap_model(unet),
825 | text_encoder=accelerator.unwrap_model(text_encoder),
826 | revision=args.revision,
827 | )
828 | pipeline.save_pretrained(args.output_dir)
829 | accelerator.print(f"The customized residual embedding saved in {residual_save_path}")
830 |
831 | accelerator.end_training()
832 |
833 |
834 | if __name__ == "__main__":
835 | args = parse_args()
836 | main(args)
837 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | def downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor:
8 | return F.interpolate(
9 | img.unsqueeze(0).unsqueeze(1),
10 | size=(w, h),
11 | mode="bilinear",
12 | align_corners=True,
13 | ).squeeze()
14 |
15 |
16 | def image_grid(images, rows=2, cols=2):
17 | w, h = images[0].size
18 | grid = Image.new('RGB', size=(cols * w, rows * h))
19 |
20 | for i, img in enumerate(images):
21 | grid.paste(img, box=(i % cols * w, i // cols * h))
22 | return grid
23 |
24 |
25 | def latents_to_images(vae, latents, scale_factor=0.18215):
26 | """
27 | Decode latents to PIL images.
28 | """
29 | scaled_latents = 1.0 / scale_factor * latents.clone()
30 | images = vae.decode(scaled_latents).sample
31 | images = (images / 2 + 0.5).clamp(0, 1)
32 | images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
33 |
34 | if images.ndim == 3:
35 | images = images[None, ...]
36 | images = (images * 255).round().astype("uint8")
37 | pil_images = [Image.fromarray(image) for image in images]
38 |
39 | return pil_images
40 |
--------------------------------------------------------------------------------