├── scripts
├── lint.sh
└── ci_test.sh
├── assets
├── app.jpg
├── arch.png
├── init.jpg
├── create.png
├── model.png
├── success.png
├── arch_ctr.png
├── arch_opt.png
├── arch_tran.png
├── pipeline.png
└── prepare_model.png
├── diffusers
├── tests
│ ├── __pycache__
│ │ └── ut_config.cpython-38.pyc
│ ├── ut_config.py
│ ├── run.py
│ └── test_utils
│ │ ├── test_convert.py
│ │ ├── test_image_process.py
│ │ ├── test_io.py
│ │ └── test_blade.py
├── example
│ ├── async_example_post.py
│ ├── async_example_get.py
│ ├── sync_example_base.py
│ └── sync_example_control.py
├── Dockerfile
├── ev_error.py
├── utils
│ ├── convert.py
│ ├── io.py
│ ├── image_process.py
│ └── blade.py
└── lpw_stable_diffusion.py
├── LICENSE
├── doc
├── app.md
├── deploy.md
└── param.md
└── README.md
/scripts/lint.sh:
--------------------------------------------------------------------------------
1 | yapf -r -i ${1}
2 | isort -rc ${1}
3 | flake8 ${1}
4 |
--------------------------------------------------------------------------------
/assets/app.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/app.jpg
--------------------------------------------------------------------------------
/assets/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch.png
--------------------------------------------------------------------------------
/assets/init.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/init.jpg
--------------------------------------------------------------------------------
/assets/create.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/create.png
--------------------------------------------------------------------------------
/assets/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/model.png
--------------------------------------------------------------------------------
/assets/success.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/success.png
--------------------------------------------------------------------------------
/assets/arch_ctr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch_ctr.png
--------------------------------------------------------------------------------
/assets/arch_opt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch_opt.png
--------------------------------------------------------------------------------
/assets/arch_tran.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch_tran.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/prepare_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/prepare_model.png
--------------------------------------------------------------------------------
/diffusers/tests/__pycache__/ut_config.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/diffusers/tests/__pycache__/ut_config.cpython-38.pyc
--------------------------------------------------------------------------------
/diffusers/example/async_example_post.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import json
4 |
5 | from eas_prediction import PredictClient, StringRequest
6 |
7 | if __name__ == '__main__':
8 | client = PredictClient('http://xxx.cn-hangzhou.pai-eas.aliyuncs.com/',
9 | 'service_name')
10 | client.set_token(
11 | 'xxx')
12 |
13 | client.init()
14 |
15 | datas = json.dumps({
16 | 'task_id': 'async',
17 | 'prompt': '一个可爱的女孩',
18 | 'steps': 100,
19 | 'image_num': 3,
20 | 'width': 512,
21 | 'height': 512,
22 | 'seed': '123',
23 | })
24 |
25 | request = StringRequest(datas)
26 |
27 | for x in range(0, 1):
28 | resp = client.predict(request)
29 | print(resp)
30 |
--------------------------------------------------------------------------------
/diffusers/tests/ut_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | UT_ROOT = '/mnt/xinyi.zxy/diffuser/ut_test/'
4 | MODEL_DIR = os.path.join(UT_ROOT, 'models/test_model')
5 | BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model')
6 | CONTROLNET_MODEL_PATH = os.path.join(MODEL_DIR, 'controlnet')
7 | LORA_PATH = os.path.join(MODEL_DIR, 'lora_model/animeoutlineV4_16.safetensors')
8 | LORA_PATH_BIN = os.path.join(MODEL_DIR, 'lora_model/pytorch_lora_weights.bin')
9 |
10 | MODEL_DIR_NEW = os.path.join(UT_ROOT, 'models/new_model')
11 | CKPT_PATH = os.path.join(MODEL_DIR_NEW, 'colorful_v26.safetensors')
12 |
13 | PRETRAIN_DIR = os.path.join(UT_ROOT, 'models/pretrained_models')
14 | SAVE_DIR = os.path.join(UT_ROOT, 'results')
15 | IMAGE_DIR = os.path.join(UT_ROOT, 'images')
16 |
17 | CUSTOM_PIPELINE = './lpw_stable_diffusion.py'
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Alibaba
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/diffusers/example/async_example_get.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | from eas_prediction import QueueClient
4 |
5 | if __name__ == '__main__':
6 | # 创建输出队列对象,⽤于订阅读取输出结果数据。
7 |
8 | sink_queue = QueueClient(
9 | 'http://xxx.cn-hangzhou.pai-eas.aliyuncs.com',
10 | 'service_name/sink')
11 | sink_queue.set_token(
12 | 'xxx')
13 |
14 | sink_queue.init()
15 |
16 | # 从输出队列中watch数据,窗⼝为1。
17 | i = 0
18 | watcher = sink_queue.watch(0, 1, auto_commit=False)
19 | for x in watcher.run():
20 | data = x.data.decode('utf-8')
21 | data = json.loads(data)
22 | print(data.keys())
23 | if data['success']:
24 | print(data['image_url'])
25 | print(data['oss_url'])
26 | print(data['task_id'])
27 | print(data['use_blade'])
28 | print(data['seed'])
29 | print(data['is_nsfw'])
30 | else:
31 | print(data['error_msg'])
32 | # 每次收到⼀个请求数据后处理完成后⼿动commit。
33 | sink_queue.commit(x.index)
34 | i += 1
35 | if i == 10:
36 | break
37 |
38 | # 关闭已经打开的watcher对象,每个客户端实例只允许存在⼀个watcher对象,若watcher对象不关闭,再运⾏时会报错。
39 | watcher.close()
40 |
--------------------------------------------------------------------------------
/scripts/ci_test.sh:
--------------------------------------------------------------------------------
1 | #================================================================
2 | # Copyright (C) 2022 Alibaba Ltd. All rights reserved.
3 | #
4 | #================================================================
5 |
6 | cd ${1}
7 |
8 | # pre-commit run --all-files
9 | # if [ $? -ne 0 ]; then
10 | # echo "linter test failed, please run 'pre-commit run --all-files' to check"
11 | # exit -1
12 | # fi
13 |
14 | mkdir -p logs
15 |
16 | # coverage UT and report
17 | PYTHONPATH=. coverage run tests/run.py | tee logs/ci_test.log
18 | PYTHONPATH=. coverage report | tee logs/ci_report.log
19 | # PYTHONPATH=. coverage html
20 |
21 |
22 | # please add key requirements you think is worth record
23 |
24 | echo "" | tee >> logs/ci_report.log
25 | echo "" | tee >> logs/ci_report.log
26 | echo "Requirements Version
27 | -----------------------------------------------------------" | tee >> logs/ci_report.log
28 | key_requirements=("^torch" "easycv" "easynlp" "easyretrieval"\
29 | "blade" "mmcv" "tokenizer")
30 | pip list >> logs/envlistcache.txt
31 |
32 | for val1 in ${key_requirements[*]}; do
33 | grep $val1 logs/envlistcache.txt | tee >> logs/ci_report.log
34 | done
35 | echo "-----------------------------------------------------------" | tee >> logs/ci_report.log
36 |
37 | rm logs/envlistcache.txt
38 |
--------------------------------------------------------------------------------
/doc/app.md:
--------------------------------------------------------------------------------
1 | ## EAS 自定义processor开发
2 |
3 | 以Python SDK为例,本文档介绍 基于PAI-EAS的自定义processor开发,其核心在于维护主文件[app.py](../diffusers/app.py)。
4 |
5 | ### 模版介绍
6 |
7 | 您需要继承PAI-EAS提供的基类BaseProcessor,实现**initialize()**和**process()**函数。其**process()**函数的输入输出均为BYTES类型,输出参数分别为**response_data**和**status_code**,正常请求**status_code**可以返回**0**或**200**。
8 |
9 | | 函数 | 功能描述 | 参数描述 |
10 | | -------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ |
11 | | init(worker_threads=5, worker_processes=1,endpoint=None) | Processor构建函数。 | **worker_threads**:Worker线程数,默认值为5。**worker_processes**:进程数,默认值为1。如果**worker_processes**为1,则表示单进程多线程模式。如果**worker_processes**大于1,则**worker_threads**只负责读取数据,请求由多进程并发处理,每个进程均会执行**initialize()**函数。**endpoint**:服务监听的Endpoint,通过该参数可以指定服务监听的地址和端口,例如**endpoint=’0.0.0.0:8079’**。 |
12 | | initialize() | Processor初始化函数。服务启动时,进行模型加载等初始化工作。 | 无参数。 |
13 | | process(data) | 请求处理函数。每个请求会将Request Body作为参数传递给**process()**进行处理,并将函数返回值返回至客户端。 | **data**为Request Body,类型为BYTES。返回值也为BYTES类型。 |
14 | | run() | 启动服务。 | 无参数。 |
15 |
16 |
17 |
18 | ### 二次开发
19 |
20 | 我们在[app.py](../diffusers/app.py)已经实现了部分基于diffusers api的功能实现,本节通过流程图的显示,对核心代码流程进行展示,方便您进行二次开发。
21 |
22 | #### initialize()
23 |
24 |
25 | #### process(data)
26 |
27 | 
28 |
--------------------------------------------------------------------------------
/diffusers/tests/run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import unittest
5 | from fnmatch import fnmatch
6 |
7 |
8 | def gather_test_cases(test_dir, pattern, list_tests):
9 | case_list = []
10 | dir_list = []
11 | for dirpath, dirnames, filenames in os.walk(test_dir):
12 | for file in filenames:
13 | if fnmatch(file, pattern):
14 | case_list.append(file)
15 | dir_list.append(dirpath)
16 | test_suite = unittest.TestSuite()
17 |
18 | for dirname, case in zip(dir_list, case_list):
19 | test_case = unittest.defaultTestLoader.\
20 | discover(start_dir=dirname, pattern=case)
21 | test_suite.addTest(test_case)
22 | if hasattr(test_case, '__iter__'):
23 | for subcase in test_case:
24 | if list_tests:
25 | print(subcase)
26 | else:
27 | if list_tests:
28 | print(test_case)
29 |
30 | return test_suite
31 |
32 |
33 | def main(args):
34 |
35 | runner = unittest.TextTestRunner()
36 | test_suite = gather_test_cases(os.path.abspath(args.test_dir),
37 | args.pattern, args.list_tests)
38 | if not args.list_tests:
39 | result = runner.run(test_suite)
40 | if len(result.failures) > 0 or len(result.errors) > 0:
41 | sys.exit(1)
42 |
43 |
44 | if __name__ == '__main__':
45 | parser = argparse.ArgumentParser('test runner')
46 | parser.add_argument('--list_tests',
47 | action='store_true',
48 | help='list all tests')
49 | parser.add_argument('--pattern',
50 | default='test_*.py',
51 | help='test file pattern')
52 | parser.add_argument('--test_dir',
53 | default='tests',
54 | help='directory to be tested')
55 |
56 | args = parser.parse_args()
57 | main(args)
58 |
--------------------------------------------------------------------------------
/diffusers/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM bladedisc/bladedisc:latest-devel-cu113
2 | ENV BLADE_GEMM_TUNE_JIT=1 DISC_ENABLE_PREDEFINED_PDL=true DISC_ENABLE_PACK_QKV=true
3 | RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple
4 |
5 | RUN pip install https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/pytorch/wheels/torch-1.12.0%2Bcu113-cp38-cp38-linux_x86_64.whl
6 |
7 | RUN pip install https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/temp/xformers-0.0.17%2B658ebab.d20230327-cp38-cp38-linux_x86_64.whl &&\
8 | pip install transformers &&\
9 | pip install opencv-python-headless &&\
10 | pip install diffusers==0.15.0 &&\
11 | pip install -U http://eas-data.oss-cn-shanghai.aliyuncs.com/sdk/allspark-0.15-py2.py3-none-any.whl &&\
12 | pip install https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/zxy/diffusers/torch_blade-0.0.1%2B1.12.0.cu113-cp38-cp38-linux_x86_64.whl &&\
13 | pip install safetensors &&\
14 | pip install modelscope &&\
15 | pip install subword_nmt &&\
16 | pip install jieba &&\
17 | pip install sacremoses &&\
18 | pip install tensorflow &&\
19 | pip install omegaconf
20 | RUN pip install scikit-image
21 | RUN pip install https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/zxy/diffusers/controlnet_aux-0.0.3-py3-none-any.whl --no-deps
22 | RUN pip install torchvision==0.13.0
23 | RUN pip install timm
24 | RUN pip install mediapipe
25 | RUN pip cache purge
26 |
27 | RUN apt-get install wget
28 | RUN mkdir /home/pai/
29 |
30 | ADD ./app.py /home/pai/app.py
31 | ADD ./utils /home/pai/utils
32 | ADD ./ev_error.py /home/pai/ev_error.py
33 | ADD ./lpw_stable_diffusion.py /home/pai/lpw_stable_diffusion.py
34 |
35 | RUN wget https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/image/optimized_model.tar.gz \
36 | && tar -xvf optimized_model.tar.gz \
37 | && mv optimized_model /home/pai/optimized_model \
38 | && rm optimized_model.tar.gz
39 |
40 | RUN wget https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/image/pretrained_models.tar.gz \
41 | && tar -xvf pretrained_models.tar.gz \
42 | && mv pretrained_models /home/pai/pretrained_models \
43 | && rm pretrained_models.tar.gz
44 |
--------------------------------------------------------------------------------
/diffusers/tests/test_utils/test_convert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | import unittest
4 |
5 | import torch
6 | from diffusers import StableDiffusionPipeline
7 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import \
8 | download_from_original_stable_diffusion_ckpt
9 | from tests.ut_config import BASE_MODEL_PATH, CKPT_PATH, LORA_PATH
10 | from utils.convert import (convert_base_model_to_diffuser,
11 | convert_lora_safetensor_to_bin, convert_name_to_bin)
12 |
13 |
14 | class TestModelConvert(unittest.TestCase):
15 | def test_convert_lora_safetensor_to_bin(self):
16 | # for .safetensors to load by ori diffuser api (only attn in unet will be loaded)
17 | with tempfile.TemporaryDirectory() as temp_dir:
18 | bin_path = LORA_PATH.replace('.safetensors', '.bin')
19 | bin_path = os.path.join(temp_dir, 'bin_model.pth')
20 | convert_lora_safetensor_to_bin(LORA_PATH, bin_path)
21 | self.assertTrue(os.path.exists(bin_path))
22 |
23 | # Load the converted lora model
24 | pipe = StableDiffusionPipeline.from_pretrained(
25 | BASE_MODEL_PATH,
26 | revision='fp16',
27 | torch_dtype=torch.float16,
28 | safety_checker=None)
29 | pipe.unet.load_attn_procs(bin_path, use_safetensors=False)
30 |
31 | def test_convert_base_model_to_diffuser(self):
32 | # convert .safetensors to multiple dirs
33 | from_safetensors = True
34 |
35 | with tempfile.TemporaryDirectory() as temp_dir:
36 | convert_base_model_to_diffuser(CKPT_PATH, temp_dir,
37 | from_safetensors)
38 | files = os.listdir(temp_dir)
39 | print(files)
40 | need_file_list = [
41 | 'feature_extractor', 'model_index.json', 'safety_checker',
42 | 'scheduler', 'text_encoder', 'tokenizer', 'unet', 'vae'
43 | ]
44 | self.assertTrue(set(need_file_list).issubset(set(files)))
45 |
46 | # Load the converted base model
47 | pipe = StableDiffusionPipeline.from_pretrained(
48 | temp_dir,
49 | revision='fp16',
50 | torch_dtype=torch.float16,
51 | safety_checker=None)
52 |
53 |
54 | if __name__ == '__main__':
55 | unittest.main()
56 |
--------------------------------------------------------------------------------
/diffusers/tests/test_utils/test_image_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | import numpy as np
5 | from PIL import Image
6 |
7 | import torch
8 | # need to be import or an malloc error will be occured by controlnet_aux
9 | from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
10 | from tests.ut_config import IMAGE_DIR, PRETRAIN_DIR, SAVE_DIR
11 | from utils.image_process import (generate_mask_and_img_expand,
12 | preprocess_control, transform_image)
13 |
14 |
15 | class TestImageProcessing(unittest.TestCase):
16 | def setUp(self):
17 | # image for testing
18 | img_path = os.path.join(IMAGE_DIR, 'room.png')
19 | self.image = Image.open(img_path).convert('RGB')
20 |
21 | def test_preprocess_control(self):
22 | # Test with valid process_func
23 | process_func_list = [
24 | 'canny', 'depth', 'hed', 'mlsd', 'normal', 'openpose', 'scribble',
25 | 'seg'
26 | ]
27 | for process_func in process_func_list:
28 | print('Process: {}'.format(process_func))
29 | processed_image = preprocess_control(self.image, process_func,
30 | PRETRAIN_DIR)
31 | self.assertIsInstance(processed_image, Image.Image)
32 | processed_image.save(
33 | os.path.join(SAVE_DIR, 'pre_{}.jpg'.format(process_func)))
34 |
35 | # Test with an invalid process_func
36 | error_message = preprocess_control(self.image, 'invalid_func',
37 | PRETRAIN_DIR)
38 | self.assertIsInstance(error_message, str)
39 |
40 | def test_transform_image(self):
41 | test_params = {0: 'Stretch', 1: 'Crop', 2: 'Pad'}
42 |
43 | expected_sizes = [(1024, 1024), (768, 1024), (1024, 768)]
44 |
45 | for expected_size in expected_sizes:
46 | width, height = expected_size
47 | for mode, mode_type in test_params.items():
48 | print('Process: {}, width: {}, height: {}'.format(
49 | mode, width, height))
50 | transformed_image = transform_image(self.image, width, height,
51 | mode)
52 | self.assertEqual(transformed_image.size, (width, height))
53 | transformed_image.save(
54 | os.path.join(SAVE_DIR, '{}.jpg'.format(mode_type)))
55 |
56 | def test_generate_mask_and_img_expand(self):
57 | expand = (10, 20, 30, 40)
58 |
59 | left, right, up, down = expand
60 | width, height = self.image.size
61 | new_width, new_height = width + left + right, height + up + down
62 |
63 | expand_list = ['copy', 'reflect']
64 | for expand_type in expand_list:
65 | expanded_image, mask = generate_mask_and_img_expand(
66 | self.image, expand, expand_type)
67 | self.assertEqual(expanded_image.size, (new_width, new_height))
68 | self.assertEqual(mask.size, (new_width, new_height))
69 |
70 | expanded_image.save(
71 | os.path.join(SAVE_DIR,
72 | 'expanded_image_{}.jpg'.format(expand_type)))
73 | mask.save(os.path.join(SAVE_DIR,
74 | 'mask_{}.jpg'.format(expand_type)))
75 |
76 |
77 | if __name__ == '__main__':
78 | unittest.main()
79 |
--------------------------------------------------------------------------------
/diffusers/ev_error.py:
--------------------------------------------------------------------------------
1 | class Error(object):
2 | """A class for representing errors that occur in the program.
3 |
4 | Attributes:
5 | _code (int): The error code.
6 | _msg_prefix (str): The prefix for the error message.
7 | _msg (str): The error message.
8 | """
9 | def __init__(self, code, msg_prefix):
10 | """Initialize a new Error instance.
11 |
12 | Args:
13 | code (int): The error code.
14 | msg_prefix (str): The prefix for the error message.
15 | """
16 | self._code = code
17 | self._msg_prefix = msg_prefix
18 | self._msg = ''
19 |
20 | @property
21 | def code(self):
22 | """int: The error code."""
23 | return self._code
24 |
25 | @code.setter
26 | def code(self, code):
27 | """Set the error code.
28 |
29 | Args:
30 | code (int): The new error code.
31 | """
32 | self._code = code
33 |
34 | @property
35 | def msg(self):
36 | """str: The full error message, including the prefix and message."""
37 | return self._msg_prefix + ' - ' + self._msg
38 |
39 | @msg.setter
40 | def msg(self, msg):
41 | """Set the error message.
42 |
43 | Args:
44 | msg (str): The new error message.
45 | """
46 | self._msg = msg
47 |
48 |
49 | class JsonParseError(Error):
50 | """A class for representing JSON parse errors that occur in the program.
51 |
52 | Args:
53 | msg (str, optional): The error message. Defaults to an empty string.
54 | """
55 | def __init__(self, msg=''):
56 | """Initialize a new JsonParseError instance.
57 |
58 | Args:
59 | msg (str, optional): The error message. Defaults to an empty
60 | string.
61 | """
62 | super(JsonParseError, self).__init__(460, 'Json Parse Error')
63 | self.msg = msg
64 |
65 |
66 | class InputFormatError(Error):
67 | """A class for representing input format errors that occur in the program.
68 |
69 | Args:
70 | msg (str, optional): The error message. Defaults to an empty string.
71 | """
72 | def __init__(self, msg=''):
73 | """Initialize a new InputFormatError instance.
74 |
75 | Args:
76 | msg (str, optional): The error message. Defaults to an empty
77 | string.
78 | """
79 | super(InputFormatError, self).__init__(460, 'Input Body Format Error')
80 | self.msg = msg
81 |
82 |
83 | class ImageDecodeError(Error):
84 | """A class for representing image decode errors that occur in the program.
85 |
86 | Args:
87 | msg (str, optional): The error message. Defaults to an empty string.
88 | """
89 | def __init__(self, msg=''):
90 | """Initialize a new ImageDecodeError instance.
91 |
92 | Args:
93 | msg (str, optional): The error message. Defaults to an empty
94 | string.
95 | """
96 | super(ImageDecodeError, self).__init__(462, 'Image Decode Error')
97 | self.msg = msg
98 |
99 |
100 | class UnExpectedServerError(Error):
101 | """A class for representing unexpected server errors that occur in
102 | the program.
103 |
104 | Args:
105 | msg (str, optional): The error message. Defaults to an empty string.
106 | """
107 | def __init__(self, msg=''):
108 | """Initialize a new UnExpectedServerError instance.
109 |
110 | Args:
111 | msg (str, optional): The error message. Defaults to an
112 | empty string.
113 | """
114 | super(UnExpectedServerError,
115 | self).__init__(469, 'Unexpected \
116 | Server Error')
117 | self.msg = msg
118 |
--------------------------------------------------------------------------------
/diffusers/tests/test_utils/test_io.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | from PIL import Image
5 |
6 | import torch
7 | from tests.ut_config import (BASE_MODEL_PATH, CONTROLNET_MODEL_PATH,
8 | CUSTOM_PIPELINE, IMAGE_DIR, PRETRAIN_DIR,
9 | SAVE_DIR)
10 | from utils.image_process import preprocess_control
11 | from utils.io import load_diffusers_pipeline
12 |
13 |
14 | class TestLoadDiffusersPipeline(unittest.TestCase):
15 | def setUp(self):
16 | # hyper parameters
17 | self.prompt = 'a dog'
18 | img_path = os.path.join(IMAGE_DIR, 'image.png')
19 | mask_path = os.path.join(IMAGE_DIR, 'mask.png')
20 |
21 | self.image = Image.open(img_path).convert('RGB')
22 | self.mask = Image.open(mask_path).convert('RGB')
23 | self.num_inference_steps = 20
24 | self.num_images_per_prompt = 1
25 |
26 | self.device = torch.device(
27 | 'cuda') if torch.cuda.is_available() else torch.device('cpu')
28 |
29 | def test_base_mode(self):
30 | mode = 'base'
31 | close_safety = False
32 |
33 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None, None,
34 | self.device, mode, close_safety,
35 | CUSTOM_PIPELINE)
36 |
37 | self.assertIsNotNone(pipe)
38 |
39 | # t2i
40 | with torch.no_grad():
41 | res = pipe.text2img(
42 | prompt=self.prompt,
43 | num_inference_steps=self.num_inference_steps,
44 | num_images_per_prompt=self.num_images_per_prompt)
45 | image = res.images[0]
46 | self.assertIsInstance(image, Image.Image)
47 | image.save(os.path.join(SAVE_DIR, 't2i.jpg'))
48 |
49 | # i2i
50 | with torch.no_grad():
51 | res = pipe.img2img(
52 | prompt=self.prompt,
53 | image=self.image,
54 | num_inference_steps=self.num_inference_steps,
55 | num_images_per_prompt=self.num_images_per_prompt)
56 | image = res.images[0]
57 | self.assertIsInstance(image, Image.Image)
58 | image.save(os.path.join(SAVE_DIR, 'i2i.jpg'))
59 |
60 | # inpaint
61 | with torch.no_grad():
62 | res = pipe.inpaint(
63 | prompt=self.prompt,
64 | image=self.image,
65 | mask_image=self.mask,
66 | num_inference_steps=self.num_inference_steps,
67 | num_images_per_prompt=self.num_images_per_prompt)
68 | image = res.images[0]
69 | self.assertIsInstance(image, Image.Image)
70 | image.save(os.path.join(SAVE_DIR, 'inpaint.jpg'))
71 |
72 | def test_controlnet_mode(self):
73 | mode = 'controlnet'
74 | close_safety = False
75 |
76 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None,
77 | CONTROLNET_MODEL_PATH, self.device,
78 | mode, close_safety, CUSTOM_PIPELINE)
79 |
80 | self.assertIsNotNone(pipe)
81 |
82 | with torch.no_grad():
83 | process_image = preprocess_control(self.image, 'canny',
84 | PRETRAIN_DIR)
85 | res = pipe(prompt=self.prompt,
86 | image=process_image,
87 | num_inference_steps=self.num_inference_steps,
88 | num_images_per_prompt=self.num_images_per_prompt)
89 | image = res.images[0]
90 | self.assertIsInstance(image, Image.Image)
91 | image.save(os.path.join(SAVE_DIR, 'control.jpg'))
92 |
93 | def test_invalid_mode(self):
94 | with self.assertRaises(ValueError):
95 | mode = 'invalid'
96 | close_safety = False
97 | load_diffusers_pipeline(BASE_MODEL_PATH, None, None, self.device,
98 | mode, close_safety, CUSTOM_PIPELINE)
99 |
100 |
101 | if __name__ == '__main__':
102 | unittest.main()
103 |
--------------------------------------------------------------------------------
/doc/deploy.md:
--------------------------------------------------------------------------------
1 | ### 部署文档
2 |
3 | 通过[PAI-EAS](https://help.aliyun.com/document_detail/113696.html?spm=a2c4g.113696.0.0.2c421af2NqTtmW),您可以将模型快速部署为RESTful API,再通过HTTP请求的方式调用该服务。PAI-EAS提供弹性扩缩容、版本控制及资源监控等功能,便于将模型服务应用于业务。
4 |
5 | 本文档介绍了如何在PAI-EAS,使用给定镜像部署基于diffusers的服务。我们使用自定义镜像部署+oss挂载的方式进行服务部署。在二次开发后,您可修改镜像并参考本文档部署自定义的服务。
6 |
7 | #### 前提条件
8 |
9 | 在开始执行操作前,请确认您已完成以下准备工作。
10 |
11 | - 已开通PAI(DSW、EAS)后付费,并创建默认工作空间,具体操作,请参见[开通并创建默认工作空间](https://help.aliyun.com/document_detail/326190.htm#task-2121596)。
12 | - 已创建OSS存储空间(Bucket),用于存储数据集、训练获得的模型文件和配置文件。关于如何创建存储空间,详情请参见[创建存储空间](https://help.aliyun.com/document_detail/31885.htm#task-u3p-3n4-tdb)。
13 |
14 |
15 | #### 步骤一:模型准备
16 |
17 | 准备好所需加载的模型文件,并上传至自己的oss bucket。
18 |
19 | - 将主模型所需文件均存在指定oss bucket的base_model文件夹中【必须】
20 | - ControlNet相关模型放在controlnet文件夹中【可选】
21 | - 讲预置的翻译模型放入translate文件夹中【可选】
22 |
23 |
24 |
25 | - base_model 文件夹支持下列格式:
26 |
27 | - diffusers api支持的多文件夹格式 (单个的.safetensors/.ckpt文件将自动转化)
28 |
29 | 
30 |
31 |
32 | - controlnet文件夹需包含下列文件【使用controlnet时必须】
33 |
34 |
35 |
36 | - optimized_model文件夹需包含以下三个/四个文件(controlnet会生成额外的controlnet.pt)。在部署服务时打开--use_blade开关,Blade模型将自动在后台优化,优化完成后自动进行模型替换
37 |
38 |
39 |
40 | - translate文件夹,您可通过下方链接下载模型文件内置翻译模型。【可选】
41 |
42 | - 模型下载地址:https://www.modelscope.cn/models/damo/nlp_csanmt_translation_zh2en/files
43 | - 参考下载链接:https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/zxy/model/damo_translate.tar.gz
44 |
45 |
46 |
47 |
48 | #### 步骤二:模型部署
49 |
50 | 在模型准备完成后,可利用EAS进行服务的部署。您可以根据需要及时调整部署命令,以部署不同类型的服务。
51 |
52 | 具体地,您需要进行执行以下步骤:
53 |
54 | - **Step1:创建EAS服务**
55 |
56 | 
57 |
58 | - **Step2:选择镜像部署并修改相应的参数**
59 |
60 | - ”对应配置编辑“处,复制下列配置文件,并修改相应的参数
61 |
62 | ```json
63 | {
64 | "cloud": {
65 | "computing": {
66 | "instance_type": "ml.gu7i.c8m30.1-gu30"
67 | }
68 | },
69 | "containers": [
70 | {
71 | "command": "python /home/pai/app.py --func_name base --oss_save_dir oss://converter-offline-installer/zxy/diffusers/0605/ --region hangzhou --use_blade --use_translate",
72 | "image": "eas-registry-vpc.cn-hangzhou.cr.aliyuncs.com/pai-eas/diffuser-inference:2.2.1-py38-cu113-unbuntu2004-blade-public"
73 | }
74 | ],
75 | "features": {
76 | "eas.aliyun.com/extra-ephemeral-storage": "100Gi"
77 | },
78 | "metadata": {
79 | "cpu": 8,
80 | "gpu": 1,
81 | "instance": 1,
82 | "memory": 30000,
83 | "name": "diffuser_base_ch_async",
84 | "rpc": {
85 | "keepalive": 500000
86 | },
87 | "type": "Async"
88 | },
89 | "name": "diffuser_base_ch_async",
90 | "storage": [
91 | {
92 | "mount_path": "/oss",
93 | "oss": {
94 | "path": "oss://converter-offline-installer/zxy/diffusers/0605/base/"
95 | }
96 | },
97 | {
98 | "mount_path": "/result",
99 | "oss": {
100 | "path": "oss://converter-offline-installer/zxy/diffusers/0605/"
101 | }
102 | }
103 | ]
104 | }
105 | ```
106 |
107 | - containers.command 命令
108 |
109 | | key | value | 说明 |
110 | | --------------- | --------------------- | ------------------------------------ |
111 | | --func_name | base | 同时支持t2i/i2i/inpaint/outpaint功能 |
112 | | | controlnet | 进行基于ControlNet的图像编辑 |
113 | | --oss_save_dir | oss://xxx | 保存图片的oss路径 |
114 | | --region | hangzhou/shanghai/... | 服务部署所在区域的拼音全称 |
115 | | --use_blade | — | 启用Blade推理优化功能 |
116 | | --use_translate | — | 加载翻译模型支持中文prompt输入 |
117 |
118 | - containers.image:替换杭州为服务部署所在的reigon
119 |
120 | - metadata.name / name 修改为任意的服务名称
121 |
122 | - metadata.type 设置为 Async表示 部署异步服务 (同步服务无需添加该项)
123 |
124 | - storage 处 添加 模型和生成结果的oss路径挂载
125 |
126 | - mount_path 不可修改 / 或修改镜像中的对应代码
127 |
128 | | mount_path | oss | 说明 |
129 | | ---------- | --------- | -------------------- |
130 | | /oss | oss://xxx | 所挂载的模型路径 |
131 | | /result | oss://xxx | 所挂载的生成图片位置 |
132 |
133 | - 点击“部署”按钮,等待服务部署完成即可。
134 |
135 | 
136 |
137 | - 点击上图的调用信息获得 测试所需的:
138 |
139 | ```json
140 | hosts = 'xxx'
141 | head = {
142 | "Authorization": "xxx"
143 | }
144 | ```
145 |
--------------------------------------------------------------------------------
/diffusers/utils/convert.py:
--------------------------------------------------------------------------------
1 | """
2 | convert differnt model type to the standard diffuser type
3 | """
4 |
5 | import torch
6 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import \
7 | download_from_original_stable_diffusion_ckpt
8 | from safetensors.torch import load_file
9 |
10 | LORA_PREFIX_UNET = 'lora_unet'
11 |
12 |
13 | def convert_name_to_bin(name: str) -> str:
14 | """
15 | Convert a name to binary format.
16 |
17 | Args:
18 | name (str): Name to be converted.
19 |
20 | Returns:
21 | str: Converted name in binary format.
22 | """
23 |
24 | # down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up
25 | new_name = name.replace(LORA_PREFIX_UNET + '_', '')
26 | new_name = new_name.replace('.weight', '')
27 |
28 | # ['down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q', 'lora.up']
29 | parts = new_name.split('.')
30 |
31 | #parts[0] = parts[0].replace('_0', '')
32 | if 'out' in parts[0]:
33 | parts[0] = '_'.join(parts[0].split('_')[:-1])
34 | parts[1] = parts[1].replace('_', '.')
35 |
36 | # ['down', 'blocks', '0', 'attentions', '0', 'transformer', 'blocks', '0', 'attn1', 'to', 'q']
37 | # ['mid', 'block', 'attentions', '0', 'transformer', 'blocks', '0', 'attn2', 'to', 'out']
38 | sub_parts = parts[0].split('_')
39 |
40 | # down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q_
41 | new_sub_parts = ''
42 | for i in range(len(sub_parts)):
43 | if sub_parts[i] in [
44 | 'block', 'blocks', 'attentions'
45 | ] or sub_parts[i].isnumeric() or 'attn' in sub_parts[i]:
46 | if 'attn' in sub_parts[i]:
47 | new_sub_parts += sub_parts[i] + '.processor.'
48 | else:
49 | new_sub_parts += sub_parts[i] + '.'
50 | else:
51 | new_sub_parts += sub_parts[i] + '_'
52 |
53 | # down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.up
54 | new_sub_parts += parts[1]
55 |
56 | new_name = new_sub_parts + '.weight'
57 |
58 | return new_name
59 |
60 |
61 | def convert_lora_safetensor_to_bin(safetensor_path: str,
62 | bin_path: str) -> None:
63 | """
64 | Convert LoRA safetensor file to binary format and save it. (only the attn parameters will be saved)
65 |
66 | Args:
67 | safetensor_path (str): Path to the safetensor file.
68 | bin_path (str): Path to save the binary file.
69 | """
70 |
71 | bin_state_dict = {}
72 | safetensors_state_dict = load_file(safetensor_path)
73 |
74 | for key_safetensors in safetensors_state_dict:
75 | # these if are required by current diffusers' API
76 | # remove these may have negative effect as not all LoRAs are used
77 | if 'text' in key_safetensors:
78 | continue
79 | if 'unet' not in key_safetensors:
80 | continue
81 | if 'transformer_blocks' not in key_safetensors:
82 | continue
83 | if 'ff_net' in key_safetensors or 'alpha' in key_safetensors:
84 | continue
85 | key_bin = convert_name_to_bin(key_safetensors)
86 | bin_state_dict[key_bin] = safetensors_state_dict[key_safetensors]
87 |
88 | torch.save(bin_state_dict, bin_path)
89 |
90 |
91 | def convert_base_model_to_diffuser(checkpoint_path: str,
92 | target_path: str,
93 | from_safetensors: bool = False,
94 | save_half: bool = False,
95 | controlnet: str = None,
96 | to_safetensors: bool = False) -> None:
97 | """
98 | Convert base model to diffuser format and save it.
99 |
100 | Args:
101 | checkpoint_path (str): Path to the checkpoint file.
102 | target_path (str): Path to save the diffuser model.
103 | from_safetensors (bool, optional): Flag indicating whether to load from safetensors.
104 | save_half (bool, optional): Flag indicating whether to save the model in half precision.
105 | controlnet (str, optional): Controlnet model path.
106 | to_safetensors (bool, optional): Flag indicating whether to serialize in safetensors format.
107 | """
108 |
109 | pipe = download_from_original_stable_diffusion_ckpt(
110 | checkpoint_path=checkpoint_path,
111 | from_safetensors=from_safetensors,
112 | controlnet=controlnet)
113 |
114 | if save_half:
115 | pipe.to(torch_dtype=torch.float16)
116 |
117 | if controlnet:
118 | # only save the controlnet model
119 | pipe.controlnet.save_pretrained(target_path,
120 | safe_serialization=to_safetensors)
121 | else:
122 | pipe.save_pretrained(target_path, safe_serialization=to_safetensors)
123 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Service for diffusers api
2 |
3 | ## 项目概述
4 |
5 | 本项目提供了基于PAI-EAS的[diffusers api](https://github.com/huggingface/diffusers) 云服务实现。
6 |
7 | 
8 |
9 |
10 | 经PAI-Blade优化后,显著提升模型推理性能:
11 | (A10 显卡)
12 | | **image size** | **sample steps** | **Time of Pytorch(s)** | **Time of PAI-Blade(s)** | **speedup** | **Pytorch memory(GB)** | **PAI-Blade memory(GB)** |
13 | | -------------- | ---------------- | ---------------------- | ------------------------ | ----------- | ---------------------- | ------------------------ |
14 | | 1024x1024 | 50 | OOM | 14.36 | - | OOM | 5.98 |
15 | | 768x768 | 50 | 14.88 | 6.45 | 2.31X | 15.35 | 5.77 |
16 | | 512x512 | 50 | 6.14 | 2.68 | 2.29X | 6.98 | 5.44 |
17 |
18 |
19 |
20 | ### 功能概览
21 |
22 | 您可以参考本项目实现:
23 |
24 | - 直接基于本项目提供镜像,在PAI-EAS上部署已支持功能的服务。已支持的主流功能如下:
25 |
26 | (在服务部署时通过指定--func_name 进行不同类型的服务部署)
27 |
28 | - base服务 (通过请求时的func_name参数指定)
29 | - 文生图t2i/图生图i2i/图像inpaint/图像outpaint
30 | - controlnet服务
31 | - 基于ControlNet的端到端图像编辑
32 | - ControlNet的在线切换
33 | - LoRA模型的添加/修改
34 | - post时 通过指定lora_path,lora_attn 使用LoRA模型
35 |
36 | - 参考项目源码,快速进行二次开发,部署实现任意的diffusers api服务。
37 |
38 |
39 |
40 | ### Features
41 |
42 | - 基于PAI-Blade 进行推理优化
43 |
44 | - 降低 Text2Img、Img2Img 推理流程的端到端延迟 2.3 倍,同时可显著降低显存占用,超过TensorRT-v8.5等业内SOTA优化手段
45 |
46 | - 扩展和兼容diffusers API和Web UI,以适配社区可下载的模型
47 |
48 | - 支持多个LoRA模型的融合,及sd-script训练的LoRA模型(civitai等第三方网站下载的模型),LoRA模型的在线切换
49 |
50 | - 基于PAI-EAS,提供异步推理及弹性调度能力
51 |
52 | - 内置翻译模型,支持中/英文prompt输入
53 |
54 | - 简单的API实现,方便进行二次开发
55 |
56 |
57 | ### WebUI
58 | ⚠️ 项目提供了EAS上部署SD model服务的自定义实践,对SD-WebUI具有一定的兼容性,相比SD-WebUI的api更适合二次开发。
59 |
60 | 您亦可参考[WebUI使用文档](https://alidocs.dingtalk.com/i/nodes/R1zknDm0WR6XzZ4Lt1aQewElWBQEx5rG),在PAI-EAS部署基于SDWEBUI的前/后端服务。
61 |
62 | PAI-SDWEBUI解决方案为您提供了:
63 |
64 | - 快捷部署,开箱即用
65 | - 获取阿里云账号后,5分钟即可完成单机SD WEBUI部署,获取webui 网址,具备与PC机使用完全一致的体验
66 | - 预置常用插件,热门开源模型下载加速
67 | - 底层资源动态切换,根据需求随时置换GU30、A10、A100等多种GPU
68 | - 企业级功能
69 | - 前后端分离改造,支持多用户对多GPU卡的集群调度(如10个用户共用3张卡)
70 | - 支持按阿里云子账号识别用户身份,隔离使用的模型、输出图片,但共享使用的webui网址、gpu算力
71 | - 支持账单按工作室、集群等方式拆分
72 | - 插件及优化
73 | - 集成PAI-blade性能优化工具,启用后,图片生成速度较原生WEBUI 有2-3倍的提升,较启用xformer优化有20%-60%的提升,且对模型效果无损
74 | - 提供filebrowser插件,支持用户在PC电脑上传下载云端模型、图片
75 | - 提供自研modelzoo插件,支持开源模型下载加速,模型维度的prompt、图片结果、参数的管理功能。
76 | - 支持企业用户使用时,插件集中化管理、用户个性化自定义使用两种形式
77 |
78 |
79 |
80 | ## 快速开始
81 |
82 | ### Step1: 环境搭建。
83 |
84 | 您可使用预置镜像或通过[Dockerfile](./diffusers/Dockerfile)自行搭建。(目前Blade推理优化仅支持A10/A100显卡,及PAI上推出的GU系列显卡)
85 |
86 | ```bash
87 | # EAS部署时 根据您所部署服务的region的不同 region可替换为hangzhou/shanghai等 该镜像下载自带加速 服务部署快
88 | eas-registry-vpc.cn-{region}.cr.aliyuncs.com/pai-eas/diffuser-inference:2.2.1-py38-cu113-unbuntu2004-blade-public
89 |
90 | # 公开镜像
91 | registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:2.2.1-py38-cu113-unbuntu2004-blade-public
92 | ```
93 |
94 | ### Step2: 服务开发/部署。
95 |
96 | - 自定义processor开发
97 |
98 | - 您可修改预测服务主文件[app.py](./diffusers/app.py) 进行自定义processor的开发。[这里](./doc/app.md)我们简单梳理本服务的流程,便于您进行二次开发。
99 |
100 | - 更多PAI-EAS的自定义processor开发请参考[官方文档](https://help.aliyun.com/document_detail/143418.html?spm=a2c4g.130248.0.0.3c316f27SLZN0o)。
101 |
102 |
103 | - 服务部署
104 |
105 |
106 | - 以提供的app.py 为例,请参考 [部署文档](./doc/deploy.md) 在PAI-EAS 进行服务部署。部署完成后,您可以进行服务调用。
107 |
108 |
109 | - 本地服务部署测试可直接运行 (更多命令参数请参考 [部署文档](./doc/deploy.md) 的详细说明)
110 |
111 | ```bash
112 | # 使用 --local_debug 开启本地测试
113 | # 本地测试时 --model_dir 指定模型位置 --save_dir 指定出图位置 (EAS部署时 通过json挂载指定,无需额外输入)
114 | # 本地测试时 --oss_save_dir --region参数用于生成保存的oss链接,可忽略(任意输入)
115 | # --func_name base/controlnet
116 | # 本地调试请加入 export BLADE_AUTH_USE_COUNTING=1 我们提供了一定次数的试用权限。您可再PAI上无限次使用PAI-Blade推理优化。
117 | python /home/pai/app.py --func_name base --oss_save_dir oss://xxx --region hangzhou --model_dir=your_path_to_model --save_dir=your_path_to_output --use_blade --local_debug
118 | ```
119 |
120 |
121 |
122 | ### Step3: 服务调用。
123 |
124 | **接口定义**:[param.md](./doc/param.md)
125 |
126 | **同步调用:**
127 |
128 | base服务调用:[sync_example_base.py](./diffusers/example/sync_example_base.py)
129 |
130 | controlnet服务调用:[sync_example_control.py](./diffusers/example/sync_example_control.py)
131 |
132 | **异步调用:**
133 |
134 | 对于aigc而言,通常推理时间较长,同步返回结果易超时。
135 |
136 | 可部署EAS异步服务,调用时分别 发送请求 和 订阅结果。
137 |
138 | - 服务部署时,设置
139 |
140 | ```json
141 | "metadata": {
142 | "type": "Async"
143 | },
144 | ```
145 |
146 | 以python SDK为例,可参考以下脚本,更多EAS异步推理SDK见:[参考文档](https://help.aliyun.com/document_detail/446942.html)
147 |
148 | 发送请求:[async_example_post.py](./diffusers/example/async_example_post.py)
149 |
150 | 订阅结果:[async_example_get.py](./diffusers/example/async_example_get.py)
151 |
152 | ⚠️:异步推理暂时不支持使用base64进行图像返回。
153 |
--------------------------------------------------------------------------------
/diffusers/example/sync_example_base.py:
--------------------------------------------------------------------------------
1 | """
2 | post example when deploy the service name as base
3 | set --use_transalte and upload the translate model to post chinese prompt
4 | translate model: https://www.modelscope.cn/models/damo/nlp_csanmt_translation_zh2en/summary
5 | """
6 |
7 | import base64
8 | import json
9 | import os
10 | import sys
11 | from io import BytesIO
12 |
13 | import requests
14 | from PIL import Image, PngImagePlugin
15 |
16 | ENCODING = 'utf-8'
17 |
18 | hosts = 'http://xxx.cn-hangzhou.pai-eas.aliyuncs.com/api/predict/service_name'
19 | head = {
20 | 'Authorization': 'xxx'
21 | }
22 |
23 | func_list = ['t2i','i2i','inpaint','outpaint']
24 |
25 | def decode_base64(image_base64, save_file):
26 | img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64)))
27 | img.save(save_file)
28 |
29 |
30 | def select_data(func_name):
31 | if func_name == 't2i':
32 | datas = json.dumps({
33 | 'task_id':
34 | func_name,
35 | 'prompt':
36 | '一只可爱的小猫',
37 | 'func_name':
38 | func_name, # or default is t2i
39 | 'negative_prompt':
40 | 'NSFW',
41 | 'steps':
42 | 50,
43 | 'image_num':
44 | 1,
45 | 'width':
46 | 512,
47 | 'height':
48 | 512,
49 | 'lora_path':
50 | ['lora/animeLineartMangaLike_v30MangaLike.safetensors'],
51 | 'lora_attn':
52 | 1.0
53 | })
54 | elif func_name == 'i2i':
55 | datas = json.dumps({
56 | 'image_link':
57 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/image.png',
58 | # 'image_base64': base64.b64encode(open('/mnt/xinyi.zxy/diffuser/models/bosi2/result/a001_20230602_075912_12.png', 'rb').read()).decode(ENCODING),
59 | 'task_id':
60 | func_name,
61 | 'prompt':
62 | 'a cat',
63 | 'func_name':
64 | func_name,
65 | 'negative_prompt':
66 | 'NSFW',
67 | 'steps':
68 | 50,
69 | 'image_num':
70 | 1,
71 | 'width':
72 | 512,
73 | 'height':
74 | 512,
75 | 'lora_path':
76 | ['lora/animeLineartMangaLike_v30MangaLike.safetensors']
77 | })
78 | elif func_name == 'inpaint':
79 | datas = json.dumps({
80 | 'image_link':
81 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/image.png',
82 | 'mask_link':
83 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/mask.png',
84 | 'task_id':
85 | func_name,
86 | 'prompt':
87 | 'a cat',
88 | 'func_name':
89 | func_name,
90 | 'negative_prompt':
91 | 'NSFW',
92 | 'steps':
93 | 50,
94 | 'image_num':
95 | 1,
96 | 'width':
97 | 512,
98 | 'height':
99 | 512,
100 | 'lora_path':
101 | ['lora/animeLineartMangaLike_v30MangaLike.safetensors']
102 | })
103 | elif func_name == 'outpaint':
104 | datas = json.dumps({
105 | 'image_link':
106 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/image.png',
107 | # 'image_base64': base64.b64encode(open('/mnt/xinyi.zxy/diffuser/models/bosi2/result/a001_20230602_075912_12.png', 'rb').read()).decode(ENCODING),
108 | 'task_id': func_name,
109 | 'prompt': 'a cat',
110 | 'func_name': func_name,
111 | 'negative_prompt': 'NSFW',
112 | 'steps': 50,
113 | 'image_num': 1,
114 | 'width': 512,
115 | 'height': 512,
116 | 'expand': [256, 256, 0, 0], # [left,right,up,down]
117 | 'expand_type': 'copy', # or 'reflect',
118 | 'denoising_strength': 0.6
119 | # 'lora_path': ['path_to_your_lora_model']
120 | })
121 | else:
122 | raise ValueError('Invalid process_func value')
123 |
124 | return datas
125 |
126 |
127 | for func_name in func_list:
128 | datas = select_data(func_name)
129 |
130 | r = requests.post(hosts, data=datas, headers=head)
131 | # r = requests.post("http://0.0.0.0:8000/test", data=datas, timeout=1500)
132 |
133 | data = json.loads(r.content.decode('utf-8'))
134 | print(data.keys())
135 |
136 | if data['success']:
137 | print(data['image_url'])
138 | print(data['oss_url'])
139 | print(data['task_id'])
140 | print(data['use_blade'])
141 | print(data['seed'])
142 | print(data['is_nsfw'])
143 | if 'images_base64' in data.keys():
144 | for i, image_base64 in enumerate(data['images_base64']):
145 | decode_base64(image_base64, './result_{}.png'.format(str(i)))
146 |
147 | else:
148 | print(data['error_msg'])
149 |
--------------------------------------------------------------------------------
/diffusers/utils/io.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Dict, Optional, Tuple, Union
4 |
5 | import requests
6 | from PIL import Image
7 |
8 | import torch
9 | from diffusers import (ControlNetModel, DiffusionPipeline,
10 | DPMSolverMultistepScheduler,
11 | StableDiffusionControlNetPipeline,
12 | StableDiffusionPipeline)
13 |
14 |
15 | def load_diffusers_pipeline(
16 | model_base: str,
17 | lora_path: Optional[str],
18 | controlnet_path: Optional[str],
19 | device: str,
20 | mode: str = 'base',
21 | close_safety: bool = False,
22 | custom_pipeline: str = '/home/pai/lpw_stable_diffusion.py'
23 | ) -> Union[DiffusionPipeline, StableDiffusionControlNetPipeline]:
24 | """
25 | Loads a DiffusionPipeline or StableDiffusionControlNetPipeline with a LoRA checkpoint,
26 | based on the specified mode of operation.
27 |
28 | Args:
29 | model_base (str): The path to the base model checkpoint
30 | lora_path (str, optional): The path to the LoRA checkpoint
31 | controlnet_path (str, optional): The path to the controlnet checkpoint (if mode='controlnet')
32 | device (str): The device where the pipeline will run (e.g. 'cpu' or 'cuda')
33 | mode (str): The mode of operation ('base', or 'controlnet')
34 | close_safety (bool): Whether to disable safety checks in the pipeline
35 | custom_pipeline (str): The path to a custom pipeline script (if any)
36 |
37 | Returns:
38 | Union[DiffusionPipeline, StableDiffusionControlNetPipeline]:
39 | A DiffusionPipeline (LPW) or StableDiffusionControlNetPipeline object with a LoRA checkpoint loaded.
40 | """
41 | if mode == 'base':
42 | if close_safety:
43 | pipe = DiffusionPipeline.from_pretrained(
44 | model_base,
45 | custom_pipeline=custom_pipeline,
46 | torch_dtype=torch.float16,
47 | safety_checker=None)
48 | else:
49 | pipe = DiffusionPipeline.from_pretrained(
50 | model_base,
51 | custom_pipeline=custom_pipeline,
52 | torch_dtype=torch.float16,
53 | )
54 |
55 | elif mode == 'controlnet':
56 | controlnet = ControlNetModel.from_pretrained(controlnet_path,
57 | torch_dtype=torch.float16)
58 |
59 | if close_safety:
60 | pipe = StableDiffusionControlNetPipeline.from_pretrained(
61 | model_base,
62 | controlnet=controlnet,
63 | revision='fp16',
64 | torch_dtype=torch.float16,
65 | safety_checker=None)
66 | else:
67 | pipe = StableDiffusionControlNetPipeline.from_pretrained(
68 | model_base,
69 | controlnet=controlnet,
70 | revision='fp16',
71 | torch_dtype=torch.float16)
72 | else:
73 | raise ValueError(
74 | 'Unrecognized function name: {}. We support base(t2i)/controlnet'.
75 | format(mode))
76 |
77 | pipe.to(device)
78 |
79 | if lora_path is not None:
80 | pipe.unet.load_attn_procs(lora_path, use_safetensors=False)
81 |
82 | return pipe
83 |
84 |
85 | def download_image(image_link: str) -> Image.Image:
86 | """
87 | Download an image from the given image_link and return it as a PIL Image object.
88 |
89 | Args:
90 | image_link (str): The URL of the image to download.
91 |
92 | Returns:
93 | Image.Image: The downloaded image as a PIL Image object.
94 | """
95 | response = requests.get(image_link)
96 | image_name = image_link.split('/')[-1]
97 | with open(image_name, 'ab') as f:
98 | f.write(response.content)
99 | f.flush()
100 | img = Image.open(image_name).convert('RGB')
101 |
102 | return img
103 |
104 |
105 | def get_result_str(result_dict: Optional[Dict[str, Union[int, str]]] = None,
106 | error: Optional[Exception] = None) -> Tuple[bytes, int]:
107 | """
108 | Generates a result string in JSON format based on the provided result dictionary and error.
109 |
110 | Args:
111 | result_dict (Optional[Dict[str, Union[int, str]]]): A dictionary containing the result information.
112 | error (Optional[Exception]): An error object representing any occurred error.
113 |
114 | Returns:
115 | Tuple[bytes, int]: A tuple containing the result string encoded in UTF-8 and the HTTP status code.
116 |
117 | """
118 | result = {}
119 |
120 | if error is not None:
121 | result['success'] = 0
122 | result['error_code'] = error.code
123 | result['error_msg'] = error.msg[:200]
124 | stat = error.code
125 |
126 | if result_dict is not None and 'task_id' in result_dict.keys():
127 | result['task_id'] = result_dict['task_id']
128 |
129 | elif result_dict is not None:
130 | result['success'] = 1
131 | result.update(result_dict)
132 | stat = 200
133 |
134 | result_str = json.dumps(result).encode('utf-8')
135 |
136 | return result_str, stat
137 |
--------------------------------------------------------------------------------
/diffusers/tests/test_utils/test_blade.py:
--------------------------------------------------------------------------------
1 | import os
2 | import unittest
3 |
4 | from PIL import Image
5 |
6 | import torch
7 | import torch_blade
8 | from tests.ut_config import (BASE_MODEL_PATH, CONTROLNET_MODEL_PATH,
9 | CUSTOM_PIPELINE, IMAGE_DIR, MODEL_DIR,
10 | PRETRAIN_DIR, SAVE_DIR)
11 | from utils.blade import load_blade_model, optimize_and_save_blade_model
12 | from utils.image_process import preprocess_control
13 | from utils.io import load_diffusers_pipeline
14 |
15 | # important!or the blade result will be incorrect
16 | os.environ['DISC_ENABLE_DOT_MERGE'] = '0'
17 |
18 |
19 | class TestBladeOptimization(unittest.TestCase):
20 | def setUp(self):
21 | # hyper parameters
22 | self.prompt = 'a dog'
23 | img_path = os.path.join(IMAGE_DIR, 'image.png')
24 | mask_path = os.path.join(IMAGE_DIR, 'mask.png')
25 | self.image = Image.open(img_path).convert('RGB')
26 | self.mask = Image.open(mask_path).convert('RGB')
27 | self.num_inference_steps = 20
28 | self.num_images_per_prompt = 1
29 | self.device = torch.device(
30 | 'cuda') if torch.cuda.is_available() else torch.device('cpu')
31 |
32 | def test_optimize_and_save_blade_model_base(self):
33 | # save and optimize base model
34 | blade_dir = os.path.join(MODEL_DIR, 'optimized_model')
35 | os.makedirs(blade_dir, exist_ok=True)
36 | encoder_path = os.path.join(blade_dir, 'encoder.pt')
37 | unet_path = os.path.join(blade_dir, 'unet.pt')
38 | decoder_path = os.path.join(blade_dir, 'decoder.pt')
39 | controlnet_path = None
40 | mode = 'base'
41 | close_safety = False
42 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None, None,
43 | self.device, mode, close_safety,
44 | CUSTOM_PIPELINE)
45 | optimize_and_save_blade_model(pipe, encoder_path, unet_path,
46 | decoder_path, controlnet_path)
47 |
48 | # save and optimize base model
49 | assert os.path.exists(
50 | encoder_path), f"Encoder path '{encoder_path}' does not exist!"
51 | assert os.path.exists(
52 | unet_path), f"UNet path '{unet_path}' does not exist!"
53 | assert os.path.exists(
54 | decoder_path), f"Decoder path '{decoder_path}' does not exist!"
55 | # load
56 | pipe = load_blade_model(pipe, encoder_path, unet_path, decoder_path,
57 | controlnet_path)
58 | with torch.no_grad():
59 | res = pipe.text2img(
60 | prompt=self.prompt,
61 | num_inference_steps=self.num_inference_steps,
62 | num_images_per_prompt=self.num_images_per_prompt)
63 | image = res.images[0]
64 | self.assertIsInstance(image, Image.Image)
65 | image.save(os.path.join(SAVE_DIR, 't2i_blade.jpg'))
66 |
67 | def test_optimize_and_save_blade_model_controlnet(self):
68 | # save and optimize base model
69 | blade_dir = os.path.join(MODEL_DIR, 'optimized_control_model')
70 | os.makedirs(blade_dir, exist_ok=True)
71 | encoder_path = os.path.join(blade_dir, 'encoder.pt')
72 | unet_path = os.path.join(blade_dir, 'unet.pt')
73 | decoder_path = os.path.join(blade_dir, 'decoder.pt')
74 | controlnet_path = os.path.join(blade_dir, 'controlnet.pt')
75 | mode = 'controlnet'
76 | close_safety = False
77 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None,
78 | CONTROLNET_MODEL_PATH, self.device,
79 | mode, close_safety, CUSTOM_PIPELINE)
80 |
81 | optimize_and_save_blade_model(pipe, encoder_path, unet_path,
82 | decoder_path, controlnet_path)
83 |
84 | # save and optimize base model
85 | assert os.path.exists(
86 | encoder_path), f"Encoder path '{encoder_path}' does not exist!"
87 | assert os.path.exists(
88 | unet_path), f"UNet path '{unet_path}' does not exist!"
89 | assert os.path.exists(
90 | decoder_path), f"Decoder path '{decoder_path}' does not exist!"
91 | assert os.path.exists(
92 | controlnet_path
93 | ), f"ControlNet path '{controlnet_path}' does not exist!"
94 | # load
95 | pipe = load_blade_model(pipe, encoder_path, unet_path, decoder_path,
96 | controlnet_path)
97 | with torch.no_grad():
98 | process_image = preprocess_control(self.image, 'canny',
99 | PRETRAIN_DIR)
100 | res = pipe(prompt=self.prompt,
101 | image=process_image,
102 | num_inference_steps=self.num_inference_steps,
103 | num_images_per_prompt=self.num_images_per_prompt)
104 | image = res.images[0]
105 | self.assertIsInstance(image, Image.Image)
106 | image.save(os.path.join(SAVE_DIR, 'control_blade.jpg'))
107 |
108 |
109 | if __name__ == '__main__':
110 | unittest.main()
111 |
--------------------------------------------------------------------------------
/doc/param.md:
--------------------------------------------------------------------------------
1 | ## 服务输入输出参数说明
2 |
3 | - post 输入参数
4 |
5 | | **参数名** | **说明** | **类型** | **默认值** |
6 | | ------------------ | ------------------------------------------------------------ | ------------------ | ---------------------------------- |
7 | | task_id | 任务ID | string | 必须 |
8 | | prompt | 用户输入的正向提示词 | string | 必须 |
9 | | func_name | post的功能部署服务为base时支持传入t2i/i2i/inpaint 进行功能转换 | string | t2i |
10 | | steps | 用户输入的步数 | int | 50 |
11 | | cfg_scale | guidance_scale | int | 7 |
12 | | denoising_strength | 与原图的合并比例【只在图生图中有效】 | float | 0.55 |
13 | | width | 生成图片宽度 | int | 512 |
14 | | height | 生成图片高度 | int | 512 |
15 | | negative_prompt | 用户输入的负向提示词 | string | “” |
16 | | image_num | 用户输入的图片数量 | int | 1 |
17 | | resize_mode | 调整生成图片缩放方式 0 拉伸 1 裁剪 2 填充 | int | 0 |
18 | | image_link | 用户输入的图片url地址 | string | 图生图,inpaint controlnet必须提供 |
19 | | mask_link | 用户输入的mask url 地址 | string | inpaint 必须提供 |
20 | | image_base64 | 用户输入的图片 base64格式 | base64 | 与image_link二选一 |
21 | | mask_base64 | 用户输入的mask base64格式 | base64 | 与mask_link二选一 |
22 | | use_base64 | 是否返回imagebase64的图像结果 | bool | False |
23 | | lora_attn | lora使用的比例当使用多LoRA融合时支持列表的输入 | floatList[float] | 0.75 |
24 | | lora_path | 需要更新的lora模型在oss挂载路径的相对位置使用多LoRA融合时支持列表的输入 | stringList[string] | 无 |
25 | | controlnet_path | 需要更新的controlnet模型在oss挂载路径的相对位置(huggingface 上可下载的safetensors/bin 文件) | string | 无 |
26 | | process_func | 图像预处理方式,用于生成controlnet的控制图像 | string | 具体支持的列表见下表 |
27 | | expand | outpaint时 各个方向需要填充的像素数[left,right,up,down] | list | 无 |
28 | | expand_type | 原始图像的扩充方式(影响outpaint的出图效果)copy(复制边缘)reflect(镜像翻转边缘) | string | copy |
29 | | save_dir | 传入文件夹的名字文件将存放在部署挂载的result路径中的save_dir文件夹中 | string | result |
30 |
31 | - controlnet支持列表(仅支持下表中的8种格式的端到端处理,对于其他controlnet 您可自行处理后用于控制生成的图像)
32 |
33 | | process_func | 实现功能 | controlnet参考下载地址 |
34 | | ------------ | -------- | ------------------------------------------------------------ |
35 | | canny | 边缘检测 | [边缘检测](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-canny/diffusion_pytorch_model.safetensors) |
36 | | depth | 深度检测 | [深度检测](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-depth/diffusion_pytorch_model.safetensors) |
37 | | hed | 线稿上色 | [线稿上色](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-hed/diffusion_pytorch_model.safetensors) |
38 | | mlsd | 线段识别 | [线段识别](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-mlsd/diffusion_pytorch_model.safetensors) |
39 | | normal | 物体识别 | [物体识别](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--fusing--stable-diffusion-v1-5-controlnet-normal/diffusion_pytorch_model.safetensors) |
40 | | openpose | 姿态识别 | [姿态识别](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-openpose/diffusion_pytorch_model.safetensors) |
41 | | scribble | 线稿上色 | [线稿上色](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-scribble/diffusion_pytorch_model.safetensors) |
42 | | seg | 语义分割 | [语义分割](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-seg/diffusion_pytorch_model.safetensors) |
43 |
44 | - post 输出参数
45 |
46 | | **参数名** | **说明** | **类型** |
47 | | ------------- | ------------------------------------------------------------ | ---------- |
48 | | image_url | 生成图像的公网可访问链接 【在开放acl权限后有效】 | list |
49 | | images_base64 | 生成的图像列表 base64格式(use_base64开启时会返回) | list |
50 | | oss_url | 生成图像的oss地址 | list |
51 | | success | 是否成功 0-失败 1-成功 | int |
52 | | seed | 生成图像的种子 | string |
53 | | task_id | 任务ID | string |
54 | | error_msg | 错误的原因【只在success=0时返回错误】 | string |
55 | | use_blade | 是否使用了blade 进行推理优化【blade模型成功优化后,会在第一次推理时默认使用】 | bool |
56 | | is_nsfw | 用于表示生成图片是否不合法【True为黑图】 | list[bool] |
57 |
--------------------------------------------------------------------------------
/diffusers/example/sync_example_control.py:
--------------------------------------------------------------------------------
1 | """
2 | post example when deploy the service name as controlnet
3 | """
4 |
5 | import base64
6 | import json
7 | import os
8 | import sys
9 | from io import BytesIO
10 |
11 | import requests
12 | from PIL import Image, PngImagePlugin
13 |
14 | ENCODING = 'utf-8'
15 |
16 | hosts = 'http://xxx.cn-hangzhou.pai-eas.aliyuncs.com/api/predict/service_name'
17 | head = {
18 | 'Authorization': 'xxx'
19 | }
20 |
21 |
22 | def decode_base64(image_base64, save_file):
23 | img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64)))
24 | img.save(save_file)
25 |
26 |
27 | def select_data(process_func):
28 | if process_func == 'canny':
29 | datas = json.dumps({
30 | 'task_id': 'canny',
31 | 'steps': 50,
32 | 'image_num': 1,
33 | 'width': 512,
34 | 'height': 512,
35 | 'image_link':
36 | 'https://huggingface.co/lllyasviel/sd-controlnet-hed/resolve/main/images/man.png',
37 | 'prompt': 'man',
38 | 'process_func': 'canny',
39 | })
40 | elif process_func == 'depth':
41 | datas = json.dumps({
42 | 'task_id': 'depth',
43 | 'steps': 50,
44 | 'image_num': 1,
45 | 'width': 512,
46 | 'height': 512,
47 | 'image_link':
48 | 'https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png',
49 | 'prompt': "Stormtrooper's lecture",
50 | 'controlnet_path':
51 | 'new_controlnet/models--lllyasviel--sd-controlnet-depth/diffusion_pytorch_model.safetensors', # use to change the controlnet path
52 | 'process_func': 'depth',
53 | })
54 | elif process_func == 'hed':
55 | datas = json.dumps({
56 | 'task_id': 'hed',
57 | 'steps': 50,
58 | 'image_num': 1,
59 | 'width': 512,
60 | 'height': 512,
61 | 'image_link':
62 | 'https://huggingface.co/lllyasviel/sd-controlnet-hed/resolve/main/images/man.png',
63 | 'prompt': 'oil painting of handsome old man, masterpiece',
64 | 'controlnet_path':
65 | 'new_controlnet/models--lllyasviel--sd-controlnet-hed/diffusion_pytorch_model.safetensors',
66 | 'process_func': 'hed',
67 | })
68 | elif process_func == 'mlsd':
69 | datas = json.dumps({
70 | 'task_id': 'mlsd',
71 | 'steps': 50,
72 | 'image_num': 1,
73 | 'width': 512,
74 | 'height': 512,
75 | 'image_link':
76 | 'https://huggingface.co/lllyasviel/sd-controlnet-mlsd/resolve/main/images/room.png',
77 | 'prompt': 'room',
78 | 'controlnet_path':
79 | 'new_controlnet/models--lllyasviel--sd-controlnet-mlsd/diffusion_pytorch_model.safetensors',
80 | 'process_func': 'mlsd',
81 | })
82 | elif process_func == 'normal':
83 | datas = json.dumps({
84 | 'task_id': 'normal',
85 | 'steps': 50,
86 | 'image_num': 1,
87 | 'width': 512,
88 | 'height': 512,
89 | 'image_link':
90 | 'https://huggingface.co/lllyasviel/sd-controlnet-normal/resolve/main/images/toy.png',
91 | 'prompt': 'cute toy',
92 | 'controlnet_path':
93 | 'new_controlnet/models--fusing--stable-diffusion-v1-5-controlnet-normal/diffusion_pytorch_model.safetensors',
94 | 'process_func': 'normal',
95 | })
96 | elif process_func == 'openpose':
97 | datas = json.dumps({
98 | 'task_id': 'openpose',
99 | 'steps': 50,
100 | 'image_num': 1,
101 | 'width': 512,
102 | 'height': 512,
103 | 'image_link':
104 | 'https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png',
105 | 'prompt': 'chef in the kitchen',
106 | 'controlnet_path':
107 | 'new_controlnet/models--lllyasviel--sd-controlnet-openpose/diffusion_pytorch_model.safetensors',
108 | 'process_func': 'openpose',
109 | })
110 | elif process_func == 'scribble':
111 |
112 | datas = json.dumps({
113 | 'task_id': 'scribble',
114 | 'steps': 50,
115 | 'image_num': 1,
116 | 'width': 512,
117 | 'height': 512,
118 | 'image_link':
119 | 'https://huggingface.co/lllyasviel/sd-controlnet-scribble/resolve/main/images/bag.png',
120 | 'prompt': 'bag',
121 | 'controlnet_path':
122 | 'new_controlnet/models--lllyasviel--sd-controlnet-scribble/diffusion_pytorch_model.safetensors',
123 | 'process_func': 'scribble',
124 | })
125 |
126 | elif process_func == 'seg':
127 | datas = json.dumps({
128 | 'task_id': 'seg',
129 | 'steps': 50,
130 | 'image_num': 1,
131 | 'width': 512,
132 | 'height': 512,
133 | 'image_link':
134 | 'https://huggingface.co/lllyasviel/sd-controlnet-seg/resolve/main/images/house.png',
135 | 'prompt': 'house',
136 | 'controlnet_path':
137 | 'new_controlnet/models--lllyasviel--sd-controlnet-seg/diffusion_pytorch_model.safetensors',
138 | 'process_func': 'seg',
139 | })
140 | else:
141 | raise ValueError('Invalid process_func value')
142 |
143 | return datas
144 |
145 |
146 | process_func_list = [
147 | 'canny', 'depth', 'hed', 'mlsd', 'normal', 'openpose', 'scribble', 'seg'
148 | ]
149 |
150 | for process_func in process_func_list:
151 | datas = select_data(process_func)
152 |
153 | r = requests.post(hosts, data=datas, headers=head)
154 | # r = requests.post("http://0.0.0.0:8000/test", data=datas, timeout=1500)
155 |
156 | data = json.loads(r.content.decode('utf-8'))
157 | print(data.keys())
158 |
159 | if data['success']:
160 | print(data['image_url'])
161 | print(data['oss_url'])
162 | print(data['task_id'])
163 | print(data['use_blade'])
164 | print(data['seed'])
165 | print(data['is_nsfw'])
166 | if 'images_base64' in data.keys():
167 | for i, image_base64 in enumerate(data['images_base64']):
168 | decode_base64(image_base64,
169 | './decode_ldm_base64_{}.png'.format(str(i)))
170 |
171 | else:
172 | print(data['error_msg'])
173 |
--------------------------------------------------------------------------------
/diffusers/utils/image_process.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Tuple, Union
3 |
4 | import cv2
5 | import numpy as np
6 | from PIL import Image, ImageDraw
7 |
8 | import torch
9 | from controlnet_aux import HEDdetector, MLSDdetector, OpenposeDetector
10 | # need to be import or an malloc error will be occured by controlnet_aux
11 | from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
12 | from transformers import (AutoImageProcessor, UperNetForSemanticSegmentation,
13 | pipeline)
14 |
15 |
16 | def canny(image: Image.Image,
17 | pretrain_dir: str,
18 | low_threshold: int = 100,
19 | high_threshold: int = 200) -> Image.Image:
20 | """
21 | Apply the Canny edge detection algorithm to the image.
22 |
23 | Args:
24 | image (Image.Image): The input image.
25 | low_threshold (int): The lower threshold for edge detection (default: 100).
26 | high_threshold (int): The higher threshold for edge detection (default: 200).
27 |
28 | Returns:
29 | Image.Image: The processed image with detected edges.
30 | """
31 | image = np.array(image)
32 | image = cv2.Canny(image, low_threshold, high_threshold)
33 | image = image[:, :, None]
34 | image = np.concatenate([image, image, image], axis=2)
35 | image = Image.fromarray(image)
36 | return image
37 |
38 |
39 | def depth(image: Image.Image, pretrain_dir: str) -> Image.Image:
40 | """
41 | Estimate the depth map of the image using a pre-trained depth estimation model.
42 |
43 | Args:
44 | image (Image.Image): The input image.
45 | pretrain_dir (str): The directory containing the pre-trained models.
46 |
47 | Returns:
48 | Image.Image: The estimated depth map of the image.
49 | """
50 | depth_estimator = pipeline('depth-estimation',
51 | model=os.path.join(pretrain_dir,
52 | 'models--Intel--dpt-large'))
53 | image = depth_estimator(image)['depth']
54 | image = np.array(image)
55 | image = image[:, :, None]
56 | image = np.concatenate([image, image, image], axis=2)
57 | image = Image.fromarray(image)
58 | return image
59 |
60 |
61 | def hed(image: Image.Image, pretrain_dir: str) -> Image.Image:
62 | """
63 | Apply the Holistically-Nested Edge Detection (HED) algorithm to the image.
64 |
65 | Args:
66 | image (Image.Image): The input image.
67 | pretrain_dir (str): The directory containing the pre-trained models.
68 |
69 | Returns:
70 | Image.Image: The processed image with detected edges.
71 | """
72 | hed = HEDdetector.from_pretrained(
73 | os.path.join(pretrain_dir, 'models--lllyasviel--ControlNet'))
74 | image = hed(image)
75 | return image
76 |
77 |
78 | def mlsd(image: Image.Image, pretrain_dir: str) -> Image.Image:
79 | """
80 | Apply MLSD (Multi-Line Segment Detection) model to the input image.
81 |
82 | Args:
83 | image (Image.Image): The input image.
84 | pretrain_dir (str): The directory path where the pre-trained model is located.
85 |
86 | Returns:
87 | Image.Image: The processed image.
88 |
89 | """
90 | mlsd = MLSDdetector.from_pretrained(
91 | os.path.join(pretrain_dir, 'models--lllyasviel--ControlNet'))
92 | image = mlsd(image)
93 | return image
94 |
95 |
96 | def normal(image: Image.Image,
97 | pretrain_dir: str,
98 | bg_threshold: float = 0.4) -> Image.Image:
99 | """
100 | Perform normal estimation on the input image.
101 |
102 | Args:
103 | image (Image.Image): The input image.
104 | pretrain_dir (str): The directory path where the pre-trained model is located.
105 | bg_threshold (float, optional): Background depth threshold. Default is 0.4.
106 |
107 | Returns:
108 | Image.Image: The image with normal estimation.
109 |
110 | """
111 | depth_estimator = pipeline('depth-estimation',
112 | model=os.path.join(
113 | pretrain_dir,
114 | 'models--Intel--dpt-hybrid-midas'))
115 | image = depth_estimator(image)['predicted_depth'][0]
116 | image = image.numpy()
117 |
118 | image_depth = image.copy()
119 | image_depth -= np.min(image_depth)
120 | image_depth /= np.max(image_depth)
121 |
122 | x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3)
123 | x[image_depth < bg_threshold] = 0
124 |
125 | y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3)
126 | y[image_depth < bg_threshold] = 0
127 |
128 | z = np.ones_like(x) * np.pi * 2.0
129 |
130 | image = np.stack([x, y, z], axis=2)
131 | image /= np.sum(image**2.0, axis=2, keepdims=True)**0.5
132 | image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8)
133 | image = Image.fromarray(image)
134 |
135 | return image
136 |
137 |
138 | def openpose(image: Image.Image, pretrain_dir: str) -> Image.Image:
139 | """
140 | Apply OpenPose model to the input image.
141 |
142 | Args:
143 | image (Image.Image): The input image.
144 | pretrain_dir (str): The directory path where the pre-trained model is located.
145 |
146 | Returns:
147 | Image.Image: The processed image.
148 |
149 | """
150 | openpose = OpenposeDetector.from_pretrained(
151 | os.path.join(pretrain_dir, 'models--lllyasviel--ControlNet'))
152 | image = openpose(image)
153 | return image
154 |
155 |
156 | def scribble(image: Image.Image, pretrain_dir: str) -> Image.Image:
157 | """
158 | Apply scribble-based HED (Holistically-Nested Edge Detection) model to the input image.
159 |
160 | Args:
161 | image (Image.Image): The input image.
162 | pretrain_dir (str): The directory path where the pre-trained model is located.
163 |
164 | Returns:
165 | Image.Image: The processed image.
166 |
167 | """
168 | hed = HEDdetector.from_pretrained(pretrained_model_or_path=os.path.join(
169 | pretrain_dir, 'models--lllyasviel--ControlNet'))
170 | image = hed(image, scribble=True)
171 | return image
172 |
173 |
174 | def seg(image: Image.Image, pretrain_dir: str) -> Image.Image:
175 | """
176 | Apply semantic segmentation to the input image.
177 |
178 | Args:
179 | image (Image.Image): The input image.
180 | pretrain_dir (str): The directory path where the pre-trained models are located.
181 |
182 | Returns:
183 | Image.Image: The processed image.
184 |
185 | """
186 |
187 | palette = np.asarray([
188 | [0, 0, 0],
189 | [120, 120, 120],
190 | [180, 120, 120],
191 | [6, 230, 230],
192 | [80, 50, 50],
193 | [4, 200, 3],
194 | [120, 120, 80],
195 | [140, 140, 140],
196 | [204, 5, 255],
197 | [230, 230, 230],
198 | [4, 250, 7],
199 | [224, 5, 255],
200 | [235, 255, 7],
201 | [150, 5, 61],
202 | [120, 120, 70],
203 | [8, 255, 51],
204 | [255, 6, 82],
205 | [143, 255, 140],
206 | [204, 255, 4],
207 | [255, 51, 7],
208 | [204, 70, 3],
209 | [0, 102, 200],
210 | [61, 230, 250],
211 | [255, 6, 51],
212 | [11, 102, 255],
213 | [255, 7, 71],
214 | [255, 9, 224],
215 | [9, 7, 230],
216 | [220, 220, 220],
217 | [255, 9, 92],
218 | [112, 9, 255],
219 | [8, 255, 214],
220 | [7, 255, 224],
221 | [255, 184, 6],
222 | [10, 255, 71],
223 | [255, 41, 10],
224 | [7, 255, 255],
225 | [224, 255, 8],
226 | [102, 8, 255],
227 | [255, 61, 6],
228 | [255, 194, 7],
229 | [255, 122, 8],
230 | [0, 255, 20],
231 | [255, 8, 41],
232 | [255, 5, 153],
233 | [6, 51, 255],
234 | [235, 12, 255],
235 | [160, 150, 20],
236 | [0, 163, 255],
237 | [140, 140, 140],
238 | [250, 10, 15],
239 | [20, 255, 0],
240 | [31, 255, 0],
241 | [255, 31, 0],
242 | [255, 224, 0],
243 | [153, 255, 0],
244 | [0, 0, 255],
245 | [255, 71, 0],
246 | [0, 235, 255],
247 | [0, 173, 255],
248 | [31, 0, 255],
249 | [11, 200, 200],
250 | [255, 82, 0],
251 | [0, 255, 245],
252 | [0, 61, 255],
253 | [0, 255, 112],
254 | [0, 255, 133],
255 | [255, 0, 0],
256 | [255, 163, 0],
257 | [255, 102, 0],
258 | [194, 255, 0],
259 | [0, 143, 255],
260 | [51, 255, 0],
261 | [0, 82, 255],
262 | [0, 255, 41],
263 | [0, 255, 173],
264 | [10, 0, 255],
265 | [173, 255, 0],
266 | [0, 255, 153],
267 | [255, 92, 0],
268 | [255, 0, 255],
269 | [255, 0, 245],
270 | [255, 0, 102],
271 | [255, 173, 0],
272 | [255, 0, 20],
273 | [255, 184, 184],
274 | [0, 31, 255],
275 | [0, 255, 61],
276 | [0, 71, 255],
277 | [255, 0, 204],
278 | [0, 255, 194],
279 | [0, 255, 82],
280 | [0, 10, 255],
281 | [0, 112, 255],
282 | [51, 0, 255],
283 | [0, 194, 255],
284 | [0, 122, 255],
285 | [0, 255, 163],
286 | [255, 153, 0],
287 | [0, 255, 10],
288 | [255, 112, 0],
289 | [143, 255, 0],
290 | [82, 0, 255],
291 | [163, 255, 0],
292 | [255, 235, 0],
293 | [8, 184, 170],
294 | [133, 0, 255],
295 | [0, 255, 92],
296 | [184, 0, 255],
297 | [255, 0, 31],
298 | [0, 184, 255],
299 | [0, 214, 255],
300 | [255, 0, 112],
301 | [92, 255, 0],
302 | [0, 224, 255],
303 | [112, 224, 255],
304 | [70, 184, 160],
305 | [163, 0, 255],
306 | [153, 0, 255],
307 | [71, 255, 0],
308 | [255, 0, 163],
309 | [255, 204, 0],
310 | [255, 0, 143],
311 | [0, 255, 235],
312 | [133, 255, 0],
313 | [255, 0, 235],
314 | [245, 0, 255],
315 | [255, 0, 122],
316 | [255, 245, 0],
317 | [10, 190, 212],
318 | [214, 255, 0],
319 | [0, 204, 255],
320 | [20, 0, 255],
321 | [255, 255, 0],
322 | [0, 153, 255],
323 | [0, 41, 255],
324 | [0, 255, 204],
325 | [41, 0, 255],
326 | [41, 255, 0],
327 | [173, 0, 255],
328 | [0, 245, 255],
329 | [71, 0, 255],
330 | [122, 0, 255],
331 | [0, 255, 184],
332 | [0, 92, 255],
333 | [184, 255, 0],
334 | [0, 133, 255],
335 | [255, 214, 0],
336 | [25, 194, 194],
337 | [102, 255, 0],
338 | [92, 0, 255],
339 | ])
340 |
341 | image_processor = AutoImageProcessor.from_pretrained(
342 | os.path.join(pretrain_dir,
343 | 'models--openmmlab--upernet-convnext-small'))
344 | image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
345 | os.path.join(pretrain_dir,
346 | 'models--openmmlab--upernet-convnext-small'))
347 |
348 | pixel_values = image_processor(image, return_tensors='pt').pixel_values
349 |
350 | with torch.no_grad():
351 | outputs = image_segmentor(pixel_values)
352 |
353 | seg = image_processor.post_process_semantic_segmentation(
354 | outputs, target_sizes=[image.size[::-1]])[0]
355 |
356 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
357 |
358 | for label, color in enumerate(palette):
359 | color_seg[seg == label, :] = color
360 |
361 | color_seg = color_seg.astype(np.uint8)
362 |
363 | image = Image.fromarray(color_seg)
364 |
365 | return image
366 |
367 |
368 | def preprocess_control(image: Image.Image, process_func: str,
369 | pretrain_dir: str) -> Union[str, Image.Image]:
370 | """
371 | Apply the specified image processing function to the input image for controlnet.
372 |
373 | Args:
374 | image (Image.Image): The input image to be processed.
375 | process_func (str): The name of the processing function to be applied.
376 | pretrain_dir (str): The directory containing the pre-trained models.
377 |
378 | Returns:
379 | Union[str, Image.Image]: The processed image if successful, or an error message as a string if the specified process_func is not supported.
380 | """
381 | process_func_dict = {
382 | 'canny': canny,
383 | 'depth': depth,
384 | 'hed': hed,
385 | 'mlsd': mlsd,
386 | 'normal': normal,
387 | 'openpose': openpose,
388 | 'scribble': scribble,
389 | 'seg': seg
390 | }
391 |
392 | if process_func not in process_func_dict:
393 | return 'We only support process functions: {}. But got {}.'.format(
394 | list(process_func_dict.keys()), process_func)
395 |
396 | process_func = process_func_dict[process_func]
397 | processed_image = process_func(image, pretrain_dir)
398 | return processed_image
399 |
400 |
401 | def transform_image(image: Image.Image,
402 | width: int,
403 | height: int,
404 | mode: int = 0) -> Image.Image:
405 | """
406 | Transform the input image to the specified width and height using the specified mode.
407 |
408 | Args:
409 | image (PIL Image object): The image that needs to be transformed.
410 | width (int): The width of the output image.
411 | height (int): The height of the output image.
412 | mode (int, optional): Specifies the mode of image transformation.
413 | 0 - Stretch 拉伸, 1 - Crop 裁剪, 2 - Padding 填充. Defaults to 0. align with webui
414 |
415 | Returns:
416 | PIL Image object: The transformed image.
417 | """
418 |
419 | if mode == 0: # Stretch
420 | image = image.resize((width, height))
421 | elif mode == 1: # Crop
422 | aspect_ratio = float(image.size[0]) / float(image.size[1])
423 | new_aspect_ratio = float(width) / float(height)
424 |
425 | if aspect_ratio > new_aspect_ratio:
426 | # Crop the width
427 | new_width = int(float(height) * aspect_ratio)
428 | left = int((new_width - width) / 2)
429 | right = new_width - left
430 | image = image.resize((new_width, height))
431 | image = image.crop((left, 0, left + width, height))
432 | else:
433 | # Crop the height
434 | new_height = int(float(width) / aspect_ratio)
435 | up = int((new_height - height) / 2)
436 | down = new_height - up
437 | image = image.resize((width, new_height))
438 | image = image.crop((0, up, width, up + height))
439 |
440 | elif mode == 2: # Padding
441 | new_image = Image.new('RGB', (width, height), (255, 255, 255))
442 | new_image.paste(image, ((width - image.size[0]) // 2,
443 | (height - image.size[1]) // 2))
444 | image = new_image
445 |
446 | return image
447 |
448 |
449 | def generate_mask_and_img_expand(
450 | img: Image.Image,
451 | expand: Tuple[int, int, int, int],
452 | expand_type: str = 'copy') -> Tuple[Image.Image, Image.Image]:
453 | """
454 | Generate a mask and an expanded image based on the given image and expand parameters.
455 |
456 | Args:
457 | img (Image.Image): The original image.
458 | expand (Tuple[int, int, int, int]): The expansion values for left, right, up, and down directions.
459 | expand_type (str, optional): The type of expansion ('copy' or 'reflect'). Defaults to 'copy'.
460 |
461 | Returns:
462 | Tuple[Image.Image, Image.Image]: The expanded image and the corresponding mask.
463 | """
464 |
465 | left, right, up, down = expand
466 |
467 | width, height = img.size
468 | new_width, new_height = width + left + right, height + up + down
469 |
470 | # ----------- 1. Create mask where the image is black and the expanded region is white -----------
471 | mask = Image.new('L', (new_width, new_height), 0)
472 | draw = ImageDraw.Draw(mask)
473 | # Add white edge
474 | color = 255
475 | draw.rectangle((0, 0, new_width, up), fill=color) # up
476 | draw.rectangle((0, new_height - down, new_width, new_height),
477 | fill=color) # down
478 | draw.rectangle((0, 0, left, new_height), fill=color) # left
479 | draw.rectangle((new_width - right, 0, new_width, new_height),
480 | fill=color) # right
481 |
482 | # ----------- 2. Expand the image by a copy or reflection operation -----------
483 | # simply use the filled pixel can not generate meaningful image in unified pipeline
484 | # img_expand = Image.new('RGB', (new_width, new_height), (255, 255, 255))
485 | # img_expand.paste(img, (left, up))
486 |
487 | # Convert the image to a NumPy array
488 | image_array = np.array(img)
489 |
490 | # new img
491 | expanded_image_array = np.zeros((new_height, new_width, 3), dtype=np.uint8)
492 |
493 | # copy ori img
494 | expanded_image_array[up:up + height, left:left + width, :] = image_array
495 |
496 | if expand_type == 'reflect':
497 | # Reflect the boundary pixels to the new boundaries
498 | expanded_image_array[:up, left:left + width, :] = np.flipud(
499 | expanded_image_array[up:2 * up, left:left + width, :]) # up
500 | expanded_image_array[up + height:, left:left + width, :] = np.flipud(
501 | expanded_image_array[up + height - 2:up + height - 2 - down:-1,
502 | left:left + width, :]) # down
503 | expanded_image_array[:, :left, :] = np.fliplr(
504 | expanded_image_array[:, left:2 * left, :]) # left
505 | expanded_image_array[:, left + width:, :] = np.fliplr(
506 | expanded_image_array[:, left + width - 2:left + width - 2 -
507 | right:-1, :]) # right
508 |
509 | else:
510 | # Copy the boundary pixels to the new boundaries
511 | expanded_image_array[:up, left:left +
512 | width, :] = image_array[0:1, :, :] # up
513 | expanded_image_array[up + height:, left:left +
514 | width, :] = image_array[height -
515 | 1:height, :, :] # down
516 | expanded_image_array[:, :left, :] = expanded_image_array[:, left:left +
517 | 1, :] # left
518 | expanded_image_array[:, left +
519 | width:, :] = expanded_image_array[:, left +
520 | width - 1:left +
521 | width, :] # right
522 |
523 | # Create a new image from the expanded image array
524 | img_expand = Image.fromarray(expanded_image_array)
525 |
526 | return img_expand, mask
527 |
--------------------------------------------------------------------------------
/diffusers/utils/blade.py:
--------------------------------------------------------------------------------
1 | """
2 | The file is used for blade optimization.
3 | The main function are three fold:
4 | 1. optimize_and_save_blade_model: do blade optimization and save optimized models (it takes about 30min)
5 | 2. load_blade_model: use to load the optimized blade model
6 | 3. load_attn_procs / unload_lora: online change and merge multiple lora weights
7 | """
8 |
9 | from dataclasses import dataclass
10 | from functools import lru_cache
11 | from pathlib import Path
12 | from typing import Dict, List, Optional, Tuple, Union
13 |
14 | import torch
15 | import torch_blade
16 | from safetensors.torch import load_file
17 | from torch import Tensor, nn
18 | from torch_blade import optimize as blade_optimize
19 |
20 | # ------------------------ 1. optimize_and_save_blade_model ------------------------
21 |
22 |
23 | def gen_inputs(
24 | pipe,
25 | use_controlnet: bool = False
26 | ) -> Tuple[Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...], Tensor]:
27 | """
28 | Generate inputs for the specified pipe to forward pipeline.
29 |
30 | Args:
31 | pipe: The diffusion pipeline.
32 | use_controlnet (bool, optional): Flag indicating whether to use controlnet inputs. Default is False.
33 |
34 | Returns:
35 | Tuple[Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...], Tensor]: The generated inputs consisting of:
36 | - encoder_inputs: Tensor of shape (1, text_max_length) containing integer values.
37 | - controlnet_inputs: Tuple of tensors containing inputs for controlnet.
38 | - unet_inputs: Tuple of tensors containing inputs for unet.
39 | - decoder_inputs: Tensor of shape (1, unet_out_channels, 128, 128) containing float values.
40 | """
41 | device = torch.device('cuda:0')
42 | # use bs=1 to trace and optimize
43 | text_max_length = pipe.tokenizer.model_max_length
44 | text_embedding_size = pipe.text_encoder.config.hidden_size
45 | sample_size = pipe.unet.config.sample_size
46 | unet_in_channels = pipe.unet.config.in_channels
47 | unet_out_channels = pipe.unet.config.out_channels
48 |
49 | encoder_inputs = torch.randint(1,
50 | 999, (1, text_max_length),
51 | device=device,
52 | dtype=torch.int64)
53 |
54 | unet_inputs = [
55 | torch.randn(
56 | (2, unet_in_channels, sample_size, sample_size),
57 | dtype=torch.half,
58 | device=device,
59 | ),
60 | torch.tensor(999, device=device, dtype=torch.half),
61 | torch.randn((2, text_max_length, text_embedding_size),
62 | dtype=torch.half,
63 | device=device),
64 | ]
65 |
66 | # controlnet has same inputs as unet, with additional condition
67 | controlnet_inputs = unet_inputs + [
68 | torch.randn((2, 3, 512, 512), dtype=torch.half, device=device),
69 | ]
70 |
71 | decoder_inputs = torch.randn(1,
72 | unet_out_channels,
73 | 128,
74 | 128,
75 | device=device,
76 | dtype=torch.half)
77 |
78 | return encoder_inputs, controlnet_inputs, unet_inputs, decoder_inputs
79 |
80 |
81 | def optimize_and_save_blade_model(
82 | pipe: nn.Module,
83 | encoder_path: str,
84 | unet_path: str,
85 | decoder_path: str,
86 | controlnet_path: Optional[str] = None) -> None:
87 | """
88 | Optimize and save the Blade model.
89 |
90 | Args:
91 | pipe (nn.Module): The pipeline module.
92 | encoder_path (str): The path to save the optimized encoder model.
93 | unet_path (str): The path to save the optimized UNet model.
94 | decoder_path (str): The path to save the optimized decoder model.
95 | controlnet_path (str, optional): The path to save the optimized controlnet model. Default is None.
96 |
97 | Returns:
98 | None
99 | """
100 |
101 | if controlnet_path is not None:
102 | use_controlnet = True
103 | else:
104 | use_controlnet = False
105 |
106 | encoder_inputs, controlnet_inputs, unet_inputs, decoder_inputs = gen_inputs(
107 | pipe, use_controlnet=use_controlnet)
108 |
109 | if not use_controlnet:
110 | # base
111 | class UnetWrapper(torch.nn.Module):
112 | def __init__(self, unet):
113 | super().__init__()
114 | self.unet = unet
115 |
116 | def forward(self, sample, timestep, encoder_hidden_states):
117 | return self.unet(
118 | sample,
119 | timestep,
120 | encoder_hidden_states=encoder_hidden_states,
121 | )
122 |
123 | opt_cfg = torch_blade.Config()
124 | opt_cfg.enable_fp16 = True
125 | opt_cfg.freeze_module = False # allow to change the lora weight when inferring
126 |
127 | from torch_blade.monkey_patch import patch_utils
128 | # change layout for conv layer [NCHW]->[NHWC] for a better inference time
129 | patch_utils.patch_conv2d(pipe.unet)
130 | patch_utils.patch_conv2d(pipe.vae.decoder)
131 |
132 | with opt_cfg, torch.no_grad():
133 | unet = torch.jit.trace(
134 | UnetWrapper(pipe.unet).eval(),
135 | tuple(unet_inputs),
136 | strict=False,
137 | check_trace=False,
138 | )
139 | # unet = torch.jit.trace(pipe.unet, unet_inputs, strict=False, check_trace=False)
140 |
141 | unet = torch_blade.optimize(unet,
142 | model_inputs=tuple(unet_inputs),
143 | allow_tracing=True)
144 |
145 | encoder = torch_blade.optimize(pipe.text_encoder,
146 | model_inputs=encoder_inputs,
147 | allow_tracing=True)
148 |
149 | decoder = torch.jit.trace(pipe.vae.decoder,
150 | decoder_inputs,
151 | strict=False,
152 | check_trace=False)
153 | decoder = torch_blade.optimize(decoder,
154 | model_inputs=decoder_inputs,
155 | allow_tracing=True)
156 |
157 | torch.jit.save(encoder, encoder_path)
158 | torch.jit.save(unet, unet_path)
159 | torch.jit.save(decoder, decoder_path)
160 |
161 | else:
162 | # controlnet
163 | opt_cfg = torch_blade.Config()
164 | opt_cfg.enable_fp16 = True
165 |
166 | class UnetWrapper(torch.nn.Module):
167 | def __init__(self, unet):
168 | super().__init__()
169 | self.unet = unet
170 |
171 | def forward(
172 | self,
173 | sample,
174 | timestep,
175 | encoder_hidden_states,
176 | down_block_additional_residuals,
177 | mid_block_additional_residual,
178 | ):
179 | return self.unet(
180 | sample,
181 | timestep,
182 | encoder_hidden_states=encoder_hidden_states,
183 | down_block_additional_residuals=
184 | down_block_additional_residuals,
185 | mid_block_additional_residual=mid_block_additional_residual,
186 | )
187 |
188 | import functools
189 |
190 | from torch_blade.monkey_patch import patch_utils
191 |
192 | patch_utils.patch_conv2d(pipe.unet)
193 | patch_utils.patch_conv2d(pipe.controlnet)
194 |
195 | pipe.controlnet.forward = functools.partial(pipe.controlnet.forward,
196 | return_dict=False)
197 |
198 | with opt_cfg, torch.no_grad():
199 | encoder = torch_blade.optimize(pipe.text_encoder,
200 | model_inputs=encoder_inputs,
201 | allow_tracing=True)
202 | # decoder = torch.jit.trace(pipe.vae.decoder, decoder_inputs, strict=False, check_trace=False)
203 | decoder = torch_blade.optimize(pipe.vae.decoder,
204 | model_inputs=decoder_inputs,
205 | allow_tracing=True)
206 |
207 | # not freeze to load other weights
208 | opt_cfg.freeze_module = False
209 |
210 | controlnet = torch.jit.trace(pipe.controlnet,
211 | tuple(controlnet_inputs),
212 | strict=False,
213 | check_trace=False)
214 | controlnet = torch_blade.optimize(
215 | controlnet,
216 | model_inputs=tuple(controlnet_inputs),
217 | allow_tracing=True)
218 | # add controlnet outputs to unet inputs
219 | down_block_res_samples, mid_block_res_sample = controlnet(
220 | *controlnet_inputs)
221 |
222 | device = torch.device('cuda:0')
223 |
224 | unet_inputs += [
225 | tuple(down_block_res_samples),
226 | mid_block_res_sample,
227 | ]
228 |
229 | unet = torch.jit.trace(
230 | UnetWrapper(pipe.unet).eval(),
231 | tuple(unet_inputs),
232 | strict=False,
233 | check_trace=False,
234 | )
235 |
236 | unet = torch_blade.optimize(unet,
237 | model_inputs=tuple(unet_inputs),
238 | allow_tracing=True)
239 |
240 | torch.jit.save(encoder, encoder_path)
241 | torch.jit.save(controlnet, controlnet_path)
242 | torch.jit.save(unet, unet_path)
243 | torch.jit.save(decoder, decoder_path)
244 |
245 |
246 | # ------------------------ 2. load_blade_model ------------------------
247 |
248 |
249 | def load_blade_model(pipe: nn.Module,
250 | encoder_path: str,
251 | unet_path: str,
252 | decoder_path: str,
253 | controlnet_path: Optional[str] = None) -> nn.Module:
254 | """
255 | Load the Blade model.
256 |
257 | Args:
258 | pipe (nn.Module): The pipeline module.
259 | encoder_path (str): The path to the optimized encoder model.
260 | unet_path (str): The path to the optimized UNet model.
261 | decoder_path (str): The path to the optimized decoder model.
262 | controlnet_path (str, optional): The path to the optimized controlnet model. Default is None.
263 |
264 | Returns:
265 | nn.Module: The loaded Blade model.
266 | """
267 |
268 | if controlnet_path is not None:
269 | use_controlnet = True
270 | else:
271 | use_controlnet = False
272 |
273 | encoder_inputs, controlnet_inputs, unet_inputs, decoder_inputs = gen_inputs(
274 | pipe, use_controlnet=use_controlnet)
275 |
276 | # encoder = torch.jit.load(encoder_path).eval().cuda()
277 | unet = torch.jit.load(unet_path).eval().cuda()
278 | decoder = torch.jit.load(decoder_path).eval().cuda()
279 |
280 | # load weights from current model
281 | if not use_controlnet:
282 | unet_state_dict = {
283 | 'unet.' + k: v
284 | for k, v in pipe.unet.state_dict().items()
285 | }
286 | _, unexpected = unet.load_state_dict(unet_state_dict, strict=False)
287 | print(unexpected)
288 |
289 | _, unexpected = decoder.load_state_dict(pipe.vae.decoder.state_dict(),
290 | strict=False)
291 | print(unexpected)
292 |
293 | # warmup
294 | # encoder(encoder_inputs)
295 | unet(*unet_inputs)
296 | decoder(*decoder_inputs)
297 |
298 | patch_conv_weights(unet)
299 | patch_conv_weights(decoder)
300 |
301 | if use_controlnet:
302 | controlnet = torch.jit.load(controlnet_path).eval().cuda()
303 |
304 | @dataclass
305 | class UNet2DConditionOutput:
306 | sample: torch.FloatTensor
307 |
308 | class TracedEncoder(torch.nn.Module):
309 | def __init__(self):
310 | super().__init__()
311 | self.config = pipe.text_encoder.config
312 | self.device = pipe.text_encoder.device
313 | self.dtype = torch.half
314 |
315 | def forward(self, input_ids, **kwargs):
316 | embeddings = encoder(input_ids.long())
317 | return [embeddings['last_hidden_state']]
318 |
319 | if use_controlnet:
320 | # controlnet
321 | class TracedControlNet(torch.nn.Module):
322 | def __init__(self):
323 | super().__init__()
324 | self.controlnet_conditioning_channel_order = 'rgb'
325 |
326 | def forward(self, sample, timestep, encoder_hidden_states,
327 | **kwargs):
328 | if self.controlnet_conditioning_channel_order == 'rgb':
329 | return controlnet(sample.half(), timestep.half(),
330 | encoder_hidden_states.half(),
331 | kwargs['controlnet_cond'])
332 | else:
333 | return controlnet(
334 | sample.half(), timestep.half(),
335 | encoder_hidden_states.half(),
336 | torch.flip(kwargs['controlnet_cond'], dims=[1]))
337 |
338 | def load_state_dict(self, state_dict, strict=False):
339 | _, unexpected = controlnet.load_state_dict(state_dict,
340 | strict=strict)
341 | if unexpected:
342 | print(
343 | f'load controlNet with unexpected keys: {unexpected}')
344 | return
345 |
346 | def state_dict(self):
347 | return controlnet.state_dict()
348 |
349 | def set_channel_order(self, channel_order):
350 | self.controlnet_conditioning_channel_order = channel_order
351 |
352 | class TracedUNet(torch.nn.Module):
353 | def __init__(self):
354 | super().__init__()
355 | self.config = pipe.unet.config
356 | self.in_channels = pipe.unet.in_channels
357 | self.device = pipe.unet.device
358 | self.device = pipe.unet.device
359 | self.lora_weights = {}
360 | self.cur_lora = {}
361 |
362 | def state_dict(self):
363 | return unet.state_dict()
364 |
365 | def forward(self, latent_model_input, t, encoder_hidden_states,
366 | **kwargs):
367 | if kwargs.get('down_block_additional_residuals', None) is None:
368 | kwargs['down_block_additional_residuals'] = tuple([
369 | torch.tensor(
370 | [[[[0.0]]]], device=self.device, dtype=torch.half)
371 | ] * 13)
372 | if kwargs.get('mid_block_additional_residual', None) is None:
373 | kwargs['mid_block_additional_residual'] = torch.tensor(
374 | [[[[0.0]]]], device=self.device, dtype=torch.half)
375 |
376 | sample = unet(
377 | latent_model_input.half(),
378 | t.half(),
379 | encoder_hidden_states.half(),
380 | kwargs['down_block_additional_residuals'],
381 | kwargs['mid_block_additional_residual'],
382 | )['sample']
383 |
384 | return UNet2DConditionOutput(sample=sample)
385 |
386 | else:
387 | # base model
388 | class TracedUNet(torch.nn.Module):
389 | def __init__(self):
390 | super().__init__()
391 | self.config = pipe.unet.config
392 | self.in_channels = pipe.unet.in_channels
393 | self.device = pipe.unet.device
394 | self.lora_weights = {}
395 | self.cur_lora = {}
396 |
397 | def state_dict(self):
398 | return unet.state_dict()
399 |
400 | def forward(self, latent_model_input, t, encoder_hidden_states,
401 | **kwargs):
402 | sample = unet(latent_model_input.half(), t.half(),
403 | encoder_hidden_states.half())['sample']
404 | return UNet2DConditionOutput(sample=sample)
405 |
406 | class TracedDecoder(torch.nn.Module):
407 | def forward(self, input):
408 | return decoder(input.half())
409 |
410 | # pipe.text_encoder = TracedEncoder() # lead to incorrect output
411 |
412 | if use_controlnet:
413 | controlnet_wrapper = TracedControlNet()
414 | pipe.controlnet.forward = controlnet_wrapper.forward
415 | pipe.controlnet.load_state_dict = controlnet_wrapper.load_state_dict
416 | pipe.controlnet.state_dict = controlnet_wrapper.state_dict
417 | pipe.controlnet.set_channel_order = controlnet_wrapper.set_channel_order
418 |
419 | pipe.unet = TracedUNet()
420 | pipe.vae.decoder = TracedDecoder()
421 |
422 | return pipe
423 |
424 |
425 | # ------------------------ 3. load_attn_procs / unload_lora: online change and merge multiple lora weights ------------------
426 |
427 | LORA_PREFIX_TEXT_ENCODER, LORA_PREFIX_UNET = 'lora_te', 'lora_unet'
428 |
429 |
430 | @lru_cache(maxsize=32)
431 | def load_lora_and_mul(
432 | lora_path: str,
433 | dtype: torch.dtype) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
434 | """
435 | Load and process LoRA weights from the specified path.
436 |
437 | Args:
438 | lora_path (str): Path to the LoRA weights file.
439 | dtype (torch.dtype): Desired data type for the processed weights.
440 |
441 | Returns:
442 | Tuple[Dict[str, Tensor], Dict[str, Tensor]]: A tuple containing two dictionaries:
443 | - text_encoder_state_dict: Dictionary containing the state dictionary for the text encoder.
444 | - unet_state_dict: Dictionary containing the state dictionary for the UNet.
445 | """
446 | if lora_path.endswith('.safetensors'):
447 | # lora model trained by webui script (e.g., model from civitai)
448 |
449 | state_dict = load_file(lora_path)
450 |
451 | visited, unet_state_dict, text_encoder_state_dict = [], {}, {}
452 |
453 | # directly update weight in diffusers model
454 | for key in state_dict:
455 | # it is suggested to print out the key, it usually will be something like below
456 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
457 |
458 | # as we have set the alpha beforehand, so just skip
459 | if '.alpha' in key or key in visited:
460 | continue
461 |
462 | if 'text' in key:
463 | diffusers_key = key.split(
464 | '.')[0].split(LORA_PREFIX_TEXT_ENCODER + '_')[-1].replace(
465 | '_', '.').replace('text.model', 'text_model').replace(
466 | '.proj', '_proj').replace('self.attn',
467 | 'self_attn') + '.weight'
468 | curr_state_dict = text_encoder_state_dict
469 | else:
470 | diffusers_key = 'unet.' + key.split('.')[0].split(
471 | LORA_PREFIX_UNET + '_')[-1].replace('_', '.').replace(
472 | '.block', '_block').replace('to.', 'to_').replace(
473 | 'proj.', 'proj_') + '.weight'
474 | curr_state_dict = unet_state_dict
475 |
476 | pair_keys = []
477 | if 'lora_down' in key:
478 | alpha = state_dict.get(
479 | key.replace('lora_down.weight', 'alpha'), None)
480 | pair_keys.append(key.replace('lora_down', 'lora_up'))
481 | pair_keys.append(key)
482 | else:
483 | alpha = state_dict.get(key.replace('lora_up.weight', 'alpha'),
484 | None)
485 | pair_keys.append(key)
486 | pair_keys.append(key.replace('lora_up', 'lora_down'))
487 |
488 | # update weight
489 | if alpha:
490 | alpha = alpha.item() / state_dict[pair_keys[0]].shape[1]
491 | else:
492 | alpha = 0.75
493 |
494 | if len(state_dict[pair_keys[0]].shape) == 4:
495 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(
496 | torch.float32)
497 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(
498 | 2).to(torch.float32)
499 | if len(weight_up.shape) == len(weight_down.shape):
500 | curr_state_dict[diffusers_key] = alpha * torch.mm(
501 | weight_up.cuda(),
502 | weight_down.cuda()).unsqueeze(2).unsqueeze(3).to(dtype)
503 | else:
504 | curr_state_dict[diffusers_key] = alpha * torch.einsum(
505 | 'a b, b c h w -> a c h w', weight_up.cuda(),
506 | weight_down.cuda()).to(dtype)
507 | else:
508 | weight_up = state_dict[pair_keys[0]].to(torch.float32)
509 | weight_down = state_dict[pair_keys[1]].to(torch.float32)
510 | curr_state_dict[diffusers_key] = alpha * torch.mm(
511 | weight_up.cuda(), weight_down.cuda()).to(dtype)
512 |
513 | # update visited list
514 | for item in pair_keys:
515 | visited.append(item)
516 | return text_encoder_state_dict, unet_state_dict
517 | else:
518 | # model trained by diffusers api (lora attn only in unet)
519 | state_dict = torch.load(lora_path)
520 | multied_state_dict = {}
521 | for k, v in state_dict.items():
522 | if '_lora.up.weight' in k:
523 | up_weight = v
524 | new_k = 'unet.' + k.replace(
525 | '_lora.up.weight', '.weight').replace(
526 | 'processor.', '').replace('to_out', 'to_out.0')
527 | down_weight = state_dict[k.replace('up.weight', 'down.weight')]
528 | # xxxx_lora.up.weight
529 | multied_state_dict[new_k] = torch.matmul(
530 | up_weight.cuda(), down_weight.cuda())
531 | return {}, multied_state_dict
532 |
533 |
534 | def patch_conv_weights(model: nn.Module) -> nn.Module:
535 | """
536 | Patch the convolutional weights in the model to be compatible with NHWC format.
537 | For model acceleration in blade optimization
538 |
539 | Args:
540 | model (nn.Module): The model to be patched.
541 |
542 | Returns:
543 | nn.Module: The patched model.
544 | """
545 | origin_state_dict = model.state_dict()
546 | state_dict = {}
547 | for k, v in origin_state_dict.items():
548 | if k.endswith('_nhwc'):
549 | state_dict[k] = origin_state_dict[k[:-5]].permute([0, 2, 3, 1])
550 | model.load_state_dict(state_dict, strict=False)
551 | return model
552 |
553 |
554 | def merge_lora_weights(origin: Dict[str, Tensor], to_merge: Dict[str, Tensor],
555 | scale: float) -> Dict[str, Tensor]:
556 | """
557 | Merge LoRA weights into the origin dictionary with the specified scale.
558 |
559 | Args:
560 | origin (Dict[str, Tensor]): The original weights dictionary.
561 | to_merge (Dict[str, Tensor]): The weights dictionary to be merged.
562 | scale (float): The scaling factor for the merged weights.
563 |
564 | Returns:
565 | Dict[str, Tensor]: The merged weights dictionary.
566 | """
567 | for k, v in to_merge.items():
568 | v = v.to('cuda')
569 | weight = v * scale
570 | if origin.get(k, None) is None:
571 | origin[k] = weight
572 | else:
573 | origin[k] += weight
574 |
575 |
576 | def apply_lora_weights(model_state_dict: Dict[str, Tensor],
577 | weights: Dict[str, Tensor]) -> None:
578 | """
579 | Apply LoRA weights to the model state dictionary.
580 |
581 | Args:
582 | model_state_dict (Dict[str, Tensor]): The model's state dictionary.
583 | weights (Dict[str, Tensor]): The LoRA weights to be applied.
584 |
585 | Returns:
586 | None
587 | """
588 |
589 | with torch.no_grad():
590 | for k, v in weights.items():
591 | v = v.to('cuda')
592 | model_state_dict[k].add_(v)
593 |
594 |
595 | def unload_lora_weights(model_state_dict: Dict[str, Tensor],
596 | weights: Dict[str, Tensor]) -> None:
597 | """
598 | Unload LoRA weights from the model state dictionary.
599 |
600 | Args:
601 | model_state_dict (Dict[str, Tensor]): The model's state dictionary.
602 | weights (Dict[str, Tensor]): The LoRA weights to be unloaded.
603 |
604 | Returns:
605 | None
606 | """
607 | with torch.no_grad():
608 | for k, v in weights.items():
609 | v = v.to('cuda')
610 | model_state_dict[k].sub_(v)
611 |
612 |
613 | def load_attn_procs(pipe,
614 | attn_procs_paths: Union[str, List[str]],
615 | scales: Union[float, List[float]] = 0.75) -> None:
616 | """
617 | Load and merge multiple lora model weights into the pipeline.
618 |
619 | Args:
620 | pipe: Stable diffusion pipeline
621 | attn_procs_paths (Union[str, List[str]]): The paths to the attention processor weights.
622 | scales (Union[float, List[float]], optional): The scaling factor(s) for the merged weights. Default is 0.75.
623 |
624 | Returns:
625 | None
626 | """
627 |
628 | if isinstance(scales, str):
629 | attn_procs_paths = [attn_procs_paths]
630 | if isinstance(scales, float):
631 | scales = [scales] * len(attn_procs_paths)
632 |
633 | pipe.text_encoder_merged_weights, pipe.unet_merged_weights = {}, {}
634 |
635 | for attn_procs_path, scale in zip(attn_procs_paths, scales):
636 | text_encoder_state_dict, unet_state_dict = load_lora_and_mul(
637 | attn_procs_path, dtype=torch.half)
638 | # merge weights from multiple lora models with scale
639 | merge_lora_weights(pipe.text_encoder_merged_weights,
640 | text_encoder_state_dict, scale)
641 | merge_lora_weights(pipe.unet_merged_weights, unet_state_dict, scale)
642 |
643 | # apply the final lora weights to text_encoder and unet
644 | apply_lora_weights(pipe.text_encoder.state_dict(),
645 | pipe.text_encoder_merged_weights)
646 | apply_lora_weights(pipe.unet.state_dict(), pipe.unet_merged_weights)
647 |
648 | patch_conv_weights(pipe.unet)
649 |
650 |
651 | def unload_lora(pipe):
652 | # unload the lora weight after each infer
653 | unload_lora_weights(pipe.text_encoder.state_dict(),
654 | pipe.text_encoder_merged_weights)
655 | unload_lora_weights(pipe.unet.state_dict(), pipe.unet_merged_weights)
656 | pipe.text_encoder_merged_weights, pipe.unet_merged_weights = {}, {}
657 | patch_conv_weights(pipe.unet)
658 |
--------------------------------------------------------------------------------
/diffusers/lpw_stable_diffusion.py:
--------------------------------------------------------------------------------
1 | # This code is borrowed from the HuggingFace diffusers library
2 | # Source: https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py
3 | # License: Apache-2.0
4 |
5 | import inspect
6 | import re
7 | from typing import Callable, List, Optional, Union
8 |
9 | import numpy as np
10 | import PIL
11 | from packaging import version
12 |
13 | import diffusers
14 | import torch
15 | from diffusers import SchedulerMixin, StableDiffusionPipeline
16 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
17 | from diffusers.pipelines.stable_diffusion import (
18 | StableDiffusionPipelineOutput, StableDiffusionSafetyChecker)
19 | from diffusers.utils import logging
20 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
21 |
22 | try:
23 | from diffusers.utils import PIL_INTERPOLATION
24 | except ImportError:
25 | if version.parse(version.parse(
26 | PIL.__version__).base_version) >= version.parse('9.1.0'):
27 | PIL_INTERPOLATION = {
28 | 'linear': PIL.Image.Resampling.BILINEAR,
29 | 'bilinear': PIL.Image.Resampling.BILINEAR,
30 | 'bicubic': PIL.Image.Resampling.BICUBIC,
31 | 'lanczos': PIL.Image.Resampling.LANCZOS,
32 | 'nearest': PIL.Image.Resampling.NEAREST,
33 | }
34 | else:
35 | PIL_INTERPOLATION = {
36 | 'linear': PIL.Image.LINEAR,
37 | 'bilinear': PIL.Image.BILINEAR,
38 | 'bicubic': PIL.Image.BICUBIC,
39 | 'lanczos': PIL.Image.LANCZOS,
40 | 'nearest': PIL.Image.NEAREST,
41 | }
42 | # ------------------------------------------------------------------------------
43 |
44 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45 |
46 | re_attention = re.compile(
47 | r"""
48 | \\\(|
49 | \\\)|
50 | \\\[|
51 | \\]|
52 | \\\\|
53 | \\|
54 | \(|
55 | \[|
56 | :([+-]?[.\d]+)\)|
57 | \)|
58 | ]|
59 | [^\\()\[\]:]+|
60 | :
61 | """,
62 | re.X,
63 | )
64 |
65 |
66 | def parse_prompt_attention(text):
67 | """
68 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
69 | Accepted tokens are:
70 | (abc) - increases attention to abc by a multiplier of 1.1
71 | (abc:3.12) - increases attention to abc by a multiplier of 3.12
72 | [abc] - decreases attention to abc by a multiplier of 1.1
73 | \( - literal character '('
74 | \[ - literal character '['
75 | \) - literal character ')'
76 | \] - literal character ']'
77 | \\ - literal character '\'
78 | anything else - just text
79 | >>> parse_prompt_attention('normal text')
80 | [['normal text', 1.0]]
81 | >>> parse_prompt_attention('an (important) word')
82 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
83 | >>> parse_prompt_attention('(unbalanced')
84 | [['unbalanced', 1.1]]
85 | >>> parse_prompt_attention('\(literal\]')
86 | [['(literal]', 1.0]]
87 | >>> parse_prompt_attention('(unnecessary)(parens)')
88 | [['unnecessaryparens', 1.1]]
89 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
90 | [['a ', 1.0],
91 | ['house', 1.5730000000000004],
92 | [' ', 1.1],
93 | ['on', 1.0],
94 | [' a ', 1.1],
95 | ['hill', 0.55],
96 | [', sun, ', 1.1],
97 | ['sky', 1.4641000000000006],
98 | ['.', 1.1]]
99 | """
100 |
101 | res = []
102 | round_brackets = []
103 | square_brackets = []
104 |
105 | round_bracket_multiplier = 1.1
106 | square_bracket_multiplier = 1 / 1.1
107 |
108 | def multiply_range(start_position, multiplier):
109 | for p in range(start_position, len(res)):
110 | res[p][1] *= multiplier
111 |
112 | for m in re_attention.finditer(text):
113 | text = m.group(0)
114 | weight = m.group(1)
115 |
116 | if text.startswith('\\'):
117 | res.append([text[1:], 1.0])
118 | elif text == '(':
119 | round_brackets.append(len(res))
120 | elif text == '[':
121 | square_brackets.append(len(res))
122 | elif weight is not None and len(round_brackets) > 0:
123 | multiply_range(round_brackets.pop(), float(weight))
124 | elif text == ')' and len(round_brackets) > 0:
125 | multiply_range(round_brackets.pop(), round_bracket_multiplier)
126 | elif text == ']' and len(square_brackets) > 0:
127 | multiply_range(square_brackets.pop(), square_bracket_multiplier)
128 | else:
129 | res.append([text, 1.0])
130 |
131 | for pos in round_brackets:
132 | multiply_range(pos, round_bracket_multiplier)
133 |
134 | for pos in square_brackets:
135 | multiply_range(pos, square_bracket_multiplier)
136 |
137 | if len(res) == 0:
138 | res = [['', 1.0]]
139 |
140 | # merge runs of identical weights
141 | i = 0
142 | while i + 1 < len(res):
143 | if res[i][1] == res[i + 1][1]:
144 | res[i][0] += res[i + 1][0]
145 | res.pop(i + 1)
146 | else:
147 | i += 1
148 |
149 | return res
150 |
151 |
152 | def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str],
153 | max_length: int):
154 | r"""
155 | Tokenize a list of prompts and return its tokens with weights of each token.
156 |
157 | No padding, starting or ending token is included.
158 | """
159 | tokens = []
160 | weights = []
161 | truncated = False
162 | for text in prompt:
163 | texts_and_weights = parse_prompt_attention(text)
164 | text_token = []
165 | text_weight = []
166 | for word, weight in texts_and_weights:
167 | # tokenize and discard the starting and the ending token
168 | token = pipe.tokenizer(word).input_ids[1:-1]
169 | text_token += token
170 | # copy the weight by length of token
171 | text_weight += [weight] * len(token)
172 | # stop if the text is too long (longer than truncation limit)
173 | if len(text_token) > max_length:
174 | truncated = True
175 | break
176 | # truncate
177 | if len(text_token) > max_length:
178 | truncated = True
179 | text_token = text_token[:max_length]
180 | text_weight = text_weight[:max_length]
181 | tokens.append(text_token)
182 | weights.append(text_weight)
183 | if truncated:
184 | logger.warning(
185 | 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples'
186 | )
187 | return tokens, weights
188 |
189 |
190 | def pad_tokens_and_weights(tokens,
191 | weights,
192 | max_length,
193 | bos,
194 | eos,
195 | pad,
196 | no_boseos_middle=True,
197 | chunk_length=77):
198 | r"""
199 | Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
200 | """
201 | max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
202 | weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
203 | for i in range(len(tokens)):
204 | tokens[i] = [
205 | bos
206 | ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos]
207 | if no_boseos_middle:
208 | weights[i] = [
209 | 1.0
210 | ] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
211 | else:
212 | w = []
213 | if len(weights[i]) == 0:
214 | w = [1.0] * weights_length
215 | else:
216 | for j in range(max_embeddings_multiples):
217 | w.append(1.0) # weight for starting token in this chunk
218 | w += weights[i][j * (chunk_length -
219 | 2):min(len(weights[i]), (j + 1) *
220 | (chunk_length - 2))]
221 | w.append(1.0) # weight for ending token in this chunk
222 | w += [1.0] * (weights_length - len(w))
223 | weights[i] = w[:]
224 |
225 | return tokens, weights
226 |
227 |
228 | def get_unweighted_text_embeddings(
229 | pipe: StableDiffusionPipeline,
230 | text_input: torch.Tensor,
231 | chunk_length: int,
232 | no_boseos_middle: Optional[bool] = True,
233 | ):
234 | """
235 | When the length of tokens is a multiple of the capacity of the text encoder,
236 | it should be split into chunks and sent to the text encoder individually.
237 | """
238 | max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
239 | if max_embeddings_multiples > 1:
240 | text_embeddings = []
241 | for i in range(max_embeddings_multiples):
242 | # extract the i-th chunk
243 | text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) *
244 | (chunk_length - 2) + 2].clone()
245 |
246 | # cover the head and the tail by the starting and the ending tokens
247 | text_input_chunk[:, 0] = text_input[0, 0]
248 | text_input_chunk[:, -1] = text_input[0, -1]
249 | text_embedding = pipe.text_encoder(text_input_chunk)[0]
250 |
251 | if no_boseos_middle:
252 | if i == 0:
253 | # discard the ending token
254 | text_embedding = text_embedding[:, :-1]
255 | elif i == max_embeddings_multiples - 1:
256 | # discard the starting token
257 | text_embedding = text_embedding[:, 1:]
258 | else:
259 | # discard both starting and ending tokens
260 | text_embedding = text_embedding[:, 1:-1]
261 |
262 | text_embeddings.append(text_embedding)
263 | text_embeddings = torch.concat(text_embeddings, axis=1)
264 | else:
265 | text_embeddings = pipe.text_encoder(text_input)[0]
266 | return text_embeddings
267 |
268 |
269 | def get_weighted_text_embeddings(
270 | pipe: StableDiffusionPipeline,
271 | prompt: Union[str, List[str]],
272 | uncond_prompt: Optional[Union[str, List[str]]] = None,
273 | max_embeddings_multiples: Optional[int] = 3,
274 | no_boseos_middle: Optional[bool] = False,
275 | skip_parsing: Optional[bool] = False,
276 | skip_weighting: Optional[bool] = False,
277 | ):
278 | r"""
279 | Prompts can be assigned with local weights using brackets. For example,
280 | prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
281 | and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
282 |
283 | Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
284 |
285 | Args:
286 | pipe (`StableDiffusionPipeline`):
287 | Pipe to provide access to the tokenizer and the text encoder.
288 | prompt (`str` or `List[str]`):
289 | The prompt or prompts to guide the image generation.
290 | uncond_prompt (`str` or `List[str]`):
291 | The unconditional prompt or prompts for guide the image generation. If unconditional prompt
292 | is provided, the embeddings of prompt and uncond_prompt are concatenated.
293 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
294 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
295 | no_boseos_middle (`bool`, *optional*, defaults to `False`):
296 | If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
297 | ending token in each of the chunk in the middle.
298 | skip_parsing (`bool`, *optional*, defaults to `False`):
299 | Skip the parsing of brackets.
300 | skip_weighting (`bool`, *optional*, defaults to `False`):
301 | Skip the weighting. When the parsing is skipped, it is forced True.
302 | """
303 | max_length = (pipe.tokenizer.model_max_length -
304 | 2) * max_embeddings_multiples + 2
305 | if isinstance(prompt, str):
306 | prompt = [prompt]
307 |
308 | if not skip_parsing:
309 | prompt_tokens, prompt_weights = get_prompts_with_weights(
310 | pipe, prompt, max_length - 2)
311 | if uncond_prompt is not None:
312 | if isinstance(uncond_prompt, str):
313 | uncond_prompt = [uncond_prompt]
314 | uncond_tokens, uncond_weights = get_prompts_with_weights(
315 | pipe, uncond_prompt, max_length - 2)
316 | else:
317 | prompt_tokens = [
318 | token[1:-1] for token in pipe.tokenizer(
319 | prompt, max_length=max_length, truncation=True).input_ids
320 | ]
321 | prompt_weights = [[1.0] * len(token) for token in prompt_tokens]
322 | if uncond_prompt is not None:
323 | if isinstance(uncond_prompt, str):
324 | uncond_prompt = [uncond_prompt]
325 | uncond_tokens = [
326 | token[1:-1]
327 | for token in pipe.tokenizer(uncond_prompt,
328 | max_length=max_length,
329 | truncation=True).input_ids
330 | ]
331 | uncond_weights = [[1.0] * len(token) for token in uncond_tokens]
332 |
333 | # round up the longest length of tokens to a multiple of (model_max_length - 2)
334 | max_length = max([len(token) for token in prompt_tokens])
335 | if uncond_prompt is not None:
336 | max_length = max(max_length,
337 | max([len(token) for token in uncond_tokens]))
338 |
339 | max_embeddings_multiples = min(
340 | max_embeddings_multiples,
341 | (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1,
342 | )
343 | max_embeddings_multiples = max(1, max_embeddings_multiples)
344 | max_length = (pipe.tokenizer.model_max_length -
345 | 2) * max_embeddings_multiples + 2
346 |
347 | # pad the length of tokens and weights
348 | if isinstance(pipe.tokenizer, CLIPTokenizer):
349 | bos = pipe.tokenizer.bos_token_id
350 | eos = pipe.tokenizer.eos_token_id
351 | else:
352 | bos = pipe.tokenizer.cls_token_id
353 | eos = pipe.tokenizer.sep_token_id
354 | pad = getattr(pipe.tokenizer, 'pad_token_id', eos)
355 | prompt_tokens, prompt_weights = pad_tokens_and_weights(
356 | prompt_tokens,
357 | prompt_weights,
358 | max_length,
359 | bos,
360 | eos,
361 | pad,
362 | no_boseos_middle=no_boseos_middle,
363 | chunk_length=pipe.tokenizer.model_max_length,
364 | )
365 | prompt_tokens = torch.tensor(prompt_tokens,
366 | dtype=torch.long,
367 | device=pipe.device)
368 | if uncond_prompt is not None:
369 | uncond_tokens, uncond_weights = pad_tokens_and_weights(
370 | uncond_tokens,
371 | uncond_weights,
372 | max_length,
373 | bos,
374 | eos,
375 | pad,
376 | no_boseos_middle=no_boseos_middle,
377 | chunk_length=pipe.tokenizer.model_max_length,
378 | )
379 | uncond_tokens = torch.tensor(uncond_tokens,
380 | dtype=torch.long,
381 | device=pipe.device)
382 |
383 | # get the embeddings
384 | text_embeddings = get_unweighted_text_embeddings(
385 | pipe,
386 | prompt_tokens,
387 | pipe.tokenizer.model_max_length,
388 | no_boseos_middle=no_boseos_middle,
389 | )
390 | prompt_weights = torch.tensor(prompt_weights,
391 | dtype=text_embeddings.dtype,
392 | device=pipe.device)
393 | if uncond_prompt is not None:
394 | uncond_embeddings = get_unweighted_text_embeddings(
395 | pipe,
396 | uncond_tokens,
397 | pipe.tokenizer.model_max_length,
398 | no_boseos_middle=no_boseos_middle,
399 | )
400 | uncond_weights = torch.tensor(uncond_weights,
401 | dtype=uncond_embeddings.dtype,
402 | device=pipe.device)
403 |
404 | # assign weights to the prompts and normalize in the sense of mean
405 | # TODO: should we normalize by chunk or in a whole (current implementation)?
406 | if (not skip_parsing) and (not skip_weighting):
407 | previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
408 | text_embeddings.dtype)
409 | text_embeddings *= prompt_weights.unsqueeze(-1)
410 | current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(
411 | text_embeddings.dtype)
412 | text_embeddings *= (previous_mean /
413 | current_mean).unsqueeze(-1).unsqueeze(-1)
414 | if uncond_prompt is not None:
415 | previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
416 | uncond_embeddings.dtype)
417 | uncond_embeddings *= uncond_weights.unsqueeze(-1)
418 | current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to(
419 | uncond_embeddings.dtype)
420 | uncond_embeddings *= (previous_mean /
421 | current_mean).unsqueeze(-1).unsqueeze(-1)
422 |
423 | if uncond_prompt is not None:
424 | return text_embeddings, uncond_embeddings
425 | return text_embeddings, None
426 |
427 |
428 | def preprocess_image(image):
429 | w, h = image.size
430 | w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
431 | image = image.resize((w, h), resample=PIL_INTERPOLATION['lanczos'])
432 | image = np.array(image).astype(np.float32) / 255.0
433 | image = image[None].transpose(0, 3, 1, 2)
434 | image = torch.from_numpy(image)
435 | return 2.0 * image - 1.0
436 |
437 |
438 | def preprocess_mask(mask, scale_factor=8):
439 | mask = mask.convert('L')
440 | w, h = mask.size
441 | w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32
442 | mask = mask.resize((w // scale_factor, h // scale_factor),
443 | resample=PIL_INTERPOLATION['nearest'])
444 | mask = np.array(mask).astype(np.float32) / 255.0
445 | mask = np.tile(mask, (4, 1, 1))
446 | mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
447 | mask = 1 - mask # repaint white, keep black
448 | mask = torch.from_numpy(mask)
449 | return mask
450 |
451 |
452 | class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
453 | r"""
454 | Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
455 | weighting in prompt.
456 |
457 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
458 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
459 |
460 | Args:
461 | vae ([`AutoencoderKL`]):
462 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
463 | text_encoder ([`CLIPTextModel`]):
464 | Frozen text-encoder. Stable Diffusion uses the text portion of
465 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
466 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
467 | tokenizer (`CLIPTokenizer`):
468 | Tokenizer of class
469 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
470 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
471 | scheduler ([`SchedulerMixin`]):
472 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
473 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
474 | safety_checker ([`StableDiffusionSafetyChecker`]):
475 | Classification module that estimates whether generated images could be considered offensive or harmful.
476 | Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
477 | feature_extractor ([`CLIPImageProcessor`]):
478 | Model that extracts features from generated images to be used as inputs for the `safety_checker`.
479 | """
480 |
481 | if version.parse(version.parse(
482 | diffusers.__version__).base_version) >= version.parse('0.9.0'):
483 |
484 | def __init__(
485 | self,
486 | vae: AutoencoderKL,
487 | text_encoder: CLIPTextModel,
488 | tokenizer: CLIPTokenizer,
489 | unet: UNet2DConditionModel,
490 | scheduler: SchedulerMixin,
491 | safety_checker: StableDiffusionSafetyChecker,
492 | feature_extractor: CLIPImageProcessor,
493 | requires_safety_checker: bool = True,
494 | ):
495 | super().__init__(
496 | vae=vae,
497 | text_encoder=text_encoder,
498 | tokenizer=tokenizer,
499 | unet=unet,
500 | scheduler=scheduler,
501 | safety_checker=safety_checker,
502 | feature_extractor=feature_extractor,
503 | requires_safety_checker=requires_safety_checker,
504 | )
505 | self.__init__additional__()
506 |
507 | else:
508 |
509 | def __init__(
510 | self,
511 | vae: AutoencoderKL,
512 | text_encoder: CLIPTextModel,
513 | tokenizer: CLIPTokenizer,
514 | unet: UNet2DConditionModel,
515 | scheduler: SchedulerMixin,
516 | safety_checker: StableDiffusionSafetyChecker,
517 | feature_extractor: CLIPImageProcessor,
518 | ):
519 | super().__init__(
520 | vae=vae,
521 | text_encoder=text_encoder,
522 | tokenizer=tokenizer,
523 | unet=unet,
524 | scheduler=scheduler,
525 | safety_checker=safety_checker,
526 | feature_extractor=feature_extractor,
527 | )
528 | self.__init__additional__()
529 |
530 | def __init__additional__(self):
531 | if not hasattr(self, 'vae_scale_factor'):
532 | setattr(self, 'vae_scale_factor',
533 | 2**(len(self.vae.config.block_out_channels) - 1))
534 |
535 | @property
536 | def _execution_device(self):
537 | r"""
538 | Returns the device on which the pipeline's models will be executed. After calling
539 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
540 | hooks.
541 | """
542 | if self.device != torch.device('meta') or not hasattr(
543 | self.unet, '_hf_hook'):
544 | return self.device
545 | for module in self.unet.modules():
546 | if (hasattr(module, '_hf_hook')
547 | and hasattr(module._hf_hook, 'execution_device')
548 | and module._hf_hook.execution_device is not None):
549 | return torch.device(module._hf_hook.execution_device)
550 | return self.device
551 |
552 | def _encode_prompt(
553 | self,
554 | prompt,
555 | device,
556 | num_images_per_prompt,
557 | do_classifier_free_guidance,
558 | negative_prompt,
559 | max_embeddings_multiples,
560 | ):
561 | r"""
562 | Encodes the prompt into text encoder hidden states.
563 |
564 | Args:
565 | prompt (`str` or `list(int)`):
566 | prompt to be encoded
567 | device: (`torch.device`):
568 | torch device
569 | num_images_per_prompt (`int`):
570 | number of images that should be generated per prompt
571 | do_classifier_free_guidance (`bool`):
572 | whether to use classifier free guidance or not
573 | negative_prompt (`str` or `List[str]`):
574 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
575 | if `guidance_scale` is less than `1`).
576 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
577 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
578 | """
579 | batch_size = len(prompt) if isinstance(prompt, list) else 1
580 |
581 | if negative_prompt is None:
582 | negative_prompt = [''] * batch_size
583 | elif isinstance(negative_prompt, str):
584 | negative_prompt = [negative_prompt] * batch_size
585 | if batch_size != len(negative_prompt):
586 | raise ValueError(
587 | f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
588 | f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
589 | ' the batch size of `prompt`.')
590 |
591 | text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
592 | pipe=self,
593 | prompt=prompt,
594 | uncond_prompt=negative_prompt
595 | if do_classifier_free_guidance else None,
596 | max_embeddings_multiples=max_embeddings_multiples,
597 | )
598 | bs_embed, seq_len, _ = text_embeddings.shape
599 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
600 | text_embeddings = text_embeddings.view(
601 | bs_embed * num_images_per_prompt, seq_len, -1)
602 |
603 | if do_classifier_free_guidance:
604 | bs_embed, seq_len, _ = uncond_embeddings.shape
605 | uncond_embeddings = uncond_embeddings.repeat(
606 | 1, num_images_per_prompt, 1)
607 | uncond_embeddings = uncond_embeddings.view(
608 | bs_embed * num_images_per_prompt, seq_len, -1)
609 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
610 |
611 | return text_embeddings
612 |
613 | def check_inputs(self, prompt, height, width, strength, callback_steps):
614 | if not isinstance(prompt, str) and not isinstance(prompt, list):
615 | raise ValueError(
616 | f'`prompt` has to be of type `str` or `list` but is {type(prompt)}'
617 | )
618 |
619 | if strength < 0 or strength > 1:
620 | raise ValueError(
621 | f'The value of strength should in [0.0, 1.0] but is {strength}'
622 | )
623 |
624 | if height % 8 != 0 or width % 8 != 0:
625 | raise ValueError(
626 | f'`height` and `width` have to be divisible by 8 but are {height} and {width}.'
627 | )
628 |
629 | if (callback_steps is None) or (callback_steps is not None and
630 | (not isinstance(callback_steps, int)
631 | or callback_steps <= 0)):
632 | raise ValueError(
633 | f'`callback_steps` has to be a positive integer but is {callback_steps} of type'
634 | f' {type(callback_steps)}.')
635 |
636 | def get_timesteps(self, num_inference_steps, strength, device,
637 | is_text2img):
638 | if is_text2img:
639 | return self.scheduler.timesteps.to(device), num_inference_steps
640 | else:
641 | # get the original timestep using init_timestep
642 | offset = self.scheduler.config.get('steps_offset', 0)
643 | init_timestep = int(num_inference_steps * strength) + offset
644 | init_timestep = min(init_timestep, num_inference_steps)
645 |
646 | t_start = max(num_inference_steps - init_timestep + offset, 0)
647 | timesteps = self.scheduler.timesteps[t_start:].to(device)
648 | return timesteps, num_inference_steps - t_start
649 |
650 | def run_safety_checker(self, image, device, dtype):
651 | if self.safety_checker is not None:
652 | safety_checker_input = self.feature_extractor(
653 | self.numpy_to_pil(image), return_tensors='pt').to(device)
654 | image, has_nsfw_concept = self.safety_checker(
655 | images=image,
656 | clip_input=safety_checker_input.pixel_values.to(dtype))
657 | else:
658 | has_nsfw_concept = None
659 | return image, has_nsfw_concept
660 |
661 | def decode_latents(self, latents):
662 | latents = 1 / 0.18215 * latents
663 | image = self.vae.decode(latents).sample
664 | image = (image / 2 + 0.5).clamp(0, 1)
665 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
666 | image = image.cpu().permute(0, 2, 3, 1).float().numpy()
667 | return image
668 |
669 | def prepare_extra_step_kwargs(self, generator, eta):
670 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
671 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
672 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
673 | # and should be between [0, 1]
674 |
675 | accepts_eta = 'eta' in set(
676 | inspect.signature(self.scheduler.step).parameters.keys())
677 | extra_step_kwargs = {}
678 | if accepts_eta:
679 | extra_step_kwargs['eta'] = eta
680 |
681 | # check if the scheduler accepts generator
682 | accepts_generator = 'generator' in set(
683 | inspect.signature(self.scheduler.step).parameters.keys())
684 | if accepts_generator:
685 | extra_step_kwargs['generator'] = generator
686 | return extra_step_kwargs
687 |
688 | def prepare_latents(self,
689 | image,
690 | timestep,
691 | batch_size,
692 | height,
693 | width,
694 | dtype,
695 | device,
696 | generator,
697 | latents=None):
698 | if image is None:
699 | shape = (
700 | batch_size,
701 | self.unet.config.in_channels,
702 | height // self.vae_scale_factor,
703 | width // self.vae_scale_factor,
704 | )
705 |
706 | if latents is None:
707 | if device.type == 'mps':
708 | # randn does not work reproducibly on mps
709 | latents = torch.randn(shape,
710 | generator=generator,
711 | device='cpu',
712 | dtype=dtype).to(device)
713 | else:
714 | latents = torch.randn(shape,
715 | generator=generator,
716 | device=device,
717 | dtype=dtype)
718 | else:
719 | if latents.shape != shape:
720 | raise ValueError(
721 | f'Unexpected latents shape, got {latents.shape}, expected {shape}'
722 | )
723 | latents = latents.to(device)
724 |
725 | # scale the initial noise by the standard deviation required by the scheduler
726 | latents = latents * self.scheduler.init_noise_sigma
727 | return latents, None, None
728 | else:
729 | init_latent_dist = self.vae.encode(image).latent_dist
730 | init_latents = init_latent_dist.sample(generator=generator)
731 | init_latents = 0.18215 * init_latents
732 | init_latents = torch.cat([init_latents] * batch_size, dim=0)
733 | init_latents_orig = init_latents
734 | shape = init_latents.shape
735 |
736 | # add noise to latents using the timesteps
737 | if device.type == 'mps':
738 | noise = torch.randn(shape,
739 | generator=generator,
740 | device='cpu',
741 | dtype=dtype).to(device)
742 | else:
743 | noise = torch.randn(shape,
744 | generator=generator,
745 | device=device,
746 | dtype=dtype)
747 | latents = self.scheduler.add_noise(init_latents, noise, timestep)
748 | return latents, init_latents_orig, noise
749 |
750 | @torch.no_grad()
751 | def __call__(
752 | self,
753 | prompt: Union[str, List[str]],
754 | negative_prompt: Optional[Union[str, List[str]]] = None,
755 | image: Union[torch.FloatTensor, PIL.Image.Image] = None,
756 | mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
757 | height: int = 512,
758 | width: int = 512,
759 | num_inference_steps: int = 50,
760 | guidance_scale: float = 7.5,
761 | strength: float = 0.8,
762 | num_images_per_prompt: Optional[int] = 1,
763 | eta: float = 0.0,
764 | generator: Optional[torch.Generator] = None,
765 | latents: Optional[torch.FloatTensor] = None,
766 | max_embeddings_multiples: Optional[int] = 3,
767 | output_type: Optional[str] = 'pil',
768 | return_dict: bool = True,
769 | callback: Optional[Callable[[int, int, torch.FloatTensor],
770 | None]] = None,
771 | is_cancelled_callback: Optional[Callable[[], bool]] = None,
772 | callback_steps: int = 1,
773 | ):
774 | r"""
775 | Function invoked when calling the pipeline for generation.
776 |
777 | Args:
778 | prompt (`str` or `List[str]`):
779 | The prompt or prompts to guide the image generation.
780 | negative_prompt (`str` or `List[str]`, *optional*):
781 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
782 | if `guidance_scale` is less than `1`).
783 | image (`torch.FloatTensor` or `PIL.Image.Image`):
784 | `Image`, or tensor representing an image batch, that will be used as the starting point for the
785 | process.
786 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
787 | `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
788 | replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
789 | PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
790 | contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
791 | height (`int`, *optional*, defaults to 512):
792 | The height in pixels of the generated image.
793 | width (`int`, *optional*, defaults to 512):
794 | The width in pixels of the generated image.
795 | num_inference_steps (`int`, *optional*, defaults to 50):
796 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
797 | expense of slower inference.
798 | guidance_scale (`float`, *optional*, defaults to 7.5):
799 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
800 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
801 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
802 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
803 | usually at the expense of lower image quality.
804 | strength (`float`, *optional*, defaults to 0.8):
805 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
806 | `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
807 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
808 | noise will be maximum and the denoising process will run for the full number of iterations specified in
809 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
810 | num_images_per_prompt (`int`, *optional*, defaults to 1):
811 | The number of images to generate per prompt.
812 | eta (`float`, *optional*, defaults to 0.0):
813 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
814 | [`schedulers.DDIMScheduler`], will be ignored for others.
815 | generator (`torch.Generator`, *optional*):
816 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
817 | deterministic.
818 | latents (`torch.FloatTensor`, *optional*):
819 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
820 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
821 | tensor will ge generated by sampling using the supplied random `generator`.
822 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
823 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
824 | output_type (`str`, *optional*, defaults to `"pil"`):
825 | The output format of the generate image. Choose between
826 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
827 | return_dict (`bool`, *optional*, defaults to `True`):
828 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
829 | plain tuple.
830 | callback (`Callable`, *optional*):
831 | A function that will be called every `callback_steps` steps during inference. The function will be
832 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
833 | is_cancelled_callback (`Callable`, *optional*):
834 | A function that will be called every `callback_steps` steps during inference. If the function returns
835 | `True`, the inference will be cancelled.
836 | callback_steps (`int`, *optional*, defaults to 1):
837 | The frequency at which the `callback` function will be called. If not specified, the callback will be
838 | called at every step.
839 |
840 | Returns:
841 | `None` if cancelled by `is_cancelled_callback`,
842 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
843 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
844 | When returning a tuple, the first element is a list with the generated images, and the second element is a
845 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
846 | (nsfw) content, according to the `safety_checker`.
847 | """
848 | # 0. Default height and width to unet
849 | height = height or self.unet.config.sample_size * self.vae_scale_factor
850 | width = width or self.unet.config.sample_size * self.vae_scale_factor
851 |
852 | # 1. Check inputs. Raise error if not correct
853 | self.check_inputs(prompt, height, width, strength, callback_steps)
854 |
855 | # 2. Define call parameters
856 | batch_size = 1 if isinstance(prompt, str) else len(prompt)
857 | device = self._execution_device
858 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
859 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
860 | # corresponds to doing no classifier free guidance.
861 | do_classifier_free_guidance = guidance_scale > 1.0
862 |
863 | # 3. Encode input prompt
864 | text_embeddings = self._encode_prompt(
865 | prompt,
866 | device,
867 | num_images_per_prompt,
868 | do_classifier_free_guidance,
869 | negative_prompt,
870 | max_embeddings_multiples,
871 | )
872 | dtype = text_embeddings.dtype
873 |
874 | # 4. Preprocess image and mask
875 | if isinstance(image, PIL.Image.Image):
876 | image = preprocess_image(image)
877 | if image is not None:
878 | image = image.to(device=self.device, dtype=dtype)
879 | if isinstance(mask_image, PIL.Image.Image):
880 | mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
881 | if mask_image is not None:
882 | mask = mask_image.to(device=self.device, dtype=dtype)
883 | mask = torch.cat([mask] * batch_size * num_images_per_prompt)
884 | else:
885 | mask = None
886 |
887 | # 5. set timesteps
888 | self.scheduler.set_timesteps(num_inference_steps, device=device)
889 | timesteps, num_inference_steps = self.get_timesteps(
890 | num_inference_steps, strength, device, image is None)
891 | latent_timestep = timesteps[:1].repeat(batch_size *
892 | num_images_per_prompt)
893 |
894 | # 6. Prepare latent variables
895 | latents, init_latents_orig, noise = self.prepare_latents(
896 | image,
897 | latent_timestep,
898 | batch_size * num_images_per_prompt,
899 | height,
900 | width,
901 | dtype,
902 | device,
903 | generator,
904 | latents,
905 | )
906 |
907 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
908 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
909 |
910 | # 8. Denoising loop
911 | for i, t in enumerate(self.progress_bar(timesteps)):
912 | # expand the latents if we are doing classifier free guidance
913 | latent_model_input = torch.cat(
914 | [latents] * 2) if do_classifier_free_guidance else latents
915 | latent_model_input = self.scheduler.scale_model_input(
916 | latent_model_input, t)
917 |
918 | # predict the noise residual
919 | noise_pred = self.unet(
920 | latent_model_input, t,
921 | encoder_hidden_states=text_embeddings).sample
922 |
923 | # perform guidance
924 | if do_classifier_free_guidance:
925 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
926 | noise_pred = noise_pred_uncond + guidance_scale * (
927 | noise_pred_text - noise_pred_uncond)
928 |
929 | # compute the previous noisy sample x_t -> x_t-1
930 | latents = self.scheduler.step(noise_pred, t, latents,
931 | **extra_step_kwargs).prev_sample
932 |
933 | if mask is not None:
934 | # masking
935 | init_latents_proper = self.scheduler.add_noise(
936 | init_latents_orig, noise, torch.tensor([t]))
937 | latents = (init_latents_proper * mask) + (latents * (1 - mask))
938 |
939 | # call the callback, if provided
940 | if i % callback_steps == 0:
941 | if callback is not None:
942 | callback(i, t, latents)
943 | if is_cancelled_callback is not None and is_cancelled_callback(
944 | ):
945 | return None
946 |
947 | # 9. Post-processing
948 | image = self.decode_latents(latents)
949 |
950 | # 10. Run safety checker
951 | image, has_nsfw_concept = self.run_safety_checker(
952 | image, device, text_embeddings.dtype)
953 |
954 | # 11. Convert to PIL
955 | if output_type == 'pil':
956 | image = self.numpy_to_pil(image)
957 |
958 | if not return_dict:
959 | return image, has_nsfw_concept
960 |
961 | return StableDiffusionPipelineOutput(
962 | images=image, nsfw_content_detected=has_nsfw_concept)
963 |
964 | def text2img(
965 | self,
966 | prompt: Union[str, List[str]],
967 | negative_prompt: Optional[Union[str, List[str]]] = None,
968 | height: int = 512,
969 | width: int = 512,
970 | num_inference_steps: int = 50,
971 | guidance_scale: float = 7.5,
972 | num_images_per_prompt: Optional[int] = 1,
973 | eta: float = 0.0,
974 | generator: Optional[torch.Generator] = None,
975 | latents: Optional[torch.FloatTensor] = None,
976 | max_embeddings_multiples: Optional[int] = 3,
977 | output_type: Optional[str] = 'pil',
978 | return_dict: bool = True,
979 | callback: Optional[Callable[[int, int, torch.FloatTensor],
980 | None]] = None,
981 | is_cancelled_callback: Optional[Callable[[], bool]] = None,
982 | callback_steps: int = 1,
983 | ):
984 | r"""
985 | Function for text-to-image generation.
986 | Args:
987 | prompt (`str` or `List[str]`):
988 | The prompt or prompts to guide the image generation.
989 | negative_prompt (`str` or `List[str]`, *optional*):
990 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
991 | if `guidance_scale` is less than `1`).
992 | height (`int`, *optional*, defaults to 512):
993 | The height in pixels of the generated image.
994 | width (`int`, *optional*, defaults to 512):
995 | The width in pixels of the generated image.
996 | num_inference_steps (`int`, *optional*, defaults to 50):
997 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
998 | expense of slower inference.
999 | guidance_scale (`float`, *optional*, defaults to 7.5):
1000 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1001 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
1002 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1003 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1004 | usually at the expense of lower image quality.
1005 | num_images_per_prompt (`int`, *optional*, defaults to 1):
1006 | The number of images to generate per prompt.
1007 | eta (`float`, *optional*, defaults to 0.0):
1008 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1009 | [`schedulers.DDIMScheduler`], will be ignored for others.
1010 | generator (`torch.Generator`, *optional*):
1011 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1012 | deterministic.
1013 | latents (`torch.FloatTensor`, *optional*):
1014 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1015 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1016 | tensor will ge generated by sampling using the supplied random `generator`.
1017 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1018 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
1019 | output_type (`str`, *optional*, defaults to `"pil"`):
1020 | The output format of the generate image. Choose between
1021 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1022 | return_dict (`bool`, *optional*, defaults to `True`):
1023 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1024 | plain tuple.
1025 | callback (`Callable`, *optional*):
1026 | A function that will be called every `callback_steps` steps during inference. The function will be
1027 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1028 | is_cancelled_callback (`Callable`, *optional*):
1029 | A function that will be called every `callback_steps` steps during inference. If the function returns
1030 | `True`, the inference will be cancelled.
1031 | callback_steps (`int`, *optional*, defaults to 1):
1032 | The frequency at which the `callback` function will be called. If not specified, the callback will be
1033 | called at every step.
1034 | Returns:
1035 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1036 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1037 | When returning a tuple, the first element is a list with the generated images, and the second element is a
1038 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1039 | (nsfw) content, according to the `safety_checker`.
1040 | """
1041 | return self.__call__(
1042 | prompt=prompt,
1043 | negative_prompt=negative_prompt,
1044 | height=height,
1045 | width=width,
1046 | num_inference_steps=num_inference_steps,
1047 | guidance_scale=guidance_scale,
1048 | num_images_per_prompt=num_images_per_prompt,
1049 | eta=eta,
1050 | generator=generator,
1051 | latents=latents,
1052 | max_embeddings_multiples=max_embeddings_multiples,
1053 | output_type=output_type,
1054 | return_dict=return_dict,
1055 | callback=callback,
1056 | is_cancelled_callback=is_cancelled_callback,
1057 | callback_steps=callback_steps,
1058 | )
1059 |
1060 | def img2img(
1061 | self,
1062 | image: Union[torch.FloatTensor, PIL.Image.Image],
1063 | prompt: Union[str, List[str]],
1064 | negative_prompt: Optional[Union[str, List[str]]] = None,
1065 | strength: float = 0.8,
1066 | num_inference_steps: Optional[int] = 50,
1067 | guidance_scale: Optional[float] = 7.5,
1068 | num_images_per_prompt: Optional[int] = 1,
1069 | eta: Optional[float] = 0.0,
1070 | generator: Optional[torch.Generator] = None,
1071 | max_embeddings_multiples: Optional[int] = 3,
1072 | output_type: Optional[str] = 'pil',
1073 | return_dict: bool = True,
1074 | callback: Optional[Callable[[int, int, torch.FloatTensor],
1075 | None]] = None,
1076 | is_cancelled_callback: Optional[Callable[[], bool]] = None,
1077 | callback_steps: int = 1,
1078 | ):
1079 | r"""
1080 | Function for image-to-image generation.
1081 | Args:
1082 | image (`torch.FloatTensor` or `PIL.Image.Image`):
1083 | `Image`, or tensor representing an image batch, that will be used as the starting point for the
1084 | process.
1085 | prompt (`str` or `List[str]`):
1086 | The prompt or prompts to guide the image generation.
1087 | negative_prompt (`str` or `List[str]`, *optional*):
1088 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1089 | if `guidance_scale` is less than `1`).
1090 | strength (`float`, *optional*, defaults to 0.8):
1091 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
1092 | `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
1093 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
1094 | noise will be maximum and the denoising process will run for the full number of iterations specified in
1095 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
1096 | num_inference_steps (`int`, *optional*, defaults to 50):
1097 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1098 | expense of slower inference. This parameter will be modulated by `strength`.
1099 | guidance_scale (`float`, *optional*, defaults to 7.5):
1100 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1101 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
1102 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1103 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1104 | usually at the expense of lower image quality.
1105 | num_images_per_prompt (`int`, *optional*, defaults to 1):
1106 | The number of images to generate per prompt.
1107 | eta (`float`, *optional*, defaults to 0.0):
1108 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1109 | [`schedulers.DDIMScheduler`], will be ignored for others.
1110 | generator (`torch.Generator`, *optional*):
1111 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1112 | deterministic.
1113 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1114 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
1115 | output_type (`str`, *optional*, defaults to `"pil"`):
1116 | The output format of the generate image. Choose between
1117 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1118 | return_dict (`bool`, *optional*, defaults to `True`):
1119 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1120 | plain tuple.
1121 | callback (`Callable`, *optional*):
1122 | A function that will be called every `callback_steps` steps during inference. The function will be
1123 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1124 | is_cancelled_callback (`Callable`, *optional*):
1125 | A function that will be called every `callback_steps` steps during inference. If the function returns
1126 | `True`, the inference will be cancelled.
1127 | callback_steps (`int`, *optional*, defaults to 1):
1128 | The frequency at which the `callback` function will be called. If not specified, the callback will be
1129 | called at every step.
1130 | Returns:
1131 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1132 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1133 | When returning a tuple, the first element is a list with the generated images, and the second element is a
1134 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1135 | (nsfw) content, according to the `safety_checker`.
1136 | """
1137 | return self.__call__(
1138 | prompt=prompt,
1139 | negative_prompt=negative_prompt,
1140 | image=image,
1141 | num_inference_steps=num_inference_steps,
1142 | guidance_scale=guidance_scale,
1143 | strength=strength,
1144 | num_images_per_prompt=num_images_per_prompt,
1145 | eta=eta,
1146 | generator=generator,
1147 | max_embeddings_multiples=max_embeddings_multiples,
1148 | output_type=output_type,
1149 | return_dict=return_dict,
1150 | callback=callback,
1151 | is_cancelled_callback=is_cancelled_callback,
1152 | callback_steps=callback_steps,
1153 | )
1154 |
1155 | def inpaint(
1156 | self,
1157 | image: Union[torch.FloatTensor, PIL.Image.Image],
1158 | mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1159 | prompt: Union[str, List[str]],
1160 | negative_prompt: Optional[Union[str, List[str]]] = None,
1161 | strength: float = 0.8,
1162 | num_inference_steps: Optional[int] = 50,
1163 | guidance_scale: Optional[float] = 7.5,
1164 | num_images_per_prompt: Optional[int] = 1,
1165 | eta: Optional[float] = 0.0,
1166 | generator: Optional[torch.Generator] = None,
1167 | max_embeddings_multiples: Optional[int] = 3,
1168 | output_type: Optional[str] = 'pil',
1169 | return_dict: bool = True,
1170 | callback: Optional[Callable[[int, int, torch.FloatTensor],
1171 | None]] = None,
1172 | is_cancelled_callback: Optional[Callable[[], bool]] = None,
1173 | callback_steps: int = 1,
1174 | ):
1175 | r"""
1176 | Function for inpaint.
1177 | Args:
1178 | image (`torch.FloatTensor` or `PIL.Image.Image`):
1179 | `Image`, or tensor representing an image batch, that will be used as the starting point for the
1180 | process. This is the image whose masked region will be inpainted.
1181 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1182 | `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1183 | replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1184 | PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1185 | contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
1186 | prompt (`str` or `List[str]`):
1187 | The prompt or prompts to guide the image generation.
1188 | negative_prompt (`str` or `List[str]`, *optional*):
1189 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
1190 | if `guidance_scale` is less than `1`).
1191 | strength (`float`, *optional*, defaults to 0.8):
1192 | Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1193 | is 1, the denoising process will be run on the masked area for the full number of iterations specified
1194 | in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1195 | noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1196 | num_inference_steps (`int`, *optional*, defaults to 50):
1197 | The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
1198 | the expense of slower inference. This parameter will be modulated by `strength`, as explained above.
1199 | guidance_scale (`float`, *optional*, defaults to 7.5):
1200 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1201 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
1202 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1203 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1204 | usually at the expense of lower image quality.
1205 | num_images_per_prompt (`int`, *optional*, defaults to 1):
1206 | The number of images to generate per prompt.
1207 | eta (`float`, *optional*, defaults to 0.0):
1208 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1209 | [`schedulers.DDIMScheduler`], will be ignored for others.
1210 | generator (`torch.Generator`, *optional*):
1211 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
1212 | deterministic.
1213 | max_embeddings_multiples (`int`, *optional*, defaults to `3`):
1214 | The max multiple length of prompt embeddings compared to the max output length of text encoder.
1215 | output_type (`str`, *optional*, defaults to `"pil"`):
1216 | The output format of the generate image. Choose between
1217 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1218 | return_dict (`bool`, *optional*, defaults to `True`):
1219 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1220 | plain tuple.
1221 | callback (`Callable`, *optional*):
1222 | A function that will be called every `callback_steps` steps during inference. The function will be
1223 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1224 | is_cancelled_callback (`Callable`, *optional*):
1225 | A function that will be called every `callback_steps` steps during inference. If the function returns
1226 | `True`, the inference will be cancelled.
1227 | callback_steps (`int`, *optional*, defaults to 1):
1228 | The frequency at which the `callback` function will be called. If not specified, the callback will be
1229 | called at every step.
1230 | Returns:
1231 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1232 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1233 | When returning a tuple, the first element is a list with the generated images, and the second element is a
1234 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1235 | (nsfw) content, according to the `safety_checker`.
1236 | """
1237 | return self.__call__(
1238 | prompt=prompt,
1239 | negative_prompt=negative_prompt,
1240 | image=image,
1241 | mask_image=mask_image,
1242 | num_inference_steps=num_inference_steps,
1243 | guidance_scale=guidance_scale,
1244 | strength=strength,
1245 | num_images_per_prompt=num_images_per_prompt,
1246 | eta=eta,
1247 | generator=generator,
1248 | max_embeddings_multiples=max_embeddings_multiples,
1249 | output_type=output_type,
1250 | return_dict=return_dict,
1251 | callback=callback,
1252 | is_cancelled_callback=is_cancelled_callback,
1253 | callback_steps=callback_steps,
1254 | )
1255 |
--------------------------------------------------------------------------------