├── FADING_demo.png ├── FADING_util ├── ptp_utils.py ├── seq_aligner.py └── util.py ├── README.md ├── age_editing.py ├── null_inversion.py ├── p2p.py └── specialize.py /FADING_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MunchkinChen/FADING/14fefbf7edbf5b1d3f867bd4720a61bfc1bdf2c0/FADING_demo.png -------------------------------------------------------------------------------- /FADING_util/ptp_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | from PIL import Image, ImageDraw, ImageFont 18 | import cv2 19 | from typing import Optional, Union, Tuple, List, Callable, Dict 20 | from IPython.display import display 21 | from tqdm.notebook import tqdm 22 | 23 | 24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): 25 | h, w, c = image.shape 26 | offset = int(h * .2) 27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 28 | font = cv2.FONT_HERSHEY_SIMPLEX 29 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) 30 | img[:h] = image 31 | textsize = cv2.getTextSize(text, font, 1, 2)[0] 32 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 33 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2) 34 | return img 35 | 36 | 37 | def view_images(images, num_rows=1, offset_ratio=0.02): 38 | if type(images) is list: 39 | num_empty = len(images) % num_rows 40 | elif images.ndim == 4: 41 | num_empty = images.shape[0] % num_rows 42 | else: 43 | images = [images] 44 | num_empty = 0 45 | 46 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 47 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty 48 | num_items = len(images) 49 | 50 | h, w, c = images[0].shape 51 | offset = int(h * offset_ratio) 52 | num_cols = num_items // num_rows 53 | image_ = np.ones((h * num_rows + offset * (num_rows - 1), 54 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 55 | for i in range(num_rows): 56 | for j in range(num_cols): 57 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ 58 | i * num_cols + j] 59 | 60 | pil_img = Image.fromarray(image_) 61 | return image_ 62 | 63 | 64 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False): 65 | if low_resource: 66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"] 67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"] 68 | else: 69 | latents_input = torch.cat([latents] * 2) 70 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 71 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 72 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 73 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"] 74 | latents = controller.step_callback(latents) 75 | return latents 76 | 77 | 78 | def latent2image(vae, latents): 79 | latents = 1 / 0.18215 * latents 80 | image = vae.decode(latents)['sample'] 81 | image = (image / 2 + 0.5).clamp(0, 1) 82 | image = image.cpu().permute(0, 2, 3, 1).numpy() 83 | image = (image * 255).astype(np.uint8) 84 | return image 85 | 86 | 87 | def init_latent(latent, model, height, width, generator, batch_size): 88 | if latent is None: 89 | latent = torch.randn( 90 | (1, model.unet.in_channels, height // 8, width // 8), 91 | generator=generator,device=model.device, 92 | ) 93 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device) 94 | return latent, latents 95 | 96 | # 97 | # @torch.no_grad() 98 | # def text2image_ldm( 99 | # model, 100 | # prompt: List[str], 101 | # controller, 102 | # num_inference_steps: int = 50, 103 | # guidance_scale: Optional[float] = 7., 104 | # generator: Optional[torch.Generator] = None, 105 | # latent: Optional[torch.FloatTensor] = None, 106 | # ): 107 | # register_attention_control(model, controller) 108 | # height = width = 256 109 | # batch_size = len(prompt) 110 | # 111 | # uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt") 112 | # uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0] 113 | # 114 | # text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt") 115 | # text_embeddings = model.bert(text_input.input_ids.to(model.device))[0] 116 | # latent, latents = init_latent(latent, model, height, width, generator, batch_size) 117 | # context = torch.cat([uncond_embeddings, text_embeddings]) 118 | # 119 | # model.scheduler.set_timesteps(num_inference_steps) 120 | # for t in tqdm(model.scheduler.timesteps): 121 | # latents = diffusion_step(model, controller, latents, context, t, guidance_scale) 122 | # 123 | # image = latent2image(model.vqvae, latents) 124 | # 125 | # return image, latent 126 | # 127 | # 128 | # @torch.no_grad() 129 | # def text2image_ldm_stable( 130 | # model, 131 | # prompt: List[str], 132 | # controller, 133 | # num_inference_steps: int = 50, 134 | # guidance_scale: float = 7.5, 135 | # generator: Optional[torch.Generator] = None, 136 | # latent: Optional[torch.FloatTensor] = None, 137 | # low_resource: bool = False, 138 | # ): 139 | # register_attention_control(model, controller) 140 | # height = width = 512 141 | # batch_size = len(prompt) 142 | # 143 | # text_input = model.tokenizer( 144 | # prompt, 145 | # padding="max_length", 146 | # max_length=model.tokenizer.model_max_length, 147 | # truncation=True, 148 | # return_tensors="pt", 149 | # ) 150 | # text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 151 | # max_length = text_input.input_ids.shape[-1] 152 | # uncond_input = model.tokenizer( 153 | # [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 154 | # ) 155 | # uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 156 | # 157 | # context = [uncond_embeddings, text_embeddings] 158 | # if not low_resource: 159 | # context = torch.cat(context) 160 | # latent, latents = init_latent(latent, model, height, width, generator, batch_size) 161 | # 162 | # # set timesteps 163 | # # extra_set_kwargs = {"offset": 1} 164 | # # model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 165 | # model.scheduler.set_timesteps(num_inference_steps) 166 | # for t in tqdm(model.scheduler.timesteps): 167 | # latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource) 168 | # 169 | # image = latent2image(model.vae, latents) 170 | # 171 | # return image, latent 172 | 173 | 174 | def register_attention_control(model, controller): 175 | def ca_forward(self, place_in_unet): 176 | to_out = self.to_out 177 | if type(to_out) is torch.nn.modules.container.ModuleList: 178 | to_out = self.to_out[0] 179 | else: 180 | to_out = self.to_out 181 | 182 | def forward(x, context=None, mask=None): 183 | batch_size, sequence_length, dim = x.shape 184 | h = self.heads 185 | q = self.to_q(x) 186 | is_cross = context is not None 187 | context = context if is_cross else x 188 | k = self.to_k(context) 189 | v = self.to_v(context) 190 | q = self.reshape_heads_to_batch_dim(q) 191 | k = self.reshape_heads_to_batch_dim(k) 192 | v = self.reshape_heads_to_batch_dim(v) 193 | 194 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale 195 | 196 | if mask is not None: 197 | mask = mask.reshape(batch_size, -1) 198 | max_neg_value = -torch.finfo(sim.dtype).max 199 | mask = mask[:, None, :].repeat(h, 1, 1) 200 | sim.masked_fill_(~mask, max_neg_value) 201 | 202 | # attention, what we cannot get enough of 203 | attn = sim.softmax(dim=-1) 204 | attn = controller(attn, is_cross, place_in_unet) 205 | out = torch.einsum("b i j, b j d -> b i d", attn, v) 206 | out = self.reshape_batch_dim_to_heads(out) 207 | return to_out(out) 208 | 209 | return forward 210 | 211 | class DummyController: 212 | 213 | def __call__(self, *args): 214 | return args[0] 215 | 216 | def __init__(self): 217 | self.num_att_layers = 0 218 | 219 | if controller is None: 220 | controller = DummyController() 221 | 222 | def register_recr(net_, count, place_in_unet): 223 | if net_.__class__.__name__ == 'CrossAttention': 224 | net_.forward = ca_forward(net_, place_in_unet) 225 | return count + 1 226 | elif hasattr(net_, 'children'): 227 | for net__ in net_.children(): 228 | count = register_recr(net__, count, place_in_unet) 229 | return count 230 | 231 | cross_att_count = 0 232 | sub_nets = model.unet.named_children() 233 | for net in sub_nets: 234 | if "down" in net[0]: 235 | cross_att_count += register_recr(net[1], 0, "down") 236 | elif "up" in net[0]: 237 | cross_att_count += register_recr(net[1], 0, "up") 238 | elif "mid" in net[0]: 239 | cross_att_count += register_recr(net[1], 0, "mid") 240 | 241 | controller.num_att_layers = cross_att_count 242 | 243 | 244 | def get_word_inds(text: str, word_place: int, tokenizer): 245 | split_text = text.split(" ") 246 | if type(word_place) is str: 247 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 248 | elif type(word_place) is int: 249 | word_place = [word_place] 250 | out = [] 251 | if len(word_place) > 0: 252 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 253 | cur_len, ptr = 0, 0 254 | 255 | for i in range(len(words_encode)): 256 | cur_len += len(words_encode[i]) 257 | if ptr in word_place: 258 | out.append(i + 1) 259 | if cur_len >= len(split_text[ptr]): 260 | ptr += 1 261 | cur_len = 0 262 | return np.array(out) 263 | 264 | 265 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int, 266 | word_inds: Optional[torch.Tensor] = None): 267 | if type(bounds) is float: 268 | bounds = 0, bounds 269 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0]) 270 | if word_inds is None: 271 | word_inds = torch.arange(alpha.shape[2]) 272 | alpha[: start, prompt_ind, word_inds] = 0 273 | alpha[start: end, prompt_ind, word_inds] = 1 274 | alpha[end:, prompt_ind, word_inds] = 0 275 | return alpha 276 | 277 | 278 | def get_time_words_attention_alpha(prompts, num_steps, 279 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]], 280 | tokenizer, max_num_words=77): 281 | if type(cross_replace_steps) is not dict: 282 | cross_replace_steps = {"default_": cross_replace_steps} 283 | if "default_" not in cross_replace_steps: 284 | cross_replace_steps["default_"] = (0., 1.) 285 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words) 286 | for i in range(len(prompts) - 1): 287 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"], 288 | i) 289 | for key, item in cross_replace_steps.items(): 290 | if key != "default_": 291 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))] 292 | for i, ind in enumerate(inds): 293 | if len(ind) > 0: 294 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind) 295 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words) 296 | return alpha_time_words -------------------------------------------------------------------------------- /FADING_util/seq_aligner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Google LLC 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class ScoreParams: 19 | 20 | def __init__(self, gap, match, mismatch): 21 | self.gap = gap 22 | self.match = match 23 | self.mismatch = mismatch 24 | 25 | def mis_match_char(self, x, y): 26 | if x != y: 27 | return self.mismatch 28 | else: 29 | return self.match 30 | 31 | 32 | def get_matrix(size_x, size_y, gap): 33 | matrix = [] 34 | for i in range(len(size_x) + 1): 35 | sub_matrix = [] 36 | for j in range(len(size_y) + 1): 37 | sub_matrix.append(0) 38 | matrix.append(sub_matrix) 39 | for j in range(1, len(size_y) + 1): 40 | matrix[0][j] = j * gap 41 | for i in range(1, len(size_x) + 1): 42 | matrix[i][0] = i * gap 43 | return matrix 44 | 45 | 46 | def get_matrix(size_x, size_y, gap): 47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 50 | return matrix 51 | 52 | 53 | def get_traceback_matrix(size_x, size_y): 54 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 55 | matrix[0, 1:] = 1 56 | matrix[1:, 0] = 2 57 | matrix[0, 0] = 4 58 | return matrix 59 | 60 | 61 | def global_align(x, y, score): 62 | matrix = get_matrix(len(x), len(y), score.gap) 63 | trace_back = get_traceback_matrix(len(x), len(y)) 64 | for i in range(1, len(x) + 1): 65 | for j in range(1, len(y) + 1): 66 | left = matrix[i, j - 1] + score.gap 67 | up = matrix[i - 1, j] + score.gap 68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 69 | matrix[i, j] = max(left, up, diag) 70 | if matrix[i, j] == left: 71 | trace_back[i, j] = 1 72 | elif matrix[i, j] == up: 73 | trace_back[i, j] = 2 74 | else: 75 | trace_back[i, j] = 3 76 | return matrix, trace_back 77 | 78 | 79 | def get_aligned_sequences(x, y, trace_back): 80 | x_seq = [] 81 | y_seq = [] 82 | i = len(x) 83 | j = len(y) 84 | mapper_y_to_x = [] 85 | while i > 0 or j > 0: 86 | if trace_back[i, j] == 3: 87 | x_seq.append(x[i - 1]) 88 | y_seq.append(y[j - 1]) 89 | i = i - 1 90 | j = j - 1 91 | mapper_y_to_x.append((j, i)) 92 | elif trace_back[i][j] == 1: 93 | x_seq.append('-') 94 | y_seq.append(y[j - 1]) 95 | j = j - 1 96 | mapper_y_to_x.append((j, -1)) 97 | elif trace_back[i][j] == 2: 98 | x_seq.append(x[i - 1]) 99 | y_seq.append('-') 100 | i = i - 1 101 | elif trace_back[i][j] == 4: 102 | break 103 | mapper_y_to_x.reverse() 104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64) 105 | 106 | 107 | def get_mapper(x: str, y: str, tokenizer, max_len=77): 108 | x_seq = tokenizer.encode(x) 109 | y_seq = tokenizer.encode(y) 110 | score = ScoreParams(0, 1, -1) 111 | matrix, trace_back = global_align(x_seq, y_seq, score) 112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 113 | alphas = torch.ones(max_len) 114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float() 115 | mapper = torch.zeros(max_len, dtype=torch.int64) 116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1] 117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq)) 118 | return mapper, alphas 119 | 120 | 121 | def get_refinement_mapper(prompts, tokenizer, max_len=77): 122 | x_seq = prompts[0] 123 | mappers, alphas = [], [] 124 | for i in range(1, len(prompts)): 125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len) 126 | mappers.append(mapper) 127 | alphas.append(alpha) 128 | return torch.stack(mappers), torch.stack(alphas) 129 | 130 | 131 | def get_word_inds(text: str, word_place: int, tokenizer): 132 | split_text = text.split(" ") 133 | if type(word_place) is str: 134 | word_place = [i for i, word in enumerate(split_text) if word_place == word] 135 | elif type(word_place) is int: 136 | word_place = [word_place] 137 | out = [] 138 | if len(word_place) > 0: 139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1] 140 | cur_len, ptr = 0, 0 141 | 142 | for i in range(len(words_encode)): 143 | cur_len += len(words_encode[i]) 144 | if ptr in word_place: 145 | out.append(i + 1) 146 | if cur_len >= len(split_text[ptr]): 147 | ptr += 1 148 | cur_len = 0 149 | return np.array(out) 150 | 151 | 152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77): 153 | words_x = x.split(' ') 154 | words_y = y.split(' ') 155 | if len(words_x) != len(words_y): 156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length" 157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.") 158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]] 159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace] 160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace] 161 | mapper = np.zeros((max_len, max_len)) 162 | i = j = 0 163 | cur_inds = 0 164 | while i < max_len and j < max_len: 165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i: 166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds] 167 | if len(inds_source_) == len(inds_target_): 168 | mapper[inds_source_, inds_target_] = 1 169 | else: 170 | ratio = 1 / len(inds_target_) 171 | for i_t in inds_target_: 172 | mapper[inds_source_, i_t] = ratio 173 | cur_inds += 1 174 | i += len(inds_source_) 175 | j += len(inds_target_) 176 | elif cur_inds < len(inds_source): 177 | mapper[i, j] = 1 178 | i += 1 179 | j += 1 180 | else: 181 | mapper[j, j] = 1 182 | i += 1 183 | j += 1 184 | 185 | return torch.from_numpy(mapper).float() 186 | 187 | 188 | def get_replacement_mapper(prompts, tokenizer, max_len=77): 189 | x_seq = prompts[0] 190 | mappers = [] 191 | for i in range(1, len(prompts)): 192 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len) 193 | mappers.append(mapper) 194 | return torch.stack(mappers) 195 | -------------------------------------------------------------------------------- /FADING_util/util.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import os 7 | from torchvision import transforms 8 | import json 9 | 10 | def get_instance_prompt(dreambooth_dir): 11 | json_path = os.path.join(dreambooth_dir, "model_config.json") 12 | with open(json_path, 'r') as file: 13 | model_config = json.load(file) 14 | return model_config['instance_prompt'] 15 | 16 | # get_instance_prompt('saved_model/de-id/FFHQ_512_00006_100') 17 | #%% 18 | def mydisplay(img): 19 | plt.axis('off') 20 | plt.imshow(img) 21 | plt.show() 22 | 23 | def load_image(p, arr=False, resize=None): 24 | ''' 25 | Function to load images from a defined path 26 | ''' 27 | ret = Image.open(p).convert('RGB') 28 | if resize is not None: 29 | ret = ret.resize((resize[0],resize[1])) 30 | if not arr: 31 | return ret 32 | return np.array(ret) 33 | 34 | 35 | 36 | def numpy_to_pil(images): 37 | """ 38 | Convert a numpy image or a batch of images to a PIL image. 39 | """ 40 | if images.ndim == 3: 41 | images = images[None, ...] 42 | images = (images * 255).round().astype("uint8") 43 | if images.shape[-1] == 1: 44 | # special case for grayscale (single channel) images 45 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] 46 | else: 47 | pil_images = [Image.fromarray(image) for image in images] 48 | 49 | return pil_images 50 | 51 | 52 | def tensor_to_img(tensor,arr=False): 53 | tmp = tensor.clone().squeeze(0).cpu() 54 | tfpil = transforms.ToPILImage() 55 | tmp = tfpil(tmp) 56 | # tmp = (tmp+1)*0.5 57 | if arr: 58 | tmp = np.array(tmp) 59 | return tmp 60 | 61 | 62 | #%% 63 | def image_grid(imgs_, rows=None, cols=None, sort_file_filter=None, remove_filter=None, border=0, resize=None): 64 | if isinstance(imgs_, str) or (isinstance(imgs_, list) and isinstance(imgs_[0], str)): 65 | if isinstance(imgs_, str): 66 | # imgs 是一个dir 67 | files = os.listdir(imgs_) 68 | 69 | if remove_filter: 70 | files = remove_filter(files) 71 | 72 | if sort_file_filter: 73 | files = sorted(files, key=sort_file_filter) 74 | 75 | files = [os.path.join(imgs_, f) for f in files] 76 | else: 77 | # imgs 是一个dir的list 78 | files = imgs_ 79 | 80 | print(files) 81 | 82 | imgs = [] 83 | for f in files[:]: 84 | img = load_image(f,resize=resize) 85 | imgs.append(img) 86 | 87 | elif isinstance(imgs_, np.ndarray): 88 | # imgs 是一个ndarray 89 | imgs = [Image.fromarray(i) for i in imgs_] 90 | 91 | else: 92 | # imgs 是一个PIL的list 93 | imgs = imgs_[:] 94 | 95 | if not rows or not cols: 96 | rows = 1 97 | cols = len(imgs) 98 | 99 | assert len(imgs) == rows * cols 100 | 101 | w, h = imgs[-1].size 102 | grid = Image.new('RGB', size=(cols * w + (cols-1)*border, rows * h + (rows-1)*border), color='white') 103 | grid_w, grid_h = grid.size 104 | 105 | for i, img in enumerate(imgs): 106 | grid.paste(img, box=(i % cols * (w+border), i // cols * (h+border))) 107 | return grid 108 | 109 | #%% 110 | def sort_by_num(separator='-'): 111 | def sort_by_num_(x): 112 | return int(x.split(separator, 1)[0]) 113 | return sort_by_num_ 114 | def remove_filter(files): 115 | ret_files = [] 116 | for f in files: 117 | if f[0]!= '.' and f.split('.')[0] not in ['1','8','17']: 118 | ret_files.append(f) 119 | return ret_files 120 | def tmp(x): 121 | num = int(x[:5]) 122 | return num 123 | 124 | 125 | def get_person_placeholder(age=None, predicted_gender=None): 126 | if predicted_gender is not None: 127 | if age and age <= 15: 128 | person_placeholder = ['boy', 'girl'][predicted_gender == 'Female' or predicted_gender == 1] 129 | else: # init age > 15 或者根本没有init age 130 | person_placeholder = ['man','woman'][predicted_gender == 'Female' or predicted_gender == 1] 131 | else: 132 | if age and age <= 15: 133 | person_placeholder = "child" 134 | else: 135 | person_placeholder = "person" 136 | return person_placeholder -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FADING 2 | 3 | ## Face Aging via Diffusion-based Editing 4 | [![arXiv](https://img.shields.io/badge/arXiv-2309.11321-b31b1b)](https://arxiv.org/abs/2309.11321) 5 | [![Project Page](https://img.shields.io/badge/Project-Website-orange)](https://proceedings.bmvc2023.org/595/) 6 | [![Static Badge](https://img.shields.io/badge/supplementary-blue)](https://bmvc2022.mpi-inf.mpg.de/BMVC2023/0595_supp.pdf) 7 | 8 | Official repo for BMVC 2023 paper: [Face Aging via Diffusion-based Editing](https://proceedings.bmvc2023.org/595/). 9 | 10 | For more visualization results, please check [supplementary materiels](https://bmvc2022.mpi-inf.mpg.de/BMVC2023/0595_supp.pdf). 11 | 12 |
13 | 14 |
15 | 16 | > In this paper, we address the problem of face aging—generating past or future facial images by incorporating age-related changes to the given face. Previous aging methods rely solely on human facial image datasets and are thus constrained by their inherent scale and bias. This restricts their application to a limited generatable age range and the inability to handle large age gaps. We propose FADING, a novel approach to address Face Aging via DIffusion-based editiNG. We go beyond existing methods by leveraging the rich prior of large-scale language-image diffusion models. First, we specialize a pre-trained diffusion model for the task of face age editing by using an age-aware fine-tuning scheme. Next, we invert the input image to latent noise and obtain optimized null text embeddings. Finally, we perform text-guided local age editing via attention control. The quantitative and qualitative analyses demonstrate that our method outperforms existing approaches with respect to aging accuracy, attribute preservation, and aging quality. 17 | 18 | 19 | 20 | ## Dataset 21 | The FFHQ-Aging Dataset used for training FADING could be downloaded from https://github.com/royorel/FFHQ-Aging-Dataset 22 | 23 | ## Training (Specialization) 24 | 25 | ### Available pretrained weights 26 | We release weights of our specialized model at https://drive.google.com/file/d/1galwrcHq1HoZNfOI4jdJJqVs5ehB_dvO/view?usp=share_link 27 | 28 | ### Train a new model 29 | 30 | ```shell 31 | accelerate launch specialize_general.py \ 32 | --instance_data_dir 'specialization_data/training_images' \ 33 | --instance_age_path 'specialization_data/training_ages.npy' \ 34 | --output_dir \ 35 | --max_train_steps 150 36 | ``` 37 | Training images should be saved at `specialization_data/training_images`. The training set is described through `training_ages.npy` that contains the age of the training images. 38 | ```angular2html 39 | array([['00007.jpg', '1'], 40 | ['00004.jpg', '35'], 41 | ... 42 | ['00009.jpg', '35']], dtype=' \ 50 | --age_init \ 51 | --gender \ 52 | --save_aged_dir \ 53 | --specialized_path \ 54 | --target_ages 10 20 40 60 80 55 | ``` 56 | -------------------------------------------------------------------------------- /age_editing.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from diffusers import StableDiffusionPipeline, DDIMScheduler 4 | 5 | from FADING_util import util 6 | from p2p import * 7 | from null_inversion import * 8 | 9 | #%% 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--image_path', required=True, help='Path to input image') 13 | parser.add_argument('--age_init', required=True, type=int, help='Specify the initial age') 14 | parser.add_argument('--gender', required=True, choices=["female", "male"], help="Specify the gender ('female' or 'male')") 15 | parser.add_argument('--specialized_path', required=True, help='Path to specialized diffusion model') 16 | parser.add_argument('--save_aged_dir', default='./outputs', help='Path to save outputs') 17 | parser.add_argument('--target_ages', nargs='+', default=[10, 20, 40, 60, 80], type=int, help='Target age values') 18 | 19 | args = parser.parse_args() 20 | 21 | #%% 22 | image_path = args.image_path 23 | age_init = args.age_init 24 | gender = args.gender 25 | save_aged_dir = args.save_aged_dir 26 | specialized_path = args.specialized_path 27 | target_ages = args.target_ages 28 | 29 | if not os.path.exists(save_aged_dir): 30 | os.makedirs(save_aged_dir) 31 | 32 | gt_gender = int(gender == 'female') 33 | person_placeholder = util.get_person_placeholder(age_init, gt_gender) 34 | inversion_prompt = f"photo of {age_init} year old {person_placeholder}" 35 | 36 | input_img_name = image_path.split('/')[-1].split('.')[-2] 37 | 38 | #%% load specialized diffusion model 39 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", 40 | clip_sample=False, set_alpha_to_one=False, 41 | steps_offset=1) 42 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 43 | g_cuda = torch.Generator(device=device) 44 | 45 | ldm_stable = StableDiffusionPipeline.from_pretrained(specialized_path, 46 | scheduler=scheduler, 47 | safety_checker=None).to(device) 48 | tokenizer = ldm_stable.tokenizer 49 | 50 | #%% null text inversion 51 | null_inversion = NullInversion(ldm_stable) 52 | (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, inversion_prompt, 53 | offsets=(0,0,0,0), verbose=True) 54 | #%% age editing 55 | for age_new in target_ages: 56 | print(f'Age editing with target age {age_new}...') 57 | new_person_placeholder = util.get_person_placeholder(age_new, gt_gender) 58 | new_prompt = inversion_prompt.replace(person_placeholder, new_person_placeholder) 59 | 60 | new_prompt = new_prompt.replace(str(age_init),str(age_new)) 61 | blend_word = (((str(age_init),person_placeholder,), (str(age_new),new_person_placeholder,))) 62 | is_replace_controller = True 63 | 64 | prompts = [inversion_prompt, new_prompt] 65 | 66 | cross_replace_steps = {'default_': .8,} 67 | self_replace_steps = .5 68 | 69 | eq_params = {"words": (str(age_new)), "values": (1,)} 70 | 71 | controller = make_controller(prompts, is_replace_controller, cross_replace_steps, self_replace_steps, 72 | tokenizer, blend_word, eq_params) 73 | 74 | images, _ = p2p_text2image(ldm_stable, prompts, controller, generator=g_cuda.manual_seed(0), 75 | latent=x_t, uncond_embeddings=uncond_embeddings) 76 | 77 | new_img = images[-1] 78 | new_img_pil = Image.fromarray(new_img) 79 | new_img_pil.save(os.path.join(save_aged_dir,f'{input_img_name}_{age_new}.png')) 80 | -------------------------------------------------------------------------------- /null_inversion.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from tqdm import tqdm 3 | import torch 4 | # from diffusers import StableDiffusionPipeline, DDIMScheduler 5 | import torch.nn.functional as nnf 6 | import numpy as np 7 | from torch.optim.adam import Adam 8 | from PIL import Image 9 | 10 | import FADING_util.ptp_utils as ptp_utils 11 | 12 | 13 | LOW_RESOURCE = False 14 | NUM_DDIM_STEPS = 50 15 | GUIDANCE_SCALE = 7.5 16 | MAX_NUM_WORDS = 77 17 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 18 | 19 | 20 | 21 | def load_512(image_path, left=0, right=0, top=0, bottom=0): 22 | if type(image_path) is str: 23 | image = np.array(Image.open(image_path))[:, :, :3] 24 | else: 25 | image = image_path 26 | h, w, c = image.shape 27 | left = min(left, w - 1) 28 | right = min(right, w - left - 1) 29 | top = min(top, h - left - 1) 30 | bottom = min(bottom, h - top - 1) 31 | image = image[top:h - bottom, left:w - right] 32 | h, w, c = image.shape 33 | if h < w: 34 | offset = (w - h) // 2 35 | image = image[:, offset:offset + h] 36 | elif w < h: 37 | offset = (h - w) // 2 38 | image = image[offset:offset + w] 39 | image = np.array(Image.fromarray(image).resize((512, 512))) 40 | return image 41 | 42 | 43 | class NullInversion: 44 | 45 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 46 | sample: Union[torch.FloatTensor, np.ndarray]): 47 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps 48 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] 49 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[ 50 | prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod 51 | beta_prod_t = 1 - alpha_prod_t 52 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 53 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output 54 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction 55 | return prev_sample 56 | 57 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int, 58 | sample: Union[torch.FloatTensor, np.ndarray]): 59 | timestep, next_timestep = min( 60 | timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep 61 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod 62 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep] 63 | beta_prod_t = 1 - alpha_prod_t 64 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5 65 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output 66 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction 67 | return next_sample 68 | 69 | def get_noise_pred_single(self, latents, t, context): 70 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"] 71 | return noise_pred 72 | 73 | def get_noise_pred(self, latents, t, is_forward=True, context=None): 74 | latents_input = torch.cat([latents] * 2) 75 | if context is None: 76 | context = self.context 77 | guidance_scale = 1 if is_forward else GUIDANCE_SCALE 78 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"] 79 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2) 80 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond) 81 | if is_forward: 82 | latents = self.next_step(noise_pred, t, latents) 83 | else: 84 | latents = self.prev_step(noise_pred, t, latents) 85 | return latents 86 | 87 | @torch.no_grad() 88 | def latent2image(self, latents, return_type='np'): 89 | latents = 1 / 0.18215 * latents.detach() 90 | image = self.model.vae.decode(latents)['sample'] 91 | if return_type == 'np': 92 | image = (image / 2 + 0.5).clamp(0, 1) 93 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0] 94 | image = (image * 255).astype(np.uint8) 95 | return image 96 | 97 | @torch.no_grad() 98 | def image2latent(self, image): 99 | with torch.no_grad(): 100 | if type(image) is Image: 101 | image = np.array(image) 102 | if type(image) is torch.Tensor and image.dim() == 4: 103 | latents = image 104 | else: 105 | image = torch.from_numpy(image).float() / 127.5 - 1 106 | image = image.permute(2, 0, 1).unsqueeze(0).to(device) 107 | latents = self.model.vae.encode(image)['latent_dist'].mean 108 | latents = latents * 0.18215 109 | return latents 110 | 111 | @torch.no_grad() 112 | def init_prompt(self, prompt: str): 113 | uncond_input = self.tokenizer( 114 | [""], padding="max_length", max_length=self.tokenizer.model_max_length, 115 | return_tensors="pt" 116 | ) 117 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] 118 | text_input = self.tokenizer( 119 | [prompt], 120 | padding="max_length", 121 | max_length=self.tokenizer.model_max_length, 122 | truncation=True, 123 | return_tensors="pt", 124 | ) 125 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0] 126 | self.context = torch.cat([uncond_embeddings, text_embeddings]) 127 | self.prompt = prompt 128 | 129 | @torch.no_grad() 130 | def ddim_loop(self, latent): 131 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 132 | all_latent = [latent] 133 | latent = latent.clone().detach() 134 | for i in range(NUM_DDIM_STEPS): 135 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1] 136 | noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings) 137 | latent = self.next_step(noise_pred, t, latent) 138 | all_latent.append(latent) 139 | return all_latent 140 | 141 | @property 142 | def scheduler(self): 143 | return self.model.scheduler 144 | 145 | @torch.no_grad() 146 | def ddim_inversion(self, image): 147 | latent = self.image2latent(image) 148 | image_rec = self.latent2image(latent) 149 | ddim_latents = self.ddim_loop(latent) 150 | return image_rec, ddim_latents 151 | 152 | def null_optimization(self, latents, num_inner_steps, epsilon): 153 | uncond_embeddings, cond_embeddings = self.context.chunk(2) 154 | uncond_embeddings_list = [] 155 | latent_cur = latents[-1] 156 | bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS) 157 | for i in range(NUM_DDIM_STEPS): 158 | uncond_embeddings = uncond_embeddings.clone().detach() 159 | uncond_embeddings.requires_grad = True 160 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.)) 161 | latent_prev = latents[len(latents) - i - 2] 162 | t = self.model.scheduler.timesteps[i] 163 | with torch.no_grad(): 164 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings) 165 | for j in range(num_inner_steps): 166 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings) 167 | noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond) 168 | latents_prev_rec = self.prev_step(noise_pred, t, latent_cur) 169 | loss = nnf.mse_loss(latents_prev_rec, latent_prev) 170 | optimizer.zero_grad() 171 | loss.backward() 172 | optimizer.step() 173 | loss_item = loss.item() 174 | bar.update() 175 | if loss_item < epsilon + i * 2e-5: 176 | break 177 | for j in range(j + 1, num_inner_steps): 178 | bar.update() 179 | uncond_embeddings_list.append(uncond_embeddings[:1].detach()) 180 | with torch.no_grad(): 181 | context = torch.cat([uncond_embeddings, cond_embeddings]) 182 | latent_cur = self.get_noise_pred(latent_cur, t, False, context) 183 | bar.close() 184 | return uncond_embeddings_list 185 | 186 | def invert(self, image_path: str, prompt: str, offsets=(0, 0, 0, 0), num_inner_steps=10, early_stop_epsilon=1e-5, 187 | verbose=False): 188 | self.init_prompt(prompt) 189 | ptp_utils.register_attention_control(self.model, None) 190 | image_gt = load_512(image_path, *offsets) 191 | if verbose: 192 | print("DDIM inversion...") 193 | image_rec, ddim_latents = self.ddim_inversion(image_gt) 194 | if verbose: 195 | print("Null-text optimization...") 196 | uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon) 197 | return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings 198 | 199 | def __init__(self, model): 200 | # scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, 201 | # set_alpha_to_one=False) 202 | self.model = model 203 | self.tokenizer = self.model.tokenizer 204 | self.model.scheduler.set_timesteps(NUM_DDIM_STEPS) 205 | self.prompt = None 206 | self.context = None -------------------------------------------------------------------------------- /p2p.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union, Tuple, List, Dict 2 | from tqdm import tqdm 3 | import torch 4 | import torch.nn.functional as nnf 5 | import numpy as np 6 | from PIL import Image 7 | import abc 8 | 9 | import FADING_util.ptp_utils as ptp_utils 10 | import FADING_util.seq_aligner as seq_aligner 11 | 12 | LOW_RESOURCE = False 13 | NUM_DDIM_STEPS = 50 14 | GUIDANCE_SCALE = 7.5 15 | MAX_NUM_WORDS = 77 16 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') 17 | #%% 18 | #% Prompt-to-Prompt code 19 | class LocalBlend: 20 | 21 | def get_mask(self, x_t, maps, alpha, use_pool): 22 | k = 1 23 | maps = (maps * alpha).sum(-1).mean(1) 24 | if use_pool: 25 | maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k)) 26 | mask = nnf.interpolate(maps, size=(x_t.shape[2:])) 27 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0] 28 | mask = mask.gt(self.th[1 - int(use_pool)]) 29 | mask = mask[:1] + mask 30 | return mask 31 | 32 | def __call__(self, x_t, attention_store): 33 | self.counter += 1 34 | if self.counter > self.start_blend: 35 | 36 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3] 37 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps] 38 | maps = torch.cat(maps, dim=1) 39 | mask = self.get_mask(x_t, maps, self.alpha_layers, True) 40 | if self.substruct_layers is not None: 41 | maps_sub = ~self.get_mask(x_t, maps, self.substruct_layers, False) 42 | mask = mask * maps_sub 43 | mask = mask.float() 44 | x_t = x_t[:1] + mask * (x_t - x_t[:1]) 45 | return x_t 46 | 47 | def __init__(self, prompts: List[str], words: [List[List[str]]], tokenizer, substruct_words=None, start_blend=0.2, 48 | th=(.3, .3)): 49 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 50 | for i, (prompt, words_) in enumerate(zip(prompts, words)): 51 | if type(words_) is str: 52 | words_ = [words_] 53 | for word in words_: 54 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 55 | alpha_layers[i, :, :, :, :, ind] = 1 56 | 57 | if substruct_words is not None: 58 | substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS) 59 | for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)): 60 | if type(words_) is str: 61 | words_ = [words_] 62 | for word in words_: 63 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer) 64 | substruct_layers[i, :, :, :, :, ind] = 1 65 | self.substruct_layers = substruct_layers.to(device) 66 | else: 67 | self.substruct_layers = None 68 | self.alpha_layers = alpha_layers.to(device) 69 | self.start_blend = int(start_blend * NUM_DDIM_STEPS) 70 | self.counter = 0 71 | self.th = th 72 | 73 | 74 | class EmptyControl: 75 | 76 | def step_callback(self, x_t): 77 | return x_t 78 | 79 | def between_steps(self): 80 | return 81 | 82 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 83 | return attn 84 | 85 | 86 | class AttentionControl(abc.ABC): 87 | 88 | def step_callback(self, x_t): 89 | return x_t 90 | 91 | def between_steps(self): 92 | return 93 | 94 | @property 95 | def num_uncond_att_layers(self): 96 | return self.num_att_layers if LOW_RESOURCE else 0 97 | 98 | @abc.abstractmethod 99 | def forward(self, attn, is_cross: bool, place_in_unet: str): 100 | raise NotImplementedError 101 | 102 | def __call__(self, attn, is_cross: bool, place_in_unet: str): 103 | if self.cur_att_layer >= self.num_uncond_att_layers: 104 | if LOW_RESOURCE: 105 | attn = self.forward(attn, is_cross, place_in_unet) 106 | else: 107 | h = attn.shape[0] 108 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) 109 | self.cur_att_layer += 1 110 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: 111 | self.cur_att_layer = 0 112 | self.cur_step += 1 113 | self.between_steps() 114 | return attn 115 | 116 | def reset(self): 117 | self.cur_step = 0 118 | self.cur_att_layer = 0 119 | 120 | def __init__(self): 121 | self.cur_step = 0 122 | self.num_att_layers = -1 123 | self.cur_att_layer = 0 124 | 125 | 126 | class SpatialReplace(EmptyControl): 127 | 128 | def step_callback(self, x_t): 129 | if self.cur_step < self.stop_inject: 130 | b = x_t.shape[0] 131 | x_t = x_t[:1].expand(b, *x_t.shape[1:]) 132 | return x_t 133 | 134 | def __init__(self, stop_inject: float): 135 | super(SpatialReplace, self).__init__() 136 | self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS) 137 | 138 | 139 | class AttentionStore(AttentionControl): 140 | 141 | @staticmethod 142 | def get_empty_store(): 143 | return {"down_cross": [], "mid_cross": [], "up_cross": [], 144 | "down_self": [], "mid_self": [], "up_self": []} 145 | 146 | def forward(self, attn, is_cross: bool, place_in_unet: str): 147 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" 148 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead 149 | self.step_store[key].append(attn) 150 | return attn 151 | 152 | def between_steps(self): 153 | if len(self.attention_store) == 0: 154 | self.attention_store = self.step_store 155 | else: 156 | for key in self.attention_store: 157 | for i in range(len(self.attention_store[key])): 158 | self.attention_store[key][i] += self.step_store[key][i] 159 | self.step_store = self.get_empty_store() 160 | 161 | def get_average_attention(self): 162 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in 163 | self.attention_store} 164 | return average_attention 165 | 166 | def reset(self): 167 | super(AttentionStore, self).reset() 168 | self.step_store = self.get_empty_store() 169 | self.attention_store = {} 170 | 171 | def __init__(self): 172 | super(AttentionStore, self).__init__() 173 | self.step_store = self.get_empty_store() 174 | self.attention_store = {} 175 | 176 | 177 | class AttentionControlEdit(AttentionStore, abc.ABC): 178 | 179 | def step_callback(self, x_t): 180 | if self.local_blend is not None: 181 | x_t = self.local_blend(x_t, self.attention_store) 182 | return x_t 183 | 184 | def replace_self_attention(self, attn_base, att_replace, place_in_unet): 185 | if att_replace.shape[2] <= 32 ** 2: 186 | attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) 187 | return attn_base 188 | else: 189 | return att_replace 190 | 191 | @abc.abstractmethod 192 | def replace_cross_attention(self, attn_base, att_replace): 193 | raise NotImplementedError 194 | 195 | def forward(self, attn, is_cross: bool, place_in_unet: str): 196 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) 197 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): 198 | h = attn.shape[0] // (self.batch_size) 199 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) 200 | attn_base, attn_repalce = attn[0], attn[1:] 201 | if is_cross: 202 | alpha_words = self.cross_replace_alpha[self.cur_step] 203 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + ( 204 | 1 - alpha_words) * attn_repalce 205 | attn[1:] = attn_repalce_new 206 | else: 207 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet) 208 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) 209 | return attn 210 | 211 | def __init__(self, prompts, num_steps: int, 212 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], 213 | self_replace_steps: Union[float, Tuple[float, float]], 214 | tokenizer, 215 | local_blend: Optional[LocalBlend]): 216 | super(AttentionControlEdit, self).__init__() 217 | self.batch_size = len(prompts) 218 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps, 219 | tokenizer).to(device) 220 | if type(self_replace_steps) is float: 221 | self_replace_steps = 0, self_replace_steps 222 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) 223 | self.local_blend = local_blend 224 | 225 | 226 | class AttentionReplace(AttentionControlEdit): 227 | 228 | def replace_cross_attention(self, attn_base, att_replace): 229 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) 230 | 231 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 232 | tokenizer, 233 | local_blend: Optional[LocalBlend] = None): 234 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, tokenizer, local_blend) 235 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) 236 | 237 | 238 | class AttentionRefine(AttentionControlEdit): 239 | 240 | def replace_cross_attention(self, attn_base, att_replace): 241 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3) 242 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) 243 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True) 244 | return attn_replace 245 | 246 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 247 | tokenizer, 248 | local_blend: Optional[LocalBlend] = None): 249 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, tokenizer, local_blend) 250 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) 251 | self.mapper, alphas = self.mapper.to(device), alphas.to(device) 252 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) 253 | 254 | 255 | class AttentionReweight(AttentionControlEdit): 256 | 257 | def replace_cross_attention(self, attn_base, att_replace): 258 | if self.prev_controller is not None: 259 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) 260 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] 261 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True) 262 | return attn_replace 263 | 264 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, 265 | tokenizer, equalizer, 266 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): 267 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, 268 | tokenizer, 269 | local_blend) 270 | self.equalizer = equalizer.to(device) 271 | self.prev_controller = controller 272 | 273 | 274 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],Tuple[float, ...]], 275 | tokenizer): 276 | if type(word_select) is int or type(word_select) is str: 277 | word_select = (word_select,) 278 | equalizer = torch.ones(1, 77) 279 | 280 | for word, val in zip(word_select, values): 281 | inds = ptp_utils.get_word_inds(text, word, tokenizer) 282 | equalizer[:, inds] = val 283 | return equalizer 284 | 285 | 286 | 287 | 288 | def make_controller(prompts: List[str], is_replace_controller: bool, 289 | cross_replace_steps: Dict[str, float], 290 | self_replace_steps: float, 291 | tokenizer, 292 | blend_words=None, equilizer_params=None) -> AttentionControlEdit: 293 | if blend_words is None: 294 | lb = None 295 | else: 296 | lb = LocalBlend(prompts, blend_words, tokenizer=tokenizer) 297 | if is_replace_controller: 298 | controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, 299 | self_replace_steps=self_replace_steps, 300 | tokenizer=tokenizer, 301 | local_blend=lb) 302 | else: 303 | controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, 304 | self_replace_steps=self_replace_steps, 305 | tokenizer=tokenizer, 306 | local_blend=lb) 307 | if equilizer_params is not None: 308 | eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"], tokenizer=tokenizer) 309 | controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps, 310 | self_replace_steps=self_replace_steps, 311 | tokenizer=tokenizer, 312 | equalizer=eq, local_blend=lb, 313 | controller=controller) 314 | return controller 315 | 316 | 317 | @torch.no_grad() 318 | def p2p_text2image( 319 | model, 320 | prompt: List[str], 321 | controller, 322 | num_inference_steps: int = 50, 323 | guidance_scale: Optional[float] = 7.5, 324 | generator: Optional[torch.Generator] = None, 325 | latent: Optional[torch.FloatTensor] = None, 326 | uncond_embeddings=None, 327 | start_time=50, 328 | return_type='image', 329 | height=512, width=512 330 | ): 331 | tokenizer = model.tokenizer 332 | batch_size = len(prompt) 333 | ptp_utils.register_attention_control(model, controller) 334 | 335 | text_input = tokenizer( 336 | prompt, 337 | padding="max_length", 338 | max_length=tokenizer.model_max_length, 339 | truncation=True, 340 | return_tensors="pt", 341 | ) 342 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0] 343 | max_length = text_input.input_ids.shape[-1] 344 | if uncond_embeddings is None: 345 | uncond_input = tokenizer( 346 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 347 | ) 348 | uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0] 349 | else: 350 | uncond_embeddings_ = None 351 | 352 | latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size) 353 | model.scheduler.set_timesteps(num_inference_steps) 354 | 355 | for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])): 356 | if uncond_embeddings_ is None: 357 | context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings]) 358 | # context = torch.cat([uncond_embeddings[i], text_embeddings]) # 我改了 359 | else: 360 | context = torch.cat([uncond_embeddings_, text_embeddings]) 361 | latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False) 362 | 363 | if return_type == 'image': 364 | image = ptp_utils.latent2image(model.vae, latents) 365 | else: 366 | image = latents 367 | 368 | return image, latent 369 | 370 | 371 | 372 | # def run_and_display(my_ldm_stable, prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None, 373 | # verbose=True): 374 | # if run_baseline: 375 | # print("w.o. prompt-to-prompt") 376 | # images, latent = run_and_display(my_ldm_stable, prompts, EmptyControl(), latent=latent, run_baseline=False, 377 | # generator=generator) 378 | # print("with prompt-to-prompt") 379 | # images, x_t = p2p_text2image(my_ldm_stable, prompts, controller, latent=latent, 380 | # num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE, 381 | # generator=generator, uncond_embeddings=uncond_embeddings) 382 | # if verbose: 383 | # images = ptp_utils.view_images(images) 384 | # 385 | # return images, x_t 386 | 387 | 388 | 389 | 390 | 391 | 392 | 393 | 394 | def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): 395 | out = [] 396 | attention_maps = attention_store.get_average_attention() 397 | num_pixels = res ** 2 398 | for location in from_where: 399 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: 400 | if item.shape[1] == num_pixels: 401 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] 402 | out.append(cross_maps) 403 | out = torch.cat(out, dim=0) 404 | out = out.sum(0) / out.shape[0] 405 | return out.cpu() 406 | 407 | def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): 408 | tokens = tokenizer.encode(prompts[select]) 409 | decoder = tokenizer.decode 410 | attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select) 411 | images = [] 412 | for i in range(len(tokens)): 413 | image = attention_maps[:, :, i] 414 | image = 255 * image / image.max() 415 | image = image.unsqueeze(-1).expand(*image.shape, 3) 416 | image = image.numpy().astype(np.uint8) 417 | image = np.array(Image.fromarray(image).resize((256, 256))) 418 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i]))) 419 | images.append(image) 420 | image_ = ptp_utils.view_images(np.stack(images, axis=0)) 421 | return image_ 422 | 423 | 424 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], 425 | max_com=10, select: int = 0): 426 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape( 427 | (res ** 2, res ** 2)) 428 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) 429 | images = [] 430 | for i in range(max_com): 431 | image = vh[i].reshape(res, res) 432 | image = image - image.min() 433 | image = 255 * image / image.max() 434 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) 435 | image = Image.fromarray(image).resize((256, 256)) 436 | image = np.array(image) 437 | images.append(image) 438 | ptp_utils.view_images(np.concatenate(images, axis=1)) 439 | 440 | 441 | #% null inversion 442 | -------------------------------------------------------------------------------- /specialize.py: -------------------------------------------------------------------------------- 1 | # code adapted from https://github.com/huggingface/diffusers/blob/v0.10.0/examples/dreambooth/train_dreambooth.py 2 | 3 | import os 4 | import argparse 5 | import json 6 | import hashlib 7 | import itertools 8 | import logging 9 | import math 10 | import numpy as np 11 | from PIL import Image 12 | 13 | try: 14 | from torchvision.transforms import InterpolationMode 15 | BICUBIC = InterpolationMode.BICUBIC 16 | except ImportError: 17 | BICUBIC = Image.BICUBIC 18 | import warnings 19 | from pathlib import Path 20 | from typing import List, Optional, Tuple, Union 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | import torch.utils.checkpoint 25 | from torch.utils.data import Dataset 26 | 27 | import diffusers 28 | import transformers 29 | from accelerate import Accelerator 30 | from accelerate.logging import get_logger 31 | from accelerate.utils import set_seed 32 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel 33 | from diffusers.optimization import get_scheduler 34 | # from diffusers.utils import check_min_version 35 | from diffusers.utils.import_utils import is_xformers_available 36 | from huggingface_hub import HfFolder, Repository, whoami 37 | 38 | from torchvision import transforms 39 | from tqdm.auto import tqdm 40 | from transformers import AutoTokenizer, PretrainedConfig 41 | 42 | #%% 43 | 44 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 45 | # check_min_version("0.10.0.dev0") 46 | 47 | logger = get_logger(__name__) 48 | 49 | 50 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): 51 | text_encoder_config = PretrainedConfig.from_pretrained( 52 | pretrained_model_name_or_path, 53 | subfolder="text_encoder", 54 | revision=revision, 55 | ) 56 | model_class = text_encoder_config.architectures[0] 57 | 58 | if model_class == "CLIPTextModel": 59 | from transformers import CLIPTextModel 60 | 61 | return CLIPTextModel 62 | elif model_class == "RobertaSeriesModelWithTransformation": 63 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation 64 | 65 | return RobertaSeriesModelWithTransformation 66 | else: 67 | raise ValueError(f"{model_class} is not supported.") 68 | 69 | 70 | def parse_args(input_args=None): 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 72 | parser.add_argument( 73 | "--pretrained_model_name_or_path", 74 | type=str, 75 | default="runwayml/stable-diffusion-v1-5", 76 | required=False, 77 | help="Path to pretrained model or model identifier from huggingface.co/models.", 78 | ) 79 | parser.add_argument( 80 | "--instance_data_dir", 81 | type=str, 82 | default=None, 83 | required=True, 84 | help="A folder containing the training data of training images.", 85 | ) 86 | parser.add_argument( 87 | "--instance_age_path", 88 | type=str, 89 | default=None, 90 | required=True, 91 | help="A numpy array that contains the initial ages of training images.", 92 | ) 93 | parser.add_argument( 94 | "--instance_prompt", 95 | type=str, 96 | default="photo of a person", 97 | required=False, 98 | help="The prompt with identifier specifying the instance", 99 | ) 100 | parser.add_argument( 101 | "--output_dir", 102 | type=str, 103 | default=None, 104 | required=True, 105 | help="The output directory where the model predictions and checkpoints will be written.", 106 | ) 107 | parser.add_argument( 108 | "--finetune_mode", 109 | type=str, 110 | default='finetune_double_prompt', 111 | required=False, 112 | help="Specialization mode, 'finetune_double_prompt'|'finetune_single_prompt'", 113 | ) 114 | 115 | 116 | 117 | ##### 118 | parser.add_argument( 119 | "--revision", 120 | type=str, 121 | default=None, 122 | required=False, 123 | help=( 124 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" 125 | " float32 precision." 126 | ), 127 | ) 128 | parser.add_argument( 129 | "--tokenizer_name", 130 | type=str, 131 | default=None, 132 | help="Pretrained tokenizer name or path if not the same as model_name", 133 | ) 134 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 135 | parser.add_argument( 136 | "--resolution", 137 | type=int, 138 | default=512, 139 | help=( 140 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 141 | " resolution" 142 | ), 143 | ) 144 | parser.add_argument( 145 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 146 | ) 147 | parser.add_argument( 148 | "--train_text_encoder", 149 | action="store_true", 150 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", 151 | ) 152 | parser.add_argument( 153 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." 154 | ) 155 | parser.add_argument( 156 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." 157 | ) 158 | parser.add_argument("--num_train_epochs", type=int, default=1) 159 | parser.add_argument( 160 | "--max_train_steps", 161 | type=int, 162 | default=150, 163 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 164 | ) 165 | 166 | parser.add_argument( 167 | "--checkpointing_steps", 168 | type=int, 169 | default=500, 170 | help=( 171 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" 172 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" 173 | " training using `--resume_from_checkpoint`." 174 | ), 175 | ) 176 | parser.add_argument( 177 | "--resume_from_checkpoint", 178 | type=str, 179 | default=None, 180 | help=( 181 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 182 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 183 | ), 184 | ) 185 | parser.add_argument( 186 | "--gradient_accumulation_steps", 187 | type=int, 188 | default=1, 189 | help="Number of updates steps to accumulate before performing a backward/update pass.", 190 | ) 191 | parser.add_argument( 192 | "--gradient_checkpointing", 193 | action="store_true", 194 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 195 | ) 196 | parser.add_argument( 197 | "--learning_rate", 198 | type=float, 199 | default=5e-6, 200 | help="Initial learning rate (after the potential warmup period) to use.", 201 | ) 202 | parser.add_argument( 203 | "--scale_lr", 204 | action="store_true", 205 | default=False, 206 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 207 | ) 208 | parser.add_argument( 209 | "--lr_scheduler", 210 | type=str, 211 | default="constant", 212 | help=( 213 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 214 | ' "constant", "constant_with_warmup"]' 215 | ), 216 | ) 217 | parser.add_argument( 218 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 219 | ) 220 | parser.add_argument( 221 | "--lr_num_cycles", 222 | type=int, 223 | default=1, 224 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.", 225 | ) 226 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") 227 | parser.add_argument( 228 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 229 | ) 230 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 231 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 232 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 233 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 234 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 235 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 236 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 237 | parser.add_argument( 238 | "--hub_model_id", 239 | type=str, 240 | default=None, 241 | help="The name of the repository to keep in sync with the local `output_dir`.", 242 | ) 243 | parser.add_argument( 244 | "--logging_dir", 245 | type=str, 246 | default="logs", 247 | help=( 248 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 249 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 250 | ), 251 | ) 252 | parser.add_argument( 253 | "--allow_tf32", 254 | action="store_true", 255 | help=( 256 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 257 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 258 | ), 259 | ) 260 | parser.add_argument( 261 | "--report_to", 262 | type=str, 263 | default="tensorboard", 264 | help=( 265 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 266 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 267 | ), 268 | ) 269 | parser.add_argument( 270 | "--mixed_precision", 271 | type=str, 272 | default=None, 273 | choices=["no", "fp16", "bf16"], 274 | help=( 275 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 276 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 277 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 278 | ), 279 | ) 280 | parser.add_argument( 281 | "--prior_generation_precision", 282 | type=str, 283 | default=None, 284 | choices=["no", "fp32", "fp16", "bf16"], 285 | help=( 286 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 287 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." 288 | ), 289 | ) 290 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 291 | parser.add_argument( 292 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 293 | ) 294 | 295 | if input_args is not None: 296 | args = parser.parse_args(input_args) 297 | else: 298 | args = parser.parse_args() 299 | 300 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 301 | if env_local_rank != -1 and env_local_rank != args.local_rank: 302 | args.local_rank = env_local_rank 303 | 304 | return args 305 | 306 | 307 | 308 | #%% 309 | class DreamBoothDatasetAge(Dataset): 310 | """ 311 | return instance (img, token of 'photo of a xx year old person') 312 | and (img, token of 'photo of a person') 313 | """ 314 | def __init__( 315 | self, 316 | instance_data_root, 317 | instance_age_path, # numpy array that stores the age of training images 318 | tokenizer, 319 | instance_prompt="photo of a person", 320 | size=512, 321 | center_crop=False, 322 | ): 323 | 324 | self.size = size 325 | self.center_crop = center_crop 326 | self.tokenizer = tokenizer 327 | 328 | self.instance_data_root = Path(instance_data_root) 329 | if not self.instance_data_root.exists(): 330 | raise ValueError("Instance images root doesn't exists.") 331 | 332 | if self.instance_data_root.is_file(): 333 | self.instance_images_path = [Path(instance_data_root)] 334 | else: 335 | self.instance_images_path = [os.path.join(instance_data_root,filename) for filename in os.listdir(instance_data_root)] 336 | 337 | 338 | self.num_instance_images = len(self.instance_images_path) 339 | self.instance_prompt = instance_prompt 340 | self._length = self.num_instance_images 341 | 342 | self.age_labels = np.load(instance_age_path) 343 | 344 | 345 | self.image_transforms = transforms.Compose( 346 | [ 347 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), 348 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 349 | transforms.ToTensor(), 350 | transforms.Normalize([0.5], [0.5]), 351 | ] 352 | ) 353 | 354 | 355 | def __len__(self): 356 | return self._length 357 | 358 | def __getitem__(self, index): 359 | example = {} 360 | 361 | instance_image_path = self.instance_images_path[index % self.num_instance_images] 362 | instance_image = Image.open(instance_image_path) 363 | if not instance_image.mode == "RGB": 364 | instance_image = instance_image.convert("RGB") 365 | 366 | example["instance_images"] = self.image_transforms(instance_image) 367 | 368 | example["instance_prompt"] = self.instance_prompt 369 | example["instance_prompt_ids"] = self.tokenizer( # token(photo of a person) 370 | self.instance_prompt, 371 | truncation=True, 372 | padding="max_length", 373 | max_length=self.tokenizer.model_max_length, 374 | return_tensors="pt", 375 | ).input_ids 376 | 377 | instance_image_name = instance_image_path.split('/')[-1] 378 | instance_image_age = int(self.age_labels[self.age_labels[:, 0] == instance_image_name][0, 1]) 379 | 380 | example["instance_image_age"] = instance_image_age 381 | instance_age_prompt = self.instance_prompt.replace(" a ", f" a {instance_image_age} year old ") 382 | 383 | example["instance_age_prompt"] = instance_age_prompt 384 | example["instance_age_prompt_ids"] = self.tokenizer( # token(photo of a xx year old person) 385 | instance_age_prompt, 386 | truncation=True, 387 | padding="max_length", 388 | max_length=self.tokenizer.model_max_length, 389 | return_tensors="pt", 390 | ).input_ids 391 | 392 | example["blank_prompt_ids"] = self.tokenizer( # token("") 393 | "", 394 | truncation=True, 395 | padding="max_length", 396 | max_length=self.tokenizer.model_max_length, 397 | return_tensors="pt", 398 | ).input_ids 399 | 400 | 401 | return example 402 | #% 403 | 404 | def collate_fn(examples, finetune_mode="finetune_double_prompt"): 405 | if len(examples)!=1: 406 | raise ValueError("batchsize can only be 1..") 407 | 408 | if finetune_mode == "finetune_double_prompt": 409 | input_ids = [example["instance_prompt_ids"] for example in examples] 410 | input_ids += [example["instance_age_prompt_ids"] for example in examples] 411 | 412 | elif finetune_mode == "finetune_single_prompt": 413 | input_ids = [example["instance_age_prompt_ids"] for example in examples] 414 | input_ids += [example["instance_age_prompt_ids"] for example in examples] 415 | 416 | # elif finetune_mode == "finetune_single_prompt_no_age": 417 | # input_ids = [example["instance_prompt_ids"] for example in examples] 418 | # input_ids += [example["instance_prompt_ids"] for example in examples] 419 | # 420 | # elif finetune_mode == "finetune_no_prompt": 421 | # input_ids = [example["blank_prompt_ids"] for example in examples] 422 | # input_ids += [example["blank_prompt_ids"] for example in examples] 423 | 424 | else: 425 | raise ValueError("invalid finetune_mode") 426 | 427 | pixel_values = [example["instance_images"] for example in examples] 428 | pixel_values += [example["instance_images"] for example in examples] 429 | 430 | pixel_values_ages = [example["instance_image_age"] for example in examples] 431 | 432 | pixel_values = torch.stack(pixel_values) 433 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 434 | 435 | input_ids = torch.cat(input_ids, dim=0) 436 | 437 | batch = { 438 | "input_ids": input_ids, 439 | "pixel_values": pixel_values, 440 | 441 | "pixel_values_ages": pixel_values_ages, 442 | } 443 | 444 | return batch 445 | 446 | 447 | 448 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): 449 | if token is None: 450 | token = HfFolder.get_token() 451 | if organization is None: 452 | username = whoami(token)["name"] 453 | return f"{username}/{model_id}" 454 | else: 455 | return f"{organization}/{model_id}" 456 | 457 | 458 | #%% 459 | def main(args): 460 | argparse_dict = vars(args) 461 | 462 | argparse_json = json.dumps(argparse_dict, indent=4) 463 | print("model configuration:\n", argparse_json) 464 | os.makedirs(args.output_dir, exist_ok=True) 465 | with open(os.path.join(args.output_dir, "model_config.json"), "w") as outfile: 466 | outfile.write(argparse_json) 467 | 468 | logging_dir = Path(args.output_dir, args.logging_dir) 469 | 470 | accelerator = Accelerator( 471 | gradient_accumulation_steps=args.gradient_accumulation_steps, 472 | mixed_precision=args.mixed_precision, 473 | log_with=args.report_to, 474 | logging_dir=logging_dir, 475 | ) 476 | 477 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate 478 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. 479 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. 480 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 481 | raise ValueError( 482 | "Gradient accumulation is not supported when training the text encoder in distributed training. " 483 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 484 | ) 485 | 486 | # Make one log on every process with the configuration for debugging. 487 | logging.basicConfig( 488 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 489 | datefmt="%m/%d/%Y %H:%M:%S", 490 | level=logging.INFO, 491 | ) 492 | logger.info(accelerator.state, main_process_only=False) 493 | if accelerator.is_local_main_process: 494 | # datasets.utils.logging.set_verbosity_warning() 495 | transformers.utils.logging.set_verbosity_warning() 496 | diffusers.utils.logging.set_verbosity_info() 497 | else: 498 | # datasets.utils.logging.set_verbosity_error() 499 | transformers.utils.logging.set_verbosity_error() 500 | diffusers.utils.logging.set_verbosity_error() 501 | 502 | # If passed along, set the training seed now. 503 | if args.seed is not None: 504 | set_seed(args.seed) 505 | 506 | # Handle the repository creation 507 | if accelerator.is_main_process: 508 | if args.push_to_hub: 509 | if args.hub_model_id is None: 510 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) 511 | else: 512 | repo_name = args.hub_model_id 513 | repo = Repository(args.output_dir, clone_from=repo_name) 514 | 515 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: 516 | if "step_*" not in gitignore: 517 | gitignore.write("step_*\n") 518 | if "epoch_*" not in gitignore: 519 | gitignore.write("epoch_*\n") 520 | elif args.output_dir is not None: 521 | os.makedirs(args.output_dir, exist_ok=True) 522 | 523 | # Load the tokenizer 524 | if args.tokenizer_name: 525 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) 526 | elif args.pretrained_model_name_or_path: 527 | tokenizer = AutoTokenizer.from_pretrained( 528 | args.pretrained_model_name_or_path, 529 | subfolder="tokenizer", 530 | revision=args.revision, 531 | use_fast=False, 532 | ) 533 | 534 | # import correct text encoder class 535 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) 536 | 537 | # Load scheduler and models 538 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 539 | 540 | 541 | def pred_original_samples(samples: torch.FloatTensor, 542 | noise: torch.FloatTensor, 543 | timesteps: torch.IntTensor): 544 | sqrt_alpha_prod = noise_scheduler.alphas_cumprod[timesteps] ** 0.5 545 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 546 | while len(sqrt_alpha_prod.shape) < len(samples.shape): 547 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 548 | 549 | sqrt_one_minus_alpha_prod = (1 - noise_scheduler.alphas_cumprod[timesteps]) ** 0.5 550 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 551 | while len(sqrt_one_minus_alpha_prod.shape) < len(samples.shape): 552 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 553 | 554 | original_samples = (samples - sqrt_one_minus_alpha_prod * noise) / sqrt_alpha_prod 555 | return original_samples 556 | 557 | text_encoder = text_encoder_cls.from_pretrained( 558 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision 559 | ) 560 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) 561 | unet = UNet2DConditionModel.from_pretrained( 562 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision 563 | ) 564 | 565 | vae.requires_grad_(False) 566 | if not args.train_text_encoder: 567 | text_encoder.requires_grad_(False) 568 | 569 | if args.enable_xformers_memory_efficient_attention: 570 | if is_xformers_available(): 571 | unet.enable_xformers_memory_efficient_attention() 572 | else: 573 | raise ValueError("xformers is not available. Make sure it is installed correctly") 574 | 575 | if args.gradient_checkpointing: 576 | unet.enable_gradient_checkpointing() 577 | if args.train_text_encoder: 578 | text_encoder.gradient_checkpointing_enable() 579 | 580 | # Enable TF32 for faster training on Ampere GPUs, 581 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 582 | if args.allow_tf32: 583 | torch.backends.cuda.matmul.allow_tf32 = True 584 | 585 | if args.scale_lr: 586 | args.learning_rate = ( 587 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 588 | ) 589 | 590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 591 | if args.use_8bit_adam: 592 | try: 593 | import bitsandbytes as bnb 594 | except ImportError: 595 | raise ImportError( 596 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." 597 | ) 598 | 599 | optimizer_class = bnb.optim.AdamW8bit 600 | else: 601 | optimizer_class = torch.optim.AdamW 602 | 603 | # Optimizer creation 604 | params_to_optimize = ( 605 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() 606 | ) 607 | optimizer = optimizer_class( 608 | params_to_optimize, 609 | lr=args.learning_rate, 610 | betas=(args.adam_beta1, args.adam_beta2), 611 | weight_decay=args.adam_weight_decay, 612 | eps=args.adam_epsilon, 613 | ) 614 | 615 | 616 | train_dataset = DreamBoothDatasetAge( 617 | instance_data_root = args.instance_data_dir, 618 | instance_age_path = args.instance_age_path, 619 | tokenizer = tokenizer, 620 | instance_prompt=args.instance_prompt, 621 | size=args.resolution, 622 | center_crop=args.center_crop, 623 | ) 624 | train_dataloader = torch.utils.data.DataLoader( 625 | train_dataset, 626 | batch_size=args.train_batch_size, 627 | shuffle=False, 628 | collate_fn=lambda examples: collate_fn(examples, args.finetune_mode), 629 | num_workers=1, 630 | ) 631 | 632 | # Scheduler and math around the number of training steps. 633 | overrode_max_train_steps = False 634 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 635 | if args.max_train_steps is None: 636 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 637 | overrode_max_train_steps = True 638 | 639 | lr_scheduler = get_scheduler( 640 | args.lr_scheduler, 641 | optimizer=optimizer, 642 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, 643 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 644 | # num_cycles=args.lr_num_cycles, 645 | # power=args.lr_power, 646 | ) 647 | 648 | # Prepare everything with our `accelerator`. 649 | if args.train_text_encoder: 650 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 651 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 652 | ) 653 | else: 654 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 655 | unet, optimizer, train_dataloader, lr_scheduler 656 | ) 657 | 658 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 659 | # as these models are only used for inference, keeping weights in full precision is not required. 660 | weight_dtype = torch.float32 661 | if accelerator.mixed_precision == "fp16": 662 | weight_dtype = torch.float16 663 | elif accelerator.mixed_precision == "bf16": 664 | weight_dtype = torch.bfloat16 665 | 666 | # Move vae and text_encoder to device and cast to weight_dtype 667 | vae.to(accelerator.device, dtype=weight_dtype) 668 | if not args.train_text_encoder: 669 | text_encoder.to(accelerator.device, dtype=weight_dtype) 670 | 671 | low_precision_error_string = ( 672 | "Please make sure to always have all model weights in full float32 precision when starting training - even if" 673 | " doing mixed precision training. copy of the weights should still be float32." 674 | ) 675 | 676 | if unet.dtype != torch.float32: 677 | raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}") 678 | 679 | if args.train_text_encoder and text_encoder.dtype != torch.float32: 680 | raise ValueError(f"Text encoder loaded as datatype {text_encoder.dtype}. {low_precision_error_string}") 681 | 682 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 683 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 684 | if overrode_max_train_steps: 685 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 686 | # Afterwards we recalculate our number of training epochs 687 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 688 | 689 | # We need to initialize the trackers we use, and also store our configuration. 690 | # The trackers initializes automatically on the main process. 691 | if accelerator.is_main_process: 692 | accelerator.init_trackers("dreambooth", config=vars(args)) 693 | 694 | # Train! 695 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 696 | 697 | logger.info("***** Running training *****") 698 | logger.info(f" Num examples = {len(train_dataset)}") 699 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 700 | logger.info(f" Num Epochs = {args.num_train_epochs}") 701 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 702 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 703 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 704 | logger.info(f" Total optimization steps = {args.max_train_steps}") 705 | global_step = 0 706 | first_epoch = 0 707 | 708 | # Potentially load in the weights and states from a previous save 709 | if args.resume_from_checkpoint: 710 | if args.resume_from_checkpoint != "latest": 711 | path = os.path.basename(args.resume_from_checkpoint) 712 | else: 713 | # Get the mos recent checkpoint 714 | dirs = os.listdir(args.output_dir) 715 | dirs = [d for d in dirs if d.startswith("checkpoint")] 716 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 717 | path = dirs[-1] 718 | accelerator.print(f"Resuming from checkpoint {path}") 719 | accelerator.load_state(os.path.join(args.output_dir, path)) 720 | global_step = int(path.split("-")[1]) 721 | 722 | resume_global_step = global_step * args.gradient_accumulation_steps 723 | first_epoch = resume_global_step // num_update_steps_per_epoch 724 | resume_step = resume_global_step % num_update_steps_per_epoch 725 | 726 | # Only show the progress bar once on each machine. 727 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 728 | progress_bar.set_description("Steps") 729 | 730 | for epoch in range(first_epoch, args.num_train_epochs): 731 | unet.train() 732 | if args.train_text_encoder: 733 | text_encoder.train() 734 | 735 | for step, batch in enumerate(train_dataloader): 736 | # Skip steps until we reach the resumed step 737 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 738 | if step % args.gradient_accumulation_steps == 0: 739 | progress_bar.update(1) 740 | continue 741 | 742 | with accelerator.accumulate(unet): 743 | # Convert images to latent space 744 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 745 | latents = latents * 0.18215 746 | 747 | # Sample noise that we'll add to the latents 748 | noise = torch.randn_like(latents) 749 | 750 | noise_single = torch.randn_like(latents[0]) 751 | bsz = latents.shape[0] 752 | noise = torch.cat([noise_single.unsqueeze(0)] * bsz, dim=0) 753 | 754 | # Sample a random timestep for each image 755 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 756 | timesteps = timesteps.long() 757 | 758 | # Add noise to the latents according to the noise magnitude at each timestep 759 | # (this is the forward diffusion process) 760 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 761 | 762 | # Get the text embedding for conditioning 763 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 764 | 765 | # Predict the noise residual 766 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 767 | 768 | # Get the target for loss depending on the prediction type 769 | if noise_scheduler.config.prediction_type == "epsilon": 770 | target = noise 771 | elif noise_scheduler.config.prediction_type == "v_prediction": 772 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 773 | else: 774 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 775 | 776 | pred_denoised_latents = pred_original_samples(samples = noisy_latents, 777 | noise = model_pred, 778 | timesteps = timesteps) 779 | 780 | # noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 781 | pred_noisy_latents = noise_scheduler.add_noise(latents, model_pred.to(dtype=weight_dtype), timesteps) 782 | 783 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 784 | instance_loss = loss 785 | 786 | # Checks if the accelerator has performed an optimization step behind the scenes 787 | if accelerator.sync_gradients: 788 | progress_bar.update(1) 789 | global_step += 1 790 | 791 | if global_step % args.checkpointing_steps == 0: 792 | if accelerator.is_main_process: 793 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 794 | accelerator.save_state(save_path) 795 | logger.info(f"Saved state to {save_path}") 796 | 797 | logs = {"loss": loss.detach().item(), 798 | "instance_loss": instance_loss.detach().item(), 799 | "lr": lr_scheduler.get_last_lr()[0]} 800 | 801 | progress_bar.set_postfix(**logs) 802 | accelerator.log(logs, step=global_step) 803 | 804 | if global_step >= args.max_train_steps: 805 | break 806 | 807 | # Create the pipeline using using the trained modules and save it. 808 | accelerator.wait_for_everyone() 809 | if accelerator.is_main_process: 810 | pipeline = DiffusionPipeline.from_pretrained( 811 | args.pretrained_model_name_or_path, 812 | unet=accelerator.unwrap_model(unet), 813 | text_encoder=accelerator.unwrap_model(text_encoder), 814 | revision=args.revision, 815 | ) 816 | pipeline.save_pretrained(args.output_dir) 817 | 818 | if args.push_to_hub: 819 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True) 820 | 821 | accelerator.end_training() 822 | 823 | if __name__ == "__main__": 824 | args = parse_args() 825 | 826 | # args.instance_data_dir = 'specialization_data/training_images' 827 | # args.instance_age_path = 'specialization_data/training_ages.npy' 828 | # 829 | # args.output_dir = 'specialized_models/tmp' 830 | 831 | main(args) 832 | 833 | 834 | 835 | 836 | --------------------------------------------------------------------------------