├── facechain ├── __init__.py ├── data_process │ ├── __init__.py │ ├── preprocessing.py │ ├── preprocessing -win10.py │ └── deepbooru.py ├── merge_lora.py ├── inference.py └── train_text_to_image_lora.py ├── .gitattributes ├── resources ├── example1.jpg ├── example2.jpg ├── example3.jpg ├── framework.jpg └── framework_eng.jpg ├── requirements.txt ├── train_lora.sh ├── train_lora.bat ├── run_inference.py ├── README_ZH.md ├── README.md ├── app.py ├── app-win10.py └── LICENSE /facechain/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /facechain/data_process/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | *.safetensors filter=lfs diff=lfs merge=lfs -text 2 | *.jpg filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /resources/example1.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:c773cdd97a6ab7dc8fe4075f216eefad71043b5f099fcf8399ff00c1a0bed2cc 3 | size 83419 4 | -------------------------------------------------------------------------------- /resources/example2.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:637efd55906eeefc27dc76f97049424ff3cfc96eb3ab3089fc8f325ebda5f857 3 | size 51409 4 | -------------------------------------------------------------------------------- /resources/example3.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:896bcd46629910602709ae94602c2ba99d239ce9c6eb4790a05111c4df7eaa6e 3 | size 131820 4 | -------------------------------------------------------------------------------- /resources/framework.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1ef2305413f33d93bac369836231f4a79a4df2463086b118ca1584b115c2e2b0 3 | size 476295 4 | -------------------------------------------------------------------------------- /resources/framework_eng.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:18adcd55c884e59d70addb717b52aa961369caa18fc5436115dbdc1f58614086 3 | size 487479 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | transformers 3 | diffusers 4 | onnxruntime 5 | modelscope[framework] 6 | Pillow 7 | opencv-python 8 | torchvision 9 | mmdet==2.26.0 10 | mmengine 11 | tensorflow==2.7.0 12 | #tensorflow-cpu # slower but no cuda conflicts 13 | numpy==1.22.0 14 | protobuf==3.20.1 15 | timm 16 | scikit-image 17 | gradio 18 | 19 | # mmcv-full (need mim install) 20 | -------------------------------------------------------------------------------- /train_lora.sh: -------------------------------------------------------------------------------- 1 | export MODEL_NAME=$1 2 | export VERSION=$2 3 | export SUB_PATH=$3 4 | export DATASET_NAME=$4 5 | export OUTPUT_DATASET_NAME=$5 6 | export WORK_DIR=$6 7 | 8 | accelerate launch facechain/train_text_to_image_lora.py \ 9 | --pretrained_model_name_or_path=$MODEL_NAME \ 10 | --revision=$VERSION \ 11 | --sub_path=$SUB_PATH \ 12 | --dataset_name=$DATASET_NAME \ 13 | --output_dataset_name=$OUTPUT_DATASET_NAME \ 14 | --caption_column="text" \ 15 | --resolution=512 --random_flip \ 16 | --train_batch_size=1 \ 17 | --num_train_epochs=200 --checkpointing_steps=5000 \ 18 | --learning_rate=1e-04 --lr_scheduler="cosine" --lr_warmup_steps=0 \ 19 | --seed=42 \ 20 | --output_dir=$WORK_DIR \ 21 | --lora_r=32 --lora_alpha=32 \ 22 | --lora_text_encoder_r=32 --lora_text_encoder_alpha=32 23 | 24 | 25 | -------------------------------------------------------------------------------- /train_lora.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set MODEL_NAME=%1 3 | set VERSION=%2 4 | set SUB_PATH=%3 5 | set DATASET_NAME=%4 6 | set OUTPUT_DATASET_NAME=%5 7 | set WORK_DIR=%6 8 | 9 | accelerate launch facechain/train_text_to_image_lora.py ^ 10 | --pretrained_model_name_or_path=%MODEL_NAME% ^ 11 | --revision=%VERSION% ^ 12 | --sub_path=%SUB_PATH% ^ 13 | --dataset_name=%DATASET_NAME% ^ 14 | --output_dataset_name=%OUTPUT_DATASET_NAME% ^ 15 | --caption_column="text" ^ 16 | --resolution=512 --random_flip ^ 17 | --train_batch_size=1 ^ 18 | --num_train_epochs=200 --checkpointing_steps=5000 ^ 19 | --learning_rate=1e-04 --lr_scheduler="cosine" --lr_warmup_steps=0 ^ 20 | --seed=42 ^ 21 | --output_dir=%WORK_DIR% ^ 22 | --lora_r=32 --lora_alpha=32 ^ 23 | --lora_text_encoder_r=32 --lora_text_encoder_alpha=32 24 | -------------------------------------------------------------------------------- /run_inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | import os 3 | 4 | from facechain.inference import GenPortrait 5 | import cv2 6 | 7 | use_main_model = True 8 | use_face_swap = True 9 | use_post_process = True 10 | use_stylization = False 11 | 12 | gen_portrait = GenPortrait(use_main_model, use_face_swap, use_post_process, 13 | use_stylization) 14 | 15 | processed_dir = './processed' 16 | num_generate = 5 17 | base_model = 'ly261666/cv_portrait_model' 18 | revision = 'v2.0' 19 | base_model_sub_dir = 'film/film' 20 | train_output_dir = './output' 21 | output_dir = './generated' 22 | 23 | outputs = gen_portrait(processed_dir, num_generate, base_model, 24 | train_output_dir, base_model_sub_dir, revision) 25 | 26 | os.makedirs(output_dir, exist_ok=True) 27 | 28 | for i, out_tmp in enumerate(outputs): 29 | cv2.imwrite(os.path.join(output_dir, f'{i}.png'), out_tmp) 30 | -------------------------------------------------------------------------------- /facechain/merge_lora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | import torch 3 | import os 4 | import re 5 | from collections import defaultdict 6 | from safetensors.torch import load_file 7 | 8 | 9 | def merge_lora(pipeline, lora_path, multiplier, from_safetensor=False, device='cpu', dtype=torch.float32): 10 | LORA_PREFIX_UNET = "lora_unet" 11 | LORA_PREFIX_TEXT_ENCODER = "lora_te" 12 | if from_safetensor: 13 | state_dict = load_file(lora_path, device=device) 14 | else: 15 | checkpoint = torch.load(os.path.join(lora_path, 'pytorch_lora_weights.bin'), map_location=torch.device(device)) 16 | new_dict = dict() 17 | for idx, key in enumerate(checkpoint): 18 | new_key = re.sub(r'\.processor\.', '_', key) 19 | new_key = re.sub(r'mid_block\.', 'mid_block_', new_key) 20 | new_key = re.sub('_lora.up.', '.lora_up.', new_key) 21 | new_key = re.sub('_lora.down.', '.lora_down.', new_key) 22 | new_key = re.sub(r'\.(\d+)\.', '_\\1_', new_key) 23 | new_key = re.sub('to_out', 'to_out_0', new_key) 24 | new_key = 'lora_unet_' + new_key 25 | new_dict[new_key] = checkpoint[key] 26 | state_dict = new_dict 27 | updates = defaultdict(dict) 28 | for key, value in state_dict.items(): 29 | layer, elem = key.split('.', 1) 30 | updates[layer][elem] = value 31 | 32 | for layer, elems in updates.items(): 33 | 34 | if "text" in layer: 35 | layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") 36 | curr_layer = pipeline.text_encoder 37 | else: 38 | layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") 39 | curr_layer = pipeline.unet 40 | 41 | temp_name = layer_infos.pop(0) 42 | while len(layer_infos) > -1: 43 | try: 44 | curr_layer = curr_layer.__getattr__(temp_name) 45 | if len(layer_infos) > 0: 46 | temp_name = layer_infos.pop(0) 47 | elif len(layer_infos) == 0: 48 | break 49 | except Exception: 50 | if len(layer_infos) == 0: 51 | print('Error loading layer') 52 | if len(temp_name) > 0: 53 | temp_name += "_" + layer_infos.pop(0) 54 | else: 55 | temp_name = layer_infos.pop(0) 56 | 57 | weight_up = elems['lora_up.weight'].to(dtype) 58 | weight_down = elems['lora_down.weight'].to(dtype) 59 | if 'alpha' in elems.keys(): 60 | alpha = elems['alpha'].item() / weight_up.shape[1] 61 | else: 62 | alpha = 1.0 63 | 64 | curr_layer.weight.data = curr_layer.weight.data.to(device) 65 | if len(weight_up.shape) == 4: 66 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), 67 | weight_down.squeeze(3).squeeze(2)).unsqueeze( 68 | 2).unsqueeze(3) 69 | else: 70 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) 71 | 72 | return pipeline 73 | -------------------------------------------------------------------------------- /README_ZH.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |
5 |

FaceChain

6 |

7 | 8 | 9 | 10 | # 介绍 11 | 12 | FaceChain是一个可以用来打造个人数字形象的深度学习模型工具。用户仅需要提供最低三张照片即可获得独属于自己的个人形象数字替身。FaceChain支持在gradio的界面中使用模型训练和推理能力,也支持资深开发者使用python脚本进行训练推理。同时,FaceChain欢迎开发者对本Repo进行继续开发和贡献。 13 | 14 | 您也可以在[ModelScope创空间](https://modelscope.cn/studios/CVstudio/cv_human_portrait/summary)中直接体验这项技术而无需安装任何软件。 15 | 16 | FaceChain的模型由[ModelScope](https://github.com/modelscope/modelscope)开源模型社区提供支持。 17 | 18 | ![image](resources/example1.jpg) 19 | 20 | ![image](resources/example2.jpg) 21 | 22 | ![image](resources/example3.jpg) 23 | 24 | # 安装 25 | 26 | 您也可以使用pip和conda搭建本地python环境,我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/)来管理您的依赖,安装完成后,执行如下命令: 27 | 28 | ```shell 29 | conda create -n facechain python=3.8 # python version >= 3.8 30 | conda activate facechain 31 | 32 | pip3 install -r requirements.txt 33 | pip3 install -U openmim 34 | mim install mmcv-full==1.7.0 35 | ``` 36 | 37 | 或者,您可以使用ModelScope提供的官方镜像,这样您只需要安装gradio即可使用: 38 | 39 | ```shell 40 | registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.7.1-py38-torch2.0.1-tf1.15.5-1.8.0 41 | ``` 42 | 43 | 我们也推荐使用我们的[notebook](https://www.modelscope.cn/my/mynotebook/preset)来进行训练和推理。 44 | 45 | 将本仓库克隆到本地: 46 | 47 | ```shell 48 | GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/modelscope/facechain.git 49 | cd facechain 50 | ``` 51 | 52 | 安装依赖: 53 | 54 | ```shell 55 | # 如果使用了官方镜像,只需要执行 56 | pip3 install gradio 57 | 58 | # 如果使用conda虚拟环境,则参考上述”安装“章节 59 | ``` 60 | 61 | 62 | 运行gradio来生成个人数字形象: 63 | 64 | ```shell 65 | python app.py 66 | ``` 67 | 68 | 您可以看到log中的gradio启动日志,等待展示出http链接后,将http链接复制到浏览器中进行访问。之后在页面中点击“选择图片上传”,并选择最少一张包含人脸的图片。点击“开始训练”即可训练模型。训练完成后日志中会有对应展示,之后切换到“形象体验”标签页点击“开始推理”即可生成属于自己的数字形象。 69 | 70 | # 脚本运行 71 | 72 | FaceChain支持在python环境中直接进行训练和推理。在克隆后的文件夹中直接运行如下命令来进行训练: 73 | 74 | ```shell 75 | PYTHONPATH=. sh train_lora.sh "ly261666/cv_portrait_model" "v2.0" "film/film" "./imgs" "./processed" "./output" 76 | ``` 77 | 78 | 参数含义: 79 | 80 | ```text 81 | ly261666/cv_portrait_model: ModelScope模型仓库的stable diffusion基模型,该模型会用于训练,可以不修改 82 | v2.0: 该基模型的版本号,可以不修改 83 | film/film: 该基模型包含了多个不同风格的子目录,其中使用了film/film目录中的风格模型,可以不修改 84 | ./imgs: 本参数需要用实际值替换,本参数是一个本地文件目录,包含了用来训练和生成的原始照片 85 | ./processed: 预处理之后的图片文件夹,这个参数需要在推理中被传入相同的值,可以不修改 86 | ./output: 训练生成保存模型weights的文件夹,可以不修改 87 | ``` 88 | 89 | 等待5-20分钟即可训练完成。用户也可以调节其他训练超参数,训练支持的超参数可以查看`train_lora.sh`的配置,或者`facechain/train_text_to_image_lora.py`中的完整超参数列表。 90 | 91 | 进行推理时,请编辑run_inference.py中的代码: 92 | 93 | ```python 94 | # 填入上述的预处理之后的图片文件夹,需要和训练时相同 95 | processed_dir = './processed' 96 | # 推理生成的图片数量 97 | num_generate = 5 98 | # 训练时使用的stable diffusion基模型,可以不修改 99 | base_model = 'ly261666/cv_portrait_model' 100 | # 该基模型的版本号,可以不修改 101 | revision = 'v2.0' 102 | # 该基模型包含了多个不同风格的子目录,其中使用了film/film目录中的风格模型,可以不修改 103 | base_model_sub_dir = 'film/film' 104 | # 训练生成保存模型weights的文件夹,需要保证和训练时相同 105 | train_output_dir = './output' 106 | # 指定一个保存生成的图片的文件夹,本参数可以根据需要修改 107 | output_dir = './generated' 108 | ``` 109 | 110 | 之后执行: 111 | 112 | ```python 113 | python run_inference.py 114 | ``` 115 | 116 | 即可在`output_dir`中找到生成的个人数字形象照片。 117 | 118 | # 算法介绍 119 | 120 | ## 基本原理 121 | 122 | 个人写真模型的能力来源于Stable Diffusion模型的文生图功能,输入一段文本或一系列提示词,输出对应的图像。我们考虑影响个人写真生成效果的主要因素:写真风格信息,以及用户人物信息。为此,我们分别使用线下训练的风格LoRA模型和线上训练的人脸LoRA模型以学习上述信息。LoRA是一种具有较少可训练参数的微调模型,在Stable Diffusion中,可以通过对少量输入图像进行文生图训练的方式将输入图像的信息注入到LoRA模型中。因此,个人写真模型的能力分为训练与推断两个阶段,训练阶段生成用于微调Stable Diffusion模型的图像与文本标签数据,得到人脸LoRA模型;推断阶段基于人脸LoRA模型和风格LoRA模型生成个人写真图像。 123 | 124 | ![image](resources/framework.jpg) 125 | 126 | ## 训练阶段 127 | 128 | 输入:用户上传的包含清晰人脸区域的图像 129 | 130 | 输出:人脸LoRA模型 131 | 132 | 描述:首先,我们分别使用基于朝向判断的图像旋转模型,以及基于人脸检测和关键点模型的人脸精细化旋转方法处理用户上传图像,得到包含正向人脸的图像;接下来,我们使用人体解析模型和人像美肤模型,以获得高质量的人脸训练图像;随后,我们使用人脸属性模型和文本标注模型,结合标签后处理方法,产生训练图像的精细化标签;最后,我们使用上述图像和标签数据微调Stable Diffusion模型得到人脸LoRA模型。 133 | 134 | ## 推断阶段 135 | 136 | 输入:训练阶段用户上传图像,预设的用于生成个人写真的输入提示词 137 | 138 | 输出:个人写真图像 139 | 140 | 描述:首先,我们将人脸LoRA模型和风格LoRA模型的权重融合到Stable Diffusion模型中;接下来,我们使用Stable Diffusion模型的文生图功能,基于预设的输入提示词初步生成个人写真图像;随后,我们使用人脸融合模型进一步改善上述写真图像的人脸细节,其中用于融合的模板人脸通过人脸质量评估模型在训练图像中挑选;最后,我们使用人脸识别模型计算生成的写真图像与模板人脸的相似度,以此对写真图像进行排序,并输出排名靠前的个人写真图像作为最终输出结果。 141 | 142 | ## 模型列表 143 | 144 | 附(流程图中模型链接) 145 | 146 | [1] 人脸检测+关键点模型DamoFD:https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damof 147 | 148 | [2] 图像旋转模型:创空间内置模型 149 | 150 | [3] 人体解析模型M2FP:https://modelscope.cn/models/damo/cv_resnet101_image-multiple-human-parsing 151 | 152 | [4] 人像美肤模型ABPN:https://modelscope.cn/models/damo/cv_unet_skin-retouching 153 | 154 | [5] 人脸属性模型FairFace:https://modelscope.cn/models/damo/cv_resnet34_face-attribute-recognition_fairface 155 | 156 | [6] 文本标注模型Deepbooru:https://github.com/KichangKim/DeepDanbooru 157 | 158 | [7] 模板脸筛选模型FQA:https://modelscope.cn/models/damo/cv_manual_face-quality-assessment_fqa 159 | 160 | [8] 人脸融合模型:https://modelscope.cn/models/damo/cv_unet-image-face-fusion_damo 161 | 162 | [9] 人脸识别模型RTS:https://modelscope.cn/models/damo/cv_ir_face-recognition-ood_rts 163 | 164 | # 更多信息 165 | 166 | - [ModelScope library](https://github.com/modelscope/modelscope/) 167 | 168 | ModelScope Library是一个托管于github上的模型生态仓库,隶属于达摩院魔搭项目。 169 | 170 | - [贡献模型到ModelScope](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88) 171 | 172 | # License 173 | 174 | This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). 175 | 176 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |
3 | 4 |
5 |

FaceChain

6 |

