183 | Text-to-image generative models have become a prominent and powerful tool that excels at generating
184 | high-resolution realistic images. However, guiding the generative process of these models to consider
185 | detailed forms of conditioning reflecting style and/or structure information remains an open problem. In
186 | this paper, we present LoRAdapter, an approach that unifies both style and structure
187 | conditioning under the same formulation using a novel conditional LoRA block that enables zero-shot
188 | control. LoRAdapter is an efficient, powerful, and architecture-agnostic approach to condition
189 | text-to-image diffusion models, which enables fine-grained control conditioning during generation and
190 | outperforms recent state-of-the-art approaches.
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
How it works
205 |
206 |
207 |
208 | Following the standard LoRA method, we keep the original weight matrix W_0 frozen and add two new
209 | trainable
210 | weight matrices A and B for each layer (i) that we want to adapt.
211 | Usually, we would train A and B on a small dataset to capture a specific style or subject, resulting in an
212 | adapter that is fixed at inference time.
213 | However, we propose to dynamically apply a transformation φ on the embedding of the first LoRA matrix A.
214 | In
215 | practice, we
216 | implement φ as an affine transformation with scale and shift parameter γ and β, respectively. These are
217 | predicted by a mapping network that depend on the conditioning c.
218 |
219 |
220 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
Qualitative Comparison
242 |
243 |
Style
244 |
245 |
246 | Samples from our method with style conditioning compared against other methods. We used an empty prompt
247 | and only conditioned on the image. We generally perform on par with IP-Adapter and outperform it on some
248 | samples. Note that the third image from the left is less degraded, and the third image from the right
249 | captures the mane of the horse better.
250 |
251 |
Structure
252 |
253 |
254 |
255 | Samples from our method with structural conditioning compared against other methods. Note that for our
256 | method, especially compared with T2I Adapter, the details of the images are substantially more closely
257 | aligned with the depth prompt (see e.g. the lamp in the background of the living room scene and the side
258 | table's legs, or the salad on the pizza)
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
Quantitative Comparison
272 |
273 |
Style
274 |
275 |
276 | Best results are in bold. LoRAdapter needs the fewest parameters and is able to achieve
277 | state-of-the-art
278 | performance while also enabling direct structure control.
279 |
280 |
Structure
281 |
282 |
283 |
284 | Best results are in bold. We evaluate cycle consistency (MSE-d), FID and LPIPS. The difference
285 | between configuration A and B is the number of layers that are adapted resulting in a different number of
286 | parameters. LoRAdapter outperforms all other methods in all metrics.
287 |
360 | @misc{stracke2024loradapter,
361 | title={CTRLorALTer: Conditional LoRAdapter for Efficient 0-Shot Control & Altering of T2I Models},
362 | author={Nick Stracke and Stefan Andreas Baumann and Joshua Susskind and Miguel Angel Bautista and Björn Ommer},
363 | year={2024},
364 | eprint={2405.07913},
365 | archivePrefix={arXiv},
366 | primaryClass={cs.CV}
367 | }
368 |
369 |
370 |
371 |
372 |
373 |
374 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 | from typing import Union, Literal
3 | import torch
4 | from torch import nn
5 | from src.utils import DataProvider
6 | import src.lora as loras
7 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline
8 | from diffusers import AutoencoderTiny
9 | from src.utils import getattr_recursive
10 |
11 | import torch.nn.functional as F
12 | from pydoc import locate
13 | from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
14 |
15 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
16 | retrieve_timesteps,
17 | )
18 |
19 | from tqdm.auto import tqdm
20 | import random
21 |
22 | from diffusers import ControlNetModel
23 | from torchvision.transforms import Compose
24 | from typing import Callable
25 |
26 | ATTENTION_MODULES = ["to_k", "to_v"]
27 |
28 | # only for SD15
29 | CONV_MODULES = ["conv1", "conv2"]
30 |
31 | ADAPTION_MODE = Literal[
32 | "full_attention",
33 | "only_self",
34 | "only_cross",
35 | "only_conv",
36 | "only_first_conv",
37 | "only_res_conv",
38 | "full",
39 | "no_cross",
40 | "only_value",
41 | # below only works for sdxl
42 | "b-lora_style",
43 | "b-lora_content",
44 | "b-lora",
45 | "sdxl_cross",
46 | "sdxl_self",
47 | "sdxl_inner",
48 | ]
49 |
50 | CONDITION_MODE = Literal["style", "structure"]
51 |
52 |
53 | class ModelBase(ABC, nn.Module):
54 |
55 | def __init__(
56 | self,
57 | pipeline_type: str,
58 | model_name: str,
59 | local_files_only: bool = True,
60 | c_dropout: float = 0.05,
61 | guidance_scale: float = 7.5,
62 | use_controlnet: bool = False,
63 | annotator: None | nn.Module = None,
64 | tiny_vae: bool = False,
65 | ) -> None:
66 | super().__init__()
67 | self.params_to_optimize: list[nn.Parameter] = []
68 | self.lora_state_dict_keys: dict[str, list[str]] = {}
69 | self.lora_layers: dict[str, list[nn.Module]] = {}
70 | self.lora_transforms: list[Compose | None] = []
71 |
72 | self.encoders: list[nn.Module] = list()
73 | self.mappers: list[nn.Module] = list()
74 | self.dps: list[DataProvider] = []
75 |
76 | self.tiny_vae = tiny_vae
77 | self.c_dropout = c_dropout
78 | self.guidance_scale = guidance_scale
79 | self.use_controlnet = use_controlnet
80 |
81 | addition_config = {}
82 |
83 | # Note that this requires the controlnet pipe which also has to be set in the config
84 |
85 | if tiny_vae:
86 | vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", local_files_only=local_files_only)
87 | addition_config["vae"] = vae
88 |
89 | if self.use_controlnet:
90 | assert annotator is not None, "Need annotator for controlnet"
91 |
92 | controlnet = ControlNetModel.from_pretrained(
93 | "lllyasviel/sd-controlnet-depth",
94 | use_safetensors=True,
95 | local_files_only=local_files_only,
96 | **addition_config,
97 | )
98 | # controlnet.requires_grad_(False)
99 | # controlnet.eval()
100 | addition_config["controlnet"] = controlnet
101 |
102 | # fix this cheap work around!
103 | self.encoders.append(annotator)
104 | self.mappers.append(controlnet)
105 |
106 | self.pipe: DiffusionPipeline = locate(pipeline_type).from_pretrained(
107 | model_name,
108 | local_files_only=local_files_only,
109 | safety_checker=None, # too anoying
110 | safe_tensors=True,
111 | **addition_config,
112 | )
113 | assert isinstance(self.pipe, DiffusionPipeline)
114 |
115 | self.noise_scheduler = DDPMScheduler.from_config(
116 | {**self.pipe.scheduler.config, "rescale_betas_zero_snr": False},
117 | subfolder="scheduler",
118 | )
119 |
120 | self.max_depth = len(self.pipe.unet.config["block_out_channels"]) - 1
121 |
122 | # we register all the individual pipeline modules here
123 | # such that all the typical calls like .to and .prepare effect them.
124 | self.unet = self.pipe.unet
125 | self.unet.requires_grad_(False)
126 |
127 | self.vae = self.pipe.vae
128 | self.text_encoder = self.pipe.text_encoder
129 | self.tokenizer = self.pipe.tokenizer
130 |
131 | self.vae = self.pipe.vae
132 | self.text_encoder.requires_grad_(False)
133 |
134 | # handle sdxl case
135 | if hasattr(self.pipe, "text_encoder_2"):
136 | self.text_encoder_2 = self.pipe.text_encoder_2
137 | self.text_encoder_2.requires_grad_(False)
138 |
139 | def add_lora_to_unet(
140 | self,
141 | config: dict,
142 | name: str,
143 | data_provider: DataProvider,
144 | encoder: nn.Module,
145 | mapper: nn.Module,
146 | optimize: bool = True,
147 | transforms: list[Callable] = [],
148 | ):
149 | self.rank = config.rank
150 | self.c_dim = config.c_dim
151 | unet = self.unet
152 | sd = unet.state_dict()
153 |
154 | self.mappers.append(mapper)
155 | self.encoders.append(encoder)
156 | self.dps.append(data_provider)
157 |
158 | self.lora_transforms.append(Compose(transforms) if len(transforms) > 0 else None)
159 |
160 | print(f"adding {len(transforms)} transforms to LoRA {name}")
161 |
162 | lora_cls = config.lora_cls
163 | adaption_mode = config.adaption_mode
164 |
165 | if not optimize:
166 | mapper.eval()
167 | mapper.requires_grad_(False)
168 |
169 | local_lora_sd_keys: list[str] = []
170 |
171 | for path, w in sd.items():
172 | class_config = {**config}
173 | del class_config["lora_cls"]
174 | del class_config["adaption_mode"]
175 |
176 | _continue = True
177 | if adaption_mode == "full_attention" and "attn" in path:
178 | _continue = False
179 |
180 | if adaption_mode == "only_self" and "attn1" in path:
181 | _continue = False
182 |
183 | if adaption_mode == "only_cross" and "attn2" in path:
184 | _continue = False
185 |
186 | if adaption_mode == "only_conv" and ("conv1" in path or "conv2" in path):
187 | _continue = False
188 |
189 | # only the first conv layer in each resnet block
190 | if adaption_mode == "only_first_conv" and "0.conv1" in path:
191 | _continue = False
192 |
193 | if adaption_mode == "only_res_conv" and ("0.conv1" in path or "1.conv1" in path):
194 | _continue = False
195 |
196 | if adaption_mode == "full" and ("attn" in path or "conv" in path):
197 | _continue = False
198 |
199 | if adaption_mode == "no_cross" and "attn2" not in path:
200 | _continue = False
201 |
202 | if adaption_mode == "b-lora_content" and ("up_blocks.0.attentions.0" in path and "attn" in path):
203 | _continue = False
204 |
205 | if adaption_mode == "b-lora_style" and ("up_blocks.0.attentions.1" in path and "attn" in path):
206 | _continue = False
207 |
208 | if (
209 | adaption_mode == "b-lora"
210 | and ("up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path)
211 | and "attn" in path
212 | ):
213 | _continue = False
214 |
215 | # supposed setting content to have no effect
216 | # if "up_blocks.0.attentions.0" in path:
217 | # class_config["lora_scale"] = 0.0
218 |
219 | # "down_blocks.2.attentions.1" in path or
220 | if (
221 | adaption_mode == "sdxl_inner"
222 | and ("mid_block" in path or "up_blocks.0.attentions.0" in path or "up_blocks.0.attentions.1" in path)
223 | and "attn2" in path
224 | ):
225 | _continue = False
226 |
227 | if (
228 | adaption_mode == "sdxl_cross"
229 | and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path)
230 | and "attn2" in path
231 | ):
232 | _continue = False
233 |
234 | if (
235 | adaption_mode == "sdxl_self"
236 | and ("down_blocks.2" in path or "up_blocks.0" in path or "mid_block" in path)
237 | and "attn1" in path
238 | ):
239 | _continue = False
240 |
241 | if _continue:
242 | continue
243 |
244 | if "bias" in path:
245 | # we handle the bias together with the weight
246 | # this is only relevant for the conv layers
247 | continue
248 |
249 | parent_path = ".".join(path.split(".")[:-2])
250 | target_path = ".".join(path.split(".")[:-1])
251 | target_name = path.split(".")[-2]
252 | parent_module = getattr_recursive(unet, parent_path)
253 | target_module = getattr_recursive(unet, target_path)
254 |
255 | if "mid_block" in path:
256 | depth = self.max_depth
257 | elif "down_blocks" in path:
258 | depth = int(path.split("down_blocks.")[1][0])
259 | elif "up_blocks" in path:
260 | depth = self.max_depth - int(path.split("up_blocks.")[1][0])
261 | else:
262 | raise ValueError(f"Unknown module {path}")
263 |
264 | lora = None
265 | if "attn" in path:
266 | if not any([m in path for m in ATTENTION_MODULES]):
267 | continue
268 |
269 | lora = getattr(loras, lora_cls)(
270 | out_features=target_module.out_features,
271 | in_features=target_module.in_features,
272 | data_provider=data_provider,
273 | depth=depth,
274 | **class_config,
275 | )
276 |
277 | # W is the original weight matrix
278 | # those layers have no bias
279 | lora.W.load_state_dict({path.split(".")[-1]: w})
280 |
281 | if lora_cls == "IPLinear":
282 | # for faster convergence
283 | lora.W_IP.load_state_dict({path.split(".")[-1]: w})
284 |
285 | if "conv" in path:
286 | lora = getattr(loras, lora_cls)(
287 | in_channels=target_module.in_channels,
288 | out_channels=target_module.out_channels,
289 | kernel_size=target_module.kernel_size,
290 | stride=target_module.stride,
291 | padding=target_module.padding,
292 | data_provider=data_provider,
293 | depth=depth,
294 | **class_config,
295 | )
296 |
297 | # find bias term
298 | bias_path = ".".join(path.split(".")[:-1] + ["bias"])
299 | b = sd[bias_path]
300 | lora.W.load_state_dict({path.split(".")[-1]: w, "bias": b})
301 |
302 | if lora is None:
303 | raise ValueError(f"Unknown module {path}")
304 |
305 | for k in lora.state_dict().keys():
306 | # W is by design the original weight matrix which we don't need to save
307 | if k.split(".")[0] == "W":
308 | continue
309 |
310 | local_lora_sd_keys.append(f"{target_path}.{k}")
311 |
312 | self.lora_state_dict_keys[name] = local_lora_sd_keys
313 |
314 | setattr(
315 | parent_module,
316 | target_name,
317 | lora,
318 | )
319 |
320 | if optimize:
321 | for p in lora.parameters():
322 | if p.requires_grad:
323 | self.params_to_optimize.append(p)
324 | else:
325 | lora.eval()
326 | for p in lora.parameters():
327 | p.requires_grad_(False)
328 |
329 | self.lora_layers[name] = [lora] + self.lora_layers.get(name, [])
330 |
331 | def get_lora_state_dict(self, unet: Union[nn.Module, None] = None):
332 | lora_sd = {}
333 |
334 | if unet is None:
335 | unet = self.unet
336 |
337 | for k, v in unet.state_dict().items():
338 | for n, keys in self.lora_state_dict_keys.items():
339 | if n not in lora_sd:
340 | lora_sd[n] = {}
341 |
342 | if k in keys:
343 | lora_sd[n][k] = v.cpu()
344 |
345 | return lora_sd
346 |
347 | @abstractmethod
348 | def get_input(self, imgs: torch.Tensor, prompts: list[str]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
349 | raise NotImplementedError()
350 |
351 | # -> epsilon, loss, x0
352 | def forward_easy(self, *args, **kwargs) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
353 | return self(args, kwargs)
354 |
355 | def sample(self, *args, **kwargs):
356 | return self.pipe(*args, **kwargs).images
357 |
358 |
359 | class SD15(ModelBase):
360 | def __init__(self, pipeline_type, model_name, *args, **kwargs) -> None:
361 | super().__init__(pipeline_type, model_name, *args, **kwargs)
362 |
363 | @torch.no_grad()
364 | def get_input(self, imgs: torch.Tensor, prompts: list[str]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
365 | assert len(imgs.shape) == 4
366 | assert imgs.min() >= -1.0
367 | assert imgs.max() <= 1.0
368 |
369 | imgs = imgs.clip(-1.0, 1.0)
370 |
371 | # Convert images to latent space
372 | if self.tiny_vae:
373 | latents = self.vae.encode(imgs).latents
374 | else:
375 | latents = self.vae.encode(imgs).latent_dist.sample()
376 |
377 | latents = latents * self.vae.config.scaling_factor
378 |
379 | # prompt dropout
380 | prompts = ["" if random.random() < self.c_dropout else p for p in prompts]
381 |
382 | # prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
383 | # prompt=prompts,
384 | # device=self.unet.device,
385 | # num_images_per_prompt=1,
386 | # do_classifier_free_guidance=False,
387 | # )
388 |
389 | # do it manually to avoid stupid warnings
390 | input_ids = self.tokenizer(
391 | prompts,
392 | truncation=True,
393 | padding="max_length",
394 | max_length=self.tokenizer.model_max_length,
395 | return_tensors="pt",
396 | ).input_ids
397 | prompt_embeds = self.text_encoder(input_ids.to(imgs.device))["last_hidden_state"]
398 |
399 | # assert (prompt_embeds - prompt_embeds).mean() < 1e-6
400 | # assert (prompt_embeds == prompt_embeds).all()
401 |
402 | c = {
403 | "prompt_embeds": prompt_embeds,
404 | }
405 |
406 | return latents, c
407 |
408 | def forward(
409 | self,
410 | latents: torch.Tensor,
411 | c: dict[str, torch.Tensor],
412 | cs: list[torch.Tensor],
413 | timesteps: torch.Tensor,
414 | noise: torch.Tensor,
415 | cfg_mask: list[bool] | None = None,
416 | skip_encode: bool = False,
417 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
418 |
419 | noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
420 | prompt_embeds = c["prompt_embeds"]
421 | bsz = latents.shape[0]
422 | encoders = self.encoders
423 | mappers = self.mappers
424 |
425 | additional_inputs = {}
426 | if self.use_controlnet:
427 | # controlnet related stuff is always at index 0
428 | cn_input = cs[0]
429 | cs = cs[1:]
430 |
431 | controlnet = mappers[0]
432 | mappers = mappers[1:]
433 |
434 | annotator = encoders[0]
435 | encoders = encoders[1:]
436 |
437 | with torch.no_grad():
438 | cn_cond = annotator(cn_input)
439 |
440 | down_block_res_samples, mid_block_res_sample = controlnet(
441 | noisy_latents,
442 | timesteps,
443 | encoder_hidden_states=prompt_embeds,
444 | controlnet_cond=cn_cond,
445 | conditioning_scale=1.0,
446 | return_dict=False,
447 | )
448 |
449 | additional_inputs["down_block_additional_residuals"] = down_block_res_samples
450 | additional_inputs["mid_block_additional_residual"] = mid_block_res_sample
451 |
452 | # add our lora conditioning
453 | # cs in [-1, 1]
454 | for i, (encoder, dp, mapper, lora_c) in enumerate(zip(encoders, self.dps, mappers, cs)):
455 | if cfg_mask is None or cfg_mask[i]:
456 | dropout_mask = torch.rand(bsz, device=lora_c.device) < self.c_dropout
457 |
458 | # apply dropout for cfg
459 | lora_c[dropout_mask] = torch.zeros_like(lora_c[dropout_mask])
460 |
461 | if skip_encode:
462 | cond = lora_c
463 | else:
464 | # some encoders we want to finetune
465 | # so no torch.no_grad() here
466 | # instead we set requires_grad in the corresponding classes/configs
467 | t = self.lora_transforms[i]
468 | if t is not None:
469 | lora_c = t(lora_c)
470 | cond = encoder(lora_c)
471 | mapped_cond = mapper(cond)
472 | dp.set_batch(mapped_cond)
473 |
474 | # Predict the noise residual
475 | model_pred = self.unet(
476 | noisy_latents, timesteps, encoder_hidden_states=prompt_embeds, **additional_inputs
477 | ).sample
478 |
479 | # get x0 prediction
480 | alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device=model_pred.device, dtype=model_pred.dtype)
481 |
482 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
483 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
484 | while len(sqrt_alpha_prod.shape) < len(model_pred.shape):
485 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
486 |
487 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
488 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
489 | while len(sqrt_one_minus_alpha_prod.shape) < len(model_pred.shape):
490 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
491 |
492 | x0 = (noisy_latents - sqrt_one_minus_alpha_prod * model_pred) / sqrt_alpha_prod
493 |
494 | # Get the target for loss depending on the prediction type
495 | if self.noise_scheduler.config.prediction_type == "epsilon":
496 | target = noise
497 | elif self.noise_scheduler.config.prediction_type == "v_prediction":
498 | target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
499 | else:
500 | raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
501 |
502 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
503 |
504 | return model_pred, loss, x0, cond
505 |
506 | def forward_easy(
507 | self,
508 | imgs: torch.Tensor,
509 | prompts: list[str],
510 | cs: list[torch.Tensor],
511 | cfg_mask: list[bool] | None = None,
512 | skip_encode: bool = False,
513 | batch: dict | None = None,
514 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
515 |
516 | latents, c = self.get_input(imgs, prompts)
517 |
518 | # Sample noise that we'll add to the latents
519 | noise = torch.randn_like(latents)
520 | bsz = latents.shape[0]
521 | # Sample a random timestep for each image
522 | timesteps = torch.randint(
523 | 0,
524 | self.noise_scheduler.config.num_train_timesteps,
525 | (bsz,),
526 | device=latents.device,
527 | )
528 | # timesteps = timesteps.long()
529 |
530 | return self(
531 | latents=latents,
532 | c=c,
533 | cs=cs,
534 | timesteps=timesteps,
535 | noise=noise,
536 | cfg_mask=cfg_mask,
537 | skip_encode=skip_encode,
538 | )
539 |
540 | @torch.no_grad()
541 | def sample_custom(
542 | self,
543 | prompt,
544 | num_images_per_prompt,
545 | cs: list[torch.Tensor],
546 | generator,
547 | cfg_mask: list[bool] | None = None,
548 | prompt_offset_step: int = 0,
549 | skip_encode: bool = False,
550 | **kwargs,
551 | ):
552 | height = self.unet.config.sample_size * self.pipe.vae_scale_factor
553 | width = self.unet.config.sample_size * self.pipe.vae_scale_factor
554 |
555 | if prompt is not None and isinstance(prompt, str):
556 | batch_size = 1
557 | elif prompt is not None and isinstance(prompt, list):
558 | batch_size = len(prompt)
559 |
560 | batch_size = batch_size * num_images_per_prompt
561 |
562 | device = self.unet.device
563 |
564 | prompt_embeds, negative_prompt_embeds = self.pipe.encode_prompt(
565 | prompt, device, num_images_per_prompt, True
566 | ) # do cfg
567 | dtype = prompt_embeds.dtype
568 |
569 | # for cfg
570 | c_prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]).to(dtype)
571 | uc_prompt_embeds = torch.cat([negative_prompt_embeds, negative_prompt_embeds]).to(dtype)
572 |
573 | # we have to do two separate forward passes for the cfg with the loras
574 | # add our lora conditioning
575 | for i, (encoder, dp, mapper, c) in enumerate(zip(self.encoders, self.dps, self.mappers, cs)):
576 |
577 | if c.shape[0] != batch_size:
578 | assert c.shape[0] == 1
579 | c = torch.cat(batch_size * [c]) # repeat along batch dim
580 |
581 | neg_c = torch.zeros_like(c)
582 | if cfg_mask is not None and not cfg_mask[i]:
583 | print("no cfg for lora nr", i)
584 | c = torch.cat([c, c])
585 | else:
586 | c = torch.cat([neg_c, c])
587 |
588 | if skip_encode:
589 | cond = c
590 | else:
591 | cond = encoder(c)
592 | mapped_cond = mapper(cond)
593 | if isinstance(mapped_cond, tuple) or isinstance(mapped_cond, list):
594 | mapped_cond = [mc.to(dtype) for mc in mapped_cond]
595 | else:
596 | mapped_cond = mapped_cond.to(dtype)
597 |
598 | dp.set_batch(mapped_cond)
599 |
600 | # 4. Prepare timesteps
601 | timesteps, num_inference_steps = retrieve_timesteps(self.pipe.scheduler, 50, device)
602 |
603 | # 5. Prepare latent variables
604 | num_channels_latents = 4 # self.unet.config.in_channels
605 | latents = self.pipe.prepare_latents(
606 | batch_size,
607 | num_channels_latents,
608 | height,
609 | width,
610 | c_prompt_embeds.dtype,
611 | device,
612 | generator,
613 | )
614 |
615 | for i, t in tqdm(enumerate(timesteps)):
616 | # cfg
617 | latent_model_input = torch.cat([latents] * 2)
618 |
619 | noise_pred = self.unet(
620 | latent_model_input,
621 | t,
622 | encoder_hidden_states=(c_prompt_embeds if i >= prompt_offset_step else uc_prompt_embeds),
623 | return_dict=False,
624 | )[0]
625 |
626 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
627 | noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
628 |
629 | latents = self.pipe.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
630 |
631 | latents = latents.to(torch.float32)
632 | image = self.vae.decode(
633 | latents / self.vae.config.scaling_factor,
634 | return_dict=False,
635 | generator=generator,
636 | )[0]
637 | do_denormalize = [True] * image.shape[0]
638 |
639 | image = self.pipe.image_processor.postprocess(image, output_type="pil", do_denormalize=do_denormalize)
640 |
641 | return image
642 |
643 | # with self.progress_bar(total=num_inference_steps) as progress_bar:
644 |
645 | @torch.no_grad()
646 | def sample(self, *args, **kwargs):
647 | return self.sample_easy(*args, **kwargs)
648 |
649 | @torch.no_grad()
650 | def sample_easy(
651 | self,
652 | prompt,
653 | num_images_per_prompt,
654 | cs: list[torch.Tensor],
655 | generator,
656 | cfg_mask: list[bool] | None = None,
657 | prompt_offset_step: int = 0,
658 | # dtype=torch.float32,
659 | **kwargs,
660 | ):
661 | if prompt is not None and isinstance(prompt, str):
662 | batch_size = 1
663 | elif prompt is not None and isinstance(prompt, list):
664 | batch_size = len(prompt)
665 |
666 | batch_size = batch_size * num_images_per_prompt
667 |
668 | mappers = self.mappers
669 | encoders = self.encoders
670 | if self.use_controlnet:
671 | # controlnet related stuff is always at index 0
672 | cn_input = cs[0]
673 | cs = cs[1:]
674 |
675 | mappers = mappers[1:]
676 |
677 | annotator = encoders[0]
678 | encoders = encoders[1:]
679 |
680 | with torch.no_grad():
681 | cn_cond = annotator(cn_input)
682 |
683 | kwargs["image"] = cn_cond
684 |
685 | # we have to do two separate forward passes for the cfg with the loras
686 | # add our lora conditioning
687 | for i, (encoder, dp, mapper, c) in enumerate(zip(encoders, self.dps, mappers, cs)):
688 |
689 | if c.shape[0] != batch_size:
690 | assert c.shape[0] == 1
691 | c = torch.cat(batch_size * [c]) # repeat along batch dim
692 |
693 | neg_c = torch.zeros_like(c)
694 | if cfg_mask is not None and not cfg_mask[i]:
695 | print("no cfg for lora nr", i)
696 | c = torch.cat([c, c])
697 | else:
698 | c = torch.cat([neg_c, c])
699 | cond = encoder(c)
700 | mapped_cond = mapper(cond)
701 | # if isinstance(mapped_cond, tuple) or isinstance(mapped_cond, list):
702 | # mapped_cond = [mc.to(dtype) for mc in mapped_cond]
703 | # else:
704 | # mapped_cond = mapped_cond.to(dtype)
705 |
706 | dp.set_batch(mapped_cond)
707 |
708 | return self.pipe(
709 | prompt=prompt,
710 | num_images_per_prompt=num_images_per_prompt,
711 | generator=generator,
712 | **kwargs,
713 | ).images
714 |
715 |
716 | class SDXL(ModelBase):
717 | def __init__(self, pipeline_type, model_name, *args, **kwargs) -> None:
718 | super().__init__(pipeline_type, model_name, *args, **kwargs)
719 |
720 | def get_input(self, imgs: torch.Tensor, prompts: list[str]) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
721 | raise NotImplementedError()
722 |
723 | def compute_time_ids(self, device, weight_dtype):
724 | # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
725 |
726 | # we could adjust this if we knew that we had cropped / shifted images
727 | original_size = (1024, 1024)
728 | target_size = (1024, 1024)
729 | crops_coords_top_left = (0, 0)
730 | add_time_ids = list(original_size + crops_coords_top_left + target_size)
731 | add_time_ids = torch.tensor([add_time_ids])
732 | add_time_ids = add_time_ids.to(device, dtype=weight_dtype)
733 | return add_time_ids
734 |
735 | def get_conditioning(
736 | self,
737 | prompts: list[str],
738 | bsz: int,
739 | device: torch.device,
740 | dtype: torch.dtype,
741 | do_cfg=False,
742 | ):
743 | add_time_ids = self.compute_time_ids(device, dtype)
744 | negative_add_time_ids = add_time_ids # no conditioning for now
745 |
746 | (
747 | prompt_embeds,
748 | negative_prompt_embeds,
749 | pooled_prompt_embeds,
750 | negative_pooled_prompt_embeds,
751 | ) = self.pipe.encode_prompt(
752 | prompt=prompts,
753 | device=device,
754 | num_images_per_prompt=1,
755 | do_classifier_free_guidance=do_cfg,
756 | )
757 |
758 | if do_cfg:
759 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
760 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
761 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
762 | else:
763 | # prompt_embeds = prompt_embeds
764 | add_text_embeds = pooled_prompt_embeds
765 | # add_time_ids = add_time_ids.repeat
766 |
767 | prompt_embeds = prompt_embeds.to(device)
768 | add_text_embeds = add_text_embeds.to(device)
769 | add_time_ids = add_time_ids.to(device).repeat(bsz, 1)
770 |
771 | return {
772 | "prompt_embeds": prompt_embeds,
773 | "add_text_embeds": add_text_embeds,
774 | "add_time_ids": add_time_ids,
775 | }
776 |
777 | def forward_easy(self, *args, **kwargs):
778 | return self.forward(*args, **kwargs)
779 |
780 | def forward(
781 | self,
782 | imgs: torch.Tensor,
783 | prompts: list[str],
784 | cs: list[torch.Tensor],
785 | cfg_mask: list[bool] | None = None,
786 | skip_encode: bool = False,
787 | batch: dict | None = None,
788 | ) -> Union[torch.Tensor, torch.Tensor, torch.Tensor]:
789 | assert len(imgs.shape) == 4
790 | assert imgs.min() >= -1.0
791 | assert imgs.max() <= 1.0
792 |
793 | B = imgs.shape[0]
794 |
795 | with torch.no_grad():
796 | # Convert images to latent space
797 | imgs = imgs.to(self.unet.device)
798 | latents = self.pipe.vae.encode(imgs).latent_dist.sample()
799 | latents = latents * self.pipe.vae.config.scaling_factor
800 |
801 | # prompt dropout
802 | prompts = ["" if random.random() < self.c_dropout else p for p in prompts]
803 |
804 | c = self.get_conditioning(prompts, B, latents.device, latents.dtype)
805 |
806 | unet_added_conditions = {
807 | "time_ids": c["add_time_ids"],
808 | "text_embeds": c["add_text_embeds"],
809 | }
810 | prompt_embeds_input = c["prompt_embeds"]
811 |
812 | # Sample noise that we'll add to the latents
813 | noise = torch.randn_like(latents)
814 | bsz = latents.shape[0]
815 | # Sample a random timestep for each image
816 | timesteps = torch.randint(
817 | 0,
818 | self.noise_scheduler.config.num_train_timesteps,
819 | (B,),
820 | device=latents.device,
821 | )
822 | timesteps = timesteps.long()
823 |
824 | # Add noise to the latents according to the noise magnitude at each timestep
825 | # (this is the forward diffusion process)
826 | noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
827 |
828 | # add our lora conditioning
829 | for i, (encoder, dp, mapper, c) in enumerate(zip(self.encoders, self.dps, self.mappers, cs)):
830 | if cfg_mask is None or cfg_mask[i]:
831 | dropout_mask = torch.rand(bsz, device=c.device) < self.c_dropout
832 |
833 | # apply dropout for cfg
834 | c[dropout_mask] = torch.zeros_like(c[dropout_mask])
835 |
836 | with torch.no_grad():
837 | cond = encoder(c)
838 | mapped_cond = mapper(cond)
839 | dp.set_batch(mapped_cond)
840 |
841 | # Predict the noise residual
842 | model_pred = self.unet(
843 | noisy_latents,
844 | timesteps,
845 | prompt_embeds_input,
846 | added_cond_kwargs=unet_added_conditions,
847 | ).sample
848 |
849 | # get the x0 prediction in ddpm sampling
850 | alphas_cumprod = self.noise_scheduler.alphas_cumprod.to(device=model_pred.device, dtype=model_pred.dtype)
851 |
852 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
853 | sqrt_alpha_prod = sqrt_alpha_prod.flatten()
854 | while len(sqrt_alpha_prod.shape) < len(model_pred.shape):
855 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
856 |
857 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
858 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
859 | while len(sqrt_one_minus_alpha_prod.shape) < len(model_pred.shape):
860 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
861 |
862 | x0 = (noisy_latents - sqrt_one_minus_alpha_prod * model_pred) / sqrt_alpha_prod
863 |
864 | # Get the target for loss depending on the prediction type
865 | if self.noise_scheduler.config.prediction_type == "epsilon":
866 | target = noise
867 |
868 | elif self.noise_scheduler.config.prediction_type == "v_prediction":
869 | target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
870 | else:
871 | raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
872 |
873 | loss = F.mse_loss(model_pred, target, reduction="mean")
874 |
875 | return model_pred, loss, x0
876 |
877 | @torch.no_grad()
878 | def sample(
879 | self,
880 | prompt,
881 | num_images_per_prompt,
882 | cs: list[torch.Tensor],
883 | generator,
884 | cfg_mask: list[bool] | None = None,
885 | prompt_offset_step: int = 0,
886 | skip_encode: bool = False,
887 | dtype=torch.float32,
888 | batch: dict | None = None,
889 | **kwargs,
890 | ):
891 | if prompt is not None and isinstance(prompt, str):
892 | batch_size = 1
893 | elif prompt is not None and isinstance(prompt, list):
894 | batch_size = len(prompt)
895 |
896 | batch_size = batch_size * num_images_per_prompt
897 |
898 | device = self.unet.device
899 |
900 | prompt_embeds = None
901 | pooled_prompt_embeds = None
902 |
903 | # we have to do two separate forward passes for the cfg with the loras
904 | # add our lora conditioning
905 | for i, (encoder, dp, mapper, c) in enumerate(zip(self.encoders, self.dps, self.mappers, cs)):
906 |
907 | if c.shape[0] != batch_size:
908 | assert c.shape[0] == 1
909 | c = torch.cat(batch_size * [c]) # repeat along batch dim
910 |
911 | neg_c = torch.zeros_like(c)
912 | if self.guidance_scale > 1:
913 | if cfg_mask is not None and not cfg_mask[i]:
914 | print("no cfg for lora nr", i)
915 | c = torch.cat([c, c])
916 | else:
917 | c = torch.cat([neg_c, c])
918 | cond = encoder(c)
919 | mapped_cond = mapper(cond)
920 | # if isinstance(mapped_cond, tuple) or isinstance(mapped_cond, list):
921 | # mapped_cond = [mc.to(dtype) for mc in mapped_cond]
922 | # else:
923 | # mapped_cond = mapped_cond.to(dtype)
924 |
925 | dp.set_batch(mapped_cond)
926 |
927 | return self.pipe(
928 | prompt=prompt,
929 | num_images_per_prompt=num_images_per_prompt,
930 | generator=generator,
931 | guidance_scale=self.guidance_scale,
932 | prompt_embeds=prompt_embeds,
933 | pooled_prompt_embeds=pooled_prompt_embeds,
934 | **kwargs,
935 | ).images
936 |
--------------------------------------------------------------------------------
/docs/static/js/bulma-carousel.min.js:
--------------------------------------------------------------------------------
1 | !function(t,e){"object"==typeof exports&&"object"==typeof module?module.exports=e():"function"==typeof define&&define.amd?define([],e):"object"==typeof exports?exports.bulmaCarousel=e():t.bulmaCarousel=e()}("undefined"!=typeof self?self:this,function(){return function(i){var n={};function s(t){if(n[t])return n[t].exports;var e=n[t]={i:t,l:!1,exports:{}};return i[t].call(e.exports,e,e.exports,s),e.l=!0,e.exports}return s.m=i,s.c=n,s.d=function(t,e,i){s.o(t,e)||Object.defineProperty(t,e,{configurable:!1,enumerable:!0,get:i})},s.n=function(t){var e=t&&t.__esModule?function(){return t.default}:function(){return t};return s.d(e,"a",e),e},s.o=function(t,e){return Object.prototype.hasOwnProperty.call(t,e)},s.p="",s(s.s=5)}([function(t,e,i){"use strict";i.d(e,"d",function(){return s}),i.d(e,"e",function(){return r}),i.d(e,"b",function(){return o}),i.d(e,"c",function(){return a}),i.d(e,"a",function(){return l});var n=i(2),s=function(e,t){(t=Array.isArray(t)?t:t.split(" ")).forEach(function(t){e.classList.remove(t)})},r=function(t){return t.getBoundingClientRect().width||t.offsetWidth},o=function(t){return t.getBoundingClientRect().height||t.offsetHeight},a=function(t){var e=1=t._x&&this._x<=e._x&&this._y>=t._y&&this._y<=e._y}},{key:"constrain",value:function(t,e){if(t._x>e._x||t._y>e._y)return this;var i=this._x,n=this._y;return null!==t._x&&(i=Math.max(i,t._x)),null!==e._x&&(i=Math.min(i,e._x)),null!==t._y&&(n=Math.max(n,t._y)),null!==e._y&&(n=Math.min(n,e._y)),new s(i,n)}},{key:"reposition",value:function(t){t.style.top=this._y+"px",t.style.left=this._x+"px"}},{key:"toString",value:function(){return"("+this._x+","+this._y+")"}},{key:"x",get:function(){return this._x},set:function(){var t=0this.state.length-this.slidesToShow&&!this.options.centerMode?this.state.next=this.state.index:this.state.next=this.state.index+this.slidesToScroll,this.show()}},{key:"previous",value:function(){this.options.loop||this.options.infinite||0!==this.state.index?this.state.next=this.state.index-this.slidesToScroll:this.state.next=this.state.index,this.show()}},{key:"start",value:function(){this._autoplay.start()}},{key:"pause",value:function(){this._autoplay.pause()}},{key:"stop",value:function(){this._autoplay.stop()}},{key:"show",value:function(t){var e=1this.options.slidesToShow&&(this.options.slidesToScroll=this.slidesToShow),this._breakpoint.init(),this.state.index>=this.state.length&&0!==this.state.index&&(this.state.index=this.state.index-this.slidesToScroll),this.state.length<=this.slidesToShow&&(this.state.index=0),this._ui.wrapper.appendChild(this._navigation.init().render()),this._ui.wrapper.appendChild(this._pagination.init().render()),this.options.navigationSwipe?this._swipe.bindEvents():this._swipe._bindEvents(),this._breakpoint.apply(),this._slides.forEach(function(t){return e._ui.container.appendChild(t)}),this._transitioner.init().apply(!0,this._setHeight.bind(this)),this.options.autoplay&&this._autoplay.init().start()}},{key:"destroy",value:function(){var e=this;this._unbindEvents(),this._items.forEach(function(t){e.element.appendChild(t)}),this.node.remove()}},{key:"id",get:function(){return this._id}},{key:"index",set:function(t){this._index=t},get:function(){return this._index}},{key:"length",set:function(t){this._length=t},get:function(){return this._length}},{key:"slides",get:function(){return this._slides},set:function(t){this._slides=t}},{key:"slidesToScroll",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToScroll():1}},{key:"slidesToShow",get:function(){return"translate"===this.options.effect?this._breakpoint.getSlidesToShow():1}},{key:"direction",get:function(){return"rtl"===this.element.dir.toLowerCase()||"rtl"===this.element.style.direction?"rtl":"ltr"}},{key:"wrapper",get:function(){return this._ui.wrapper}},{key:"wrapperWidth",get:function(){return this._wrapperWidth||0}},{key:"container",get:function(){return this._ui.container}},{key:"containerWidth",get:function(){return this._containerWidth||0}},{key:"slideWidth",get:function(){return this._slideWidth||0}},{key:"transitioner",get:function(){return this._transitioner}}],[{key:"attach",value:function(){var i=this,t=0>t/4).toString(16)})}},function(t,e,i){"use strict";var n=i(3),s=i(8),r=function(){function n(t,e){for(var i=0;i=t.slider.state.length-t.slider.slidesToShow&&!t.slider.options.loop&&!t.slider.options.infinite?t.stop():t.slider.next())},this.slider.options.autoplaySpeed))}},{key:"stop",value:function(){this._interval=clearInterval(this._interval),this.emit("stop",this)}},{key:"pause",value:function(){var t=this,e=0parseInt(e.changePoint,10)}),this._currentBreakpoint=this._getActiveBreakpoint(),this}},{key:"destroy",value:function(){this._unbindEvents()}},{key:"_bindEvents",value:function(){window.addEventListener("resize",this[s]),window.addEventListener("orientationchange",this[s])}},{key:"_unbindEvents",value:function(){window.removeEventListener("resize",this[s]),window.removeEventListener("orientationchange",this[s])}},{key:"_getActiveBreakpoint",value:function(){var t=!0,e=!1,i=void 0;try{for(var n,s=this.options.breakpoints[Symbol.iterator]();!(t=(n=s.next()).done);t=!0){var r=n.value;if(r.changePoint>=window.innerWidth)return r}}catch(t){e=!0,i=t}finally{try{!t&&s.return&&s.return()}finally{if(e)throw i}}return this._defaultBreakpoint}},{key:"getSlidesToShow",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToShow:this._defaultBreakpoint.slidesToShow}},{key:"getSlidesToScroll",value:function(){return this._currentBreakpoint?this._currentBreakpoint.slidesToScroll:this._defaultBreakpoint.slidesToScroll}},{key:"apply",value:function(){this.slider.state.index>=this.slider.state.length&&0!==this.slider.state.index&&(this.slider.state.index=this.slider.state.index-this._currentBreakpoint.slidesToScroll),this.slider.state.length<=this._currentBreakpoint.slidesToShow&&(this.slider.state.index=0),this.options.loop&&this.slider._loop.init().apply(),this.options.infinite&&this.slider._infinite.init().apply(),this.slider._setDimensions(),this.slider._transitioner.init().apply(!0,this.slider._setHeight.bind(this.slider)),this.slider._setClasses(),this.slider._navigation.refresh(),this.slider._pagination.refresh()}},{key:s,value:function(t){var e=this._getActiveBreakpoint();e.slidesToShow!==this._currentBreakpoint.slidesToShow&&(this._currentBreakpoint=e,this.apply())}}]),e}();e.a=r},function(t,e,i){"use strict";var n=function(){function n(t,e){for(var i=0;ithis.slider.state.length-1-this._infiniteCount;i-=1)e=i-1,t.unshift(this._cloneSlide(this.slider.slides[e],e-this.slider.state.length));for(var n=[],s=0;s=this.slider.state.length?(this.slider.state.index=this.slider.state.next=this.slider.state.next-this.slider.state.length,this.slider.transitioner.apply(!0)):this.slider.state.next<0&&(this.slider.state.index=this.slider.state.next=this.slider.state.length+this.slider.state.next,this.slider.transitioner.apply(!0)))}},{key:"_cloneSlide",value:function(t,e){var i=t.cloneNode(!0);return i.dataset.sliderIndex=e,i.dataset.cloned=!0,(i.querySelectorAll("[id]")||[]).forEach(function(t){t.setAttribute("id","")}),i}}]),e}();e.a=s},function(t,e,i){"use strict";var n=i(12),s=function(){function n(t,e){for(var i=0;ithis.slider.state.length-this.slider.slidesToShow&&Object(n.a)(this.slider._slides[this.slider.state.length-1],this.slider.wrapper)?this.slider.state.next=0:this.slider.state.next=Math.min(Math.max(this.slider.state.next,0),this.slider.state.length-this.slider.slidesToShow):this.slider.state.next=0:this.slider.state.next<=0-this.slider.slidesToScroll?this.slider.state.next=this.slider.state.length-this.slider.slidesToShow:this.slider.state.next=0)}}]),e}();e.a=r},function(t,e,i){"use strict";i.d(e,"a",function(){return n});var n=function(t,e){var i=t.getBoundingClientRect();return e=e||document.documentElement,0<=i.top&&0<=i.left&&i.bottom<=(window.innerHeight||e.clientHeight)&&i.right<=(window.innerWidth||e.clientWidth)}},function(t,e,i){"use strict";var n=i(14),s=i(1),r=function(){function n(t,e){for(var i=0;ithis.slider.slidesToShow?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.remove("is-hidden"),0===this.slider.state.next?(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.remove("is-hidden")):this.slider.state.next>=this.slider.state.length-this.slider.slidesToShow&&!this.slider.options.centerMode?(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden")):this.slider.state.next>=this.slider.state.length-1&&this.slider.options.centerMode&&(this._ui.previous.classList.remove("is-hidden"),this._ui.next.classList.add("is-hidden"))):(this._ui.previous.classList.add("is-hidden"),this._ui.next.classList.add("is-hidden")))}},{key:"render",value:function(){return this.node}}]),e}();e.a=o},function(t,e,i){"use strict";e.a=function(t){return'