├── scripts ├── lint.sh └── ci_test.sh ├── assets ├── app.jpg ├── arch.png ├── init.jpg ├── create.png ├── model.png ├── success.png ├── arch_ctr.png ├── arch_opt.png ├── arch_tran.png ├── pipeline.png └── prepare_model.png ├── diffusers ├── tests │ ├── __pycache__ │ │ └── ut_config.cpython-38.pyc │ ├── ut_config.py │ ├── run.py │ └── test_utils │ │ ├── test_convert.py │ │ ├── test_image_process.py │ │ ├── test_io.py │ │ └── test_blade.py ├── example │ ├── async_example_post.py │ ├── async_example_get.py │ ├── sync_example_base.py │ └── sync_example_control.py ├── Dockerfile ├── ev_error.py ├── utils │ ├── convert.py │ ├── io.py │ ├── image_process.py │ └── blade.py └── lpw_stable_diffusion.py ├── LICENSE ├── doc ├── app.md ├── deploy.md └── param.md └── README.md /scripts/lint.sh: -------------------------------------------------------------------------------- 1 | yapf -r -i ${1} 2 | isort -rc ${1} 3 | flake8 ${1} 4 | -------------------------------------------------------------------------------- /assets/app.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/app.jpg -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch.png -------------------------------------------------------------------------------- /assets/init.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/init.jpg -------------------------------------------------------------------------------- /assets/create.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/create.png -------------------------------------------------------------------------------- /assets/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/model.png -------------------------------------------------------------------------------- /assets/success.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/success.png -------------------------------------------------------------------------------- /assets/arch_ctr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch_ctr.png -------------------------------------------------------------------------------- /assets/arch_opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch_opt.png -------------------------------------------------------------------------------- /assets/arch_tran.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/arch_tran.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/pipeline.png -------------------------------------------------------------------------------- /assets/prepare_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/assets/prepare_model.png -------------------------------------------------------------------------------- /diffusers/tests/__pycache__/ut_config.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/alibaba/diffusers-api/HEAD/diffusers/tests/__pycache__/ut_config.cpython-38.pyc -------------------------------------------------------------------------------- /diffusers/example/async_example_post.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import json 4 | 5 | from eas_prediction import PredictClient, StringRequest 6 | 7 | if __name__ == '__main__': 8 | client = PredictClient('http://xxx.cn-hangzhou.pai-eas.aliyuncs.com/', 9 | 'service_name') 10 | client.set_token( 11 | 'xxx') 12 | 13 | client.init() 14 | 15 | datas = json.dumps({ 16 | 'task_id': 'async', 17 | 'prompt': '一个可爱的女孩', 18 | 'steps': 100, 19 | 'image_num': 3, 20 | 'width': 512, 21 | 'height': 512, 22 | 'seed': '123', 23 | }) 24 | 25 | request = StringRequest(datas) 26 | 27 | for x in range(0, 1): 28 | resp = client.predict(request) 29 | print(resp) 30 | -------------------------------------------------------------------------------- /diffusers/tests/ut_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | UT_ROOT = '/mnt/xinyi.zxy/diffuser/ut_test/' 4 | MODEL_DIR = os.path.join(UT_ROOT, 'models/test_model') 5 | BASE_MODEL_PATH = os.path.join(MODEL_DIR, 'base_model') 6 | CONTROLNET_MODEL_PATH = os.path.join(MODEL_DIR, 'controlnet') 7 | LORA_PATH = os.path.join(MODEL_DIR, 'lora_model/animeoutlineV4_16.safetensors') 8 | LORA_PATH_BIN = os.path.join(MODEL_DIR, 'lora_model/pytorch_lora_weights.bin') 9 | 10 | MODEL_DIR_NEW = os.path.join(UT_ROOT, 'models/new_model') 11 | CKPT_PATH = os.path.join(MODEL_DIR_NEW, 'colorful_v26.safetensors') 12 | 13 | PRETRAIN_DIR = os.path.join(UT_ROOT, 'models/pretrained_models') 14 | SAVE_DIR = os.path.join(UT_ROOT, 'results') 15 | IMAGE_DIR = os.path.join(UT_ROOT, 'images') 16 | 17 | CUSTOM_PIPELINE = './lpw_stable_diffusion.py' 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Alibaba 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /diffusers/example/async_example_get.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | from eas_prediction import QueueClient 4 | 5 | if __name__ == '__main__': 6 | # 创建输出队列对象,⽤于订阅读取输出结果数据。 7 | 8 | sink_queue = QueueClient( 9 | 'http://xxx.cn-hangzhou.pai-eas.aliyuncs.com', 10 | 'service_name/sink') 11 | sink_queue.set_token( 12 | 'xxx') 13 | 14 | sink_queue.init() 15 | 16 | # 从输出队列中watch数据,窗⼝为1。 17 | i = 0 18 | watcher = sink_queue.watch(0, 1, auto_commit=False) 19 | for x in watcher.run(): 20 | data = x.data.decode('utf-8') 21 | data = json.loads(data) 22 | print(data.keys()) 23 | if data['success']: 24 | print(data['image_url']) 25 | print(data['oss_url']) 26 | print(data['task_id']) 27 | print(data['use_blade']) 28 | print(data['seed']) 29 | print(data['is_nsfw']) 30 | else: 31 | print(data['error_msg']) 32 | # 每次收到⼀个请求数据后处理完成后⼿动commit。 33 | sink_queue.commit(x.index) 34 | i += 1 35 | if i == 10: 36 | break 37 | 38 | # 关闭已经打开的watcher对象,每个客户端实例只允许存在⼀个watcher对象,若watcher对象不关闭,再运⾏时会报错。 39 | watcher.close() 40 | -------------------------------------------------------------------------------- /scripts/ci_test.sh: -------------------------------------------------------------------------------- 1 | #================================================================ 2 | # Copyright (C) 2022 Alibaba Ltd. All rights reserved. 3 | # 4 | #================================================================ 5 | 6 | cd ${1} 7 | 8 | # pre-commit run --all-files 9 | # if [ $? -ne 0 ]; then 10 | # echo "linter test failed, please run 'pre-commit run --all-files' to check" 11 | # exit -1 12 | # fi 13 | 14 | mkdir -p logs 15 | 16 | # coverage UT and report 17 | PYTHONPATH=. coverage run tests/run.py | tee logs/ci_test.log 18 | PYTHONPATH=. coverage report | tee logs/ci_report.log 19 | # PYTHONPATH=. coverage html 20 | 21 | 22 | # please add key requirements you think is worth record 23 | 24 | echo "" | tee >> logs/ci_report.log 25 | echo "" | tee >> logs/ci_report.log 26 | echo "Requirements Version 27 | -----------------------------------------------------------" | tee >> logs/ci_report.log 28 | key_requirements=("^torch" "easycv" "easynlp" "easyretrieval"\ 29 | "blade" "mmcv" "tokenizer") 30 | pip list >> logs/envlistcache.txt 31 | 32 | for val1 in ${key_requirements[*]}; do 33 | grep $val1 logs/envlistcache.txt | tee >> logs/ci_report.log 34 | done 35 | echo "-----------------------------------------------------------" | tee >> logs/ci_report.log 36 | 37 | rm logs/envlistcache.txt 38 | -------------------------------------------------------------------------------- /doc/app.md: -------------------------------------------------------------------------------- 1 | ## EAS 自定义processor开发 2 | 3 | 以Python SDK为例,本文档介绍 基于PAI-EAS的自定义processor开发,其核心在于维护主文件[app.py](../diffusers/app.py)。 4 | 5 | ### 模版介绍 6 | 7 | 您需要继承PAI-EAS提供的基类BaseProcessor,实现**initialize()**和**process()**函数。其**process()**函数的输入输出均为BYTES类型,输出参数分别为**response_data**和**status_code**,正常请求**status_code**可以返回**0**或**200**。 8 | 9 | | 函数 | 功能描述 | 参数描述 | 10 | | -------------------------------------------------------- | ------------------------------------------------------------ | ------------------------------------------------------------ | 11 | | init(worker_threads=5, worker_processes=1,endpoint=None) | Processor构建函数。 | **worker_threads**:Worker线程数,默认值为5。**worker_processes**:进程数,默认值为1。如果**worker_processes**为1,则表示单进程多线程模式。如果**worker_processes**大于1,则**worker_threads**只负责读取数据,请求由多进程并发处理,每个进程均会执行**initialize()**函数。**endpoint**:服务监听的Endpoint,通过该参数可以指定服务监听的地址和端口,例如**endpoint=’0.0.0.0:8079’**。 | 12 | | initialize() | Processor初始化函数。服务启动时,进行模型加载等初始化工作。 | 无参数。 | 13 | | process(data) | 请求处理函数。每个请求会将Request Body作为参数传递给**process()**进行处理,并将函数返回值返回至客户端。 | **data**为Request Body,类型为BYTES。返回值也为BYTES类型。 | 14 | | run() | 启动服务。 | 无参数。 | 15 | 16 | 17 | 18 | ### 二次开发 19 | 20 | 我们在[app.py](../diffusers/app.py)已经实现了部分基于diffusers api的功能实现,本节通过流程图的显示,对核心代码流程进行展示,方便您进行二次开发。 21 | 22 | #### initialize() 23 | 24 | 25 | #### process(data) 26 | 27 | ![img](../assets/app.jpg) 28 | -------------------------------------------------------------------------------- /diffusers/tests/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import unittest 5 | from fnmatch import fnmatch 6 | 7 | 8 | def gather_test_cases(test_dir, pattern, list_tests): 9 | case_list = [] 10 | dir_list = [] 11 | for dirpath, dirnames, filenames in os.walk(test_dir): 12 | for file in filenames: 13 | if fnmatch(file, pattern): 14 | case_list.append(file) 15 | dir_list.append(dirpath) 16 | test_suite = unittest.TestSuite() 17 | 18 | for dirname, case in zip(dir_list, case_list): 19 | test_case = unittest.defaultTestLoader.\ 20 | discover(start_dir=dirname, pattern=case) 21 | test_suite.addTest(test_case) 22 | if hasattr(test_case, '__iter__'): 23 | for subcase in test_case: 24 | if list_tests: 25 | print(subcase) 26 | else: 27 | if list_tests: 28 | print(test_case) 29 | 30 | return test_suite 31 | 32 | 33 | def main(args): 34 | 35 | runner = unittest.TextTestRunner() 36 | test_suite = gather_test_cases(os.path.abspath(args.test_dir), 37 | args.pattern, args.list_tests) 38 | if not args.list_tests: 39 | result = runner.run(test_suite) 40 | if len(result.failures) > 0 or len(result.errors) > 0: 41 | sys.exit(1) 42 | 43 | 44 | if __name__ == '__main__': 45 | parser = argparse.ArgumentParser('test runner') 46 | parser.add_argument('--list_tests', 47 | action='store_true', 48 | help='list all tests') 49 | parser.add_argument('--pattern', 50 | default='test_*.py', 51 | help='test file pattern') 52 | parser.add_argument('--test_dir', 53 | default='tests', 54 | help='directory to be tested') 55 | 56 | args = parser.parse_args() 57 | main(args) 58 | -------------------------------------------------------------------------------- /diffusers/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM bladedisc/bladedisc:latest-devel-cu113 2 | ENV BLADE_GEMM_TUNE_JIT=1 DISC_ENABLE_PREDEFINED_PDL=true DISC_ENABLE_PACK_QKV=true 3 | RUN pip config set global.index-url https://mirrors.aliyun.com/pypi/simple 4 | 5 | RUN pip install https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/pytorch/wheels/torch-1.12.0%2Bcu113-cp38-cp38-linux_x86_64.whl 6 | 7 | RUN pip install https://pai-blade.oss-cn-zhangjiakou.aliyuncs.com/temp/xformers-0.0.17%2B658ebab.d20230327-cp38-cp38-linux_x86_64.whl &&\ 8 | pip install transformers &&\ 9 | pip install opencv-python-headless &&\ 10 | pip install diffusers==0.15.0 &&\ 11 | pip install -U http://eas-data.oss-cn-shanghai.aliyuncs.com/sdk/allspark-0.15-py2.py3-none-any.whl &&\ 12 | pip install https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/zxy/diffusers/torch_blade-0.0.1%2B1.12.0.cu113-cp38-cp38-linux_x86_64.whl &&\ 13 | pip install safetensors &&\ 14 | pip install modelscope &&\ 15 | pip install subword_nmt &&\ 16 | pip install jieba &&\ 17 | pip install sacremoses &&\ 18 | pip install tensorflow &&\ 19 | pip install omegaconf 20 | RUN pip install scikit-image 21 | RUN pip install https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/zxy/diffusers/controlnet_aux-0.0.3-py3-none-any.whl --no-deps 22 | RUN pip install torchvision==0.13.0 23 | RUN pip install timm 24 | RUN pip install mediapipe 25 | RUN pip cache purge 26 | 27 | RUN apt-get install wget 28 | RUN mkdir /home/pai/ 29 | 30 | ADD ./app.py /home/pai/app.py 31 | ADD ./utils /home/pai/utils 32 | ADD ./ev_error.py /home/pai/ev_error.py 33 | ADD ./lpw_stable_diffusion.py /home/pai/lpw_stable_diffusion.py 34 | 35 | RUN wget https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/image/optimized_model.tar.gz \ 36 | && tar -xvf optimized_model.tar.gz \ 37 | && mv optimized_model /home/pai/optimized_model \ 38 | && rm optimized_model.tar.gz 39 | 40 | RUN wget https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/image/pretrained_models.tar.gz \ 41 | && tar -xvf pretrained_models.tar.gz \ 42 | && mv pretrained_models /home/pai/pretrained_models \ 43 | && rm pretrained_models.tar.gz 44 | -------------------------------------------------------------------------------- /diffusers/tests/test_utils/test_convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | 5 | import torch 6 | from diffusers import StableDiffusionPipeline 7 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import \ 8 | download_from_original_stable_diffusion_ckpt 9 | from tests.ut_config import BASE_MODEL_PATH, CKPT_PATH, LORA_PATH 10 | from utils.convert import (convert_base_model_to_diffuser, 11 | convert_lora_safetensor_to_bin, convert_name_to_bin) 12 | 13 | 14 | class TestModelConvert(unittest.TestCase): 15 | def test_convert_lora_safetensor_to_bin(self): 16 | # for .safetensors to load by ori diffuser api (only attn in unet will be loaded) 17 | with tempfile.TemporaryDirectory() as temp_dir: 18 | bin_path = LORA_PATH.replace('.safetensors', '.bin') 19 | bin_path = os.path.join(temp_dir, 'bin_model.pth') 20 | convert_lora_safetensor_to_bin(LORA_PATH, bin_path) 21 | self.assertTrue(os.path.exists(bin_path)) 22 | 23 | # Load the converted lora model 24 | pipe = StableDiffusionPipeline.from_pretrained( 25 | BASE_MODEL_PATH, 26 | revision='fp16', 27 | torch_dtype=torch.float16, 28 | safety_checker=None) 29 | pipe.unet.load_attn_procs(bin_path, use_safetensors=False) 30 | 31 | def test_convert_base_model_to_diffuser(self): 32 | # convert .safetensors to multiple dirs 33 | from_safetensors = True 34 | 35 | with tempfile.TemporaryDirectory() as temp_dir: 36 | convert_base_model_to_diffuser(CKPT_PATH, temp_dir, 37 | from_safetensors) 38 | files = os.listdir(temp_dir) 39 | print(files) 40 | need_file_list = [ 41 | 'feature_extractor', 'model_index.json', 'safety_checker', 42 | 'scheduler', 'text_encoder', 'tokenizer', 'unet', 'vae' 43 | ] 44 | self.assertTrue(set(need_file_list).issubset(set(files))) 45 | 46 | # Load the converted base model 47 | pipe = StableDiffusionPipeline.from_pretrained( 48 | temp_dir, 49 | revision='fp16', 50 | torch_dtype=torch.float16, 51 | safety_checker=None) 52 | 53 | 54 | if __name__ == '__main__': 55 | unittest.main() 56 | -------------------------------------------------------------------------------- /diffusers/tests/test_utils/test_image_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | import torch 8 | # need to be import or an malloc error will be occured by controlnet_aux 9 | from diffusers import ControlNetModel, StableDiffusionControlNetPipeline 10 | from tests.ut_config import IMAGE_DIR, PRETRAIN_DIR, SAVE_DIR 11 | from utils.image_process import (generate_mask_and_img_expand, 12 | preprocess_control, transform_image) 13 | 14 | 15 | class TestImageProcessing(unittest.TestCase): 16 | def setUp(self): 17 | # image for testing 18 | img_path = os.path.join(IMAGE_DIR, 'room.png') 19 | self.image = Image.open(img_path).convert('RGB') 20 | 21 | def test_preprocess_control(self): 22 | # Test with valid process_func 23 | process_func_list = [ 24 | 'canny', 'depth', 'hed', 'mlsd', 'normal', 'openpose', 'scribble', 25 | 'seg' 26 | ] 27 | for process_func in process_func_list: 28 | print('Process: {}'.format(process_func)) 29 | processed_image = preprocess_control(self.image, process_func, 30 | PRETRAIN_DIR) 31 | self.assertIsInstance(processed_image, Image.Image) 32 | processed_image.save( 33 | os.path.join(SAVE_DIR, 'pre_{}.jpg'.format(process_func))) 34 | 35 | # Test with an invalid process_func 36 | error_message = preprocess_control(self.image, 'invalid_func', 37 | PRETRAIN_DIR) 38 | self.assertIsInstance(error_message, str) 39 | 40 | def test_transform_image(self): 41 | test_params = {0: 'Stretch', 1: 'Crop', 2: 'Pad'} 42 | 43 | expected_sizes = [(1024, 1024), (768, 1024), (1024, 768)] 44 | 45 | for expected_size in expected_sizes: 46 | width, height = expected_size 47 | for mode, mode_type in test_params.items(): 48 | print('Process: {}, width: {}, height: {}'.format( 49 | mode, width, height)) 50 | transformed_image = transform_image(self.image, width, height, 51 | mode) 52 | self.assertEqual(transformed_image.size, (width, height)) 53 | transformed_image.save( 54 | os.path.join(SAVE_DIR, '{}.jpg'.format(mode_type))) 55 | 56 | def test_generate_mask_and_img_expand(self): 57 | expand = (10, 20, 30, 40) 58 | 59 | left, right, up, down = expand 60 | width, height = self.image.size 61 | new_width, new_height = width + left + right, height + up + down 62 | 63 | expand_list = ['copy', 'reflect'] 64 | for expand_type in expand_list: 65 | expanded_image, mask = generate_mask_and_img_expand( 66 | self.image, expand, expand_type) 67 | self.assertEqual(expanded_image.size, (new_width, new_height)) 68 | self.assertEqual(mask.size, (new_width, new_height)) 69 | 70 | expanded_image.save( 71 | os.path.join(SAVE_DIR, 72 | 'expanded_image_{}.jpg'.format(expand_type))) 73 | mask.save(os.path.join(SAVE_DIR, 74 | 'mask_{}.jpg'.format(expand_type))) 75 | 76 | 77 | if __name__ == '__main__': 78 | unittest.main() 79 | -------------------------------------------------------------------------------- /diffusers/ev_error.py: -------------------------------------------------------------------------------- 1 | class Error(object): 2 | """A class for representing errors that occur in the program. 3 | 4 | Attributes: 5 | _code (int): The error code. 6 | _msg_prefix (str): The prefix for the error message. 7 | _msg (str): The error message. 8 | """ 9 | def __init__(self, code, msg_prefix): 10 | """Initialize a new Error instance. 11 | 12 | Args: 13 | code (int): The error code. 14 | msg_prefix (str): The prefix for the error message. 15 | """ 16 | self._code = code 17 | self._msg_prefix = msg_prefix 18 | self._msg = '' 19 | 20 | @property 21 | def code(self): 22 | """int: The error code.""" 23 | return self._code 24 | 25 | @code.setter 26 | def code(self, code): 27 | """Set the error code. 28 | 29 | Args: 30 | code (int): The new error code. 31 | """ 32 | self._code = code 33 | 34 | @property 35 | def msg(self): 36 | """str: The full error message, including the prefix and message.""" 37 | return self._msg_prefix + ' - ' + self._msg 38 | 39 | @msg.setter 40 | def msg(self, msg): 41 | """Set the error message. 42 | 43 | Args: 44 | msg (str): The new error message. 45 | """ 46 | self._msg = msg 47 | 48 | 49 | class JsonParseError(Error): 50 | """A class for representing JSON parse errors that occur in the program. 51 | 52 | Args: 53 | msg (str, optional): The error message. Defaults to an empty string. 54 | """ 55 | def __init__(self, msg=''): 56 | """Initialize a new JsonParseError instance. 57 | 58 | Args: 59 | msg (str, optional): The error message. Defaults to an empty 60 | string. 61 | """ 62 | super(JsonParseError, self).__init__(460, 'Json Parse Error') 63 | self.msg = msg 64 | 65 | 66 | class InputFormatError(Error): 67 | """A class for representing input format errors that occur in the program. 68 | 69 | Args: 70 | msg (str, optional): The error message. Defaults to an empty string. 71 | """ 72 | def __init__(self, msg=''): 73 | """Initialize a new InputFormatError instance. 74 | 75 | Args: 76 | msg (str, optional): The error message. Defaults to an empty 77 | string. 78 | """ 79 | super(InputFormatError, self).__init__(460, 'Input Body Format Error') 80 | self.msg = msg 81 | 82 | 83 | class ImageDecodeError(Error): 84 | """A class for representing image decode errors that occur in the program. 85 | 86 | Args: 87 | msg (str, optional): The error message. Defaults to an empty string. 88 | """ 89 | def __init__(self, msg=''): 90 | """Initialize a new ImageDecodeError instance. 91 | 92 | Args: 93 | msg (str, optional): The error message. Defaults to an empty 94 | string. 95 | """ 96 | super(ImageDecodeError, self).__init__(462, 'Image Decode Error') 97 | self.msg = msg 98 | 99 | 100 | class UnExpectedServerError(Error): 101 | """A class for representing unexpected server errors that occur in 102 | the program. 103 | 104 | Args: 105 | msg (str, optional): The error message. Defaults to an empty string. 106 | """ 107 | def __init__(self, msg=''): 108 | """Initialize a new UnExpectedServerError instance. 109 | 110 | Args: 111 | msg (str, optional): The error message. Defaults to an 112 | empty string. 113 | """ 114 | super(UnExpectedServerError, 115 | self).__init__(469, 'Unexpected \ 116 | Server Error') 117 | self.msg = msg 118 | -------------------------------------------------------------------------------- /diffusers/tests/test_utils/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from PIL import Image 5 | 6 | import torch 7 | from tests.ut_config import (BASE_MODEL_PATH, CONTROLNET_MODEL_PATH, 8 | CUSTOM_PIPELINE, IMAGE_DIR, PRETRAIN_DIR, 9 | SAVE_DIR) 10 | from utils.image_process import preprocess_control 11 | from utils.io import load_diffusers_pipeline 12 | 13 | 14 | class TestLoadDiffusersPipeline(unittest.TestCase): 15 | def setUp(self): 16 | # hyper parameters 17 | self.prompt = 'a dog' 18 | img_path = os.path.join(IMAGE_DIR, 'image.png') 19 | mask_path = os.path.join(IMAGE_DIR, 'mask.png') 20 | 21 | self.image = Image.open(img_path).convert('RGB') 22 | self.mask = Image.open(mask_path).convert('RGB') 23 | self.num_inference_steps = 20 24 | self.num_images_per_prompt = 1 25 | 26 | self.device = torch.device( 27 | 'cuda') if torch.cuda.is_available() else torch.device('cpu') 28 | 29 | def test_base_mode(self): 30 | mode = 'base' 31 | close_safety = False 32 | 33 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None, None, 34 | self.device, mode, close_safety, 35 | CUSTOM_PIPELINE) 36 | 37 | self.assertIsNotNone(pipe) 38 | 39 | # t2i 40 | with torch.no_grad(): 41 | res = pipe.text2img( 42 | prompt=self.prompt, 43 | num_inference_steps=self.num_inference_steps, 44 | num_images_per_prompt=self.num_images_per_prompt) 45 | image = res.images[0] 46 | self.assertIsInstance(image, Image.Image) 47 | image.save(os.path.join(SAVE_DIR, 't2i.jpg')) 48 | 49 | # i2i 50 | with torch.no_grad(): 51 | res = pipe.img2img( 52 | prompt=self.prompt, 53 | image=self.image, 54 | num_inference_steps=self.num_inference_steps, 55 | num_images_per_prompt=self.num_images_per_prompt) 56 | image = res.images[0] 57 | self.assertIsInstance(image, Image.Image) 58 | image.save(os.path.join(SAVE_DIR, 'i2i.jpg')) 59 | 60 | # inpaint 61 | with torch.no_grad(): 62 | res = pipe.inpaint( 63 | prompt=self.prompt, 64 | image=self.image, 65 | mask_image=self.mask, 66 | num_inference_steps=self.num_inference_steps, 67 | num_images_per_prompt=self.num_images_per_prompt) 68 | image = res.images[0] 69 | self.assertIsInstance(image, Image.Image) 70 | image.save(os.path.join(SAVE_DIR, 'inpaint.jpg')) 71 | 72 | def test_controlnet_mode(self): 73 | mode = 'controlnet' 74 | close_safety = False 75 | 76 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None, 77 | CONTROLNET_MODEL_PATH, self.device, 78 | mode, close_safety, CUSTOM_PIPELINE) 79 | 80 | self.assertIsNotNone(pipe) 81 | 82 | with torch.no_grad(): 83 | process_image = preprocess_control(self.image, 'canny', 84 | PRETRAIN_DIR) 85 | res = pipe(prompt=self.prompt, 86 | image=process_image, 87 | num_inference_steps=self.num_inference_steps, 88 | num_images_per_prompt=self.num_images_per_prompt) 89 | image = res.images[0] 90 | self.assertIsInstance(image, Image.Image) 91 | image.save(os.path.join(SAVE_DIR, 'control.jpg')) 92 | 93 | def test_invalid_mode(self): 94 | with self.assertRaises(ValueError): 95 | mode = 'invalid' 96 | close_safety = False 97 | load_diffusers_pipeline(BASE_MODEL_PATH, None, None, self.device, 98 | mode, close_safety, CUSTOM_PIPELINE) 99 | 100 | 101 | if __name__ == '__main__': 102 | unittest.main() 103 | -------------------------------------------------------------------------------- /doc/deploy.md: -------------------------------------------------------------------------------- 1 | ### 部署文档 2 | 3 | 通过[PAI-EAS](https://help.aliyun.com/document_detail/113696.html?spm=a2c4g.113696.0.0.2c421af2NqTtmW),您可以将模型快速部署为RESTful API,再通过HTTP请求的方式调用该服务。PAI-EAS提供弹性扩缩容、版本控制及资源监控等功能,便于将模型服务应用于业务。 4 | 5 | 本文档介绍了如何在PAI-EAS,使用给定镜像部署基于diffusers的服务。我们使用自定义镜像部署+oss挂载的方式进行服务部署。在二次开发后,您可修改镜像并参考本文档部署自定义的服务。 6 | 7 | #### 前提条件 8 | 9 | 在开始执行操作前,请确认您已完成以下准备工作。 10 | 11 | - 已开通PAI(DSW、EAS)后付费,并创建默认工作空间,具体操作,请参见[开通并创建默认工作空间](https://help.aliyun.com/document_detail/326190.htm#task-2121596)。 12 | - 已创建OSS存储空间(Bucket),用于存储数据集、训练获得的模型文件和配置文件。关于如何创建存储空间,详情请参见[创建存储空间](https://help.aliyun.com/document_detail/31885.htm#task-u3p-3n4-tdb)。 13 | 14 | 15 | #### 步骤一:模型准备 16 | 17 | 准备好所需加载的模型文件,并上传至自己的oss bucket。 18 | 19 | - 将主模型所需文件均存在指定oss bucket的base_model文件夹中【必须】 20 | - ControlNet相关模型放在controlnet文件夹中【可选】 21 | - 讲预置的翻译模型放入translate文件夹中【可选】 22 | 23 | 24 | 25 | - base_model 文件夹支持下列格式: 26 | 27 | - diffusers api支持的多文件夹格式 (单个的.safetensors/.ckpt文件将自动转化) 28 | 29 | ![img](../assets/arch.png) 30 | 31 | 32 | - controlnet文件夹需包含下列文件【使用controlnet时必须】 33 | 34 | 35 | 36 | - optimized_model文件夹需包含以下三个/四个文件(controlnet会生成额外的controlnet.pt)。在部署服务时打开--use_blade开关,Blade模型将自动在后台优化,优化完成后自动进行模型替换 37 | 38 | 39 | 40 | - translate文件夹,您可通过下方链接下载模型文件内置翻译模型。【可选】 41 | 42 | - 模型下载地址:https://www.modelscope.cn/models/damo/nlp_csanmt_translation_zh2en/files 43 | - 参考下载链接:https://pai-vision-exp.oss-cn-zhangjiakou.aliyuncs.com/zxy/model/damo_translate.tar.gz 44 | 45 | 46 | 47 | 48 | #### 步骤二:模型部署 49 | 50 | 在模型准备完成后,可利用EAS进行服务的部署。您可以根据需要及时调整部署命令,以部署不同类型的服务。 51 | 52 | 具体地,您需要进行执行以下步骤: 53 | 54 | - **Step1:创建EAS服务** 55 | 56 | ![img](../assets/create.png) 57 | 58 | - **Step2:选择镜像部署并修改相应的参数** 59 | 60 | - ”对应配置编辑“处,复制下列配置文件,并修改相应的参数 61 | 62 | ```json 63 | { 64 | "cloud": { 65 | "computing": { 66 | "instance_type": "ml.gu7i.c8m30.1-gu30" 67 | } 68 | }, 69 | "containers": [ 70 | { 71 | "command": "python /home/pai/app.py --func_name base --oss_save_dir oss://converter-offline-installer/zxy/diffusers/0605/ --region hangzhou --use_blade --use_translate", 72 | "image": "eas-registry-vpc.cn-hangzhou.cr.aliyuncs.com/pai-eas/diffuser-inference:2.2.1-py38-cu113-unbuntu2004-blade-public" 73 | } 74 | ], 75 | "features": { 76 | "eas.aliyun.com/extra-ephemeral-storage": "100Gi" 77 | }, 78 | "metadata": { 79 | "cpu": 8, 80 | "gpu": 1, 81 | "instance": 1, 82 | "memory": 30000, 83 | "name": "diffuser_base_ch_async", 84 | "rpc": { 85 | "keepalive": 500000 86 | }, 87 | "type": "Async" 88 | }, 89 | "name": "diffuser_base_ch_async", 90 | "storage": [ 91 | { 92 | "mount_path": "/oss", 93 | "oss": { 94 | "path": "oss://converter-offline-installer/zxy/diffusers/0605/base/" 95 | } 96 | }, 97 | { 98 | "mount_path": "/result", 99 | "oss": { 100 | "path": "oss://converter-offline-installer/zxy/diffusers/0605/" 101 | } 102 | } 103 | ] 104 | } 105 | ``` 106 | 107 | - containers.command 命令 108 | 109 | | key | value | 说明 | 110 | | --------------- | --------------------- | ------------------------------------ | 111 | | --func_name | base | 同时支持t2i/i2i/inpaint/outpaint功能 | 112 | | | controlnet | 进行基于ControlNet的图像编辑 | 113 | | --oss_save_dir | oss://xxx | 保存图片的oss路径 | 114 | | --region | hangzhou/shanghai/... | 服务部署所在区域的拼音全称 | 115 | | --use_blade | — | 启用Blade推理优化功能 | 116 | | --use_translate | — | 加载翻译模型支持中文prompt输入 | 117 | 118 | - containers.image:替换杭州为服务部署所在的reigon 119 | 120 | - metadata.name / name 修改为任意的服务名称 121 | 122 | - metadata.type 设置为 Async表示 部署异步服务 (同步服务无需添加该项) 123 | 124 | - storage 处 添加 模型和生成结果的oss路径挂载 125 | 126 | - mount_path 不可修改 / 或修改镜像中的对应代码 127 | 128 | | mount_path | oss | 说明 | 129 | | ---------- | --------- | -------------------- | 130 | | /oss | oss://xxx | 所挂载的模型路径 | 131 | | /result | oss://xxx | 所挂载的生成图片位置 | 132 | 133 | - 点击“部署”按钮,等待服务部署完成即可。 134 | 135 | ![img](../assets/success.png) 136 | 137 | - 点击上图的调用信息获得 测试所需的: 138 | 139 | ```json 140 | hosts = 'xxx' 141 | head = { 142 | "Authorization": "xxx" 143 | } 144 | ``` 145 | -------------------------------------------------------------------------------- /diffusers/utils/convert.py: -------------------------------------------------------------------------------- 1 | """ 2 | convert differnt model type to the standard diffuser type 3 | """ 4 | 5 | import torch 6 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import \ 7 | download_from_original_stable_diffusion_ckpt 8 | from safetensors.torch import load_file 9 | 10 | LORA_PREFIX_UNET = 'lora_unet' 11 | 12 | 13 | def convert_name_to_bin(name: str) -> str: 14 | """ 15 | Convert a name to binary format. 16 | 17 | Args: 18 | name (str): Name to be converted. 19 | 20 | Returns: 21 | str: Converted name in binary format. 22 | """ 23 | 24 | # down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q.lora_up 25 | new_name = name.replace(LORA_PREFIX_UNET + '_', '') 26 | new_name = new_name.replace('.weight', '') 27 | 28 | # ['down_blocks_0_attentions_0_transformer_blocks_0_attn1_to_q', 'lora.up'] 29 | parts = new_name.split('.') 30 | 31 | #parts[0] = parts[0].replace('_0', '') 32 | if 'out' in parts[0]: 33 | parts[0] = '_'.join(parts[0].split('_')[:-1]) 34 | parts[1] = parts[1].replace('_', '.') 35 | 36 | # ['down', 'blocks', '0', 'attentions', '0', 'transformer', 'blocks', '0', 'attn1', 'to', 'q'] 37 | # ['mid', 'block', 'attentions', '0', 'transformer', 'blocks', '0', 'attn2', 'to', 'out'] 38 | sub_parts = parts[0].split('_') 39 | 40 | # down_blocks.0.attentions.0.transformer_blocks.0.attn1.to_q_ 41 | new_sub_parts = '' 42 | for i in range(len(sub_parts)): 43 | if sub_parts[i] in [ 44 | 'block', 'blocks', 'attentions' 45 | ] or sub_parts[i].isnumeric() or 'attn' in sub_parts[i]: 46 | if 'attn' in sub_parts[i]: 47 | new_sub_parts += sub_parts[i] + '.processor.' 48 | else: 49 | new_sub_parts += sub_parts[i] + '.' 50 | else: 51 | new_sub_parts += sub_parts[i] + '_' 52 | 53 | # down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor.to_q_lora.up 54 | new_sub_parts += parts[1] 55 | 56 | new_name = new_sub_parts + '.weight' 57 | 58 | return new_name 59 | 60 | 61 | def convert_lora_safetensor_to_bin(safetensor_path: str, 62 | bin_path: str) -> None: 63 | """ 64 | Convert LoRA safetensor file to binary format and save it. (only the attn parameters will be saved) 65 | 66 | Args: 67 | safetensor_path (str): Path to the safetensor file. 68 | bin_path (str): Path to save the binary file. 69 | """ 70 | 71 | bin_state_dict = {} 72 | safetensors_state_dict = load_file(safetensor_path) 73 | 74 | for key_safetensors in safetensors_state_dict: 75 | # these if are required by current diffusers' API 76 | # remove these may have negative effect as not all LoRAs are used 77 | if 'text' in key_safetensors: 78 | continue 79 | if 'unet' not in key_safetensors: 80 | continue 81 | if 'transformer_blocks' not in key_safetensors: 82 | continue 83 | if 'ff_net' in key_safetensors or 'alpha' in key_safetensors: 84 | continue 85 | key_bin = convert_name_to_bin(key_safetensors) 86 | bin_state_dict[key_bin] = safetensors_state_dict[key_safetensors] 87 | 88 | torch.save(bin_state_dict, bin_path) 89 | 90 | 91 | def convert_base_model_to_diffuser(checkpoint_path: str, 92 | target_path: str, 93 | from_safetensors: bool = False, 94 | save_half: bool = False, 95 | controlnet: str = None, 96 | to_safetensors: bool = False) -> None: 97 | """ 98 | Convert base model to diffuser format and save it. 99 | 100 | Args: 101 | checkpoint_path (str): Path to the checkpoint file. 102 | target_path (str): Path to save the diffuser model. 103 | from_safetensors (bool, optional): Flag indicating whether to load from safetensors. 104 | save_half (bool, optional): Flag indicating whether to save the model in half precision. 105 | controlnet (str, optional): Controlnet model path. 106 | to_safetensors (bool, optional): Flag indicating whether to serialize in safetensors format. 107 | """ 108 | 109 | pipe = download_from_original_stable_diffusion_ckpt( 110 | checkpoint_path=checkpoint_path, 111 | from_safetensors=from_safetensors, 112 | controlnet=controlnet) 113 | 114 | if save_half: 115 | pipe.to(torch_dtype=torch.float16) 116 | 117 | if controlnet: 118 | # only save the controlnet model 119 | pipe.controlnet.save_pretrained(target_path, 120 | safe_serialization=to_safetensors) 121 | else: 122 | pipe.save_pretrained(target_path, safe_serialization=to_safetensors) 123 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Service for diffusers api 2 | 3 | ## 项目概述 4 | 5 | 本项目提供了基于PAI-EAS的[diffusers api](https://github.com/huggingface/diffusers) 云服务实现。 6 | 7 | ![img](./assets/pipeline.png) 8 | 9 | 10 | 经PAI-Blade优化后,显著提升模型推理性能: 11 | (A10 显卡) 12 | | **image size** | **sample steps** | **Time of Pytorch(s)** | **Time of PAI-Blade(s)** | **speedup** | **Pytorch memory(GB)** | **PAI-Blade memory(GB)** | 13 | | -------------- | ---------------- | ---------------------- | ------------------------ | ----------- | ---------------------- | ------------------------ | 14 | | 1024x1024 | 50 | OOM | 14.36 | - | OOM | 5.98 | 15 | | 768x768 | 50 | 14.88 | 6.45 | 2.31X | 15.35 | 5.77 | 16 | | 512x512 | 50 | 6.14 | 2.68 | 2.29X | 6.98 | 5.44 | 17 | 18 | 19 | 20 | ### 功能概览 21 | 22 | 您可以参考本项目实现: 23 | 24 | - 直接基于本项目提供镜像,在PAI-EAS上部署已支持功能的服务。已支持的主流功能如下: 25 | 26 | (在服务部署时通过指定--func_name 进行不同类型的服务部署) 27 | 28 | - base服务 (通过请求时的func_name参数指定) 29 | - 文生图t2i/图生图i2i/图像inpaint/图像outpaint 30 | - controlnet服务 31 | - 基于ControlNet的端到端图像编辑 32 | - ControlNet的在线切换 33 | - LoRA模型的添加/修改 34 | - post时 通过指定lora_path,lora_attn 使用LoRA模型 35 | 36 | - 参考项目源码,快速进行二次开发,部署实现任意的diffusers api服务。 37 | 38 | 39 | 40 | ### Features 41 | 42 | - 基于PAI-Blade 进行推理优化 43 | 44 | - 降低 Text2Img、Img2Img 推理流程的端到端延迟 2.3 倍,同时可显著降低显存占用,超过TensorRT-v8.5等业内SOTA优化手段 45 | 46 | - 扩展和兼容diffusers API和Web UI,以适配社区可下载的模型 47 | 48 | - 支持多个LoRA模型的融合,及sd-script训练的LoRA模型(civitai等第三方网站下载的模型),LoRA模型的在线切换 49 | 50 | - 基于PAI-EAS,提供异步推理及弹性调度能力 51 | 52 | - 内置翻译模型,支持中/英文prompt输入 53 | 54 | - 简单的API实现,方便进行二次开发 55 | 56 | 57 | ### WebUI 58 | ⚠️ 项目提供了EAS上部署SD model服务的自定义实践,对SD-WebUI具有一定的兼容性,相比SD-WebUI的api更适合二次开发。 59 | 60 | 您亦可参考[WebUI使用文档](https://alidocs.dingtalk.com/i/nodes/R1zknDm0WR6XzZ4Lt1aQewElWBQEx5rG),在PAI-EAS部署基于SDWEBUI的前/后端服务。 61 | 62 | PAI-SDWEBUI解决方案为您提供了: 63 | 64 | - 快捷部署,开箱即用 65 | - 获取阿里云账号后,5分钟即可完成单机SD WEBUI部署,获取webui 网址,具备与PC机使用完全一致的体验 66 | - 预置常用插件,热门开源模型下载加速 67 | - 底层资源动态切换,根据需求随时置换GU30、A10、A100等多种GPU 68 | - 企业级功能 69 | - 前后端分离改造,支持多用户对多GPU卡的集群调度(如10个用户共用3张卡) 70 | - 支持按阿里云子账号识别用户身份,隔离使用的模型、输出图片,但共享使用的webui网址、gpu算力 71 | - 支持账单按工作室、集群等方式拆分 72 | - 插件及优化 73 | - 集成PAI-blade性能优化工具,启用后,图片生成速度较原生WEBUI 有2-3倍的提升,较启用xformer优化有20%-60%的提升,且对模型效果无损 74 | - 提供filebrowser插件,支持用户在PC电脑上传下载云端模型、图片 75 | - 提供自研modelzoo插件,支持开源模型下载加速,模型维度的prompt、图片结果、参数的管理功能。 76 | - 支持企业用户使用时,插件集中化管理、用户个性化自定义使用两种形式 77 | 78 | 79 | 80 | ## 快速开始 81 | 82 | ### Step1: 环境搭建。 83 | 84 | 您可使用预置镜像或通过[Dockerfile](./diffusers/Dockerfile)自行搭建。(目前Blade推理优化仅支持A10/A100显卡,及PAI上推出的GU系列显卡) 85 | 86 | ```bash 87 | # EAS部署时 根据您所部署服务的region的不同 region可替换为hangzhou/shanghai等 该镜像下载自带加速 服务部署快 88 | eas-registry-vpc.cn-{region}.cr.aliyuncs.com/pai-eas/diffuser-inference:2.2.1-py38-cu113-unbuntu2004-blade-public 89 | 90 | # 公开镜像 91 | registry.cn-shanghai.aliyuncs.com/pai-ai-test/eas-service:2.2.1-py38-cu113-unbuntu2004-blade-public 92 | ``` 93 | 94 | ### Step2: 服务开发/部署。 95 | 96 | - 自定义processor开发 97 | 98 | - 您可修改预测服务主文件[app.py](./diffusers/app.py) 进行自定义processor的开发。[这里](./doc/app.md)我们简单梳理本服务的流程,便于您进行二次开发。 99 | 100 | - 更多PAI-EAS的自定义processor开发请参考[官方文档](https://help.aliyun.com/document_detail/143418.html?spm=a2c4g.130248.0.0.3c316f27SLZN0o)。 101 | 102 | 103 | - 服务部署 104 | 105 | 106 | - 以提供的app.py 为例,请参考 [部署文档](./doc/deploy.md) 在PAI-EAS 进行服务部署。部署完成后,您可以进行服务调用。 107 | 108 | 109 | - 本地服务部署测试可直接运行 (更多命令参数请参考 [部署文档](./doc/deploy.md) 的详细说明) 110 | 111 | ```bash 112 | # 使用 --local_debug 开启本地测试 113 | # 本地测试时 --model_dir 指定模型位置 --save_dir 指定出图位置 (EAS部署时 通过json挂载指定,无需额外输入) 114 | # 本地测试时 --oss_save_dir --region参数用于生成保存的oss链接,可忽略(任意输入) 115 | # --func_name base/controlnet 116 | # 本地调试请加入 export BLADE_AUTH_USE_COUNTING=1 我们提供了一定次数的试用权限。您可再PAI上无限次使用PAI-Blade推理优化。 117 | python /home/pai/app.py --func_name base --oss_save_dir oss://xxx --region hangzhou --model_dir=your_path_to_model --save_dir=your_path_to_output --use_blade --local_debug 118 | ``` 119 | 120 | 121 | 122 | ### Step3: 服务调用。 123 | 124 | **接口定义**:[param.md](./doc/param.md) 125 | 126 | **同步调用:** 127 | 128 | base服务调用:[sync_example_base.py](./diffusers/example/sync_example_base.py) 129 | 130 | controlnet服务调用:[sync_example_control.py](./diffusers/example/sync_example_control.py) 131 | 132 | **异步调用:** 133 | 134 | 对于aigc而言,通常推理时间较长,同步返回结果易超时。 135 | 136 | 可部署EAS异步服务,调用时分别 发送请求 和 订阅结果。 137 | 138 | - 服务部署时,设置 139 | 140 | ```json 141 | "metadata": { 142 | "type": "Async" 143 | }, 144 | ``` 145 | 146 | 以python SDK为例,可参考以下脚本,更多EAS异步推理SDK见:[参考文档](https://help.aliyun.com/document_detail/446942.html) 147 | 148 | 发送请求:[async_example_post.py](./diffusers/example/async_example_post.py) 149 | 150 | 订阅结果:[async_example_get.py](./diffusers/example/async_example_get.py) 151 | 152 | ⚠️:异步推理暂时不支持使用base64进行图像返回。 153 | -------------------------------------------------------------------------------- /diffusers/example/sync_example_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | post example when deploy the service name as base 3 | set --use_transalte and upload the translate model to post chinese prompt 4 | translate model: https://www.modelscope.cn/models/damo/nlp_csanmt_translation_zh2en/summary 5 | """ 6 | 7 | import base64 8 | import json 9 | import os 10 | import sys 11 | from io import BytesIO 12 | 13 | import requests 14 | from PIL import Image, PngImagePlugin 15 | 16 | ENCODING = 'utf-8' 17 | 18 | hosts = 'http://xxx.cn-hangzhou.pai-eas.aliyuncs.com/api/predict/service_name' 19 | head = { 20 | 'Authorization': 'xxx' 21 | } 22 | 23 | func_list = ['t2i','i2i','inpaint','outpaint'] 24 | 25 | def decode_base64(image_base64, save_file): 26 | img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64))) 27 | img.save(save_file) 28 | 29 | 30 | def select_data(func_name): 31 | if func_name == 't2i': 32 | datas = json.dumps({ 33 | 'task_id': 34 | func_name, 35 | 'prompt': 36 | '一只可爱的小猫', 37 | 'func_name': 38 | func_name, # or default is t2i 39 | 'negative_prompt': 40 | 'NSFW', 41 | 'steps': 42 | 50, 43 | 'image_num': 44 | 1, 45 | 'width': 46 | 512, 47 | 'height': 48 | 512, 49 | 'lora_path': 50 | ['lora/animeLineartMangaLike_v30MangaLike.safetensors'], 51 | 'lora_attn': 52 | 1.0 53 | }) 54 | elif func_name == 'i2i': 55 | datas = json.dumps({ 56 | 'image_link': 57 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/image.png', 58 | # 'image_base64': base64.b64encode(open('/mnt/xinyi.zxy/diffuser/models/bosi2/result/a001_20230602_075912_12.png', 'rb').read()).decode(ENCODING), 59 | 'task_id': 60 | func_name, 61 | 'prompt': 62 | 'a cat', 63 | 'func_name': 64 | func_name, 65 | 'negative_prompt': 66 | 'NSFW', 67 | 'steps': 68 | 50, 69 | 'image_num': 70 | 1, 71 | 'width': 72 | 512, 73 | 'height': 74 | 512, 75 | 'lora_path': 76 | ['lora/animeLineartMangaLike_v30MangaLike.safetensors'] 77 | }) 78 | elif func_name == 'inpaint': 79 | datas = json.dumps({ 80 | 'image_link': 81 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/image.png', 82 | 'mask_link': 83 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/mask.png', 84 | 'task_id': 85 | func_name, 86 | 'prompt': 87 | 'a cat', 88 | 'func_name': 89 | func_name, 90 | 'negative_prompt': 91 | 'NSFW', 92 | 'steps': 93 | 50, 94 | 'image_num': 95 | 1, 96 | 'width': 97 | 512, 98 | 'height': 99 | 512, 100 | 'lora_path': 101 | ['lora/animeLineartMangaLike_v30MangaLike.safetensors'] 102 | }) 103 | elif func_name == 'outpaint': 104 | datas = json.dumps({ 105 | 'image_link': 106 | 'https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/image.png', 107 | # 'image_base64': base64.b64encode(open('/mnt/xinyi.zxy/diffuser/models/bosi2/result/a001_20230602_075912_12.png', 'rb').read()).decode(ENCODING), 108 | 'task_id': func_name, 109 | 'prompt': 'a cat', 110 | 'func_name': func_name, 111 | 'negative_prompt': 'NSFW', 112 | 'steps': 50, 113 | 'image_num': 1, 114 | 'width': 512, 115 | 'height': 512, 116 | 'expand': [256, 256, 0, 0], # [left,right,up,down] 117 | 'expand_type': 'copy', # or 'reflect', 118 | 'denoising_strength': 0.6 119 | # 'lora_path': ['path_to_your_lora_model'] 120 | }) 121 | else: 122 | raise ValueError('Invalid process_func value') 123 | 124 | return datas 125 | 126 | 127 | for func_name in func_list: 128 | datas = select_data(func_name) 129 | 130 | r = requests.post(hosts, data=datas, headers=head) 131 | # r = requests.post("http://0.0.0.0:8000/test", data=datas, timeout=1500) 132 | 133 | data = json.loads(r.content.decode('utf-8')) 134 | print(data.keys()) 135 | 136 | if data['success']: 137 | print(data['image_url']) 138 | print(data['oss_url']) 139 | print(data['task_id']) 140 | print(data['use_blade']) 141 | print(data['seed']) 142 | print(data['is_nsfw']) 143 | if 'images_base64' in data.keys(): 144 | for i, image_base64 in enumerate(data['images_base64']): 145 | decode_base64(image_base64, './result_{}.png'.format(str(i))) 146 | 147 | else: 148 | print(data['error_msg']) 149 | -------------------------------------------------------------------------------- /diffusers/utils/io.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, Optional, Tuple, Union 4 | 5 | import requests 6 | from PIL import Image 7 | 8 | import torch 9 | from diffusers import (ControlNetModel, DiffusionPipeline, 10 | DPMSolverMultistepScheduler, 11 | StableDiffusionControlNetPipeline, 12 | StableDiffusionPipeline) 13 | 14 | 15 | def load_diffusers_pipeline( 16 | model_base: str, 17 | lora_path: Optional[str], 18 | controlnet_path: Optional[str], 19 | device: str, 20 | mode: str = 'base', 21 | close_safety: bool = False, 22 | custom_pipeline: str = '/home/pai/lpw_stable_diffusion.py' 23 | ) -> Union[DiffusionPipeline, StableDiffusionControlNetPipeline]: 24 | """ 25 | Loads a DiffusionPipeline or StableDiffusionControlNetPipeline with a LoRA checkpoint, 26 | based on the specified mode of operation. 27 | 28 | Args: 29 | model_base (str): The path to the base model checkpoint 30 | lora_path (str, optional): The path to the LoRA checkpoint 31 | controlnet_path (str, optional): The path to the controlnet checkpoint (if mode='controlnet') 32 | device (str): The device where the pipeline will run (e.g. 'cpu' or 'cuda') 33 | mode (str): The mode of operation ('base', or 'controlnet') 34 | close_safety (bool): Whether to disable safety checks in the pipeline 35 | custom_pipeline (str): The path to a custom pipeline script (if any) 36 | 37 | Returns: 38 | Union[DiffusionPipeline, StableDiffusionControlNetPipeline]: 39 | A DiffusionPipeline (LPW) or StableDiffusionControlNetPipeline object with a LoRA checkpoint loaded. 40 | """ 41 | if mode == 'base': 42 | if close_safety: 43 | pipe = DiffusionPipeline.from_pretrained( 44 | model_base, 45 | custom_pipeline=custom_pipeline, 46 | torch_dtype=torch.float16, 47 | safety_checker=None) 48 | else: 49 | pipe = DiffusionPipeline.from_pretrained( 50 | model_base, 51 | custom_pipeline=custom_pipeline, 52 | torch_dtype=torch.float16, 53 | ) 54 | 55 | elif mode == 'controlnet': 56 | controlnet = ControlNetModel.from_pretrained(controlnet_path, 57 | torch_dtype=torch.float16) 58 | 59 | if close_safety: 60 | pipe = StableDiffusionControlNetPipeline.from_pretrained( 61 | model_base, 62 | controlnet=controlnet, 63 | revision='fp16', 64 | torch_dtype=torch.float16, 65 | safety_checker=None) 66 | else: 67 | pipe = StableDiffusionControlNetPipeline.from_pretrained( 68 | model_base, 69 | controlnet=controlnet, 70 | revision='fp16', 71 | torch_dtype=torch.float16) 72 | else: 73 | raise ValueError( 74 | 'Unrecognized function name: {}. We support base(t2i)/controlnet'. 75 | format(mode)) 76 | 77 | pipe.to(device) 78 | 79 | if lora_path is not None: 80 | pipe.unet.load_attn_procs(lora_path, use_safetensors=False) 81 | 82 | return pipe 83 | 84 | 85 | def download_image(image_link: str) -> Image.Image: 86 | """ 87 | Download an image from the given image_link and return it as a PIL Image object. 88 | 89 | Args: 90 | image_link (str): The URL of the image to download. 91 | 92 | Returns: 93 | Image.Image: The downloaded image as a PIL Image object. 94 | """ 95 | response = requests.get(image_link) 96 | image_name = image_link.split('/')[-1] 97 | with open(image_name, 'ab') as f: 98 | f.write(response.content) 99 | f.flush() 100 | img = Image.open(image_name).convert('RGB') 101 | 102 | return img 103 | 104 | 105 | def get_result_str(result_dict: Optional[Dict[str, Union[int, str]]] = None, 106 | error: Optional[Exception] = None) -> Tuple[bytes, int]: 107 | """ 108 | Generates a result string in JSON format based on the provided result dictionary and error. 109 | 110 | Args: 111 | result_dict (Optional[Dict[str, Union[int, str]]]): A dictionary containing the result information. 112 | error (Optional[Exception]): An error object representing any occurred error. 113 | 114 | Returns: 115 | Tuple[bytes, int]: A tuple containing the result string encoded in UTF-8 and the HTTP status code. 116 | 117 | """ 118 | result = {} 119 | 120 | if error is not None: 121 | result['success'] = 0 122 | result['error_code'] = error.code 123 | result['error_msg'] = error.msg[:200] 124 | stat = error.code 125 | 126 | if result_dict is not None and 'task_id' in result_dict.keys(): 127 | result['task_id'] = result_dict['task_id'] 128 | 129 | elif result_dict is not None: 130 | result['success'] = 1 131 | result.update(result_dict) 132 | stat = 200 133 | 134 | result_str = json.dumps(result).encode('utf-8') 135 | 136 | return result_str, stat 137 | -------------------------------------------------------------------------------- /diffusers/tests/test_utils/test_blade.py: -------------------------------------------------------------------------------- 1 | import os 2 | import unittest 3 | 4 | from PIL import Image 5 | 6 | import torch 7 | import torch_blade 8 | from tests.ut_config import (BASE_MODEL_PATH, CONTROLNET_MODEL_PATH, 9 | CUSTOM_PIPELINE, IMAGE_DIR, MODEL_DIR, 10 | PRETRAIN_DIR, SAVE_DIR) 11 | from utils.blade import load_blade_model, optimize_and_save_blade_model 12 | from utils.image_process import preprocess_control 13 | from utils.io import load_diffusers_pipeline 14 | 15 | # important!or the blade result will be incorrect 16 | os.environ['DISC_ENABLE_DOT_MERGE'] = '0' 17 | 18 | 19 | class TestBladeOptimization(unittest.TestCase): 20 | def setUp(self): 21 | # hyper parameters 22 | self.prompt = 'a dog' 23 | img_path = os.path.join(IMAGE_DIR, 'image.png') 24 | mask_path = os.path.join(IMAGE_DIR, 'mask.png') 25 | self.image = Image.open(img_path).convert('RGB') 26 | self.mask = Image.open(mask_path).convert('RGB') 27 | self.num_inference_steps = 20 28 | self.num_images_per_prompt = 1 29 | self.device = torch.device( 30 | 'cuda') if torch.cuda.is_available() else torch.device('cpu') 31 | 32 | def test_optimize_and_save_blade_model_base(self): 33 | # save and optimize base model 34 | blade_dir = os.path.join(MODEL_DIR, 'optimized_model') 35 | os.makedirs(blade_dir, exist_ok=True) 36 | encoder_path = os.path.join(blade_dir, 'encoder.pt') 37 | unet_path = os.path.join(blade_dir, 'unet.pt') 38 | decoder_path = os.path.join(blade_dir, 'decoder.pt') 39 | controlnet_path = None 40 | mode = 'base' 41 | close_safety = False 42 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None, None, 43 | self.device, mode, close_safety, 44 | CUSTOM_PIPELINE) 45 | optimize_and_save_blade_model(pipe, encoder_path, unet_path, 46 | decoder_path, controlnet_path) 47 | 48 | # save and optimize base model 49 | assert os.path.exists( 50 | encoder_path), f"Encoder path '{encoder_path}' does not exist!" 51 | assert os.path.exists( 52 | unet_path), f"UNet path '{unet_path}' does not exist!" 53 | assert os.path.exists( 54 | decoder_path), f"Decoder path '{decoder_path}' does not exist!" 55 | # load 56 | pipe = load_blade_model(pipe, encoder_path, unet_path, decoder_path, 57 | controlnet_path) 58 | with torch.no_grad(): 59 | res = pipe.text2img( 60 | prompt=self.prompt, 61 | num_inference_steps=self.num_inference_steps, 62 | num_images_per_prompt=self.num_images_per_prompt) 63 | image = res.images[0] 64 | self.assertIsInstance(image, Image.Image) 65 | image.save(os.path.join(SAVE_DIR, 't2i_blade.jpg')) 66 | 67 | def test_optimize_and_save_blade_model_controlnet(self): 68 | # save and optimize base model 69 | blade_dir = os.path.join(MODEL_DIR, 'optimized_control_model') 70 | os.makedirs(blade_dir, exist_ok=True) 71 | encoder_path = os.path.join(blade_dir, 'encoder.pt') 72 | unet_path = os.path.join(blade_dir, 'unet.pt') 73 | decoder_path = os.path.join(blade_dir, 'decoder.pt') 74 | controlnet_path = os.path.join(blade_dir, 'controlnet.pt') 75 | mode = 'controlnet' 76 | close_safety = False 77 | pipe = load_diffusers_pipeline(BASE_MODEL_PATH, None, 78 | CONTROLNET_MODEL_PATH, self.device, 79 | mode, close_safety, CUSTOM_PIPELINE) 80 | 81 | optimize_and_save_blade_model(pipe, encoder_path, unet_path, 82 | decoder_path, controlnet_path) 83 | 84 | # save and optimize base model 85 | assert os.path.exists( 86 | encoder_path), f"Encoder path '{encoder_path}' does not exist!" 87 | assert os.path.exists( 88 | unet_path), f"UNet path '{unet_path}' does not exist!" 89 | assert os.path.exists( 90 | decoder_path), f"Decoder path '{decoder_path}' does not exist!" 91 | assert os.path.exists( 92 | controlnet_path 93 | ), f"ControlNet path '{controlnet_path}' does not exist!" 94 | # load 95 | pipe = load_blade_model(pipe, encoder_path, unet_path, decoder_path, 96 | controlnet_path) 97 | with torch.no_grad(): 98 | process_image = preprocess_control(self.image, 'canny', 99 | PRETRAIN_DIR) 100 | res = pipe(prompt=self.prompt, 101 | image=process_image, 102 | num_inference_steps=self.num_inference_steps, 103 | num_images_per_prompt=self.num_images_per_prompt) 104 | image = res.images[0] 105 | self.assertIsInstance(image, Image.Image) 106 | image.save(os.path.join(SAVE_DIR, 'control_blade.jpg')) 107 | 108 | 109 | if __name__ == '__main__': 110 | unittest.main() 111 | -------------------------------------------------------------------------------- /doc/param.md: -------------------------------------------------------------------------------- 1 | ## 服务输入输出参数说明 2 | 3 | - post 输入参数 4 | 5 | | **参数名** | **说明** | **类型** | **默认值** | 6 | | ------------------ | ------------------------------------------------------------ | ------------------ | ---------------------------------- | 7 | | task_id | 任务ID | string | 必须 | 8 | | prompt | 用户输入的正向提示词 | string | 必须 | 9 | | func_name | post的功能部署服务为base时支持传入t2i/i2i/inpaint 进行功能转换 | string | t2i | 10 | | steps | 用户输入的步数 | int | 50 | 11 | | cfg_scale | guidance_scale | int | 7 | 12 | | denoising_strength | 与原图的合并比例【只在图生图中有效】 | float | 0.55 | 13 | | width | 生成图片宽度 | int | 512 | 14 | | height | 生成图片高度 | int | 512 | 15 | | negative_prompt | 用户输入的负向提示词 | string | “” | 16 | | image_num | 用户输入的图片数量 | int | 1 | 17 | | resize_mode | 调整生成图片缩放方式 0 拉伸 1 裁剪 2 填充 | int | 0 | 18 | | image_link | 用户输入的图片url地址 | string | 图生图,inpaint controlnet必须提供 | 19 | | mask_link | 用户输入的mask url 地址 | string | inpaint 必须提供 | 20 | | image_base64 | 用户输入的图片 base64格式 | base64 | 与image_link二选一 | 21 | | mask_base64 | 用户输入的mask base64格式 | base64 | 与mask_link二选一 | 22 | | use_base64 | 是否返回imagebase64的图像结果 | bool | False | 23 | | lora_attn | lora使用的比例当使用多LoRA融合时支持列表的输入 | floatList[float] | 0.75 | 24 | | lora_path | 需要更新的lora模型在oss挂载路径的相对位置使用多LoRA融合时支持列表的输入 | stringList[string] | 无 | 25 | | controlnet_path | 需要更新的controlnet模型在oss挂载路径的相对位置(huggingface 上可下载的safetensors/bin 文件) | string | 无 | 26 | | process_func | 图像预处理方式,用于生成controlnet的控制图像 | string | 具体支持的列表见下表 | 27 | | expand | outpaint时 各个方向需要填充的像素数[left,right,up,down] | list | 无 | 28 | | expand_type | 原始图像的扩充方式(影响outpaint的出图效果)copy(复制边缘)reflect(镜像翻转边缘) | string | copy | 29 | | save_dir | 传入文件夹的名字文件将存放在部署挂载的result路径中的save_dir文件夹中 | string | result | 30 | 31 | - controlnet支持列表(仅支持下表中的8种格式的端到端处理,对于其他controlnet 您可自行处理后用于控制生成的图像) 32 | 33 | | process_func | 实现功能 | controlnet参考下载地址 | 34 | | ------------ | -------- | ------------------------------------------------------------ | 35 | | canny | 边缘检测 | [边缘检测](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-canny/diffusion_pytorch_model.safetensors) | 36 | | depth | 深度检测 | [深度检测](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-depth/diffusion_pytorch_model.safetensors) | 37 | | hed | 线稿上色 | [线稿上色](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-hed/diffusion_pytorch_model.safetensors) | 38 | | mlsd | 线段识别 | [线段识别](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-mlsd/diffusion_pytorch_model.safetensors) | 39 | | normal | 物体识别 | [物体识别](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--fusing--stable-diffusion-v1-5-controlnet-normal/diffusion_pytorch_model.safetensors) | 40 | | openpose | 姿态识别 | [姿态识别](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-openpose/diffusion_pytorch_model.safetensors) | 41 | | scribble | 线稿上色 | [线稿上色](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-scribble/diffusion_pytorch_model.safetensors) | 42 | | seg | 语义分割 | [语义分割](https://converter-offline-installer.oss-cn-hangzhou.aliyuncs.com/zxy/diffusers/dbt/demo_controlnet/new_controlnet/models--lllyasviel--sd-controlnet-seg/diffusion_pytorch_model.safetensors) | 43 | 44 | - post 输出参数 45 | 46 | | **参数名** | **说明** | **类型** | 47 | | ------------- | ------------------------------------------------------------ | ---------- | 48 | | image_url | 生成图像的公网可访问链接 【在开放acl权限后有效】 | list | 49 | | images_base64 | 生成的图像列表 base64格式(use_base64开启时会返回) | list | 50 | | oss_url | 生成图像的oss地址 | list | 51 | | success | 是否成功 0-失败 1-成功 | int | 52 | | seed | 生成图像的种子 | string | 53 | | task_id | 任务ID | string | 54 | | error_msg | 错误的原因【只在success=0时返回错误】 | string | 55 | | use_blade | 是否使用了blade 进行推理优化【blade模型成功优化后,会在第一次推理时默认使用】 | bool | 56 | | is_nsfw | 用于表示生成图片是否不合法【True为黑图】 | list[bool] | 57 | -------------------------------------------------------------------------------- /diffusers/example/sync_example_control.py: -------------------------------------------------------------------------------- 1 | """ 2 | post example when deploy the service name as controlnet 3 | """ 4 | 5 | import base64 6 | import json 7 | import os 8 | import sys 9 | from io import BytesIO 10 | 11 | import requests 12 | from PIL import Image, PngImagePlugin 13 | 14 | ENCODING = 'utf-8' 15 | 16 | hosts = 'http://xxx.cn-hangzhou.pai-eas.aliyuncs.com/api/predict/service_name' 17 | head = { 18 | 'Authorization': 'xxx' 19 | } 20 | 21 | 22 | def decode_base64(image_base64, save_file): 23 | img = Image.open(BytesIO(base64.urlsafe_b64decode(image_base64))) 24 | img.save(save_file) 25 | 26 | 27 | def select_data(process_func): 28 | if process_func == 'canny': 29 | datas = json.dumps({ 30 | 'task_id': 'canny', 31 | 'steps': 50, 32 | 'image_num': 1, 33 | 'width': 512, 34 | 'height': 512, 35 | 'image_link': 36 | 'https://huggingface.co/lllyasviel/sd-controlnet-hed/resolve/main/images/man.png', 37 | 'prompt': 'man', 38 | 'process_func': 'canny', 39 | }) 40 | elif process_func == 'depth': 41 | datas = json.dumps({ 42 | 'task_id': 'depth', 43 | 'steps': 50, 44 | 'image_num': 1, 45 | 'width': 512, 46 | 'height': 512, 47 | 'image_link': 48 | 'https://huggingface.co/lllyasviel/sd-controlnet-depth/resolve/main/images/stormtrooper.png', 49 | 'prompt': "Stormtrooper's lecture", 50 | 'controlnet_path': 51 | 'new_controlnet/models--lllyasviel--sd-controlnet-depth/diffusion_pytorch_model.safetensors', # use to change the controlnet path 52 | 'process_func': 'depth', 53 | }) 54 | elif process_func == 'hed': 55 | datas = json.dumps({ 56 | 'task_id': 'hed', 57 | 'steps': 50, 58 | 'image_num': 1, 59 | 'width': 512, 60 | 'height': 512, 61 | 'image_link': 62 | 'https://huggingface.co/lllyasviel/sd-controlnet-hed/resolve/main/images/man.png', 63 | 'prompt': 'oil painting of handsome old man, masterpiece', 64 | 'controlnet_path': 65 | 'new_controlnet/models--lllyasviel--sd-controlnet-hed/diffusion_pytorch_model.safetensors', 66 | 'process_func': 'hed', 67 | }) 68 | elif process_func == 'mlsd': 69 | datas = json.dumps({ 70 | 'task_id': 'mlsd', 71 | 'steps': 50, 72 | 'image_num': 1, 73 | 'width': 512, 74 | 'height': 512, 75 | 'image_link': 76 | 'https://huggingface.co/lllyasviel/sd-controlnet-mlsd/resolve/main/images/room.png', 77 | 'prompt': 'room', 78 | 'controlnet_path': 79 | 'new_controlnet/models--lllyasviel--sd-controlnet-mlsd/diffusion_pytorch_model.safetensors', 80 | 'process_func': 'mlsd', 81 | }) 82 | elif process_func == 'normal': 83 | datas = json.dumps({ 84 | 'task_id': 'normal', 85 | 'steps': 50, 86 | 'image_num': 1, 87 | 'width': 512, 88 | 'height': 512, 89 | 'image_link': 90 | 'https://huggingface.co/lllyasviel/sd-controlnet-normal/resolve/main/images/toy.png', 91 | 'prompt': 'cute toy', 92 | 'controlnet_path': 93 | 'new_controlnet/models--fusing--stable-diffusion-v1-5-controlnet-normal/diffusion_pytorch_model.safetensors', 94 | 'process_func': 'normal', 95 | }) 96 | elif process_func == 'openpose': 97 | datas = json.dumps({ 98 | 'task_id': 'openpose', 99 | 'steps': 50, 100 | 'image_num': 1, 101 | 'width': 512, 102 | 'height': 512, 103 | 'image_link': 104 | 'https://huggingface.co/lllyasviel/sd-controlnet-openpose/resolve/main/images/pose.png', 105 | 'prompt': 'chef in the kitchen', 106 | 'controlnet_path': 107 | 'new_controlnet/models--lllyasviel--sd-controlnet-openpose/diffusion_pytorch_model.safetensors', 108 | 'process_func': 'openpose', 109 | }) 110 | elif process_func == 'scribble': 111 | 112 | datas = json.dumps({ 113 | 'task_id': 'scribble', 114 | 'steps': 50, 115 | 'image_num': 1, 116 | 'width': 512, 117 | 'height': 512, 118 | 'image_link': 119 | 'https://huggingface.co/lllyasviel/sd-controlnet-scribble/resolve/main/images/bag.png', 120 | 'prompt': 'bag', 121 | 'controlnet_path': 122 | 'new_controlnet/models--lllyasviel--sd-controlnet-scribble/diffusion_pytorch_model.safetensors', 123 | 'process_func': 'scribble', 124 | }) 125 | 126 | elif process_func == 'seg': 127 | datas = json.dumps({ 128 | 'task_id': 'seg', 129 | 'steps': 50, 130 | 'image_num': 1, 131 | 'width': 512, 132 | 'height': 512, 133 | 'image_link': 134 | 'https://huggingface.co/lllyasviel/sd-controlnet-seg/resolve/main/images/house.png', 135 | 'prompt': 'house', 136 | 'controlnet_path': 137 | 'new_controlnet/models--lllyasviel--sd-controlnet-seg/diffusion_pytorch_model.safetensors', 138 | 'process_func': 'seg', 139 | }) 140 | else: 141 | raise ValueError('Invalid process_func value') 142 | 143 | return datas 144 | 145 | 146 | process_func_list = [ 147 | 'canny', 'depth', 'hed', 'mlsd', 'normal', 'openpose', 'scribble', 'seg' 148 | ] 149 | 150 | for process_func in process_func_list: 151 | datas = select_data(process_func) 152 | 153 | r = requests.post(hosts, data=datas, headers=head) 154 | # r = requests.post("http://0.0.0.0:8000/test", data=datas, timeout=1500) 155 | 156 | data = json.loads(r.content.decode('utf-8')) 157 | print(data.keys()) 158 | 159 | if data['success']: 160 | print(data['image_url']) 161 | print(data['oss_url']) 162 | print(data['task_id']) 163 | print(data['use_blade']) 164 | print(data['seed']) 165 | print(data['is_nsfw']) 166 | if 'images_base64' in data.keys(): 167 | for i, image_base64 in enumerate(data['images_base64']): 168 | decode_base64(image_base64, 169 | './decode_ldm_base64_{}.png'.format(str(i))) 170 | 171 | else: 172 | print(data['error_msg']) 173 | -------------------------------------------------------------------------------- /diffusers/utils/image_process.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Tuple, Union 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image, ImageDraw 7 | 8 | import torch 9 | from controlnet_aux import HEDdetector, MLSDdetector, OpenposeDetector 10 | # need to be import or an malloc error will be occured by controlnet_aux 11 | from diffusers import ControlNetModel, StableDiffusionControlNetPipeline 12 | from transformers import (AutoImageProcessor, UperNetForSemanticSegmentation, 13 | pipeline) 14 | 15 | 16 | def canny(image: Image.Image, 17 | pretrain_dir: str, 18 | low_threshold: int = 100, 19 | high_threshold: int = 200) -> Image.Image: 20 | """ 21 | Apply the Canny edge detection algorithm to the image. 22 | 23 | Args: 24 | image (Image.Image): The input image. 25 | low_threshold (int): The lower threshold for edge detection (default: 100). 26 | high_threshold (int): The higher threshold for edge detection (default: 200). 27 | 28 | Returns: 29 | Image.Image: The processed image with detected edges. 30 | """ 31 | image = np.array(image) 32 | image = cv2.Canny(image, low_threshold, high_threshold) 33 | image = image[:, :, None] 34 | image = np.concatenate([image, image, image], axis=2) 35 | image = Image.fromarray(image) 36 | return image 37 | 38 | 39 | def depth(image: Image.Image, pretrain_dir: str) -> Image.Image: 40 | """ 41 | Estimate the depth map of the image using a pre-trained depth estimation model. 42 | 43 | Args: 44 | image (Image.Image): The input image. 45 | pretrain_dir (str): The directory containing the pre-trained models. 46 | 47 | Returns: 48 | Image.Image: The estimated depth map of the image. 49 | """ 50 | depth_estimator = pipeline('depth-estimation', 51 | model=os.path.join(pretrain_dir, 52 | 'models--Intel--dpt-large')) 53 | image = depth_estimator(image)['depth'] 54 | image = np.array(image) 55 | image = image[:, :, None] 56 | image = np.concatenate([image, image, image], axis=2) 57 | image = Image.fromarray(image) 58 | return image 59 | 60 | 61 | def hed(image: Image.Image, pretrain_dir: str) -> Image.Image: 62 | """ 63 | Apply the Holistically-Nested Edge Detection (HED) algorithm to the image. 64 | 65 | Args: 66 | image (Image.Image): The input image. 67 | pretrain_dir (str): The directory containing the pre-trained models. 68 | 69 | Returns: 70 | Image.Image: The processed image with detected edges. 71 | """ 72 | hed = HEDdetector.from_pretrained( 73 | os.path.join(pretrain_dir, 'models--lllyasviel--ControlNet')) 74 | image = hed(image) 75 | return image 76 | 77 | 78 | def mlsd(image: Image.Image, pretrain_dir: str) -> Image.Image: 79 | """ 80 | Apply MLSD (Multi-Line Segment Detection) model to the input image. 81 | 82 | Args: 83 | image (Image.Image): The input image. 84 | pretrain_dir (str): The directory path where the pre-trained model is located. 85 | 86 | Returns: 87 | Image.Image: The processed image. 88 | 89 | """ 90 | mlsd = MLSDdetector.from_pretrained( 91 | os.path.join(pretrain_dir, 'models--lllyasviel--ControlNet')) 92 | image = mlsd(image) 93 | return image 94 | 95 | 96 | def normal(image: Image.Image, 97 | pretrain_dir: str, 98 | bg_threshold: float = 0.4) -> Image.Image: 99 | """ 100 | Perform normal estimation on the input image. 101 | 102 | Args: 103 | image (Image.Image): The input image. 104 | pretrain_dir (str): The directory path where the pre-trained model is located. 105 | bg_threshold (float, optional): Background depth threshold. Default is 0.4. 106 | 107 | Returns: 108 | Image.Image: The image with normal estimation. 109 | 110 | """ 111 | depth_estimator = pipeline('depth-estimation', 112 | model=os.path.join( 113 | pretrain_dir, 114 | 'models--Intel--dpt-hybrid-midas')) 115 | image = depth_estimator(image)['predicted_depth'][0] 116 | image = image.numpy() 117 | 118 | image_depth = image.copy() 119 | image_depth -= np.min(image_depth) 120 | image_depth /= np.max(image_depth) 121 | 122 | x = cv2.Sobel(image, cv2.CV_32F, 1, 0, ksize=3) 123 | x[image_depth < bg_threshold] = 0 124 | 125 | y = cv2.Sobel(image, cv2.CV_32F, 0, 1, ksize=3) 126 | y[image_depth < bg_threshold] = 0 127 | 128 | z = np.ones_like(x) * np.pi * 2.0 129 | 130 | image = np.stack([x, y, z], axis=2) 131 | image /= np.sum(image**2.0, axis=2, keepdims=True)**0.5 132 | image = (image * 127.5 + 127.5).clip(0, 255).astype(np.uint8) 133 | image = Image.fromarray(image) 134 | 135 | return image 136 | 137 | 138 | def openpose(image: Image.Image, pretrain_dir: str) -> Image.Image: 139 | """ 140 | Apply OpenPose model to the input image. 141 | 142 | Args: 143 | image (Image.Image): The input image. 144 | pretrain_dir (str): The directory path where the pre-trained model is located. 145 | 146 | Returns: 147 | Image.Image: The processed image. 148 | 149 | """ 150 | openpose = OpenposeDetector.from_pretrained( 151 | os.path.join(pretrain_dir, 'models--lllyasviel--ControlNet')) 152 | image = openpose(image) 153 | return image 154 | 155 | 156 | def scribble(image: Image.Image, pretrain_dir: str) -> Image.Image: 157 | """ 158 | Apply scribble-based HED (Holistically-Nested Edge Detection) model to the input image. 159 | 160 | Args: 161 | image (Image.Image): The input image. 162 | pretrain_dir (str): The directory path where the pre-trained model is located. 163 | 164 | Returns: 165 | Image.Image: The processed image. 166 | 167 | """ 168 | hed = HEDdetector.from_pretrained(pretrained_model_or_path=os.path.join( 169 | pretrain_dir, 'models--lllyasviel--ControlNet')) 170 | image = hed(image, scribble=True) 171 | return image 172 | 173 | 174 | def seg(image: Image.Image, pretrain_dir: str) -> Image.Image: 175 | """ 176 | Apply semantic segmentation to the input image. 177 | 178 | Args: 179 | image (Image.Image): The input image. 180 | pretrain_dir (str): The directory path where the pre-trained models are located. 181 | 182 | Returns: 183 | Image.Image: The processed image. 184 | 185 | """ 186 | 187 | palette = np.asarray([ 188 | [0, 0, 0], 189 | [120, 120, 120], 190 | [180, 120, 120], 191 | [6, 230, 230], 192 | [80, 50, 50], 193 | [4, 200, 3], 194 | [120, 120, 80], 195 | [140, 140, 140], 196 | [204, 5, 255], 197 | [230, 230, 230], 198 | [4, 250, 7], 199 | [224, 5, 255], 200 | [235, 255, 7], 201 | [150, 5, 61], 202 | [120, 120, 70], 203 | [8, 255, 51], 204 | [255, 6, 82], 205 | [143, 255, 140], 206 | [204, 255, 4], 207 | [255, 51, 7], 208 | [204, 70, 3], 209 | [0, 102, 200], 210 | [61, 230, 250], 211 | [255, 6, 51], 212 | [11, 102, 255], 213 | [255, 7, 71], 214 | [255, 9, 224], 215 | [9, 7, 230], 216 | [220, 220, 220], 217 | [255, 9, 92], 218 | [112, 9, 255], 219 | [8, 255, 214], 220 | [7, 255, 224], 221 | [255, 184, 6], 222 | [10, 255, 71], 223 | [255, 41, 10], 224 | [7, 255, 255], 225 | [224, 255, 8], 226 | [102, 8, 255], 227 | [255, 61, 6], 228 | [255, 194, 7], 229 | [255, 122, 8], 230 | [0, 255, 20], 231 | [255, 8, 41], 232 | [255, 5, 153], 233 | [6, 51, 255], 234 | [235, 12, 255], 235 | [160, 150, 20], 236 | [0, 163, 255], 237 | [140, 140, 140], 238 | [250, 10, 15], 239 | [20, 255, 0], 240 | [31, 255, 0], 241 | [255, 31, 0], 242 | [255, 224, 0], 243 | [153, 255, 0], 244 | [0, 0, 255], 245 | [255, 71, 0], 246 | [0, 235, 255], 247 | [0, 173, 255], 248 | [31, 0, 255], 249 | [11, 200, 200], 250 | [255, 82, 0], 251 | [0, 255, 245], 252 | [0, 61, 255], 253 | [0, 255, 112], 254 | [0, 255, 133], 255 | [255, 0, 0], 256 | [255, 163, 0], 257 | [255, 102, 0], 258 | [194, 255, 0], 259 | [0, 143, 255], 260 | [51, 255, 0], 261 | [0, 82, 255], 262 | [0, 255, 41], 263 | [0, 255, 173], 264 | [10, 0, 255], 265 | [173, 255, 0], 266 | [0, 255, 153], 267 | [255, 92, 0], 268 | [255, 0, 255], 269 | [255, 0, 245], 270 | [255, 0, 102], 271 | [255, 173, 0], 272 | [255, 0, 20], 273 | [255, 184, 184], 274 | [0, 31, 255], 275 | [0, 255, 61], 276 | [0, 71, 255], 277 | [255, 0, 204], 278 | [0, 255, 194], 279 | [0, 255, 82], 280 | [0, 10, 255], 281 | [0, 112, 255], 282 | [51, 0, 255], 283 | [0, 194, 255], 284 | [0, 122, 255], 285 | [0, 255, 163], 286 | [255, 153, 0], 287 | [0, 255, 10], 288 | [255, 112, 0], 289 | [143, 255, 0], 290 | [82, 0, 255], 291 | [163, 255, 0], 292 | [255, 235, 0], 293 | [8, 184, 170], 294 | [133, 0, 255], 295 | [0, 255, 92], 296 | [184, 0, 255], 297 | [255, 0, 31], 298 | [0, 184, 255], 299 | [0, 214, 255], 300 | [255, 0, 112], 301 | [92, 255, 0], 302 | [0, 224, 255], 303 | [112, 224, 255], 304 | [70, 184, 160], 305 | [163, 0, 255], 306 | [153, 0, 255], 307 | [71, 255, 0], 308 | [255, 0, 163], 309 | [255, 204, 0], 310 | [255, 0, 143], 311 | [0, 255, 235], 312 | [133, 255, 0], 313 | [255, 0, 235], 314 | [245, 0, 255], 315 | [255, 0, 122], 316 | [255, 245, 0], 317 | [10, 190, 212], 318 | [214, 255, 0], 319 | [0, 204, 255], 320 | [20, 0, 255], 321 | [255, 255, 0], 322 | [0, 153, 255], 323 | [0, 41, 255], 324 | [0, 255, 204], 325 | [41, 0, 255], 326 | [41, 255, 0], 327 | [173, 0, 255], 328 | [0, 245, 255], 329 | [71, 0, 255], 330 | [122, 0, 255], 331 | [0, 255, 184], 332 | [0, 92, 255], 333 | [184, 255, 0], 334 | [0, 133, 255], 335 | [255, 214, 0], 336 | [25, 194, 194], 337 | [102, 255, 0], 338 | [92, 0, 255], 339 | ]) 340 | 341 | image_processor = AutoImageProcessor.from_pretrained( 342 | os.path.join(pretrain_dir, 343 | 'models--openmmlab--upernet-convnext-small')) 344 | image_segmentor = UperNetForSemanticSegmentation.from_pretrained( 345 | os.path.join(pretrain_dir, 346 | 'models--openmmlab--upernet-convnext-small')) 347 | 348 | pixel_values = image_processor(image, return_tensors='pt').pixel_values 349 | 350 | with torch.no_grad(): 351 | outputs = image_segmentor(pixel_values) 352 | 353 | seg = image_processor.post_process_semantic_segmentation( 354 | outputs, target_sizes=[image.size[::-1]])[0] 355 | 356 | color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8) 357 | 358 | for label, color in enumerate(palette): 359 | color_seg[seg == label, :] = color 360 | 361 | color_seg = color_seg.astype(np.uint8) 362 | 363 | image = Image.fromarray(color_seg) 364 | 365 | return image 366 | 367 | 368 | def preprocess_control(image: Image.Image, process_func: str, 369 | pretrain_dir: str) -> Union[str, Image.Image]: 370 | """ 371 | Apply the specified image processing function to the input image for controlnet. 372 | 373 | Args: 374 | image (Image.Image): The input image to be processed. 375 | process_func (str): The name of the processing function to be applied. 376 | pretrain_dir (str): The directory containing the pre-trained models. 377 | 378 | Returns: 379 | Union[str, Image.Image]: The processed image if successful, or an error message as a string if the specified process_func is not supported. 380 | """ 381 | process_func_dict = { 382 | 'canny': canny, 383 | 'depth': depth, 384 | 'hed': hed, 385 | 'mlsd': mlsd, 386 | 'normal': normal, 387 | 'openpose': openpose, 388 | 'scribble': scribble, 389 | 'seg': seg 390 | } 391 | 392 | if process_func not in process_func_dict: 393 | return 'We only support process functions: {}. But got {}.'.format( 394 | list(process_func_dict.keys()), process_func) 395 | 396 | process_func = process_func_dict[process_func] 397 | processed_image = process_func(image, pretrain_dir) 398 | return processed_image 399 | 400 | 401 | def transform_image(image: Image.Image, 402 | width: int, 403 | height: int, 404 | mode: int = 0) -> Image.Image: 405 | """ 406 | Transform the input image to the specified width and height using the specified mode. 407 | 408 | Args: 409 | image (PIL Image object): The image that needs to be transformed. 410 | width (int): The width of the output image. 411 | height (int): The height of the output image. 412 | mode (int, optional): Specifies the mode of image transformation. 413 | 0 - Stretch 拉伸, 1 - Crop 裁剪, 2 - Padding 填充. Defaults to 0. align with webui 414 | 415 | Returns: 416 | PIL Image object: The transformed image. 417 | """ 418 | 419 | if mode == 0: # Stretch 420 | image = image.resize((width, height)) 421 | elif mode == 1: # Crop 422 | aspect_ratio = float(image.size[0]) / float(image.size[1]) 423 | new_aspect_ratio = float(width) / float(height) 424 | 425 | if aspect_ratio > new_aspect_ratio: 426 | # Crop the width 427 | new_width = int(float(height) * aspect_ratio) 428 | left = int((new_width - width) / 2) 429 | right = new_width - left 430 | image = image.resize((new_width, height)) 431 | image = image.crop((left, 0, left + width, height)) 432 | else: 433 | # Crop the height 434 | new_height = int(float(width) / aspect_ratio) 435 | up = int((new_height - height) / 2) 436 | down = new_height - up 437 | image = image.resize((width, new_height)) 438 | image = image.crop((0, up, width, up + height)) 439 | 440 | elif mode == 2: # Padding 441 | new_image = Image.new('RGB', (width, height), (255, 255, 255)) 442 | new_image.paste(image, ((width - image.size[0]) // 2, 443 | (height - image.size[1]) // 2)) 444 | image = new_image 445 | 446 | return image 447 | 448 | 449 | def generate_mask_and_img_expand( 450 | img: Image.Image, 451 | expand: Tuple[int, int, int, int], 452 | expand_type: str = 'copy') -> Tuple[Image.Image, Image.Image]: 453 | """ 454 | Generate a mask and an expanded image based on the given image and expand parameters. 455 | 456 | Args: 457 | img (Image.Image): The original image. 458 | expand (Tuple[int, int, int, int]): The expansion values for left, right, up, and down directions. 459 | expand_type (str, optional): The type of expansion ('copy' or 'reflect'). Defaults to 'copy'. 460 | 461 | Returns: 462 | Tuple[Image.Image, Image.Image]: The expanded image and the corresponding mask. 463 | """ 464 | 465 | left, right, up, down = expand 466 | 467 | width, height = img.size 468 | new_width, new_height = width + left + right, height + up + down 469 | 470 | # ----------- 1. Create mask where the image is black and the expanded region is white ----------- 471 | mask = Image.new('L', (new_width, new_height), 0) 472 | draw = ImageDraw.Draw(mask) 473 | # Add white edge 474 | color = 255 475 | draw.rectangle((0, 0, new_width, up), fill=color) # up 476 | draw.rectangle((0, new_height - down, new_width, new_height), 477 | fill=color) # down 478 | draw.rectangle((0, 0, left, new_height), fill=color) # left 479 | draw.rectangle((new_width - right, 0, new_width, new_height), 480 | fill=color) # right 481 | 482 | # ----------- 2. Expand the image by a copy or reflection operation ----------- 483 | # simply use the filled pixel can not generate meaningful image in unified pipeline 484 | # img_expand = Image.new('RGB', (new_width, new_height), (255, 255, 255)) 485 | # img_expand.paste(img, (left, up)) 486 | 487 | # Convert the image to a NumPy array 488 | image_array = np.array(img) 489 | 490 | # new img 491 | expanded_image_array = np.zeros((new_height, new_width, 3), dtype=np.uint8) 492 | 493 | # copy ori img 494 | expanded_image_array[up:up + height, left:left + width, :] = image_array 495 | 496 | if expand_type == 'reflect': 497 | # Reflect the boundary pixels to the new boundaries 498 | expanded_image_array[:up, left:left + width, :] = np.flipud( 499 | expanded_image_array[up:2 * up, left:left + width, :]) # up 500 | expanded_image_array[up + height:, left:left + width, :] = np.flipud( 501 | expanded_image_array[up + height - 2:up + height - 2 - down:-1, 502 | left:left + width, :]) # down 503 | expanded_image_array[:, :left, :] = np.fliplr( 504 | expanded_image_array[:, left:2 * left, :]) # left 505 | expanded_image_array[:, left + width:, :] = np.fliplr( 506 | expanded_image_array[:, left + width - 2:left + width - 2 - 507 | right:-1, :]) # right 508 | 509 | else: 510 | # Copy the boundary pixels to the new boundaries 511 | expanded_image_array[:up, left:left + 512 | width, :] = image_array[0:1, :, :] # up 513 | expanded_image_array[up + height:, left:left + 514 | width, :] = image_array[height - 515 | 1:height, :, :] # down 516 | expanded_image_array[:, :left, :] = expanded_image_array[:, left:left + 517 | 1, :] # left 518 | expanded_image_array[:, left + 519 | width:, :] = expanded_image_array[:, left + 520 | width - 1:left + 521 | width, :] # right 522 | 523 | # Create a new image from the expanded image array 524 | img_expand = Image.fromarray(expanded_image_array) 525 | 526 | return img_expand, mask 527 | -------------------------------------------------------------------------------- /diffusers/utils/blade.py: -------------------------------------------------------------------------------- 1 | """ 2 | The file is used for blade optimization. 3 | The main function are three fold: 4 | 1. optimize_and_save_blade_model: do blade optimization and save optimized models (it takes about 30min) 5 | 2. load_blade_model: use to load the optimized blade model 6 | 3. load_attn_procs / unload_lora: online change and merge multiple lora weights 7 | """ 8 | 9 | from dataclasses import dataclass 10 | from functools import lru_cache 11 | from pathlib import Path 12 | from typing import Dict, List, Optional, Tuple, Union 13 | 14 | import torch 15 | import torch_blade 16 | from safetensors.torch import load_file 17 | from torch import Tensor, nn 18 | from torch_blade import optimize as blade_optimize 19 | 20 | # ------------------------ 1. optimize_and_save_blade_model ------------------------ 21 | 22 | 23 | def gen_inputs( 24 | pipe, 25 | use_controlnet: bool = False 26 | ) -> Tuple[Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...], Tensor]: 27 | """ 28 | Generate inputs for the specified pipe to forward pipeline. 29 | 30 | Args: 31 | pipe: The diffusion pipeline. 32 | use_controlnet (bool, optional): Flag indicating whether to use controlnet inputs. Default is False. 33 | 34 | Returns: 35 | Tuple[Tensor, Tuple[Tensor, ...], Tuple[Tensor, ...], Tensor]: The generated inputs consisting of: 36 | - encoder_inputs: Tensor of shape (1, text_max_length) containing integer values. 37 | - controlnet_inputs: Tuple of tensors containing inputs for controlnet. 38 | - unet_inputs: Tuple of tensors containing inputs for unet. 39 | - decoder_inputs: Tensor of shape (1, unet_out_channels, 128, 128) containing float values. 40 | """ 41 | device = torch.device('cuda:0') 42 | # use bs=1 to trace and optimize 43 | text_max_length = pipe.tokenizer.model_max_length 44 | text_embedding_size = pipe.text_encoder.config.hidden_size 45 | sample_size = pipe.unet.config.sample_size 46 | unet_in_channels = pipe.unet.config.in_channels 47 | unet_out_channels = pipe.unet.config.out_channels 48 | 49 | encoder_inputs = torch.randint(1, 50 | 999, (1, text_max_length), 51 | device=device, 52 | dtype=torch.int64) 53 | 54 | unet_inputs = [ 55 | torch.randn( 56 | (2, unet_in_channels, sample_size, sample_size), 57 | dtype=torch.half, 58 | device=device, 59 | ), 60 | torch.tensor(999, device=device, dtype=torch.half), 61 | torch.randn((2, text_max_length, text_embedding_size), 62 | dtype=torch.half, 63 | device=device), 64 | ] 65 | 66 | # controlnet has same inputs as unet, with additional condition 67 | controlnet_inputs = unet_inputs + [ 68 | torch.randn((2, 3, 512, 512), dtype=torch.half, device=device), 69 | ] 70 | 71 | decoder_inputs = torch.randn(1, 72 | unet_out_channels, 73 | 128, 74 | 128, 75 | device=device, 76 | dtype=torch.half) 77 | 78 | return encoder_inputs, controlnet_inputs, unet_inputs, decoder_inputs 79 | 80 | 81 | def optimize_and_save_blade_model( 82 | pipe: nn.Module, 83 | encoder_path: str, 84 | unet_path: str, 85 | decoder_path: str, 86 | controlnet_path: Optional[str] = None) -> None: 87 | """ 88 | Optimize and save the Blade model. 89 | 90 | Args: 91 | pipe (nn.Module): The pipeline module. 92 | encoder_path (str): The path to save the optimized encoder model. 93 | unet_path (str): The path to save the optimized UNet model. 94 | decoder_path (str): The path to save the optimized decoder model. 95 | controlnet_path (str, optional): The path to save the optimized controlnet model. Default is None. 96 | 97 | Returns: 98 | None 99 | """ 100 | 101 | if controlnet_path is not None: 102 | use_controlnet = True 103 | else: 104 | use_controlnet = False 105 | 106 | encoder_inputs, controlnet_inputs, unet_inputs, decoder_inputs = gen_inputs( 107 | pipe, use_controlnet=use_controlnet) 108 | 109 | if not use_controlnet: 110 | # base 111 | class UnetWrapper(torch.nn.Module): 112 | def __init__(self, unet): 113 | super().__init__() 114 | self.unet = unet 115 | 116 | def forward(self, sample, timestep, encoder_hidden_states): 117 | return self.unet( 118 | sample, 119 | timestep, 120 | encoder_hidden_states=encoder_hidden_states, 121 | ) 122 | 123 | opt_cfg = torch_blade.Config() 124 | opt_cfg.enable_fp16 = True 125 | opt_cfg.freeze_module = False # allow to change the lora weight when inferring 126 | 127 | from torch_blade.monkey_patch import patch_utils 128 | # change layout for conv layer [NCHW]->[NHWC] for a better inference time 129 | patch_utils.patch_conv2d(pipe.unet) 130 | patch_utils.patch_conv2d(pipe.vae.decoder) 131 | 132 | with opt_cfg, torch.no_grad(): 133 | unet = torch.jit.trace( 134 | UnetWrapper(pipe.unet).eval(), 135 | tuple(unet_inputs), 136 | strict=False, 137 | check_trace=False, 138 | ) 139 | # unet = torch.jit.trace(pipe.unet, unet_inputs, strict=False, check_trace=False) 140 | 141 | unet = torch_blade.optimize(unet, 142 | model_inputs=tuple(unet_inputs), 143 | allow_tracing=True) 144 | 145 | encoder = torch_blade.optimize(pipe.text_encoder, 146 | model_inputs=encoder_inputs, 147 | allow_tracing=True) 148 | 149 | decoder = torch.jit.trace(pipe.vae.decoder, 150 | decoder_inputs, 151 | strict=False, 152 | check_trace=False) 153 | decoder = torch_blade.optimize(decoder, 154 | model_inputs=decoder_inputs, 155 | allow_tracing=True) 156 | 157 | torch.jit.save(encoder, encoder_path) 158 | torch.jit.save(unet, unet_path) 159 | torch.jit.save(decoder, decoder_path) 160 | 161 | else: 162 | # controlnet 163 | opt_cfg = torch_blade.Config() 164 | opt_cfg.enable_fp16 = True 165 | 166 | class UnetWrapper(torch.nn.Module): 167 | def __init__(self, unet): 168 | super().__init__() 169 | self.unet = unet 170 | 171 | def forward( 172 | self, 173 | sample, 174 | timestep, 175 | encoder_hidden_states, 176 | down_block_additional_residuals, 177 | mid_block_additional_residual, 178 | ): 179 | return self.unet( 180 | sample, 181 | timestep, 182 | encoder_hidden_states=encoder_hidden_states, 183 | down_block_additional_residuals= 184 | down_block_additional_residuals, 185 | mid_block_additional_residual=mid_block_additional_residual, 186 | ) 187 | 188 | import functools 189 | 190 | from torch_blade.monkey_patch import patch_utils 191 | 192 | patch_utils.patch_conv2d(pipe.unet) 193 | patch_utils.patch_conv2d(pipe.controlnet) 194 | 195 | pipe.controlnet.forward = functools.partial(pipe.controlnet.forward, 196 | return_dict=False) 197 | 198 | with opt_cfg, torch.no_grad(): 199 | encoder = torch_blade.optimize(pipe.text_encoder, 200 | model_inputs=encoder_inputs, 201 | allow_tracing=True) 202 | # decoder = torch.jit.trace(pipe.vae.decoder, decoder_inputs, strict=False, check_trace=False) 203 | decoder = torch_blade.optimize(pipe.vae.decoder, 204 | model_inputs=decoder_inputs, 205 | allow_tracing=True) 206 | 207 | # not freeze to load other weights 208 | opt_cfg.freeze_module = False 209 | 210 | controlnet = torch.jit.trace(pipe.controlnet, 211 | tuple(controlnet_inputs), 212 | strict=False, 213 | check_trace=False) 214 | controlnet = torch_blade.optimize( 215 | controlnet, 216 | model_inputs=tuple(controlnet_inputs), 217 | allow_tracing=True) 218 | # add controlnet outputs to unet inputs 219 | down_block_res_samples, mid_block_res_sample = controlnet( 220 | *controlnet_inputs) 221 | 222 | device = torch.device('cuda:0') 223 | 224 | unet_inputs += [ 225 | tuple(down_block_res_samples), 226 | mid_block_res_sample, 227 | ] 228 | 229 | unet = torch.jit.trace( 230 | UnetWrapper(pipe.unet).eval(), 231 | tuple(unet_inputs), 232 | strict=False, 233 | check_trace=False, 234 | ) 235 | 236 | unet = torch_blade.optimize(unet, 237 | model_inputs=tuple(unet_inputs), 238 | allow_tracing=True) 239 | 240 | torch.jit.save(encoder, encoder_path) 241 | torch.jit.save(controlnet, controlnet_path) 242 | torch.jit.save(unet, unet_path) 243 | torch.jit.save(decoder, decoder_path) 244 | 245 | 246 | # ------------------------ 2. load_blade_model ------------------------ 247 | 248 | 249 | def load_blade_model(pipe: nn.Module, 250 | encoder_path: str, 251 | unet_path: str, 252 | decoder_path: str, 253 | controlnet_path: Optional[str] = None) -> nn.Module: 254 | """ 255 | Load the Blade model. 256 | 257 | Args: 258 | pipe (nn.Module): The pipeline module. 259 | encoder_path (str): The path to the optimized encoder model. 260 | unet_path (str): The path to the optimized UNet model. 261 | decoder_path (str): The path to the optimized decoder model. 262 | controlnet_path (str, optional): The path to the optimized controlnet model. Default is None. 263 | 264 | Returns: 265 | nn.Module: The loaded Blade model. 266 | """ 267 | 268 | if controlnet_path is not None: 269 | use_controlnet = True 270 | else: 271 | use_controlnet = False 272 | 273 | encoder_inputs, controlnet_inputs, unet_inputs, decoder_inputs = gen_inputs( 274 | pipe, use_controlnet=use_controlnet) 275 | 276 | # encoder = torch.jit.load(encoder_path).eval().cuda() 277 | unet = torch.jit.load(unet_path).eval().cuda() 278 | decoder = torch.jit.load(decoder_path).eval().cuda() 279 | 280 | # load weights from current model 281 | if not use_controlnet: 282 | unet_state_dict = { 283 | 'unet.' + k: v 284 | for k, v in pipe.unet.state_dict().items() 285 | } 286 | _, unexpected = unet.load_state_dict(unet_state_dict, strict=False) 287 | print(unexpected) 288 | 289 | _, unexpected = decoder.load_state_dict(pipe.vae.decoder.state_dict(), 290 | strict=False) 291 | print(unexpected) 292 | 293 | # warmup 294 | # encoder(encoder_inputs) 295 | unet(*unet_inputs) 296 | decoder(*decoder_inputs) 297 | 298 | patch_conv_weights(unet) 299 | patch_conv_weights(decoder) 300 | 301 | if use_controlnet: 302 | controlnet = torch.jit.load(controlnet_path).eval().cuda() 303 | 304 | @dataclass 305 | class UNet2DConditionOutput: 306 | sample: torch.FloatTensor 307 | 308 | class TracedEncoder(torch.nn.Module): 309 | def __init__(self): 310 | super().__init__() 311 | self.config = pipe.text_encoder.config 312 | self.device = pipe.text_encoder.device 313 | self.dtype = torch.half 314 | 315 | def forward(self, input_ids, **kwargs): 316 | embeddings = encoder(input_ids.long()) 317 | return [embeddings['last_hidden_state']] 318 | 319 | if use_controlnet: 320 | # controlnet 321 | class TracedControlNet(torch.nn.Module): 322 | def __init__(self): 323 | super().__init__() 324 | self.controlnet_conditioning_channel_order = 'rgb' 325 | 326 | def forward(self, sample, timestep, encoder_hidden_states, 327 | **kwargs): 328 | if self.controlnet_conditioning_channel_order == 'rgb': 329 | return controlnet(sample.half(), timestep.half(), 330 | encoder_hidden_states.half(), 331 | kwargs['controlnet_cond']) 332 | else: 333 | return controlnet( 334 | sample.half(), timestep.half(), 335 | encoder_hidden_states.half(), 336 | torch.flip(kwargs['controlnet_cond'], dims=[1])) 337 | 338 | def load_state_dict(self, state_dict, strict=False): 339 | _, unexpected = controlnet.load_state_dict(state_dict, 340 | strict=strict) 341 | if unexpected: 342 | print( 343 | f'load controlNet with unexpected keys: {unexpected}') 344 | return 345 | 346 | def state_dict(self): 347 | return controlnet.state_dict() 348 | 349 | def set_channel_order(self, channel_order): 350 | self.controlnet_conditioning_channel_order = channel_order 351 | 352 | class TracedUNet(torch.nn.Module): 353 | def __init__(self): 354 | super().__init__() 355 | self.config = pipe.unet.config 356 | self.in_channels = pipe.unet.in_channels 357 | self.device = pipe.unet.device 358 | self.device = pipe.unet.device 359 | self.lora_weights = {} 360 | self.cur_lora = {} 361 | 362 | def state_dict(self): 363 | return unet.state_dict() 364 | 365 | def forward(self, latent_model_input, t, encoder_hidden_states, 366 | **kwargs): 367 | if kwargs.get('down_block_additional_residuals', None) is None: 368 | kwargs['down_block_additional_residuals'] = tuple([ 369 | torch.tensor( 370 | [[[[0.0]]]], device=self.device, dtype=torch.half) 371 | ] * 13) 372 | if kwargs.get('mid_block_additional_residual', None) is None: 373 | kwargs['mid_block_additional_residual'] = torch.tensor( 374 | [[[[0.0]]]], device=self.device, dtype=torch.half) 375 | 376 | sample = unet( 377 | latent_model_input.half(), 378 | t.half(), 379 | encoder_hidden_states.half(), 380 | kwargs['down_block_additional_residuals'], 381 | kwargs['mid_block_additional_residual'], 382 | )['sample'] 383 | 384 | return UNet2DConditionOutput(sample=sample) 385 | 386 | else: 387 | # base model 388 | class TracedUNet(torch.nn.Module): 389 | def __init__(self): 390 | super().__init__() 391 | self.config = pipe.unet.config 392 | self.in_channels = pipe.unet.in_channels 393 | self.device = pipe.unet.device 394 | self.lora_weights = {} 395 | self.cur_lora = {} 396 | 397 | def state_dict(self): 398 | return unet.state_dict() 399 | 400 | def forward(self, latent_model_input, t, encoder_hidden_states, 401 | **kwargs): 402 | sample = unet(latent_model_input.half(), t.half(), 403 | encoder_hidden_states.half())['sample'] 404 | return UNet2DConditionOutput(sample=sample) 405 | 406 | class TracedDecoder(torch.nn.Module): 407 | def forward(self, input): 408 | return decoder(input.half()) 409 | 410 | # pipe.text_encoder = TracedEncoder() # lead to incorrect output 411 | 412 | if use_controlnet: 413 | controlnet_wrapper = TracedControlNet() 414 | pipe.controlnet.forward = controlnet_wrapper.forward 415 | pipe.controlnet.load_state_dict = controlnet_wrapper.load_state_dict 416 | pipe.controlnet.state_dict = controlnet_wrapper.state_dict 417 | pipe.controlnet.set_channel_order = controlnet_wrapper.set_channel_order 418 | 419 | pipe.unet = TracedUNet() 420 | pipe.vae.decoder = TracedDecoder() 421 | 422 | return pipe 423 | 424 | 425 | # ------------------------ 3. load_attn_procs / unload_lora: online change and merge multiple lora weights ------------------ 426 | 427 | LORA_PREFIX_TEXT_ENCODER, LORA_PREFIX_UNET = 'lora_te', 'lora_unet' 428 | 429 | 430 | @lru_cache(maxsize=32) 431 | def load_lora_and_mul( 432 | lora_path: str, 433 | dtype: torch.dtype) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]: 434 | """ 435 | Load and process LoRA weights from the specified path. 436 | 437 | Args: 438 | lora_path (str): Path to the LoRA weights file. 439 | dtype (torch.dtype): Desired data type for the processed weights. 440 | 441 | Returns: 442 | Tuple[Dict[str, Tensor], Dict[str, Tensor]]: A tuple containing two dictionaries: 443 | - text_encoder_state_dict: Dictionary containing the state dictionary for the text encoder. 444 | - unet_state_dict: Dictionary containing the state dictionary for the UNet. 445 | """ 446 | if lora_path.endswith('.safetensors'): 447 | # lora model trained by webui script (e.g., model from civitai) 448 | 449 | state_dict = load_file(lora_path) 450 | 451 | visited, unet_state_dict, text_encoder_state_dict = [], {}, {} 452 | 453 | # directly update weight in diffusers model 454 | for key in state_dict: 455 | # it is suggested to print out the key, it usually will be something like below 456 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 457 | 458 | # as we have set the alpha beforehand, so just skip 459 | if '.alpha' in key or key in visited: 460 | continue 461 | 462 | if 'text' in key: 463 | diffusers_key = key.split( 464 | '.')[0].split(LORA_PREFIX_TEXT_ENCODER + '_')[-1].replace( 465 | '_', '.').replace('text.model', 'text_model').replace( 466 | '.proj', '_proj').replace('self.attn', 467 | 'self_attn') + '.weight' 468 | curr_state_dict = text_encoder_state_dict 469 | else: 470 | diffusers_key = 'unet.' + key.split('.')[0].split( 471 | LORA_PREFIX_UNET + '_')[-1].replace('_', '.').replace( 472 | '.block', '_block').replace('to.', 'to_').replace( 473 | 'proj.', 'proj_') + '.weight' 474 | curr_state_dict = unet_state_dict 475 | 476 | pair_keys = [] 477 | if 'lora_down' in key: 478 | alpha = state_dict.get( 479 | key.replace('lora_down.weight', 'alpha'), None) 480 | pair_keys.append(key.replace('lora_down', 'lora_up')) 481 | pair_keys.append(key) 482 | else: 483 | alpha = state_dict.get(key.replace('lora_up.weight', 'alpha'), 484 | None) 485 | pair_keys.append(key) 486 | pair_keys.append(key.replace('lora_up', 'lora_down')) 487 | 488 | # update weight 489 | if alpha: 490 | alpha = alpha.item() / state_dict[pair_keys[0]].shape[1] 491 | else: 492 | alpha = 0.75 493 | 494 | if len(state_dict[pair_keys[0]].shape) == 4: 495 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to( 496 | torch.float32) 497 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze( 498 | 2).to(torch.float32) 499 | if len(weight_up.shape) == len(weight_down.shape): 500 | curr_state_dict[diffusers_key] = alpha * torch.mm( 501 | weight_up.cuda(), 502 | weight_down.cuda()).unsqueeze(2).unsqueeze(3).to(dtype) 503 | else: 504 | curr_state_dict[diffusers_key] = alpha * torch.einsum( 505 | 'a b, b c h w -> a c h w', weight_up.cuda(), 506 | weight_down.cuda()).to(dtype) 507 | else: 508 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 509 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 510 | curr_state_dict[diffusers_key] = alpha * torch.mm( 511 | weight_up.cuda(), weight_down.cuda()).to(dtype) 512 | 513 | # update visited list 514 | for item in pair_keys: 515 | visited.append(item) 516 | return text_encoder_state_dict, unet_state_dict 517 | else: 518 | # model trained by diffusers api (lora attn only in unet) 519 | state_dict = torch.load(lora_path) 520 | multied_state_dict = {} 521 | for k, v in state_dict.items(): 522 | if '_lora.up.weight' in k: 523 | up_weight = v 524 | new_k = 'unet.' + k.replace( 525 | '_lora.up.weight', '.weight').replace( 526 | 'processor.', '').replace('to_out', 'to_out.0') 527 | down_weight = state_dict[k.replace('up.weight', 'down.weight')] 528 | # xxxx_lora.up.weight 529 | multied_state_dict[new_k] = torch.matmul( 530 | up_weight.cuda(), down_weight.cuda()) 531 | return {}, multied_state_dict 532 | 533 | 534 | def patch_conv_weights(model: nn.Module) -> nn.Module: 535 | """ 536 | Patch the convolutional weights in the model to be compatible with NHWC format. 537 | For model acceleration in blade optimization 538 | 539 | Args: 540 | model (nn.Module): The model to be patched. 541 | 542 | Returns: 543 | nn.Module: The patched model. 544 | """ 545 | origin_state_dict = model.state_dict() 546 | state_dict = {} 547 | for k, v in origin_state_dict.items(): 548 | if k.endswith('_nhwc'): 549 | state_dict[k] = origin_state_dict[k[:-5]].permute([0, 2, 3, 1]) 550 | model.load_state_dict(state_dict, strict=False) 551 | return model 552 | 553 | 554 | def merge_lora_weights(origin: Dict[str, Tensor], to_merge: Dict[str, Tensor], 555 | scale: float) -> Dict[str, Tensor]: 556 | """ 557 | Merge LoRA weights into the origin dictionary with the specified scale. 558 | 559 | Args: 560 | origin (Dict[str, Tensor]): The original weights dictionary. 561 | to_merge (Dict[str, Tensor]): The weights dictionary to be merged. 562 | scale (float): The scaling factor for the merged weights. 563 | 564 | Returns: 565 | Dict[str, Tensor]: The merged weights dictionary. 566 | """ 567 | for k, v in to_merge.items(): 568 | v = v.to('cuda') 569 | weight = v * scale 570 | if origin.get(k, None) is None: 571 | origin[k] = weight 572 | else: 573 | origin[k] += weight 574 | 575 | 576 | def apply_lora_weights(model_state_dict: Dict[str, Tensor], 577 | weights: Dict[str, Tensor]) -> None: 578 | """ 579 | Apply LoRA weights to the model state dictionary. 580 | 581 | Args: 582 | model_state_dict (Dict[str, Tensor]): The model's state dictionary. 583 | weights (Dict[str, Tensor]): The LoRA weights to be applied. 584 | 585 | Returns: 586 | None 587 | """ 588 | 589 | with torch.no_grad(): 590 | for k, v in weights.items(): 591 | v = v.to('cuda') 592 | model_state_dict[k].add_(v) 593 | 594 | 595 | def unload_lora_weights(model_state_dict: Dict[str, Tensor], 596 | weights: Dict[str, Tensor]) -> None: 597 | """ 598 | Unload LoRA weights from the model state dictionary. 599 | 600 | Args: 601 | model_state_dict (Dict[str, Tensor]): The model's state dictionary. 602 | weights (Dict[str, Tensor]): The LoRA weights to be unloaded. 603 | 604 | Returns: 605 | None 606 | """ 607 | with torch.no_grad(): 608 | for k, v in weights.items(): 609 | v = v.to('cuda') 610 | model_state_dict[k].sub_(v) 611 | 612 | 613 | def load_attn_procs(pipe, 614 | attn_procs_paths: Union[str, List[str]], 615 | scales: Union[float, List[float]] = 0.75) -> None: 616 | """ 617 | Load and merge multiple lora model weights into the pipeline. 618 | 619 | Args: 620 | pipe: Stable diffusion pipeline 621 | attn_procs_paths (Union[str, List[str]]): The paths to the attention processor weights. 622 | scales (Union[float, List[float]], optional): The scaling factor(s) for the merged weights. Default is 0.75. 623 | 624 | Returns: 625 | None 626 | """ 627 | 628 | if isinstance(scales, str): 629 | attn_procs_paths = [attn_procs_paths] 630 | if isinstance(scales, float): 631 | scales = [scales] * len(attn_procs_paths) 632 | 633 | pipe.text_encoder_merged_weights, pipe.unet_merged_weights = {}, {} 634 | 635 | for attn_procs_path, scale in zip(attn_procs_paths, scales): 636 | text_encoder_state_dict, unet_state_dict = load_lora_and_mul( 637 | attn_procs_path, dtype=torch.half) 638 | # merge weights from multiple lora models with scale 639 | merge_lora_weights(pipe.text_encoder_merged_weights, 640 | text_encoder_state_dict, scale) 641 | merge_lora_weights(pipe.unet_merged_weights, unet_state_dict, scale) 642 | 643 | # apply the final lora weights to text_encoder and unet 644 | apply_lora_weights(pipe.text_encoder.state_dict(), 645 | pipe.text_encoder_merged_weights) 646 | apply_lora_weights(pipe.unet.state_dict(), pipe.unet_merged_weights) 647 | 648 | patch_conv_weights(pipe.unet) 649 | 650 | 651 | def unload_lora(pipe): 652 | # unload the lora weight after each infer 653 | unload_lora_weights(pipe.text_encoder.state_dict(), 654 | pipe.text_encoder_merged_weights) 655 | unload_lora_weights(pipe.unet.state_dict(), pipe.unet_merged_weights) 656 | pipe.text_encoder_merged_weights, pipe.unet_merged_weights = {}, {} 657 | patch_conv_weights(pipe.unet) 658 | -------------------------------------------------------------------------------- /diffusers/lpw_stable_diffusion.py: -------------------------------------------------------------------------------- 1 | # This code is borrowed from the HuggingFace diffusers library 2 | # Source: https://github.com/huggingface/diffusers/blob/main/examples/community/lpw_stable_diffusion.py 3 | # License: Apache-2.0 4 | 5 | import inspect 6 | import re 7 | from typing import Callable, List, Optional, Union 8 | 9 | import numpy as np 10 | import PIL 11 | from packaging import version 12 | 13 | import diffusers 14 | import torch 15 | from diffusers import SchedulerMixin, StableDiffusionPipeline 16 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 17 | from diffusers.pipelines.stable_diffusion import ( 18 | StableDiffusionPipelineOutput, StableDiffusionSafetyChecker) 19 | from diffusers.utils import logging 20 | from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer 21 | 22 | try: 23 | from diffusers.utils import PIL_INTERPOLATION 24 | except ImportError: 25 | if version.parse(version.parse( 26 | PIL.__version__).base_version) >= version.parse('9.1.0'): 27 | PIL_INTERPOLATION = { 28 | 'linear': PIL.Image.Resampling.BILINEAR, 29 | 'bilinear': PIL.Image.Resampling.BILINEAR, 30 | 'bicubic': PIL.Image.Resampling.BICUBIC, 31 | 'lanczos': PIL.Image.Resampling.LANCZOS, 32 | 'nearest': PIL.Image.Resampling.NEAREST, 33 | } 34 | else: 35 | PIL_INTERPOLATION = { 36 | 'linear': PIL.Image.LINEAR, 37 | 'bilinear': PIL.Image.BILINEAR, 38 | 'bicubic': PIL.Image.BICUBIC, 39 | 'lanczos': PIL.Image.LANCZOS, 40 | 'nearest': PIL.Image.NEAREST, 41 | } 42 | # ------------------------------------------------------------------------------ 43 | 44 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 45 | 46 | re_attention = re.compile( 47 | r""" 48 | \\\(| 49 | \\\)| 50 | \\\[| 51 | \\]| 52 | \\\\| 53 | \\| 54 | \(| 55 | \[| 56 | :([+-]?[.\d]+)\)| 57 | \)| 58 | ]| 59 | [^\\()\[\]:]+| 60 | : 61 | """, 62 | re.X, 63 | ) 64 | 65 | 66 | def parse_prompt_attention(text): 67 | """ 68 | Parses a string with attention tokens and returns a list of pairs: text and its associated weight. 69 | Accepted tokens are: 70 | (abc) - increases attention to abc by a multiplier of 1.1 71 | (abc:3.12) - increases attention to abc by a multiplier of 3.12 72 | [abc] - decreases attention to abc by a multiplier of 1.1 73 | \( - literal character '(' 74 | \[ - literal character '[' 75 | \) - literal character ')' 76 | \] - literal character ']' 77 | \\ - literal character '\' 78 | anything else - just text 79 | >>> parse_prompt_attention('normal text') 80 | [['normal text', 1.0]] 81 | >>> parse_prompt_attention('an (important) word') 82 | [['an ', 1.0], ['important', 1.1], [' word', 1.0]] 83 | >>> parse_prompt_attention('(unbalanced') 84 | [['unbalanced', 1.1]] 85 | >>> parse_prompt_attention('\(literal\]') 86 | [['(literal]', 1.0]] 87 | >>> parse_prompt_attention('(unnecessary)(parens)') 88 | [['unnecessaryparens', 1.1]] 89 | >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') 90 | [['a ', 1.0], 91 | ['house', 1.5730000000000004], 92 | [' ', 1.1], 93 | ['on', 1.0], 94 | [' a ', 1.1], 95 | ['hill', 0.55], 96 | [', sun, ', 1.1], 97 | ['sky', 1.4641000000000006], 98 | ['.', 1.1]] 99 | """ 100 | 101 | res = [] 102 | round_brackets = [] 103 | square_brackets = [] 104 | 105 | round_bracket_multiplier = 1.1 106 | square_bracket_multiplier = 1 / 1.1 107 | 108 | def multiply_range(start_position, multiplier): 109 | for p in range(start_position, len(res)): 110 | res[p][1] *= multiplier 111 | 112 | for m in re_attention.finditer(text): 113 | text = m.group(0) 114 | weight = m.group(1) 115 | 116 | if text.startswith('\\'): 117 | res.append([text[1:], 1.0]) 118 | elif text == '(': 119 | round_brackets.append(len(res)) 120 | elif text == '[': 121 | square_brackets.append(len(res)) 122 | elif weight is not None and len(round_brackets) > 0: 123 | multiply_range(round_brackets.pop(), float(weight)) 124 | elif text == ')' and len(round_brackets) > 0: 125 | multiply_range(round_brackets.pop(), round_bracket_multiplier) 126 | elif text == ']' and len(square_brackets) > 0: 127 | multiply_range(square_brackets.pop(), square_bracket_multiplier) 128 | else: 129 | res.append([text, 1.0]) 130 | 131 | for pos in round_brackets: 132 | multiply_range(pos, round_bracket_multiplier) 133 | 134 | for pos in square_brackets: 135 | multiply_range(pos, square_bracket_multiplier) 136 | 137 | if len(res) == 0: 138 | res = [['', 1.0]] 139 | 140 | # merge runs of identical weights 141 | i = 0 142 | while i + 1 < len(res): 143 | if res[i][1] == res[i + 1][1]: 144 | res[i][0] += res[i + 1][0] 145 | res.pop(i + 1) 146 | else: 147 | i += 1 148 | 149 | return res 150 | 151 | 152 | def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], 153 | max_length: int): 154 | r""" 155 | Tokenize a list of prompts and return its tokens with weights of each token. 156 | 157 | No padding, starting or ending token is included. 158 | """ 159 | tokens = [] 160 | weights = [] 161 | truncated = False 162 | for text in prompt: 163 | texts_and_weights = parse_prompt_attention(text) 164 | text_token = [] 165 | text_weight = [] 166 | for word, weight in texts_and_weights: 167 | # tokenize and discard the starting and the ending token 168 | token = pipe.tokenizer(word).input_ids[1:-1] 169 | text_token += token 170 | # copy the weight by length of token 171 | text_weight += [weight] * len(token) 172 | # stop if the text is too long (longer than truncation limit) 173 | if len(text_token) > max_length: 174 | truncated = True 175 | break 176 | # truncate 177 | if len(text_token) > max_length: 178 | truncated = True 179 | text_token = text_token[:max_length] 180 | text_weight = text_weight[:max_length] 181 | tokens.append(text_token) 182 | weights.append(text_weight) 183 | if truncated: 184 | logger.warning( 185 | 'Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples' 186 | ) 187 | return tokens, weights 188 | 189 | 190 | def pad_tokens_and_weights(tokens, 191 | weights, 192 | max_length, 193 | bos, 194 | eos, 195 | pad, 196 | no_boseos_middle=True, 197 | chunk_length=77): 198 | r""" 199 | Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. 200 | """ 201 | max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) 202 | weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length 203 | for i in range(len(tokens)): 204 | tokens[i] = [ 205 | bos 206 | ] + tokens[i] + [pad] * (max_length - 1 - len(tokens[i]) - 1) + [eos] 207 | if no_boseos_middle: 208 | weights[i] = [ 209 | 1.0 210 | ] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) 211 | else: 212 | w = [] 213 | if len(weights[i]) == 0: 214 | w = [1.0] * weights_length 215 | else: 216 | for j in range(max_embeddings_multiples): 217 | w.append(1.0) # weight for starting token in this chunk 218 | w += weights[i][j * (chunk_length - 219 | 2):min(len(weights[i]), (j + 1) * 220 | (chunk_length - 2))] 221 | w.append(1.0) # weight for ending token in this chunk 222 | w += [1.0] * (weights_length - len(w)) 223 | weights[i] = w[:] 224 | 225 | return tokens, weights 226 | 227 | 228 | def get_unweighted_text_embeddings( 229 | pipe: StableDiffusionPipeline, 230 | text_input: torch.Tensor, 231 | chunk_length: int, 232 | no_boseos_middle: Optional[bool] = True, 233 | ): 234 | """ 235 | When the length of tokens is a multiple of the capacity of the text encoder, 236 | it should be split into chunks and sent to the text encoder individually. 237 | """ 238 | max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) 239 | if max_embeddings_multiples > 1: 240 | text_embeddings = [] 241 | for i in range(max_embeddings_multiples): 242 | # extract the i-th chunk 243 | text_input_chunk = text_input[:, i * (chunk_length - 2):(i + 1) * 244 | (chunk_length - 2) + 2].clone() 245 | 246 | # cover the head and the tail by the starting and the ending tokens 247 | text_input_chunk[:, 0] = text_input[0, 0] 248 | text_input_chunk[:, -1] = text_input[0, -1] 249 | text_embedding = pipe.text_encoder(text_input_chunk)[0] 250 | 251 | if no_boseos_middle: 252 | if i == 0: 253 | # discard the ending token 254 | text_embedding = text_embedding[:, :-1] 255 | elif i == max_embeddings_multiples - 1: 256 | # discard the starting token 257 | text_embedding = text_embedding[:, 1:] 258 | else: 259 | # discard both starting and ending tokens 260 | text_embedding = text_embedding[:, 1:-1] 261 | 262 | text_embeddings.append(text_embedding) 263 | text_embeddings = torch.concat(text_embeddings, axis=1) 264 | else: 265 | text_embeddings = pipe.text_encoder(text_input)[0] 266 | return text_embeddings 267 | 268 | 269 | def get_weighted_text_embeddings( 270 | pipe: StableDiffusionPipeline, 271 | prompt: Union[str, List[str]], 272 | uncond_prompt: Optional[Union[str, List[str]]] = None, 273 | max_embeddings_multiples: Optional[int] = 3, 274 | no_boseos_middle: Optional[bool] = False, 275 | skip_parsing: Optional[bool] = False, 276 | skip_weighting: Optional[bool] = False, 277 | ): 278 | r""" 279 | Prompts can be assigned with local weights using brackets. For example, 280 | prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', 281 | and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. 282 | 283 | Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. 284 | 285 | Args: 286 | pipe (`StableDiffusionPipeline`): 287 | Pipe to provide access to the tokenizer and the text encoder. 288 | prompt (`str` or `List[str]`): 289 | The prompt or prompts to guide the image generation. 290 | uncond_prompt (`str` or `List[str]`): 291 | The unconditional prompt or prompts for guide the image generation. If unconditional prompt 292 | is provided, the embeddings of prompt and uncond_prompt are concatenated. 293 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 294 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 295 | no_boseos_middle (`bool`, *optional*, defaults to `False`): 296 | If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and 297 | ending token in each of the chunk in the middle. 298 | skip_parsing (`bool`, *optional*, defaults to `False`): 299 | Skip the parsing of brackets. 300 | skip_weighting (`bool`, *optional*, defaults to `False`): 301 | Skip the weighting. When the parsing is skipped, it is forced True. 302 | """ 303 | max_length = (pipe.tokenizer.model_max_length - 304 | 2) * max_embeddings_multiples + 2 305 | if isinstance(prompt, str): 306 | prompt = [prompt] 307 | 308 | if not skip_parsing: 309 | prompt_tokens, prompt_weights = get_prompts_with_weights( 310 | pipe, prompt, max_length - 2) 311 | if uncond_prompt is not None: 312 | if isinstance(uncond_prompt, str): 313 | uncond_prompt = [uncond_prompt] 314 | uncond_tokens, uncond_weights = get_prompts_with_weights( 315 | pipe, uncond_prompt, max_length - 2) 316 | else: 317 | prompt_tokens = [ 318 | token[1:-1] for token in pipe.tokenizer( 319 | prompt, max_length=max_length, truncation=True).input_ids 320 | ] 321 | prompt_weights = [[1.0] * len(token) for token in prompt_tokens] 322 | if uncond_prompt is not None: 323 | if isinstance(uncond_prompt, str): 324 | uncond_prompt = [uncond_prompt] 325 | uncond_tokens = [ 326 | token[1:-1] 327 | for token in pipe.tokenizer(uncond_prompt, 328 | max_length=max_length, 329 | truncation=True).input_ids 330 | ] 331 | uncond_weights = [[1.0] * len(token) for token in uncond_tokens] 332 | 333 | # round up the longest length of tokens to a multiple of (model_max_length - 2) 334 | max_length = max([len(token) for token in prompt_tokens]) 335 | if uncond_prompt is not None: 336 | max_length = max(max_length, 337 | max([len(token) for token in uncond_tokens])) 338 | 339 | max_embeddings_multiples = min( 340 | max_embeddings_multiples, 341 | (max_length - 1) // (pipe.tokenizer.model_max_length - 2) + 1, 342 | ) 343 | max_embeddings_multiples = max(1, max_embeddings_multiples) 344 | max_length = (pipe.tokenizer.model_max_length - 345 | 2) * max_embeddings_multiples + 2 346 | 347 | # pad the length of tokens and weights 348 | if isinstance(pipe.tokenizer, CLIPTokenizer): 349 | bos = pipe.tokenizer.bos_token_id 350 | eos = pipe.tokenizer.eos_token_id 351 | else: 352 | bos = pipe.tokenizer.cls_token_id 353 | eos = pipe.tokenizer.sep_token_id 354 | pad = getattr(pipe.tokenizer, 'pad_token_id', eos) 355 | prompt_tokens, prompt_weights = pad_tokens_and_weights( 356 | prompt_tokens, 357 | prompt_weights, 358 | max_length, 359 | bos, 360 | eos, 361 | pad, 362 | no_boseos_middle=no_boseos_middle, 363 | chunk_length=pipe.tokenizer.model_max_length, 364 | ) 365 | prompt_tokens = torch.tensor(prompt_tokens, 366 | dtype=torch.long, 367 | device=pipe.device) 368 | if uncond_prompt is not None: 369 | uncond_tokens, uncond_weights = pad_tokens_and_weights( 370 | uncond_tokens, 371 | uncond_weights, 372 | max_length, 373 | bos, 374 | eos, 375 | pad, 376 | no_boseos_middle=no_boseos_middle, 377 | chunk_length=pipe.tokenizer.model_max_length, 378 | ) 379 | uncond_tokens = torch.tensor(uncond_tokens, 380 | dtype=torch.long, 381 | device=pipe.device) 382 | 383 | # get the embeddings 384 | text_embeddings = get_unweighted_text_embeddings( 385 | pipe, 386 | prompt_tokens, 387 | pipe.tokenizer.model_max_length, 388 | no_boseos_middle=no_boseos_middle, 389 | ) 390 | prompt_weights = torch.tensor(prompt_weights, 391 | dtype=text_embeddings.dtype, 392 | device=pipe.device) 393 | if uncond_prompt is not None: 394 | uncond_embeddings = get_unweighted_text_embeddings( 395 | pipe, 396 | uncond_tokens, 397 | pipe.tokenizer.model_max_length, 398 | no_boseos_middle=no_boseos_middle, 399 | ) 400 | uncond_weights = torch.tensor(uncond_weights, 401 | dtype=uncond_embeddings.dtype, 402 | device=pipe.device) 403 | 404 | # assign weights to the prompts and normalize in the sense of mean 405 | # TODO: should we normalize by chunk or in a whole (current implementation)? 406 | if (not skip_parsing) and (not skip_weighting): 407 | previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to( 408 | text_embeddings.dtype) 409 | text_embeddings *= prompt_weights.unsqueeze(-1) 410 | current_mean = text_embeddings.float().mean(axis=[-2, -1]).to( 411 | text_embeddings.dtype) 412 | text_embeddings *= (previous_mean / 413 | current_mean).unsqueeze(-1).unsqueeze(-1) 414 | if uncond_prompt is not None: 415 | previous_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to( 416 | uncond_embeddings.dtype) 417 | uncond_embeddings *= uncond_weights.unsqueeze(-1) 418 | current_mean = uncond_embeddings.float().mean(axis=[-2, -1]).to( 419 | uncond_embeddings.dtype) 420 | uncond_embeddings *= (previous_mean / 421 | current_mean).unsqueeze(-1).unsqueeze(-1) 422 | 423 | if uncond_prompt is not None: 424 | return text_embeddings, uncond_embeddings 425 | return text_embeddings, None 426 | 427 | 428 | def preprocess_image(image): 429 | w, h = image.size 430 | w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 431 | image = image.resize((w, h), resample=PIL_INTERPOLATION['lanczos']) 432 | image = np.array(image).astype(np.float32) / 255.0 433 | image = image[None].transpose(0, 3, 1, 2) 434 | image = torch.from_numpy(image) 435 | return 2.0 * image - 1.0 436 | 437 | 438 | def preprocess_mask(mask, scale_factor=8): 439 | mask = mask.convert('L') 440 | w, h = mask.size 441 | w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 442 | mask = mask.resize((w // scale_factor, h // scale_factor), 443 | resample=PIL_INTERPOLATION['nearest']) 444 | mask = np.array(mask).astype(np.float32) / 255.0 445 | mask = np.tile(mask, (4, 1, 1)) 446 | mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? 447 | mask = 1 - mask # repaint white, keep black 448 | mask = torch.from_numpy(mask) 449 | return mask 450 | 451 | 452 | class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline): 453 | r""" 454 | Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing 455 | weighting in prompt. 456 | 457 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 458 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 459 | 460 | Args: 461 | vae ([`AutoencoderKL`]): 462 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 463 | text_encoder ([`CLIPTextModel`]): 464 | Frozen text-encoder. Stable Diffusion uses the text portion of 465 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 466 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 467 | tokenizer (`CLIPTokenizer`): 468 | Tokenizer of class 469 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 470 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 471 | scheduler ([`SchedulerMixin`]): 472 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 473 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 474 | safety_checker ([`StableDiffusionSafetyChecker`]): 475 | Classification module that estimates whether generated images could be considered offensive or harmful. 476 | Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details. 477 | feature_extractor ([`CLIPImageProcessor`]): 478 | Model that extracts features from generated images to be used as inputs for the `safety_checker`. 479 | """ 480 | 481 | if version.parse(version.parse( 482 | diffusers.__version__).base_version) >= version.parse('0.9.0'): 483 | 484 | def __init__( 485 | self, 486 | vae: AutoencoderKL, 487 | text_encoder: CLIPTextModel, 488 | tokenizer: CLIPTokenizer, 489 | unet: UNet2DConditionModel, 490 | scheduler: SchedulerMixin, 491 | safety_checker: StableDiffusionSafetyChecker, 492 | feature_extractor: CLIPImageProcessor, 493 | requires_safety_checker: bool = True, 494 | ): 495 | super().__init__( 496 | vae=vae, 497 | text_encoder=text_encoder, 498 | tokenizer=tokenizer, 499 | unet=unet, 500 | scheduler=scheduler, 501 | safety_checker=safety_checker, 502 | feature_extractor=feature_extractor, 503 | requires_safety_checker=requires_safety_checker, 504 | ) 505 | self.__init__additional__() 506 | 507 | else: 508 | 509 | def __init__( 510 | self, 511 | vae: AutoencoderKL, 512 | text_encoder: CLIPTextModel, 513 | tokenizer: CLIPTokenizer, 514 | unet: UNet2DConditionModel, 515 | scheduler: SchedulerMixin, 516 | safety_checker: StableDiffusionSafetyChecker, 517 | feature_extractor: CLIPImageProcessor, 518 | ): 519 | super().__init__( 520 | vae=vae, 521 | text_encoder=text_encoder, 522 | tokenizer=tokenizer, 523 | unet=unet, 524 | scheduler=scheduler, 525 | safety_checker=safety_checker, 526 | feature_extractor=feature_extractor, 527 | ) 528 | self.__init__additional__() 529 | 530 | def __init__additional__(self): 531 | if not hasattr(self, 'vae_scale_factor'): 532 | setattr(self, 'vae_scale_factor', 533 | 2**(len(self.vae.config.block_out_channels) - 1)) 534 | 535 | @property 536 | def _execution_device(self): 537 | r""" 538 | Returns the device on which the pipeline's models will be executed. After calling 539 | `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module 540 | hooks. 541 | """ 542 | if self.device != torch.device('meta') or not hasattr( 543 | self.unet, '_hf_hook'): 544 | return self.device 545 | for module in self.unet.modules(): 546 | if (hasattr(module, '_hf_hook') 547 | and hasattr(module._hf_hook, 'execution_device') 548 | and module._hf_hook.execution_device is not None): 549 | return torch.device(module._hf_hook.execution_device) 550 | return self.device 551 | 552 | def _encode_prompt( 553 | self, 554 | prompt, 555 | device, 556 | num_images_per_prompt, 557 | do_classifier_free_guidance, 558 | negative_prompt, 559 | max_embeddings_multiples, 560 | ): 561 | r""" 562 | Encodes the prompt into text encoder hidden states. 563 | 564 | Args: 565 | prompt (`str` or `list(int)`): 566 | prompt to be encoded 567 | device: (`torch.device`): 568 | torch device 569 | num_images_per_prompt (`int`): 570 | number of images that should be generated per prompt 571 | do_classifier_free_guidance (`bool`): 572 | whether to use classifier free guidance or not 573 | negative_prompt (`str` or `List[str]`): 574 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 575 | if `guidance_scale` is less than `1`). 576 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 577 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 578 | """ 579 | batch_size = len(prompt) if isinstance(prompt, list) else 1 580 | 581 | if negative_prompt is None: 582 | negative_prompt = [''] * batch_size 583 | elif isinstance(negative_prompt, str): 584 | negative_prompt = [negative_prompt] * batch_size 585 | if batch_size != len(negative_prompt): 586 | raise ValueError( 587 | f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:' 588 | f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches' 589 | ' the batch size of `prompt`.') 590 | 591 | text_embeddings, uncond_embeddings = get_weighted_text_embeddings( 592 | pipe=self, 593 | prompt=prompt, 594 | uncond_prompt=negative_prompt 595 | if do_classifier_free_guidance else None, 596 | max_embeddings_multiples=max_embeddings_multiples, 597 | ) 598 | bs_embed, seq_len, _ = text_embeddings.shape 599 | text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1) 600 | text_embeddings = text_embeddings.view( 601 | bs_embed * num_images_per_prompt, seq_len, -1) 602 | 603 | if do_classifier_free_guidance: 604 | bs_embed, seq_len, _ = uncond_embeddings.shape 605 | uncond_embeddings = uncond_embeddings.repeat( 606 | 1, num_images_per_prompt, 1) 607 | uncond_embeddings = uncond_embeddings.view( 608 | bs_embed * num_images_per_prompt, seq_len, -1) 609 | text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) 610 | 611 | return text_embeddings 612 | 613 | def check_inputs(self, prompt, height, width, strength, callback_steps): 614 | if not isinstance(prompt, str) and not isinstance(prompt, list): 615 | raise ValueError( 616 | f'`prompt` has to be of type `str` or `list` but is {type(prompt)}' 617 | ) 618 | 619 | if strength < 0 or strength > 1: 620 | raise ValueError( 621 | f'The value of strength should in [0.0, 1.0] but is {strength}' 622 | ) 623 | 624 | if height % 8 != 0 or width % 8 != 0: 625 | raise ValueError( 626 | f'`height` and `width` have to be divisible by 8 but are {height} and {width}.' 627 | ) 628 | 629 | if (callback_steps is None) or (callback_steps is not None and 630 | (not isinstance(callback_steps, int) 631 | or callback_steps <= 0)): 632 | raise ValueError( 633 | f'`callback_steps` has to be a positive integer but is {callback_steps} of type' 634 | f' {type(callback_steps)}.') 635 | 636 | def get_timesteps(self, num_inference_steps, strength, device, 637 | is_text2img): 638 | if is_text2img: 639 | return self.scheduler.timesteps.to(device), num_inference_steps 640 | else: 641 | # get the original timestep using init_timestep 642 | offset = self.scheduler.config.get('steps_offset', 0) 643 | init_timestep = int(num_inference_steps * strength) + offset 644 | init_timestep = min(init_timestep, num_inference_steps) 645 | 646 | t_start = max(num_inference_steps - init_timestep + offset, 0) 647 | timesteps = self.scheduler.timesteps[t_start:].to(device) 648 | return timesteps, num_inference_steps - t_start 649 | 650 | def run_safety_checker(self, image, device, dtype): 651 | if self.safety_checker is not None: 652 | safety_checker_input = self.feature_extractor( 653 | self.numpy_to_pil(image), return_tensors='pt').to(device) 654 | image, has_nsfw_concept = self.safety_checker( 655 | images=image, 656 | clip_input=safety_checker_input.pixel_values.to(dtype)) 657 | else: 658 | has_nsfw_concept = None 659 | return image, has_nsfw_concept 660 | 661 | def decode_latents(self, latents): 662 | latents = 1 / 0.18215 * latents 663 | image = self.vae.decode(latents).sample 664 | image = (image / 2 + 0.5).clamp(0, 1) 665 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 666 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 667 | return image 668 | 669 | def prepare_extra_step_kwargs(self, generator, eta): 670 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 671 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 672 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 673 | # and should be between [0, 1] 674 | 675 | accepts_eta = 'eta' in set( 676 | inspect.signature(self.scheduler.step).parameters.keys()) 677 | extra_step_kwargs = {} 678 | if accepts_eta: 679 | extra_step_kwargs['eta'] = eta 680 | 681 | # check if the scheduler accepts generator 682 | accepts_generator = 'generator' in set( 683 | inspect.signature(self.scheduler.step).parameters.keys()) 684 | if accepts_generator: 685 | extra_step_kwargs['generator'] = generator 686 | return extra_step_kwargs 687 | 688 | def prepare_latents(self, 689 | image, 690 | timestep, 691 | batch_size, 692 | height, 693 | width, 694 | dtype, 695 | device, 696 | generator, 697 | latents=None): 698 | if image is None: 699 | shape = ( 700 | batch_size, 701 | self.unet.config.in_channels, 702 | height // self.vae_scale_factor, 703 | width // self.vae_scale_factor, 704 | ) 705 | 706 | if latents is None: 707 | if device.type == 'mps': 708 | # randn does not work reproducibly on mps 709 | latents = torch.randn(shape, 710 | generator=generator, 711 | device='cpu', 712 | dtype=dtype).to(device) 713 | else: 714 | latents = torch.randn(shape, 715 | generator=generator, 716 | device=device, 717 | dtype=dtype) 718 | else: 719 | if latents.shape != shape: 720 | raise ValueError( 721 | f'Unexpected latents shape, got {latents.shape}, expected {shape}' 722 | ) 723 | latents = latents.to(device) 724 | 725 | # scale the initial noise by the standard deviation required by the scheduler 726 | latents = latents * self.scheduler.init_noise_sigma 727 | return latents, None, None 728 | else: 729 | init_latent_dist = self.vae.encode(image).latent_dist 730 | init_latents = init_latent_dist.sample(generator=generator) 731 | init_latents = 0.18215 * init_latents 732 | init_latents = torch.cat([init_latents] * batch_size, dim=0) 733 | init_latents_orig = init_latents 734 | shape = init_latents.shape 735 | 736 | # add noise to latents using the timesteps 737 | if device.type == 'mps': 738 | noise = torch.randn(shape, 739 | generator=generator, 740 | device='cpu', 741 | dtype=dtype).to(device) 742 | else: 743 | noise = torch.randn(shape, 744 | generator=generator, 745 | device=device, 746 | dtype=dtype) 747 | latents = self.scheduler.add_noise(init_latents, noise, timestep) 748 | return latents, init_latents_orig, noise 749 | 750 | @torch.no_grad() 751 | def __call__( 752 | self, 753 | prompt: Union[str, List[str]], 754 | negative_prompt: Optional[Union[str, List[str]]] = None, 755 | image: Union[torch.FloatTensor, PIL.Image.Image] = None, 756 | mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, 757 | height: int = 512, 758 | width: int = 512, 759 | num_inference_steps: int = 50, 760 | guidance_scale: float = 7.5, 761 | strength: float = 0.8, 762 | num_images_per_prompt: Optional[int] = 1, 763 | eta: float = 0.0, 764 | generator: Optional[torch.Generator] = None, 765 | latents: Optional[torch.FloatTensor] = None, 766 | max_embeddings_multiples: Optional[int] = 3, 767 | output_type: Optional[str] = 'pil', 768 | return_dict: bool = True, 769 | callback: Optional[Callable[[int, int, torch.FloatTensor], 770 | None]] = None, 771 | is_cancelled_callback: Optional[Callable[[], bool]] = None, 772 | callback_steps: int = 1, 773 | ): 774 | r""" 775 | Function invoked when calling the pipeline for generation. 776 | 777 | Args: 778 | prompt (`str` or `List[str]`): 779 | The prompt or prompts to guide the image generation. 780 | negative_prompt (`str` or `List[str]`, *optional*): 781 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 782 | if `guidance_scale` is less than `1`). 783 | image (`torch.FloatTensor` or `PIL.Image.Image`): 784 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 785 | process. 786 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`): 787 | `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be 788 | replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a 789 | PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should 790 | contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. 791 | height (`int`, *optional*, defaults to 512): 792 | The height in pixels of the generated image. 793 | width (`int`, *optional*, defaults to 512): 794 | The width in pixels of the generated image. 795 | num_inference_steps (`int`, *optional*, defaults to 50): 796 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 797 | expense of slower inference. 798 | guidance_scale (`float`, *optional*, defaults to 7.5): 799 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 800 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 801 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 802 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 803 | usually at the expense of lower image quality. 804 | strength (`float`, *optional*, defaults to 0.8): 805 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. 806 | `image` will be used as a starting point, adding more noise to it the larger the `strength`. The 807 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added 808 | noise will be maximum and the denoising process will run for the full number of iterations specified in 809 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. 810 | num_images_per_prompt (`int`, *optional*, defaults to 1): 811 | The number of images to generate per prompt. 812 | eta (`float`, *optional*, defaults to 0.0): 813 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 814 | [`schedulers.DDIMScheduler`], will be ignored for others. 815 | generator (`torch.Generator`, *optional*): 816 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 817 | deterministic. 818 | latents (`torch.FloatTensor`, *optional*): 819 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 820 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 821 | tensor will ge generated by sampling using the supplied random `generator`. 822 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 823 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 824 | output_type (`str`, *optional*, defaults to `"pil"`): 825 | The output format of the generate image. Choose between 826 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 827 | return_dict (`bool`, *optional*, defaults to `True`): 828 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 829 | plain tuple. 830 | callback (`Callable`, *optional*): 831 | A function that will be called every `callback_steps` steps during inference. The function will be 832 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 833 | is_cancelled_callback (`Callable`, *optional*): 834 | A function that will be called every `callback_steps` steps during inference. If the function returns 835 | `True`, the inference will be cancelled. 836 | callback_steps (`int`, *optional*, defaults to 1): 837 | The frequency at which the `callback` function will be called. If not specified, the callback will be 838 | called at every step. 839 | 840 | Returns: 841 | `None` if cancelled by `is_cancelled_callback`, 842 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 843 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 844 | When returning a tuple, the first element is a list with the generated images, and the second element is a 845 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 846 | (nsfw) content, according to the `safety_checker`. 847 | """ 848 | # 0. Default height and width to unet 849 | height = height or self.unet.config.sample_size * self.vae_scale_factor 850 | width = width or self.unet.config.sample_size * self.vae_scale_factor 851 | 852 | # 1. Check inputs. Raise error if not correct 853 | self.check_inputs(prompt, height, width, strength, callback_steps) 854 | 855 | # 2. Define call parameters 856 | batch_size = 1 if isinstance(prompt, str) else len(prompt) 857 | device = self._execution_device 858 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 859 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 860 | # corresponds to doing no classifier free guidance. 861 | do_classifier_free_guidance = guidance_scale > 1.0 862 | 863 | # 3. Encode input prompt 864 | text_embeddings = self._encode_prompt( 865 | prompt, 866 | device, 867 | num_images_per_prompt, 868 | do_classifier_free_guidance, 869 | negative_prompt, 870 | max_embeddings_multiples, 871 | ) 872 | dtype = text_embeddings.dtype 873 | 874 | # 4. Preprocess image and mask 875 | if isinstance(image, PIL.Image.Image): 876 | image = preprocess_image(image) 877 | if image is not None: 878 | image = image.to(device=self.device, dtype=dtype) 879 | if isinstance(mask_image, PIL.Image.Image): 880 | mask_image = preprocess_mask(mask_image, self.vae_scale_factor) 881 | if mask_image is not None: 882 | mask = mask_image.to(device=self.device, dtype=dtype) 883 | mask = torch.cat([mask] * batch_size * num_images_per_prompt) 884 | else: 885 | mask = None 886 | 887 | # 5. set timesteps 888 | self.scheduler.set_timesteps(num_inference_steps, device=device) 889 | timesteps, num_inference_steps = self.get_timesteps( 890 | num_inference_steps, strength, device, image is None) 891 | latent_timestep = timesteps[:1].repeat(batch_size * 892 | num_images_per_prompt) 893 | 894 | # 6. Prepare latent variables 895 | latents, init_latents_orig, noise = self.prepare_latents( 896 | image, 897 | latent_timestep, 898 | batch_size * num_images_per_prompt, 899 | height, 900 | width, 901 | dtype, 902 | device, 903 | generator, 904 | latents, 905 | ) 906 | 907 | # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 908 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 909 | 910 | # 8. Denoising loop 911 | for i, t in enumerate(self.progress_bar(timesteps)): 912 | # expand the latents if we are doing classifier free guidance 913 | latent_model_input = torch.cat( 914 | [latents] * 2) if do_classifier_free_guidance else latents 915 | latent_model_input = self.scheduler.scale_model_input( 916 | latent_model_input, t) 917 | 918 | # predict the noise residual 919 | noise_pred = self.unet( 920 | latent_model_input, t, 921 | encoder_hidden_states=text_embeddings).sample 922 | 923 | # perform guidance 924 | if do_classifier_free_guidance: 925 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 926 | noise_pred = noise_pred_uncond + guidance_scale * ( 927 | noise_pred_text - noise_pred_uncond) 928 | 929 | # compute the previous noisy sample x_t -> x_t-1 930 | latents = self.scheduler.step(noise_pred, t, latents, 931 | **extra_step_kwargs).prev_sample 932 | 933 | if mask is not None: 934 | # masking 935 | init_latents_proper = self.scheduler.add_noise( 936 | init_latents_orig, noise, torch.tensor([t])) 937 | latents = (init_latents_proper * mask) + (latents * (1 - mask)) 938 | 939 | # call the callback, if provided 940 | if i % callback_steps == 0: 941 | if callback is not None: 942 | callback(i, t, latents) 943 | if is_cancelled_callback is not None and is_cancelled_callback( 944 | ): 945 | return None 946 | 947 | # 9. Post-processing 948 | image = self.decode_latents(latents) 949 | 950 | # 10. Run safety checker 951 | image, has_nsfw_concept = self.run_safety_checker( 952 | image, device, text_embeddings.dtype) 953 | 954 | # 11. Convert to PIL 955 | if output_type == 'pil': 956 | image = self.numpy_to_pil(image) 957 | 958 | if not return_dict: 959 | return image, has_nsfw_concept 960 | 961 | return StableDiffusionPipelineOutput( 962 | images=image, nsfw_content_detected=has_nsfw_concept) 963 | 964 | def text2img( 965 | self, 966 | prompt: Union[str, List[str]], 967 | negative_prompt: Optional[Union[str, List[str]]] = None, 968 | height: int = 512, 969 | width: int = 512, 970 | num_inference_steps: int = 50, 971 | guidance_scale: float = 7.5, 972 | num_images_per_prompt: Optional[int] = 1, 973 | eta: float = 0.0, 974 | generator: Optional[torch.Generator] = None, 975 | latents: Optional[torch.FloatTensor] = None, 976 | max_embeddings_multiples: Optional[int] = 3, 977 | output_type: Optional[str] = 'pil', 978 | return_dict: bool = True, 979 | callback: Optional[Callable[[int, int, torch.FloatTensor], 980 | None]] = None, 981 | is_cancelled_callback: Optional[Callable[[], bool]] = None, 982 | callback_steps: int = 1, 983 | ): 984 | r""" 985 | Function for text-to-image generation. 986 | Args: 987 | prompt (`str` or `List[str]`): 988 | The prompt or prompts to guide the image generation. 989 | negative_prompt (`str` or `List[str]`, *optional*): 990 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 991 | if `guidance_scale` is less than `1`). 992 | height (`int`, *optional*, defaults to 512): 993 | The height in pixels of the generated image. 994 | width (`int`, *optional*, defaults to 512): 995 | The width in pixels of the generated image. 996 | num_inference_steps (`int`, *optional*, defaults to 50): 997 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 998 | expense of slower inference. 999 | guidance_scale (`float`, *optional*, defaults to 7.5): 1000 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 1001 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 1002 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1003 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1004 | usually at the expense of lower image quality. 1005 | num_images_per_prompt (`int`, *optional*, defaults to 1): 1006 | The number of images to generate per prompt. 1007 | eta (`float`, *optional*, defaults to 0.0): 1008 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 1009 | [`schedulers.DDIMScheduler`], will be ignored for others. 1010 | generator (`torch.Generator`, *optional*): 1011 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 1012 | deterministic. 1013 | latents (`torch.FloatTensor`, *optional*): 1014 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 1015 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 1016 | tensor will ge generated by sampling using the supplied random `generator`. 1017 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 1018 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 1019 | output_type (`str`, *optional*, defaults to `"pil"`): 1020 | The output format of the generate image. Choose between 1021 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 1022 | return_dict (`bool`, *optional*, defaults to `True`): 1023 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 1024 | plain tuple. 1025 | callback (`Callable`, *optional*): 1026 | A function that will be called every `callback_steps` steps during inference. The function will be 1027 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 1028 | is_cancelled_callback (`Callable`, *optional*): 1029 | A function that will be called every `callback_steps` steps during inference. If the function returns 1030 | `True`, the inference will be cancelled. 1031 | callback_steps (`int`, *optional*, defaults to 1): 1032 | The frequency at which the `callback` function will be called. If not specified, the callback will be 1033 | called at every step. 1034 | Returns: 1035 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 1036 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 1037 | When returning a tuple, the first element is a list with the generated images, and the second element is a 1038 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 1039 | (nsfw) content, according to the `safety_checker`. 1040 | """ 1041 | return self.__call__( 1042 | prompt=prompt, 1043 | negative_prompt=negative_prompt, 1044 | height=height, 1045 | width=width, 1046 | num_inference_steps=num_inference_steps, 1047 | guidance_scale=guidance_scale, 1048 | num_images_per_prompt=num_images_per_prompt, 1049 | eta=eta, 1050 | generator=generator, 1051 | latents=latents, 1052 | max_embeddings_multiples=max_embeddings_multiples, 1053 | output_type=output_type, 1054 | return_dict=return_dict, 1055 | callback=callback, 1056 | is_cancelled_callback=is_cancelled_callback, 1057 | callback_steps=callback_steps, 1058 | ) 1059 | 1060 | def img2img( 1061 | self, 1062 | image: Union[torch.FloatTensor, PIL.Image.Image], 1063 | prompt: Union[str, List[str]], 1064 | negative_prompt: Optional[Union[str, List[str]]] = None, 1065 | strength: float = 0.8, 1066 | num_inference_steps: Optional[int] = 50, 1067 | guidance_scale: Optional[float] = 7.5, 1068 | num_images_per_prompt: Optional[int] = 1, 1069 | eta: Optional[float] = 0.0, 1070 | generator: Optional[torch.Generator] = None, 1071 | max_embeddings_multiples: Optional[int] = 3, 1072 | output_type: Optional[str] = 'pil', 1073 | return_dict: bool = True, 1074 | callback: Optional[Callable[[int, int, torch.FloatTensor], 1075 | None]] = None, 1076 | is_cancelled_callback: Optional[Callable[[], bool]] = None, 1077 | callback_steps: int = 1, 1078 | ): 1079 | r""" 1080 | Function for image-to-image generation. 1081 | Args: 1082 | image (`torch.FloatTensor` or `PIL.Image.Image`): 1083 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 1084 | process. 1085 | prompt (`str` or `List[str]`): 1086 | The prompt or prompts to guide the image generation. 1087 | negative_prompt (`str` or `List[str]`, *optional*): 1088 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 1089 | if `guidance_scale` is less than `1`). 1090 | strength (`float`, *optional*, defaults to 0.8): 1091 | Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1. 1092 | `image` will be used as a starting point, adding more noise to it the larger the `strength`. The 1093 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added 1094 | noise will be maximum and the denoising process will run for the full number of iterations specified in 1095 | `num_inference_steps`. A value of 1, therefore, essentially ignores `image`. 1096 | num_inference_steps (`int`, *optional*, defaults to 50): 1097 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 1098 | expense of slower inference. This parameter will be modulated by `strength`. 1099 | guidance_scale (`float`, *optional*, defaults to 7.5): 1100 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 1101 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 1102 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1103 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1104 | usually at the expense of lower image quality. 1105 | num_images_per_prompt (`int`, *optional*, defaults to 1): 1106 | The number of images to generate per prompt. 1107 | eta (`float`, *optional*, defaults to 0.0): 1108 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 1109 | [`schedulers.DDIMScheduler`], will be ignored for others. 1110 | generator (`torch.Generator`, *optional*): 1111 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 1112 | deterministic. 1113 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 1114 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 1115 | output_type (`str`, *optional*, defaults to `"pil"`): 1116 | The output format of the generate image. Choose between 1117 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 1118 | return_dict (`bool`, *optional*, defaults to `True`): 1119 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 1120 | plain tuple. 1121 | callback (`Callable`, *optional*): 1122 | A function that will be called every `callback_steps` steps during inference. The function will be 1123 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 1124 | is_cancelled_callback (`Callable`, *optional*): 1125 | A function that will be called every `callback_steps` steps during inference. If the function returns 1126 | `True`, the inference will be cancelled. 1127 | callback_steps (`int`, *optional*, defaults to 1): 1128 | The frequency at which the `callback` function will be called. If not specified, the callback will be 1129 | called at every step. 1130 | Returns: 1131 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 1132 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 1133 | When returning a tuple, the first element is a list with the generated images, and the second element is a 1134 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 1135 | (nsfw) content, according to the `safety_checker`. 1136 | """ 1137 | return self.__call__( 1138 | prompt=prompt, 1139 | negative_prompt=negative_prompt, 1140 | image=image, 1141 | num_inference_steps=num_inference_steps, 1142 | guidance_scale=guidance_scale, 1143 | strength=strength, 1144 | num_images_per_prompt=num_images_per_prompt, 1145 | eta=eta, 1146 | generator=generator, 1147 | max_embeddings_multiples=max_embeddings_multiples, 1148 | output_type=output_type, 1149 | return_dict=return_dict, 1150 | callback=callback, 1151 | is_cancelled_callback=is_cancelled_callback, 1152 | callback_steps=callback_steps, 1153 | ) 1154 | 1155 | def inpaint( 1156 | self, 1157 | image: Union[torch.FloatTensor, PIL.Image.Image], 1158 | mask_image: Union[torch.FloatTensor, PIL.Image.Image], 1159 | prompt: Union[str, List[str]], 1160 | negative_prompt: Optional[Union[str, List[str]]] = None, 1161 | strength: float = 0.8, 1162 | num_inference_steps: Optional[int] = 50, 1163 | guidance_scale: Optional[float] = 7.5, 1164 | num_images_per_prompt: Optional[int] = 1, 1165 | eta: Optional[float] = 0.0, 1166 | generator: Optional[torch.Generator] = None, 1167 | max_embeddings_multiples: Optional[int] = 3, 1168 | output_type: Optional[str] = 'pil', 1169 | return_dict: bool = True, 1170 | callback: Optional[Callable[[int, int, torch.FloatTensor], 1171 | None]] = None, 1172 | is_cancelled_callback: Optional[Callable[[], bool]] = None, 1173 | callback_steps: int = 1, 1174 | ): 1175 | r""" 1176 | Function for inpaint. 1177 | Args: 1178 | image (`torch.FloatTensor` or `PIL.Image.Image`): 1179 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 1180 | process. This is the image whose masked region will be inpainted. 1181 | mask_image (`torch.FloatTensor` or `PIL.Image.Image`): 1182 | `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be 1183 | replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a 1184 | PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should 1185 | contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`. 1186 | prompt (`str` or `List[str]`): 1187 | The prompt or prompts to guide the image generation. 1188 | negative_prompt (`str` or `List[str]`, *optional*): 1189 | The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored 1190 | if `guidance_scale` is less than `1`). 1191 | strength (`float`, *optional*, defaults to 0.8): 1192 | Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength` 1193 | is 1, the denoising process will be run on the masked area for the full number of iterations specified 1194 | in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more 1195 | noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur. 1196 | num_inference_steps (`int`, *optional*, defaults to 50): 1197 | The reference number of denoising steps. More denoising steps usually lead to a higher quality image at 1198 | the expense of slower inference. This parameter will be modulated by `strength`, as explained above. 1199 | guidance_scale (`float`, *optional*, defaults to 7.5): 1200 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 1201 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 1202 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1203 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 1204 | usually at the expense of lower image quality. 1205 | num_images_per_prompt (`int`, *optional*, defaults to 1): 1206 | The number of images to generate per prompt. 1207 | eta (`float`, *optional*, defaults to 0.0): 1208 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 1209 | [`schedulers.DDIMScheduler`], will be ignored for others. 1210 | generator (`torch.Generator`, *optional*): 1211 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 1212 | deterministic. 1213 | max_embeddings_multiples (`int`, *optional*, defaults to `3`): 1214 | The max multiple length of prompt embeddings compared to the max output length of text encoder. 1215 | output_type (`str`, *optional*, defaults to `"pil"`): 1216 | The output format of the generate image. Choose between 1217 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 1218 | return_dict (`bool`, *optional*, defaults to `True`): 1219 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 1220 | plain tuple. 1221 | callback (`Callable`, *optional*): 1222 | A function that will be called every `callback_steps` steps during inference. The function will be 1223 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 1224 | is_cancelled_callback (`Callable`, *optional*): 1225 | A function that will be called every `callback_steps` steps during inference. If the function returns 1226 | `True`, the inference will be cancelled. 1227 | callback_steps (`int`, *optional*, defaults to 1): 1228 | The frequency at which the `callback` function will be called. If not specified, the callback will be 1229 | called at every step. 1230 | Returns: 1231 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 1232 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 1233 | When returning a tuple, the first element is a list with the generated images, and the second element is a 1234 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 1235 | (nsfw) content, according to the `safety_checker`. 1236 | """ 1237 | return self.__call__( 1238 | prompt=prompt, 1239 | negative_prompt=negative_prompt, 1240 | image=image, 1241 | mask_image=mask_image, 1242 | num_inference_steps=num_inference_steps, 1243 | guidance_scale=guidance_scale, 1244 | strength=strength, 1245 | num_images_per_prompt=num_images_per_prompt, 1246 | eta=eta, 1247 | generator=generator, 1248 | max_embeddings_multiples=max_embeddings_multiples, 1249 | output_type=output_type, 1250 | return_dict=return_dict, 1251 | callback=callback, 1252 | is_cancelled_callback=is_cancelled_callback, 1253 | callback_steps=callback_steps, 1254 | ) 1255 | --------------------------------------------------------------------------------