├── model ├── __pycache__ │ ├── cldm.cpython-310.pyc │ ├── cldm.cpython-38.pyc │ ├── clip.cpython-310.pyc │ ├── clip.cpython-38.pyc │ ├── lkpn.cpython-38.pyc │ ├── unet.cpython-310.pyc │ ├── unet.cpython-38.pyc │ ├── util.cpython-310.pyc │ ├── util.cpython-38.pyc │ ├── vae.cpython-310.pyc │ ├── vae.cpython-38.pyc │ ├── bsrnet.cpython-310.pyc │ ├── config.cpython-310.pyc │ ├── config.cpython-38.pyc │ ├── scunet.cpython-310.pyc │ ├── swinir.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-38.pyc │ ├── attention.cpython-310.pyc │ ├── attention.cpython-38.pyc │ ├── controlnet.cpython-38.pyc │ ├── controlnet.cpython-310.pyc │ ├── distributions.cpython-310.pyc │ ├── distributions.cpython-38.pyc │ ├── gaussian_diffusion.cpython-310.pyc │ └── gaussian_diffusion.cpython-38.pyc ├── open_clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── __pycache__ │ │ ├── model.cpython-38.pyc │ │ ├── model.cpython-310.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── tokenizer.cpython-310.pyc │ │ ├── tokenizer.cpython-38.pyc │ │ ├── transformer.cpython-310.pyc │ │ └── transformer.cpython-38.pyc │ ├── tokenizer.py │ ├── model.py │ └── transformer.py ├── __init__.py ├── config.py ├── clip.py ├── distributions.py ├── gaussian_diffusion.py ├── cldm.py ├── util.py ├── controlnet.py ├── lkpn.py ├── attention.py └── vae.py ├── utils ├── __pycache__ │ ├── common.cpython-310.pyc │ ├── common.cpython-38.pyc │ ├── cond_fn.cpython-38.pyc │ ├── helpers.cpython-38.pyc │ ├── sampler.cpython-38.pyc │ ├── cond_fn.cpython-310.pyc │ ├── helpers.cpython-310.pyc │ ├── inference.cpython-310.pyc │ ├── inference.cpython-38.pyc │ ├── pipeline.cpython-38.pyc │ ├── sampler.cpython-310.pyc │ ├── face_restoration_helper.cpython-310.pyc │ └── face_restoration_helper.cpython-38.pyc ├── cond_fn.py ├── inference.py ├── common.py ├── pipeline.py └── sampler.py ├── configs └── inference │ ├── diffusion.yaml │ └── cldm.yaml ├── test.sh ├── README.md ├── LICENSE ├── inference.py └── environment.yml /model/__pycache__/cldm.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/cldm.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/cldm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/cldm.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/clip.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/lkpn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/lkpn.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/unet.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/util.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/util.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/vae.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/vae.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/vae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/vae.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/bsrnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/bsrnet.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/config.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/scunet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/scunet.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/swinir.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/swinir.cpython-310.pyc -------------------------------------------------------------------------------- /model/open_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CLIP 2 | from .tokenizer import tokenize 3 | 4 | __all__ = ["CLIP", "tokenize"] 5 | -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/common.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/common.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/common.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cond_fn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/cond_fn.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/helpers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/helpers.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/sampler.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/controlnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/controlnet.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/cond_fn.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/cond_fn.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/helpers.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/helpers.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/inference.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/inference.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/inference.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/pipeline.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/pipeline.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/sampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/sampler.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /model/open_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /configs/inference/diffusion.yaml: -------------------------------------------------------------------------------- 1 | target: model.Diffusion 2 | params: 3 | linear_start: 0.00085 4 | linear_end: 0.0120 5 | timesteps: 1000 6 | -------------------------------------------------------------------------------- /model/__pycache__/distributions.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/distributions.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/model.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/model.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/gaussian_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/gaussian_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/gaussian_diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/__pycache__/gaussian_diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/transformer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/transformer.cpython-310.pyc -------------------------------------------------------------------------------- /model/open_clip/__pycache__/transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/model/open_clip/__pycache__/transformer.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/face_restoration_helper.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/face_restoration_helper.cpython-310.pyc -------------------------------------------------------------------------------- /utils/__pycache__/face_restoration_helper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kkkls/DeblurDiff/HEAD/utils/__pycache__/face_restoration_helper.cpython-38.pyc -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=7 python -u inference.py \ 2 | --model ./checkpoint/model.pth \ 3 | --input /data0/konglingshun/dataset/Real_image/Image \ 4 | --output results/Real_image \ 5 | --device cuda \ 6 | 7 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from . import config 2 | 3 | from .controlnet import ControlledUnetModel, ControlNet 4 | from .vae import AutoencoderKL 5 | from .clip import FrozenOpenCLIPEmbedder 6 | 7 | from .cldm import ControlLDM 8 | from .gaussian_diffusion import Diffusion 9 | 10 | #from .swinir import SwinIR 11 | #from .bsrnet import RRDBNet 12 | #from .scunet import SCUNet 13 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeblurDiff: Real-World Image Deblurring with Generative Diffusion Models 2 | [![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2502.03810) 3 | 4 | ## Dependencies 5 | conda env create -f environment.yml 6 | 7 | ## checkpoint 8 | 9 | [download](https://drive.google.com/drive/folders/1CUtnUKbu_zTyjJ17F95UYyh2SDzCOHeW?usp=drive_link) 10 | 11 | ## Test 12 | 13 | bash test.sh 14 | 15 | ## Acknowledgment: 16 | 17 | This code is based on the [DiffBIR](https://github.com/XPixelGroup/DiffBIR) and [DemystifyLocalViT](https://github.com/Atten4Vis/DemystifyLocalViT). Thanks for their awesome work. 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 kkkls 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/inference/cldm.yaml: -------------------------------------------------------------------------------- 1 | target: model.ControlLDM 2 | params: 3 | latent_scale_factor: 0.18215 4 | unet_cfg: 5 | use_checkpoint: True 6 | image_size: 32 # unused 7 | in_channels: 4 8 | out_channels: 4 9 | model_channels: 320 10 | attention_resolutions: [ 4, 2, 1 ] 11 | num_res_blocks: 2 12 | channel_mult: [ 1, 2, 4, 4 ] 13 | num_head_channels: 64 # need to fix for flash-attn 14 | use_spatial_transformer: True 15 | use_linear_in_transformer: True 16 | transformer_depth: 1 17 | context_dim: 1024 18 | legacy: False 19 | vae_cfg: 20 | embed_dim: 4 21 | ddconfig: 22 | double_z: true 23 | z_channels: 4 24 | resolution: 256 25 | in_channels: 3 26 | out_ch: 3 27 | ch: 128 28 | ch_mult: 29 | - 1 30 | - 2 31 | - 4 32 | - 4 33 | num_res_blocks: 2 34 | attn_resolutions: [] 35 | dropout: 0.0 36 | clip_cfg: 37 | embed_dim: 1024 38 | vision_cfg: 39 | image_size: 224 40 | layers: 32 41 | width: 1280 42 | head_width: 80 43 | patch_size: 14 44 | text_cfg: 45 | context_length: 77 46 | vocab_size: 49408 47 | width: 1024 48 | heads: 16 49 | layers: 24 50 | layer: "penultimate" 51 | controlnet_cfg: 52 | use_checkpoint: True 53 | image_size: 32 # unused 54 | in_channels: 4 55 | hint_channels: 8 56 | model_channels: 320 57 | attention_resolutions: [ 4, 2, 1 ] 58 | num_res_blocks: 2 59 | channel_mult: [ 1, 2, 4, 4 ] 60 | num_head_channels: 64 # need to fix for flash-attn 61 | use_spatial_transformer: True 62 | use_linear_in_transformer: True 63 | transformer_depth: 1 64 | context_dim: 1024 65 | legacy: False 66 | -------------------------------------------------------------------------------- /model/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Optional, Literal 3 | from types import ModuleType 4 | import enum 5 | from packaging import version 6 | 7 | import torch 8 | 9 | # collect system information 10 | if version.parse(torch.__version__) >= version.parse("2.0.0"): 11 | SDP_IS_AVAILABLE = True 12 | else: 13 | SDP_IS_AVAILABLE = False 14 | 15 | try: 16 | import xformers 17 | import xformers.ops 18 | XFORMERS_IS_AVAILBLE = True 19 | except: 20 | XFORMERS_IS_AVAILBLE = False 21 | 22 | 23 | class AttnMode(enum.Enum): 24 | SDP = 0 25 | XFORMERS = 1 26 | VANILLA = 2 27 | 28 | 29 | class Config: 30 | xformers: Optional[ModuleType] = None 31 | attn_mode: AttnMode = AttnMode.VANILLA 32 | 33 | 34 | # initialize attention mode 35 | if SDP_IS_AVAILABLE: 36 | Config.attn_mode = AttnMode.SDP 37 | print(f"use sdp attention as default") 38 | elif XFORMERS_IS_AVAILBLE: 39 | Config.attn_mode = AttnMode.XFORMERS 40 | print(f"use xformers attention as default") 41 | else: 42 | print(f"both sdp attention and xformers are not available, use vanilla attention (very expensive) as default") 43 | 44 | if XFORMERS_IS_AVAILBLE: 45 | Config.xformers = xformers 46 | 47 | 48 | # user-specified attention mode 49 | ATTN_MODE = os.environ.get("ATTN_MODE", None) 50 | if ATTN_MODE is not None: 51 | assert ATTN_MODE in ["vanilla", "sdp", "xformers"] 52 | if ATTN_MODE == "sdp": 53 | assert SDP_IS_AVAILABLE 54 | Config.attn_mode = AttnMode.SDP 55 | elif ATTN_MODE == "xformers": 56 | assert XFORMERS_IS_AVAILBLE 57 | Config.attn_mode = AttnMode.XFORMERS 58 | else: 59 | Config.attn_mode = AttnMode.VANILLA 60 | print(f"set attention mode to {ATTN_MODE}") 61 | else: 62 | print("keep default attention mode") 63 | -------------------------------------------------------------------------------- /model/clip.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.checkpoint import checkpoint 5 | from model.open_clip import CLIP, tokenize 6 | 7 | ### pretrained model path 8 | # _VITH14 = dict( 9 | # laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 10 | # ) 11 | 12 | class FrozenOpenCLIPEmbedder(nn.Module): 13 | """ 14 | Uses the OpenCLIP transformer encoder for text 15 | """ 16 | LAYERS = [ 17 | #"pooled", 18 | "last", 19 | "penultimate" 20 | ] 21 | def __init__(self, embed_dim, vision_cfg, text_cfg, layer="last"): 22 | super().__init__() 23 | assert layer in self.LAYERS 24 | # model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) 25 | model = CLIP(embed_dim, dict(vision_cfg), dict(text_cfg)) 26 | del model.visual 27 | self.model = model 28 | 29 | self.layer = layer 30 | if self.layer == "last": 31 | self.layer_idx = 0 32 | elif self.layer == "penultimate": 33 | self.layer_idx = 1 34 | else: 35 | raise NotImplementedError() 36 | 37 | def forward(self, tokens): 38 | z = self.encode_with_transformer(tokens) 39 | return z 40 | 41 | def encode_with_transformer(self, text): 42 | x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] 43 | x = x + self.model.positional_embedding 44 | x = x.permute(1, 0, 2) # NLD -> LND 45 | x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) 46 | x = x.permute(1, 0, 2) # LND -> NLD 47 | x = self.model.ln_final(x) 48 | return x 49 | 50 | def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): 51 | for i, r in enumerate(self.model.transformer.resblocks): 52 | if i == len(self.model.transformer.resblocks) - self.layer_idx: 53 | break 54 | if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): 55 | x = checkpoint(r, x, attn_mask) 56 | else: 57 | x = r(x, attn_mask=attn_mask) 58 | return x 59 | 60 | def encode(self, text: List[str]) -> torch.Tensor: 61 | # convert a batch of text to tensor 62 | tokens = tokenize(text) 63 | # move tensor to model device 64 | tokens = tokens.to(next(self.model.parameters()).device) 65 | return self(tokens) 66 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser, Namespace 2 | 3 | import torch 4 | 5 | from accelerate.utils import set_seed 6 | from utils.inference import InferenceLoop 7 | 8 | 9 | def check_device(device: str) -> str: 10 | if device == "cuda": 11 | if not torch.cuda.is_available(): 12 | print("CUDA not available because the current PyTorch install was not " 13 | "built with CUDA enabled.") 14 | device = "cpu" 15 | else: 16 | if device == "mps": 17 | if not torch.backends.mps.is_available(): 18 | if not torch.backends.mps.is_built(): 19 | print("MPS not available because the current PyTorch install was not " 20 | "built with MPS enabled.") 21 | device = "cpu" 22 | else: 23 | print("MPS not available because the current MacOS version is not 12.3+ " 24 | "and/or you do not have an MPS-enabled device on this machine.") 25 | device = "cpu" 26 | print(f"using device {device}") 27 | return device 28 | 29 | 30 | def parse_args() -> Namespace: 31 | parser = ArgumentParser() 32 | parser.add_argument("--steps", type=int, default=50) 33 | parser.add_argument("--better_start", action="store_true") 34 | parser.add_argument("--tiled", action="store_true") 35 | parser.add_argument("--tile_size", type=int, default=512) 36 | parser.add_argument("--tile_stride", type=int, default=256) 37 | parser.add_argument("--pos_prompt", type=str, default="") 38 | parser.add_argument("--neg_prompt", type=str, default="low quality, blurry, low-resolution, noisy, unsharp, weird textures") 39 | parser.add_argument("--cfg_scale", type=float, default=1.0) 40 | ### input parameters 41 | parser.add_argument("--input", type=str, required=True) 42 | parser.add_argument("--model", type=str, required=True) 43 | parser.add_argument("--n_samples", type=int, default=1) 44 | ### guidance parameters 45 | parser.add_argument("--guidance", action="store_true") 46 | parser.add_argument("--g_loss", type=str, default="w_mse", choices=["mse", "w_mse"]) 47 | parser.add_argument("--g_scale", type=float, default=0.0) 48 | parser.add_argument("--g_start", type=int, default=1001) 49 | parser.add_argument("--g_stop", type=int, default=-1) 50 | parser.add_argument("--g_space", type=str, default="latent") 51 | parser.add_argument("--g_repeat", type=int, default=1) 52 | ### output parameters 53 | parser.add_argument("--output", type=str, required=True) 54 | ### common parameters 55 | parser.add_argument("--seed", type=int, default=231) 56 | parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"]) 57 | 58 | return parser.parse_args() 59 | 60 | 61 | def main(): 62 | args = parse_args() 63 | args.device = check_device(args.device) 64 | set_seed(args.seed) 65 | 66 | InferenceLoop(args).run() 67 | print("done!") 68 | 69 | 70 | if __name__ == "__main__": 71 | main() 72 | -------------------------------------------------------------------------------- /model/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /utils/cond_fn.py: -------------------------------------------------------------------------------- 1 | from typing import overload, Tuple 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | class Guidance: 7 | 8 | def __init__(self, scale: float, t_start: int, t_stop: int, space: str, repeat: int) -> "Guidance": 9 | """ 10 | Initialize restoration guidance. 11 | 12 | Args: 13 | scale (float): Gradient scale (denoted as `s` in our paper). The larger the gradient scale, 14 | the closer the final result will be to the output of the first stage model. 15 | t_start (int), t_stop (int): The timestep to start or stop guidance. Note that the sampling 16 | process starts from t=1000 to t=0, the `t_start` should be larger than `t_stop`. 17 | space (str): The data space for computing loss function (rgb or latent). 18 | 19 | Our restoration guidance is based on [GDP](https://github.com/Fayeben/GenerativeDiffusionPrior). 20 | Thanks for their work! 21 | """ 22 | self.scale = scale * 3000 23 | self.t_start = t_start 24 | self.t_stop = t_stop 25 | self.target = None 26 | self.space = space 27 | self.repeat = repeat 28 | 29 | def load_target(self, target: torch.Tensor) -> None: 30 | self.target = target 31 | 32 | def __call__(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 33 | # avoid propagating gradient out of this scope 34 | pred_x0 = pred_x0.detach().clone() 35 | target_x0 = target_x0.detach().clone() 36 | return self._forward(target_x0, pred_x0, t) 37 | 38 | @overload 39 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 40 | ... 41 | 42 | 43 | class MSEGuidance(Guidance): 44 | 45 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 46 | # inputs: [-1, 1], nchw, rgb 47 | with torch.enable_grad(): 48 | pred_x0.requires_grad_(True) 49 | loss = (pred_x0 - target_x0).pow(2).mean((1, 2, 3)).sum() 50 | scale = self.scale 51 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 52 | return g, loss.item() 53 | 54 | 55 | class WeightedMSEGuidance(Guidance): 56 | 57 | def _get_weight(self, target: torch.Tensor) -> torch.Tensor: 58 | # convert RGB to G 59 | rgb_to_gray_kernel = torch.tensor([0.2989, 0.5870, 0.1140]).view(1, 3, 1, 1) 60 | target = torch.sum(target * rgb_to_gray_kernel.to(target.device), dim=1, keepdim=True) 61 | # initialize sobel kernel in x and y axis 62 | G_x = [ 63 | [1, 0, -1], 64 | [2, 0, -2], 65 | [1, 0, -1] 66 | ] 67 | G_y = [ 68 | [1, 2, 1], 69 | [0, 0, 0], 70 | [-1, -2, -1] 71 | ] 72 | G_x = torch.tensor(G_x, dtype=target.dtype, device=target.device)[None] 73 | G_y = torch.tensor(G_y, dtype=target.dtype, device=target.device)[None] 74 | G = torch.stack((G_x, G_y)) 75 | 76 | target = F.pad(target, (1, 1, 1, 1), mode='replicate') # padding = 1 77 | grad = F.conv2d(target, G, stride=1) 78 | mag = grad.pow(2).sum(dim=1, keepdim=True).sqrt() 79 | 80 | n, c, h, w = mag.size() 81 | block_size = 2 82 | blocks = mag.view(n, c, h // block_size, block_size, w // block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous() 83 | block_mean = blocks.sum(dim=(-2, -1), keepdim=True).tanh().repeat(1, 1, 1, 1, block_size, block_size).permute(0, 1, 2, 4, 3, 5).contiguous() 84 | block_mean = block_mean.view(n, c, h, w) 85 | weight_map = 1 - block_mean 86 | 87 | return weight_map 88 | 89 | def _forward(self, target_x0: torch.Tensor, pred_x0: torch.Tensor, t: int) -> Tuple[torch.Tensor, float]: 90 | # inputs: [-1, 1], nchw, rgb 91 | with torch.no_grad(): 92 | w = self._get_weight((target_x0 + 1) / 2) 93 | with torch.enable_grad(): 94 | pred_x0.requires_grad_(True) 95 | loss = ((pred_x0 - target_x0).pow(2) * w).mean((1, 2, 3)).sum() 96 | scale = self.scale 97 | g = -torch.autograd.grad(loss, pred_x0)[0] * scale 98 | return g, loss.item() 99 | -------------------------------------------------------------------------------- /utils/inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import overload, Generator, Dict 3 | from argparse import Namespace 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from omegaconf import OmegaConf 9 | 10 | from model.cldm import ControlLDM 11 | from model.gaussian_diffusion import Diffusion 12 | 13 | from utils.common import instantiate_from_config, load_file_from_url, count_vram_usage 14 | from utils.pipeline import ( 15 | Pipeline, 16 | bicubic_resize 17 | ) 18 | from utils.cond_fn import MSEGuidance, WeightedMSEGuidance 19 | import torch 20 | 21 | class InferenceLoop: 22 | 23 | def __init__(self, args: Namespace) -> "InferenceLoop": 24 | self.args = args 25 | self.loop_ctx = {} 26 | self.pipeline: Pipeline = None 27 | self.init_model() 28 | self.init_cond_fn() 29 | self.init_pipeline() 30 | 31 | 32 | @count_vram_usage 33 | def init_model(self) -> None: 34 | 35 | self.cldm: ControlLDM = instantiate_from_config(OmegaConf.load("configs/inference/cldm.yaml")) 36 | self.cldm.load_state_dict(torch.load(self.args.model)) 37 | self.cldm.eval().to(self.args.device) 38 | ### load diffusion 39 | self.diffusion: Diffusion = instantiate_from_config(OmegaConf.load("configs/inference/diffusion.yaml")) 40 | self.diffusion.to(self.args.device) 41 | 42 | def init_cond_fn(self) -> None: 43 | if not self.args.guidance: 44 | self.cond_fn = None 45 | return 46 | if self.args.g_loss == "mse": 47 | cond_fn_cls = MSEGuidance 48 | elif self.args.g_loss == "w_mse": 49 | cond_fn_cls = WeightedMSEGuidance 50 | else: 51 | raise ValueError(self.args.g_loss) 52 | self.cond_fn = cond_fn_cls( 53 | scale=self.args.g_scale, t_start=self.args.g_start, t_stop=self.args.g_stop, 54 | space=self.args.g_space, repeat=self.args.g_repeat 55 | ) 56 | 57 | 58 | def init_pipeline(self) -> None: 59 | self.pipeline = Pipeline(self.cldm, self.diffusion, self.cond_fn, self.args.device) 60 | 61 | 62 | def setup(self) -> None: 63 | self.output_dir = self.args.output 64 | os.makedirs(self.output_dir, exist_ok=True) 65 | 66 | def lq_loader(self) -> Generator[np.ndarray, None, None]: 67 | img_exts = [".png", ".jpg", ".jpeg",".PNG"] 68 | if os.path.isdir(self.args.input): 69 | file_names = sorted([ 70 | file_name for file_name in os.listdir(self.args.input) if os.path.splitext(file_name)[-1] in img_exts 71 | ]) 72 | file_paths = [os.path.join(self.args.input, file_name) for file_name in file_names] 73 | else: 74 | assert os.path.splitext(self.args.input)[-1] in img_exts 75 | file_paths = [self.args.input] 76 | 77 | def _loader() -> Generator[np.ndarray, None, None]: 78 | for file_path in file_paths: 79 | ### load lq 80 | lq = np.array(Image.open(file_path).convert("RGB")) 81 | print(f"load lq: {file_path}") 82 | ### set context for saving results 83 | self.loop_ctx["file_stem"] = os.path.splitext(os.path.basename(file_path))[0] 84 | for i in range(self.args.n_samples): 85 | self.loop_ctx["repeat_idx"] = i 86 | yield lq 87 | 88 | return _loader 89 | 90 | def after_load_lq(self, lq: np.ndarray) -> np.ndarray: 91 | return lq 92 | 93 | @torch.no_grad() 94 | def run(self) -> None: 95 | self.setup() 96 | loader = self.lq_loader() 97 | for lq in loader(): 98 | lq = self.after_load_lq(lq) 99 | sample = self.pipeline.run( 100 | lq[None], self.args.steps, 1.0, self.args.tiled, 101 | self.args.tile_size, self.args.tile_stride, 102 | self.args.pos_prompt, self.args.neg_prompt, self.args.cfg_scale, 103 | self.args.better_start 104 | )[0] 105 | self.save(sample) 106 | 107 | def save(self, sample: np.ndarray) -> None: 108 | file_stem, repeat_idx = self.loop_ctx["file_stem"], self.loop_ctx["repeat_idx"] 109 | file_name = f"{file_stem}_{repeat_idx}.png" if self.args.n_samples > 1 else f"{file_stem}.png" 110 | save_path = os.path.join(self.args.output, file_name) 111 | Image.fromarray(sample).save(save_path) 112 | print(f"save result to {save_path}") 113 | 114 | 115 | 116 | -------------------------------------------------------------------------------- /model/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | 8 | 9 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 10 | if schedule == "linear": 11 | betas = ( 12 | np.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=np.float64) ** 2 13 | ) 14 | 15 | elif schedule == "cosine": 16 | timesteps = ( 17 | np.arange(n_timestep + 1, dtype=np.float64) / n_timestep + cosine_s 18 | ) 19 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 20 | alphas = np.cos(alphas).pow(2) 21 | alphas = alphas / alphas[0] 22 | betas = 1 - alphas[1:] / alphas[:-1] 23 | betas = np.clip(betas, a_min=0, a_max=0.999) 24 | 25 | elif schedule == "sqrt_linear": 26 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) 27 | elif schedule == "sqrt": 28 | betas = np.linspace(linear_start, linear_end, n_timestep, dtype=np.float64) ** 0.5 29 | else: 30 | raise ValueError(f"schedule '{schedule}' unknown.") 31 | return betas 32 | 33 | 34 | def extract_into_tensor(a: torch.Tensor, t: torch.Tensor, x_shape: Tuple[int]) -> torch.Tensor: 35 | b, *_ = t.shape 36 | out = a.gather(-1, t) 37 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 38 | 39 | 40 | class Diffusion(nn.Module): 41 | 42 | def __init__( 43 | self, 44 | timesteps=1000, 45 | beta_schedule="linear", 46 | loss_type="l2", 47 | linear_start=1e-4, 48 | linear_end=2e-2, 49 | cosine_s=8e-3, 50 | parameterization="eps" 51 | ): 52 | super().__init__() 53 | self.num_timesteps = timesteps 54 | self.beta_schedule = beta_schedule 55 | self.linear_start = linear_start 56 | self.linear_end = linear_end 57 | self.cosine_s = cosine_s 58 | assert parameterization in ["eps", "x0", "v"], "currently only supporting 'eps' and 'x0' and 'v'" 59 | self.parameterization = parameterization 60 | self.loss_type = loss_type 61 | 62 | betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, 63 | cosine_s=cosine_s) 64 | alphas = 1. - betas 65 | alphas_cumprod = np.cumprod(alphas, axis=0) 66 | sqrt_alphas_cumprod = np.sqrt(alphas_cumprod) 67 | sqrt_one_minus_alphas_cumprod = np.sqrt(1. - alphas_cumprod) 68 | 69 | self.betas = betas 70 | self.register("sqrt_alphas_cumprod", sqrt_alphas_cumprod) 71 | self.register("sqrt_one_minus_alphas_cumprod", sqrt_one_minus_alphas_cumprod) 72 | 73 | def register(self, name: str, value: np.ndarray) -> None: 74 | self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) 75 | 76 | def q_sample(self, x_start, t, noise): 77 | return ( 78 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + 79 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 80 | ) 81 | 82 | def get_v(self, x, noise, t): 83 | return ( 84 | extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - 85 | extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x 86 | ) 87 | 88 | def get_loss(self, pred, target, mean=True): 89 | if self.loss_type == 'l1': 90 | loss = (target - pred).abs() 91 | if mean: 92 | loss = loss.mean() 93 | elif self.loss_type == 'l2': 94 | if mean: 95 | loss = torch.nn.functional.mse_loss(target, pred) 96 | else: 97 | loss = torch.nn.functional.mse_loss(target, pred, reduction='none') 98 | else: 99 | raise NotImplementedError("unknown loss type '{loss_type}'") 100 | 101 | return loss 102 | 103 | def p_losses(self, model, x_start, t, cond): 104 | noise = torch.randn_like(x_start) 105 | x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) 106 | model_output,lr_kpn = model(x_noisy, t, cond) 107 | 108 | if self.parameterization == "x0": 109 | target = x_start 110 | elif self.parameterization == "eps": 111 | target = noise 112 | elif self.parameterization == "v": 113 | target = self.get_v(x_start, noise, t) 114 | else: 115 | raise NotImplementedError() 116 | 117 | loss_simple = self.get_loss(model_output, target, mean=False).mean() 118 | loss_kpn = self.get_loss(lr_kpn, x_start, mean=False).mean() 119 | loss = loss_kpn+loss_simple 120 | return loss 121 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: sd3-kpn 2 | channels: 3 | - conda-forge 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=conda_forge 8 | - _openmp_mutex=4.5=2_gnu 9 | - bzip2=1.0.8=h4bc722e_7 10 | - ca-certificates=2025.1.31=hbcca054_0 11 | - cuda-version=11.8=h70ddcb2_3 12 | - cudatoolkit=11.8.0=h4ba93d1_13 13 | - cupy=12.3.0=py38h7b7cd4b_2 14 | - fastrlock=0.8.2=py38h17151c0_2 15 | - ld_impl_linux-64=2.43=h712a8e2_1 16 | - libblas=3.9.0=28_h59b9bed_openblas 17 | - libcblas=3.9.0=28_he106b2a_openblas 18 | - libffi=3.4.2=h7f98852_5 19 | - libgcc=14.1.0=h77fa898_1 20 | - libgcc-ng=14.1.0=h69a702a_1 21 | - libgfortran=14.1.0=h69a702a_1 22 | - libgfortran-ng=14.1.0=h69a702a_1 23 | - libgfortran5=14.1.0=hc5f4f2c_1 24 | - libgomp=14.1.0=h77fa898_1 25 | - liblapack=3.9.0=28_h7ac8fdf_openblas 26 | - libnsl=2.0.1=hd590300_0 27 | - libopenblas=0.3.28=pthreads_h94d23a6_0 28 | - libsqlite=3.46.1=hadc24fc_0 29 | - libstdcxx=14.1.0=hc0a3c3a_1 30 | - libstdcxx-ng=14.1.0=h4852527_1 31 | - libuuid=2.38.1=h0b41bf4_0 32 | - libxcrypt=4.4.36=hd590300_1 33 | - libzlib=1.3.1=h4ab18f5_1 34 | - ncurses=6.5=he02047a_1 35 | - numpy=1.24.4=py38h59b608b_0 36 | - openssl=3.4.0=h7b32b05_1 37 | - pip=24.2=pyh8b19718_1 38 | - python=3.8.20=h4a871b0_2_cpython 39 | - python_abi=3.8=5_cp38 40 | - readline=8.2=h8228510_1 41 | - setuptools=75.1.0=pyhd8ed1ab_0 42 | - tk=8.6.13=noxft_h4845f30_101 43 | - wheel=0.44.0=pyhd8ed1ab_0 44 | - xz=5.2.6=h166bdaf_0 45 | - pip: 46 | - accelerate==0.34.2 47 | - aiofiles==23.2.1 48 | - aiohttp==3.9.5 49 | - aiosignal==1.3.1 50 | - alembic==1.13.3 51 | - altair==5.3.0 52 | - antlr4-python3-runtime==4.9.3 53 | - anyio==4.3.0 54 | - async-timeout==4.0.3 55 | - attrs==23.2.0 56 | - banal==1.0.6 57 | - certifi==2024.2.2 58 | - charset-normalizer==3.3.2 59 | - click==8.1.7 60 | - clip==0.2.0 61 | - contourpy==1.1.1 62 | - cycler==0.12.1 63 | - dataset==1.6.2 64 | - diffusers==0.31.0 65 | - einops==0.7.0 66 | - exceptiongroup==1.2.1 67 | - fairscale==0.4.13 68 | - fastapi==0.110.2 69 | - ffmpy==0.3.2 70 | - filelock==3.13.4 71 | - fonttools==4.51.0 72 | - frozenlist==1.4.1 73 | - fsspec==2024.3.1 74 | - ftfy==6.2.0 75 | - gradio==3.24.0 76 | - gradio-client==0.15.1 77 | - greenlet==3.1.1 78 | - h11==0.14.0 79 | - httpcore==1.0.5 80 | - httpx==0.27.0 81 | - huggingface-hub==0.25.1 82 | - idna==3.7 83 | - importlib-metadata==7.1.0 84 | - importlib-resources==6.4.0 85 | - jinja2==3.1.3 86 | - jsonschema==4.21.1 87 | - jsonschema-specifications==2023.12.1 88 | - kiwisolver==1.4.5 89 | - lightning-utilities==0.11.2 90 | - linkify-it-py==2.0.3 91 | - loralib==0.1.2 92 | - mako==1.3.5 93 | - markdown-it-py==2.2.0 94 | - markupsafe==2.1.5 95 | - matplotlib==3.7.5 96 | - mdit-py-plugins==0.3.3 97 | - mdurl==0.1.2 98 | - mpmath==1.3.0 99 | - multidict==6.0.5 100 | - networkx==3.1 101 | - nvidia-cublas-cu12==12.1.3.1 102 | - nvidia-cuda-cupti-cu12==12.1.105 103 | - nvidia-cuda-nvrtc-cu12==12.1.105 104 | - nvidia-cuda-runtime-cu12==12.1.105 105 | - nvidia-cudnn-cu12==8.9.2.26 106 | - nvidia-cufft-cu12==11.0.2.54 107 | - nvidia-curand-cu12==10.3.2.106 108 | - nvidia-cusolver-cu12==11.4.5.107 109 | - nvidia-cusparse-cu12==12.1.0.106 110 | - nvidia-nccl-cu12==2.18.1 111 | - nvidia-nvjitlink-cu12==12.4.127 112 | - nvidia-nvtx-cu12==12.1.105 113 | - omegaconf==2.3.0 114 | - open-clip-torch==2.24.0 115 | - openai-clip==1.0.1 116 | - opencv-python==4.9.0.80 117 | - orjson==3.10.1 118 | - packaging==24.0 119 | - pandas==2.0.3 120 | - peft==0.13.2 121 | - pillow==10.3.0 122 | - pkgutil-resolve-name==1.3.10 123 | - protobuf==5.26.1 124 | - psutil==5.9.8 125 | - pydantic==1.10.11 126 | - pydub==0.25.1 127 | - pyparsing==3.1.2 128 | - python-dateutil==2.9.0.post0 129 | - python-multipart==0.0.9 130 | - pytorch-lightning==2.1.3 131 | - pytz==2024.1 132 | - pyyaml==6.0.1 133 | - referencing==0.34.0 134 | - regex==2024.4.16 135 | - requests==2.31.0 136 | - rpds-py==0.18.0 137 | - safetensors==0.4.3 138 | - scipy==1.10.1 139 | - semantic-version==2.10.0 140 | - sentencepiece==0.2.0 141 | - six==1.16.0 142 | - sniffio==1.3.1 143 | - sqlalchemy==1.4.54 144 | - starlette==0.37.2 145 | - sympy==1.12 146 | - timm==0.9.16 147 | - tokenizers==0.20.0 148 | - toolz==0.12.1 149 | - torch==2.1.2 150 | - torchmetrics==1.3.2 151 | - torchvision==0.16.2 152 | - tqdm==4.66.2 153 | - transformers==4.45.1 154 | - triton==2.1.0 155 | - typing-extensions==4.11.0 156 | - tzdata==2024.1 157 | - uc-micro-py==1.0.3 158 | - urllib3==2.2.1 159 | - uvicorn==0.29.0 160 | - wcwidth==0.2.13 161 | - websockets==11.0.3 162 | - xformers==0.0.23.post1 163 | - yarl==1.9.4 164 | - zipp==3.18.1 165 | prefix: /home/konglingshun/miniconda3/envs/sd3-kpn 166 | -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping, Any, Tuple, Callable 2 | import importlib 3 | import os 4 | from urllib.parse import urlparse 5 | 6 | import torch 7 | from torch import Tensor 8 | from torch.nn import functional as F 9 | import numpy as np 10 | 11 | from torch.hub import download_url_to_file, get_dir 12 | 13 | 14 | def get_obj_from_str(string: str, reload: bool=False) -> Any: 15 | module, cls = string.rsplit(".", 1) 16 | if reload: 17 | module_imp = importlib.import_module(module) 18 | importlib.reload(module_imp) 19 | return getattr(importlib.import_module(module, package=None), cls) 20 | 21 | 22 | def instantiate_from_config(config: Mapping[str, Any]) -> Any: 23 | if not "target" in config: 24 | raise KeyError("Expected key `target` to instantiate.") 25 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 26 | 27 | 28 | def wavelet_blur(image: Tensor, radius: int): 29 | """ 30 | Apply wavelet blur to the input tensor. 31 | """ 32 | # input shape: (1, 3, H, W) 33 | # convolution kernel 34 | kernel_vals = [ 35 | [0.0625, 0.125, 0.0625], 36 | [0.125, 0.25, 0.125], 37 | [0.0625, 0.125, 0.0625], 38 | ] 39 | kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) 40 | # add channel dimensions to the kernel to make it a 4D tensor 41 | kernel = kernel[None, None] 42 | # repeat the kernel across all input channels 43 | kernel = kernel.repeat(3, 1, 1, 1) 44 | image = F.pad(image, (radius, radius, radius, radius), mode='replicate') 45 | # apply convolution 46 | output = F.conv2d(image, kernel, groups=3, dilation=radius) 47 | return output 48 | 49 | 50 | def wavelet_decomposition(image: Tensor, levels=5): 51 | """ 52 | Apply wavelet decomposition to the input tensor. 53 | This function only returns the low frequency & the high frequency. 54 | """ 55 | high_freq = torch.zeros_like(image) 56 | for i in range(levels): 57 | radius = 2 ** i 58 | low_freq = wavelet_blur(image, radius) 59 | high_freq += (image - low_freq) 60 | image = low_freq 61 | 62 | return high_freq, low_freq 63 | 64 | 65 | def wavelet_reconstruction(content_feat:Tensor, style_feat:Tensor): 66 | """ 67 | Apply wavelet decomposition, so that the content will have the same color as the style. 68 | """ 69 | # calculate the wavelet decomposition of the content feature 70 | content_high_freq, content_low_freq = wavelet_decomposition(content_feat) 71 | del content_low_freq 72 | # calculate the wavelet decomposition of the style feature 73 | style_high_freq, style_low_freq = wavelet_decomposition(style_feat) 74 | del style_high_freq 75 | # reconstruct the content feature with the style's high frequency 76 | return content_high_freq + style_low_freq 77 | 78 | 79 | # https://github.com/XPixelGroup/BasicSR/blob/master/basicsr/utils/download_util.py/ 80 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 81 | """Load file form http url, will download models if necessary. 82 | 83 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 84 | 85 | Args: 86 | url (str): URL to be downloaded. 87 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 88 | Default: None. 89 | progress (bool): Whether to show the download progress. Default: True. 90 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 91 | 92 | Returns: 93 | str: The path to the downloaded file. 94 | """ 95 | if model_dir is None: # use the pytorch hub_dir 96 | hub_dir = get_dir() 97 | model_dir = os.path.join(hub_dir, 'checkpoints') 98 | 99 | os.makedirs(model_dir, exist_ok=True) 100 | 101 | parts = urlparse(url) 102 | filename = os.path.basename(parts.path) 103 | if file_name is not None: 104 | filename = file_name 105 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 106 | if not os.path.exists(cached_file): 107 | print(f'Downloading: "{url}" to {cached_file}\n') 108 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 109 | return cached_file 110 | 111 | 112 | def sliding_windows(h: int, w: int, tile_size: int, tile_stride: int) -> Tuple[int, int, int, int]: 113 | hi_list = list(range(0, h - tile_size + 1, tile_stride)) 114 | if (h - tile_size) % tile_stride != 0: 115 | hi_list.append(h - tile_size) 116 | 117 | wi_list = list(range(0, w - tile_size + 1, tile_stride)) 118 | if (w - tile_size) % tile_stride != 0: 119 | wi_list.append(w - tile_size) 120 | 121 | coords = [] 122 | for hi in hi_list: 123 | for wi in wi_list: 124 | coords.append((hi, hi + tile_size, wi, wi + tile_size)) 125 | return coords 126 | 127 | 128 | # https://github.com/csslc/CCSR/blob/main/model/q_sampler.py#L503 129 | def gaussian_weights(tile_width: int, tile_height: int) -> np.ndarray: 130 | """Generates a gaussian mask of weights for tile contributions""" 131 | latent_width = tile_width 132 | latent_height = tile_height 133 | var = 0.01 134 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 135 | x_probs = [ 136 | np.exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / np.sqrt(2 * np.pi * var) 137 | for x in range(latent_width)] 138 | midpoint = latent_height / 2 139 | y_probs = [ 140 | np.exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / np.sqrt(2 * np.pi * var) 141 | for y in range(latent_height)] 142 | weights = np.outer(y_probs, x_probs) 143 | return weights 144 | 145 | 146 | COUNT_VRAM = bool(os.environ.get("COUNT_VRAM", False)) 147 | 148 | def count_vram_usage(func: Callable) -> Callable: 149 | if not COUNT_VRAM: 150 | return func 151 | 152 | def wrapper(*args, **kwargs): 153 | peak_before = torch.cuda.max_memory_allocated() / (1024 ** 3) 154 | ret = func(*args, **kwargs) 155 | torch.cuda.synchronize() 156 | peak_after = torch.cuda.max_memory_allocated() / (1024 ** 3) 157 | print(f"VRAM peak before {func.__name__}: {peak_before:.5f} GB, after: {peak_after:.5f} GB") 158 | return ret 159 | return wrapper 160 | -------------------------------------------------------------------------------- /model/cldm.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Set, List, Dict 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from model import ( 7 | ControlledUnetModel, ControlNet, 8 | AutoencoderKL, FrozenOpenCLIPEmbedder 9 | ) 10 | from utils.common import sliding_windows, count_vram_usage, gaussian_weights 11 | 12 | from model.lkpn import LKPN 13 | 14 | 15 | def disabled_train(self: nn.Module) -> nn.Module: 16 | """Overwrite model.train with this function to make sure train/eval mode 17 | does not change anymore.""" 18 | return self 19 | 20 | 21 | class ControlLDM(nn.Module): 22 | 23 | def __init__( 24 | self, 25 | unet_cfg, 26 | vae_cfg, 27 | clip_cfg, 28 | controlnet_cfg, 29 | latent_scale_factor 30 | ): 31 | super().__init__() 32 | self.unet = ControlledUnetModel(**unet_cfg) 33 | self.kpn = LKPN() 34 | self.vae = AutoencoderKL(**vae_cfg) 35 | self.clip = FrozenOpenCLIPEmbedder(**clip_cfg) 36 | self.controlnet = ControlNet(**controlnet_cfg) 37 | self.scale_factor = latent_scale_factor 38 | self.control_scales = [1.0] * 13 39 | 40 | @torch.no_grad() 41 | def load_pretrained_sd(self, sd: Dict[str, torch.Tensor]) -> Set[str]: 42 | module_map = { 43 | "unet": "model.diffusion_model", 44 | "vae": "first_stage_model", 45 | "clip": "cond_stage_model", 46 | } 47 | modules = [("unet", self.unet), ("vae", self.vae), ("clip", self.clip)] 48 | used = set() 49 | for name, module in modules: 50 | init_sd = {} 51 | scratch_sd = module.state_dict() 52 | for key in scratch_sd: 53 | target_key = ".".join([module_map[name], key]) 54 | init_sd[key] = sd[target_key].clone() 55 | used.add(target_key) 56 | module.load_state_dict(init_sd, strict=True) 57 | unused = set(sd.keys()) - used 58 | # NOTE: this is slightly different from previous version, which haven't switched 59 | # the UNet to eval mode and disabled the requires_grad flag. 60 | for module in [self.vae, self.clip, self.unet]: 61 | module.eval() 62 | module.train = disabled_train 63 | for p in module.parameters(): 64 | p.requires_grad = False 65 | return unused 66 | 67 | @torch.no_grad() 68 | def load_controlnet_from_ckpt(self, sd: Dict[str, torch.Tensor]) -> None: 69 | self.controlnet.load_state_dict(sd, strict=True) 70 | 71 | @torch.no_grad() 72 | def load_controlnet_from_unet(self) -> Tuple[Set[str]]: 73 | unet_sd = self.unet.state_dict() 74 | scratch_sd = self.controlnet.state_dict() 75 | init_sd = {} 76 | init_with_new_zero = set() 77 | init_with_scratch = set() 78 | for key in scratch_sd: 79 | if key in unet_sd: 80 | this, target = scratch_sd[key], unet_sd[key] 81 | if this.size() == target.size(): 82 | init_sd[key] = target.clone() 83 | else: 84 | d_ic = this.size(1) - target.size(1) 85 | oc, _, h, w = this.size() 86 | zeros = torch.zeros((oc, d_ic, h, w), dtype=target.dtype) 87 | init_sd[key] = torch.cat((target, zeros), dim=1) 88 | init_with_new_zero.add(key) 89 | else: 90 | init_sd[key] = scratch_sd[key].clone() 91 | init_with_scratch.add(key) 92 | self.controlnet.load_state_dict(init_sd, strict=True) 93 | return init_with_new_zero, init_with_scratch 94 | 95 | def vae_encode(self, image: torch.Tensor, sample: bool = True) -> torch.Tensor: 96 | if sample: 97 | return self.vae.encode(image).sample() * self.scale_factor 98 | else: 99 | return self.vae.encode(image).mode() * self.scale_factor 100 | 101 | def vae_encode_tiled(self, image: torch.Tensor, tile_size: int, tile_stride: int, 102 | sample: bool = True) -> torch.Tensor: 103 | bs, _, h, w = image.shape 104 | z = torch.zeros((bs, 4, h // 8, w // 8), dtype=torch.float32, device=image.device) 105 | count = torch.zeros_like(z, dtype=torch.float32) 106 | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] 107 | weights = torch.tensor(weights, dtype=torch.float32, device=image.device) 108 | tiles = sliding_windows(h // 8, w // 8, tile_size // 8, tile_stride // 8) 109 | # print(tiles) 110 | for hi, hi_end, wi, wi_end in tiles: 111 | tile_image = image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] 112 | z[:, :, hi:hi_end, wi:wi_end] = self.vae_encode(tile_image, sample=sample) 113 | 114 | return z 115 | 116 | def vae_decode(self, z: torch.Tensor) -> torch.Tensor: 117 | return self.vae.decode(z / self.scale_factor) 118 | 119 | @count_vram_usage 120 | def vae_decode_tiled(self, z: torch.Tensor, tile_size: int, tile_stride: int) -> torch.Tensor: 121 | bs, _, h, w = z.shape 122 | image = torch.zeros((bs, 3, h * 8, w * 8), dtype=torch.float32, device=z.device) 123 | count = torch.zeros_like(image, dtype=torch.float32) 124 | weights = gaussian_weights(tile_size * 8, tile_size * 8)[None, None] 125 | weights = torch.tensor(weights, dtype=torch.float32, device=z.device) 126 | tiles = sliding_windows(h, w, tile_size, tile_stride) 127 | for hi, hi_end, wi, wi_end in tiles: 128 | tile_z = z[:, :, hi:hi_end, wi:wi_end] 129 | image[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += self.vae_decode(tile_z) * weights 130 | count[:, :, hi * 8:hi_end * 8, wi * 8:wi_end * 8] += weights 131 | image.div_(count) 132 | return image 133 | 134 | def prepare_condition(self, clean: torch.Tensor, txt: List[str]) -> Dict[str, torch.Tensor]: 135 | return dict( 136 | c_txt=self.clip.encode(txt), 137 | c_img=self.vae_encode(clean, sample=True) 138 | ) 139 | 140 | @count_vram_usage 141 | def prepare_condition_tiled(self, clean: torch.Tensor, txt: List[str], tile_size: int, tile_stride: int) -> Dict[ 142 | str, torch.Tensor]: 143 | return dict( 144 | c_txt=self.clip.encode(txt), 145 | c_img=self.vae_encode_tiled(clean, tile_size, tile_stride, sample=False) 146 | ) 147 | 148 | def forward(self, x_noisy, t, cond): 149 | c_txt = cond["c_txt"] 150 | c_img = cond["c_img"] 151 | 152 | c_img_kpn, x_noisy_kpn = c_img.contiguous(),x_noisy.contiguous() 153 | lr_kpn = self.kpn(c_img_kpn, x_noisy_kpn, t, c_txt) 154 | cond = torch.cat((c_img,lr_kpn),dim=1) 155 | control = self.controlnet( 156 | x=x_noisy, hint=cond, 157 | timesteps=t, context=c_txt 158 | ) 159 | control = [c * scale for c, scale in zip(control, self.control_scales)] 160 | eps = self.unet( 161 | x=x_noisy, timesteps=t, 162 | context=c_txt, control=control, only_mid_control=False 163 | ) 164 | 165 | return eps,lr_kpn 166 | -------------------------------------------------------------------------------- /utils/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import overload, Tuple, Optional 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import numpy as np 7 | from PIL import Image 8 | from einops import rearrange 9 | 10 | from model.cldm import ControlLDM 11 | from model.gaussian_diffusion import Diffusion 12 | from utils.sampler import SpacedSampler 13 | from utils.cond_fn import Guidance 14 | from utils.common import wavelet_decomposition, wavelet_reconstruction, count_vram_usage 15 | from torch import Tensor 16 | 17 | def calc_mean_std(feat: Tensor, eps=1e-5): 18 | """Calculate mean and std for adaptive_instance_normalization. 19 | Args: 20 | feat (Tensor): 4D tensor. 21 | eps (float): A small value added to the variance to avoid 22 | divide-by-zero. Default: 1e-5. 23 | """ 24 | size = feat.size() 25 | assert len(size) == 4, 'The input feature should be 4D tensor.' 26 | b, c = size[:2] 27 | feat_var = feat.reshape(b, c, -1).var(dim=2) + eps 28 | feat_std = feat_var.sqrt().reshape(b, c, 1, 1) 29 | feat_mean = feat.reshape(b, c, -1).mean(dim=2).reshape(b, c, 1, 1) 30 | return feat_mean, feat_std 31 | 32 | def adaptive_instance_normalization(content_feat:Tensor, style_feat:Tensor): 33 | """Adaptive instance normalization. 34 | Adjust the reference features to have the similar color and illuminations 35 | as those in the degradate features. 36 | Args: 37 | content_feat (Tensor): The reference feature. 38 | style_feat (Tensor): The degradate features. 39 | """ 40 | size = content_feat.size() 41 | style_mean, style_std = calc_mean_std(style_feat) 42 | content_mean, content_std = calc_mean_std(content_feat) 43 | normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) 44 | return normalized_feat * style_std.expand(size) + style_mean.expand(size) 45 | 46 | def bicubic_resize(img: np.ndarray, scale: float) -> np.ndarray: 47 | pil = Image.fromarray(img) 48 | res = pil.resize(tuple(int(x * scale) for x in pil.size), Image.BICUBIC) 49 | return np.array(res) 50 | 51 | 52 | def resize_short_edge_to(imgs: torch.Tensor, size: int) -> torch.Tensor: 53 | _, _, h, w = imgs.size() 54 | if h == w: 55 | new_h, new_w = size, size 56 | elif h < w: 57 | new_h, new_w = size, int(w * (size / h)) 58 | else: 59 | new_h, new_w = int(h * (size / w)), size 60 | return F.interpolate(imgs, size=(new_h, new_w), mode="bicubic", antialias=True) 61 | 62 | 63 | def pad_to_multiples_of(imgs: torch.Tensor, multiple: int) -> torch.Tensor: 64 | _, _, h, w = imgs.size() 65 | if h % multiple == 0 and w % multiple == 0: 66 | return imgs.clone() 67 | # get_pad = lambda x: (x // multiple + 1) * multiple - x 68 | get_pad = lambda x: (x // multiple + int(x % multiple != 0)) * multiple - x 69 | ph, pw = get_pad(h), get_pad(w) 70 | return F.pad(imgs, pad=(0, pw, 0, ph), mode="constant", value=0) 71 | 72 | 73 | class Pipeline: 74 | 75 | def __init__(self, cldm: ControlLDM, diffusion: Diffusion, cond_fn: Optional[Guidance], device: str) -> None: 76 | # self.stage1_model = stage1_model 77 | self.cldm = cldm 78 | self.diffusion = diffusion 79 | self.cond_fn = cond_fn 80 | self.device = device 81 | self.final_size: Tuple[int] = None 82 | 83 | def set_final_size(self, lq: torch.Tensor) -> None: 84 | h, w = lq.shape[2:] 85 | self.final_size = (h, w) 86 | 87 | 88 | @count_vram_usage 89 | def run_diff( 90 | self, 91 | clean: torch.Tensor, 92 | steps: int, 93 | strength: float, 94 | tiled: bool, 95 | tile_size: int, 96 | tile_stride: int, 97 | pos_prompt: str, 98 | neg_prompt: str, 99 | cfg_scale: float, 100 | better_start: float 101 | ) -> torch.Tensor: 102 | ### preprocess 103 | bs, _, ori_h, ori_w = clean.shape 104 | # pad: ensure that height & width are multiples of 64 105 | pad_clean = pad_to_multiples_of(clean, multiple=64) 106 | h, w = pad_clean.shape[2:] 107 | # prepare conditon 108 | if not tiled: 109 | cond = self.cldm.prepare_condition(pad_clean, [pos_prompt] * bs) 110 | uncond = self.cldm.prepare_condition(pad_clean, [neg_prompt] * bs) 111 | else: 112 | cond = self.cldm.prepare_condition_tiled(pad_clean, [pos_prompt] * bs, tile_size, tile_stride) 113 | uncond = self.cldm.prepare_condition_tiled(pad_clean, [neg_prompt] * bs, tile_size, tile_stride) 114 | if self.cond_fn: 115 | self.cond_fn.load_target(pad_clean * 2 - 1) 116 | old_control_scales = self.cldm.control_scales 117 | self.cldm.control_scales = [strength] * 13 118 | if better_start: 119 | # using noised low frequency part of condition as a better start point of 120 | # reverse sampling, which can prevent our model from generating noise in 121 | # image background. 122 | _, low_freq = wavelet_decomposition(pad_clean) 123 | if not tiled: 124 | x_0 = self.cldm.vae_encode(low_freq) 125 | else: 126 | x_0 = self.cldm.vae_encode_tiled(low_freq, tile_size, tile_stride) 127 | x_T = self.diffusion.q_sample( 128 | x_0, 129 | torch.full((bs, ), self.diffusion.num_timesteps - 1, dtype=torch.long, device=self.device), 130 | torch.randn(x_0.shape, dtype=torch.float32, device=self.device) 131 | ) 132 | # print(f"diffusion sqrt_alphas_cumprod: {self.diffusion.sqrt_alphas_cumprod[-1]}") 133 | else: 134 | x_T = torch.randn((bs, 4, h // 8, w // 8), dtype=torch.float32, device=self.device) 135 | ### run sampler 136 | sampler = SpacedSampler(self.diffusion.betas) 137 | z = sampler.sample( 138 | model=self.cldm, device=self.device, steps=steps, batch_size=bs, x_size=(4, h // 8, w // 8), 139 | cond=cond, uncond=uncond, cfg_scale=cfg_scale, x_T=x_T, progress=True, 140 | progress_leave=True, cond_fn=self.cond_fn, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride 141 | ) 142 | if not tiled: 143 | x = self.cldm.vae_decode(z) 144 | else: 145 | x = self.cldm.vae_decode_tiled(z, tile_size // 8, tile_stride // 8) 146 | ### postprocess 147 | self.cldm.control_scales = old_control_scales 148 | sample = x[:, :, :ori_h, :ori_w] 149 | return sample 150 | 151 | @torch.no_grad() 152 | def run( 153 | self, 154 | lq: np.ndarray, 155 | steps: int, 156 | strength: float, 157 | tiled: bool, 158 | tile_size: int, 159 | tile_stride: int, 160 | pos_prompt: str, 161 | neg_prompt: str, 162 | cfg_scale: float, 163 | better_start: bool 164 | ) -> np.ndarray: 165 | # image to tensor 166 | lq = torch.tensor((lq / 255.).clip(0, 1), dtype=torch.float32, device=self.device) 167 | lq = rearrange(lq, "n h w c -> n c h w").contiguous() 168 | # set pipeline output size 169 | self.set_final_size(lq) 170 | clean = lq 171 | sample = self.run_diff( 172 | clean, steps, strength, tiled, tile_size, tile_stride, 173 | pos_prompt, neg_prompt, cfg_scale, better_start 174 | ) 175 | # colorfix (borrowed from StableSR, thanks for their work) 176 | sample = (sample + 1) / 2 177 | sample = adaptive_instance_normalization(sample, clean) 178 | 179 | sample = rearrange(sample * 255., "n c h w -> n h w c") 180 | sample = sample.contiguous().clamp(0, 255).to(torch.uint8).cpu().numpy() 181 | return sample 182 | 183 | 184 | -------------------------------------------------------------------------------- /model/open_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a significant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | def decode(output_ids: torch.Tensor): 156 | output_ids = output_ids.cpu().numpy() 157 | return _tokenizer.decode(output_ids) 158 | 159 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 160 | """ 161 | Returns the tokenized representation of given input string(s) 162 | 163 | Parameters 164 | ---------- 165 | texts : Union[str, List[str]] 166 | An input string or a list of input strings to tokenize 167 | context_length : int 168 | The context length to use; all CLIP models use 77 as the context length 169 | 170 | Returns 171 | ------- 172 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 173 | """ 174 | if isinstance(texts, str): 175 | texts = [texts] 176 | 177 | sot_token = _tokenizer.encoder[""] 178 | eot_token = _tokenizer.encoder[""] 179 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 180 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 181 | 182 | for i, tokens in enumerate(all_tokens): 183 | if len(tokens) > context_length: 184 | tokens = tokens[:context_length] # Truncate 185 | tokens[-1] = eot_token 186 | result[i, :len(tokens)] = torch.tensor(tokens) 187 | 188 | return result 189 | 190 | 191 | class HFTokenizer: 192 | """HuggingFace tokenizer wrapper""" 193 | 194 | def __init__(self, tokenizer_name: str): 195 | from transformers import AutoTokenizer 196 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 197 | 198 | def save_pretrained(self, dest): 199 | self.tokenizer.save_pretrained(dest) 200 | 201 | def __call__(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.Tensor: 202 | # same cleaning as for default tokenizer, except lowercasing 203 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 207 | input_ids = self.tokenizer( 208 | texts, 209 | return_tensors='pt', 210 | max_length=context_length, 211 | padding='max_length', 212 | truncation=True, 213 | ).input_ids 214 | return input_ids 215 | -------------------------------------------------------------------------------- /model/open_clip/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | from typing import Optional, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn 12 | 13 | from .transformer import LayerNormFp32, LayerNorm, QuickGELU, VisionTransformer, TextTransformer 14 | 15 | 16 | @dataclass 17 | class CLIPVisionCfg: 18 | layers: Union[Tuple[int, int, int, int], int] = 12 19 | width: int = 768 20 | head_width: int = 64 21 | mlp_ratio: float = 4.0 22 | patch_size: int = 16 23 | image_size: Union[Tuple[int, int], int] = 224 24 | 25 | ls_init_value: Optional[float] = None # layer scale initial value 26 | patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results 27 | input_patchnorm: bool = False # whether to use dual patchnorm - would only apply the input layernorm on each patch, as post-layernorm already exist in original clip vit design 28 | global_average_pool: bool = False # whether to global average pool the last embedding layer, instead of using CLS token (https://arxiv.org/abs/2205.01580) 29 | attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer 30 | n_queries: int = 256 # n_queries for attentional pooler 31 | attn_pooler_heads: int = 8 # n heads for attentional_pooling 32 | output_tokens: bool = False 33 | 34 | timm_model_name: str = None # a valid model name overrides layers, width, patch_size 35 | timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model 36 | timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') 37 | timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') 38 | timm_proj_bias: bool = False # enable bias final projection 39 | timm_drop: float = 0. # head dropout 40 | timm_drop_path: Optional[float] = None # backbone stochastic depth 41 | 42 | 43 | @dataclass 44 | class CLIPTextCfg: 45 | context_length: int = 77 46 | vocab_size: int = 49408 47 | width: int = 512 48 | heads: int = 8 49 | layers: int = 12 50 | ls_init_value: Optional[float] = None # layer scale initial value 51 | hf_model_name: str = None 52 | hf_tokenizer_name: str = None 53 | hf_model_pretrained: bool = True 54 | proj: str = 'mlp' 55 | pooler_type: str = 'mean_pooler' 56 | embed_cls: bool = False 57 | pad_id: int = 0 58 | output_tokens: bool = False 59 | 60 | 61 | def get_cast_dtype(precision: str): 62 | cast_dtype = None 63 | if precision == 'bf16': 64 | cast_dtype = torch.bfloat16 65 | elif precision == 'fp16': 66 | cast_dtype = torch.float16 67 | return cast_dtype 68 | 69 | 70 | def _build_vision_tower( 71 | embed_dim: int, 72 | vision_cfg: CLIPVisionCfg, 73 | quick_gelu: bool = False, 74 | cast_dtype: Optional[torch.dtype] = None 75 | ): 76 | if isinstance(vision_cfg, dict): 77 | vision_cfg = CLIPVisionCfg(**vision_cfg) 78 | 79 | # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more 80 | # memory efficient in recent PyTorch releases (>= 1.10). 81 | # NOTE: timm models always use native GELU regardless of quick_gelu flag. 82 | act_layer = QuickGELU if quick_gelu else nn.GELU 83 | 84 | vision_heads = vision_cfg.width // vision_cfg.head_width 85 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 86 | visual = VisionTransformer( 87 | image_size=vision_cfg.image_size, 88 | patch_size=vision_cfg.patch_size, 89 | width=vision_cfg.width, 90 | layers=vision_cfg.layers, 91 | heads=vision_heads, 92 | mlp_ratio=vision_cfg.mlp_ratio, 93 | ls_init_value=vision_cfg.ls_init_value, 94 | patch_dropout=vision_cfg.patch_dropout, 95 | input_patchnorm=vision_cfg.input_patchnorm, 96 | global_average_pool=vision_cfg.global_average_pool, 97 | attentional_pool=vision_cfg.attentional_pool, 98 | n_queries=vision_cfg.n_queries, 99 | attn_pooler_heads=vision_cfg.attn_pooler_heads, 100 | output_tokens=vision_cfg.output_tokens, 101 | output_dim=embed_dim, 102 | act_layer=act_layer, 103 | norm_layer=norm_layer, 104 | ) 105 | 106 | return visual 107 | 108 | 109 | def _build_text_tower( 110 | embed_dim: int, 111 | text_cfg: CLIPTextCfg, 112 | quick_gelu: bool = False, 113 | cast_dtype: Optional[torch.dtype] = None, 114 | ): 115 | if isinstance(text_cfg, dict): 116 | text_cfg = CLIPTextCfg(**text_cfg) 117 | 118 | act_layer = QuickGELU if quick_gelu else nn.GELU 119 | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm 120 | 121 | text = TextTransformer( 122 | context_length=text_cfg.context_length, 123 | vocab_size=text_cfg.vocab_size, 124 | width=text_cfg.width, 125 | heads=text_cfg.heads, 126 | layers=text_cfg.layers, 127 | ls_init_value=text_cfg.ls_init_value, 128 | output_dim=embed_dim, 129 | embed_cls=text_cfg.embed_cls, 130 | output_tokens=text_cfg.output_tokens, 131 | pad_id=text_cfg.pad_id, 132 | act_layer=act_layer, 133 | norm_layer=norm_layer, 134 | ) 135 | return text 136 | 137 | 138 | class CLIP(nn.Module): 139 | output_dict: torch.jit.Final[bool] 140 | 141 | def __init__( 142 | self, 143 | embed_dim: int, 144 | vision_cfg: CLIPVisionCfg, 145 | text_cfg: CLIPTextCfg, 146 | quick_gelu: bool = False, 147 | cast_dtype: Optional[torch.dtype] = None, 148 | output_dict: bool = False, 149 | ): 150 | super().__init__() 151 | self.output_dict = output_dict 152 | self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) 153 | 154 | text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) 155 | self.transformer = text.transformer 156 | self.context_length = text.context_length 157 | self.vocab_size = text.vocab_size 158 | self.token_embedding = text.token_embedding 159 | self.positional_embedding = text.positional_embedding 160 | self.ln_final = text.ln_final 161 | self.text_projection = text.text_projection 162 | self.register_buffer('attn_mask', text.attn_mask, persistent=False) 163 | 164 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 165 | 166 | def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): 167 | # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 168 | self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) 169 | 170 | @torch.jit.ignore 171 | def set_grad_checkpointing(self, enable=True): 172 | self.visual.set_grad_checkpointing(enable) 173 | self.transformer.grad_checkpointing = enable 174 | 175 | def encode_image(self, image, normalize: bool = False): 176 | features = self.visual(image) 177 | return F.normalize(features, dim=-1) if normalize else features 178 | 179 | def encode_text(self, text, normalize: bool = False): 180 | cast_dtype = self.transformer.get_cast_dtype() 181 | 182 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 183 | 184 | x = x + self.positional_embedding.to(cast_dtype) 185 | x = x.permute(1, 0, 2) # NLD -> LND 186 | x = self.transformer(x, attn_mask=self.attn_mask) 187 | x = x.permute(1, 0, 2) # LND -> NLD 188 | x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] 189 | # take features from the eot embedding (eot_token is the highest number in each sequence) 190 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 191 | return F.normalize(x, dim=-1) if normalize else x 192 | 193 | def forward( 194 | self, 195 | image: Optional[torch.Tensor] = None, 196 | text: Optional[torch.Tensor] = None, 197 | ): 198 | image_features = self.encode_image(image, normalize=True) if image is not None else None 199 | text_features = self.encode_text(text, normalize=True) if text is not None else None 200 | if self.output_dict: 201 | return { 202 | "image_features": image_features, 203 | "text_features": text_features, 204 | "logit_scale": self.logit_scale.exp() 205 | } 206 | return image_features, text_features, self.logit_scale.exp() 207 | -------------------------------------------------------------------------------- /model/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | from inspect import isfunction 14 | import torch 15 | import torch.nn as nn 16 | import numpy as np 17 | from einops import repeat 18 | 19 | 20 | def exists(val): 21 | return val is not None 22 | 23 | 24 | def default(val, d): 25 | if exists(val): 26 | return val 27 | return d() if isfunction(d) else d 28 | 29 | 30 | def checkpoint(func, inputs, params, flag): 31 | """ 32 | Evaluate a function without caching intermediate activations, allowing for 33 | reduced memory at the expense of extra compute in the backward pass. 34 | :param func: the function to evaluate. 35 | :param inputs: the argument sequence to pass to `func`. 36 | :param params: a sequence of parameters `func` depends on but does not 37 | explicitly take as arguments. 38 | :param flag: if False, disable gradient checkpointing. 39 | """ 40 | if flag: 41 | args = tuple(inputs) + tuple(params) 42 | return CheckpointFunction.apply(func, len(inputs), *args) 43 | else: 44 | return func(*inputs) 45 | 46 | 47 | # class CheckpointFunction(torch.autograd.Function): 48 | # @staticmethod 49 | # def forward(ctx, run_function, length, *args): 50 | # ctx.run_function = run_function 51 | # ctx.input_tensors = list(args[:length]) 52 | # ctx.input_params = list(args[length:]) 53 | # ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 54 | # "dtype": torch.get_autocast_gpu_dtype(), 55 | # "cache_enabled": torch.is_autocast_cache_enabled()} 56 | # with torch.no_grad(): 57 | # output_tensors = ctx.run_function(*ctx.input_tensors) 58 | # return output_tensors 59 | 60 | # @staticmethod 61 | # def backward(ctx, *output_grads): 62 | # ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 63 | # with torch.enable_grad(), \ 64 | # torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 65 | # # Fixes a bug where the first op in run_function modifies the 66 | # # Tensor storage in place, which is not allowed for detach()'d 67 | # # Tensors. 68 | # shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 69 | # output_tensors = ctx.run_function(*shallow_copies) 70 | # input_grads = torch.autograd.grad( 71 | # output_tensors, 72 | # ctx.input_tensors + ctx.input_params, 73 | # output_grads, 74 | # allow_unused=True, 75 | # ) 76 | # del ctx.input_tensors 77 | # del ctx.input_params 78 | # del output_tensors 79 | # return (None, None) + input_grads 80 | 81 | 82 | # Fixes: When we set unet parameters with requires_grad=False, the original CheckpointFunction 83 | # still tries to compute gradient for unet parameters. 84 | # https://discuss.pytorch.org/t/get-runtimeerror-one-of-the-differentiated-tensors-does-not-require-grad-in-pytorch-lightning/179738/6 85 | class CheckpointFunction(torch.autograd.Function): 86 | @staticmethod 87 | def forward(ctx, run_function, length, *args): 88 | ctx.run_function = run_function 89 | ctx.input_tensors = list(args[:length]) 90 | ctx.input_params = list(args[length:]) 91 | ctx.gpu_autocast_kwargs = {"enabled": torch.is_autocast_enabled(), 92 | "dtype": torch.get_autocast_gpu_dtype(), 93 | "cache_enabled": torch.is_autocast_cache_enabled()} 94 | with torch.no_grad(): 95 | output_tensors = ctx.run_function(*ctx.input_tensors) 96 | return output_tensors 97 | 98 | @staticmethod 99 | def backward(ctx, *output_grads): 100 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 101 | with torch.enable_grad(), \ 102 | torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs): 103 | # Fixes a bug where the first op in run_function modifies the 104 | # Tensor storage in place, which is not allowed for detach()'d 105 | # Tensors. 106 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 107 | output_tensors = ctx.run_function(*shallow_copies) 108 | grads = torch.autograd.grad( 109 | output_tensors, 110 | ctx.input_tensors + [x for x in ctx.input_params if x.requires_grad], 111 | output_grads, 112 | allow_unused=True, 113 | ) 114 | grads = list(grads) 115 | # Assign gradients to the correct positions, matching None for those that do not require gradients 116 | input_grads = [] 117 | for tensor in ctx.input_tensors + ctx.input_params: 118 | if tensor.requires_grad: 119 | input_grads.append(grads.pop(0)) # Get the next computed gradient 120 | else: 121 | input_grads.append(None) # No gradient required for this tensor 122 | del ctx.input_tensors 123 | del ctx.input_params 124 | del output_tensors 125 | return (None, None) + tuple(input_grads) 126 | 127 | 128 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 129 | """ 130 | Create sinusoidal timestep embeddings. 131 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 132 | These may be fractional. 133 | :param dim: the dimension of the output. 134 | :param max_period: controls the minimum frequency of the embeddings. 135 | :return: an [N x dim] Tensor of positional embeddings. 136 | """ 137 | if not repeat_only: 138 | half = dim // 2 139 | freqs = torch.exp( 140 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 141 | ).to(device=timesteps.device) 142 | args = timesteps[:, None].float() * freqs[None] 143 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 144 | if dim % 2: 145 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 146 | else: 147 | embedding = repeat(timesteps, 'b -> b d', d=dim) 148 | return embedding 149 | 150 | 151 | def zero_module(module): 152 | """ 153 | Zero out the parameters of a module and return it. 154 | """ 155 | for p in module.parameters(): 156 | p.detach().zero_() 157 | return module 158 | 159 | 160 | def scale_module(module, scale): 161 | """ 162 | Scale the parameters of a module and return it. 163 | """ 164 | for p in module.parameters(): 165 | p.detach().mul_(scale) 166 | return module 167 | 168 | 169 | def mean_flat(tensor): 170 | """ 171 | Take the mean over all non-batch dimensions. 172 | """ 173 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 174 | 175 | 176 | def normalization(channels): 177 | """ 178 | Make a standard normalization layer. 179 | :param channels: number of input channels. 180 | :return: an nn.Module for normalization. 181 | """ 182 | return GroupNorm32(32, channels) 183 | 184 | 185 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 186 | class SiLU(nn.Module): 187 | def forward(self, x): 188 | return x * torch.sigmoid(x) 189 | 190 | 191 | class GroupNorm32(nn.GroupNorm): 192 | def forward(self, x): 193 | return super().forward(x.float()).type(x.dtype) 194 | 195 | def conv_nd(dims, *args, **kwargs): 196 | """ 197 | Create a 1D, 2D, or 3D convolution module. 198 | """ 199 | if dims == 1: 200 | return nn.Conv1d(*args, **kwargs) 201 | elif dims == 2: 202 | return nn.Conv2d(*args, **kwargs) 203 | elif dims == 3: 204 | return nn.Conv3d(*args, **kwargs) 205 | raise ValueError(f"unsupported dimensions: {dims}") 206 | 207 | 208 | def linear(*args, **kwargs): 209 | """ 210 | Create a linear module. 211 | """ 212 | return nn.Linear(*args, **kwargs) 213 | 214 | 215 | def avg_pool_nd(dims, *args, **kwargs): 216 | """ 217 | Create a 1D, 2D, or 3D average pooling module. 218 | """ 219 | if dims == 1: 220 | return nn.AvgPool1d(*args, **kwargs) 221 | elif dims == 2: 222 | return nn.AvgPool2d(*args, **kwargs) 223 | elif dims == 3: 224 | return nn.AvgPool3d(*args, **kwargs) 225 | raise ValueError(f"unsupported dimensions: {dims}") 226 | -------------------------------------------------------------------------------- /model/controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch as th 3 | import torch.nn as nn 4 | 5 | from model.util import ( 6 | conv_nd, 7 | linear, 8 | zero_module, 9 | timestep_embedding, 10 | exists 11 | ) 12 | from model.attention import SpatialTransformer 13 | from model.unet import ( 14 | TimestepEmbedSequential, ResBlock, Downsample, AttentionBlock, UNetModel 15 | ) 16 | 17 | 18 | class ControlledUnetModel(UNetModel): 19 | 20 | def forward(self, x, timesteps=None, context=None, control=None, only_mid_control=False, **kwargs): 21 | hs = [] 22 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 23 | emb = self.time_embed(t_emb) 24 | h = x.type(self.dtype) 25 | for module in self.input_blocks: 26 | h = module(h, emb, context) 27 | hs.append(h) 28 | h = self.middle_block(h, emb, context) 29 | 30 | if control is not None: 31 | h += control.pop() 32 | 33 | for i, module in enumerate(self.output_blocks): 34 | if only_mid_control or control is None: 35 | h = torch.cat([h, hs.pop()], dim=1) 36 | else: 37 | h = torch.cat([h, hs.pop() + control.pop()], dim=1) 38 | h = module(h, emb, context) 39 | 40 | h = h.type(x.dtype) 41 | return self.out(h) 42 | 43 | 44 | class ControlNet(nn.Module): 45 | 46 | def __init__( 47 | self, 48 | image_size, 49 | in_channels, 50 | model_channels, 51 | hint_channels, 52 | num_res_blocks, 53 | attention_resolutions, 54 | dropout=0, 55 | channel_mult=(1, 2, 4, 8), 56 | conv_resample=True, 57 | dims=2, 58 | use_checkpoint=False, 59 | use_fp16=False, 60 | num_heads=-1, 61 | num_head_channels=-1, 62 | num_heads_upsample=-1, 63 | use_scale_shift_norm=False, 64 | resblock_updown=False, 65 | use_new_attention_order=False, 66 | use_spatial_transformer=False, # custom transformer support 67 | transformer_depth=1, # custom transformer support 68 | context_dim=None, # custom transformer support 69 | n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model 70 | legacy=True, 71 | disable_self_attentions=None, 72 | num_attention_blocks=None, 73 | disable_middle_self_attn=False, 74 | use_linear_in_transformer=False, 75 | ): 76 | super().__init__() 77 | if use_spatial_transformer: 78 | assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...' 79 | 80 | if context_dim is not None: 81 | assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...' 82 | from omegaconf.listconfig import ListConfig 83 | if type(context_dim) == ListConfig: 84 | context_dim = list(context_dim) 85 | 86 | if num_heads_upsample == -1: 87 | num_heads_upsample = num_heads 88 | 89 | if num_heads == -1: 90 | assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set' 91 | 92 | if num_head_channels == -1: 93 | assert num_heads != -1, 'Either num_heads or num_head_channels has to be set' 94 | 95 | self.dims = dims 96 | self.image_size = image_size 97 | self.in_channels = in_channels 98 | self.model_channels = model_channels 99 | if isinstance(num_res_blocks, int): 100 | self.num_res_blocks = len(channel_mult) * [num_res_blocks] 101 | else: 102 | if len(num_res_blocks) != len(channel_mult): 103 | raise ValueError("provide num_res_blocks either as an int (globally constant) or " 104 | "as a list/tuple (per-level) with the same length as channel_mult") 105 | self.num_res_blocks = num_res_blocks 106 | if disable_self_attentions is not None: 107 | # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not 108 | assert len(disable_self_attentions) == len(channel_mult) 109 | if num_attention_blocks is not None: 110 | assert len(num_attention_blocks) == len(self.num_res_blocks) 111 | assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks)))) 112 | print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. " 113 | f"This option has LESS priority than attention_resolutions {attention_resolutions}, " 114 | f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, " 115 | f"attention will still not be set.") 116 | 117 | self.attention_resolutions = attention_resolutions 118 | self.dropout = dropout 119 | self.channel_mult = channel_mult 120 | self.conv_resample = conv_resample 121 | self.use_checkpoint = use_checkpoint 122 | self.dtype = th.float16 if use_fp16 else th.float32 123 | self.num_heads = num_heads 124 | self.num_head_channels = num_head_channels 125 | self.num_heads_upsample = num_heads_upsample 126 | self.predict_codebook_ids = n_embed is not None 127 | 128 | time_embed_dim = model_channels * 4 129 | self.time_embed = nn.Sequential( 130 | linear(model_channels, time_embed_dim), 131 | nn.SiLU(), 132 | linear(time_embed_dim, time_embed_dim), 133 | ) 134 | 135 | self.input_blocks = nn.ModuleList( 136 | [ 137 | TimestepEmbedSequential( 138 | conv_nd(dims, in_channels + hint_channels, model_channels, 3, padding=1) 139 | ) 140 | ] 141 | ) 142 | self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)]) 143 | 144 | self._feature_size = model_channels 145 | input_block_chans = [model_channels] 146 | ch = model_channels 147 | ds = 1 148 | k=64 149 | for level, mult in enumerate(channel_mult): 150 | for nr in range(self.num_res_blocks[level]): 151 | layers = [ 152 | ResBlock( 153 | ch, 154 | time_embed_dim, 155 | dropout, 156 | out_channels=mult * model_channels, 157 | dims=dims, 158 | use_checkpoint=use_checkpoint, 159 | use_scale_shift_norm=use_scale_shift_norm, 160 | ) 161 | ] 162 | ch = mult * model_channels 163 | if ds in attention_resolutions: 164 | if num_head_channels == -1: 165 | dim_head = ch // num_heads 166 | else: 167 | num_heads = ch // num_head_channels 168 | dim_head = num_head_channels 169 | if legacy: 170 | # num_heads = 1 171 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 172 | if exists(disable_self_attentions): 173 | disabled_sa = disable_self_attentions[level] 174 | else: 175 | disabled_sa = False 176 | 177 | if not exists(num_attention_blocks) or nr < num_attention_blocks[level]: 178 | layers.append( 179 | AttentionBlock( 180 | ch, 181 | use_checkpoint=use_checkpoint, 182 | num_heads=num_heads, 183 | num_head_channels=dim_head, 184 | use_new_attention_order=use_new_attention_order, 185 | ) if not use_spatial_transformer else SpatialTransformer( 186 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 187 | disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer, 188 | use_checkpoint=use_checkpoint,kernel_size=(64,64) 189 | ) 190 | ) 191 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 192 | self.zero_convs.append(self.make_zero_conv(ch)) 193 | self._feature_size += ch 194 | input_block_chans.append(ch) 195 | if level != len(channel_mult) - 1: 196 | out_ch = ch 197 | self.input_blocks.append( 198 | TimestepEmbedSequential( 199 | ResBlock( 200 | ch, 201 | time_embed_dim, 202 | dropout, 203 | out_channels=out_ch, 204 | dims=dims, 205 | use_checkpoint=use_checkpoint, 206 | use_scale_shift_norm=use_scale_shift_norm, 207 | down=True, 208 | ) 209 | if resblock_updown 210 | else Downsample( 211 | ch, conv_resample, dims=dims, out_channels=out_ch 212 | ) 213 | ) 214 | ) 215 | ch = out_ch 216 | input_block_chans.append(ch) 217 | self.zero_convs.append(self.make_zero_conv(ch)) 218 | ds *= 2 219 | k = k//2 220 | self._feature_size += ch 221 | 222 | if num_head_channels == -1: 223 | dim_head = ch // num_heads 224 | else: 225 | num_heads = ch // num_head_channels 226 | dim_head = num_head_channels 227 | if legacy: 228 | # num_heads = 1 229 | dim_head = ch // num_heads if use_spatial_transformer else num_head_channels 230 | self.middle_block = TimestepEmbedSequential( 231 | ResBlock( 232 | ch, 233 | time_embed_dim, 234 | dropout, 235 | dims=dims, 236 | use_checkpoint=use_checkpoint, 237 | use_scale_shift_norm=use_scale_shift_norm, 238 | ), 239 | AttentionBlock( 240 | ch, 241 | use_checkpoint=use_checkpoint, 242 | num_heads=num_heads, 243 | num_head_channels=dim_head, 244 | use_new_attention_order=use_new_attention_order, 245 | ) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn 246 | ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim, 247 | disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer, 248 | use_checkpoint=use_checkpoint,kernel_size=(64,64) 249 | ), 250 | ResBlock( 251 | ch, 252 | time_embed_dim, 253 | dropout, 254 | dims=dims, 255 | use_checkpoint=use_checkpoint, 256 | use_scale_shift_norm=use_scale_shift_norm, 257 | ), 258 | ) 259 | self.middle_block_out = self.make_zero_conv(ch) 260 | self._feature_size += ch 261 | 262 | def make_zero_conv(self, channels): 263 | return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) 264 | 265 | def forward(self, x, hint, timesteps, context, **kwargs): 266 | t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) 267 | emb = self.time_embed(t_emb) 268 | x = torch.cat((x, hint), dim=1) 269 | outs = [] 270 | 271 | h = x.type(self.dtype) 272 | for module, zero_conv in zip(self.input_blocks, self.zero_convs): 273 | h = module(h, emb, context) 274 | outs.append(zero_conv(h, emb, context)) 275 | 276 | h = self.middle_block(h, emb, context) 277 | outs.append(self.middle_block_out(h, emb, context)) 278 | 279 | return outs 280 | -------------------------------------------------------------------------------- /model/lkpn.py: -------------------------------------------------------------------------------- 1 | # ----------------------------------------------------------------------------------- 2 | # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 3 | # Originally Written by Ze Liu, Modified by Jingyun Liang. 4 | # ----------------------------------------------------------------------------------- 5 | 6 | # Originally borrowed from DifFace (https://github.com/zsyOAOA/DifFace/blob/master/models/swinir.py) 7 | 8 | import math 9 | from typing import Set 10 | 11 | import torch.nn.functional as F 12 | 13 | from model import AutoencoderKL 14 | from torch.autograd import Function 15 | import torch 16 | from torch.nn.modules.utils import _pair 17 | import torch.nn as nn 18 | 19 | from collections import namedtuple 20 | import cupy # idynamic implement is based on cupy-cuda 21 | from string import Template 22 | from omegaconf import OmegaConf 23 | from model.unet import UNetModel 24 | 25 | Stream = namedtuple('Stream', ['ptr']) 26 | 27 | 28 | def Dtype(t): 29 | if isinstance(t, torch.cuda.FloatTensor): 30 | return 'float' 31 | elif isinstance(t, torch.cuda.DoubleTensor): 32 | return 'double' 33 | 34 | 35 | @cupy._util.memoize(for_each_device=True) 36 | def load_kernel(kernel_name, code, **kwargs): 37 | code = Template(code).substitute(**kwargs) 38 | # kernel_code = cupy.cuda.compile_with_cache(code) 39 | return cupy.RawKernel(code, kernel_name) 40 | # return kernel_code.get_function(kernel_name) 41 | 42 | 43 | CUDA_NUM_THREADS = 512 44 | # if you use in 3090 and above, please set 1024 for the fastest calculation 45 | # CUDA_NUM_THREADS = 1024 # FIXME: cuda 46 | 47 | 48 | kernel_loop = ''' 49 | #define CUDA_KERNEL_LOOP(i, n) \ 50 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ 51 | i < (n); \ 52 | i += blockDim.x * gridDim.x) 53 | ''' 54 | 55 | 56 | def GET_BLOCKS(N): 57 | return (N + CUDA_NUM_THREADS - 1) // CUDA_NUM_THREADS 58 | 59 | 60 | _idynamic_kernel = kernel_loop + ''' 61 | extern "C" 62 | __global__ void idynamic_forward_kernel( 63 | const ${Dtype}* bottom_data, const ${Dtype}* weight_data, ${Dtype}* top_data) { 64 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 65 | const int n = index / ${channels} / ${top_height} / ${top_width}; 66 | const int c = (index / ${top_height} / ${top_width}) % ${channels}; 67 | const int h = (index / ${top_width}) % ${top_height}; 68 | const int w = index % ${top_width}; 69 | const int g = c / (${channels} / ${groups}); 70 | ${Dtype} value = 0; 71 | #pragma unroll 72 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 73 | #pragma unroll 74 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 75 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 76 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 77 | if ((h_in >= 0) && (h_in < ${bottom_height}) 78 | && (w_in >= 0) && (w_in < ${bottom_width})) { 79 | const int offset = ((n * ${channels} + c) * ${bottom_height} + h_in) 80 | * ${bottom_width} + w_in; 81 | const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h) 82 | * ${top_width} + w; 83 | value += weight_data[offset_weight] * bottom_data[offset]; 84 | } 85 | } 86 | } 87 | top_data[index] = value; 88 | } 89 | } 90 | ''' 91 | 92 | _idynamic_kernel_backward_grad_input = kernel_loop + ''' 93 | extern "C" 94 | __global__ void idynamic_backward_grad_input_kernel( 95 | const ${Dtype}* const top_diff, const ${Dtype}* const weight_data, ${Dtype}* const bottom_diff) { 96 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 97 | const int n = index / ${channels} / ${bottom_height} / ${bottom_width}; 98 | const int c = (index / ${bottom_height} / ${bottom_width}) % ${channels}; 99 | const int h = (index / ${bottom_width}) % ${bottom_height}; 100 | const int w = index % ${bottom_width}; 101 | const int g = c / (${channels} / ${groups}); 102 | ${Dtype} value = 0; 103 | #pragma unroll 104 | for (int kh = 0; kh < ${kernel_h}; ++kh) { 105 | #pragma unroll 106 | for (int kw = 0; kw < ${kernel_w}; ++kw) { 107 | const int h_out_s = h + ${pad_h} - kh * ${dilation_h}; 108 | const int w_out_s = w + ${pad_w} - kw * ${dilation_w}; 109 | if (((h_out_s % ${stride_h}) == 0) && ((w_out_s % ${stride_w}) == 0)) { 110 | const int h_out = h_out_s / ${stride_h}; 111 | const int w_out = w_out_s / ${stride_w}; 112 | if ((h_out >= 0) && (h_out < ${top_height}) 113 | && (w_out >= 0) && (w_out < ${top_width})) { 114 | const int offset = ((n * ${channels} + c) * ${top_height} + h_out) 115 | * ${top_width} + w_out; 116 | const int offset_weight = ((((n * ${groups} + g) * ${kernel_h} + kh) * ${kernel_w} + kw) * ${top_height} + h_out) 117 | * ${top_width} + w_out; 118 | value += weight_data[offset_weight] * top_diff[offset]; 119 | } 120 | } 121 | } 122 | } 123 | bottom_diff[index] = value; 124 | } 125 | } 126 | ''' 127 | 128 | _idynamic_kernel_backward_grad_weight = kernel_loop + ''' 129 | extern "C" 130 | __global__ void idynamic_backward_grad_weight_kernel( 131 | const ${Dtype}* const top_diff, const ${Dtype}* const bottom_data, ${Dtype}* const buffer_data) { 132 | CUDA_KERNEL_LOOP(index, ${nthreads}) { 133 | const int h = (index / ${top_width}) % ${top_height}; 134 | const int w = index % ${top_width}; 135 | const int kh = (index / ${kernel_w} / ${top_height} / ${top_width}) 136 | % ${kernel_h}; 137 | const int kw = (index / ${top_height} / ${top_width}) % ${kernel_w}; 138 | const int h_in = -${pad_h} + h * ${stride_h} + kh * ${dilation_h}; 139 | const int w_in = -${pad_w} + w * ${stride_w} + kw * ${dilation_w}; 140 | if ((h_in >= 0) && (h_in < ${bottom_height}) 141 | && (w_in >= 0) && (w_in < ${bottom_width})) { 142 | const int g = (index / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${groups}; 143 | const int n = (index / ${groups} / ${kernel_h} / ${kernel_w} / ${top_height} / ${top_width}) % ${num}; 144 | ${Dtype} value = 0; 145 | #pragma unroll 146 | for (int c = g * (${channels} / ${groups}); c < (g + 1) * (${channels} / ${groups}); ++c) { 147 | const int top_offset = ((n * ${channels} + c) * ${top_height} + h) 148 | * ${top_width} + w; 149 | const int bottom_offset = ((n * ${channels} + c) * ${bottom_height} + h_in) 150 | * ${bottom_width} + w_in; 151 | value += top_diff[top_offset] * bottom_data[bottom_offset]; 152 | } 153 | buffer_data[index] = value; 154 | } else { 155 | buffer_data[index] = 0; 156 | } 157 | } 158 | } 159 | ''' 160 | 161 | 162 | class _idynamic(Function): 163 | @staticmethod 164 | def forward(ctx, input, weight, stride, padding, dilation): 165 | assert input.dim() == 4 and input.is_cuda 166 | assert weight.dim() == 6 and weight.is_cuda 167 | batch_size, channels, height, width = input.size() 168 | kernel_h, kernel_w = weight.size()[2:4] 169 | output_h = int((height + 2 * padding[0] - (dilation[0] * (kernel_h - 1) + 1)) / stride[0] + 1) 170 | output_w = int((width + 2 * padding[1] - (dilation[1] * (kernel_w - 1) + 1)) / stride[1] + 1) 171 | 172 | output = input.new(batch_size, channels, output_h, output_w) 173 | n = output.numel() 174 | 175 | with torch.cuda.device_of(input): 176 | f = load_kernel('idynamic_forward_kernel', _idynamic_kernel, Dtype=Dtype(input), nthreads=n, 177 | num=batch_size, channels=channels, groups=weight.size()[1], 178 | bottom_height=height, bottom_width=width, 179 | top_height=output_h, top_width=output_w, 180 | kernel_h=kernel_h, kernel_w=kernel_w, 181 | stride_h=stride[0], stride_w=stride[1], 182 | dilation_h=dilation[0], dilation_w=dilation[1], 183 | pad_h=padding[0], pad_w=padding[1]) 184 | f(block=(CUDA_NUM_THREADS, 1, 1), 185 | grid=(GET_BLOCKS(n), 1, 1), 186 | args=[input.data_ptr(), weight.data_ptr(), output.data_ptr()], 187 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 188 | 189 | ctx.save_for_backward(input, weight) 190 | ctx.stride, ctx.padding, ctx.dilation = stride, padding, dilation 191 | return output 192 | 193 | @staticmethod 194 | def backward(ctx, grad_output): 195 | assert grad_output.is_cuda 196 | if not grad_output.is_contiguous(): 197 | grad_output.contiguous() 198 | input, weight = ctx.saved_tensors 199 | stride, padding, dilation = ctx.stride, ctx.padding, ctx.dilation 200 | 201 | batch_size, channels, height, width = input.size() 202 | kernel_h, kernel_w = weight.size()[2:4] 203 | output_h, output_w = grad_output.size()[2:] 204 | 205 | grad_input, grad_weight = None, None 206 | 207 | opt = dict(Dtype=Dtype(grad_output), 208 | num=batch_size, channels=channels, groups=weight.size()[1], 209 | bottom_height=height, bottom_width=width, 210 | top_height=output_h, top_width=output_w, 211 | kernel_h=kernel_h, kernel_w=kernel_w, 212 | stride_h=stride[0], stride_w=stride[1], 213 | dilation_h=dilation[0], dilation_w=dilation[1], 214 | pad_h=padding[0], pad_w=padding[1]) 215 | 216 | with torch.cuda.device_of(input): 217 | if ctx.needs_input_grad[0]: 218 | grad_input = input.new(input.size()) 219 | 220 | n = grad_input.numel() 221 | opt['nthreads'] = n 222 | 223 | f = load_kernel('idynamic_backward_grad_input_kernel', 224 | _idynamic_kernel_backward_grad_input, **opt) 225 | f(block=(CUDA_NUM_THREADS, 1, 1), 226 | grid=(GET_BLOCKS(n), 1, 1), 227 | args=[grad_output.data_ptr(), weight.data_ptr(), grad_input.data_ptr()], 228 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 229 | 230 | if ctx.needs_input_grad[1]: 231 | grad_weight = weight.new(weight.size()) 232 | 233 | n = grad_weight.numel() 234 | opt['nthreads'] = n 235 | 236 | f = load_kernel('idynamic_backward_grad_weight_kernel', 237 | _idynamic_kernel_backward_grad_weight, **opt) 238 | f(block=(CUDA_NUM_THREADS, 1, 1), 239 | grid=(GET_BLOCKS(n), 1, 1), 240 | args=[grad_output.data_ptr(), input.data_ptr(), grad_weight.data_ptr()], 241 | stream=Stream(ptr=torch.cuda.current_stream().cuda_stream)) 242 | 243 | return grad_input, grad_weight, None, None, None 244 | 245 | 246 | def _idynamic_cuda(input, weight, bias=None, stride=1, padding=0, dilation=1): 247 | """ idynamic kernel 248 | """ 249 | assert input.size(0) == weight.size(0) 250 | assert input.size(-2) // stride == weight.size(-2) 251 | assert input.size(-1) // stride == weight.size(-1) 252 | if input.is_cuda: 253 | out = _idynamic.apply(input, weight, _pair(stride), _pair(padding), _pair(dilation)) 254 | if bias is not None: 255 | out += bias.view(1, -1, 1, 1) 256 | else: 257 | raise NotImplementedError 258 | return out 259 | 260 | 261 | class IDynamicConv(nn.Module): 262 | """ 263 | IDynamicDWConv: HyperNet for the weight of DynamicDWConv 264 | """ 265 | 266 | def __init__(self): 267 | """ 268 | code based on github: https://github.com/Atten4Vis/DemystifyLocalViT 269 | :param channels: the feature 270 | :param kernel_size: as window_size 271 | :param group_channels: as num_heads 272 | :param bias: bias for the conv in the HyperNet; Default: True 273 | """ 274 | super(IDynamicConv, self).__init__() 275 | self.kernel_size = 5 276 | # self.channels = channels 277 | 278 | def forward(self, x, weight): 279 | b, c, h, w = weight.shape 280 | 281 | weight = weight.view(b, c // (self.kernel_size * self.kernel_size), self.kernel_size, self.kernel_size, h, w) 282 | out = _idynamic_cuda(x, weight, stride=1, padding=(self.kernel_size - 1) // 2) 283 | return out 284 | 285 | 286 | class LKPN(nn.Module): 287 | 288 | def __init__(self) -> "LKPN": 289 | super(LKPN, self).__init__() 290 | 291 | 292 | self.unet = UNetModel(use_checkpoint=True, image_size=32, in_channels=8, out_channels=4*5*5, model_channels=128, 293 | attention_resolutions=[4, 2, 1], 294 | num_res_blocks=2, channel_mult=[1, 2, 4, 4], num_head_channels=64, 295 | use_spatial_transformer=True, use_linear_in_transformer=True, 296 | transformer_depth=1, context_dim=1024, legacy=False) 297 | 298 | self.idy_conv = IDynamicConv() 299 | 300 | 301 | def forward(self, x, hint,timesteps, context): 302 | 303 | merge = torch.cat([x, hint], dim=1) 304 | 305 | kernel = self.unet(x=merge,timesteps=timesteps,context=context) 306 | 307 | result = self.idy_conv(x, kernel)+x 308 | 309 | return result 310 | 311 | -------------------------------------------------------------------------------- /model/attention.py: -------------------------------------------------------------------------------- 1 | from packaging import version 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from einops import rearrange, repeat 6 | from typing import Optional, Any 7 | 8 | from model.util import ( 9 | checkpoint, zero_module, exists, default 10 | ) 11 | from model.config import Config, AttnMode 12 | 13 | # CrossAttn precision handling 14 | import os 15 | 16 | _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32") 17 | 18 | 19 | # feedforward 20 | class GEGLU(nn.Module): 21 | def __init__(self, dim_in, dim_out): 22 | super().__init__() 23 | self.proj = nn.Linear(dim_in, dim_out * 2) 24 | 25 | def forward(self, x): 26 | x, gate = self.proj(x).chunk(2, dim=-1) 27 | return x * F.gelu(gate) 28 | 29 | 30 | class FeedForward(nn.Module): 31 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 32 | super().__init__() 33 | inner_dim = int(dim * mult) 34 | dim_out = default(dim_out, dim) 35 | project_in = nn.Sequential( 36 | nn.Linear(dim, inner_dim), 37 | nn.GELU() 38 | ) if not glu else GEGLU(dim, inner_dim) 39 | 40 | self.net = nn.Sequential( 41 | project_in, 42 | nn.Dropout(dropout), 43 | nn.Linear(inner_dim, dim_out) 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | def Normalize(in_channels): 51 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 52 | 53 | 54 | class CrossAttention(nn.Module): 55 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 56 | super().__init__() 57 | # print( 58 | # f"Setting up {self.__class__.__name__} (vanilla). Query dim is {query_dim}, context_dim is {context_dim} and using " 59 | # f"{heads} heads.") 60 | inner_dim = dim_head * heads 61 | context_dim = default(context_dim, query_dim) 62 | 63 | self.scale = dim_head ** -0.5 64 | self.heads = heads 65 | 66 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 67 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 68 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 69 | 70 | self.to_out = nn.Sequential( 71 | nn.Linear(inner_dim, query_dim), 72 | nn.Dropout(dropout) 73 | ) 74 | 75 | def forward(self, x, context=None, mask=None): 76 | h = self.heads 77 | 78 | q = self.to_q(x) 79 | context = default(context, x) 80 | k = self.to_k(context) 81 | v = self.to_v(context) 82 | 83 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 84 | 85 | # force cast to fp32 to avoid overflowing 86 | if _ATTN_PRECISION == "fp32": 87 | # with torch.autocast(enabled=False, device_type = 'cuda'): 88 | with torch.autocast(enabled=False, device_type="cuda" if str(x.device).startswith("cuda") else "cpu"): 89 | q, k = q.float(), k.float() 90 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 91 | else: 92 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 93 | 94 | del q, k 95 | 96 | if exists(mask): 97 | mask = rearrange(mask, 'b ... -> b (...)') 98 | max_neg_value = -torch.finfo(sim.dtype).max 99 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 100 | sim.masked_fill_(~mask, max_neg_value) 101 | 102 | # attention, what we cannot get enough of 103 | sim = sim.softmax(dim=-1) 104 | 105 | out = einsum('b i j, b j d -> b i d', sim, v) 106 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 107 | return self.to_out(out) 108 | 109 | 110 | class MemoryEfficientCrossAttention(nn.Module): 111 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 112 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 113 | super().__init__() 114 | # print( 115 | # f"Setting up {self.__class__.__name__} (xformers). Query dim is {query_dim}, context_dim is {context_dim} and using " 116 | # f"{heads} heads.") 117 | inner_dim = dim_head * heads 118 | context_dim = default(context_dim, query_dim) 119 | 120 | self.heads = heads 121 | self.dim_head = dim_head 122 | 123 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 124 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 125 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 126 | 127 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 128 | self.attention_op: Optional[Any] = None 129 | 130 | def forward(self, x, context=None, mask=None): 131 | q = self.to_q(x) 132 | context = default(context, x) 133 | k = self.to_k(context) 134 | v = self.to_v(context) 135 | 136 | b, _, _ = q.shape 137 | q, k, v = map( 138 | lambda t: t.unsqueeze(3) 139 | .reshape(b, t.shape[1], self.heads, self.dim_head) 140 | .permute(0, 2, 1, 3) 141 | .reshape(b * self.heads, t.shape[1], self.dim_head) 142 | .contiguous(), 143 | (q, k, v), 144 | ) 145 | 146 | # actually compute the attention, what we cannot get enough of 147 | out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 148 | 149 | if exists(mask): 150 | raise NotImplementedError 151 | out = ( 152 | out.unsqueeze(0) 153 | .reshape(b, self.heads, out.shape[1], self.dim_head) 154 | .permute(0, 2, 1, 3) 155 | .reshape(b, out.shape[1], self.heads * self.dim_head) 156 | ) 157 | return self.to_out(out) 158 | 159 | 160 | class SDPCrossAttention(nn.Module): 161 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 162 | super().__init__() 163 | # print( 164 | # f"Setting up {self.__class__.__name__} (sdp). Query dim is {query_dim}, context_dim is {context_dim} and using " 165 | # f"{heads} heads.") 166 | inner_dim = dim_head * heads 167 | context_dim = default(context_dim, query_dim) 168 | 169 | self.heads = heads 170 | self.dim_head = dim_head 171 | 172 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 173 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 174 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 175 | 176 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 177 | 178 | def forward(self, x, context=None, mask=None): 179 | q = self.to_q(x) 180 | b,_,_ = q.shape 181 | context = default(context, x) 182 | k = self.to_k(context) 183 | v = self.to_v(context) 184 | bk,_,_ = k.shape 185 | bv,_,_ = v.shape 186 | 187 | k = k.repeat(b//bk, 1, 1) 188 | v = v.repeat(b//bv, 1, 1) 189 | 190 | 191 | q, k, v = map( 192 | lambda t: t.unsqueeze(3) 193 | .reshape(b, t.shape[1], self.heads, self.dim_head) 194 | .permute(0, 2, 1, 3) 195 | .reshape(b * self.heads, t.shape[1], self.dim_head) 196 | .contiguous(), 197 | (q, k, v), 198 | ) 199 | 200 | # actually compute the attention, what we cannot get enough of 201 | out = F.scaled_dot_product_attention(q, k, v) 202 | 203 | if exists(mask): 204 | raise NotImplementedError 205 | out = ( 206 | out.unsqueeze(0) 207 | .reshape(b, self.heads, out.shape[1], self.dim_head) 208 | .permute(0, 2, 1, 3) 209 | .reshape(b, out.shape[1], self.heads * self.dim_head) 210 | ) 211 | return self.to_out(out) 212 | 213 | 214 | class BasicTransformerBlock(nn.Module): 215 | ATTENTION_MODES = { 216 | AttnMode.VANILLA: CrossAttention, # vanilla attention 217 | AttnMode.XFORMERS: MemoryEfficientCrossAttention, 218 | AttnMode.SDP: SDPCrossAttention 219 | } 220 | 221 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 222 | disable_self_attn=False): 223 | super().__init__() 224 | attn_cls = self.ATTENTION_MODES[Config.attn_mode] 225 | self.disable_self_attn = disable_self_attn 226 | self.attn1 = attn_cls(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 227 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 228 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 229 | self.attn2 = attn_cls(query_dim=dim, context_dim=context_dim, 230 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 231 | self.norm1 = nn.LayerNorm(dim) 232 | self.norm2 = nn.LayerNorm(dim) 233 | self.norm3 = nn.LayerNorm(dim) 234 | self.checkpoint = checkpoint 235 | 236 | def forward(self, x, context=None): 237 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 238 | 239 | def _forward(self, x, context=None): 240 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 241 | x = self.attn2(self.norm2(x), context=context) + x 242 | x = self.ff(self.norm3(x)) + x 243 | return x 244 | 245 | 246 | class SpatialTransformer(nn.Module): 247 | """ 248 | Transformer block for image-like data. 249 | First, project the input (aka embedding) 250 | and reshape to b, t, d. 251 | Then apply standard transformer action. 252 | Finally, reshape to image 253 | NEW: use_linear for more efficiency instead of the 1x1 convs 254 | """ 255 | 256 | def __init__(self, in_channels, n_heads, d_head, 257 | depth=1, dropout=0., context_dim=None, 258 | disable_self_attn=False, use_linear=False, 259 | use_checkpoint=True, kernel_size=(64, 64)): 260 | super().__init__() 261 | if exists(context_dim) and not isinstance(context_dim, list): 262 | context_dim = [context_dim] 263 | self.in_channels = in_channels 264 | inner_dim = n_heads * d_head 265 | self.norm = Normalize(in_channels) 266 | if not use_linear: 267 | self.proj_in = nn.Conv2d(in_channels, 268 | inner_dim, 269 | kernel_size=1, 270 | stride=1, 271 | padding=0) 272 | else: 273 | self.proj_in = nn.Linear(in_channels, inner_dim) 274 | 275 | self.transformer_blocks = nn.ModuleList( 276 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d], 277 | disable_self_attn=disable_self_attn, checkpoint=use_checkpoint) 278 | for d in range(depth)] 279 | ) 280 | if not use_linear: 281 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 282 | in_channels, 283 | kernel_size=1, 284 | stride=1, 285 | padding=0)) 286 | else: 287 | self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) 288 | self.use_linear = use_linear 289 | 290 | self.kernel_size = kernel_size 291 | 292 | def grids(self, x): 293 | b, c, h, w = x.shape 294 | self.original_size = (b, c, h, w) 295 | assert b == 1 296 | k1, k2 = self.kernel_size 297 | k1 = min(h, k1) 298 | k2 = min(w, k2) 299 | num_row = (h - 1) // k1 + 1 300 | num_col = (w - 1) // k2 + 1 301 | self.nr = num_row 302 | self.nc = num_col 303 | 304 | import math 305 | step_j = k2 if num_col == 1 else math.ceil((w - k2) / (num_col - 1) - 1e-8) 306 | step_i = k1 if num_row == 1 else math.ceil((h - k1) / (num_row - 1) - 1e-8) 307 | 308 | parts = [] 309 | idxes = [] 310 | i = 0 # 0~h-1 311 | last_i = False 312 | while i < h and not last_i: 313 | j = 0 314 | if i + k1 >= h: 315 | i = h - k1 316 | last_i = True 317 | last_j = False 318 | while j < w and not last_j: 319 | if j + k2 >= w: 320 | j = w - k2 321 | last_j = True 322 | parts.append(x[:, :, i:i + k1, j:j + k2]) 323 | idxes.append({'i': i, 'j': j}) 324 | j = j + step_j 325 | i = i + step_i 326 | 327 | parts = torch.cat(parts, dim=0) 328 | self.idxes = idxes 329 | return parts 330 | 331 | def grids_inverse(self, outs): 332 | preds = torch.zeros(self.original_size).to(outs.device) 333 | b, c, h, w = self.original_size 334 | 335 | count_mt = torch.zeros((b, 1, h, w)).to(outs.device) 336 | k1, k2 = self.kernel_size 337 | k1 = min(h, k1) 338 | k2 = min(w, k2) 339 | 340 | for cnt, each_idx in enumerate(self.idxes): 341 | i = each_idx['i'] 342 | j = each_idx['j'] 343 | preds[0, :, i:i + k1, j:j + k2] += outs[cnt, :, :, :] 344 | count_mt[0, 0, i:i + k1, j:j + k2] += 1. 345 | 346 | del outs 347 | torch.cuda.empty_cache() 348 | return preds / count_mt 349 | 350 | def forward(self, x, context=None): 351 | # note: if no context is given, cross-attention defaults to self-attention 352 | if not isinstance(context, list): 353 | context = [context] 354 | # b, c, h, w = x.shape 355 | x_in = x 356 | 357 | 358 | x = self.grids(x) 359 | 360 | 361 | x = self.norm(x) 362 | b, c, h, w = x.shape 363 | 364 | if not self.use_linear: 365 | x = self.proj_in(x) 366 | 367 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 368 | if self.use_linear: 369 | x = self.proj_in(x) 370 | for i, block in enumerate(self.transformer_blocks): 371 | x = block(x, context=context[i]) 372 | if self.use_linear: 373 | x = self.proj_out(x) 374 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 375 | if not self.use_linear: 376 | x = self.proj_out(x) 377 | 378 | x = self.grids_inverse(x) 379 | 380 | return x + x_in 381 | -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple, Dict 2 | 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from model.gaussian_diffusion import extract_into_tensor 9 | from model.cldm import ControlLDM 10 | from utils.cond_fn import Guidance 11 | from utils.common import sliding_windows, gaussian_weights 12 | 13 | 14 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py 15 | def space_timesteps(num_timesteps, section_counts): 16 | """ 17 | Create a list of timesteps to use from an original diffusion process, 18 | given the number of timesteps we want to take from equally-sized portions 19 | of the original process. 20 | For example, if there's 300 timesteps and the section counts are [10,15,20] 21 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 22 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 23 | If the stride is a string starting with "ddim", then the fixed striding 24 | from the DDIM paper is used, and only one section is allowed. 25 | :param num_timesteps: the number of diffusion steps in the original 26 | process to divide up. 27 | :param section_counts: either a list of numbers, or a string containing 28 | comma-separated numbers, indicating the step count 29 | per section. As a special case, use "ddimN" where N 30 | is a number of steps to use the striding from the 31 | DDIM paper. 32 | :return: a set of diffusion steps from the original process to use. 33 | """ 34 | if isinstance(section_counts, str): 35 | if section_counts.startswith("ddim"): 36 | desired_count = int(section_counts[len("ddim"):]) 37 | for i in range(1, num_timesteps): 38 | if len(range(0, num_timesteps, i)) == desired_count: 39 | return set(range(0, num_timesteps, i)) 40 | raise ValueError( 41 | f"cannot create exactly {num_timesteps} steps with an integer stride" 42 | ) 43 | section_counts = [int(x) for x in section_counts.split(",")] 44 | size_per = num_timesteps // len(section_counts) 45 | extra = num_timesteps % len(section_counts) 46 | start_idx = 0 47 | all_steps = [] 48 | for i, section_count in enumerate(section_counts): 49 | size = size_per + (1 if i < extra else 0) 50 | if size < section_count: 51 | raise ValueError( 52 | f"cannot divide section of {size} steps into {section_count}" 53 | ) 54 | if section_count <= 1: 55 | frac_stride = 1 56 | else: 57 | frac_stride = (size - 1) / (section_count - 1) 58 | cur_idx = 0.0 59 | taken_steps = [] 60 | for _ in range(section_count): 61 | taken_steps.append(start_idx + round(cur_idx)) 62 | cur_idx += frac_stride 63 | all_steps += taken_steps 64 | start_idx += size 65 | return set(all_steps) 66 | 67 | 68 | class SpacedSampler(nn.Module): 69 | """ 70 | Implementation for spaced sampling schedule proposed in IDDPM. This class is designed 71 | for sampling ControlLDM. 72 | 73 | https://arxiv.org/pdf/2102.09672.pdf 74 | """ 75 | 76 | def __init__(self, betas: np.ndarray) -> "SpacedSampler": 77 | super().__init__() 78 | self.num_timesteps = len(betas) 79 | self.original_betas = betas 80 | self.original_alphas_cumprod = np.cumprod(1.0 - betas, axis=0) 81 | self.context = {} 82 | 83 | def register(self, name: str, value: np.ndarray) -> None: 84 | self.register_buffer(name, torch.tensor(value, dtype=torch.float32)) 85 | 86 | def make_schedule(self, num_steps: int) -> None: 87 | # calcualte betas for spaced sampling 88 | # https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/respace.py 89 | used_timesteps = space_timesteps(self.num_timesteps, str(num_steps)) 90 | betas = [] 91 | last_alpha_cumprod = 1.0 92 | for i, alpha_cumprod in enumerate(self.original_alphas_cumprod): 93 | if i in used_timesteps: 94 | # marginal distribution is the same as q(x_{S_t}|x_0) 95 | betas.append(1 - alpha_cumprod / last_alpha_cumprod) 96 | last_alpha_cumprod = alpha_cumprod 97 | assert len(betas) == num_steps 98 | self.timesteps = np.array(sorted(list(used_timesteps)), dtype=np.int32) # e.g. [0, 10, 20, ...] 99 | 100 | betas = np.array(betas, dtype=np.float64) 101 | alphas = 1.0 - betas 102 | alphas_cumprod = np.cumprod(alphas, axis=0) 103 | # print(f"sampler sqrt_alphas_cumprod: {np.sqrt(alphas_cumprod)[-1]}") 104 | alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) 105 | sqrt_recip_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod) 106 | sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / alphas_cumprod - 1) 107 | # calculations for posterior q(x_{t-1} | x_t, x_0) 108 | posterior_variance = ( 109 | betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) 110 | ) 111 | # log calculation clipped because the posterior variance is 0 at the 112 | # beginning of the diffusion chain. 113 | posterior_log_variance_clipped = np.log( 114 | np.append(posterior_variance[1], posterior_variance[1:]) 115 | ) 116 | posterior_mean_coef1 = ( 117 | betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod) 118 | ) 119 | posterior_mean_coef2 = ( 120 | (1.0 - alphas_cumprod_prev) 121 | * np.sqrt(alphas) 122 | / (1.0 - alphas_cumprod) 123 | ) 124 | 125 | self.register("sqrt_recip_alphas_cumprod", sqrt_recip_alphas_cumprod) 126 | self.register("sqrt_recipm1_alphas_cumprod", sqrt_recipm1_alphas_cumprod) 127 | self.register("posterior_variance", posterior_variance) 128 | self.register("posterior_log_variance_clipped", posterior_log_variance_clipped) 129 | self.register("posterior_mean_coef1", posterior_mean_coef1) 130 | self.register("posterior_mean_coef2", posterior_mean_coef2) 131 | 132 | def q_posterior_mean_variance(self, x_start: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> Tuple[ 133 | torch.Tensor]: 134 | """ 135 | Implement the posterior distribution q(x_{t-1}|x_t, x_0). 136 | 137 | Args: 138 | x_start (torch.Tensor): The predicted images (NCHW) in timestep `t`. 139 | x_t (torch.Tensor): The sampled intermediate variables (NCHW) of timestep `t`. 140 | t (torch.Tensor): Timestep (N) of `x_t`. `t` serves as an index to get 141 | parameters for each timestep. 142 | 143 | Returns: 144 | posterior_mean (torch.Tensor): Mean of the posterior distribution. 145 | posterior_variance (torch.Tensor): Variance of the posterior distribution. 146 | posterior_log_variance_clipped (torch.Tensor): Log variance of the posterior distribution. 147 | """ 148 | posterior_mean = ( 149 | extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 150 | + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 151 | ) 152 | posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) 153 | posterior_log_variance_clipped = extract_into_tensor( 154 | self.posterior_log_variance_clipped, t, x_t.shape 155 | ) 156 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 157 | 158 | def _predict_xstart_from_eps(self, x_t: torch.Tensor, t: torch.Tensor, eps: torch.Tensor) -> torch.Tensor: 159 | return ( 160 | extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 161 | - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 162 | ) 163 | 164 | def apply_cond_fn( 165 | self, 166 | model: ControlLDM, 167 | pred_x0: torch.Tensor, 168 | t: torch.Tensor, 169 | index: torch.Tensor, 170 | cond_fn: Guidance 171 | ) -> torch.Tensor: 172 | t_now = int(t[0].item()) + 1 173 | if not (cond_fn.t_stop < t_now and t_now < cond_fn.t_start): 174 | # stop guidance 175 | self.context["g_apply"] = False 176 | return pred_x0 177 | grad_rescale = 1 / extract_into_tensor(self.posterior_mean_coef1, index, pred_x0.shape) 178 | # apply guidance for multiple times 179 | loss_vals = [] 180 | for _ in range(cond_fn.repeat): 181 | # set target and pred for gradient computation 182 | target, pred = None, None 183 | if cond_fn.space == "latent": 184 | target = model.vae_encode(cond_fn.target) 185 | pred = pred_x0 186 | elif cond_fn.space == "rgb": 187 | # We need to backward gradient to x0 in latent space, so it's required 188 | # to trace the computation graph while decoding the latent. 189 | with torch.enable_grad(): 190 | target = cond_fn.target 191 | pred_x0_rg = pred_x0.detach().clone().requires_grad_(True) 192 | pred = model.vae_decode(pred_x0_rg) 193 | assert pred.requires_grad 194 | else: 195 | raise NotImplementedError(cond_fn.space) 196 | # compute gradient 197 | delta_pred, loss_val = cond_fn(target, pred, t_now) 198 | loss_vals.append(loss_val) 199 | # update pred_x0 w.r.t gradient 200 | if cond_fn.space == "latent": 201 | delta_pred_x0 = delta_pred 202 | pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale 203 | elif cond_fn.space == "rgb": 204 | pred.backward(delta_pred) 205 | delta_pred_x0 = pred_x0_rg.grad 206 | pred_x0 = pred_x0 + delta_pred_x0 * grad_rescale 207 | else: 208 | raise NotImplementedError(cond_fn.space) 209 | self.context["g_apply"] = True 210 | self.context["g_loss"] = float(np.mean(loss_vals)) 211 | return pred_x0 212 | 213 | def predict_noise( 214 | self, 215 | model: ControlLDM, 216 | x: torch.Tensor, 217 | t: torch.Tensor, 218 | cond: Dict[str, torch.Tensor], 219 | uncond: Optional[Dict[str, torch.Tensor]], 220 | cfg_scale: float 221 | ) -> torch.Tensor: 222 | if uncond is None or cfg_scale == 1.: 223 | model_output, kpn = model(x, t, cond) 224 | else: 225 | # apply classifier-free guidance 226 | model_cond,kpn = model(x, t, cond) 227 | model_uncond,_ = model(x, t, uncond) 228 | model_output = model_uncond + cfg_scale * (model_cond - model_uncond) 229 | return model_output, kpn 230 | 231 | @torch.no_grad() 232 | def predict_noise_tiled( 233 | self, 234 | model: ControlLDM, 235 | x: torch.Tensor, 236 | t: torch.Tensor, 237 | cond: Dict[str, torch.Tensor], 238 | uncond: Optional[Dict[str, torch.Tensor]], 239 | cfg_scale: float, 240 | tile_size: int, 241 | tile_stride: int 242 | ): 243 | _, _, h, w = x.shape 244 | tiles = tqdm(sliding_windows(h, w, tile_size // 8, tile_stride // 8), unit="tile", leave=False) 245 | eps = torch.zeros_like(x) 246 | kpn = torch.zeros_like(x) 247 | count = torch.zeros_like(x, dtype=torch.float32) 248 | weights = gaussian_weights(tile_size // 8, tile_size // 8)[None, None] 249 | weights = torch.tensor(weights, dtype=torch.float32, device=x.device) 250 | for hi, hi_end, wi, wi_end in tiles: 251 | tiles.set_description(f"Process tile ({hi} {hi_end}), ({wi} {wi_end})") 252 | tile_x = x[:, :, hi:hi_end, wi:wi_end] 253 | tile_cond = { 254 | "c_img": cond["c_img"][:, :, hi:hi_end, wi:wi_end], 255 | "c_txt": cond["c_txt"] 256 | } 257 | if uncond: 258 | tile_uncond = { 259 | "c_img": uncond["c_img"][:, :, hi:hi_end, wi:wi_end], 260 | "c_txt": uncond["c_txt"] 261 | } 262 | tile_eps,tile_kpn = self.predict_noise(model, tile_x, t, tile_cond, tile_uncond, cfg_scale) 263 | kpn[:, :, hi:hi_end, wi:wi_end] = tile_kpn 264 | # accumulate noise 265 | eps[:, :, hi:hi_end, wi:wi_end] += tile_eps * weights 266 | count[:, :, hi:hi_end, wi:wi_end] += weights 267 | # average on noise (score) 268 | eps.div_(count) 269 | return eps,kpn 270 | 271 | @torch.no_grad() 272 | def p_sample( 273 | self, 274 | model: ControlLDM, 275 | x: torch.Tensor, 276 | t: torch.Tensor, 277 | index: torch.Tensor, 278 | cond: Dict[str, torch.Tensor], 279 | uncond: Optional[Dict[str, torch.Tensor]], 280 | cfg_scale: float, 281 | cond_fn: Optional[Guidance], 282 | tiled: bool, 283 | tile_size: int, 284 | tile_stride: int 285 | ) -> torch.Tensor: 286 | if tiled: 287 | eps, kpn = self.predict_noise_tiled(model, x, t, cond, uncond, cfg_scale, tile_size, tile_stride) 288 | else: 289 | eps, kpn = self.predict_noise(model, x, t, cond, uncond, cfg_scale) 290 | pred_x0 = self._predict_xstart_from_eps(x, index, eps) 291 | if cond_fn: 292 | assert not tiled, f"tiled sampling currently doesn't support guidance" 293 | pred_x0 = self.apply_cond_fn(model, pred_x0, t, index, cond_fn) 294 | model_mean, model_variance, _ = self.q_posterior_mean_variance(pred_x0, x, index) 295 | noise = torch.randn_like(x) 296 | nonzero_mask = ( 297 | (index != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 298 | ) 299 | x_prev = model_mean + nonzero_mask * torch.sqrt(model_variance) * noise 300 | return x_prev, kpn 301 | 302 | @torch.no_grad() 303 | def sample( 304 | self, 305 | model: ControlLDM, 306 | device: str, 307 | steps: int, 308 | batch_size: int, 309 | x_size: Tuple[int], 310 | cond: Dict[str, torch.Tensor], 311 | uncond: Dict[str, torch.Tensor], 312 | cfg_scale: float, 313 | cond_fn: Optional[Guidance] = None, 314 | tiled: bool = False, 315 | tile_size: int = -1, 316 | tile_stride: int = -1, 317 | x_T: Optional[torch.Tensor] = None, 318 | progress: bool = True, 319 | progress_leave: bool = True, 320 | ) -> torch.Tensor: 321 | self.make_schedule(steps) 322 | self.to(device) 323 | if x_T is None: 324 | # TODO: not convert to float32, may trigger an error 325 | img = torch.randn((batch_size, *x_size), device=device) 326 | else: 327 | img = x_T 328 | timesteps = np.flip(self.timesteps) # [1000, 950, 900, ...] 329 | total_steps = len(self.timesteps) 330 | iterator = tqdm(timesteps, total=total_steps, leave=progress_leave, disable=not progress) 331 | for i, step in enumerate(iterator): 332 | ts = torch.full((batch_size,), step, device=device, dtype=torch.long) 333 | index = torch.full_like(ts, fill_value=total_steps - i - 1) 334 | img, kpn = self.p_sample( 335 | model, img, ts, index, cond, uncond, cfg_scale, cond_fn, 336 | tiled, tile_size, tile_stride 337 | ) 338 | if cond_fn and self.context["g_apply"]: 339 | loss_val = self.context["g_loss"] 340 | desc = f"Spaced Sampler With Guidance, Loss: {loss_val:.6f}" 341 | else: 342 | desc = "Spaced Sampler" 343 | iterator.set_description(desc) 344 | return img -------------------------------------------------------------------------------- /model/vae.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from einops import rearrange 7 | from typing import Optional, Any 8 | 9 | from model.distributions import DiagonalGaussianDistribution 10 | from model.config import Config, AttnMode 11 | 12 | 13 | def nonlinearity(x): 14 | # swish 15 | return x*torch.sigmoid(x) 16 | 17 | 18 | def Normalize(in_channels, num_groups=32): 19 | return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) 20 | 21 | 22 | class Upsample(nn.Module): 23 | def __init__(self, in_channels, with_conv): 24 | super().__init__() 25 | self.with_conv = with_conv 26 | if self.with_conv: 27 | self.conv = torch.nn.Conv2d(in_channels, 28 | in_channels, 29 | kernel_size=3, 30 | stride=1, 31 | padding=1) 32 | 33 | def forward(self, x): 34 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 35 | if self.with_conv: 36 | x = self.conv(x) 37 | return x 38 | 39 | 40 | class Downsample(nn.Module): 41 | def __init__(self, in_channels, with_conv): 42 | super().__init__() 43 | self.with_conv = with_conv 44 | if self.with_conv: 45 | # no asymmetric padding in torch conv, must do it ourselves 46 | self.conv = torch.nn.Conv2d(in_channels, 47 | in_channels, 48 | kernel_size=3, 49 | stride=2, 50 | padding=0) 51 | 52 | def forward(self, x): 53 | if self.with_conv: 54 | pad = (0,1,0,1) 55 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 56 | x = self.conv(x) 57 | else: 58 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 59 | return x 60 | 61 | 62 | class ResnetBlock(nn.Module): 63 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 64 | dropout, temb_channels=512): 65 | super().__init__() 66 | self.in_channels = in_channels 67 | out_channels = in_channels if out_channels is None else out_channels 68 | self.out_channels = out_channels 69 | self.use_conv_shortcut = conv_shortcut 70 | 71 | self.norm1 = Normalize(in_channels) 72 | self.conv1 = torch.nn.Conv2d(in_channels, 73 | out_channels, 74 | kernel_size=3, 75 | stride=1, 76 | padding=1) 77 | if temb_channels > 0: 78 | self.temb_proj = torch.nn.Linear(temb_channels, 79 | out_channels) 80 | self.norm2 = Normalize(out_channels) 81 | self.dropout = torch.nn.Dropout(dropout) 82 | self.conv2 = torch.nn.Conv2d(out_channels, 83 | out_channels, 84 | kernel_size=3, 85 | stride=1, 86 | padding=1) 87 | if self.in_channels != self.out_channels: 88 | if self.use_conv_shortcut: 89 | self.conv_shortcut = torch.nn.Conv2d(in_channels, 90 | out_channels, 91 | kernel_size=3, 92 | stride=1, 93 | padding=1) 94 | else: 95 | self.nin_shortcut = torch.nn.Conv2d(in_channels, 96 | out_channels, 97 | kernel_size=1, 98 | stride=1, 99 | padding=0) 100 | 101 | def forward(self, x, temb): 102 | h = x 103 | h = self.norm1(h) 104 | h = nonlinearity(h) 105 | h = self.conv1(h) 106 | 107 | if temb is not None: 108 | h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] 109 | 110 | h = self.norm2(h) 111 | h = nonlinearity(h) 112 | h = self.dropout(h) 113 | h = self.conv2(h) 114 | 115 | if self.in_channels != self.out_channels: 116 | if self.use_conv_shortcut: 117 | x = self.conv_shortcut(x) 118 | else: 119 | x = self.nin_shortcut(x) 120 | 121 | return x+h 122 | 123 | 124 | class AttnBlock(nn.Module): 125 | def __init__(self, in_channels): 126 | super().__init__() 127 | print(f"building AttnBlock (vanilla) with {in_channels} in_channels") 128 | 129 | self.in_channels = in_channels 130 | 131 | self.norm = Normalize(in_channels) 132 | self.q = torch.nn.Conv2d(in_channels, 133 | in_channels, 134 | kernel_size=1, 135 | stride=1, 136 | padding=0) 137 | self.k = torch.nn.Conv2d(in_channels, 138 | in_channels, 139 | kernel_size=1, 140 | stride=1, 141 | padding=0) 142 | self.v = torch.nn.Conv2d(in_channels, 143 | in_channels, 144 | kernel_size=1, 145 | stride=1, 146 | padding=0) 147 | self.proj_out = torch.nn.Conv2d(in_channels, 148 | in_channels, 149 | kernel_size=1, 150 | stride=1, 151 | padding=0) 152 | 153 | def forward(self, x): 154 | h_ = x 155 | h_ = self.norm(h_) 156 | q = self.q(h_) 157 | k = self.k(h_) 158 | v = self.v(h_) 159 | 160 | # compute attention 161 | b,c,h,w = q.shape 162 | q = q.reshape(b,c,h*w) 163 | q = q.permute(0,2,1) # b,hw,c 164 | k = k.reshape(b,c,h*w) # b,c,hw 165 | w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 166 | w_ = w_ * (int(c)**(-0.5)) 167 | w_ = torch.nn.functional.softmax(w_, dim=2) 168 | 169 | # attend to values 170 | v = v.reshape(b,c,h*w) 171 | w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) 172 | h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 173 | h_ = h_.reshape(b,c,h,w) 174 | 175 | h_ = self.proj_out(h_) 176 | 177 | return x+h_ 178 | 179 | 180 | class MemoryEfficientAttnBlock(nn.Module): 181 | """ 182 | Uses xformers efficient implementation, 183 | see https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 184 | Note: this is a single-head self-attention operation 185 | """ 186 | # 187 | def __init__(self, in_channels): 188 | super().__init__() 189 | print(f"building MemoryEfficientAttnBlock (xformers) with {in_channels} in_channels") 190 | self.in_channels = in_channels 191 | 192 | self.norm = Normalize(in_channels) 193 | self.q = torch.nn.Conv2d(in_channels, 194 | in_channels, 195 | kernel_size=1, 196 | stride=1, 197 | padding=0) 198 | self.k = torch.nn.Conv2d(in_channels, 199 | in_channels, 200 | kernel_size=1, 201 | stride=1, 202 | padding=0) 203 | self.v = torch.nn.Conv2d(in_channels, 204 | in_channels, 205 | kernel_size=1, 206 | stride=1, 207 | padding=0) 208 | self.proj_out = torch.nn.Conv2d(in_channels, 209 | in_channels, 210 | kernel_size=1, 211 | stride=1, 212 | padding=0) 213 | self.attention_op: Optional[Any] = None 214 | 215 | def forward(self, x): 216 | h_ = x 217 | h_ = self.norm(h_) 218 | q = self.q(h_) 219 | k = self.k(h_) 220 | v = self.v(h_) 221 | 222 | # compute attention 223 | B, C, H, W = q.shape 224 | q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) 225 | 226 | q, k, v = map( 227 | lambda t: t.unsqueeze(3) 228 | .reshape(B, t.shape[1], 1, C) 229 | .permute(0, 2, 1, 3) 230 | .reshape(B * 1, t.shape[1], C) 231 | .contiguous(), 232 | (q, k, v), 233 | ) 234 | out = Config.xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 235 | 236 | out = ( 237 | out.unsqueeze(0) 238 | .reshape(B, 1, out.shape[1], C) 239 | .permute(0, 2, 1, 3) 240 | .reshape(B, out.shape[1], C) 241 | ) 242 | out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) 243 | out = self.proj_out(out) 244 | return x+out 245 | 246 | 247 | class SDPAttnBlock(nn.Module): 248 | 249 | def __init__(self, in_channels): 250 | super().__init__() 251 | print(f"building SDPAttnBlock (sdp) with {in_channels} in_channels") 252 | self.in_channels = in_channels 253 | 254 | self.norm = Normalize(in_channels) 255 | self.q = torch.nn.Conv2d(in_channels, 256 | in_channels, 257 | kernel_size=1, 258 | stride=1, 259 | padding=0) 260 | self.k = torch.nn.Conv2d(in_channels, 261 | in_channels, 262 | kernel_size=1, 263 | stride=1, 264 | padding=0) 265 | self.v = torch.nn.Conv2d(in_channels, 266 | in_channels, 267 | kernel_size=1, 268 | stride=1, 269 | padding=0) 270 | self.proj_out = torch.nn.Conv2d(in_channels, 271 | in_channels, 272 | kernel_size=1, 273 | stride=1, 274 | padding=0) 275 | 276 | def forward(self, x): 277 | h_ = x 278 | h_ = self.norm(h_) 279 | q = self.q(h_) 280 | k = self.k(h_) 281 | v = self.v(h_) 282 | 283 | # compute attention 284 | B, C, H, W = q.shape 285 | q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'), (q, k, v)) 286 | 287 | q, k, v = map( 288 | lambda t: t.unsqueeze(3) 289 | .reshape(B, t.shape[1], 1, C) 290 | .permute(0, 2, 1, 3) 291 | .reshape(B * 1, t.shape[1], C) 292 | .contiguous(), 293 | (q, k, v), 294 | ) 295 | out = F.scaled_dot_product_attention(q, k, v) 296 | 297 | out = ( 298 | out.unsqueeze(0) 299 | .reshape(B, 1, out.shape[1], C) 300 | .permute(0, 2, 1, 3) 301 | .reshape(B, out.shape[1], C) 302 | ) 303 | out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C) 304 | out = self.proj_out(out) 305 | return x+out 306 | 307 | 308 | def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): 309 | assert attn_type in ["vanilla", "sdp", "xformers", "linear", "none"], f'attn_type {attn_type} unknown' 310 | if attn_type == "vanilla": 311 | assert attn_kwargs is None 312 | return AttnBlock(in_channels) 313 | elif attn_type == "sdp": 314 | return SDPAttnBlock(in_channels) 315 | elif attn_type == "xformers": 316 | return MemoryEfficientAttnBlock(in_channels) 317 | elif attn_type == "none": 318 | return nn.Identity(in_channels) 319 | else: 320 | raise NotImplementedError() 321 | 322 | 323 | class Encoder(nn.Module): 324 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 325 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 326 | resolution, z_channels, double_z=True, use_linear_attn=False, 327 | **ignore_kwargs): 328 | super().__init__() 329 | ### setup attention type 330 | if Config.attn_mode == AttnMode.SDP: 331 | attn_type = "sdp" 332 | elif Config.attn_mode == AttnMode.XFORMERS: 333 | attn_type = "xformers" 334 | else: 335 | attn_type = "vanilla" 336 | if use_linear_attn: attn_type = "linear" 337 | self.ch = ch 338 | self.temb_ch = 0 339 | self.num_resolutions = len(ch_mult) 340 | self.num_res_blocks = num_res_blocks 341 | self.resolution = resolution 342 | self.in_channels = in_channels 343 | 344 | # downsampling 345 | self.conv_in = torch.nn.Conv2d(in_channels, 346 | self.ch, 347 | kernel_size=3, 348 | stride=1, 349 | padding=1) 350 | 351 | curr_res = resolution 352 | in_ch_mult = (1,)+tuple(ch_mult) 353 | self.in_ch_mult = in_ch_mult 354 | self.down = nn.ModuleList() 355 | for i_level in range(self.num_resolutions): 356 | block = nn.ModuleList() 357 | attn = nn.ModuleList() 358 | block_in = ch*in_ch_mult[i_level] 359 | block_out = ch*ch_mult[i_level] 360 | for i_block in range(self.num_res_blocks): 361 | block.append(ResnetBlock(in_channels=block_in, 362 | out_channels=block_out, 363 | temb_channels=self.temb_ch, 364 | dropout=dropout)) 365 | block_in = block_out 366 | if curr_res in attn_resolutions: 367 | attn.append(make_attn(block_in, attn_type=attn_type)) 368 | down = nn.Module() 369 | down.block = block 370 | down.attn = attn 371 | if i_level != self.num_resolutions-1: 372 | down.downsample = Downsample(block_in, resamp_with_conv) 373 | curr_res = curr_res // 2 374 | self.down.append(down) 375 | 376 | # middle 377 | self.mid = nn.Module() 378 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 379 | out_channels=block_in, 380 | temb_channels=self.temb_ch, 381 | dropout=dropout) 382 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 383 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 384 | out_channels=block_in, 385 | temb_channels=self.temb_ch, 386 | dropout=dropout) 387 | 388 | # end 389 | self.norm_out = Normalize(block_in) 390 | self.conv_out = torch.nn.Conv2d(block_in, 391 | 2*z_channels if double_z else z_channels, 392 | kernel_size=3, 393 | stride=1, 394 | padding=1) 395 | 396 | def forward(self, x): 397 | # timestep embedding 398 | temb = None 399 | 400 | # downsampling 401 | hs = [self.conv_in(x)] 402 | for i_level in range(self.num_resolutions): 403 | for i_block in range(self.num_res_blocks): 404 | h = self.down[i_level].block[i_block](hs[-1], temb) 405 | if len(self.down[i_level].attn) > 0: 406 | h = self.down[i_level].attn[i_block](h) 407 | hs.append(h) 408 | if i_level != self.num_resolutions-1: 409 | hs.append(self.down[i_level].downsample(hs[-1])) 410 | 411 | # middle 412 | h = hs[-1] 413 | h = self.mid.block_1(h, temb) 414 | h = self.mid.attn_1(h) 415 | h = self.mid.block_2(h, temb) 416 | 417 | # end 418 | h = self.norm_out(h) 419 | h = nonlinearity(h) 420 | h = self.conv_out(h) 421 | return h 422 | 423 | 424 | class Decoder(nn.Module): 425 | def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, 426 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 427 | resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, 428 | **ignorekwargs): 429 | super().__init__() 430 | ### setup attention type 431 | if Config.attn_mode == AttnMode.SDP: 432 | attn_type = "sdp" 433 | elif Config.attn_mode == AttnMode.XFORMERS: 434 | attn_type = "xformers" 435 | else: 436 | attn_type = "vanilla" 437 | if use_linear_attn: attn_type = "linear" 438 | self.ch = ch 439 | self.temb_ch = 0 440 | self.num_resolutions = len(ch_mult) 441 | self.num_res_blocks = num_res_blocks 442 | self.resolution = resolution 443 | self.in_channels = in_channels 444 | self.give_pre_end = give_pre_end 445 | self.tanh_out = tanh_out 446 | 447 | # compute in_ch_mult, block_in and curr_res at lowest res 448 | in_ch_mult = (1,)+tuple(ch_mult) 449 | block_in = ch*ch_mult[self.num_resolutions-1] 450 | curr_res = resolution // 2**(self.num_resolutions-1) 451 | self.z_shape = (1,z_channels,curr_res,curr_res) 452 | 453 | # z to block_in 454 | self.conv_in = torch.nn.Conv2d(z_channels, 455 | block_in, 456 | kernel_size=3, 457 | stride=1, 458 | padding=1) 459 | 460 | # middle 461 | self.mid = nn.Module() 462 | self.mid.block_1 = ResnetBlock(in_channels=block_in, 463 | out_channels=block_in, 464 | temb_channels=self.temb_ch, 465 | dropout=dropout) 466 | self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) 467 | self.mid.block_2 = ResnetBlock(in_channels=block_in, 468 | out_channels=block_in, 469 | temb_channels=self.temb_ch, 470 | dropout=dropout) 471 | 472 | # upsampling 473 | self.up = nn.ModuleList() 474 | for i_level in reversed(range(self.num_resolutions)): 475 | block = nn.ModuleList() 476 | attn = nn.ModuleList() 477 | block_out = ch*ch_mult[i_level] 478 | for i_block in range(self.num_res_blocks+1): 479 | block.append(ResnetBlock(in_channels=block_in, 480 | out_channels=block_out, 481 | temb_channels=self.temb_ch, 482 | dropout=dropout)) 483 | block_in = block_out 484 | if curr_res in attn_resolutions: 485 | attn.append(make_attn(block_in, attn_type=attn_type)) 486 | up = nn.Module() 487 | up.block = block 488 | up.attn = attn 489 | if i_level != 0: 490 | up.upsample = Upsample(block_in, resamp_with_conv) 491 | curr_res = curr_res * 2 492 | self.up.insert(0, up) # prepend to get consistent order 493 | 494 | # end 495 | self.norm_out = Normalize(block_in) 496 | self.conv_out = torch.nn.Conv2d(block_in, 497 | out_ch, 498 | kernel_size=3, 499 | stride=1, 500 | padding=1) 501 | 502 | def forward(self, z): 503 | #assert z.shape[1:] == self.z_shape[1:] 504 | self.last_z_shape = z.shape 505 | 506 | # timestep embedding 507 | temb = None 508 | 509 | # z to block_in 510 | h = self.conv_in(z) 511 | 512 | # middle 513 | h = self.mid.block_1(h, temb) 514 | h = self.mid.attn_1(h) 515 | h = self.mid.block_2(h, temb) 516 | 517 | # upsampling 518 | for i_level in reversed(range(self.num_resolutions)): 519 | for i_block in range(self.num_res_blocks+1): 520 | h = self.up[i_level].block[i_block](h, temb) 521 | if len(self.up[i_level].attn) > 0: 522 | h = self.up[i_level].attn[i_block](h) 523 | if i_level != 0: 524 | h = self.up[i_level].upsample(h) 525 | 526 | # end 527 | if self.give_pre_end: 528 | return h 529 | 530 | h = self.norm_out(h) 531 | h = nonlinearity(h) 532 | h = self.conv_out(h) 533 | if self.tanh_out: 534 | h = torch.tanh(h) 535 | return h 536 | 537 | 538 | class AutoencoderKL(nn.Module): 539 | 540 | def __init__(self, ddconfig, embed_dim): 541 | super().__init__() 542 | self.encoder = Encoder(**ddconfig) 543 | self.decoder = Decoder(**ddconfig) 544 | assert ddconfig["double_z"] 545 | self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) 546 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 547 | self.embed_dim = embed_dim 548 | 549 | def encode(self, x): 550 | h = self.encoder(x) 551 | moments = self.quant_conv(h) 552 | posterior = DiagonalGaussianDistribution(moments) 553 | return posterior 554 | 555 | def decode(self, z): 556 | z = self.post_quant_conv(z) 557 | dec = self.decoder(z) 558 | return dec 559 | 560 | def forward(self, input, sample_posterior=True): 561 | posterior = self.encode(input) 562 | if sample_posterior: 563 | z = posterior.sample() 564 | else: 565 | z = posterior.mode() 566 | dec = self.decode(z) 567 | return dec, posterior 568 | -------------------------------------------------------------------------------- /model/open_clip/transformer.py: -------------------------------------------------------------------------------- 1 | import collections 2 | from collections import OrderedDict 3 | import math 4 | from typing import Callable, Optional, Sequence, Tuple 5 | from itertools import repeat 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | from torch.utils.checkpoint import checkpoint 11 | 12 | # From PyTorch internals 13 | def _ntuple(n): 14 | def parse(x): 15 | if isinstance(x, collections.abc.Iterable): 16 | return x 17 | return tuple(repeat(x, n)) 18 | return parse 19 | 20 | to_2tuple = _ntuple(2) 21 | 22 | 23 | class LayerNormFp32(nn.LayerNorm): 24 | """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" 25 | 26 | def forward(self, x: torch.Tensor): 27 | orig_type = x.dtype 28 | x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) 29 | return x.to(orig_type) 30 | 31 | 32 | class LayerNorm(nn.LayerNorm): 33 | """Subclass torch's LayerNorm (with cast back to input dtype).""" 34 | 35 | def forward(self, x: torch.Tensor): 36 | orig_type = x.dtype 37 | x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) 38 | return x.to(orig_type) 39 | 40 | 41 | class QuickGELU(nn.Module): 42 | # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory 43 | def forward(self, x: torch.Tensor): 44 | return x * torch.sigmoid(1.702 * x) 45 | 46 | 47 | class LayerScale(nn.Module): 48 | def __init__(self, dim, init_values=1e-5, inplace=False): 49 | super().__init__() 50 | self.inplace = inplace 51 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 52 | 53 | def forward(self, x): 54 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 55 | 56 | 57 | class PatchDropout(nn.Module): 58 | """ 59 | https://arxiv.org/abs/2212.00794 60 | """ 61 | 62 | def __init__(self, prob, exclude_first_token=True): 63 | super().__init__() 64 | assert 0 <= prob < 1. 65 | self.prob = prob 66 | self.exclude_first_token = exclude_first_token # exclude CLS token 67 | 68 | def forward(self, x): 69 | if not self.training or self.prob == 0.: 70 | return x 71 | 72 | if self.exclude_first_token: 73 | cls_tokens, x = x[:, :1], x[:, 1:] 74 | else: 75 | cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) 76 | 77 | batch = x.size()[0] 78 | num_tokens = x.size()[1] 79 | 80 | batch_indices = torch.arange(batch) 81 | batch_indices = batch_indices[..., None] 82 | 83 | keep_prob = 1 - self.prob 84 | num_patches_keep = max(1, int(num_tokens * keep_prob)) 85 | 86 | rand = torch.randn(batch, num_tokens) 87 | patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices 88 | 89 | x = x[batch_indices, patch_indices_keep] 90 | 91 | if self.exclude_first_token: 92 | x = torch.cat((cls_tokens, x), dim=1) 93 | 94 | return x 95 | 96 | 97 | class Attention(nn.Module): 98 | def __init__( 99 | self, 100 | dim, 101 | num_heads=8, 102 | qkv_bias=True, 103 | scaled_cosine=False, 104 | scale_heads=False, 105 | logit_scale_max=math.log(1. / 0.01), 106 | attn_drop=0., 107 | proj_drop=0. 108 | ): 109 | super().__init__() 110 | self.scaled_cosine = scaled_cosine 111 | self.scale_heads = scale_heads 112 | assert dim % num_heads == 0, 'dim should be divisible by num_heads' 113 | self.num_heads = num_heads 114 | self.head_dim = dim // num_heads 115 | self.scale = self.head_dim ** -0.5 116 | self.logit_scale_max = logit_scale_max 117 | 118 | # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original 119 | self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) 120 | if qkv_bias: 121 | self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) 122 | else: 123 | self.in_proj_bias = None 124 | 125 | if self.scaled_cosine: 126 | self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) 127 | else: 128 | self.logit_scale = None 129 | self.attn_drop = nn.Dropout(attn_drop) 130 | if self.scale_heads: 131 | self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) 132 | else: 133 | self.head_scale = None 134 | self.out_proj = nn.Linear(dim, dim) 135 | self.out_drop = nn.Dropout(proj_drop) 136 | 137 | def forward(self, x, attn_mask: Optional[torch.Tensor] = None): 138 | L, N, C = x.shape 139 | q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) 140 | q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 141 | k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 142 | v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) 143 | 144 | if self.logit_scale is not None: 145 | attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) 146 | logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() 147 | attn = attn.view(N, self.num_heads, L, L) * logit_scale 148 | attn = attn.view(-1, L, L) 149 | else: 150 | q = q * self.scale 151 | attn = torch.bmm(q, k.transpose(-1, -2)) 152 | 153 | if attn_mask is not None: 154 | if attn_mask.dtype == torch.bool: 155 | new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) 156 | new_attn_mask.masked_fill_(attn_mask, float("-inf")) 157 | attn_mask = new_attn_mask 158 | attn += attn_mask 159 | 160 | attn = attn.softmax(dim=-1) 161 | attn = self.attn_drop(attn) 162 | 163 | x = torch.bmm(attn, v) 164 | if self.head_scale is not None: 165 | x = x.view(N, self.num_heads, L, C) * self.head_scale 166 | x = x.view(-1, L, C) 167 | x = x.transpose(0, 1).reshape(L, N, C) 168 | x = self.out_proj(x) 169 | x = self.out_drop(x) 170 | return x 171 | 172 | 173 | class AttentionalPooler(nn.Module): 174 | def __init__( 175 | self, 176 | d_model: int, 177 | context_dim: int, 178 | n_head: int = 8, 179 | n_queries: int = 256, 180 | norm_layer: Callable = LayerNorm 181 | ): 182 | super().__init__() 183 | self.query = nn.Parameter(torch.randn(n_queries, d_model)) 184 | self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) 185 | self.ln_q = norm_layer(d_model) 186 | self.ln_k = norm_layer(context_dim) 187 | 188 | def forward(self, x: torch.Tensor): 189 | x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND 190 | N = x.shape[1] 191 | q = self.ln_q(self.query) 192 | out = self.attn(self._repeat(q, N), x, x, need_weights=False)[0] 193 | return out.permute(1, 0, 2) # LND -> NLD 194 | 195 | def _repeat(self, query, N: int): 196 | return query.unsqueeze(1).repeat(1, N, 1) 197 | 198 | 199 | class ResidualAttentionBlock(nn.Module): 200 | def __init__( 201 | self, 202 | d_model: int, 203 | n_head: int, 204 | mlp_ratio: float = 4.0, 205 | ls_init_value: float = None, 206 | act_layer: Callable = nn.GELU, 207 | norm_layer: Callable = LayerNorm, 208 | is_cross_attention: bool = False, 209 | ): 210 | super().__init__() 211 | 212 | self.ln_1 = norm_layer(d_model) 213 | self.attn = nn.MultiheadAttention(d_model, n_head) 214 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 215 | if is_cross_attention: 216 | self.ln_1_kv = norm_layer(d_model) 217 | 218 | self.ln_2 = norm_layer(d_model) 219 | mlp_width = int(d_model * mlp_ratio) 220 | self.mlp = nn.Sequential(OrderedDict([ 221 | ("c_fc", nn.Linear(d_model, mlp_width)), 222 | ("gelu", act_layer()), 223 | ("c_proj", nn.Linear(mlp_width, d_model)) 224 | ])) 225 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 226 | 227 | def attention( 228 | self, 229 | q_x: torch.Tensor, 230 | k_x: Optional[torch.Tensor] = None, 231 | v_x: Optional[torch.Tensor] = None, 232 | attn_mask: Optional[torch.Tensor] = None, 233 | ): 234 | k_x = k_x if k_x is not None else q_x 235 | v_x = v_x if v_x is not None else q_x 236 | 237 | attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None 238 | return self.attn( 239 | q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask 240 | )[0] 241 | 242 | def forward( 243 | self, 244 | q_x: torch.Tensor, 245 | k_x: Optional[torch.Tensor] = None, 246 | v_x: Optional[torch.Tensor] = None, 247 | attn_mask: Optional[torch.Tensor] = None, 248 | ): 249 | k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None 250 | v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None 251 | 252 | x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) 253 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 254 | return x 255 | 256 | 257 | class CustomResidualAttentionBlock(nn.Module): 258 | def __init__( 259 | self, 260 | d_model: int, 261 | n_head: int, 262 | mlp_ratio: float = 4.0, 263 | ls_init_value: float = None, 264 | act_layer: Callable = nn.GELU, 265 | norm_layer: Callable = LayerNorm, 266 | scale_cosine_attn: bool = False, 267 | scale_heads: bool = False, 268 | scale_attn: bool = False, 269 | scale_fc: bool = False, 270 | ): 271 | super().__init__() 272 | 273 | self.ln_1 = norm_layer(d_model) 274 | self.attn = Attention( 275 | d_model, n_head, 276 | scaled_cosine=scale_cosine_attn, 277 | scale_heads=scale_heads, 278 | ) 279 | self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() 280 | self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 281 | 282 | self.ln_2 = norm_layer(d_model) 283 | mlp_width = int(d_model * mlp_ratio) 284 | self.mlp = nn.Sequential(OrderedDict([ 285 | ("c_fc", nn.Linear(d_model, mlp_width)), 286 | ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), 287 | ("gelu", act_layer()), 288 | ("c_proj", nn.Linear(mlp_width, d_model)) 289 | ])) 290 | self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() 291 | 292 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 293 | x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) 294 | x = x + self.ls_2(self.mlp(self.ln_2(x))) 295 | return x 296 | 297 | 298 | class Transformer(nn.Module): 299 | def __init__( 300 | self, 301 | width: int, 302 | layers: int, 303 | heads: int, 304 | mlp_ratio: float = 4.0, 305 | ls_init_value: float = None, 306 | act_layer: Callable = nn.GELU, 307 | norm_layer: Callable = LayerNorm, 308 | ): 309 | super().__init__() 310 | self.width = width 311 | self.layers = layers 312 | self.grad_checkpointing = False 313 | 314 | self.resblocks = nn.ModuleList([ 315 | ResidualAttentionBlock( 316 | width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) 317 | for _ in range(layers) 318 | ]) 319 | 320 | def get_cast_dtype(self) -> torch.dtype: 321 | if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): 322 | return self.resblocks[0].mlp.c_fc.int8_original_dtype 323 | return self.resblocks[0].mlp.c_fc.weight.dtype 324 | 325 | def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): 326 | for r in self.resblocks: 327 | if self.grad_checkpointing and not torch.jit.is_scripting(): 328 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 329 | x = checkpoint(r, x, None, None, attn_mask) 330 | else: 331 | x = r(x, attn_mask=attn_mask) 332 | return x 333 | 334 | 335 | class VisionTransformer(nn.Module): 336 | output_tokens: torch.jit.Final[bool] 337 | 338 | def __init__( 339 | self, 340 | image_size: int, 341 | patch_size: int, 342 | width: int, 343 | layers: int, 344 | heads: int, 345 | mlp_ratio: float, 346 | ls_init_value: float = None, 347 | global_average_pool: bool = False, 348 | attentional_pool: bool = False, 349 | n_queries: int = 256, 350 | attn_pooler_heads: int = 8, 351 | output_dim: int = 512, 352 | patch_dropout: float = 0., 353 | input_patchnorm: bool = False, 354 | act_layer: Callable = nn.GELU, 355 | norm_layer: Callable = LayerNorm, 356 | output_tokens: bool = False 357 | ): 358 | super().__init__() 359 | self.output_tokens = output_tokens 360 | image_height, image_width = self.image_size = to_2tuple(image_size) 361 | patch_height, patch_width = self.patch_size = to_2tuple(patch_size) 362 | self.grid_size = (image_height // patch_height, image_width // patch_width) 363 | self.output_dim = output_dim 364 | 365 | # whether to layernorm each patch, as done in dual patchnorm paper - https://arxiv.org/abs/2302.01327v1 366 | self.input_patchnorm = input_patchnorm 367 | 368 | if input_patchnorm: 369 | patch_input_dim = patch_height * patch_width * 3 370 | self.patchnorm_pre_ln = LayerNorm(patch_input_dim) 371 | self.conv1 = nn.Linear(patch_input_dim, width) 372 | else: 373 | self.patchnorm_pre_ln = nn.Identity() 374 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 375 | 376 | # class embeddings and positional embeddings 377 | scale = width ** -0.5 378 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 379 | self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) 380 | 381 | # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn 382 | self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() 383 | 384 | self.ln_pre = norm_layer(width) 385 | self.transformer = Transformer( 386 | width, 387 | layers, 388 | heads, 389 | mlp_ratio, 390 | ls_init_value=ls_init_value, 391 | act_layer=act_layer, 392 | norm_layer=norm_layer, 393 | ) 394 | 395 | self.global_average_pool = global_average_pool 396 | if attentional_pool: 397 | self.attn_pool = AttentionalPooler(output_dim, width, n_head=attn_pooler_heads, n_queries=n_queries) 398 | self.ln_post = norm_layer(output_dim) 399 | self.proj = nn.Parameter(scale * torch.randn(output_dim, output_dim)) 400 | else: 401 | self.attn_pool = None 402 | self.ln_post = norm_layer(width) 403 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 404 | 405 | self.init_parameters() 406 | 407 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 408 | for param in self.parameters(): 409 | param.requires_grad = False 410 | 411 | if unlocked_groups != 0: 412 | groups = [ 413 | [ 414 | self.conv1, 415 | self.class_embedding, 416 | self.positional_embedding, 417 | self.ln_pre, 418 | ], 419 | *self.transformer.resblocks[:-1], 420 | [ 421 | self.transformer.resblocks[-1], 422 | self.ln_post, 423 | ], 424 | self.proj, 425 | ] 426 | 427 | def _unlock(x): 428 | if isinstance(x, Sequence): 429 | for g in x: 430 | _unlock(g) 431 | else: 432 | if isinstance(x, torch.nn.Parameter): 433 | x.requires_grad = True 434 | else: 435 | for p in x.parameters(): 436 | p.requires_grad = True 437 | 438 | _unlock(groups[-unlocked_groups:]) 439 | 440 | def init_parameters(self): 441 | # FIXME OpenAI CLIP did not define an init for the VisualTransformer 442 | # TODO experiment if default PyTorch init, below, or alternate init is best. 443 | 444 | # nn.init.normal_(self.class_embedding, std=self.scale) 445 | # nn.init.normal_(self.positional_embedding, std=self.scale) 446 | # 447 | # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 448 | # attn_std = self.transformer.width ** -0.5 449 | # fc_std = (2 * self.transformer.width) ** -0.5 450 | # for block in self.transformer.resblocks: 451 | # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 452 | # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 453 | # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 454 | # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 455 | # 456 | # if self.text_projection is not None: 457 | # nn.init.normal_(self.text_projection, std=self.scale) 458 | pass 459 | 460 | @torch.jit.ignore 461 | def set_grad_checkpointing(self, enable=True): 462 | self.transformer.grad_checkpointing = enable 463 | 464 | def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 465 | if self.global_average_pool: 466 | return x.mean(dim=1), x 467 | else: 468 | return x[:, 0], x[:, 1:] 469 | 470 | def forward(self, x: torch.Tensor): 471 | 472 | # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1 473 | if self.input_patchnorm: 474 | # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)') 475 | x = x.reshape(x.shape[0], x.shape[1], self.grid_size[0], self.patch_size[0], self.grid_size[1], self.patch_size[1]) 476 | x = x.permute(0, 2, 4, 1, 3, 5) 477 | x = x.reshape(x.shape[0], self.grid_size[0] * self.grid_size[1], -1) 478 | x = self.patchnorm_pre_ln(x) 479 | x = self.conv1(x) 480 | else: 481 | x = self.conv1(x) # shape = [*, width, grid, grid] 482 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 483 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 484 | 485 | # class embeddings and positional embeddings 486 | x = torch.cat( 487 | [self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 488 | x], dim=1) # shape = [*, grid ** 2 + 1, width] 489 | x = x + self.positional_embedding.to(x.dtype) 490 | 491 | # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in 492 | x = self.patch_dropout(x) 493 | x = self.ln_pre(x) 494 | 495 | x = x.permute(1, 0, 2) # NLD -> LND 496 | x = self.transformer(x) 497 | x = x.permute(1, 0, 2) # LND -> NLD 498 | 499 | if self.attn_pool is not None: 500 | x = self.attn_pool(x) 501 | x = self.ln_post(x) 502 | pooled, tokens = self._global_pool(x) 503 | else: 504 | pooled, tokens = self._global_pool(x) 505 | pooled = self.ln_post(pooled) 506 | 507 | if self.proj is not None: 508 | pooled = pooled @ self.proj 509 | 510 | if self.output_tokens: 511 | return pooled, tokens 512 | 513 | return pooled 514 | 515 | 516 | class TextTransformer(nn.Module): 517 | output_tokens: torch.jit.Final[bool] 518 | 519 | def __init__( 520 | self, 521 | context_length: int = 77, 522 | vocab_size: int = 49408, 523 | width: int = 512, 524 | heads: int = 8, 525 | layers: int = 12, 526 | ls_init_value: float = None, 527 | output_dim: int = 512, 528 | act_layer: Callable = nn.GELU, 529 | norm_layer: Callable = LayerNorm, 530 | embed_cls: bool = False, 531 | pad_id: int = 0, 532 | output_tokens: bool = False, 533 | ): 534 | super().__init__() 535 | self.output_tokens = output_tokens 536 | self.num_pos = self.context_length = context_length 537 | self.vocab_size = vocab_size 538 | self.width = width 539 | self.output_dim = output_dim 540 | self.heads = heads 541 | self.pad_id = pad_id 542 | 543 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 544 | 545 | if embed_cls: 546 | self.cls_emb = nn.Parameter(torch.empty(width)) 547 | self.num_pos += 1 548 | else: 549 | self.cls_emb = None 550 | 551 | self.token_embedding = nn.Embedding(vocab_size, width) 552 | self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) 553 | self.transformer = Transformer( 554 | width=width, 555 | layers=layers, 556 | heads=heads, 557 | ls_init_value=ls_init_value, 558 | act_layer=act_layer, 559 | norm_layer=norm_layer, 560 | ) 561 | self.ln_final = norm_layer(width) 562 | 563 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 564 | 565 | self.init_parameters() 566 | 567 | def init_parameters(self): 568 | nn.init.normal_(self.token_embedding.weight, std=0.02) 569 | nn.init.normal_(self.positional_embedding, std=0.01) 570 | if self.cls_emb is not None: 571 | nn.init.normal_(self.cls_emb, std=0.01) 572 | 573 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 574 | attn_std = self.transformer.width ** -0.5 575 | fc_std = (2 * self.transformer.width) ** -0.5 576 | for block in self.transformer.resblocks: 577 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 578 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 579 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 580 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 581 | 582 | if self.text_projection is not None: 583 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 584 | 585 | @torch.jit.ignore 586 | def set_grad_checkpointing(self, enable=True): 587 | self.transformer.grad_checkpointing = enable 588 | 589 | def build_attention_mask(self): 590 | # lazily create causal attention mask, with full attention between the tokens 591 | # pytorch uses additive attention mask; fill with -inf 592 | mask = torch.empty(self.num_pos, self.num_pos) 593 | mask.fill_(float("-inf")) 594 | mask.triu_(1) # zero out the lower diagonal 595 | return mask 596 | 597 | def build_cls_mask(self, text, cast_dtype: torch.dtype): 598 | cls_mask = (text != self.pad_id).unsqueeze(1) 599 | cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=1.0) 600 | additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) 601 | additive_mask.fill_(0) 602 | additive_mask.masked_fill_(~cls_mask, float("-inf")) 603 | additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) 604 | return additive_mask 605 | 606 | def _repeat(self, t, N: int): 607 | return t.reshape(1, 1, -1).repeat(N, 1, 1) 608 | 609 | def forward(self, text): 610 | cast_dtype = self.transformer.get_cast_dtype() 611 | seq_len = text.shape[1] 612 | 613 | x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] 614 | attn_mask = self.attn_mask 615 | if self.cls_emb is not None: 616 | seq_len += 1 617 | x = torch.cat([x, self._repeat(self.cls_emb, x.shape[0])], dim=1) 618 | cls_mask = self.build_cls_mask(text, cast_dtype) 619 | attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] 620 | 621 | x = x + self.positional_embedding[:seq_len].to(cast_dtype) 622 | x = x.permute(1, 0, 2) # NLD -> LND 623 | x = self.transformer(x, attn_mask=attn_mask) 624 | x = x.permute(1, 0, 2) # LND -> NLD 625 | 626 | # x.shape = [batch_size, n_ctx, transformer.width] 627 | # take features from the eot embedding (eot_token is the highest number in each sequence) 628 | if self.cls_emb is not None: 629 | pooled, tokens = x[:, -1], x[:, :-1] 630 | pooled = self.ln_final(pooled) 631 | else: 632 | x = self.ln_final(x) 633 | pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x 634 | 635 | if self.text_projection is not None: 636 | pooled = pooled @ self.text_projection 637 | 638 | if self.output_tokens: 639 | return pooled, tokens 640 | 641 | return pooled 642 | 643 | 644 | class MultimodalTransformer(Transformer): 645 | def __init__( 646 | self, 647 | width: int, 648 | layers: int, 649 | heads: int, 650 | context_length: int = 77, 651 | mlp_ratio: float = 4.0, 652 | ls_init_value: float = None, 653 | act_layer: Callable = nn.GELU, 654 | norm_layer: Callable = LayerNorm, 655 | output_dim: int = 512, 656 | ): 657 | 658 | super().__init__( 659 | width=width, 660 | layers=layers, 661 | heads=heads, 662 | mlp_ratio=mlp_ratio, 663 | ls_init_value=ls_init_value, 664 | act_layer=act_layer, 665 | norm_layer=norm_layer, 666 | ) 667 | self.context_length = context_length 668 | self.cross_attn = nn.ModuleList([ 669 | ResidualAttentionBlock( 670 | width, 671 | heads, 672 | mlp_ratio, 673 | ls_init_value=ls_init_value, 674 | act_layer=act_layer, 675 | norm_layer=norm_layer, 676 | is_cross_attention=True, 677 | ) 678 | for _ in range(layers) 679 | ]) 680 | 681 | self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) 682 | 683 | self.ln_final = norm_layer(width) 684 | self.text_projection = nn.Parameter(torch.empty(width, output_dim)) 685 | 686 | def init_parameters(self): 687 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 688 | attn_std = self.transformer.width ** -0.5 689 | fc_std = (2 * self.transformer.width) ** -0.5 690 | for block in self.transformer.resblocks: 691 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 692 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 693 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 694 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 695 | for block in self.transformer.cross_attn: 696 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 697 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 698 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 699 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 700 | 701 | if self.text_projection is not None: 702 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 703 | 704 | def build_attention_mask(self): 705 | # lazily create causal attention mask, with full attention between the tokens 706 | # pytorch uses additive attention mask; fill with -inf 707 | mask = torch.empty(self.context_length, self.context_length) 708 | mask.fill_(float("-inf")) 709 | mask.triu_(1) # zero out the lower diagonal 710 | return mask 711 | 712 | def forward(self, image_embs, text_embs): 713 | text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq 714 | image_embs = image_embs.permute(1, 0, 2) # NLD -> LND 715 | seq_len = text_embs.shape[0] 716 | 717 | for resblock, cross_attn in zip(self.resblocks, self.cross_attn): 718 | if self.grad_checkpointing and not torch.jit.is_scripting(): 719 | # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 720 | text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) 721 | text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) 722 | else: 723 | text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) 724 | text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) 725 | 726 | x = text_embs.permute(1, 0, 2) # LND -> NLD 727 | x = self.ln_final(x) 728 | 729 | if self.text_projection is not None: 730 | x = x @ self.text_projection 731 | 732 | return x 733 | 734 | @torch.jit.ignore 735 | def set_grad_checkpointing(self, enable=True): 736 | self.grad_checkpointing = enable 737 | --------------------------------------------------------------------------------