├── README.md
├── assets
├── examples.png
├── results_customization.png
└── results_fidelity.png
├── download.py
├── download.sh
├── get_face_info.py
├── infer.py
├── infer_from_pkl.py
├── ip_adapter
├── __pycache__
│ ├── attention_processor.cpython-310.pyc
│ ├── resampler.cpython-310.pyc
│ └── utils.cpython-310.pyc
├── attention_processor.py
├── mlp.py
├── resampler.py
└── utils.py
├── pipeline_stable_diffusion_xl_instantid.py
├── train_instantId_sdxl.py
├── train_instantId_sdxl.sh
└── utils
└── dataset.py
/README.md:
--------------------------------------------------------------------------------
1 | # Turn That Frown Upside Down: FaceID Customization via Cross-Training Data
2 |
3 |
4 |
5 | This repository contains resources referenced in the paper [Turn That Frown Upside Down: FaceID Customization via Cross-Training Data](https://arxiv.org/abs/2501.15407v1).
6 |
7 | If you find this repository helpful, please cite the following:
8 | ```latex
9 | @misc{wang2025turnfrownupsidedown,
10 | title={Turn That Frown Upside Down: FaceID Customization via Cross-Training Data},
11 | author={Shuhe Wang and Xiaoya Li and Xiaofei Sun and Guoyin Wang and Tianwei Zhang and Jiwei Li and Eduard Hovy},
12 | year={2025},
13 | eprint={2501.15407},
14 | archivePrefix={arXiv},
15 | primaryClass={cs.CV},
16 | url={https://arxiv.org/abs/2501.15407},
17 | }
18 | ```
19 |
20 |
21 |
22 | ## 🥳 News
23 |
24 | **Stay tuned! More related work will be updated!**
25 | * **[25 Jan, 2025]** The repository is created.
26 | * **[25 Jan, 2025]** We release the first version of the paper.
27 |
28 |
29 | ## Links
30 | - [Turn That Frown Upside Down: FaceID Customization via Cross-Training Data](#turn-that-frown-upside-down-faceid-customization-via-cross-training-data)
31 | - [🥳 News](#-news)
32 | - [Links](#links)
33 | - [Introduction](#introduction)
34 | - [Comparison with Previous Works](#comparison-with-previous-works)
35 | - [FaceID Fidelity](#faceid-fidelity)
36 | - [FaceID Customization](#faceid-customization)
37 | - [Released CrossFaceID dataset](#released-crossfaceid-dataset)
38 | - [Released FaceID Customization Models](#released-faceid-customization-models)
39 | - [Usage](#usage)
40 | - [Training](#training)
41 | - [Step 1. Download Required Models](#step-1-download-required-models)
42 | - [Step 2. Download Required Dataset](#step-2-download-required-dataset)
43 | - [Step 3. Training](#step-3-training)
44 | - [Inference](#inference)
45 | - [Contact](#contact)
46 |
47 |
48 |
49 | ## Introduction
50 |
51 | CrossFaceID is the first large-scale, high-quality, and publicly available dataset specifically designed to improve the facial modification capabilities of FaceID customization models. Specifically, CrossFaceID consists of 40,000 text-image pairs from approximately 2,000 persons, with each person represented by around 20 images showcasing diverse facial attributes such as poses, expressions, angles, and adornments. During the training stage, a specific face of a person is used as input, and the FaceID customization model is forced to generate another image of the same person but with altered facial features. This allows the FaceID customization model to acquire the ability to personalize and modify known facial features during the inference stage.
52 |
53 |
54 |

55 |
56 |
57 |
58 | ## Comparison with Previous Works
59 |
60 |
61 | ### FaceID Fidelity
62 |
63 |
64 |

65 |
66 |
67 | The results demonstrate the performance of FaceID customization models in maintaining FaceID fidelity. For models, “InstantID” refers to the official InstantID model, while “InstantID + CrossFaceID” represents the model further fine-tuned on our CrossFaceID dataset. “LAION” denotes the InstantID model pre-trained on our curated LAION dataset, and “LAION + CrossFaceID” refers to the model further trained on the CrossFaceID dataset. These results indicate that (1) for both the official InstantID model and the LAION-trained model, the ability to maintain FaceID fidelity remains consistent before and after fine-tuning on our CrossFaceID dataset, and (2) the model trained on our curated LAION dataset achieves comparable performance to the official InstantID model in preserving FaceID fidelity.
68 |
69 | ### FaceID Customization
70 |
71 |
72 |

73 |
74 |
75 | The results of the performance for FaceID customization models in customizing or editing FaceID. Here, "InstantID" represents the official InstantID model, while "InstantID + CrossFaceID" refers to the model fine-tuned on our CrossFaceID dataset. Similarly, "LAION" denotes the InstantID model pre-trained on our curated LAION dataset, and "LAION + CrossFaceID" refers to the model further fine-tuned on the CrossFaceID dataset. From these results, we can clearly observe an improvement in the models' ability to customize FaceID after being fine-tuned on our constructed CrossFaceID dataset.
76 |
77 |
78 |
79 | ## Released CrossFaceID dataset
80 |
81 | Our CrossFaceID dataset is available on [Huggingface](https://huggingface.co/datasets/Super-shuhe/CrossFaceID). It comprises 40,000 text-image pairs from approximately 2,000 individuals, with each person represented by around 20 images that capture various facial attributes, including different poses, expressions, angles, and adornments.
82 |
83 |
84 | ## Released FaceID Customization Models
85 |
86 | The trained InstantID model is available [here](https://huggingface.co/Super-shuhe/CrossFaceID-InstantID).
87 |
88 | ## Usage
89 |
90 | ### Training
91 | As the original InstantID repository (https://github.com/InstantID/InstantID) doesn't contain training codes, we follow [this repository](https://github.com/MFaceTech/InstantID?tab=readme-ov-file) to train our own InstantID.
92 |
93 | #### Step 1. Download Required Models
94 |
95 | You can directly download the model from [Huggingface](https://huggingface.co/InstantX/InstantID).
96 | You also can download the model in python script:
97 |
98 | ```python
99 | from huggingface_hub import hf_hub_download
100 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir="./checkpoints")
101 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir="./checkpoints")
102 | hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir="./checkpoints")
103 | ```
104 |
105 | If you cannot access to Huggingface, you can use [hf-mirror](https://hf-mirror.com/) to download models.
106 | ```python
107 | export HF_ENDPOINT=https://hf-mirror.com
108 | huggingface-cli download --resume-download InstantX/InstantID --local-dir checkpoints
109 | ```
110 |
111 | For face encoder, you need to manutally download via this [URL](https://github.com/deepinsight/insightface/issues/1896#issuecomment-1023867304) to `models/antelopev2` as the default link is invalid. Once you have prepared all models, the folder tree should be like:
112 |
113 | ```
114 | .
115 | ├── models
116 | ├── checkpoints
117 | ├── ip_adapter
118 | ├── pipeline_stable_diffusion_xl_instantid.py
119 | ├── download.py
120 | ├── download.sh
121 | ├── get_face_info.py
122 | ├── infer_from_pkl.py
123 | ├── infer.py
124 | ├── train_instantId_sdxl.py
125 | ├── train_instantId_sdxl.sh
126 | └── README.md
127 | ```
128 |
129 |
130 | #### Step 2. Download Required Dataset
131 |
132 | Please download our released dataset from [Huggingface](https://huggingface.co/datasets/Super-shuhe/CrossFaceID).
133 |
134 | #### Step 3. Training
135 |
136 | 1. Fill the `MODEL_NAME`, `ENCODER_NAME`, `ADAPTOR_NAME`, `CONTROLNET_NAME`, and `JSON_FILE` into our provided training script `./train_instantId_sdxl.sh`, where:
137 | 1. `MODEL_NAME` refers to the backboned diffusion model, e.g., `stable-diffusion-xl-base-1.0`
138 | 2. `ENCODER_NAME` refers to the downloaded encoder, e.g., `image_encoder`
139 | 3. `ADAPTOR_NAME` and `CONTROLNET_NAME` refers to the pre-trained official InstantID model, e.g., `checkpoints/ip-adapter.bin` and `checkpoints/ControlNetModel`
140 | 4. `JSON_FILE` refers to our released CrossFaceID dataset.
141 | 2. Run the training scirpt, such as: `bash ./train_instantId_sdxl.sh`
142 |
143 |
144 | ### Inference
145 |
146 | 1. Fill the `base_model_path`, `face_adapter`, `controlnet_path`, `prompt0`, and `face_image` into our provided inference script `./infer_from_pkl.py`, where:
147 | 1. `base_model_path` refers to the backboned diffusion model, e.g., `stable-diffusion-xl-base-1.0`
148 | 2. `face_adapter` and `controlnet_path` refer to your trained model e.g., `checkpoints/ip-adapter.bin` and `checkpoints/ControlNetModel`
149 | 3. `prompt0` and `face_image` refer to your test sample.
150 | 2. Run the training script, such as: `python ./infer_from_pkl.py`
151 |
152 |
153 | ## Contact
154 |
155 | If you have any issues or questions about this repo, feel free to contact shuhewang@student.unimelb.edu.au
156 |
--------------------------------------------------------------------------------
/assets/examples.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/assets/examples.png
--------------------------------------------------------------------------------
/assets/results_customization.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/assets/results_customization.png
--------------------------------------------------------------------------------
/assets/results_fidelity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/assets/results_fidelity.png
--------------------------------------------------------------------------------
/download.py:
--------------------------------------------------------------------------------
1 | import os
2 | from huggingface_hub import hf_hub_download
3 | local_dir = "checkpoints"
4 | os.makedirs(local_dir, exist_ok=True)
5 |
6 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/config.json", local_dir=local_dir)
7 | hf_hub_download(repo_id="InstantX/InstantID", filename="ControlNetModel/diffusion_pytorch_model.safetensors", local_dir=local_dir)
8 | hf_hub_download(repo_id="InstantX/InstantID", filename="ip-adapter.bin", local_dir=local_dir)
9 |
10 |
--------------------------------------------------------------------------------
/download.sh:
--------------------------------------------------------------------------------
1 | export HF_ENDPOINT=https://hf-mirror.com
2 | huggingface-cli download --resume-download InstantX/InstantID --local-dir checkpoints
3 |
--------------------------------------------------------------------------------
/get_face_info.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | from PIL import Image
4 | import pickle
5 | from diffusers.utils import load_image
6 | from insightface.app import FaceAnalysis
7 |
8 |
9 | def resize_img(input_image, max_side=1280, min_side=1024, size=None,
10 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
11 |
12 | w, h = input_image.size
13 | if size is not None:
14 | w_resize_new, h_resize_new = size
15 | else:
16 | ratio = min_side / min(h, w)
17 | w, h = round(ratio*w), round(ratio*h)
18 | ratio = max_side / max(h, w)
19 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
20 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
21 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
22 | input_image = input_image.resize([w_resize_new, h_resize_new], mode)
23 |
24 | if pad_to_max_side:
25 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
26 | offset_x = (max_side - w_resize_new) // 2
27 | offset_y = (max_side - h_resize_new) // 2
28 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
29 | input_image = Image.fromarray(res)
30 | return input_image
31 |
32 |
33 |
34 | root = "model_zoo/instantID/"
35 | app = FaceAnalysis(name='antelopev2', root=root, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
36 |
37 | app.prepare(ctx_id=0, det_size=(640, 640))
38 |
39 |
40 | face_image = load_image("./examples/wenyongshan.png")
41 | # face_image = load_image("./examples/zhuyilong.jpg")
42 |
43 |
44 | face_image = resize_img(face_image)
45 | face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
46 | face_info = sorted(face_info, key=lambda x :(x['bbox'][2 ] -x['bbox'][0] ) *x['bbox'][3 ] -x['bbox'][1])[-1] # only use the maximum face
47 |
48 | print(type(face_info))
49 |
50 | # 创建一个示例字典,其值是 dtype 为 float32 的 NumPy 数组
51 | data_dict = {
52 | 'embedding': face_info['embedding'],
53 | 'kps': face_info['kps']
54 | }
55 |
56 | # 使用 pickle 序列化字典并保存到文件
57 | with open('examples/face_info.pkl', 'wb') as pickle_file:
58 | pickle.dump(data_dict, pickle_file)
--------------------------------------------------------------------------------
/infer.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import PIL
6 | from diffusers.utils import load_image
7 | from diffusers.models import ControlNetModel
8 | import math
9 | from insightface.app import FaceAnalysis
10 | from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline
11 |
12 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
13 | stickwidth = 4
14 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
15 | kps = np.array(kps)
16 |
17 | w, h = image_pil.size
18 | out_img = np.zeros([h, w, 3])
19 |
20 | for i in range(len(limbSeq)):
21 | index = limbSeq[i]
22 | color = color_list[index[0]]
23 |
24 | x = kps[index][:, 0]
25 | y = kps[index][:, 1]
26 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
27 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
28 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
29 | 360, 1)
30 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
31 | out_img = (out_img * 0.6).astype(np.uint8)
32 |
33 | for idx_kp, kp in enumerate(kps):
34 | color = color_list[idx_kp]
35 | x, y = kp
36 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
37 |
38 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
39 | return out_img_pil
40 | def resize_img(input_image, max_side=1280, min_side=1024, size=None,
41 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
42 |
43 | w, h = input_image.size
44 | if size is not None:
45 | w_resize_new, h_resize_new = size
46 | else:
47 | ratio = min_side / min(h, w)
48 | w, h = round(ratio*w), round(ratio*h)
49 | ratio = max_side / max(h, w)
50 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
51 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
52 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
53 | input_image = input_image.resize([w_resize_new, h_resize_new], mode)
54 |
55 | if pad_to_max_side:
56 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
57 | offset_x = (max_side - w_resize_new) // 2
58 | offset_y = (max_side - h_resize_new) // 2
59 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
60 | input_image = Image.fromarray(res)
61 | return input_image
62 |
63 |
64 | if __name__ == "__main__":
65 |
66 | root = "aigc_models/InstantID"
67 | app = FaceAnalysis(name='antelopev2', root=root, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
68 |
69 | app.prepare(ctx_id=0, det_size=(640, 640))
70 |
71 | # Path to InstantID models
72 | # face_adapter = f'./checkpoints/ip-adapter.bin'
73 | # controlnet_path = f'./checkpoints/ControlNetModel'
74 | face_adapter = "InstantID/checkpoints/ip-adapter.bin"
75 | controlnet_path = "InstantID/checkpoints/ControlNetModel"
76 |
77 | # Load pipeline
78 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
79 |
80 | # base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0'
81 | base_model_path = "huggingface/hub/models--stabilityai--stable-diffusion-xl-base-1.0/snapshots/462165984030d82259a11f4367a4eed129e94a7b"
82 |
83 | pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
84 | base_model_path,
85 | controlnet=controlnet,
86 | torch_dtype=torch.float16,
87 | )
88 | pipe.cuda()
89 | pipe.load_ip_adapter_instantid(face_adapter)
90 |
91 | prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
92 | n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
93 |
94 | face_image = load_image("./examples/yann-lecun_resize.jpg")
95 | face_image = resize_img(face_image)
96 |
97 | face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
98 | face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*x['bbox'][3]-x['bbox'][1])[-1] # only use the maximum face
99 | face_emb = face_info['embedding']
100 | face_kps = draw_kps(face_image, face_info['kps'])
101 |
102 | pipe.set_ip_adapter_scale(0.8)
103 | image = pipe(
104 | prompt=prompt,
105 | negative_prompt=n_prompt,
106 | image_embeds=face_emb,
107 | image=face_kps,
108 | controlnet_conditioning_scale=0.8,
109 | num_inference_steps=30,
110 | guidance_scale=5,
111 | ).images[0]
112 |
113 | image.save('result.jpg')
--------------------------------------------------------------------------------
/infer_from_pkl.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import pickle
6 | from diffusers.utils import load_image
7 | from diffusers.models import ControlNetModel
8 |
9 | # from insightface.app import FaceAnalysis
10 | from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
11 |
12 |
13 |
14 | import json
15 | import os
16 | from tqdm import tqdm
17 |
18 | from insightface.app import FaceAnalysis
19 |
20 |
21 |
22 |
23 | def face_extraction(image):
24 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
25 | app.prepare(ctx_id=0, det_size=(640, 640))
26 |
27 | # image = load_image(image_path)
28 |
29 | # prepare face emb
30 | face_info = app.get(cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR))
31 |
32 | face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
33 | face_emb = face_info['embedding']
34 | face_kps = face_info['kps'].tolist()
35 | face_bbox = face_info['bbox'].tolist()
36 |
37 | return face_emb, face_kps
38 |
39 |
40 |
41 |
42 | def resize_img(input_image, max_side=1280, min_side=1024, size=None,
43 | pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):
44 |
45 | w, h = input_image.size
46 | if size is not None:
47 | w_resize_new, h_resize_new = size
48 | else:
49 | ratio = min_side / min(h, w)
50 | w, h = round(ratio*w), round(ratio*h)
51 | ratio = max_side / max(h, w)
52 | input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
53 | w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
54 | h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
55 | input_image = input_image.resize([w_resize_new, h_resize_new], mode)
56 |
57 | if pad_to_max_side:
58 | res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
59 | offset_x = (max_side - w_resize_new) // 2
60 | offset_y = (max_side - h_resize_new) // 2
61 | res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
62 | input_image = Image.fromarray(res)
63 | return input_image
64 |
65 |
66 | if __name__ == "__main__":
67 | # user model
68 |
69 | base_model_path = ".../stable-diffusion-xl-base-1.0"
70 | face_adapter = ".../checkpoints/ip-adapter.bin"
71 | controlnet_path = ".../checkpoints/ControlNetModel/"
72 |
73 |
74 |
75 | # Load pipeline
76 | controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
77 |
78 | pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
79 | base_model_path,
80 | controlnet=controlnet,
81 | torch_dtype=torch.float16,
82 | )
83 | pipe.cuda()
84 | pipe.load_ip_adapter_instantid(face_adapter)
85 |
86 | # prompt0 = "jpeg artifacts,asian man,early twenties,casual pose,low quality,blurry,poorly drawn,worst quality,photography,vintage,"
87 | # prompt1 = "photography aesthetic, elegant portrait photography, young man exuding classical beauty and sophistication, pale skin, delicate features, captivating gaze, slightly parted lips, contemporary fashion fused with vintage elements, naturally-lit setting, warm neutral backdrop, soft, diffused lighting highlighting subject's face, subtle shadowing, clear, high-resolution photo,film grain"
88 | # prompt2 = "portrait, young asian woman sleek, shoulder-length black hair gazes contemplatively towards camera, dark eyes slightlycast, lips closed neutral expression hints poise introspection, softness skin enhanced,"
89 | # # prompt2 = "portrait, young asian man, gazes contemplatively towards camera, dark eyes slightlycast, lips closed neutral expression hints poise introspection, softness skin enhanced,"
90 | # prompt3 = "photography aesthetic,Studio portrait photography, striking sexy female model, serene countenance, youthful appearance, light skin, elegant goth makeup emphasizing sharp eyebrows, pronounced eyelashes, dark lipstick, jewelry showcasing sparkling goth earrings,fashionably styled dark hair, exposing clear skin, backdrop complementing the subject's attire, natural light simulation enhancing soft shadows, subtle retouching ensuring smooth skin textures, harmonious color palette imbued with warm, muted tones,film grain,"
91 | # # prompt3 = "photography aesthetic,Studio portrait photography, striking male model, serene countenance, youthful appearance, light skin, elegant goth makeup emphasizing sharp eyebrows, pronounced eyelashes, dark lipstick,fashionably styled dark hair, exposing clear skin, backdrop complementing the subject's attire, natural light simulation enhancing soft shadows, subtle retouching ensuring smooth skin textures, harmonious color palette imbued with warm, muted tones,film grain,"
92 | # prompt4 = "a beautiful woman dressed in casual attire, looking energetic and vibrant, positive and upbeat atmosphere"
93 | # # prompt4 = "a man dressed in casual attire, looking energetic and vibrant, positive and upbeat atmosphere"
94 | # prompt5 = "A serene ambience,traditional Chinese aesthetic,a young woman,gentle expression,concentration on instrument,traditional Chinese guzheng,flowing pale blue and white hanfu with delicate floral accents,a backdrop of lush foliage,soft natural lighting,harmonious color palette of cool tones,ancient heritage,cultural reverence,timeless elegance,poised positioning amidst rocks,black hair adorned with classical hairpin,embodiment of classical Chinese music and beauty,tranquility amidst nature,subtlety in details,fine craftsmanship of the guzheng,ethereal atmosphere,cultural homage."
95 | # # prompt5 = "A serene ambience,traditional Chinese aesthetic,a young man,gentle expression,concentration on instrument,traditional Chinese guzheng,flowing pale blue and white hanfu with delicate floral accents,a backdrop of lush foliage,soft natural lighting,harmonious color palette of cool tones,ancient heritage,cultural reverence,timeless elegance,poised positioning amidst rocks,black hair adorned with classical hairpin,embodiment of classical Chinese music and beauty,tranquility amidst nature,subtlety in details,fine craftsmanship of the guzheng,ethereal atmosphere,cultural homage."
96 |
97 |
98 | # prompt0 = "a beautiful girl wearing casual shirt in a garden and smiling"
99 |
100 | # prompt0 = "a beautiful girl wearing casual shirt in a garden"
101 |
102 | prompt0 = "facing one side, wearing red sunglasses, a golden chain, and a green cap"
103 |
104 | # prompt0 = "A man on the red carpet"
105 |
106 | # prompt0 = "a man is standing here, his eyes sharp and full of spirit."
107 |
108 | # prompt0 = "A man holding a cup of coffee"
109 |
110 | # prompt0 = "A man wearing casual shirt in a garden"
111 |
112 | # prompt0 = "a beautiful woman dressed in casual attire, looking energetic and vibrant, positive and upbeat atmosphere"
113 |
114 | # prompt0 = "side face, wearing red sunglasses, a golden chain, and a green cap"
115 |
116 | # prompt0 = "a man is wearing a sunglasses"
117 |
118 | # prompt0 = "a girl looks very sad"
119 |
120 | # prompt0 = "a man is wearing a mask"
121 |
122 |
123 | n_prompt = "ng_deepnegative_v1_75t, (badhandv4:1.2), (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands,((monochrome)), ((grayscale)) watermark, moles, large breast, big breast"
124 |
125 |
126 |
127 | face_image = load_image("") # image path
128 |
129 |
130 |
131 | face_image = resize_img(face_image, size=(1024, 1024))
132 | # face_image = resize_img(face_image, size=(512, 512))
133 |
134 |
135 | face_emb, face_kps = face_extraction(image=face_image)
136 | face_kps = draw_kps(face_image, face_kps)
137 | # prompts = [prompt0, prompt1, prompt2, prompt3, prompt4, prompt5, prompt6, prompt7,prompt8,prompt9]
138 | prompts = [prompt0]
139 |
140 |
141 | pipe.set_ip_adapter_scale(0.8)
142 |
143 |
144 | print("================")
145 |
146 | inference
147 | for i in range(1):
148 | print("-------------")
149 | image = pipe(
150 | prompt=prompts[i],
151 | negative_prompt=n_prompt,
152 | image_embeds=face_emb,
153 | image=face_kps,
154 | controlnet_conditioing_scale=0.5,
155 | num_inference_steps=32,
156 | guidance_scale=5,
157 | ).images[0]
158 | print("+++++++++++++++++")
159 | ind = len(os.listdir("./results/"))
160 | # image.save("./results/test_%d.jpg" % (i))
161 | image.save("./results/test_%d.jpg" % (ind))
162 |
163 |
--------------------------------------------------------------------------------
/ip_adapter/__pycache__/attention_processor.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/ip_adapter/__pycache__/attention_processor.cpython-310.pyc
--------------------------------------------------------------------------------
/ip_adapter/__pycache__/resampler.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/ip_adapter/__pycache__/resampler.cpython-310.pyc
--------------------------------------------------------------------------------
/ip_adapter/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ShuheSH/CrossFaceID/c32c351967aaf9adb7c9bf91b725b894d6bb7ccf/ip_adapter/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/ip_adapter/attention_processor.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | try:
7 | import xformers
8 | import xformers.ops
9 | xformers_available = True
10 | except Exception as e:
11 | xformers_available = False
12 |
13 |
14 |
15 | class RegionControler(object):
16 | def __init__(self) -> None:
17 | self.prompt_image_conditioning = []
18 | region_control = RegionControler()
19 |
20 |
21 | class AttnProcessor(nn.Module):
22 | r"""
23 | Default processor for performing attention-related computations.
24 | """
25 | def __init__(
26 | self,
27 | hidden_size=None,
28 | cross_attention_dim=None,
29 | ):
30 | super().__init__()
31 |
32 | def __call__(
33 | self,
34 | attn,
35 | hidden_states,
36 | encoder_hidden_states=None,
37 | attention_mask=None,
38 | temb=None,
39 | ):
40 | residual = hidden_states
41 |
42 | if attn.spatial_norm is not None:
43 | hidden_states = attn.spatial_norm(hidden_states, temb)
44 |
45 | input_ndim = hidden_states.ndim
46 |
47 | if input_ndim == 4:
48 | batch_size, channel, height, width = hidden_states.shape
49 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
50 |
51 | batch_size, sequence_length, _ = (
52 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53 | )
54 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
55 |
56 | if attn.group_norm is not None:
57 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
58 |
59 | query = attn.to_q(hidden_states)
60 |
61 | if encoder_hidden_states is None:
62 | encoder_hidden_states = hidden_states
63 | elif attn.norm_cross:
64 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
65 |
66 | key = attn.to_k(encoder_hidden_states)
67 | value = attn.to_v(encoder_hidden_states)
68 |
69 | query = attn.head_to_batch_dim(query)
70 | key = attn.head_to_batch_dim(key)
71 | value = attn.head_to_batch_dim(value)
72 |
73 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
74 | hidden_states = torch.bmm(attention_probs, value)
75 | hidden_states = attn.batch_to_head_dim(hidden_states)
76 |
77 | # linear proj
78 | hidden_states = attn.to_out[0](hidden_states)
79 | # dropout
80 | hidden_states = attn.to_out[1](hidden_states)
81 |
82 | if input_ndim == 4:
83 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
84 |
85 | if attn.residual_connection:
86 | hidden_states = hidden_states + residual
87 |
88 | hidden_states = hidden_states / attn.rescale_output_factor
89 |
90 | return hidden_states
91 |
92 |
93 | class IPAttnProcessor(nn.Module):
94 | r"""
95 | Attention processor for IP-Adapater.
96 | Args:
97 | hidden_size (`int`):
98 | The hidden size of the attention layer.
99 | cross_attention_dim (`int`):
100 | The number of channels in the `encoder_hidden_states`.
101 | scale (`float`, defaults to 1.0):
102 | the weight scale of image prompt.
103 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104 | The context length of the image features.
105 | """
106 |
107 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
108 | super().__init__()
109 |
110 | self.hidden_size = hidden_size
111 | self.cross_attention_dim = cross_attention_dim
112 | self.scale = scale
113 | self.num_tokens = num_tokens
114 |
115 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
116 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
117 |
118 | def __call__(
119 | self,
120 | attn,
121 | hidden_states,
122 | encoder_hidden_states=None,
123 | attention_mask=None,
124 | temb=None,
125 | ):
126 | residual = hidden_states
127 |
128 | if attn.spatial_norm is not None:
129 | hidden_states = attn.spatial_norm(hidden_states, temb)
130 |
131 | input_ndim = hidden_states.ndim
132 |
133 | if input_ndim == 4:
134 | batch_size, channel, height, width = hidden_states.shape
135 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
136 |
137 | batch_size, sequence_length, _ = (
138 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
139 | )
140 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
141 |
142 | if attn.group_norm is not None:
143 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
144 |
145 | query = attn.to_q(hidden_states)
146 |
147 | if encoder_hidden_states is None:
148 | encoder_hidden_states = hidden_states
149 | else:
150 | # get encoder_hidden_states, ip_hidden_states
151 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
152 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
153 | if attn.norm_cross:
154 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
155 |
156 | key = attn.to_k(encoder_hidden_states)
157 | value = attn.to_v(encoder_hidden_states)
158 |
159 | query = attn.head_to_batch_dim(query)
160 | key = attn.head_to_batch_dim(key)
161 | value = attn.head_to_batch_dim(value)
162 |
163 | if xformers_available:
164 | hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
165 | else:
166 | attention_probs = attn.get_attention_scores(query, key, attention_mask)
167 | hidden_states = torch.bmm(attention_probs, value)
168 | hidden_states = attn.batch_to_head_dim(hidden_states)
169 |
170 | # for ip-adapter
171 | ip_key = self.to_k_ip(ip_hidden_states)
172 | ip_value = self.to_v_ip(ip_hidden_states)
173 |
174 | ip_key = attn.head_to_batch_dim(ip_key)
175 | ip_value = attn.head_to_batch_dim(ip_value)
176 |
177 | if xformers_available:
178 | ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)
179 | else:
180 | ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
181 | ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
182 | ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
183 |
184 | # region control
185 | if len(region_control.prompt_image_conditioning) == 1:
186 | region_mask = region_control.prompt_image_conditioning[0].get('region_mask', None)
187 | if region_mask is not None:
188 | h, w = region_mask.shape[:2]
189 | ratio = (h * w / query.shape[1]) ** 0.5
190 | mask = F.interpolate(region_mask[None, None], scale_factor=1/ratio, mode='nearest').reshape([1, -1, 1])
191 | else:
192 | mask = torch.ones_like(ip_hidden_states)
193 | ip_hidden_states = ip_hidden_states * mask
194 |
195 | hidden_states = hidden_states + self.scale * ip_hidden_states
196 |
197 | # linear proj
198 | hidden_states = attn.to_out[0](hidden_states)
199 | # dropout
200 | hidden_states = attn.to_out[1](hidden_states)
201 |
202 | if input_ndim == 4:
203 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
204 |
205 | if attn.residual_connection:
206 | hidden_states = hidden_states + residual
207 |
208 | hidden_states = hidden_states / attn.rescale_output_factor
209 |
210 | return hidden_states
211 |
212 |
213 | def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
214 | # TODO attention_mask
215 | query = query.contiguous()
216 | key = key.contiguous()
217 | value = value.contiguous()
218 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
219 | # hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
220 | return hidden_states
221 |
222 |
223 | class AttnProcessor2_0(torch.nn.Module):
224 | r"""
225 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
226 | """
227 | def __init__(
228 | self,
229 | hidden_size=None,
230 | cross_attention_dim=None,
231 | ):
232 | super().__init__()
233 | if not hasattr(F, "scaled_dot_product_attention"):
234 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
235 |
236 | def __call__(
237 | self,
238 | attn,
239 | hidden_states,
240 | encoder_hidden_states=None,
241 | attention_mask=None,
242 | temb=None,
243 | ):
244 | residual = hidden_states
245 |
246 | if attn.spatial_norm is not None:
247 | hidden_states = attn.spatial_norm(hidden_states, temb)
248 |
249 | input_ndim = hidden_states.ndim
250 |
251 | if input_ndim == 4:
252 | batch_size, channel, height, width = hidden_states.shape
253 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
254 |
255 | batch_size, sequence_length, _ = (
256 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
257 | )
258 |
259 | if attention_mask is not None:
260 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
261 | # scaled_dot_product_attention expects attention_mask shape to be
262 | # (batch, heads, source_length, target_length)
263 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
264 |
265 | if attn.group_norm is not None:
266 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
267 |
268 | query = attn.to_q(hidden_states)
269 |
270 | if encoder_hidden_states is None:
271 | encoder_hidden_states = hidden_states
272 | elif attn.norm_cross:
273 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
274 |
275 | key = attn.to_k(encoder_hidden_states)
276 | value = attn.to_v(encoder_hidden_states)
277 |
278 | inner_dim = key.shape[-1]
279 | head_dim = inner_dim // attn.heads
280 |
281 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282 |
283 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
284 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
285 |
286 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
287 | # TODO: add support for attn.scale when we move to Torch 2.1
288 | hidden_states = F.scaled_dot_product_attention(
289 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
290 | )
291 |
292 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
293 | hidden_states = hidden_states.to(query.dtype)
294 |
295 | # linear proj
296 | hidden_states = attn.to_out[0](hidden_states)
297 | # dropout
298 | hidden_states = attn.to_out[1](hidden_states)
299 |
300 | if input_ndim == 4:
301 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
302 |
303 | if attn.residual_connection:
304 | hidden_states = hidden_states + residual
305 |
306 | hidden_states = hidden_states / attn.rescale_output_factor
307 |
308 | return hidden_states
309 |
310 | class AttnProcessor2_0(torch.nn.Module):
311 | r"""
312 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
313 | """
314 |
315 | def __init__(
316 | self,
317 | hidden_size=None,
318 | cross_attention_dim=None,
319 | ):
320 | super().__init__()
321 | if not hasattr(F, "scaled_dot_product_attention"):
322 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
323 |
324 | def __call__(
325 | self,
326 | attn,
327 | hidden_states,
328 | encoder_hidden_states=None,
329 | attention_mask=None,
330 | temb=None,
331 | ):
332 | residual = hidden_states
333 |
334 | if attn.spatial_norm is not None:
335 | hidden_states = attn.spatial_norm(hidden_states, temb)
336 |
337 | input_ndim = hidden_states.ndim
338 |
339 | if input_ndim == 4:
340 | batch_size, channel, height, width = hidden_states.shape
341 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
342 |
343 | batch_size, sequence_length, _ = (
344 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
345 | )
346 |
347 | if attention_mask is not None:
348 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
349 | # scaled_dot_product_attention expects attention_mask shape to be
350 | # (batch, heads, source_length, target_length)
351 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
352 |
353 | if attn.group_norm is not None:
354 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
355 |
356 | query = attn.to_q(hidden_states)
357 |
358 | if encoder_hidden_states is None:
359 | encoder_hidden_states = hidden_states
360 | elif attn.norm_cross:
361 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
362 |
363 | key = attn.to_k(encoder_hidden_states)
364 | value = attn.to_v(encoder_hidden_states)
365 |
366 | inner_dim = key.shape[-1]
367 | head_dim = inner_dim // attn.heads
368 |
369 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
370 |
371 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
372 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
373 |
374 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
375 | # TODO: add support for attn.scale when we move to Torch 2.1
376 | hidden_states = F.scaled_dot_product_attention(
377 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
378 | )
379 |
380 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
381 | hidden_states = hidden_states.to(query.dtype)
382 |
383 | # linear proj
384 | hidden_states = attn.to_out[0](hidden_states)
385 | # dropout
386 | hidden_states = attn.to_out[1](hidden_states)
387 |
388 | if input_ndim == 4:
389 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
390 |
391 | if attn.residual_connection:
392 | hidden_states = hidden_states + residual
393 |
394 | hidden_states = hidden_states / attn.rescale_output_factor
395 |
396 | return hidden_states
397 |
398 |
399 | class IPAttnProcessor2_0(torch.nn.Module):
400 | r"""
401 | Attention processor for IP-Adapater for PyTorch 2.0.
402 | Args:
403 | hidden_size (`int`):
404 | The hidden size of the attention layer.
405 | cross_attention_dim (`int`):
406 | The number of channels in the `encoder_hidden_states`.
407 | scale (`float`, defaults to 1.0):
408 | the weight scale of image prompt.
409 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
410 | The context length of the image features.
411 | """
412 |
413 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
414 | super().__init__()
415 |
416 | if not hasattr(F, "scaled_dot_product_attention"):
417 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
418 |
419 | self.hidden_size = hidden_size
420 | self.cross_attention_dim = cross_attention_dim
421 | self.scale = scale
422 | self.num_tokens = num_tokens
423 |
424 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
425 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
426 |
427 | def __call__(
428 | self,
429 | attn,
430 | hidden_states,
431 | encoder_hidden_states=None,
432 | attention_mask=None,
433 | temb=None,
434 | ):
435 | residual = hidden_states
436 |
437 | if attn.spatial_norm is not None:
438 | hidden_states = attn.spatial_norm(hidden_states, temb)
439 |
440 | input_ndim = hidden_states.ndim
441 |
442 | if input_ndim == 4:
443 | batch_size, channel, height, width = hidden_states.shape
444 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
445 |
446 | batch_size, sequence_length, _ = (
447 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
448 | )
449 |
450 | if attention_mask is not None:
451 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
452 | # scaled_dot_product_attention expects attention_mask shape to be
453 | # (batch, heads, source_length, target_length)
454 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
455 |
456 | if attn.group_norm is not None:
457 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
458 |
459 | query = attn.to_q(hidden_states)
460 |
461 | if encoder_hidden_states is None:
462 | encoder_hidden_states = hidden_states
463 | else:
464 | # get encoder_hidden_states, ip_hidden_states
465 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
466 | encoder_hidden_states, ip_hidden_states = (
467 | encoder_hidden_states[:, :end_pos, :],
468 | encoder_hidden_states[:, end_pos:, :],
469 | )
470 | if attn.norm_cross:
471 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
472 |
473 | key = attn.to_k(encoder_hidden_states)
474 | value = attn.to_v(encoder_hidden_states)
475 |
476 | inner_dim = key.shape[-1]
477 | head_dim = inner_dim // attn.heads
478 |
479 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
480 |
481 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
482 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
483 |
484 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
485 | # TODO: add support for attn.scale when we move to Torch 2.1
486 | hidden_states = F.scaled_dot_product_attention(
487 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
488 | )
489 |
490 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
491 | hidden_states = hidden_states.to(query.dtype)
492 |
493 | # for ip-adapter
494 | ip_key = self.to_k_ip(ip_hidden_states)
495 | ip_value = self.to_v_ip(ip_hidden_states)
496 |
497 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
498 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
499 |
500 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
501 | # TODO: add support for attn.scale when we move to Torch 2.1
502 | ip_hidden_states = F.scaled_dot_product_attention(
503 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
504 | )
505 |
506 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
507 | ip_hidden_states = ip_hidden_states.to(query.dtype)
508 |
509 | hidden_states = hidden_states + self.scale * ip_hidden_states
510 |
511 | # linear proj
512 | hidden_states = attn.to_out[0](hidden_states)
513 | # dropout
514 | hidden_states = attn.to_out[1](hidden_states)
515 |
516 | if input_ndim == 4:
517 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
518 |
519 | if attn.residual_connection:
520 | hidden_states = hidden_states + residual
521 |
522 | hidden_states = hidden_states / attn.rescale_output_factor
523 |
524 | return hidden_states
525 |
--------------------------------------------------------------------------------
/ip_adapter/mlp.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | class MLPFeatureProjModel(nn.Module):
8 | """SD model with feature prompt"""
9 |
10 | # torch.Size([1, 49, 1536]) -> torch.Size([1, 49, 768])
11 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=1536):
12 | super().__init__()
13 |
14 | self.cross_attention_dim = cross_attention_dim
15 |
16 | self.proj = torch.nn.Sequential(
17 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim),
18 | torch.nn.GELU(),
19 | torch.nn.Linear(id_embeddings_dim, cross_attention_dim),
20 | torch.nn.LayerNorm(cross_attention_dim)
21 | )
22 |
23 | def forward(self, id_embeds):
24 | feature_tokens = self.proj(id_embeds)
25 |
26 | return feature_tokens
--------------------------------------------------------------------------------
/ip_adapter/resampler.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2 | import math
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | # FFN
9 | def FeedForward(dim, mult=4):
10 | inner_dim = int(dim * mult)
11 | return nn.Sequential(
12 | nn.LayerNorm(dim),
13 | nn.Linear(dim, inner_dim, bias=False),
14 | nn.GELU(),
15 | nn.Linear(inner_dim, dim, bias=False),
16 | )
17 |
18 |
19 | def reshape_tensor(x, heads):
20 | bs, length, width = x.shape
21 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
22 | x = x.view(bs, length, heads, -1)
23 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
24 | x = x.transpose(1, 2)
25 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
26 | x = x.reshape(bs, heads, length, -1)
27 | return x
28 |
29 |
30 | class PerceiverAttention(nn.Module):
31 | def __init__(self, *, dim, dim_head=64, heads=8):
32 | super().__init__()
33 | self.scale = dim_head**-0.5
34 | self.dim_head = dim_head
35 | self.heads = heads
36 | inner_dim = dim_head * heads
37 |
38 | self.norm1 = nn.LayerNorm(dim)
39 | self.norm2 = nn.LayerNorm(dim)
40 |
41 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
42 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
43 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
44 |
45 |
46 | def forward(self, x, latents):
47 | """
48 | Args:
49 | x (torch.Tensor): image features
50 | shape (b, n1, D)
51 | latent (torch.Tensor): latent features
52 | shape (b, n2, D)
53 | """
54 | x = self.norm1(x)
55 | latents = self.norm2(latents)
56 |
57 | b, l, _ = latents.shape
58 |
59 | q = self.to_q(latents)
60 | kv_input = torch.cat((x, latents), dim=-2)
61 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
62 |
63 | q = reshape_tensor(q, self.heads)
64 | k = reshape_tensor(k, self.heads)
65 | v = reshape_tensor(v, self.heads)
66 |
67 | # attention
68 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
69 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
70 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
71 | out = weight @ v
72 |
73 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
74 |
75 | return self.to_out(out)
76 |
77 |
78 | class Resampler(nn.Module):
79 | def __init__(
80 | self,
81 | dim=1024,
82 | depth=8,
83 | dim_head=64,
84 | heads=16,
85 | num_queries=8,
86 | embedding_dim=768,
87 | output_dim=1024,
88 | ff_mult=4,
89 | ):
90 | super().__init__()
91 |
92 | self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
93 |
94 | self.proj_in = nn.Linear(embedding_dim, dim)
95 |
96 | self.proj_out = nn.Linear(dim, output_dim)
97 | self.norm_out = nn.LayerNorm(output_dim)
98 |
99 | self.layers = nn.ModuleList([])
100 | for _ in range(depth):
101 | self.layers.append(
102 | nn.ModuleList(
103 | [
104 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
105 | FeedForward(dim=dim, mult=ff_mult),
106 | ]
107 | )
108 | )
109 |
110 | def forward(self, x):
111 |
112 | latents = self.latents.repeat(x.size(0), 1, 1)
113 |
114 | x = self.proj_in(x)
115 |
116 | for attn, ff in self.layers:
117 | latents = attn(x, latents) + latents
118 | latents = ff(latents) + latents
119 |
120 | latents = self.proj_out(latents)
121 | return self.norm_out(latents)
122 |
123 |
--------------------------------------------------------------------------------
/ip_adapter/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 |
3 |
4 | def is_torch2_available():
5 | return hasattr(F, "scaled_dot_product_attention")
6 |
--------------------------------------------------------------------------------
/pipeline_stable_diffusion_xl_instantid.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The InstantX Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17 |
18 | import cv2
19 | import math
20 |
21 | import numpy as np
22 | import PIL.Image
23 | import torch
24 | import torch.nn.functional as F
25 |
26 | from diffusers.image_processor import PipelineImageInput
27 |
28 | from diffusers.models import ControlNetModel
29 |
30 | from diffusers.utils import (
31 | deprecate,
32 | logging,
33 | replace_example_docstring,
34 | )
35 | from diffusers.utils.torch_utils import is_compiled_module, is_torch_version
36 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
37 |
38 | from diffusers import StableDiffusionXLControlNetPipeline
39 | from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
40 | from diffusers.utils.import_utils import is_xformers_available
41 |
42 | from ip_adapter.resampler import Resampler
43 |
44 | from ip_adapter.utils import is_torch2_available
45 |
46 | if is_torch2_available():
47 | from ip_adapter.attention_processor import (
48 | AttnProcessor2_0 as AttnProcessor,
49 | )
50 | from ip_adapter.attention_processor import (
51 | IPAttnProcessor2_0 as IPAttnProcessor,
52 | )
53 | else:
54 | from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
55 |
56 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57 |
58 |
59 | EXAMPLE_DOC_STRING = """
60 | Examples:
61 | ```py
62 | >>> # !pip install opencv-python transformers accelerate insightface
63 | >>> import diffusers
64 | >>> from diffusers.utils import load_image
65 | >>> from diffusers.models import ControlNetModel
66 |
67 | >>> import cv2
68 | >>> import torch
69 | >>> import numpy as np
70 | >>> from PIL import Image
71 |
72 | >>> from insightface.app import FaceAnalysis
73 | >>> from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
74 |
75 | >>> # download 'antelopev2' under ./models
76 | >>> app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
77 | >>> app.prepare(ctx_id=0, det_size=(640, 640))
78 |
79 | >>> # download models under ./checkpoints
80 | >>> face_adapter = f'./checkpoints/ip-adapter.bin'
81 | >>> controlnet_path = f'./checkpoints/ControlNetModel'
82 |
83 | >>> # load IdentityNet
84 | >>> controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
85 |
86 | >>> pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
87 | ... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
88 | ... )
89 | >>> pipe.cuda()
90 |
91 | >>> # load adapter
92 | >>> pipe.load_ip_adapter_instantid(face_adapter)
93 |
94 | >>> prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
95 | >>> negative_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"
96 |
97 | >>> # load an image
98 | >>> image = load_image("your-example.jpg")
99 |
100 | >>> face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))[-1]
101 | >>> face_emb = face_info['embedding']
102 | >>> face_kps = draw_kps(face_image, face_info['kps'])
103 |
104 | >>> pipe.set_ip_adapter_scale(0.8)
105 |
106 | >>> # generate image
107 | >>> image = pipe(
108 | ... prompt, image_embeds=face_emb, image=face_kps, controlnet_conditioning_scale=0.8
109 | ... ).images[0]
110 | ```from pipeline_stable_diffusion_xl_instantid import StableDiffusionXLInstantIDPipeline, draw_kps
111 |
112 | """
113 |
114 | def draw_kps(image_pil, kps, color_list=[(255,0,0), (0,255,0), (0,0,255), (255,255,0), (255,0,255)]):
115 |
116 | stickwidth = 4
117 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
118 | kps = np.array(kps)
119 |
120 | w, h = image_pil.size
121 | out_img = np.zeros([h, w, 3])
122 |
123 | for i in range(len(limbSeq)):
124 | index = limbSeq[i]
125 | color = color_list[index[0]]
126 |
127 | x = kps[index][:, 0]
128 | y = kps[index][:, 1]
129 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
130 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
131 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
132 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
133 | out_img = (out_img * 0.6).astype(np.uint8)
134 |
135 | for idx_kp, kp in enumerate(kps):
136 | color = color_list[idx_kp]
137 | x, y = kp
138 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
139 |
140 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
141 | return out_img_pil
142 |
143 | class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
144 |
145 | def cuda(self, dtype=torch.float16, use_xformers=False):
146 | self.to('cuda', dtype)
147 |
148 | if hasattr(self, 'image_proj_model'):
149 | self.image_proj_model.to(self.unet.device).to(self.unet.dtype)
150 |
151 | if use_xformers:
152 | if is_xformers_available():
153 | import xformers
154 | from packaging import version
155 |
156 | xformers_version = version.parse(xformers.__version__)
157 | if xformers_version == version.parse("0.0.16"):
158 | logger.warn(
159 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
160 | )
161 | self.enable_xformers_memory_efficient_attention()
162 | else:
163 | raise ValueError("xformers is not available. Make sure it is installed correctly")
164 |
165 | def load_ip_adapter_instantid(self, model_ckpt, image_emb_dim=512, num_tokens=16, scale=0.5):
166 | self.set_image_proj_model(model_ckpt, image_emb_dim, num_tokens)
167 | self.set_ip_adapter(model_ckpt, num_tokens, scale)
168 |
169 | def set_image_proj_model(self, model_ckpt, image_emb_dim=512, num_tokens=16):
170 |
171 | image_proj_model = Resampler(
172 | dim=1280,
173 | depth=4,
174 | dim_head=64,
175 | heads=20,
176 | num_queries=num_tokens,
177 | embedding_dim=image_emb_dim,
178 | output_dim=self.unet.config.cross_attention_dim,
179 | ff_mult=4,
180 | )
181 |
182 | image_proj_model.eval()
183 |
184 | self.image_proj_model = image_proj_model.to(self.device, dtype=self.dtype)
185 | state_dict = torch.load(model_ckpt, map_location="cpu")
186 | if 'image_proj' in state_dict:
187 | state_dict = state_dict["image_proj"]
188 | self.image_proj_model.load_state_dict(state_dict)
189 |
190 | self.image_proj_model_in_features = image_emb_dim
191 |
192 | def set_ip_adapter(self, model_ckpt, num_tokens, scale):
193 |
194 | unet = self.unet
195 | attn_procs = {}
196 | for name in unet.attn_processors.keys():
197 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
198 | if name.startswith("mid_block"):
199 | hidden_size = unet.config.block_out_channels[-1]
200 | elif name.startswith("up_blocks"):
201 | block_id = int(name[len("up_blocks.")])
202 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
203 | elif name.startswith("down_blocks"):
204 | block_id = int(name[len("down_blocks.")])
205 | hidden_size = unet.config.block_out_channels[block_id]
206 | if cross_attention_dim is None:
207 | attn_procs[name] = AttnProcessor().to(unet.device, dtype=unet.dtype)
208 | else:
209 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size,
210 | cross_attention_dim=cross_attention_dim,
211 | scale=scale,
212 | num_tokens=num_tokens).to(unet.device, dtype=unet.dtype)
213 | unet.set_attn_processor(attn_procs)
214 |
215 | state_dict = torch.load(model_ckpt, map_location="cpu")
216 | ip_layers = torch.nn.ModuleList(self.unet.attn_processors.values())
217 | if 'ip_adapter' in state_dict:
218 | state_dict = state_dict['ip_adapter']
219 | ip_layers.load_state_dict(state_dict)
220 |
221 | def set_ip_adapter_scale(self, scale):
222 | unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
223 | for attn_processor in unet.attn_processors.values():
224 | if isinstance(attn_processor, IPAttnProcessor):
225 | attn_processor.scale = scale
226 |
227 | def _encode_prompt_image_emb(self, prompt_image_emb, device, dtype, do_classifier_free_guidance):
228 |
229 | if isinstance(prompt_image_emb, torch.Tensor):
230 | prompt_image_emb = prompt_image_emb.clone().detach()
231 | else:
232 | prompt_image_emb = torch.tensor(prompt_image_emb)
233 |
234 | prompt_image_emb = prompt_image_emb.to(device=device, dtype=dtype)
235 | prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
236 |
237 | if do_classifier_free_guidance:
238 | prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
239 | else:
240 | prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
241 |
242 | prompt_image_emb = self.image_proj_model(prompt_image_emb)
243 | return prompt_image_emb
244 |
245 | @torch.no_grad()
246 | @replace_example_docstring(EXAMPLE_DOC_STRING)
247 | def __call__(
248 | self,
249 | prompt: Union[str, List[str]] = None,
250 | prompt_2: Optional[Union[str, List[str]]] = None,
251 | image: PipelineImageInput = None,
252 | height: Optional[int] = None,
253 | width: Optional[int] = None,
254 | num_inference_steps: int = 50,
255 | guidance_scale: float = 5.0,
256 | negative_prompt: Optional[Union[str, List[str]]] = None,
257 | negative_prompt_2: Optional[Union[str, List[str]]] = None,
258 | num_images_per_prompt: Optional[int] = 1,
259 | eta: float = 0.0,
260 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
261 | latents: Optional[torch.FloatTensor] = None,
262 | prompt_embeds: Optional[torch.FloatTensor] = None,
263 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
264 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
265 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
266 | image_embeds: Optional[torch.FloatTensor] = None,
267 | output_type: Optional[str] = "pil",
268 | return_dict: bool = True,
269 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
270 | controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
271 | guess_mode: bool = False,
272 | control_guidance_start: Union[float, List[float]] = 0.0,
273 | control_guidance_end: Union[float, List[float]] = 1.0,
274 | original_size: Tuple[int, int] = None,
275 | crops_coords_top_left: Tuple[int, int] = (0, 0),
276 | target_size: Tuple[int, int] = None,
277 | negative_original_size: Optional[Tuple[int, int]] = None,
278 | negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
279 | negative_target_size: Optional[Tuple[int, int]] = None,
280 | clip_skip: Optional[int] = None,
281 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
282 | callback_on_step_end_tensor_inputs: List[str] = ["latents"],
283 | **kwargs,
284 | ):
285 | r"""
286 | The call function to the pipeline for generation.
287 |
288 | Args:
289 | prompt (`str` or `List[str]`, *optional*):
290 | The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
291 | prompt_2 (`str` or `List[str]`, *optional*):
292 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
293 | used in both text-encoders.
294 | image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
295 | `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
296 | The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
297 | specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
298 | accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
299 | and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
300 | `init`, images must be passed as a list such that each element of the list can be correctly batched for
301 | input to a single ControlNet.
302 | height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
303 | The height in pixels of the generated image. Anything below 512 pixels won't work well for
304 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
305 | and checkpoints that are not specifically fine-tuned on low resolutions.
306 | width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
307 | The width in pixels of the generated image. Anything below 512 pixels won't work well for
308 | [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
309 | and checkpoints that are not specifically fine-tuned on low resolutions.
310 | num_inference_steps (`int`, *optional*, defaults to 50):
311 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
312 | expense of slower inference.
313 | guidance_scale (`float`, *optional*, defaults to 5.0):
314 | A higher guidance scale value encourages the model to generate images closely linked to the text
315 | `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
316 | negative_prompt (`str` or `List[str]`, *optional*):
317 | The prompt or prompts to guide what to not include in image generation. If not defined, you need to
318 | pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
319 | negative_prompt_2 (`str` or `List[str]`, *optional*):
320 | The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
321 | and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
322 | num_images_per_prompt (`int`, *optional*, defaults to 1):
323 | The number of images to generate per prompt.
324 | eta (`float`, *optional*, defaults to 0.0):
325 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
326 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
327 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
328 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
329 | generation deterministic.
330 | latents (`torch.FloatTensor`, *optional*):
331 | Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
332 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
333 | tensor is generated by sampling using the supplied random `generator`.
334 | prompt_embeds (`torch.FloatTensor`, *optional*):
335 | Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
336 | provided, text embeddings are generated from the `prompt` input argument.
337 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
338 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
339 | not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
340 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
341 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
342 | not provided, pooled text embeddings are generated from `prompt` input argument.
343 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
344 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
345 | weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
346 | argument.
347 | image_embeds (`torch.FloatTensor`, *optional*):
348 | Pre-generated image embeddings.
349 | output_type (`str`, *optional*, defaults to `"pil"`):
350 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
351 | return_dict (`bool`, *optional*, defaults to `True`):
352 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
353 | plain tuple.
354 | cross_attention_kwargs (`dict`, *optional*):
355 | A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
356 | [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
357 | controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
358 | The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
359 | to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
360 | the corresponding scale as a list.
361 | guess_mode (`bool`, *optional*, defaults to `False`):
362 | The ControlNet encoder tries to recognize the content of the input image even if you remove all
363 | prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
364 | control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
365 | The percentage of total steps at which the ControlNet starts applying.
366 | control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
367 | The percentage of total steps at which the ControlNet stops applying.
368 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
369 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
370 | `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
371 | explained in section 2.2 of
372 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
373 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
374 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
375 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
376 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
377 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
378 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
379 | For most cases, `target_size` should be set to the desired height and width of the generated image. If
380 | not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
381 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
382 | negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
383 | To negatively condition the generation process based on a specific image resolution. Part of SDXL's
384 | micro-conditioning as explained in section 2.2 of
385 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
386 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
387 | negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
388 | To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
389 | micro-conditioning as explained in section 2.2 of
390 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
391 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
392 | negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
393 | To negatively condition the generation process based on a target image resolution. It should be as same
394 | as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
395 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
396 | information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
397 | clip_skip (`int`, *optional*):
398 | Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
399 | the output of the pre-final layer will be used for computing the prompt embeddings.
400 | callback_on_step_end (`Callable`, *optional*):
401 | A function that calls at the end of each denoising steps during the inference. The function is called
402 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
403 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
404 | `callback_on_step_end_tensor_inputs`.
405 | callback_on_step_end_tensor_inputs (`List`, *optional*):
406 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
407 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
408 | `._callback_tensor_inputs` attribute of your pipeine class.
409 |
410 | Examples:
411 |
412 | Returns:
413 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
414 | If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
415 | otherwise a `tuple` is returned containing the output images.
416 | """
417 |
418 | callback = kwargs.pop("callback", None)
419 | callback_steps = kwargs.pop("callback_steps", None)
420 |
421 | if callback is not None:
422 | deprecate(
423 | "callback",
424 | "1.0.0",
425 | "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
426 | )
427 | if callback_steps is not None:
428 | deprecate(
429 | "callback_steps",
430 | "1.0.0",
431 | "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
432 | )
433 |
434 | controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
435 |
436 | # align format for control guidance
437 | if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
438 | control_guidance_start = len(control_guidance_end) * [control_guidance_start]
439 | elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
440 | control_guidance_end = len(control_guidance_start) * [control_guidance_end]
441 | elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
442 | mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
443 | control_guidance_start, control_guidance_end = (
444 | mult * [control_guidance_start],
445 | mult * [control_guidance_end],
446 | )
447 |
448 | # 1. Check inputs. Raise error if not correct
449 | self.check_inputs(
450 | prompt,
451 | prompt_2,
452 | image,
453 | callback_steps,
454 | negative_prompt,
455 | negative_prompt_2,
456 | prompt_embeds,
457 | negative_prompt_embeds,
458 | pooled_prompt_embeds,
459 | negative_pooled_prompt_embeds,
460 | controlnet_conditioning_scale,
461 | control_guidance_start,
462 | control_guidance_end,
463 | callback_on_step_end_tensor_inputs,
464 | )
465 |
466 | self._guidance_scale = guidance_scale
467 | self._clip_skip = clip_skip
468 | self._cross_attention_kwargs = cross_attention_kwargs
469 |
470 | # 2. Define call parameters
471 | if prompt is not None and isinstance(prompt, str):
472 | batch_size = 1
473 | elif prompt is not None and isinstance(prompt, list):
474 | batch_size = len(prompt)
475 | else:
476 | batch_size = prompt_embeds.shape[0]
477 |
478 | device = self._execution_device
479 |
480 | if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
481 | controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
482 |
483 | global_pool_conditions = (
484 | controlnet.config.global_pool_conditions
485 | if isinstance(controlnet, ControlNetModel)
486 | else controlnet.nets[0].config.global_pool_conditions
487 | )
488 | guess_mode = guess_mode or global_pool_conditions
489 |
490 | # 3.1 Encode input prompt
491 | text_encoder_lora_scale = (
492 | self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
493 | )
494 | (
495 | prompt_embeds,
496 | negative_prompt_embeds,
497 | pooled_prompt_embeds,
498 | negative_pooled_prompt_embeds,
499 | ) = self.encode_prompt(
500 | prompt,
501 | prompt_2,
502 | device,
503 | num_images_per_prompt,
504 | self.do_classifier_free_guidance,
505 | negative_prompt,
506 | negative_prompt_2,
507 | prompt_embeds=prompt_embeds,
508 | negative_prompt_embeds=negative_prompt_embeds,
509 | pooled_prompt_embeds=pooled_prompt_embeds,
510 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
511 | lora_scale=text_encoder_lora_scale,
512 | clip_skip=self.clip_skip,
513 | )
514 |
515 | # 3.2 Encode image prompt
516 | prompt_image_emb = self._encode_prompt_image_emb(image_embeds,
517 | device,
518 | self.unet.dtype,
519 | self.do_classifier_free_guidance)
520 |
521 | # 4. Prepare image
522 | if isinstance(controlnet, ControlNetModel):
523 | image = self.prepare_image(
524 | image=image,
525 | width=width,
526 | height=height,
527 | batch_size=batch_size * num_images_per_prompt,
528 | num_images_per_prompt=num_images_per_prompt,
529 | device=device,
530 | dtype=controlnet.dtype,
531 | do_classifier_free_guidance=self.do_classifier_free_guidance,
532 | guess_mode=guess_mode,
533 | )
534 | height, width = image.shape[-2:]
535 | print("image: ", image.shape)
536 | elif isinstance(controlnet, MultiControlNetModel):
537 | images = []
538 |
539 | for image_ in image:
540 | image_ = self.prepare_image(
541 | image=image_,
542 | width=width,
543 | height=height,
544 | batch_size=batch_size * num_images_per_prompt,
545 | num_images_per_prompt=num_images_per_prompt,
546 | device=device,
547 | dtype=controlnet.dtype,
548 | do_classifier_free_guidance=self.do_classifier_free_guidance,
549 | guess_mode=guess_mode,
550 | )
551 |
552 | images.append(image_)
553 |
554 | image = images
555 | height, width = image[0].shape[-2:]
556 | else:
557 | assert False
558 |
559 | # 5. Prepare timesteps
560 | self.scheduler.set_timesteps(num_inference_steps, device=device)
561 | timesteps = self.scheduler.timesteps
562 | self._num_timesteps = len(timesteps)
563 |
564 | # 6. Prepare latent variables
565 | num_channels_latents = self.unet.config.in_channels
566 | latents = self.prepare_latents(
567 | batch_size * num_images_per_prompt,
568 | num_channels_latents,
569 | height,
570 | width,
571 | prompt_embeds.dtype,
572 | device,
573 | generator,
574 | latents,
575 | )
576 |
577 | # 6.5 Optionally get Guidance Scale Embedding
578 | timestep_cond = None
579 | if self.unet.config.time_cond_proj_dim is not None:
580 | guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
581 | timestep_cond = self.get_guidance_scale_embedding(
582 | guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
583 | ).to(device=device, dtype=latents.dtype)
584 |
585 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
586 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
587 |
588 | # 7.1 Create tensor stating which controlnets to keep
589 | controlnet_keep = []
590 | for i in range(len(timesteps)):
591 | keeps = [
592 | 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
593 | for s, e in zip(control_guidance_start, control_guidance_end)
594 | ]
595 | controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
596 |
597 | # 7.2 Prepare added time ids & embeddings
598 | if isinstance(image, list):
599 | original_size = original_size or image[0].shape[-2:]
600 | else:
601 | original_size = original_size or image.shape[-2:]
602 | target_size = target_size or (height, width)
603 |
604 | add_text_embeds = pooled_prompt_embeds
605 | if self.text_encoder_2 is None:
606 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
607 | else:
608 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
609 |
610 | add_time_ids = self._get_add_time_ids(
611 | original_size,
612 | crops_coords_top_left,
613 | target_size,
614 | dtype=prompt_embeds.dtype,
615 | text_encoder_projection_dim=text_encoder_projection_dim,
616 | )
617 |
618 | if negative_original_size is not None and negative_target_size is not None:
619 | negative_add_time_ids = self._get_add_time_ids(
620 | negative_original_size,
621 | negative_crops_coords_top_left,
622 | negative_target_size,
623 | dtype=prompt_embeds.dtype,
624 | text_encoder_projection_dim=text_encoder_projection_dim,
625 | )
626 | else:
627 | negative_add_time_ids = add_time_ids
628 |
629 | if self.do_classifier_free_guidance:
630 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
631 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
632 | add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
633 |
634 | prompt_embeds = prompt_embeds.to(device)
635 | add_text_embeds = add_text_embeds.to(device)
636 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
637 | encoder_hidden_states = torch.cat([prompt_embeds, prompt_image_emb], dim=1)
638 |
639 | # 8. Denoising loop
640 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
641 | is_unet_compiled = is_compiled_module(self.unet)
642 | is_controlnet_compiled = is_compiled_module(self.controlnet)
643 | is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
644 |
645 | with self.progress_bar(total=num_inference_steps) as progress_bar:
646 | for i, t in enumerate(timesteps):
647 | # Relevant thread:
648 | # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
649 | if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1:
650 | torch._inductor.cudagraph_mark_step_begin()
651 | # expand the latents if we are doing classifier free guidance
652 | latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
653 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
654 |
655 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
656 |
657 | # controlnet(s) inference
658 | if guess_mode and self.do_classifier_free_guidance:
659 | # Infer ControlNet only for the conditional batch.
660 | control_model_input = latents
661 | control_model_input = self.scheduler.scale_model_input(control_model_input, t)
662 | controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
663 | controlnet_added_cond_kwargs = {
664 | "text_embeds": add_text_embeds.chunk(2)[1],
665 | "time_ids": add_time_ids.chunk(2)[1],
666 | }
667 | else:
668 | control_model_input = latent_model_input
669 | controlnet_prompt_embeds = prompt_embeds
670 | controlnet_added_cond_kwargs = added_cond_kwargs
671 |
672 | if isinstance(controlnet_keep[i], list):
673 | cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
674 | else:
675 | controlnet_cond_scale = controlnet_conditioning_scale
676 | if isinstance(controlnet_cond_scale, list):
677 | controlnet_cond_scale = controlnet_cond_scale[0]
678 | cond_scale = controlnet_cond_scale * controlnet_keep[i]
679 |
680 | down_block_res_samples, mid_block_res_sample = self.controlnet(
681 | control_model_input,
682 | t,
683 | encoder_hidden_states=prompt_image_emb,
684 | controlnet_cond=image,
685 | conditioning_scale=cond_scale,
686 | guess_mode=guess_mode,
687 | added_cond_kwargs=controlnet_added_cond_kwargs,
688 | return_dict=False,
689 | )
690 |
691 | if guess_mode and self.do_classifier_free_guidance:
692 | # Infered ControlNet only for the conditional batch.
693 | # To apply the output of ControlNet to both the unconditional and conditional batches,
694 | # add 0 to the unconditional batch to keep it unchanged.
695 | down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
696 | mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
697 |
698 | # predict the noise residual
699 | noise_pred = self.unet(
700 | latent_model_input,
701 | t,
702 | encoder_hidden_states=encoder_hidden_states,
703 | timestep_cond=timestep_cond,
704 | cross_attention_kwargs=self.cross_attention_kwargs,
705 | down_block_additional_residuals=down_block_res_samples,
706 | mid_block_additional_residual=mid_block_res_sample,
707 | added_cond_kwargs=added_cond_kwargs,
708 | return_dict=False,
709 | )[0]
710 |
711 | # perform guidance
712 | if self.do_classifier_free_guidance:
713 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
714 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
715 |
716 | # compute the previous noisy sample x_t -> x_t-1
717 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
718 |
719 | if callback_on_step_end is not None:
720 | callback_kwargs = {}
721 | for k in callback_on_step_end_tensor_inputs:
722 | callback_kwargs[k] = locals()[k]
723 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
724 |
725 | latents = callback_outputs.pop("latents", latents)
726 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
727 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
728 |
729 | # call the callback, if provided
730 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
731 | progress_bar.update()
732 | if callback is not None and i % callback_steps == 0:
733 | step_idx = i // getattr(self.scheduler, "order", 1)
734 | callback(step_idx, t, latents)
735 |
736 | if not output_type == "latent":
737 | # make sure the VAE is in float32 mode, as it overflows in float16
738 | needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
739 | if needs_upcasting:
740 | self.upcast_vae()
741 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
742 |
743 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
744 |
745 | # cast back to fp16 if needed
746 | if needs_upcasting:
747 | self.vae.to(dtype=torch.float16)
748 | else:
749 | image = latents
750 |
751 | if not output_type == "latent":
752 | # apply watermark if available
753 | if self.watermark is not None:
754 | image = self.watermark.apply_watermark(image)
755 |
756 | image = self.image_processor.postprocess(image, output_type=output_type)
757 |
758 | # Offload all models
759 | self.maybe_free_model_hooks()
760 |
761 | if not return_dict:
762 | return (image,)
763 |
764 | return StableDiffusionXLPipelineOutput(images=image)
--------------------------------------------------------------------------------
/train_instantId_sdxl.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import random
4 | import argparse
5 | from pathlib import Path
6 | import json
7 | import itertools
8 | import time
9 | from datetime import datetime
10 | import shutil
11 | import torch
12 | import torch.nn.functional as F
13 | import numpy as np
14 | import math
15 | import cv2
16 | from torchvision import transforms
17 | from PIL import Image
18 | import PIL
19 | from transformers import CLIPImageProcessor
20 | from accelerate import Accelerator
21 | from accelerate.logging import get_logger
22 | from accelerate.utils import ProjectConfiguration
23 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, ControlNetModel
24 | from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
25 |
26 | from ip_adapter.resampler import Resampler
27 | from ip_adapter.utils import is_torch2_available
28 |
29 | if is_torch2_available():
30 | from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
31 | else:
32 | from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor
33 |
34 |
35 | # Draw the input image for controlnet based on facial keypoints.
36 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
37 | stickwidth = 4
38 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
39 | kps = np.array(kps)
40 |
41 | w, h = image_pil.size
42 | out_img = np.zeros([h, w, 3])
43 |
44 | for i in range(len(limbSeq)):
45 | index = limbSeq[i]
46 | color = color_list[index[0]]
47 |
48 | x = kps[index][:, 0]
49 | y = kps[index][:, 1]
50 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
51 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
52 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
53 | 360, 1)
54 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
55 | out_img = (out_img * 0.6).astype(np.uint8)
56 |
57 | for idx_kp, kp in enumerate(kps):
58 | color = color_list[idx_kp]
59 | x, y = kp
60 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
61 |
62 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
63 | return out_img_pil
64 |
65 | # Process the dataset by loading info from a JSON file, which includes image files, image labels, feature files, keypoint coordinates.
66 | class MyDataset(torch.utils.data.Dataset):
67 |
68 | def __init__(self, json_file, tokenizer, tokenizer_2, size=1024, center_crop=True,
69 | t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""):
70 | super().__init__()
71 |
72 | self.tokenizer = tokenizer
73 | self.tokenizer_2 = tokenizer_2
74 | self.size = size
75 | self.center_crop = center_crop
76 | self.i_drop_rate = i_drop_rate
77 | self.t_drop_rate = t_drop_rate
78 | self.ti_drop_rate = ti_drop_rate
79 | self.image_root_path = image_root_path
80 |
81 | self.data = []
82 | with open(json_file, 'r') as f:
83 | for line in f:
84 | self.data.append(json.loads(line))
85 |
86 | self.image_transforms = transforms.Compose(
87 | [
88 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
89 | transforms.ToTensor(),
90 | transforms.Normalize([0.5], [0.5]),
91 | ]
92 | )
93 |
94 | self.conditioning_image_transforms = transforms.Compose(
95 | [
96 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
97 | transforms.ToTensor(),
98 | ]
99 | )
100 |
101 | self.clip_image_processor = CLIPImageProcessor()
102 |
103 | def __getitem__(self, idx):
104 | item = self.data[idx]
105 | image_file = item["image_file"]
106 | text = item["additional_feature"]
107 | bbox = item['bbox']
108 | landmarks = item['landmarks']
109 | feature_file = item["insightface_feature_file"]
110 |
111 | # read image
112 | raw_image = Image.open(os.path.join(self.image_root_path, image_file))
113 | # draw keypoints
114 | kps_image = draw_kps(raw_image.convert("RGB"), landmarks)
115 |
116 | # original size
117 | original_width, original_height = raw_image.size
118 | original_size = torch.tensor([original_height, original_width])
119 |
120 | # transform raw_image and kps_image
121 | image_tensor = self.image_transforms(raw_image.convert("RGB"))
122 | kps_image_tensor = self.conditioning_image_transforms(kps_image)
123 |
124 | # random crop
125 | delta_h = image_tensor.shape[1] - self.size
126 | delta_w = image_tensor.shape[2] - self.size
127 | assert not all([delta_h, delta_w])
128 |
129 | if self.center_crop:
130 | top = delta_h // 2
131 | left = delta_w // 2
132 | else:
133 | top = np.random.randint(0, delta_h // 2 + 1) # random top crop
134 | # top = np.random.randint(0, delta_h + 1) # random crop
135 | left = np.random.randint(0, delta_w + 1) # random crop
136 |
137 | # The image and kps_image must follow the same cropping to ensure that the facial coordinates correspond correctly.
138 | image = transforms.functional.crop(
139 | image_tensor, top=top, left=left, height=self.size, width=self.size
140 | )
141 | kps_image = transforms.functional.crop(
142 | kps_image_tensor, top=top, left=left, height=self.size, width=self.size
143 | )
144 |
145 | crop_coords_top_left = torch.tensor([top, left])
146 |
147 | # load face feature
148 | # face_id_embed = torch.load(os.path.join(self.image_root_path, feature_file), map_location="cpu")
149 | face_id_embed = np.load(os.path.join(self.image_root_path, feature_file))
150 | face_id_embed = torch.from_numpy(face_id_embed)
151 | face_id_embed = face_id_embed.reshape(1, -1)
152 |
153 | # set cfg drop rate
154 | drop_feature_embed = 0
155 | drop_text_embed = 0
156 | rand_num = random.random()
157 | if rand_num < self.i_drop_rate:
158 | drop_feature_embed = 1
159 | elif rand_num < (self.i_drop_rate + self.t_drop_rate):
160 | drop_text_embed = 1
161 | elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
162 | drop_text_embed = 1
163 | drop_feature_embed = 1
164 |
165 | # CFG process
166 | if drop_text_embed:
167 | text = ""
168 | if drop_feature_embed:
169 | face_id_embed = torch.zeros_like(face_id_embed)
170 |
171 | # get text and tokenize
172 | text_input_ids = self.tokenizer(
173 | text,
174 | max_length=self.tokenizer.model_max_length,
175 | padding="max_length",
176 | truncation=True,
177 | return_tensors="pt"
178 | ).input_ids
179 |
180 | text_input_ids_2 = self.tokenizer_2(
181 | text,
182 | max_length=self.tokenizer_2.model_max_length,
183 | padding="max_length",
184 | truncation=True,
185 | return_tensors="pt"
186 | ).input_ids
187 |
188 | return {
189 | "image": image,
190 | "kps_image": kps_image,
191 | "text_input_ids": text_input_ids,
192 | "text_input_ids_2": text_input_ids_2,
193 | "face_id_embed": face_id_embed,
194 | "original_size": original_size,
195 | "crop_coords_top_left": crop_coords_top_left,
196 | "target_size": torch.tensor([self.size, self.size]),
197 | }
198 |
199 | def __len__(self):
200 | return len(self.data)
201 |
202 |
203 | def collate_fn(data):
204 | images = torch.stack([example["image"] for example in data])
205 | kps_images = torch.stack([example["kps_image"] for example in data])
206 |
207 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
208 | text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
209 | face_id_embed = torch.stack([example["face_id_embed"] for example in data])
210 | original_size = torch.stack([example["original_size"] for example in data])
211 | crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
212 | target_size = torch.stack([example["target_size"] for example in data])
213 |
214 | return {
215 | "images": images,
216 | "kps_images": kps_images,
217 | "text_input_ids": text_input_ids,
218 | "text_input_ids_2": text_input_ids_2,
219 | "face_id_embed": face_id_embed,
220 | "original_size": original_size,
221 | "crop_coords_top_left": crop_coords_top_left,
222 | "target_size": target_size,
223 | }
224 |
225 |
226 | class InstantIDAdapter(torch.nn.Module):
227 | """InstantIDAdapter"""
228 | def __init__(self, unet, controlnet, feature_proj_model, adapter_modules, ckpt_path=None):
229 | super().__init__()
230 | self.unet = unet
231 | self.controlnet = controlnet
232 | self.feature_proj_model = feature_proj_model
233 | self.adapter_modules = adapter_modules
234 | if ckpt_path is not None:
235 | self.load_from_checkpoint(ckpt_path)
236 |
237 | def forward(self,noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, feature_embeds, controlnet_image):
238 | face_embedding = self.feature_proj_model(feature_embeds)
239 | encoder_hidden_states = torch.cat([encoder_hidden_states, face_embedding], dim=1)
240 | # ControlNet conditioning.
241 | down_block_res_samples, mid_block_res_sample = self.controlnet(
242 | noisy_latents,
243 | timesteps,
244 | encoder_hidden_states=face_embedding, # Insightface feature
245 | added_cond_kwargs=unet_added_cond_kwargs,
246 | controlnet_cond=controlnet_image, # keypoints image
247 | return_dict=False,
248 | )
249 | # Predict the noise residual.
250 | noise_pred = self.unet(
251 | noisy_latents,
252 | timesteps,
253 | encoder_hidden_states=encoder_hidden_states,
254 | added_cond_kwargs=unet_added_cond_kwargs,
255 | down_block_additional_residuals=[sample for sample in down_block_res_samples],
256 | mid_block_additional_residual=mid_block_res_sample,
257 | ).sample
258 |
259 | return noise_pred
260 |
261 | def load_from_checkpoint(self, ckpt_path: str):
262 | # Calculate original checksums
263 | orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.feature_proj_model.parameters()]))
264 | orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
265 |
266 | state_dict = torch.load(ckpt_path, map_location="cpu")
267 |
268 | # Check if 'latents' exists in both the saved state_dict and the current model's state_dict
269 | strict_load_feature_proj_model = True
270 | if "latents" in state_dict["image_proj"] and "latents" in self.feature_proj_model.state_dict():
271 | # Check if the shapes are mismatched
272 | if state_dict["image_proj"]["latents"].shape != self.feature_proj_model.state_dict()["latents"].shape:
273 | print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
274 | print("Removing 'latents' from checkpoint and loading the rest of the weights.")
275 | del state_dict["image_proj"]["latents"]
276 | strict_load_feature_proj_model = False
277 |
278 | # Load state dict for feature_proj_model and adapter_modules
279 | self.feature_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_feature_proj_model)
280 | self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
281 |
282 | # Calculate new checksums
283 | new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.feature_proj_model.parameters()]))
284 | new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
285 |
286 | # Verify if the weights have changed
287 | assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of feature_proj_model did not change!"
288 | assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
289 |
290 | print(f"Successfully loaded weights from checkpoint {ckpt_path}")
291 |
292 |
293 | def parse_args():
294 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
295 | parser.add_argument(
296 | "--pretrained_model_name_or_path",
297 | type=str,
298 | default=None,
299 | required=True,
300 | help="Path to pretrained model or model identifier from huggingface.co/models.",
301 | )
302 | parser.add_argument(
303 | "--pretrained_ip_adapter_path",
304 | type=str,
305 | default=None,
306 | help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
307 | )
308 | parser.add_argument(
309 | "--controlnet_model_name_or_path",
310 | type=str,
311 | default=None,
312 | help="Path to pretrained controlnet model. If not specified weights are initialized from unet.",
313 | )
314 |
315 | parser.add_argument(
316 | "--num_tokens",
317 | type=int,
318 | default=16,
319 | help="Number of tokens to query from the CLIP image encoding.",
320 | )
321 | parser.add_argument(
322 | "--checkpoints_total_limit",
323 | type=int,
324 | default=1,
325 | help=(
326 | "Save a checkpoint of the training state every X updates"
327 | ),
328 | )
329 | parser.add_argument(
330 | "--data_json_file",
331 | type=str,
332 | default=None,
333 | required=True,
334 | help="Training data",
335 | )
336 | parser.add_argument(
337 | "--data_root_path",
338 | type=str,
339 | default="",
340 | required=False,
341 | help="Training data root path",
342 | )
343 | parser.add_argument('--clip_proc_mode',
344 | choices=["seg_align", "seg_crop", "orig_align", "orig_crop", "seg_align_pad",
345 | "orig_align_pad"],
346 | default="orig_crop",
347 | help='The mode to preprocess clip image encoder input.')
348 |
349 | parser.add_argument(
350 | "--image_encoder_path",
351 | type=str,
352 | default=None,
353 | required=True,
354 | help="Path to CLIP image encoder",
355 | )
356 | parser.add_argument(
357 | "--center_crop",
358 | default=False,
359 | action="store_true",
360 | help=(
361 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
362 | " cropped. The images will be resized to the resolution first before cropping."
363 | ),
364 | )
365 | parser.add_argument(
366 | "--output_dir",
367 | type=str,
368 | default="sd-ip_adapter",
369 | help="The output directory where the model predictions and checkpoints will be written.",
370 | )
371 | parser.add_argument(
372 | "--logging_dir",
373 | type=str,
374 | default="logs",
375 | help=(
376 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
377 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
378 | ),
379 | )
380 | parser.add_argument(
381 | "--resolution",
382 | type=int,
383 | default=512,
384 | help=(
385 | "The resolution for input images"
386 | ),
387 | )
388 | parser.add_argument(
389 | "--learning_rate",
390 | type=float,
391 | default=1e-4,
392 | help="Learning rate to use.",
393 | )
394 | parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
395 | parser.add_argument("--num_train_epochs", type=int, default=100)
396 | parser.add_argument(
397 | "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
398 | )
399 | parser.add_argument(
400 | "--dataloader_num_workers",
401 | type=int,
402 | default=0,
403 | help=(
404 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
405 | ),
406 | )
407 | parser.add_argument(
408 | "--save_steps",
409 | type=int,
410 | default=2000,
411 | help=(
412 | "Save a checkpoint of the training state every X updates"
413 | ),
414 | )
415 | parser.add_argument(
416 | "--mixed_precision",
417 | type=str,
418 | default=None,
419 | choices=["no", "fp16", "bf16"],
420 | help=(
421 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
422 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
423 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
424 | ),
425 | )
426 | parser.add_argument(
427 | "--report_to",
428 | type=str,
429 | default="tensorboard",
430 | help=(
431 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
432 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
433 | ),
434 | )
435 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
436 | parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
437 |
438 | args = parser.parse_args()
439 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
440 | if env_local_rank != -1 and env_local_rank != args.local_rank:
441 | args.local_rank = env_local_rank
442 |
443 | return args
444 |
445 |
446 | def main():
447 | args = parse_args()
448 | logging_dir = Path(args.output_dir, args.logging_dir)
449 |
450 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
451 |
452 | accelerator = Accelerator(
453 | mixed_precision=args.mixed_precision,
454 | log_with=args.report_to,
455 | project_config=accelerator_project_config,
456 | )
457 |
458 | num_devices = accelerator.num_processes
459 |
460 | if accelerator.is_main_process:
461 | if args.output_dir is not None:
462 | os.makedirs(args.output_dir, exist_ok=True)
463 |
464 | # Load scheduler, tokenizer and models.
465 | noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
466 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
467 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
468 | tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
469 | text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder_2")
470 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
471 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
472 | image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
473 | if args.controlnet_model_name_or_path:
474 | print("Loading existing controlnet weights")
475 | controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path)
476 | else:
477 | print("Initializing controlnet weights from unet")
478 | controlnet = ControlNetModel.from_unet(unet)
479 |
480 | # freeze parameters of models to save more memory
481 | unet.requires_grad_(False)
482 | vae.requires_grad_(False)
483 | text_encoder.requires_grad_(False)
484 | text_encoder_2.requires_grad_(False)
485 | image_encoder.requires_grad_(False)
486 | controlnet.requires_grad_(True)
487 | controlnet.train()
488 |
489 | # ip-adapter: insightface feature
490 | num_tokens = 16
491 |
492 | feature_proj_model = Resampler(
493 | dim=1280,
494 | depth=4,
495 | dim_head=64,
496 | heads=20,
497 | num_queries=num_tokens,
498 | embedding_dim=512,
499 | output_dim=unet.config.cross_attention_dim,
500 | ff_mult=4,
501 | )
502 |
503 | # init adapter modules
504 | attn_procs = {}
505 | unet_sd = unet.state_dict()
506 | for name in unet.attn_processors.keys():
507 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
508 | if name.startswith("mid_block"):
509 | hidden_size = unet.config.block_out_channels[-1]
510 | elif name.startswith("up_blocks"):
511 | block_id = int(name[len("up_blocks.")])
512 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
513 | elif name.startswith("down_blocks"):
514 | block_id = int(name[len("down_blocks.")])
515 | hidden_size = unet.config.block_out_channels[block_id]
516 | if cross_attention_dim is None:
517 | attn_procs[name] = AttnProcessor()
518 | else:
519 | layer_name = name.split(".processor")[0]
520 | weights = {
521 | "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
522 | "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
523 | }
524 | attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens)
525 | attn_procs[name].load_state_dict(weights)
526 | unet.set_attn_processor(attn_procs)
527 | adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
528 | # Instantiate InstantIDAdapter from pretrained model or from scratch.
529 | ip_adapter = InstantIDAdapter(unet, controlnet, feature_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
530 |
531 | # Register a hook function to process the state of a specific module before saving.
532 | def save_model_hook(models, weights, output_dir):
533 | if accelerator.is_main_process:
534 | # find instance of InstantIDAdapter Model.
535 | for i, model_instance in enumerate(models):
536 | if isinstance(model_instance, InstantIDAdapter):
537 | # When saving a checkpoint, only save the ip-adapter and image_proj, do not save the unet.
538 | ip_adapter_state = {
539 | 'image_proj': model_instance.feature_proj_model.state_dict(),
540 | 'ip_adapter': model_instance.adapter_modules.state_dict(),
541 | }
542 | torch.save(ip_adapter_state, os.path.join(output_dir, 'pytorch_model.bin'))
543 | print(f"IP-Adapter Model weights saved in {os.path.join(output_dir, 'pytorch_model.bin')}")
544 | # Save controlnet separately.
545 | sub_dir = "controlnet"
546 | model_instance.controlnet.save_pretrained(os.path.join(output_dir, sub_dir))
547 | print(f"Controlnet weights saved in {os.path.join(output_dir, sub_dir)}")
548 | # Remove the corresponding weights from the weights list because they have been saved separately.
549 | # Remember not to delete the corresponding model, otherwise, you will not be able to save the model
550 | # starting from the second epoch.
551 | weights.pop(i)
552 | break
553 |
554 | def load_model_hook(models, input_dir):
555 | # find instance of InstantIDAdapter Model.
556 | while len(models) > 0:
557 | model_instance = models.pop()
558 | if isinstance(model_instance, InstantIDAdapter):
559 | ip_adapter_path = os.path.join(input_dir, 'pytorch_model.bin')
560 | if os.path.exists(ip_adapter_path):
561 | ip_adapter_state = torch.load(ip_adapter_path)
562 | model_instance.feature_proj_model.load_state_dict(ip_adapter_state['image_proj'])
563 | model_instance.adapter_modules.load_state_dict(ip_adapter_state['ip_adapter'])
564 | sub_dir = "controlnet"
565 | model_instance.controlnet.from_pretrained(os.path.join(input_dir, sub_dir))
566 | print(f"Model weights loaded from {ip_adapter_path}")
567 | else:
568 | print(f"No saved weights found at {ip_adapter_path}")
569 |
570 |
571 | # Register hook functions for saving and loading.
572 | accelerator.register_save_state_pre_hook(save_model_hook)
573 | accelerator.register_load_state_pre_hook(load_model_hook)
574 |
575 | weight_dtype = torch.float32
576 | if accelerator.mixed_precision == "fp16":
577 | weight_dtype = torch.float16
578 | elif accelerator.mixed_precision == "bf16":
579 | weight_dtype = torch.bfloat16
580 | # unet.to(accelerator.device, dtype=weight_dtype) # error
581 | vae.to(accelerator.device) # use fp32
582 | text_encoder.to(accelerator.device, dtype=weight_dtype)
583 | text_encoder_2.to(accelerator.device, dtype=weight_dtype)
584 | image_encoder.to(accelerator.device, dtype=weight_dtype)
585 | # controlnet.to(accelerator.device, dtype=weight_dtype) # error
586 | controlnet.to(accelerator.device)
587 |
588 | # trainable params
589 | params_to_opt = itertools.chain(ip_adapter.feature_proj_model.parameters(),
590 | ip_adapter.adapter_modules.parameters(),
591 | ip_adapter.controlnet.parameters())
592 |
593 | optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
594 |
595 | # dataloader
596 | train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, tokenizer_2=tokenizer_2, size=args.resolution,
597 | center_crop=args.center_crop, image_root_path=args.data_root_path)
598 | total_data_size = len(train_dataset)
599 |
600 | train_dataloader = torch.utils.data.DataLoader(
601 | train_dataset,
602 | shuffle=True,
603 | collate_fn=collate_fn,
604 | batch_size=args.train_batch_size,
605 | num_workers=args.dataloader_num_workers,
606 | )
607 |
608 | # Prepare everything with our `accelerator`.
609 | ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
610 |
611 | # # Restore checkpoints
612 | # checkpoint_folders = [folder for folder in os.listdir(args.output_dir) if folder.startswith('checkpoint-')]
613 | # if checkpoint_folders:
614 | # # Extract step numbers from all checkpoints and find the maximum step number
615 | # global_step = max(int(folder.split('-')[-1]) for folder in checkpoint_folders if folder.split('-')[-1].isdigit())
616 | # checkpoint_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
617 | # # Load the checkpoint
618 | # accelerator.load_state(checkpoint_path)
619 | # else:
620 | # global_step = 0
621 | # print("No checkpoint folders found.")
622 | global_step = 0
623 | # Calculate steps per epoch and the current epoch and its step number
624 | # steps_per_epoch = total_data_size // (args.train_batch_size * num_devices)
625 | # current_epoch = global_step // steps_per_epoch
626 | # current_step_in_epoch = global_step % steps_per_epoch
627 |
628 | # Training loop
629 | for epoch in range(0, args.num_train_epochs):
630 | begin = time.perf_counter()
631 | for step, batch in enumerate(train_dataloader):
632 | load_data_time = time.perf_counter() - begin
633 | with accelerator.accumulate(ip_adapter):
634 | # Convert images to latent space
635 | with torch.no_grad():
636 | # vae of sdxl should use fp32
637 | latents = vae.encode(
638 | batch["images"].to(accelerator.device, dtype=torch.float32)).latent_dist.sample()
639 | latents = latents * vae.config.scaling_factor
640 | latents = latents.to(accelerator.device, dtype=weight_dtype)
641 |
642 | # Sample noise that we'll add to the latents
643 | noise = torch.randn_like(latents)
644 | if args.noise_offset:
645 | # https://www.crosslabs.org//blog/diffusion-with-offset-noise
646 | noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(
647 | accelerator.device, dtype=weight_dtype)
648 |
649 | bsz = latents.shape[0]
650 | # Sample a random timestep for each image
651 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
652 | timesteps = timesteps.long()
653 |
654 | # Add noise to the latents according to the noise magnitude at each timestep
655 | # (this is the forward diffusion process)
656 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
657 |
658 | # get feature embeddings, with cfg
659 | feat_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
660 | kps_images = batch["kps_images"].to(accelerator.device, dtype=weight_dtype)
661 |
662 | # for other experiments
663 | # clip_images = []
664 | # for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
665 | # if drop_image_embed == 1:
666 | # clip_images.append(torch.zeros_like(clip_image))
667 | # else:
668 | # clip_images.append(clip_image)
669 | # clip_images = torch.stack(clip_images, dim=0)
670 | # with torch.no_grad():
671 | # image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype),
672 | # output_hidden_states=True).hidden_states[-2]
673 |
674 | with torch.no_grad():
675 | encoder_output = text_encoder(batch['text_input_ids'].to(accelerator.device), output_hidden_states=True)
676 | text_embeds = encoder_output.hidden_states[-2]
677 | encoder_output_2 = text_encoder_2(batch['text_input_ids_2'].to(accelerator.device), output_hidden_states=True)
678 | pooled_text_embeds = encoder_output_2[0]
679 | text_embeds_2 = encoder_output_2.hidden_states[-2]
680 | text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
681 |
682 | # add cond
683 | add_time_ids = [
684 | batch["original_size"].to(accelerator.device),
685 | batch["crop_coords_top_left"].to(accelerator.device),
686 | batch["target_size"].to(accelerator.device),
687 | ]
688 | add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
689 | unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
690 |
691 | noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, feat_embeds, kps_images)
692 |
693 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
694 |
695 | # Gather the losses across all processes for logging (if we use distributed training).
696 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
697 |
698 | # Backpropagate
699 | accelerator.backward(loss)
700 | optimizer.step()
701 | optimizer.zero_grad()
702 |
703 | now = datetime.now()
704 | formatted_time = now.strftime('%Y-%m-%d %H:%M:%S.%f')[:-3]
705 | if accelerator.is_main_process and step % 10 == 0:
706 | print("[{}]: Epoch {}, global_step {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
707 | formatted_time, epoch, global_step, step, load_data_time, time.perf_counter() - begin,
708 | avg_loss))
709 |
710 | global_step += 1
711 | if accelerator.is_main_process and global_step % args.save_steps == 0:
712 | # before saving state, check if this save would set us over the `checkpoints_total_limit`
713 | if args.checkpoints_total_limit is not None:
714 | checkpoints = os.listdir(args.output_dir)
715 | checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
716 | checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
717 | # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
718 | if len(checkpoints) >= args.checkpoints_total_limit:
719 | num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
720 | removing_checkpoints = checkpoints[0:num_to_remove]
721 | print(
722 | f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints")
723 | print(f"removing checkpoints: {', '.join(removing_checkpoints)}")
724 |
725 | for removing_checkpoint in removing_checkpoints:
726 | removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
727 | shutil.rmtree(removing_checkpoint)
728 |
729 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
730 | accelerator.save_state(save_path)
731 |
732 | begin = time.perf_counter()
733 |
734 |
735 | if __name__ == "__main__":
736 | main()
737 |
--------------------------------------------------------------------------------
/train_instantId_sdxl.sh:
--------------------------------------------------------------------------------
1 | # SDXL Model
2 | export MODEL_NAME=.../stable-diffusion-xl-base-1.0
3 | # CLIP Model
4 | export ENCODER_NAME=".../image_encoder"
5 | # pretrained InstantID model
6 | export ADAPTOR_NAME=".../checkpoints/ip-adapter.bin"
7 | export CONTROLNET_NAME=".../checkpoints/ControlNetModel"
8 |
9 | # This json file ' format:
10 | # {"file_name": "/data/train_data/images_part0/84634599103.jpg", "additional_feature": "myolv1,a man with glasses and a
11 | # tie on posing for a picture in front of a window with a building in the background, Andrew Law, johnson ting, a picture,
12 | # mannerism", "bbox": [-31.329412311315536, 160.6865997314453, 496.19240215420723, 688.1674156188965],
13 | # "landmarks": [[133.046875, 318], [319.3125, 318], [221.0625, 422], [153.515625, 535], [298.84375, 537]],
14 | # "insightface_feature_file": "/data/feature_data/images_part0/84634599103.bin"}
15 | export JSON_FILE=".../CrossFaceID.jsonl"
16 |
17 |
18 | # Output
19 | export OUTPUT_DIR="..."
20 |
21 |
22 | echo "OUTPUT_DIR: $OUTPUT_DIR"
23 | #accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
24 | #CUDA_VISIBLE_DEVICES=0 \
25 |
26 | accelerate launch --mixed_precision="fp16" train_instantId_sdxl.py \
27 | --pretrained_model_name_or_path $MODEL_NAME \
28 | --controlnet_model_name_or_path $CONTROLNET_NAME \
29 | --image_encoder_path $ENCODER_NAME \
30 | --pretrained_ip_adapter_path $ADAPTOR_NAME \
31 | --data_json_file $JSON_FILE \
32 | --output_dir $OUTPUT_DIR \
33 | --clip_proc_mode orig_crop \
34 | --mixed_precision="fp16" \
35 | --resolution 512 \
36 | --learning_rate 1e-5 \
37 | --weight_decay=0.01 \
38 | --num_train_epochs 5 \
39 | --train_batch_size 8 \
40 | --dataloader_num_workers=8 \
41 | --checkpoints_total_limit 20 \
42 | --save_steps 10000
43 |
44 |
45 |
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | import torch
5 | from torchvision import transforms
6 | from PIL import Image
7 | from transformers import CLIPImageProcessor
8 | import numpy as np
9 | import cv2
10 | import math
11 | import PIL
12 |
13 |
14 | def crop_with_expanded_size(image, crop_coords, expand_factor=1.1):
15 | # 打开图像文件
16 | # img = Image.open(path)
17 |
18 | # 原始图像尺寸
19 | original_width, original_height = image.size
20 |
21 | # 已知的裁剪坐标 (left, top, right, bottom)
22 | # crop_coords = (left, top, right, bottom)
23 |
24 | # 计算原始裁剪区域的中心点
25 | center_x = (crop_coords[0] + crop_coords[2]) / 2
26 | center_y = (crop_coords[1] + crop_coords[3]) / 2
27 |
28 | # 计算原始裁剪区域的宽度和高度
29 | original_crop_width = crop_coords[2] - crop_coords[0]
30 | original_crop_height = crop_coords[3] - crop_coords[1]
31 |
32 | # 计算新的裁剪区域的宽度和高度
33 | new_crop_width = original_crop_width * expand_factor
34 | new_crop_height = original_crop_height * expand_factor
35 |
36 | # 计算新的裁剪坐标,确保不会超出图像边界
37 | new_left = max(center_x - new_crop_width / 2, 0)
38 | new_top = max(center_y - new_crop_height / 2, 0)
39 | new_right = min(center_x + new_crop_width / 2, original_width)
40 | new_bottom = min(center_y + new_crop_height / 2, original_height)
41 |
42 | # 新的裁剪坐标
43 | new_crop_coords = (int(new_left), int(new_top), int(new_right), int(new_bottom))
44 |
45 | # 裁剪图像
46 | cropped_img = image.crop(new_crop_coords)
47 |
48 | return cropped_img
49 |
50 |
51 | class CropToRatioTransform(object):
52 | def __init__(self, target_aspect_ratio=512 / 640):
53 | self.target_aspect_ratio = target_aspect_ratio
54 |
55 | def __call__(self, img):
56 | # 计算当前宽高比
57 | current_w, current_h = img.size
58 | current_aspect_ratio = current_w / current_h
59 | # print(current_aspect_ratio)
60 |
61 | # 如果当前宽高比大于目标宽高比,则截取宽度至目标宽高比
62 | if current_aspect_ratio > self.target_aspect_ratio:
63 | # 计算目标宽度
64 | target_w = int(current_h * self.target_aspect_ratio)
65 | # 计算需要截取的区域
66 | left = (current_w - target_w) // 2
67 | right = left + target_w
68 | # 截取图像
69 | img = img.crop((left, 0, right, current_h))
70 |
71 | return img
72 |
73 |
74 | class TopCropTransform(object):
75 | def __init__(self, crop_size):
76 | # crop_size可以是单个整数或包含两个整数的元组/列表
77 | if isinstance(crop_size, int):
78 | self.crop_height = crop_size
79 | self.crop_width = crop_size
80 | elif isinstance(crop_size, (list, tuple)) and len(crop_size) == 2:
81 | self.crop_height, self.crop_width = crop_size
82 | else:
83 | raise TypeError('crop_size must be an int or a list/tuple of length 2.')
84 |
85 | def __call__(self, img):
86 | # 检查提供的crop_size是否不大于图像的尺寸
87 | w, h = img.size
88 | if self.crop_width > w or self.crop_height > h:
89 | raise ValueError('crop_size must be smaller than the dimensions of the image.')
90 |
91 | top = 0
92 | center = w // 2
93 | crop_width, crop_height = self.crop_width, self.crop_height
94 | left = center - crop_width // 2
95 |
96 | # 防止坐标超出图像边界
97 | left = max(0, left)
98 | right = min(w, left + crop_width)
99 | bottom = min(h, top + crop_height)
100 |
101 | # 执行裁剪
102 | img = img.crop((left, top, right, bottom))
103 | return img
104 |
105 |
106 | def draw_kps(image_pil, kps, color_list=[(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (255, 0, 255)]):
107 | stickwidth = 4
108 | limbSeq = np.array([[0, 2], [1, 2], [3, 2], [4, 2]])
109 | kps = np.array(kps)
110 |
111 | w, h = image_pil.size
112 | out_img = np.zeros([h, w, 3])
113 |
114 | for i in range(len(limbSeq)):
115 | index = limbSeq[i]
116 | color = color_list[index[0]]
117 |
118 | x = kps[index][:, 0]
119 | y = kps[index][:, 1]
120 | length = ((x[0] - x[1]) ** 2 + (y[0] - y[1]) ** 2) ** 0.5
121 | angle = math.degrees(math.atan2(y[0] - y[1], x[0] - x[1]))
122 | polygon = cv2.ellipse2Poly((int(np.mean(x)), int(np.mean(y))), (int(length / 2), stickwidth), int(angle), 0,
123 | 360, 1)
124 | out_img = cv2.fillConvexPoly(out_img.copy(), polygon, color)
125 | out_img = (out_img * 0.6).astype(np.uint8)
126 |
127 | for idx_kp, kp in enumerate(kps):
128 | color = color_list[idx_kp]
129 | x, y = kp
130 | out_img = cv2.circle(out_img.copy(), (int(x), int(y)), 10, color, -1)
131 |
132 | out_img_pil = PIL.Image.fromarray(out_img.astype(np.uint8))
133 | return out_img_pil
134 |
135 | class MyDataset(torch.utils.data.Dataset):
136 | def __init__(self, json_file, tokenizer, size=512, crop_size=(640, 512), center_crop=False, image_root_path=""):
137 | super().__init__()
138 | self.tokenizer = tokenizer
139 | self.size = size #短边缩放到size
140 | self.image_root_path = image_root_path
141 | # 创建一个空列表来存储解析后的数据
142 | self.data = []
143 | # 读取并解析JSON文件的每一行
144 | with open(json_file, 'r') as f:
145 | for line in f:
146 | # 解析JSON数据并添加到列表中
147 | self.data.append(json.loads(line))
148 |
149 | self.transform = transforms.Compose([
150 | CropToRatioTransform(target_aspect_ratio=crop_size[1] / crop_size[0]),
151 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
152 | transforms.CenterCrop(size) if center_crop else TopCropTransform(crop_size),
153 | transforms.ToTensor(),
154 | transforms.Normalize([0.5], [0.5]),
155 | ])
156 |
157 | self.kps_transform = transforms.Compose([
158 | CropToRatioTransform(target_aspect_ratio=crop_size[1] / crop_size[0]),
159 | transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
160 | transforms.CenterCrop(size) if center_crop else TopCropTransform(crop_size),
161 | transforms.ToTensor(),
162 | ])
163 |
164 | self.clip_image_processor = CLIPImageProcessor()
165 |
166 | def __getitem__(self, idx):
167 | item = self.data[idx]
168 | image_file = item["file_name"]
169 | text = item["additional_feature"]
170 | bbox = item['bbox']
171 | landmarks = item['landmarks']
172 | feature_file = item["penult_id_embed_file"]
173 | clip_from_seg_file = item["clip_from_seg_file"]
174 | clip_from_orig_file = item["clip_from_orig_file"]
175 | seg_map_orig_file = item["seg_map_orig_file"]
176 |
177 | # read image
178 | raw_image = Image.open(os.path.join(self.image_root_path, image_file))
179 | image = self.transform(raw_image.convert("RGB"))
180 | kps_image = draw_kps(raw_image.convert("RGB"), landmarks)
181 | kps_image = self.kps_transform(kps_image)
182 |
183 | # crop image to clip
184 | crop_image = crop_with_expanded_size(raw_image, bbox)
185 | clip_image = self.clip_image_processor(images=crop_image, return_tensors="pt").pixel_values
186 | # load face feature
187 | face_id_embed = torch.load(os.path.join(self.image_root_path, feature_file), map_location="cpu")
188 | face_id_embed = torch.from_numpy(face_id_embed)
189 |
190 | # 定义所有可能的丢弃组合及其概率
191 | drop_combinations = {
192 | ('text',): 0.05,
193 | ('feature',): 0.05,
194 | ('feature', 'text'): 0.05,
195 | }
196 | # drop_combinations = {
197 | # ('text',): 0.05,
198 | # ('feature',): 0.04,
199 | # ('image',): 0.04,
200 | # ('image', 'feature'): 0.03,
201 | # ('image', 'text'): 0.03,
202 | # ('feature', 'text'): 0.03,
203 | # ('image', 'feature', 'text'): 0.03
204 | # }
205 | # 计算剩余概率
206 | remaining_probability = 1 - sum(drop_combinations.values())
207 | # 添加新的键值对,对应不丢弃任何条件
208 | drop_combinations[()] = remaining_probability
209 | # 根据概率选择一个丢弃组合
210 | drop_choice = random.choices(list(drop_combinations.keys()), weights=list(drop_combinations.values()), k=1)[0]
211 | # 根据选择的组合来丢弃对象
212 | drop_text_embed = int('text' in drop_choice)
213 | drop_feature_embed = int('feature' in drop_choice)
214 | drop_image_embed = int('image' in drop_choice)
215 |
216 | # CFG处理
217 | if drop_text_embed:
218 | text = ""
219 | if drop_feature_embed:
220 | face_id_embed = torch.zeros_like(face_id_embed)
221 | if drop_image_embed:
222 | pass # drop in train loop
223 |
224 | # get text and tokenize
225 | text_input_ids = self.tokenizer(
226 | text,
227 | max_length=self.tokenizer.model_max_length,
228 | padding="max_length",
229 | truncation=True,
230 | return_tensors="pt"
231 | ).input_ids
232 |
233 | return {
234 | "image": image,
235 | "kps_image": kps_image,
236 | "clip_image": clip_image,
237 | "text_input_ids": text_input_ids,
238 | "face_id_embed": face_id_embed,
239 | "drop_image_embed": drop_image_embed
240 | }
241 |
242 | def __len__(self):
243 | return len(self.data)
244 |
245 |
246 | def collate_fn(data):
247 | images = torch.stack([example["image"] for example in data])
248 | kps_images = torch.stack([example["kps_image"] for example in data])
249 |
250 | clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
251 |
252 | text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
253 | face_id_embed = torch.stack([example["face_id_embed"] for example in data])
254 | drop_image_embeds = [example["drop_image_embed"] for example in data]
255 |
256 | return {
257 | "images": images,
258 | "kps_images": kps_images,
259 | "clip_images": clip_images,
260 | "text_input_ids": text_input_ids,
261 | "face_id_embed": face_id_embed,
262 | "drop_image_embeds": drop_image_embeds
263 | }
264 |
--------------------------------------------------------------------------------