├── FADING_demo.png
├── FADING_util
├── ptp_utils.py
├── seq_aligner.py
└── util.py
├── README.md
├── age_editing.py
├── null_inversion.py
├── p2p.py
└── specialize.py
/FADING_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MunchkinChen/FADING/14fefbf7edbf5b1d3f867bd4720a61bfc1bdf2c0/FADING_demo.png
--------------------------------------------------------------------------------
/FADING_util/ptp_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import numpy as np
16 | import torch
17 | from PIL import Image, ImageDraw, ImageFont
18 | import cv2
19 | from typing import Optional, Union, Tuple, List, Callable, Dict
20 | from IPython.display import display
21 | from tqdm.notebook import tqdm
22 |
23 |
24 | def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
25 | h, w, c = image.shape
26 | offset = int(h * .2)
27 | img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
28 | font = cv2.FONT_HERSHEY_SIMPLEX
29 | # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
30 | img[:h] = image
31 | textsize = cv2.getTextSize(text, font, 1, 2)[0]
32 | text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
33 | cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
34 | return img
35 |
36 |
37 | def view_images(images, num_rows=1, offset_ratio=0.02):
38 | if type(images) is list:
39 | num_empty = len(images) % num_rows
40 | elif images.ndim == 4:
41 | num_empty = images.shape[0] % num_rows
42 | else:
43 | images = [images]
44 | num_empty = 0
45 |
46 | empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
47 | images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
48 | num_items = len(images)
49 |
50 | h, w, c = images[0].shape
51 | offset = int(h * offset_ratio)
52 | num_cols = num_items // num_rows
53 | image_ = np.ones((h * num_rows + offset * (num_rows - 1),
54 | w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
55 | for i in range(num_rows):
56 | for j in range(num_cols):
57 | image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
58 | i * num_cols + j]
59 |
60 | pil_img = Image.fromarray(image_)
61 | return image_
62 |
63 |
64 | def diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False):
65 | if low_resource:
66 | noise_pred_uncond = model.unet(latents, t, encoder_hidden_states=context[0])["sample"]
67 | noise_prediction_text = model.unet(latents, t, encoder_hidden_states=context[1])["sample"]
68 | else:
69 | latents_input = torch.cat([latents] * 2)
70 | noise_pred = model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
71 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
72 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
73 | latents = model.scheduler.step(noise_pred, t, latents)["prev_sample"]
74 | latents = controller.step_callback(latents)
75 | return latents
76 |
77 |
78 | def latent2image(vae, latents):
79 | latents = 1 / 0.18215 * latents
80 | image = vae.decode(latents)['sample']
81 | image = (image / 2 + 0.5).clamp(0, 1)
82 | image = image.cpu().permute(0, 2, 3, 1).numpy()
83 | image = (image * 255).astype(np.uint8)
84 | return image
85 |
86 |
87 | def init_latent(latent, model, height, width, generator, batch_size):
88 | if latent is None:
89 | latent = torch.randn(
90 | (1, model.unet.in_channels, height // 8, width // 8),
91 | generator=generator,device=model.device,
92 | )
93 | latents = latent.expand(batch_size, model.unet.in_channels, height // 8, width // 8).to(model.device)
94 | return latent, latents
95 |
96 | #
97 | # @torch.no_grad()
98 | # def text2image_ldm(
99 | # model,
100 | # prompt: List[str],
101 | # controller,
102 | # num_inference_steps: int = 50,
103 | # guidance_scale: Optional[float] = 7.,
104 | # generator: Optional[torch.Generator] = None,
105 | # latent: Optional[torch.FloatTensor] = None,
106 | # ):
107 | # register_attention_control(model, controller)
108 | # height = width = 256
109 | # batch_size = len(prompt)
110 | #
111 | # uncond_input = model.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
112 | # uncond_embeddings = model.bert(uncond_input.input_ids.to(model.device))[0]
113 | #
114 | # text_input = model.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
115 | # text_embeddings = model.bert(text_input.input_ids.to(model.device))[0]
116 | # latent, latents = init_latent(latent, model, height, width, generator, batch_size)
117 | # context = torch.cat([uncond_embeddings, text_embeddings])
118 | #
119 | # model.scheduler.set_timesteps(num_inference_steps)
120 | # for t in tqdm(model.scheduler.timesteps):
121 | # latents = diffusion_step(model, controller, latents, context, t, guidance_scale)
122 | #
123 | # image = latent2image(model.vqvae, latents)
124 | #
125 | # return image, latent
126 | #
127 | #
128 | # @torch.no_grad()
129 | # def text2image_ldm_stable(
130 | # model,
131 | # prompt: List[str],
132 | # controller,
133 | # num_inference_steps: int = 50,
134 | # guidance_scale: float = 7.5,
135 | # generator: Optional[torch.Generator] = None,
136 | # latent: Optional[torch.FloatTensor] = None,
137 | # low_resource: bool = False,
138 | # ):
139 | # register_attention_control(model, controller)
140 | # height = width = 512
141 | # batch_size = len(prompt)
142 | #
143 | # text_input = model.tokenizer(
144 | # prompt,
145 | # padding="max_length",
146 | # max_length=model.tokenizer.model_max_length,
147 | # truncation=True,
148 | # return_tensors="pt",
149 | # )
150 | # text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
151 | # max_length = text_input.input_ids.shape[-1]
152 | # uncond_input = model.tokenizer(
153 | # [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
154 | # )
155 | # uncond_embeddings = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
156 | #
157 | # context = [uncond_embeddings, text_embeddings]
158 | # if not low_resource:
159 | # context = torch.cat(context)
160 | # latent, latents = init_latent(latent, model, height, width, generator, batch_size)
161 | #
162 | # # set timesteps
163 | # # extra_set_kwargs = {"offset": 1}
164 | # # model.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
165 | # model.scheduler.set_timesteps(num_inference_steps)
166 | # for t in tqdm(model.scheduler.timesteps):
167 | # latents = diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource)
168 | #
169 | # image = latent2image(model.vae, latents)
170 | #
171 | # return image, latent
172 |
173 |
174 | def register_attention_control(model, controller):
175 | def ca_forward(self, place_in_unet):
176 | to_out = self.to_out
177 | if type(to_out) is torch.nn.modules.container.ModuleList:
178 | to_out = self.to_out[0]
179 | else:
180 | to_out = self.to_out
181 |
182 | def forward(x, context=None, mask=None):
183 | batch_size, sequence_length, dim = x.shape
184 | h = self.heads
185 | q = self.to_q(x)
186 | is_cross = context is not None
187 | context = context if is_cross else x
188 | k = self.to_k(context)
189 | v = self.to_v(context)
190 | q = self.reshape_heads_to_batch_dim(q)
191 | k = self.reshape_heads_to_batch_dim(k)
192 | v = self.reshape_heads_to_batch_dim(v)
193 |
194 | sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
195 |
196 | if mask is not None:
197 | mask = mask.reshape(batch_size, -1)
198 | max_neg_value = -torch.finfo(sim.dtype).max
199 | mask = mask[:, None, :].repeat(h, 1, 1)
200 | sim.masked_fill_(~mask, max_neg_value)
201 |
202 | # attention, what we cannot get enough of
203 | attn = sim.softmax(dim=-1)
204 | attn = controller(attn, is_cross, place_in_unet)
205 | out = torch.einsum("b i j, b j d -> b i d", attn, v)
206 | out = self.reshape_batch_dim_to_heads(out)
207 | return to_out(out)
208 |
209 | return forward
210 |
211 | class DummyController:
212 |
213 | def __call__(self, *args):
214 | return args[0]
215 |
216 | def __init__(self):
217 | self.num_att_layers = 0
218 |
219 | if controller is None:
220 | controller = DummyController()
221 |
222 | def register_recr(net_, count, place_in_unet):
223 | if net_.__class__.__name__ == 'CrossAttention':
224 | net_.forward = ca_forward(net_, place_in_unet)
225 | return count + 1
226 | elif hasattr(net_, 'children'):
227 | for net__ in net_.children():
228 | count = register_recr(net__, count, place_in_unet)
229 | return count
230 |
231 | cross_att_count = 0
232 | sub_nets = model.unet.named_children()
233 | for net in sub_nets:
234 | if "down" in net[0]:
235 | cross_att_count += register_recr(net[1], 0, "down")
236 | elif "up" in net[0]:
237 | cross_att_count += register_recr(net[1], 0, "up")
238 | elif "mid" in net[0]:
239 | cross_att_count += register_recr(net[1], 0, "mid")
240 |
241 | controller.num_att_layers = cross_att_count
242 |
243 |
244 | def get_word_inds(text: str, word_place: int, tokenizer):
245 | split_text = text.split(" ")
246 | if type(word_place) is str:
247 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
248 | elif type(word_place) is int:
249 | word_place = [word_place]
250 | out = []
251 | if len(word_place) > 0:
252 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
253 | cur_len, ptr = 0, 0
254 |
255 | for i in range(len(words_encode)):
256 | cur_len += len(words_encode[i])
257 | if ptr in word_place:
258 | out.append(i + 1)
259 | if cur_len >= len(split_text[ptr]):
260 | ptr += 1
261 | cur_len = 0
262 | return np.array(out)
263 |
264 |
265 | def update_alpha_time_word(alpha, bounds: Union[float, Tuple[float, float]], prompt_ind: int,
266 | word_inds: Optional[torch.Tensor] = None):
267 | if type(bounds) is float:
268 | bounds = 0, bounds
269 | start, end = int(bounds[0] * alpha.shape[0]), int(bounds[1] * alpha.shape[0])
270 | if word_inds is None:
271 | word_inds = torch.arange(alpha.shape[2])
272 | alpha[: start, prompt_ind, word_inds] = 0
273 | alpha[start: end, prompt_ind, word_inds] = 1
274 | alpha[end:, prompt_ind, word_inds] = 0
275 | return alpha
276 |
277 |
278 | def get_time_words_attention_alpha(prompts, num_steps,
279 | cross_replace_steps: Union[float, Dict[str, Tuple[float, float]]],
280 | tokenizer, max_num_words=77):
281 | if type(cross_replace_steps) is not dict:
282 | cross_replace_steps = {"default_": cross_replace_steps}
283 | if "default_" not in cross_replace_steps:
284 | cross_replace_steps["default_"] = (0., 1.)
285 | alpha_time_words = torch.zeros(num_steps + 1, len(prompts) - 1, max_num_words)
286 | for i in range(len(prompts) - 1):
287 | alpha_time_words = update_alpha_time_word(alpha_time_words, cross_replace_steps["default_"],
288 | i)
289 | for key, item in cross_replace_steps.items():
290 | if key != "default_":
291 | inds = [get_word_inds(prompts[i], key, tokenizer) for i in range(1, len(prompts))]
292 | for i, ind in enumerate(inds):
293 | if len(ind) > 0:
294 | alpha_time_words = update_alpha_time_word(alpha_time_words, item, i, ind)
295 | alpha_time_words = alpha_time_words.reshape(num_steps + 1, len(prompts) - 1, 1, 1, max_num_words)
296 | return alpha_time_words
--------------------------------------------------------------------------------
/FADING_util/seq_aligner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Google LLC
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | import torch
15 | import numpy as np
16 |
17 |
18 | class ScoreParams:
19 |
20 | def __init__(self, gap, match, mismatch):
21 | self.gap = gap
22 | self.match = match
23 | self.mismatch = mismatch
24 |
25 | def mis_match_char(self, x, y):
26 | if x != y:
27 | return self.mismatch
28 | else:
29 | return self.match
30 |
31 |
32 | def get_matrix(size_x, size_y, gap):
33 | matrix = []
34 | for i in range(len(size_x) + 1):
35 | sub_matrix = []
36 | for j in range(len(size_y) + 1):
37 | sub_matrix.append(0)
38 | matrix.append(sub_matrix)
39 | for j in range(1, len(size_y) + 1):
40 | matrix[0][j] = j * gap
41 | for i in range(1, len(size_x) + 1):
42 | matrix[i][0] = i * gap
43 | return matrix
44 |
45 |
46 | def get_matrix(size_x, size_y, gap):
47 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
48 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap
49 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap
50 | return matrix
51 |
52 |
53 | def get_traceback_matrix(size_x, size_y):
54 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
55 | matrix[0, 1:] = 1
56 | matrix[1:, 0] = 2
57 | matrix[0, 0] = 4
58 | return matrix
59 |
60 |
61 | def global_align(x, y, score):
62 | matrix = get_matrix(len(x), len(y), score.gap)
63 | trace_back = get_traceback_matrix(len(x), len(y))
64 | for i in range(1, len(x) + 1):
65 | for j in range(1, len(y) + 1):
66 | left = matrix[i, j - 1] + score.gap
67 | up = matrix[i - 1, j] + score.gap
68 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
69 | matrix[i, j] = max(left, up, diag)
70 | if matrix[i, j] == left:
71 | trace_back[i, j] = 1
72 | elif matrix[i, j] == up:
73 | trace_back[i, j] = 2
74 | else:
75 | trace_back[i, j] = 3
76 | return matrix, trace_back
77 |
78 |
79 | def get_aligned_sequences(x, y, trace_back):
80 | x_seq = []
81 | y_seq = []
82 | i = len(x)
83 | j = len(y)
84 | mapper_y_to_x = []
85 | while i > 0 or j > 0:
86 | if trace_back[i, j] == 3:
87 | x_seq.append(x[i - 1])
88 | y_seq.append(y[j - 1])
89 | i = i - 1
90 | j = j - 1
91 | mapper_y_to_x.append((j, i))
92 | elif trace_back[i][j] == 1:
93 | x_seq.append('-')
94 | y_seq.append(y[j - 1])
95 | j = j - 1
96 | mapper_y_to_x.append((j, -1))
97 | elif trace_back[i][j] == 2:
98 | x_seq.append(x[i - 1])
99 | y_seq.append('-')
100 | i = i - 1
101 | elif trace_back[i][j] == 4:
102 | break
103 | mapper_y_to_x.reverse()
104 | return x_seq, y_seq, torch.tensor(mapper_y_to_x, dtype=torch.int64)
105 |
106 |
107 | def get_mapper(x: str, y: str, tokenizer, max_len=77):
108 | x_seq = tokenizer.encode(x)
109 | y_seq = tokenizer.encode(y)
110 | score = ScoreParams(0, 1, -1)
111 | matrix, trace_back = global_align(x_seq, y_seq, score)
112 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
113 | alphas = torch.ones(max_len)
114 | alphas[: mapper_base.shape[0]] = mapper_base[:, 1].ne(-1).float()
115 | mapper = torch.zeros(max_len, dtype=torch.int64)
116 | mapper[:mapper_base.shape[0]] = mapper_base[:, 1]
117 | mapper[mapper_base.shape[0]:] = len(y_seq) + torch.arange(max_len - len(y_seq))
118 | return mapper, alphas
119 |
120 |
121 | def get_refinement_mapper(prompts, tokenizer, max_len=77):
122 | x_seq = prompts[0]
123 | mappers, alphas = [], []
124 | for i in range(1, len(prompts)):
125 | mapper, alpha = get_mapper(x_seq, prompts[i], tokenizer, max_len)
126 | mappers.append(mapper)
127 | alphas.append(alpha)
128 | return torch.stack(mappers), torch.stack(alphas)
129 |
130 |
131 | def get_word_inds(text: str, word_place: int, tokenizer):
132 | split_text = text.split(" ")
133 | if type(word_place) is str:
134 | word_place = [i for i, word in enumerate(split_text) if word_place == word]
135 | elif type(word_place) is int:
136 | word_place = [word_place]
137 | out = []
138 | if len(word_place) > 0:
139 | words_encode = [tokenizer.decode([item]).strip("#") for item in tokenizer.encode(text)][1:-1]
140 | cur_len, ptr = 0, 0
141 |
142 | for i in range(len(words_encode)):
143 | cur_len += len(words_encode[i])
144 | if ptr in word_place:
145 | out.append(i + 1)
146 | if cur_len >= len(split_text[ptr]):
147 | ptr += 1
148 | cur_len = 0
149 | return np.array(out)
150 |
151 |
152 | def get_replacement_mapper_(x: str, y: str, tokenizer, max_len=77):
153 | words_x = x.split(' ')
154 | words_y = y.split(' ')
155 | if len(words_x) != len(words_y):
156 | raise ValueError(f"attention replacement edit can only be applied on prompts with the same length"
157 | f" but prompt A has {len(words_x)} words and prompt B has {len(words_y)} words.")
158 | inds_replace = [i for i in range(len(words_y)) if words_y[i] != words_x[i]]
159 | inds_source = [get_word_inds(x, i, tokenizer) for i in inds_replace]
160 | inds_target = [get_word_inds(y, i, tokenizer) for i in inds_replace]
161 | mapper = np.zeros((max_len, max_len))
162 | i = j = 0
163 | cur_inds = 0
164 | while i < max_len and j < max_len:
165 | if cur_inds < len(inds_source) and inds_source[cur_inds][0] == i:
166 | inds_source_, inds_target_ = inds_source[cur_inds], inds_target[cur_inds]
167 | if len(inds_source_) == len(inds_target_):
168 | mapper[inds_source_, inds_target_] = 1
169 | else:
170 | ratio = 1 / len(inds_target_)
171 | for i_t in inds_target_:
172 | mapper[inds_source_, i_t] = ratio
173 | cur_inds += 1
174 | i += len(inds_source_)
175 | j += len(inds_target_)
176 | elif cur_inds < len(inds_source):
177 | mapper[i, j] = 1
178 | i += 1
179 | j += 1
180 | else:
181 | mapper[j, j] = 1
182 | i += 1
183 | j += 1
184 |
185 | return torch.from_numpy(mapper).float()
186 |
187 |
188 | def get_replacement_mapper(prompts, tokenizer, max_len=77):
189 | x_seq = prompts[0]
190 | mappers = []
191 | for i in range(1, len(prompts)):
192 | mapper = get_replacement_mapper_(x_seq, prompts[i], tokenizer, max_len)
193 | mappers.append(mapper)
194 | return torch.stack(mappers)
195 |
--------------------------------------------------------------------------------
/FADING_util/util.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw
2 | import matplotlib.pyplot as plt
3 | import torch
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import os
7 | from torchvision import transforms
8 | import json
9 |
10 | def get_instance_prompt(dreambooth_dir):
11 | json_path = os.path.join(dreambooth_dir, "model_config.json")
12 | with open(json_path, 'r') as file:
13 | model_config = json.load(file)
14 | return model_config['instance_prompt']
15 |
16 | # get_instance_prompt('saved_model/de-id/FFHQ_512_00006_100')
17 | #%%
18 | def mydisplay(img):
19 | plt.axis('off')
20 | plt.imshow(img)
21 | plt.show()
22 |
23 | def load_image(p, arr=False, resize=None):
24 | '''
25 | Function to load images from a defined path
26 | '''
27 | ret = Image.open(p).convert('RGB')
28 | if resize is not None:
29 | ret = ret.resize((resize[0],resize[1]))
30 | if not arr:
31 | return ret
32 | return np.array(ret)
33 |
34 |
35 |
36 | def numpy_to_pil(images):
37 | """
38 | Convert a numpy image or a batch of images to a PIL image.
39 | """
40 | if images.ndim == 3:
41 | images = images[None, ...]
42 | images = (images * 255).round().astype("uint8")
43 | if images.shape[-1] == 1:
44 | # special case for grayscale (single channel) images
45 | pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
46 | else:
47 | pil_images = [Image.fromarray(image) for image in images]
48 |
49 | return pil_images
50 |
51 |
52 | def tensor_to_img(tensor,arr=False):
53 | tmp = tensor.clone().squeeze(0).cpu()
54 | tfpil = transforms.ToPILImage()
55 | tmp = tfpil(tmp)
56 | # tmp = (tmp+1)*0.5
57 | if arr:
58 | tmp = np.array(tmp)
59 | return tmp
60 |
61 |
62 | #%%
63 | def image_grid(imgs_, rows=None, cols=None, sort_file_filter=None, remove_filter=None, border=0, resize=None):
64 | if isinstance(imgs_, str) or (isinstance(imgs_, list) and isinstance(imgs_[0], str)):
65 | if isinstance(imgs_, str):
66 | # imgs 是一个dir
67 | files = os.listdir(imgs_)
68 |
69 | if remove_filter:
70 | files = remove_filter(files)
71 |
72 | if sort_file_filter:
73 | files = sorted(files, key=sort_file_filter)
74 |
75 | files = [os.path.join(imgs_, f) for f in files]
76 | else:
77 | # imgs 是一个dir的list
78 | files = imgs_
79 |
80 | print(files)
81 |
82 | imgs = []
83 | for f in files[:]:
84 | img = load_image(f,resize=resize)
85 | imgs.append(img)
86 |
87 | elif isinstance(imgs_, np.ndarray):
88 | # imgs 是一个ndarray
89 | imgs = [Image.fromarray(i) for i in imgs_]
90 |
91 | else:
92 | # imgs 是一个PIL的list
93 | imgs = imgs_[:]
94 |
95 | if not rows or not cols:
96 | rows = 1
97 | cols = len(imgs)
98 |
99 | assert len(imgs) == rows * cols
100 |
101 | w, h = imgs[-1].size
102 | grid = Image.new('RGB', size=(cols * w + (cols-1)*border, rows * h + (rows-1)*border), color='white')
103 | grid_w, grid_h = grid.size
104 |
105 | for i, img in enumerate(imgs):
106 | grid.paste(img, box=(i % cols * (w+border), i // cols * (h+border)))
107 | return grid
108 |
109 | #%%
110 | def sort_by_num(separator='-'):
111 | def sort_by_num_(x):
112 | return int(x.split(separator, 1)[0])
113 | return sort_by_num_
114 | def remove_filter(files):
115 | ret_files = []
116 | for f in files:
117 | if f[0]!= '.' and f.split('.')[0] not in ['1','8','17']:
118 | ret_files.append(f)
119 | return ret_files
120 | def tmp(x):
121 | num = int(x[:5])
122 | return num
123 |
124 |
125 | def get_person_placeholder(age=None, predicted_gender=None):
126 | if predicted_gender is not None:
127 | if age and age <= 15:
128 | person_placeholder = ['boy', 'girl'][predicted_gender == 'Female' or predicted_gender == 1]
129 | else: # init age > 15 或者根本没有init age
130 | person_placeholder = ['man','woman'][predicted_gender == 'Female' or predicted_gender == 1]
131 | else:
132 | if age and age <= 15:
133 | person_placeholder = "child"
134 | else:
135 | person_placeholder = "person"
136 | return person_placeholder
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FADING
2 |
3 | ## Face Aging via Diffusion-based Editing
4 | [](https://arxiv.org/abs/2309.11321)
5 | [](https://proceedings.bmvc2023.org/595/)
6 | [](https://bmvc2022.mpi-inf.mpg.de/BMVC2023/0595_supp.pdf)
7 |
8 | Official repo for BMVC 2023 paper: [Face Aging via Diffusion-based Editing](https://proceedings.bmvc2023.org/595/).
9 |
10 | For more visualization results, please check [supplementary materiels](https://bmvc2022.mpi-inf.mpg.de/BMVC2023/0595_supp.pdf).
11 |
12 |
13 |

14 |
15 |
16 | > In this paper, we address the problem of face aging—generating past or future facial images by incorporating age-related changes to the given face. Previous aging methods rely solely on human facial image datasets and are thus constrained by their inherent scale and bias. This restricts their application to a limited generatable age range and the inability to handle large age gaps. We propose FADING, a novel approach to address Face Aging via DIffusion-based editiNG. We go beyond existing methods by leveraging the rich prior of large-scale language-image diffusion models. First, we specialize a pre-trained diffusion model for the task of face age editing by using an age-aware fine-tuning scheme. Next, we invert the input image to latent noise and obtain optimized null text embeddings. Finally, we perform text-guided local age editing via attention control. The quantitative and qualitative analyses demonstrate that our method outperforms existing approaches with respect to aging accuracy, attribute preservation, and aging quality.
17 |
18 |
19 |
20 | ## Dataset
21 | The FFHQ-Aging Dataset used for training FADING could be downloaded from https://github.com/royorel/FFHQ-Aging-Dataset
22 |
23 | ## Training (Specialization)
24 |
25 | ### Available pretrained weights
26 | We release weights of our specialized model at https://drive.google.com/file/d/1galwrcHq1HoZNfOI4jdJJqVs5ehB_dvO/view?usp=share_link
27 |
28 | ### Train a new model
29 |
30 | ```shell
31 | accelerate launch specialize_general.py \
32 | --instance_data_dir 'specialization_data/training_images' \
33 | --instance_age_path 'specialization_data/training_ages.npy' \
34 | --output_dir \
35 | --max_train_steps 150
36 | ```
37 | Training images should be saved at `specialization_data/training_images`. The training set is described through `training_ages.npy` that contains the age of the training images.
38 | ```angular2html
39 | array([['00007.jpg', '1'],
40 | ['00004.jpg', '35'],
41 | ...
42 | ['00009.jpg', '35']], dtype=' \
50 | --age_init \
51 | --gender \
52 | --save_aged_dir \
53 | --specialized_path \
54 | --target_ages 10 20 40 60 80
55 | ```
56 |
--------------------------------------------------------------------------------
/age_editing.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from diffusers import StableDiffusionPipeline, DDIMScheduler
4 |
5 | from FADING_util import util
6 | from p2p import *
7 | from null_inversion import *
8 |
9 | #%%
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument('--image_path', required=True, help='Path to input image')
13 | parser.add_argument('--age_init', required=True, type=int, help='Specify the initial age')
14 | parser.add_argument('--gender', required=True, choices=["female", "male"], help="Specify the gender ('female' or 'male')")
15 | parser.add_argument('--specialized_path', required=True, help='Path to specialized diffusion model')
16 | parser.add_argument('--save_aged_dir', default='./outputs', help='Path to save outputs')
17 | parser.add_argument('--target_ages', nargs='+', default=[10, 20, 40, 60, 80], type=int, help='Target age values')
18 |
19 | args = parser.parse_args()
20 |
21 | #%%
22 | image_path = args.image_path
23 | age_init = args.age_init
24 | gender = args.gender
25 | save_aged_dir = args.save_aged_dir
26 | specialized_path = args.specialized_path
27 | target_ages = args.target_ages
28 |
29 | if not os.path.exists(save_aged_dir):
30 | os.makedirs(save_aged_dir)
31 |
32 | gt_gender = int(gender == 'female')
33 | person_placeholder = util.get_person_placeholder(age_init, gt_gender)
34 | inversion_prompt = f"photo of {age_init} year old {person_placeholder}"
35 |
36 | input_img_name = image_path.split('/')[-1].split('.')[-2]
37 |
38 | #%% load specialized diffusion model
39 | scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
40 | clip_sample=False, set_alpha_to_one=False,
41 | steps_offset=1)
42 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
43 | g_cuda = torch.Generator(device=device)
44 |
45 | ldm_stable = StableDiffusionPipeline.from_pretrained(specialized_path,
46 | scheduler=scheduler,
47 | safety_checker=None).to(device)
48 | tokenizer = ldm_stable.tokenizer
49 |
50 | #%% null text inversion
51 | null_inversion = NullInversion(ldm_stable)
52 | (image_gt, image_enc), x_t, uncond_embeddings = null_inversion.invert(image_path, inversion_prompt,
53 | offsets=(0,0,0,0), verbose=True)
54 | #%% age editing
55 | for age_new in target_ages:
56 | print(f'Age editing with target age {age_new}...')
57 | new_person_placeholder = util.get_person_placeholder(age_new, gt_gender)
58 | new_prompt = inversion_prompt.replace(person_placeholder, new_person_placeholder)
59 |
60 | new_prompt = new_prompt.replace(str(age_init),str(age_new))
61 | blend_word = (((str(age_init),person_placeholder,), (str(age_new),new_person_placeholder,)))
62 | is_replace_controller = True
63 |
64 | prompts = [inversion_prompt, new_prompt]
65 |
66 | cross_replace_steps = {'default_': .8,}
67 | self_replace_steps = .5
68 |
69 | eq_params = {"words": (str(age_new)), "values": (1,)}
70 |
71 | controller = make_controller(prompts, is_replace_controller, cross_replace_steps, self_replace_steps,
72 | tokenizer, blend_word, eq_params)
73 |
74 | images, _ = p2p_text2image(ldm_stable, prompts, controller, generator=g_cuda.manual_seed(0),
75 | latent=x_t, uncond_embeddings=uncond_embeddings)
76 |
77 | new_img = images[-1]
78 | new_img_pil = Image.fromarray(new_img)
79 | new_img_pil.save(os.path.join(save_aged_dir,f'{input_img_name}_{age_new}.png'))
80 |
--------------------------------------------------------------------------------
/null_inversion.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from tqdm import tqdm
3 | import torch
4 | # from diffusers import StableDiffusionPipeline, DDIMScheduler
5 | import torch.nn.functional as nnf
6 | import numpy as np
7 | from torch.optim.adam import Adam
8 | from PIL import Image
9 |
10 | import FADING_util.ptp_utils as ptp_utils
11 |
12 |
13 | LOW_RESOURCE = False
14 | NUM_DDIM_STEPS = 50
15 | GUIDANCE_SCALE = 7.5
16 | MAX_NUM_WORDS = 77
17 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
18 |
19 |
20 |
21 | def load_512(image_path, left=0, right=0, top=0, bottom=0):
22 | if type(image_path) is str:
23 | image = np.array(Image.open(image_path))[:, :, :3]
24 | else:
25 | image = image_path
26 | h, w, c = image.shape
27 | left = min(left, w - 1)
28 | right = min(right, w - left - 1)
29 | top = min(top, h - left - 1)
30 | bottom = min(bottom, h - top - 1)
31 | image = image[top:h - bottom, left:w - right]
32 | h, w, c = image.shape
33 | if h < w:
34 | offset = (w - h) // 2
35 | image = image[:, offset:offset + h]
36 | elif w < h:
37 | offset = (h - w) // 2
38 | image = image[offset:offset + w]
39 | image = np.array(Image.fromarray(image).resize((512, 512)))
40 | return image
41 |
42 |
43 | class NullInversion:
44 |
45 | def prev_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
46 | sample: Union[torch.FloatTensor, np.ndarray]):
47 | prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
48 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
49 | alpha_prod_t_prev = self.scheduler.alphas_cumprod[
50 | prev_timestep] if prev_timestep >= 0 else self.scheduler.final_alpha_cumprod
51 | beta_prod_t = 1 - alpha_prod_t
52 | pred_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
53 | pred_sample_direction = (1 - alpha_prod_t_prev) ** 0.5 * model_output
54 | prev_sample = alpha_prod_t_prev ** 0.5 * pred_original_sample + pred_sample_direction
55 | return prev_sample
56 |
57 | def next_step(self, model_output: Union[torch.FloatTensor, np.ndarray], timestep: int,
58 | sample: Union[torch.FloatTensor, np.ndarray]):
59 | timestep, next_timestep = min(
60 | timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999), timestep
61 | alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
62 | alpha_prod_t_next = self.scheduler.alphas_cumprod[next_timestep]
63 | beta_prod_t = 1 - alpha_prod_t
64 | next_original_sample = (sample - beta_prod_t ** 0.5 * model_output) / alpha_prod_t ** 0.5
65 | next_sample_direction = (1 - alpha_prod_t_next) ** 0.5 * model_output
66 | next_sample = alpha_prod_t_next ** 0.5 * next_original_sample + next_sample_direction
67 | return next_sample
68 |
69 | def get_noise_pred_single(self, latents, t, context):
70 | noise_pred = self.model.unet(latents, t, encoder_hidden_states=context)["sample"]
71 | return noise_pred
72 |
73 | def get_noise_pred(self, latents, t, is_forward=True, context=None):
74 | latents_input = torch.cat([latents] * 2)
75 | if context is None:
76 | context = self.context
77 | guidance_scale = 1 if is_forward else GUIDANCE_SCALE
78 | noise_pred = self.model.unet(latents_input, t, encoder_hidden_states=context)["sample"]
79 | noise_pred_uncond, noise_prediction_text = noise_pred.chunk(2)
80 | noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
81 | if is_forward:
82 | latents = self.next_step(noise_pred, t, latents)
83 | else:
84 | latents = self.prev_step(noise_pred, t, latents)
85 | return latents
86 |
87 | @torch.no_grad()
88 | def latent2image(self, latents, return_type='np'):
89 | latents = 1 / 0.18215 * latents.detach()
90 | image = self.model.vae.decode(latents)['sample']
91 | if return_type == 'np':
92 | image = (image / 2 + 0.5).clamp(0, 1)
93 | image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
94 | image = (image * 255).astype(np.uint8)
95 | return image
96 |
97 | @torch.no_grad()
98 | def image2latent(self, image):
99 | with torch.no_grad():
100 | if type(image) is Image:
101 | image = np.array(image)
102 | if type(image) is torch.Tensor and image.dim() == 4:
103 | latents = image
104 | else:
105 | image = torch.from_numpy(image).float() / 127.5 - 1
106 | image = image.permute(2, 0, 1).unsqueeze(0).to(device)
107 | latents = self.model.vae.encode(image)['latent_dist'].mean
108 | latents = latents * 0.18215
109 | return latents
110 |
111 | @torch.no_grad()
112 | def init_prompt(self, prompt: str):
113 | uncond_input = self.tokenizer(
114 | [""], padding="max_length", max_length=self.tokenizer.model_max_length,
115 | return_tensors="pt"
116 | )
117 | uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
118 | text_input = self.tokenizer(
119 | [prompt],
120 | padding="max_length",
121 | max_length=self.tokenizer.model_max_length,
122 | truncation=True,
123 | return_tensors="pt",
124 | )
125 | text_embeddings = self.model.text_encoder(text_input.input_ids.to(self.model.device))[0]
126 | self.context = torch.cat([uncond_embeddings, text_embeddings])
127 | self.prompt = prompt
128 |
129 | @torch.no_grad()
130 | def ddim_loop(self, latent):
131 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
132 | all_latent = [latent]
133 | latent = latent.clone().detach()
134 | for i in range(NUM_DDIM_STEPS):
135 | t = self.model.scheduler.timesteps[len(self.model.scheduler.timesteps) - i - 1]
136 | noise_pred = self.get_noise_pred_single(latent, t, cond_embeddings)
137 | latent = self.next_step(noise_pred, t, latent)
138 | all_latent.append(latent)
139 | return all_latent
140 |
141 | @property
142 | def scheduler(self):
143 | return self.model.scheduler
144 |
145 | @torch.no_grad()
146 | def ddim_inversion(self, image):
147 | latent = self.image2latent(image)
148 | image_rec = self.latent2image(latent)
149 | ddim_latents = self.ddim_loop(latent)
150 | return image_rec, ddim_latents
151 |
152 | def null_optimization(self, latents, num_inner_steps, epsilon):
153 | uncond_embeddings, cond_embeddings = self.context.chunk(2)
154 | uncond_embeddings_list = []
155 | latent_cur = latents[-1]
156 | bar = tqdm(total=num_inner_steps * NUM_DDIM_STEPS)
157 | for i in range(NUM_DDIM_STEPS):
158 | uncond_embeddings = uncond_embeddings.clone().detach()
159 | uncond_embeddings.requires_grad = True
160 | optimizer = Adam([uncond_embeddings], lr=1e-2 * (1. - i / 100.))
161 | latent_prev = latents[len(latents) - i - 2]
162 | t = self.model.scheduler.timesteps[i]
163 | with torch.no_grad():
164 | noise_pred_cond = self.get_noise_pred_single(latent_cur, t, cond_embeddings)
165 | for j in range(num_inner_steps):
166 | noise_pred_uncond = self.get_noise_pred_single(latent_cur, t, uncond_embeddings)
167 | noise_pred = noise_pred_uncond + GUIDANCE_SCALE * (noise_pred_cond - noise_pred_uncond)
168 | latents_prev_rec = self.prev_step(noise_pred, t, latent_cur)
169 | loss = nnf.mse_loss(latents_prev_rec, latent_prev)
170 | optimizer.zero_grad()
171 | loss.backward()
172 | optimizer.step()
173 | loss_item = loss.item()
174 | bar.update()
175 | if loss_item < epsilon + i * 2e-5:
176 | break
177 | for j in range(j + 1, num_inner_steps):
178 | bar.update()
179 | uncond_embeddings_list.append(uncond_embeddings[:1].detach())
180 | with torch.no_grad():
181 | context = torch.cat([uncond_embeddings, cond_embeddings])
182 | latent_cur = self.get_noise_pred(latent_cur, t, False, context)
183 | bar.close()
184 | return uncond_embeddings_list
185 |
186 | def invert(self, image_path: str, prompt: str, offsets=(0, 0, 0, 0), num_inner_steps=10, early_stop_epsilon=1e-5,
187 | verbose=False):
188 | self.init_prompt(prompt)
189 | ptp_utils.register_attention_control(self.model, None)
190 | image_gt = load_512(image_path, *offsets)
191 | if verbose:
192 | print("DDIM inversion...")
193 | image_rec, ddim_latents = self.ddim_inversion(image_gt)
194 | if verbose:
195 | print("Null-text optimization...")
196 | uncond_embeddings = self.null_optimization(ddim_latents, num_inner_steps, early_stop_epsilon)
197 | return (image_gt, image_rec), ddim_latents[-1], uncond_embeddings
198 |
199 | def __init__(self, model):
200 | # scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
201 | # set_alpha_to_one=False)
202 | self.model = model
203 | self.tokenizer = self.model.tokenizer
204 | self.model.scheduler.set_timesteps(NUM_DDIM_STEPS)
205 | self.prompt = None
206 | self.context = None
--------------------------------------------------------------------------------
/p2p.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union, Tuple, List, Dict
2 | from tqdm import tqdm
3 | import torch
4 | import torch.nn.functional as nnf
5 | import numpy as np
6 | from PIL import Image
7 | import abc
8 |
9 | import FADING_util.ptp_utils as ptp_utils
10 | import FADING_util.seq_aligner as seq_aligner
11 |
12 | LOW_RESOURCE = False
13 | NUM_DDIM_STEPS = 50
14 | GUIDANCE_SCALE = 7.5
15 | MAX_NUM_WORDS = 77
16 | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
17 | #%%
18 | #% Prompt-to-Prompt code
19 | class LocalBlend:
20 |
21 | def get_mask(self, x_t, maps, alpha, use_pool):
22 | k = 1
23 | maps = (maps * alpha).sum(-1).mean(1)
24 | if use_pool:
25 | maps = nnf.max_pool2d(maps, (k * 2 + 1, k * 2 + 1), (1, 1), padding=(k, k))
26 | mask = nnf.interpolate(maps, size=(x_t.shape[2:]))
27 | mask = mask / mask.max(2, keepdims=True)[0].max(3, keepdims=True)[0]
28 | mask = mask.gt(self.th[1 - int(use_pool)])
29 | mask = mask[:1] + mask
30 | return mask
31 |
32 | def __call__(self, x_t, attention_store):
33 | self.counter += 1
34 | if self.counter > self.start_blend:
35 |
36 | maps = attention_store["down_cross"][2:4] + attention_store["up_cross"][:3]
37 | maps = [item.reshape(self.alpha_layers.shape[0], -1, 1, 16, 16, MAX_NUM_WORDS) for item in maps]
38 | maps = torch.cat(maps, dim=1)
39 | mask = self.get_mask(x_t, maps, self.alpha_layers, True)
40 | if self.substruct_layers is not None:
41 | maps_sub = ~self.get_mask(x_t, maps, self.substruct_layers, False)
42 | mask = mask * maps_sub
43 | mask = mask.float()
44 | x_t = x_t[:1] + mask * (x_t - x_t[:1])
45 | return x_t
46 |
47 | def __init__(self, prompts: List[str], words: [List[List[str]]], tokenizer, substruct_words=None, start_blend=0.2,
48 | th=(.3, .3)):
49 | alpha_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
50 | for i, (prompt, words_) in enumerate(zip(prompts, words)):
51 | if type(words_) is str:
52 | words_ = [words_]
53 | for word in words_:
54 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
55 | alpha_layers[i, :, :, :, :, ind] = 1
56 |
57 | if substruct_words is not None:
58 | substruct_layers = torch.zeros(len(prompts), 1, 1, 1, 1, MAX_NUM_WORDS)
59 | for i, (prompt, words_) in enumerate(zip(prompts, substruct_words)):
60 | if type(words_) is str:
61 | words_ = [words_]
62 | for word in words_:
63 | ind = ptp_utils.get_word_inds(prompt, word, tokenizer)
64 | substruct_layers[i, :, :, :, :, ind] = 1
65 | self.substruct_layers = substruct_layers.to(device)
66 | else:
67 | self.substruct_layers = None
68 | self.alpha_layers = alpha_layers.to(device)
69 | self.start_blend = int(start_blend * NUM_DDIM_STEPS)
70 | self.counter = 0
71 | self.th = th
72 |
73 |
74 | class EmptyControl:
75 |
76 | def step_callback(self, x_t):
77 | return x_t
78 |
79 | def between_steps(self):
80 | return
81 |
82 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
83 | return attn
84 |
85 |
86 | class AttentionControl(abc.ABC):
87 |
88 | def step_callback(self, x_t):
89 | return x_t
90 |
91 | def between_steps(self):
92 | return
93 |
94 | @property
95 | def num_uncond_att_layers(self):
96 | return self.num_att_layers if LOW_RESOURCE else 0
97 |
98 | @abc.abstractmethod
99 | def forward(self, attn, is_cross: bool, place_in_unet: str):
100 | raise NotImplementedError
101 |
102 | def __call__(self, attn, is_cross: bool, place_in_unet: str):
103 | if self.cur_att_layer >= self.num_uncond_att_layers:
104 | if LOW_RESOURCE:
105 | attn = self.forward(attn, is_cross, place_in_unet)
106 | else:
107 | h = attn.shape[0]
108 | attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
109 | self.cur_att_layer += 1
110 | if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
111 | self.cur_att_layer = 0
112 | self.cur_step += 1
113 | self.between_steps()
114 | return attn
115 |
116 | def reset(self):
117 | self.cur_step = 0
118 | self.cur_att_layer = 0
119 |
120 | def __init__(self):
121 | self.cur_step = 0
122 | self.num_att_layers = -1
123 | self.cur_att_layer = 0
124 |
125 |
126 | class SpatialReplace(EmptyControl):
127 |
128 | def step_callback(self, x_t):
129 | if self.cur_step < self.stop_inject:
130 | b = x_t.shape[0]
131 | x_t = x_t[:1].expand(b, *x_t.shape[1:])
132 | return x_t
133 |
134 | def __init__(self, stop_inject: float):
135 | super(SpatialReplace, self).__init__()
136 | self.stop_inject = int((1 - stop_inject) * NUM_DDIM_STEPS)
137 |
138 |
139 | class AttentionStore(AttentionControl):
140 |
141 | @staticmethod
142 | def get_empty_store():
143 | return {"down_cross": [], "mid_cross": [], "up_cross": [],
144 | "down_self": [], "mid_self": [], "up_self": []}
145 |
146 | def forward(self, attn, is_cross: bool, place_in_unet: str):
147 | key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
148 | if attn.shape[1] <= 32 ** 2: # avoid memory overhead
149 | self.step_store[key].append(attn)
150 | return attn
151 |
152 | def between_steps(self):
153 | if len(self.attention_store) == 0:
154 | self.attention_store = self.step_store
155 | else:
156 | for key in self.attention_store:
157 | for i in range(len(self.attention_store[key])):
158 | self.attention_store[key][i] += self.step_store[key][i]
159 | self.step_store = self.get_empty_store()
160 |
161 | def get_average_attention(self):
162 | average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in
163 | self.attention_store}
164 | return average_attention
165 |
166 | def reset(self):
167 | super(AttentionStore, self).reset()
168 | self.step_store = self.get_empty_store()
169 | self.attention_store = {}
170 |
171 | def __init__(self):
172 | super(AttentionStore, self).__init__()
173 | self.step_store = self.get_empty_store()
174 | self.attention_store = {}
175 |
176 |
177 | class AttentionControlEdit(AttentionStore, abc.ABC):
178 |
179 | def step_callback(self, x_t):
180 | if self.local_blend is not None:
181 | x_t = self.local_blend(x_t, self.attention_store)
182 | return x_t
183 |
184 | def replace_self_attention(self, attn_base, att_replace, place_in_unet):
185 | if att_replace.shape[2] <= 32 ** 2:
186 | attn_base = attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape)
187 | return attn_base
188 | else:
189 | return att_replace
190 |
191 | @abc.abstractmethod
192 | def replace_cross_attention(self, attn_base, att_replace):
193 | raise NotImplementedError
194 |
195 | def forward(self, attn, is_cross: bool, place_in_unet: str):
196 | super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet)
197 | if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]):
198 | h = attn.shape[0] // (self.batch_size)
199 | attn = attn.reshape(self.batch_size, h, *attn.shape[1:])
200 | attn_base, attn_repalce = attn[0], attn[1:]
201 | if is_cross:
202 | alpha_words = self.cross_replace_alpha[self.cur_step]
203 | attn_repalce_new = self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + (
204 | 1 - alpha_words) * attn_repalce
205 | attn[1:] = attn_repalce_new
206 | else:
207 | attn[1:] = self.replace_self_attention(attn_base, attn_repalce, place_in_unet)
208 | attn = attn.reshape(self.batch_size * h, *attn.shape[2:])
209 | return attn
210 |
211 | def __init__(self, prompts, num_steps: int,
212 | cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]],
213 | self_replace_steps: Union[float, Tuple[float, float]],
214 | tokenizer,
215 | local_blend: Optional[LocalBlend]):
216 | super(AttentionControlEdit, self).__init__()
217 | self.batch_size = len(prompts)
218 | self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha(prompts, num_steps, cross_replace_steps,
219 | tokenizer).to(device)
220 | if type(self_replace_steps) is float:
221 | self_replace_steps = 0, self_replace_steps
222 | self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1])
223 | self.local_blend = local_blend
224 |
225 |
226 | class AttentionReplace(AttentionControlEdit):
227 |
228 | def replace_cross_attention(self, attn_base, att_replace):
229 | return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper)
230 |
231 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
232 | tokenizer,
233 | local_blend: Optional[LocalBlend] = None):
234 | super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, tokenizer, local_blend)
235 | self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device)
236 |
237 |
238 | class AttentionRefine(AttentionControlEdit):
239 |
240 | def replace_cross_attention(self, attn_base, att_replace):
241 | attn_base_replace = attn_base[:, :, self.mapper].permute(2, 0, 1, 3)
242 | attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas)
243 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
244 | return attn_replace
245 |
246 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
247 | tokenizer,
248 | local_blend: Optional[LocalBlend] = None):
249 | super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, tokenizer, local_blend)
250 | self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer)
251 | self.mapper, alphas = self.mapper.to(device), alphas.to(device)
252 | self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1])
253 |
254 |
255 | class AttentionReweight(AttentionControlEdit):
256 |
257 | def replace_cross_attention(self, attn_base, att_replace):
258 | if self.prev_controller is not None:
259 | attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace)
260 | attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :]
261 | # attn_replace = attn_replace / attn_replace.sum(-1, keepdims=True)
262 | return attn_replace
263 |
264 | def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float,
265 | tokenizer, equalizer,
266 | local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None):
267 | super(AttentionReweight, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps,
268 | tokenizer,
269 | local_blend)
270 | self.equalizer = equalizer.to(device)
271 | self.prev_controller = controller
272 |
273 |
274 | def get_equalizer(text: str, word_select: Union[int, Tuple[int, ...]], values: Union[List[float],Tuple[float, ...]],
275 | tokenizer):
276 | if type(word_select) is int or type(word_select) is str:
277 | word_select = (word_select,)
278 | equalizer = torch.ones(1, 77)
279 |
280 | for word, val in zip(word_select, values):
281 | inds = ptp_utils.get_word_inds(text, word, tokenizer)
282 | equalizer[:, inds] = val
283 | return equalizer
284 |
285 |
286 |
287 |
288 | def make_controller(prompts: List[str], is_replace_controller: bool,
289 | cross_replace_steps: Dict[str, float],
290 | self_replace_steps: float,
291 | tokenizer,
292 | blend_words=None, equilizer_params=None) -> AttentionControlEdit:
293 | if blend_words is None:
294 | lb = None
295 | else:
296 | lb = LocalBlend(prompts, blend_words, tokenizer=tokenizer)
297 | if is_replace_controller:
298 | controller = AttentionReplace(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
299 | self_replace_steps=self_replace_steps,
300 | tokenizer=tokenizer,
301 | local_blend=lb)
302 | else:
303 | controller = AttentionRefine(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
304 | self_replace_steps=self_replace_steps,
305 | tokenizer=tokenizer,
306 | local_blend=lb)
307 | if equilizer_params is not None:
308 | eq = get_equalizer(prompts[1], equilizer_params["words"], equilizer_params["values"], tokenizer=tokenizer)
309 | controller = AttentionReweight(prompts, NUM_DDIM_STEPS, cross_replace_steps=cross_replace_steps,
310 | self_replace_steps=self_replace_steps,
311 | tokenizer=tokenizer,
312 | equalizer=eq, local_blend=lb,
313 | controller=controller)
314 | return controller
315 |
316 |
317 | @torch.no_grad()
318 | def p2p_text2image(
319 | model,
320 | prompt: List[str],
321 | controller,
322 | num_inference_steps: int = 50,
323 | guidance_scale: Optional[float] = 7.5,
324 | generator: Optional[torch.Generator] = None,
325 | latent: Optional[torch.FloatTensor] = None,
326 | uncond_embeddings=None,
327 | start_time=50,
328 | return_type='image',
329 | height=512, width=512
330 | ):
331 | tokenizer = model.tokenizer
332 | batch_size = len(prompt)
333 | ptp_utils.register_attention_control(model, controller)
334 |
335 | text_input = tokenizer(
336 | prompt,
337 | padding="max_length",
338 | max_length=tokenizer.model_max_length,
339 | truncation=True,
340 | return_tensors="pt",
341 | )
342 | text_embeddings = model.text_encoder(text_input.input_ids.to(model.device))[0]
343 | max_length = text_input.input_ids.shape[-1]
344 | if uncond_embeddings is None:
345 | uncond_input = tokenizer(
346 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
347 | )
348 | uncond_embeddings_ = model.text_encoder(uncond_input.input_ids.to(model.device))[0]
349 | else:
350 | uncond_embeddings_ = None
351 |
352 | latent, latents = ptp_utils.init_latent(latent, model, height, width, generator, batch_size)
353 | model.scheduler.set_timesteps(num_inference_steps)
354 |
355 | for i, t in enumerate(tqdm(model.scheduler.timesteps[-start_time:])):
356 | if uncond_embeddings_ is None:
357 | context = torch.cat([uncond_embeddings[i].expand(*text_embeddings.shape), text_embeddings])
358 | # context = torch.cat([uncond_embeddings[i], text_embeddings]) # 我改了
359 | else:
360 | context = torch.cat([uncond_embeddings_, text_embeddings])
361 | latents = ptp_utils.diffusion_step(model, controller, latents, context, t, guidance_scale, low_resource=False)
362 |
363 | if return_type == 'image':
364 | image = ptp_utils.latent2image(model.vae, latents)
365 | else:
366 | image = latents
367 |
368 | return image, latent
369 |
370 |
371 |
372 | # def run_and_display(my_ldm_stable, prompts, controller, latent=None, run_baseline=False, generator=None, uncond_embeddings=None,
373 | # verbose=True):
374 | # if run_baseline:
375 | # print("w.o. prompt-to-prompt")
376 | # images, latent = run_and_display(my_ldm_stable, prompts, EmptyControl(), latent=latent, run_baseline=False,
377 | # generator=generator)
378 | # print("with prompt-to-prompt")
379 | # images, x_t = p2p_text2image(my_ldm_stable, prompts, controller, latent=latent,
380 | # num_inference_steps=NUM_DDIM_STEPS, guidance_scale=GUIDANCE_SCALE,
381 | # generator=generator, uncond_embeddings=uncond_embeddings)
382 | # if verbose:
383 | # images = ptp_utils.view_images(images)
384 | #
385 | # return images, x_t
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 | def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
395 | out = []
396 | attention_maps = attention_store.get_average_attention()
397 | num_pixels = res ** 2
398 | for location in from_where:
399 | for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
400 | if item.shape[1] == num_pixels:
401 | cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
402 | out.append(cross_maps)
403 | out = torch.cat(out, dim=0)
404 | out = out.sum(0) / out.shape[0]
405 | return out.cpu()
406 |
407 | def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
408 | tokens = tokenizer.encode(prompts[select])
409 | decoder = tokenizer.decode
410 | attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
411 | images = []
412 | for i in range(len(tokens)):
413 | image = attention_maps[:, :, i]
414 | image = 255 * image / image.max()
415 | image = image.unsqueeze(-1).expand(*image.shape, 3)
416 | image = image.numpy().astype(np.uint8)
417 | image = np.array(Image.fromarray(image).resize((256, 256)))
418 | image = ptp_utils.text_under_image(image, decoder(int(tokens[i])))
419 | images.append(image)
420 | image_ = ptp_utils.view_images(np.stack(images, axis=0))
421 | return image_
422 |
423 |
424 | def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
425 | max_com=10, select: int = 0):
426 | attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape(
427 | (res ** 2, res ** 2))
428 | u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
429 | images = []
430 | for i in range(max_com):
431 | image = vh[i].reshape(res, res)
432 | image = image - image.min()
433 | image = 255 * image / image.max()
434 | image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
435 | image = Image.fromarray(image).resize((256, 256))
436 | image = np.array(image)
437 | images.append(image)
438 | ptp_utils.view_images(np.concatenate(images, axis=1))
439 |
440 |
441 | #% null inversion
442 |
--------------------------------------------------------------------------------
/specialize.py:
--------------------------------------------------------------------------------
1 | # code adapted from https://github.com/huggingface/diffusers/blob/v0.10.0/examples/dreambooth/train_dreambooth.py
2 |
3 | import os
4 | import argparse
5 | import json
6 | import hashlib
7 | import itertools
8 | import logging
9 | import math
10 | import numpy as np
11 | from PIL import Image
12 |
13 | try:
14 | from torchvision.transforms import InterpolationMode
15 | BICUBIC = InterpolationMode.BICUBIC
16 | except ImportError:
17 | BICUBIC = Image.BICUBIC
18 | import warnings
19 | from pathlib import Path
20 | from typing import List, Optional, Tuple, Union
21 |
22 | import torch
23 | import torch.nn.functional as F
24 | import torch.utils.checkpoint
25 | from torch.utils.data import Dataset
26 |
27 | import diffusers
28 | import transformers
29 | from accelerate import Accelerator
30 | from accelerate.logging import get_logger
31 | from accelerate.utils import set_seed
32 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
33 | from diffusers.optimization import get_scheduler
34 | # from diffusers.utils import check_min_version
35 | from diffusers.utils.import_utils import is_xformers_available
36 | from huggingface_hub import HfFolder, Repository, whoami
37 |
38 | from torchvision import transforms
39 | from tqdm.auto import tqdm
40 | from transformers import AutoTokenizer, PretrainedConfig
41 |
42 | #%%
43 |
44 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
45 | # check_min_version("0.10.0.dev0")
46 |
47 | logger = get_logger(__name__)
48 |
49 |
50 | def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
51 | text_encoder_config = PretrainedConfig.from_pretrained(
52 | pretrained_model_name_or_path,
53 | subfolder="text_encoder",
54 | revision=revision,
55 | )
56 | model_class = text_encoder_config.architectures[0]
57 |
58 | if model_class == "CLIPTextModel":
59 | from transformers import CLIPTextModel
60 |
61 | return CLIPTextModel
62 | elif model_class == "RobertaSeriesModelWithTransformation":
63 | from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
64 |
65 | return RobertaSeriesModelWithTransformation
66 | else:
67 | raise ValueError(f"{model_class} is not supported.")
68 |
69 |
70 | def parse_args(input_args=None):
71 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
72 | parser.add_argument(
73 | "--pretrained_model_name_or_path",
74 | type=str,
75 | default="runwayml/stable-diffusion-v1-5",
76 | required=False,
77 | help="Path to pretrained model or model identifier from huggingface.co/models.",
78 | )
79 | parser.add_argument(
80 | "--instance_data_dir",
81 | type=str,
82 | default=None,
83 | required=True,
84 | help="A folder containing the training data of training images.",
85 | )
86 | parser.add_argument(
87 | "--instance_age_path",
88 | type=str,
89 | default=None,
90 | required=True,
91 | help="A numpy array that contains the initial ages of training images.",
92 | )
93 | parser.add_argument(
94 | "--instance_prompt",
95 | type=str,
96 | default="photo of a person",
97 | required=False,
98 | help="The prompt with identifier specifying the instance",
99 | )
100 | parser.add_argument(
101 | "--output_dir",
102 | type=str,
103 | default=None,
104 | required=True,
105 | help="The output directory where the model predictions and checkpoints will be written.",
106 | )
107 | parser.add_argument(
108 | "--finetune_mode",
109 | type=str,
110 | default='finetune_double_prompt',
111 | required=False,
112 | help="Specialization mode, 'finetune_double_prompt'|'finetune_single_prompt'",
113 | )
114 |
115 |
116 |
117 | #####
118 | parser.add_argument(
119 | "--revision",
120 | type=str,
121 | default=None,
122 | required=False,
123 | help=(
124 | "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
125 | " float32 precision."
126 | ),
127 | )
128 | parser.add_argument(
129 | "--tokenizer_name",
130 | type=str,
131 | default=None,
132 | help="Pretrained tokenizer name or path if not the same as model_name",
133 | )
134 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
135 | parser.add_argument(
136 | "--resolution",
137 | type=int,
138 | default=512,
139 | help=(
140 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
141 | " resolution"
142 | ),
143 | )
144 | parser.add_argument(
145 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
146 | )
147 | parser.add_argument(
148 | "--train_text_encoder",
149 | action="store_true",
150 | help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
151 | )
152 | parser.add_argument(
153 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
154 | )
155 | parser.add_argument(
156 | "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
157 | )
158 | parser.add_argument("--num_train_epochs", type=int, default=1)
159 | parser.add_argument(
160 | "--max_train_steps",
161 | type=int,
162 | default=150,
163 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
164 | )
165 |
166 | parser.add_argument(
167 | "--checkpointing_steps",
168 | type=int,
169 | default=500,
170 | help=(
171 | "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
172 | " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
173 | " training using `--resume_from_checkpoint`."
174 | ),
175 | )
176 | parser.add_argument(
177 | "--resume_from_checkpoint",
178 | type=str,
179 | default=None,
180 | help=(
181 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
182 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
183 | ),
184 | )
185 | parser.add_argument(
186 | "--gradient_accumulation_steps",
187 | type=int,
188 | default=1,
189 | help="Number of updates steps to accumulate before performing a backward/update pass.",
190 | )
191 | parser.add_argument(
192 | "--gradient_checkpointing",
193 | action="store_true",
194 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
195 | )
196 | parser.add_argument(
197 | "--learning_rate",
198 | type=float,
199 | default=5e-6,
200 | help="Initial learning rate (after the potential warmup period) to use.",
201 | )
202 | parser.add_argument(
203 | "--scale_lr",
204 | action="store_true",
205 | default=False,
206 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
207 | )
208 | parser.add_argument(
209 | "--lr_scheduler",
210 | type=str,
211 | default="constant",
212 | help=(
213 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
214 | ' "constant", "constant_with_warmup"]'
215 | ),
216 | )
217 | parser.add_argument(
218 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
219 | )
220 | parser.add_argument(
221 | "--lr_num_cycles",
222 | type=int,
223 | default=1,
224 | help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
225 | )
226 | parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
227 | parser.add_argument(
228 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
229 | )
230 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
231 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
232 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
233 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
234 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
235 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
236 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
237 | parser.add_argument(
238 | "--hub_model_id",
239 | type=str,
240 | default=None,
241 | help="The name of the repository to keep in sync with the local `output_dir`.",
242 | )
243 | parser.add_argument(
244 | "--logging_dir",
245 | type=str,
246 | default="logs",
247 | help=(
248 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
249 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
250 | ),
251 | )
252 | parser.add_argument(
253 | "--allow_tf32",
254 | action="store_true",
255 | help=(
256 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
257 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
258 | ),
259 | )
260 | parser.add_argument(
261 | "--report_to",
262 | type=str,
263 | default="tensorboard",
264 | help=(
265 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
266 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
267 | ),
268 | )
269 | parser.add_argument(
270 | "--mixed_precision",
271 | type=str,
272 | default=None,
273 | choices=["no", "fp16", "bf16"],
274 | help=(
275 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
276 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
277 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
278 | ),
279 | )
280 | parser.add_argument(
281 | "--prior_generation_precision",
282 | type=str,
283 | default=None,
284 | choices=["no", "fp32", "fp16", "bf16"],
285 | help=(
286 | "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
287 | " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
288 | ),
289 | )
290 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
291 | parser.add_argument(
292 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
293 | )
294 |
295 | if input_args is not None:
296 | args = parser.parse_args(input_args)
297 | else:
298 | args = parser.parse_args()
299 |
300 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
301 | if env_local_rank != -1 and env_local_rank != args.local_rank:
302 | args.local_rank = env_local_rank
303 |
304 | return args
305 |
306 |
307 |
308 | #%%
309 | class DreamBoothDatasetAge(Dataset):
310 | """
311 | return instance (img, token of 'photo of a xx year old person')
312 | and (img, token of 'photo of a person')
313 | """
314 | def __init__(
315 | self,
316 | instance_data_root,
317 | instance_age_path, # numpy array that stores the age of training images
318 | tokenizer,
319 | instance_prompt="photo of a person",
320 | size=512,
321 | center_crop=False,
322 | ):
323 |
324 | self.size = size
325 | self.center_crop = center_crop
326 | self.tokenizer = tokenizer
327 |
328 | self.instance_data_root = Path(instance_data_root)
329 | if not self.instance_data_root.exists():
330 | raise ValueError("Instance images root doesn't exists.")
331 |
332 | if self.instance_data_root.is_file():
333 | self.instance_images_path = [Path(instance_data_root)]
334 | else:
335 | self.instance_images_path = [os.path.join(instance_data_root,filename) for filename in os.listdir(instance_data_root)]
336 |
337 |
338 | self.num_instance_images = len(self.instance_images_path)
339 | self.instance_prompt = instance_prompt
340 | self._length = self.num_instance_images
341 |
342 | self.age_labels = np.load(instance_age_path)
343 |
344 |
345 | self.image_transforms = transforms.Compose(
346 | [
347 | transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
348 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
349 | transforms.ToTensor(),
350 | transforms.Normalize([0.5], [0.5]),
351 | ]
352 | )
353 |
354 |
355 | def __len__(self):
356 | return self._length
357 |
358 | def __getitem__(self, index):
359 | example = {}
360 |
361 | instance_image_path = self.instance_images_path[index % self.num_instance_images]
362 | instance_image = Image.open(instance_image_path)
363 | if not instance_image.mode == "RGB":
364 | instance_image = instance_image.convert("RGB")
365 |
366 | example["instance_images"] = self.image_transforms(instance_image)
367 |
368 | example["instance_prompt"] = self.instance_prompt
369 | example["instance_prompt_ids"] = self.tokenizer( # token(photo of a person)
370 | self.instance_prompt,
371 | truncation=True,
372 | padding="max_length",
373 | max_length=self.tokenizer.model_max_length,
374 | return_tensors="pt",
375 | ).input_ids
376 |
377 | instance_image_name = instance_image_path.split('/')[-1]
378 | instance_image_age = int(self.age_labels[self.age_labels[:, 0] == instance_image_name][0, 1])
379 |
380 | example["instance_image_age"] = instance_image_age
381 | instance_age_prompt = self.instance_prompt.replace(" a ", f" a {instance_image_age} year old ")
382 |
383 | example["instance_age_prompt"] = instance_age_prompt
384 | example["instance_age_prompt_ids"] = self.tokenizer( # token(photo of a xx year old person)
385 | instance_age_prompt,
386 | truncation=True,
387 | padding="max_length",
388 | max_length=self.tokenizer.model_max_length,
389 | return_tensors="pt",
390 | ).input_ids
391 |
392 | example["blank_prompt_ids"] = self.tokenizer( # token("")
393 | "",
394 | truncation=True,
395 | padding="max_length",
396 | max_length=self.tokenizer.model_max_length,
397 | return_tensors="pt",
398 | ).input_ids
399 |
400 |
401 | return example
402 | #%
403 |
404 | def collate_fn(examples, finetune_mode="finetune_double_prompt"):
405 | if len(examples)!=1:
406 | raise ValueError("batchsize can only be 1..")
407 |
408 | if finetune_mode == "finetune_double_prompt":
409 | input_ids = [example["instance_prompt_ids"] for example in examples]
410 | input_ids += [example["instance_age_prompt_ids"] for example in examples]
411 |
412 | elif finetune_mode == "finetune_single_prompt":
413 | input_ids = [example["instance_age_prompt_ids"] for example in examples]
414 | input_ids += [example["instance_age_prompt_ids"] for example in examples]
415 |
416 | # elif finetune_mode == "finetune_single_prompt_no_age":
417 | # input_ids = [example["instance_prompt_ids"] for example in examples]
418 | # input_ids += [example["instance_prompt_ids"] for example in examples]
419 | #
420 | # elif finetune_mode == "finetune_no_prompt":
421 | # input_ids = [example["blank_prompt_ids"] for example in examples]
422 | # input_ids += [example["blank_prompt_ids"] for example in examples]
423 |
424 | else:
425 | raise ValueError("invalid finetune_mode")
426 |
427 | pixel_values = [example["instance_images"] for example in examples]
428 | pixel_values += [example["instance_images"] for example in examples]
429 |
430 | pixel_values_ages = [example["instance_image_age"] for example in examples]
431 |
432 | pixel_values = torch.stack(pixel_values)
433 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
434 |
435 | input_ids = torch.cat(input_ids, dim=0)
436 |
437 | batch = {
438 | "input_ids": input_ids,
439 | "pixel_values": pixel_values,
440 |
441 | "pixel_values_ages": pixel_values_ages,
442 | }
443 |
444 | return batch
445 |
446 |
447 |
448 | def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
449 | if token is None:
450 | token = HfFolder.get_token()
451 | if organization is None:
452 | username = whoami(token)["name"]
453 | return f"{username}/{model_id}"
454 | else:
455 | return f"{organization}/{model_id}"
456 |
457 |
458 | #%%
459 | def main(args):
460 | argparse_dict = vars(args)
461 |
462 | argparse_json = json.dumps(argparse_dict, indent=4)
463 | print("model configuration:\n", argparse_json)
464 | os.makedirs(args.output_dir, exist_ok=True)
465 | with open(os.path.join(args.output_dir, "model_config.json"), "w") as outfile:
466 | outfile.write(argparse_json)
467 |
468 | logging_dir = Path(args.output_dir, args.logging_dir)
469 |
470 | accelerator = Accelerator(
471 | gradient_accumulation_steps=args.gradient_accumulation_steps,
472 | mixed_precision=args.mixed_precision,
473 | log_with=args.report_to,
474 | logging_dir=logging_dir,
475 | )
476 |
477 | # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
478 | # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
479 | # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
480 | if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
481 | raise ValueError(
482 | "Gradient accumulation is not supported when training the text encoder in distributed training. "
483 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
484 | )
485 |
486 | # Make one log on every process with the configuration for debugging.
487 | logging.basicConfig(
488 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
489 | datefmt="%m/%d/%Y %H:%M:%S",
490 | level=logging.INFO,
491 | )
492 | logger.info(accelerator.state, main_process_only=False)
493 | if accelerator.is_local_main_process:
494 | # datasets.utils.logging.set_verbosity_warning()
495 | transformers.utils.logging.set_verbosity_warning()
496 | diffusers.utils.logging.set_verbosity_info()
497 | else:
498 | # datasets.utils.logging.set_verbosity_error()
499 | transformers.utils.logging.set_verbosity_error()
500 | diffusers.utils.logging.set_verbosity_error()
501 |
502 | # If passed along, set the training seed now.
503 | if args.seed is not None:
504 | set_seed(args.seed)
505 |
506 | # Handle the repository creation
507 | if accelerator.is_main_process:
508 | if args.push_to_hub:
509 | if args.hub_model_id is None:
510 | repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
511 | else:
512 | repo_name = args.hub_model_id
513 | repo = Repository(args.output_dir, clone_from=repo_name)
514 |
515 | with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
516 | if "step_*" not in gitignore:
517 | gitignore.write("step_*\n")
518 | if "epoch_*" not in gitignore:
519 | gitignore.write("epoch_*\n")
520 | elif args.output_dir is not None:
521 | os.makedirs(args.output_dir, exist_ok=True)
522 |
523 | # Load the tokenizer
524 | if args.tokenizer_name:
525 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
526 | elif args.pretrained_model_name_or_path:
527 | tokenizer = AutoTokenizer.from_pretrained(
528 | args.pretrained_model_name_or_path,
529 | subfolder="tokenizer",
530 | revision=args.revision,
531 | use_fast=False,
532 | )
533 |
534 | # import correct text encoder class
535 | text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
536 |
537 | # Load scheduler and models
538 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
539 |
540 |
541 | def pred_original_samples(samples: torch.FloatTensor,
542 | noise: torch.FloatTensor,
543 | timesteps: torch.IntTensor):
544 | sqrt_alpha_prod = noise_scheduler.alphas_cumprod[timesteps] ** 0.5
545 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
546 | while len(sqrt_alpha_prod.shape) < len(samples.shape):
547 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
548 |
549 | sqrt_one_minus_alpha_prod = (1 - noise_scheduler.alphas_cumprod[timesteps]) ** 0.5
550 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
551 | while len(sqrt_one_minus_alpha_prod.shape) < len(samples.shape):
552 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
553 |
554 | original_samples = (samples - sqrt_one_minus_alpha_prod * noise) / sqrt_alpha_prod
555 | return original_samples
556 |
557 | text_encoder = text_encoder_cls.from_pretrained(
558 | args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
559 | )
560 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
561 | unet = UNet2DConditionModel.from_pretrained(
562 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
563 | )
564 |
565 | vae.requires_grad_(False)
566 | if not args.train_text_encoder:
567 | text_encoder.requires_grad_(False)
568 |
569 | if args.enable_xformers_memory_efficient_attention:
570 | if is_xformers_available():
571 | unet.enable_xformers_memory_efficient_attention()
572 | else:
573 | raise ValueError("xformers is not available. Make sure it is installed correctly")
574 |
575 | if args.gradient_checkpointing:
576 | unet.enable_gradient_checkpointing()
577 | if args.train_text_encoder:
578 | text_encoder.gradient_checkpointing_enable()
579 |
580 | # Enable TF32 for faster training on Ampere GPUs,
581 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
582 | if args.allow_tf32:
583 | torch.backends.cuda.matmul.allow_tf32 = True
584 |
585 | if args.scale_lr:
586 | args.learning_rate = (
587 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
588 | )
589 |
590 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
591 | if args.use_8bit_adam:
592 | try:
593 | import bitsandbytes as bnb
594 | except ImportError:
595 | raise ImportError(
596 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
597 | )
598 |
599 | optimizer_class = bnb.optim.AdamW8bit
600 | else:
601 | optimizer_class = torch.optim.AdamW
602 |
603 | # Optimizer creation
604 | params_to_optimize = (
605 | itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
606 | )
607 | optimizer = optimizer_class(
608 | params_to_optimize,
609 | lr=args.learning_rate,
610 | betas=(args.adam_beta1, args.adam_beta2),
611 | weight_decay=args.adam_weight_decay,
612 | eps=args.adam_epsilon,
613 | )
614 |
615 |
616 | train_dataset = DreamBoothDatasetAge(
617 | instance_data_root = args.instance_data_dir,
618 | instance_age_path = args.instance_age_path,
619 | tokenizer = tokenizer,
620 | instance_prompt=args.instance_prompt,
621 | size=args.resolution,
622 | center_crop=args.center_crop,
623 | )
624 | train_dataloader = torch.utils.data.DataLoader(
625 | train_dataset,
626 | batch_size=args.train_batch_size,
627 | shuffle=False,
628 | collate_fn=lambda examples: collate_fn(examples, args.finetune_mode),
629 | num_workers=1,
630 | )
631 |
632 | # Scheduler and math around the number of training steps.
633 | overrode_max_train_steps = False
634 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
635 | if args.max_train_steps is None:
636 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
637 | overrode_max_train_steps = True
638 |
639 | lr_scheduler = get_scheduler(
640 | args.lr_scheduler,
641 | optimizer=optimizer,
642 | num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
643 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
644 | # num_cycles=args.lr_num_cycles,
645 | # power=args.lr_power,
646 | )
647 |
648 | # Prepare everything with our `accelerator`.
649 | if args.train_text_encoder:
650 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
651 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler
652 | )
653 | else:
654 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
655 | unet, optimizer, train_dataloader, lr_scheduler
656 | )
657 |
658 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
659 | # as these models are only used for inference, keeping weights in full precision is not required.
660 | weight_dtype = torch.float32
661 | if accelerator.mixed_precision == "fp16":
662 | weight_dtype = torch.float16
663 | elif accelerator.mixed_precision == "bf16":
664 | weight_dtype = torch.bfloat16
665 |
666 | # Move vae and text_encoder to device and cast to weight_dtype
667 | vae.to(accelerator.device, dtype=weight_dtype)
668 | if not args.train_text_encoder:
669 | text_encoder.to(accelerator.device, dtype=weight_dtype)
670 |
671 | low_precision_error_string = (
672 | "Please make sure to always have all model weights in full float32 precision when starting training - even if"
673 | " doing mixed precision training. copy of the weights should still be float32."
674 | )
675 |
676 | if unet.dtype != torch.float32:
677 | raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}")
678 |
679 | if args.train_text_encoder and text_encoder.dtype != torch.float32:
680 | raise ValueError(f"Text encoder loaded as datatype {text_encoder.dtype}. {low_precision_error_string}")
681 |
682 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
683 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
684 | if overrode_max_train_steps:
685 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
686 | # Afterwards we recalculate our number of training epochs
687 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
688 |
689 | # We need to initialize the trackers we use, and also store our configuration.
690 | # The trackers initializes automatically on the main process.
691 | if accelerator.is_main_process:
692 | accelerator.init_trackers("dreambooth", config=vars(args))
693 |
694 | # Train!
695 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
696 |
697 | logger.info("***** Running training *****")
698 | logger.info(f" Num examples = {len(train_dataset)}")
699 | logger.info(f" Num batches each epoch = {len(train_dataloader)}")
700 | logger.info(f" Num Epochs = {args.num_train_epochs}")
701 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
702 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
703 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
704 | logger.info(f" Total optimization steps = {args.max_train_steps}")
705 | global_step = 0
706 | first_epoch = 0
707 |
708 | # Potentially load in the weights and states from a previous save
709 | if args.resume_from_checkpoint:
710 | if args.resume_from_checkpoint != "latest":
711 | path = os.path.basename(args.resume_from_checkpoint)
712 | else:
713 | # Get the mos recent checkpoint
714 | dirs = os.listdir(args.output_dir)
715 | dirs = [d for d in dirs if d.startswith("checkpoint")]
716 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
717 | path = dirs[-1]
718 | accelerator.print(f"Resuming from checkpoint {path}")
719 | accelerator.load_state(os.path.join(args.output_dir, path))
720 | global_step = int(path.split("-")[1])
721 |
722 | resume_global_step = global_step * args.gradient_accumulation_steps
723 | first_epoch = resume_global_step // num_update_steps_per_epoch
724 | resume_step = resume_global_step % num_update_steps_per_epoch
725 |
726 | # Only show the progress bar once on each machine.
727 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
728 | progress_bar.set_description("Steps")
729 |
730 | for epoch in range(first_epoch, args.num_train_epochs):
731 | unet.train()
732 | if args.train_text_encoder:
733 | text_encoder.train()
734 |
735 | for step, batch in enumerate(train_dataloader):
736 | # Skip steps until we reach the resumed step
737 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
738 | if step % args.gradient_accumulation_steps == 0:
739 | progress_bar.update(1)
740 | continue
741 |
742 | with accelerator.accumulate(unet):
743 | # Convert images to latent space
744 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
745 | latents = latents * 0.18215
746 |
747 | # Sample noise that we'll add to the latents
748 | noise = torch.randn_like(latents)
749 |
750 | noise_single = torch.randn_like(latents[0])
751 | bsz = latents.shape[0]
752 | noise = torch.cat([noise_single.unsqueeze(0)] * bsz, dim=0)
753 |
754 | # Sample a random timestep for each image
755 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
756 | timesteps = timesteps.long()
757 |
758 | # Add noise to the latents according to the noise magnitude at each timestep
759 | # (this is the forward diffusion process)
760 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
761 |
762 | # Get the text embedding for conditioning
763 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
764 |
765 | # Predict the noise residual
766 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
767 |
768 | # Get the target for loss depending on the prediction type
769 | if noise_scheduler.config.prediction_type == "epsilon":
770 | target = noise
771 | elif noise_scheduler.config.prediction_type == "v_prediction":
772 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
773 | else:
774 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
775 |
776 | pred_denoised_latents = pred_original_samples(samples = noisy_latents,
777 | noise = model_pred,
778 | timesteps = timesteps)
779 |
780 | # noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
781 | pred_noisy_latents = noise_scheduler.add_noise(latents, model_pred.to(dtype=weight_dtype), timesteps)
782 |
783 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
784 | instance_loss = loss
785 |
786 | # Checks if the accelerator has performed an optimization step behind the scenes
787 | if accelerator.sync_gradients:
788 | progress_bar.update(1)
789 | global_step += 1
790 |
791 | if global_step % args.checkpointing_steps == 0:
792 | if accelerator.is_main_process:
793 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
794 | accelerator.save_state(save_path)
795 | logger.info(f"Saved state to {save_path}")
796 |
797 | logs = {"loss": loss.detach().item(),
798 | "instance_loss": instance_loss.detach().item(),
799 | "lr": lr_scheduler.get_last_lr()[0]}
800 |
801 | progress_bar.set_postfix(**logs)
802 | accelerator.log(logs, step=global_step)
803 |
804 | if global_step >= args.max_train_steps:
805 | break
806 |
807 | # Create the pipeline using using the trained modules and save it.
808 | accelerator.wait_for_everyone()
809 | if accelerator.is_main_process:
810 | pipeline = DiffusionPipeline.from_pretrained(
811 | args.pretrained_model_name_or_path,
812 | unet=accelerator.unwrap_model(unet),
813 | text_encoder=accelerator.unwrap_model(text_encoder),
814 | revision=args.revision,
815 | )
816 | pipeline.save_pretrained(args.output_dir)
817 |
818 | if args.push_to_hub:
819 | repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
820 |
821 | accelerator.end_training()
822 |
823 | if __name__ == "__main__":
824 | args = parse_args()
825 |
826 | # args.instance_data_dir = 'specialization_data/training_images'
827 | # args.instance_age_path = 'specialization_data/training_ages.npy'
828 | #
829 | # args.output_dir = 'specialized_models/tmp'
830 |
831 | main(args)
832 |
833 |
834 |
835 |
836 |
--------------------------------------------------------------------------------