├── .gitmodules ├── LICENSE ├── README.md ├── arc2face ├── __init__.py ├── models.py └── utils.py ├── assets ├── controlnet.jpg ├── examples │ ├── freddie.png │ ├── freeman.jpg │ ├── hepburn.png │ ├── jackie.png │ ├── joacquin.png │ ├── lily.png │ ├── pose1.png │ ├── pose2.png │ └── pose3.jpg ├── samples.jpg └── teaser.gif ├── gradio_demo ├── app.py └── app_controlnet.py ├── requirements.txt └── requirements_controlnet.txt /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "external/emoca"] 2 | path = external/emoca 3 | url = https://github.com/foivospar/emoca.git 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Foivos Paraperas Papantoniou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Arc2Face: A Foundation Model for ID-Consistent Human Faces 4 | 5 | [Foivos Paraperas Papantoniou](https://foivospar.github.io/)1   [Alexandros Lattas](https://alexlattas.com/)1   [Stylianos Moschoglou](https://moschoglou.com/)1 6 | 7 | [Jiankang Deng](https://jiankangdeng.github.io/)1   [Bernhard Kainz](https://bernhard-kainz.com/)1,2   [Stefanos Zafeiriou](https://www.imperial.ac.uk/people/s.zafeiriou)1 8 | 9 | 1Imperial College London, UK
10 | 2FAU Erlangen-Nürnberg, Germany 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | This is the official implementation of **[Arc2Face](https://arc2face.github.io/)**, an ID-conditioned face model: 21 | 22 |  ✅ that generates high-quality images of any subject given only its ArcFace embedding, within a few seconds
23 |  ✅ trained on the large-scale WebFace42M dataset offers superior ID similarity compared to existing models
24 |  ✅ built on top of Stable Diffusion, can be extended to different input modalities, e.g. with ControlNet
25 | 26 | 27 | 28 | # News/Updates 29 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/arc2face-a-foundation-model-of-human-faces/diffusion-personalization-tuning-free-on)](https://paperswithcode.com/sota/diffusion-personalization-tuning-free-on?p=arc2face-a-foundation-model-of-human-faces) 30 | 31 | - [2024/08/16] 🔥 Accepted to ECCV24 as an **oral**! 32 | - [2024/08/06] 🔥 ComfyUI support available at [caleboleary/ComfyUI-Arc2Face](https://github.com/caleboleary/ComfyUI-Arc2Face)! 33 | - [2024/04/12] 🔥 We add LCM-LoRA support for even faster inference (check the details [below](#lcm-lora-acceleration)). 34 | - [2024/04/11] 🔥 We release the training dataset on [HuggingFace Datasets](https://huggingface.co/datasets/FoivosPar/Arc2Face). 35 | - [2024/03/31] 🔥 We release our demo for pose control using Arc2Face + ControlNet (see instructions [below](#arc2face--controlnet-pose)). 36 | - [2024/03/28] 🔥 We release our Gradio [demo](https://huggingface.co/spaces/FoivosPar/Arc2Face) on HuggingFace Spaces (thanks to the HF team for their free GPU support)! 37 | - [2024/03/14] 🔥 We release Arc2Face. 38 | 39 | # Installation 40 | ```bash 41 | conda create -n arc2face python=3.10 42 | conda activate arc2face 43 | 44 | # Install requirements 45 | pip install -r requirements.txt 46 | ``` 47 | 48 | # Download Models 49 | 1) The models can be downloaded manually from [HuggingFace](https://huggingface.co/FoivosPar/Arc2Face) or using python: 50 | ```python 51 | from huggingface_hub import hf_hub_download 52 | 53 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arc2face/config.json", local_dir="./models") 54 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arc2face/diffusion_pytorch_model.safetensors", local_dir="./models") 55 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="encoder/config.json", local_dir="./models") 56 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="encoder/pytorch_model.bin", local_dir="./models") 57 | ``` 58 | 59 | 2) For face detection and ID-embedding extraction, manually download the [antelopev2](https://github.com/deepinsight/insightface/tree/master/python-package#model-zoo) package ([direct link](https://drive.google.com/file/d/18wEUfMNohBJ4K3Ly5wpTejPfDzp-8fI8/view)) and place the checkpoints under `models/antelopev2`. 60 | 61 | 3) We use an ArcFace recognition model trained on WebFace42M. Download `arcface.onnx` from [HuggingFace](https://huggingface.co/FoivosPar/Arc2Face) and put it in `models/antelopev2` or using python: 62 | ```python 63 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="arcface.onnx", local_dir="./models/antelopev2") 64 | ``` 65 | 4) Then **delete** `glintr100.onnx` (the default backbone from insightface). 66 | 67 | The `models` folder structure should finally be: 68 | ``` 69 | . ── models ──┌── antelopev2 70 | ├── arc2face 71 | └── encoder 72 | ``` 73 | 74 | # Usage 75 | 76 | Load pipeline using [diffusers](https://huggingface.co/docs/diffusers/index): 77 | ```python 78 | from diffusers import ( 79 | StableDiffusionPipeline, 80 | UNet2DConditionModel, 81 | DPMSolverMultistepScheduler, 82 | ) 83 | 84 | from arc2face import CLIPTextModelWrapper, project_face_embs 85 | 86 | import torch 87 | from insightface.app import FaceAnalysis 88 | from PIL import Image 89 | import numpy as np 90 | 91 | # Arc2Face is built upon SD1.5 92 | # The repo below can be used instead of the now deprecated 'runwayml/stable-diffusion-v1-5' 93 | base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5' 94 | 95 | encoder = CLIPTextModelWrapper.from_pretrained( 96 | 'models', subfolder="encoder", torch_dtype=torch.float16 97 | ) 98 | 99 | unet = UNet2DConditionModel.from_pretrained( 100 | 'models', subfolder="arc2face", torch_dtype=torch.float16 101 | ) 102 | 103 | pipeline = StableDiffusionPipeline.from_pretrained( 104 | base_model, 105 | text_encoder=encoder, 106 | unet=unet, 107 | torch_dtype=torch.float16, 108 | safety_checker=None 109 | ) 110 | ``` 111 | You can use any SD-compatible schedulers and steps, just like with Stable Diffusion. By default, we use `DPMSolverMultistepScheduler` with 25 steps, which produces very good results in just a few seconds. 112 | ```python 113 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 114 | pipeline = pipeline.to('cuda') 115 | ``` 116 | Pick an image and extract the ID-embedding: 117 | ```python 118 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 119 | app.prepare(ctx_id=0, det_size=(640, 640)) 120 | 121 | img = np.array(Image.open('assets/examples/joacquin.png'))[:,:,::-1] 122 | 123 | faces = app.get(img) 124 | faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) 125 | id_emb = torch.tensor(faces['embedding'], dtype=torch.float16)[None].cuda() 126 | id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding 127 | id_emb = project_face_embs(pipeline, id_emb) # pass through the encoder 128 | ``` 129 | 130 |
131 | 132 |
133 | 134 | Generate images: 135 | ```python 136 | num_images = 4 137 | images = pipeline(prompt_embeds=id_emb, num_inference_steps=25, guidance_scale=3.0, num_images_per_prompt=num_images).images 138 | ``` 139 |
140 | 141 |
142 | 143 | # LCM-LoRA acceleration 144 | 145 | [LCM-LoRA](https://arxiv.org/abs/2311.05556) allows you to reduce the sampling steps to as few as 2-4 for super-fast inference. Just plug in the pre-trained distillation adapter for SD v1.5 and switch to `LCMScheduler`: 146 | ```python 147 | from diffusers import LCMScheduler 148 | 149 | pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") 150 | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) 151 | ``` 152 | Then, you can sample with as few as 2 steps (and disable `guidance_scale` by using a value of 1.0, as LCM is very sensitive to it and even small values lead to oversaturation): 153 | ```python 154 | images = pipeline(prompt_embeds=id_emb, num_inference_steps=2, guidance_scale=1.0, num_images_per_prompt=num_images).images 155 | ``` 156 | Note that this technique accelerates sampling in exchange for a slight drop in quality. 157 | 158 | # Start a local gradio demo 159 | You can start a local demo for inference by running: 160 | ```python 161 | python gradio_demo/app.py 162 | ``` 163 | 164 | # Arc2Face + ControlNet (pose) 165 |
166 | 167 |
168 | 169 | We provide a ControlNet model trained on top of Arc2Face for pose control. We use [EMOCA](https://github.com/radekd91/emoca) for 3D pose extraction. To run our demo, follow the steps below: 170 | ### 1) Download Model 171 | Download the ControlNet checkpoint manually from [HuggingFace](https://huggingface.co/FoivosPar/Arc2Face) or using python: 172 | ```python 173 | from huggingface_hub import hf_hub_download 174 | 175 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="controlnet/config.json", local_dir="./models") 176 | hf_hub_download(repo_id="FoivosPar/Arc2Face", filename="controlnet/diffusion_pytorch_model.safetensors", local_dir="./models") 177 | ``` 178 | ### 2) Pull EMOCA 179 | ```bash 180 | git submodule update --init external/emoca 181 | ``` 182 | ### 3) Installation 183 | This is the most tricky part. You will need PyTorch3D to run EMOCA. As its installation may cause conflicts, we suggest to follow the process below: 184 | 1) Create a new environment and start by installing PyTorch3D with GPU support first (follow the official [instructions](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md)). 185 | 2) Add Arc2Face + EMOCA requirements with: 186 | ```bash 187 | pip install -r requirements_controlnet.txt 188 | ``` 189 | 3) Install EMOCA code: 190 | ```bash 191 | pip install -e external/emoca 192 | ``` 193 | 4) Finally, you need to download the EMOCA/FLAME assets. Run the following and follow the instructions in the terminal: 194 | ```bash 195 | cd external/emoca/gdl_apps/EMOCA/demos 196 | bash download_assets.sh 197 | cd ../../../../.. 198 | ``` 199 | ### 4) Start a local gradio demo 200 | You can start a local ControlNet demo by running: 201 | ```python 202 | python gradio_demo/app_controlnet.py 203 | ``` 204 | 205 | # Test Data 206 | The test images used for comparisons in the paper (Synth-500, AgeDB) are available [here](https://drive.google.com/drive/folders/1exnvCECmqWcqNIFCck2EQD-hkE42Ayjc?usp=sharing). Please use them only for evaluation purposes and make sure to cite the corresponding [sources](https://ibug.doc.ic.ac.uk/resources/agedb/) when using them. 207 | 208 | # Community Resources 209 | 210 | ### Replicate Demo 211 | - [Demo link](https://replicate.com/camenduru/arc2face) by [@camenduru](https://github.com/camenduru). 212 | 213 | ### ComfyUI 214 | - [caleboleary/ComfyUI-Arc2Face](https://github.com/caleboleary/ComfyUI-Arc2Face) by [@caleboleary](https://github.com/caleboleary). 215 | 216 | ### Pinokio 217 | - Pinokio [implementation](https://pinokio.computer/item?uri=https://github.com/cocktailpeanutlabs/arc2face) by [@cocktailpeanut](https://github.com/cocktailpeanut) (runs locally on all OS - Windows, Mac, Linux). 218 | 219 | # Acknowledgements 220 | - Thanks to the creators of Stable Diffusion and the HuggingFace [diffusers](https://github.com/huggingface/diffusers) team for the awesome work ❤️. 221 | - Thanks to the WebFace42M creators for providing such a million-scale facial dataset ❤️. 222 | - Thanks to the HuggingFace team for their generous support through the community GPU grant for our demo ❤️. 223 | - We also acknowledge the invaluable support of the HPC resources provided by the Erlangen National High Performance Computing Center (NHR@FAU) of the Friedrich-Alexander-Universität Erlangen-Nürnberg (FAU), which made the training of Arc2Face possible. 224 | 225 | # Citation 226 | If you find Arc2Face useful for your research, please consider citing us: 227 | 228 | ```bibtex 229 | @inproceedings{paraperas2024arc2face, 230 | title={Arc2Face: A Foundation Model for ID-Consistent Human Faces}, 231 | author={Paraperas Papantoniou, Foivos and Lattas, Alexandros and Moschoglou, Stylianos and Deng, Jiankang and Kainz, Bernhard and Zafeiriou, Stefanos}, 232 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 233 | year={2024} 234 | } 235 | ``` 236 | -------------------------------------------------------------------------------- /arc2face/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import CLIPTextModelWrapper 2 | from .utils import project_face_embs, image_align -------------------------------------------------------------------------------- /arc2face/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPTextModel 3 | from typing import Any, Callable, Dict, Optional, Tuple, Union, List 4 | from transformers.modeling_outputs import BaseModelOutputWithPooling 5 | from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask 6 | 7 | 8 | class CLIPTextModelWrapper(CLIPTextModel): 9 | # Adapted from https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/clip/modeling_clip.py#L812 10 | # Modified to accept precomputed token embeddings "input_token_embs" as input or calculate them from input_ids and return them. 11 | def forward( 12 | self, 13 | input_ids: Optional[torch.Tensor] = None, 14 | attention_mask: Optional[torch.Tensor] = None, 15 | position_ids: Optional[torch.Tensor] = None, 16 | output_attentions: Optional[bool] = None, 17 | output_hidden_states: Optional[bool] = None, 18 | return_dict: Optional[bool] = None, 19 | input_token_embs: Optional[torch.Tensor] = None, 20 | return_token_embs: Optional[bool] = False, 21 | ) -> Union[Tuple, torch.Tensor, BaseModelOutputWithPooling]: 22 | 23 | if return_token_embs: 24 | return self.text_model.embeddings.token_embedding(input_ids) 25 | 26 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 27 | 28 | output_attentions = output_attentions if output_attentions is not None else self.text_model.config.output_attentions 29 | output_hidden_states = ( 30 | output_hidden_states if output_hidden_states is not None else self.text_model.config.output_hidden_states 31 | ) 32 | return_dict = return_dict if return_dict is not None else self.text_model.config.use_return_dict 33 | 34 | if input_ids is None: 35 | raise ValueError("You have to specify input_ids") 36 | 37 | input_shape = input_ids.size() 38 | input_ids = input_ids.view(-1, input_shape[-1]) 39 | 40 | hidden_states = self.text_model.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=input_token_embs) 41 | 42 | # CLIP's text model uses causal mask, prepare it here. 43 | # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 44 | causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device) 45 | # expand attention_mask 46 | if attention_mask is not None: 47 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 48 | attention_mask = _expand_mask(attention_mask, hidden_states.dtype) 49 | 50 | encoder_outputs = self.text_model.encoder( 51 | inputs_embeds=hidden_states, 52 | attention_mask=attention_mask, 53 | causal_attention_mask=causal_attention_mask, 54 | output_attentions=output_attentions, 55 | output_hidden_states=output_hidden_states, 56 | return_dict=return_dict, 57 | ) 58 | 59 | last_hidden_state = encoder_outputs[0] 60 | last_hidden_state = self.text_model.final_layer_norm(last_hidden_state) 61 | 62 | if self.text_model.eos_token_id == 2: 63 | # The `eos_token_id` was incorrect before PR #24773: Let's keep what have been done here. 64 | # A CLIP model with such `eos_token_id` in the config can't work correctly with extra new tokens added 65 | # ------------------------------------------------------------ 66 | # text_embeds.shape = [batch_size, sequence_length, transformer.width] 67 | # take features from the eot embedding (eot_token is the highest number in each sequence) 68 | # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14 69 | pooled_output = last_hidden_state[ 70 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 71 | input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1), 72 | ] 73 | else: 74 | # The config gets updated `eos_token_id` from PR #24773 (so the use of exta new tokens is possible) 75 | pooled_output = last_hidden_state[ 76 | torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device), 77 | # We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`) 78 | (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.text_model.eos_token_id) 79 | .int() 80 | .argmax(dim=-1), 81 | ] 82 | 83 | if not return_dict: 84 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 85 | 86 | return BaseModelOutputWithPooling( 87 | last_hidden_state=last_hidden_state, 88 | pooler_output=pooled_output, 89 | hidden_states=encoder_outputs.hidden_states, 90 | attentions=encoder_outputs.attentions, 91 | ) -------------------------------------------------------------------------------- /arc2face/utils.py: -------------------------------------------------------------------------------- 1 | import scipy 2 | import PIL 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | @torch.no_grad() 8 | def project_face_embs(pipeline, face_embs): 9 | 10 | ''' 11 | face_embs: (N, 512) normalized ArcFace embeddings 12 | ''' 13 | 14 | arcface_token_id = pipeline.tokenizer.encode("id", add_special_tokens=False)[0] 15 | 16 | input_ids = pipeline.tokenizer( 17 | "photo of a id person", 18 | truncation=True, 19 | padding="max_length", 20 | max_length=pipeline.tokenizer.model_max_length, 21 | return_tensors="pt", 22 | ).input_ids.to(pipeline.device) 23 | 24 | face_embs_padded = F.pad(face_embs, (0, pipeline.text_encoder.config.hidden_size-512), "constant", 0) 25 | token_embs = pipeline.text_encoder(input_ids=input_ids.repeat(len(face_embs), 1), return_token_embs=True) 26 | token_embs[input_ids==arcface_token_id] = face_embs_padded 27 | 28 | prompt_embeds = pipeline.text_encoder( 29 | input_ids=input_ids, 30 | input_token_embs=token_embs 31 | )[0] 32 | 33 | return prompt_embeds 34 | 35 | 36 | def image_align(img, 37 | face_landmarks, 38 | output_size=1024, 39 | transform_size=4096, 40 | enable_padding=True): 41 | # Align function from FFHQ dataset pre-processing step 42 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 43 | 44 | lm = face_landmarks 45 | lm_eye_left = lm[36:42] # left-clockwise 46 | lm_eye_right = lm[42:48] # left-clockwise 47 | lm_mouth_outer = lm[48:60] # left-clockwise 48 | 49 | # Calculate auxiliary vectors. 50 | eye_left = np.mean(lm_eye_left, axis=0) 51 | eye_right = np.mean(lm_eye_right, axis=0) 52 | eye_avg = (eye_left + eye_right) * 0.5 53 | eye_to_eye = eye_right - eye_left 54 | mouth_left = lm_mouth_outer[0] 55 | mouth_right = lm_mouth_outer[6] 56 | mouth_avg = (mouth_left + mouth_right) * 0.5 57 | eye_to_mouth = mouth_avg - eye_avg 58 | 59 | # Choose oriented crop rectangle. 60 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 61 | x /= np.hypot(*x) 62 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 63 | y = np.flipud(x) * [-1, 1] 64 | c = eye_avg + eye_to_mouth * 0.1 65 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 66 | qsize = np.hypot(*x) * 2 67 | 68 | img = img.convert('RGB') 69 | 70 | # Shrink. 71 | shrink = int(np.floor(qsize / output_size * 0.5)) 72 | if shrink > 1: 73 | rsize = (int(np.rint(float(img.size[0]) / shrink)), 74 | int(np.rint(float(img.size[1]) / shrink))) 75 | img = img.resize(rsize, PIL.Image.LANCZOS) 76 | quad /= shrink 77 | qsize /= shrink 78 | 79 | # Crop. 80 | border = max(int(np.rint(qsize * 0.1)), 3) 81 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), 82 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) 83 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), 84 | min(crop[2] + border, 85 | img.size[0]), min(crop[3] + border, img.size[1])) 86 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 87 | img = img.crop(crop) 88 | quad -= crop[0:2] 89 | 90 | # Pad. 91 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), 92 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) 93 | pad = (max(-pad[0] + border, 94 | 0), max(-pad[1] + border, 95 | 0), max(pad[2] - img.size[0] + border, 96 | 0), max(pad[3] - img.size[1] + border, 0)) 97 | if enable_padding and max(pad) > border - 4: 98 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 99 | img = np.pad(np.float32(img), 100 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 101 | h, w, _ = img.shape 102 | y, x, _ = np.ogrid[:h, :w, :1] 103 | mask = np.maximum( 104 | 1.0 - 105 | np.minimum(np.float32(x) / pad[0], 106 | np.float32(w - 1 - x) / pad[2]), 1.0 - 107 | np.minimum(np.float32(y) / pad[1], 108 | np.float32(h - 1 - y) / pad[3])) 109 | blur = qsize * 0.02 110 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - 111 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 112 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 113 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 114 | 'RGB') 115 | quad += pad[:2] 116 | 117 | 118 | # Transform. 119 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, 120 | (quad + 0.5).flatten(), PIL.Image.BILINEAR) 121 | if output_size < transform_size: 122 | img = img.resize((output_size, output_size), PIL.Image.LANCZOS) 123 | 124 | return img -------------------------------------------------------------------------------- /assets/controlnet.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/controlnet.jpg -------------------------------------------------------------------------------- /assets/examples/freddie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/freddie.png -------------------------------------------------------------------------------- /assets/examples/freeman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/freeman.jpg -------------------------------------------------------------------------------- /assets/examples/hepburn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/hepburn.png -------------------------------------------------------------------------------- /assets/examples/jackie.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/jackie.png -------------------------------------------------------------------------------- /assets/examples/joacquin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/joacquin.png -------------------------------------------------------------------------------- /assets/examples/lily.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/lily.png -------------------------------------------------------------------------------- /assets/examples/pose1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/pose1.png -------------------------------------------------------------------------------- /assets/examples/pose2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/pose2.png -------------------------------------------------------------------------------- /assets/examples/pose3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/examples/pose3.jpg -------------------------------------------------------------------------------- /assets/samples.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/samples.jpg -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/foivospar/Arc2Face/017c4b414da03758478855f519fe0bc3e92caa4c/assets/teaser.gif -------------------------------------------------------------------------------- /gradio_demo/app.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | from diffusers import ( 5 | StableDiffusionPipeline, 6 | UNet2DConditionModel, 7 | DPMSolverMultistepScheduler, 8 | LCMScheduler 9 | ) 10 | 11 | from arc2face import CLIPTextModelWrapper, project_face_embs 12 | 13 | import torch 14 | from insightface.app import FaceAnalysis 15 | from PIL import Image 16 | import numpy as np 17 | import random 18 | 19 | import gradio as gr 20 | 21 | # global variable 22 | MAX_SEED = np.iinfo(np.int32).max 23 | if torch.cuda.is_available(): 24 | device = "cuda" 25 | dtype = torch.float16 26 | else: 27 | device = "cpu" 28 | dtype = torch.float32 29 | 30 | # Load face detection and recognition package 31 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 32 | app.prepare(ctx_id=0, det_size=(640, 640)) 33 | 34 | # Load pipeline 35 | base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5' 36 | encoder = CLIPTextModelWrapper.from_pretrained( 37 | 'models', subfolder="encoder", torch_dtype=dtype 38 | ) 39 | unet = UNet2DConditionModel.from_pretrained( 40 | 'models', subfolder="arc2face", torch_dtype=dtype 41 | ) 42 | pipeline = StableDiffusionPipeline.from_pretrained( 43 | base_model, 44 | text_encoder=encoder, 45 | unet=unet, 46 | torch_dtype=dtype, 47 | safety_checker=None 48 | ) 49 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 50 | pipeline = pipeline.to(device) 51 | 52 | # load and disable LCM 53 | pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") 54 | pipeline.disable_lora() 55 | 56 | def toggle_lcm_ui(value): 57 | if value: 58 | return ( 59 | gr.update(minimum=1, maximum=20, step=1, value=3), 60 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=1.0), 61 | ) 62 | else: 63 | return ( 64 | gr.update(minimum=1, maximum=100, step=1, value=25), 65 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=3.0), 66 | ) 67 | 68 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 69 | if randomize_seed: 70 | seed = random.randint(0, MAX_SEED) 71 | return seed 72 | 73 | def get_example(): 74 | case = [ 75 | [ 76 | './assets/examples/freeman.jpg', 77 | ], 78 | [ 79 | './assets/examples/lily.png', 80 | ], 81 | [ 82 | './assets/examples/joacquin.png', 83 | ], 84 | [ 85 | './assets/examples/jackie.png', 86 | ], 87 | [ 88 | './assets/examples/freddie.png', 89 | ], 90 | [ 91 | './assets/examples/hepburn.png', 92 | ], 93 | ] 94 | return case 95 | 96 | def run_example(img_file): 97 | return generate_image(img_file, 25, 3, 23, 2, False) 98 | 99 | 100 | def generate_image(image_path, num_steps, guidance_scale, seed, num_images, use_lcm, progress=gr.Progress(track_tqdm=True)): 101 | 102 | if use_lcm: 103 | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) 104 | pipeline.enable_lora() 105 | else: 106 | pipeline.disable_lora() 107 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 108 | 109 | if image_path is None: 110 | raise gr.Error(f"Cannot find any input face image! Please upload a face image.") 111 | 112 | img = np.array(Image.open(image_path))[:,:,::-1] 113 | 114 | # Face detection and ID-embedding extraction 115 | faces = app.get(img) 116 | 117 | if len(faces) == 0: 118 | raise gr.Error(f"Face detection failed! Please try with another image") 119 | 120 | faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) 121 | id_emb = torch.tensor(faces['embedding'], dtype=dtype)[None].to(device) 122 | id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding 123 | id_emb = project_face_embs(pipeline, id_emb) # pass throught the encoder 124 | 125 | generator = torch.Generator(device=device).manual_seed(seed) 126 | 127 | print("Start inference...") 128 | images = pipeline( 129 | prompt_embeds=id_emb, 130 | num_inference_steps=num_steps, 131 | guidance_scale=guidance_scale, 132 | num_images_per_prompt=num_images, 133 | generator=generator 134 | ).images 135 | 136 | return images 137 | 138 | ### Description 139 | title = r""" 140 |

Arc2Face: A Foundation Model for ID-Consistent Human Faces

141 | """ 142 | 143 | description = r""" 144 | Official 🤗 Gradio demo for Arc2Face: A Foundation Model for ID-Consistent Human Faces.
145 | 146 | Steps:
147 | 1. Upload an image with a face. If multiple faces are detected, we use the largest one. For images with already tightly cropped faces, detection may fail, try images with a larger margin. 148 | 2. Click Submit to generate new images of the subject. 149 | """ 150 | 151 | Footer = r""" 152 | --- 153 | 📝 **Citation** 154 |
155 | If you find Arc2Face helpful for your research, please consider citing our paper: 156 | ```bibtex 157 | @inproceedings{paraperas2024arc2face, 158 | title={Arc2Face: A Foundation Model for ID-Consistent Human Faces}, 159 | author={Paraperas Papantoniou, Foivos and Lattas, Alexandros and Moschoglou, Stylianos and Deng, Jiankang and Kainz, Bernhard and Zafeiriou, Stefanos}, 160 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 161 | year={2024} 162 | } 163 | ``` 164 | """ 165 | 166 | css = ''' 167 | .gradio-container {width: 85% !important} 168 | ''' 169 | with gr.Blocks(css=css) as demo: 170 | 171 | # description 172 | gr.Markdown(title) 173 | gr.Markdown(description) 174 | 175 | with gr.Row(): 176 | with gr.Column(): 177 | 178 | # upload face image 179 | img_file = gr.Image(label="Upload a photo with a face", type="filepath") 180 | 181 | submit = gr.Button("Submit", variant="primary") 182 | 183 | use_lcm = gr.Checkbox( 184 | label="Use LCM-LoRA to accelerate sampling", value=False, 185 | info="Reduces sampling steps significantly, but may decrease quality.", 186 | ) 187 | 188 | with gr.Accordion(open=False, label="Advanced Options"): 189 | num_steps = gr.Slider( 190 | label="Number of sample steps", 191 | minimum=1, 192 | maximum=100, 193 | step=1, 194 | value=25, 195 | ) 196 | guidance_scale = gr.Slider( 197 | label="Guidance scale", 198 | minimum=0.1, 199 | maximum=10.0, 200 | step=0.1, 201 | value=3.0, 202 | ) 203 | num_images = gr.Slider( 204 | label="Number of output images", 205 | minimum=1, 206 | maximum=4, 207 | step=1, 208 | value=2, 209 | ) 210 | seed = gr.Slider( 211 | label="Seed", 212 | minimum=0, 213 | maximum=MAX_SEED, 214 | step=1, 215 | value=0, 216 | ) 217 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 218 | 219 | with gr.Column(): 220 | gallery = gr.Gallery(label="Generated Images") 221 | 222 | submit.click( 223 | fn=randomize_seed_fn, 224 | inputs=[seed, randomize_seed], 225 | outputs=seed, 226 | queue=False, 227 | api_name=False, 228 | ).then( 229 | fn=generate_image, 230 | inputs=[img_file, num_steps, guidance_scale, seed, num_images, use_lcm], 231 | outputs=[gallery] 232 | ) 233 | 234 | use_lcm.input( 235 | fn=toggle_lcm_ui, 236 | inputs=[use_lcm], 237 | outputs=[num_steps, guidance_scale], 238 | queue=False, 239 | ) 240 | 241 | gr.Examples( 242 | examples=get_example(), 243 | inputs=[img_file], 244 | run_on_click=True, 245 | fn=run_example, 246 | outputs=[gallery], 247 | ) 248 | 249 | gr.Markdown(Footer) 250 | 251 | demo.launch() -------------------------------------------------------------------------------- /gradio_demo/app_controlnet.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('./') 3 | 4 | from diffusers import ( 5 | StableDiffusionPipeline, 6 | UNet2DConditionModel, 7 | DPMSolverMultistepScheduler, 8 | LCMScheduler, 9 | ControlNetModel, 10 | StableDiffusionControlNetPipeline 11 | ) 12 | 13 | from arc2face import CLIPTextModelWrapper, project_face_embs, image_align 14 | 15 | from gdl.utils.FaceDetector import FAN 16 | from gdl_apps.EMOCA.utils.load import load_model 17 | from gdl.datasets.ImageTestDataset import preprocess_for_emoca 18 | 19 | import torch 20 | from insightface.app import FaceAnalysis 21 | from PIL import Image 22 | import numpy as np 23 | import random 24 | 25 | import gradio as gr 26 | 27 | # global variable 28 | MAX_SEED = np.iinfo(np.int32).max 29 | if torch.cuda.is_available(): 30 | device = "cuda" 31 | dtype = torch.float16 32 | else: 33 | device = "cpu" 34 | dtype = torch.float32 35 | 36 | # Load face detection and recognition package 37 | app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 38 | app.prepare(ctx_id=0, det_size=(640, 640)) 39 | 40 | # Load pipeline 41 | base_model = 'stable-diffusion-v1-5/stable-diffusion-v1-5' 42 | encoder = CLIPTextModelWrapper.from_pretrained( 43 | 'models', subfolder="encoder", torch_dtype=dtype 44 | ) 45 | unet = UNet2DConditionModel.from_pretrained( 46 | 'models', subfolder="arc2face", torch_dtype=dtype 47 | ) 48 | controlnet = ControlNetModel.from_pretrained( 49 | 'models', subfolder="controlnet", torch_dtype=dtype 50 | ) 51 | pipeline = StableDiffusionControlNetPipeline.from_pretrained( 52 | base_model, 53 | text_encoder=encoder, 54 | unet=unet, 55 | controlnet=controlnet, 56 | torch_dtype=dtype, 57 | safety_checker=None 58 | ) 59 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 60 | pipeline = pipeline.to(device) 61 | 62 | # load and disable LCM 63 | pipeline.load_lora_weights("latent-consistency/lcm-lora-sdv1-5") 64 | pipeline.disable_lora() 65 | 66 | def toggle_lcm_ui(value): 67 | if value: 68 | return ( 69 | gr.update(minimum=1, maximum=20, step=1, value=3), 70 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=1.0), 71 | ) 72 | else: 73 | return ( 74 | gr.update(minimum=1, maximum=100, step=1, value=25), 75 | gr.update(minimum=0.1, maximum=10.0, step=0.1, value=3.0), 76 | ) 77 | 78 | # Load Emoca 79 | face_detector = FAN() 80 | path_to_models = "external/emoca/assets/EMOCA/models" 81 | model_name = 'EMOCA_v2_lr_mse_20' 82 | mode = 'detail' 83 | emoca_model, conf = load_model(path_to_models, model_name, mode) 84 | emoca_model.to(device) 85 | emoca_model.eval() 86 | 87 | def randomize_seed_fn(seed: int, randomize_seed: bool) -> int: 88 | if randomize_seed: 89 | seed = random.randint(0, MAX_SEED) 90 | return seed 91 | 92 | def get_example(): 93 | case = [ 94 | [ 95 | './assets/examples/freddie.png', 96 | './assets/examples/pose1.png', 97 | ], 98 | [ 99 | './assets/examples/lily.png', 100 | './assets/examples/pose2.png', 101 | ], 102 | [ 103 | './assets/examples/freeman.jpg', 104 | './assets/examples/pose3.jpg', 105 | ], 106 | ] 107 | return case 108 | 109 | def run_example(img_file, ref_img_file): 110 | return generate_image(img_file, ref_img_file, 25, 3, 23, 2, False) 111 | 112 | def run_emoca(img, ref_img): 113 | 114 | img_dict = preprocess_for_emoca(img, face_detector) 115 | img_dict['image'] = img_dict['image'].unsqueeze(0).to(device) 116 | with torch.no_grad(): 117 | codedict = emoca_model.encode(img_dict, training=False) 118 | 119 | bbox, bbox_type, landmarks = face_detector.run(np.array(ref_img.convert('RGB')), with_landmarks=True) 120 | if len(bbox) == 0: 121 | raise gr.Error(f"Face detection failed in reference image! Please try with another reference image.") 122 | if len(bbox)>1: # select largest face 123 | sizes = [(b[2]-b[0])*(b[3]-b[1]) for b in bbox] 124 | idx = np.argmax(sizes) 125 | lmks = landmarks[idx] 126 | else: 127 | lmks = landmarks[0] 128 | ref_img_aligned = image_align(ref_img.copy(), lmks, output_size=512) 129 | ref_img_dict = preprocess_for_emoca(ref_img_aligned, face_detector) 130 | ref_img_dict['image'] = ref_img_dict['image'].unsqueeze(0).to(device) 131 | with torch.no_grad(): 132 | ref_codedict = emoca_model.encode(ref_img_dict, training=False) 133 | ref_codedict['shapecode'] = codedict['shapecode'].clone() 134 | ref_codedict['detailcode'] = codedict['detailcode'].clone() 135 | tform = ref_img_dict['tform'].unsqueeze(0).to(device) 136 | tform = torch.inverse(tform).transpose(1, 2) 137 | visdict = emoca_model.decode(ref_codedict, training=False, render_orig=True, original_image=ref_img_dict['original_image'].unsqueeze(0).to(device), tform=tform) 138 | 139 | cond_img = Image.fromarray(((visdict['normal_images'][0]*0.5+0.5).clamp(0,1).permute(1,2,0).cpu().numpy()*255).astype(np.uint8)) 140 | 141 | return ref_img_aligned, cond_img 142 | 143 | def generate_image(image_path, ref_image_path, num_steps, guidance_scale, seed, num_images, use_lcm, progress=gr.Progress(track_tqdm=True)): 144 | 145 | if use_lcm: 146 | pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config) 147 | pipeline.enable_lora() 148 | else: 149 | pipeline.disable_lora() 150 | pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) 151 | 152 | if image_path is None: 153 | raise gr.Error(f"Cannot find any input face image! Please upload a face image.") 154 | 155 | if ref_image_path is None: 156 | raise gr.Error(f"Cannot find any reference image! Please upload a reference image.") 157 | 158 | img = np.array(Image.open(image_path))[:,:,::-1] 159 | 160 | # Face detection and ID-embedding extraction 161 | faces = app.get(img) 162 | 163 | if len(faces) == 0: 164 | raise gr.Error(f"Face detection failed! Please try with another input face image.") 165 | 166 | faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # select largest face (if more than one detected) 167 | id_emb = torch.tensor(faces['embedding'], dtype=dtype)[None].to(device) 168 | id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True) # normalize embedding 169 | id_emb = project_face_embs(pipeline, id_emb) # pass throught the encoder 170 | 171 | # pose extraction with EMOCA 172 | ref_img_a, cond_img = run_emoca(Image.open(image_path), Image.open(ref_image_path)) 173 | 174 | generator = torch.Generator(device=device).manual_seed(seed) 175 | 176 | print("Start inference...") 177 | images = pipeline( 178 | image=cond_img, 179 | prompt_embeds=id_emb, 180 | num_inference_steps=num_steps, 181 | guidance_scale=guidance_scale, 182 | num_images_per_prompt=num_images, 183 | generator=generator 184 | ).images 185 | 186 | return [ref_img_a, cond_img] + images 187 | 188 | ### Description 189 | title = r""" 190 |

Arc2Face: A Foundation Model for ID-Consistent Human Faces

191 | """ 192 | 193 | description = r""" 194 | Official 🤗 Gradio demo for Arc2Face: A Foundation Model for ID-Consistent Human Faces.
195 | This demo uses Arc2Face with ControlNet to generate images of a subject with the pose extracted from a reference image. 196 | 197 | Steps:
198 | 1. Upload an image with a face. If multiple faces are detected, we use the largest one. For images with already tightly cropped faces, detection may fail, try images with a larger margin. 199 | 2. Upload a reference image for the pose (can be a different subject). We align this image to the FFHQ template for pose extraction (so, the final generated images correspond to the aligned pose). Again, if multiple faces are detected, we use the largest one. 200 | 2. Click Submit to generate new images of the input subject with the reference pose. 201 | 202 | Note: For pose extraction and conditioning, we use a 3D reconstruction method based on the FLAME model (we use only the rendered mesh normals as conditioning image, which is visualized alongside the output images). Reconstruction may fail in extreme poses (normals will correspond to random poses instead of the given one). 203 | """ 204 | 205 | Footer = r""" 206 | --- 207 | 📝 **Citation** 208 |
209 | If you find Arc2Face helpful for your research, please consider citing our paper: 210 | ```bibtex 211 | @inproceedings{paraperas2024arc2face, 212 | title={Arc2Face: A Foundation Model for ID-Consistent Human Faces}, 213 | author={Paraperas Papantoniou, Foivos and Lattas, Alexandros and Moschoglou, Stylianos and Deng, Jiankang and Kainz, Bernhard and Zafeiriou, Stefanos}, 214 | booktitle={Proceedings of the European Conference on Computer Vision (ECCV)}, 215 | year={2024} 216 | } 217 | ``` 218 | """ 219 | 220 | css = ''' 221 | .gradio-container {width: 85% !important} 222 | ''' 223 | with gr.Blocks(css=css) as demo: 224 | 225 | # description 226 | gr.Markdown(title) 227 | gr.Markdown(description) 228 | 229 | with gr.Row(): 230 | with gr.Column(): 231 | with gr.Row(): 232 | # upload face image 233 | img_file = gr.Image(label="Upload a photo with a face", type="filepath") 234 | 235 | # upload reference image 236 | ref_img_file = gr.Image(label="Upload a reference photo for the pose", type="filepath") 237 | 238 | submit = gr.Button("Submit", variant="primary") 239 | 240 | use_lcm = gr.Checkbox( 241 | label="Use LCM-LoRA to accelerate sampling", value=False, 242 | info="Reduces sampling steps significantly, but may decrease quality.", 243 | ) 244 | 245 | with gr.Accordion(open=False, label="Advanced Options"): 246 | num_steps = gr.Slider( 247 | label="Number of sample steps", 248 | minimum=1, 249 | maximum=100, 250 | step=1, 251 | value=25, 252 | ) 253 | guidance_scale = gr.Slider( 254 | label="Guidance scale", 255 | minimum=0.1, 256 | maximum=10.0, 257 | step=0.1, 258 | value=3.0, 259 | ) 260 | num_images = gr.Slider( 261 | label="Number of output images", 262 | minimum=1, 263 | maximum=4, 264 | step=1, 265 | value=2, 266 | ) 267 | seed = gr.Slider( 268 | label="Seed", 269 | minimum=0, 270 | maximum=MAX_SEED, 271 | step=1, 272 | value=0, 273 | ) 274 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 275 | 276 | with gr.Column(): 277 | gallery = gr.Gallery(label="Extracted Pose + Generated Images") 278 | 279 | submit.click( 280 | fn=randomize_seed_fn, 281 | inputs=[seed, randomize_seed], 282 | outputs=seed, 283 | queue=False, 284 | api_name=False, 285 | ).then( 286 | fn=generate_image, 287 | inputs=[img_file, ref_img_file, num_steps, guidance_scale, seed, num_images, use_lcm], 288 | outputs=[gallery] 289 | ) 290 | 291 | use_lcm.input( 292 | fn=toggle_lcm_ui, 293 | inputs=[use_lcm], 294 | outputs=[num_steps, guidance_scale], 295 | queue=False, 296 | ) 297 | 298 | gr.Examples( 299 | examples=get_example(), 300 | inputs=[img_file, ref_img_file], 301 | run_on_click=True, 302 | fn=run_example, 303 | outputs=[gallery], 304 | ) 305 | 306 | gr.Markdown(Footer) 307 | 308 | demo.launch() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy<1.24.0 2 | torch==2.0.1 3 | torchvision==0.15.2 4 | diffusers==0.23.0 5 | transformers==4.34.1 6 | peft 7 | accelerate 8 | insightface 9 | onnxruntime-gpu 10 | gradio 11 | -------------------------------------------------------------------------------- /requirements_controlnet.txt: -------------------------------------------------------------------------------- 1 | numpy<1.24.0 2 | diffusers==0.23.0 3 | transformers==4.34.1 4 | peft 5 | accelerate 6 | insightface 7 | onnxruntime-gpu 8 | gradio 9 | face-alignment 10 | omegaconf 11 | pytorch-lightning==1.4.9 12 | torchmetrics==0.6.2 13 | torchfile 14 | facenet-pytorch 15 | chumpy --------------------------------------------------------------------------------