├── README.md ├── assets ├── examples.png ├── results_customization.png └── results_fidelity.png ├── download.py ├── download.sh ├── get_face_info.py ├── infer.py ├── infer_from_pkl.py ├── ip_adapter ├── __pycache__ │ ├── attention_processor.cpython-310.pyc │ ├── resampler.cpython-310.pyc │ └── utils.cpython-310.pyc ├── attention_processor.py ├── mlp.py ├── resampler.py └── utils.py ├── pipeline_stable_diffusion_xl_instantid.py ├── train_instantId_sdxl.py ├── train_instantId_sdxl.sh └── utils └── dataset.py /README.md: -------------------------------------------------------------------------------- 1 | # Turn That Frown Upside Down: FaceID Customization via Cross-Training Data 2 | 3 | 4 | 5 | This repository contains resources referenced in the paper [Turn That Frown Upside Down: FaceID Customization via Cross-Training Data](https://arxiv.org/abs/2501.15407v1). 6 | 7 | If you find this repository helpful, please cite the following: 8 | ```latex 9 | @misc{wang2025turnfrownupsidedown, 10 | title={Turn That Frown Upside Down: FaceID Customization via Cross-Training Data}, 11 | author={Shuhe Wang and Xiaoya Li and Xiaofei Sun and Guoyin Wang and Tianwei Zhang and Jiwei Li and Eduard Hovy}, 12 | year={2025}, 13 | eprint={2501.15407}, 14 | archivePrefix={arXiv}, 15 | primaryClass={cs.CV}, 16 | url={https://arxiv.org/abs/2501.15407}, 17 | } 18 | ``` 19 | 20 | 21 | 22 | ## 🥳 News 23 | 24 | **Stay tuned! More related work will be updated!** 25 | * **[25 Jan, 2025]** The repository is created. 26 | * **[25 Jan, 2025]** We release the first version of the paper. 27 | 28 | 29 | ## Links 30 | - [Turn That Frown Upside Down: FaceID Customization via Cross-Training Data](#turn-that-frown-upside-down-faceid-customization-via-cross-training-data) 31 | - [🥳 News](#-news) 32 | - [Links](#links) 33 | - [Introduction](#introduction) 34 | - [Comparison with Previous Works](#comparison-with-previous-works) 35 | - [FaceID Fidelity](#faceid-fidelity) 36 | - [FaceID Customization](#faceid-customization) 37 | - [Released CrossFaceID dataset](#released-crossfaceid-dataset) 38 | - [Released FaceID Customization Models](#released-faceid-customization-models) 39 | - [Usage](#usage) 40 | - [Training](#training) 41 | - [Step 1. Download Required Models](#step-1-download-required-models) 42 | - [Step 2. Download Required Dataset](#step-2-download-required-dataset) 43 | - [Step 3. Training](#step-3-training) 44 | - [Inference](#inference) 45 | - [Contact](#contact) 46 | 47 | 48 | 49 | ## Introduction 50 | 51 | CrossFaceID is the first large-scale, high-quality, and publicly available dataset specifically designed to improve the facial modification capabilities of FaceID customization models. Specifically, CrossFaceID consists of 40,000 text-image pairs from approximately 2,000 persons, with each person represented by around 20 images showcasing diverse facial attributes such as poses, expressions, angles, and adornments. During the training stage, a specific face of a person is used as input, and the FaceID customization model is forced to generate another image of the same person but with altered facial features. This allows the FaceID customization model to acquire the ability to personalize and modify known facial features during the inference stage. 52 | 53 |
54 | 55 |
56 | 57 | 58 | ## Comparison with Previous Works 59 | 60 | 61 | ### FaceID Fidelity 62 | 63 |
64 | 65 |
66 | 67 | The results demonstrate the performance of FaceID customization models in maintaining FaceID fidelity. For models, “InstantID” refers to the official InstantID model, while “InstantID + CrossFaceID” represents the model further fine-tuned on our CrossFaceID dataset. “LAION” denotes the InstantID model pre-trained on our curated LAION dataset, and “LAION + CrossFaceID” refers to the model further trained on the CrossFaceID dataset. These results indicate that (1) for both the official InstantID model and the LAION-trained model, the ability to maintain FaceID fidelity remains consistent before and after fine-tuning on our CrossFaceID dataset, and (2) the model trained on our curated LAION dataset achieves comparable performance to the official InstantID model in preserving FaceID fidelity. 68 | 69 | ### FaceID Customization 70 | 71 |
72 | 73 |
74 | 75 | The results of the performance for FaceID customization models in customizing or editing FaceID. Here, "InstantID" represents the official InstantID model, while "InstantID + CrossFaceID" refers to the model fine-tuned on our CrossFaceID dataset. Similarly, "LAION" denotes the InstantID model pre-trained on our curated LAION dataset, and "LAION + CrossFaceID" refers to the model further fine-tuned on the CrossFaceID dataset. From these results, we can clearly observe an improvement in the models' ability to customize FaceID after being fine-tuned on our constructed CrossFaceID dataset. 76 | 77 | 78 | 79 | ## Released CrossFaceID dataset 80 | 81 | Our CrossFaceID dataset is available on [Huggingface](https://huggingface.co/datasets/Super-shuhe/CrossFaceID). It comprises 40,000 text-image pairs from approximately 2,000 individuals, with each person represented by around 20 images that capture various facial attributes, including different poses, expressions, angles, and adornments. 82 | 83 | 84 | ## Released FaceID Customization Models 85 | 86 | The trained InstantID model is available [here](https://huggingface.co/Super-shuhe/CrossFaceID-InstantID). 87 | 88 | ## Usage 89 | 90 | ### Training 91 | As the original InstantID repository (https://github.com/InstantID/InstantID) doesn't contain training codes, we follow [this repository](https://github.com/MFaceTech/InstantID?tab=readme-ov-file) to train our own InstantID. 92 | 93 | #### Step 1. Download Required Models 94 | 95 | You can directly download the model from [Huggingface](https://huggingface.co/InstantX/InstantID). 96 | You also can download the model in python script: 97 | 98 | ```python 99 | from huggingface_hub import hf_hub_download 100 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints") 101 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints") 102 | hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints") 103 | ``` 104 | 105 | If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download models. 106 | ```python 107 | export HF_ENDPOINT=https://hf-mirror.com 108 | huggingface-cli download --resume-download InstantX/InstantID --local-dir checkpoints 109 | ``` 110 | 111 | For face encoder, you need to manutally download via this [URL](https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304) to `models/antelopev2` as the default link is invalid. Once you have prepared all models, the folder tree should be like: 112 | 113 | ``` 114 | . 115 | ├── models 116 | ├── checkpoints 117 | ├── ip_adapter 118 | ├── pipeline_stable_diffusion_xl_instantid.py 119 | ├── download.py 120 | ├── download.sh 121 | ├── get_face_info.py 122 | ├── infer_from_pkl.py 123 | ├── infer.py 124 | ├── train_instantId_sdxl.py 125 | ├── train_instantId_sdxl.sh 126 | └── README.md 127 | ``` 128 | 129 | 130 | #### Step 2. Download Required Dataset 131 | 132 | Please download our released dataset from [Huggingface](https://huggingface.co/datasets/Super-shuhe/CrossFaceID). 133 | 134 | #### Step 3. Training 135 | 136 | 1. Fill the `MODEL_NAME`, `ENCODER_NAME`, `ADAPTOR_NAME`, `CONTROLNET_NAME`, and `JSON_FILE` into our provided training script `./train_instantId_sdxl.sh`, where: 137 | 1. `MODEL_NAME` refers to the backboned diffusion model, e.g., `stable-diffusion-xl-base-1.0` 138 | 2. `ENCODER_NAME` refers to the downloaded encoder, e.g., `image_encoder` 139 | 3. `ADAPTOR_NAME` and `CONTROLNET_NAME` refers to the pre-trained official InstantID model, e.g., `checkpoints/ip-adapter.bin` and `checkpoints/ControlNetModel` 140 | 4. `JSON_FILE` refers to our released CrossFaceID dataset. 141 | 2. Run the training scirpt, such as: `bash ./train_instantId_sdxl.sh` 142 | 143 | 144 | ### Inference 145 | 146 | 1. Fill the `base_model_path`, `face_adapter`, `controlnet_path`, `prompt0`, and `face_image` into our provided inference script `./infer_from_pkl.py`, where: 147 | 1. `base_model_path` refers to the backboned diffusion model, e.g., `stable-diffusion-xl-base-1.0` 148 | 2. `face_adapter` and `controlnet_path` refer to your trained model e.g., `checkpoints/ip-adapter.bin` and `checkpoints/ControlNetModel` 149 | 3. `prompt0` and `face_image` refer to your test sample. 150 | 2. Run the training script, such as: `python ./infer_from_pkl.py` 151 | 152 | 153 | ## Contact 154 | 155 | If you have any issues or questions about this repo, feel free to contact shuhewang@student.unimelb.edu.au 156 | -------------------------------------------------------------------------------- /assets/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/assets/examples.png -------------------------------------------------------------------------------- /assets/results_customization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/assets/results_customization.png -------------------------------------------------------------------------------- /assets/results_fidelity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/assets/results_fidelity.png -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | import os 2 | from huggingface_hub import hf_hub_download 3 | local_dir = "checkpoints" 4 | os.makedirs(local_dir, exist_ok=True) 5 | 6 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=local_dir) 7 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=local_dir) 8 | hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=local_dir) 9 | 10 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | export HF_ENDPOINT=https://hf-mirror.com 2 | huggingface-cli download --resume-download InstantX/InstantID --local-dir checkpoints 3 | -------------------------------------------------------------------------------- /get_face_info.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | from PIL import Image 4 | import pickle 5 | from diffusers.utils import load_image 6 | from insightface.app import FaceAnalysis 7 | 8 | 9 | def resize_img(input_image, max_side=1280, min_side=1024, size=None, 10 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): 11 | 12 | w, h = input_image.size 13 | if size is not None: 14 | w_resize_new, h_resize_new = size 15 | else: 16 | ratio = min_side / min(h, w) 17 | w, h = round(ratio*w), round(ratio*h) 18 | ratio = max_side / max(h, w) 19 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) 20 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number 21 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number 22 | input_image = input_image.resize([w_resize_new, h_resize_new], mode) 23 | 24 | if pad_to_max_side: 25 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 26 | offset_x = (max_side - w_resize_new) // 2 27 | offset_y = (max_side - h_resize_new) // 2 28 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) 29 | input_image = Image.fromarray(res) 30 | return input_image 31 | 32 | 33 | 34 | root = "model_zoo/instantID/" 35 | app = FaceAnalysis(name='antelopev2', root=root, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 36 | 37 | app.prepare(ctx_id=0, det_size=(640, 640)) 38 | 39 | 40 | face_image = load_image("./examples/wenyongshan.png") 41 | # face_image = load_image("./examples/zhuyilong.jpg") 42 | 43 | 44 | face_image = resize_img(face_image) 45 | face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) 46 | face_info = sorted(face_info, key=lambda x :(x['bbox'][2 ] -x['bbox'][0] ) *x['bbox'][3 ] -x['bbox'][1])[-1] # only use the maximum face 47 | 48 | print(type(face_info)) 49 | 50 | # 创建一个示例字典,其值是 dtype 为 float32 的 NumPy 数组 51 | data_dict = { 52 | 'embedding': face_info['embedding'], 53 | 'kps': face_info['kps'] 54 | } 55 | 56 | # 使用 pickle 序列化字典并保存到文件 57 | with open('examples/face_info.pkl', 'wb') as pickle_file: 58 | pickle.dump(data_dict, pickle_file) -------------------------------------------------------------------------------- /infer.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import PIL 6 | from diffusers.utils import load_image 7 | from diffusers.models import ControlNetModel 8 | import math 9 | from insightface.app import FaceAnalysis 10 | from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline 11 | 12 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): 13 | stickwidth = 4 14 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) 15 | kps = np.array(kps) 16 | 17 | w, h = image_pil.size 18 | out_img = np.zeros([h, w, 3]) 19 | 20 | for i in range(len(limbSeq)): 21 | index = limbSeq[i] 22 | color = color_list[index[0]] 23 | 24 | x = kps[index][:, 0] 25 | y = kps[index][:, 1] 26 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 27 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) 28 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 29 | 360, 1) 30 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) 31 | out_img = (out_img * 0.6).astype(np.uint8) 32 | 33 | for idx_kp, kp in enumerate(kps): 34 | color = color_list[idx_kp] 35 | x, y = kp 36 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) 37 | 38 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) 39 | return out_img_pil 40 | def resize_img(input_image, max_side=1280, min_side=1024, size=None, 41 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): 42 | 43 | w, h = input_image.size 44 | if size is not None: 45 | w_resize_new, h_resize_new = size 46 | else: 47 | ratio = min_side / min(h, w) 48 | w, h = round(ratio*w), round(ratio*h) 49 | ratio = max_side / max(h, w) 50 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) 51 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number 52 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number 53 | input_image = input_image.resize([w_resize_new, h_resize_new], mode) 54 | 55 | if pad_to_max_side: 56 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 57 | offset_x = (max_side - w_resize_new) // 2 58 | offset_y = (max_side - h_resize_new) // 2 59 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) 60 | input_image = Image.fromarray(res) 61 | return input_image 62 | 63 | 64 | if __name__ == "__main__": 65 | 66 | root = "aigc_models/InstantID" 67 | app = FaceAnalysis(name='antelopev2', root=root, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 68 | 69 | app.prepare(ctx_id=0, det_size=(640, 640)) 70 | 71 | # Path to InstantID models 72 | # face_adapter = f'./checkpoints/ip-adapter.bin' 73 | # controlnet_path = f'./checkpoints/ControlNetModel' 74 | face_adapter = "InstantID/checkpoints/ip-adapter.bin" 75 | controlnet_path = "InstantID/checkpoints/ControlNetModel" 76 | 77 | # Load pipeline 78 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) 79 | 80 | # base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0' 81 | base_model_path = "huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b" 82 | 83 | pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( 84 | base_model_path, 85 | controlnet=controlnet, 86 | torch_dtype=torch.float16, 87 | ) 88 | pipe.cuda() 89 | pipe.load_ip_adapter_instantid(face_adapter) 90 | 91 | prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" 92 | n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" 93 | 94 | face_image = load_image("./examples/yann-lecun_resize.jpg") 95 | face_image = resize_img(face_image) 96 | 97 | face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) 98 | face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face 99 | face_emb = face_info['embedding'] 100 | face_kps = draw_kps(face_image, face_info['kps']) 101 | 102 | pipe.set_ip_adapter_scale(0.8) 103 | image = pipe( 104 | prompt=prompt, 105 | negative_prompt=n_prompt, 106 | image_embeds=face_emb, 107 | image=face_kps, 108 | controlnet_conditioning_scale=0.8, 109 | num_inference_steps=30, 110 | guidance_scale=5, 111 | ).images[0] 112 | 113 | image.save('result.jpg') -------------------------------------------------------------------------------- /infer_from_pkl.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import pickle 6 | from diffusers.utils import load_image 7 | from diffusers.models import ControlNetModel 8 | 9 | # from insightface.app import FaceAnalysis 10 | from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps 11 | 12 | 13 | 14 | import json 15 | import os 16 | from tqdm import tqdm 17 | 18 | from insightface.app import FaceAnalysis 19 | 20 | 21 | 22 | 23 | def face_extraction(image): 24 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 25 | app.prepare(ctx_id=0, det_size=(640, 640)) 26 | 27 | # image = load_image(image_path) 28 | 29 | # prepare face emb 30 | face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)) 31 | 32 | face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face 33 | face_emb = face_info['embedding'] 34 | face_kps = face_info['kps'].tolist() 35 | face_bbox = face_info['bbox'].tolist() 36 | 37 | return face_emb, face_kps 38 | 39 | 40 | 41 | 42 | def resize_img(input_image, max_side=1280, min_side=1024, size=None, 43 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64): 44 | 45 | w, h = input_image.size 46 | if size is not None: 47 | w_resize_new, h_resize_new = size 48 | else: 49 | ratio = min_side / min(h, w) 50 | w, h = round(ratio*w), round(ratio*h) 51 | ratio = max_side / max(h, w) 52 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode) 53 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number 54 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number 55 | input_image = input_image.resize([w_resize_new, h_resize_new], mode) 56 | 57 | if pad_to_max_side: 58 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255 59 | offset_x = (max_side - w_resize_new) // 2 60 | offset_y = (max_side - h_resize_new) // 2 61 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image) 62 | input_image = Image.fromarray(res) 63 | return input_image 64 | 65 | 66 | if __name__ == "__main__": 67 | # user model 68 | 69 | base_model_path = ".../stable-diffusion-xl-base-1.0" 70 | face_adapter = ".../checkpoints/ip-adapter.bin" 71 | controlnet_path = ".../checkpoints/ControlNetModel/" 72 | 73 | 74 | 75 | # Load pipeline 76 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) 77 | 78 | pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( 79 | base_model_path, 80 | controlnet=controlnet, 81 | torch_dtype=torch.float16, 82 | ) 83 | pipe.cuda() 84 | pipe.load_ip_adapter_instantid(face_adapter) 85 | 86 | # prompt0 = "jpeg artifacts,asian man,early twenties,casual pose,low quality,blurry,poorly drawn,worst quality,photography,vintage," 87 | # prompt1 = "photography aesthetic, elegant portrait photography, young man exuding classical beauty and sophistication, pale skin, delicate features, captivating gaze, slightly parted lips, contemporary fashion fused with vintage elements, naturally-lit setting, warm neutral backdrop, soft, diffused lighting highlighting subject's face, subtle shadowing, clear, high-resolution photo,film grain" 88 | # prompt2 = "portrait, young asian woman sleek, shoulder-length black hair gazes contemplatively towards camera, dark eyes slightlycast, lips closed neutral expression hints poise introspection, softness skin enhanced," 89 | # # prompt2 = "portrait, young asian man, gazes contemplatively towards camera, dark eyes slightlycast, lips closed neutral expression hints poise introspection, softness skin enhanced," 90 | # prompt3 = "photography aesthetic,Studio portrait photography, striking sexy female model, serene countenance, youthful appearance, light skin, elegant goth makeup emphasizing sharp eyebrows, pronounced eyelashes, dark lipstick, jewelry showcasing sparkling goth earrings,fashionably styled dark hair, exposing clear skin, backdrop complementing the subject's attire, natural light simulation enhancing soft shadows, subtle retouching ensuring smooth skin textures, harmonious color palette imbued with warm, muted tones,film grain," 91 | # # prompt3 = "photography aesthetic,Studio portrait photography, striking male model, serene countenance, youthful appearance, light skin, elegant goth makeup emphasizing sharp eyebrows, pronounced eyelashes, dark lipstick,fashionably styled dark hair, exposing clear skin, backdrop complementing the subject's attire, natural light simulation enhancing soft shadows, subtle retouching ensuring smooth skin textures, harmonious color palette imbued with warm, muted tones,film grain," 92 | # prompt4 = "a beautiful woman dressed in casual attire, looking energetic and vibrant, positive and upbeat atmosphere" 93 | # # prompt4 = "a man dressed in casual attire, looking energetic and vibrant, positive and upbeat atmosphere" 94 | # prompt5 = "A serene ambience,traditional Chinese aesthetic,a young woman,gentle expression,concentration on instrument,traditional Chinese guzheng,flowing pale blue and white hanfu with delicate floral accents,a backdrop of lush foliage,soft natural lighting,harmonious color palette of cool tones,ancient heritage,cultural reverence,timeless elegance,poised positioning amidst rocks,black hair adorned with classical hairpin,embodiment of classical Chinese music and beauty,tranquility amidst nature,subtlety in details,fine craftsmanship of the guzheng,ethereal atmosphere,cultural homage." 95 | # # prompt5 = "A serene ambience,traditional Chinese aesthetic,a young man,gentle expression,concentration on instrument,traditional Chinese guzheng,flowing pale blue and white hanfu with delicate floral accents,a backdrop of lush foliage,soft natural lighting,harmonious color palette of cool tones,ancient heritage,cultural reverence,timeless elegance,poised positioning amidst rocks,black hair adorned with classical hairpin,embodiment of classical Chinese music and beauty,tranquility amidst nature,subtlety in details,fine craftsmanship of the guzheng,ethereal atmosphere,cultural homage." 96 | 97 | 98 | # prompt0 = "a beautiful girl wearing casual shirt in a garden and smiling" 99 | 100 | # prompt0 = "a beautiful girl wearing casual shirt in a garden" 101 | 102 | prompt0 = "facing one side, wearing red sunglasses, a golden chain, and a green cap" 103 | 104 | # prompt0 = "A man on the red carpet" 105 | 106 | # prompt0 = "a man is standing here, his eyes sharp and full of spirit." 107 | 108 | # prompt0 = "A man holding a cup of coffee" 109 | 110 | # prompt0 = "A man wearing casual shirt in a garden" 111 | 112 | # prompt0 = "a beautiful woman dressed in casual attire, looking energetic and vibrant, positive and upbeat atmosphere" 113 | 114 | # prompt0 = "side face, wearing red sunglasses, a golden chain, and a green cap" 115 | 116 | # prompt0 = "a man is wearing a sunglasses" 117 | 118 | # prompt0 = "a girl looks very sad" 119 | 120 | # prompt0 = "a man is wearing a mask" 121 | 122 | 123 | n_prompt = "ng_deepnegative_v1_75t, (badhandv4:1.2), (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands,((monochrome)), ((grayscale)) watermark, moles, large breast, big breast" 124 | 125 | 126 | 127 | face_image = load_image("") # image path 128 | 129 | 130 | 131 | face_image = resize_img(face_image, size=(1024, 1024)) 132 | # face_image = resize_img(face_image, size=(512, 512)) 133 | 134 | 135 | face_emb, face_kps = face_extraction(image=face_image) 136 | face_kps = draw_kps(face_image, face_kps) 137 | # prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7,prompt8,prompt9] 138 | prompts = [prompt0] 139 | 140 | 141 | pipe.set_ip_adapter_scale(0.8) 142 | 143 | 144 | print("================") 145 | 146 | inference 147 | for i in range(1): 148 | print("-------------") 149 | image = pipe( 150 | prompt=prompts[i], 151 | negative_prompt=n_prompt, 152 | image_embeds=face_emb, 153 | image=face_kps, 154 | controlnet_conditioing_scale=0.5, 155 | num_inference_steps=32, 156 | guidance_scale=5, 157 | ).images[0] 158 | print("+++++++++++++++++") 159 | ind = len(os.listdir("./results/")) 160 | # image.save("./results/test_%d.jpg" % (i)) 161 | image.save("./results/test_%d.jpg" % (ind)) 162 | 163 | -------------------------------------------------------------------------------- /ip_adapter/__pycache__/attention_processor.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/ip_adapter/__pycache__/attention_processor.cpython-310.pyc -------------------------------------------------------------------------------- /ip_adapter/__pycache__/resampler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/ip_adapter/__pycache__/resampler.cpython-310.pyc -------------------------------------------------------------------------------- /ip_adapter/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/ip_adapter/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /ip_adapter/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | try: 7 | import xformers 8 | import xformers.ops 9 | xformers_available = True 10 | except Exception as e: 11 | xformers_available = False 12 | 13 | 14 | 15 | class RegionControler(object): 16 | def __init__(self) -> None: 17 | self.prompt_image_conditioning = [] 18 | region_control = RegionControler() 19 | 20 | 21 | class AttnProcessor(nn.Module): 22 | r""" 23 | Default processor for performing attention-related computations. 24 | """ 25 | def __init__( 26 | self, 27 | hidden_size=None, 28 | cross_attention_dim=None, 29 | ): 30 | super().__init__() 31 | 32 | def __call__( 33 | self, 34 | attn, 35 | hidden_states, 36 | encoder_hidden_states=None, 37 | attention_mask=None, 38 | temb=None, 39 | ): 40 | residual = hidden_states 41 | 42 | if attn.spatial_norm is not None: 43 | hidden_states = attn.spatial_norm(hidden_states, temb) 44 | 45 | input_ndim = hidden_states.ndim 46 | 47 | if input_ndim == 4: 48 | batch_size, channel, height, width = hidden_states.shape 49 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 50 | 51 | batch_size, sequence_length, _ = ( 52 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 53 | ) 54 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 55 | 56 | if attn.group_norm is not None: 57 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 58 | 59 | query = attn.to_q(hidden_states) 60 | 61 | if encoder_hidden_states is None: 62 | encoder_hidden_states = hidden_states 63 | elif attn.norm_cross: 64 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 65 | 66 | key = attn.to_k(encoder_hidden_states) 67 | value = attn.to_v(encoder_hidden_states) 68 | 69 | query = attn.head_to_batch_dim(query) 70 | key = attn.head_to_batch_dim(key) 71 | value = attn.head_to_batch_dim(value) 72 | 73 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 74 | hidden_states = torch.bmm(attention_probs, value) 75 | hidden_states = attn.batch_to_head_dim(hidden_states) 76 | 77 | # linear proj 78 | hidden_states = attn.to_out[0](hidden_states) 79 | # dropout 80 | hidden_states = attn.to_out[1](hidden_states) 81 | 82 | if input_ndim == 4: 83 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 84 | 85 | if attn.residual_connection: 86 | hidden_states = hidden_states + residual 87 | 88 | hidden_states = hidden_states / attn.rescale_output_factor 89 | 90 | return hidden_states 91 | 92 | 93 | class IPAttnProcessor(nn.Module): 94 | r""" 95 | Attention processor for IP-Adapater. 96 | Args: 97 | hidden_size (`int`): 98 | The hidden size of the attention layer. 99 | cross_attention_dim (`int`): 100 | The number of channels in the `encoder_hidden_states`. 101 | scale (`float`, defaults to 1.0): 102 | the weight scale of image prompt. 103 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 104 | The context length of the image features. 105 | """ 106 | 107 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 108 | super().__init__() 109 | 110 | self.hidden_size = hidden_size 111 | self.cross_attention_dim = cross_attention_dim 112 | self.scale = scale 113 | self.num_tokens = num_tokens 114 | 115 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 116 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 117 | 118 | def __call__( 119 | self, 120 | attn, 121 | hidden_states, 122 | encoder_hidden_states=None, 123 | attention_mask=None, 124 | temb=None, 125 | ): 126 | residual = hidden_states 127 | 128 | if attn.spatial_norm is not None: 129 | hidden_states = attn.spatial_norm(hidden_states, temb) 130 | 131 | input_ndim = hidden_states.ndim 132 | 133 | if input_ndim == 4: 134 | batch_size, channel, height, width = hidden_states.shape 135 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 136 | 137 | batch_size, sequence_length, _ = ( 138 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 139 | ) 140 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 141 | 142 | if attn.group_norm is not None: 143 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 144 | 145 | query = attn.to_q(hidden_states) 146 | 147 | if encoder_hidden_states is None: 148 | encoder_hidden_states = hidden_states 149 | else: 150 | # get encoder_hidden_states, ip_hidden_states 151 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 152 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :] 153 | if attn.norm_cross: 154 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 155 | 156 | key = attn.to_k(encoder_hidden_states) 157 | value = attn.to_v(encoder_hidden_states) 158 | 159 | query = attn.head_to_batch_dim(query) 160 | key = attn.head_to_batch_dim(key) 161 | value = attn.head_to_batch_dim(value) 162 | 163 | if xformers_available: 164 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) 165 | else: 166 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 167 | hidden_states = torch.bmm(attention_probs, value) 168 | hidden_states = attn.batch_to_head_dim(hidden_states) 169 | 170 | # for ip-adapter 171 | ip_key = self.to_k_ip(ip_hidden_states) 172 | ip_value = self.to_v_ip(ip_hidden_states) 173 | 174 | ip_key = attn.head_to_batch_dim(ip_key) 175 | ip_value = attn.head_to_batch_dim(ip_value) 176 | 177 | if xformers_available: 178 | ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None) 179 | else: 180 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None) 181 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) 182 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) 183 | 184 | # region control 185 | if len(region_control.prompt_image_conditioning) == 1: 186 | region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None) 187 | if region_mask is not None: 188 | h, w = region_mask.shape[:2] 189 | ratio = (h * w / query.shape[1]) ** 0.5 190 | mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1]) 191 | else: 192 | mask = torch.ones_like(ip_hidden_states) 193 | ip_hidden_states = ip_hidden_states * mask 194 | 195 | hidden_states = hidden_states + self.scale * ip_hidden_states 196 | 197 | # linear proj 198 | hidden_states = attn.to_out[0](hidden_states) 199 | # dropout 200 | hidden_states = attn.to_out[1](hidden_states) 201 | 202 | if input_ndim == 4: 203 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 204 | 205 | if attn.residual_connection: 206 | hidden_states = hidden_states + residual 207 | 208 | hidden_states = hidden_states / attn.rescale_output_factor 209 | 210 | return hidden_states 211 | 212 | 213 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): 214 | # TODO attention_mask 215 | query = query.contiguous() 216 | key = key.contiguous() 217 | value = value.contiguous() 218 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) 219 | # hidden_states = self.reshape_batch_dim_to_heads(hidden_states) 220 | return hidden_states 221 | 222 | 223 | class AttnProcessor2_0(torch.nn.Module): 224 | r""" 225 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 226 | """ 227 | def __init__( 228 | self, 229 | hidden_size=None, 230 | cross_attention_dim=None, 231 | ): 232 | super().__init__() 233 | if not hasattr(F, "scaled_dot_product_attention"): 234 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 235 | 236 | def __call__( 237 | self, 238 | attn, 239 | hidden_states, 240 | encoder_hidden_states=None, 241 | attention_mask=None, 242 | temb=None, 243 | ): 244 | residual = hidden_states 245 | 246 | if attn.spatial_norm is not None: 247 | hidden_states = attn.spatial_norm(hidden_states, temb) 248 | 249 | input_ndim = hidden_states.ndim 250 | 251 | if input_ndim == 4: 252 | batch_size, channel, height, width = hidden_states.shape 253 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 254 | 255 | batch_size, sequence_length, _ = ( 256 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 257 | ) 258 | 259 | if attention_mask is not None: 260 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 261 | # scaled_dot_product_attention expects attention_mask shape to be 262 | # (batch, heads, source_length, target_length) 263 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 264 | 265 | if attn.group_norm is not None: 266 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 267 | 268 | query = attn.to_q(hidden_states) 269 | 270 | if encoder_hidden_states is None: 271 | encoder_hidden_states = hidden_states 272 | elif attn.norm_cross: 273 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 274 | 275 | key = attn.to_k(encoder_hidden_states) 276 | value = attn.to_v(encoder_hidden_states) 277 | 278 | inner_dim = key.shape[-1] 279 | head_dim = inner_dim // attn.heads 280 | 281 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 282 | 283 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 284 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 285 | 286 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 287 | # TODO: add support for attn.scale when we move to Torch 2.1 288 | hidden_states = F.scaled_dot_product_attention( 289 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 290 | ) 291 | 292 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 293 | hidden_states = hidden_states.to(query.dtype) 294 | 295 | # linear proj 296 | hidden_states = attn.to_out[0](hidden_states) 297 | # dropout 298 | hidden_states = attn.to_out[1](hidden_states) 299 | 300 | if input_ndim == 4: 301 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 302 | 303 | if attn.residual_connection: 304 | hidden_states = hidden_states + residual 305 | 306 | hidden_states = hidden_states / attn.rescale_output_factor 307 | 308 | return hidden_states 309 | 310 | class AttnProcessor2_0(torch.nn.Module): 311 | r""" 312 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 313 | """ 314 | 315 | def __init__( 316 | self, 317 | hidden_size=None, 318 | cross_attention_dim=None, 319 | ): 320 | super().__init__() 321 | if not hasattr(F, "scaled_dot_product_attention"): 322 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 323 | 324 | def __call__( 325 | self, 326 | attn, 327 | hidden_states, 328 | encoder_hidden_states=None, 329 | attention_mask=None, 330 | temb=None, 331 | ): 332 | residual = hidden_states 333 | 334 | if attn.spatial_norm is not None: 335 | hidden_states = attn.spatial_norm(hidden_states, temb) 336 | 337 | input_ndim = hidden_states.ndim 338 | 339 | if input_ndim == 4: 340 | batch_size, channel, height, width = hidden_states.shape 341 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 342 | 343 | batch_size, sequence_length, _ = ( 344 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 345 | ) 346 | 347 | if attention_mask is not None: 348 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 349 | # scaled_dot_product_attention expects attention_mask shape to be 350 | # (batch, heads, source_length, target_length) 351 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 352 | 353 | if attn.group_norm is not None: 354 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 355 | 356 | query = attn.to_q(hidden_states) 357 | 358 | if encoder_hidden_states is None: 359 | encoder_hidden_states = hidden_states 360 | elif attn.norm_cross: 361 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 362 | 363 | key = attn.to_k(encoder_hidden_states) 364 | value = attn.to_v(encoder_hidden_states) 365 | 366 | inner_dim = key.shape[-1] 367 | head_dim = inner_dim // attn.heads 368 | 369 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 370 | 371 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 372 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 373 | 374 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 375 | # TODO: add support for attn.scale when we move to Torch 2.1 376 | hidden_states = F.scaled_dot_product_attention( 377 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 378 | ) 379 | 380 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 381 | hidden_states = hidden_states.to(query.dtype) 382 | 383 | # linear proj 384 | hidden_states = attn.to_out[0](hidden_states) 385 | # dropout 386 | hidden_states = attn.to_out[1](hidden_states) 387 | 388 | if input_ndim == 4: 389 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 390 | 391 | if attn.residual_connection: 392 | hidden_states = hidden_states + residual 393 | 394 | hidden_states = hidden_states / attn.rescale_output_factor 395 | 396 | return hidden_states 397 | 398 | 399 | class IPAttnProcessor2_0(torch.nn.Module): 400 | r""" 401 | Attention processor for IP-Adapater for PyTorch 2.0. 402 | Args: 403 | hidden_size (`int`): 404 | The hidden size of the attention layer. 405 | cross_attention_dim (`int`): 406 | The number of channels in the `encoder_hidden_states`. 407 | scale (`float`, defaults to 1.0): 408 | the weight scale of image prompt. 409 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 410 | The context length of the image features. 411 | """ 412 | 413 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 414 | super().__init__() 415 | 416 | if not hasattr(F, "scaled_dot_product_attention"): 417 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 418 | 419 | self.hidden_size = hidden_size 420 | self.cross_attention_dim = cross_attention_dim 421 | self.scale = scale 422 | self.num_tokens = num_tokens 423 | 424 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 425 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 426 | 427 | def __call__( 428 | self, 429 | attn, 430 | hidden_states, 431 | encoder_hidden_states=None, 432 | attention_mask=None, 433 | temb=None, 434 | ): 435 | residual = hidden_states 436 | 437 | if attn.spatial_norm is not None: 438 | hidden_states = attn.spatial_norm(hidden_states, temb) 439 | 440 | input_ndim = hidden_states.ndim 441 | 442 | if input_ndim == 4: 443 | batch_size, channel, height, width = hidden_states.shape 444 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 445 | 446 | batch_size, sequence_length, _ = ( 447 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 448 | ) 449 | 450 | if attention_mask is not None: 451 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 452 | # scaled_dot_product_attention expects attention_mask shape to be 453 | # (batch, heads, source_length, target_length) 454 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 455 | 456 | if attn.group_norm is not None: 457 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 458 | 459 | query = attn.to_q(hidden_states) 460 | 461 | if encoder_hidden_states is None: 462 | encoder_hidden_states = hidden_states 463 | else: 464 | # get encoder_hidden_states, ip_hidden_states 465 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 466 | encoder_hidden_states, ip_hidden_states = ( 467 | encoder_hidden_states[:, :end_pos, :], 468 | encoder_hidden_states[:, end_pos:, :], 469 | ) 470 | if attn.norm_cross: 471 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 472 | 473 | key = attn.to_k(encoder_hidden_states) 474 | value = attn.to_v(encoder_hidden_states) 475 | 476 | inner_dim = key.shape[-1] 477 | head_dim = inner_dim // attn.heads 478 | 479 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 480 | 481 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 482 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 483 | 484 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 485 | # TODO: add support for attn.scale when we move to Torch 2.1 486 | hidden_states = F.scaled_dot_product_attention( 487 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 488 | ) 489 | 490 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 491 | hidden_states = hidden_states.to(query.dtype) 492 | 493 | # for ip-adapter 494 | ip_key = self.to_k_ip(ip_hidden_states) 495 | ip_value = self.to_v_ip(ip_hidden_states) 496 | 497 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 498 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 499 | 500 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 501 | # TODO: add support for attn.scale when we move to Torch 2.1 502 | ip_hidden_states = F.scaled_dot_product_attention( 503 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 504 | ) 505 | 506 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 507 | ip_hidden_states = ip_hidden_states.to(query.dtype) 508 | 509 | hidden_states = hidden_states + self.scale * ip_hidden_states 510 | 511 | # linear proj 512 | hidden_states = attn.to_out[0](hidden_states) 513 | # dropout 514 | hidden_states = attn.to_out[1](hidden_states) 515 | 516 | if input_ndim == 4: 517 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 518 | 519 | if attn.residual_connection: 520 | hidden_states = hidden_states + residual 521 | 522 | hidden_states = hidden_states / attn.rescale_output_factor 523 | 524 | return hidden_states 525 | -------------------------------------------------------------------------------- /ip_adapter/mlp.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | class MLPFeatureProjModel(nn.Module): 8 | """SD model with feature prompt""" 9 | 10 | # torch.Size([1, 49, 1536]) -> torch.Size([1, 49, 768]) 11 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=1536): 12 | super().__init__() 13 | 14 | self.cross_attention_dim = cross_attention_dim 15 | 16 | self.proj = torch.nn.Sequential( 17 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim), 18 | torch.nn.GELU(), 19 | torch.nn.Linear(id_embeddings_dim, cross_attention_dim), 20 | torch.nn.LayerNorm(cross_attention_dim) 21 | ) 22 | 23 | def forward(self, id_embeds): 24 | feature_tokens = self.proj(id_embeds) 25 | 26 | return feature_tokens -------------------------------------------------------------------------------- /ip_adapter/resampler.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py 2 | import math 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | # FFN 9 | def FeedForward(dim, mult=4): 10 | inner_dim = int(dim * mult) 11 | return nn.Sequential( 12 | nn.LayerNorm(dim), 13 | nn.Linear(dim, inner_dim, bias=False), 14 | nn.GELU(), 15 | nn.Linear(inner_dim, dim, bias=False), 16 | ) 17 | 18 | 19 | def reshape_tensor(x, heads): 20 | bs, length, width = x.shape 21 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 22 | x = x.view(bs, length, heads, -1) 23 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 24 | x = x.transpose(1, 2) 25 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 26 | x = x.reshape(bs, heads, length, -1) 27 | return x 28 | 29 | 30 | class PerceiverAttention(nn.Module): 31 | def __init__(self, *, dim, dim_head=64, heads=8): 32 | super().__init__() 33 | self.scale = dim_head**-0.5 34 | self.dim_head = dim_head 35 | self.heads = heads 36 | inner_dim = dim_head * heads 37 | 38 | self.norm1 = nn.LayerNorm(dim) 39 | self.norm2 = nn.LayerNorm(dim) 40 | 41 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 42 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 43 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 44 | 45 | 46 | def forward(self, x, latents): 47 | """ 48 | Args: 49 | x (torch.Tensor): image features 50 | shape (b, n1, D) 51 | latent (torch.Tensor): latent features 52 | shape (b, n2, D) 53 | """ 54 | x = self.norm1(x) 55 | latents = self.norm2(latents) 56 | 57 | b, l, _ = latents.shape 58 | 59 | q = self.to_q(latents) 60 | kv_input = torch.cat((x, latents), dim=-2) 61 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 62 | 63 | q = reshape_tensor(q, self.heads) 64 | k = reshape_tensor(k, self.heads) 65 | v = reshape_tensor(v, self.heads) 66 | 67 | # attention 68 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 69 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 70 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 71 | out = weight @ v 72 | 73 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 74 | 75 | return self.to_out(out) 76 | 77 | 78 | class Resampler(nn.Module): 79 | def __init__( 80 | self, 81 | dim=1024, 82 | depth=8, 83 | dim_head=64, 84 | heads=16, 85 | num_queries=8, 86 | embedding_dim=768, 87 | output_dim=1024, 88 | ff_mult=4, 89 | ): 90 | super().__init__() 91 | 92 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5) 93 | 94 | self.proj_in = nn.Linear(embedding_dim, dim) 95 | 96 | self.proj_out = nn.Linear(dim, output_dim) 97 | self.norm_out = nn.LayerNorm(output_dim) 98 | 99 | self.layers = nn.ModuleList([]) 100 | for _ in range(depth): 101 | self.layers.append( 102 | nn.ModuleList( 103 | [ 104 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 105 | FeedForward(dim=dim, mult=ff_mult), 106 | ] 107 | ) 108 | ) 109 | 110 | def forward(self, x): 111 | 112 | latents = self.latents.repeat(x.size(0), 1, 1) 113 | 114 | x = self.proj_in(x) 115 | 116 | for attn, ff in self.layers: 117 | latents = attn(x, latents) + latents 118 | latents = ff(latents) + latents 119 | 120 | latents = self.proj_out(latents) 121 | return self.norm_out(latents) 122 | 123 | -------------------------------------------------------------------------------- /ip_adapter/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | 3 | 4 | def is_torch2_available(): 5 | return hasattr(F, "scaled_dot_product_attention") 6 | -------------------------------------------------------------------------------- /pipeline_stable_diffusion_xl_instantid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The InstantX Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 17 | 18 | import cv2 19 | import math 20 | 21 | import numpy as np 22 | import PIL.Image 23 | import torch 24 | import torch.nn.functional as F 25 | 26 | from diffusers.image_processor import PipelineImageInput 27 | 28 | from diffusers.models import ControlNetModel 29 | 30 | from diffusers.utils import ( 31 | deprecate, 32 | logging, 33 | replace_example_docstring, 34 | ) 35 | from diffusers.utils.torch_utils import is_compiled_module, is_torch_version 36 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput 37 | 38 | from diffusers import StableDiffusionXLControlNetPipeline 39 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel 40 | from diffusers.utils.import_utils import is_xformers_available 41 | 42 | from ip_adapter.resampler import Resampler 43 | 44 | from ip_adapter.utils import is_torch2_available 45 | 46 | if is_torch2_available(): 47 | from ip_adapter.attention_processor import ( 48 | AttnProcessor2_0 as AttnProcessor, 49 | ) 50 | from ip_adapter.attention_processor import ( 51 | IPAttnProcessor2_0 as IPAttnProcessor, 52 | ) 53 | else: 54 | from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor 55 | 56 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 57 | 58 | 59 | EXAMPLE_DOC_STRING = """ 60 | Examples: 61 | ```py 62 | >>> # !pip install opencv-python transformers accelerate insightface 63 | >>> import diffusers 64 | >>> from diffusers.utils import load_image 65 | >>> from diffusers.models import ControlNetModel 66 | 67 | >>> import cv2 68 | >>> import torch 69 | >>> import numpy as np 70 | >>> from PIL import Image 71 | 72 | >>> from insightface.app import FaceAnalysis 73 | >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps 74 | 75 | >>> # download 'antelopev2' under ./models 76 | >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 77 | >>> app.prepare(ctx_id=0, det_size=(640, 640)) 78 | 79 | >>> # download models under ./checkpoints 80 | >>> face_adapter = f'./checkpoints/ip-adapter.bin' 81 | >>> controlnet_path = f'./checkpoints/ControlNetModel' 82 | 83 | >>> # load IdentityNet 84 | >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) 85 | 86 | >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained( 87 | ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 88 | ... ) 89 | >>> pipe.cuda() 90 | 91 | >>> # load adapter 92 | >>> pipe.load_ip_adapter_instantid(face_adapter) 93 | 94 | >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality" 95 | >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured" 96 | 97 | >>> # load an image 98 | >>> image = load_image("your-example.jpg") 99 | 100 | >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1] 101 | >>> face_emb = face_info['embedding'] 102 | >>> face_kps = draw_kps(face_image, face_info['kps']) 103 | 104 | >>> pipe.set_ip_adapter_scale(0.8) 105 | 106 | >>> # generate image 107 | >>> image = pipe( 108 | ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8 109 | ... ).images[0] 110 | ```from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps 111 | 112 | """ 113 | 114 | def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]): 115 | 116 | stickwidth = 4 117 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) 118 | kps = np.array(kps) 119 | 120 | w, h = image_pil.size 121 | out_img = np.zeros([h, w, 3]) 122 | 123 | for i in range(len(limbSeq)): 124 | index = limbSeq[i] 125 | color = color_list[index[0]] 126 | 127 | x = kps[index][:, 0] 128 | y = kps[index][:, 1] 129 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 130 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) 131 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 132 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) 133 | out_img = (out_img * 0.6).astype(np.uint8) 134 | 135 | for idx_kp, kp in enumerate(kps): 136 | color = color_list[idx_kp] 137 | x, y = kp 138 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) 139 | 140 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) 141 | return out_img_pil 142 | 143 | class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline): 144 | 145 | def cuda(self, dtype=torch.float16, use_xformers=False): 146 | self.to('cuda', dtype) 147 | 148 | if hasattr(self, 'image_proj_model'): 149 | self.image_proj_model.to(self.unet.device).to(self.unet.dtype) 150 | 151 | if use_xformers: 152 | if is_xformers_available(): 153 | import xformers 154 | from packaging import version 155 | 156 | xformers_version = version.parse(xformers.__version__) 157 | if xformers_version == version.parse("0.0.16"): 158 | logger.warn( 159 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 160 | ) 161 | self.enable_xformers_memory_efficient_attention() 162 | else: 163 | raise ValueError("xformers is not available. Make sure it is installed correctly") 164 | 165 | def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5): 166 | self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens) 167 | self.set_ip_adapter(model_ckpt, num_tokens, scale) 168 | 169 | def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16): 170 | 171 | image_proj_model = Resampler( 172 | dim=1280, 173 | depth=4, 174 | dim_head=64, 175 | heads=20, 176 | num_queries=num_tokens, 177 | embedding_dim=image_emb_dim, 178 | output_dim=self.unet.config.cross_attention_dim, 179 | ff_mult=4, 180 | ) 181 | 182 | image_proj_model.eval() 183 | 184 | self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype) 185 | state_dict = torch.load(model_ckpt, map_location="cpu") 186 | if 'image_proj' in state_dict: 187 | state_dict = state_dict["image_proj"] 188 | self.image_proj_model.load_state_dict(state_dict) 189 | 190 | self.image_proj_model_in_features = image_emb_dim 191 | 192 | def set_ip_adapter(self, model_ckpt, num_tokens, scale): 193 | 194 | unet = self.unet 195 | attn_procs = {} 196 | for name in unet.attn_processors.keys(): 197 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 198 | if name.startswith("mid_block"): 199 | hidden_size = unet.config.block_out_channels[-1] 200 | elif name.startswith("up_blocks"): 201 | block_id = int(name[len("up_blocks.")]) 202 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 203 | elif name.startswith("down_blocks"): 204 | block_id = int(name[len("down_blocks.")]) 205 | hidden_size = unet.config.block_out_channels[block_id] 206 | if cross_attention_dim is None: 207 | attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype) 208 | else: 209 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, 210 | cross_attention_dim=cross_attention_dim, 211 | scale=scale, 212 | num_tokens=num_tokens).to(unet.device, dtype=unet.dtype) 213 | unet.set_attn_processor(attn_procs) 214 | 215 | state_dict = torch.load(model_ckpt, map_location="cpu") 216 | ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values()) 217 | if 'ip_adapter' in state_dict: 218 | state_dict = state_dict['ip_adapter'] 219 | ip_layers.load_state_dict(state_dict) 220 | 221 | def set_ip_adapter_scale(self, scale): 222 | unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet 223 | for attn_processor in unet.attn_processors.values(): 224 | if isinstance(attn_processor, IPAttnProcessor): 225 | attn_processor.scale = scale 226 | 227 | def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance): 228 | 229 | if isinstance(prompt_image_emb, torch.Tensor): 230 | prompt_image_emb = prompt_image_emb.clone().detach() 231 | else: 232 | prompt_image_emb = torch.tensor(prompt_image_emb) 233 | 234 | prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype) 235 | prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features]) 236 | 237 | if do_classifier_free_guidance: 238 | prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0) 239 | else: 240 | prompt_image_emb = torch.cat([prompt_image_emb], dim=0) 241 | 242 | prompt_image_emb = self.image_proj_model(prompt_image_emb) 243 | return prompt_image_emb 244 | 245 | @torch.no_grad() 246 | @replace_example_docstring(EXAMPLE_DOC_STRING) 247 | def __call__( 248 | self, 249 | prompt: Union[str, List[str]] = None, 250 | prompt_2: Optional[Union[str, List[str]]] = None, 251 | image: PipelineImageInput = None, 252 | height: Optional[int] = None, 253 | width: Optional[int] = None, 254 | num_inference_steps: int = 50, 255 | guidance_scale: float = 5.0, 256 | negative_prompt: Optional[Union[str, List[str]]] = None, 257 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 258 | num_images_per_prompt: Optional[int] = 1, 259 | eta: float = 0.0, 260 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 261 | latents: Optional[torch.FloatTensor] = None, 262 | prompt_embeds: Optional[torch.FloatTensor] = None, 263 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 264 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 265 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 266 | image_embeds: Optional[torch.FloatTensor] = None, 267 | output_type: Optional[str] = "pil", 268 | return_dict: bool = True, 269 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 270 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0, 271 | guess_mode: bool = False, 272 | control_guidance_start: Union[float, List[float]] = 0.0, 273 | control_guidance_end: Union[float, List[float]] = 1.0, 274 | original_size: Tuple[int, int] = None, 275 | crops_coords_top_left: Tuple[int, int] = (0, 0), 276 | target_size: Tuple[int, int] = None, 277 | negative_original_size: Optional[Tuple[int, int]] = None, 278 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0), 279 | negative_target_size: Optional[Tuple[int, int]] = None, 280 | clip_skip: Optional[int] = None, 281 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 282 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 283 | **kwargs, 284 | ): 285 | r""" 286 | The call function to the pipeline for generation. 287 | 288 | Args: 289 | prompt (`str` or `List[str]`, *optional*): 290 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. 291 | prompt_2 (`str` or `List[str]`, *optional*): 292 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 293 | used in both text-encoders. 294 | image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: 295 | `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): 296 | The ControlNet input condition to provide guidance to the `unet` for generation. If the type is 297 | specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be 298 | accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height 299 | and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in 300 | `init`, images must be passed as a list such that each element of the list can be correctly batched for 301 | input to a single ControlNet. 302 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 303 | The height in pixels of the generated image. Anything below 512 pixels won't work well for 304 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 305 | and checkpoints that are not specifically fine-tuned on low resolutions. 306 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): 307 | The width in pixels of the generated image. Anything below 512 pixels won't work well for 308 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) 309 | and checkpoints that are not specifically fine-tuned on low resolutions. 310 | num_inference_steps (`int`, *optional*, defaults to 50): 311 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 312 | expense of slower inference. 313 | guidance_scale (`float`, *optional*, defaults to 5.0): 314 | A higher guidance scale value encourages the model to generate images closely linked to the text 315 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. 316 | negative_prompt (`str` or `List[str]`, *optional*): 317 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to 318 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). 319 | negative_prompt_2 (`str` or `List[str]`, *optional*): 320 | The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2` 321 | and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. 322 | num_images_per_prompt (`int`, *optional*, defaults to 1): 323 | The number of images to generate per prompt. 324 | eta (`float`, *optional*, defaults to 0.0): 325 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 326 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. 327 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 328 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 329 | generation deterministic. 330 | latents (`torch.FloatTensor`, *optional*): 331 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image 332 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 333 | tensor is generated by sampling using the supplied random `generator`. 334 | prompt_embeds (`torch.FloatTensor`, *optional*): 335 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not 336 | provided, text embeddings are generated from the `prompt` input argument. 337 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 338 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 339 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. 340 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 341 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If 342 | not provided, pooled text embeddings are generated from `prompt` input argument. 343 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 344 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt 345 | weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input 346 | argument. 347 | image_embeds (`torch.FloatTensor`, *optional*): 348 | Pre-generated image embeddings. 349 | output_type (`str`, *optional*, defaults to `"pil"`): 350 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 351 | return_dict (`bool`, *optional*, defaults to `True`): 352 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 353 | plain tuple. 354 | cross_attention_kwargs (`dict`, *optional*): 355 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in 356 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 357 | controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): 358 | The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added 359 | to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set 360 | the corresponding scale as a list. 361 | guess_mode (`bool`, *optional*, defaults to `False`): 362 | The ControlNet encoder tries to recognize the content of the input image even if you remove all 363 | prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. 364 | control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): 365 | The percentage of total steps at which the ControlNet starts applying. 366 | control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): 367 | The percentage of total steps at which the ControlNet stops applying. 368 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 369 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. 370 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as 371 | explained in section 2.2 of 372 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 373 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 374 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position 375 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting 376 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of 377 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 378 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 379 | For most cases, `target_size` should be set to the desired height and width of the generated image. If 380 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in 381 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 382 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 383 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's 384 | micro-conditioning as explained in section 2.2 of 385 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 386 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 387 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 388 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's 389 | micro-conditioning as explained in section 2.2 of 390 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 391 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 392 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 393 | To negatively condition the generation process based on a target image resolution. It should be as same 394 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of 395 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more 396 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208. 397 | clip_skip (`int`, *optional*): 398 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that 399 | the output of the pre-final layer will be used for computing the prompt embeddings. 400 | callback_on_step_end (`Callable`, *optional*): 401 | A function that calls at the end of each denoising steps during the inference. The function is called 402 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 403 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 404 | `callback_on_step_end_tensor_inputs`. 405 | callback_on_step_end_tensor_inputs (`List`, *optional*): 406 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 407 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 408 | `._callback_tensor_inputs` attribute of your pipeine class. 409 | 410 | Examples: 411 | 412 | Returns: 413 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 414 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, 415 | otherwise a `tuple` is returned containing the output images. 416 | """ 417 | 418 | callback = kwargs.pop("callback", None) 419 | callback_steps = kwargs.pop("callback_steps", None) 420 | 421 | if callback is not None: 422 | deprecate( 423 | "callback", 424 | "1.0.0", 425 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 426 | ) 427 | if callback_steps is not None: 428 | deprecate( 429 | "callback_steps", 430 | "1.0.0", 431 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", 432 | ) 433 | 434 | controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet 435 | 436 | # align format for control guidance 437 | if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): 438 | control_guidance_start = len(control_guidance_end) * [control_guidance_start] 439 | elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): 440 | control_guidance_end = len(control_guidance_start) * [control_guidance_end] 441 | elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): 442 | mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 443 | control_guidance_start, control_guidance_end = ( 444 | mult * [control_guidance_start], 445 | mult * [control_guidance_end], 446 | ) 447 | 448 | # 1. Check inputs. Raise error if not correct 449 | self.check_inputs( 450 | prompt, 451 | prompt_2, 452 | image, 453 | callback_steps, 454 | negative_prompt, 455 | negative_prompt_2, 456 | prompt_embeds, 457 | negative_prompt_embeds, 458 | pooled_prompt_embeds, 459 | negative_pooled_prompt_embeds, 460 | controlnet_conditioning_scale, 461 | control_guidance_start, 462 | control_guidance_end, 463 | callback_on_step_end_tensor_inputs, 464 | ) 465 | 466 | self._guidance_scale = guidance_scale 467 | self._clip_skip = clip_skip 468 | self._cross_attention_kwargs = cross_attention_kwargs 469 | 470 | # 2. Define call parameters 471 | if prompt is not None and isinstance(prompt, str): 472 | batch_size = 1 473 | elif prompt is not None and isinstance(prompt, list): 474 | batch_size = len(prompt) 475 | else: 476 | batch_size = prompt_embeds.shape[0] 477 | 478 | device = self._execution_device 479 | 480 | if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): 481 | controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) 482 | 483 | global_pool_conditions = ( 484 | controlnet.config.global_pool_conditions 485 | if isinstance(controlnet, ControlNetModel) 486 | else controlnet.nets[0].config.global_pool_conditions 487 | ) 488 | guess_mode = guess_mode or global_pool_conditions 489 | 490 | # 3.1 Encode input prompt 491 | text_encoder_lora_scale = ( 492 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None 493 | ) 494 | ( 495 | prompt_embeds, 496 | negative_prompt_embeds, 497 | pooled_prompt_embeds, 498 | negative_pooled_prompt_embeds, 499 | ) = self.encode_prompt( 500 | prompt, 501 | prompt_2, 502 | device, 503 | num_images_per_prompt, 504 | self.do_classifier_free_guidance, 505 | negative_prompt, 506 | negative_prompt_2, 507 | prompt_embeds=prompt_embeds, 508 | negative_prompt_embeds=negative_prompt_embeds, 509 | pooled_prompt_embeds=pooled_prompt_embeds, 510 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 511 | lora_scale=text_encoder_lora_scale, 512 | clip_skip=self.clip_skip, 513 | ) 514 | 515 | # 3.2 Encode image prompt 516 | prompt_image_emb = self._encode_prompt_image_emb(image_embeds, 517 | device, 518 | self.unet.dtype, 519 | self.do_classifier_free_guidance) 520 | 521 | # 4. Prepare image 522 | if isinstance(controlnet, ControlNetModel): 523 | image = self.prepare_image( 524 | image=image, 525 | width=width, 526 | height=height, 527 | batch_size=batch_size * num_images_per_prompt, 528 | num_images_per_prompt=num_images_per_prompt, 529 | device=device, 530 | dtype=controlnet.dtype, 531 | do_classifier_free_guidance=self.do_classifier_free_guidance, 532 | guess_mode=guess_mode, 533 | ) 534 | height, width = image.shape[-2:] 535 | print("image: ", image.shape) 536 | elif isinstance(controlnet, MultiControlNetModel): 537 | images = [] 538 | 539 | for image_ in image: 540 | image_ = self.prepare_image( 541 | image=image_, 542 | width=width, 543 | height=height, 544 | batch_size=batch_size * num_images_per_prompt, 545 | num_images_per_prompt=num_images_per_prompt, 546 | device=device, 547 | dtype=controlnet.dtype, 548 | do_classifier_free_guidance=self.do_classifier_free_guidance, 549 | guess_mode=guess_mode, 550 | ) 551 | 552 | images.append(image_) 553 | 554 | image = images 555 | height, width = image[0].shape[-2:] 556 | else: 557 | assert False 558 | 559 | # 5. Prepare timesteps 560 | self.scheduler.set_timesteps(num_inference_steps, device=device) 561 | timesteps = self.scheduler.timesteps 562 | self._num_timesteps = len(timesteps) 563 | 564 | # 6. Prepare latent variables 565 | num_channels_latents = self.unet.config.in_channels 566 | latents = self.prepare_latents( 567 | batch_size * num_images_per_prompt, 568 | num_channels_latents, 569 | height, 570 | width, 571 | prompt_embeds.dtype, 572 | device, 573 | generator, 574 | latents, 575 | ) 576 | 577 | # 6.5 Optionally get Guidance Scale Embedding 578 | timestep_cond = None 579 | if self.unet.config.time_cond_proj_dim is not None: 580 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) 581 | timestep_cond = self.get_guidance_scale_embedding( 582 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim 583 | ).to(device=device, dtype=latents.dtype) 584 | 585 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 586 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 587 | 588 | # 7.1 Create tensor stating which controlnets to keep 589 | controlnet_keep = [] 590 | for i in range(len(timesteps)): 591 | keeps = [ 592 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) 593 | for s, e in zip(control_guidance_start, control_guidance_end) 594 | ] 595 | controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps) 596 | 597 | # 7.2 Prepare added time ids & embeddings 598 | if isinstance(image, list): 599 | original_size = original_size or image[0].shape[-2:] 600 | else: 601 | original_size = original_size or image.shape[-2:] 602 | target_size = target_size or (height, width) 603 | 604 | add_text_embeds = pooled_prompt_embeds 605 | if self.text_encoder_2 is None: 606 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 607 | else: 608 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 609 | 610 | add_time_ids = self._get_add_time_ids( 611 | original_size, 612 | crops_coords_top_left, 613 | target_size, 614 | dtype=prompt_embeds.dtype, 615 | text_encoder_projection_dim=text_encoder_projection_dim, 616 | ) 617 | 618 | if negative_original_size is not None and negative_target_size is not None: 619 | negative_add_time_ids = self._get_add_time_ids( 620 | negative_original_size, 621 | negative_crops_coords_top_left, 622 | negative_target_size, 623 | dtype=prompt_embeds.dtype, 624 | text_encoder_projection_dim=text_encoder_projection_dim, 625 | ) 626 | else: 627 | negative_add_time_ids = add_time_ids 628 | 629 | if self.do_classifier_free_guidance: 630 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 631 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 632 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) 633 | 634 | prompt_embeds = prompt_embeds.to(device) 635 | add_text_embeds = add_text_embeds.to(device) 636 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 637 | encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1) 638 | 639 | # 8. Denoising loop 640 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 641 | is_unet_compiled = is_compiled_module(self.unet) 642 | is_controlnet_compiled = is_compiled_module(self.controlnet) 643 | is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") 644 | 645 | with self.progress_bar(total=num_inference_steps) as progress_bar: 646 | for i, t in enumerate(timesteps): 647 | # Relevant thread: 648 | # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 649 | if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: 650 | torch._inductor.cudagraph_mark_step_begin() 651 | # expand the latents if we are doing classifier free guidance 652 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents 653 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 654 | 655 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 656 | 657 | # controlnet(s) inference 658 | if guess_mode and self.do_classifier_free_guidance: 659 | # Infer ControlNet only for the conditional batch. 660 | control_model_input = latents 661 | control_model_input = self.scheduler.scale_model_input(control_model_input, t) 662 | controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] 663 | controlnet_added_cond_kwargs = { 664 | "text_embeds": add_text_embeds.chunk(2)[1], 665 | "time_ids": add_time_ids.chunk(2)[1], 666 | } 667 | else: 668 | control_model_input = latent_model_input 669 | controlnet_prompt_embeds = prompt_embeds 670 | controlnet_added_cond_kwargs = added_cond_kwargs 671 | 672 | if isinstance(controlnet_keep[i], list): 673 | cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] 674 | else: 675 | controlnet_cond_scale = controlnet_conditioning_scale 676 | if isinstance(controlnet_cond_scale, list): 677 | controlnet_cond_scale = controlnet_cond_scale[0] 678 | cond_scale = controlnet_cond_scale * controlnet_keep[i] 679 | 680 | down_block_res_samples, mid_block_res_sample = self.controlnet( 681 | control_model_input, 682 | t, 683 | encoder_hidden_states=prompt_image_emb, 684 | controlnet_cond=image, 685 | conditioning_scale=cond_scale, 686 | guess_mode=guess_mode, 687 | added_cond_kwargs=controlnet_added_cond_kwargs, 688 | return_dict=False, 689 | ) 690 | 691 | if guess_mode and self.do_classifier_free_guidance: 692 | # Infered ControlNet only for the conditional batch. 693 | # To apply the output of ControlNet to both the unconditional and conditional batches, 694 | # add 0 to the unconditional batch to keep it unchanged. 695 | down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] 696 | mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) 697 | 698 | # predict the noise residual 699 | noise_pred = self.unet( 700 | latent_model_input, 701 | t, 702 | encoder_hidden_states=encoder_hidden_states, 703 | timestep_cond=timestep_cond, 704 | cross_attention_kwargs=self.cross_attention_kwargs, 705 | down_block_additional_residuals=down_block_res_samples, 706 | mid_block_additional_residual=mid_block_res_sample, 707 | added_cond_kwargs=added_cond_kwargs, 708 | return_dict=False, 709 | )[0] 710 | 711 | # perform guidance 712 | if self.do_classifier_free_guidance: 713 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 714 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 715 | 716 | # compute the previous noisy sample x_t -> x_t-1 717 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 718 | 719 | if callback_on_step_end is not None: 720 | callback_kwargs = {} 721 | for k in callback_on_step_end_tensor_inputs: 722 | callback_kwargs[k] = locals()[k] 723 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 724 | 725 | latents = callback_outputs.pop("latents", latents) 726 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 727 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 728 | 729 | # call the callback, if provided 730 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 731 | progress_bar.update() 732 | if callback is not None and i % callback_steps == 0: 733 | step_idx = i // getattr(self.scheduler, "order", 1) 734 | callback(step_idx, t, latents) 735 | 736 | if not output_type == "latent": 737 | # make sure the VAE is in float32 mode, as it overflows in float16 738 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast 739 | if needs_upcasting: 740 | self.upcast_vae() 741 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 742 | 743 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 744 | 745 | # cast back to fp16 if needed 746 | if needs_upcasting: 747 | self.vae.to(dtype=torch.float16) 748 | else: 749 | image = latents 750 | 751 | if not output_type == "latent": 752 | # apply watermark if available 753 | if self.watermark is not None: 754 | image = self.watermark.apply_watermark(image) 755 | 756 | image = self.image_processor.postprocess(image, output_type=output_type) 757 | 758 | # Offload all models 759 | self.maybe_free_model_hooks() 760 | 761 | if not return_dict: 762 | return (image,) 763 | 764 | return StableDiffusionXLPipelineOutput(images=image) -------------------------------------------------------------------------------- /train_instantId_sdxl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import random 4 | import argparse 5 | from pathlib import Path 6 | import json 7 | import itertools 8 | import time 9 | from datetime import datetime 10 | import shutil 11 | import torch 12 | import torch.nn.functional as F 13 | import numpy as np 14 | import math 15 | import cv2 16 | from torchvision import transforms 17 | from PIL import Image 18 | import PIL 19 | from transformers import CLIPImageProcessor 20 | from accelerate import Accelerator 21 | from accelerate.logging import get_logger 22 | from accelerate.utils import ProjectConfiguration 23 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, ControlNetModel 24 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection 25 | 26 | from ip_adapter.resampler import Resampler 27 | from ip_adapter.utils import is_torch2_available 28 | 29 | if is_torch2_available(): 30 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor 31 | else: 32 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor 33 | 34 | 35 | # Draw the input image for controlnet based on facial keypoints. 36 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): 37 | stickwidth = 4 38 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) 39 | kps = np.array(kps) 40 | 41 | w, h = image_pil.size 42 | out_img = np.zeros([h, w, 3]) 43 | 44 | for i in range(len(limbSeq)): 45 | index = limbSeq[i] 46 | color = color_list[index[0]] 47 | 48 | x = kps[index][:, 0] 49 | y = kps[index][:, 1] 50 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 51 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) 52 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 53 | 360, 1) 54 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) 55 | out_img = (out_img * 0.6).astype(np.uint8) 56 | 57 | for idx_kp, kp in enumerate(kps): 58 | color = color_list[idx_kp] 59 | x, y = kp 60 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) 61 | 62 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) 63 | return out_img_pil 64 | 65 | # Process the dataset by loading info from a JSON file, which includes image files, image labels, feature files, keypoint coordinates. 66 | class MyDataset(torch.utils.data.Dataset): 67 | 68 | def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True, 69 | t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): 70 | super().__init__() 71 | 72 | self.tokenizer = tokenizer 73 | self.tokenizer_2 = tokenizer_2 74 | self.size = size 75 | self.center_crop = center_crop 76 | self.i_drop_rate = i_drop_rate 77 | self.t_drop_rate = t_drop_rate 78 | self.ti_drop_rate = ti_drop_rate 79 | self.image_root_path = image_root_path 80 | 81 | self.data = [] 82 | with open(json_file, 'r') as f: 83 | for line in f: 84 | self.data.append(json.loads(line)) 85 | 86 | self.image_transforms = transforms.Compose( 87 | [ 88 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 89 | transforms.ToTensor(), 90 | transforms.Normalize([0.5], [0.5]), 91 | ] 92 | ) 93 | 94 | self.conditioning_image_transforms = transforms.Compose( 95 | [ 96 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 97 | transforms.ToTensor(), 98 | ] 99 | ) 100 | 101 | self.clip_image_processor = CLIPImageProcessor() 102 | 103 | def __getitem__(self, idx): 104 | item = self.data[idx] 105 | image_file = item["image_file"] 106 | text = item["additional_feature"] 107 | bbox = item['bbox'] 108 | landmarks = item['landmarks'] 109 | feature_file = item["insightface_feature_file"] 110 | 111 | # read image 112 | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) 113 | # draw keypoints 114 | kps_image = draw_kps(raw_image.convert("RGB"), landmarks) 115 | 116 | # original size 117 | original_width, original_height = raw_image.size 118 | original_size = torch.tensor([original_height, original_width]) 119 | 120 | # transform raw_image and kps_image 121 | image_tensor = self.image_transforms(raw_image.convert("RGB")) 122 | kps_image_tensor = self.conditioning_image_transforms(kps_image) 123 | 124 | # random crop 125 | delta_h = image_tensor.shape[1] - self.size 126 | delta_w = image_tensor.shape[2] - self.size 127 | assert not all([delta_h, delta_w]) 128 | 129 | if self.center_crop: 130 | top = delta_h // 2 131 | left = delta_w // 2 132 | else: 133 | top = np.random.randint(0, delta_h // 2 + 1) # random top crop 134 | # top = np.random.randint(0, delta_h + 1) # random crop 135 | left = np.random.randint(0, delta_w + 1) # random crop 136 | 137 | # The image and kps_image must follow the same cropping to ensure that the facial coordinates correspond correctly. 138 | image = transforms.functional.crop( 139 | image_tensor, top=top, left=left, height=self.size, width=self.size 140 | ) 141 | kps_image = transforms.functional.crop( 142 | kps_image_tensor, top=top, left=left, height=self.size, width=self.size 143 | ) 144 | 145 | crop_coords_top_left = torch.tensor([top, left]) 146 | 147 | # load face feature 148 | # face_id_embed = torch.load(os.path.join(self.image_root_path, feature_file), map_location="cpu") 149 | face_id_embed = np.load(os.path.join(self.image_root_path, feature_file)) 150 | face_id_embed = torch.from_numpy(face_id_embed) 151 | face_id_embed = face_id_embed.reshape(1, -1) 152 | 153 | # set cfg drop rate 154 | drop_feature_embed = 0 155 | drop_text_embed = 0 156 | rand_num = random.random() 157 | if rand_num < self.i_drop_rate: 158 | drop_feature_embed = 1 159 | elif rand_num < (self.i_drop_rate + self.t_drop_rate): 160 | drop_text_embed = 1 161 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): 162 | drop_text_embed = 1 163 | drop_feature_embed = 1 164 | 165 | # CFG process 166 | if drop_text_embed: 167 | text = "" 168 | if drop_feature_embed: 169 | face_id_embed = torch.zeros_like(face_id_embed) 170 | 171 | # get text and tokenize 172 | text_input_ids = self.tokenizer( 173 | text, 174 | max_length=self.tokenizer.model_max_length, 175 | padding="max_length", 176 | truncation=True, 177 | return_tensors="pt" 178 | ).input_ids 179 | 180 | text_input_ids_2 = self.tokenizer_2( 181 | text, 182 | max_length=self.tokenizer_2.model_max_length, 183 | padding="max_length", 184 | truncation=True, 185 | return_tensors="pt" 186 | ).input_ids 187 | 188 | return { 189 | "image": image, 190 | "kps_image": kps_image, 191 | "text_input_ids": text_input_ids, 192 | "text_input_ids_2": text_input_ids_2, 193 | "face_id_embed": face_id_embed, 194 | "original_size": original_size, 195 | "crop_coords_top_left": crop_coords_top_left, 196 | "target_size": torch.tensor([self.size, self.size]), 197 | } 198 | 199 | def __len__(self): 200 | return len(self.data) 201 | 202 | 203 | def collate_fn(data): 204 | images = torch.stack([example["image"] for example in data]) 205 | kps_images = torch.stack([example["kps_image"] for example in data]) 206 | 207 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) 208 | text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0) 209 | face_id_embed = torch.stack([example["face_id_embed"] for example in data]) 210 | original_size = torch.stack([example["original_size"] for example in data]) 211 | crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data]) 212 | target_size = torch.stack([example["target_size"] for example in data]) 213 | 214 | return { 215 | "images": images, 216 | "kps_images": kps_images, 217 | "text_input_ids": text_input_ids, 218 | "text_input_ids_2": text_input_ids_2, 219 | "face_id_embed": face_id_embed, 220 | "original_size": original_size, 221 | "crop_coords_top_left": crop_coords_top_left, 222 | "target_size": target_size, 223 | } 224 | 225 | 226 | class InstantIDAdapter(torch.nn.Module): 227 | """InstantIDAdapter""" 228 | def __init__(self, unet, controlnet, feature_proj_model, adapter_modules, ckpt_path=None): 229 | super().__init__() 230 | self.unet = unet 231 | self.controlnet = controlnet 232 | self.feature_proj_model = feature_proj_model 233 | self.adapter_modules = adapter_modules 234 | if ckpt_path is not None: 235 | self.load_from_checkpoint(ckpt_path) 236 | 237 | def forward(self,noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, feature_embeds, controlnet_image): 238 | face_embedding = self.feature_proj_model(feature_embeds) 239 | encoder_hidden_states = torch.cat([encoder_hidden_states, face_embedding], dim=1) 240 | # ControlNet conditioning. 241 | down_block_res_samples, mid_block_res_sample = self.controlnet( 242 | noisy_latents, 243 | timesteps, 244 | encoder_hidden_states=face_embedding, # Insightface feature 245 | added_cond_kwargs=unet_added_cond_kwargs, 246 | controlnet_cond=controlnet_image, # keypoints image 247 | return_dict=False, 248 | ) 249 | # Predict the noise residual. 250 | noise_pred = self.unet( 251 | noisy_latents, 252 | timesteps, 253 | encoder_hidden_states=encoder_hidden_states, 254 | added_cond_kwargs=unet_added_cond_kwargs, 255 | down_block_additional_residuals=[sample for sample in down_block_res_samples], 256 | mid_block_additional_residual=mid_block_res_sample, 257 | ).sample 258 | 259 | return noise_pred 260 | 261 | def load_from_checkpoint(self, ckpt_path: str): 262 | # Calculate original checksums 263 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.feature_proj_model.parameters()])) 264 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 265 | 266 | state_dict = torch.load(ckpt_path, map_location="cpu") 267 | 268 | # Check if 'latents' exists in both the saved state_dict and the current model's state_dict 269 | strict_load_feature_proj_model = True 270 | if "latents" in state_dict["image_proj"] and "latents" in self.feature_proj_model.state_dict(): 271 | # Check if the shapes are mismatched 272 | if state_dict["image_proj"]["latents"].shape != self.feature_proj_model.state_dict()["latents"].shape: 273 | print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.") 274 | print("Removing 'latents' from checkpoint and loading the rest of the weights.") 275 | del state_dict["image_proj"]["latents"] 276 | strict_load_feature_proj_model = False 277 | 278 | # Load state dict for feature_proj_model and adapter_modules 279 | self.feature_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_feature_proj_model) 280 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) 281 | 282 | # Calculate new checksums 283 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.feature_proj_model.parameters()])) 284 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) 285 | 286 | # Verify if the weights have changed 287 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of feature_proj_model did not change!" 288 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" 289 | 290 | print(f"Successfully loaded weights from checkpoint {ckpt_path}") 291 | 292 | 293 | def parse_args(): 294 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 295 | parser.add_argument( 296 | "--pretrained_model_name_or_path", 297 | type=str, 298 | default=None, 299 | required=True, 300 | help="Path to pretrained model or model identifier from huggingface.co/models.", 301 | ) 302 | parser.add_argument( 303 | "--pretrained_ip_adapter_path", 304 | type=str, 305 | default=None, 306 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.", 307 | ) 308 | parser.add_argument( 309 | "--controlnet_model_name_or_path", 310 | type=str, 311 | default=None, 312 | help="Path to pretrained controlnet model. If not specified weights are initialized from unet.", 313 | ) 314 | 315 | parser.add_argument( 316 | "--num_tokens", 317 | type=int, 318 | default=16, 319 | help="Number of tokens to query from the CLIP image encoding.", 320 | ) 321 | parser.add_argument( 322 | "--checkpoints_total_limit", 323 | type=int, 324 | default=1, 325 | help=( 326 | "Save a checkpoint of the training state every X updates" 327 | ), 328 | ) 329 | parser.add_argument( 330 | "--data_json_file", 331 | type=str, 332 | default=None, 333 | required=True, 334 | help="Training data", 335 | ) 336 | parser.add_argument( 337 | "--data_root_path", 338 | type=str, 339 | default="", 340 | required=False, 341 | help="Training data root path", 342 | ) 343 | parser.add_argument('--clip_proc_mode', 344 | choices=["seg_align", "seg_crop", "orig_align", "orig_crop", "seg_align_pad", 345 | "orig_align_pad"], 346 | default="orig_crop", 347 | help='The mode to preprocess clip image encoder input.') 348 | 349 | parser.add_argument( 350 | "--image_encoder_path", 351 | type=str, 352 | default=None, 353 | required=True, 354 | help="Path to CLIP image encoder", 355 | ) 356 | parser.add_argument( 357 | "--center_crop", 358 | default=False, 359 | action="store_true", 360 | help=( 361 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 362 | " cropped. The images will be resized to the resolution first before cropping." 363 | ), 364 | ) 365 | parser.add_argument( 366 | "--output_dir", 367 | type=str, 368 | default="sd-ip_adapter", 369 | help="The output directory where the model predictions and checkpoints will be written.", 370 | ) 371 | parser.add_argument( 372 | "--logging_dir", 373 | type=str, 374 | default="logs", 375 | help=( 376 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 377 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 378 | ), 379 | ) 380 | parser.add_argument( 381 | "--resolution", 382 | type=int, 383 | default=512, 384 | help=( 385 | "The resolution for input images" 386 | ), 387 | ) 388 | parser.add_argument( 389 | "--learning_rate", 390 | type=float, 391 | default=1e-4, 392 | help="Learning rate to use.", 393 | ) 394 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") 395 | parser.add_argument("--num_train_epochs", type=int, default=100) 396 | parser.add_argument( 397 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." 398 | ) 399 | parser.add_argument( 400 | "--dataloader_num_workers", 401 | type=int, 402 | default=0, 403 | help=( 404 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 405 | ), 406 | ) 407 | parser.add_argument( 408 | "--save_steps", 409 | type=int, 410 | default=2000, 411 | help=( 412 | "Save a checkpoint of the training state every X updates" 413 | ), 414 | ) 415 | parser.add_argument( 416 | "--mixed_precision", 417 | type=str, 418 | default=None, 419 | choices=["no", "fp16", "bf16"], 420 | help=( 421 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 422 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 423 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 424 | ), 425 | ) 426 | parser.add_argument( 427 | "--report_to", 428 | type=str, 429 | default="tensorboard", 430 | help=( 431 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 432 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 433 | ), 434 | ) 435 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 436 | parser.add_argument("--noise_offset", type=float, default=None, help="noise offset") 437 | 438 | args = parser.parse_args() 439 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 440 | if env_local_rank != -1 and env_local_rank != args.local_rank: 441 | args.local_rank = env_local_rank 442 | 443 | return args 444 | 445 | 446 | def main(): 447 | args = parse_args() 448 | logging_dir = Path(args.output_dir, args.logging_dir) 449 | 450 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 451 | 452 | accelerator = Accelerator( 453 | mixed_precision=args.mixed_precision, 454 | log_with=args.report_to, 455 | project_config=accelerator_project_config, 456 | ) 457 | 458 | num_devices = accelerator.num_processes 459 | 460 | if accelerator.is_main_process: 461 | if args.output_dir is not None: 462 | os.makedirs(args.output_dir, exist_ok=True) 463 | 464 | # Load scheduler, tokenizer and models. 465 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") 466 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") 467 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") 468 | tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2") 469 | text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2") 470 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") 471 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") 472 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) 473 | if args.controlnet_model_name_or_path: 474 | print("Loading existing controlnet weights") 475 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) 476 | else: 477 | print("Initializing controlnet weights from unet") 478 | controlnet = ControlNetModel.from_unet(unet) 479 | 480 | # freeze parameters of models to save more memory 481 | unet.requires_grad_(False) 482 | vae.requires_grad_(False) 483 | text_encoder.requires_grad_(False) 484 | text_encoder_2.requires_grad_(False) 485 | image_encoder.requires_grad_(False) 486 | controlnet.requires_grad_(True) 487 | controlnet.train() 488 | 489 | # ip-adapter: insightface feature 490 | num_tokens = 16 491 | 492 | feature_proj_model = Resampler( 493 | dim=1280, 494 | depth=4, 495 | dim_head=64, 496 | heads=20, 497 | num_queries=num_tokens, 498 | embedding_dim=512, 499 | output_dim=unet.config.cross_attention_dim, 500 | ff_mult=4, 501 | ) 502 | 503 | # init adapter modules 504 | attn_procs = {} 505 | unet_sd = unet.state_dict() 506 | for name in unet.attn_processors.keys(): 507 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 508 | if name.startswith("mid_block"): 509 | hidden_size = unet.config.block_out_channels[-1] 510 | elif name.startswith("up_blocks"): 511 | block_id = int(name[len("up_blocks.")]) 512 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 513 | elif name.startswith("down_blocks"): 514 | block_id = int(name[len("down_blocks.")]) 515 | hidden_size = unet.config.block_out_channels[block_id] 516 | if cross_attention_dim is None: 517 | attn_procs[name] = AttnProcessor() 518 | else: 519 | layer_name = name.split(".processor")[0] 520 | weights = { 521 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], 522 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], 523 | } 524 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) 525 | attn_procs[name].load_state_dict(weights) 526 | unet.set_attn_processor(attn_procs) 527 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) 528 | # Instantiate InstantIDAdapter from pretrained model or from scratch. 529 | ip_adapter = InstantIDAdapter(unet, controlnet, feature_proj_model, adapter_modules, args.pretrained_ip_adapter_path) 530 | 531 | # Register a hook function to process the state of a specific module before saving. 532 | def save_model_hook(models, weights, output_dir): 533 | if accelerator.is_main_process: 534 | # find instance of InstantIDAdapter Model. 535 | for i, model_instance in enumerate(models): 536 | if isinstance(model_instance, InstantIDAdapter): 537 | # When saving a checkpoint, only save the ip-adapter and image_proj, do not save the unet. 538 | ip_adapter_state = { 539 | 'image_proj': model_instance.feature_proj_model.state_dict(), 540 | 'ip_adapter': model_instance.adapter_modules.state_dict(), 541 | } 542 | torch.save(ip_adapter_state, os.path.join(output_dir, 'pytorch_model.bin')) 543 | print(f"IP-Adapter Model weights saved in {os.path.join(output_dir, 'pytorch_model.bin')}") 544 | # Save controlnet separately. 545 | sub_dir = "controlnet" 546 | model_instance.controlnet.save_pretrained(os.path.join(output_dir, sub_dir)) 547 | print(f"Controlnet weights saved in {os.path.join(output_dir, sub_dir)}") 548 | # Remove the corresponding weights from the weights list because they have been saved separately. 549 | # Remember not to delete the corresponding model, otherwise, you will not be able to save the model 550 | # starting from the second epoch. 551 | weights.pop(i) 552 | break 553 | 554 | def load_model_hook(models, input_dir): 555 | # find instance of InstantIDAdapter Model. 556 | while len(models) > 0: 557 | model_instance = models.pop() 558 | if isinstance(model_instance, InstantIDAdapter): 559 | ip_adapter_path = os.path.join(input_dir, 'pytorch_model.bin') 560 | if os.path.exists(ip_adapter_path): 561 | ip_adapter_state = torch.load(ip_adapter_path) 562 | model_instance.feature_proj_model.load_state_dict(ip_adapter_state['image_proj']) 563 | model_instance.adapter_modules.load_state_dict(ip_adapter_state['ip_adapter']) 564 | sub_dir = "controlnet" 565 | model_instance.controlnet.from_pretrained(os.path.join(input_dir, sub_dir)) 566 | print(f"Model weights loaded from {ip_adapter_path}") 567 | else: 568 | print(f"No saved weights found at {ip_adapter_path}") 569 | 570 | 571 | # Register hook functions for saving and loading. 572 | accelerator.register_save_state_pre_hook(save_model_hook) 573 | accelerator.register_load_state_pre_hook(load_model_hook) 574 | 575 | weight_dtype = torch.float32 576 | if accelerator.mixed_precision == "fp16": 577 | weight_dtype = torch.float16 578 | elif accelerator.mixed_precision == "bf16": 579 | weight_dtype = torch.bfloat16 580 | # unet.to(accelerator.device, dtype=weight_dtype) # error 581 | vae.to(accelerator.device) # use fp32 582 | text_encoder.to(accelerator.device, dtype=weight_dtype) 583 | text_encoder_2.to(accelerator.device, dtype=weight_dtype) 584 | image_encoder.to(accelerator.device, dtype=weight_dtype) 585 | # controlnet.to(accelerator.device, dtype=weight_dtype) # error 586 | controlnet.to(accelerator.device) 587 | 588 | # trainable params 589 | params_to_opt = itertools.chain(ip_adapter.feature_proj_model.parameters(), 590 | ip_adapter.adapter_modules.parameters(), 591 | ip_adapter.controlnet.parameters()) 592 | 593 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) 594 | 595 | # dataloader 596 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution, 597 | center_crop=args.center_crop, image_root_path=args.data_root_path) 598 | total_data_size = len(train_dataset) 599 | 600 | train_dataloader = torch.utils.data.DataLoader( 601 | train_dataset, 602 | shuffle=True, 603 | collate_fn=collate_fn, 604 | batch_size=args.train_batch_size, 605 | num_workers=args.dataloader_num_workers, 606 | ) 607 | 608 | # Prepare everything with our `accelerator`. 609 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) 610 | 611 | # # Restore checkpoints 612 | # checkpoint_folders = [folder for folder in os.listdir(args.output_dir) if folder.startswith('checkpoint-')] 613 | # if checkpoint_folders: 614 | # # Extract step numbers from all checkpoints and find the maximum step number 615 | # global_step = max(int(folder.split('-')[-1]) for folder in checkpoint_folders if folder.split('-')[-1].isdigit()) 616 | # checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 617 | # # Load the checkpoint 618 | # accelerator.load_state(checkpoint_path) 619 | # else: 620 | # global_step = 0 621 | # print("No checkpoint folders found.") 622 | global_step = 0 623 | # Calculate steps per epoch and the current epoch and its step number 624 | # steps_per_epoch = total_data_size // (args.train_batch_size * num_devices) 625 | # current_epoch = global_step // steps_per_epoch 626 | # current_step_in_epoch = global_step % steps_per_epoch 627 | 628 | # Training loop 629 | for epoch in range(0, args.num_train_epochs): 630 | begin = time.perf_counter() 631 | for step, batch in enumerate(train_dataloader): 632 | load_data_time = time.perf_counter() - begin 633 | with accelerator.accumulate(ip_adapter): 634 | # Convert images to latent space 635 | with torch.no_grad(): 636 | # vae of sdxl should use fp32 637 | latents = vae.encode( 638 | batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample() 639 | latents = latents * vae.config.scaling_factor 640 | latents = latents.to(accelerator.device, dtype=weight_dtype) 641 | 642 | # Sample noise that we'll add to the latents 643 | noise = torch.randn_like(latents) 644 | if args.noise_offset: 645 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise 646 | noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to( 647 | accelerator.device, dtype=weight_dtype) 648 | 649 | bsz = latents.shape[0] 650 | # Sample a random timestep for each image 651 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) 652 | timesteps = timesteps.long() 653 | 654 | # Add noise to the latents according to the noise magnitude at each timestep 655 | # (this is the forward diffusion process) 656 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 657 | 658 | # get feature embeddings, with cfg 659 | feat_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype) 660 | kps_images = batch["kps_images"].to(accelerator.device, dtype=weight_dtype) 661 | 662 | # for other experiments 663 | # clip_images = [] 664 | # for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]): 665 | # if drop_image_embed == 1: 666 | # clip_images.append(torch.zeros_like(clip_image)) 667 | # else: 668 | # clip_images.append(clip_image) 669 | # clip_images = torch.stack(clip_images, dim=0) 670 | # with torch.no_grad(): 671 | # image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), 672 | # output_hidden_states=True).hidden_states[-2] 673 | 674 | with torch.no_grad(): 675 | encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True) 676 | text_embeds = encoder_output.hidden_states[-2] 677 | encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True) 678 | pooled_text_embeds = encoder_output_2[0] 679 | text_embeds_2 = encoder_output_2.hidden_states[-2] 680 | text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat 681 | 682 | # add cond 683 | add_time_ids = [ 684 | batch["original_size"].to(accelerator.device), 685 | batch["crop_coords_top_left"].to(accelerator.device), 686 | batch["target_size"].to(accelerator.device), 687 | ] 688 | add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype) 689 | unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids} 690 | 691 | noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, feat_embeds, kps_images) 692 | 693 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") 694 | 695 | # Gather the losses across all processes for logging (if we use distributed training). 696 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() 697 | 698 | # Backpropagate 699 | accelerator.backward(loss) 700 | optimizer.step() 701 | optimizer.zero_grad() 702 | 703 | now = datetime.now() 704 | formatted_time = now.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3] 705 | if accelerator.is_main_process and step % 10 == 0: 706 | print("[{}]: Epoch {}, global_step {}, step {}, data_time: {}, time: {}, step_loss: {}".format( 707 | formatted_time, epoch, global_step, step, load_data_time, time.perf_counter() - begin, 708 | avg_loss)) 709 | 710 | global_step += 1 711 | if accelerator.is_main_process and global_step % args.save_steps == 0: 712 | # before saving state, check if this save would set us over the `checkpoints_total_limit` 713 | if args.checkpoints_total_limit is not None: 714 | checkpoints = os.listdir(args.output_dir) 715 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] 716 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) 717 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints 718 | if len(checkpoints) >= args.checkpoints_total_limit: 719 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 720 | removing_checkpoints = checkpoints[0:num_to_remove] 721 | print( 722 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints") 723 | print(f"removing checkpoints: {', '.join(removing_checkpoints)}") 724 | 725 | for removing_checkpoint in removing_checkpoints: 726 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) 727 | shutil.rmtree(removing_checkpoint) 728 | 729 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 730 | accelerator.save_state(save_path) 731 | 732 | begin = time.perf_counter() 733 | 734 | 735 | if __name__ == "__main__": 736 | main() 737 | -------------------------------------------------------------------------------- /train_instantId_sdxl.sh: -------------------------------------------------------------------------------- 1 | # SDXL Model 2 | export MODEL_NAME=.../stable-diffusion-xl-base-1.0 3 | # CLIP Model 4 | export ENCODER_NAME=".../image_encoder" 5 | # pretrained InstantID model 6 | export ADAPTOR_NAME=".../checkpoints/ip-adapter.bin" 7 | export CONTROLNET_NAME=".../checkpoints/ControlNetModel" 8 | 9 | # This json file ' format: 10 | # {"file_name": "/data/train_data/images_part0/84634599103.jpg", "additional_feature": "myolv1,a man with glasses and a 11 | # tie on posing for a picture in front of a window with a building in the background, Andrew Law, johnson ting, a picture, 12 | # mannerism", "bbox": [-31.329412311315536, 160.6865997314453, 496.19240215420723, 688.1674156188965], 13 | # "landmarks": [[133.046875, 318], [319.3125, 318], [221.0625, 422], [153.515625, 535], [298.84375, 537]], 14 | # "insightface_feature_file": "/data/feature_data/images_part0/84634599103.bin"} 15 | export JSON_FILE=".../CrossFaceID.jsonl" 16 | 17 | 18 | # Output 19 | export OUTPUT_DIR="..." 20 | 21 | 22 | echo "OUTPUT_DIR: $OUTPUT_DIR" 23 | #accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \ 24 | #CUDA_VISIBLE_DEVICES=0 \ 25 | 26 | accelerate launch --mixed_precision="fp16" train_instantId_sdxl.py \ 27 | --pretrained_model_name_or_path $MODEL_NAME \ 28 | --controlnet_model_name_or_path $CONTROLNET_NAME \ 29 | --image_encoder_path $ENCODER_NAME \ 30 | --pretrained_ip_adapter_path $ADAPTOR_NAME \ 31 | --data_json_file $JSON_FILE \ 32 | --output_dir $OUTPUT_DIR \ 33 | --clip_proc_mode orig_crop \ 34 | --mixed_precision="fp16" \ 35 | --resolution 512 \ 36 | --learning_rate 1e-5 \ 37 | --weight_decay=0.01 \ 38 | --num_train_epochs 5 \ 39 | --train_batch_size 8 \ 40 | --dataloader_num_workers=8 \ 41 | --checkpoints_total_limit 20 \ 42 | --save_steps 10000 43 | 44 | 45 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import torch 5 | from torchvision import transforms 6 | from PIL import Image 7 | from transformers import CLIPImageProcessor 8 | import numpy as np 9 | import cv2 10 | import math 11 | import PIL 12 | 13 | 14 | def crop_with_expanded_size(image, crop_coords, expand_factor=1.1): 15 | # 打开图像文件 16 | # img = Image.open(path) 17 | 18 | # 原始图像尺寸 19 | original_width, original_height = image.size 20 | 21 | # 已知的裁剪坐标 (left, top, right, bottom) 22 | # crop_coords = (left, top, right, bottom) 23 | 24 | # 计算原始裁剪区域的中心点 25 | center_x = (crop_coords[0] + crop_coords[2]) / 2 26 | center_y = (crop_coords[1] + crop_coords[3]) / 2 27 | 28 | # 计算原始裁剪区域的宽度和高度 29 | original_crop_width = crop_coords[2] - crop_coords[0] 30 | original_crop_height = crop_coords[3] - crop_coords[1] 31 | 32 | # 计算新的裁剪区域的宽度和高度 33 | new_crop_width = original_crop_width * expand_factor 34 | new_crop_height = original_crop_height * expand_factor 35 | 36 | # 计算新的裁剪坐标,确保不会超出图像边界 37 | new_left = max(center_x - new_crop_width / 2, 0) 38 | new_top = max(center_y - new_crop_height / 2, 0) 39 | new_right = min(center_x + new_crop_width / 2, original_width) 40 | new_bottom = min(center_y + new_crop_height / 2, original_height) 41 | 42 | # 新的裁剪坐标 43 | new_crop_coords = (int(new_left), int(new_top), int(new_right), int(new_bottom)) 44 | 45 | # 裁剪图像 46 | cropped_img = image.crop(new_crop_coords) 47 | 48 | return cropped_img 49 | 50 | 51 | class CropToRatioTransform(object): 52 | def __init__(self, target_aspect_ratio=512 / 640): 53 | self.target_aspect_ratio = target_aspect_ratio 54 | 55 | def __call__(self, img): 56 | # 计算当前宽高比 57 | current_w, current_h = img.size 58 | current_aspect_ratio = current_w / current_h 59 | # print(current_aspect_ratio) 60 | 61 | # 如果当前宽高比大于目标宽高比,则截取宽度至目标宽高比 62 | if current_aspect_ratio > self.target_aspect_ratio: 63 | # 计算目标宽度 64 | target_w = int(current_h * self.target_aspect_ratio) 65 | # 计算需要截取的区域 66 | left = (current_w - target_w) // 2 67 | right = left + target_w 68 | # 截取图像 69 | img = img.crop((left, 0, right, current_h)) 70 | 71 | return img 72 | 73 | 74 | class TopCropTransform(object): 75 | def __init__(self, crop_size): 76 | # crop_size可以是单个整数或包含两个整数的元组/列表 77 | if isinstance(crop_size, int): 78 | self.crop_height = crop_size 79 | self.crop_width = crop_size 80 | elif isinstance(crop_size, (list, tuple)) and len(crop_size) == 2: 81 | self.crop_height, self.crop_width = crop_size 82 | else: 83 | raise TypeError('crop_size must be an int or a list/tuple of length 2.') 84 | 85 | def __call__(self, img): 86 | # 检查提供的crop_size是否不大于图像的尺寸 87 | w, h = img.size 88 | if self.crop_width > w or self.crop_height > h: 89 | raise ValueError('crop_size must be smaller than the dimensions of the image.') 90 | 91 | top = 0 92 | center = w // 2 93 | crop_width, crop_height = self.crop_width, self.crop_height 94 | left = center - crop_width // 2 95 | 96 | # 防止坐标超出图像边界 97 | left = max(0, left) 98 | right = min(w, left + crop_width) 99 | bottom = min(h, top + crop_height) 100 | 101 | # 执行裁剪 102 | img = img.crop((left, top, right, bottom)) 103 | return img 104 | 105 | 106 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]): 107 | stickwidth = 4 108 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]]) 109 | kps = np.array(kps) 110 | 111 | w, h = image_pil.size 112 | out_img = np.zeros([h, w, 3]) 113 | 114 | for i in range(len(limbSeq)): 115 | index = limbSeq[i] 116 | color = color_list[index[0]] 117 | 118 | x = kps[index][:, 0] 119 | y = kps[index][:, 1] 120 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5 121 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1])) 122 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 123 | 360, 1) 124 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color) 125 | out_img = (out_img * 0.6).astype(np.uint8) 126 | 127 | for idx_kp, kp in enumerate(kps): 128 | color = color_list[idx_kp] 129 | x, y = kp 130 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1) 131 | 132 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8)) 133 | return out_img_pil 134 | 135 | class MyDataset(torch.utils.data.Dataset): 136 | def __init__(self, json_file, tokenizer, size=512, crop_size=(640, 512), center_crop=False, image_root_path=""): 137 | super().__init__() 138 | self.tokenizer = tokenizer 139 | self.size = size #短边缩放到size 140 | self.image_root_path = image_root_path 141 | # 创建一个空列表来存储解析后的数据 142 | self.data = [] 143 | # 读取并解析JSON文件的每一行 144 | with open(json_file, 'r') as f: 145 | for line in f: 146 | # 解析JSON数据并添加到列表中 147 | self.data.append(json.loads(line)) 148 | 149 | self.transform = transforms.Compose([ 150 | CropToRatioTransform(target_aspect_ratio=crop_size[1] / crop_size[0]), 151 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 152 | transforms.CenterCrop(size) if center_crop else TopCropTransform(crop_size), 153 | transforms.ToTensor(), 154 | transforms.Normalize([0.5], [0.5]), 155 | ]) 156 | 157 | self.kps_transform = transforms.Compose([ 158 | CropToRatioTransform(target_aspect_ratio=crop_size[1] / crop_size[0]), 159 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), 160 | transforms.CenterCrop(size) if center_crop else TopCropTransform(crop_size), 161 | transforms.ToTensor(), 162 | ]) 163 | 164 | self.clip_image_processor = CLIPImageProcessor() 165 | 166 | def __getitem__(self, idx): 167 | item = self.data[idx] 168 | image_file = item["file_name"] 169 | text = item["additional_feature"] 170 | bbox = item['bbox'] 171 | landmarks = item['landmarks'] 172 | feature_file = item["penult_id_embed_file"] 173 | clip_from_seg_file = item["clip_from_seg_file"] 174 | clip_from_orig_file = item["clip_from_orig_file"] 175 | seg_map_orig_file = item["seg_map_orig_file"] 176 | 177 | # read image 178 | raw_image = Image.open(os.path.join(self.image_root_path, image_file)) 179 | image = self.transform(raw_image.convert("RGB")) 180 | kps_image = draw_kps(raw_image.convert("RGB"), landmarks) 181 | kps_image = self.kps_transform(kps_image) 182 | 183 | # crop image to clip 184 | crop_image = crop_with_expanded_size(raw_image, bbox) 185 | clip_image = self.clip_image_processor(images=crop_image, return_tensors="pt").pixel_values 186 | # load face feature 187 | face_id_embed = torch.load(os.path.join(self.image_root_path, feature_file), map_location="cpu") 188 | face_id_embed = torch.from_numpy(face_id_embed) 189 | 190 | # 定义所有可能的丢弃组合及其概率 191 | drop_combinations = { 192 | ('text',): 0.05, 193 | ('feature',): 0.05, 194 | ('feature', 'text'): 0.05, 195 | } 196 | # drop_combinations = { 197 | # ('text',): 0.05, 198 | # ('feature',): 0.04, 199 | # ('image',): 0.04, 200 | # ('image', 'feature'): 0.03, 201 | # ('image', 'text'): 0.03, 202 | # ('feature', 'text'): 0.03, 203 | # ('image', 'feature', 'text'): 0.03 204 | # } 205 | # 计算剩余概率 206 | remaining_probability = 1 - sum(drop_combinations.values()) 207 | # 添加新的键值对,对应不丢弃任何条件 208 | drop_combinations[()] = remaining_probability 209 | # 根据概率选择一个丢弃组合 210 | drop_choice = random.choices(list(drop_combinations.keys()), weights=list(drop_combinations.values()), k=1)[0] 211 | # 根据选择的组合来丢弃对象 212 | drop_text_embed = int('text' in drop_choice) 213 | drop_feature_embed = int('feature' in drop_choice) 214 | drop_image_embed = int('image' in drop_choice) 215 | 216 | # CFG处理 217 | if drop_text_embed: 218 | text = "" 219 | if drop_feature_embed: 220 | face_id_embed = torch.zeros_like(face_id_embed) 221 | if drop_image_embed: 222 | pass # drop in train loop 223 | 224 | # get text and tokenize 225 | text_input_ids = self.tokenizer( 226 | text, 227 | max_length=self.tokenizer.model_max_length, 228 | padding="max_length", 229 | truncation=True, 230 | return_tensors="pt" 231 | ).input_ids 232 | 233 | return { 234 | "image": image, 235 | "kps_image": kps_image, 236 | "clip_image": clip_image, 237 | "text_input_ids": text_input_ids, 238 | "face_id_embed": face_id_embed, 239 | "drop_image_embed": drop_image_embed 240 | } 241 | 242 | def __len__(self): 243 | return len(self.data) 244 | 245 | 246 | def collate_fn(data): 247 | images = torch.stack([example["image"] for example in data]) 248 | kps_images = torch.stack([example["kps_image"] for example in data]) 249 | 250 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0) 251 | 252 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) 253 | face_id_embed = torch.stack([example["face_id_embed"] for example in data]) 254 | drop_image_embeds = [example["drop_image_embed"] for example in data] 255 | 256 | return { 257 | "images": images, 258 | "kps_images": kps_images, 259 | "clip_images": clip_images, 260 | "text_input_ids": text_input_ids, 261 | "face_id_embed": face_id_embed, 262 | "drop_image_embeds": drop_image_embeds 263 | } 264 | --------------------------------------------------------------------------------