├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── cog.yaml ├── latent_consistency_controlnet.py ├── predict.py └── samples.py /.dockerignore: -------------------------------------------------------------------------------- 1 | model_cache/ 2 | *.png 3 | *.jpg 4 | .git/ 5 | cog_class_data/ 6 | cog_instance_data/ 7 | model_cache.tar 8 | tmp/ 9 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | model_cache/ 2 | __pycache__ 3 | *.png 4 | *.jpg 5 | cog_class_data/ 6 | cog_instance_data/ 7 | model_cache.tar 8 | tmp/ 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 fofrAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # latent-consistency-model 2 | 3 | https://replicate.com/fofr/latent-consistency-model 4 | -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | # Configuration for Cog ⚙️ 2 | # Reference: https://github.com/replicate/cog/blob/main/docs/yaml.md 3 | 4 | build: 5 | gpu: true 6 | system_packages: 7 | - "libgl1-mesa-glx" 8 | - "libglib2.0-0" 9 | python_version: "3.11" 10 | python_packages: 11 | - "accelerate==0.23.0" 12 | - "torch==2.0.1" 13 | - "torchvision==0.15.2" 14 | - "diffusers==0.22.3" 15 | - "Pillow==10.1.0" 16 | - "transformers==4.34.1" 17 | - "xformers==0.0.22" 18 | - "opencv-python-headless==4.8.1.78" 19 | run: 20 | - curl -o /usr/local/bin/pget -L "https://github.com/replicate/pget/releases/download/v0.0.6/pget" && chmod +x /usr/local/bin/pget 21 | predict: "predict.py:Predictor" 22 | -------------------------------------------------------------------------------- /latent_consistency_controlnet.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/taabata/LCM_Inpaint_Outpaint_Comfy/blob/main/LCM/pipeline_cn.py 2 | # Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion 17 | # and https://github.com/hojonathanho/diffusion 18 | 19 | import math 20 | from dataclasses import dataclass 21 | from typing import Any, Dict, List, Optional, Tuple, Union 22 | 23 | import numpy as np 24 | import torch 25 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 26 | 27 | from diffusers import ( 28 | AutoencoderKL, 29 | ConfigMixin, 30 | DiffusionPipeline, 31 | SchedulerMixin, 32 | UNet2DConditionModel, 33 | ControlNetModel, 34 | logging, 35 | ) 36 | from diffusers.configuration_utils import register_to_config 37 | from diffusers.image_processor import VaeImageProcessor, PipelineImageInput 38 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 39 | from diffusers.pipelines.stable_diffusion.safety_checker import ( 40 | StableDiffusionSafetyChecker, 41 | ) 42 | from diffusers.utils import BaseOutput 43 | 44 | from diffusers.utils.torch_utils import randn_tensor, is_compiled_module 45 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel 46 | 47 | 48 | import PIL.Image 49 | 50 | 51 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 52 | 53 | 54 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents 55 | def retrieve_latents(encoder_output, generator): 56 | if hasattr(encoder_output, "latent_dist"): 57 | return encoder_output.latent_dist.sample(generator) 58 | elif hasattr(encoder_output, "latents"): 59 | return encoder_output.latents 60 | else: 61 | raise AttributeError("Could not access latents of provided encoder_output") 62 | 63 | 64 | class LatentConsistencyModelPipeline_controlnet(DiffusionPipeline): 65 | _optional_components = ["scheduler"] 66 | 67 | def __init__( 68 | self, 69 | vae: AutoencoderKL, 70 | text_encoder: CLIPTextModel, 71 | tokenizer: CLIPTokenizer, 72 | controlnet: Union[ 73 | ControlNetModel, 74 | List[ControlNetModel], 75 | Tuple[ControlNetModel], 76 | MultiControlNetModel, 77 | ], 78 | unet: UNet2DConditionModel, 79 | scheduler: "LCMScheduler", 80 | safety_checker: StableDiffusionSafetyChecker, 81 | feature_extractor: CLIPImageProcessor, 82 | requires_safety_checker: bool = True, 83 | ): 84 | super().__init__() 85 | 86 | scheduler = ( 87 | scheduler 88 | if scheduler is not None 89 | else LCMScheduler_X( 90 | beta_start=0.00085, 91 | beta_end=0.0120, 92 | beta_schedule="scaled_linear", 93 | prediction_type="epsilon", 94 | ) 95 | ) 96 | 97 | self.register_modules( 98 | vae=vae, 99 | text_encoder=text_encoder, 100 | tokenizer=tokenizer, 101 | unet=unet, 102 | controlnet=controlnet, 103 | scheduler=scheduler, 104 | safety_checker=safety_checker, 105 | feature_extractor=feature_extractor, 106 | ) 107 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 108 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 109 | self.control_image_processor = VaeImageProcessor( 110 | vae_scale_factor=self.vae_scale_factor, 111 | do_convert_rgb=True, 112 | do_normalize=False, 113 | ) 114 | 115 | def _encode_prompt( 116 | self, 117 | prompt, 118 | device, 119 | num_images_per_prompt, 120 | prompt_embeds: None, 121 | ): 122 | r""" 123 | Encodes the prompt into text encoder hidden states. 124 | Args: 125 | prompt (`str` or `List[str]`, *optional*): 126 | prompt to be encoded 127 | device: (`torch.device`): 128 | torch device 129 | num_images_per_prompt (`int`): 130 | number of images that should be generated per prompt 131 | prompt_embeds (`torch.FloatTensor`, *optional*): 132 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 133 | provided, text embeddings will be generated from `prompt` input argument. 134 | """ 135 | 136 | if prompt is not None and isinstance(prompt, str): 137 | pass 138 | elif prompt is not None and isinstance(prompt, list): 139 | len(prompt) 140 | else: 141 | prompt_embeds.shape[0] 142 | 143 | if prompt_embeds is None: 144 | text_inputs = self.tokenizer( 145 | prompt, 146 | padding="max_length", 147 | max_length=self.tokenizer.model_max_length, 148 | truncation=True, 149 | return_tensors="pt", 150 | ) 151 | text_input_ids = text_inputs.input_ids 152 | untruncated_ids = self.tokenizer( 153 | prompt, padding="longest", return_tensors="pt" 154 | ).input_ids 155 | 156 | if untruncated_ids.shape[-1] >= text_input_ids.shape[ 157 | -1 158 | ] and not torch.equal(text_input_ids, untruncated_ids): 159 | removed_text = self.tokenizer.batch_decode( 160 | untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] 161 | ) 162 | logger.warning( 163 | "The following part of your input was truncated because CLIP can only handle sequences up to" 164 | f" {self.tokenizer.model_max_length} tokens: {removed_text}" 165 | ) 166 | 167 | if ( 168 | hasattr(self.text_encoder.config, "use_attention_mask") 169 | and self.text_encoder.config.use_attention_mask 170 | ): 171 | attention_mask = text_inputs.attention_mask.to(device) 172 | else: 173 | attention_mask = None 174 | 175 | prompt_embeds = self.text_encoder( 176 | text_input_ids.to(device), 177 | attention_mask=attention_mask, 178 | ) 179 | prompt_embeds = prompt_embeds[0] 180 | 181 | if self.text_encoder is not None: 182 | prompt_embeds_dtype = self.text_encoder.dtype 183 | elif self.unet is not None: 184 | prompt_embeds_dtype = self.unet.dtype 185 | else: 186 | prompt_embeds_dtype = prompt_embeds.dtype 187 | 188 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 189 | 190 | bs_embed, seq_len, _ = prompt_embeds.shape 191 | # duplicate text embeddings for each generation per prompt, using mps friendly method 192 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 193 | prompt_embeds = prompt_embeds.view( 194 | bs_embed * num_images_per_prompt, seq_len, -1 195 | ) 196 | 197 | # Don't need to get uncond prompt embedding because of LCM Guided Distillation 198 | return prompt_embeds 199 | 200 | def run_safety_checker(self, image, device, dtype): 201 | if self.safety_checker is None: 202 | has_nsfw_concept = None 203 | else: 204 | if torch.is_tensor(image): 205 | feature_extractor_input = self.image_processor.postprocess( 206 | image, output_type="pil" 207 | ) 208 | else: 209 | feature_extractor_input = self.image_processor.numpy_to_pil(image) 210 | safety_checker_input = self.feature_extractor( 211 | feature_extractor_input, return_tensors="pt" 212 | ).to(device) 213 | image, has_nsfw_concept = self.safety_checker( 214 | images=image, clip_input=safety_checker_input.pixel_values.to(dtype) 215 | ) 216 | return image, has_nsfw_concept 217 | 218 | def prepare_control_image( 219 | self, 220 | image, 221 | width, 222 | height, 223 | batch_size, 224 | num_images_per_prompt, 225 | device, 226 | dtype, 227 | do_classifier_free_guidance=False, 228 | guess_mode=False, 229 | ): 230 | image = self.control_image_processor.preprocess( 231 | image, height=height, width=width 232 | ).to(dtype=dtype) 233 | image_batch_size = image.shape[0] 234 | 235 | if image_batch_size == 1: 236 | repeat_by = batch_size 237 | else: 238 | # image batch size is the same as prompt batch size 239 | repeat_by = num_images_per_prompt 240 | 241 | image = image.repeat_interleave(repeat_by, dim=0) 242 | 243 | image = image.to(device=device, dtype=dtype) 244 | 245 | if do_classifier_free_guidance and not guess_mode: 246 | image = torch.cat([image] * 2) 247 | 248 | return image 249 | 250 | def prepare_latents( 251 | self, 252 | image, 253 | timestep, 254 | batch_size, 255 | num_channels_latents, 256 | height, 257 | width, 258 | dtype, 259 | device, 260 | latents=None, 261 | generator=None, 262 | ): 263 | shape = ( 264 | batch_size, 265 | num_channels_latents, 266 | height // self.vae_scale_factor, 267 | width // self.vae_scale_factor, 268 | ) 269 | 270 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 271 | raise ValueError( 272 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 273 | ) 274 | 275 | image = image.to(device=device, dtype=dtype) 276 | 277 | # batch_size = batch_size * num_images_per_prompt 278 | 279 | if image.shape[1] == 4: 280 | init_latents = image 281 | 282 | else: 283 | if isinstance(generator, list) and len(generator) != batch_size: 284 | raise ValueError( 285 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 286 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 287 | ) 288 | 289 | elif isinstance(generator, list): 290 | init_latents = [ 291 | retrieve_latents( 292 | self.vae.encode(image[i : i + 1]), generator=generator[i] 293 | ) 294 | for i in range(batch_size) 295 | ] 296 | init_latents = torch.cat(init_latents, dim=0) 297 | else: 298 | init_latents = retrieve_latents( 299 | self.vae.encode(image), generator=generator 300 | ) 301 | 302 | init_latents = self.vae.config.scaling_factor * init_latents 303 | 304 | if ( 305 | batch_size > init_latents.shape[0] 306 | and batch_size % init_latents.shape[0] == 0 307 | ): 308 | # expand init_latents for batch_size 309 | deprecation_message = ( 310 | f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" 311 | " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" 312 | " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" 313 | " your script to pass as many initial images as text prompts to suppress this warning." 314 | ) 315 | # deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) 316 | additional_image_per_prompt = batch_size // init_latents.shape[0] 317 | init_latents = torch.cat( 318 | [init_latents] * additional_image_per_prompt, dim=0 319 | ) 320 | elif ( 321 | batch_size > init_latents.shape[0] 322 | and batch_size % init_latents.shape[0] != 0 323 | ): 324 | raise ValueError( 325 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 326 | ) 327 | else: 328 | init_latents = torch.cat([init_latents], dim=0) 329 | 330 | shape = init_latents.shape 331 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 332 | 333 | # get latents 334 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) 335 | latents = init_latents 336 | 337 | return latents 338 | 339 | if latents is None: 340 | latents = torch.randn(shape, dtype=dtype).to(device) 341 | else: 342 | latents = latents.to(device) 343 | # scale the initial noise by the standard deviation required by the scheduler 344 | latents = latents * self.scheduler.init_noise_sigma 345 | return latents 346 | 347 | def get_w_embedding(self, w, embedding_dim=512, dtype=torch.float32): 348 | """ 349 | see https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 350 | Args: 351 | timesteps: torch.Tensor: generate embedding vectors at these timesteps 352 | embedding_dim: int: dimension of the embeddings to generate 353 | dtype: data type of the generated embeddings 354 | Returns: 355 | embedding vectors with shape `(len(timesteps), embedding_dim)` 356 | """ 357 | assert len(w.shape) == 1 358 | w = w * 1000.0 359 | 360 | half_dim = embedding_dim // 2 361 | emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) 362 | emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) 363 | emb = w.to(dtype)[:, None] * emb[None, :] 364 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) 365 | if embedding_dim % 2 == 1: # zero pad 366 | emb = torch.nn.functional.pad(emb, (0, 1)) 367 | assert emb.shape == (w.shape[0], embedding_dim) 368 | return emb 369 | 370 | def get_timesteps(self, num_inference_steps, strength, device): 371 | # get the original timestep using init_timestep 372 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 373 | 374 | t_start = max(num_inference_steps - init_timestep, 0) 375 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 376 | 377 | return timesteps, num_inference_steps - t_start 378 | 379 | @torch.no_grad() 380 | def __call__( 381 | self, 382 | prompt: Union[str, List[str]] = None, 383 | image: PipelineImageInput = None, 384 | control_image: PipelineImageInput = None, 385 | strength: float = 0.8, 386 | height: Optional[int] = 768, 387 | width: Optional[int] = 768, 388 | guidance_scale: float = 7.5, 389 | num_images_per_prompt: Optional[int] = 1, 390 | latents: Optional[torch.FloatTensor] = None, 391 | generator: Optional[torch.Generator] = None, 392 | num_inference_steps: int = 4, 393 | lcm_origin_steps: int = 50, 394 | prompt_embeds: Optional[torch.FloatTensor] = None, 395 | output_type: Optional[str] = "pil", 396 | return_dict: bool = True, 397 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 398 | controlnet_conditioning_scale: Union[float, List[float]] = 0.8, 399 | guess_mode: bool = True, 400 | control_guidance_start: Union[float, List[float]] = 0.0, 401 | control_guidance_end: Union[float, List[float]] = 1.0, 402 | ): 403 | controlnet = ( 404 | self.controlnet._orig_mod 405 | if is_compiled_module(self.controlnet) 406 | else self.controlnet 407 | ) 408 | # 0. Default height and width to unet 409 | height = height or self.unet.config.sample_size * self.vae_scale_factor 410 | width = width or self.unet.config.sample_size * self.vae_scale_factor 411 | if not isinstance(control_guidance_start, list) and isinstance( 412 | control_guidance_end, list 413 | ): 414 | control_guidance_start = len(control_guidance_end) * [ 415 | control_guidance_start 416 | ] 417 | elif not isinstance(control_guidance_end, list) and isinstance( 418 | control_guidance_start, list 419 | ): 420 | control_guidance_end = len(control_guidance_start) * [control_guidance_end] 421 | elif not isinstance(control_guidance_start, list) and not isinstance( 422 | control_guidance_end, list 423 | ): 424 | mult = ( 425 | len(controlnet.nets) 426 | if isinstance(controlnet, MultiControlNetModel) 427 | else 1 428 | ) 429 | control_guidance_start, control_guidance_end = mult * [ 430 | control_guidance_start 431 | ], mult * [control_guidance_end] 432 | # 2. Define call parameters 433 | if prompt is not None and isinstance(prompt, str): 434 | batch_size = 1 435 | elif prompt is not None and isinstance(prompt, list): 436 | batch_size = len(prompt) 437 | else: 438 | batch_size = prompt_embeds.shape[0] 439 | 440 | device = self._execution_device 441 | # do_classifier_free_guidance = guidance_scale > 0.0 # In LCM Implementation: cfg_noise = noise_cond + cfg_scale * (noise_cond - noise_uncond) , (cfg_scale > 0.0 using CFG) 442 | global_pool_conditions = ( 443 | controlnet.config.global_pool_conditions 444 | if isinstance(controlnet, ControlNetModel) 445 | else controlnet.nets[0].config.global_pool_conditions 446 | ) 447 | guess_mode = guess_mode or global_pool_conditions 448 | # 3. Encode input prompt 449 | prompt_embeds = self._encode_prompt( 450 | prompt, 451 | device, 452 | num_images_per_prompt, 453 | prompt_embeds=prompt_embeds, 454 | ) 455 | 456 | # 3.5 encode image 457 | image = self.image_processor.preprocess(image) 458 | 459 | if isinstance(controlnet, ControlNetModel): 460 | control_image = self.prepare_control_image( 461 | image=control_image, 462 | width=width, 463 | height=height, 464 | batch_size=batch_size * num_images_per_prompt, 465 | num_images_per_prompt=num_images_per_prompt, 466 | device=device, 467 | dtype=controlnet.dtype, 468 | guess_mode=guess_mode, 469 | ) 470 | elif isinstance(controlnet, MultiControlNetModel): 471 | control_images = [] 472 | 473 | for control_image_ in control_image: 474 | control_image_ = self.prepare_control_image( 475 | image=control_image_, 476 | width=width, 477 | height=height, 478 | batch_size=batch_size * num_images_per_prompt, 479 | num_images_per_prompt=num_images_per_prompt, 480 | device=device, 481 | dtype=controlnet.dtype, 482 | do_classifier_free_guidance=do_classifier_free_guidance, 483 | guess_mode=guess_mode, 484 | ) 485 | 486 | control_images.append(control_image_) 487 | 488 | control_image = control_images 489 | else: 490 | assert False 491 | 492 | # 4. Prepare timesteps 493 | self.scheduler.set_timesteps(strength, num_inference_steps, lcm_origin_steps) 494 | # timesteps = self.scheduler.timesteps 495 | # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, 1.0, device) 496 | timesteps = self.scheduler.timesteps 497 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 498 | 499 | # print("timesteps: ", timesteps) 500 | 501 | # 5. Prepare latent variable 502 | num_channels_latents = self.unet.config.in_channels 503 | latents = self.prepare_latents( 504 | image, 505 | latent_timestep, 506 | batch_size * num_images_per_prompt, 507 | num_channels_latents, 508 | height, 509 | width, 510 | prompt_embeds.dtype, 511 | device, 512 | latents, 513 | generator, 514 | ) 515 | bs = batch_size * num_images_per_prompt 516 | 517 | # 6. Get Guidance Scale Embedding 518 | w = torch.tensor(guidance_scale).repeat(bs) 519 | w_embedding = self.get_w_embedding(w, embedding_dim=256).to( 520 | device=device, dtype=latents.dtype 521 | ) 522 | controlnet_keep = [] 523 | for i in range(len(timesteps)): 524 | keeps = [ 525 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 526 | for s, e in zip(control_guidance_start, control_guidance_end) 527 | ] 528 | controlnet_keep.append( 529 | keeps[0] if isinstance(controlnet, ControlNetModel) else keeps 530 | ) 531 | # 7. LCM MultiStep Sampling Loop: 532 | with self.progress_bar(total=num_inference_steps) as progress_bar: 533 | for i, t in enumerate(timesteps): 534 | ts = torch.full((bs,), t, device=device, dtype=torch.long) 535 | latents = latents.to(prompt_embeds.dtype) 536 | if guess_mode: 537 | # Infer ControlNet only for the conditional batch. 538 | control_model_input = latents 539 | control_model_input = self.scheduler.scale_model_input( 540 | control_model_input, ts 541 | ) 542 | controlnet_prompt_embeds = prompt_embeds 543 | else: 544 | control_model_input = latents 545 | controlnet_prompt_embeds = prompt_embeds 546 | if isinstance(controlnet_keep[i], list): 547 | cond_scale = [ 548 | c * s 549 | for c, s in zip( 550 | controlnet_conditioning_scale, controlnet_keep[i] 551 | ) 552 | ] 553 | else: 554 | controlnet_cond_scale = controlnet_conditioning_scale 555 | if isinstance(controlnet_cond_scale, list): 556 | controlnet_cond_scale = controlnet_cond_scale[0] 557 | cond_scale = controlnet_cond_scale * controlnet_keep[i] 558 | 559 | down_block_res_samples, mid_block_res_sample = self.controlnet( 560 | control_model_input, 561 | ts, 562 | encoder_hidden_states=controlnet_prompt_embeds, 563 | controlnet_cond=control_image, 564 | conditioning_scale=cond_scale, 565 | guess_mode=guess_mode, 566 | return_dict=False, 567 | ) 568 | # model prediction (v-prediction, eps, x) 569 | model_pred = self.unet( 570 | latents, 571 | ts, 572 | timestep_cond=w_embedding, 573 | encoder_hidden_states=prompt_embeds, 574 | cross_attention_kwargs=cross_attention_kwargs, 575 | down_block_additional_residuals=down_block_res_samples, 576 | mid_block_additional_residual=mid_block_res_sample, 577 | return_dict=False, 578 | )[0] 579 | 580 | # compute the previous noisy sample x_t -> x_t-1 581 | latents, denoised = self.scheduler.step( 582 | model_pred, i, t, latents, return_dict=False 583 | ) 584 | 585 | # # call the callback, if provided 586 | # if i == len(timesteps) - 1: 587 | progress_bar.update() 588 | 589 | denoised = denoised.to(prompt_embeds.dtype) 590 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 591 | self.unet.to("cpu") 592 | self.controlnet.to("cpu") 593 | torch.cuda.empty_cache() 594 | if not output_type == "latent": 595 | image = self.vae.decode( 596 | denoised / self.vae.config.scaling_factor, return_dict=False 597 | )[0] 598 | image, has_nsfw_concept = self.run_safety_checker( 599 | image, device, prompt_embeds.dtype 600 | ) 601 | else: 602 | image = denoised 603 | has_nsfw_concept = None 604 | 605 | if has_nsfw_concept is None: 606 | do_denormalize = [True] * image.shape[0] 607 | else: 608 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 609 | 610 | image = self.image_processor.postprocess( 611 | image, output_type=output_type, do_denormalize=do_denormalize 612 | ) 613 | 614 | if not return_dict: 615 | return (image, has_nsfw_concept) 616 | 617 | return StableDiffusionPipelineOutput( 618 | images=image, nsfw_content_detected=has_nsfw_concept 619 | ) 620 | 621 | 622 | @dataclass 623 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM 624 | class LCMSchedulerOutput(BaseOutput): 625 | """ 626 | Output class for the scheduler's `step` function output. 627 | Args: 628 | prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 629 | Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the 630 | denoising loop. 631 | pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): 632 | The predicted denoised sample `(x_{0})` based on the model output from the current timestep. 633 | `pred_original_sample` can be used to preview progress or for guidance. 634 | """ 635 | 636 | prev_sample: torch.FloatTensor 637 | denoised: Optional[torch.FloatTensor] = None 638 | 639 | 640 | # Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar 641 | def betas_for_alpha_bar( 642 | num_diffusion_timesteps, 643 | max_beta=0.999, 644 | alpha_transform_type="cosine", 645 | ): 646 | """ 647 | Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 648 | (1-beta) over time from t = [0,1]. 649 | Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up 650 | to that part of the diffusion process. 651 | Args: 652 | num_diffusion_timesteps (`int`): the number of betas to produce. 653 | max_beta (`float`): the maximum beta to use; use values lower than 1 to 654 | prevent singularities. 655 | alpha_transform_type (`str`, *optional*, default to `cosine`): the type of noise schedule for alpha_bar. 656 | Choose from `cosine` or `exp` 657 | Returns: 658 | betas (`np.ndarray`): the betas used by the scheduler to step the model outputs 659 | """ 660 | if alpha_transform_type == "cosine": 661 | 662 | def alpha_bar_fn(t): 663 | return math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2 664 | 665 | elif alpha_transform_type == "exp": 666 | 667 | def alpha_bar_fn(t): 668 | return math.exp(t * -12.0) 669 | 670 | else: 671 | raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}") 672 | 673 | betas = [] 674 | for i in range(num_diffusion_timesteps): 675 | t1 = i / num_diffusion_timesteps 676 | t2 = (i + 1) / num_diffusion_timesteps 677 | betas.append(min(1 - alpha_bar_fn(t2) / alpha_bar_fn(t1), max_beta)) 678 | return torch.tensor(betas, dtype=torch.float32) 679 | 680 | 681 | def rescale_zero_terminal_snr(betas): 682 | """ 683 | Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) 684 | Args: 685 | betas (`torch.FloatTensor`): 686 | the betas that the scheduler is being initialized with. 687 | Returns: 688 | `torch.FloatTensor`: rescaled betas with zero terminal SNR 689 | """ 690 | # Convert betas to alphas_bar_sqrt 691 | alphas = 1.0 - betas 692 | alphas_cumprod = torch.cumprod(alphas, dim=0) 693 | alphas_bar_sqrt = alphas_cumprod.sqrt() 694 | 695 | # Store old values. 696 | alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() 697 | alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() 698 | 699 | # Shift so the last timestep is zero. 700 | alphas_bar_sqrt -= alphas_bar_sqrt_T 701 | 702 | # Scale so the first timestep is back to the old value. 703 | alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) 704 | 705 | # Convert alphas_bar_sqrt to betas 706 | alphas_bar = alphas_bar_sqrt**2 # Revert sqrt 707 | alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod 708 | alphas = torch.cat([alphas_bar[0:1], alphas]) 709 | betas = 1 - alphas 710 | 711 | return betas 712 | 713 | 714 | class LCMScheduler_X(SchedulerMixin, ConfigMixin): 715 | """ 716 | `LCMScheduler` extends the denoising procedure introduced in denoising diffusion probabilistic models (DDPMs) with 717 | non-Markovian guidance. 718 | This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic 719 | methods the library implements for all schedulers such as loading and saving. 720 | Args: 721 | num_train_timesteps (`int`, defaults to 1000): 722 | The number of diffusion steps to train the model. 723 | beta_start (`float`, defaults to 0.0001): 724 | The starting `beta` value of inference. 725 | beta_end (`float`, defaults to 0.02): 726 | The final `beta` value. 727 | beta_schedule (`str`, defaults to `"linear"`): 728 | The beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from 729 | `linear`, `scaled_linear`, or `squaredcos_cap_v2`. 730 | trained_betas (`np.ndarray`, *optional*): 731 | Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`. 732 | clip_sample (`bool`, defaults to `True`): 733 | Clip the predicted sample for numerical stability. 734 | clip_sample_range (`float`, defaults to 1.0): 735 | The maximum magnitude for sample clipping. Valid only when `clip_sample=True`. 736 | set_alpha_to_one (`bool`, defaults to `True`): 737 | Each diffusion step uses the alphas product value at that step and at the previous one. For the final step 738 | there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, 739 | otherwise it uses the alpha value at step 0. 740 | steps_offset (`int`, defaults to 0): 741 | An offset added to the inference steps. You can use a combination of `offset=1` and 742 | `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable 743 | Diffusion. 744 | prediction_type (`str`, defaults to `epsilon`, *optional*): 745 | Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), 746 | `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen 747 | Video](https://imagen.research.google/video/paper.pdf) paper). 748 | thresholding (`bool`, defaults to `False`): 749 | Whether to use the "dynamic thresholding" method. This is unsuitable for latent-space diffusion models such 750 | as Stable Diffusion. 751 | dynamic_thresholding_ratio (`float`, defaults to 0.995): 752 | The ratio for the dynamic thresholding method. Valid only when `thresholding=True`. 753 | sample_max_value (`float`, defaults to 1.0): 754 | The threshold value for dynamic thresholding. Valid only when `thresholding=True`. 755 | timestep_spacing (`str`, defaults to `"leading"`): 756 | The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and 757 | Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. 758 | rescale_betas_zero_snr (`bool`, defaults to `False`): 759 | Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and 760 | dark samples instead of limiting it to samples with medium brightness. Loosely related to 761 | [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506). 762 | """ 763 | 764 | # _compatibles = [e.name for e in KarrasDiffusionSchedulers] 765 | order = 1 766 | 767 | @register_to_config 768 | def __init__( 769 | self, 770 | num_train_timesteps: int = 1000, 771 | beta_start: float = 0.0001, 772 | beta_end: float = 0.02, 773 | beta_schedule: str = "linear", 774 | trained_betas: Optional[Union[np.ndarray, List[float]]] = None, 775 | clip_sample: bool = True, 776 | set_alpha_to_one: bool = True, 777 | steps_offset: int = 0, 778 | prediction_type: str = "epsilon", 779 | thresholding: bool = False, 780 | dynamic_thresholding_ratio: float = 0.995, 781 | clip_sample_range: float = 1.0, 782 | sample_max_value: float = 1.0, 783 | timestep_spacing: str = "leading", 784 | rescale_betas_zero_snr: bool = False, 785 | ): 786 | if trained_betas is not None: 787 | self.betas = torch.tensor(trained_betas, dtype=torch.float32) 788 | elif beta_schedule == "linear": 789 | self.betas = torch.linspace( 790 | beta_start, beta_end, num_train_timesteps, dtype=torch.float32 791 | ) 792 | elif beta_schedule == "scaled_linear": 793 | # this schedule is very specific to the latent diffusion model. 794 | self.betas = ( 795 | torch.linspace( 796 | beta_start**0.5, 797 | beta_end**0.5, 798 | num_train_timesteps, 799 | dtype=torch.float32, 800 | ) 801 | ** 2 802 | ) 803 | elif beta_schedule == "squaredcos_cap_v2": 804 | # Glide cosine schedule 805 | self.betas = betas_for_alpha_bar(num_train_timesteps) 806 | else: 807 | raise NotImplementedError( 808 | f"{beta_schedule} does is not implemented for {self.__class__}" 809 | ) 810 | 811 | # Rescale for zero SNR 812 | if rescale_betas_zero_snr: 813 | self.betas = rescale_zero_terminal_snr(self.betas) 814 | 815 | self.alphas = 1.0 - self.betas 816 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 817 | 818 | # At every step in ddim, we are looking into the previous alphas_cumprod 819 | # For the final step, there is no previous alphas_cumprod because we are already at 0 820 | # `set_alpha_to_one` decides whether we set this parameter simply to one or 821 | # whether we use the final alpha of the "non-previous" one. 822 | self.final_alpha_cumprod = ( 823 | torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] 824 | ) 825 | 826 | # standard deviation of the initial noise distribution 827 | self.init_noise_sigma = 1.0 828 | 829 | # setable values 830 | self.num_inference_steps = None 831 | self.timesteps = torch.from_numpy( 832 | np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64) 833 | ) 834 | 835 | def scale_model_input( 836 | self, sample: torch.FloatTensor, timestep: Optional[int] = None 837 | ) -> torch.FloatTensor: 838 | """ 839 | Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 840 | current timestep. 841 | Args: 842 | sample (`torch.FloatTensor`): 843 | The input sample. 844 | timestep (`int`, *optional*): 845 | The current timestep in the diffusion chain. 846 | Returns: 847 | `torch.FloatTensor`: 848 | A scaled input sample. 849 | """ 850 | return sample 851 | 852 | def _get_variance(self, timestep, prev_timestep): 853 | alpha_prod_t = self.alphas_cumprod[timestep] 854 | alpha_prod_t_prev = ( 855 | self.alphas_cumprod[prev_timestep] 856 | if prev_timestep >= 0 857 | else self.final_alpha_cumprod 858 | ) 859 | beta_prod_t = 1 - alpha_prod_t 860 | beta_prod_t_prev = 1 - alpha_prod_t_prev 861 | 862 | variance = (beta_prod_t_prev / beta_prod_t) * ( 863 | 1 - alpha_prod_t / alpha_prod_t_prev 864 | ) 865 | 866 | return variance 867 | 868 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample 869 | def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: 870 | """ 871 | "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the 872 | prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by 873 | s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing 874 | pixels from saturation at each step. We find that dynamic thresholding results in significantly better 875 | photorealism as well as better image-text alignment, especially when using very large guidance weights." 876 | https://arxiv.org/abs/2205.11487 877 | """ 878 | dtype = sample.dtype 879 | batch_size, channels, height, width = sample.shape 880 | 881 | if dtype not in (torch.float32, torch.float64): 882 | sample = ( 883 | sample.float() 884 | ) # upcast for quantile calculation, and clamp not implemented for cpu half 885 | 886 | # Flatten sample for doing quantile calculation along each image 887 | sample = sample.reshape(batch_size, channels * height * width) 888 | 889 | abs_sample = sample.abs() # "a certain percentile absolute pixel value" 890 | 891 | s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) 892 | s = torch.clamp( 893 | s, min=1, max=self.config.sample_max_value 894 | ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] 895 | 896 | s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 897 | sample = ( 898 | torch.clamp(sample, -s, s) / s 899 | ) # "we threshold xt0 to the range [-s, s] and then divide by s" 900 | 901 | sample = sample.reshape(batch_size, channels, height, width) 902 | sample = sample.to(dtype) 903 | 904 | return sample 905 | 906 | def set_timesteps( 907 | self, 908 | stength, 909 | num_inference_steps: int, 910 | lcm_origin_steps: int, 911 | device: Union[str, torch.device] = None, 912 | ): 913 | """ 914 | Sets the discrete timesteps used for the diffusion chain (to be run before inference). 915 | Args: 916 | num_inference_steps (`int`): 917 | The number of diffusion steps used when generating samples with a pre-trained model. 918 | """ 919 | 920 | if num_inference_steps > self.config.num_train_timesteps: 921 | raise ValueError( 922 | f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" 923 | f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" 924 | f" maximal {self.config.num_train_timesteps} timesteps." 925 | ) 926 | 927 | self.num_inference_steps = num_inference_steps 928 | 929 | # LCM Timesteps Setting: # Linear Spacing 930 | c = self.config.num_train_timesteps // lcm_origin_steps 931 | lcm_origin_timesteps = ( 932 | np.asarray(list(range(1, int(lcm_origin_steps * stength) + 1))) * c - 1 933 | ) # LCM Training Steps Schedule 934 | skipping_step = max(len(lcm_origin_timesteps) // num_inference_steps, 1) 935 | timesteps = lcm_origin_timesteps[::-skipping_step][ 936 | :num_inference_steps 937 | ] # LCM Inference Steps Schedule 938 | 939 | self.timesteps = torch.from_numpy(timesteps.copy()).to(device) 940 | 941 | def get_scalings_for_boundary_condition_discrete(self, t): 942 | self.sigma_data = 0.5 # Default: 0.5 943 | 944 | # By dividing 0.1: This is almost a delta function at t=0. 945 | c_skip = self.sigma_data**2 / ((t / 0.1) ** 2 + self.sigma_data**2) 946 | c_out = (t / 0.1) / ((t / 0.1) ** 2 + self.sigma_data**2) ** 0.5 947 | return c_skip, c_out 948 | 949 | def step( 950 | self, 951 | model_output: torch.FloatTensor, 952 | timeindex: int, 953 | timestep: int, 954 | sample: torch.FloatTensor, 955 | eta: float = 0.0, 956 | use_clipped_model_output: bool = False, 957 | generator=None, 958 | variance_noise: Optional[torch.FloatTensor] = None, 959 | return_dict: bool = True, 960 | ) -> Union[LCMSchedulerOutput, Tuple]: 961 | """ 962 | Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion 963 | process from the learned model outputs (most often the predicted noise). 964 | Args: 965 | model_output (`torch.FloatTensor`): 966 | The direct output from learned diffusion model. 967 | timestep (`float`): 968 | The current discrete timestep in the diffusion chain. 969 | sample (`torch.FloatTensor`): 970 | A current instance of a sample created by the diffusion process. 971 | eta (`float`): 972 | The weight of noise for added noise in diffusion step. 973 | use_clipped_model_output (`bool`, defaults to `False`): 974 | If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary 975 | because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no 976 | clipping has happened, "corrected" `model_output` would coincide with the one provided as input and 977 | `use_clipped_model_output` has no effect. 978 | generator (`torch.Generator`, *optional*): 979 | A random number generator. 980 | variance_noise (`torch.FloatTensor`): 981 | Alternative to generating noise with `generator` by directly providing the noise for the variance 982 | itself. Useful for methods such as [`CycleDiffusion`]. 983 | return_dict (`bool`, *optional*, defaults to `True`): 984 | Whether or not to return a [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] or `tuple`. 985 | Returns: 986 | [`~schedulers.scheduling_utils.LCMSchedulerOutput`] or `tuple`: 987 | If return_dict is `True`, [`~schedulers.scheduling_lcm.LCMSchedulerOutput`] is returned, otherwise a 988 | tuple is returned where the first element is the sample tensor. 989 | """ 990 | if self.num_inference_steps is None: 991 | raise ValueError( 992 | "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" 993 | ) 994 | 995 | # 1. get previous step value 996 | prev_timeindex = timeindex + 1 997 | if prev_timeindex < len(self.timesteps): 998 | prev_timestep = self.timesteps[prev_timeindex] 999 | else: 1000 | prev_timestep = timestep 1001 | 1002 | # 2. compute alphas, betas 1003 | alpha_prod_t = self.alphas_cumprod[timestep] 1004 | alpha_prod_t_prev = ( 1005 | self.alphas_cumprod[prev_timestep] 1006 | if prev_timestep >= 0 1007 | else self.final_alpha_cumprod 1008 | ) 1009 | 1010 | beta_prod_t = 1 - alpha_prod_t 1011 | beta_prod_t_prev = 1 - alpha_prod_t_prev 1012 | 1013 | # 3. Get scalings for boundary conditions 1014 | c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep) 1015 | 1016 | # 4. Different Parameterization: 1017 | parameterization = self.config.prediction_type 1018 | 1019 | if parameterization == "epsilon": # noise-prediction 1020 | pred_x0 = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() 1021 | 1022 | elif parameterization == "sample": # x-prediction 1023 | pred_x0 = model_output 1024 | 1025 | elif parameterization == "v_prediction": # v-prediction 1026 | pred_x0 = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output 1027 | 1028 | # 4. Denoise model output using boundary conditions 1029 | denoised = c_out * pred_x0 + c_skip * sample 1030 | 1031 | # 5. Sample z ~ N(0, I), For MultiStep Inference 1032 | # Noise is not used for one-step sampling. 1033 | if len(self.timesteps) > 1: 1034 | noise = torch.randn(model_output.shape).to(model_output.device) 1035 | prev_sample = ( 1036 | alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise 1037 | ) 1038 | else: 1039 | prev_sample = denoised 1040 | 1041 | if not return_dict: 1042 | return (prev_sample, denoised) 1043 | 1044 | return LCMSchedulerOutput(prev_sample=prev_sample, denoised=denoised) 1045 | 1046 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise 1047 | def add_noise( 1048 | self, 1049 | original_samples: torch.FloatTensor, 1050 | noise: torch.FloatTensor, 1051 | timesteps: torch.IntTensor, 1052 | ) -> torch.FloatTensor: 1053 | # Make sure alphas_cumprod and timestep have same device and dtype as original_samples 1054 | alphas_cumprod = self.alphas_cumprod.to( 1055 | device=original_samples.device, dtype=original_samples.dtype 1056 | ) 1057 | timesteps = timesteps.to(original_samples.device) 1058 | 1059 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 1060 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 1061 | while len(sqrt_alpha_prod.shape) < len(original_samples.shape): 1062 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 1063 | 1064 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 1065 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 1066 | while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): 1067 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 1068 | 1069 | noisy_samples = ( 1070 | sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise 1071 | ) 1072 | return noisy_samples 1073 | 1074 | # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity 1075 | def get_velocity( 1076 | self, 1077 | sample: torch.FloatTensor, 1078 | noise: torch.FloatTensor, 1079 | timesteps: torch.IntTensor, 1080 | ) -> torch.FloatTensor: 1081 | # Make sure alphas_cumprod and timestep have same device and dtype as sample 1082 | alphas_cumprod = self.alphas_cumprod.to( 1083 | device=sample.device, dtype=sample.dtype 1084 | ) 1085 | timesteps = timesteps.to(sample.device) 1086 | 1087 | sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 1088 | sqrt_alpha_prod = sqrt_alpha_prod.flatten() 1089 | while len(sqrt_alpha_prod.shape) < len(sample.shape): 1090 | sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) 1091 | 1092 | sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 1093 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() 1094 | while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): 1095 | sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) 1096 | 1097 | velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample 1098 | return velocity 1099 | 1100 | def __len__(self): 1101 | return self.config.num_train_timesteps 1102 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import os 3 | import torch 4 | import datetime 5 | import tarfile 6 | import numpy as np 7 | import time 8 | import subprocess 9 | from typing import List, Optional 10 | from diffusers import ControlNetModel, DiffusionPipeline, AutoPipelineForImage2Image 11 | from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet 12 | from cog import BasePredictor, Input, Path 13 | from PIL import Image 14 | 15 | MODEL_CACHE_URL = "https://weights.replicate.delivery/default/fofr-lcm/model_cache.tar" 16 | MODEL_CACHE = "model_cache" 17 | 18 | 19 | def download_weights(url, dest): 20 | start = time.time() 21 | print("downloading url: ", url) 22 | print("downloading to: ", dest) 23 | subprocess.check_call(["pget", "-x", url, dest], close_fds=False) 24 | print("downloading took: ", time.time() - start) 25 | 26 | 27 | class Predictor(BasePredictor): 28 | def create_pipeline( 29 | self, 30 | pipeline_class, 31 | safety_checker: bool = True, 32 | controlnet: Optional[ControlNetModel] = None, 33 | ): 34 | kwargs = { 35 | "cache_dir": MODEL_CACHE, 36 | "local_files_only": True, 37 | } 38 | 39 | if not safety_checker: 40 | kwargs["safety_checker"] = None 41 | 42 | if controlnet: 43 | kwargs["controlnet"] = controlnet 44 | kwargs["scheduler"] = None 45 | 46 | pipe = pipeline_class.from_pretrained("SimianLuo/LCM_Dreamshaper_v7", **kwargs) 47 | pipe.to(torch_device="cuda", torch_dtype=torch.float16) 48 | pipe.enable_xformers_memory_efficient_attention() 49 | return pipe 50 | 51 | def setup(self) -> None: 52 | """Load the model into memory to make running multiple predictions efficient""" 53 | 54 | if not os.path.exists(MODEL_CACHE): 55 | download_weights(MODEL_CACHE_URL, MODEL_CACHE) 56 | 57 | self.txt2img_pipe = self.create_pipeline(DiffusionPipeline) 58 | self.txt2img_pipe_unsafe = self.create_pipeline( 59 | DiffusionPipeline, safety_checker=False 60 | ) 61 | 62 | self.img2img_pipe = self.create_pipeline(AutoPipelineForImage2Image) 63 | self.img2img_pipe_unsafe = self.create_pipeline( 64 | AutoPipelineForImage2Image, safety_checker=False 65 | ) 66 | 67 | controlnet_canny = ControlNetModel.from_pretrained( 68 | "lllyasviel/control_v11p_sd15_canny", 69 | cache_dir="model_cache", 70 | local_files_only=True, 71 | torch_dtype=torch.float16, 72 | ).to("cuda") 73 | 74 | self.controlnet_pipe = self.create_pipeline( 75 | LatentConsistencyModelPipeline_controlnet, controlnet=controlnet_canny 76 | ) 77 | self.controlnet_pipe_unsafe = self.create_pipeline( 78 | LatentConsistencyModelPipeline_controlnet, 79 | safety_checker=False, 80 | controlnet=controlnet_canny, 81 | ) 82 | 83 | # warm the pipes 84 | self.txt2img_pipe(prompt="warmup") 85 | self.txt2img_pipe_unsafe(prompt="warmup") 86 | self.img2img_pipe(prompt="warmup", image=[Image.new("RGB", (768, 768))]) 87 | self.img2img_pipe_unsafe(prompt="warmup", image=[Image.new("RGB", (768, 768))]) 88 | self.controlnet_pipe( 89 | prompt="warmup", 90 | image=[Image.new("RGB", (768, 768))], 91 | control_image=[Image.new("RGB", (768, 768))], 92 | ) 93 | self.controlnet_pipe_unsafe( 94 | prompt="warmup", 95 | image=[Image.new("RGB", (768, 768))], 96 | control_image=[Image.new("RGB", (768, 768))], 97 | ) 98 | 99 | def control_image(self, image, canny_low_threshold, canny_high_threshold): 100 | image = np.array(image) 101 | canny = cv.Canny(image, canny_low_threshold, canny_high_threshold) 102 | return Image.fromarray(canny) 103 | 104 | def get_dimensions(self, image): 105 | original_width, original_height = image.size 106 | print( 107 | f"Original dimensions: Width: {original_width}, Height: {original_height}" 108 | ) 109 | resized_width, resized_height = self.get_resized_dimensions( 110 | original_width, original_height 111 | ) 112 | print( 113 | f"Dimensions to resize to: Width: {resized_width}, Height: {resized_height}" 114 | ) 115 | return resized_width, resized_height 116 | 117 | def get_allowed_dimensions(self, base=512, max_dim=1024): 118 | """ 119 | Function to generate allowed dimensions optimized around a base up to a max 120 | """ 121 | allowed_dimensions = [] 122 | for i in range(base, max_dim + 1, 64): 123 | for j in range(base, max_dim + 1, 64): 124 | allowed_dimensions.append((i, j)) 125 | return allowed_dimensions 126 | 127 | def get_resized_dimensions(self, width, height): 128 | """ 129 | Function adapted from Lucataco's implementation of SDXL-Controlnet for Replicate 130 | """ 131 | allowed_dimensions = self.get_allowed_dimensions() 132 | aspect_ratio = width / height 133 | print(f"Aspect Ratio: {aspect_ratio:.2f}") 134 | # Find the closest allowed dimensions that maintain the aspect ratio 135 | # and are closest to the optimum dimension of 768 136 | optimum_dimension = 768 137 | closest_dimensions = min( 138 | allowed_dimensions, 139 | key=lambda dim: abs(dim[0] / dim[1] - aspect_ratio) 140 | + abs(dim[0] - optimum_dimension), 141 | ) 142 | return closest_dimensions 143 | 144 | def resize_images(self, images, width, height): 145 | return [ 146 | img.resize((width, height)) if img is not None else None for img in images 147 | ] 148 | 149 | def open_image(self, image_path): 150 | return Image.open(str(image_path)) if image_path is not None else None 151 | 152 | def apply_sizing_strategy( 153 | self, sizing_strategy, width, height, control_image=None, image=None 154 | ): 155 | image = self.open_image(image) 156 | control_image = self.open_image(control_image) 157 | 158 | if image and image.mode == "RGBA": 159 | image = image.convert("RGB") 160 | 161 | if control_image and control_image.mode == "RGBA": 162 | control_image = control_image.convert("RGB") 163 | 164 | if sizing_strategy == "input_image": 165 | print("Resizing based on input image") 166 | width, height = self.get_dimensions(image) 167 | elif sizing_strategy == "control_image": 168 | print("Resizing based on control image") 169 | width, height = self.get_dimensions(control_image) 170 | else: 171 | print("Using given dimensions") 172 | 173 | image, control_image = self.resize_images([image, control_image], width, height) 174 | return width, height, control_image, image 175 | 176 | @torch.inference_mode() 177 | def predict( 178 | self, 179 | prompt: str = Input( 180 | description="For multiple prompts, enter each on a new line.", 181 | default="Self-portrait oil painting, a beautiful cyborg with golden hair, 8k", 182 | ), 183 | width: int = Input( 184 | description="Width of output image. Lower if out of memory", 185 | default=768, 186 | ), 187 | height: int = Input( 188 | description="Height of output image. Lower if out of memory", 189 | default=768, 190 | ), 191 | sizing_strategy: str = Input( 192 | description="Decide how to resize images – use width/height, resize based on input image or control image", 193 | choices=["width/height", "input_image", "control_image"], 194 | default="width/height", 195 | ), 196 | image: Path = Input( 197 | description="Input image for img2img", 198 | default=None, 199 | ), 200 | prompt_strength: float = Input( 201 | description="Prompt strength when using img2img. 1.0 corresponds to full destruction of information in image", 202 | ge=0.0, 203 | le=1.0, 204 | default=0.8, 205 | ), 206 | num_images: int = Input( 207 | description="Number of images per prompt", 208 | ge=1, 209 | le=50, 210 | default=1, 211 | ), 212 | num_inference_steps: int = Input( 213 | description="Number of denoising steps. Recommend 1 to 8 steps.", 214 | ge=1, 215 | le=50, 216 | default=8, 217 | ), 218 | guidance_scale: float = Input( 219 | description="Scale for classifier-free guidance", ge=1, le=20, default=8.0 220 | ), 221 | lcm_origin_steps: int = Input( 222 | ge=1, 223 | default=50, 224 | ), 225 | seed: int = Input( 226 | description="Random seed. Leave blank to randomize the seed", default=None 227 | ), 228 | control_image: Path = Input( 229 | description="Image for controlnet conditioning", 230 | default=None, 231 | ), 232 | controlnet_conditioning_scale: float = Input( 233 | description="Controlnet conditioning scale", 234 | ge=0.1, 235 | le=4.0, 236 | default=2.0, 237 | ), 238 | control_guidance_start: float = Input( 239 | description="Controlnet start", 240 | ge=0.0, 241 | le=1.0, 242 | default=0.0, 243 | ), 244 | control_guidance_end: float = Input( 245 | description="Controlnet end", 246 | ge=0.0, 247 | le=1.0, 248 | default=1.0, 249 | ), 250 | canny_low_threshold: float = Input( 251 | description="Canny low threshold", 252 | ge=1, 253 | le=255, 254 | default=100, 255 | ), 256 | canny_high_threshold: float = Input( 257 | description="Canny high threshold", 258 | ge=1, 259 | le=255, 260 | default=200, 261 | ), 262 | archive_outputs: bool = Input( 263 | description="Option to archive the output images", 264 | default=False, 265 | ), 266 | disable_safety_checker: bool = Input( 267 | description="Disable safety checker for generated images. This feature is only available through the API", 268 | default=False, 269 | ), 270 | ) -> List[Path]: 271 | """Run a single prediction on the model""" 272 | prediction_start = time.time() 273 | 274 | if seed is None: 275 | seed = int.from_bytes(os.urandom(2), "big") 276 | 277 | print(f"Using seed: {seed}") 278 | 279 | os.environ['PYTHONHASHSEED'] = str(seed) 280 | torch.manual_seed(seed) 281 | torch.cuda.manual_seed_all(seed) 282 | torch.backends.cudnn.deterministic = True 283 | 284 | prompt = prompt.strip().splitlines() 285 | if len(prompt) == 1: 286 | print("Found 1 prompt:") 287 | else: 288 | print(f"Found {len(prompt)} prompts:") 289 | for p in prompt: 290 | print(f"- {p}") 291 | 292 | if len(prompt) * num_images == 1: 293 | print("Making 1 image") 294 | else: 295 | print(f"Making {len(prompt) * num_images} images") 296 | 297 | if image or control_image: 298 | ( 299 | width, 300 | height, 301 | control_image, 302 | image, 303 | ) = self.apply_sizing_strategy( 304 | sizing_strategy, width, height, control_image, image 305 | ) 306 | 307 | kwargs = {} 308 | canny_image = None 309 | 310 | if image: 311 | kwargs["image"] = image 312 | kwargs["strength"] = prompt_strength 313 | 314 | if control_image: 315 | canny_image = self.control_image( 316 | control_image, canny_low_threshold, canny_high_threshold 317 | ) 318 | kwargs["control_guidance_start"]: control_guidance_start 319 | kwargs["control_guidance_end"]: control_guidance_end 320 | kwargs["controlnet_conditioning_scale"]: controlnet_conditioning_scale 321 | 322 | # TODO: This is a hack to get controlnet working without an image input 323 | # The current pipeline doesn't seem to support not having an image, so 324 | # we pass one in but set strength to 1 to ignore it 325 | if not image: 326 | kwargs["image"] = Image.new("RGB", (width, height), (128, 128, 128)) 327 | kwargs["strength"] = 1.0 328 | 329 | kwargs["control_image"] = canny_image 330 | 331 | mode = "controlnet" if control_image else "img2img" if image else "txt2img" 332 | print(f"{mode} mode") 333 | 334 | pipe = getattr( 335 | self, 336 | f"{mode}_pipe" if not disable_safety_checker else f"{mode}_pipe_unsafe", 337 | ) 338 | 339 | common_args = { 340 | "width": width, 341 | "height": height, 342 | "prompt": prompt, 343 | "guidance_scale": guidance_scale, 344 | "num_images_per_prompt": num_images, 345 | "num_inference_steps": num_inference_steps, 346 | "lcm_origin_steps": lcm_origin_steps, 347 | "output_type": "pil", 348 | } 349 | 350 | start = time.time() 351 | result = pipe( 352 | **common_args, 353 | **kwargs, 354 | generator=torch.Generator("cuda").manual_seed(seed), 355 | ).images 356 | print(f"Inference took: {time.time() - start:.2f}s") 357 | 358 | if archive_outputs: 359 | start = time.time() 360 | archive_start_time = datetime.datetime.now() 361 | print(f"Archiving images started at {archive_start_time}") 362 | 363 | tar_path = "/tmp/output_images.tar" 364 | with tarfile.open(tar_path, "w") as tar: 365 | for i, sample in enumerate(result): 366 | output_path = f"/tmp/out-{i}.png" 367 | sample.save(output_path) 368 | tar.add(output_path, f"out-{i}.png") 369 | 370 | print(f"Archiving took: {time.time() - start:.2f}s") 371 | return Path(tar_path) 372 | 373 | # If not archiving, or there is an error in archiving, return the paths of individual images. 374 | output_paths = [] 375 | for i, sample in enumerate(result): 376 | output_path = f"/tmp/out-{i}.jpg" 377 | sample.save(output_path) 378 | output_paths.append(Path(output_path)) 379 | 380 | if canny_image: 381 | canny_image_path = "/tmp/canny-image.jpg" 382 | canny_image.save(canny_image_path) 383 | output_paths.append(Path(canny_image_path)) 384 | 385 | print(f"Prediction took: {time.time() - prediction_start:.2f}s") 386 | return output_paths 387 | -------------------------------------------------------------------------------- /samples.py: -------------------------------------------------------------------------------- 1 | """ 2 | A handy utility for verifying SDXL image generation locally. 3 | To set up, first run a local cog server using: 4 | cog run -p 5000 python -m cog.server.http 5 | Then, in a separate terminal, generate samples 6 | python samples.py 7 | """ 8 | 9 | import base64 10 | import os 11 | import sys 12 | import requests 13 | import glob 14 | import time 15 | 16 | def gen(output_fn, **kwargs): 17 | if glob.glob(f"{output_fn}*"): 18 | return 19 | 20 | print("Generating", output_fn) 21 | url = "http://localhost:5000/predictions" 22 | response = requests.post(url, json={"input": kwargs}) 23 | data = response.json() 24 | 25 | print(data) 26 | 27 | try: 28 | for i, datauri in enumerate(data["output"]): 29 | base64_encoded_data = datauri.split(",")[1] 30 | decoded_data = base64.b64decode(base64_encoded_data) 31 | with open( 32 | f"{output_fn.rsplit('.', 1)[0]}_{i}.{output_fn.rsplit('.', 1)[1]}", "wb" 33 | ) as f: 34 | f.write(decoded_data) 35 | except: 36 | print("Error!") 37 | print("input:", kwargs) 38 | print(data["logs"]) 39 | sys.exit(1) 40 | 41 | 42 | def main(): 43 | total_time = 0 44 | for i in range(10): 45 | start_time = time.time() 46 | gen( 47 | f"sample_{i}.png", 48 | prompt="A studio portrait photo of a cat", 49 | seed=1000, 50 | ) 51 | end_time = time.time() 52 | print(f"Time taken: {end_time - start_time} seconds") 53 | total_time += end_time - start_time 54 | average_time = total_time / 10 55 | print(f"Average time taken: {average_time} seconds") 56 | 57 | if __name__ == "__main__": 58 | main() 59 | --------------------------------------------------------------------------------