├── .gitmodules
├── LICENSE
├── README.md
├── arc2face
├── __init__.py
├── models.py
└── utils.py
├── assets
├── controlnet.jpg
├── examples
│ ├── freddie.png
│ ├── freeman.jpg
│ ├── hepburn.png
│ ├── jackie.png
│ ├── joacquin.png
│ ├── lily.png
│ ├── pose1.png
│ ├── pose2.png
│ └── pose3.jpg
├── samples.jpg
└── teaser.gif
├── gradio_demo
├── app.py
└── app_controlnet.py
├── requirements.txt
└── requirements_controlnet.txt
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "external/emoca"]
2 | path = external/emoca
3 | url = https://github.com/foivospar/emoca.git
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Foivos Paraperas Papantoniou
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Arc2Face: A Foundation Model for ID-Consistent Human Faces
4 |
5 | [Foivos Paraperas Papantoniou](https://foivospar.github.io/)
1 [Alexandros Lattas](https://alexlattas.com/)
1 [Stylianos Moschoglou](https://moschoglou.com/)
1
6 |
7 | [Jiankang Deng](https://jiankangdeng.github.io/)
1 [Bernhard Kainz](https://bernhard-kainz.com/)
1,2 [Stefanos Zafeiriou](https://www.imperial.ac.uk/people/s.zafeiriou)
1
8 |
9 |
1Imperial College London, UK
10 |
2FAU Erlangen-Nürnberg, Germany
11 |
12 |

13 |

14 |

15 |

16 |

17 |
18 |
19 |
20 | This is the official implementation of **[Arc2Face](https://arc2face.github.io/)**, an ID-conditioned face model:
21 |
22 | ✅ that generates high-quality images of any subject given only its ArcFace embedding, within a few seconds
23 | ✅ trained on the large-scale WebFace42M dataset offers superior ID similarity compared to existing models
24 | ✅ built on top of Stable Diffusion, can be extended to different input modalities, e.g. with ControlNet
25 |
26 |
27 |
28 | # News/Updates
29 | [](https://paperswithcode.com/sota/diffusion-personalization-tuning-free-on?p=arc2face-a-foundation-model-of-human-faces)
30 |
31 | - [2024/08/16] 🔥 Accepted to ECCV24 as an **oral**!
32 | - [2024/08/06] 🔥 ComfyUI support available at [caleboleary/ComfyUI-Arc2Face](https://github.com/caleboleary/ComfyUI-Arc2Face)!
33 | - [2024/04/12] 🔥 We add LCM-LoRA support for even faster inference (check the details [below](#lcm-lora-acceleration)).
34 | - [2024/04/11] 🔥 We release the training dataset on [HuggingFace Datasets](https://huggingface.co/datasets/FoivosPar/Arc2Face).
35 | - [2024/03/31] 🔥 We release our demo for pose control using Arc2Face + ControlNet (see instructions [below](#arc2face--controlnet-pose)).
36 | - [2024/03/28] 🔥 We release our Gradio [demo](https://huggingface.co/spaces/FoivosPar/Arc2Face) on HuggingFace Spaces (thanks to the HF team for their free GPU support)!
37 | - [2024/03/14] 🔥 We release Arc2Face.
38 |
39 | # Installation
40 | ```bash
41 | conda create -n arc2face python=3.10
42 | conda activate arc2face
43 |
44 | # Install requirements
45 | pip install -r requirements.txt
46 | ```
47 |
48 | # Download Models
49 | 1) The models can be downloaded manually from [HuggingFace](https://huggingface.co/FoivosPar/Arc2Face) or using python:
50 | ```python
51 | from huggingface_hub import hf_hub_download
52 |
53 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arc2face/config.json", local_dir="./models")
54 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arc2face/diffusion_pytorch_model.safetensors", local_dir="./models")
55 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="encoder/config.json", local_dir="./models")
56 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="encoder/pytorch_model.bin", local_dir="./models")
57 | ```
58 |
59 | 2) For face detection and ID-embedding extraction, manually download the [antelopev2](https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo) package ([direct link](https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view)) and place the checkpoints under `models/antelopev2`.
60 |
61 | 3) We use an ArcFace recognition model trained on WebFace42M. Download `arcface.onnx` from [HuggingFace](https://huggingface.co/FoivosPar/Arc2Face) and put it in `models/antelopev2` or using python:
62 | ```python
63 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arcface.onnx", local_dir="./models/antelopev2")
64 | ```
65 | 4) Then **delete** `glintr100.onnx` (the default backbone from insightface).
66 |
67 | The `models` folder structure should finally be:
68 | ```
69 | . ── models ──┌── antelopev2
70 | ├── arc2face
71 | └── encoder
72 | ```
73 |
74 | # Usage
75 |
76 | Load pipeline using [diffusers](https://huggingface.co/docs/diffusers/index):
77 | ```python
78 | from diffusers import (
79 | StableDiffusionPipeline,
80 | UNet2DConditionModel,
81 | DPMSolverMultistepScheduler,
82 | )
83 |
84 | from arc2face import CLIPTextModelWrapper, project_face_embs
85 |
86 | import torch
87 | from insightface.app import FaceAnalysis
88 | from PIL import Image
89 | import numpy as np
90 |
91 | # Arc2Face is built upon SD1.5
92 | # The repo below can be used instead of the now deprecated 'runwayml/stable-diffusion-v1-5'
93 | base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
94 |
95 | encoder = CLIPTextModelWrapper.from_pretrained(
96 | 'models', subfolder="encoder", torch_dtype=torch.float16
97 | )
98 |
99 | unet = UNet2DConditionModel.from_pretrained(
100 | 'models', subfolder="arc2face", torch_dtype=torch.float16
101 | )
102 |
103 | pipeline = StableDiffusionPipeline.from_pretrained(
104 | base_model,
105 | text_encoder=encoder,
106 | unet=unet,
107 | torch_dtype=torch.float16,
108 | safety_checker=None
109 | )
110 | ```
111 | You can use any SD-compatible schedulers and steps, just like with Stable Diffusion. By default, we use `DPMSolverMultistepScheduler` with 25 steps, which produces very good results in just a few seconds.
112 | ```python
113 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
114 | pipeline = pipeline.to('cuda')
115 | ```
116 | Pick an image and extract the ID-embedding:
117 | ```python
118 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
119 | app.prepare(ctx_id=0, det_size=(640, 640))
120 |
121 | img = np.array(Image.open('assets/examples/joacquin.png'))[:,:,::-1]
122 |
123 | faces = app.get(img)
124 | faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected)
125 | id_emb = torch.tensor(faces['embedding'], dtype=torch.float16)[None].cuda()
126 | id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding
127 | id_emb = project_face_embs(pipeline, id_emb) # pass through the encoder
128 | ```
129 |
130 |
131 |

132 |
133 |
134 | Generate images:
135 | ```python
136 | num_images = 4
137 | images = pipeline(prompt_embeds=id_emb, num_inference_steps=25, guidance_scale=3.0, num_images_per_prompt=num_images).images
138 | ```
139 |
140 |

141 |
142 |
143 | # LCM-LoRA acceleration
144 |
145 | [LCM-LoRA](https://arxiv.org/abs/2311.05556) allows you to reduce the sampling steps to as few as 2-4 for super-fast inference. Just plug in the pre-trained distillation adapter for SD v1.5 and switch to `LCMScheduler`:
146 | ```python
147 | from diffusers import LCMScheduler
148 |
149 | pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
150 | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
151 | ```
152 | Then, you can sample with as few as 2 steps (and disable `guidance_scale` by using a value of 1.0, as LCM is very sensitive to it and even small values lead to oversaturation):
153 | ```python
154 | images = pipeline(prompt_embeds=id_emb, num_inference_steps=2, guidance_scale=1.0, num_images_per_prompt=num_images).images
155 | ```
156 | Note that this technique accelerates sampling in exchange for a slight drop in quality.
157 |
158 | # Start a local gradio demo
159 | You can start a local demo for inference by running:
160 | ```python
161 | python gradio_demo/app.py
162 | ```
163 |
164 | # Arc2Face + ControlNet (pose)
165 |
166 |

167 |
168 |
169 | We provide a ControlNet model trained on top of Arc2Face for pose control. We use [EMOCA](https://github.com/radekd91/emoca) for 3D pose extraction. To run our demo, follow the steps below:
170 | ### 1) Download Model
171 | Download the ControlNet checkpoint manually from [HuggingFace](https://huggingface.co/FoivosPar/Arc2Face) or using python:
172 | ```python
173 | from huggingface_hub import hf_hub_download
174 |
175 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="controlnet/config.json", local_dir="./models")
176 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="controlnet/diffusion_pytorch_model.safetensors", local_dir="./models")
177 | ```
178 | ### 2) Pull EMOCA
179 | ```bash
180 | git submodule update --init external/emoca
181 | ```
182 | ### 3) Installation
183 | This is the most tricky part. You will need PyTorch3D to run EMOCA. As its installation may cause conflicts, we suggest to follow the process below:
184 | 1) Create a new environment and start by installing PyTorch3D with GPU support first (follow the official [instructions](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md)).
185 | 2) Add Arc2Face + EMOCA requirements with:
186 | ```bash
187 | pip install -r requirements_controlnet.txt
188 | ```
189 | 3) Install EMOCA code:
190 | ```bash
191 | pip install -e external/emoca
192 | ```
193 | 4) Finally, you need to download the EMOCA/FLAME assets. Run the following and follow the instructions in the terminal:
194 | ```bash
195 | cd external/emoca/gdl_apps/EMOCA/demos
196 | bash download_assets.sh
197 | cd ../../../../..
198 | ```
199 | ### 4) Start a local gradio demo
200 | You can start a local ControlNet demo by running:
201 | ```python
202 | python gradio_demo/app_controlnet.py
203 | ```
204 |
205 | # Test Data
206 | The test images used for comparisons in the paper (Synth-500, AgeDB) are available [here](https://drive.google.com/drive/folders/1exnvCECmqWcqNIFCck2EQD-hkE42Ayjc?usp=sharing). Please use them only for evaluation purposes and make sure to cite the corresponding [sources](https://ibug.doc.ic.ac.uk/resources/agedb/) when using them.
207 |
208 | # Community Resources
209 |
210 | ### Replicate Demo
211 | - [Demo link](https://replicate.com/camenduru/arc2face) by [@camenduru](https://github.com/camenduru).
212 |
213 | ### ComfyUI
214 | - [caleboleary/ComfyUI-Arc2Face](https://github.com/caleboleary/ComfyUI-Arc2Face) by [@caleboleary](https://github.com/caleboleary).
215 |
216 | ### Pinokio
217 | - Pinokio [implementation](https://pinokio.computer/item?uri=https://github.com/cocktailpeanutlabs/arc2face) by [@cocktailpeanut](https://github.com/cocktailpeanut) (runs locally on all OS - Windows, Mac, Linux).
218 |
219 | # Acknowledgements
220 | - Thanks to the creators of Stable Diffusion and the HuggingFace [diffusers](https://github.com/huggingface/diffusers) team for the awesome work ❤️.
221 | - Thanks to the WebFace42M creators for providing such a million-scale facial dataset ❤️.
222 | - Thanks to the HuggingFace team for their generous support through the community GPU grant for our demo ❤️.
223 | - We also acknowledge the invaluable support of the HPC resources provided by the Erlangen National High Performance Computing Center (NHR@FAU) of the Friedrich-Alexander-Universität Erlangen-Nürnberg (FAU), which made the training of Arc2Face possible.
224 |
225 | # Citation
226 | If you find Arc2Face useful for your research, please consider citing us:
227 |
228 | ```bibtex
229 | @inproceedings{paraperas2024arc2face,
230 | title={Arc2Face: A Foundation Model for ID-Consistent Human Faces},
231 | author={Paraperas Papantoniou, Foivos and Lattas, Alexandros and Moschoglou, Stylianos and Deng, Jiankang and Kainz, Bernhard and Zafeiriou, Stefanos},
232 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
233 | year={2024}
234 | }
235 | ```
236 |
--------------------------------------------------------------------------------
/arc2face/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import CLIPTextModelWrapper
2 | from .utils import project_face_embs, image_align
--------------------------------------------------------------------------------
/arc2face/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import CLIPTextModel
3 | from typing import Any, Callable, Dict, Optional, Tuple, Union, List
4 | from transformers.modeling_outputs import BaseModelOutputWithPooling
5 | from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
6 |
7 |
8 | class CLIPTextModelWrapper(CLIPTextModel):
9 | # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812
10 | # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them.
11 | def forward(
12 | self,
13 | input_ids: Optional[torch.Tensor] = None,
14 | attention_mask: Optional[torch.Tensor] = None,
15 | position_ids: Optional[torch.Tensor] = None,
16 | output_attentions: Optional[bool] = None,
17 | output_hidden_states: Optional[bool] = None,
18 | return_dict: Optional[bool] = None,
19 | input_token_embs: Optional[torch.Tensor] = None,
20 | return_token_embs: Optional[bool] = False,
21 | ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]:
22 |
23 | if return_token_embs:
24 | return self.text_model.embeddings.token_embedding(input_ids)
25 |
26 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
27 |
28 | output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions
29 | output_hidden_states = (
30 | output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states
31 | )
32 | return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict
33 |
34 | if input_ids is None:
35 | raise ValueError("You have to specify input_ids")
36 |
37 | input_shape = input_ids.size()
38 | input_ids = input_ids.view(-1, input_shape[-1])
39 |
40 | hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs)
41 |
42 | # CLIP's text model uses causal mask, prepare it here.
43 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
44 | causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
45 | # expand attention_mask
46 | if attention_mask is not None:
47 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
48 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
49 |
50 | encoder_outputs = self.text_model.encoder(
51 | inputs_embeds=hidden_states,
52 | attention_mask=attention_mask,
53 | causal_attention_mask=causal_attention_mask,
54 | output_attentions=output_attentions,
55 | output_hidden_states=output_hidden_states,
56 | return_dict=return_dict,
57 | )
58 |
59 | last_hidden_state = encoder_outputs[0]
60 | last_hidden_state = self.text_model.final_layer_norm(last_hidden_state)
61 |
62 | if self.text_model.eos_token_id == 2:
63 | # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here.
64 | # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added
65 | # ------------------------------------------------------------
66 | # text_embeds.shape = [batch_size, sequence_length, transformer.width]
67 | # take features from the eot embedding (eot_token is the highest number in each sequence)
68 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
69 | pooled_output = last_hidden_state[
70 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
71 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
72 | ]
73 | else:
74 | # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible)
75 | pooled_output = last_hidden_state[
76 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
77 | # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
78 | (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id)
79 | .int()
80 | .argmax(dim=-1),
81 | ]
82 |
83 | if not return_dict:
84 | return (last_hidden_state, pooled_output) + encoder_outputs[1:]
85 |
86 | return BaseModelOutputWithPooling(
87 | last_hidden_state=last_hidden_state,
88 | pooler_output=pooled_output,
89 | hidden_states=encoder_outputs.hidden_states,
90 | attentions=encoder_outputs.attentions,
91 | )
--------------------------------------------------------------------------------
/arc2face/utils.py:
--------------------------------------------------------------------------------
1 | import scipy
2 | import PIL
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 |
7 | @torch.no_grad()
8 | def project_face_embs(pipeline, face_embs):
9 |
10 | '''
11 | face_embs: (N, 512) normalized ArcFace embeddings
12 | '''
13 |
14 | arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0]
15 |
16 | input_ids = pipeline.tokenizer(
17 | "photo of a id person",
18 | truncation=True,
19 | padding="max_length",
20 | max_length=pipeline.tokenizer.model_max_length,
21 | return_tensors="pt",
22 | ).input_ids.to(pipeline.device)
23 |
24 | face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0)
25 | token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True)
26 | token_embs[input_ids==arcface_token_id] = face_embs_padded
27 |
28 | prompt_embeds = pipeline.text_encoder(
29 | input_ids=input_ids,
30 | input_token_embs=token_embs
31 | )[0]
32 |
33 | return prompt_embeds
34 |
35 |
36 | def image_align(img,
37 | face_landmarks,
38 | output_size=1024,
39 | transform_size=4096,
40 | enable_padding=True):
41 | # Align function from FFHQ dataset pre-processing step
42 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
43 |
44 | lm = face_landmarks
45 | lm_eye_left = lm[36:42] # left-clockwise
46 | lm_eye_right = lm[42:48] # left-clockwise
47 | lm_mouth_outer = lm[48:60] # left-clockwise
48 |
49 | # Calculate auxiliary vectors.
50 | eye_left = np.mean(lm_eye_left, axis=0)
51 | eye_right = np.mean(lm_eye_right, axis=0)
52 | eye_avg = (eye_left + eye_right) * 0.5
53 | eye_to_eye = eye_right - eye_left
54 | mouth_left = lm_mouth_outer[0]
55 | mouth_right = lm_mouth_outer[6]
56 | mouth_avg = (mouth_left + mouth_right) * 0.5
57 | eye_to_mouth = mouth_avg - eye_avg
58 |
59 | # Choose oriented crop rectangle.
60 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
61 | x /= np.hypot(*x)
62 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
63 | y = np.flipud(x) * [-1, 1]
64 | c = eye_avg + eye_to_mouth * 0.1
65 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
66 | qsize = np.hypot(*x) * 2
67 |
68 | img = img.convert('RGB')
69 |
70 | # Shrink.
71 | shrink = int(np.floor(qsize / output_size * 0.5))
72 | if shrink > 1:
73 | rsize = (int(np.rint(float(img.size[0]) / shrink)),
74 | int(np.rint(float(img.size[1]) / shrink)))
75 | img = img.resize(rsize, PIL.Image.LANCZOS)
76 | quad /= shrink
77 | qsize /= shrink
78 |
79 | # Crop.
80 | border = max(int(np.rint(qsize * 0.1)), 3)
81 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
82 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
83 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
84 | min(crop[2] + border,
85 | img.size[0]), min(crop[3] + border, img.size[1]))
86 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
87 | img = img.crop(crop)
88 | quad -= crop[0:2]
89 |
90 | # Pad.
91 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
92 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
93 | pad = (max(-pad[0] + border,
94 | 0), max(-pad[1] + border,
95 | 0), max(pad[2] - img.size[0] + border,
96 | 0), max(pad[3] - img.size[1] + border, 0))
97 | if enable_padding and max(pad) > border - 4:
98 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
99 | img = np.pad(np.float32(img),
100 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
101 | h, w, _ = img.shape
102 | y, x, _ = np.ogrid[:h, :w, :1]
103 | mask = np.maximum(
104 | 1.0 -
105 | np.minimum(np.float32(x) / pad[0],
106 | np.float32(w - 1 - x) / pad[2]), 1.0 -
107 | np.minimum(np.float32(y) / pad[1],
108 | np.float32(h - 1 - y) / pad[3]))
109 | blur = qsize * 0.02
110 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
111 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
112 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
113 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)),
114 | 'RGB')
115 | quad += pad[:2]
116 |
117 |
118 | # Transform.
119 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
120 | (quad + 0.5).flatten(), PIL.Image.BILINEAR)
121 | if output_size < transform_size:
122 | img = img.resize((output_size, output_size), PIL.Image.LANCZOS)
123 |
124 | return img
--------------------------------------------------------------------------------
/assets/controlnet.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/controlnet.jpg
--------------------------------------------------------------------------------
/assets/examples/freddie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/freddie.png
--------------------------------------------------------------------------------
/assets/examples/freeman.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/freeman.jpg
--------------------------------------------------------------------------------
/assets/examples/hepburn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/hepburn.png
--------------------------------------------------------------------------------
/assets/examples/jackie.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/jackie.png
--------------------------------------------------------------------------------
/assets/examples/joacquin.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/joacquin.png
--------------------------------------------------------------------------------
/assets/examples/lily.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/lily.png
--------------------------------------------------------------------------------
/assets/examples/pose1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/pose1.png
--------------------------------------------------------------------------------
/assets/examples/pose2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/pose2.png
--------------------------------------------------------------------------------
/assets/examples/pose3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/pose3.jpg
--------------------------------------------------------------------------------
/assets/samples.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/samples.jpg
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/teaser.gif
--------------------------------------------------------------------------------
/gradio_demo/app.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 |
4 | from diffusers import (
5 | StableDiffusionPipeline,
6 | UNet2DConditionModel,
7 | DPMSolverMultistepScheduler,
8 | LCMScheduler
9 | )
10 |
11 | from arc2face import CLIPTextModelWrapper, project_face_embs
12 |
13 | import torch
14 | from insightface.app import FaceAnalysis
15 | from PIL import Image
16 | import numpy as np
17 | import random
18 |
19 | import gradio as gr
20 |
21 | # global variable
22 | MAX_SEED = np.iinfo(np.int32).max
23 | if torch.cuda.is_available():
24 | device = "cuda"
25 | dtype = torch.float16
26 | else:
27 | device = "cpu"
28 | dtype = torch.float32
29 |
30 | # Load face detection and recognition package
31 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
32 | app.prepare(ctx_id=0, det_size=(640, 640))
33 |
34 | # Load pipeline
35 | base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
36 | encoder = CLIPTextModelWrapper.from_pretrained(
37 | 'models', subfolder="encoder", torch_dtype=dtype
38 | )
39 | unet = UNet2DConditionModel.from_pretrained(
40 | 'models', subfolder="arc2face", torch_dtype=dtype
41 | )
42 | pipeline = StableDiffusionPipeline.from_pretrained(
43 | base_model,
44 | text_encoder=encoder,
45 | unet=unet,
46 | torch_dtype=dtype,
47 | safety_checker=None
48 | )
49 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
50 | pipeline = pipeline.to(device)
51 |
52 | # load and disable LCM
53 | pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
54 | pipeline.disable_lora()
55 |
56 | def toggle_lcm_ui(value):
57 | if value:
58 | return (
59 | gr.update(minimum=1, maximum=20, step=1, value=3),
60 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=1.0),
61 | )
62 | else:
63 | return (
64 | gr.update(minimum=1, maximum=100, step=1, value=25),
65 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=3.0),
66 | )
67 |
68 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
69 | if randomize_seed:
70 | seed = random.randint(0, MAX_SEED)
71 | return seed
72 |
73 | def get_example():
74 | case = [
75 | [
76 | './assets/examples/freeman.jpg',
77 | ],
78 | [
79 | './assets/examples/lily.png',
80 | ],
81 | [
82 | './assets/examples/joacquin.png',
83 | ],
84 | [
85 | './assets/examples/jackie.png',
86 | ],
87 | [
88 | './assets/examples/freddie.png',
89 | ],
90 | [
91 | './assets/examples/hepburn.png',
92 | ],
93 | ]
94 | return case
95 |
96 | def run_example(img_file):
97 | return generate_image(img_file, 25, 3, 23, 2, False)
98 |
99 |
100 | def generate_image(image_path, num_steps, guidance_scale, seed, num_images, use_lcm, progress=gr.Progress(track_tqdm=True)):
101 |
102 | if use_lcm:
103 | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
104 | pipeline.enable_lora()
105 | else:
106 | pipeline.disable_lora()
107 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
108 |
109 | if image_path is None:
110 | raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
111 |
112 | img = np.array(Image.open(image_path))[:,:,::-1]
113 |
114 | # Face detection and ID-embedding extraction
115 | faces = app.get(img)
116 |
117 | if len(faces) == 0:
118 | raise gr.Error(f"Face detection failed! Please try with another image")
119 |
120 | faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected)
121 | id_emb = torch.tensor(faces['embedding'], dtype=dtype)[None].to(device)
122 | id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding
123 | id_emb = project_face_embs(pipeline, id_emb) # pass throught the encoder
124 |
125 | generator = torch.Generator(device=device).manual_seed(seed)
126 |
127 | print("Start inference...")
128 | images = pipeline(
129 | prompt_embeds=id_emb,
130 | num_inference_steps=num_steps,
131 | guidance_scale=guidance_scale,
132 | num_images_per_prompt=num_images,
133 | generator=generator
134 | ).images
135 |
136 | return images
137 |
138 | ### Description
139 | title = r"""
140 | Arc2Face: A Foundation Model for ID-Consistent Human Faces
141 | """
142 |
143 | description = r"""
144 | Official 🤗 Gradio demo for Arc2Face: A Foundation Model for ID-Consistent Human Faces.
145 |
146 | Steps:
147 | 1. Upload an image with a face. If multiple faces are detected, we use the largest one. For images with already tightly cropped faces, detection may fail, try images with a larger margin.
148 | 2. Click Submit to generate new images of the subject.
149 | """
150 |
151 | Footer = r"""
152 | ---
153 | 📝 **Citation**
154 |
155 | If you find Arc2Face helpful for your research, please consider citing our paper:
156 | ```bibtex
157 | @inproceedings{paraperas2024arc2face,
158 | title={Arc2Face: A Foundation Model for ID-Consistent Human Faces},
159 | author={Paraperas Papantoniou, Foivos and Lattas, Alexandros and Moschoglou, Stylianos and Deng, Jiankang and Kainz, Bernhard and Zafeiriou, Stefanos},
160 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
161 | year={2024}
162 | }
163 | ```
164 | """
165 |
166 | css = '''
167 | .gradio-container {width: 85% !important}
168 | '''
169 | with gr.Blocks(css=css) as demo:
170 |
171 | # description
172 | gr.Markdown(title)
173 | gr.Markdown(description)
174 |
175 | with gr.Row():
176 | with gr.Column():
177 |
178 | # upload face image
179 | img_file = gr.Image(label="Upload a photo with a face", type="filepath")
180 |
181 | submit = gr.Button("Submit", variant="primary")
182 |
183 | use_lcm = gr.Checkbox(
184 | label="Use LCM-LoRA to accelerate sampling", value=False,
185 | info="Reduces sampling steps significantly, but may decrease quality.",
186 | )
187 |
188 | with gr.Accordion(open=False, label="Advanced Options"):
189 | num_steps = gr.Slider(
190 | label="Number of sample steps",
191 | minimum=1,
192 | maximum=100,
193 | step=1,
194 | value=25,
195 | )
196 | guidance_scale = gr.Slider(
197 | label="Guidance scale",
198 | minimum=0.1,
199 | maximum=10.0,
200 | step=0.1,
201 | value=3.0,
202 | )
203 | num_images = gr.Slider(
204 | label="Number of output images",
205 | minimum=1,
206 | maximum=4,
207 | step=1,
208 | value=2,
209 | )
210 | seed = gr.Slider(
211 | label="Seed",
212 | minimum=0,
213 | maximum=MAX_SEED,
214 | step=1,
215 | value=0,
216 | )
217 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
218 |
219 | with gr.Column():
220 | gallery = gr.Gallery(label="Generated Images")
221 |
222 | submit.click(
223 | fn=randomize_seed_fn,
224 | inputs=[seed, randomize_seed],
225 | outputs=seed,
226 | queue=False,
227 | api_name=False,
228 | ).then(
229 | fn=generate_image,
230 | inputs=[img_file, num_steps, guidance_scale, seed, num_images, use_lcm],
231 | outputs=[gallery]
232 | )
233 |
234 | use_lcm.input(
235 | fn=toggle_lcm_ui,
236 | inputs=[use_lcm],
237 | outputs=[num_steps, guidance_scale],
238 | queue=False,
239 | )
240 |
241 | gr.Examples(
242 | examples=get_example(),
243 | inputs=[img_file],
244 | run_on_click=True,
245 | fn=run_example,
246 | outputs=[gallery],
247 | )
248 |
249 | gr.Markdown(Footer)
250 |
251 | demo.launch()
--------------------------------------------------------------------------------
/gradio_demo/app_controlnet.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.append('./')
3 |
4 | from diffusers import (
5 | StableDiffusionPipeline,
6 | UNet2DConditionModel,
7 | DPMSolverMultistepScheduler,
8 | LCMScheduler,
9 | ControlNetModel,
10 | StableDiffusionControlNetPipeline
11 | )
12 |
13 | from arc2face import CLIPTextModelWrapper, project_face_embs, image_align
14 |
15 | from gdl.utils.FaceDetector import FAN
16 | from gdl_apps.EMOCA.utils.load import load_model
17 | from gdl.datasets.ImageTestDataset import preprocess_for_emoca
18 |
19 | import torch
20 | from insightface.app import FaceAnalysis
21 | from PIL import Image
22 | import numpy as np
23 | import random
24 |
25 | import gradio as gr
26 |
27 | # global variable
28 | MAX_SEED = np.iinfo(np.int32).max
29 | if torch.cuda.is_available():
30 | device = "cuda"
31 | dtype = torch.float16
32 | else:
33 | device = "cpu"
34 | dtype = torch.float32
35 |
36 | # Load face detection and recognition package
37 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
38 | app.prepare(ctx_id=0, det_size=(640, 640))
39 |
40 | # Load pipeline
41 | base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5'
42 | encoder = CLIPTextModelWrapper.from_pretrained(
43 | 'models', subfolder="encoder", torch_dtype=dtype
44 | )
45 | unet = UNet2DConditionModel.from_pretrained(
46 | 'models', subfolder="arc2face", torch_dtype=dtype
47 | )
48 | controlnet = ControlNetModel.from_pretrained(
49 | 'models', subfolder="controlnet", torch_dtype=dtype
50 | )
51 | pipeline = StableDiffusionControlNetPipeline.from_pretrained(
52 | base_model,
53 | text_encoder=encoder,
54 | unet=unet,
55 | controlnet=controlnet,
56 | torch_dtype=dtype,
57 | safety_checker=None
58 | )
59 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
60 | pipeline = pipeline.to(device)
61 |
62 | # load and disable LCM
63 | pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5")
64 | pipeline.disable_lora()
65 |
66 | def toggle_lcm_ui(value):
67 | if value:
68 | return (
69 | gr.update(minimum=1, maximum=20, step=1, value=3),
70 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=1.0),
71 | )
72 | else:
73 | return (
74 | gr.update(minimum=1, maximum=100, step=1, value=25),
75 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=3.0),
76 | )
77 |
78 | # Load Emoca
79 | face_detector = FAN()
80 | path_to_models = "external/emoca/assets/EMOCA/models"
81 | model_name = 'EMOCA_v2_lr_mse_20'
82 | mode = 'detail'
83 | emoca_model, conf = load_model(path_to_models, model_name, mode)
84 | emoca_model.to(device)
85 | emoca_model.eval()
86 |
87 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
88 | if randomize_seed:
89 | seed = random.randint(0, MAX_SEED)
90 | return seed
91 |
92 | def get_example():
93 | case = [
94 | [
95 | './assets/examples/freddie.png',
96 | './assets/examples/pose1.png',
97 | ],
98 | [
99 | './assets/examples/lily.png',
100 | './assets/examples/pose2.png',
101 | ],
102 | [
103 | './assets/examples/freeman.jpg',
104 | './assets/examples/pose3.jpg',
105 | ],
106 | ]
107 | return case
108 |
109 | def run_example(img_file, ref_img_file):
110 | return generate_image(img_file, ref_img_file, 25, 3, 23, 2, False)
111 |
112 | def run_emoca(img, ref_img):
113 |
114 | img_dict = preprocess_for_emoca(img, face_detector)
115 | img_dict['image'] = img_dict['image'].unsqueeze(0).to(device)
116 | with torch.no_grad():
117 | codedict = emoca_model.encode(img_dict, training=False)
118 |
119 | bbox, bbox_type, landmarks = face_detector.run(np.array(ref_img.convert('RGB')), with_landmarks=True)
120 | if len(bbox) == 0:
121 | raise gr.Error(f"Face detection failed in reference image! Please try with another reference image.")
122 | if len(bbox)>1: # select largest face
123 | sizes = [(b[2]-b[0])*(b[3]-b[1]) for b in bbox]
124 | idx = np.argmax(sizes)
125 | lmks = landmarks[idx]
126 | else:
127 | lmks = landmarks[0]
128 | ref_img_aligned = image_align(ref_img.copy(), lmks, output_size=512)
129 | ref_img_dict = preprocess_for_emoca(ref_img_aligned, face_detector)
130 | ref_img_dict['image'] = ref_img_dict['image'].unsqueeze(0).to(device)
131 | with torch.no_grad():
132 | ref_codedict = emoca_model.encode(ref_img_dict, training=False)
133 | ref_codedict['shapecode'] = codedict['shapecode'].clone()
134 | ref_codedict['detailcode'] = codedict['detailcode'].clone()
135 | tform = ref_img_dict['tform'].unsqueeze(0).to(device)
136 | tform = torch.inverse(tform).transpose(1, 2)
137 | visdict = emoca_model.decode(ref_codedict, training=False, render_orig=True, original_image=ref_img_dict['original_image'].unsqueeze(0).to(device), tform=tform)
138 |
139 | cond_img = Image.fromarray(((visdict['normal_images'][0]*0.5+0.5).clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8))
140 |
141 | return ref_img_aligned, cond_img
142 |
143 | def generate_image(image_path, ref_image_path, num_steps, guidance_scale, seed, num_images, use_lcm, progress=gr.Progress(track_tqdm=True)):
144 |
145 | if use_lcm:
146 | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)
147 | pipeline.enable_lora()
148 | else:
149 | pipeline.disable_lora()
150 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
151 |
152 | if image_path is None:
153 | raise gr.Error(f"Cannot find any input face image! Please upload a face image.")
154 |
155 | if ref_image_path is None:
156 | raise gr.Error(f"Cannot find any reference image! Please upload a reference image.")
157 |
158 | img = np.array(Image.open(image_path))[:,:,::-1]
159 |
160 | # Face detection and ID-embedding extraction
161 | faces = app.get(img)
162 |
163 | if len(faces) == 0:
164 | raise gr.Error(f"Face detection failed! Please try with another input face image.")
165 |
166 | faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected)
167 | id_emb = torch.tensor(faces['embedding'], dtype=dtype)[None].to(device)
168 | id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding
169 | id_emb = project_face_embs(pipeline, id_emb) # pass throught the encoder
170 |
171 | # pose extraction with EMOCA
172 | ref_img_a, cond_img = run_emoca(Image.open(image_path), Image.open(ref_image_path))
173 |
174 | generator = torch.Generator(device=device).manual_seed(seed)
175 |
176 | print("Start inference...")
177 | images = pipeline(
178 | image=cond_img,
179 | prompt_embeds=id_emb,
180 | num_inference_steps=num_steps,
181 | guidance_scale=guidance_scale,
182 | num_images_per_prompt=num_images,
183 | generator=generator
184 | ).images
185 |
186 | return [ref_img_a, cond_img] + images
187 |
188 | ### Description
189 | title = r"""
190 | Arc2Face: A Foundation Model for ID-Consistent Human Faces
191 | """
192 |
193 | description = r"""
194 | Official 🤗 Gradio demo for Arc2Face: A Foundation Model for ID-Consistent Human Faces.
195 | This demo uses Arc2Face with ControlNet to generate images of a subject with the pose extracted from a reference image.
196 |
197 | Steps:
198 | 1. Upload an image with a face. If multiple faces are detected, we use the largest one. For images with already tightly cropped faces, detection may fail, try images with a larger margin.
199 | 2. Upload a reference image for the pose (can be a different subject). We align this image to the FFHQ template for pose extraction (so, the final generated images correspond to the aligned pose). Again, if multiple faces are detected, we use the largest one.
200 | 2. Click Submit to generate new images of the input subject with the reference pose.
201 |
202 | Note: For pose extraction and conditioning, we use a 3D reconstruction method based on the FLAME model (we use only the rendered mesh normals as conditioning image, which is visualized alongside the output images). Reconstruction may fail in extreme poses (normals will correspond to random poses instead of the given one).
203 | """
204 |
205 | Footer = r"""
206 | ---
207 | 📝 **Citation**
208 |
209 | If you find Arc2Face helpful for your research, please consider citing our paper:
210 | ```bibtex
211 | @inproceedings{paraperas2024arc2face,
212 | title={Arc2Face: A Foundation Model for ID-Consistent Human Faces},
213 | author={Paraperas Papantoniou, Foivos and Lattas, Alexandros and Moschoglou, Stylianos and Deng, Jiankang and Kainz, Bernhard and Zafeiriou, Stefanos},
214 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)},
215 | year={2024}
216 | }
217 | ```
218 | """
219 |
220 | css = '''
221 | .gradio-container {width: 85% !important}
222 | '''
223 | with gr.Blocks(css=css) as demo:
224 |
225 | # description
226 | gr.Markdown(title)
227 | gr.Markdown(description)
228 |
229 | with gr.Row():
230 | with gr.Column():
231 | with gr.Row():
232 | # upload face image
233 | img_file = gr.Image(label="Upload a photo with a face", type="filepath")
234 |
235 | # upload reference image
236 | ref_img_file = gr.Image(label="Upload a reference photo for the pose", type="filepath")
237 |
238 | submit = gr.Button("Submit", variant="primary")
239 |
240 | use_lcm = gr.Checkbox(
241 | label="Use LCM-LoRA to accelerate sampling", value=False,
242 | info="Reduces sampling steps significantly, but may decrease quality.",
243 | )
244 |
245 | with gr.Accordion(open=False, label="Advanced Options"):
246 | num_steps = gr.Slider(
247 | label="Number of sample steps",
248 | minimum=1,
249 | maximum=100,
250 | step=1,
251 | value=25,
252 | )
253 | guidance_scale = gr.Slider(
254 | label="Guidance scale",
255 | minimum=0.1,
256 | maximum=10.0,
257 | step=0.1,
258 | value=3.0,
259 | )
260 | num_images = gr.Slider(
261 | label="Number of output images",
262 | minimum=1,
263 | maximum=4,
264 | step=1,
265 | value=2,
266 | )
267 | seed = gr.Slider(
268 | label="Seed",
269 | minimum=0,
270 | maximum=MAX_SEED,
271 | step=1,
272 | value=0,
273 | )
274 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
275 |
276 | with gr.Column():
277 | gallery = gr.Gallery(label="Extracted Pose + Generated Images")
278 |
279 | submit.click(
280 | fn=randomize_seed_fn,
281 | inputs=[seed, randomize_seed],
282 | outputs=seed,
283 | queue=False,
284 | api_name=False,
285 | ).then(
286 | fn=generate_image,
287 | inputs=[img_file, ref_img_file, num_steps, guidance_scale, seed, num_images, use_lcm],
288 | outputs=[gallery]
289 | )
290 |
291 | use_lcm.input(
292 | fn=toggle_lcm_ui,
293 | inputs=[use_lcm],
294 | outputs=[num_steps, guidance_scale],
295 | queue=False,
296 | )
297 |
298 | gr.Examples(
299 | examples=get_example(),
300 | inputs=[img_file, ref_img_file],
301 | run_on_click=True,
302 | fn=run_example,
303 | outputs=[gallery],
304 | )
305 |
306 | gr.Markdown(Footer)
307 |
308 | demo.launch()
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy<1.24.0
2 | torch==2.0.1
3 | torchvision==0.15.2
4 | diffusers==0.23.0
5 | transformers==4.34.1
6 | peft
7 | accelerate
8 | insightface
9 | onnxruntime-gpu
10 | gradio
11 |
--------------------------------------------------------------------------------
/requirements_controlnet.txt:
--------------------------------------------------------------------------------
1 | numpy<1.24.0
2 | diffusers==0.23.0
3 | transformers==4.34.1
4 | peft
5 | accelerate
6 | insightface
7 | onnxruntime-gpu
8 | gradio
9 | face-alignment
10 | omegaconf
11 | pytorch-lightning==1.4.9
12 | torchmetrics==0.6.2
13 | torchfile
14 | facenet-pytorch
15 | chumpy
--------------------------------------------------------------------------------