7 | 8 | # Introduction 9 | 10 | 如果您熟悉中文,可以阅读[中文版本的README](./README_ZH.md)。 11 | 12 | FaceChain is a deep-learning toolchain for generating your Digital-Twin. With a minimum of 1 portrait-photo, you can create a Digital-Twin of your own and to create personal photos in different settings (work photos as starter!). You may train your Digital-Twin model and generate photos via FaceChain's Python scripts, or via the familiar Gradio interface. You can also experience FaceChain directly with our [ModelScope Studio](https://modelscope.cn/studios/CVstudio/cv_human_portrait/summary). 13 | 14 | FaceChain is powered by [ModelScope](https://github.com/modelscope/modelscope). 15 | 16 | ![image](resources/example1.jpg) 17 | 18 | ![image](resources/example2.jpg) 19 | 20 | ![image](resources/example3.jpg) 21 | 22 | # Installation 23 | 24 | You may use pip and conda to build a local python environment. We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage your dependencies. After installation, execute the following commands: 25 | 26 | ```shell 27 | conda create -n facechain python=3.8 # python version >= 3.8 28 | conda activate facechain 29 | 30 | pip3 install -r requirements.txt 31 | pip3 install -U openmim 32 | mim install mmcv-full==1.7.0 33 | ```` 34 | 35 | You may use the official docker-image provided by ModelScope: 36 | 37 | ```shell 38 | registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.7.1-py38-torch2.0.1-tf1.15.5-1.8.0 39 | ``` 40 | With this docker image, the only thing you need to install is Gradio. 41 | 42 | For online training an inference, you may leverage the ModelScope [notebook](https://www.modelscope.cn/my/mynotebook/) to start the process immediately. 43 | 44 | Clone the repo: 45 | 46 | ```shell 47 | GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/modelscope/facechain.git 48 | cd facechain 49 | ``` 50 | 51 | Install dependencies: 52 | 53 | ```shell 54 | # If you use the official docker image, you only need to execute 55 | pip3 install gradio 56 | 57 | # If you use the conda env, please refer to section Installation. 58 | ``` 59 | 60 | Launch Gradio to generate personal digital images: 61 | 62 | ```shell 63 | python app.py 64 | ``` 65 | 66 | You can reference to the Gradio startup log in the log. Once the hyper-link is displayed, copy it to your browser for access. Then click on "Select Image Upload" on the page, and select at least one picture containing a face. Click "Start Training" to train the model. After the training is completed, there will be a corresponding display in the log. Afterward, switch to the "Image Experience" tab and click "Start Inference" to generate your own digital image. 67 | 68 | # Script Execution 69 | 70 | FaceChain supports direct training and inference in the python environment. Run the following command in the cloned folder to start training: 71 | 72 | ```shell 73 | PYTHONPATH=. sh train_lora.sh "ly261666/cv_portrait_model" "v2.0" "film/film" "./imgs" "./processed" "./output" 74 | ``` 75 | 76 | Parameter meaning: 77 | 78 | ```text 79 | ly261666/cv_portrait_model: The stable diffusion base model of the ModelScope model hub, which will be used for training, no need to be changed. 80 | v2.0: The version number of this base model, no need to be changed 81 | film/film: This base model may contains multiple subdirectories of different styles, currently we use film/film, no need to be changed 82 | ./imgs: This parameter needs to be replaced with the actual value. It means a local file directory that contains the original photos used for training and generation 83 | ./processed: The folder of the processed images after preprocessing, this parameter needs to be passed the same value in inference, no need to be changed 84 | ./output: The folder where the model weights stored after training, no need to be changed 85 | ``` 86 | 87 | Wait for 5-20 minutes to complete the training. Users can also adjust other training hyperparameters. The hyperparameters supported by training can be viewed in the file of `train_lora.sh`, or the complete hyperparameter list in `facechain/train_text_to_image_lora.py`. 88 | 89 | When inferring, please edit the code in run_inference.py: 90 | 91 | ```python 92 | # Fill in the folder of the images after preprocessing above, it should be the same as during training 93 | processed_dir = './processed' 94 | # The number of images to generate in inference 95 | num_generate = 5 96 | # The stable diffusion base model used in training, no need to be changed 97 | base_model = 'ly261666/cv_portrait_model' 98 | # The version number of this base model, no need to be changed 99 | revision = 'v2.0' 100 | # This base model may contains multiple subdirectories of different styles, currently we use film/film, no need to be changed 101 | base_model_sub_dir = 'film/film' 102 | # The folder where the model weights stored after training, it must be the same as during training 103 | train_output_dir = './output' 104 | # Specify a folder to save the generated images, this parameter can be modified as needed 105 | output_dir = './generated' 106 | ``` 107 | 108 | Then execute: 109 | 110 | ```shell 111 | python run_inference.py 112 | ``` 113 | 114 | You can find the generated personal digital image photos in the `output_dir`. 115 | 116 | # Algorithm Introduction 117 | 118 | ## Principle 119 | 120 | The ability of the personal portrait model comes from the text generation image function of the Stable Diffusion model. It inputs a piece of text or a series of prompt words and outputs corresponding images. We consider the main factors that affect the generation effect of personal portraits: portrait style information and user character information. For this, we use the style LoRA model trained offline and the face LoRA model trained online to learn the above information. LoRA is a fine-tuning model with fewer trainable parameters. In Stable Diffusion, the information of the input image can be injected into the LoRA model by the way of text generation image training with a small amount of input image. Therefore, the ability of the personal portrait model is divided into training and inference stages. The training stage generates image and text label data for fine-tuning the Stable Diffusion model, and obtains the face LoRA model. The inference stage generates personal portrait images based on the face LoRA model and style LoRA model. 121 | 122 | ![image](resources/framework_eng.jpg) 123 | 124 | ## Training 125 | 126 | Input: User-uploaded images that contain clear face areas 127 | 128 | Output: Face LoRA model 129 | 130 | Description: First, we process the user-uploaded images using an image rotation model based on orientation judgment and a face refinement rotation method based on face detection and keypoint models, and obtain images containing forward faces. Next, we use a human body parsing model and a human portrait beautification model to obtain high-quality face training images. Afterwards, we use a face attribute model and a text annotation model, combined with tag post-processing methods, to generate fine-grained labels for training images. Finally, we use the above images and label data to fine-tune the Stable Diffusion model to obtain the face LoRA model. 131 | 132 | ## Inference 133 | 134 | Input: User-uploaded images in the training phase, preset input prompt words for generating personal portraits 135 | 136 | Output: Personal portrait image 137 | 138 | Description: First, we fuse the weights of the face LoRA model and style LoRA model into the Stable Diffusion model. Next, we use the text generation image function of the Stable Diffusion model to preliminarily generate personal portrait images based on the preset input prompt words. Then we further improve the face details of the above portrait image using the face fusion model. The template face used for fusion is selected from the training images through the face quality evaluation model. Finally, we use the face recognition model to calculate the similarity between the generated portrait image and the template face, and use this to sort the portrait images, and output the personal portrait image that ranks first as the final output result. 139 | 140 | ## Model List 141 | 142 | The models used in FaceChain: 143 | 144 | [1] Face detection model DamoFD:https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damof 145 | 146 | [2] Image rotating model, offered in the ModelScope studio 147 | 148 | [3] Human parsing model M2FP:https://modelscope.cn/models/damo/cv_resnet101_image-multiple-human-parsing 149 | 150 | [4] Skin retouching model ABPN:https://modelscope.cn/models/damo/cv_unet_skin-retouching 151 | 152 | [5] Face attribute recognition model FairFace:https://modelscope.cn/models/damo/cv_resnet34_face-attribute-recognition_fairface 153 | 154 | [6] DeepDanbooru model:https://github.com/KichangKim/DeepDanbooru 155 | 156 | [7] Face quality assessment FQA:https://modelscope.cn/models/damo/cv_manual_face-quality-assessment_fqa 157 | 158 | [8] Face fusion model:https://modelscope.cn/models/damo/cv_unet-image-face-fusion_damo 159 | 160 | [9] Face recognition model RTS:https://modelscope.cn/models/damo/cv_ir_face-recognition-ood_rts 161 | 162 | # More Information 163 | 164 | - [ModelScope library](https://github.com/modelscope/modelscope/) 165 | 166 | 167 | ​ ModelScope Library provides the foundation for building the model-ecosystem of ModelScope, including the interface and implementation to integrate various models into ModelScope. 168 | 169 | - [Contribute models to ModelScope](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88) 170 | 171 | # License 172 | 173 | This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). 174 | -------------------------------------------------------------------------------- /facechain/inference.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | import json 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from diffusers import StableDiffusionPipeline 10 | from modelscope.outputs import OutputKeys 11 | from modelscope.pipelines import pipeline 12 | from modelscope.utils.constant import Tasks 13 | from modelscope import snapshot_download 14 | 15 | from facechain.merge_lora import merge_lora 16 | from facechain.data_process.preprocessing import Blipv2 17 | 18 | 19 | def data_process_fn(input_img_dir, use_data_process): 20 | ## TODO add face quality filter 21 | if use_data_process: 22 | ## TODO 23 | data_process_fn = Blipv2() 24 | out_json_name = data_process_fn(input_img_dir) 25 | return out_json_name 26 | else: 27 | return os.path.join(str(input_img_dir) + '_labeled', "metadata.jsonl") 28 | 29 | 30 | def txt2img(pipe, pos_prompt, neg_prompt, num_images=10): 31 | images_out = [] 32 | for i in range(int(num_images / 5)): 33 | images_style = pipe(prompt=pos_prompt, height=512, width=512, guidance_scale=7, negative_prompt=neg_prompt, 34 | num_inference_steps=40, num_images_per_prompt=5).images 35 | images_out.extend(images_style) 36 | return images_out 37 | 38 | 39 | def main_diffusion_inference(input_img_dir, base_model_path, style_model_path, lora_model_path, multiplier_style=0.25, 40 | multiplier_human=1.0): 41 | pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) 42 | neg_prompt = 'nsfw, paintings, sketches, (worst quality:2), (low quality:2) lowers, normal quality, ((monochrome)), ((grayscale)), logo, word, character' 43 | pos_prompt = 'raw photo, masterpiece, chinese, simple background, wearing high-class business/working suit, high-class pure color background, solo, medium shot, high detail face, looking straight into the camera with shoulders parallel to the frame, slim body, photorealistic, best quality' 44 | lora_style_path = style_model_path 45 | lora_human_path = lora_model_path 46 | pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True) 47 | pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=False) 48 | train_dir = str(input_img_dir) + '_labeled' 49 | add_prompt_style = [] 50 | f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r') 51 | tags_all = [] 52 | cnt = 0 53 | cnts_trigger = np.zeros(6) 54 | for line in f: 55 | cnt += 1 56 | data = json.loads(line)['text'].split(', ') 57 | tags_all.extend(data) 58 | if data[1] == 'a boy': 59 | cnts_trigger[0] += 1 60 | elif data[1] == 'a girl': 61 | cnts_trigger[1] += 1 62 | elif data[1] == 'a handsome man': 63 | cnts_trigger[2] += 1 64 | elif data[1] == 'a beautiful woman': 65 | cnts_trigger[3] += 1 66 | elif data[1] == 'a mature man': 67 | cnts_trigger[4] += 1 68 | elif data[1] == 'a mature woman': 69 | cnts_trigger[5] += 1 70 | else: 71 | print('Error.') 72 | f.close() 73 | 74 | attr_idx = np.argmax(cnts_trigger) 75 | trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ', 76 | 'a mature man, ', 'a mature woman, '] 77 | trigger_style = ', ' + trigger_styles[attr_idx] 78 | if attr_idx == 2 or attr_idx == 4: 79 | neg_prompt += ', children' 80 | 81 | for tag in tags_all: 82 | if tags_all.count(tag) > 0.5 * cnt: 83 | if ('hair' in tag or 'face' in tag or 'mouth' in tag or 'skin' in tag or 'smile' in tag): 84 | if not tag in add_prompt_style: 85 | add_prompt_style.append(tag) 86 | 87 | if len(add_prompt_style) > 0: 88 | add_prompt_style = ", ".join(add_prompt_style) + ', ' 89 | else: 90 | add_prompt_style = '' 91 | # trigger_style = trigger_style + 'with face, ' 92 | # pos_prompt = 'Generate a standard ID photo of a chinese {}, solo, wearing high-class business/working suit, beautiful smooth face, with high-class/simple pure color background, looking straight into the camera with shoulders parallel to the frame, smile, high detail face, best quality, photorealistic'.format(gender) 93 | pipe = pipe.to("cuda") 94 | # print(trigger_style + add_prompt_style + pos_prompt) 95 | images_style = txt2img(pipe, trigger_style + add_prompt_style + pos_prompt, neg_prompt, num_images=10) 96 | return images_style 97 | 98 | 99 | def stylization_fn(use_stylization, rank_results): 100 | if use_stylization: 101 | ## TODO 102 | pass 103 | else: 104 | return rank_results 105 | 106 | 107 | def main_model_inference(use_main_model, input_img_dir=None, base_model_path=None, lora_model_path=None): 108 | if use_main_model: 109 | model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0') 110 | style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors') 111 | image = main_diffusion_inference(input_img_dir, base_model_path, style_model_path, lora_model_path) 112 | return image 113 | 114 | 115 | def select_high_quality_face(input_img_dir): 116 | input_img_dir = str(input_img_dir) + '_labeled' 117 | quality_score_list = [] 118 | abs_img_path_list = [] 119 | ## TODO 120 | face_quality_func = pipeline(Tasks.face_quality_assessment, 'damo/cv_manual_face-quality-assessment_fqa') 121 | 122 | for img_name in os.listdir(input_img_dir): 123 | if img_name.endswith('jsonl') or img_name.startswith('.ipynb'): 124 | continue 125 | abs_img_name = os.path.join(input_img_dir, img_name) 126 | face_quality_score = face_quality_func(abs_img_name)[OutputKeys.SCORES] 127 | if face_quality_score is None: 128 | quality_score_list.append(0) 129 | else: 130 | quality_score_list.append(face_quality_score[0]) 131 | abs_img_path_list.append(abs_img_name) 132 | 133 | sort_idx = np.argsort(quality_score_list)[::-1] 134 | print('Selected face: ' + abs_img_path_list[sort_idx[0]]) 135 | 136 | return Image.open(abs_img_path_list[sort_idx[0]]) 137 | 138 | 139 | def face_swap_fn(use_face_swap, gen_results, template_face): 140 | if use_face_swap: 141 | ## TODO 142 | out_img_list = [] 143 | image_face_fusion = pipeline(Tasks.image_face_fusion, 144 | model='damo/cv_unet-image-face-fusion_damo') 145 | for img in gen_results: 146 | result = image_face_fusion(dict(template=img, user=template_face))[OutputKeys.OUTPUT_IMG] 147 | out_img_list.append(result) 148 | 149 | return out_img_list 150 | else: 151 | ret_results = [] 152 | for img in gen_results: 153 | ret_results.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)) 154 | return ret_results 155 | 156 | 157 | def post_process_fn(use_post_process, swap_results_ori, selected_face, num_gen_images): 158 | if use_post_process: 159 | sim_list = [] 160 | ## TODO 161 | # face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_vit_face-recognition') 162 | face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_ir_face-recognition-ood_rts') 163 | face_det_func = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd') 164 | swap_results = [] 165 | for img in swap_results_ori: 166 | result_det = face_det_func(img) 167 | bboxes = result_det['boxes'] 168 | if len(bboxes) == 1: 169 | bbox = bboxes[0] 170 | lenface = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) 171 | if 160 < lenface < 360: 172 | swap_results.append(img) 173 | 174 | select_face_emb = face_recognition_func(selected_face)[OutputKeys.IMG_EMBEDDING][0] 175 | 176 | for img in swap_results: 177 | emb = face_recognition_func(img)[OutputKeys.IMG_EMBEDDING] 178 | if emb is None or select_face_emb is None: 179 | sim_list.append(0) 180 | else: 181 | sim = np.dot(emb, select_face_emb) 182 | sim_list.append(sim.item()) 183 | sort_idx = np.argsort(sim_list)[::-1] 184 | 185 | return np.array(swap_results)[sort_idx[:min(int(num_gen_images), len(swap_results))]] 186 | else: 187 | return np.array(swap_results_ori) 188 | 189 | 190 | class GenPortrait: 191 | def __init__(self, use_main_model=True, use_face_swap=True, 192 | use_post_process=True, use_stylization=True): 193 | self.use_main_model = use_main_model 194 | self.use_face_swap = use_face_swap 195 | self.use_post_process = use_post_process 196 | self.use_stylization = use_stylization 197 | 198 | def __call__(self, input_img_dir, num_gen_images=6, base_model_path=None, 199 | lora_model_path=None, sub_path=None, revision=None): 200 | base_model_path = snapshot_download(base_model_path, revision=revision) 201 | if sub_path is not None and len(sub_path) > 0: 202 | base_model_path = os.path.join(base_model_path, sub_path) 203 | 204 | # main_model_inference PIL 205 | gen_results = main_model_inference(self.use_main_model, input_img_dir=input_img_dir, 206 | lora_model_path=lora_model_path, base_model_path=base_model_path) 207 | # select_high_quality_face PIL 208 | selected_face = select_high_quality_face(input_img_dir) 209 | # face_swap cv2 210 | swap_results = face_swap_fn(self.use_face_swap, gen_results, selected_face) 211 | # pose_process 212 | rank_results = post_process_fn(self.use_post_process, swap_results, selected_face, 213 | num_gen_images=num_gen_images) 214 | # stylization 215 | final_gen_results = stylization_fn(self.use_stylization, rank_results) 216 | 217 | return final_gen_results 218 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | import enum 3 | import os 4 | import shutil 5 | import sys 6 | import time 7 | from concurrent.futures import ThreadPoolExecutor 8 | 9 | import cv2 10 | import gradio as gr 11 | import numpy as np 12 | import torch 13 | 14 | from facechain.inference import GenPortrait 15 | 16 | sys.path.append('facechain') 17 | 18 | 19 | training_threadpool = ThreadPoolExecutor(max_workers=1) 20 | inference_threadpool = ThreadPoolExecutor(max_workers=5) 21 | 22 | training_done_count = 0 23 | inference_done_count = 0 24 | 25 | HOT_MODELS = [ 26 | "\N{fire}数字身份", 27 | ] 28 | 29 | 30 | class UploadTarget(enum.Enum): 31 | PERSONAL_PROFILE = 'Personal Profile' 32 | LORA_LIaBRARY = 'LoRA Library' 33 | 34 | 35 | def concatenate_images(images): 36 | heights = [img.shape[0] for img in images] 37 | max_width = sum([img.shape[1] for img in images]) 38 | 39 | concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8) 40 | x_offset = 0 41 | for img in images: 42 | concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img 43 | x_offset += img.shape[1] 44 | return concatenated_image 45 | 46 | 47 | def train_lora_fn(foundation_model_path=None, revision=None, input_img_dir=None, output_img_dir=None, work_dir=None): 48 | sh_file_path = os.path.join('/'.join(os.path.abspath(__file__).split('/')[:-1]), 49 | 'train_lora.sh') 50 | os.system(f'PYTHONPATH=. sh {sh_file_path} {foundation_model_path} {revision} "film/film" {input_img_dir} {output_img_dir} {work_dir}') 51 | 52 | 53 | def launch_pipeline(uuid, 54 | user_models, 55 | num_images=1, 56 | ): 57 | base_model = 'ly261666/cv_portrait_model' 58 | before_queue_size = inference_threadpool._work_queue.qsize() 59 | before_done_count = inference_done_count 60 | 61 | print("-------user_models: ", user_models) 62 | if not uuid: 63 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio': 64 | return "请登陆后使用! " 65 | else: 66 | uuid = 'qw' 67 | 68 | use_main_model = True 69 | use_face_swap = True 70 | use_post_process = True 71 | use_stylization = False 72 | 73 | output_model_name = 'personalizaition_lora' 74 | instance_data_dir = os.path.join('/tmp', uuid, 'training_data', output_model_name) 75 | 76 | lora_model_path = f'/tmp/{uuid}/{output_model_name}' 77 | 78 | gen_portrait = GenPortrait(use_main_model, use_face_swap, use_post_process, 79 | use_stylization) 80 | 81 | num_images = min(6, num_images) 82 | future = inference_threadpool.submit(gen_portrait, instance_data_dir, 83 | num_images, base_model, lora_model_path, 'film/film', 'v2.0') 84 | 85 | while not future.done(): 86 | is_processing = future.running() 87 | if not is_processing: 88 | cur_done_count = inference_done_count 89 | to_wait = before_queue_size - (cur_done_count - before_done_count) 90 | # yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, round(to_wait*2.5, 5)), None] 91 | yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, to_wait * 2.5), None] 92 | else: 93 | yield ["生成中, 请耐心等待...", None] 94 | time.sleep(1) 95 | 96 | outputs = future.result() 97 | outputs_RGB = [] 98 | for out_tmp in outputs: 99 | outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB)) 100 | image_path = './lora_result.png' 101 | result = concatenate_images(outputs) 102 | cv2.imwrite(image_path, result) 103 | # image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 104 | 105 | yield ["生成完毕!", outputs_RGB] 106 | 107 | 108 | class Trainer: 109 | def __init__(self): 110 | pass 111 | 112 | def run( 113 | self, 114 | uuid: str, 115 | instance_images: list, 116 | ) -> str: 117 | 118 | if not torch.cuda.is_available(): 119 | raise gr.Error('CUDA is not available.') 120 | if instance_images is None: 121 | raise gr.Error('您需要上传训练图片!') 122 | # return "请传完图片后点击<开始训练>! " 123 | if len(instance_images) > 10: 124 | raise gr.Error('您需要上传小于10张训练图片!') 125 | if not uuid: 126 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio': 127 | return "请登陆后使用! " 128 | else: 129 | uuid = 'qw' 130 | 131 | output_model_name = 'personalizaition_lora' 132 | 133 | # mv user upload data to target dir 134 | instance_data_dir = os.path.join('/tmp', uuid, 'training_data', output_model_name) 135 | print("--------uuid: ", uuid) 136 | 137 | if not os.path.exists(f"/tmp/{uuid}"): 138 | os.makedirs(f"/tmp/{uuid}") 139 | work_dir = f"/tmp/{uuid}/{output_model_name}" 140 | print("----------work_dir: ", work_dir) 141 | 142 | source_img_dir = f"/tmp/{uuid}/sources" 143 | shutil.rmtree(source_img_dir, ignore_errors=True) 144 | os.makedirs(source_img_dir, exist_ok=True) 145 | for img in instance_images: 146 | shutil.copy(img['name'], os.path.join(source_img_dir, os.path.basename(img['name']))) 147 | 148 | # train lora 149 | train_lora_fn(foundation_model_path='ly261666/cv_portrait_model', 150 | revision='v2.0', input_img_dir=source_img_dir, 151 | output_img_dir=instance_data_dir, 152 | work_dir=work_dir) 153 | 154 | message = f'训练已经完成!请切换至 [形象体验] 标签体验模型效果' 155 | print(message) 156 | return message 157 | 158 | 159 | def flash_model_list(uuid): 160 | folder_path = f"/tmp/{uuid}" 161 | folder_list = [] 162 | print("------flash_model_list folder_path: ", folder_path) 163 | if not os.path.exists(folder_path): 164 | print('--------The folder_path is missing.') 165 | else: 166 | files = os.listdir(folder_path) 167 | for file in files: 168 | file_path = os.path.join(folder_path, file) 169 | if os.path.isdir(folder_path): 170 | file_lora_path = f"{file_path}/output/pytorch_lora_weights.bin" 171 | if os.path.exists(file_lora_path): 172 | folder_list.append(file) 173 | 174 | print("-------folder_list + HOT_MODELS: ", folder_list + HOT_MODELS) 175 | return gr.Radio.update(choices=HOT_MODELS + folder_list) 176 | 177 | 178 | def upload_file(files): 179 | file_paths = [file.name for file in files] 180 | return file_paths 181 | 182 | 183 | def train_input(): 184 | trainer = Trainer() 185 | 186 | with gr.Blocks() as demo: 187 | uuid = gr.Text(label="modelscope_uuid", visible=False) 188 | with gr.Row(): 189 | with gr.Column(): 190 | with gr.Box(): 191 | gr.Markdown('训练数据') 192 | # instance_images = gr.Files(label='Instance images', visible=False) 193 | instance_images = gr.Gallery() 194 | upload_button = gr.UploadButton("选择图片上传", file_types=["image"], file_count="multiple") 195 | upload_button.upload(upload_file, upload_button, instance_images) 196 | gr.Markdown(''' 197 | - Step 0. 登陆ModelScope账号,未登录无法使用定制功能 198 | - Step 1. 上传你计划训练的图片,3~10张头肩照(注意:图片中多人脸、脸部遮挡等情况会导致效果异常,需要重新上传符合规范图片训练) 199 | - Step 2. 点击 [形象定制] ,启动模型训练,等待约15分钟,请您耐心等待 200 | - Step 3. 切换至 [形象体验] ,生成你的风格照片 201 | - 注意:生成结果严禁用于非法用途! 202 | ''') 203 | 204 | run_button = gr.Button('开始训练(等待上传图片加载显示出来再点,否则会报错)') 205 | 206 | with gr.Box(): 207 | gr.Markdown( 208 | '输出信号(出现error时训练可能已完成或还在进行。可直接切到形象体验tab页面,如果体验时报错则训练还没好,再等待一般10来分钟。)') 209 | output_message = gr.Markdown() 210 | with gr.Box(): 211 | gr.Markdown(''' 212 | 碰到抓狂的错误或者计算资源紧张的情况下,推荐直接在[NoteBook](https://modelscope.cn/my/mynotebook/preset)上按照如下命令自行体验 213 | 1. git clone https://www.modelscope.cn/studios/CVstudio/cv_human_portrait.git 214 | 2. cd cv_human_portrait 215 | 3. pip install -r requirements.txt 216 | 4. pip install gradio==3.35.2 217 | 5. python app.py 218 | ''') 219 | 220 | run_button.click(fn=trainer.run, 221 | inputs=[ 222 | uuid, 223 | instance_images, 224 | ], 225 | outputs=[output_message]) 226 | 227 | return demo 228 | 229 | 230 | def inference_input(): 231 | with gr.Blocks() as demo: 232 | uuid = gr.Text(label="modelscope_uuid", visible=False) 233 | with gr.Row(): 234 | with gr.Column(): 235 | # user_models = gr.Radio(label="风格选择", choices=['\N{fire}商务证件'], type="value", value='\N{fire}商务证件') 236 | user_models = gr.Radio(label="模型选择", choices=HOT_MODELS, type="value", value=HOT_MODELS[0]) 237 | # flash_button = gr.Button('刷新模型列表') 238 | 239 | with gr.Box(): 240 | num_images = gr.Number( 241 | label='生成图片数量', value=6, precision=1) 242 | gr.Markdown(''' 243 | 注意:最多支持生成6张图片! 244 | ''') 245 | 246 | display_button = gr.Button('开始推理') 247 | 248 | with gr.Box(): 249 | infer_progress = gr.Textbox(label="生成进度", value="当前无生成任务", interactive=False) 250 | with gr.Box(): 251 | gr.Markdown('生成结果') 252 | # output_image = gr.Image() 253 | output_images = gr.Gallery(label='Output', show_label=False).style(columns=3, rows=2, height=600, 254 | object_fit="contain") 255 | 256 | display_button.click(fn=launch_pipeline, 257 | inputs=[uuid, user_models, num_images], 258 | outputs=[infer_progress, output_images]) 259 | 260 | return demo 261 | 262 | 263 | with gr.Blocks(css='style.css') as demo: 264 | with gr.Tabs(): 265 | with gr.TabItem('\N{rocket}形象定制'): 266 | train_input() 267 | with gr.TabItem('\N{party popper}形象体验'): 268 | inference_input() 269 | 270 | # demo.queue(max_size=100).launch(share=False) 271 | # demo.queue(concurrency_count=20).launch(share=False) 272 | demo.queue(status_update_rate=1).launch(share=True) 273 | -------------------------------------------------------------------------------- /app-win10.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | import enum 3 | import os 4 | import shutil 5 | import sys 6 | import time 7 | from concurrent.futures import ThreadPoolExecutor 8 | 9 | import cv2 10 | import gradio as gr 11 | import numpy as np 12 | import torch 13 | 14 | from facechain.inference import GenPortrait 15 | 16 | sys.path.append('facechain') 17 | os.environ["PYTHONPATH"] = os.path.dirname(os.path.abspath(__file__)) 18 | # print("PYTHONPATH:", os.environ["PYTHONPATH"]) 19 | 20 | training_threadpool = ThreadPoolExecutor(max_workers=1) 21 | inference_threadpool = ThreadPoolExecutor(max_workers=5) 22 | 23 | training_done_count = 0 24 | inference_done_count = 0 25 | 26 | HOT_MODELS = [ 27 | "\N{fire}数字身份", 28 | ] 29 | 30 | 31 | class UploadTarget(enum.Enum): 32 | PERSONAL_PROFILE = 'Personal Profile' 33 | LORA_LIaBRARY = 'LoRA Library' 34 | 35 | 36 | def concatenate_images(images): 37 | heights = [img.shape[0] for img in images] 38 | max_width = sum([img.shape[1] for img in images]) 39 | 40 | concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8) 41 | x_offset = 0 42 | for img in images: 43 | concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img 44 | x_offset += img.shape[1] 45 | return concatenated_image 46 | 47 | 48 | # def train_lora_fn(foundation_model_path=None, revision=None, input_img_dir=None, output_img_dir=None, work_dir=None): 49 | # sh_file_path = os.path.join('/'.join(os.path.abspath(__file__).split('/')[:-1]), 50 | # 'train_lora.sh') 51 | # os.system(f'PYTHONPATH=. sh {sh_file_path} {foundation_model_path} {revision} "film/film" {input_img_dir} {output_img_dir} {work_dir}') 52 | 53 | 54 | def train_lora_fn(foundation_model_path=None, revision=None, input_img_dir=None, output_img_dir=None, work_dir=None): 55 | bat_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_lora.bat') 56 | 57 | os.system(f'{bat_file_path} {foundation_model_path} {revision} "film/film" {input_img_dir} {output_img_dir} {work_dir}') 58 | 59 | 60 | def launch_pipeline(uuid, 61 | user_models, 62 | num_images=1, 63 | ): 64 | base_model = 'ly261666/cv_portrait_model' 65 | before_queue_size = inference_threadpool._work_queue.qsize() 66 | before_done_count = inference_done_count 67 | 68 | print("-------user_models: ", user_models) 69 | if not uuid: 70 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio': 71 | return "请登陆后使用! " 72 | else: 73 | uuid = 'qw' 74 | 75 | use_main_model = True 76 | use_face_swap = True 77 | use_post_process = True 78 | use_stylization = False 79 | 80 | output_model_name = 'personalizaition_lora' 81 | instance_data_dir = os.path.join('D:\\AI\\', uuid, 'training_data', output_model_name) 82 | 83 | lora_model_path = f'D:\\AI\\{uuid}\{output_model_name}' 84 | 85 | gen_portrait = GenPortrait(use_main_model, use_face_swap, use_post_process, 86 | use_stylization) 87 | 88 | num_images = min(6, num_images) 89 | future = inference_threadpool.submit(gen_portrait, instance_data_dir, 90 | num_images, base_model, lora_model_path, 'film/film', 'v2.0') 91 | 92 | while not future.done(): 93 | is_processing = future.running() 94 | if not is_processing: 95 | cur_done_count = inference_done_count 96 | to_wait = before_queue_size - (cur_done_count - before_done_count) 97 | # yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, round(to_wait*2.5, 5)), None] 98 | yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, to_wait * 2.5), None] 99 | else: 100 | yield ["生成中, 请耐心等待...", None] 101 | time.sleep(1) 102 | 103 | outputs = future.result() 104 | outputs_RGB = [] 105 | for out_tmp in outputs: 106 | outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB)) 107 | image_path = '.\lora_result.png' 108 | result = concatenate_images(outputs) 109 | cv2.imwrite(image_path, result) 110 | # image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB) 111 | 112 | yield ["生成完毕!", outputs_RGB] 113 | 114 | 115 | class Trainer: 116 | def __init__(self): 117 | pass 118 | 119 | def run( 120 | self, 121 | uuid: str, 122 | instance_images: list, 123 | ) -> str: 124 | 125 | if not torch.cuda.is_available(): 126 | raise gr.Error('CUDA is not available.') 127 | if instance_images is None: 128 | raise gr.Error('您需要上传训练图片!') 129 | # return "请传完图片后点击<开始训练>! " 130 | if len(instance_images) > 10: 131 | raise gr.Error('您需要上传小于10张训练图片!') 132 | if not uuid: 133 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio': 134 | return "请登陆后使用! " 135 | else: 136 | uuid = 'qw' 137 | 138 | output_model_name = 'personalizaition_lora' 139 | 140 | # mv user upload data to target dir 141 | instance_data_dir = os.path.join('D:\\AI\\', uuid, 'training_data', output_model_name) 142 | print("--------uuid: ", uuid) 143 | 144 | if not os.path.exists(f"D:\\AI\\{uuid}"): 145 | os.makedirs(f"D:\\AI\\{uuid}") 146 | work_dir = f"D:\\AI\\{uuid}\\{output_model_name}" 147 | print("----------work_dir: ", work_dir) 148 | 149 | source_img_dir = f"D:\\AI\\{uuid}\\sources" 150 | shutil.rmtree(source_img_dir, ignore_errors=True) 151 | os.makedirs(source_img_dir, exist_ok=True) 152 | for img in instance_images: 153 | shutil.copy(img['name'], os.path.join(source_img_dir, os.path.basename(img['name']))) 154 | 155 | # train lora 156 | train_lora_fn(foundation_model_path='ly261666/cv_portrait_model', 157 | revision='v2.0', input_img_dir=source_img_dir, 158 | output_img_dir=instance_data_dir, 159 | work_dir=work_dir) 160 | 161 | message = f'训练已经完成!请切换至 [形象体验] 标签体验模型效果' 162 | print(message) 163 | return message 164 | 165 | 166 | def flash_model_list(uuid): 167 | folder_path = f"D:\\AI\\{uuid}" 168 | folder_list = [] 169 | print("------flash_model_list folder_path: ", folder_path) 170 | if not os.path.exists(folder_path): 171 | print('--------The folder_path is missing.') 172 | else: 173 | files = os.listdir(folder_path) 174 | for file in files: 175 | file_path = os.path.join(folder_path, file) 176 | if os.path.isdir(folder_path): 177 | file_lora_path = f"{file_path}\\output\\pytorch_lora_weights.bin" 178 | if os.path.exists(file_lora_path): 179 | folder_list.append(file) 180 | 181 | print("-------folder_list + HOT_MODELS: ", folder_list + HOT_MODELS) 182 | return gr.Radio.update(choices=HOT_MODELS + folder_list) 183 | 184 | 185 | def upload_file(files): 186 | file_paths = [file.name for file in files] 187 | return file_paths 188 | 189 | 190 | def train_input(): 191 | trainer = Trainer() 192 | 193 | with gr.Blocks() as demo: 194 | uuid = gr.Text(label="modelscope_uuid", visible=False) 195 | with gr.Row(): 196 | with gr.Column(): 197 | with gr.Box(): 198 | gr.Markdown('训练数据') 199 | # instance_images = gr.Files(label='Instance images', visible=False) 200 | instance_images = gr.Gallery() 201 | upload_button = gr.UploadButton("选择图片上传", file_types=["image"], file_count="multiple") 202 | upload_button.upload(upload_file, upload_button, instance_images) 203 | gr.Markdown(''' 204 | - Step 0. 登陆ModelScope账号,未登录无法使用定制功能 205 | - Step 1. 上传你计划训练的图片,3~10张头肩照(注意:图片中多人脸、脸部遮挡等情况会导致效果异常,需要重新上传符合规范图片训练) 206 | - Step 2. 点击 [形象定制] ,启动模型训练,等待约15分钟,请您耐心等待 207 | - Step 3. 切换至 [形象体验] ,生成你的风格照片 208 | - 注意:生成结果严禁用于非法用途! 209 | ''') 210 | 211 | run_button = gr.Button('开始训练(等待上传图片加载显示出来再点,否则会报错)') 212 | 213 | with gr.Box(): 214 | gr.Markdown( 215 | '输出信号(出现error时训练可能已完成或还在进行。可直接切到形象体验tab页面,如果体验时报错则训练还没好,再等待一般10来分钟。)') 216 | output_message = gr.Markdown() 217 | with gr.Box(): 218 | gr.Markdown(''' 219 | 碰到抓狂的错误或者计算资源紧张的情况下,推荐直接在[NoteBook](https://modelscope.cn/my/mynotebook/preset)上按照如下命令自行体验 220 | 1. git clone https://www.modelscope.cn/studios/CVstudio/cv_human_portrait.git 221 | 2. cd cv_human_portrait 222 | 3. pip install -r requirements.txt 223 | 4. pip install gradio==3.35.2 224 | 5. python app.py 225 | ''') 226 | 227 | run_button.click(fn=trainer.run, 228 | inputs=[ 229 | uuid, 230 | instance_images, 231 | ], 232 | outputs=[output_message]) 233 | 234 | return demo 235 | 236 | 237 | def inference_input(): 238 | with gr.Blocks() as demo: 239 | uuid = gr.Text(label="modelscope_uuid", visible=False) 240 | with gr.Row(): 241 | with gr.Column(): 242 | # user_models = gr.Radio(label="风格选择", choices=['\N{fire}商务证件'], type="value", value='\N{fire}商务证件') 243 | user_models = gr.Radio(label="模型选择", choices=HOT_MODELS, type="value", value=HOT_MODELS[0]) 244 | # flash_button = gr.Button('刷新模型列表') 245 | 246 | with gr.Box(): 247 | num_images = gr.Number( 248 | label='生成图片数量', value=6, precision=1) 249 | gr.Markdown(''' 250 | 注意:最多支持生成6张图片! 251 | ''') 252 | 253 | display_button = gr.Button('开始推理') 254 | 255 | with gr.Box(): 256 | infer_progress = gr.Textbox(label="生成进度", value="当前无生成任务", interactive=False) 257 | with gr.Box(): 258 | gr.Markdown('生成结果') 259 | # output_image = gr.Image() 260 | output_images = gr.Gallery(label='Output', show_label=False).style(columns=3, rows=2, height=600, 261 | object_fit="contain") 262 | 263 | display_button.click(fn=launch_pipeline, 264 | inputs=[uuid, user_models, num_images], 265 | outputs=[infer_progress, output_images]) 266 | 267 | return demo 268 | 269 | 270 | with gr.Blocks(css='style.css') as demo: 271 | with gr.Tabs(): 272 | with gr.TabItem('\N{rocket}形象定制'): 273 | train_input() 274 | with gr.TabItem('\N{party popper}形象体验'): 275 | inference_input() 276 | 277 | # demo.queue(max_size=100).launch(share=False) 278 | # demo.queue(concurrency_count=20).launch(share=False) 279 | demo.queue(status_update_rate=1).launch(share=True) 280 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /facechain/data_process/preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import os 4 | import shutil 5 | 6 | import numpy as np 7 | import json 8 | import math 9 | 10 | from .deepbooru import DeepDanbooru 11 | 12 | import cv2 13 | from PIL import Image 14 | from modelscope.pipelines import pipeline 15 | from modelscope.outputs import OutputKeys 16 | from modelscope.utils.constant import Tasks 17 | 18 | 19 | def crop_and_resize(im, bbox, thres=0.35, thres1=0.45): 20 | h, w, _ = im.shape 21 | thre = np.random.rand() * (thres1 - thres) + thres 22 | maxf = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) 23 | cx = (bbox[2] + bbox[0]) / 2 24 | cy = (bbox[3] + bbox[1]) / 2 25 | lenp = int(maxf / thre) 26 | yc = np.random.rand() * 0.15 + 0.35 27 | xc = 0.5 28 | xmin = int(cx - xc * lenp) 29 | xmax = xmin + lenp 30 | ymin = int(cy - yc * lenp) 31 | ymax = ymin + lenp 32 | x1 = 0 33 | x2 = lenp 34 | y1 = 0 35 | y2 = lenp 36 | if xmin < 0: 37 | x1 = -xmin 38 | xmin = 0 39 | if xmax > w: 40 | x2 = w - (xmax - lenp) 41 | xmax = w 42 | if ymin < 0: 43 | y1 = -ymin 44 | ymin = 0 45 | if ymax > h: 46 | y2 = h - (ymax - lenp) 47 | ymax = h 48 | imc = (np.ones((lenp, lenp, 3)) * 255).astype(np.uint8) 49 | imc[y1:y2, x1:x2, :] = im[ymin:ymax, xmin:xmax, :] 50 | imr = cv2.resize(imc, (512, 512)) 51 | return imr 52 | 53 | 54 | def pad_to_square(im): 55 | h, w, _ = im.shape 56 | ns = int(max(h, w) * 1.5) 57 | im = cv2.copyMakeBorder(im, int((ns - h) / 2), (ns - h) - int((ns - h) / 2), int((ns - w) / 2), 58 | (ns - w) - int((ns - w) / 2), cv2.BORDER_CONSTANT, 255) 59 | return im 60 | 61 | 62 | def post_process_naive(result_list, score_gender, score_age): 63 | # determine trigger word 64 | gender = np.argmax(score_gender) 65 | age = np.argmax(score_age) 66 | if age < 2: 67 | if gender == 0: 68 | tag_a_g = ['a boy', 'children'] 69 | else: 70 | tag_a_g = ['a girl', 'children'] 71 | elif age > 4: 72 | if gender == 0: 73 | tag_a_g = ['a mature man'] 74 | else: 75 | tag_a_g = ['a mature woman'] 76 | else: 77 | if gender == 0: 78 | tag_a_g = ['a handsome man'] 79 | else: 80 | tag_a_g = ['a beautiful woman'] 81 | num_images = len(result_list) 82 | cnt_girl = 0 83 | cnt_boy = 0 84 | result_list_new = [] 85 | for result in result_list: 86 | result_new = [] 87 | result_new.extend(tag_a_g) 88 | for tag in result: 89 | if tag == '1girl' or tag == '1boy': 90 | continue 91 | if tag[-4:] == '_man': 92 | continue 93 | if tag[-6:] == '_woman': 94 | continue 95 | if tag[-5:] == '_male': 96 | continue 97 | elif tag[-7:] == '_female': 98 | continue 99 | elif ( 100 | tag == 'ears' or tag == 'head' or tag == 'face' or tag == 'lips' or tag == 'mouth' or tag == '3d' or tag == 'asian' or tag == 'teeth'): 101 | continue 102 | elif ('eye' in tag and not 'eyewear' in tag): 103 | continue 104 | elif ('nose' in tag or 'body' in tag): 105 | continue 106 | elif tag[-5:] == '_lips': 107 | continue 108 | else: 109 | result_new.append(tag) 110 | # import pdb;pdb.set_trace() 111 | # result_new.append('slim body') 112 | result_list_new.append(result_new) 113 | 114 | return result_list_new 115 | 116 | 117 | def transformation_from_points(points1, points2): 118 | points1 = points1.astype(np.float64) 119 | points2 = points2.astype(np.float64) 120 | c1 = np.mean(points1, axis=0) 121 | c2 = np.mean(points2, axis=0) 122 | points1 -= c1 123 | points2 -= c2 124 | s1 = np.std(points1) 125 | s2 = np.std(points2) 126 | if s1 < 1.0e-4: 127 | s1 = 1.0e-4 128 | points1 /= s1 129 | points2 /= s2 130 | U, S, Vt = np.linalg.svd(points1.T * points2) 131 | R = (U * Vt).T 132 | return np.vstack([np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), np.matrix([0., 0., 1.])]) 133 | 134 | 135 | def rotate(im, keypoints): 136 | h, w, _ = im.shape 137 | points_array = np.zeros((5, 2)) 138 | dst_mean_face_size = 160 139 | dst_mean_face = np.asarray([0.31074522411511746, 0.2798131190011913, 140 | 0.6892073313037804, 0.2797830232679366, 141 | 0.49997367716346774, 0.5099309118810921, 142 | 0.35811903020866753, 0.7233174007629063, 143 | 0.6418878095835022, 0.7232890570786875]) 144 | dst_mean_face = np.reshape(dst_mean_face, (5, 2)) * dst_mean_face_size 145 | 146 | for k in range(5): 147 | points_array[k, 0] = keypoints[2 * k] 148 | points_array[k, 1] = keypoints[2 * k + 1] 149 | 150 | pts1 = np.float64(np.matrix([[point[0], point[1]] for point in points_array])) 151 | pts2 = np.float64(np.matrix([[point[0], point[1]] for point in dst_mean_face])) 152 | trans_mat = transformation_from_points(pts1, pts2) 153 | if trans_mat[1, 1] > 1.0e-4: 154 | angle = math.atan(trans_mat[1, 0] / trans_mat[1, 1]) 155 | else: 156 | angle = math.atan(trans_mat[0, 1] / trans_mat[0, 2]) 157 | im = pad_to_square(im) 158 | ns = int(1.5 * max(h, w)) 159 | M = cv2.getRotationMatrix2D((ns / 2, ns / 2), angle=-angle / np.pi * 180, scale=1.0) 160 | im = cv2.warpAffine(im, M=M, dsize=(ns, ns)) 161 | return im 162 | 163 | 164 | def get_mask_head(result): 165 | masks = result['masks'] 166 | scores = result['scores'] 167 | labels = result['labels'] 168 | mask_hair = np.zeros((512, 512)) 169 | mask_face = np.zeros((512, 512)) 170 | mask_human = np.zeros((512, 512)) 171 | for i in range(len(labels)): 172 | if scores[i] > 0.8: 173 | if labels[i] == 'Face': 174 | if np.sum(masks[i]) > np.sum(mask_face): 175 | mask_face = masks[i] 176 | elif labels[i] == 'Human': 177 | if np.sum(masks[i]) > np.sum(mask_human): 178 | mask_human = masks[i] 179 | elif labels[i] == 'Hair': 180 | if np.sum(masks[i]) > np.sum(mask_hair): 181 | mask_hair = masks[i] 182 | mask_head = np.clip(mask_hair + mask_face, 0, 1) 183 | ksize = max(int(np.sqrt(np.sum(mask_face)) / 20), 1) 184 | kernel = np.ones((ksize, ksize)) 185 | mask_head = cv2.dilate(mask_head, kernel, iterations=1) * mask_human 186 | _, mask_head = cv2.threshold((mask_head * 255).astype(np.uint8), 127, 255, cv2.THRESH_BINARY) 187 | contours, hierarchy = cv2.findContours(mask_head, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 188 | area = [] 189 | for j in range(len(contours)): 190 | area.append(cv2.contourArea(contours[j])) 191 | max_idx = np.argmax(area) 192 | mask_head = np.zeros((512, 512)).astype(np.uint8) 193 | cv2.fillPoly(mask_head, [contours[max_idx]], 255) 194 | mask_head = mask_head.astype(np.float32) / 255 195 | mask_head = np.clip(mask_head + mask_face, 0, 1) 196 | mask_head = np.expand_dims(mask_head, 2) 197 | return mask_head 198 | 199 | 200 | class Blipv2(): 201 | def __init__(self): 202 | self.model = DeepDanbooru() 203 | self.skin_retouching = pipeline(Tasks.skin_retouching, model='damo/cv_unet_skin-retouching') 204 | self.face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd') 205 | # self.mog_face_detection_func = pipeline(Tasks.face_detection, 'damo/cv_resnet101_face-detection_cvpr22papermogface') 206 | self.segmentation_pipeline = pipeline(Tasks.image_segmentation, 207 | 'damo/cv_resnet101_image-multiple-human-parsing') 208 | self.fair_face_attribute_func = pipeline(Tasks.face_attribute_recognition, 209 | 'damo/cv_resnet34_face-attribute-recognition_fairface') 210 | self.facial_landmark_confidence_func = pipeline(Tasks.face_2d_keypoints, 211 | 'damo/cv_manual_facial-landmark-confidence_flcm') 212 | 213 | def __call__(self, imdir): 214 | self.model.start() 215 | savedir = str(imdir) + '_labeled' 216 | shutil.rmtree(savedir, ignore_errors=True) 217 | os.makedirs(savedir, exist_ok=True) 218 | 219 | imlist = os.listdir(imdir) 220 | result_list = [] 221 | imgs_list = [] 222 | 223 | cnt = 0 224 | tmp_path = os.path.join(savedir, 'tmp.png') 225 | for imname in imlist: 226 | try: 227 | # if 1: 228 | if imname.startswith('.'): 229 | continue 230 | img_path = os.path.join(imdir, imname) 231 | im = cv2.imread(img_path) 232 | h, w, _ = im.shape 233 | max_size = max(w, h) 234 | ratio = 1024 / max_size 235 | new_w = round(w * ratio) 236 | new_h = round(h * ratio) 237 | imt = cv2.resize(im, (new_w, new_h)) 238 | cv2.imwrite(tmp_path, imt) 239 | result_det = self.face_detection(tmp_path) 240 | bboxes = result_det['boxes'] 241 | if len(bboxes) > 1: 242 | areas = [] 243 | for i in range(len(bboxes)): 244 | bbox = bboxes[i] 245 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) 246 | areas = np.array(areas) 247 | areas_new = np.sort(areas)[::-1] 248 | idxs = np.argsort(areas)[::-1] 249 | if areas_new[0] < 4 * areas_new[1]: 250 | print('Detecting multiple faces, do not use image {}.'.format(imname)) 251 | continue 252 | else: 253 | keypoints = result_det['keypoints'][idxs[0]] 254 | elif len(bboxes) == 0: 255 | print('Detecting no face, do not use image {}.'.format(imname)) 256 | continue 257 | else: 258 | keypoints = result_det['keypoints'][0] 259 | 260 | im = rotate(im, keypoints) 261 | ns = im.shape[0] 262 | imt = cv2.resize(im, (1024, 1024)) 263 | cv2.imwrite(tmp_path, imt) 264 | result_det = self.face_detection(tmp_path) 265 | bboxes = result_det['boxes'] 266 | 267 | if len(bboxes) > 1: 268 | areas = [] 269 | for i in range(len(bboxes)): 270 | bbox = bboxes[i] 271 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) 272 | areas = np.array(areas) 273 | areas_new = np.sort(areas)[::-1] 274 | idxs = np.argsort(areas)[::-1] 275 | if areas_new[0] < 4 * areas_new[1]: 276 | print('Detecting multiple faces after rotation, do not use image {}.'.format(imname)) 277 | continue 278 | else: 279 | bbox = bboxes[idxs[0]] 280 | elif len(bboxes) == 0: 281 | print('Detecting no face after rotation, do not use this image {}'.format(imname)) 282 | continue 283 | else: 284 | bbox = bboxes[0] 285 | 286 | for idx in range(4): 287 | bbox[idx] = bbox[idx] * ns / 1024 288 | imr = crop_and_resize(im, bbox) 289 | cv2.imwrite(tmp_path, imr) 290 | 291 | result = self.skin_retouching(tmp_path) 292 | if (result is None or (result[OutputKeys.OUTPUT_IMG] is None)): 293 | print('Cannot do skin retouching, do not use this image.') 294 | continue 295 | cv2.imwrite(tmp_path, result[OutputKeys.OUTPUT_IMG]) 296 | 297 | result = self.segmentation_pipeline(tmp_path) 298 | mask_head = get_mask_head(result) 299 | im = cv2.imread(tmp_path) 300 | im = im * mask_head + 255 * (1 - mask_head) 301 | # print(im.shape) 302 | 303 | raw_result = self.facial_landmark_confidence_func(im) 304 | if raw_result is None: 305 | print('landmark quality fail...') 306 | continue 307 | 308 | print(imname, raw_result['scores'][0]) 309 | if float(raw_result['scores'][0]) < (1 - 0.145): 310 | print('landmark quality fail...') 311 | continue 312 | 313 | cv2.imwrite(os.path.join(savedir, '{}.png'.format(cnt)), im) 314 | imgs_list.append('{}.png'.format(cnt)) 315 | img = Image.open(os.path.join(savedir, '{}.png'.format(cnt))) 316 | result = self.model.tag(img) 317 | print(result) 318 | attribute_result = self.fair_face_attribute_func(tmp_path) 319 | if cnt == 0: 320 | score_gender = np.array(attribute_result['scores'][0]) 321 | score_age = np.array(attribute_result['scores'][1]) 322 | else: 323 | score_gender += np.array(attribute_result['scores'][0]) 324 | score_age += np.array(attribute_result['scores'][1]) 325 | 326 | result_list.append(result.split(', ')) 327 | cnt += 1 328 | except Exception as e: 329 | print('cathed for image process of ' + imname) 330 | print('Error: ' + e) 331 | 332 | print(result_list) 333 | if len(result_list) == 0: 334 | print('Error: ' + e) 335 | exit() 336 | return os.path.join(savedir, "metadata.jsonl") 337 | 338 | result_list = post_process_naive(result_list, score_gender, score_age) 339 | self.model.stop() 340 | os.system('rm ' + tmp_path) 341 | 342 | out_json_name = os.path.join(savedir, "metadata.jsonl") 343 | fo = open(out_json_name, 'w') 344 | for i in range(len(result_list)): 345 | generated_text = ", ".join(result_list[i]) 346 | print(imgs_list[i], generated_text) 347 | info_dict = {"file_name": imgs_list[i], "text": ", " + generated_text} 348 | fo.write(json.dumps(info_dict) + '\n') 349 | fo.close() 350 | return out_json_name 351 | -------------------------------------------------------------------------------- /facechain/data_process/preprocessing -win10.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import os 4 | import shutil 5 | 6 | import numpy as np 7 | import json 8 | import math 9 | 10 | from .deepbooru import DeepDanbooru 11 | 12 | import cv2 13 | from PIL import Image 14 | from modelscope.pipelines import pipeline 15 | from modelscope.outputs import OutputKeys 16 | from modelscope.utils.constant import Tasks 17 | 18 | 19 | def crop_and_resize(im, bbox, thres=0.35, thres1=0.45): 20 | h, w, _ = im.shape 21 | thre = np.random.rand() * (thres1 - thres) + thres 22 | maxf = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) 23 | cx = (bbox[2] + bbox[0]) / 2 24 | cy = (bbox[3] + bbox[1]) / 2 25 | lenp = int(maxf / thre) 26 | yc = np.random.rand() * 0.15 + 0.35 27 | xc = 0.5 28 | xmin = int(cx - xc * lenp) 29 | xmax = xmin + lenp 30 | ymin = int(cy - yc * lenp) 31 | ymax = ymin + lenp 32 | x1 = 0 33 | x2 = lenp 34 | y1 = 0 35 | y2 = lenp 36 | if xmin < 0: 37 | x1 = -xmin 38 | xmin = 0 39 | if xmax > w: 40 | x2 = w - (xmax - lenp) 41 | xmax = w 42 | if ymin < 0: 43 | y1 = -ymin 44 | ymin = 0 45 | if ymax > h: 46 | y2 = h - (ymax - lenp) 47 | ymax = h 48 | imc = (np.ones((lenp, lenp, 3)) * 255).astype(np.uint8) 49 | imc[y1:y2, x1:x2, :] = im[ymin:ymax, xmin:xmax, :] 50 | imr = cv2.resize(imc, (512, 512)) 51 | return imr 52 | 53 | 54 | def pad_to_square(im): 55 | h, w, _ = im.shape 56 | ns = int(max(h, w) * 1.5) 57 | im = cv2.copyMakeBorder(im, int((ns - h) / 2), (ns - h) - int((ns - h) / 2), int((ns - w) / 2), 58 | (ns - w) - int((ns - w) / 2), cv2.BORDER_CONSTANT, 255) 59 | return im 60 | 61 | 62 | def post_process_naive(result_list, score_gender, score_age): 63 | # determine trigger word 64 | gender = np.argmax(score_gender) 65 | age = np.argmax(score_age) 66 | if age < 2: 67 | if gender == 0: 68 | tag_a_g = ['a boy', 'children'] 69 | else: 70 | tag_a_g = ['a girl', 'children'] 71 | elif age > 4: 72 | if gender == 0: 73 | tag_a_g = ['a mature man'] 74 | else: 75 | tag_a_g = ['a mature woman'] 76 | else: 77 | if gender == 0: 78 | tag_a_g = ['a handsome man'] 79 | else: 80 | tag_a_g = ['a beautiful woman'] 81 | num_images = len(result_list) 82 | cnt_girl = 0 83 | cnt_boy = 0 84 | result_list_new = [] 85 | for result in result_list: 86 | result_new = [] 87 | result_new.extend(tag_a_g) 88 | for tag in result: 89 | if tag == '1girl' or tag == '1boy': 90 | continue 91 | if tag[-4:] == '_man': 92 | continue 93 | if tag[-6:] == '_woman': 94 | continue 95 | if tag[-5:] == '_male': 96 | continue 97 | elif tag[-7:] == '_female': 98 | continue 99 | elif ( 100 | tag == 'ears' or tag == 'head' or tag == 'face' or tag == 'lips' or tag == 'mouth' or tag == '3d' or tag == 'asian' or tag == 'teeth'): 101 | continue 102 | elif ('eye' in tag and not 'eyewear' in tag): 103 | continue 104 | elif ('nose' in tag or 'body' in tag): 105 | continue 106 | elif tag[-5:] == '_lips': 107 | continue 108 | else: 109 | result_new.append(tag) 110 | # import pdb;pdb.set_trace() 111 | # result_new.append('slim body') 112 | result_list_new.append(result_new) 113 | 114 | return result_list_new 115 | 116 | 117 | def transformation_from_points(points1, points2): 118 | points1 = points1.astype(np.float64) 119 | points2 = points2.astype(np.float64) 120 | c1 = np.mean(points1, axis=0) 121 | c2 = np.mean(points2, axis=0) 122 | points1 -= c1 123 | points2 -= c2 124 | s1 = np.std(points1) 125 | s2 = np.std(points2) 126 | if s1 < 1.0e-4: 127 | s1 = 1.0e-4 128 | points1 /= s1 129 | points2 /= s2 130 | U, S, Vt = np.linalg.svd(points1.T * points2) 131 | R = (U * Vt).T 132 | return np.vstack([np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), np.matrix([0., 0., 1.])]) 133 | 134 | 135 | def rotate(im, keypoints): 136 | h, w, _ = im.shape 137 | points_array = np.zeros((5, 2)) 138 | dst_mean_face_size = 160 139 | dst_mean_face = np.asarray([0.31074522411511746, 0.2798131190011913, 140 | 0.6892073313037804, 0.2797830232679366, 141 | 0.49997367716346774, 0.5099309118810921, 142 | 0.35811903020866753, 0.7233174007629063, 143 | 0.6418878095835022, 0.7232890570786875]) 144 | dst_mean_face = np.reshape(dst_mean_face, (5, 2)) * dst_mean_face_size 145 | 146 | for k in range(5): 147 | points_array[k, 0] = keypoints[2 * k] 148 | points_array[k, 1] = keypoints[2 * k + 1] 149 | 150 | pts1 = np.float64(np.matrix([[point[0], point[1]] for point in points_array])) 151 | pts2 = np.float64(np.matrix([[point[0], point[1]] for point in dst_mean_face])) 152 | trans_mat = transformation_from_points(pts1, pts2) 153 | if trans_mat[1, 1] > 1.0e-4: 154 | angle = math.atan(trans_mat[1, 0] / trans_mat[1, 1]) 155 | else: 156 | angle = math.atan(trans_mat[0, 1] / trans_mat[0, 2]) 157 | im = pad_to_square(im) 158 | ns = int(1.5 * max(h, w)) 159 | M = cv2.getRotationMatrix2D((ns / 2, ns / 2), angle=-angle / np.pi * 180, scale=1.0) 160 | im = cv2.warpAffine(im, M=M, dsize=(ns, ns)) 161 | return im 162 | 163 | 164 | def get_mask_head(result): 165 | masks = result['masks'] 166 | scores = result['scores'] 167 | labels = result['labels'] 168 | mask_hair = np.zeros((512, 512)) 169 | mask_face = np.zeros((512, 512)) 170 | mask_human = np.zeros((512, 512)) 171 | for i in range(len(labels)): 172 | if scores[i] > 0.8: 173 | if labels[i] == 'Face': 174 | if np.sum(masks[i]) > np.sum(mask_face): 175 | mask_face = masks[i] 176 | elif labels[i] == 'Human': 177 | if np.sum(masks[i]) > np.sum(mask_human): 178 | mask_human = masks[i] 179 | elif labels[i] == 'Hair': 180 | if np.sum(masks[i]) > np.sum(mask_hair): 181 | mask_hair = masks[i] 182 | mask_head = np.clip(mask_hair + mask_face, 0, 1) 183 | ksize = max(int(np.sqrt(np.sum(mask_face)) / 20), 1) 184 | kernel = np.ones((ksize, ksize)) 185 | mask_head = cv2.dilate(mask_head, kernel, iterations=1) * mask_human 186 | _, mask_head = cv2.threshold((mask_head * 255).astype(np.uint8), 127, 255, cv2.THRESH_BINARY) 187 | contours, hierarchy = cv2.findContours(mask_head, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE) 188 | area = [] 189 | for j in range(len(contours)): 190 | area.append(cv2.contourArea(contours[j])) 191 | max_idx = np.argmax(area) 192 | mask_head = np.zeros((512, 512)).astype(np.uint8) 193 | cv2.fillPoly(mask_head, [contours[max_idx]], 255) 194 | mask_head = mask_head.astype(np.float32) / 255 195 | mask_head = np.clip(mask_head + mask_face, 0, 1) 196 | mask_head = np.expand_dims(mask_head, 2) 197 | return mask_head 198 | 199 | 200 | class Blipv2(): 201 | def __init__(self): 202 | self.model = DeepDanbooru() 203 | self.skin_retouching = pipeline(Tasks.skin_retouching, model='damo/cv_unet_skin-retouching') 204 | self.face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd') 205 | # self.mog_face_detection_func = pipeline(Tasks.face_detection, 'damo/cv_resnet101_face-detection_cvpr22papermogface') 206 | self.segmentation_pipeline = pipeline(Tasks.image_segmentation, 207 | 'damo/cv_resnet101_image-multiple-human-parsing') 208 | self.fair_face_attribute_func = pipeline(Tasks.face_attribute_recognition, 209 | 'damo/cv_resnet34_face-attribute-recognition_fairface') 210 | self.facial_landmark_confidence_func = pipeline(Tasks.face_2d_keypoints, 211 | 'damo/cv_manual_facial-landmark-confidence_flcm') 212 | 213 | def __call__(self, imdir): 214 | self.model.start() 215 | savedir = str(imdir) + '_labeled' 216 | shutil.rmtree(savedir, ignore_errors=True) 217 | os.makedirs(savedir, exist_ok=True) 218 | 219 | imlist = os.listdir(imdir) 220 | result_list = [] 221 | imgs_list = [] 222 | 223 | cnt = 0 224 | tmp_path = os.path.join(savedir, 'tmp.png') 225 | for imname in imlist: 226 | try: 227 | # if 1: 228 | if imname.startswith('.'): 229 | continue 230 | img_path = os.path.join(imdir, imname) 231 | im = cv2.imread(img_path) 232 | h, w, _ = im.shape 233 | max_size = max(w, h) 234 | ratio = 1024 / max_size 235 | new_w = round(w * ratio) 236 | new_h = round(h * ratio) 237 | imt = cv2.resize(im, (new_w, new_h)) 238 | cv2.imwrite(tmp_path, imt) 239 | result_det = self.face_detection(tmp_path) 240 | bboxes = result_det['boxes'] 241 | if len(bboxes) > 1: 242 | areas = [] 243 | for i in range(len(bboxes)): 244 | bbox = bboxes[i] 245 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) 246 | areas = np.array(areas) 247 | areas_new = np.sort(areas)[::-1] 248 | idxs = np.argsort(areas)[::-1] 249 | if areas_new[0] < 4 * areas_new[1]: 250 | print('Detecting multiple faces, do not use image {}.'.format(imname)) 251 | continue 252 | else: 253 | keypoints = result_det['keypoints'][idxs[0]] 254 | elif len(bboxes) == 0: 255 | print('Detecting no face, do not use image {}.'.format(imname)) 256 | continue 257 | else: 258 | keypoints = result_det['keypoints'][0] 259 | 260 | im = rotate(im, keypoints) 261 | ns = im.shape[0] 262 | imt = cv2.resize(im, (1024, 1024)) 263 | cv2.imwrite(tmp_path, imt) 264 | result_det = self.face_detection(tmp_path) 265 | bboxes = result_det['boxes'] 266 | 267 | if len(bboxes) > 1: 268 | areas = [] 269 | for i in range(len(bboxes)): 270 | bbox = bboxes[i] 271 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1])) 272 | areas = np.array(areas) 273 | areas_new = np.sort(areas)[::-1] 274 | idxs = np.argsort(areas)[::-1] 275 | if areas_new[0] < 4 * areas_new[1]: 276 | print('Detecting multiple faces after rotation, do not use image {}.'.format(imname)) 277 | continue 278 | else: 279 | bbox = bboxes[idxs[0]] 280 | elif len(bboxes) == 0: 281 | print('Detecting no face after rotation, do not use this image {}'.format(imname)) 282 | continue 283 | else: 284 | bbox = bboxes[0] 285 | 286 | for idx in range(4): 287 | bbox[idx] = bbox[idx] * ns / 1024 288 | imr = crop_and_resize(im, bbox) 289 | cv2.imwrite(tmp_path, imr) 290 | 291 | result = self.skin_retouching(tmp_path) 292 | if (result is None or (result[OutputKeys.OUTPUT_IMG] is None)): 293 | print('Cannot do skin retouching, do not use this image.') 294 | continue 295 | cv2.imwrite(tmp_path, result[OutputKeys.OUTPUT_IMG]) 296 | 297 | result = self.segmentation_pipeline(tmp_path) 298 | mask_head = get_mask_head(result) 299 | im = cv2.imread(tmp_path) 300 | im = im * mask_head + 255 * (1 - mask_head) 301 | # print(im.shape) 302 | 303 | raw_result = self.facial_landmark_confidence_func(im) 304 | if raw_result is None: 305 | print('landmark quality fail...') 306 | continue 307 | 308 | print(imname, raw_result['scores'][0]) 309 | if float(raw_result['scores'][0]) < (1 - 0.145): 310 | print('landmark quality fail...') 311 | continue 312 | 313 | cv2.imwrite(os.path.join(savedir, '{}.png'.format(cnt)), im) 314 | imgs_list.append('{}.png'.format(cnt)) 315 | img = Image.open(os.path.join(savedir, '{}.png'.format(cnt))) 316 | result = self.model.tag(img) 317 | print(result) 318 | attribute_result = self.fair_face_attribute_func(tmp_path) 319 | if cnt == 0: 320 | score_gender = np.array(attribute_result['scores'][0]) 321 | score_age = np.array(attribute_result['scores'][1]) 322 | else: 323 | score_gender += np.array(attribute_result['scores'][0]) 324 | score_age += np.array(attribute_result['scores'][1]) 325 | 326 | result_list.append(result.split(', ')) 327 | cnt += 1 328 | except Exception as e: 329 | print('cathed for image process of ' + imname) 330 | print('Error: ' + e) 331 | 332 | print(result_list) 333 | if len(result_list) == 0: 334 | print('Error: ' + e) 335 | exit() 336 | return os.path.join(savedir, "metadata.jsonl") 337 | 338 | result_list = post_process_naive(result_list, score_gender, score_age) 339 | self.model.stop() 340 | # os.system('rm ' + tmp_path) 341 | os.system('del ' + tmp_path) 342 | 343 | out_json_name = os.path.join(savedir, "metadata.jsonl") 344 | fo = open(out_json_name, 'w') 345 | for i in range(len(result_list)): 346 | generated_text = ", ".join(result_list[i]) 347 | print(imgs_list[i], generated_text) 348 | info_dict = {"file_name": imgs_list[i], "text": ", " + generated_text} 349 | fo.write(json.dumps(info_dict) + '\n') 350 | fo.close() 351 | return out_json_name 352 | -------------------------------------------------------------------------------- /facechain/data_process/deepbooru.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba, Inc. and its affiliates. 2 | 3 | import os 4 | import re 5 | 6 | from PIL import Image 7 | import numpy as np 8 | 9 | re_special = re.compile(r'([\\()])') 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from modelscope.hub.snapshot_download import snapshot_download 15 | 16 | # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more 17 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS) 18 | 19 | 20 | class DeepDanbooruModel(nn.Module): 21 | def __init__(self): 22 | super(DeepDanbooruModel, self).__init__() 23 | 24 | self.tags = [] 25 | 26 | self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2)) 27 | self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2)) 28 | self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) 29 | self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64) 30 | self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64) 31 | self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) 32 | self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64) 33 | self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64) 34 | self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) 35 | self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64) 36 | self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64) 37 | self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256) 38 | self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2)) 39 | self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128) 40 | self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2)) 41 | self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 42 | self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 43 | self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 44 | self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 45 | self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 46 | self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 47 | self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 48 | self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 49 | self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 50 | self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 51 | self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 52 | self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 53 | self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 54 | self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 55 | self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 56 | self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 57 | self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 58 | self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 59 | self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 60 | self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128) 61 | self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128) 62 | self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512) 63 | self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2)) 64 | self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256) 65 | self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2)) 66 | self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 67 | self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 68 | self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 69 | self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 70 | self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 71 | self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 72 | self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 73 | self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 74 | self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 75 | self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 76 | self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 77 | self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 78 | self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 79 | self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 80 | self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 81 | self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 82 | self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 83 | self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 84 | self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 85 | self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 86 | self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 87 | self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 88 | self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 89 | self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 90 | self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 91 | self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 92 | self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 93 | self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 94 | self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 95 | self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 96 | self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 97 | self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 98 | self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 99 | self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 100 | self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 101 | self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 102 | self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 103 | self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 104 | self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 105 | self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 106 | self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 107 | self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 108 | self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 109 | self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 110 | self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 111 | self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 112 | self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 113 | self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 114 | self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 115 | self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 116 | self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 117 | self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 118 | self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 119 | self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 120 | self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 121 | self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 122 | self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 123 | self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 124 | self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 125 | self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2)) 126 | self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 127 | self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2)) 128 | self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 129 | self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 130 | self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 131 | self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 132 | self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 133 | self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 134 | self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 135 | self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 136 | self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 137 | self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 138 | self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 139 | self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 140 | self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 141 | self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 142 | self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 143 | self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 144 | self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 145 | self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 146 | self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 147 | self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 148 | self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 149 | self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 150 | self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 151 | self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 152 | self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 153 | self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 154 | self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 155 | self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 156 | self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 157 | self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 158 | self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 159 | self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 160 | self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 161 | self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 162 | self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 163 | self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 164 | self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 165 | self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 166 | self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 167 | self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 168 | self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 169 | self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 170 | self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 171 | self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 172 | self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 173 | self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 174 | self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 175 | self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 176 | self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 177 | self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 178 | self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 179 | self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 180 | self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 181 | self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 182 | self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256) 183 | self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256) 184 | self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024) 185 | self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2)) 186 | self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512) 187 | self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2)) 188 | self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048) 189 | self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512) 190 | self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512) 191 | self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048) 192 | self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512) 193 | self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512) 194 | self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048) 195 | self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2)) 196 | self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024) 197 | self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2)) 198 | self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096) 199 | self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024) 200 | self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024) 201 | self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096) 202 | self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024) 203 | self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024) 204 | self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096) 205 | self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False) 206 | 207 | def forward(self, *inputs): 208 | t_358, = inputs 209 | t_359 = t_358.permute(*[0, 3, 1, 2]) 210 | t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0) 211 | # t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded) 212 | t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype)) 213 | t_361 = F.relu(t_360) 214 | t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf')) 215 | t_362 = self.n_MaxPool_0(t_361) 216 | t_363 = self.n_Conv_1(t_362) 217 | t_364 = self.n_Conv_2(t_362) 218 | t_365 = F.relu(t_364) 219 | t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0) 220 | t_366 = self.n_Conv_3(t_365_padded) 221 | t_367 = F.relu(t_366) 222 | t_368 = self.n_Conv_4(t_367) 223 | t_369 = torch.add(t_368, t_363) 224 | t_370 = F.relu(t_369) 225 | t_371 = self.n_Conv_5(t_370) 226 | t_372 = F.relu(t_371) 227 | t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0) 228 | t_373 = self.n_Conv_6(t_372_padded) 229 | t_374 = F.relu(t_373) 230 | t_375 = self.n_Conv_7(t_374) 231 | t_376 = torch.add(t_375, t_370) 232 | t_377 = F.relu(t_376) 233 | t_378 = self.n_Conv_8(t_377) 234 | t_379 = F.relu(t_378) 235 | t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0) 236 | t_380 = self.n_Conv_9(t_379_padded) 237 | t_381 = F.relu(t_380) 238 | t_382 = self.n_Conv_10(t_381) 239 | t_383 = torch.add(t_382, t_377) 240 | t_384 = F.relu(t_383) 241 | t_385 = self.n_Conv_11(t_384) 242 | t_386 = self.n_Conv_12(t_384) 243 | t_387 = F.relu(t_386) 244 | t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0) 245 | t_388 = self.n_Conv_13(t_387_padded) 246 | t_389 = F.relu(t_388) 247 | t_390 = self.n_Conv_14(t_389) 248 | t_391 = torch.add(t_390, t_385) 249 | t_392 = F.relu(t_391) 250 | t_393 = self.n_Conv_15(t_392) 251 | t_394 = F.relu(t_393) 252 | t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0) 253 | t_395 = self.n_Conv_16(t_394_padded) 254 | t_396 = F.relu(t_395) 255 | t_397 = self.n_Conv_17(t_396) 256 | t_398 = torch.add(t_397, t_392) 257 | t_399 = F.relu(t_398) 258 | t_400 = self.n_Conv_18(t_399) 259 | t_401 = F.relu(t_400) 260 | t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0) 261 | t_402 = self.n_Conv_19(t_401_padded) 262 | t_403 = F.relu(t_402) 263 | t_404 = self.n_Conv_20(t_403) 264 | t_405 = torch.add(t_404, t_399) 265 | t_406 = F.relu(t_405) 266 | t_407 = self.n_Conv_21(t_406) 267 | t_408 = F.relu(t_407) 268 | t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0) 269 | t_409 = self.n_Conv_22(t_408_padded) 270 | t_410 = F.relu(t_409) 271 | t_411 = self.n_Conv_23(t_410) 272 | t_412 = torch.add(t_411, t_406) 273 | t_413 = F.relu(t_412) 274 | t_414 = self.n_Conv_24(t_413) 275 | t_415 = F.relu(t_414) 276 | t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0) 277 | t_416 = self.n_Conv_25(t_415_padded) 278 | t_417 = F.relu(t_416) 279 | t_418 = self.n_Conv_26(t_417) 280 | t_419 = torch.add(t_418, t_413) 281 | t_420 = F.relu(t_419) 282 | t_421 = self.n_Conv_27(t_420) 283 | t_422 = F.relu(t_421) 284 | t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0) 285 | t_423 = self.n_Conv_28(t_422_padded) 286 | t_424 = F.relu(t_423) 287 | t_425 = self.n_Conv_29(t_424) 288 | t_426 = torch.add(t_425, t_420) 289 | t_427 = F.relu(t_426) 290 | t_428 = self.n_Conv_30(t_427) 291 | t_429 = F.relu(t_428) 292 | t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0) 293 | t_430 = self.n_Conv_31(t_429_padded) 294 | t_431 = F.relu(t_430) 295 | t_432 = self.n_Conv_32(t_431) 296 | t_433 = torch.add(t_432, t_427) 297 | t_434 = F.relu(t_433) 298 | t_435 = self.n_Conv_33(t_434) 299 | t_436 = F.relu(t_435) 300 | t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0) 301 | t_437 = self.n_Conv_34(t_436_padded) 302 | t_438 = F.relu(t_437) 303 | t_439 = self.n_Conv_35(t_438) 304 | t_440 = torch.add(t_439, t_434) 305 | t_441 = F.relu(t_440) 306 | t_442 = self.n_Conv_36(t_441) 307 | t_443 = self.n_Conv_37(t_441) 308 | t_444 = F.relu(t_443) 309 | t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0) 310 | t_445 = self.n_Conv_38(t_444_padded) 311 | t_446 = F.relu(t_445) 312 | t_447 = self.n_Conv_39(t_446) 313 | t_448 = torch.add(t_447, t_442) 314 | t_449 = F.relu(t_448) 315 | t_450 = self.n_Conv_40(t_449) 316 | t_451 = F.relu(t_450) 317 | t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0) 318 | t_452 = self.n_Conv_41(t_451_padded) 319 | t_453 = F.relu(t_452) 320 | t_454 = self.n_Conv_42(t_453) 321 | t_455 = torch.add(t_454, t_449) 322 | t_456 = F.relu(t_455) 323 | t_457 = self.n_Conv_43(t_456) 324 | t_458 = F.relu(t_457) 325 | t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0) 326 | t_459 = self.n_Conv_44(t_458_padded) 327 | t_460 = F.relu(t_459) 328 | t_461 = self.n_Conv_45(t_460) 329 | t_462 = torch.add(t_461, t_456) 330 | t_463 = F.relu(t_462) 331 | t_464 = self.n_Conv_46(t_463) 332 | t_465 = F.relu(t_464) 333 | t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0) 334 | t_466 = self.n_Conv_47(t_465_padded) 335 | t_467 = F.relu(t_466) 336 | t_468 = self.n_Conv_48(t_467) 337 | t_469 = torch.add(t_468, t_463) 338 | t_470 = F.relu(t_469) 339 | t_471 = self.n_Conv_49(t_470) 340 | t_472 = F.relu(t_471) 341 | t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0) 342 | t_473 = self.n_Conv_50(t_472_padded) 343 | t_474 = F.relu(t_473) 344 | t_475 = self.n_Conv_51(t_474) 345 | t_476 = torch.add(t_475, t_470) 346 | t_477 = F.relu(t_476) 347 | t_478 = self.n_Conv_52(t_477) 348 | t_479 = F.relu(t_478) 349 | t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0) 350 | t_480 = self.n_Conv_53(t_479_padded) 351 | t_481 = F.relu(t_480) 352 | t_482 = self.n_Conv_54(t_481) 353 | t_483 = torch.add(t_482, t_477) 354 | t_484 = F.relu(t_483) 355 | t_485 = self.n_Conv_55(t_484) 356 | t_486 = F.relu(t_485) 357 | t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0) 358 | t_487 = self.n_Conv_56(t_486_padded) 359 | t_488 = F.relu(t_487) 360 | t_489 = self.n_Conv_57(t_488) 361 | t_490 = torch.add(t_489, t_484) 362 | t_491 = F.relu(t_490) 363 | t_492 = self.n_Conv_58(t_491) 364 | t_493 = F.relu(t_492) 365 | t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0) 366 | t_494 = self.n_Conv_59(t_493_padded) 367 | t_495 = F.relu(t_494) 368 | t_496 = self.n_Conv_60(t_495) 369 | t_497 = torch.add(t_496, t_491) 370 | t_498 = F.relu(t_497) 371 | t_499 = self.n_Conv_61(t_498) 372 | t_500 = F.relu(t_499) 373 | t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0) 374 | t_501 = self.n_Conv_62(t_500_padded) 375 | t_502 = F.relu(t_501) 376 | t_503 = self.n_Conv_63(t_502) 377 | t_504 = torch.add(t_503, t_498) 378 | t_505 = F.relu(t_504) 379 | t_506 = self.n_Conv_64(t_505) 380 | t_507 = F.relu(t_506) 381 | t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0) 382 | t_508 = self.n_Conv_65(t_507_padded) 383 | t_509 = F.relu(t_508) 384 | t_510 = self.n_Conv_66(t_509) 385 | t_511 = torch.add(t_510, t_505) 386 | t_512 = F.relu(t_511) 387 | t_513 = self.n_Conv_67(t_512) 388 | t_514 = F.relu(t_513) 389 | t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0) 390 | t_515 = self.n_Conv_68(t_514_padded) 391 | t_516 = F.relu(t_515) 392 | t_517 = self.n_Conv_69(t_516) 393 | t_518 = torch.add(t_517, t_512) 394 | t_519 = F.relu(t_518) 395 | t_520 = self.n_Conv_70(t_519) 396 | t_521 = F.relu(t_520) 397 | t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0) 398 | t_522 = self.n_Conv_71(t_521_padded) 399 | t_523 = F.relu(t_522) 400 | t_524 = self.n_Conv_72(t_523) 401 | t_525 = torch.add(t_524, t_519) 402 | t_526 = F.relu(t_525) 403 | t_527 = self.n_Conv_73(t_526) 404 | t_528 = F.relu(t_527) 405 | t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0) 406 | t_529 = self.n_Conv_74(t_528_padded) 407 | t_530 = F.relu(t_529) 408 | t_531 = self.n_Conv_75(t_530) 409 | t_532 = torch.add(t_531, t_526) 410 | t_533 = F.relu(t_532) 411 | t_534 = self.n_Conv_76(t_533) 412 | t_535 = F.relu(t_534) 413 | t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0) 414 | t_536 = self.n_Conv_77(t_535_padded) 415 | t_537 = F.relu(t_536) 416 | t_538 = self.n_Conv_78(t_537) 417 | t_539 = torch.add(t_538, t_533) 418 | t_540 = F.relu(t_539) 419 | t_541 = self.n_Conv_79(t_540) 420 | t_542 = F.relu(t_541) 421 | t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0) 422 | t_543 = self.n_Conv_80(t_542_padded) 423 | t_544 = F.relu(t_543) 424 | t_545 = self.n_Conv_81(t_544) 425 | t_546 = torch.add(t_545, t_540) 426 | t_547 = F.relu(t_546) 427 | t_548 = self.n_Conv_82(t_547) 428 | t_549 = F.relu(t_548) 429 | t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0) 430 | t_550 = self.n_Conv_83(t_549_padded) 431 | t_551 = F.relu(t_550) 432 | t_552 = self.n_Conv_84(t_551) 433 | t_553 = torch.add(t_552, t_547) 434 | t_554 = F.relu(t_553) 435 | t_555 = self.n_Conv_85(t_554) 436 | t_556 = F.relu(t_555) 437 | t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0) 438 | t_557 = self.n_Conv_86(t_556_padded) 439 | t_558 = F.relu(t_557) 440 | t_559 = self.n_Conv_87(t_558) 441 | t_560 = torch.add(t_559, t_554) 442 | t_561 = F.relu(t_560) 443 | t_562 = self.n_Conv_88(t_561) 444 | t_563 = F.relu(t_562) 445 | t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0) 446 | t_564 = self.n_Conv_89(t_563_padded) 447 | t_565 = F.relu(t_564) 448 | t_566 = self.n_Conv_90(t_565) 449 | t_567 = torch.add(t_566, t_561) 450 | t_568 = F.relu(t_567) 451 | t_569 = self.n_Conv_91(t_568) 452 | t_570 = F.relu(t_569) 453 | t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0) 454 | t_571 = self.n_Conv_92(t_570_padded) 455 | t_572 = F.relu(t_571) 456 | t_573 = self.n_Conv_93(t_572) 457 | t_574 = torch.add(t_573, t_568) 458 | t_575 = F.relu(t_574) 459 | t_576 = self.n_Conv_94(t_575) 460 | t_577 = F.relu(t_576) 461 | t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0) 462 | t_578 = self.n_Conv_95(t_577_padded) 463 | t_579 = F.relu(t_578) 464 | t_580 = self.n_Conv_96(t_579) 465 | t_581 = torch.add(t_580, t_575) 466 | t_582 = F.relu(t_581) 467 | t_583 = self.n_Conv_97(t_582) 468 | t_584 = F.relu(t_583) 469 | t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0) 470 | t_585 = self.n_Conv_98(t_584_padded) 471 | t_586 = F.relu(t_585) 472 | t_587 = self.n_Conv_99(t_586) 473 | t_588 = self.n_Conv_100(t_582) 474 | t_589 = torch.add(t_587, t_588) 475 | t_590 = F.relu(t_589) 476 | t_591 = self.n_Conv_101(t_590) 477 | t_592 = F.relu(t_591) 478 | t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0) 479 | t_593 = self.n_Conv_102(t_592_padded) 480 | t_594 = F.relu(t_593) 481 | t_595 = self.n_Conv_103(t_594) 482 | t_596 = torch.add(t_595, t_590) 483 | t_597 = F.relu(t_596) 484 | t_598 = self.n_Conv_104(t_597) 485 | t_599 = F.relu(t_598) 486 | t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0) 487 | t_600 = self.n_Conv_105(t_599_padded) 488 | t_601 = F.relu(t_600) 489 | t_602 = self.n_Conv_106(t_601) 490 | t_603 = torch.add(t_602, t_597) 491 | t_604 = F.relu(t_603) 492 | t_605 = self.n_Conv_107(t_604) 493 | t_606 = F.relu(t_605) 494 | t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0) 495 | t_607 = self.n_Conv_108(t_606_padded) 496 | t_608 = F.relu(t_607) 497 | t_609 = self.n_Conv_109(t_608) 498 | t_610 = torch.add(t_609, t_604) 499 | t_611 = F.relu(t_610) 500 | t_612 = self.n_Conv_110(t_611) 501 | t_613 = F.relu(t_612) 502 | t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0) 503 | t_614 = self.n_Conv_111(t_613_padded) 504 | t_615 = F.relu(t_614) 505 | t_616 = self.n_Conv_112(t_615) 506 | t_617 = torch.add(t_616, t_611) 507 | t_618 = F.relu(t_617) 508 | t_619 = self.n_Conv_113(t_618) 509 | t_620 = F.relu(t_619) 510 | t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0) 511 | t_621 = self.n_Conv_114(t_620_padded) 512 | t_622 = F.relu(t_621) 513 | t_623 = self.n_Conv_115(t_622) 514 | t_624 = torch.add(t_623, t_618) 515 | t_625 = F.relu(t_624) 516 | t_626 = self.n_Conv_116(t_625) 517 | t_627 = F.relu(t_626) 518 | t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0) 519 | t_628 = self.n_Conv_117(t_627_padded) 520 | t_629 = F.relu(t_628) 521 | t_630 = self.n_Conv_118(t_629) 522 | t_631 = torch.add(t_630, t_625) 523 | t_632 = F.relu(t_631) 524 | t_633 = self.n_Conv_119(t_632) 525 | t_634 = F.relu(t_633) 526 | t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0) 527 | t_635 = self.n_Conv_120(t_634_padded) 528 | t_636 = F.relu(t_635) 529 | t_637 = self.n_Conv_121(t_636) 530 | t_638 = torch.add(t_637, t_632) 531 | t_639 = F.relu(t_638) 532 | t_640 = self.n_Conv_122(t_639) 533 | t_641 = F.relu(t_640) 534 | t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0) 535 | t_642 = self.n_Conv_123(t_641_padded) 536 | t_643 = F.relu(t_642) 537 | t_644 = self.n_Conv_124(t_643) 538 | t_645 = torch.add(t_644, t_639) 539 | t_646 = F.relu(t_645) 540 | t_647 = self.n_Conv_125(t_646) 541 | t_648 = F.relu(t_647) 542 | t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0) 543 | t_649 = self.n_Conv_126(t_648_padded) 544 | t_650 = F.relu(t_649) 545 | t_651 = self.n_Conv_127(t_650) 546 | t_652 = torch.add(t_651, t_646) 547 | t_653 = F.relu(t_652) 548 | t_654 = self.n_Conv_128(t_653) 549 | t_655 = F.relu(t_654) 550 | t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0) 551 | t_656 = self.n_Conv_129(t_655_padded) 552 | t_657 = F.relu(t_656) 553 | t_658 = self.n_Conv_130(t_657) 554 | t_659 = torch.add(t_658, t_653) 555 | t_660 = F.relu(t_659) 556 | t_661 = self.n_Conv_131(t_660) 557 | t_662 = F.relu(t_661) 558 | t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0) 559 | t_663 = self.n_Conv_132(t_662_padded) 560 | t_664 = F.relu(t_663) 561 | t_665 = self.n_Conv_133(t_664) 562 | t_666 = torch.add(t_665, t_660) 563 | t_667 = F.relu(t_666) 564 | t_668 = self.n_Conv_134(t_667) 565 | t_669 = F.relu(t_668) 566 | t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0) 567 | t_670 = self.n_Conv_135(t_669_padded) 568 | t_671 = F.relu(t_670) 569 | t_672 = self.n_Conv_136(t_671) 570 | t_673 = torch.add(t_672, t_667) 571 | t_674 = F.relu(t_673) 572 | t_675 = self.n_Conv_137(t_674) 573 | t_676 = F.relu(t_675) 574 | t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0) 575 | t_677 = self.n_Conv_138(t_676_padded) 576 | t_678 = F.relu(t_677) 577 | t_679 = self.n_Conv_139(t_678) 578 | t_680 = torch.add(t_679, t_674) 579 | t_681 = F.relu(t_680) 580 | t_682 = self.n_Conv_140(t_681) 581 | t_683 = F.relu(t_682) 582 | t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0) 583 | t_684 = self.n_Conv_141(t_683_padded) 584 | t_685 = F.relu(t_684) 585 | t_686 = self.n_Conv_142(t_685) 586 | t_687 = torch.add(t_686, t_681) 587 | t_688 = F.relu(t_687) 588 | t_689 = self.n_Conv_143(t_688) 589 | t_690 = F.relu(t_689) 590 | t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0) 591 | t_691 = self.n_Conv_144(t_690_padded) 592 | t_692 = F.relu(t_691) 593 | t_693 = self.n_Conv_145(t_692) 594 | t_694 = torch.add(t_693, t_688) 595 | t_695 = F.relu(t_694) 596 | t_696 = self.n_Conv_146(t_695) 597 | t_697 = F.relu(t_696) 598 | t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0) 599 | t_698 = self.n_Conv_147(t_697_padded) 600 | t_699 = F.relu(t_698) 601 | t_700 = self.n_Conv_148(t_699) 602 | t_701 = torch.add(t_700, t_695) 603 | t_702 = F.relu(t_701) 604 | t_703 = self.n_Conv_149(t_702) 605 | t_704 = F.relu(t_703) 606 | t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0) 607 | t_705 = self.n_Conv_150(t_704_padded) 608 | t_706 = F.relu(t_705) 609 | t_707 = self.n_Conv_151(t_706) 610 | t_708 = torch.add(t_707, t_702) 611 | t_709 = F.relu(t_708) 612 | t_710 = self.n_Conv_152(t_709) 613 | t_711 = F.relu(t_710) 614 | t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0) 615 | t_712 = self.n_Conv_153(t_711_padded) 616 | t_713 = F.relu(t_712) 617 | t_714 = self.n_Conv_154(t_713) 618 | t_715 = torch.add(t_714, t_709) 619 | t_716 = F.relu(t_715) 620 | t_717 = self.n_Conv_155(t_716) 621 | t_718 = F.relu(t_717) 622 | t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0) 623 | t_719 = self.n_Conv_156(t_718_padded) 624 | t_720 = F.relu(t_719) 625 | t_721 = self.n_Conv_157(t_720) 626 | t_722 = torch.add(t_721, t_716) 627 | t_723 = F.relu(t_722) 628 | t_724 = self.n_Conv_158(t_723) 629 | t_725 = self.n_Conv_159(t_723) 630 | t_726 = F.relu(t_725) 631 | t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0) 632 | t_727 = self.n_Conv_160(t_726_padded) 633 | t_728 = F.relu(t_727) 634 | t_729 = self.n_Conv_161(t_728) 635 | t_730 = torch.add(t_729, t_724) 636 | t_731 = F.relu(t_730) 637 | t_732 = self.n_Conv_162(t_731) 638 | t_733 = F.relu(t_732) 639 | t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0) 640 | t_734 = self.n_Conv_163(t_733_padded) 641 | t_735 = F.relu(t_734) 642 | t_736 = self.n_Conv_164(t_735) 643 | t_737 = torch.add(t_736, t_731) 644 | t_738 = F.relu(t_737) 645 | t_739 = self.n_Conv_165(t_738) 646 | t_740 = F.relu(t_739) 647 | t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0) 648 | t_741 = self.n_Conv_166(t_740_padded) 649 | t_742 = F.relu(t_741) 650 | t_743 = self.n_Conv_167(t_742) 651 | t_744 = torch.add(t_743, t_738) 652 | t_745 = F.relu(t_744) 653 | t_746 = self.n_Conv_168(t_745) 654 | t_747 = self.n_Conv_169(t_745) 655 | t_748 = F.relu(t_747) 656 | t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0) 657 | t_749 = self.n_Conv_170(t_748_padded) 658 | t_750 = F.relu(t_749) 659 | t_751 = self.n_Conv_171(t_750) 660 | t_752 = torch.add(t_751, t_746) 661 | t_753 = F.relu(t_752) 662 | t_754 = self.n_Conv_172(t_753) 663 | t_755 = F.relu(t_754) 664 | t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0) 665 | t_756 = self.n_Conv_173(t_755_padded) 666 | t_757 = F.relu(t_756) 667 | t_758 = self.n_Conv_174(t_757) 668 | t_759 = torch.add(t_758, t_753) 669 | t_760 = F.relu(t_759) 670 | t_761 = self.n_Conv_175(t_760) 671 | t_762 = F.relu(t_761) 672 | t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0) 673 | t_763 = self.n_Conv_176(t_762_padded) 674 | t_764 = F.relu(t_763) 675 | t_765 = self.n_Conv_177(t_764) 676 | t_766 = torch.add(t_765, t_760) 677 | t_767 = F.relu(t_766) 678 | t_768 = self.n_Conv_178(t_767) 679 | t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:]) 680 | t_770 = torch.squeeze(t_769, 3) 681 | t_770 = torch.squeeze(t_770, 2) 682 | t_771 = torch.sigmoid(t_770) 683 | return t_771 684 | 685 | def load_state_dict(self, state_dict, **kwargs): 686 | self.tags = state_dict.get('tags', []) 687 | 688 | super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'}) 689 | 690 | 691 | def resize_image(im, width, height): 692 | ratio = width / height 693 | src_ratio = im.width / im.height 694 | 695 | src_w = width if ratio < src_ratio else im.width * height // im.height 696 | src_h = height if ratio >= src_ratio else im.height * width // im.width 697 | 698 | resized = im.resize((src_w, src_h), resample=LANCZOS) 699 | res = Image.new("RGB", (width, height)) 700 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2)) 701 | 702 | if ratio < src_ratio: 703 | fill_height = height // 2 - src_h // 2 704 | res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0)) 705 | res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), 706 | box=(0, fill_height + src_h)) 707 | elif ratio > src_ratio: 708 | fill_width = width // 2 - src_w // 2 709 | res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0)) 710 | res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), 711 | box=(fill_width + src_w, 0)) 712 | 713 | return res 714 | 715 | 716 | class DeepDanbooru: 717 | def __init__(self): 718 | self.model = DeepDanbooruModel() 719 | 720 | foundation_model_id = 'ly261666/cv_portrait_model' 721 | snapshot_path = snapshot_download(foundation_model_id, revision='v4.0') 722 | pretrain_model_path = os.path.join(snapshot_path, 'model-resnet_custom_v3.pt') 723 | 724 | self.model.load_state_dict(torch.load(pretrain_model_path, map_location="cpu")) 725 | self.model.eval() 726 | self.model.to(torch.float16) 727 | 728 | def start(self): 729 | self.model.cuda() 730 | 731 | def stop(self): 732 | self.model.cpu() 733 | torch.cuda.empty_cache() 734 | torch.cuda.ipc_collect() 735 | 736 | def tag(self, pil_image): 737 | threshold = 0.5 738 | use_spaces = False 739 | use_escape = True 740 | alpha_sort = True 741 | include_ranks = False 742 | 743 | pic = resize_image(pil_image.convert("RGB"), 512, 512) 744 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255 745 | 746 | with torch.no_grad(), torch.autocast("cuda"): 747 | x = torch.from_numpy(a).cuda() 748 | y = self.model(x)[0].detach().cpu().numpy() 749 | 750 | probability_dict = {} 751 | 752 | for tag, probability in zip(self.model.tags, y): 753 | if probability < threshold: 754 | continue 755 | 756 | if tag.startswith("rating:"): 757 | continue 758 | 759 | probability_dict[tag] = probability 760 | 761 | if alpha_sort: 762 | tags = sorted(probability_dict) 763 | else: 764 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])] 765 | 766 | res = [] 767 | 768 | for tag in [x for x in tags]: 769 | probability = probability_dict[tag] 770 | tag_outformat = tag 771 | if use_spaces: 772 | tag_outformat = tag_outformat.replace('_', ' ') 773 | if use_escape: 774 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat) 775 | if include_ranks: 776 | tag_outformat = f"({tag_outformat}:{probability:.3f})" 777 | 778 | res.append(tag_outformat) 779 | 780 | return ", ".join(res) 781 | 782 | 783 | ''' 784 | model = DeepDanbooru() 785 | impath = 'lyf' 786 | imlist = os.listdir(impath) 787 | result_list = [] 788 | for im in imlist: 789 | if im[-4:]=='.png': 790 | print(im) 791 | img = Image.open(os.path.join(impath, im)) 792 | result = model.tag(img) 793 | print(result) 794 | result_list.append(result) 795 | model.stop() 796 | ''' 797 | -------------------------------------------------------------------------------- /facechain/train_text_to_image_lora.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright (c) Alibaba, Inc. and its affiliates. 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """Fine-tuning script for Stable Diffusion for text2image with support for LoRA.""" 17 | 18 | import argparse 19 | import itertools 20 | import json 21 | import logging 22 | import math 23 | import os 24 | import random 25 | import shutil 26 | from pathlib import Path 27 | 28 | import PIL.Image 29 | import cv2 30 | import datasets 31 | import diffusers 32 | import numpy as np 33 | import onnxruntime 34 | import torch 35 | import torch.nn.functional as F 36 | import torch.utils.checkpoint 37 | import transformers 38 | from PIL import Image 39 | from accelerate import Accelerator 40 | from accelerate.logging import get_logger 41 | from accelerate.utils import ProjectConfiguration, set_seed 42 | from datasets import load_dataset 43 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel 44 | from diffusers.loaders import AttnProcsLayers 45 | from diffusers.models.attention_processor import LoRAAttnProcessor 46 | from diffusers.optimization import get_scheduler 47 | from diffusers.utils import check_min_version, is_wandb_available 48 | from diffusers.utils.import_utils import is_xformers_available 49 | from huggingface_hub import create_repo, upload_folder 50 | from modelscope import snapshot_download 51 | from packaging import version 52 | from torchvision import transforms 53 | from tqdm.auto import tqdm 54 | from transformers import CLIPTextModel, CLIPTokenizer 55 | 56 | from facechain.inference import data_process_fn 57 | 58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks. 59 | check_min_version("0.14.0.dev0") 60 | 61 | logger = get_logger(__name__, log_level="INFO") 62 | 63 | 64 | def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None): 65 | img_str = "" 66 | for i, image in enumerate(images): 67 | image.save(os.path.join(repo_folder, f"image_{i}.png")) 68 | img_str += f"![img_{i}](./image_{i}.png)\n" 69 | 70 | yaml = f""" 71 | --- 72 | license: creativeml-openrail-m 73 | base_model: {base_model} 74 | tags: 75 | - stable-diffusion 76 | - stable-diffusion-diffusers 77 | - text-to-image 78 | - diffusers 79 | - lora 80 | inference: true 81 | --- 82 | """ 83 | model_card = f""" 84 | # LoRA text2image fine-tuning - {repo_id} 85 | These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n 86 | {img_str} 87 | """ 88 | with open(os.path.join(repo_folder, "README.md"), "w") as f: 89 | f.write(yaml + model_card) 90 | 91 | 92 | def softmax(x): 93 | x -= np.max(x, axis=0, keepdims=True) 94 | x = np.exp(x) / np.sum(np.exp(x), axis=0, keepdims=True) 95 | return x 96 | 97 | 98 | def get_rot(image): 99 | model_dir = snapshot_download('Cherrytest/rot_bgr', revision='v1.0.0') 100 | model_path = os.path.join(model_dir, 'rot_bgr.onnx') 101 | ort_session = onnxruntime.InferenceSession(model_path) 102 | 103 | img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) 104 | img_clone = img_cv.copy() 105 | img_np = cv2.resize(img_cv, (224, 224)) 106 | img_np = img_np.astype(np.float32) 107 | mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape((1, 1, 3)) 108 | norm = np.array([0.01742919, 0.017507, 0.01712475], dtype=np.float32).reshape((1, 1, 3)) 109 | img_np = (img_np - mean) * norm 110 | img_tensor = torch.from_numpy(img_np) 111 | img_tensor = img_tensor.unsqueeze(0) 112 | img_nchw = img_tensor.permute(0, 3, 1, 2) 113 | ort_inputs = {ort_session.get_inputs()[0].name: img_nchw.numpy()} 114 | outputs = ort_session.run(None, ort_inputs) 115 | logits = outputs[0].reshape((-1,)) 116 | probs = softmax(logits) 117 | rot_idx = np.argmax(probs) 118 | if rot_idx == 1: 119 | print('rot 90') 120 | img_clone = cv2.transpose(img_clone) 121 | img_clone = np.flip(img_clone, 1) 122 | return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB)) 123 | elif rot_idx == 2: 124 | print('rot 180') 125 | img_clone = cv2.flip(img_clone, -1) 126 | return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB)) 127 | elif rot_idx == 3: 128 | print('rot 270') 129 | img_clone = cv2.transpose(img_clone) 130 | img_clone = np.flip(img_clone, 0) 131 | return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB)) 132 | else: 133 | return image 134 | 135 | 136 | def prepare_dataset(instance_images: list, output_dataset_dir): 137 | if not os.path.exists(output_dataset_dir): 138 | os.makedirs(output_dataset_dir) 139 | for i, temp_path in enumerate(instance_images): 140 | image = PIL.Image.open(temp_path) 141 | # image = PIL.Image.open(temp_path.name) 142 | ''' 143 | w, h = image.size 144 | max_size = max(w, h) 145 | ratio = 1024 / max_size 146 | new_w = round(w * ratio) 147 | new_h = round(h * ratio) 148 | ''' 149 | image = image.convert('RGB') 150 | image = get_rot(image) 151 | # image = image.resize((new_w, new_h)) 152 | # image = image.resize((new_w, new_h), PIL.Image.ANTIALIAS) 153 | out_path = f'{output_dataset_dir}/{i:03d}.jpg' 154 | image.save(out_path, format='JPEG', quality=100) 155 | 156 | 157 | def parse_args(): 158 | parser = argparse.ArgumentParser(description="Simple example of a training script.") 159 | parser.add_argument( 160 | "--pretrained_model_name_or_path", 161 | type=str, 162 | default=None, 163 | required=True, 164 | help="Path to pretrained model or model identifier.", 165 | ) 166 | parser.add_argument( 167 | "--revision", 168 | type=str, 169 | default=None, 170 | required=False, 171 | help="Revision of pretrained model identifier.", 172 | ) 173 | parser.add_argument( 174 | "--sub_path", 175 | type=str, 176 | default=None, 177 | required=False, 178 | help="The sub model path of the `pretrained_model_name_or_path`", 179 | ) 180 | parser.add_argument( 181 | "--dataset_name", 182 | type=str, 183 | default=None, 184 | help=( 185 | "The data images dir" 186 | ), 187 | ) 188 | parser.add_argument( 189 | "--dataset_config_name", 190 | type=str, 191 | default=None, 192 | help="The config of the Dataset, leave as None if there's only one config.", 193 | ) 194 | parser.add_argument( 195 | "--train_data_dir", 196 | type=str, 197 | default=None, 198 | help=( 199 | "A folder containing the training data. Folder contents must follow the structure described in" 200 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" 201 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." 202 | ), 203 | ) 204 | parser.add_argument( 205 | "--output_dataset_name", 206 | type=str, 207 | default=None, 208 | help=( 209 | "The dataset dir after processing" 210 | ), 211 | ) 212 | parser.add_argument( 213 | "--image_column", type=str, default="image", help="The column of the dataset containing an image." 214 | ) 215 | parser.add_argument( 216 | "--caption_column", 217 | type=str, 218 | default="text", 219 | help="The column of the dataset containing a caption or a list of captions.", 220 | ) 221 | parser.add_argument( 222 | "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference." 223 | ) 224 | parser.add_argument( 225 | "--num_validation_images", 226 | type=int, 227 | default=4, 228 | help="Number of images that should be generated during validation with `validation_prompt`.", 229 | ) 230 | parser.add_argument( 231 | "--validation_epochs", 232 | type=int, 233 | default=1, 234 | help=( 235 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt" 236 | " `args.validation_prompt` multiple times: `args.num_validation_images`." 237 | ), 238 | ) 239 | parser.add_argument( 240 | "--max_train_samples", 241 | type=int, 242 | default=None, 243 | help=( 244 | "For debugging purposes or quicker training, truncate the number of training examples to this " 245 | "value if set." 246 | ), 247 | ) 248 | parser.add_argument( 249 | "--output_dir", 250 | type=str, 251 | default="sd-model-finetuned-lora", 252 | help="The output directory where the model predictions and checkpoints will be written.", 253 | ) 254 | parser.add_argument( 255 | "--cache_dir", 256 | type=str, 257 | default=None, 258 | help="The directory where the downloaded models and datasets will be stored.", 259 | ) 260 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") 261 | parser.add_argument( 262 | "--resolution", 263 | type=int, 264 | default=512, 265 | help=( 266 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" 267 | " resolution" 268 | ), 269 | ) 270 | parser.add_argument( 271 | "--center_crop", 272 | default=False, 273 | action="store_true", 274 | help=( 275 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly" 276 | " cropped. The images will be resized to the resolution first before cropping." 277 | ), 278 | ) 279 | parser.add_argument( 280 | "--random_flip", 281 | action="store_true", 282 | help="whether to randomly flip images horizontally", 283 | ) 284 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") 285 | 286 | # lora args 287 | parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora") 288 | parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True") 289 | parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True") 290 | parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True") 291 | parser.add_argument( 292 | "--lora_bias", 293 | type=str, 294 | default="none", 295 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True", 296 | ) 297 | parser.add_argument( 298 | "--lora_text_encoder_r", 299 | type=int, 300 | default=4, 301 | help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True", 302 | ) 303 | parser.add_argument( 304 | "--lora_text_encoder_alpha", 305 | type=int, 306 | default=32, 307 | help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True", 308 | ) 309 | parser.add_argument( 310 | "--lora_text_encoder_dropout", 311 | type=float, 312 | default=0.0, 313 | help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True", 314 | ) 315 | parser.add_argument( 316 | "--lora_text_encoder_bias", 317 | type=str, 318 | default="none", 319 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True", 320 | ) 321 | 322 | parser.add_argument( 323 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." 324 | ) 325 | parser.add_argument("--num_train_epochs", type=int, default=100) 326 | parser.add_argument( 327 | "--max_train_steps", 328 | type=int, 329 | default=None, 330 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 331 | ) 332 | parser.add_argument( 333 | "--gradient_accumulation_steps", 334 | type=int, 335 | default=1, 336 | help="Number of updates steps to accumulate before performing a backward/update pass.", 337 | ) 338 | parser.add_argument( 339 | "--gradient_checkpointing", 340 | action="store_true", 341 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", 342 | ) 343 | parser.add_argument( 344 | "--learning_rate", 345 | type=float, 346 | default=1e-4, 347 | help="Initial learning rate (after the potential warmup period) to use.", 348 | ) 349 | parser.add_argument( 350 | "--scale_lr", 351 | action="store_true", 352 | default=False, 353 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", 354 | ) 355 | parser.add_argument( 356 | "--lr_scheduler", 357 | type=str, 358 | default="constant", 359 | help=( 360 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 361 | ' "constant", "constant_with_warmup"]' 362 | ), 363 | ) 364 | parser.add_argument( 365 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 366 | ) 367 | parser.add_argument( 368 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." 369 | ) 370 | parser.add_argument( 371 | "--allow_tf32", 372 | action="store_true", 373 | help=( 374 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" 375 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" 376 | ), 377 | ) 378 | parser.add_argument( 379 | "--dataloader_num_workers", 380 | type=int, 381 | default=0, 382 | help=( 383 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." 384 | ), 385 | ) 386 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") 387 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") 388 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") 389 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") 390 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") 391 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") 392 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") 393 | parser.add_argument( 394 | "--hub_model_id", 395 | type=str, 396 | default=None, 397 | help="The name of the repository to keep in sync with the local `output_dir`.", 398 | ) 399 | parser.add_argument( 400 | "--logging_dir", 401 | type=str, 402 | default="logs", 403 | help=( 404 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 405 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." 406 | ), 407 | ) 408 | parser.add_argument( 409 | "--mixed_precision", 410 | type=str, 411 | default=None, 412 | choices=["no", "fp16", "bf16"], 413 | help=( 414 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" 415 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" 416 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." 417 | ), 418 | ) 419 | parser.add_argument( 420 | "--report_to", 421 | type=str, 422 | default="tensorboard", 423 | help=( 424 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 425 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 426 | ), 427 | ) 428 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 429 | parser.add_argument( 430 | "--checkpointing_steps", 431 | type=int, 432 | default=500, 433 | help=( 434 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" 435 | " training using `--resume_from_checkpoint`." 436 | ), 437 | ) 438 | parser.add_argument( 439 | "--checkpoints_total_limit", 440 | type=int, 441 | default=None, 442 | help=( 443 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." 444 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" 445 | " for more docs" 446 | ), 447 | ) 448 | parser.add_argument( 449 | "--resume_from_checkpoint", 450 | type=str, 451 | default=None, 452 | help=( 453 | "Whether training should be resumed from a previous checkpoint. Use a path saved by" 454 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' 455 | ), 456 | ) 457 | parser.add_argument( 458 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." 459 | ) 460 | 461 | args = parser.parse_args() 462 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 463 | if env_local_rank != -1 and env_local_rank != args.local_rank: 464 | args.local_rank = env_local_rank 465 | 466 | # Sanity checks 467 | if args.dataset_name is None and args.train_data_dir is None: 468 | raise ValueError("Need either a dataset name or a training folder.") 469 | 470 | return args 471 | 472 | 473 | DATASET_NAME_MAPPING = { 474 | "lambdalabs/pokemon-blip-captions": ("image", "text"), 475 | } 476 | 477 | 478 | def main(): 479 | args = parse_args() 480 | logging_dir = os.path.join(args.output_dir, args.logging_dir) 481 | shutil.rmtree(args.output_dataset_name, ignore_errors=True) 482 | shutil.rmtree(args.output_dir, ignore_errors=True) 483 | os.makedirs(args.output_dir) 484 | args.dataset_name = [os.path.join(args.dataset_name, x) for x in os.listdir(args.dataset_name)] 485 | 486 | print('All input images:', args.dataset_name) 487 | prepare_dataset(args.dataset_name, args.output_dataset_name) 488 | ## Our data process fn 489 | data_process_fn(input_img_dir=args.output_dataset_name, use_data_process=True) 490 | args.dataset_name = args.output_dataset_name + '_labeled' 491 | 492 | accelerator_project_config = ProjectConfiguration( 493 | total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir 494 | ) 495 | 496 | accelerator = Accelerator( 497 | gradient_accumulation_steps=args.gradient_accumulation_steps, 498 | mixed_precision=args.mixed_precision, 499 | log_with=args.report_to, 500 | project_config=accelerator_project_config, 501 | ) 502 | if args.report_to == "wandb": 503 | if not is_wandb_available(): 504 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.") 505 | import wandb 506 | 507 | # Make one log on every process with the configuration for debugging. 508 | logging.basicConfig( 509 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 510 | datefmt="%m/%d/%Y %H:%M:%S", 511 | level=logging.INFO, 512 | ) 513 | logger.info(accelerator.state, main_process_only=False) 514 | if accelerator.is_local_main_process: 515 | datasets.utils.logging.set_verbosity_warning() 516 | transformers.utils.logging.set_verbosity_warning() 517 | diffusers.utils.logging.set_verbosity_info() 518 | else: 519 | datasets.utils.logging.set_verbosity_error() 520 | transformers.utils.logging.set_verbosity_error() 521 | diffusers.utils.logging.set_verbosity_error() 522 | 523 | # If passed along, set the training seed now. 524 | if args.seed is not None: 525 | set_seed(args.seed) 526 | 527 | # Handle the repository creation 528 | if accelerator.is_main_process: 529 | if args.output_dir is not None: 530 | os.makedirs(args.output_dir, exist_ok=True) 531 | 532 | if args.push_to_hub: 533 | repo_id = create_repo( 534 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token 535 | ).repo_id 536 | 537 | ## Download foundation Model 538 | model_dir = snapshot_download(args.pretrained_model_name_or_path, revision=args.revision) 539 | 540 | if args.sub_path is not None and len(args.sub_path) > 0: 541 | model_dir = os.path.join(model_dir, args.sub_path) 542 | 543 | # Load scheduler, tokenizer and models. 544 | noise_scheduler = DDPMScheduler.from_pretrained(model_dir, subfolder="scheduler") 545 | tokenizer = CLIPTokenizer.from_pretrained( 546 | model_dir, subfolder="tokenizer" 547 | ) 548 | text_encoder = CLIPTextModel.from_pretrained( 549 | model_dir, subfolder="text_encoder" 550 | ) 551 | vae = AutoencoderKL.from_pretrained(model_dir, subfolder="vae") 552 | unet = UNet2DConditionModel.from_pretrained( 553 | model_dir, subfolder="unet" 554 | ) 555 | 556 | # For mixed precision training we cast the text_encoder and vae weights to half-precision 557 | # as these models are only used for inference, keeping weights in full precision is not required. 558 | weight_dtype = torch.float32 559 | if accelerator.mixed_precision == "fp16": 560 | weight_dtype = torch.float16 561 | elif accelerator.mixed_precision == "bf16": 562 | weight_dtype = torch.bfloat16 563 | 564 | if args.use_peft: 565 | from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict 566 | 567 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"] 568 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"] 569 | 570 | config = LoraConfig( 571 | r=args.lora_r, 572 | lora_alpha=args.lora_alpha, 573 | target_modules=UNET_TARGET_MODULES, 574 | lora_dropout=args.lora_dropout, 575 | bias=args.lora_bias, 576 | ) 577 | unet = LoraModel(config, unet) 578 | 579 | vae.requires_grad_(False) 580 | if args.train_text_encoder: 581 | config = LoraConfig( 582 | r=args.lora_text_encoder_r, 583 | lora_alpha=args.lora_text_encoder_alpha, 584 | target_modules=TEXT_ENCODER_TARGET_MODULES, 585 | lora_dropout=args.lora_text_encoder_dropout, 586 | bias=args.lora_text_encoder_bias, 587 | ) 588 | text_encoder = LoraModel(config, text_encoder) 589 | else: 590 | # freeze parameters of models to save more memory 591 | unet.requires_grad_(False) 592 | vae.requires_grad_(False) 593 | 594 | text_encoder.requires_grad_(False) 595 | 596 | # now we will add new LoRA weights to the attention layers 597 | # It's important to realize here how many attention weights will be added and of which sizes 598 | # The sizes of the attention layers consist only of two different variables: 599 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`. 600 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`. 601 | 602 | # Let's first see how many attention processors we will have to set. 603 | # For Stable Diffusion, it should be equal to: 604 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12 605 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2 606 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18 607 | # => 32 layers 608 | 609 | # Set correct lora layers 610 | lora_attn_procs = {} 611 | for name in unet.attn_processors.keys(): 612 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 613 | if name.startswith("mid_block"): 614 | hidden_size = unet.config.block_out_channels[-1] 615 | elif name.startswith("up_blocks"): 616 | block_id = int(name[len("up_blocks.")]) 617 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 618 | elif name.startswith("down_blocks"): 619 | block_id = int(name[len("down_blocks.")]) 620 | hidden_size = unet.config.block_out_channels[block_id] 621 | 622 | lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) 623 | 624 | unet.set_attn_processor(lora_attn_procs) 625 | lora_layers = AttnProcsLayers(unet.attn_processors) 626 | 627 | # Move unet, vae and text_encoder to device and cast to weight_dtype 628 | vae.to(accelerator.device, dtype=weight_dtype) 629 | if not args.train_text_encoder: 630 | text_encoder.to(accelerator.device, dtype=weight_dtype) 631 | 632 | if args.enable_xformers_memory_efficient_attention: 633 | if is_xformers_available(): 634 | import xformers 635 | 636 | xformers_version = version.parse(xformers.__version__) 637 | if xformers_version == version.parse("0.0.16"): 638 | logger.warn( 639 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 640 | ) 641 | unet.enable_xformers_memory_efficient_attention() 642 | else: 643 | raise ValueError("xformers is not available. Make sure it is installed correctly") 644 | 645 | # Enable TF32 for faster training on Ampere GPUs, 646 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices 647 | if args.allow_tf32: 648 | torch.backends.cuda.matmul.allow_tf32 = True 649 | 650 | if args.scale_lr: 651 | args.learning_rate = ( 652 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes 653 | ) 654 | 655 | # Initialize the optimizer 656 | if args.use_8bit_adam: 657 | try: 658 | import bitsandbytes as bnb 659 | except ImportError: 660 | raise ImportError( 661 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" 662 | ) 663 | 664 | optimizer_cls = bnb.optim.AdamW8bit 665 | else: 666 | optimizer_cls = torch.optim.AdamW 667 | 668 | if args.use_peft: 669 | # Optimizer creation 670 | params_to_optimize = ( 671 | itertools.chain(unet.parameters(), text_encoder.parameters()) 672 | if args.train_text_encoder 673 | else unet.parameters() 674 | ) 675 | optimizer = optimizer_cls( 676 | params_to_optimize, 677 | lr=args.learning_rate, 678 | betas=(args.adam_beta1, args.adam_beta2), 679 | weight_decay=args.adam_weight_decay, 680 | eps=args.adam_epsilon, 681 | ) 682 | else: 683 | optimizer = optimizer_cls( 684 | lora_layers.parameters(), 685 | lr=args.learning_rate, 686 | betas=(args.adam_beta1, args.adam_beta2), 687 | weight_decay=args.adam_weight_decay, 688 | eps=args.adam_epsilon, 689 | ) 690 | 691 | # Get the datasets: you can either provide your own training and evaluation files (see below) 692 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). 693 | 694 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently 695 | # download the dataset. 696 | if args.dataset_name is not None: 697 | # Downloading and loading a dataset from the hub. 698 | dataset = load_dataset( 699 | args.dataset_name, 700 | args.dataset_config_name, 701 | cache_dir=args.cache_dir, 702 | ) 703 | else: 704 | data_files = {} 705 | if args.train_data_dir is not None: 706 | data_files["train"] = os.path.join(args.train_data_dir, "**") 707 | dataset = load_dataset( 708 | "imagefolder", 709 | data_files=data_files, 710 | cache_dir=args.cache_dir, 711 | ) 712 | # See more about loading custom images at 713 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder 714 | 715 | # Preprocessing the datasets. 716 | # We need to tokenize inputs and targets. 717 | column_names = dataset["train"].column_names 718 | 719 | # 6. Get the column names for input/target. 720 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) 721 | if args.image_column is None: 722 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 723 | else: 724 | image_column = args.image_column 725 | if image_column not in column_names: 726 | raise ValueError( 727 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" 728 | ) 729 | if args.caption_column is None: 730 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 731 | else: 732 | caption_column = args.caption_column 733 | if caption_column not in column_names: 734 | raise ValueError( 735 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" 736 | ) 737 | 738 | # Preprocessing the datasets. 739 | # We need to tokenize input captions and transform the images. 740 | def tokenize_captions(examples, is_train=True): 741 | captions = [] 742 | for caption in examples[caption_column]: 743 | if isinstance(caption, str): 744 | captions.append(caption) 745 | elif isinstance(caption, (list, np.ndarray)): 746 | # take a random caption if there are multiple 747 | captions.append(random.choice(caption) if is_train else caption[0]) 748 | else: 749 | raise ValueError( 750 | f"Caption column `{caption_column}` should contain either strings or lists of strings." 751 | ) 752 | inputs = tokenizer( 753 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" 754 | ) 755 | return inputs.input_ids 756 | 757 | # Preprocessing the datasets. 758 | train_transforms = transforms.Compose( 759 | [ 760 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), 761 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), 762 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), 763 | transforms.ToTensor(), 764 | transforms.Normalize([0.5], [0.5]), 765 | ] 766 | ) 767 | 768 | def preprocess_train(examples): 769 | images = [image.convert("RGB") for image in examples[image_column]] 770 | examples["pixel_values"] = [train_transforms(image) for image in images] 771 | examples["input_ids"] = tokenize_captions(examples) 772 | return examples 773 | 774 | with accelerator.main_process_first(): 775 | if args.max_train_samples is not None: 776 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) 777 | # Set the training transforms 778 | train_dataset = dataset["train"].with_transform(preprocess_train) 779 | 780 | def collate_fn(examples): 781 | pixel_values = torch.stack([example["pixel_values"] for example in examples]) 782 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() 783 | input_ids = torch.stack([example["input_ids"] for example in examples]) 784 | return {"pixel_values": pixel_values, "input_ids": input_ids} 785 | 786 | # DataLoaders creation: 787 | train_dataloader = torch.utils.data.DataLoader( 788 | train_dataset, 789 | shuffle=True, 790 | collate_fn=collate_fn, 791 | batch_size=args.train_batch_size, 792 | num_workers=args.dataloader_num_workers, 793 | ) 794 | 795 | # Scheduler and math around the number of training steps. 796 | overrode_max_train_steps = False 797 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 798 | if args.max_train_steps is None: 799 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 800 | overrode_max_train_steps = True 801 | 802 | lr_scheduler = get_scheduler( 803 | args.lr_scheduler, 804 | optimizer=optimizer, 805 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 806 | num_training_steps=args.max_train_steps * accelerator.num_processes, 807 | ) 808 | 809 | # Prepare everything with our `accelerator`. 810 | if args.use_peft: 811 | if args.train_text_encoder: 812 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 813 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler 814 | ) 815 | else: 816 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 817 | unet, optimizer, train_dataloader, lr_scheduler 818 | ) 819 | else: 820 | lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 821 | lora_layers, optimizer, train_dataloader, lr_scheduler 822 | ) 823 | unet = unet.cuda() 824 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 825 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 826 | if overrode_max_train_steps: 827 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 828 | # Afterwards we recalculate our number of training epochs 829 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 830 | 831 | # We need to initialize the trackers we use, and also store our configuration. 832 | # The trackers initializes automatically on the main process. 833 | if accelerator.is_main_process: 834 | accelerator.init_trackers("text2image-fine-tune", config=vars(args)) 835 | 836 | # Train! 837 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 838 | 839 | logger.info("***** Running training *****") 840 | logger.info(f" Num examples = {len(train_dataset)}") 841 | logger.info(f" Num Epochs = {args.num_train_epochs}") 842 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 843 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 844 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 845 | logger.info(f" Total optimization steps = {args.max_train_steps}") 846 | global_step = 0 847 | first_epoch = 0 848 | 849 | # Potentially load in the weights and states from a previous save 850 | if args.resume_from_checkpoint: 851 | if args.resume_from_checkpoint != "latest": 852 | path = os.path.basename(args.resume_from_checkpoint) 853 | else: 854 | # Get the most recent checkpoint 855 | dirs = os.listdir(args.output_dir) 856 | dirs = [d for d in dirs if d.startswith("checkpoint")] 857 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) 858 | path = dirs[-1] if len(dirs) > 0 else None 859 | 860 | if path is None: 861 | accelerator.print( 862 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." 863 | ) 864 | args.resume_from_checkpoint = None 865 | else: 866 | accelerator.print(f"Resuming from checkpoint {path}") 867 | accelerator.load_state(os.path.join(args.output_dir, path)) 868 | global_step = int(path.split("-")[1]) 869 | 870 | resume_global_step = global_step * args.gradient_accumulation_steps 871 | first_epoch = global_step // num_update_steps_per_epoch 872 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps) 873 | 874 | # Only show the progress bar once on each machine. 875 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) 876 | progress_bar.set_description("Steps") 877 | 878 | for epoch in range(first_epoch, args.num_train_epochs): 879 | unet.train() 880 | if args.train_text_encoder: 881 | text_encoder.train() 882 | train_loss = 0.0 883 | for step, batch in enumerate(train_dataloader): 884 | # Skip steps until we reach the resumed step 885 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: 886 | if step % args.gradient_accumulation_steps == 0: 887 | progress_bar.update(1) 888 | continue 889 | 890 | with accelerator.accumulate(unet): 891 | # Convert images to latent space 892 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() 893 | latents = latents * vae.config.scaling_factor 894 | 895 | # Sample noise that we'll add to the latents 896 | noise = torch.randn_like(latents) 897 | bsz = latents.shape[0] 898 | # Sample a random timestep for each image 899 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) 900 | timesteps = timesteps.long() 901 | 902 | # Add noise to the latents according to the noise magnitude at each timestep 903 | # (this is the forward diffusion process) 904 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 905 | 906 | # Get the text embedding for conditioning 907 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] 908 | 909 | # Get the target for loss depending on the prediction type 910 | if noise_scheduler.config.prediction_type == "epsilon": 911 | target = noise 912 | elif noise_scheduler.config.prediction_type == "v_prediction": 913 | target = noise_scheduler.get_velocity(latents, noise, timesteps) 914 | else: 915 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 916 | 917 | # Predict the noise residual and compute loss 918 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 919 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") 920 | 921 | # Gather the losses across all processes for logging (if we use distributed training). 922 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean() 923 | train_loss += avg_loss.item() / args.gradient_accumulation_steps 924 | 925 | # Backpropagate 926 | accelerator.backward(loss) 927 | if accelerator.sync_gradients: 928 | if args.use_peft: 929 | params_to_clip = ( 930 | itertools.chain(unet.parameters(), text_encoder.parameters()) 931 | if args.train_text_encoder 932 | else unet.parameters() 933 | ) 934 | else: 935 | params_to_clip = lora_layers.parameters() 936 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) 937 | optimizer.step() 938 | lr_scheduler.step() 939 | optimizer.zero_grad() 940 | 941 | # Checks if the accelerator has performed an optimization step behind the scenes 942 | if accelerator.sync_gradients: 943 | progress_bar.update(1) 944 | global_step += 1 945 | accelerator.log({"train_loss": train_loss}, step=global_step) 946 | train_loss = 0.0 947 | 948 | if global_step % args.checkpointing_steps == 0: 949 | if accelerator.is_main_process: 950 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 951 | accelerator.save_state(save_path) 952 | logger.info(f"Saved state to {save_path}") 953 | 954 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 955 | progress_bar.set_postfix(**logs) 956 | 957 | if global_step >= args.max_train_steps: 958 | break 959 | 960 | if accelerator.is_main_process: 961 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0: 962 | logger.info( 963 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:" 964 | f" {args.validation_prompt}." 965 | ) 966 | # create pipeline 967 | pipeline = DiffusionPipeline.from_pretrained( 968 | model_dir, 969 | unet=accelerator.unwrap_model(unet), 970 | text_encoder=accelerator.unwrap_model(text_encoder), 971 | torch_dtype=weight_dtype, 972 | ) 973 | pipeline = pipeline.to(accelerator.device) 974 | pipeline.set_progress_bar_config(disable=True) 975 | 976 | # run inference 977 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 978 | images = [] 979 | for _ in range(args.num_validation_images): 980 | images.append( 981 | pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0] 982 | ) 983 | 984 | if accelerator.is_main_process: 985 | for tracker in accelerator.trackers: 986 | if tracker.name == "tensorboard": 987 | np_images = np.stack([np.asarray(img) for img in images]) 988 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") 989 | if tracker.name == "wandb": 990 | tracker.log( 991 | { 992 | "validation": [ 993 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}") 994 | for i, image in enumerate(images) 995 | ] 996 | } 997 | ) 998 | 999 | del pipeline 1000 | torch.cuda.empty_cache() 1001 | 1002 | # Save the lora layers 1003 | accelerator.wait_for_everyone() 1004 | if accelerator.is_main_process: 1005 | if args.use_peft: 1006 | lora_config = {} 1007 | unwarpped_unet = accelerator.unwrap_model(unet) 1008 | state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet)) 1009 | lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True) 1010 | if args.train_text_encoder: 1011 | unwarpped_text_encoder = accelerator.unwrap_model(text_encoder) 1012 | text_encoder_state_dict = get_peft_model_state_dict( 1013 | unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder) 1014 | ) 1015 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} 1016 | state_dict.update(text_encoder_state_dict) 1017 | lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict( 1018 | inference=True 1019 | ) 1020 | 1021 | accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt")) 1022 | with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f: 1023 | json.dump(lora_config, f) 1024 | else: 1025 | unet = unet.to(torch.float32) 1026 | unet.save_attn_procs(args.output_dir) 1027 | 1028 | if args.push_to_hub: 1029 | save_model_card( 1030 | repo_id, 1031 | images=images, 1032 | base_model=model_dir, 1033 | dataset_name=args.dataset_name, 1034 | repo_folder=args.output_dir, 1035 | ) 1036 | upload_folder( 1037 | repo_id=repo_id, 1038 | folder_path=args.output_dir, 1039 | commit_message="End of training", 1040 | ignore_patterns=["step_*", "epoch_*"], 1041 | ) 1042 | 1043 | # Final inference 1044 | # Load previous pipeline 1045 | pipeline = DiffusionPipeline.from_pretrained( 1046 | model_dir, torch_dtype=weight_dtype 1047 | ) 1048 | 1049 | if args.use_peft: 1050 | 1051 | def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype): 1052 | with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f: 1053 | lora_config = json.load(f) 1054 | print(lora_config) 1055 | 1056 | checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt") 1057 | lora_checkpoint_sd = torch.load(checkpoint) 1058 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} 1059 | text_encoder_lora_ds = { 1060 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k 1061 | } 1062 | 1063 | unet_config = LoraConfig(**lora_config["peft_config"]) 1064 | pipe.unet = LoraModel(unet_config, pipe.unet) 1065 | set_peft_model_state_dict(pipe.unet, unet_lora_ds) 1066 | 1067 | if "text_encoder_peft_config" in lora_config: 1068 | text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"]) 1069 | pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder) 1070 | set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds) 1071 | 1072 | if dtype in (torch.float16, torch.bfloat16): 1073 | pipe.unet.half() 1074 | pipe.text_encoder.half() 1075 | 1076 | pipe.to(device) 1077 | return pipe 1078 | 1079 | pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype) 1080 | 1081 | else: 1082 | pipeline = pipeline.to(accelerator.device) 1083 | # load attention processors 1084 | pipeline.unet.load_attn_procs(args.output_dir) 1085 | 1086 | # run inference 1087 | if args.seed is not None: 1088 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) 1089 | else: 1090 | generator = None 1091 | images = [] 1092 | 1093 | accelerator.end_training() 1094 | 1095 | 1096 | if __name__ == "__main__": 1097 | main() 1098 | --------------------------------------------------------------------------------