├── facechain
├── __init__.py
├── data_process
│ ├── __init__.py
│ ├── preprocessing.py
│ ├── preprocessing -win10.py
│ └── deepbooru.py
├── merge_lora.py
├── inference.py
└── train_text_to_image_lora.py
├── .gitattributes
├── resources
├── example1.jpg
├── example2.jpg
├── example3.jpg
├── framework.jpg
└── framework_eng.jpg
├── requirements.txt
├── train_lora.sh
├── train_lora.bat
├── run_inference.py
├── README_ZH.md
├── README.md
├── app.py
├── app-win10.py
└── LICENSE
/facechain/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 |
--------------------------------------------------------------------------------
/facechain/data_process/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 |
--------------------------------------------------------------------------------
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.safetensors filter=lfs diff=lfs merge=lfs -text
2 | *.jpg filter=lfs diff=lfs merge=lfs -text
3 |
--------------------------------------------------------------------------------
/resources/example1.jpg:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:c773cdd97a6ab7dc8fe4075f216eefad71043b5f099fcf8399ff00c1a0bed2cc
3 | size 83419
4 |
--------------------------------------------------------------------------------
/resources/example2.jpg:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:637efd55906eeefc27dc76f97049424ff3cfc96eb3ab3089fc8f325ebda5f857
3 | size 51409
4 |
--------------------------------------------------------------------------------
/resources/example3.jpg:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:896bcd46629910602709ae94602c2ba99d239ce9c6eb4790a05111c4df7eaa6e
3 | size 131820
4 |
--------------------------------------------------------------------------------
/resources/framework.jpg:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:1ef2305413f33d93bac369836231f4a79a4df2463086b118ca1584b115c2e2b0
3 | size 476295
4 |
--------------------------------------------------------------------------------
/resources/framework_eng.jpg:
--------------------------------------------------------------------------------
1 | version https://git-lfs.github.com/spec/v1
2 | oid sha256:18adcd55c884e59d70addb717b52aa961369caa18fc5436115dbdc1f58614086
3 | size 487479
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | transformers
3 | diffusers
4 | onnxruntime
5 | modelscope[framework]
6 | Pillow
7 | opencv-python
8 | torchvision
9 | mmdet==2.26.0
10 | mmengine
11 | tensorflow==2.7.0
12 | #tensorflow-cpu # slower but no cuda conflicts
13 | numpy==1.22.0
14 | protobuf==3.20.1
15 | timm
16 | scikit-image
17 | gradio
18 |
19 | # mmcv-full (need mim install)
20 |
--------------------------------------------------------------------------------
/train_lora.sh:
--------------------------------------------------------------------------------
1 | export MODEL_NAME=$1
2 | export VERSION=$2
3 | export SUB_PATH=$3
4 | export DATASET_NAME=$4
5 | export OUTPUT_DATASET_NAME=$5
6 | export WORK_DIR=$6
7 |
8 | accelerate launch facechain/train_text_to_image_lora.py \
9 | --pretrained_model_name_or_path=$MODEL_NAME \
10 | --revision=$VERSION \
11 | --sub_path=$SUB_PATH \
12 | --dataset_name=$DATASET_NAME \
13 | --output_dataset_name=$OUTPUT_DATASET_NAME \
14 | --caption_column="text" \
15 | --resolution=512 --random_flip \
16 | --train_batch_size=1 \
17 | --num_train_epochs=200 --checkpointing_steps=5000 \
18 | --learning_rate=1e-04 --lr_scheduler="cosine" --lr_warmup_steps=0 \
19 | --seed=42 \
20 | --output_dir=$WORK_DIR \
21 | --lora_r=32 --lora_alpha=32 \
22 | --lora_text_encoder_r=32 --lora_text_encoder_alpha=32
23 |
24 |
25 |
--------------------------------------------------------------------------------
/train_lora.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set MODEL_NAME=%1
3 | set VERSION=%2
4 | set SUB_PATH=%3
5 | set DATASET_NAME=%4
6 | set OUTPUT_DATASET_NAME=%5
7 | set WORK_DIR=%6
8 |
9 | accelerate launch facechain/train_text_to_image_lora.py ^
10 | --pretrained_model_name_or_path=%MODEL_NAME% ^
11 | --revision=%VERSION% ^
12 | --sub_path=%SUB_PATH% ^
13 | --dataset_name=%DATASET_NAME% ^
14 | --output_dataset_name=%OUTPUT_DATASET_NAME% ^
15 | --caption_column="text" ^
16 | --resolution=512 --random_flip ^
17 | --train_batch_size=1 ^
18 | --num_train_epochs=200 --checkpointing_steps=5000 ^
19 | --learning_rate=1e-04 --lr_scheduler="cosine" --lr_warmup_steps=0 ^
20 | --seed=42 ^
21 | --output_dir=%WORK_DIR% ^
22 | --lora_r=32 --lora_alpha=32 ^
23 | --lora_text_encoder_r=32 --lora_text_encoder_alpha=32
24 |
--------------------------------------------------------------------------------
/run_inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 | import os
3 |
4 | from facechain.inference import GenPortrait
5 | import cv2
6 |
7 | use_main_model = True
8 | use_face_swap = True
9 | use_post_process = True
10 | use_stylization = False
11 |
12 | gen_portrait = GenPortrait(use_main_model, use_face_swap, use_post_process,
13 | use_stylization)
14 |
15 | processed_dir = './processed'
16 | num_generate = 5
17 | base_model = 'ly261666/cv_portrait_model'
18 | revision = 'v2.0'
19 | base_model_sub_dir = 'film/film'
20 | train_output_dir = './output'
21 | output_dir = './generated'
22 |
23 | outputs = gen_portrait(processed_dir, num_generate, base_model,
24 | train_output_dir, base_model_sub_dir, revision)
25 |
26 | os.makedirs(output_dir, exist_ok=True)
27 |
28 | for i, out_tmp in enumerate(outputs):
29 | cv2.imwrite(os.path.join(output_dir, f'{i}.png'), out_tmp)
30 |
--------------------------------------------------------------------------------
/facechain/merge_lora.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 | import torch
3 | import os
4 | import re
5 | from collections import defaultdict
6 | from safetensors.torch import load_file
7 |
8 |
9 | def merge_lora(pipeline, lora_path, multiplier, from_safetensor=False, device='cpu', dtype=torch.float32):
10 | LORA_PREFIX_UNET = "lora_unet"
11 | LORA_PREFIX_TEXT_ENCODER = "lora_te"
12 | if from_safetensor:
13 | state_dict = load_file(lora_path, device=device)
14 | else:
15 | checkpoint = torch.load(os.path.join(lora_path, 'pytorch_lora_weights.bin'), map_location=torch.device(device))
16 | new_dict = dict()
17 | for idx, key in enumerate(checkpoint):
18 | new_key = re.sub(r'\.processor\.', '_', key)
19 | new_key = re.sub(r'mid_block\.', 'mid_block_', new_key)
20 | new_key = re.sub('_lora.up.', '.lora_up.', new_key)
21 | new_key = re.sub('_lora.down.', '.lora_down.', new_key)
22 | new_key = re.sub(r'\.(\d+)\.', '_\\1_', new_key)
23 | new_key = re.sub('to_out', 'to_out_0', new_key)
24 | new_key = 'lora_unet_' + new_key
25 | new_dict[new_key] = checkpoint[key]
26 | state_dict = new_dict
27 | updates = defaultdict(dict)
28 | for key, value in state_dict.items():
29 | layer, elem = key.split('.', 1)
30 | updates[layer][elem] = value
31 |
32 | for layer, elems in updates.items():
33 |
34 | if "text" in layer:
35 | layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
36 | curr_layer = pipeline.text_encoder
37 | else:
38 | layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
39 | curr_layer = pipeline.unet
40 |
41 | temp_name = layer_infos.pop(0)
42 | while len(layer_infos) > -1:
43 | try:
44 | curr_layer = curr_layer.__getattr__(temp_name)
45 | if len(layer_infos) > 0:
46 | temp_name = layer_infos.pop(0)
47 | elif len(layer_infos) == 0:
48 | break
49 | except Exception:
50 | if len(layer_infos) == 0:
51 | print('Error loading layer')
52 | if len(temp_name) > 0:
53 | temp_name += "_" + layer_infos.pop(0)
54 | else:
55 | temp_name = layer_infos.pop(0)
56 |
57 | weight_up = elems['lora_up.weight'].to(dtype)
58 | weight_down = elems['lora_down.weight'].to(dtype)
59 | if 'alpha' in elems.keys():
60 | alpha = elems['alpha'].item() / weight_up.shape[1]
61 | else:
62 | alpha = 1.0
63 |
64 | curr_layer.weight.data = curr_layer.weight.data.to(device)
65 | if len(weight_up.shape) == 4:
66 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2),
67 | weight_down.squeeze(3).squeeze(2)).unsqueeze(
68 | 2).unsqueeze(3)
69 | else:
70 | curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
71 |
72 | return pipeline
73 |
--------------------------------------------------------------------------------
/README_ZH.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
FaceChain
6 |
7 |
8 |
9 |
10 | # 介绍
11 |
12 | FaceChain是一个可以用来打造个人数字形象的深度学习模型工具。用户仅需要提供最低三张照片即可获得独属于自己的个人形象数字替身。FaceChain支持在gradio的界面中使用模型训练和推理能力,也支持资深开发者使用python脚本进行训练推理。同时,FaceChain欢迎开发者对本Repo进行继续开发和贡献。
13 |
14 | 您也可以在[ModelScope创空间](https://modelscope.cn/studios/CVstudio/cv_human_portrait/summary)中直接体验这项技术而无需安装任何软件。
15 |
16 | FaceChain的模型由[ModelScope](https://github.com/modelscope/modelscope)开源模型社区提供支持。
17 |
18 | 
19 |
20 | 
21 |
22 | 
23 |
24 | # 安装
25 |
26 | 您也可以使用pip和conda搭建本地python环境,我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/)来管理您的依赖,安装完成后,执行如下命令:
27 |
28 | ```shell
29 | conda create -n facechain python=3.8 # python version >= 3.8
30 | conda activate facechain
31 |
32 | pip3 install -r requirements.txt
33 | pip3 install -U openmim
34 | mim install mmcv-full==1.7.0
35 | ```
36 |
37 | 或者,您可以使用ModelScope提供的官方镜像,这样您只需要安装gradio即可使用:
38 |
39 | ```shell
40 | registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.7.1-py38-torch2.0.1-tf1.15.5-1.8.0
41 | ```
42 |
43 | 我们也推荐使用我们的[notebook](https://www.modelscope.cn/my/mynotebook/preset)来进行训练和推理。
44 |
45 | 将本仓库克隆到本地:
46 |
47 | ```shell
48 | GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/modelscope/facechain.git
49 | cd facechain
50 | ```
51 |
52 | 安装依赖:
53 |
54 | ```shell
55 | # 如果使用了官方镜像,只需要执行
56 | pip3 install gradio
57 |
58 | # 如果使用conda虚拟环境,则参考上述”安装“章节
59 | ```
60 |
61 |
62 | 运行gradio来生成个人数字形象:
63 |
64 | ```shell
65 | python app.py
66 | ```
67 |
68 | 您可以看到log中的gradio启动日志,等待展示出http链接后,将http链接复制到浏览器中进行访问。之后在页面中点击“选择图片上传”,并选择最少一张包含人脸的图片。点击“开始训练”即可训练模型。训练完成后日志中会有对应展示,之后切换到“形象体验”标签页点击“开始推理”即可生成属于自己的数字形象。
69 |
70 | # 脚本运行
71 |
72 | FaceChain支持在python环境中直接进行训练和推理。在克隆后的文件夹中直接运行如下命令来进行训练:
73 |
74 | ```shell
75 | PYTHONPATH=. sh train_lora.sh "ly261666/cv_portrait_model" "v2.0" "film/film" "./imgs" "./processed" "./output"
76 | ```
77 |
78 | 参数含义:
79 |
80 | ```text
81 | ly261666/cv_portrait_model: ModelScope模型仓库的stable diffusion基模型,该模型会用于训练,可以不修改
82 | v2.0: 该基模型的版本号,可以不修改
83 | film/film: 该基模型包含了多个不同风格的子目录,其中使用了film/film目录中的风格模型,可以不修改
84 | ./imgs: 本参数需要用实际值替换,本参数是一个本地文件目录,包含了用来训练和生成的原始照片
85 | ./processed: 预处理之后的图片文件夹,这个参数需要在推理中被传入相同的值,可以不修改
86 | ./output: 训练生成保存模型weights的文件夹,可以不修改
87 | ```
88 |
89 | 等待5-20分钟即可训练完成。用户也可以调节其他训练超参数,训练支持的超参数可以查看`train_lora.sh`的配置,或者`facechain/train_text_to_image_lora.py`中的完整超参数列表。
90 |
91 | 进行推理时,请编辑run_inference.py中的代码:
92 |
93 | ```python
94 | # 填入上述的预处理之后的图片文件夹,需要和训练时相同
95 | processed_dir = './processed'
96 | # 推理生成的图片数量
97 | num_generate = 5
98 | # 训练时使用的stable diffusion基模型,可以不修改
99 | base_model = 'ly261666/cv_portrait_model'
100 | # 该基模型的版本号,可以不修改
101 | revision = 'v2.0'
102 | # 该基模型包含了多个不同风格的子目录,其中使用了film/film目录中的风格模型,可以不修改
103 | base_model_sub_dir = 'film/film'
104 | # 训练生成保存模型weights的文件夹,需要保证和训练时相同
105 | train_output_dir = './output'
106 | # 指定一个保存生成的图片的文件夹,本参数可以根据需要修改
107 | output_dir = './generated'
108 | ```
109 |
110 | 之后执行:
111 |
112 | ```python
113 | python run_inference.py
114 | ```
115 |
116 | 即可在`output_dir`中找到生成的个人数字形象照片。
117 |
118 | # 算法介绍
119 |
120 | ## 基本原理
121 |
122 | 个人写真模型的能力来源于Stable Diffusion模型的文生图功能,输入一段文本或一系列提示词,输出对应的图像。我们考虑影响个人写真生成效果的主要因素:写真风格信息,以及用户人物信息。为此,我们分别使用线下训练的风格LoRA模型和线上训练的人脸LoRA模型以学习上述信息。LoRA是一种具有较少可训练参数的微调模型,在Stable Diffusion中,可以通过对少量输入图像进行文生图训练的方式将输入图像的信息注入到LoRA模型中。因此,个人写真模型的能力分为训练与推断两个阶段,训练阶段生成用于微调Stable Diffusion模型的图像与文本标签数据,得到人脸LoRA模型;推断阶段基于人脸LoRA模型和风格LoRA模型生成个人写真图像。
123 |
124 | 
125 |
126 | ## 训练阶段
127 |
128 | 输入:用户上传的包含清晰人脸区域的图像
129 |
130 | 输出:人脸LoRA模型
131 |
132 | 描述:首先,我们分别使用基于朝向判断的图像旋转模型,以及基于人脸检测和关键点模型的人脸精细化旋转方法处理用户上传图像,得到包含正向人脸的图像;接下来,我们使用人体解析模型和人像美肤模型,以获得高质量的人脸训练图像;随后,我们使用人脸属性模型和文本标注模型,结合标签后处理方法,产生训练图像的精细化标签;最后,我们使用上述图像和标签数据微调Stable Diffusion模型得到人脸LoRA模型。
133 |
134 | ## 推断阶段
135 |
136 | 输入:训练阶段用户上传图像,预设的用于生成个人写真的输入提示词
137 |
138 | 输出:个人写真图像
139 |
140 | 描述:首先,我们将人脸LoRA模型和风格LoRA模型的权重融合到Stable Diffusion模型中;接下来,我们使用Stable Diffusion模型的文生图功能,基于预设的输入提示词初步生成个人写真图像;随后,我们使用人脸融合模型进一步改善上述写真图像的人脸细节,其中用于融合的模板人脸通过人脸质量评估模型在训练图像中挑选;最后,我们使用人脸识别模型计算生成的写真图像与模板人脸的相似度,以此对写真图像进行排序,并输出排名靠前的个人写真图像作为最终输出结果。
141 |
142 | ## 模型列表
143 |
144 | 附(流程图中模型链接)
145 |
146 | [1] 人脸检测+关键点模型DamoFD:https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damof
147 |
148 | [2] 图像旋转模型:创空间内置模型
149 |
150 | [3] 人体解析模型M2FP:https://modelscope.cn/models/damo/cv_resnet101_image-multiple-human-parsing
151 |
152 | [4] 人像美肤模型ABPN:https://modelscope.cn/models/damo/cv_unet_skin-retouching
153 |
154 | [5] 人脸属性模型FairFace:https://modelscope.cn/models/damo/cv_resnet34_face-attribute-recognition_fairface
155 |
156 | [6] 文本标注模型Deepbooru:https://github.com/KichangKim/DeepDanbooru
157 |
158 | [7] 模板脸筛选模型FQA:https://modelscope.cn/models/damo/cv_manual_face-quality-assessment_fqa
159 |
160 | [8] 人脸融合模型:https://modelscope.cn/models/damo/cv_unet-image-face-fusion_damo
161 |
162 | [9] 人脸识别模型RTS:https://modelscope.cn/models/damo/cv_ir_face-recognition-ood_rts
163 |
164 | # 更多信息
165 |
166 | - [ModelScope library](https://github.com/modelscope/modelscope/)
167 |
168 | ModelScope Library是一个托管于github上的模型生态仓库,隶属于达摩院魔搭项目。
169 |
170 | - [贡献模型到ModelScope](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88)
171 |
172 | # License
173 |
174 | This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
175 |
176 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
FaceChain
6 |
7 |
8 | # Introduction
9 |
10 | 如果您熟悉中文,可以阅读[中文版本的README](./README_ZH.md)。
11 |
12 | FaceChain is a deep-learning toolchain for generating your Digital-Twin. With a minimum of 1 portrait-photo, you can create a Digital-Twin of your own and to create personal photos in different settings (work photos as starter!). You may train your Digital-Twin model and generate photos via FaceChain's Python scripts, or via the familiar Gradio interface. You can also experience FaceChain directly with our [ModelScope Studio](https://modelscope.cn/studios/CVstudio/cv_human_portrait/summary).
13 |
14 | FaceChain is powered by [ModelScope](https://github.com/modelscope/modelscope).
15 |
16 | 
17 |
18 | 
19 |
20 | 
21 |
22 | # Installation
23 |
24 | You may use pip and conda to build a local python environment. We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage your dependencies. After installation, execute the following commands:
25 |
26 | ```shell
27 | conda create -n facechain python=3.8 # python version >= 3.8
28 | conda activate facechain
29 |
30 | pip3 install -r requirements.txt
31 | pip3 install -U openmim
32 | mim install mmcv-full==1.7.0
33 | ````
34 |
35 | You may use the official docker-image provided by ModelScope:
36 |
37 | ```shell
38 | registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.7.1-py38-torch2.0.1-tf1.15.5-1.8.0
39 | ```
40 | With this docker image, the only thing you need to install is Gradio.
41 |
42 | For online training an inference, you may leverage the ModelScope [notebook](https://www.modelscope.cn/my/mynotebook/) to start the process immediately.
43 |
44 | Clone the repo:
45 |
46 | ```shell
47 | GIT_LFS_SKIP_SMUDGE=1 git clone https://github.com/modelscope/facechain.git
48 | cd facechain
49 | ```
50 |
51 | Install dependencies:
52 |
53 | ```shell
54 | # If you use the official docker image, you only need to execute
55 | pip3 install gradio
56 |
57 | # If you use the conda env, please refer to section Installation.
58 | ```
59 |
60 | Launch Gradio to generate personal digital images:
61 |
62 | ```shell
63 | python app.py
64 | ```
65 |
66 | You can reference to the Gradio startup log in the log. Once the hyper-link is displayed, copy it to your browser for access. Then click on "Select Image Upload" on the page, and select at least one picture containing a face. Click "Start Training" to train the model. After the training is completed, there will be a corresponding display in the log. Afterward, switch to the "Image Experience" tab and click "Start Inference" to generate your own digital image.
67 |
68 | # Script Execution
69 |
70 | FaceChain supports direct training and inference in the python environment. Run the following command in the cloned folder to start training:
71 |
72 | ```shell
73 | PYTHONPATH=. sh train_lora.sh "ly261666/cv_portrait_model" "v2.0" "film/film" "./imgs" "./processed" "./output"
74 | ```
75 |
76 | Parameter meaning:
77 |
78 | ```text
79 | ly261666/cv_portrait_model: The stable diffusion base model of the ModelScope model hub, which will be used for training, no need to be changed.
80 | v2.0: The version number of this base model, no need to be changed
81 | film/film: This base model may contains multiple subdirectories of different styles, currently we use film/film, no need to be changed
82 | ./imgs: This parameter needs to be replaced with the actual value. It means a local file directory that contains the original photos used for training and generation
83 | ./processed: The folder of the processed images after preprocessing, this parameter needs to be passed the same value in inference, no need to be changed
84 | ./output: The folder where the model weights stored after training, no need to be changed
85 | ```
86 |
87 | Wait for 5-20 minutes to complete the training. Users can also adjust other training hyperparameters. The hyperparameters supported by training can be viewed in the file of `train_lora.sh`, or the complete hyperparameter list in `facechain/train_text_to_image_lora.py`.
88 |
89 | When inferring, please edit the code in run_inference.py:
90 |
91 | ```python
92 | # Fill in the folder of the images after preprocessing above, it should be the same as during training
93 | processed_dir = './processed'
94 | # The number of images to generate in inference
95 | num_generate = 5
96 | # The stable diffusion base model used in training, no need to be changed
97 | base_model = 'ly261666/cv_portrait_model'
98 | # The version number of this base model, no need to be changed
99 | revision = 'v2.0'
100 | # This base model may contains multiple subdirectories of different styles, currently we use film/film, no need to be changed
101 | base_model_sub_dir = 'film/film'
102 | # The folder where the model weights stored after training, it must be the same as during training
103 | train_output_dir = './output'
104 | # Specify a folder to save the generated images, this parameter can be modified as needed
105 | output_dir = './generated'
106 | ```
107 |
108 | Then execute:
109 |
110 | ```shell
111 | python run_inference.py
112 | ```
113 |
114 | You can find the generated personal digital image photos in the `output_dir`.
115 |
116 | # Algorithm Introduction
117 |
118 | ## Principle
119 |
120 | The ability of the personal portrait model comes from the text generation image function of the Stable Diffusion model. It inputs a piece of text or a series of prompt words and outputs corresponding images. We consider the main factors that affect the generation effect of personal portraits: portrait style information and user character information. For this, we use the style LoRA model trained offline and the face LoRA model trained online to learn the above information. LoRA is a fine-tuning model with fewer trainable parameters. In Stable Diffusion, the information of the input image can be injected into the LoRA model by the way of text generation image training with a small amount of input image. Therefore, the ability of the personal portrait model is divided into training and inference stages. The training stage generates image and text label data for fine-tuning the Stable Diffusion model, and obtains the face LoRA model. The inference stage generates personal portrait images based on the face LoRA model and style LoRA model.
121 |
122 | 
123 |
124 | ## Training
125 |
126 | Input: User-uploaded images that contain clear face areas
127 |
128 | Output: Face LoRA model
129 |
130 | Description: First, we process the user-uploaded images using an image rotation model based on orientation judgment and a face refinement rotation method based on face detection and keypoint models, and obtain images containing forward faces. Next, we use a human body parsing model and a human portrait beautification model to obtain high-quality face training images. Afterwards, we use a face attribute model and a text annotation model, combined with tag post-processing methods, to generate fine-grained labels for training images. Finally, we use the above images and label data to fine-tune the Stable Diffusion model to obtain the face LoRA model.
131 |
132 | ## Inference
133 |
134 | Input: User-uploaded images in the training phase, preset input prompt words for generating personal portraits
135 |
136 | Output: Personal portrait image
137 |
138 | Description: First, we fuse the weights of the face LoRA model and style LoRA model into the Stable Diffusion model. Next, we use the text generation image function of the Stable Diffusion model to preliminarily generate personal portrait images based on the preset input prompt words. Then we further improve the face details of the above portrait image using the face fusion model. The template face used for fusion is selected from the training images through the face quality evaluation model. Finally, we use the face recognition model to calculate the similarity between the generated portrait image and the template face, and use this to sort the portrait images, and output the personal portrait image that ranks first as the final output result.
139 |
140 | ## Model List
141 |
142 | The models used in FaceChain:
143 |
144 | [1] Face detection model DamoFD:https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damof
145 |
146 | [2] Image rotating model, offered in the ModelScope studio
147 |
148 | [3] Human parsing model M2FP:https://modelscope.cn/models/damo/cv_resnet101_image-multiple-human-parsing
149 |
150 | [4] Skin retouching model ABPN:https://modelscope.cn/models/damo/cv_unet_skin-retouching
151 |
152 | [5] Face attribute recognition model FairFace:https://modelscope.cn/models/damo/cv_resnet34_face-attribute-recognition_fairface
153 |
154 | [6] DeepDanbooru model:https://github.com/KichangKim/DeepDanbooru
155 |
156 | [7] Face quality assessment FQA:https://modelscope.cn/models/damo/cv_manual_face-quality-assessment_fqa
157 |
158 | [8] Face fusion model:https://modelscope.cn/models/damo/cv_unet-image-face-fusion_damo
159 |
160 | [9] Face recognition model RTS:https://modelscope.cn/models/damo/cv_ir_face-recognition-ood_rts
161 |
162 | # More Information
163 |
164 | - [ModelScope library](https://github.com/modelscope/modelscope/)
165 |
166 |
167 | ModelScope Library provides the foundation for building the model-ecosystem of ModelScope, including the interface and implementation to integrate various models into ModelScope.
168 |
169 | - [Contribute models to ModelScope](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88)
170 |
171 | # License
172 |
173 | This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
174 |
--------------------------------------------------------------------------------
/facechain/inference.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 | import json
3 | import os
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from diffusers import StableDiffusionPipeline
10 | from modelscope.outputs import OutputKeys
11 | from modelscope.pipelines import pipeline
12 | from modelscope.utils.constant import Tasks
13 | from modelscope import snapshot_download
14 |
15 | from facechain.merge_lora import merge_lora
16 | from facechain.data_process.preprocessing import Blipv2
17 |
18 |
19 | def data_process_fn(input_img_dir, use_data_process):
20 | ## TODO add face quality filter
21 | if use_data_process:
22 | ## TODO
23 | data_process_fn = Blipv2()
24 | out_json_name = data_process_fn(input_img_dir)
25 | return out_json_name
26 | else:
27 | return os.path.join(str(input_img_dir) + '_labeled', "metadata.jsonl")
28 |
29 |
30 | def txt2img(pipe, pos_prompt, neg_prompt, num_images=10):
31 | images_out = []
32 | for i in range(int(num_images / 5)):
33 | images_style = pipe(prompt=pos_prompt, height=512, width=512, guidance_scale=7, negative_prompt=neg_prompt,
34 | num_inference_steps=40, num_images_per_prompt=5).images
35 | images_out.extend(images_style)
36 | return images_out
37 |
38 |
39 | def main_diffusion_inference(input_img_dir, base_model_path, style_model_path, lora_model_path, multiplier_style=0.25,
40 | multiplier_human=1.0):
41 | pipe = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
42 | neg_prompt = 'nsfw, paintings, sketches, (worst quality:2), (low quality:2) lowers, normal quality, ((monochrome)), ((grayscale)), logo, word, character'
43 | pos_prompt = 'raw photo, masterpiece, chinese, simple background, wearing high-class business/working suit, high-class pure color background, solo, medium shot, high detail face, looking straight into the camera with shoulders parallel to the frame, slim body, photorealistic, best quality'
44 | lora_style_path = style_model_path
45 | lora_human_path = lora_model_path
46 | pipe = merge_lora(pipe, lora_style_path, multiplier_style, from_safetensor=True)
47 | pipe = merge_lora(pipe, lora_human_path, multiplier_human, from_safetensor=False)
48 | train_dir = str(input_img_dir) + '_labeled'
49 | add_prompt_style = []
50 | f = open(os.path.join(train_dir, 'metadata.jsonl'), 'r')
51 | tags_all = []
52 | cnt = 0
53 | cnts_trigger = np.zeros(6)
54 | for line in f:
55 | cnt += 1
56 | data = json.loads(line)['text'].split(', ')
57 | tags_all.extend(data)
58 | if data[1] == 'a boy':
59 | cnts_trigger[0] += 1
60 | elif data[1] == 'a girl':
61 | cnts_trigger[1] += 1
62 | elif data[1] == 'a handsome man':
63 | cnts_trigger[2] += 1
64 | elif data[1] == 'a beautiful woman':
65 | cnts_trigger[3] += 1
66 | elif data[1] == 'a mature man':
67 | cnts_trigger[4] += 1
68 | elif data[1] == 'a mature woman':
69 | cnts_trigger[5] += 1
70 | else:
71 | print('Error.')
72 | f.close()
73 |
74 | attr_idx = np.argmax(cnts_trigger)
75 | trigger_styles = ['a boy, children, ', 'a girl, children, ', 'a handsome man, ', 'a beautiful woman, ',
76 | 'a mature man, ', 'a mature woman, ']
77 | trigger_style = ', ' + trigger_styles[attr_idx]
78 | if attr_idx == 2 or attr_idx == 4:
79 | neg_prompt += ', children'
80 |
81 | for tag in tags_all:
82 | if tags_all.count(tag) > 0.5 * cnt:
83 | if ('hair' in tag or 'face' in tag or 'mouth' in tag or 'skin' in tag or 'smile' in tag):
84 | if not tag in add_prompt_style:
85 | add_prompt_style.append(tag)
86 |
87 | if len(add_prompt_style) > 0:
88 | add_prompt_style = ", ".join(add_prompt_style) + ', '
89 | else:
90 | add_prompt_style = ''
91 | # trigger_style = trigger_style + 'with face, '
92 | # pos_prompt = 'Generate a standard ID photo of a chinese {}, solo, wearing high-class business/working suit, beautiful smooth face, with high-class/simple pure color background, looking straight into the camera with shoulders parallel to the frame, smile, high detail face, best quality, photorealistic'.format(gender)
93 | pipe = pipe.to("cuda")
94 | # print(trigger_style + add_prompt_style + pos_prompt)
95 | images_style = txt2img(pipe, trigger_style + add_prompt_style + pos_prompt, neg_prompt, num_images=10)
96 | return images_style
97 |
98 |
99 | def stylization_fn(use_stylization, rank_results):
100 | if use_stylization:
101 | ## TODO
102 | pass
103 | else:
104 | return rank_results
105 |
106 |
107 | def main_model_inference(use_main_model, input_img_dir=None, base_model_path=None, lora_model_path=None):
108 | if use_main_model:
109 | model_dir = snapshot_download('Cherrytest/zjz_mj_jiyi_small_addtxt_fromleo', revision='v1.0.0')
110 | style_model_path = os.path.join(model_dir, 'zjz_mj_jiyi_small_addtxt_fromleo.safetensors')
111 | image = main_diffusion_inference(input_img_dir, base_model_path, style_model_path, lora_model_path)
112 | return image
113 |
114 |
115 | def select_high_quality_face(input_img_dir):
116 | input_img_dir = str(input_img_dir) + '_labeled'
117 | quality_score_list = []
118 | abs_img_path_list = []
119 | ## TODO
120 | face_quality_func = pipeline(Tasks.face_quality_assessment, 'damo/cv_manual_face-quality-assessment_fqa')
121 |
122 | for img_name in os.listdir(input_img_dir):
123 | if img_name.endswith('jsonl') or img_name.startswith('.ipynb'):
124 | continue
125 | abs_img_name = os.path.join(input_img_dir, img_name)
126 | face_quality_score = face_quality_func(abs_img_name)[OutputKeys.SCORES]
127 | if face_quality_score is None:
128 | quality_score_list.append(0)
129 | else:
130 | quality_score_list.append(face_quality_score[0])
131 | abs_img_path_list.append(abs_img_name)
132 |
133 | sort_idx = np.argsort(quality_score_list)[::-1]
134 | print('Selected face: ' + abs_img_path_list[sort_idx[0]])
135 |
136 | return Image.open(abs_img_path_list[sort_idx[0]])
137 |
138 |
139 | def face_swap_fn(use_face_swap, gen_results, template_face):
140 | if use_face_swap:
141 | ## TODO
142 | out_img_list = []
143 | image_face_fusion = pipeline(Tasks.image_face_fusion,
144 | model='damo/cv_unet-image-face-fusion_damo')
145 | for img in gen_results:
146 | result = image_face_fusion(dict(template=img, user=template_face))[OutputKeys.OUTPUT_IMG]
147 | out_img_list.append(result)
148 |
149 | return out_img_list
150 | else:
151 | ret_results = []
152 | for img in gen_results:
153 | ret_results.append(cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR))
154 | return ret_results
155 |
156 |
157 | def post_process_fn(use_post_process, swap_results_ori, selected_face, num_gen_images):
158 | if use_post_process:
159 | sim_list = []
160 | ## TODO
161 | # face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_vit_face-recognition')
162 | face_recognition_func = pipeline(Tasks.face_recognition, 'damo/cv_ir_face-recognition-ood_rts')
163 | face_det_func = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd')
164 | swap_results = []
165 | for img in swap_results_ori:
166 | result_det = face_det_func(img)
167 | bboxes = result_det['boxes']
168 | if len(bboxes) == 1:
169 | bbox = bboxes[0]
170 | lenface = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
171 | if 160 < lenface < 360:
172 | swap_results.append(img)
173 |
174 | select_face_emb = face_recognition_func(selected_face)[OutputKeys.IMG_EMBEDDING][0]
175 |
176 | for img in swap_results:
177 | emb = face_recognition_func(img)[OutputKeys.IMG_EMBEDDING]
178 | if emb is None or select_face_emb is None:
179 | sim_list.append(0)
180 | else:
181 | sim = np.dot(emb, select_face_emb)
182 | sim_list.append(sim.item())
183 | sort_idx = np.argsort(sim_list)[::-1]
184 |
185 | return np.array(swap_results)[sort_idx[:min(int(num_gen_images), len(swap_results))]]
186 | else:
187 | return np.array(swap_results_ori)
188 |
189 |
190 | class GenPortrait:
191 | def __init__(self, use_main_model=True, use_face_swap=True,
192 | use_post_process=True, use_stylization=True):
193 | self.use_main_model = use_main_model
194 | self.use_face_swap = use_face_swap
195 | self.use_post_process = use_post_process
196 | self.use_stylization = use_stylization
197 |
198 | def __call__(self, input_img_dir, num_gen_images=6, base_model_path=None,
199 | lora_model_path=None, sub_path=None, revision=None):
200 | base_model_path = snapshot_download(base_model_path, revision=revision)
201 | if sub_path is not None and len(sub_path) > 0:
202 | base_model_path = os.path.join(base_model_path, sub_path)
203 |
204 | # main_model_inference PIL
205 | gen_results = main_model_inference(self.use_main_model, input_img_dir=input_img_dir,
206 | lora_model_path=lora_model_path, base_model_path=base_model_path)
207 | # select_high_quality_face PIL
208 | selected_face = select_high_quality_face(input_img_dir)
209 | # face_swap cv2
210 | swap_results = face_swap_fn(self.use_face_swap, gen_results, selected_face)
211 | # pose_process
212 | rank_results = post_process_fn(self.use_post_process, swap_results, selected_face,
213 | num_gen_images=num_gen_images)
214 | # stylization
215 | final_gen_results = stylization_fn(self.use_stylization, rank_results)
216 |
217 | return final_gen_results
218 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 | import enum
3 | import os
4 | import shutil
5 | import sys
6 | import time
7 | from concurrent.futures import ThreadPoolExecutor
8 |
9 | import cv2
10 | import gradio as gr
11 | import numpy as np
12 | import torch
13 |
14 | from facechain.inference import GenPortrait
15 |
16 | sys.path.append('facechain')
17 |
18 |
19 | training_threadpool = ThreadPoolExecutor(max_workers=1)
20 | inference_threadpool = ThreadPoolExecutor(max_workers=5)
21 |
22 | training_done_count = 0
23 | inference_done_count = 0
24 |
25 | HOT_MODELS = [
26 | "\N{fire}数字身份",
27 | ]
28 |
29 |
30 | class UploadTarget(enum.Enum):
31 | PERSONAL_PROFILE = 'Personal Profile'
32 | LORA_LIaBRARY = 'LoRA Library'
33 |
34 |
35 | def concatenate_images(images):
36 | heights = [img.shape[0] for img in images]
37 | max_width = sum([img.shape[1] for img in images])
38 |
39 | concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8)
40 | x_offset = 0
41 | for img in images:
42 | concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img
43 | x_offset += img.shape[1]
44 | return concatenated_image
45 |
46 |
47 | def train_lora_fn(foundation_model_path=None, revision=None, input_img_dir=None, output_img_dir=None, work_dir=None):
48 | sh_file_path = os.path.join('/'.join(os.path.abspath(__file__).split('/')[:-1]),
49 | 'train_lora.sh')
50 | os.system(f'PYTHONPATH=. sh {sh_file_path} {foundation_model_path} {revision} "film/film" {input_img_dir} {output_img_dir} {work_dir}')
51 |
52 |
53 | def launch_pipeline(uuid,
54 | user_models,
55 | num_images=1,
56 | ):
57 | base_model = 'ly261666/cv_portrait_model'
58 | before_queue_size = inference_threadpool._work_queue.qsize()
59 | before_done_count = inference_done_count
60 |
61 | print("-------user_models: ", user_models)
62 | if not uuid:
63 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
64 | return "请登陆后使用! "
65 | else:
66 | uuid = 'qw'
67 |
68 | use_main_model = True
69 | use_face_swap = True
70 | use_post_process = True
71 | use_stylization = False
72 |
73 | output_model_name = 'personalizaition_lora'
74 | instance_data_dir = os.path.join('/tmp', uuid, 'training_data', output_model_name)
75 |
76 | lora_model_path = f'/tmp/{uuid}/{output_model_name}'
77 |
78 | gen_portrait = GenPortrait(use_main_model, use_face_swap, use_post_process,
79 | use_stylization)
80 |
81 | num_images = min(6, num_images)
82 | future = inference_threadpool.submit(gen_portrait, instance_data_dir,
83 | num_images, base_model, lora_model_path, 'film/film', 'v2.0')
84 |
85 | while not future.done():
86 | is_processing = future.running()
87 | if not is_processing:
88 | cur_done_count = inference_done_count
89 | to_wait = before_queue_size - (cur_done_count - before_done_count)
90 | # yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, round(to_wait*2.5, 5)), None]
91 | yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, to_wait * 2.5), None]
92 | else:
93 | yield ["生成中, 请耐心等待...", None]
94 | time.sleep(1)
95 |
96 | outputs = future.result()
97 | outputs_RGB = []
98 | for out_tmp in outputs:
99 | outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB))
100 | image_path = './lora_result.png'
101 | result = concatenate_images(outputs)
102 | cv2.imwrite(image_path, result)
103 | # image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
104 |
105 | yield ["生成完毕!", outputs_RGB]
106 |
107 |
108 | class Trainer:
109 | def __init__(self):
110 | pass
111 |
112 | def run(
113 | self,
114 | uuid: str,
115 | instance_images: list,
116 | ) -> str:
117 |
118 | if not torch.cuda.is_available():
119 | raise gr.Error('CUDA is not available.')
120 | if instance_images is None:
121 | raise gr.Error('您需要上传训练图片!')
122 | # return "请传完图片后点击<开始训练>! "
123 | if len(instance_images) > 10:
124 | raise gr.Error('您需要上传小于10张训练图片!')
125 | if not uuid:
126 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
127 | return "请登陆后使用! "
128 | else:
129 | uuid = 'qw'
130 |
131 | output_model_name = 'personalizaition_lora'
132 |
133 | # mv user upload data to target dir
134 | instance_data_dir = os.path.join('/tmp', uuid, 'training_data', output_model_name)
135 | print("--------uuid: ", uuid)
136 |
137 | if not os.path.exists(f"/tmp/{uuid}"):
138 | os.makedirs(f"/tmp/{uuid}")
139 | work_dir = f"/tmp/{uuid}/{output_model_name}"
140 | print("----------work_dir: ", work_dir)
141 |
142 | source_img_dir = f"/tmp/{uuid}/sources"
143 | shutil.rmtree(source_img_dir, ignore_errors=True)
144 | os.makedirs(source_img_dir, exist_ok=True)
145 | for img in instance_images:
146 | shutil.copy(img['name'], os.path.join(source_img_dir, os.path.basename(img['name'])))
147 |
148 | # train lora
149 | train_lora_fn(foundation_model_path='ly261666/cv_portrait_model',
150 | revision='v2.0', input_img_dir=source_img_dir,
151 | output_img_dir=instance_data_dir,
152 | work_dir=work_dir)
153 |
154 | message = f'训练已经完成!请切换至 [形象体验] 标签体验模型效果'
155 | print(message)
156 | return message
157 |
158 |
159 | def flash_model_list(uuid):
160 | folder_path = f"/tmp/{uuid}"
161 | folder_list = []
162 | print("------flash_model_list folder_path: ", folder_path)
163 | if not os.path.exists(folder_path):
164 | print('--------The folder_path is missing.')
165 | else:
166 | files = os.listdir(folder_path)
167 | for file in files:
168 | file_path = os.path.join(folder_path, file)
169 | if os.path.isdir(folder_path):
170 | file_lora_path = f"{file_path}/output/pytorch_lora_weights.bin"
171 | if os.path.exists(file_lora_path):
172 | folder_list.append(file)
173 |
174 | print("-------folder_list + HOT_MODELS: ", folder_list + HOT_MODELS)
175 | return gr.Radio.update(choices=HOT_MODELS + folder_list)
176 |
177 |
178 | def upload_file(files):
179 | file_paths = [file.name for file in files]
180 | return file_paths
181 |
182 |
183 | def train_input():
184 | trainer = Trainer()
185 |
186 | with gr.Blocks() as demo:
187 | uuid = gr.Text(label="modelscope_uuid", visible=False)
188 | with gr.Row():
189 | with gr.Column():
190 | with gr.Box():
191 | gr.Markdown('训练数据')
192 | # instance_images = gr.Files(label='Instance images', visible=False)
193 | instance_images = gr.Gallery()
194 | upload_button = gr.UploadButton("选择图片上传", file_types=["image"], file_count="multiple")
195 | upload_button.upload(upload_file, upload_button, instance_images)
196 | gr.Markdown('''
197 | - Step 0. 登陆ModelScope账号,未登录无法使用定制功能
198 | - Step 1. 上传你计划训练的图片,3~10张头肩照(注意:图片中多人脸、脸部遮挡等情况会导致效果异常,需要重新上传符合规范图片训练)
199 | - Step 2. 点击 [形象定制] ,启动模型训练,等待约15分钟,请您耐心等待
200 | - Step 3. 切换至 [形象体验] ,生成你的风格照片
201 | - 注意:生成结果严禁用于非法用途!
202 | ''')
203 |
204 | run_button = gr.Button('开始训练(等待上传图片加载显示出来再点,否则会报错)')
205 |
206 | with gr.Box():
207 | gr.Markdown(
208 | '输出信号(出现error时训练可能已完成或还在进行。可直接切到形象体验tab页面,如果体验时报错则训练还没好,再等待一般10来分钟。)')
209 | output_message = gr.Markdown()
210 | with gr.Box():
211 | gr.Markdown('''
212 | 碰到抓狂的错误或者计算资源紧张的情况下,推荐直接在[NoteBook](https://modelscope.cn/my/mynotebook/preset)上按照如下命令自行体验
213 | 1. git clone https://www.modelscope.cn/studios/CVstudio/cv_human_portrait.git
214 | 2. cd cv_human_portrait
215 | 3. pip install -r requirements.txt
216 | 4. pip install gradio==3.35.2
217 | 5. python app.py
218 | ''')
219 |
220 | run_button.click(fn=trainer.run,
221 | inputs=[
222 | uuid,
223 | instance_images,
224 | ],
225 | outputs=[output_message])
226 |
227 | return demo
228 |
229 |
230 | def inference_input():
231 | with gr.Blocks() as demo:
232 | uuid = gr.Text(label="modelscope_uuid", visible=False)
233 | with gr.Row():
234 | with gr.Column():
235 | # user_models = gr.Radio(label="风格选择", choices=['\N{fire}商务证件'], type="value", value='\N{fire}商务证件')
236 | user_models = gr.Radio(label="模型选择", choices=HOT_MODELS, type="value", value=HOT_MODELS[0])
237 | # flash_button = gr.Button('刷新模型列表')
238 |
239 | with gr.Box():
240 | num_images = gr.Number(
241 | label='生成图片数量', value=6, precision=1)
242 | gr.Markdown('''
243 | 注意:最多支持生成6张图片!
244 | ''')
245 |
246 | display_button = gr.Button('开始推理')
247 |
248 | with gr.Box():
249 | infer_progress = gr.Textbox(label="生成进度", value="当前无生成任务", interactive=False)
250 | with gr.Box():
251 | gr.Markdown('生成结果')
252 | # output_image = gr.Image()
253 | output_images = gr.Gallery(label='Output', show_label=False).style(columns=3, rows=2, height=600,
254 | object_fit="contain")
255 |
256 | display_button.click(fn=launch_pipeline,
257 | inputs=[uuid, user_models, num_images],
258 | outputs=[infer_progress, output_images])
259 |
260 | return demo
261 |
262 |
263 | with gr.Blocks(css='style.css') as demo:
264 | with gr.Tabs():
265 | with gr.TabItem('\N{rocket}形象定制'):
266 | train_input()
267 | with gr.TabItem('\N{party popper}形象体验'):
268 | inference_input()
269 |
270 | # demo.queue(max_size=100).launch(share=False)
271 | # demo.queue(concurrency_count=20).launch(share=False)
272 | demo.queue(status_update_rate=1).launch(share=True)
273 |
--------------------------------------------------------------------------------
/app-win10.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 | import enum
3 | import os
4 | import shutil
5 | import sys
6 | import time
7 | from concurrent.futures import ThreadPoolExecutor
8 |
9 | import cv2
10 | import gradio as gr
11 | import numpy as np
12 | import torch
13 |
14 | from facechain.inference import GenPortrait
15 |
16 | sys.path.append('facechain')
17 | os.environ["PYTHONPATH"] = os.path.dirname(os.path.abspath(__file__))
18 | # print("PYTHONPATH:", os.environ["PYTHONPATH"])
19 |
20 | training_threadpool = ThreadPoolExecutor(max_workers=1)
21 | inference_threadpool = ThreadPoolExecutor(max_workers=5)
22 |
23 | training_done_count = 0
24 | inference_done_count = 0
25 |
26 | HOT_MODELS = [
27 | "\N{fire}数字身份",
28 | ]
29 |
30 |
31 | class UploadTarget(enum.Enum):
32 | PERSONAL_PROFILE = 'Personal Profile'
33 | LORA_LIaBRARY = 'LoRA Library'
34 |
35 |
36 | def concatenate_images(images):
37 | heights = [img.shape[0] for img in images]
38 | max_width = sum([img.shape[1] for img in images])
39 |
40 | concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8)
41 | x_offset = 0
42 | for img in images:
43 | concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img
44 | x_offset += img.shape[1]
45 | return concatenated_image
46 |
47 |
48 | # def train_lora_fn(foundation_model_path=None, revision=None, input_img_dir=None, output_img_dir=None, work_dir=None):
49 | # sh_file_path = os.path.join('/'.join(os.path.abspath(__file__).split('/')[:-1]),
50 | # 'train_lora.sh')
51 | # os.system(f'PYTHONPATH=. sh {sh_file_path} {foundation_model_path} {revision} "film/film" {input_img_dir} {output_img_dir} {work_dir}')
52 |
53 |
54 | def train_lora_fn(foundation_model_path=None, revision=None, input_img_dir=None, output_img_dir=None, work_dir=None):
55 | bat_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_lora.bat')
56 |
57 | os.system(f'{bat_file_path} {foundation_model_path} {revision} "film/film" {input_img_dir} {output_img_dir} {work_dir}')
58 |
59 |
60 | def launch_pipeline(uuid,
61 | user_models,
62 | num_images=1,
63 | ):
64 | base_model = 'ly261666/cv_portrait_model'
65 | before_queue_size = inference_threadpool._work_queue.qsize()
66 | before_done_count = inference_done_count
67 |
68 | print("-------user_models: ", user_models)
69 | if not uuid:
70 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
71 | return "请登陆后使用! "
72 | else:
73 | uuid = 'qw'
74 |
75 | use_main_model = True
76 | use_face_swap = True
77 | use_post_process = True
78 | use_stylization = False
79 |
80 | output_model_name = 'personalizaition_lora'
81 | instance_data_dir = os.path.join('D:\\AI\\', uuid, 'training_data', output_model_name)
82 |
83 | lora_model_path = f'D:\\AI\\{uuid}\{output_model_name}'
84 |
85 | gen_portrait = GenPortrait(use_main_model, use_face_swap, use_post_process,
86 | use_stylization)
87 |
88 | num_images = min(6, num_images)
89 | future = inference_threadpool.submit(gen_portrait, instance_data_dir,
90 | num_images, base_model, lora_model_path, 'film/film', 'v2.0')
91 |
92 | while not future.done():
93 | is_processing = future.running()
94 | if not is_processing:
95 | cur_done_count = inference_done_count
96 | to_wait = before_queue_size - (cur_done_count - before_done_count)
97 | # yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, round(to_wait*2.5, 5)), None]
98 | yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, to_wait * 2.5), None]
99 | else:
100 | yield ["生成中, 请耐心等待...", None]
101 | time.sleep(1)
102 |
103 | outputs = future.result()
104 | outputs_RGB = []
105 | for out_tmp in outputs:
106 | outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB))
107 | image_path = '.\lora_result.png'
108 | result = concatenate_images(outputs)
109 | cv2.imwrite(image_path, result)
110 | # image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
111 |
112 | yield ["生成完毕!", outputs_RGB]
113 |
114 |
115 | class Trainer:
116 | def __init__(self):
117 | pass
118 |
119 | def run(
120 | self,
121 | uuid: str,
122 | instance_images: list,
123 | ) -> str:
124 |
125 | if not torch.cuda.is_available():
126 | raise gr.Error('CUDA is not available.')
127 | if instance_images is None:
128 | raise gr.Error('您需要上传训练图片!')
129 | # return "请传完图片后点击<开始训练>! "
130 | if len(instance_images) > 10:
131 | raise gr.Error('您需要上传小于10张训练图片!')
132 | if not uuid:
133 | if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
134 | return "请登陆后使用! "
135 | else:
136 | uuid = 'qw'
137 |
138 | output_model_name = 'personalizaition_lora'
139 |
140 | # mv user upload data to target dir
141 | instance_data_dir = os.path.join('D:\\AI\\', uuid, 'training_data', output_model_name)
142 | print("--------uuid: ", uuid)
143 |
144 | if not os.path.exists(f"D:\\AI\\{uuid}"):
145 | os.makedirs(f"D:\\AI\\{uuid}")
146 | work_dir = f"D:\\AI\\{uuid}\\{output_model_name}"
147 | print("----------work_dir: ", work_dir)
148 |
149 | source_img_dir = f"D:\\AI\\{uuid}\\sources"
150 | shutil.rmtree(source_img_dir, ignore_errors=True)
151 | os.makedirs(source_img_dir, exist_ok=True)
152 | for img in instance_images:
153 | shutil.copy(img['name'], os.path.join(source_img_dir, os.path.basename(img['name'])))
154 |
155 | # train lora
156 | train_lora_fn(foundation_model_path='ly261666/cv_portrait_model',
157 | revision='v2.0', input_img_dir=source_img_dir,
158 | output_img_dir=instance_data_dir,
159 | work_dir=work_dir)
160 |
161 | message = f'训练已经完成!请切换至 [形象体验] 标签体验模型效果'
162 | print(message)
163 | return message
164 |
165 |
166 | def flash_model_list(uuid):
167 | folder_path = f"D:\\AI\\{uuid}"
168 | folder_list = []
169 | print("------flash_model_list folder_path: ", folder_path)
170 | if not os.path.exists(folder_path):
171 | print('--------The folder_path is missing.')
172 | else:
173 | files = os.listdir(folder_path)
174 | for file in files:
175 | file_path = os.path.join(folder_path, file)
176 | if os.path.isdir(folder_path):
177 | file_lora_path = f"{file_path}\\output\\pytorch_lora_weights.bin"
178 | if os.path.exists(file_lora_path):
179 | folder_list.append(file)
180 |
181 | print("-------folder_list + HOT_MODELS: ", folder_list + HOT_MODELS)
182 | return gr.Radio.update(choices=HOT_MODELS + folder_list)
183 |
184 |
185 | def upload_file(files):
186 | file_paths = [file.name for file in files]
187 | return file_paths
188 |
189 |
190 | def train_input():
191 | trainer = Trainer()
192 |
193 | with gr.Blocks() as demo:
194 | uuid = gr.Text(label="modelscope_uuid", visible=False)
195 | with gr.Row():
196 | with gr.Column():
197 | with gr.Box():
198 | gr.Markdown('训练数据')
199 | # instance_images = gr.Files(label='Instance images', visible=False)
200 | instance_images = gr.Gallery()
201 | upload_button = gr.UploadButton("选择图片上传", file_types=["image"], file_count="multiple")
202 | upload_button.upload(upload_file, upload_button, instance_images)
203 | gr.Markdown('''
204 | - Step 0. 登陆ModelScope账号,未登录无法使用定制功能
205 | - Step 1. 上传你计划训练的图片,3~10张头肩照(注意:图片中多人脸、脸部遮挡等情况会导致效果异常,需要重新上传符合规范图片训练)
206 | - Step 2. 点击 [形象定制] ,启动模型训练,等待约15分钟,请您耐心等待
207 | - Step 3. 切换至 [形象体验] ,生成你的风格照片
208 | - 注意:生成结果严禁用于非法用途!
209 | ''')
210 |
211 | run_button = gr.Button('开始训练(等待上传图片加载显示出来再点,否则会报错)')
212 |
213 | with gr.Box():
214 | gr.Markdown(
215 | '输出信号(出现error时训练可能已完成或还在进行。可直接切到形象体验tab页面,如果体验时报错则训练还没好,再等待一般10来分钟。)')
216 | output_message = gr.Markdown()
217 | with gr.Box():
218 | gr.Markdown('''
219 | 碰到抓狂的错误或者计算资源紧张的情况下,推荐直接在[NoteBook](https://modelscope.cn/my/mynotebook/preset)上按照如下命令自行体验
220 | 1. git clone https://www.modelscope.cn/studios/CVstudio/cv_human_portrait.git
221 | 2. cd cv_human_portrait
222 | 3. pip install -r requirements.txt
223 | 4. pip install gradio==3.35.2
224 | 5. python app.py
225 | ''')
226 |
227 | run_button.click(fn=trainer.run,
228 | inputs=[
229 | uuid,
230 | instance_images,
231 | ],
232 | outputs=[output_message])
233 |
234 | return demo
235 |
236 |
237 | def inference_input():
238 | with gr.Blocks() as demo:
239 | uuid = gr.Text(label="modelscope_uuid", visible=False)
240 | with gr.Row():
241 | with gr.Column():
242 | # user_models = gr.Radio(label="风格选择", choices=['\N{fire}商务证件'], type="value", value='\N{fire}商务证件')
243 | user_models = gr.Radio(label="模型选择", choices=HOT_MODELS, type="value", value=HOT_MODELS[0])
244 | # flash_button = gr.Button('刷新模型列表')
245 |
246 | with gr.Box():
247 | num_images = gr.Number(
248 | label='生成图片数量', value=6, precision=1)
249 | gr.Markdown('''
250 | 注意:最多支持生成6张图片!
251 | ''')
252 |
253 | display_button = gr.Button('开始推理')
254 |
255 | with gr.Box():
256 | infer_progress = gr.Textbox(label="生成进度", value="当前无生成任务", interactive=False)
257 | with gr.Box():
258 | gr.Markdown('生成结果')
259 | # output_image = gr.Image()
260 | output_images = gr.Gallery(label='Output', show_label=False).style(columns=3, rows=2, height=600,
261 | object_fit="contain")
262 |
263 | display_button.click(fn=launch_pipeline,
264 | inputs=[uuid, user_models, num_images],
265 | outputs=[infer_progress, output_images])
266 |
267 | return demo
268 |
269 |
270 | with gr.Blocks(css='style.css') as demo:
271 | with gr.Tabs():
272 | with gr.TabItem('\N{rocket}形象定制'):
273 | train_input()
274 | with gr.TabItem('\N{party popper}形象体验'):
275 | inference_input()
276 |
277 | # demo.queue(max_size=100).launch(share=False)
278 | # demo.queue(concurrency_count=20).launch(share=False)
279 | demo.queue(status_update_rate=1).launch(share=True)
280 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/facechain/data_process/preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 |
3 | import os
4 | import shutil
5 |
6 | import numpy as np
7 | import json
8 | import math
9 |
10 | from .deepbooru import DeepDanbooru
11 |
12 | import cv2
13 | from PIL import Image
14 | from modelscope.pipelines import pipeline
15 | from modelscope.outputs import OutputKeys
16 | from modelscope.utils.constant import Tasks
17 |
18 |
19 | def crop_and_resize(im, bbox, thres=0.35, thres1=0.45):
20 | h, w, _ = im.shape
21 | thre = np.random.rand() * (thres1 - thres) + thres
22 | maxf = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
23 | cx = (bbox[2] + bbox[0]) / 2
24 | cy = (bbox[3] + bbox[1]) / 2
25 | lenp = int(maxf / thre)
26 | yc = np.random.rand() * 0.15 + 0.35
27 | xc = 0.5
28 | xmin = int(cx - xc * lenp)
29 | xmax = xmin + lenp
30 | ymin = int(cy - yc * lenp)
31 | ymax = ymin + lenp
32 | x1 = 0
33 | x2 = lenp
34 | y1 = 0
35 | y2 = lenp
36 | if xmin < 0:
37 | x1 = -xmin
38 | xmin = 0
39 | if xmax > w:
40 | x2 = w - (xmax - lenp)
41 | xmax = w
42 | if ymin < 0:
43 | y1 = -ymin
44 | ymin = 0
45 | if ymax > h:
46 | y2 = h - (ymax - lenp)
47 | ymax = h
48 | imc = (np.ones((lenp, lenp, 3)) * 255).astype(np.uint8)
49 | imc[y1:y2, x1:x2, :] = im[ymin:ymax, xmin:xmax, :]
50 | imr = cv2.resize(imc, (512, 512))
51 | return imr
52 |
53 |
54 | def pad_to_square(im):
55 | h, w, _ = im.shape
56 | ns = int(max(h, w) * 1.5)
57 | im = cv2.copyMakeBorder(im, int((ns - h) / 2), (ns - h) - int((ns - h) / 2), int((ns - w) / 2),
58 | (ns - w) - int((ns - w) / 2), cv2.BORDER_CONSTANT, 255)
59 | return im
60 |
61 |
62 | def post_process_naive(result_list, score_gender, score_age):
63 | # determine trigger word
64 | gender = np.argmax(score_gender)
65 | age = np.argmax(score_age)
66 | if age < 2:
67 | if gender == 0:
68 | tag_a_g = ['a boy', 'children']
69 | else:
70 | tag_a_g = ['a girl', 'children']
71 | elif age > 4:
72 | if gender == 0:
73 | tag_a_g = ['a mature man']
74 | else:
75 | tag_a_g = ['a mature woman']
76 | else:
77 | if gender == 0:
78 | tag_a_g = ['a handsome man']
79 | else:
80 | tag_a_g = ['a beautiful woman']
81 | num_images = len(result_list)
82 | cnt_girl = 0
83 | cnt_boy = 0
84 | result_list_new = []
85 | for result in result_list:
86 | result_new = []
87 | result_new.extend(tag_a_g)
88 | for tag in result:
89 | if tag == '1girl' or tag == '1boy':
90 | continue
91 | if tag[-4:] == '_man':
92 | continue
93 | if tag[-6:] == '_woman':
94 | continue
95 | if tag[-5:] == '_male':
96 | continue
97 | elif tag[-7:] == '_female':
98 | continue
99 | elif (
100 | tag == 'ears' or tag == 'head' or tag == 'face' or tag == 'lips' or tag == 'mouth' or tag == '3d' or tag == 'asian' or tag == 'teeth'):
101 | continue
102 | elif ('eye' in tag and not 'eyewear' in tag):
103 | continue
104 | elif ('nose' in tag or 'body' in tag):
105 | continue
106 | elif tag[-5:] == '_lips':
107 | continue
108 | else:
109 | result_new.append(tag)
110 | # import pdb;pdb.set_trace()
111 | # result_new.append('slim body')
112 | result_list_new.append(result_new)
113 |
114 | return result_list_new
115 |
116 |
117 | def transformation_from_points(points1, points2):
118 | points1 = points1.astype(np.float64)
119 | points2 = points2.astype(np.float64)
120 | c1 = np.mean(points1, axis=0)
121 | c2 = np.mean(points2, axis=0)
122 | points1 -= c1
123 | points2 -= c2
124 | s1 = np.std(points1)
125 | s2 = np.std(points2)
126 | if s1 < 1.0e-4:
127 | s1 = 1.0e-4
128 | points1 /= s1
129 | points2 /= s2
130 | U, S, Vt = np.linalg.svd(points1.T * points2)
131 | R = (U * Vt).T
132 | return np.vstack([np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), np.matrix([0., 0., 1.])])
133 |
134 |
135 | def rotate(im, keypoints):
136 | h, w, _ = im.shape
137 | points_array = np.zeros((5, 2))
138 | dst_mean_face_size = 160
139 | dst_mean_face = np.asarray([0.31074522411511746, 0.2798131190011913,
140 | 0.6892073313037804, 0.2797830232679366,
141 | 0.49997367716346774, 0.5099309118810921,
142 | 0.35811903020866753, 0.7233174007629063,
143 | 0.6418878095835022, 0.7232890570786875])
144 | dst_mean_face = np.reshape(dst_mean_face, (5, 2)) * dst_mean_face_size
145 |
146 | for k in range(5):
147 | points_array[k, 0] = keypoints[2 * k]
148 | points_array[k, 1] = keypoints[2 * k + 1]
149 |
150 | pts1 = np.float64(np.matrix([[point[0], point[1]] for point in points_array]))
151 | pts2 = np.float64(np.matrix([[point[0], point[1]] for point in dst_mean_face]))
152 | trans_mat = transformation_from_points(pts1, pts2)
153 | if trans_mat[1, 1] > 1.0e-4:
154 | angle = math.atan(trans_mat[1, 0] / trans_mat[1, 1])
155 | else:
156 | angle = math.atan(trans_mat[0, 1] / trans_mat[0, 2])
157 | im = pad_to_square(im)
158 | ns = int(1.5 * max(h, w))
159 | M = cv2.getRotationMatrix2D((ns / 2, ns / 2), angle=-angle / np.pi * 180, scale=1.0)
160 | im = cv2.warpAffine(im, M=M, dsize=(ns, ns))
161 | return im
162 |
163 |
164 | def get_mask_head(result):
165 | masks = result['masks']
166 | scores = result['scores']
167 | labels = result['labels']
168 | mask_hair = np.zeros((512, 512))
169 | mask_face = np.zeros((512, 512))
170 | mask_human = np.zeros((512, 512))
171 | for i in range(len(labels)):
172 | if scores[i] > 0.8:
173 | if labels[i] == 'Face':
174 | if np.sum(masks[i]) > np.sum(mask_face):
175 | mask_face = masks[i]
176 | elif labels[i] == 'Human':
177 | if np.sum(masks[i]) > np.sum(mask_human):
178 | mask_human = masks[i]
179 | elif labels[i] == 'Hair':
180 | if np.sum(masks[i]) > np.sum(mask_hair):
181 | mask_hair = masks[i]
182 | mask_head = np.clip(mask_hair + mask_face, 0, 1)
183 | ksize = max(int(np.sqrt(np.sum(mask_face)) / 20), 1)
184 | kernel = np.ones((ksize, ksize))
185 | mask_head = cv2.dilate(mask_head, kernel, iterations=1) * mask_human
186 | _, mask_head = cv2.threshold((mask_head * 255).astype(np.uint8), 127, 255, cv2.THRESH_BINARY)
187 | contours, hierarchy = cv2.findContours(mask_head, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
188 | area = []
189 | for j in range(len(contours)):
190 | area.append(cv2.contourArea(contours[j]))
191 | max_idx = np.argmax(area)
192 | mask_head = np.zeros((512, 512)).astype(np.uint8)
193 | cv2.fillPoly(mask_head, [contours[max_idx]], 255)
194 | mask_head = mask_head.astype(np.float32) / 255
195 | mask_head = np.clip(mask_head + mask_face, 0, 1)
196 | mask_head = np.expand_dims(mask_head, 2)
197 | return mask_head
198 |
199 |
200 | class Blipv2():
201 | def __init__(self):
202 | self.model = DeepDanbooru()
203 | self.skin_retouching = pipeline(Tasks.skin_retouching, model='damo/cv_unet_skin-retouching')
204 | self.face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd')
205 | # self.mog_face_detection_func = pipeline(Tasks.face_detection, 'damo/cv_resnet101_face-detection_cvpr22papermogface')
206 | self.segmentation_pipeline = pipeline(Tasks.image_segmentation,
207 | 'damo/cv_resnet101_image-multiple-human-parsing')
208 | self.fair_face_attribute_func = pipeline(Tasks.face_attribute_recognition,
209 | 'damo/cv_resnet34_face-attribute-recognition_fairface')
210 | self.facial_landmark_confidence_func = pipeline(Tasks.face_2d_keypoints,
211 | 'damo/cv_manual_facial-landmark-confidence_flcm')
212 |
213 | def __call__(self, imdir):
214 | self.model.start()
215 | savedir = str(imdir) + '_labeled'
216 | shutil.rmtree(savedir, ignore_errors=True)
217 | os.makedirs(savedir, exist_ok=True)
218 |
219 | imlist = os.listdir(imdir)
220 | result_list = []
221 | imgs_list = []
222 |
223 | cnt = 0
224 | tmp_path = os.path.join(savedir, 'tmp.png')
225 | for imname in imlist:
226 | try:
227 | # if 1:
228 | if imname.startswith('.'):
229 | continue
230 | img_path = os.path.join(imdir, imname)
231 | im = cv2.imread(img_path)
232 | h, w, _ = im.shape
233 | max_size = max(w, h)
234 | ratio = 1024 / max_size
235 | new_w = round(w * ratio)
236 | new_h = round(h * ratio)
237 | imt = cv2.resize(im, (new_w, new_h))
238 | cv2.imwrite(tmp_path, imt)
239 | result_det = self.face_detection(tmp_path)
240 | bboxes = result_det['boxes']
241 | if len(bboxes) > 1:
242 | areas = []
243 | for i in range(len(bboxes)):
244 | bbox = bboxes[i]
245 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))
246 | areas = np.array(areas)
247 | areas_new = np.sort(areas)[::-1]
248 | idxs = np.argsort(areas)[::-1]
249 | if areas_new[0] < 4 * areas_new[1]:
250 | print('Detecting multiple faces, do not use image {}.'.format(imname))
251 | continue
252 | else:
253 | keypoints = result_det['keypoints'][idxs[0]]
254 | elif len(bboxes) == 0:
255 | print('Detecting no face, do not use image {}.'.format(imname))
256 | continue
257 | else:
258 | keypoints = result_det['keypoints'][0]
259 |
260 | im = rotate(im, keypoints)
261 | ns = im.shape[0]
262 | imt = cv2.resize(im, (1024, 1024))
263 | cv2.imwrite(tmp_path, imt)
264 | result_det = self.face_detection(tmp_path)
265 | bboxes = result_det['boxes']
266 |
267 | if len(bboxes) > 1:
268 | areas = []
269 | for i in range(len(bboxes)):
270 | bbox = bboxes[i]
271 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))
272 | areas = np.array(areas)
273 | areas_new = np.sort(areas)[::-1]
274 | idxs = np.argsort(areas)[::-1]
275 | if areas_new[0] < 4 * areas_new[1]:
276 | print('Detecting multiple faces after rotation, do not use image {}.'.format(imname))
277 | continue
278 | else:
279 | bbox = bboxes[idxs[0]]
280 | elif len(bboxes) == 0:
281 | print('Detecting no face after rotation, do not use this image {}'.format(imname))
282 | continue
283 | else:
284 | bbox = bboxes[0]
285 |
286 | for idx in range(4):
287 | bbox[idx] = bbox[idx] * ns / 1024
288 | imr = crop_and_resize(im, bbox)
289 | cv2.imwrite(tmp_path, imr)
290 |
291 | result = self.skin_retouching(tmp_path)
292 | if (result is None or (result[OutputKeys.OUTPUT_IMG] is None)):
293 | print('Cannot do skin retouching, do not use this image.')
294 | continue
295 | cv2.imwrite(tmp_path, result[OutputKeys.OUTPUT_IMG])
296 |
297 | result = self.segmentation_pipeline(tmp_path)
298 | mask_head = get_mask_head(result)
299 | im = cv2.imread(tmp_path)
300 | im = im * mask_head + 255 * (1 - mask_head)
301 | # print(im.shape)
302 |
303 | raw_result = self.facial_landmark_confidence_func(im)
304 | if raw_result is None:
305 | print('landmark quality fail...')
306 | continue
307 |
308 | print(imname, raw_result['scores'][0])
309 | if float(raw_result['scores'][0]) < (1 - 0.145):
310 | print('landmark quality fail...')
311 | continue
312 |
313 | cv2.imwrite(os.path.join(savedir, '{}.png'.format(cnt)), im)
314 | imgs_list.append('{}.png'.format(cnt))
315 | img = Image.open(os.path.join(savedir, '{}.png'.format(cnt)))
316 | result = self.model.tag(img)
317 | print(result)
318 | attribute_result = self.fair_face_attribute_func(tmp_path)
319 | if cnt == 0:
320 | score_gender = np.array(attribute_result['scores'][0])
321 | score_age = np.array(attribute_result['scores'][1])
322 | else:
323 | score_gender += np.array(attribute_result['scores'][0])
324 | score_age += np.array(attribute_result['scores'][1])
325 |
326 | result_list.append(result.split(', '))
327 | cnt += 1
328 | except Exception as e:
329 | print('cathed for image process of ' + imname)
330 | print('Error: ' + e)
331 |
332 | print(result_list)
333 | if len(result_list) == 0:
334 | print('Error: ' + e)
335 | exit()
336 | return os.path.join(savedir, "metadata.jsonl")
337 |
338 | result_list = post_process_naive(result_list, score_gender, score_age)
339 | self.model.stop()
340 | os.system('rm ' + tmp_path)
341 |
342 | out_json_name = os.path.join(savedir, "metadata.jsonl")
343 | fo = open(out_json_name, 'w')
344 | for i in range(len(result_list)):
345 | generated_text = ", ".join(result_list[i])
346 | print(imgs_list[i], generated_text)
347 | info_dict = {"file_name": imgs_list[i], "text": ", " + generated_text}
348 | fo.write(json.dumps(info_dict) + '\n')
349 | fo.close()
350 | return out_json_name
351 |
--------------------------------------------------------------------------------
/facechain/data_process/preprocessing -win10.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 |
3 | import os
4 | import shutil
5 |
6 | import numpy as np
7 | import json
8 | import math
9 |
10 | from .deepbooru import DeepDanbooru
11 |
12 | import cv2
13 | from PIL import Image
14 | from modelscope.pipelines import pipeline
15 | from modelscope.outputs import OutputKeys
16 | from modelscope.utils.constant import Tasks
17 |
18 |
19 | def crop_and_resize(im, bbox, thres=0.35, thres1=0.45):
20 | h, w, _ = im.shape
21 | thre = np.random.rand() * (thres1 - thres) + thres
22 | maxf = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
23 | cx = (bbox[2] + bbox[0]) / 2
24 | cy = (bbox[3] + bbox[1]) / 2
25 | lenp = int(maxf / thre)
26 | yc = np.random.rand() * 0.15 + 0.35
27 | xc = 0.5
28 | xmin = int(cx - xc * lenp)
29 | xmax = xmin + lenp
30 | ymin = int(cy - yc * lenp)
31 | ymax = ymin + lenp
32 | x1 = 0
33 | x2 = lenp
34 | y1 = 0
35 | y2 = lenp
36 | if xmin < 0:
37 | x1 = -xmin
38 | xmin = 0
39 | if xmax > w:
40 | x2 = w - (xmax - lenp)
41 | xmax = w
42 | if ymin < 0:
43 | y1 = -ymin
44 | ymin = 0
45 | if ymax > h:
46 | y2 = h - (ymax - lenp)
47 | ymax = h
48 | imc = (np.ones((lenp, lenp, 3)) * 255).astype(np.uint8)
49 | imc[y1:y2, x1:x2, :] = im[ymin:ymax, xmin:xmax, :]
50 | imr = cv2.resize(imc, (512, 512))
51 | return imr
52 |
53 |
54 | def pad_to_square(im):
55 | h, w, _ = im.shape
56 | ns = int(max(h, w) * 1.5)
57 | im = cv2.copyMakeBorder(im, int((ns - h) / 2), (ns - h) - int((ns - h) / 2), int((ns - w) / 2),
58 | (ns - w) - int((ns - w) / 2), cv2.BORDER_CONSTANT, 255)
59 | return im
60 |
61 |
62 | def post_process_naive(result_list, score_gender, score_age):
63 | # determine trigger word
64 | gender = np.argmax(score_gender)
65 | age = np.argmax(score_age)
66 | if age < 2:
67 | if gender == 0:
68 | tag_a_g = ['a boy', 'children']
69 | else:
70 | tag_a_g = ['a girl', 'children']
71 | elif age > 4:
72 | if gender == 0:
73 | tag_a_g = ['a mature man']
74 | else:
75 | tag_a_g = ['a mature woman']
76 | else:
77 | if gender == 0:
78 | tag_a_g = ['a handsome man']
79 | else:
80 | tag_a_g = ['a beautiful woman']
81 | num_images = len(result_list)
82 | cnt_girl = 0
83 | cnt_boy = 0
84 | result_list_new = []
85 | for result in result_list:
86 | result_new = []
87 | result_new.extend(tag_a_g)
88 | for tag in result:
89 | if tag == '1girl' or tag == '1boy':
90 | continue
91 | if tag[-4:] == '_man':
92 | continue
93 | if tag[-6:] == '_woman':
94 | continue
95 | if tag[-5:] == '_male':
96 | continue
97 | elif tag[-7:] == '_female':
98 | continue
99 | elif (
100 | tag == 'ears' or tag == 'head' or tag == 'face' or tag == 'lips' or tag == 'mouth' or tag == '3d' or tag == 'asian' or tag == 'teeth'):
101 | continue
102 | elif ('eye' in tag and not 'eyewear' in tag):
103 | continue
104 | elif ('nose' in tag or 'body' in tag):
105 | continue
106 | elif tag[-5:] == '_lips':
107 | continue
108 | else:
109 | result_new.append(tag)
110 | # import pdb;pdb.set_trace()
111 | # result_new.append('slim body')
112 | result_list_new.append(result_new)
113 |
114 | return result_list_new
115 |
116 |
117 | def transformation_from_points(points1, points2):
118 | points1 = points1.astype(np.float64)
119 | points2 = points2.astype(np.float64)
120 | c1 = np.mean(points1, axis=0)
121 | c2 = np.mean(points2, axis=0)
122 | points1 -= c1
123 | points2 -= c2
124 | s1 = np.std(points1)
125 | s2 = np.std(points2)
126 | if s1 < 1.0e-4:
127 | s1 = 1.0e-4
128 | points1 /= s1
129 | points2 /= s2
130 | U, S, Vt = np.linalg.svd(points1.T * points2)
131 | R = (U * Vt).T
132 | return np.vstack([np.hstack(((s2 / s1) * R, c2.T - (s2 / s1) * R * c1.T)), np.matrix([0., 0., 1.])])
133 |
134 |
135 | def rotate(im, keypoints):
136 | h, w, _ = im.shape
137 | points_array = np.zeros((5, 2))
138 | dst_mean_face_size = 160
139 | dst_mean_face = np.asarray([0.31074522411511746, 0.2798131190011913,
140 | 0.6892073313037804, 0.2797830232679366,
141 | 0.49997367716346774, 0.5099309118810921,
142 | 0.35811903020866753, 0.7233174007629063,
143 | 0.6418878095835022, 0.7232890570786875])
144 | dst_mean_face = np.reshape(dst_mean_face, (5, 2)) * dst_mean_face_size
145 |
146 | for k in range(5):
147 | points_array[k, 0] = keypoints[2 * k]
148 | points_array[k, 1] = keypoints[2 * k + 1]
149 |
150 | pts1 = np.float64(np.matrix([[point[0], point[1]] for point in points_array]))
151 | pts2 = np.float64(np.matrix([[point[0], point[1]] for point in dst_mean_face]))
152 | trans_mat = transformation_from_points(pts1, pts2)
153 | if trans_mat[1, 1] > 1.0e-4:
154 | angle = math.atan(trans_mat[1, 0] / trans_mat[1, 1])
155 | else:
156 | angle = math.atan(trans_mat[0, 1] / trans_mat[0, 2])
157 | im = pad_to_square(im)
158 | ns = int(1.5 * max(h, w))
159 | M = cv2.getRotationMatrix2D((ns / 2, ns / 2), angle=-angle / np.pi * 180, scale=1.0)
160 | im = cv2.warpAffine(im, M=M, dsize=(ns, ns))
161 | return im
162 |
163 |
164 | def get_mask_head(result):
165 | masks = result['masks']
166 | scores = result['scores']
167 | labels = result['labels']
168 | mask_hair = np.zeros((512, 512))
169 | mask_face = np.zeros((512, 512))
170 | mask_human = np.zeros((512, 512))
171 | for i in range(len(labels)):
172 | if scores[i] > 0.8:
173 | if labels[i] == 'Face':
174 | if np.sum(masks[i]) > np.sum(mask_face):
175 | mask_face = masks[i]
176 | elif labels[i] == 'Human':
177 | if np.sum(masks[i]) > np.sum(mask_human):
178 | mask_human = masks[i]
179 | elif labels[i] == 'Hair':
180 | if np.sum(masks[i]) > np.sum(mask_hair):
181 | mask_hair = masks[i]
182 | mask_head = np.clip(mask_hair + mask_face, 0, 1)
183 | ksize = max(int(np.sqrt(np.sum(mask_face)) / 20), 1)
184 | kernel = np.ones((ksize, ksize))
185 | mask_head = cv2.dilate(mask_head, kernel, iterations=1) * mask_human
186 | _, mask_head = cv2.threshold((mask_head * 255).astype(np.uint8), 127, 255, cv2.THRESH_BINARY)
187 | contours, hierarchy = cv2.findContours(mask_head, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
188 | area = []
189 | for j in range(len(contours)):
190 | area.append(cv2.contourArea(contours[j]))
191 | max_idx = np.argmax(area)
192 | mask_head = np.zeros((512, 512)).astype(np.uint8)
193 | cv2.fillPoly(mask_head, [contours[max_idx]], 255)
194 | mask_head = mask_head.astype(np.float32) / 255
195 | mask_head = np.clip(mask_head + mask_face, 0, 1)
196 | mask_head = np.expand_dims(mask_head, 2)
197 | return mask_head
198 |
199 |
200 | class Blipv2():
201 | def __init__(self):
202 | self.model = DeepDanbooru()
203 | self.skin_retouching = pipeline(Tasks.skin_retouching, model='damo/cv_unet_skin-retouching')
204 | self.face_detection = pipeline(task=Tasks.face_detection, model='damo/cv_ddsar_face-detection_iclr23-damofd')
205 | # self.mog_face_detection_func = pipeline(Tasks.face_detection, 'damo/cv_resnet101_face-detection_cvpr22papermogface')
206 | self.segmentation_pipeline = pipeline(Tasks.image_segmentation,
207 | 'damo/cv_resnet101_image-multiple-human-parsing')
208 | self.fair_face_attribute_func = pipeline(Tasks.face_attribute_recognition,
209 | 'damo/cv_resnet34_face-attribute-recognition_fairface')
210 | self.facial_landmark_confidence_func = pipeline(Tasks.face_2d_keypoints,
211 | 'damo/cv_manual_facial-landmark-confidence_flcm')
212 |
213 | def __call__(self, imdir):
214 | self.model.start()
215 | savedir = str(imdir) + '_labeled'
216 | shutil.rmtree(savedir, ignore_errors=True)
217 | os.makedirs(savedir, exist_ok=True)
218 |
219 | imlist = os.listdir(imdir)
220 | result_list = []
221 | imgs_list = []
222 |
223 | cnt = 0
224 | tmp_path = os.path.join(savedir, 'tmp.png')
225 | for imname in imlist:
226 | try:
227 | # if 1:
228 | if imname.startswith('.'):
229 | continue
230 | img_path = os.path.join(imdir, imname)
231 | im = cv2.imread(img_path)
232 | h, w, _ = im.shape
233 | max_size = max(w, h)
234 | ratio = 1024 / max_size
235 | new_w = round(w * ratio)
236 | new_h = round(h * ratio)
237 | imt = cv2.resize(im, (new_w, new_h))
238 | cv2.imwrite(tmp_path, imt)
239 | result_det = self.face_detection(tmp_path)
240 | bboxes = result_det['boxes']
241 | if len(bboxes) > 1:
242 | areas = []
243 | for i in range(len(bboxes)):
244 | bbox = bboxes[i]
245 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))
246 | areas = np.array(areas)
247 | areas_new = np.sort(areas)[::-1]
248 | idxs = np.argsort(areas)[::-1]
249 | if areas_new[0] < 4 * areas_new[1]:
250 | print('Detecting multiple faces, do not use image {}.'.format(imname))
251 | continue
252 | else:
253 | keypoints = result_det['keypoints'][idxs[0]]
254 | elif len(bboxes) == 0:
255 | print('Detecting no face, do not use image {}.'.format(imname))
256 | continue
257 | else:
258 | keypoints = result_det['keypoints'][0]
259 |
260 | im = rotate(im, keypoints)
261 | ns = im.shape[0]
262 | imt = cv2.resize(im, (1024, 1024))
263 | cv2.imwrite(tmp_path, imt)
264 | result_det = self.face_detection(tmp_path)
265 | bboxes = result_det['boxes']
266 |
267 | if len(bboxes) > 1:
268 | areas = []
269 | for i in range(len(bboxes)):
270 | bbox = bboxes[i]
271 | areas.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))
272 | areas = np.array(areas)
273 | areas_new = np.sort(areas)[::-1]
274 | idxs = np.argsort(areas)[::-1]
275 | if areas_new[0] < 4 * areas_new[1]:
276 | print('Detecting multiple faces after rotation, do not use image {}.'.format(imname))
277 | continue
278 | else:
279 | bbox = bboxes[idxs[0]]
280 | elif len(bboxes) == 0:
281 | print('Detecting no face after rotation, do not use this image {}'.format(imname))
282 | continue
283 | else:
284 | bbox = bboxes[0]
285 |
286 | for idx in range(4):
287 | bbox[idx] = bbox[idx] * ns / 1024
288 | imr = crop_and_resize(im, bbox)
289 | cv2.imwrite(tmp_path, imr)
290 |
291 | result = self.skin_retouching(tmp_path)
292 | if (result is None or (result[OutputKeys.OUTPUT_IMG] is None)):
293 | print('Cannot do skin retouching, do not use this image.')
294 | continue
295 | cv2.imwrite(tmp_path, result[OutputKeys.OUTPUT_IMG])
296 |
297 | result = self.segmentation_pipeline(tmp_path)
298 | mask_head = get_mask_head(result)
299 | im = cv2.imread(tmp_path)
300 | im = im * mask_head + 255 * (1 - mask_head)
301 | # print(im.shape)
302 |
303 | raw_result = self.facial_landmark_confidence_func(im)
304 | if raw_result is None:
305 | print('landmark quality fail...')
306 | continue
307 |
308 | print(imname, raw_result['scores'][0])
309 | if float(raw_result['scores'][0]) < (1 - 0.145):
310 | print('landmark quality fail...')
311 | continue
312 |
313 | cv2.imwrite(os.path.join(savedir, '{}.png'.format(cnt)), im)
314 | imgs_list.append('{}.png'.format(cnt))
315 | img = Image.open(os.path.join(savedir, '{}.png'.format(cnt)))
316 | result = self.model.tag(img)
317 | print(result)
318 | attribute_result = self.fair_face_attribute_func(tmp_path)
319 | if cnt == 0:
320 | score_gender = np.array(attribute_result['scores'][0])
321 | score_age = np.array(attribute_result['scores'][1])
322 | else:
323 | score_gender += np.array(attribute_result['scores'][0])
324 | score_age += np.array(attribute_result['scores'][1])
325 |
326 | result_list.append(result.split(', '))
327 | cnt += 1
328 | except Exception as e:
329 | print('cathed for image process of ' + imname)
330 | print('Error: ' + e)
331 |
332 | print(result_list)
333 | if len(result_list) == 0:
334 | print('Error: ' + e)
335 | exit()
336 | return os.path.join(savedir, "metadata.jsonl")
337 |
338 | result_list = post_process_naive(result_list, score_gender, score_age)
339 | self.model.stop()
340 | # os.system('rm ' + tmp_path)
341 | os.system('del ' + tmp_path)
342 |
343 | out_json_name = os.path.join(savedir, "metadata.jsonl")
344 | fo = open(out_json_name, 'w')
345 | for i in range(len(result_list)):
346 | generated_text = ", ".join(result_list[i])
347 | print(imgs_list[i], generated_text)
348 | info_dict = {"file_name": imgs_list[i], "text": ", " + generated_text}
349 | fo.write(json.dumps(info_dict) + '\n')
350 | fo.close()
351 | return out_json_name
352 |
--------------------------------------------------------------------------------
/facechain/data_process/deepbooru.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Alibaba, Inc. and its affiliates.
2 |
3 | import os
4 | import re
5 |
6 | from PIL import Image
7 | import numpy as np
8 |
9 | re_special = re.compile(r'([\\()])')
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 | from modelscope.hub.snapshot_download import snapshot_download
15 |
16 | # see https://github.com/AUTOMATIC1111/TorchDeepDanbooru for more
17 | LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
18 |
19 |
20 | class DeepDanbooruModel(nn.Module):
21 | def __init__(self):
22 | super(DeepDanbooruModel, self).__init__()
23 |
24 | self.tags = []
25 |
26 | self.n_Conv_0 = nn.Conv2d(kernel_size=(7, 7), in_channels=3, out_channels=64, stride=(2, 2))
27 | self.n_MaxPool_0 = nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2))
28 | self.n_Conv_1 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
29 | self.n_Conv_2 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=64)
30 | self.n_Conv_3 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
31 | self.n_Conv_4 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
32 | self.n_Conv_5 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
33 | self.n_Conv_6 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
34 | self.n_Conv_7 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
35 | self.n_Conv_8 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=64)
36 | self.n_Conv_9 = nn.Conv2d(kernel_size=(3, 3), in_channels=64, out_channels=64)
37 | self.n_Conv_10 = nn.Conv2d(kernel_size=(1, 1), in_channels=64, out_channels=256)
38 | self.n_Conv_11 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=512, stride=(2, 2))
39 | self.n_Conv_12 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=128)
40 | self.n_Conv_13 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128, stride=(2, 2))
41 | self.n_Conv_14 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
42 | self.n_Conv_15 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
43 | self.n_Conv_16 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
44 | self.n_Conv_17 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
45 | self.n_Conv_18 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
46 | self.n_Conv_19 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
47 | self.n_Conv_20 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
48 | self.n_Conv_21 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
49 | self.n_Conv_22 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
50 | self.n_Conv_23 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
51 | self.n_Conv_24 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
52 | self.n_Conv_25 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
53 | self.n_Conv_26 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
54 | self.n_Conv_27 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
55 | self.n_Conv_28 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
56 | self.n_Conv_29 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
57 | self.n_Conv_30 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
58 | self.n_Conv_31 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
59 | self.n_Conv_32 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
60 | self.n_Conv_33 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=128)
61 | self.n_Conv_34 = nn.Conv2d(kernel_size=(3, 3), in_channels=128, out_channels=128)
62 | self.n_Conv_35 = nn.Conv2d(kernel_size=(1, 1), in_channels=128, out_channels=512)
63 | self.n_Conv_36 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=1024, stride=(2, 2))
64 | self.n_Conv_37 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=256)
65 | self.n_Conv_38 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
66 | self.n_Conv_39 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
67 | self.n_Conv_40 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
68 | self.n_Conv_41 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
69 | self.n_Conv_42 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
70 | self.n_Conv_43 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
71 | self.n_Conv_44 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
72 | self.n_Conv_45 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
73 | self.n_Conv_46 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
74 | self.n_Conv_47 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
75 | self.n_Conv_48 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
76 | self.n_Conv_49 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
77 | self.n_Conv_50 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
78 | self.n_Conv_51 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
79 | self.n_Conv_52 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
80 | self.n_Conv_53 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
81 | self.n_Conv_54 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
82 | self.n_Conv_55 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
83 | self.n_Conv_56 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
84 | self.n_Conv_57 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
85 | self.n_Conv_58 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
86 | self.n_Conv_59 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
87 | self.n_Conv_60 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
88 | self.n_Conv_61 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
89 | self.n_Conv_62 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
90 | self.n_Conv_63 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
91 | self.n_Conv_64 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
92 | self.n_Conv_65 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
93 | self.n_Conv_66 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
94 | self.n_Conv_67 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
95 | self.n_Conv_68 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
96 | self.n_Conv_69 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
97 | self.n_Conv_70 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
98 | self.n_Conv_71 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
99 | self.n_Conv_72 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
100 | self.n_Conv_73 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
101 | self.n_Conv_74 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
102 | self.n_Conv_75 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
103 | self.n_Conv_76 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
104 | self.n_Conv_77 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
105 | self.n_Conv_78 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
106 | self.n_Conv_79 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
107 | self.n_Conv_80 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
108 | self.n_Conv_81 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
109 | self.n_Conv_82 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
110 | self.n_Conv_83 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
111 | self.n_Conv_84 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
112 | self.n_Conv_85 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
113 | self.n_Conv_86 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
114 | self.n_Conv_87 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
115 | self.n_Conv_88 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
116 | self.n_Conv_89 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
117 | self.n_Conv_90 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
118 | self.n_Conv_91 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
119 | self.n_Conv_92 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
120 | self.n_Conv_93 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
121 | self.n_Conv_94 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
122 | self.n_Conv_95 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
123 | self.n_Conv_96 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
124 | self.n_Conv_97 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
125 | self.n_Conv_98 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256, stride=(2, 2))
126 | self.n_Conv_99 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
127 | self.n_Conv_100 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=1024, stride=(2, 2))
128 | self.n_Conv_101 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
129 | self.n_Conv_102 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
130 | self.n_Conv_103 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
131 | self.n_Conv_104 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
132 | self.n_Conv_105 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
133 | self.n_Conv_106 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
134 | self.n_Conv_107 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
135 | self.n_Conv_108 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
136 | self.n_Conv_109 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
137 | self.n_Conv_110 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
138 | self.n_Conv_111 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
139 | self.n_Conv_112 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
140 | self.n_Conv_113 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
141 | self.n_Conv_114 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
142 | self.n_Conv_115 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
143 | self.n_Conv_116 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
144 | self.n_Conv_117 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
145 | self.n_Conv_118 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
146 | self.n_Conv_119 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
147 | self.n_Conv_120 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
148 | self.n_Conv_121 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
149 | self.n_Conv_122 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
150 | self.n_Conv_123 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
151 | self.n_Conv_124 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
152 | self.n_Conv_125 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
153 | self.n_Conv_126 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
154 | self.n_Conv_127 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
155 | self.n_Conv_128 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
156 | self.n_Conv_129 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
157 | self.n_Conv_130 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
158 | self.n_Conv_131 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
159 | self.n_Conv_132 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
160 | self.n_Conv_133 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
161 | self.n_Conv_134 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
162 | self.n_Conv_135 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
163 | self.n_Conv_136 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
164 | self.n_Conv_137 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
165 | self.n_Conv_138 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
166 | self.n_Conv_139 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
167 | self.n_Conv_140 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
168 | self.n_Conv_141 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
169 | self.n_Conv_142 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
170 | self.n_Conv_143 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
171 | self.n_Conv_144 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
172 | self.n_Conv_145 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
173 | self.n_Conv_146 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
174 | self.n_Conv_147 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
175 | self.n_Conv_148 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
176 | self.n_Conv_149 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
177 | self.n_Conv_150 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
178 | self.n_Conv_151 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
179 | self.n_Conv_152 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
180 | self.n_Conv_153 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
181 | self.n_Conv_154 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
182 | self.n_Conv_155 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=256)
183 | self.n_Conv_156 = nn.Conv2d(kernel_size=(3, 3), in_channels=256, out_channels=256)
184 | self.n_Conv_157 = nn.Conv2d(kernel_size=(1, 1), in_channels=256, out_channels=1024)
185 | self.n_Conv_158 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=2048, stride=(2, 2))
186 | self.n_Conv_159 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=512)
187 | self.n_Conv_160 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512, stride=(2, 2))
188 | self.n_Conv_161 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
189 | self.n_Conv_162 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
190 | self.n_Conv_163 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
191 | self.n_Conv_164 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
192 | self.n_Conv_165 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=512)
193 | self.n_Conv_166 = nn.Conv2d(kernel_size=(3, 3), in_channels=512, out_channels=512)
194 | self.n_Conv_167 = nn.Conv2d(kernel_size=(1, 1), in_channels=512, out_channels=2048)
195 | self.n_Conv_168 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=4096, stride=(2, 2))
196 | self.n_Conv_169 = nn.Conv2d(kernel_size=(1, 1), in_channels=2048, out_channels=1024)
197 | self.n_Conv_170 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024, stride=(2, 2))
198 | self.n_Conv_171 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
199 | self.n_Conv_172 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
200 | self.n_Conv_173 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
201 | self.n_Conv_174 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
202 | self.n_Conv_175 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=1024)
203 | self.n_Conv_176 = nn.Conv2d(kernel_size=(3, 3), in_channels=1024, out_channels=1024)
204 | self.n_Conv_177 = nn.Conv2d(kernel_size=(1, 1), in_channels=1024, out_channels=4096)
205 | self.n_Conv_178 = nn.Conv2d(kernel_size=(1, 1), in_channels=4096, out_channels=9176, bias=False)
206 |
207 | def forward(self, *inputs):
208 | t_358, = inputs
209 | t_359 = t_358.permute(*[0, 3, 1, 2])
210 | t_359_padded = F.pad(t_359, [2, 3, 2, 3], value=0)
211 | # t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype) if devices.unet_needs_upcast else t_359_padded)
212 | t_360 = self.n_Conv_0(t_359_padded.to(self.n_Conv_0.bias.dtype))
213 | t_361 = F.relu(t_360)
214 | t_361 = F.pad(t_361, [0, 1, 0, 1], value=float('-inf'))
215 | t_362 = self.n_MaxPool_0(t_361)
216 | t_363 = self.n_Conv_1(t_362)
217 | t_364 = self.n_Conv_2(t_362)
218 | t_365 = F.relu(t_364)
219 | t_365_padded = F.pad(t_365, [1, 1, 1, 1], value=0)
220 | t_366 = self.n_Conv_3(t_365_padded)
221 | t_367 = F.relu(t_366)
222 | t_368 = self.n_Conv_4(t_367)
223 | t_369 = torch.add(t_368, t_363)
224 | t_370 = F.relu(t_369)
225 | t_371 = self.n_Conv_5(t_370)
226 | t_372 = F.relu(t_371)
227 | t_372_padded = F.pad(t_372, [1, 1, 1, 1], value=0)
228 | t_373 = self.n_Conv_6(t_372_padded)
229 | t_374 = F.relu(t_373)
230 | t_375 = self.n_Conv_7(t_374)
231 | t_376 = torch.add(t_375, t_370)
232 | t_377 = F.relu(t_376)
233 | t_378 = self.n_Conv_8(t_377)
234 | t_379 = F.relu(t_378)
235 | t_379_padded = F.pad(t_379, [1, 1, 1, 1], value=0)
236 | t_380 = self.n_Conv_9(t_379_padded)
237 | t_381 = F.relu(t_380)
238 | t_382 = self.n_Conv_10(t_381)
239 | t_383 = torch.add(t_382, t_377)
240 | t_384 = F.relu(t_383)
241 | t_385 = self.n_Conv_11(t_384)
242 | t_386 = self.n_Conv_12(t_384)
243 | t_387 = F.relu(t_386)
244 | t_387_padded = F.pad(t_387, [0, 1, 0, 1], value=0)
245 | t_388 = self.n_Conv_13(t_387_padded)
246 | t_389 = F.relu(t_388)
247 | t_390 = self.n_Conv_14(t_389)
248 | t_391 = torch.add(t_390, t_385)
249 | t_392 = F.relu(t_391)
250 | t_393 = self.n_Conv_15(t_392)
251 | t_394 = F.relu(t_393)
252 | t_394_padded = F.pad(t_394, [1, 1, 1, 1], value=0)
253 | t_395 = self.n_Conv_16(t_394_padded)
254 | t_396 = F.relu(t_395)
255 | t_397 = self.n_Conv_17(t_396)
256 | t_398 = torch.add(t_397, t_392)
257 | t_399 = F.relu(t_398)
258 | t_400 = self.n_Conv_18(t_399)
259 | t_401 = F.relu(t_400)
260 | t_401_padded = F.pad(t_401, [1, 1, 1, 1], value=0)
261 | t_402 = self.n_Conv_19(t_401_padded)
262 | t_403 = F.relu(t_402)
263 | t_404 = self.n_Conv_20(t_403)
264 | t_405 = torch.add(t_404, t_399)
265 | t_406 = F.relu(t_405)
266 | t_407 = self.n_Conv_21(t_406)
267 | t_408 = F.relu(t_407)
268 | t_408_padded = F.pad(t_408, [1, 1, 1, 1], value=0)
269 | t_409 = self.n_Conv_22(t_408_padded)
270 | t_410 = F.relu(t_409)
271 | t_411 = self.n_Conv_23(t_410)
272 | t_412 = torch.add(t_411, t_406)
273 | t_413 = F.relu(t_412)
274 | t_414 = self.n_Conv_24(t_413)
275 | t_415 = F.relu(t_414)
276 | t_415_padded = F.pad(t_415, [1, 1, 1, 1], value=0)
277 | t_416 = self.n_Conv_25(t_415_padded)
278 | t_417 = F.relu(t_416)
279 | t_418 = self.n_Conv_26(t_417)
280 | t_419 = torch.add(t_418, t_413)
281 | t_420 = F.relu(t_419)
282 | t_421 = self.n_Conv_27(t_420)
283 | t_422 = F.relu(t_421)
284 | t_422_padded = F.pad(t_422, [1, 1, 1, 1], value=0)
285 | t_423 = self.n_Conv_28(t_422_padded)
286 | t_424 = F.relu(t_423)
287 | t_425 = self.n_Conv_29(t_424)
288 | t_426 = torch.add(t_425, t_420)
289 | t_427 = F.relu(t_426)
290 | t_428 = self.n_Conv_30(t_427)
291 | t_429 = F.relu(t_428)
292 | t_429_padded = F.pad(t_429, [1, 1, 1, 1], value=0)
293 | t_430 = self.n_Conv_31(t_429_padded)
294 | t_431 = F.relu(t_430)
295 | t_432 = self.n_Conv_32(t_431)
296 | t_433 = torch.add(t_432, t_427)
297 | t_434 = F.relu(t_433)
298 | t_435 = self.n_Conv_33(t_434)
299 | t_436 = F.relu(t_435)
300 | t_436_padded = F.pad(t_436, [1, 1, 1, 1], value=0)
301 | t_437 = self.n_Conv_34(t_436_padded)
302 | t_438 = F.relu(t_437)
303 | t_439 = self.n_Conv_35(t_438)
304 | t_440 = torch.add(t_439, t_434)
305 | t_441 = F.relu(t_440)
306 | t_442 = self.n_Conv_36(t_441)
307 | t_443 = self.n_Conv_37(t_441)
308 | t_444 = F.relu(t_443)
309 | t_444_padded = F.pad(t_444, [0, 1, 0, 1], value=0)
310 | t_445 = self.n_Conv_38(t_444_padded)
311 | t_446 = F.relu(t_445)
312 | t_447 = self.n_Conv_39(t_446)
313 | t_448 = torch.add(t_447, t_442)
314 | t_449 = F.relu(t_448)
315 | t_450 = self.n_Conv_40(t_449)
316 | t_451 = F.relu(t_450)
317 | t_451_padded = F.pad(t_451, [1, 1, 1, 1], value=0)
318 | t_452 = self.n_Conv_41(t_451_padded)
319 | t_453 = F.relu(t_452)
320 | t_454 = self.n_Conv_42(t_453)
321 | t_455 = torch.add(t_454, t_449)
322 | t_456 = F.relu(t_455)
323 | t_457 = self.n_Conv_43(t_456)
324 | t_458 = F.relu(t_457)
325 | t_458_padded = F.pad(t_458, [1, 1, 1, 1], value=0)
326 | t_459 = self.n_Conv_44(t_458_padded)
327 | t_460 = F.relu(t_459)
328 | t_461 = self.n_Conv_45(t_460)
329 | t_462 = torch.add(t_461, t_456)
330 | t_463 = F.relu(t_462)
331 | t_464 = self.n_Conv_46(t_463)
332 | t_465 = F.relu(t_464)
333 | t_465_padded = F.pad(t_465, [1, 1, 1, 1], value=0)
334 | t_466 = self.n_Conv_47(t_465_padded)
335 | t_467 = F.relu(t_466)
336 | t_468 = self.n_Conv_48(t_467)
337 | t_469 = torch.add(t_468, t_463)
338 | t_470 = F.relu(t_469)
339 | t_471 = self.n_Conv_49(t_470)
340 | t_472 = F.relu(t_471)
341 | t_472_padded = F.pad(t_472, [1, 1, 1, 1], value=0)
342 | t_473 = self.n_Conv_50(t_472_padded)
343 | t_474 = F.relu(t_473)
344 | t_475 = self.n_Conv_51(t_474)
345 | t_476 = torch.add(t_475, t_470)
346 | t_477 = F.relu(t_476)
347 | t_478 = self.n_Conv_52(t_477)
348 | t_479 = F.relu(t_478)
349 | t_479_padded = F.pad(t_479, [1, 1, 1, 1], value=0)
350 | t_480 = self.n_Conv_53(t_479_padded)
351 | t_481 = F.relu(t_480)
352 | t_482 = self.n_Conv_54(t_481)
353 | t_483 = torch.add(t_482, t_477)
354 | t_484 = F.relu(t_483)
355 | t_485 = self.n_Conv_55(t_484)
356 | t_486 = F.relu(t_485)
357 | t_486_padded = F.pad(t_486, [1, 1, 1, 1], value=0)
358 | t_487 = self.n_Conv_56(t_486_padded)
359 | t_488 = F.relu(t_487)
360 | t_489 = self.n_Conv_57(t_488)
361 | t_490 = torch.add(t_489, t_484)
362 | t_491 = F.relu(t_490)
363 | t_492 = self.n_Conv_58(t_491)
364 | t_493 = F.relu(t_492)
365 | t_493_padded = F.pad(t_493, [1, 1, 1, 1], value=0)
366 | t_494 = self.n_Conv_59(t_493_padded)
367 | t_495 = F.relu(t_494)
368 | t_496 = self.n_Conv_60(t_495)
369 | t_497 = torch.add(t_496, t_491)
370 | t_498 = F.relu(t_497)
371 | t_499 = self.n_Conv_61(t_498)
372 | t_500 = F.relu(t_499)
373 | t_500_padded = F.pad(t_500, [1, 1, 1, 1], value=0)
374 | t_501 = self.n_Conv_62(t_500_padded)
375 | t_502 = F.relu(t_501)
376 | t_503 = self.n_Conv_63(t_502)
377 | t_504 = torch.add(t_503, t_498)
378 | t_505 = F.relu(t_504)
379 | t_506 = self.n_Conv_64(t_505)
380 | t_507 = F.relu(t_506)
381 | t_507_padded = F.pad(t_507, [1, 1, 1, 1], value=0)
382 | t_508 = self.n_Conv_65(t_507_padded)
383 | t_509 = F.relu(t_508)
384 | t_510 = self.n_Conv_66(t_509)
385 | t_511 = torch.add(t_510, t_505)
386 | t_512 = F.relu(t_511)
387 | t_513 = self.n_Conv_67(t_512)
388 | t_514 = F.relu(t_513)
389 | t_514_padded = F.pad(t_514, [1, 1, 1, 1], value=0)
390 | t_515 = self.n_Conv_68(t_514_padded)
391 | t_516 = F.relu(t_515)
392 | t_517 = self.n_Conv_69(t_516)
393 | t_518 = torch.add(t_517, t_512)
394 | t_519 = F.relu(t_518)
395 | t_520 = self.n_Conv_70(t_519)
396 | t_521 = F.relu(t_520)
397 | t_521_padded = F.pad(t_521, [1, 1, 1, 1], value=0)
398 | t_522 = self.n_Conv_71(t_521_padded)
399 | t_523 = F.relu(t_522)
400 | t_524 = self.n_Conv_72(t_523)
401 | t_525 = torch.add(t_524, t_519)
402 | t_526 = F.relu(t_525)
403 | t_527 = self.n_Conv_73(t_526)
404 | t_528 = F.relu(t_527)
405 | t_528_padded = F.pad(t_528, [1, 1, 1, 1], value=0)
406 | t_529 = self.n_Conv_74(t_528_padded)
407 | t_530 = F.relu(t_529)
408 | t_531 = self.n_Conv_75(t_530)
409 | t_532 = torch.add(t_531, t_526)
410 | t_533 = F.relu(t_532)
411 | t_534 = self.n_Conv_76(t_533)
412 | t_535 = F.relu(t_534)
413 | t_535_padded = F.pad(t_535, [1, 1, 1, 1], value=0)
414 | t_536 = self.n_Conv_77(t_535_padded)
415 | t_537 = F.relu(t_536)
416 | t_538 = self.n_Conv_78(t_537)
417 | t_539 = torch.add(t_538, t_533)
418 | t_540 = F.relu(t_539)
419 | t_541 = self.n_Conv_79(t_540)
420 | t_542 = F.relu(t_541)
421 | t_542_padded = F.pad(t_542, [1, 1, 1, 1], value=0)
422 | t_543 = self.n_Conv_80(t_542_padded)
423 | t_544 = F.relu(t_543)
424 | t_545 = self.n_Conv_81(t_544)
425 | t_546 = torch.add(t_545, t_540)
426 | t_547 = F.relu(t_546)
427 | t_548 = self.n_Conv_82(t_547)
428 | t_549 = F.relu(t_548)
429 | t_549_padded = F.pad(t_549, [1, 1, 1, 1], value=0)
430 | t_550 = self.n_Conv_83(t_549_padded)
431 | t_551 = F.relu(t_550)
432 | t_552 = self.n_Conv_84(t_551)
433 | t_553 = torch.add(t_552, t_547)
434 | t_554 = F.relu(t_553)
435 | t_555 = self.n_Conv_85(t_554)
436 | t_556 = F.relu(t_555)
437 | t_556_padded = F.pad(t_556, [1, 1, 1, 1], value=0)
438 | t_557 = self.n_Conv_86(t_556_padded)
439 | t_558 = F.relu(t_557)
440 | t_559 = self.n_Conv_87(t_558)
441 | t_560 = torch.add(t_559, t_554)
442 | t_561 = F.relu(t_560)
443 | t_562 = self.n_Conv_88(t_561)
444 | t_563 = F.relu(t_562)
445 | t_563_padded = F.pad(t_563, [1, 1, 1, 1], value=0)
446 | t_564 = self.n_Conv_89(t_563_padded)
447 | t_565 = F.relu(t_564)
448 | t_566 = self.n_Conv_90(t_565)
449 | t_567 = torch.add(t_566, t_561)
450 | t_568 = F.relu(t_567)
451 | t_569 = self.n_Conv_91(t_568)
452 | t_570 = F.relu(t_569)
453 | t_570_padded = F.pad(t_570, [1, 1, 1, 1], value=0)
454 | t_571 = self.n_Conv_92(t_570_padded)
455 | t_572 = F.relu(t_571)
456 | t_573 = self.n_Conv_93(t_572)
457 | t_574 = torch.add(t_573, t_568)
458 | t_575 = F.relu(t_574)
459 | t_576 = self.n_Conv_94(t_575)
460 | t_577 = F.relu(t_576)
461 | t_577_padded = F.pad(t_577, [1, 1, 1, 1], value=0)
462 | t_578 = self.n_Conv_95(t_577_padded)
463 | t_579 = F.relu(t_578)
464 | t_580 = self.n_Conv_96(t_579)
465 | t_581 = torch.add(t_580, t_575)
466 | t_582 = F.relu(t_581)
467 | t_583 = self.n_Conv_97(t_582)
468 | t_584 = F.relu(t_583)
469 | t_584_padded = F.pad(t_584, [0, 1, 0, 1], value=0)
470 | t_585 = self.n_Conv_98(t_584_padded)
471 | t_586 = F.relu(t_585)
472 | t_587 = self.n_Conv_99(t_586)
473 | t_588 = self.n_Conv_100(t_582)
474 | t_589 = torch.add(t_587, t_588)
475 | t_590 = F.relu(t_589)
476 | t_591 = self.n_Conv_101(t_590)
477 | t_592 = F.relu(t_591)
478 | t_592_padded = F.pad(t_592, [1, 1, 1, 1], value=0)
479 | t_593 = self.n_Conv_102(t_592_padded)
480 | t_594 = F.relu(t_593)
481 | t_595 = self.n_Conv_103(t_594)
482 | t_596 = torch.add(t_595, t_590)
483 | t_597 = F.relu(t_596)
484 | t_598 = self.n_Conv_104(t_597)
485 | t_599 = F.relu(t_598)
486 | t_599_padded = F.pad(t_599, [1, 1, 1, 1], value=0)
487 | t_600 = self.n_Conv_105(t_599_padded)
488 | t_601 = F.relu(t_600)
489 | t_602 = self.n_Conv_106(t_601)
490 | t_603 = torch.add(t_602, t_597)
491 | t_604 = F.relu(t_603)
492 | t_605 = self.n_Conv_107(t_604)
493 | t_606 = F.relu(t_605)
494 | t_606_padded = F.pad(t_606, [1, 1, 1, 1], value=0)
495 | t_607 = self.n_Conv_108(t_606_padded)
496 | t_608 = F.relu(t_607)
497 | t_609 = self.n_Conv_109(t_608)
498 | t_610 = torch.add(t_609, t_604)
499 | t_611 = F.relu(t_610)
500 | t_612 = self.n_Conv_110(t_611)
501 | t_613 = F.relu(t_612)
502 | t_613_padded = F.pad(t_613, [1, 1, 1, 1], value=0)
503 | t_614 = self.n_Conv_111(t_613_padded)
504 | t_615 = F.relu(t_614)
505 | t_616 = self.n_Conv_112(t_615)
506 | t_617 = torch.add(t_616, t_611)
507 | t_618 = F.relu(t_617)
508 | t_619 = self.n_Conv_113(t_618)
509 | t_620 = F.relu(t_619)
510 | t_620_padded = F.pad(t_620, [1, 1, 1, 1], value=0)
511 | t_621 = self.n_Conv_114(t_620_padded)
512 | t_622 = F.relu(t_621)
513 | t_623 = self.n_Conv_115(t_622)
514 | t_624 = torch.add(t_623, t_618)
515 | t_625 = F.relu(t_624)
516 | t_626 = self.n_Conv_116(t_625)
517 | t_627 = F.relu(t_626)
518 | t_627_padded = F.pad(t_627, [1, 1, 1, 1], value=0)
519 | t_628 = self.n_Conv_117(t_627_padded)
520 | t_629 = F.relu(t_628)
521 | t_630 = self.n_Conv_118(t_629)
522 | t_631 = torch.add(t_630, t_625)
523 | t_632 = F.relu(t_631)
524 | t_633 = self.n_Conv_119(t_632)
525 | t_634 = F.relu(t_633)
526 | t_634_padded = F.pad(t_634, [1, 1, 1, 1], value=0)
527 | t_635 = self.n_Conv_120(t_634_padded)
528 | t_636 = F.relu(t_635)
529 | t_637 = self.n_Conv_121(t_636)
530 | t_638 = torch.add(t_637, t_632)
531 | t_639 = F.relu(t_638)
532 | t_640 = self.n_Conv_122(t_639)
533 | t_641 = F.relu(t_640)
534 | t_641_padded = F.pad(t_641, [1, 1, 1, 1], value=0)
535 | t_642 = self.n_Conv_123(t_641_padded)
536 | t_643 = F.relu(t_642)
537 | t_644 = self.n_Conv_124(t_643)
538 | t_645 = torch.add(t_644, t_639)
539 | t_646 = F.relu(t_645)
540 | t_647 = self.n_Conv_125(t_646)
541 | t_648 = F.relu(t_647)
542 | t_648_padded = F.pad(t_648, [1, 1, 1, 1], value=0)
543 | t_649 = self.n_Conv_126(t_648_padded)
544 | t_650 = F.relu(t_649)
545 | t_651 = self.n_Conv_127(t_650)
546 | t_652 = torch.add(t_651, t_646)
547 | t_653 = F.relu(t_652)
548 | t_654 = self.n_Conv_128(t_653)
549 | t_655 = F.relu(t_654)
550 | t_655_padded = F.pad(t_655, [1, 1, 1, 1], value=0)
551 | t_656 = self.n_Conv_129(t_655_padded)
552 | t_657 = F.relu(t_656)
553 | t_658 = self.n_Conv_130(t_657)
554 | t_659 = torch.add(t_658, t_653)
555 | t_660 = F.relu(t_659)
556 | t_661 = self.n_Conv_131(t_660)
557 | t_662 = F.relu(t_661)
558 | t_662_padded = F.pad(t_662, [1, 1, 1, 1], value=0)
559 | t_663 = self.n_Conv_132(t_662_padded)
560 | t_664 = F.relu(t_663)
561 | t_665 = self.n_Conv_133(t_664)
562 | t_666 = torch.add(t_665, t_660)
563 | t_667 = F.relu(t_666)
564 | t_668 = self.n_Conv_134(t_667)
565 | t_669 = F.relu(t_668)
566 | t_669_padded = F.pad(t_669, [1, 1, 1, 1], value=0)
567 | t_670 = self.n_Conv_135(t_669_padded)
568 | t_671 = F.relu(t_670)
569 | t_672 = self.n_Conv_136(t_671)
570 | t_673 = torch.add(t_672, t_667)
571 | t_674 = F.relu(t_673)
572 | t_675 = self.n_Conv_137(t_674)
573 | t_676 = F.relu(t_675)
574 | t_676_padded = F.pad(t_676, [1, 1, 1, 1], value=0)
575 | t_677 = self.n_Conv_138(t_676_padded)
576 | t_678 = F.relu(t_677)
577 | t_679 = self.n_Conv_139(t_678)
578 | t_680 = torch.add(t_679, t_674)
579 | t_681 = F.relu(t_680)
580 | t_682 = self.n_Conv_140(t_681)
581 | t_683 = F.relu(t_682)
582 | t_683_padded = F.pad(t_683, [1, 1, 1, 1], value=0)
583 | t_684 = self.n_Conv_141(t_683_padded)
584 | t_685 = F.relu(t_684)
585 | t_686 = self.n_Conv_142(t_685)
586 | t_687 = torch.add(t_686, t_681)
587 | t_688 = F.relu(t_687)
588 | t_689 = self.n_Conv_143(t_688)
589 | t_690 = F.relu(t_689)
590 | t_690_padded = F.pad(t_690, [1, 1, 1, 1], value=0)
591 | t_691 = self.n_Conv_144(t_690_padded)
592 | t_692 = F.relu(t_691)
593 | t_693 = self.n_Conv_145(t_692)
594 | t_694 = torch.add(t_693, t_688)
595 | t_695 = F.relu(t_694)
596 | t_696 = self.n_Conv_146(t_695)
597 | t_697 = F.relu(t_696)
598 | t_697_padded = F.pad(t_697, [1, 1, 1, 1], value=0)
599 | t_698 = self.n_Conv_147(t_697_padded)
600 | t_699 = F.relu(t_698)
601 | t_700 = self.n_Conv_148(t_699)
602 | t_701 = torch.add(t_700, t_695)
603 | t_702 = F.relu(t_701)
604 | t_703 = self.n_Conv_149(t_702)
605 | t_704 = F.relu(t_703)
606 | t_704_padded = F.pad(t_704, [1, 1, 1, 1], value=0)
607 | t_705 = self.n_Conv_150(t_704_padded)
608 | t_706 = F.relu(t_705)
609 | t_707 = self.n_Conv_151(t_706)
610 | t_708 = torch.add(t_707, t_702)
611 | t_709 = F.relu(t_708)
612 | t_710 = self.n_Conv_152(t_709)
613 | t_711 = F.relu(t_710)
614 | t_711_padded = F.pad(t_711, [1, 1, 1, 1], value=0)
615 | t_712 = self.n_Conv_153(t_711_padded)
616 | t_713 = F.relu(t_712)
617 | t_714 = self.n_Conv_154(t_713)
618 | t_715 = torch.add(t_714, t_709)
619 | t_716 = F.relu(t_715)
620 | t_717 = self.n_Conv_155(t_716)
621 | t_718 = F.relu(t_717)
622 | t_718_padded = F.pad(t_718, [1, 1, 1, 1], value=0)
623 | t_719 = self.n_Conv_156(t_718_padded)
624 | t_720 = F.relu(t_719)
625 | t_721 = self.n_Conv_157(t_720)
626 | t_722 = torch.add(t_721, t_716)
627 | t_723 = F.relu(t_722)
628 | t_724 = self.n_Conv_158(t_723)
629 | t_725 = self.n_Conv_159(t_723)
630 | t_726 = F.relu(t_725)
631 | t_726_padded = F.pad(t_726, [0, 1, 0, 1], value=0)
632 | t_727 = self.n_Conv_160(t_726_padded)
633 | t_728 = F.relu(t_727)
634 | t_729 = self.n_Conv_161(t_728)
635 | t_730 = torch.add(t_729, t_724)
636 | t_731 = F.relu(t_730)
637 | t_732 = self.n_Conv_162(t_731)
638 | t_733 = F.relu(t_732)
639 | t_733_padded = F.pad(t_733, [1, 1, 1, 1], value=0)
640 | t_734 = self.n_Conv_163(t_733_padded)
641 | t_735 = F.relu(t_734)
642 | t_736 = self.n_Conv_164(t_735)
643 | t_737 = torch.add(t_736, t_731)
644 | t_738 = F.relu(t_737)
645 | t_739 = self.n_Conv_165(t_738)
646 | t_740 = F.relu(t_739)
647 | t_740_padded = F.pad(t_740, [1, 1, 1, 1], value=0)
648 | t_741 = self.n_Conv_166(t_740_padded)
649 | t_742 = F.relu(t_741)
650 | t_743 = self.n_Conv_167(t_742)
651 | t_744 = torch.add(t_743, t_738)
652 | t_745 = F.relu(t_744)
653 | t_746 = self.n_Conv_168(t_745)
654 | t_747 = self.n_Conv_169(t_745)
655 | t_748 = F.relu(t_747)
656 | t_748_padded = F.pad(t_748, [0, 1, 0, 1], value=0)
657 | t_749 = self.n_Conv_170(t_748_padded)
658 | t_750 = F.relu(t_749)
659 | t_751 = self.n_Conv_171(t_750)
660 | t_752 = torch.add(t_751, t_746)
661 | t_753 = F.relu(t_752)
662 | t_754 = self.n_Conv_172(t_753)
663 | t_755 = F.relu(t_754)
664 | t_755_padded = F.pad(t_755, [1, 1, 1, 1], value=0)
665 | t_756 = self.n_Conv_173(t_755_padded)
666 | t_757 = F.relu(t_756)
667 | t_758 = self.n_Conv_174(t_757)
668 | t_759 = torch.add(t_758, t_753)
669 | t_760 = F.relu(t_759)
670 | t_761 = self.n_Conv_175(t_760)
671 | t_762 = F.relu(t_761)
672 | t_762_padded = F.pad(t_762, [1, 1, 1, 1], value=0)
673 | t_763 = self.n_Conv_176(t_762_padded)
674 | t_764 = F.relu(t_763)
675 | t_765 = self.n_Conv_177(t_764)
676 | t_766 = torch.add(t_765, t_760)
677 | t_767 = F.relu(t_766)
678 | t_768 = self.n_Conv_178(t_767)
679 | t_769 = F.avg_pool2d(t_768, kernel_size=t_768.shape[-2:])
680 | t_770 = torch.squeeze(t_769, 3)
681 | t_770 = torch.squeeze(t_770, 2)
682 | t_771 = torch.sigmoid(t_770)
683 | return t_771
684 |
685 | def load_state_dict(self, state_dict, **kwargs):
686 | self.tags = state_dict.get('tags', [])
687 |
688 | super(DeepDanbooruModel, self).load_state_dict({k: v for k, v in state_dict.items() if k != 'tags'})
689 |
690 |
691 | def resize_image(im, width, height):
692 | ratio = width / height
693 | src_ratio = im.width / im.height
694 |
695 | src_w = width if ratio < src_ratio else im.width * height // im.height
696 | src_h = height if ratio >= src_ratio else im.height * width // im.width
697 |
698 | resized = im.resize((src_w, src_h), resample=LANCZOS)
699 | res = Image.new("RGB", (width, height))
700 | res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
701 |
702 | if ratio < src_ratio:
703 | fill_height = height // 2 - src_h // 2
704 | res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
705 | res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
706 | box=(0, fill_height + src_h))
707 | elif ratio > src_ratio:
708 | fill_width = width // 2 - src_w // 2
709 | res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
710 | res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
711 | box=(fill_width + src_w, 0))
712 |
713 | return res
714 |
715 |
716 | class DeepDanbooru:
717 | def __init__(self):
718 | self.model = DeepDanbooruModel()
719 |
720 | foundation_model_id = 'ly261666/cv_portrait_model'
721 | snapshot_path = snapshot_download(foundation_model_id, revision='v4.0')
722 | pretrain_model_path = os.path.join(snapshot_path, 'model-resnet_custom_v3.pt')
723 |
724 | self.model.load_state_dict(torch.load(pretrain_model_path, map_location="cpu"))
725 | self.model.eval()
726 | self.model.to(torch.float16)
727 |
728 | def start(self):
729 | self.model.cuda()
730 |
731 | def stop(self):
732 | self.model.cpu()
733 | torch.cuda.empty_cache()
734 | torch.cuda.ipc_collect()
735 |
736 | def tag(self, pil_image):
737 | threshold = 0.5
738 | use_spaces = False
739 | use_escape = True
740 | alpha_sort = True
741 | include_ranks = False
742 |
743 | pic = resize_image(pil_image.convert("RGB"), 512, 512)
744 | a = np.expand_dims(np.array(pic, dtype=np.float32), 0) / 255
745 |
746 | with torch.no_grad(), torch.autocast("cuda"):
747 | x = torch.from_numpy(a).cuda()
748 | y = self.model(x)[0].detach().cpu().numpy()
749 |
750 | probability_dict = {}
751 |
752 | for tag, probability in zip(self.model.tags, y):
753 | if probability < threshold:
754 | continue
755 |
756 | if tag.startswith("rating:"):
757 | continue
758 |
759 | probability_dict[tag] = probability
760 |
761 | if alpha_sort:
762 | tags = sorted(probability_dict)
763 | else:
764 | tags = [tag for tag, _ in sorted(probability_dict.items(), key=lambda x: -x[1])]
765 |
766 | res = []
767 |
768 | for tag in [x for x in tags]:
769 | probability = probability_dict[tag]
770 | tag_outformat = tag
771 | if use_spaces:
772 | tag_outformat = tag_outformat.replace('_', ' ')
773 | if use_escape:
774 | tag_outformat = re.sub(re_special, r'\\\1', tag_outformat)
775 | if include_ranks:
776 | tag_outformat = f"({tag_outformat}:{probability:.3f})"
777 |
778 | res.append(tag_outformat)
779 |
780 | return ", ".join(res)
781 |
782 |
783 | '''
784 | model = DeepDanbooru()
785 | impath = 'lyf'
786 | imlist = os.listdir(impath)
787 | result_list = []
788 | for im in imlist:
789 | if im[-4:]=='.png':
790 | print(im)
791 | img = Image.open(os.path.join(impath, im))
792 | result = model.tag(img)
793 | print(result)
794 | result_list.append(result)
795 | model.stop()
796 | '''
797 |
--------------------------------------------------------------------------------
/facechain/train_text_to_image_lora.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright (c) Alibaba, Inc. and its affiliates.
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
17 |
18 | import argparse
19 | import itertools
20 | import json
21 | import logging
22 | import math
23 | import os
24 | import random
25 | import shutil
26 | from pathlib import Path
27 |
28 | import PIL.Image
29 | import cv2
30 | import datasets
31 | import diffusers
32 | import numpy as np
33 | import onnxruntime
34 | import torch
35 | import torch.nn.functional as F
36 | import torch.utils.checkpoint
37 | import transformers
38 | from PIL import Image
39 | from accelerate import Accelerator
40 | from accelerate.logging import get_logger
41 | from accelerate.utils import ProjectConfiguration, set_seed
42 | from datasets import load_dataset
43 | from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel
44 | from diffusers.loaders import AttnProcsLayers
45 | from diffusers.models.attention_processor import LoRAAttnProcessor
46 | from diffusers.optimization import get_scheduler
47 | from diffusers.utils import check_min_version, is_wandb_available
48 | from diffusers.utils.import_utils import is_xformers_available
49 | from huggingface_hub import create_repo, upload_folder
50 | from modelscope import snapshot_download
51 | from packaging import version
52 | from torchvision import transforms
53 | from tqdm.auto import tqdm
54 | from transformers import CLIPTextModel, CLIPTokenizer
55 |
56 | from facechain.inference import data_process_fn
57 |
58 | # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
59 | check_min_version("0.14.0.dev0")
60 |
61 | logger = get_logger(__name__, log_level="INFO")
62 |
63 |
64 | def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
65 | img_str = ""
66 | for i, image in enumerate(images):
67 | image.save(os.path.join(repo_folder, f"image_{i}.png"))
68 | img_str += f"\n"
69 |
70 | yaml = f"""
71 | ---
72 | license: creativeml-openrail-m
73 | base_model: {base_model}
74 | tags:
75 | - stable-diffusion
76 | - stable-diffusion-diffusers
77 | - text-to-image
78 | - diffusers
79 | - lora
80 | inference: true
81 | ---
82 | """
83 | model_card = f"""
84 | # LoRA text2image fine-tuning - {repo_id}
85 | These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
86 | {img_str}
87 | """
88 | with open(os.path.join(repo_folder, "README.md"), "w") as f:
89 | f.write(yaml + model_card)
90 |
91 |
92 | def softmax(x):
93 | x -= np.max(x, axis=0, keepdims=True)
94 | x = np.exp(x) / np.sum(np.exp(x), axis=0, keepdims=True)
95 | return x
96 |
97 |
98 | def get_rot(image):
99 | model_dir = snapshot_download('Cherrytest/rot_bgr', revision='v1.0.0')
100 | model_path = os.path.join(model_dir, 'rot_bgr.onnx')
101 | ort_session = onnxruntime.InferenceSession(model_path)
102 |
103 | img_cv = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
104 | img_clone = img_cv.copy()
105 | img_np = cv2.resize(img_cv, (224, 224))
106 | img_np = img_np.astype(np.float32)
107 | mean = np.array([103.53, 116.28, 123.675], dtype=np.float32).reshape((1, 1, 3))
108 | norm = np.array([0.01742919, 0.017507, 0.01712475], dtype=np.float32).reshape((1, 1, 3))
109 | img_np = (img_np - mean) * norm
110 | img_tensor = torch.from_numpy(img_np)
111 | img_tensor = img_tensor.unsqueeze(0)
112 | img_nchw = img_tensor.permute(0, 3, 1, 2)
113 | ort_inputs = {ort_session.get_inputs()[0].name: img_nchw.numpy()}
114 | outputs = ort_session.run(None, ort_inputs)
115 | logits = outputs[0].reshape((-1,))
116 | probs = softmax(logits)
117 | rot_idx = np.argmax(probs)
118 | if rot_idx == 1:
119 | print('rot 90')
120 | img_clone = cv2.transpose(img_clone)
121 | img_clone = np.flip(img_clone, 1)
122 | return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
123 | elif rot_idx == 2:
124 | print('rot 180')
125 | img_clone = cv2.flip(img_clone, -1)
126 | return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
127 | elif rot_idx == 3:
128 | print('rot 270')
129 | img_clone = cv2.transpose(img_clone)
130 | img_clone = np.flip(img_clone, 0)
131 | return Image.fromarray(cv2.cvtColor(img_clone, cv2.COLOR_BGR2RGB))
132 | else:
133 | return image
134 |
135 |
136 | def prepare_dataset(instance_images: list, output_dataset_dir):
137 | if not os.path.exists(output_dataset_dir):
138 | os.makedirs(output_dataset_dir)
139 | for i, temp_path in enumerate(instance_images):
140 | image = PIL.Image.open(temp_path)
141 | # image = PIL.Image.open(temp_path.name)
142 | '''
143 | w, h = image.size
144 | max_size = max(w, h)
145 | ratio = 1024 / max_size
146 | new_w = round(w * ratio)
147 | new_h = round(h * ratio)
148 | '''
149 | image = image.convert('RGB')
150 | image = get_rot(image)
151 | # image = image.resize((new_w, new_h))
152 | # image = image.resize((new_w, new_h), PIL.Image.ANTIALIAS)
153 | out_path = f'{output_dataset_dir}/{i:03d}.jpg'
154 | image.save(out_path, format='JPEG', quality=100)
155 |
156 |
157 | def parse_args():
158 | parser = argparse.ArgumentParser(description="Simple example of a training script.")
159 | parser.add_argument(
160 | "--pretrained_model_name_or_path",
161 | type=str,
162 | default=None,
163 | required=True,
164 | help="Path to pretrained model or model identifier.",
165 | )
166 | parser.add_argument(
167 | "--revision",
168 | type=str,
169 | default=None,
170 | required=False,
171 | help="Revision of pretrained model identifier.",
172 | )
173 | parser.add_argument(
174 | "--sub_path",
175 | type=str,
176 | default=None,
177 | required=False,
178 | help="The sub model path of the `pretrained_model_name_or_path`",
179 | )
180 | parser.add_argument(
181 | "--dataset_name",
182 | type=str,
183 | default=None,
184 | help=(
185 | "The data images dir"
186 | ),
187 | )
188 | parser.add_argument(
189 | "--dataset_config_name",
190 | type=str,
191 | default=None,
192 | help="The config of the Dataset, leave as None if there's only one config.",
193 | )
194 | parser.add_argument(
195 | "--train_data_dir",
196 | type=str,
197 | default=None,
198 | help=(
199 | "A folder containing the training data. Folder contents must follow the structure described in"
200 | " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
201 | " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
202 | ),
203 | )
204 | parser.add_argument(
205 | "--output_dataset_name",
206 | type=str,
207 | default=None,
208 | help=(
209 | "The dataset dir after processing"
210 | ),
211 | )
212 | parser.add_argument(
213 | "--image_column", type=str, default="image", help="The column of the dataset containing an image."
214 | )
215 | parser.add_argument(
216 | "--caption_column",
217 | type=str,
218 | default="text",
219 | help="The column of the dataset containing a caption or a list of captions.",
220 | )
221 | parser.add_argument(
222 | "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
223 | )
224 | parser.add_argument(
225 | "--num_validation_images",
226 | type=int,
227 | default=4,
228 | help="Number of images that should be generated during validation with `validation_prompt`.",
229 | )
230 | parser.add_argument(
231 | "--validation_epochs",
232 | type=int,
233 | default=1,
234 | help=(
235 | "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
236 | " `args.validation_prompt` multiple times: `args.num_validation_images`."
237 | ),
238 | )
239 | parser.add_argument(
240 | "--max_train_samples",
241 | type=int,
242 | default=None,
243 | help=(
244 | "For debugging purposes or quicker training, truncate the number of training examples to this "
245 | "value if set."
246 | ),
247 | )
248 | parser.add_argument(
249 | "--output_dir",
250 | type=str,
251 | default="sd-model-finetuned-lora",
252 | help="The output directory where the model predictions and checkpoints will be written.",
253 | )
254 | parser.add_argument(
255 | "--cache_dir",
256 | type=str,
257 | default=None,
258 | help="The directory where the downloaded models and datasets will be stored.",
259 | )
260 | parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
261 | parser.add_argument(
262 | "--resolution",
263 | type=int,
264 | default=512,
265 | help=(
266 | "The resolution for input images, all the images in the train/validation dataset will be resized to this"
267 | " resolution"
268 | ),
269 | )
270 | parser.add_argument(
271 | "--center_crop",
272 | default=False,
273 | action="store_true",
274 | help=(
275 | "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
276 | " cropped. The images will be resized to the resolution first before cropping."
277 | ),
278 | )
279 | parser.add_argument(
280 | "--random_flip",
281 | action="store_true",
282 | help="whether to randomly flip images horizontally",
283 | )
284 | parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
285 |
286 | # lora args
287 | parser.add_argument("--use_peft", action="store_true", help="Whether to use peft to support lora")
288 | parser.add_argument("--lora_r", type=int, default=4, help="Lora rank, only used if use_lora is True")
289 | parser.add_argument("--lora_alpha", type=int, default=32, help="Lora alpha, only used if lora is True")
290 | parser.add_argument("--lora_dropout", type=float, default=0.0, help="Lora dropout, only used if use_lora is True")
291 | parser.add_argument(
292 | "--lora_bias",
293 | type=str,
294 | default="none",
295 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora is True",
296 | )
297 | parser.add_argument(
298 | "--lora_text_encoder_r",
299 | type=int,
300 | default=4,
301 | help="Lora rank for text encoder, only used if `use_lora` and `train_text_encoder` are True",
302 | )
303 | parser.add_argument(
304 | "--lora_text_encoder_alpha",
305 | type=int,
306 | default=32,
307 | help="Lora alpha for text encoder, only used if `use_lora` and `train_text_encoder` are True",
308 | )
309 | parser.add_argument(
310 | "--lora_text_encoder_dropout",
311 | type=float,
312 | default=0.0,
313 | help="Lora dropout for text encoder, only used if `use_lora` and `train_text_encoder` are True",
314 | )
315 | parser.add_argument(
316 | "--lora_text_encoder_bias",
317 | type=str,
318 | default="none",
319 | help="Bias type for Lora. Can be 'none', 'all' or 'lora_only', only used if use_lora and `train_text_encoder` are True",
320 | )
321 |
322 | parser.add_argument(
323 | "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
324 | )
325 | parser.add_argument("--num_train_epochs", type=int, default=100)
326 | parser.add_argument(
327 | "--max_train_steps",
328 | type=int,
329 | default=None,
330 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
331 | )
332 | parser.add_argument(
333 | "--gradient_accumulation_steps",
334 | type=int,
335 | default=1,
336 | help="Number of updates steps to accumulate before performing a backward/update pass.",
337 | )
338 | parser.add_argument(
339 | "--gradient_checkpointing",
340 | action="store_true",
341 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
342 | )
343 | parser.add_argument(
344 | "--learning_rate",
345 | type=float,
346 | default=1e-4,
347 | help="Initial learning rate (after the potential warmup period) to use.",
348 | )
349 | parser.add_argument(
350 | "--scale_lr",
351 | action="store_true",
352 | default=False,
353 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
354 | )
355 | parser.add_argument(
356 | "--lr_scheduler",
357 | type=str,
358 | default="constant",
359 | help=(
360 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
361 | ' "constant", "constant_with_warmup"]'
362 | ),
363 | )
364 | parser.add_argument(
365 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
366 | )
367 | parser.add_argument(
368 | "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
369 | )
370 | parser.add_argument(
371 | "--allow_tf32",
372 | action="store_true",
373 | help=(
374 | "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
375 | " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
376 | ),
377 | )
378 | parser.add_argument(
379 | "--dataloader_num_workers",
380 | type=int,
381 | default=0,
382 | help=(
383 | "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
384 | ),
385 | )
386 | parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
387 | parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
388 | parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
389 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
390 | parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
391 | parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
392 | parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
393 | parser.add_argument(
394 | "--hub_model_id",
395 | type=str,
396 | default=None,
397 | help="The name of the repository to keep in sync with the local `output_dir`.",
398 | )
399 | parser.add_argument(
400 | "--logging_dir",
401 | type=str,
402 | default="logs",
403 | help=(
404 | "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
405 | " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
406 | ),
407 | )
408 | parser.add_argument(
409 | "--mixed_precision",
410 | type=str,
411 | default=None,
412 | choices=["no", "fp16", "bf16"],
413 | help=(
414 | "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
415 | " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
416 | " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
417 | ),
418 | )
419 | parser.add_argument(
420 | "--report_to",
421 | type=str,
422 | default="tensorboard",
423 | help=(
424 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
425 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
426 | ),
427 | )
428 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
429 | parser.add_argument(
430 | "--checkpointing_steps",
431 | type=int,
432 | default=500,
433 | help=(
434 | "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
435 | " training using `--resume_from_checkpoint`."
436 | ),
437 | )
438 | parser.add_argument(
439 | "--checkpoints_total_limit",
440 | type=int,
441 | default=None,
442 | help=(
443 | "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
444 | " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
445 | " for more docs"
446 | ),
447 | )
448 | parser.add_argument(
449 | "--resume_from_checkpoint",
450 | type=str,
451 | default=None,
452 | help=(
453 | "Whether training should be resumed from a previous checkpoint. Use a path saved by"
454 | ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
455 | ),
456 | )
457 | parser.add_argument(
458 | "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
459 | )
460 |
461 | args = parser.parse_args()
462 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
463 | if env_local_rank != -1 and env_local_rank != args.local_rank:
464 | args.local_rank = env_local_rank
465 |
466 | # Sanity checks
467 | if args.dataset_name is None and args.train_data_dir is None:
468 | raise ValueError("Need either a dataset name or a training folder.")
469 |
470 | return args
471 |
472 |
473 | DATASET_NAME_MAPPING = {
474 | "lambdalabs/pokemon-blip-captions": ("image", "text"),
475 | }
476 |
477 |
478 | def main():
479 | args = parse_args()
480 | logging_dir = os.path.join(args.output_dir, args.logging_dir)
481 | shutil.rmtree(args.output_dataset_name, ignore_errors=True)
482 | shutil.rmtree(args.output_dir, ignore_errors=True)
483 | os.makedirs(args.output_dir)
484 | args.dataset_name = [os.path.join(args.dataset_name, x) for x in os.listdir(args.dataset_name)]
485 |
486 | print('All input images:', args.dataset_name)
487 | prepare_dataset(args.dataset_name, args.output_dataset_name)
488 | ## Our data process fn
489 | data_process_fn(input_img_dir=args.output_dataset_name, use_data_process=True)
490 | args.dataset_name = args.output_dataset_name + '_labeled'
491 |
492 | accelerator_project_config = ProjectConfiguration(
493 | total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
494 | )
495 |
496 | accelerator = Accelerator(
497 | gradient_accumulation_steps=args.gradient_accumulation_steps,
498 | mixed_precision=args.mixed_precision,
499 | log_with=args.report_to,
500 | project_config=accelerator_project_config,
501 | )
502 | if args.report_to == "wandb":
503 | if not is_wandb_available():
504 | raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
505 | import wandb
506 |
507 | # Make one log on every process with the configuration for debugging.
508 | logging.basicConfig(
509 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
510 | datefmt="%m/%d/%Y %H:%M:%S",
511 | level=logging.INFO,
512 | )
513 | logger.info(accelerator.state, main_process_only=False)
514 | if accelerator.is_local_main_process:
515 | datasets.utils.logging.set_verbosity_warning()
516 | transformers.utils.logging.set_verbosity_warning()
517 | diffusers.utils.logging.set_verbosity_info()
518 | else:
519 | datasets.utils.logging.set_verbosity_error()
520 | transformers.utils.logging.set_verbosity_error()
521 | diffusers.utils.logging.set_verbosity_error()
522 |
523 | # If passed along, set the training seed now.
524 | if args.seed is not None:
525 | set_seed(args.seed)
526 |
527 | # Handle the repository creation
528 | if accelerator.is_main_process:
529 | if args.output_dir is not None:
530 | os.makedirs(args.output_dir, exist_ok=True)
531 |
532 | if args.push_to_hub:
533 | repo_id = create_repo(
534 | repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
535 | ).repo_id
536 |
537 | ## Download foundation Model
538 | model_dir = snapshot_download(args.pretrained_model_name_or_path, revision=args.revision)
539 |
540 | if args.sub_path is not None and len(args.sub_path) > 0:
541 | model_dir = os.path.join(model_dir, args.sub_path)
542 |
543 | # Load scheduler, tokenizer and models.
544 | noise_scheduler = DDPMScheduler.from_pretrained(model_dir, subfolder="scheduler")
545 | tokenizer = CLIPTokenizer.from_pretrained(
546 | model_dir, subfolder="tokenizer"
547 | )
548 | text_encoder = CLIPTextModel.from_pretrained(
549 | model_dir, subfolder="text_encoder"
550 | )
551 | vae = AutoencoderKL.from_pretrained(model_dir, subfolder="vae")
552 | unet = UNet2DConditionModel.from_pretrained(
553 | model_dir, subfolder="unet"
554 | )
555 |
556 | # For mixed precision training we cast the text_encoder and vae weights to half-precision
557 | # as these models are only used for inference, keeping weights in full precision is not required.
558 | weight_dtype = torch.float32
559 | if accelerator.mixed_precision == "fp16":
560 | weight_dtype = torch.float16
561 | elif accelerator.mixed_precision == "bf16":
562 | weight_dtype = torch.bfloat16
563 |
564 | if args.use_peft:
565 | from peft import LoraConfig, LoraModel, get_peft_model_state_dict, set_peft_model_state_dict
566 |
567 | UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
568 | TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
569 |
570 | config = LoraConfig(
571 | r=args.lora_r,
572 | lora_alpha=args.lora_alpha,
573 | target_modules=UNET_TARGET_MODULES,
574 | lora_dropout=args.lora_dropout,
575 | bias=args.lora_bias,
576 | )
577 | unet = LoraModel(config, unet)
578 |
579 | vae.requires_grad_(False)
580 | if args.train_text_encoder:
581 | config = LoraConfig(
582 | r=args.lora_text_encoder_r,
583 | lora_alpha=args.lora_text_encoder_alpha,
584 | target_modules=TEXT_ENCODER_TARGET_MODULES,
585 | lora_dropout=args.lora_text_encoder_dropout,
586 | bias=args.lora_text_encoder_bias,
587 | )
588 | text_encoder = LoraModel(config, text_encoder)
589 | else:
590 | # freeze parameters of models to save more memory
591 | unet.requires_grad_(False)
592 | vae.requires_grad_(False)
593 |
594 | text_encoder.requires_grad_(False)
595 |
596 | # now we will add new LoRA weights to the attention layers
597 | # It's important to realize here how many attention weights will be added and of which sizes
598 | # The sizes of the attention layers consist only of two different variables:
599 | # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
600 | # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
601 |
602 | # Let's first see how many attention processors we will have to set.
603 | # For Stable Diffusion, it should be equal to:
604 | # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
605 | # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
606 | # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
607 | # => 32 layers
608 |
609 | # Set correct lora layers
610 | lora_attn_procs = {}
611 | for name in unet.attn_processors.keys():
612 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
613 | if name.startswith("mid_block"):
614 | hidden_size = unet.config.block_out_channels[-1]
615 | elif name.startswith("up_blocks"):
616 | block_id = int(name[len("up_blocks.")])
617 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
618 | elif name.startswith("down_blocks"):
619 | block_id = int(name[len("down_blocks.")])
620 | hidden_size = unet.config.block_out_channels[block_id]
621 |
622 | lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
623 |
624 | unet.set_attn_processor(lora_attn_procs)
625 | lora_layers = AttnProcsLayers(unet.attn_processors)
626 |
627 | # Move unet, vae and text_encoder to device and cast to weight_dtype
628 | vae.to(accelerator.device, dtype=weight_dtype)
629 | if not args.train_text_encoder:
630 | text_encoder.to(accelerator.device, dtype=weight_dtype)
631 |
632 | if args.enable_xformers_memory_efficient_attention:
633 | if is_xformers_available():
634 | import xformers
635 |
636 | xformers_version = version.parse(xformers.__version__)
637 | if xformers_version == version.parse("0.0.16"):
638 | logger.warn(
639 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
640 | )
641 | unet.enable_xformers_memory_efficient_attention()
642 | else:
643 | raise ValueError("xformers is not available. Make sure it is installed correctly")
644 |
645 | # Enable TF32 for faster training on Ampere GPUs,
646 | # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
647 | if args.allow_tf32:
648 | torch.backends.cuda.matmul.allow_tf32 = True
649 |
650 | if args.scale_lr:
651 | args.learning_rate = (
652 | args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
653 | )
654 |
655 | # Initialize the optimizer
656 | if args.use_8bit_adam:
657 | try:
658 | import bitsandbytes as bnb
659 | except ImportError:
660 | raise ImportError(
661 | "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
662 | )
663 |
664 | optimizer_cls = bnb.optim.AdamW8bit
665 | else:
666 | optimizer_cls = torch.optim.AdamW
667 |
668 | if args.use_peft:
669 | # Optimizer creation
670 | params_to_optimize = (
671 | itertools.chain(unet.parameters(), text_encoder.parameters())
672 | if args.train_text_encoder
673 | else unet.parameters()
674 | )
675 | optimizer = optimizer_cls(
676 | params_to_optimize,
677 | lr=args.learning_rate,
678 | betas=(args.adam_beta1, args.adam_beta2),
679 | weight_decay=args.adam_weight_decay,
680 | eps=args.adam_epsilon,
681 | )
682 | else:
683 | optimizer = optimizer_cls(
684 | lora_layers.parameters(),
685 | lr=args.learning_rate,
686 | betas=(args.adam_beta1, args.adam_beta2),
687 | weight_decay=args.adam_weight_decay,
688 | eps=args.adam_epsilon,
689 | )
690 |
691 | # Get the datasets: you can either provide your own training and evaluation files (see below)
692 | # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
693 |
694 | # In distributed training, the load_dataset function guarantees that only one local process can concurrently
695 | # download the dataset.
696 | if args.dataset_name is not None:
697 | # Downloading and loading a dataset from the hub.
698 | dataset = load_dataset(
699 | args.dataset_name,
700 | args.dataset_config_name,
701 | cache_dir=args.cache_dir,
702 | )
703 | else:
704 | data_files = {}
705 | if args.train_data_dir is not None:
706 | data_files["train"] = os.path.join(args.train_data_dir, "**")
707 | dataset = load_dataset(
708 | "imagefolder",
709 | data_files=data_files,
710 | cache_dir=args.cache_dir,
711 | )
712 | # See more about loading custom images at
713 | # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
714 |
715 | # Preprocessing the datasets.
716 | # We need to tokenize inputs and targets.
717 | column_names = dataset["train"].column_names
718 |
719 | # 6. Get the column names for input/target.
720 | dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
721 | if args.image_column is None:
722 | image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
723 | else:
724 | image_column = args.image_column
725 | if image_column not in column_names:
726 | raise ValueError(
727 | f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
728 | )
729 | if args.caption_column is None:
730 | caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
731 | else:
732 | caption_column = args.caption_column
733 | if caption_column not in column_names:
734 | raise ValueError(
735 | f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
736 | )
737 |
738 | # Preprocessing the datasets.
739 | # We need to tokenize input captions and transform the images.
740 | def tokenize_captions(examples, is_train=True):
741 | captions = []
742 | for caption in examples[caption_column]:
743 | if isinstance(caption, str):
744 | captions.append(caption)
745 | elif isinstance(caption, (list, np.ndarray)):
746 | # take a random caption if there are multiple
747 | captions.append(random.choice(caption) if is_train else caption[0])
748 | else:
749 | raise ValueError(
750 | f"Caption column `{caption_column}` should contain either strings or lists of strings."
751 | )
752 | inputs = tokenizer(
753 | captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
754 | )
755 | return inputs.input_ids
756 |
757 | # Preprocessing the datasets.
758 | train_transforms = transforms.Compose(
759 | [
760 | transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
761 | transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
762 | transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
763 | transforms.ToTensor(),
764 | transforms.Normalize([0.5], [0.5]),
765 | ]
766 | )
767 |
768 | def preprocess_train(examples):
769 | images = [image.convert("RGB") for image in examples[image_column]]
770 | examples["pixel_values"] = [train_transforms(image) for image in images]
771 | examples["input_ids"] = tokenize_captions(examples)
772 | return examples
773 |
774 | with accelerator.main_process_first():
775 | if args.max_train_samples is not None:
776 | dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
777 | # Set the training transforms
778 | train_dataset = dataset["train"].with_transform(preprocess_train)
779 |
780 | def collate_fn(examples):
781 | pixel_values = torch.stack([example["pixel_values"] for example in examples])
782 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
783 | input_ids = torch.stack([example["input_ids"] for example in examples])
784 | return {"pixel_values": pixel_values, "input_ids": input_ids}
785 |
786 | # DataLoaders creation:
787 | train_dataloader = torch.utils.data.DataLoader(
788 | train_dataset,
789 | shuffle=True,
790 | collate_fn=collate_fn,
791 | batch_size=args.train_batch_size,
792 | num_workers=args.dataloader_num_workers,
793 | )
794 |
795 | # Scheduler and math around the number of training steps.
796 | overrode_max_train_steps = False
797 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
798 | if args.max_train_steps is None:
799 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
800 | overrode_max_train_steps = True
801 |
802 | lr_scheduler = get_scheduler(
803 | args.lr_scheduler,
804 | optimizer=optimizer,
805 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
806 | num_training_steps=args.max_train_steps * accelerator.num_processes,
807 | )
808 |
809 | # Prepare everything with our `accelerator`.
810 | if args.use_peft:
811 | if args.train_text_encoder:
812 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
813 | unet, text_encoder, optimizer, train_dataloader, lr_scheduler
814 | )
815 | else:
816 | unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
817 | unet, optimizer, train_dataloader, lr_scheduler
818 | )
819 | else:
820 | lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
821 | lora_layers, optimizer, train_dataloader, lr_scheduler
822 | )
823 | unet = unet.cuda()
824 | # We need to recalculate our total training steps as the size of the training dataloader may have changed.
825 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
826 | if overrode_max_train_steps:
827 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
828 | # Afterwards we recalculate our number of training epochs
829 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
830 |
831 | # We need to initialize the trackers we use, and also store our configuration.
832 | # The trackers initializes automatically on the main process.
833 | if accelerator.is_main_process:
834 | accelerator.init_trackers("text2image-fine-tune", config=vars(args))
835 |
836 | # Train!
837 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
838 |
839 | logger.info("***** Running training *****")
840 | logger.info(f" Num examples = {len(train_dataset)}")
841 | logger.info(f" Num Epochs = {args.num_train_epochs}")
842 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
843 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
844 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
845 | logger.info(f" Total optimization steps = {args.max_train_steps}")
846 | global_step = 0
847 | first_epoch = 0
848 |
849 | # Potentially load in the weights and states from a previous save
850 | if args.resume_from_checkpoint:
851 | if args.resume_from_checkpoint != "latest":
852 | path = os.path.basename(args.resume_from_checkpoint)
853 | else:
854 | # Get the most recent checkpoint
855 | dirs = os.listdir(args.output_dir)
856 | dirs = [d for d in dirs if d.startswith("checkpoint")]
857 | dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
858 | path = dirs[-1] if len(dirs) > 0 else None
859 |
860 | if path is None:
861 | accelerator.print(
862 | f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
863 | )
864 | args.resume_from_checkpoint = None
865 | else:
866 | accelerator.print(f"Resuming from checkpoint {path}")
867 | accelerator.load_state(os.path.join(args.output_dir, path))
868 | global_step = int(path.split("-")[1])
869 |
870 | resume_global_step = global_step * args.gradient_accumulation_steps
871 | first_epoch = global_step // num_update_steps_per_epoch
872 | resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
873 |
874 | # Only show the progress bar once on each machine.
875 | progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
876 | progress_bar.set_description("Steps")
877 |
878 | for epoch in range(first_epoch, args.num_train_epochs):
879 | unet.train()
880 | if args.train_text_encoder:
881 | text_encoder.train()
882 | train_loss = 0.0
883 | for step, batch in enumerate(train_dataloader):
884 | # Skip steps until we reach the resumed step
885 | if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
886 | if step % args.gradient_accumulation_steps == 0:
887 | progress_bar.update(1)
888 | continue
889 |
890 | with accelerator.accumulate(unet):
891 | # Convert images to latent space
892 | latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
893 | latents = latents * vae.config.scaling_factor
894 |
895 | # Sample noise that we'll add to the latents
896 | noise = torch.randn_like(latents)
897 | bsz = latents.shape[0]
898 | # Sample a random timestep for each image
899 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
900 | timesteps = timesteps.long()
901 |
902 | # Add noise to the latents according to the noise magnitude at each timestep
903 | # (this is the forward diffusion process)
904 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
905 |
906 | # Get the text embedding for conditioning
907 | encoder_hidden_states = text_encoder(batch["input_ids"])[0]
908 |
909 | # Get the target for loss depending on the prediction type
910 | if noise_scheduler.config.prediction_type == "epsilon":
911 | target = noise
912 | elif noise_scheduler.config.prediction_type == "v_prediction":
913 | target = noise_scheduler.get_velocity(latents, noise, timesteps)
914 | else:
915 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
916 |
917 | # Predict the noise residual and compute loss
918 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
919 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
920 |
921 | # Gather the losses across all processes for logging (if we use distributed training).
922 | avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
923 | train_loss += avg_loss.item() / args.gradient_accumulation_steps
924 |
925 | # Backpropagate
926 | accelerator.backward(loss)
927 | if accelerator.sync_gradients:
928 | if args.use_peft:
929 | params_to_clip = (
930 | itertools.chain(unet.parameters(), text_encoder.parameters())
931 | if args.train_text_encoder
932 | else unet.parameters()
933 | )
934 | else:
935 | params_to_clip = lora_layers.parameters()
936 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
937 | optimizer.step()
938 | lr_scheduler.step()
939 | optimizer.zero_grad()
940 |
941 | # Checks if the accelerator has performed an optimization step behind the scenes
942 | if accelerator.sync_gradients:
943 | progress_bar.update(1)
944 | global_step += 1
945 | accelerator.log({"train_loss": train_loss}, step=global_step)
946 | train_loss = 0.0
947 |
948 | if global_step % args.checkpointing_steps == 0:
949 | if accelerator.is_main_process:
950 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
951 | accelerator.save_state(save_path)
952 | logger.info(f"Saved state to {save_path}")
953 |
954 | logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
955 | progress_bar.set_postfix(**logs)
956 |
957 | if global_step >= args.max_train_steps:
958 | break
959 |
960 | if accelerator.is_main_process:
961 | if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
962 | logger.info(
963 | f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
964 | f" {args.validation_prompt}."
965 | )
966 | # create pipeline
967 | pipeline = DiffusionPipeline.from_pretrained(
968 | model_dir,
969 | unet=accelerator.unwrap_model(unet),
970 | text_encoder=accelerator.unwrap_model(text_encoder),
971 | torch_dtype=weight_dtype,
972 | )
973 | pipeline = pipeline.to(accelerator.device)
974 | pipeline.set_progress_bar_config(disable=True)
975 |
976 | # run inference
977 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
978 | images = []
979 | for _ in range(args.num_validation_images):
980 | images.append(
981 | pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
982 | )
983 |
984 | if accelerator.is_main_process:
985 | for tracker in accelerator.trackers:
986 | if tracker.name == "tensorboard":
987 | np_images = np.stack([np.asarray(img) for img in images])
988 | tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
989 | if tracker.name == "wandb":
990 | tracker.log(
991 | {
992 | "validation": [
993 | wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
994 | for i, image in enumerate(images)
995 | ]
996 | }
997 | )
998 |
999 | del pipeline
1000 | torch.cuda.empty_cache()
1001 |
1002 | # Save the lora layers
1003 | accelerator.wait_for_everyone()
1004 | if accelerator.is_main_process:
1005 | if args.use_peft:
1006 | lora_config = {}
1007 | unwarpped_unet = accelerator.unwrap_model(unet)
1008 | state_dict = get_peft_model_state_dict(unwarpped_unet, state_dict=accelerator.get_state_dict(unet))
1009 | lora_config["peft_config"] = unwarpped_unet.get_peft_config_as_dict(inference=True)
1010 | if args.train_text_encoder:
1011 | unwarpped_text_encoder = accelerator.unwrap_model(text_encoder)
1012 | text_encoder_state_dict = get_peft_model_state_dict(
1013 | unwarpped_text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
1014 | )
1015 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
1016 | state_dict.update(text_encoder_state_dict)
1017 | lora_config["text_encoder_peft_config"] = unwarpped_text_encoder.get_peft_config_as_dict(
1018 | inference=True
1019 | )
1020 |
1021 | accelerator.save(state_dict, os.path.join(args.output_dir, f"{global_step}_lora.pt"))
1022 | with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "w") as f:
1023 | json.dump(lora_config, f)
1024 | else:
1025 | unet = unet.to(torch.float32)
1026 | unet.save_attn_procs(args.output_dir)
1027 |
1028 | if args.push_to_hub:
1029 | save_model_card(
1030 | repo_id,
1031 | images=images,
1032 | base_model=model_dir,
1033 | dataset_name=args.dataset_name,
1034 | repo_folder=args.output_dir,
1035 | )
1036 | upload_folder(
1037 | repo_id=repo_id,
1038 | folder_path=args.output_dir,
1039 | commit_message="End of training",
1040 | ignore_patterns=["step_*", "epoch_*"],
1041 | )
1042 |
1043 | # Final inference
1044 | # Load previous pipeline
1045 | pipeline = DiffusionPipeline.from_pretrained(
1046 | model_dir, torch_dtype=weight_dtype
1047 | )
1048 |
1049 | if args.use_peft:
1050 |
1051 | def load_and_set_lora_ckpt(pipe, ckpt_dir, global_step, device, dtype):
1052 | with open(os.path.join(args.output_dir, f"{global_step}_lora_config.json"), "r") as f:
1053 | lora_config = json.load(f)
1054 | print(lora_config)
1055 |
1056 | checkpoint = os.path.join(args.output_dir, f"{global_step}_lora.pt")
1057 | lora_checkpoint_sd = torch.load(checkpoint)
1058 | unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k}
1059 | text_encoder_lora_ds = {
1060 | k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k
1061 | }
1062 |
1063 | unet_config = LoraConfig(**lora_config["peft_config"])
1064 | pipe.unet = LoraModel(unet_config, pipe.unet)
1065 | set_peft_model_state_dict(pipe.unet, unet_lora_ds)
1066 |
1067 | if "text_encoder_peft_config" in lora_config:
1068 | text_encoder_config = LoraConfig(**lora_config["text_encoder_peft_config"])
1069 | pipe.text_encoder = LoraModel(text_encoder_config, pipe.text_encoder)
1070 | set_peft_model_state_dict(pipe.text_encoder, text_encoder_lora_ds)
1071 |
1072 | if dtype in (torch.float16, torch.bfloat16):
1073 | pipe.unet.half()
1074 | pipe.text_encoder.half()
1075 |
1076 | pipe.to(device)
1077 | return pipe
1078 |
1079 | pipeline = load_and_set_lora_ckpt(pipeline, args.output_dir, global_step, accelerator.device, weight_dtype)
1080 |
1081 | else:
1082 | pipeline = pipeline.to(accelerator.device)
1083 | # load attention processors
1084 | pipeline.unet.load_attn_procs(args.output_dir)
1085 |
1086 | # run inference
1087 | if args.seed is not None:
1088 | generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
1089 | else:
1090 | generator = None
1091 | images = []
1092 |
1093 | accelerator.end_training()
1094 |
1095 |
1096 | if __name__ == "__main__":
1097 | main()
1098 |
--------------------------------------------------------------------------------