├── data ├── example.txt ├── Mask │ └── example.png ├── Image │ └── example.png └── Caption │ └── example.txt ├── .gitignore ├── asset └── Teaser.jpg ├── output └── example.png ├── ip_adapter ├── utils.py ├── __init__.py ├── test_resampler.py ├── resampler.py └── ip_adapter.py ├── inference.sh ├── requirements.txt ├── README.md ├── dataloader.py ├── inference.py └── src ├── transformerhacked_garmnet.py ├── transformerhacked_tryon.py ├── attentionhacked_tryon.py └── attentionhacked_garmnet.py /data/example.txt: -------------------------------------------------------------------------------- 1 | example.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.pyc 3 | *.pyo 4 | *.pyd -------------------------------------------------------------------------------- /asset/Teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lanjiong-Li/AssetDropper/HEAD/asset/Teaser.jpg -------------------------------------------------------------------------------- /output/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lanjiong-Li/AssetDropper/HEAD/output/example.png -------------------------------------------------------------------------------- /data/Mask/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lanjiong-Li/AssetDropper/HEAD/data/Mask/example.png -------------------------------------------------------------------------------- /data/Image/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lanjiong-Li/AssetDropper/HEAD/data/Image/example.png -------------------------------------------------------------------------------- /data/Caption/example.txt: -------------------------------------------------------------------------------- 1 | a cheerful green bottle-shaped character holding a top hat from which a vibrant rainbow arches upward -------------------------------------------------------------------------------- /ip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def is_torch2_available(): 5 | return hasattr(F, "scaled_dot_product_attention") 6 | -------------------------------------------------------------------------------- /inference.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --main_process_port 29521 inference.py \ 2 | --num_inference_steps 120 \ 3 | --output_dir "./output" \ 4 | --data_dir "./data" \ 5 | --test_batch_size 8 --guidance_scale 2.0 \ 6 | --txt_name "example" \ 7 | --pretrained_model_name_or_path "LLanv/AssetDropper" \ 8 | --seed 42 -------------------------------------------------------------------------------- /ip_adapter/__init__.py: -------------------------------------------------------------------------------- 1 | from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterPlus_Lora,IPAdapterPlus_Lora_up 2 | 3 | __all__ = [ 4 | "IPAdapter", 5 | "IPAdapterPlus", 6 | "IPAdapterPlusXL", 7 | "IPAdapterXL", 8 | "IPAdapterFull", 9 | "IPAdapterPlus_Lora", 10 | 'IPAdapterPlus_Lora_up', 11 | ] 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.1.0 2 | certifi==2025.4.26 3 | charset-normalizer==3.4.2 4 | diffusers==0.29.2 5 | einops==0.8.1 6 | filelock==3.13.1 7 | fsspec==2024.6.1 8 | hf-xet==1.1.3 9 | huggingface-hub==0.32.4 10 | idna==3.10 11 | importlib_metadata==8.7.0 12 | Jinja2==3.1.4 13 | MarkupSafe==2.1.5 14 | mpmath==1.3.0 15 | networkx==3.3 16 | numpy==2.1.2 17 | nvidia-cublas-cu11==11.11.3.6 18 | nvidia-cuda-cupti-cu11==11.8.87 19 | nvidia-cuda-nvrtc-cu11==11.8.89 20 | nvidia-cuda-runtime-cu11==11.8.89 21 | nvidia-cudnn-cu11==9.1.0.70 22 | nvidia-cufft-cu11==10.9.0.58 23 | nvidia-curand-cu11==10.3.0.86 24 | nvidia-cusolver-cu11==11.4.1.48 25 | nvidia-cusparse-cu11==11.7.5.86 26 | nvidia-nccl-cu11==2.20.5 27 | nvidia-nvtx-cu11==11.8.86 28 | opencv-python==4.10.0.84 29 | packaging==25.0 30 | pillow==11.0.0 31 | psutil==7.0.0 32 | PyYAML==6.0.2 33 | regex==2024.11.6 34 | requests==2.32.3 35 | safetensors==0.5.3 36 | sympy==1.13.3 37 | tokenizers==0.15.2 38 | tqdm==4.67.1 39 | transformers==4.36.2 40 | triton==3.0.0 41 | typing_extensions==4.12.2 42 | urllib3==2.4.0 43 | zipp==3.22.0 44 | -------------------------------------------------------------------------------- /ip_adapter/test_resampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from resampler import Resampler 3 | from transformers import CLIPVisionModel 4 | 5 | BATCH_SIZE = 2 6 | OUTPUT_DIM = 1280 7 | NUM_QUERIES = 8 8 | NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior) 9 | APPLY_POS_EMB = True # False for no positional embeddings (previous behavior) 10 | IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" 11 | 12 | 13 | def main(): 14 | image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH) 15 | embedding_dim = image_encoder.config.hidden_size 16 | print(f"image_encoder hidden size: ", embedding_dim) 17 | 18 | image_proj_model = Resampler( 19 | dim=1024, 20 | depth=2, 21 | dim_head=64, 22 | heads=16, 23 | num_queries=NUM_QUERIES, 24 | embedding_dim=embedding_dim, 25 | output_dim=OUTPUT_DIM, 26 | ff_mult=2, 27 | max_seq_len=257, 28 | apply_pos_emb=APPLY_POS_EMB, 29 | num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED, 30 | ) 31 | 32 | dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224) 33 | with torch.no_grad(): 34 | image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2] 35 | print("image_embds shape: ", image_embeds.shape) 36 | 37 | with torch.no_grad(): 38 | ip_tokens = image_proj_model(image_embeds) 39 | print("ip_tokens shape:", ip_tokens.shape) 40 | assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM) 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AssetDropper: Asset Extraction via Diffusion Models with Reward-Driven Optimization 2 | 3 | ![teaser](asset/Teaser.jpg) 4 | 5 | ![Version](https://img.shields.io/badge/version-1.0.0-blue)   6 |   7 |   8 | [![HuggingFace Model](https://img.shields.io/badge/🤗%20Hugging%20Face-Model-green)](https://huggingface.co/LLanv/AssetDropper)  9 | 10 | ## Installation 11 | ```bash 12 | git clone https://github.com/Lanjiong-Li/AssetDropper.git 13 | cd AssetDropper 14 | 15 | conda create -n assetdropper python=3.10 -y 16 | conda activate assetdropper 17 | 18 | # Install torch, torchvision based on your machine configuration 19 | pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu118 20 | 21 | # Install other dependencies 22 | pip install -r requirements.txt 23 | ``` 24 | 25 | ## Usage 26 | 27 | ### Prepare Input 28 | To help you get started with your own images, you should follow this simple data structure: 29 | Put your own **image** (`.jpg` or `.png`) & corresponding **mask** (`.jpg` or `.png`) & **caption** in the subdirectory of data. 30 | 31 | Here is an overview of data structure: 32 | 33 | ``` 34 | data 35 | ├── Caption/ 36 | │ └── example.txt 37 | ├── Image/ 38 | │ └── example.png 39 | ├── Mask/ 40 | │ └── example.png 41 | └── example.txt (type in image names you want to process) 42 | ``` 43 | 44 | ### Get Asset from Reference Image & Mask 45 | 46 | Run the following command to get asset from the reference image: 47 | 48 | ```bash 49 | python inference.py \ 50 | --pretrained_model_name_or_path "LLanv/AssetDropper" \ 51 | --data_dir "./data" \ 52 | --output_dir "./output" \ 53 | --txt_name "example" \ 54 | --test_batch_size 8 \ 55 | --guidance_scale 2.0 \ 56 | --num_inference_steps 120 \ 57 | ``` 58 | - `--pretrained_model_name_or_path`:Path to the pre-trained AssetDropper model checkpoint. 59 | - `--data_dir`:Path to the directory containing input images & masks. 60 | - `--output_dir`:Path to the output directory. 61 | - `--txt_name`:Name of the file that records the image name you want to process. 62 | 63 | Or simply run: 64 | ```bash 65 | bash inference.sh 66 | ``` 67 | 68 | # ToDo List 69 | - [x] Inference code 70 | - [ ] Gradio & Hugging Face demo (Coming Soon) 71 | - [ ] Dataset (Coming Soon) 72 | 73 | ## Citation 74 | If you find this work useful for your research, please consider citing: 75 | ``` 76 | @article{li2025assetdropper, 77 | title={AssetDropper: Asset Extraction via Diffusion Models with Reward-Driven Optimization}, 78 | author={Li, Lanjiong and Zhao, Guanhua and Zhu, Lingting and Cai, Zeyu and Yu, Lequan and Zhang, Jian and Wang, Zeyu}, 79 | journal={arXiv preprint arXiv:2506.07738}, 80 | year={2025} 81 | } 82 | ``` -------------------------------------------------------------------------------- /ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | from einops import rearrange 9 | from einops.layers.torch import Rearrange 10 | 11 | 12 | # FFN 13 | def FeedForward(dim, mult=4): 14 | inner_dim = int(dim * mult) 15 | return nn.Sequential( 16 | nn.LayerNorm(dim), 17 | nn.Linear(dim, inner_dim, bias=False), 18 | nn.GELU(), 19 | nn.Linear(inner_dim, dim, bias=False), 20 | ) 21 | 22 | 23 | def reshape_tensor(x, heads): 24 | bs, length, width = x.shape 25 | # (bs, length, width) --> (bs, length, n_heads, dim_per_head) 26 | x = x.view(bs, length, heads, -1) 27 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 28 | x = x.transpose(1, 2) 29 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 30 | x = x.reshape(bs, heads, length, -1) 31 | return x 32 | 33 | 34 | class PerceiverAttention(nn.Module): 35 | def __init__(self, *, dim, dim_head=64, heads=8): 36 | super().__init__() 37 | self.scale = dim_head**-0.5 38 | self.dim_head = dim_head 39 | self.heads = heads 40 | inner_dim = dim_head * heads 41 | 42 | self.norm1 = nn.LayerNorm(dim) 43 | self.norm2 = nn.LayerNorm(dim) 44 | 45 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 46 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 47 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 48 | 49 | def forward(self, x, latents): 50 | """ 51 | Args: 52 | x (torch.Tensor): image features 53 | shape (b, n1, D) 54 | latent (torch.Tensor): latent features 55 | shape (b, n2, D) 56 | """ 57 | x = self.norm1(x) 58 | latents = self.norm2(latents) 59 | 60 | b, l, _ = latents.shape 61 | 62 | q = self.to_q(latents) 63 | kv_input = torch.cat((x, latents), dim=-2) 64 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 65 | 66 | q = reshape_tensor(q, self.heads) 67 | k = reshape_tensor(k, self.heads) 68 | v = reshape_tensor(v, self.heads) 69 | 70 | # attention 71 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 72 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 73 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 74 | out = weight @ v 75 | 76 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 77 | 78 | return self.to_out(out) 79 | 80 | 81 | class CrossAttention(nn.Module): 82 | def __init__(self, *, dim, dim_head=64, heads=8): 83 | super().__init__() 84 | self.scale = dim_head**-0.5 85 | self.dim_head = dim_head 86 | self.heads = heads 87 | inner_dim = dim_head * heads 88 | 89 | self.norm1 = nn.LayerNorm(dim) 90 | self.norm2 = nn.LayerNorm(dim) 91 | 92 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 93 | self.to_k = nn.Linear(dim, inner_dim, bias=False) 94 | self.to_v = nn.Linear(dim, inner_dim, bias=False) 95 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 96 | 97 | 98 | def forward(self, x, x2): 99 | """ 100 | Args: 101 | x (torch.Tensor): image features 102 | shape (b, n1, D) 103 | latent (torch.Tensor): latent features 104 | shape (b, n2, D) 105 | """ 106 | x = self.norm1(x) 107 | x2 = self.norm2(x2) 108 | 109 | b, l, _ = x2.shape 110 | 111 | q = self.to_q(x) 112 | k = self.to_k(x2) 113 | v = self.to_v(x2) 114 | 115 | q = reshape_tensor(q, self.heads) 116 | k = reshape_tensor(k, self.heads) 117 | v = reshape_tensor(v, self.heads) 118 | 119 | # attention 120 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 121 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 122 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 123 | out = weight @ v 124 | 125 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 126 | return self.to_out(out) 127 | 128 | 129 | class Resampler(nn.Module): 130 | def __init__( 131 | self, 132 | dim=1024, 133 | depth=8, 134 | dim_head=64, 135 | heads=16, 136 | num_queries=8, 137 | embedding_dim=768, 138 | output_dim=1024, 139 | ff_mult=4, 140 | max_seq_len: int = 257, # CLIP tokens + CLS token 141 | apply_pos_emb: bool = False, 142 | num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence 143 | ): 144 | super().__init__() 145 | 146 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 147 | 148 | self.proj_in = nn.Linear(embedding_dim, dim) 149 | 150 | self.proj_out = nn.Linear(dim, output_dim) 151 | self.norm_out = nn.LayerNorm(output_dim) 152 | 153 | self.layers = nn.ModuleList([]) 154 | for _ in range(depth): 155 | self.layers.append( 156 | nn.ModuleList( 157 | [ 158 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 159 | FeedForward(dim=dim, mult=ff_mult), 160 | ] 161 | ) 162 | ) 163 | 164 | def forward(self, x): 165 | 166 | latents = self.latents.repeat(x.size(0), 1, 1) 167 | 168 | x = self.proj_in(x) 169 | 170 | 171 | for attn, ff in self.layers: 172 | latents = attn(x, latents) + latents 173 | latents = ff(latents) + latents 174 | 175 | latents = self.proj_out(latents) 176 | return self.norm_out(latents) 177 | 178 | 179 | 180 | def masked_mean(t, *, dim, mask=None): 181 | if mask is None: 182 | return t.mean(dim=dim) 183 | 184 | denom = mask.sum(dim=dim, keepdim=True) 185 | mask = rearrange(mask, "b n -> b n 1") 186 | masked_t = t.masked_fill(~mask, 0.0) 187 | 188 | return masked_t.sum(dim=dim) / denom.clamp(min=1e-5) 189 | -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | from PIL import Image 5 | from transformers import CLIPImageProcessor 6 | from typing import Literal, Tuple 7 | import torch.utils.data as data 8 | import numpy as np 9 | import cv2 10 | import torch 11 | 12 | class AssetDataset(data.Dataset): 13 | def __init__( 14 | self, 15 | dataroot_path: str, 16 | phase: Literal["train", "test"], 17 | size: Tuple[int, int] = (512, 512), 18 | txt_name: str = None, 19 | ): 20 | super(AssetDataset, self).__init__() 21 | self.dataroot = dataroot_path 22 | self.phase = phase 23 | self.height = size[0] 24 | self.width = size[1] 25 | self.size = size 26 | self.txt_name = txt_name 27 | 28 | self.norm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 29 | self.transform = transforms.Compose( 30 | [ 31 | transforms.ToTensor(), 32 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 33 | ] 34 | ) 35 | self.transform2D = transforms.Compose( 36 | [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] 37 | ) 38 | 39 | self.toTensor = transforms.ToTensor() 40 | 41 | image_names = [] 42 | caption_names = [] 43 | dataroot_names = [] 44 | 45 | 46 | if phase == "train": 47 | filename = os.path.join(dataroot_path, f"{phase}.txt") 48 | else: 49 | if txt_name is None: 50 | filename = os.path.join(dataroot_path, f"{phase}.txt") 51 | else: 52 | filename = os.path.join(dataroot_path, f"{txt_name}.txt") 53 | 54 | with open(filename, "r") as f: 55 | for line in f.readlines(): 56 | 57 | image_name = line.strip() 58 | 59 | name_no_ext, _ = os.path.splitext(image_name) 60 | caption_name = name_no_ext + ".txt" 61 | 62 | image_names.append(image_name) 63 | caption_names.append(caption_name) 64 | dataroot_names.append(dataroot_path) 65 | 66 | self.image_names = image_names 67 | self.caption_names = caption_names 68 | self.dataroot_names = dataroot_names 69 | self.flip_transform = transforms.RandomHorizontalFlip(p=1) 70 | self.clip_processor = CLIPImageProcessor() 71 | 72 | def _crop_and_resize_by_mask( 73 | self, 74 | image: Image.Image, 75 | mask: Image.Image, 76 | output_size=(512, 512) 77 | ) -> Tuple[Image.Image, Image.Image]: 78 | 79 | mask_np = np.array(mask.convert("L")) 80 | if mask_np.max() == 0: 81 | return image.resize(output_size), mask.resize(output_size) 82 | 83 | ys, xs = np.nonzero(mask_np) 84 | min_x, max_x = xs.min(), xs.max() 85 | min_y, max_y = ys.min(), ys.max() 86 | 87 | box_width = max_x - min_x 88 | box_height = max_y - min_y 89 | box_size = max(box_width, box_height) 90 | 91 | center_x = (min_x + max_x) // 2 92 | center_y = (min_y + max_y) // 2 93 | half_size = box_size // 2 94 | 95 | left = max(center_x - half_size, 0) 96 | upper = max(center_y - half_size, 0) 97 | right = min(center_x + half_size, image.width) 98 | lower = min(center_y + half_size, image.height) 99 | 100 | if right - left < box_size: 101 | if left == 0: 102 | right = min(left + box_size, image.width) 103 | else: 104 | left = max(right - box_size, 0) 105 | 106 | if lower - upper < box_size: 107 | if upper == 0: 108 | lower = min(upper + box_size, image.height) 109 | else: 110 | upper = max(lower - box_size, 0) 111 | 112 | crop_box = (left, upper, right, lower) 113 | 114 | cropped_image = image.crop(crop_box).resize(output_size, resample=Image.BICUBIC) 115 | cropped_mask = mask.crop(crop_box).resize(output_size, resample=Image.NEAREST) 116 | 117 | return cropped_image, cropped_mask 118 | 119 | def __getitem__(self, index): 120 | image_name = self.image_names[index] 121 | caption_name = self.caption_names[index] 122 | 123 | #1 image 124 | image = Image.open(os.path.join(self.dataroot, "Image", image_name)) 125 | 126 | if image.mode == 'RGBA': 127 | white_bg = Image.new("RGB", image.size, (255, 255, 255)) 128 | white_bg.paste(image, (0, 0), image) 129 | image = white_bg 130 | else: 131 | image = image.convert('RGB') 132 | 133 | image = image.resize((512, 512)) 134 | 135 | mask_name_without_ext = os.path.splitext(image_name)[0] 136 | print(f"mask_name_without_ext:{mask_name_without_ext}") 137 | 138 | possible_ext = ['.jpg', '.png'] 139 | 140 | for ext in possible_ext: 141 | test_path = os.path.join(self.dataroot, "Mask", mask_name_without_ext + ext) 142 | if os.path.exists(test_path): 143 | mask_path = test_path 144 | break 145 | 146 | if mask_path is None: 147 | raise FileNotFoundError(f"Missing Mask: {image_name}") 148 | 149 | #2 mask 150 | mask = Image.open(mask_path).resize((512,512)) 151 | 152 | image, mask = self._crop_and_resize_by_mask(image, mask, output_size=(512, 512)) 153 | 154 | #3 pattern 155 | pattern = self.toTensor(image) 156 | 157 | image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 158 | mask_cv = np.array(mask.convert("L")) 159 | 160 | #4 masked_image for IP-Adapter 161 | masked_image_cv = cv2.bitwise_and(image_cv, image_cv, mask=mask_cv) 162 | masked_image = Image.fromarray(cv2.cvtColor(masked_image_cv, cv2.COLOR_BGR2RGB)).resize((512, 512)) 163 | mask_img_trim = self.clip_processor(images=masked_image, return_tensors="pt").pixel_values 164 | 165 | #5 edgemap 166 | image_gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) 167 | kernel = np.ones((3, 3), np.uint8) 168 | eroded_mask = cv2.erode(mask_cv, kernel, iterations=3) 169 | sobelx = cv2.Sobel(image_gray, cv2.CV_64F, 1, 0, ksize=3) 170 | sobely = cv2.Sobel(image_gray, cv2.CV_64F, 0, 1, ksize=3) 171 | gradient = cv2.addWeighted(cv2.convertScaleAbs(sobelx), 0.5, cv2.convertScaleAbs(sobely), 0.5, 0) 172 | gradient[eroded_mask == 0] = 0 173 | edgemap = Image.fromarray(gradient).resize((512, 512)) 174 | 175 | mask = self.toTensor(mask) 176 | edgemap = self.toTensor(edgemap) 177 | mask = mask[:1] 178 | edgemap = edgemap[:1] 179 | 180 | pattern = self.norm(pattern) 181 | image = self.transform(image) #norm [-1, 1] 182 | 183 | #caption 184 | with open(f"{self.dataroot}/Caption/{caption_name}","r") as f: 185 | caption = f.readline().strip() 186 | 187 | result = {} 188 | 189 | result["image_name"] = image_name 190 | result["image"] = image 191 | result["mask"] = mask 192 | result["edgemap"] = edgemap 193 | result["masked_image"] = mask_img_trim 194 | result["pattern"] = pattern 195 | result["caption_pattern"] = f"The pattern is {caption}" 196 | result["caption_gen"] = f"A normalized square pattern of {caption}" 197 | 198 | return result 199 | 200 | def __len__(self): 201 | return len(self.image_names) -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union, Literal 15 | from ip_adapter.ip_adapter import Resampler 16 | 17 | import argparse 18 | import logging 19 | import os 20 | import torch.utils.data as data 21 | import torchvision 22 | import json 23 | import accelerate 24 | import numpy as np 25 | import torch 26 | from PIL import Image 27 | import torch.nn.functional as F 28 | import transformers 29 | from accelerate import Accelerator 30 | from accelerate.logging import get_logger 31 | from accelerate.utils import ProjectConfiguration, set_seed 32 | from packaging import version 33 | from torchvision import transforms 34 | import diffusers 35 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, StableDiffusionXLControlNetInpaintPipeline 36 | from transformers import AutoTokenizer, PretrainedConfig,CLIPImageProcessor, CLIPVisionModelWithProjection,CLIPTextModelWithProjection, CLIPTextModel, CLIPTokenizer 37 | import random 38 | from diffusers.utils.import_utils import is_xformers_available 39 | 40 | from src.unet_hacked_tryon import UNet2DConditionModel 41 | from src.unet_hacked_garmnet import UNet2DConditionModel as UNet2DConditionModel_ref 42 | from src.assetdropper_pipeline import StableDiffusionXLInpaintPipeline as AssetDropperPipeline 43 | from huggingface_hub import snapshot_download 44 | from dataloader import AssetDataset 45 | 46 | logger = get_logger(__name__, log_level="INFO") 47 | 48 | def parse_args(): 49 | parser = argparse.ArgumentParser(description="paras for inference.") 50 | parser.add_argument("--pretrained_model_name_or_path",type=str,default="",required=False,) 51 | parser.add_argument("--width",type=int,default=512,) 52 | parser.add_argument("--height",type=int,default=512,) 53 | parser.add_argument("--Pwidth",type=int,default=512,) 54 | parser.add_argument("--Pheight",type=int,default=512,) 55 | parser.add_argument("--txt_name",type=str,default=None) 56 | parser.add_argument("--num_inference_steps",type=int,default=50,) 57 | parser.add_argument("--output_dir",type=str,default="./output",) 58 | parser.add_argument("--data_dir",type=str,default="./dataset") 59 | parser.add_argument("--seed", type=int, default=42,) 60 | parser.add_argument("--test_batch_size", type=int, default=2,) 61 | parser.add_argument("--guidance_scale",type=float,default=2.0,) 62 | parser.add_argument("--mixed_precision",type=str,default=None,choices=["no", "fp16", "bf16"],) 63 | parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.") 64 | args = parser.parse_args() 65 | 66 | return args 67 | 68 | def pil_to_tensor(images): 69 | images = np.array(images).astype(np.float32) / 255.0 70 | images = torch.from_numpy(images.transpose(2, 0, 1)) 71 | return images 72 | 73 | 74 | def main(): 75 | args = parse_args() 76 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir) 77 | accelerator = Accelerator( 78 | mixed_precision=args.mixed_precision, 79 | project_config=accelerator_project_config, 80 | ) 81 | if accelerator.is_local_main_process: 82 | transformers.utils.logging.set_verbosity_warning() 83 | diffusers.utils.logging.set_verbosity_info() 84 | else: 85 | transformers.utils.logging.set_verbosity_error() 86 | diffusers.utils.logging.set_verbosity_error() 87 | 88 | if args.seed is not None: 89 | set_seed(args.seed) 90 | 91 | if accelerator.is_main_process: 92 | if args.output_dir is not None: 93 | os.makedirs(args.output_dir, exist_ok=True) 94 | 95 | weight_dtype = torch.float16 96 | 97 | # Load scheduler, tokenizer and models. 98 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="checkpoint-37500/scheduler") 99 | 100 | vae = AutoencoderKL.from_pretrained( 101 | args.pretrained_model_name_or_path, 102 | subfolder="checkpoint-37500/vae", 103 | torch_dtype=torch.float16, 104 | ) 105 | 106 | unet_dir = snapshot_download( 107 | repo_id="LLanv/AssetDropper", 108 | repo_type="model", 109 | allow_patterns=["checkpoint-37500/unet/*"], 110 | ) 111 | unet_path = os.path.join(unet_dir, "checkpoint-37500/unet") 112 | unet = UNet2DConditionModel.from_pretrained( 113 | pretrained_model_name_or_path=unet_path, 114 | use_safetensors=True, 115 | low_cpu_mem_usage=True 116 | ) 117 | 118 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( 119 | args.pretrained_model_name_or_path, 120 | subfolder="checkpoint-37500/image_encoder", 121 | torch_dtype=torch.float16, 122 | ) 123 | unet_encoder = UNet2DConditionModel_ref.from_pretrained( 124 | 'stabilityai/stable-diffusion-xl-base-1.0', 125 | subfolder="unet" 126 | ) 127 | unet_encoder.config.addition_embed_type = None 128 | unet_encoder.config["addition_embed_type"] = None 129 | 130 | text_encoder_one = CLIPTextModel.from_pretrained( 131 | args.pretrained_model_name_or_path, 132 | subfolder="checkpoint-37500/text_encoder", 133 | torch_dtype=torch.float16, 134 | ) 135 | text_encoder_two = CLIPTextModelWithProjection.from_pretrained( 136 | args.pretrained_model_name_or_path, 137 | subfolder="checkpoint-37500/text_encoder_2", 138 | torch_dtype=torch.float16, 139 | ) 140 | tokenizer_one = AutoTokenizer.from_pretrained( 141 | args.pretrained_model_name_or_path, 142 | subfolder="checkpoint-37500/tokenizer", 143 | revision=None, 144 | use_fast=False, 145 | ) 146 | tokenizer_two = AutoTokenizer.from_pretrained( 147 | args.pretrained_model_name_or_path, 148 | subfolder="checkpoint-37500/tokenizer_2", 149 | revision=None, 150 | use_fast=False, 151 | ) 152 | 153 | unet.requires_grad_(False) 154 | vae.requires_grad_(False) 155 | image_encoder.requires_grad_(False) 156 | unet_encoder.requires_grad_(False) 157 | text_encoder_one.requires_grad_(False) 158 | text_encoder_two.requires_grad_(False) 159 | unet_encoder.to(accelerator.device, weight_dtype) 160 | unet.eval() 161 | unet_encoder.eval() 162 | 163 | conv_new_encoder = torch.nn.Conv2d( 164 | in_channels=6, 165 | out_channels=unet_encoder.conv_in.out_channels, 166 | kernel_size=3, 167 | padding=1, 168 | ) 169 | torch.nn.init.kaiming_normal_(conv_new_encoder.weight) 170 | conv_new_encoder.weight.data = conv_new_encoder.weight.data * 0. 171 | conv_new_encoder.weight.data[:, :4] = unet_encoder.conv_in.weight.data[:, :4] 172 | conv_new_encoder.bias.data = unet_encoder.conv_in.bias.data 173 | unet_encoder.conv_in = conv_new_encoder # replace conv layer in unet 174 | unet_encoder.config['in_channels'] = 6 # update config 175 | unet_encoder.config.in_channels = 6 # update config 176 | 177 | if args.enable_xformers_memory_efficient_attention: 178 | if is_xformers_available(): 179 | import xformers 180 | 181 | xformers_version = version.parse(xformers.__version__) 182 | if xformers_version == version.parse("0.0.16"): 183 | logger.warn( 184 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 185 | ) 186 | unet.enable_xformers_memory_efficient_attention() 187 | else: 188 | raise ValueError("xformers is not available. Make sure it is installed correctly") 189 | 190 | test_dataset = AssetDataset( 191 | dataroot_path=args.data_dir, 192 | phase="test", 193 | size=(args.height, args.width), 194 | txt_name=args.txt_name, 195 | ) 196 | 197 | test_dataloader = torch.utils.data.DataLoader( 198 | test_dataset, 199 | shuffle=False, 200 | batch_size=args.test_batch_size, 201 | num_workers=4, 202 | ) 203 | 204 | newpipe = AssetDropperPipeline.from_pretrained( 205 | args.pretrained_model_name_or_path, 206 | unet=unet, 207 | vae=vae, 208 | feature_extractor= CLIPImageProcessor(), 209 | text_encoder = text_encoder_one, 210 | text_encoder_2 = text_encoder_two, 211 | tokenizer = tokenizer_one, 212 | tokenizer_2 = tokenizer_two, 213 | scheduler = noise_scheduler, 214 | image_encoder=image_encoder, 215 | unet_encoder = unet_encoder, 216 | torch_dtype=torch.float16, 217 | add_watermarker=False, 218 | safety_checker=None, 219 | ).to(accelerator.device) 220 | 221 | with torch.no_grad(): 222 | with torch.cuda.amp.autocast(): 223 | with torch.no_grad(): 224 | for sample in test_dataloader: 225 | 226 | masked_image_emb_list = [] 227 | 228 | for i in range(sample['masked_image'].shape[0]): 229 | masked_image_emb_list.append(sample['masked_image'][i]) 230 | 231 | masked_image_embeds = torch.cat(masked_image_emb_list, dim=0) 232 | 233 | prompt = sample["caption_pattern"] 234 | num_prompts = sample['image'].shape[0] 235 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 236 | 237 | if not isinstance(prompt, List): 238 | prompt = [prompt] * num_prompts 239 | if not isinstance(negative_prompt, List): 240 | negative_prompt = [negative_prompt] * num_prompts 241 | 242 | with torch.inference_mode(): 243 | ( 244 | prompt_embeds, 245 | negative_prompt_embeds, 246 | pooled_prompt_embeds, 247 | negative_pooled_prompt_embeds, 248 | ) = newpipe.encode_prompt( 249 | prompt, 250 | num_images_per_prompt=1, 251 | do_classifier_free_guidance=True, 252 | negative_prompt=negative_prompt, 253 | ) 254 | 255 | prompt = sample["caption_gen"] 256 | num_prompts = sample['image'].shape[0] 257 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 258 | 259 | if not isinstance(prompt, List): 260 | prompt = [prompt] * num_prompts 261 | if not isinstance(negative_prompt, List): 262 | negative_prompt = [negative_prompt] * num_prompts 263 | 264 | 265 | with torch.inference_mode(): 266 | ( 267 | prompt_embeds_c, 268 | _, 269 | _, 270 | _, 271 | ) = newpipe.encode_prompt( 272 | prompt, 273 | num_images_per_prompt=1, 274 | do_classifier_free_guidance=False, 275 | negative_prompt=negative_prompt, 276 | ) 277 | 278 | seed = args.seed 279 | generator = torch.Generator(newpipe.device).manual_seed(seed) 280 | 281 | images = newpipe( 282 | prompt_embeds=prompt_embeds, 283 | negative_prompt_embeds=negative_prompt_embeds, 284 | pooled_prompt_embeds=pooled_prompt_embeds, 285 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 286 | num_inference_steps=args.num_inference_steps, 287 | generator=generator, 288 | strength = 1.0, 289 | reference_image_embed=prompt_embeds_c, #reference_image_embed 290 | image = sample["image"].to(accelerator.device), 291 | mask = sample['mask'], 292 | edgemap = sample['edgemap'], 293 | pattern = sample['pattern'], 294 | height=args.height, 295 | width=args.width, 296 | P_height=args.Pheight, 297 | P_width=args.Pwidth, 298 | guidance_scale=args.guidance_scale, 299 | ip_adapter_image = masked_image_embeds, 300 | )[0] 301 | 302 | for i in range(len(images)): 303 | x_sample = pil_to_tensor(images[i]) 304 | save_path = os.path.join(args.output_dir, f"{sample['image_name'][i]}") 305 | torchvision.utils.save_image(x_sample, save_path) 306 | 307 | torch.cuda.empty_cache() 308 | 309 | 310 | 311 | 312 | if __name__ == "__main__": 313 | main() 314 | -------------------------------------------------------------------------------- /src/transformerhacked_garmnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.models.embeddings import ImagePositionalEmbeddings 23 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 24 | from src.attentionhacked_garmnet import BasicTransformerBlock 25 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection 26 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from diffusers.models.normalization import AdaLayerNormSingle 29 | 30 | 31 | @dataclass 32 | class Transformer2DModelOutput(BaseOutput): 33 | """ 34 | The output of [`Transformer2DModel`]. 35 | 36 | Args: 37 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 38 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 39 | distributions for the unnoised latent pixels. 40 | """ 41 | 42 | sample: torch.FloatTensor 43 | 44 | 45 | class Transformer2DModel(ModelMixin, ConfigMixin): 46 | """ 47 | A 2D Transformer model for image-like data. 48 | 49 | Parameters: 50 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 51 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 52 | in_channels (`int`, *optional*): 53 | The number of channels in the input and output (specify if the input is **continuous**). 54 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 55 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 56 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 57 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 58 | This is fixed during training since it is used to learn a number of position embeddings. 59 | num_vector_embeds (`int`, *optional*): 60 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 61 | Includes the class for the masked latent pixel. 62 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 63 | num_embeds_ada_norm ( `int`, *optional*): 64 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 65 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 66 | added to the hidden states. 67 | 68 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 69 | attention_bias (`bool`, *optional*): 70 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 71 | """ 72 | 73 | _supports_gradient_checkpointing = True 74 | 75 | @register_to_config 76 | def __init__( 77 | self, 78 | num_attention_heads: int = 16, 79 | attention_head_dim: int = 88, 80 | in_channels: Optional[int] = None, 81 | out_channels: Optional[int] = None, 82 | num_layers: int = 1, 83 | dropout: float = 0.0, 84 | norm_num_groups: int = 32, 85 | cross_attention_dim: Optional[int] = None, 86 | attention_bias: bool = False, 87 | sample_size: Optional[int] = None, 88 | num_vector_embeds: Optional[int] = None, 89 | patch_size: Optional[int] = None, 90 | activation_fn: str = "geglu", 91 | num_embeds_ada_norm: Optional[int] = None, 92 | use_linear_projection: bool = False, 93 | only_cross_attention: bool = False, 94 | double_self_attention: bool = False, 95 | upcast_attention: bool = False, 96 | norm_type: str = "layer_norm", 97 | norm_elementwise_affine: bool = True, 98 | norm_eps: float = 1e-5, 99 | attention_type: str = "default", 100 | caption_channels: int = None, 101 | ): 102 | super().__init__() 103 | self.use_linear_projection = use_linear_projection 104 | self.num_attention_heads = num_attention_heads 105 | self.attention_head_dim = attention_head_dim 106 | inner_dim = num_attention_heads * attention_head_dim 107 | 108 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 109 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 110 | 111 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 112 | # Define whether input is continuous or discrete depending on configuration 113 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 114 | self.is_input_vectorized = num_vector_embeds is not None 115 | self.is_input_patches = in_channels is not None and patch_size is not None 116 | 117 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 118 | deprecation_message = ( 119 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 120 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 121 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 122 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 123 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 124 | ) 125 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 126 | norm_type = "ada_norm" 127 | 128 | if self.is_input_continuous and self.is_input_vectorized: 129 | raise ValueError( 130 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 131 | " sure that either `in_channels` or `num_vector_embeds` is None." 132 | ) 133 | elif self.is_input_vectorized and self.is_input_patches: 134 | raise ValueError( 135 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 136 | " sure that either `num_vector_embeds` or `num_patches` is None." 137 | ) 138 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 139 | raise ValueError( 140 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 141 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 142 | ) 143 | 144 | # 2. Define input layers 145 | if self.is_input_continuous: 146 | self.in_channels = in_channels 147 | 148 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 149 | if use_linear_projection: 150 | self.proj_in = linear_cls(in_channels, inner_dim) 151 | else: 152 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 153 | elif self.is_input_vectorized: 154 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 155 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 156 | 157 | self.height = sample_size 158 | self.width = sample_size 159 | self.num_vector_embeds = num_vector_embeds 160 | self.num_latent_pixels = self.height * self.width 161 | 162 | self.latent_image_embedding = ImagePositionalEmbeddings( 163 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 164 | ) 165 | elif self.is_input_patches: 166 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 167 | 168 | self.height = sample_size 169 | self.width = sample_size 170 | 171 | self.patch_size = patch_size 172 | interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 173 | interpolation_scale = max(interpolation_scale, 1) 174 | self.pos_embed = PatchEmbed( 175 | height=sample_size, 176 | width=sample_size, 177 | patch_size=patch_size, 178 | in_channels=in_channels, 179 | embed_dim=inner_dim, 180 | interpolation_scale=interpolation_scale, 181 | ) 182 | 183 | # 3. Define transformers blocks 184 | self.transformer_blocks = nn.ModuleList( 185 | [ 186 | BasicTransformerBlock( 187 | inner_dim, 188 | num_attention_heads, 189 | attention_head_dim, 190 | dropout=dropout, 191 | cross_attention_dim=cross_attention_dim, 192 | activation_fn=activation_fn, 193 | num_embeds_ada_norm=num_embeds_ada_norm, 194 | attention_bias=attention_bias, 195 | only_cross_attention=only_cross_attention, 196 | double_self_attention=double_self_attention, 197 | upcast_attention=upcast_attention, 198 | norm_type=norm_type, 199 | norm_elementwise_affine=norm_elementwise_affine, 200 | norm_eps=norm_eps, 201 | attention_type=attention_type, 202 | ) 203 | for d in range(num_layers) 204 | ] 205 | ) 206 | 207 | # 4. Define output layers 208 | self.out_channels = in_channels if out_channels is None else out_channels 209 | if self.is_input_continuous: 210 | # TODO: should use out_channels for continuous projections 211 | if use_linear_projection: 212 | self.proj_out = linear_cls(inner_dim, in_channels) 213 | else: 214 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 215 | elif self.is_input_vectorized: 216 | self.norm_out = nn.LayerNorm(inner_dim) 217 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 218 | elif self.is_input_patches and norm_type != "ada_norm_single": 219 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 220 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 221 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 222 | elif self.is_input_patches and norm_type == "ada_norm_single": 223 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 224 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) 225 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 226 | 227 | # 5. PixArt-Alpha blocks. 228 | self.adaln_single = None 229 | self.use_additional_conditions = False 230 | if norm_type == "ada_norm_single": 231 | self.use_additional_conditions = self.config.sample_size == 128 232 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 233 | # additional conditions until we find better name 234 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) 235 | 236 | self.caption_projection = None 237 | if caption_channels is not None: 238 | self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) 239 | 240 | self.gradient_checkpointing = False 241 | 242 | def _set_gradient_checkpointing(self, module, value=False): 243 | if hasattr(module, "gradient_checkpointing"): 244 | module.gradient_checkpointing = value 245 | 246 | def forward( 247 | self, 248 | hidden_states: torch.Tensor, 249 | encoder_hidden_states: Optional[torch.Tensor] = None, 250 | timestep: Optional[torch.LongTensor] = None, 251 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 252 | class_labels: Optional[torch.LongTensor] = None, 253 | cross_attention_kwargs: Dict[str, Any] = None, 254 | attention_mask: Optional[torch.Tensor] = None, 255 | encoder_attention_mask: Optional[torch.Tensor] = None, 256 | return_dict: bool = True, 257 | ): 258 | """ 259 | The [`Transformer2DModel`] forward method. 260 | 261 | Args: 262 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 263 | Input `hidden_states`. 264 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 265 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 266 | self-attention. 267 | timestep ( `torch.LongTensor`, *optional*): 268 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 269 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 270 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 271 | `AdaLayerZeroNorm`. 272 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 273 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 274 | `self.processor` in 275 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 276 | attention_mask ( `torch.Tensor`, *optional*): 277 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 278 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 279 | negative values to the attention scores corresponding to "discard" tokens. 280 | encoder_attention_mask ( `torch.Tensor`, *optional*): 281 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 282 | 283 | * Mask `(batch, sequence_length)` True = keep, False = discard. 284 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 285 | 286 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 287 | above. This bias will be added to the cross-attention scores. 288 | return_dict (`bool`, *optional*, defaults to `True`): 289 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 290 | tuple. 291 | 292 | Returns: 293 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 294 | `tuple` where the first element is the sample tensor. 295 | """ 296 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 297 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 298 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 299 | # expects mask of shape: 300 | # [batch, key_tokens] 301 | # adds singleton query_tokens dimension: 302 | # [batch, 1, key_tokens] 303 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 304 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 305 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 306 | if attention_mask is not None and attention_mask.ndim == 2: 307 | # assume that mask is expressed as: 308 | # (1 = keep, 0 = discard) 309 | # convert mask into a bias that can be added to attention scores: 310 | # (keep = +0, discard = -10000.0) 311 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 312 | attention_mask = attention_mask.unsqueeze(1) 313 | 314 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 315 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 316 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 317 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 318 | 319 | # Retrieve lora scale. 320 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 321 | 322 | # 1. Input 323 | if self.is_input_continuous: 324 | batch, _, height, width = hidden_states.shape 325 | # print("1111hidden_states.shape",hidden_states.shape) 326 | residual = hidden_states 327 | 328 | hidden_states = self.norm(hidden_states) 329 | if not self.use_linear_projection: 330 | hidden_states = ( 331 | self.proj_in(hidden_states, scale=lora_scale) 332 | if not USE_PEFT_BACKEND 333 | else self.proj_in(hidden_states) 334 | ) 335 | inner_dim = hidden_states.shape[1] 336 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 337 | else: 338 | inner_dim = hidden_states.shape[1] 339 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 340 | hidden_states = ( 341 | self.proj_in(hidden_states, scale=lora_scale) 342 | if not USE_PEFT_BACKEND 343 | else self.proj_in(hidden_states) 344 | ) 345 | 346 | elif self.is_input_vectorized: 347 | hidden_states = self.latent_image_embedding(hidden_states) 348 | elif self.is_input_patches: 349 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size 350 | hidden_states = self.pos_embed(hidden_states) 351 | 352 | if self.adaln_single is not None: 353 | if self.use_additional_conditions and added_cond_kwargs is None: 354 | raise ValueError( 355 | "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." 356 | ) 357 | batch_size = hidden_states.shape[0] 358 | timestep, embedded_timestep = self.adaln_single( 359 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype 360 | ) 361 | 362 | # 2. Blocks 363 | if self.caption_projection is not None: 364 | batch_size = hidden_states.shape[0] 365 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 366 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 367 | # print("1:encoder_hidden_states.shape",encoder_hidden_states.shape) 368 | 369 | garment_features = [] 370 | for block in self.transformer_blocks: 371 | if self.training and self.gradient_checkpointing: 372 | 373 | def create_custom_forward(module, return_dict=None): 374 | def custom_forward(*inputs): 375 | if return_dict is not None: 376 | return module(*inputs, return_dict=return_dict) 377 | else: 378 | return module(*inputs) 379 | 380 | return custom_forward 381 | 382 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 383 | hidden_states,out_garment_feat = torch.utils.checkpoint.checkpoint( 384 | create_custom_forward(block), 385 | hidden_states, 386 | attention_mask, 387 | encoder_hidden_states, 388 | encoder_attention_mask, 389 | timestep, 390 | cross_attention_kwargs, 391 | class_labels, 392 | **ckpt_kwargs, 393 | ) 394 | else: 395 | #print("transformer.shape",encoder_hidden_states.shape) 396 | hidden_states,out_garment_feat = block( 397 | hidden_states, 398 | attention_mask=attention_mask, 399 | encoder_hidden_states=encoder_hidden_states, 400 | encoder_attention_mask=encoder_attention_mask, 401 | timestep=timestep, 402 | cross_attention_kwargs=cross_attention_kwargs, 403 | class_labels=class_labels, 404 | ) 405 | garment_features += out_garment_feat 406 | # 3. Output 407 | if self.is_input_continuous: 408 | if not self.use_linear_projection: 409 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 410 | hidden_states = ( 411 | self.proj_out(hidden_states, scale=lora_scale) 412 | if not USE_PEFT_BACKEND 413 | else self.proj_out(hidden_states) 414 | ) 415 | else: 416 | hidden_states = ( 417 | self.proj_out(hidden_states, scale=lora_scale) 418 | if not USE_PEFT_BACKEND 419 | else self.proj_out(hidden_states) 420 | ) 421 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 422 | 423 | output = hidden_states + residual 424 | elif self.is_input_vectorized: 425 | hidden_states = self.norm_out(hidden_states) 426 | logits = self.out(hidden_states) 427 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 428 | logits = logits.permute(0, 2, 1) 429 | 430 | # log(p(x_0)) 431 | output = F.log_softmax(logits.double(), dim=1).float() 432 | 433 | if self.is_input_patches: 434 | if self.config.norm_type != "ada_norm_single": 435 | conditioning = self.transformer_blocks[0].norm1.emb( 436 | timestep, class_labels, hidden_dtype=hidden_states.dtype 437 | ) 438 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 439 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 440 | hidden_states = self.proj_out_2(hidden_states) 441 | elif self.config.norm_type == "ada_norm_single": 442 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) 443 | hidden_states = self.norm_out(hidden_states) 444 | # Modulation 445 | hidden_states = hidden_states * (1 + scale) + shift 446 | hidden_states = self.proj_out(hidden_states) 447 | hidden_states = hidden_states.squeeze(1) 448 | 449 | # unpatchify 450 | if self.adaln_single is None: 451 | height = width = int(hidden_states.shape[1] ** 0.5) 452 | hidden_states = hidden_states.reshape( 453 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 454 | ) 455 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 456 | output = hidden_states.reshape( 457 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 458 | ) 459 | 460 | if not return_dict: 461 | return (output,) ,garment_features 462 | 463 | return Transformer2DModelOutput(sample=output),garment_features 464 | -------------------------------------------------------------------------------- /src/transformerhacked_tryon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, Optional 16 | 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.models.embeddings import ImagePositionalEmbeddings 23 | from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version 24 | from src.attentionhacked_tryon import BasicTransformerBlock 25 | from diffusers.models.embeddings import PatchEmbed, PixArtAlphaTextProjection 26 | from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear 27 | from diffusers.models.modeling_utils import ModelMixin 28 | from diffusers.models.normalization import AdaLayerNormSingle 29 | 30 | 31 | @dataclass 32 | class Transformer2DModelOutput(BaseOutput): 33 | """ 34 | The output of [`Transformer2DModel`]. 35 | 36 | Args: 37 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): 38 | The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability 39 | distributions for the unnoised latent pixels. 40 | """ 41 | 42 | sample: torch.FloatTensor 43 | 44 | 45 | class Transformer2DModel(ModelMixin, ConfigMixin): 46 | """ 47 | A 2D Transformer model for image-like data. 48 | 49 | Parameters: 50 | num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. 51 | attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. 52 | in_channels (`int`, *optional*): 53 | The number of channels in the input and output (specify if the input is **continuous**). 54 | num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. 55 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 56 | cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. 57 | sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). 58 | This is fixed during training since it is used to learn a number of position embeddings. 59 | num_vector_embeds (`int`, *optional*): 60 | The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). 61 | Includes the class for the masked latent pixel. 62 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. 63 | num_embeds_ada_norm ( `int`, *optional*): 64 | The number of diffusion steps used during training. Pass if at least one of the norm_layers is 65 | `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are 66 | added to the hidden states. 67 | 68 | During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. 69 | attention_bias (`bool`, *optional*): 70 | Configure if the `TransformerBlocks` attention should contain a bias parameter. 71 | """ 72 | 73 | _supports_gradient_checkpointing = True 74 | 75 | @register_to_config 76 | def __init__( 77 | self, 78 | num_attention_heads: int = 16, 79 | attention_head_dim: int = 88, 80 | in_channels: Optional[int] = None, 81 | out_channels: Optional[int] = None, 82 | num_layers: int = 1, 83 | dropout: float = 0.0, 84 | norm_num_groups: int = 32, 85 | cross_attention_dim: Optional[int] = None, 86 | attention_bias: bool = False, 87 | sample_size: Optional[int] = None, 88 | num_vector_embeds: Optional[int] = None, 89 | patch_size: Optional[int] = None, 90 | activation_fn: str = "geglu", 91 | num_embeds_ada_norm: Optional[int] = None, 92 | use_linear_projection: bool = False, 93 | only_cross_attention: bool = False, 94 | double_self_attention: bool = False, 95 | upcast_attention: bool = False, 96 | norm_type: str = "layer_norm", 97 | norm_elementwise_affine: bool = True, 98 | norm_eps: float = 1e-5, 99 | attention_type: str = "default", 100 | caption_channels: int = None, 101 | ): 102 | super().__init__() 103 | self.use_linear_projection = use_linear_projection 104 | self.num_attention_heads = num_attention_heads 105 | self.attention_head_dim = attention_head_dim 106 | inner_dim = num_attention_heads * attention_head_dim 107 | 108 | conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv 109 | linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear 110 | 111 | # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` 112 | # Define whether input is continuous or discrete depending on configuration 113 | self.is_input_continuous = (in_channels is not None) and (patch_size is None) 114 | self.is_input_vectorized = num_vector_embeds is not None 115 | self.is_input_patches = in_channels is not None and patch_size is not None 116 | 117 | if norm_type == "layer_norm" and num_embeds_ada_norm is not None: 118 | deprecation_message = ( 119 | f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" 120 | " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." 121 | " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" 122 | " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" 123 | " would be very nice if you could open a Pull request for the `transformer/config.json` file" 124 | ) 125 | deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) 126 | norm_type = "ada_norm" 127 | 128 | if self.is_input_continuous and self.is_input_vectorized: 129 | raise ValueError( 130 | f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" 131 | " sure that either `in_channels` or `num_vector_embeds` is None." 132 | ) 133 | elif self.is_input_vectorized and self.is_input_patches: 134 | raise ValueError( 135 | f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" 136 | " sure that either `num_vector_embeds` or `num_patches` is None." 137 | ) 138 | elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: 139 | raise ValueError( 140 | f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" 141 | f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." 142 | ) 143 | 144 | # 2. Define input layers 145 | if self.is_input_continuous: 146 | self.in_channels = in_channels 147 | 148 | self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) 149 | if use_linear_projection: 150 | self.proj_in = linear_cls(in_channels, inner_dim) 151 | else: 152 | self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 153 | elif self.is_input_vectorized: 154 | assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" 155 | assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" 156 | 157 | self.height = sample_size 158 | self.width = sample_size 159 | self.num_vector_embeds = num_vector_embeds 160 | self.num_latent_pixels = self.height * self.width 161 | 162 | self.latent_image_embedding = ImagePositionalEmbeddings( 163 | num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width 164 | ) 165 | elif self.is_input_patches: 166 | assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" 167 | 168 | self.height = sample_size 169 | self.width = sample_size 170 | 171 | self.patch_size = patch_size 172 | interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1 173 | interpolation_scale = max(interpolation_scale, 1) 174 | self.pos_embed = PatchEmbed( 175 | height=sample_size, 176 | width=sample_size, 177 | patch_size=patch_size, 178 | in_channels=in_channels, 179 | embed_dim=inner_dim, 180 | interpolation_scale=interpolation_scale, 181 | ) 182 | 183 | # 3. Define transformers blocks 184 | self.transformer_blocks = nn.ModuleList( 185 | [ 186 | BasicTransformerBlock( 187 | inner_dim, 188 | num_attention_heads, 189 | attention_head_dim, 190 | dropout=dropout, 191 | cross_attention_dim=cross_attention_dim, 192 | activation_fn=activation_fn, 193 | num_embeds_ada_norm=num_embeds_ada_norm, 194 | attention_bias=attention_bias, 195 | only_cross_attention=only_cross_attention, 196 | double_self_attention=double_self_attention, 197 | upcast_attention=upcast_attention, 198 | norm_type=norm_type, 199 | norm_elementwise_affine=norm_elementwise_affine, 200 | norm_eps=norm_eps, 201 | attention_type=attention_type, 202 | ) 203 | for d in range(num_layers) 204 | ] 205 | ) 206 | 207 | # 4. Define output layers 208 | self.out_channels = in_channels if out_channels is None else out_channels 209 | if self.is_input_continuous: 210 | # TODO: should use out_channels for continuous projections 211 | if use_linear_projection: 212 | self.proj_out = linear_cls(inner_dim, in_channels) 213 | else: 214 | self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) 215 | elif self.is_input_vectorized: 216 | self.norm_out = nn.LayerNorm(inner_dim) 217 | self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) 218 | elif self.is_input_patches and norm_type != "ada_norm_single": 219 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 220 | self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) 221 | self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 222 | elif self.is_input_patches and norm_type == "ada_norm_single": 223 | self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) 224 | self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) 225 | self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) 226 | 227 | # 5. PixArt-Alpha blocks. 228 | self.adaln_single = None 229 | self.use_additional_conditions = False 230 | if norm_type == "ada_norm_single": 231 | self.use_additional_conditions = self.config.sample_size == 128 232 | # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use 233 | # additional conditions until we find better name 234 | self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions) 235 | 236 | self.caption_projection = None 237 | if caption_channels is not None: 238 | self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim) 239 | 240 | self.gradient_checkpointing = False 241 | 242 | def _set_gradient_checkpointing(self, module, value=False): 243 | if hasattr(module, "gradient_checkpointing"): 244 | module.gradient_checkpointing = value 245 | 246 | def forward( 247 | self, 248 | hidden_states: torch.Tensor, 249 | encoder_hidden_states: Optional[torch.Tensor] = None, 250 | timestep: Optional[torch.LongTensor] = None, 251 | added_cond_kwargs: Dict[str, torch.Tensor] = None, 252 | class_labels: Optional[torch.LongTensor] = None, 253 | cross_attention_kwargs: Dict[str, Any] = None, 254 | attention_mask: Optional[torch.Tensor] = None, 255 | encoder_attention_mask: Optional[torch.Tensor] = None, 256 | garment_features=None, 257 | curr_garment_feat_idx=0, 258 | return_dict: bool = True, 259 | ): 260 | """ 261 | The [`Transformer2DModel`] forward method. 262 | 263 | Args: 264 | hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): 265 | Input `hidden_states`. 266 | encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): 267 | Conditional embeddings for cross attention layer. If not given, cross-attention defaults to 268 | self-attention. 269 | timestep ( `torch.LongTensor`, *optional*): 270 | Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. 271 | class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): 272 | Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in 273 | `AdaLayerZeroNorm`. 274 | cross_attention_kwargs ( `Dict[str, Any]`, *optional*): 275 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 276 | `self.processor` in 277 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 278 | attention_mask ( `torch.Tensor`, *optional*): 279 | An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask 280 | is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large 281 | negative values to the attention scores corresponding to "discard" tokens. 282 | encoder_attention_mask ( `torch.Tensor`, *optional*): 283 | Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: 284 | 285 | * Mask `(batch, sequence_length)` True = keep, False = discard. 286 | * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. 287 | 288 | If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format 289 | above. This bias will be added to the cross-attention scores. 290 | return_dict (`bool`, *optional*, defaults to `True`): 291 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 292 | tuple. 293 | 294 | Returns: 295 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 296 | `tuple` where the first element is the sample tensor. 297 | """ 298 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. 299 | # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. 300 | # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. 301 | # expects mask of shape: 302 | # [batch, key_tokens] 303 | # adds singleton query_tokens dimension: 304 | # [batch, 1, key_tokens] 305 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 306 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 307 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 308 | if attention_mask is not None and attention_mask.ndim == 2: 309 | # assume that mask is expressed as: 310 | # (1 = keep, 0 = discard) 311 | # convert mask into a bias that can be added to attention scores: 312 | # (keep = +0, discard = -10000.0) 313 | attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 314 | attention_mask = attention_mask.unsqueeze(1) 315 | 316 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 317 | if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: 318 | encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 319 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 320 | 321 | # Retrieve lora scale. 322 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 323 | 324 | # 1. Input 325 | if self.is_input_continuous: 326 | batch, _, height, width = hidden_states.shape 327 | residual = hidden_states 328 | 329 | hidden_states = self.norm(hidden_states) 330 | if not self.use_linear_projection: 331 | hidden_states = ( 332 | self.proj_in(hidden_states, scale=lora_scale) 333 | if not USE_PEFT_BACKEND 334 | else self.proj_in(hidden_states) 335 | ) 336 | inner_dim = hidden_states.shape[1] 337 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 338 | else: 339 | inner_dim = hidden_states.shape[1] 340 | hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) 341 | hidden_states = ( 342 | self.proj_in(hidden_states, scale=lora_scale) 343 | if not USE_PEFT_BACKEND 344 | else self.proj_in(hidden_states) 345 | ) 346 | 347 | elif self.is_input_vectorized: 348 | hidden_states = self.latent_image_embedding(hidden_states) 349 | elif self.is_input_patches: 350 | height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size 351 | hidden_states = self.pos_embed(hidden_states) 352 | 353 | if self.adaln_single is not None: 354 | if self.use_additional_conditions and added_cond_kwargs is None: 355 | raise ValueError( 356 | "`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`." 357 | ) 358 | batch_size = hidden_states.shape[0] 359 | timestep, embedded_timestep = self.adaln_single( 360 | timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype 361 | ) 362 | 363 | # 2. Blocks 364 | if self.caption_projection is not None: 365 | batch_size = hidden_states.shape[0] 366 | encoder_hidden_states = self.caption_projection(encoder_hidden_states) 367 | encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1]) 368 | 369 | 370 | for block in self.transformer_blocks: 371 | if self.training and self.gradient_checkpointing: 372 | 373 | def create_custom_forward(module, return_dict=None): 374 | def custom_forward(*inputs): 375 | if return_dict is not None: 376 | return module(*inputs, return_dict=return_dict) 377 | else: 378 | return module(*inputs) 379 | 380 | return custom_forward 381 | 382 | ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} 383 | hidden_states,curr_garment_feat_idx = torch.utils.checkpoint.checkpoint( 384 | create_custom_forward(block), 385 | hidden_states, 386 | attention_mask, 387 | encoder_hidden_states, 388 | encoder_attention_mask, 389 | timestep, 390 | cross_attention_kwargs, 391 | class_labels, 392 | garment_features, 393 | curr_garment_feat_idx, 394 | **ckpt_kwargs, 395 | ) 396 | else: 397 | hidden_states,curr_garment_feat_idx = block( 398 | hidden_states, 399 | attention_mask=attention_mask, 400 | encoder_hidden_states=encoder_hidden_states, 401 | encoder_attention_mask=encoder_attention_mask, 402 | timestep=timestep, 403 | cross_attention_kwargs=cross_attention_kwargs, 404 | class_labels=class_labels, 405 | garment_features=garment_features, 406 | curr_garment_feat_idx=curr_garment_feat_idx, 407 | ) 408 | 409 | 410 | # 3. Output 411 | if self.is_input_continuous: 412 | if not self.use_linear_projection: 413 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 414 | hidden_states = ( 415 | self.proj_out(hidden_states, scale=lora_scale) 416 | if not USE_PEFT_BACKEND 417 | else self.proj_out(hidden_states) 418 | ) 419 | else: 420 | hidden_states = ( 421 | self.proj_out(hidden_states, scale=lora_scale) 422 | if not USE_PEFT_BACKEND 423 | else self.proj_out(hidden_states) 424 | ) 425 | hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() 426 | 427 | output = hidden_states + residual 428 | elif self.is_input_vectorized: 429 | hidden_states = self.norm_out(hidden_states) 430 | logits = self.out(hidden_states) 431 | # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) 432 | logits = logits.permute(0, 2, 1) 433 | 434 | # log(p(x_0)) 435 | output = F.log_softmax(logits.double(), dim=1).float() 436 | 437 | if self.is_input_patches: 438 | if self.config.norm_type != "ada_norm_single": 439 | conditioning = self.transformer_blocks[0].norm1.emb( 440 | timestep, class_labels, hidden_dtype=hidden_states.dtype 441 | ) 442 | shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) 443 | hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] 444 | hidden_states = self.proj_out_2(hidden_states) 445 | elif self.config.norm_type == "ada_norm_single": 446 | shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1) 447 | hidden_states = self.norm_out(hidden_states) 448 | # Modulation 449 | hidden_states = hidden_states * (1 + scale) + shift 450 | hidden_states = self.proj_out(hidden_states) 451 | hidden_states = hidden_states.squeeze(1) 452 | 453 | # unpatchify 454 | if self.adaln_single is None: 455 | height = width = int(hidden_states.shape[1] ** 0.5) 456 | hidden_states = hidden_states.reshape( 457 | shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) 458 | ) 459 | hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) 460 | output = hidden_states.reshape( 461 | shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) 462 | ) 463 | 464 | if not return_dict: 465 | return (output,),curr_garment_feat_idx 466 | 467 | return Transformer2DModelOutput(sample=output),curr_garment_feat_idx 468 | -------------------------------------------------------------------------------- /src/attentionhacked_tryon.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import USE_PEFT_BACKEND 21 | from diffusers.utils.torch_utils import maybe_allow_in_graph 22 | from diffusers.models.activations import GEGLU, GELU, ApproximateGELU 23 | from diffusers.models.attention_processor import Attention 24 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 25 | from diffusers.models.lora import LoRACompatibleLinear 26 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm 27 | 28 | 29 | def _chunked_feed_forward( 30 | ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None 31 | ): 32 | # "feed_forward_chunk_size" can be used to save memory 33 | if hidden_states.shape[chunk_dim] % chunk_size != 0: 34 | raise ValueError( 35 | f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 36 | ) 37 | 38 | num_chunks = hidden_states.shape[chunk_dim] // chunk_size 39 | if lora_scale is None: 40 | ff_output = torch.cat( 41 | [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 42 | dim=chunk_dim, 43 | ) 44 | else: 45 | # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete 46 | ff_output = torch.cat( 47 | [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 48 | dim=chunk_dim, 49 | ) 50 | 51 | return ff_output 52 | 53 | 54 | @maybe_allow_in_graph 55 | class GatedSelfAttentionDense(nn.Module): 56 | r""" 57 | A gated self-attention dense layer that combines visual features and object features. 58 | 59 | Parameters: 60 | query_dim (`int`): The number of channels in the query. 61 | context_dim (`int`): The number of channels in the context. 62 | n_heads (`int`): The number of heads to use for attention. 63 | d_head (`int`): The number of channels in each head. 64 | """ 65 | 66 | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): 67 | super().__init__() 68 | 69 | # we need a linear projection since we need cat visual feature and obj feature 70 | self.linear = nn.Linear(context_dim, query_dim) 71 | 72 | self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 73 | self.ff = FeedForward(query_dim, activation_fn="geglu") 74 | 75 | self.norm1 = nn.LayerNorm(query_dim) 76 | self.norm2 = nn.LayerNorm(query_dim) 77 | 78 | self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) 79 | self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) 80 | 81 | self.enabled = True 82 | 83 | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: 84 | if not self.enabled: 85 | return x 86 | 87 | n_visual = x.shape[1] 88 | objs = self.linear(objs) 89 | 90 | x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] 91 | x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) 92 | 93 | return x 94 | 95 | 96 | @maybe_allow_in_graph 97 | class BasicTransformerBlock(nn.Module): 98 | r""" 99 | A basic Transformer block. 100 | 101 | Parameters: 102 | dim (`int`): The number of channels in the input and output. 103 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 104 | attention_head_dim (`int`): The number of channels in each head. 105 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 106 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 107 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 108 | num_embeds_ada_norm (: 109 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 110 | attention_bias (: 111 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 112 | only_cross_attention (`bool`, *optional*): 113 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 114 | double_self_attention (`bool`, *optional*): 115 | Whether to use two self-attention layers. In this case no cross attention layers are used. 116 | upcast_attention (`bool`, *optional*): 117 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 118 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 119 | Whether to use learnable elementwise affine parameters for normalization. 120 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 121 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 122 | final_dropout (`bool` *optional*, defaults to False): 123 | Whether to apply a final dropout after the last feed-forward layer. 124 | attention_type (`str`, *optional*, defaults to `"default"`): 125 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 126 | positional_embeddings (`str`, *optional*, defaults to `None`): 127 | The type of positional embeddings to apply to. 128 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 129 | The maximum number of positional embeddings to apply. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | dim: int, 135 | num_attention_heads: int, 136 | attention_head_dim: int, 137 | dropout=0.0, 138 | cross_attention_dim: Optional[int] = None, 139 | activation_fn: str = "geglu", 140 | num_embeds_ada_norm: Optional[int] = None, 141 | attention_bias: bool = False, 142 | only_cross_attention: bool = False, 143 | double_self_attention: bool = False, 144 | upcast_attention: bool = False, 145 | norm_elementwise_affine: bool = True, 146 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 147 | norm_eps: float = 1e-5, 148 | final_dropout: bool = False, 149 | attention_type: str = "default", 150 | positional_embeddings: Optional[str] = None, 151 | num_positional_embeddings: Optional[int] = None, 152 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 153 | ada_norm_bias: Optional[int] = None, 154 | ff_inner_dim: Optional[int] = None, 155 | ff_bias: bool = True, 156 | attention_out_bias: bool = True, 157 | ): 158 | super().__init__() 159 | self.only_cross_attention = only_cross_attention 160 | 161 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 162 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 163 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 164 | self.use_layer_norm = norm_type == "layer_norm" 165 | self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" 166 | 167 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 168 | raise ValueError( 169 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 170 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 171 | ) 172 | 173 | if positional_embeddings and (num_positional_embeddings is None): 174 | raise ValueError( 175 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 176 | ) 177 | 178 | if positional_embeddings == "sinusoidal": 179 | self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) 180 | else: 181 | self.pos_embed = None 182 | 183 | # Define 3 blocks. Each block has its own normalization layer. 184 | # 1. Self-Attn 185 | if self.use_ada_layer_norm: 186 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 187 | elif self.use_ada_layer_norm_zero: 188 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 189 | elif self.use_ada_layer_norm_continuous: 190 | self.norm1 = AdaLayerNormContinuous( 191 | dim, 192 | ada_norm_continous_conditioning_embedding_dim, 193 | norm_elementwise_affine, 194 | norm_eps, 195 | ada_norm_bias, 196 | "rms_norm", 197 | ) 198 | else: 199 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 200 | 201 | self.attn1 = Attention( 202 | query_dim=dim, 203 | heads=num_attention_heads, 204 | dim_head=attention_head_dim, 205 | dropout=dropout, 206 | bias=attention_bias, 207 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 208 | upcast_attention=upcast_attention, 209 | out_bias=attention_out_bias, 210 | ) 211 | 212 | # 2. Cross-Attn 213 | if cross_attention_dim is not None or double_self_attention: 214 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 215 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 216 | # the second cross attention block. 217 | if self.use_ada_layer_norm: 218 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 219 | elif self.use_ada_layer_norm_continuous: 220 | self.norm2 = AdaLayerNormContinuous( 221 | dim, 222 | ada_norm_continous_conditioning_embedding_dim, 223 | norm_elementwise_affine, 224 | norm_eps, 225 | ada_norm_bias, 226 | "rms_norm", 227 | ) 228 | else: 229 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 230 | 231 | self.attn2 = Attention( 232 | query_dim=dim, 233 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 234 | heads=num_attention_heads, 235 | dim_head=attention_head_dim, 236 | dropout=dropout, 237 | bias=attention_bias, 238 | upcast_attention=upcast_attention, 239 | out_bias=attention_out_bias, 240 | ) # is self-attn if encoder_hidden_states is none 241 | else: 242 | self.norm2 = None 243 | self.attn2 = None 244 | 245 | # 3. Feed-forward 246 | if self.use_ada_layer_norm_continuous: 247 | self.norm3 = AdaLayerNormContinuous( 248 | dim, 249 | ada_norm_continous_conditioning_embedding_dim, 250 | norm_elementwise_affine, 251 | norm_eps, 252 | ada_norm_bias, 253 | "layer_norm", 254 | ) 255 | elif not self.use_ada_layer_norm_single: 256 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 257 | 258 | self.ff = FeedForward( 259 | dim, 260 | dropout=dropout, 261 | activation_fn=activation_fn, 262 | final_dropout=final_dropout, 263 | inner_dim=ff_inner_dim, 264 | bias=ff_bias, 265 | ) 266 | 267 | # 4. Fuser 268 | if attention_type == "gated" or attention_type == "gated-text-image": 269 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 270 | 271 | # 5. Scale-shift for PixArt-Alpha. 272 | if self.use_ada_layer_norm_single: 273 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 274 | 275 | # let chunk size default to None 276 | self._chunk_size = None 277 | self._chunk_dim = 0 278 | 279 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 280 | # Sets chunk feed-forward 281 | self._chunk_size = chunk_size 282 | self._chunk_dim = dim 283 | 284 | def forward( 285 | self, 286 | hidden_states: torch.FloatTensor, 287 | attention_mask: Optional[torch.FloatTensor] = None, 288 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 289 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 290 | timestep: Optional[torch.LongTensor] = None, 291 | cross_attention_kwargs: Dict[str, Any] = None, 292 | class_labels: Optional[torch.LongTensor] = None, 293 | garment_features=None, 294 | curr_garment_feat_idx=0, 295 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 296 | ) -> torch.FloatTensor: 297 | # Notice that normalization is always applied before the real computation in the following blocks. 298 | # 0. Self-Attention 299 | batch_size = hidden_states.shape[0] 300 | 301 | 302 | 303 | if self.use_ada_layer_norm: 304 | norm_hidden_states = self.norm1(hidden_states, timestep) 305 | elif self.use_ada_layer_norm_zero: 306 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 307 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 308 | ) 309 | elif self.use_layer_norm: 310 | norm_hidden_states = self.norm1(hidden_states) 311 | elif self.use_ada_layer_norm_continuous: 312 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 313 | elif self.use_ada_layer_norm_single: 314 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 315 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 316 | ).chunk(6, dim=1) 317 | norm_hidden_states = self.norm1(hidden_states) 318 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 319 | norm_hidden_states = norm_hidden_states.squeeze(1) 320 | else: 321 | raise ValueError("Incorrect norm used") 322 | 323 | if self.pos_embed is not None: 324 | norm_hidden_states = self.pos_embed(norm_hidden_states) 325 | 326 | # 1. Retrieve lora scale. 327 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 328 | 329 | # 2. Prepare GLIGEN inputs 330 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 331 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 332 | 333 | 334 | modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1) 335 | curr_garment_feat_idx +=1 336 | attn_output = self.attn1( 337 | #norm_hidden_states, 338 | modify_norm_hidden_states, 339 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 340 | attention_mask=attention_mask, 341 | **cross_attention_kwargs, 342 | ) 343 | if self.use_ada_layer_norm_zero: 344 | attn_output = gate_msa.unsqueeze(1) * attn_output 345 | elif self.use_ada_layer_norm_single: 346 | attn_output = gate_msa * attn_output 347 | 348 | hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states 349 | 350 | 351 | 352 | 353 | if hidden_states.ndim == 4: 354 | hidden_states = hidden_states.squeeze(1) 355 | 356 | # 2.5 GLIGEN Control 357 | if gligen_kwargs is not None: 358 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 359 | 360 | # 3. Cross-Attention 361 | if self.attn2 is not None: 362 | if self.use_ada_layer_norm: 363 | norm_hidden_states = self.norm2(hidden_states, timestep) 364 | elif self.use_ada_layer_norm_zero or self.use_layer_norm: 365 | norm_hidden_states = self.norm2(hidden_states) 366 | elif self.use_ada_layer_norm_single: 367 | # For PixArt norm2 isn't applied here: 368 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 369 | norm_hidden_states = hidden_states 370 | elif self.use_ada_layer_norm_continuous: 371 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 372 | else: 373 | raise ValueError("Incorrect norm") 374 | 375 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False: 376 | norm_hidden_states = self.pos_embed(norm_hidden_states) 377 | 378 | attn_output = self.attn2( 379 | norm_hidden_states, 380 | encoder_hidden_states=encoder_hidden_states, 381 | attention_mask=encoder_attention_mask, 382 | **cross_attention_kwargs, 383 | ) 384 | hidden_states = attn_output + hidden_states 385 | 386 | # 4. Feed-forward 387 | if self.use_ada_layer_norm_continuous: 388 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 389 | elif not self.use_ada_layer_norm_single: 390 | norm_hidden_states = self.norm3(hidden_states) 391 | 392 | if self.use_ada_layer_norm_zero: 393 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 394 | 395 | if self.use_ada_layer_norm_single: 396 | norm_hidden_states = self.norm2(hidden_states) 397 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 398 | 399 | if self._chunk_size is not None: 400 | # "feed_forward_chunk_size" can be used to save memory 401 | ff_output = _chunked_feed_forward( 402 | self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale 403 | ) 404 | else: 405 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 406 | 407 | if self.use_ada_layer_norm_zero: 408 | ff_output = gate_mlp.unsqueeze(1) * ff_output 409 | elif self.use_ada_layer_norm_single: 410 | ff_output = gate_mlp * ff_output 411 | 412 | hidden_states = ff_output + hidden_states 413 | if hidden_states.ndim == 4: 414 | hidden_states = hidden_states.squeeze(1) 415 | return hidden_states,curr_garment_feat_idx 416 | 417 | 418 | @maybe_allow_in_graph 419 | class TemporalBasicTransformerBlock(nn.Module): 420 | r""" 421 | A basic Transformer block for video like data. 422 | 423 | Parameters: 424 | dim (`int`): The number of channels in the input and output. 425 | time_mix_inner_dim (`int`): The number of channels for temporal attention. 426 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 427 | attention_head_dim (`int`): The number of channels in each head. 428 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 429 | """ 430 | 431 | def __init__( 432 | self, 433 | dim: int, 434 | time_mix_inner_dim: int, 435 | num_attention_heads: int, 436 | attention_head_dim: int, 437 | cross_attention_dim: Optional[int] = None, 438 | ): 439 | super().__init__() 440 | self.is_res = dim == time_mix_inner_dim 441 | 442 | self.norm_in = nn.LayerNorm(dim) 443 | 444 | # Define 3 blocks. Each block has its own normalization layer. 445 | # 1. Self-Attn 446 | self.norm_in = nn.LayerNorm(dim) 447 | self.ff_in = FeedForward( 448 | dim, 449 | dim_out=time_mix_inner_dim, 450 | activation_fn="geglu", 451 | ) 452 | 453 | self.norm1 = nn.LayerNorm(time_mix_inner_dim) 454 | self.attn1 = Attention( 455 | query_dim=time_mix_inner_dim, 456 | heads=num_attention_heads, 457 | dim_head=attention_head_dim, 458 | cross_attention_dim=None, 459 | ) 460 | 461 | # 2. Cross-Attn 462 | if cross_attention_dim is not None: 463 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 464 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 465 | # the second cross attention block. 466 | self.norm2 = nn.LayerNorm(time_mix_inner_dim) 467 | self.attn2 = Attention( 468 | query_dim=time_mix_inner_dim, 469 | cross_attention_dim=cross_attention_dim, 470 | heads=num_attention_heads, 471 | dim_head=attention_head_dim, 472 | ) # is self-attn if encoder_hidden_states is none 473 | else: 474 | self.norm2 = None 475 | self.attn2 = None 476 | 477 | # 3. Feed-forward 478 | self.norm3 = nn.LayerNorm(time_mix_inner_dim) 479 | self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") 480 | 481 | # let chunk size default to None 482 | self._chunk_size = None 483 | self._chunk_dim = None 484 | 485 | def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): 486 | # Sets chunk feed-forward 487 | self._chunk_size = chunk_size 488 | # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off 489 | self._chunk_dim = 1 490 | 491 | def forward( 492 | self, 493 | hidden_states: torch.FloatTensor, 494 | num_frames: int, 495 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 496 | ) -> torch.FloatTensor: 497 | # Notice that normalization is always applied before the real computation in the following blocks. 498 | # 0. Self-Attention 499 | batch_size = hidden_states.shape[0] 500 | 501 | batch_frames, seq_length, channels = hidden_states.shape 502 | batch_size = batch_frames // num_frames 503 | 504 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) 505 | hidden_states = hidden_states.permute(0, 2, 1, 3) 506 | hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) 507 | 508 | residual = hidden_states 509 | hidden_states = self.norm_in(hidden_states) 510 | 511 | if self._chunk_size is not None: 512 | hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) 513 | else: 514 | hidden_states = self.ff_in(hidden_states) 515 | 516 | if self.is_res: 517 | hidden_states = hidden_states + residual 518 | 519 | norm_hidden_states = self.norm1(hidden_states) 520 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) 521 | hidden_states = attn_output + hidden_states 522 | 523 | # 3. Cross-Attention 524 | if self.attn2 is not None: 525 | norm_hidden_states = self.norm2(hidden_states) 526 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 527 | hidden_states = attn_output + hidden_states 528 | 529 | # 4. Feed-forward 530 | norm_hidden_states = self.norm3(hidden_states) 531 | 532 | if self._chunk_size is not None: 533 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 534 | else: 535 | ff_output = self.ff(norm_hidden_states) 536 | 537 | if self.is_res: 538 | hidden_states = ff_output + hidden_states 539 | else: 540 | hidden_states = ff_output 541 | 542 | hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) 543 | hidden_states = hidden_states.permute(0, 2, 1, 3) 544 | hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) 545 | 546 | return hidden_states 547 | 548 | 549 | class SkipFFTransformerBlock(nn.Module): 550 | def __init__( 551 | self, 552 | dim: int, 553 | num_attention_heads: int, 554 | attention_head_dim: int, 555 | kv_input_dim: int, 556 | kv_input_dim_proj_use_bias: bool, 557 | dropout=0.0, 558 | cross_attention_dim: Optional[int] = None, 559 | attention_bias: bool = False, 560 | attention_out_bias: bool = True, 561 | ): 562 | super().__init__() 563 | if kv_input_dim != dim: 564 | self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) 565 | else: 566 | self.kv_mapper = None 567 | 568 | self.norm1 = RMSNorm(dim, 1e-06) 569 | 570 | self.attn1 = Attention( 571 | query_dim=dim, 572 | heads=num_attention_heads, 573 | dim_head=attention_head_dim, 574 | dropout=dropout, 575 | bias=attention_bias, 576 | cross_attention_dim=cross_attention_dim, 577 | out_bias=attention_out_bias, 578 | ) 579 | 580 | self.norm2 = RMSNorm(dim, 1e-06) 581 | 582 | self.attn2 = Attention( 583 | query_dim=dim, 584 | cross_attention_dim=cross_attention_dim, 585 | heads=num_attention_heads, 586 | dim_head=attention_head_dim, 587 | dropout=dropout, 588 | bias=attention_bias, 589 | out_bias=attention_out_bias, 590 | ) 591 | 592 | def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): 593 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 594 | 595 | if self.kv_mapper is not None: 596 | encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) 597 | 598 | norm_hidden_states = self.norm1(hidden_states) 599 | 600 | attn_output = self.attn1( 601 | norm_hidden_states, 602 | encoder_hidden_states=encoder_hidden_states, 603 | **cross_attention_kwargs, 604 | ) 605 | 606 | hidden_states = attn_output + hidden_states 607 | 608 | norm_hidden_states = self.norm2(hidden_states) 609 | 610 | attn_output = self.attn2( 611 | norm_hidden_states, 612 | encoder_hidden_states=encoder_hidden_states, 613 | **cross_attention_kwargs, 614 | ) 615 | 616 | hidden_states = attn_output + hidden_states 617 | 618 | return hidden_states 619 | 620 | 621 | class FeedForward(nn.Module): 622 | r""" 623 | A feed-forward layer. 624 | 625 | Parameters: 626 | dim (`int`): The number of channels in the input. 627 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 628 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 629 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 630 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 631 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 632 | bias (`bool`, defaults to True): Whether to use a bias in the linear layer. 633 | """ 634 | 635 | def __init__( 636 | self, 637 | dim: int, 638 | dim_out: Optional[int] = None, 639 | mult: int = 4, 640 | dropout: float = 0.0, 641 | activation_fn: str = "geglu", 642 | final_dropout: bool = False, 643 | inner_dim=None, 644 | bias: bool = True, 645 | ): 646 | super().__init__() 647 | if inner_dim is None: 648 | inner_dim = int(dim * mult) 649 | dim_out = dim_out if dim_out is not None else dim 650 | linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear 651 | 652 | if activation_fn == "gelu": 653 | act_fn = GELU(dim, inner_dim, bias=bias) 654 | if activation_fn == "gelu-approximate": 655 | act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) 656 | elif activation_fn == "geglu": 657 | act_fn = GEGLU(dim, inner_dim, bias=bias) 658 | elif activation_fn == "geglu-approximate": 659 | act_fn = ApproximateGELU(dim, inner_dim, bias=bias) 660 | 661 | self.net = nn.ModuleList([]) 662 | # project in 663 | self.net.append(act_fn) 664 | # project dropout 665 | self.net.append(nn.Dropout(dropout)) 666 | # project out 667 | self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) 668 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 669 | if final_dropout: 670 | self.net.append(nn.Dropout(dropout)) 671 | 672 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: 673 | compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) 674 | for module in self.net: 675 | if isinstance(module, compatible_cls): 676 | hidden_states = module(hidden_states, scale) 677 | else: 678 | hidden_states = module(hidden_states) 679 | return hidden_states 680 | -------------------------------------------------------------------------------- /src/attentionhacked_garmnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from typing import Any, Dict, Optional 15 | 16 | import torch 17 | import torch.nn.functional as F 18 | from torch import nn 19 | 20 | from diffusers.utils import USE_PEFT_BACKEND 21 | from diffusers.utils.torch_utils import maybe_allow_in_graph 22 | from diffusers.models.activations import GEGLU, GELU, ApproximateGELU 23 | from diffusers.models.attention_processor import Attention 24 | from diffusers.models.embeddings import SinusoidalPositionalEmbedding 25 | from diffusers.models.lora import LoRACompatibleLinear 26 | from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm 27 | 28 | 29 | def _chunked_feed_forward( 30 | ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None 31 | ): 32 | # "feed_forward_chunk_size" can be used to save memory 33 | if hidden_states.shape[chunk_dim] % chunk_size != 0: 34 | raise ValueError( 35 | f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`." 36 | ) 37 | 38 | num_chunks = hidden_states.shape[chunk_dim] // chunk_size 39 | if lora_scale is None: 40 | ff_output = torch.cat( 41 | [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 42 | dim=chunk_dim, 43 | ) 44 | else: 45 | # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete 46 | ff_output = torch.cat( 47 | [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)], 48 | dim=chunk_dim, 49 | ) 50 | 51 | return ff_output 52 | 53 | 54 | @maybe_allow_in_graph 55 | class GatedSelfAttentionDense(nn.Module): 56 | r""" 57 | A gated self-attention dense layer that combines visual features and object features. 58 | 59 | Parameters: 60 | query_dim (`int`): The number of channels in the query. 61 | context_dim (`int`): The number of channels in the context. 62 | n_heads (`int`): The number of heads to use for attention. 63 | d_head (`int`): The number of channels in each head. 64 | """ 65 | 66 | def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int): 67 | super().__init__() 68 | 69 | # we need a linear projection since we need cat visual feature and obj feature 70 | self.linear = nn.Linear(context_dim, query_dim) 71 | 72 | self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head) 73 | self.ff = FeedForward(query_dim, activation_fn="geglu") 74 | 75 | self.norm1 = nn.LayerNorm(query_dim) 76 | self.norm2 = nn.LayerNorm(query_dim) 77 | 78 | self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0))) 79 | self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0))) 80 | 81 | self.enabled = True 82 | 83 | def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor: 84 | if not self.enabled: 85 | return x 86 | 87 | n_visual = x.shape[1] 88 | objs = self.linear(objs) 89 | 90 | x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :] 91 | x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x)) 92 | 93 | return x 94 | 95 | 96 | @maybe_allow_in_graph 97 | class BasicTransformerBlock(nn.Module): 98 | r""" 99 | A basic Transformer block. 100 | 101 | Parameters: 102 | dim (`int`): The number of channels in the input and output. 103 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 104 | attention_head_dim (`int`): The number of channels in each head. 105 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 106 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 107 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 108 | num_embeds_ada_norm (: 109 | obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. 110 | attention_bias (: 111 | obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. 112 | only_cross_attention (`bool`, *optional*): 113 | Whether to use only cross-attention layers. In this case two cross attention layers are used. 114 | double_self_attention (`bool`, *optional*): 115 | Whether to use two self-attention layers. In this case no cross attention layers are used. 116 | upcast_attention (`bool`, *optional*): 117 | Whether to upcast the attention computation to float32. This is useful for mixed precision training. 118 | norm_elementwise_affine (`bool`, *optional*, defaults to `True`): 119 | Whether to use learnable elementwise affine parameters for normalization. 120 | norm_type (`str`, *optional*, defaults to `"layer_norm"`): 121 | The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. 122 | final_dropout (`bool` *optional*, defaults to False): 123 | Whether to apply a final dropout after the last feed-forward layer. 124 | attention_type (`str`, *optional*, defaults to `"default"`): 125 | The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. 126 | positional_embeddings (`str`, *optional*, defaults to `None`): 127 | The type of positional embeddings to apply to. 128 | num_positional_embeddings (`int`, *optional*, defaults to `None`): 129 | The maximum number of positional embeddings to apply. 130 | """ 131 | 132 | def __init__( 133 | self, 134 | dim: int, 135 | num_attention_heads: int, 136 | attention_head_dim: int, 137 | dropout=0.0, 138 | cross_attention_dim: Optional[int] = None, 139 | activation_fn: str = "geglu", 140 | num_embeds_ada_norm: Optional[int] = None, 141 | attention_bias: bool = False, 142 | only_cross_attention: bool = False, 143 | double_self_attention: bool = False, 144 | upcast_attention: bool = False, 145 | norm_elementwise_affine: bool = True, 146 | norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single' 147 | norm_eps: float = 1e-5, 148 | final_dropout: bool = False, 149 | attention_type: str = "default", 150 | positional_embeddings: Optional[str] = None, 151 | num_positional_embeddings: Optional[int] = None, 152 | ada_norm_continous_conditioning_embedding_dim: Optional[int] = None, 153 | ada_norm_bias: Optional[int] = None, 154 | ff_inner_dim: Optional[int] = None, 155 | ff_bias: bool = True, 156 | attention_out_bias: bool = True, 157 | ): 158 | super().__init__() 159 | self.only_cross_attention = only_cross_attention 160 | 161 | self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero" 162 | self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm" 163 | self.use_ada_layer_norm_single = norm_type == "ada_norm_single" 164 | self.use_layer_norm = norm_type == "layer_norm" 165 | self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous" 166 | 167 | if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None: 168 | raise ValueError( 169 | f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to" 170 | f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}." 171 | ) 172 | 173 | if positional_embeddings and (num_positional_embeddings is None): 174 | raise ValueError( 175 | "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined." 176 | ) 177 | 178 | if positional_embeddings == "sinusoidal": 179 | self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings) 180 | else: 181 | self.pos_embed = None 182 | 183 | # Define 3 blocks. Each block has its own normalization layer. 184 | # 1. Self-Attn 185 | if self.use_ada_layer_norm: 186 | self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) 187 | elif self.use_ada_layer_norm_zero: 188 | self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm) 189 | elif self.use_ada_layer_norm_continuous: 190 | self.norm1 = AdaLayerNormContinuous( 191 | dim, 192 | ada_norm_continous_conditioning_embedding_dim, 193 | norm_elementwise_affine, 194 | norm_eps, 195 | ada_norm_bias, 196 | "rms_norm", 197 | ) 198 | else: 199 | self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) 200 | 201 | self.attn1 = Attention( 202 | query_dim=dim, 203 | heads=num_attention_heads, 204 | dim_head=attention_head_dim, 205 | dropout=dropout, 206 | bias=attention_bias, 207 | cross_attention_dim=cross_attention_dim if only_cross_attention else None, 208 | upcast_attention=upcast_attention, 209 | out_bias=attention_out_bias, 210 | ) 211 | 212 | # 2. Cross-Attn 213 | if cross_attention_dim is not None or double_self_attention: 214 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 215 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 216 | # the second cross attention block. 217 | if self.use_ada_layer_norm: 218 | self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) 219 | elif self.use_ada_layer_norm_continuous: 220 | self.norm2 = AdaLayerNormContinuous( 221 | dim, 222 | ada_norm_continous_conditioning_embedding_dim, 223 | norm_elementwise_affine, 224 | norm_eps, 225 | ada_norm_bias, 226 | "rms_norm", 227 | ) 228 | else: 229 | self.norm2 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 230 | 231 | self.attn2 = Attention( 232 | query_dim=dim, 233 | cross_attention_dim=cross_attention_dim if not double_self_attention else None, 234 | heads=num_attention_heads, 235 | dim_head=attention_head_dim, 236 | dropout=dropout, 237 | bias=attention_bias, 238 | upcast_attention=upcast_attention, 239 | out_bias=attention_out_bias, 240 | ) # is self-attn if encoder_hidden_states is none 241 | else: 242 | self.norm2 = None 243 | self.attn2 = None 244 | 245 | # 3. Feed-forward 246 | if self.use_ada_layer_norm_continuous: 247 | self.norm3 = AdaLayerNormContinuous( 248 | dim, 249 | ada_norm_continous_conditioning_embedding_dim, 250 | norm_elementwise_affine, 251 | norm_eps, 252 | ada_norm_bias, 253 | "layer_norm", 254 | ) 255 | elif not self.use_ada_layer_norm_single: 256 | self.norm3 = nn.LayerNorm(dim, norm_eps, norm_elementwise_affine) 257 | 258 | self.ff = FeedForward( 259 | dim, 260 | dropout=dropout, 261 | activation_fn=activation_fn, 262 | final_dropout=final_dropout, 263 | inner_dim=ff_inner_dim, 264 | bias=ff_bias, 265 | ) 266 | 267 | # 4. Fuser 268 | if attention_type == "gated" or attention_type == "gated-text-image": 269 | self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim) 270 | 271 | # 5. Scale-shift for PixArt-Alpha. 272 | if self.use_ada_layer_norm_single: 273 | self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5) 274 | 275 | # let chunk size default to None 276 | self._chunk_size = None 277 | self._chunk_dim = 0 278 | 279 | def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): 280 | # Sets chunk feed-forward 281 | self._chunk_size = chunk_size 282 | self._chunk_dim = dim 283 | 284 | def forward( 285 | self, 286 | hidden_states: torch.FloatTensor, 287 | attention_mask: Optional[torch.FloatTensor] = None, 288 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 289 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 290 | timestep: Optional[torch.LongTensor] = None, 291 | cross_attention_kwargs: Dict[str, Any] = None, 292 | class_labels: Optional[torch.LongTensor] = None, 293 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 294 | ) -> torch.FloatTensor: 295 | # Notice that normalization is always applied before the real computation in the following blocks. 296 | # 0. Self-Attention 297 | batch_size = hidden_states.shape[0] 298 | #print("batchsize",batch_size) 299 | if self.use_ada_layer_norm: 300 | norm_hidden_states = self.norm1(hidden_states, timestep) 301 | elif self.use_ada_layer_norm_zero: 302 | norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1( 303 | hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype 304 | ) 305 | elif self.use_layer_norm: 306 | norm_hidden_states = self.norm1(hidden_states) 307 | elif self.use_ada_layer_norm_continuous: 308 | norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"]) 309 | elif self.use_ada_layer_norm_single: 310 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( 311 | self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1) 312 | ).chunk(6, dim=1) 313 | norm_hidden_states = self.norm1(hidden_states) 314 | norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa 315 | norm_hidden_states = norm_hidden_states.squeeze(1) 316 | else: 317 | raise ValueError("Incorrect norm used") 318 | 319 | if self.pos_embed is not None: 320 | norm_hidden_states = self.pos_embed(norm_hidden_states) 321 | 322 | garment_features = [] 323 | garment_features.append(norm_hidden_states) 324 | 325 | # 1. Retrieve lora scale. 326 | lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 327 | 328 | # 2. Prepare GLIGEN inputs 329 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 330 | gligen_kwargs = cross_attention_kwargs.pop("gligen", None) 331 | 332 | attn_output = self.attn1( 333 | norm_hidden_states, 334 | encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None, 335 | attention_mask=attention_mask, 336 | **cross_attention_kwargs, 337 | ) 338 | if self.use_ada_layer_norm_zero: 339 | attn_output = gate_msa.unsqueeze(1) * attn_output 340 | elif self.use_ada_layer_norm_single: 341 | attn_output = gate_msa * attn_output 342 | #print("!!!!!!",attn_output.shape) 343 | hidden_states = attn_output + hidden_states 344 | #print("before.shape",hidden_states.shape) 345 | if hidden_states.ndim == 4: 346 | hidden_states = hidden_states.squeeze(1) 347 | 348 | # 2.5 GLIGEN Control 349 | if gligen_kwargs is not None: 350 | hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"]) 351 | 352 | # 3. Cross-Attention 353 | if self.attn2 is not None: 354 | if self.use_ada_layer_norm: 355 | norm_hidden_states = self.norm2(hidden_states, timestep) 356 | elif self.use_ada_layer_norm_zero or self.use_layer_norm: 357 | norm_hidden_states = self.norm2(hidden_states) 358 | elif self.use_ada_layer_norm_single: 359 | # For PixArt norm2 isn't applied here: 360 | # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103 361 | norm_hidden_states = hidden_states 362 | elif self.use_ada_layer_norm_continuous: 363 | norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"]) 364 | else: 365 | raise ValueError("Incorrect norm") 366 | 367 | if self.pos_embed is not None and self.use_ada_layer_norm_single is False: 368 | norm_hidden_states = self.pos_embed(norm_hidden_states) 369 | # print("hidden_states.shape",hidden_states.shape) 370 | # print("norm_hidden_states.shape",norm_hidden_states.shape) 371 | # print("encoder_hidden_states.shape",encoder_hidden_states.shape) 372 | attn_output = self.attn2( 373 | norm_hidden_states, 374 | encoder_hidden_states=encoder_hidden_states, 375 | attention_mask=encoder_attention_mask, 376 | **cross_attention_kwargs, 377 | ) 378 | #print("attn_output.shape",attn_output.shape) 379 | hidden_states = attn_output + hidden_states 380 | 381 | # 4. Feed-forward 382 | if self.use_ada_layer_norm_continuous: 383 | norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"]) 384 | elif not self.use_ada_layer_norm_single: 385 | norm_hidden_states = self.norm3(hidden_states) 386 | 387 | if self.use_ada_layer_norm_zero: 388 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] 389 | 390 | if self.use_ada_layer_norm_single: 391 | norm_hidden_states = self.norm2(hidden_states) 392 | norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp 393 | 394 | if self._chunk_size is not None: 395 | # "feed_forward_chunk_size" can be used to save memory 396 | ff_output = _chunked_feed_forward( 397 | self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale 398 | ) 399 | else: 400 | ff_output = self.ff(norm_hidden_states, scale=lora_scale) 401 | 402 | if self.use_ada_layer_norm_zero: 403 | ff_output = gate_mlp.unsqueeze(1) * ff_output 404 | elif self.use_ada_layer_norm_single: 405 | ff_output = gate_mlp * ff_output 406 | 407 | hidden_states = ff_output + hidden_states 408 | if hidden_states.ndim == 4: 409 | hidden_states = hidden_states.squeeze(1) 410 | 411 | return hidden_states, garment_features 412 | 413 | 414 | @maybe_allow_in_graph 415 | class TemporalBasicTransformerBlock(nn.Module): 416 | r""" 417 | A basic Transformer block for video like data. 418 | 419 | Parameters: 420 | dim (`int`): The number of channels in the input and output. 421 | time_mix_inner_dim (`int`): The number of channels for temporal attention. 422 | num_attention_heads (`int`): The number of heads to use for multi-head attention. 423 | attention_head_dim (`int`): The number of channels in each head. 424 | cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. 425 | """ 426 | 427 | def __init__( 428 | self, 429 | dim: int, 430 | time_mix_inner_dim: int, 431 | num_attention_heads: int, 432 | attention_head_dim: int, 433 | cross_attention_dim: Optional[int] = None, 434 | ): 435 | super().__init__() 436 | self.is_res = dim == time_mix_inner_dim 437 | 438 | self.norm_in = nn.LayerNorm(dim) 439 | 440 | # Define 3 blocks. Each block has its own normalization layer. 441 | # 1. Self-Attn 442 | self.norm_in = nn.LayerNorm(dim) 443 | self.ff_in = FeedForward( 444 | dim, 445 | dim_out=time_mix_inner_dim, 446 | activation_fn="geglu", 447 | ) 448 | 449 | self.norm1 = nn.LayerNorm(time_mix_inner_dim) 450 | self.attn1 = Attention( 451 | query_dim=time_mix_inner_dim, 452 | heads=num_attention_heads, 453 | dim_head=attention_head_dim, 454 | cross_attention_dim=None, 455 | ) 456 | 457 | # 2. Cross-Attn 458 | if cross_attention_dim is not None: 459 | # We currently only use AdaLayerNormZero for self attention where there will only be one attention block. 460 | # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during 461 | # the second cross attention block. 462 | self.norm2 = nn.LayerNorm(time_mix_inner_dim) 463 | self.attn2 = Attention( 464 | query_dim=time_mix_inner_dim, 465 | cross_attention_dim=cross_attention_dim, 466 | heads=num_attention_heads, 467 | dim_head=attention_head_dim, 468 | ) # is self-attn if encoder_hidden_states is none 469 | else: 470 | self.norm2 = None 471 | self.attn2 = None 472 | 473 | # 3. Feed-forward 474 | self.norm3 = nn.LayerNorm(time_mix_inner_dim) 475 | self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu") 476 | 477 | # let chunk size default to None 478 | self._chunk_size = None 479 | self._chunk_dim = None 480 | 481 | def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs): 482 | # Sets chunk feed-forward 483 | self._chunk_size = chunk_size 484 | # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off 485 | self._chunk_dim = 1 486 | 487 | def forward( 488 | self, 489 | hidden_states: torch.FloatTensor, 490 | num_frames: int, 491 | encoder_hidden_states: Optional[torch.FloatTensor] = None, 492 | ) -> torch.FloatTensor: 493 | # Notice that normalization is always applied before the real computation in the following blocks. 494 | # 0. Self-Attention 495 | batch_size = hidden_states.shape[0] 496 | 497 | batch_frames, seq_length, channels = hidden_states.shape 498 | batch_size = batch_frames // num_frames 499 | 500 | hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels) 501 | hidden_states = hidden_states.permute(0, 2, 1, 3) 502 | hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels) 503 | 504 | residual = hidden_states 505 | hidden_states = self.norm_in(hidden_states) 506 | 507 | if self._chunk_size is not None: 508 | hidden_states = _chunked_feed_forward(self.ff_in, hidden_states, self._chunk_dim, self._chunk_size) 509 | else: 510 | hidden_states = self.ff_in(hidden_states) 511 | 512 | if self.is_res: 513 | hidden_states = hidden_states + residual 514 | 515 | norm_hidden_states = self.norm1(hidden_states) 516 | attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None) 517 | hidden_states = attn_output + hidden_states 518 | 519 | # 3. Cross-Attention 520 | if self.attn2 is not None: 521 | norm_hidden_states = self.norm2(hidden_states) 522 | attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states) 523 | hidden_states = attn_output + hidden_states 524 | 525 | # 4. Feed-forward 526 | norm_hidden_states = self.norm3(hidden_states) 527 | 528 | if self._chunk_size is not None: 529 | ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size) 530 | else: 531 | ff_output = self.ff(norm_hidden_states) 532 | 533 | if self.is_res: 534 | hidden_states = ff_output + hidden_states 535 | else: 536 | hidden_states = ff_output 537 | 538 | hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels) 539 | hidden_states = hidden_states.permute(0, 2, 1, 3) 540 | hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels) 541 | 542 | return hidden_states 543 | 544 | 545 | class SkipFFTransformerBlock(nn.Module): 546 | def __init__( 547 | self, 548 | dim: int, 549 | num_attention_heads: int, 550 | attention_head_dim: int, 551 | kv_input_dim: int, 552 | kv_input_dim_proj_use_bias: bool, 553 | dropout=0.0, 554 | cross_attention_dim: Optional[int] = None, 555 | attention_bias: bool = False, 556 | attention_out_bias: bool = True, 557 | ): 558 | super().__init__() 559 | if kv_input_dim != dim: 560 | self.kv_mapper = nn.Linear(kv_input_dim, dim, kv_input_dim_proj_use_bias) 561 | else: 562 | self.kv_mapper = None 563 | 564 | self.norm1 = RMSNorm(dim, 1e-06) 565 | 566 | self.attn1 = Attention( 567 | query_dim=dim, 568 | heads=num_attention_heads, 569 | dim_head=attention_head_dim, 570 | dropout=dropout, 571 | bias=attention_bias, 572 | cross_attention_dim=cross_attention_dim, 573 | out_bias=attention_out_bias, 574 | ) 575 | 576 | self.norm2 = RMSNorm(dim, 1e-06) 577 | 578 | self.attn2 = Attention( 579 | query_dim=dim, 580 | cross_attention_dim=cross_attention_dim, 581 | heads=num_attention_heads, 582 | dim_head=attention_head_dim, 583 | dropout=dropout, 584 | bias=attention_bias, 585 | out_bias=attention_out_bias, 586 | ) 587 | 588 | def forward(self, hidden_states, encoder_hidden_states, cross_attention_kwargs): 589 | cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} 590 | 591 | if self.kv_mapper is not None: 592 | encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) 593 | 594 | norm_hidden_states = self.norm1(hidden_states) 595 | 596 | attn_output = self.attn1( 597 | norm_hidden_states, 598 | encoder_hidden_states=encoder_hidden_states, 599 | **cross_attention_kwargs, 600 | ) 601 | 602 | hidden_states = attn_output + hidden_states 603 | 604 | norm_hidden_states = self.norm2(hidden_states) 605 | 606 | attn_output = self.attn2( 607 | norm_hidden_states, 608 | encoder_hidden_states=encoder_hidden_states, 609 | **cross_attention_kwargs, 610 | ) 611 | 612 | hidden_states = attn_output + hidden_states 613 | 614 | return hidden_states 615 | 616 | 617 | class FeedForward(nn.Module): 618 | r""" 619 | A feed-forward layer. 620 | 621 | Parameters: 622 | dim (`int`): The number of channels in the input. 623 | dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. 624 | mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. 625 | dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. 626 | activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. 627 | final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. 628 | bias (`bool`, defaults to True): Whether to use a bias in the linear layer. 629 | """ 630 | 631 | def __init__( 632 | self, 633 | dim: int, 634 | dim_out: Optional[int] = None, 635 | mult: int = 4, 636 | dropout: float = 0.0, 637 | activation_fn: str = "geglu", 638 | final_dropout: bool = False, 639 | inner_dim=None, 640 | bias: bool = True, 641 | ): 642 | super().__init__() 643 | if inner_dim is None: 644 | inner_dim = int(dim * mult) 645 | dim_out = dim_out if dim_out is not None else dim 646 | linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear 647 | 648 | if activation_fn == "gelu": 649 | act_fn = GELU(dim, inner_dim, bias=bias) 650 | if activation_fn == "gelu-approximate": 651 | act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) 652 | elif activation_fn == "geglu": 653 | act_fn = GEGLU(dim, inner_dim, bias=bias) 654 | elif activation_fn == "geglu-approximate": 655 | act_fn = ApproximateGELU(dim, inner_dim, bias=bias) 656 | 657 | self.net = nn.ModuleList([]) 658 | # project in 659 | self.net.append(act_fn) 660 | # project dropout 661 | self.net.append(nn.Dropout(dropout)) 662 | # project out 663 | self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) 664 | # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout 665 | if final_dropout: 666 | self.net.append(nn.Dropout(dropout)) 667 | 668 | def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: 669 | compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) 670 | for module in self.net: 671 | if isinstance(module, compatible_cls): 672 | hidden_states = module(hidden_states, scale) 673 | else: 674 | hidden_states = module(hidden_states) 675 | return hidden_states 676 | -------------------------------------------------------------------------------- /ip_adapter/ip_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | from diffusers import StableDiffusionPipeline 6 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel 7 | from PIL import Image 8 | from safetensors import safe_open 9 | from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection 10 | 11 | from .utils import is_torch2_available 12 | 13 | if is_torch2_available(): 14 | from .attention_processor import ( 15 | AttnProcessor2_0 as AttnProcessor, 16 | ) 17 | from .attention_processor import ( 18 | CNAttnProcessor2_0 as CNAttnProcessor, 19 | ) 20 | from .attention_processor import ( 21 | IPAttnProcessor2_0 as IPAttnProcessor, 22 | ) 23 | from .attention_processor import IPAttnProcessor2_0_Lora 24 | # else: 25 | # from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor 26 | from .resampler import Resampler 27 | from diffusers.models.lora import LoRALinearLayer 28 | 29 | 30 | class ImageProjModel(torch.nn.Module): 31 | """Projection Model""" 32 | 33 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): 34 | super().__init__() 35 | 36 | self.cross_attention_dim = cross_attention_dim 37 | self.clip_extra_context_tokens = clip_extra_context_tokens 38 | self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) 39 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 40 | 41 | def forward(self, image_embeds): 42 | embeds = image_embeds 43 | clip_extra_context_tokens = self.proj(embeds).reshape( 44 | -1, self.clip_extra_context_tokens, self.cross_attention_dim 45 | ) 46 | clip_extra_context_tokens = self.norm(clip_extra_context_tokens) 47 | return clip_extra_context_tokens 48 | 49 | 50 | class MLPProjModel(torch.nn.Module): 51 | """SD model with image prompt""" 52 | def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): 53 | super().__init__() 54 | 55 | self.proj = torch.nn.Sequential( 56 | torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), 57 | torch.nn.GELU(), 58 | torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), 59 | torch.nn.LayerNorm(cross_attention_dim) 60 | ) 61 | 62 | def forward(self, image_embeds): 63 | clip_extra_context_tokens = self.proj(image_embeds) 64 | return clip_extra_context_tokens 65 | 66 | 67 | class IPAdapter: 68 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4): 69 | self.device = device 70 | self.image_encoder_path = image_encoder_path 71 | self.ip_ckpt = ip_ckpt 72 | self.num_tokens = num_tokens 73 | 74 | self.pipe = sd_pipe.to(self.device) 75 | self.set_ip_adapter() 76 | 77 | # load image encoder 78 | self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( 79 | self.device, dtype=torch.float16 80 | ) 81 | self.clip_image_processor = CLIPImageProcessor() 82 | # image proj model 83 | self.image_proj_model = self.init_proj() 84 | 85 | self.load_ip_adapter() 86 | 87 | def init_proj(self): 88 | image_proj_model = ImageProjModel( 89 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 90 | clip_embeddings_dim=self.image_encoder.config.projection_dim, 91 | clip_extra_context_tokens=self.num_tokens, 92 | ).to(self.device, dtype=torch.float16) 93 | return image_proj_model 94 | 95 | def set_ip_adapter(self): 96 | unet = self.pipe.unet 97 | attn_procs = {} 98 | for name in unet.attn_processors.keys(): 99 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 100 | if name.startswith("mid_block"): 101 | hidden_size = unet.config.block_out_channels[-1] 102 | elif name.startswith("up_blocks"): 103 | block_id = int(name[len("up_blocks.")]) 104 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 105 | elif name.startswith("down_blocks"): 106 | block_id = int(name[len("down_blocks.")]) 107 | hidden_size = unet.config.block_out_channels[block_id] 108 | if cross_attention_dim is None: 109 | attn_procs[name] = AttnProcessor() 110 | else: 111 | attn_procs[name] = IPAttnProcessor( 112 | hidden_size=hidden_size, 113 | cross_attention_dim=cross_attention_dim, 114 | scale=1.0, 115 | num_tokens=self.num_tokens, 116 | ).to(self.device, dtype=torch.float16) 117 | unet.set_attn_processor(attn_procs) 118 | if hasattr(self.pipe, "controlnet"): 119 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 120 | for controlnet in self.pipe.controlnet.nets: 121 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 122 | else: 123 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 124 | 125 | def load_ip_adapter(self): 126 | if self.ip_ckpt is not None: 127 | if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 128 | state_dict = {"image_proj": {}, "ip_adapter": {}} 129 | with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 130 | for key in f.keys(): 131 | if key.startswith("image_proj."): 132 | state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) 133 | elif key.startswith("ip_adapter."): 134 | state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 135 | else: 136 | state_dict = torch.load(self.ip_ckpt, map_location="cpu") 137 | self.image_proj_model.load_state_dict(state_dict["image_proj"]) 138 | ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 139 | ip_layers.load_state_dict(state_dict["ip_adapter"]) 140 | 141 | 142 | # def load_ip_adapter(self): 143 | # if self.ip_ckpt is not None: 144 | # if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors": 145 | # state_dict = {"image_proj_model": {}, "ip_adapter": {}} 146 | # with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f: 147 | # for key in f.keys(): 148 | # if key.startswith("image_proj_model."): 149 | # state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key) 150 | # elif key.startswith("ip_adapter."): 151 | # state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) 152 | # else: 153 | # state_dict = torch.load(self.ip_ckpt, map_location="cpu") 154 | 155 | # tmp1 = {} 156 | # for k,v in state_dict.items(): 157 | # if 'image_proj_model' in k: 158 | # tmp1[k.replace('image_proj_model.','')] = v 159 | # self.image_proj_model.load_state_dict(tmp1, strict=True) 160 | # # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values()) 161 | # tmp2 = {} 162 | # for k,v in state_dict.ites(): 163 | # if 'adapter_mode' in k: 164 | # tmp1[k] = v 165 | 166 | # print(ip_layers.state_dict()) 167 | # ip_layers.load_state_dict(state_dict,strict=False) 168 | 169 | 170 | @torch.inference_mode() 171 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 172 | if pil_image is not None: 173 | if isinstance(pil_image, Image.Image): 174 | pil_image = [pil_image] 175 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 176 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds 177 | else: 178 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 179 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 180 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 181 | return image_prompt_embeds, uncond_image_prompt_embeds 182 | 183 | def get_image_embeds_train(self, pil_image=None, clip_image_embeds=None): 184 | if pil_image is not None: 185 | if isinstance(pil_image, Image.Image): 186 | pil_image = [pil_image] 187 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 188 | clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds 189 | else: 190 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32) 191 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 192 | uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds)) 193 | return image_prompt_embeds, uncond_image_prompt_embeds 194 | 195 | 196 | def set_scale(self, scale): 197 | for attn_processor in self.pipe.unet.attn_processors.values(): 198 | if isinstance(attn_processor, IPAttnProcessor): 199 | attn_processor.scale = scale 200 | 201 | def generate( 202 | self, 203 | pil_image=None, 204 | clip_image_embeds=None, 205 | prompt=None, 206 | negative_prompt=None, 207 | scale=1.0, 208 | num_samples=4, 209 | seed=None, 210 | guidance_scale=7.5, 211 | num_inference_steps=50, 212 | **kwargs, 213 | ): 214 | self.set_scale(scale) 215 | 216 | if pil_image is not None: 217 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 218 | else: 219 | num_prompts = clip_image_embeds.size(0) 220 | 221 | if prompt is None: 222 | prompt = "best quality, high quality" 223 | if negative_prompt is None: 224 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 225 | 226 | if not isinstance(prompt, List): 227 | prompt = [prompt] * num_prompts 228 | if not isinstance(negative_prompt, List): 229 | negative_prompt = [negative_prompt] * num_prompts 230 | 231 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 232 | pil_image=pil_image, clip_image_embeds=clip_image_embeds 233 | ) 234 | bs_embed, seq_len, _ = image_prompt_embeds.shape 235 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 236 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 237 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 238 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 239 | 240 | with torch.inference_mode(): 241 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 242 | prompt, 243 | device=self.device, 244 | num_images_per_prompt=num_samples, 245 | do_classifier_free_guidance=True, 246 | negative_prompt=negative_prompt, 247 | ) 248 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 249 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 250 | 251 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 252 | images = self.pipe( 253 | prompt_embeds=prompt_embeds, 254 | negative_prompt_embeds=negative_prompt_embeds, 255 | guidance_scale=guidance_scale, 256 | num_inference_steps=num_inference_steps, 257 | generator=generator, 258 | **kwargs, 259 | ).images 260 | 261 | return images 262 | 263 | 264 | class IPAdapterXL(IPAdapter): 265 | """SDXL""" 266 | 267 | def generate_test( 268 | self, 269 | pil_image, 270 | prompt=None, 271 | negative_prompt=None, 272 | scale=1.0, 273 | num_samples=4, 274 | seed=None, 275 | num_inference_steps=30, 276 | **kwargs, 277 | ): 278 | self.set_scale(scale) 279 | 280 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 281 | 282 | if prompt is None: 283 | prompt = "best quality, high quality" 284 | if negative_prompt is None: 285 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 286 | 287 | if not isinstance(prompt, List): 288 | prompt = [prompt] * num_prompts 289 | if not isinstance(negative_prompt, List): 290 | negative_prompt = [negative_prompt] * num_prompts 291 | 292 | 293 | with torch.inference_mode(): 294 | ( 295 | prompt_embeds, 296 | negative_prompt_embeds, 297 | pooled_prompt_embeds, 298 | negative_pooled_prompt_embeds, 299 | ) = self.pipe.encode_prompt( 300 | prompt, 301 | num_images_per_prompt=num_samples, 302 | do_classifier_free_guidance=True, 303 | negative_prompt=negative_prompt, 304 | ) 305 | 306 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 307 | images = self.pipe( 308 | prompt_embeds=prompt_embeds, 309 | negative_prompt_embeds=negative_prompt_embeds, 310 | pooled_prompt_embeds=pooled_prompt_embeds, 311 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 312 | num_inference_steps=num_inference_steps, 313 | generator=generator, 314 | **kwargs, 315 | ).images 316 | 317 | 318 | # with torch.autocast("cuda"): 319 | # images = self.pipe( 320 | # prompt_embeds=prompt_embeds, 321 | # negative_prompt_embeds=negative_prompt_embeds, 322 | # pooled_prompt_embeds=pooled_prompt_embeds, 323 | # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 324 | # num_inference_steps=num_inference_steps, 325 | # generator=generator, 326 | # **kwargs, 327 | # ).images 328 | 329 | return images 330 | 331 | 332 | def generate( 333 | self, 334 | pil_image, 335 | prompt=None, 336 | negative_prompt=None, 337 | scale=1.0, 338 | num_samples=4, 339 | seed=None, 340 | num_inference_steps=30, 341 | **kwargs, 342 | ): 343 | self.set_scale(scale) 344 | 345 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 346 | 347 | if prompt is None: 348 | prompt = "best quality, high quality" 349 | if negative_prompt is None: 350 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 351 | 352 | if not isinstance(prompt, List): 353 | prompt = [prompt] * num_prompts 354 | if not isinstance(negative_prompt, List): 355 | negative_prompt = [negative_prompt] * num_prompts 356 | 357 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 358 | bs_embed, seq_len, _ = image_prompt_embeds.shape 359 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 360 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 361 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 362 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 363 | 364 | with torch.inference_mode(): 365 | ( 366 | prompt_embeds, 367 | negative_prompt_embeds, 368 | pooled_prompt_embeds, 369 | negative_pooled_prompt_embeds, 370 | ) = self.pipe.encode_prompt( 371 | prompt, 372 | num_images_per_prompt=num_samples, 373 | do_classifier_free_guidance=True, 374 | negative_prompt=negative_prompt, 375 | ) 376 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 377 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 378 | 379 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 380 | images = self.pipe( 381 | prompt_embeds=prompt_embeds, 382 | negative_prompt_embeds=negative_prompt_embeds, 383 | pooled_prompt_embeds=pooled_prompt_embeds, 384 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 385 | num_inference_steps=num_inference_steps, 386 | generator=generator, 387 | **kwargs, 388 | ).images 389 | 390 | 391 | # with torch.autocast("cuda"): 392 | # images = self.pipe( 393 | # prompt_embeds=prompt_embeds, 394 | # negative_prompt_embeds=negative_prompt_embeds, 395 | # pooled_prompt_embeds=pooled_prompt_embeds, 396 | # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 397 | # num_inference_steps=num_inference_steps, 398 | # generator=generator, 399 | # **kwargs, 400 | # ).images 401 | 402 | return images 403 | 404 | 405 | class IPAdapterPlus(IPAdapter): 406 | """IP-Adapter with fine-grained features""" 407 | 408 | def generate( 409 | self, 410 | pil_image=None, 411 | clip_image_embeds=None, 412 | prompt=None, 413 | negative_prompt=None, 414 | scale=1.0, 415 | num_samples=4, 416 | seed=None, 417 | guidance_scale=7.5, 418 | num_inference_steps=50, 419 | **kwargs, 420 | ): 421 | self.set_scale(scale) 422 | 423 | if pil_image is not None: 424 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 425 | else: 426 | num_prompts = clip_image_embeds.size(0) 427 | 428 | if prompt is None: 429 | prompt = "best quality, high quality" 430 | if negative_prompt is None: 431 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 432 | 433 | if not isinstance(prompt, List): 434 | prompt = [prompt] * num_prompts 435 | if not isinstance(negative_prompt, List): 436 | negative_prompt = [negative_prompt] * num_prompts 437 | 438 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 439 | pil_image=pil_image, clip_image=clip_image_embeds 440 | ) 441 | bs_embed, seq_len, _ = image_prompt_embeds.shape 442 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 443 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 444 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 445 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 446 | 447 | with torch.inference_mode(): 448 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 449 | prompt, 450 | device=self.device, 451 | num_images_per_prompt=num_samples, 452 | do_classifier_free_guidance=True, 453 | negative_prompt=negative_prompt, 454 | ) 455 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 456 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 457 | 458 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 459 | images = self.pipe( 460 | prompt_embeds=prompt_embeds, 461 | negative_prompt_embeds=negative_prompt_embeds, 462 | guidance_scale=guidance_scale, 463 | num_inference_steps=num_inference_steps, 464 | generator=generator, 465 | **kwargs, 466 | ).images 467 | 468 | return images 469 | 470 | 471 | def init_proj(self): 472 | image_proj_model = Resampler( 473 | dim=self.pipe.unet.config.cross_attention_dim, 474 | depth=4, 475 | dim_head=64, 476 | heads=12, 477 | num_queries=self.num_tokens, 478 | embedding_dim=self.image_encoder.config.hidden_size, 479 | output_dim=self.pipe.unet.config.cross_attention_dim, 480 | ff_mult=4, 481 | ).to(self.device, dtype=torch.float16) 482 | return image_proj_model 483 | 484 | @torch.inference_mode() 485 | def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): 486 | if pil_image is not None: 487 | if isinstance(pil_image, Image.Image): 488 | pil_image = [pil_image] 489 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 490 | clip_image = clip_image.to(self.device, dtype=torch.float16) 491 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 492 | else: 493 | clip_image = clip_image.to(self.device, dtype=torch.float16) 494 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 495 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 496 | uncond_clip_image_embeds = self.image_encoder( 497 | torch.zeros_like(clip_image), output_hidden_states=True 498 | ).hidden_states[-2] 499 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 500 | return image_prompt_embeds, uncond_image_prompt_embeds 501 | 502 | 503 | 504 | 505 | class IPAdapterPlus_Lora(IPAdapter): 506 | """IP-Adapter with fine-grained features""" 507 | 508 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32): 509 | self.rank = rank 510 | super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens) 511 | 512 | 513 | def generate( 514 | self, 515 | pil_image=None, 516 | clip_image_embeds=None, 517 | prompt=None, 518 | negative_prompt=None, 519 | scale=1.0, 520 | num_samples=4, 521 | seed=None, 522 | guidance_scale=7.5, 523 | num_inference_steps=50, 524 | **kwargs, 525 | ): 526 | self.set_scale(scale) 527 | 528 | if pil_image is not None: 529 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 530 | else: 531 | num_prompts = clip_image_embeds.size(0) 532 | 533 | if prompt is None: 534 | prompt = "best quality, high quality" 535 | if negative_prompt is None: 536 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 537 | 538 | if not isinstance(prompt, List): 539 | prompt = [prompt] * num_prompts 540 | if not isinstance(negative_prompt, List): 541 | negative_prompt = [negative_prompt] * num_prompts 542 | 543 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 544 | pil_image=pil_image, clip_image=clip_image_embeds 545 | ) 546 | bs_embed, seq_len, _ = image_prompt_embeds.shape 547 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 548 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 549 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 550 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 551 | 552 | with torch.inference_mode(): 553 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 554 | prompt, 555 | device=self.device, 556 | num_images_per_prompt=num_samples, 557 | do_classifier_free_guidance=True, 558 | negative_prompt=negative_prompt, 559 | ) 560 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 561 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 562 | 563 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 564 | images = self.pipe( 565 | prompt_embeds=prompt_embeds, 566 | negative_prompt_embeds=negative_prompt_embeds, 567 | guidance_scale=guidance_scale, 568 | num_inference_steps=num_inference_steps, 569 | generator=generator, 570 | **kwargs, 571 | ).images 572 | 573 | return images 574 | 575 | 576 | def init_proj(self): 577 | image_proj_model = Resampler( 578 | dim=self.pipe.unet.config.cross_attention_dim, 579 | depth=4, 580 | dim_head=64, 581 | heads=12, 582 | num_queries=self.num_tokens, 583 | embedding_dim=self.image_encoder.config.hidden_size, 584 | output_dim=self.pipe.unet.config.cross_attention_dim, 585 | ff_mult=4, 586 | ).to(self.device, dtype=torch.float16) 587 | return image_proj_model 588 | 589 | @torch.inference_mode() 590 | def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): 591 | if pil_image is not None: 592 | if isinstance(pil_image, Image.Image): 593 | pil_image = [pil_image] 594 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 595 | clip_image = clip_image.to(self.device, dtype=torch.float16) 596 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 597 | else: 598 | clip_image = clip_image.to(self.device, dtype=torch.float16) 599 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 600 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 601 | uncond_clip_image_embeds = self.image_encoder( 602 | torch.zeros_like(clip_image), output_hidden_states=True 603 | ).hidden_states[-2] 604 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 605 | return image_prompt_embeds, uncond_image_prompt_embeds 606 | 607 | def set_ip_adapter(self): 608 | unet = self.pipe.unet 609 | attn_procs = {} 610 | unet_sd = unet.state_dict() 611 | 612 | for attn_processor_name, attn_processor in unet.attn_processors.items(): 613 | # Parse the attention module. 614 | cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim 615 | if attn_processor_name.startswith("mid_block"): 616 | hidden_size = unet.config.block_out_channels[-1] 617 | elif attn_processor_name.startswith("up_blocks"): 618 | block_id = int(attn_processor_name[len("up_blocks.")]) 619 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 620 | elif attn_processor_name.startswith("down_blocks"): 621 | block_id = int(attn_processor_name[len("down_blocks.")]) 622 | hidden_size = unet.config.block_out_channels[block_id] 623 | if cross_attention_dim is None: 624 | attn_procs[attn_processor_name] = AttnProcessor() 625 | else: 626 | layer_name = attn_processor_name.split(".processor")[0] 627 | weights = { 628 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 629 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 630 | } 631 | attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens) 632 | attn_procs[attn_processor_name].load_state_dict(weights,strict=False) 633 | 634 | attn_module = unet 635 | for n in attn_processor_name.split(".")[:-1]: 636 | attn_module = getattr(attn_module, n) 637 | 638 | attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank) 639 | attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank) 640 | attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank) 641 | attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank) 642 | 643 | unet.set_attn_processor(attn_procs) 644 | if hasattr(self.pipe, "controlnet"): 645 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 646 | for controlnet in self.pipe.controlnet.nets: 647 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 648 | else: 649 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 650 | 651 | 652 | 653 | class IPAdapterPlus_Lora_up(IPAdapter): 654 | """IP-Adapter with fine-grained features""" 655 | 656 | def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32): 657 | self.rank = rank 658 | super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens) 659 | 660 | 661 | def generate( 662 | self, 663 | pil_image=None, 664 | clip_image_embeds=None, 665 | prompt=None, 666 | negative_prompt=None, 667 | scale=1.0, 668 | num_samples=4, 669 | seed=None, 670 | guidance_scale=7.5, 671 | num_inference_steps=50, 672 | **kwargs, 673 | ): 674 | self.set_scale(scale) 675 | 676 | if pil_image is not None: 677 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 678 | else: 679 | num_prompts = clip_image_embeds.size(0) 680 | 681 | if prompt is None: 682 | prompt = "best quality, high quality" 683 | if negative_prompt is None: 684 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 685 | 686 | if not isinstance(prompt, List): 687 | prompt = [prompt] * num_prompts 688 | if not isinstance(negative_prompt, List): 689 | negative_prompt = [negative_prompt] * num_prompts 690 | 691 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds( 692 | pil_image=pil_image, clip_image=clip_image_embeds 693 | ) 694 | bs_embed, seq_len, _ = image_prompt_embeds.shape 695 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 696 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 697 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 698 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 699 | 700 | with torch.inference_mode(): 701 | prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt( 702 | prompt, 703 | device=self.device, 704 | num_images_per_prompt=num_samples, 705 | do_classifier_free_guidance=True, 706 | negative_prompt=negative_prompt, 707 | ) 708 | prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1) 709 | negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1) 710 | 711 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 712 | images = self.pipe( 713 | prompt_embeds=prompt_embeds, 714 | negative_prompt_embeds=negative_prompt_embeds, 715 | guidance_scale=guidance_scale, 716 | num_inference_steps=num_inference_steps, 717 | generator=generator, 718 | **kwargs, 719 | ).images 720 | 721 | return images 722 | 723 | 724 | def init_proj(self): 725 | image_proj_model = Resampler( 726 | dim=self.pipe.unet.config.cross_attention_dim, 727 | depth=4, 728 | dim_head=64, 729 | heads=12, 730 | num_queries=self.num_tokens, 731 | embedding_dim=self.image_encoder.config.hidden_size, 732 | output_dim=self.pipe.unet.config.cross_attention_dim, 733 | ff_mult=4, 734 | ).to(self.device, dtype=torch.float16) 735 | return image_proj_model 736 | 737 | @torch.inference_mode() 738 | def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None): 739 | if pil_image is not None: 740 | if isinstance(pil_image, Image.Image): 741 | pil_image = [pil_image] 742 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 743 | clip_image = clip_image.to(self.device, dtype=torch.float16) 744 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 745 | else: 746 | clip_image = clip_image.to(self.device, dtype=torch.float16) 747 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 748 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 749 | uncond_clip_image_embeds = self.image_encoder( 750 | torch.zeros_like(clip_image), output_hidden_states=True 751 | ).hidden_states[-2] 752 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 753 | return image_prompt_embeds, uncond_image_prompt_embeds 754 | 755 | def set_ip_adapter(self): 756 | unet = self.pipe.unet 757 | attn_procs = {} 758 | unet_sd = unet.state_dict() 759 | 760 | for attn_processor_name, attn_processor in unet.attn_processors.items(): 761 | # Parse the attention module. 762 | cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim 763 | if attn_processor_name.startswith("mid_block"): 764 | hidden_size = unet.config.block_out_channels[-1] 765 | elif attn_processor_name.startswith("up_blocks"): 766 | block_id = int(attn_processor_name[len("up_blocks.")]) 767 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 768 | elif attn_processor_name.startswith("down_blocks"): 769 | block_id = int(attn_processor_name[len("down_blocks.")]) 770 | hidden_size = unet.config.block_out_channels[block_id] 771 | if cross_attention_dim is None: 772 | attn_procs[attn_processor_name] = AttnProcessor() 773 | else: 774 | layer_name = attn_processor_name.split(".processor")[0] 775 | weights = { 776 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 777 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 778 | } 779 | attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens) 780 | attn_procs[attn_processor_name].load_state_dict(weights,strict=False) 781 | 782 | attn_module = unet 783 | for n in attn_processor_name.split(".")[:-1]: 784 | attn_module = getattr(attn_module, n) 785 | 786 | 787 | if "up_blocks" in attn_processor_name: 788 | attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank) 789 | attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank) 790 | attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank) 791 | attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank) 792 | 793 | 794 | 795 | unet.set_attn_processor(attn_procs) 796 | if hasattr(self.pipe, "controlnet"): 797 | if isinstance(self.pipe.controlnet, MultiControlNetModel): 798 | for controlnet in self.pipe.controlnet.nets: 799 | controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 800 | else: 801 | self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens)) 802 | 803 | 804 | 805 | class IPAdapterFull(IPAdapterPlus): 806 | """IP-Adapter with full features""" 807 | 808 | def init_proj(self): 809 | image_proj_model = MLPProjModel( 810 | cross_attention_dim=self.pipe.unet.config.cross_attention_dim, 811 | clip_embeddings_dim=self.image_encoder.config.hidden_size, 812 | ).to(self.device, dtype=torch.float16) 813 | return image_proj_model 814 | 815 | 816 | class IPAdapterPlusXL(IPAdapter): 817 | """SDXL""" 818 | 819 | def init_proj(self): 820 | image_proj_model = Resampler( 821 | dim=1280, 822 | depth=4, 823 | dim_head=64, 824 | heads=20, 825 | num_queries=self.num_tokens, 826 | embedding_dim=self.image_encoder.config.hidden_size, 827 | output_dim=self.pipe.unet.config.cross_attention_dim, 828 | ff_mult=4, 829 | ).to(self.device, dtype=torch.float16) 830 | return image_proj_model 831 | 832 | @torch.inference_mode() 833 | def get_image_embeds(self, pil_image=None, clip_image_embeds=None): 834 | if pil_image is not None: 835 | if isinstance(pil_image, Image.Image): 836 | pil_image = [pil_image] 837 | clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values 838 | clip_image = clip_image.to(self.device, dtype=torch.float16) 839 | clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] 840 | else: 841 | clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16) 842 | image_prompt_embeds = self.image_proj_model(clip_image_embeds) 843 | uncond_clip_image_embeds = self.image_encoder( 844 | torch.zeros_like(clip_image), output_hidden_states=True 845 | ).hidden_states[-2] 846 | uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) 847 | return image_prompt_embeds, uncond_image_prompt_embeds 848 | 849 | def generate( 850 | self, 851 | pil_image, 852 | prompt=None, 853 | negative_prompt=None, 854 | scale=1.0, 855 | num_samples=4, 856 | seed=None, 857 | num_inference_steps=30, 858 | **kwargs, 859 | ): 860 | self.set_scale(scale) 861 | 862 | num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image) 863 | 864 | if prompt is None: 865 | prompt = "best quality, high quality" 866 | if negative_prompt is None: 867 | negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality" 868 | 869 | if not isinstance(prompt, List): 870 | prompt = [prompt] * num_prompts 871 | if not isinstance(negative_prompt, List): 872 | negative_prompt = [negative_prompt] * num_prompts 873 | 874 | image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image) 875 | bs_embed, seq_len, _ = image_prompt_embeds.shape 876 | image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) 877 | image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 878 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) 879 | uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) 880 | 881 | with torch.inference_mode(): 882 | ( 883 | prompt_embeds, 884 | negative_prompt_embeds, 885 | pooled_prompt_embeds, 886 | negative_pooled_prompt_embeds, 887 | ) = self.pipe.encode_prompt( 888 | prompt, 889 | num_images_per_prompt=num_samples, 890 | do_classifier_free_guidance=True, 891 | negative_prompt=negative_prompt, 892 | ) 893 | prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1) 894 | negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1) 895 | 896 | generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None 897 | images = self.pipe( 898 | prompt_embeds=prompt_embeds, 899 | negative_prompt_embeds=negative_prompt_embeds, 900 | pooled_prompt_embeds=pooled_prompt_embeds, 901 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 902 | num_inference_steps=num_inference_steps, 903 | generator=generator, 904 | **kwargs, 905 | ).images 906 | 907 | return images 908 | --------------------------------------------------------------------------------