├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── builder ├── build_model.py └── requirements.txt └── src ├── comic_generator_xl.py ├── photomaker_id_encoder.py ├── pipeline.py ├── rp_handler.py ├── rp_schema.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 2 | 3 | # Build args 4 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 5 | ENV WORKER_MODEL_DIR=/app/model 6 | ENV WORKER_USE_CUDA=True 7 | ENV WORKER_MODEL_NAME=SG161222/RealVisXL_V4.0 8 | ENV WORKER_ID_LENGTH=4 9 | ENV WORKER_TOTAL_LENGTH=5 10 | ENV WORKER_SCHEDULER_TYPE=euler 11 | ENV RUNPOD_DEBUG_LEVEL=INFO 12 | 13 | SHELL ["/bin/bash", "-o", "pipefail", "-c"] 14 | 15 | ENV WORKER_DIR=/app 16 | RUN mkdir ${WORKER_DIR} 17 | WORKDIR ${WORKER_DIR} 18 | 19 | SHELL ["/bin/bash", "-c"] 20 | ENV DEBIAN_FRONTEND=noninteractive 21 | ENV SHELL=/bin/bash 22 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/x86_64-linux-gnu 23 | 24 | # Install some basic utilities 25 | RUN apt-get update --fix-missing && \ 26 | apt-get install -y wget bzip2 ca-certificates curl git sudo gcc build-essential openssh-client cmake g++ ninja-build && \ 27 | apt-get install -y libaio-dev && \ 28 | DEBIAN_FRONTEND=noninteractive apt-get install -y python3-dev python3-pip && \ 29 | apt-get clean && \ 30 | rm -rf /var/lib/apt/lists/* 31 | 32 | # Create a non-root user and switch to it 33 | RUN adduser --disabled-password --gecos '' --shell /bin/bash user \ 34 | && chown -R user:user ${WORKER_DIR} 35 | RUN echo "user ALL=(ALL) NOPASSWD:ALL" > /etc/sudoers.d/90-user 36 | USER user 37 | 38 | # All users can use /home/user as their home directory 39 | ENV HOME=/home/user 40 | ENV SHELL=/bin/bash 41 | 42 | # Install Python dependencies (Worker Template) 43 | COPY builder/requirements.txt ${WORKER_DIR}/requirements.txt 44 | RUN pip install --no-cache-dir -r ${WORKER_DIR}/requirements.txt && \ 45 | rm ${WORKER_DIR}/requirements.txt 46 | 47 | # Fetch the model 48 | COPY builder/build_model.py ${WORKER_DIR}/build_model.py 49 | RUN python3 -u ${WORKER_DIR}/build_model.py --model-name="${WORKER_MODEL_NAME}" --model-dir="${WORKER_MODEL_DIR}" --use-cuda && \ 50 | rm ${WORKER_DIR}/build_model.py && \ 51 | rm -rf /home/user/.cache/huggingface/ 52 | 53 | # Add src files (Worker Template) 54 | ADD src ${WORKER_DIR} 55 | 56 | CMD python3 -u ${WORKER_DIR}/rp_handler.py --model-dir="${WORKER_MODEL_DIR}" 57 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StoryDiffusion: Serverless RunPod Worker 2 | 3 | ## RunPod Endpoint 4 | 5 | This repository contains the worker for the StoryDiffusion AI Endpoints. 6 | 7 | ## Docker Image 8 | 9 | ```bash 10 | docker build . 11 | ``` 12 | or 13 | 14 | ```bash 15 | docker pull devbes/story-diffusion-serverless-worker:latest 16 | ``` 17 | 18 | ## Environment Variables 19 | 20 | ### S3 storage 21 | 22 | - BUCKET_ENDPOINT_URL 23 | - BUCKET_ACCESS_KEY_ID 24 | - BUCKET_SECRET_ACCESS_KEY 25 | 26 | ### Dockerfile configuration 27 | 28 | - WORKER_MODEL_NAME (default = SG161222/RealVisXL_V4.0) 29 | - WORKER_ID_LENGTH (default = 4) 30 | - WORKER_TOTAL_LENGTH (default = 5) 31 | - WORKER_SCHEDULER_TYPE (default = euler) 32 | 33 | 34 | ## Continuous Deployment 35 | This worker follows a modified version of the [worker template](https://github.com/runpod-workers/worker-template) where the Docker build workflow contains additional SD models to be built and pushed. 36 | 37 | ## API 38 | 39 | Use 'img' as a trigger word for personalized generation cases. 40 | 41 | ```json 42 | { 43 | "input": { 44 | "prompts": [, ], 45 | "negative_prompt": , 46 | "width": , 47 | "height": , 48 | "sa32": , 49 | "sa64": , 50 | "guidance_scale": , 51 | "num_inference_steps": , 52 | "seed": , 53 | "image_ref": 54 | } 55 | } 56 | ``` 57 | 58 | Sample request: 59 | ```json 60 | { 61 | "input": { 62 | "prompts": ["Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. discovering a golden key in his grandmother's attic.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. talking to a squirrel in a magical forest.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. jumping on giant marshmallows at the top of a mountain.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. hopping into a boat on a sparkling river.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. talking to a mole in an underground cave.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. dancing at a village festival.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. talking to the squirrel again in the magical forest.", "Harold img is a curious and clever boy with bright blue eyes and messy brown hair. He always wears a red hat and carries a tiny backpack full of gadgets. telling his grandmother about his adventure at her home."], 63 | "negative_prompt": "naked, deformed, bad anatomy, disfigured, poorly drawn face, mutation, extra limb, ugly, disgusting, poorly drawn hands, missing limb, floating limbs, disconnected limbs, blurry, watermarks, oversaturated, distorted hands, amputation", 64 | "width": 768, 65 | "height": 768, 66 | "sa32": 0.5, 67 | "sa64": 0.5, 68 | "guidance_scale": 5.0, 69 | "num_inference_steps": 25, 70 | "seed": 42, 71 | "image_ref": "https://alpinabook.ru/upload/resize_cache/iblock/8d9/550_800_1/8d9cd63476f15e85f0d8796555ab1e6b.jpg" 72 | } 73 | } 74 | ``` 75 | 76 | ## Related Resources 77 | 78 | This project is based on original implementation of [StoryDiffusion](https://github.com/HVision-NKU/StoryDiffusion). 79 | -------------------------------------------------------------------------------- /builder/build_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | from huggingface_hub import hf_hub_download 5 | from diffusers import StableDiffusionXLPipeline 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--model-name", type=str, required=True) 11 | parser.add_argument("--model-dir", type=str, required=True) 12 | parser.add_argument("--use-cuda", action="store_true") 13 | args = parser.parse_args() 14 | pipe = StableDiffusionXLPipeline.from_pretrained( 15 | args.model_name, 16 | torch_dtype=torch.float16, 17 | ) 18 | pipe.save_pretrained(args.model_dir, revision="fp16") 19 | # load photomaker 20 | photomaker_path = hf_hub_download( 21 | repo_id="TencentARC/PhotoMaker", 22 | filename="photomaker-v1.bin", 23 | repo_type="model", 24 | local_dir=os.path.join(args.model_dir, "photomaker") 25 | ) 26 | -------------------------------------------------------------------------------- /builder/requirements.txt: -------------------------------------------------------------------------------- 1 | runpod==1.4.2 2 | diffusers==0.27.2 3 | transformers==4.40.1 4 | accelerate==0.30.0 5 | tqdm 6 | peft 7 | torchvision 8 | -------------------------------------------------------------------------------- /src/comic_generator_xl.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import pickle 4 | import random 5 | import requests 6 | import sys 7 | # numpy 8 | import numpy as np 9 | # torch 10 | import torch 11 | import torch.nn.functional as F 12 | # utils 13 | from utils import is_torch2_available, cal_attn_mask_xl, setup_seed 14 | if is_torch2_available(): 15 | from utils import \ 16 | AttnProcessor2_0 as AttnProcessor 17 | else: 18 | from utils import AttnProcessor 19 | # diffusers 20 | import diffusers 21 | # from diffusers import StableDiffusionXLPipeline 22 | from pipeline import StoryDiffusionXLPipeline 23 | from diffusers import DDIMScheduler, EulerDiscreteScheduler 24 | # utils 25 | from PIL import Image 26 | 27 | 28 | ################################################# 29 | ########Consistent Self-Attention################ 30 | ################################################# 31 | class SpatialAttnProcessor2_0(torch.nn.Module): 32 | r""" 33 | Attention processor for IP-Adapater for PyTorch 2.0. 34 | Args: 35 | hidden_size (`int`): 36 | The hidden size of the attention layer. 37 | cross_attention_dim (`int`): 38 | The number of channels in the `encoder_hidden_states`. 39 | text_context_len (`int`, defaults to 77): 40 | The context length of the text features. 41 | scale (`float`, defaults to 1.0): 42 | the weight scale of image prompt. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | hidden_size = None, 48 | cross_attention_dim=None, 49 | id_length = 4, 50 | device = "cuda", 51 | dtype = torch.float16 52 | ): 53 | super().__init__() 54 | if not hasattr(F, "scaled_dot_product_attention"): 55 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 56 | self.device = device 57 | self.dtype = dtype 58 | self.hidden_size = hidden_size 59 | self.cross_attention_dim = cross_attention_dim 60 | self.total_length = id_length + 1 61 | self.id_length = id_length 62 | self.id_bank = {} 63 | 64 | def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): 65 | global _total_count, _attn_count, _cur_step, _mask1024, _mask4096 66 | global _sa32, _sa64 67 | global _write 68 | global _height, _width 69 | if _write: 70 | self.id_bank[_cur_step] = [hidden_states[:self.id_length], hidden_states[self.id_length:]] 71 | else: 72 | encoder_hidden_states = torch.cat(( 73 | self.id_bank[_cur_step][0].to(self.device), 74 | hidden_states[:1], 75 | self.id_bank[_cur_step][1].to(self.device), hidden_states[1:] 76 | )) 77 | # skip in early step 78 | if _cur_step < 5: 79 | hidden_states = self.__call2__(attn, hidden_states, encoder_hidden_states, attention_mask, temb) 80 | else: # 256 1024 4096 81 | random_number = random.random() 82 | if _cur_step < 20: 83 | rand_num = 0.3 84 | else: 85 | rand_num = 0.1 86 | if random_number > rand_num: 87 | if not _write: 88 | if hidden_states.shape[1] == (_height//32) * (_width//32): 89 | attention_mask = _mask1024[_mask1024.shape[0] // self.total_length * self.id_length:] 90 | else: 91 | attention_mask = _mask4096[_mask4096.shape[0] // self.total_length * self.id_length:] 92 | else: 93 | if hidden_states.shape[1] == (_height//32) * (_width//32): 94 | attention_mask = _mask1024[:_mask1024.shape[0] // self.total_length * self.id_length,:_mask1024.shape[0] // self.total_length * self.id_length] 95 | else: 96 | attention_mask = _mask4096[:_mask4096.shape[0] // self.total_length * self.id_length,:_mask4096.shape[0] // self.total_length * self.id_length] 97 | hidden_states = self.__call1__(attn, hidden_states, encoder_hidden_states, attention_mask,temb) 98 | else: 99 | hidden_states = self.__call2__(attn, hidden_states, None, attention_mask, temb) 100 | _attn_count +=1 101 | if _attn_count == _total_count: 102 | _attn_count = 0 103 | _cur_step += 1 104 | _mask1024, _mask4096 = cal_attn_mask_xl( 105 | self.total_length, 106 | self.id_length, 107 | _sa32, 108 | _sa64, 109 | _height, 110 | _width, 111 | device = self.device, 112 | dtype = self.dtype 113 | ) 114 | return hidden_states 115 | 116 | def __call1__( 117 | self, 118 | attn, 119 | hidden_states, 120 | encoder_hidden_states=None, 121 | attention_mask=None, 122 | temb=None, 123 | ): 124 | residual = hidden_states 125 | if attn.spatial_norm is not None: 126 | hidden_states = attn.spatial_norm(hidden_states, temb) 127 | input_ndim = hidden_states.ndim 128 | 129 | if input_ndim == 4: 130 | total_batch_size, channel, height, width = hidden_states.shape 131 | hidden_states = hidden_states.view(total_batch_size, channel, height * width).transpose(1, 2) 132 | total_batch_size,nums_token,channel = hidden_states.shape 133 | img_nums = total_batch_size//2 134 | hidden_states = hidden_states.view(-1,img_nums,nums_token,channel).reshape(-1,img_nums * nums_token,channel) 135 | 136 | batch_size, sequence_length, _ = hidden_states.shape 137 | 138 | if attn.group_norm is not None: 139 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 140 | 141 | query = attn.to_q(hidden_states) 142 | 143 | if encoder_hidden_states is None: 144 | encoder_hidden_states = hidden_states # B, N, C 145 | else: 146 | encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,nums_token,channel).reshape(-1,(self.id_length+1) * nums_token,channel) 147 | 148 | key = attn.to_k(encoder_hidden_states) 149 | value = attn.to_v(encoder_hidden_states) 150 | 151 | inner_dim = key.shape[-1] 152 | head_dim = inner_dim // attn.heads 153 | 154 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 155 | 156 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 157 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 158 | hidden_states = F.scaled_dot_product_attention( 159 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 160 | ) 161 | 162 | hidden_states = hidden_states.transpose(1, 2).reshape(total_batch_size, -1, attn.heads * head_dim) 163 | hidden_states = hidden_states.to(query.dtype) 164 | 165 | # linear proj 166 | hidden_states = attn.to_out[0](hidden_states) 167 | # dropout 168 | hidden_states = attn.to_out[1](hidden_states) 169 | 170 | if input_ndim == 4: 171 | hidden_states = hidden_states.transpose(-1, -2).reshape(total_batch_size, channel, height, width) 172 | if attn.residual_connection: 173 | hidden_states = hidden_states + residual 174 | hidden_states = hidden_states / attn.rescale_output_factor 175 | # print(hidden_states.shape) 176 | return hidden_states 177 | 178 | def __call2__( 179 | self, 180 | attn, 181 | hidden_states, 182 | encoder_hidden_states=None, 183 | attention_mask=None, 184 | temb=None): 185 | residual = hidden_states 186 | 187 | if attn.spatial_norm is not None: 188 | hidden_states = attn.spatial_norm(hidden_states, temb) 189 | 190 | input_ndim = hidden_states.ndim 191 | 192 | if input_ndim == 4: 193 | batch_size, channel, height, width = hidden_states.shape 194 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 195 | 196 | batch_size, sequence_length, channel = ( 197 | hidden_states.shape 198 | ) 199 | # print(hidden_states.shape) 200 | if attention_mask is not None: 201 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 202 | # scaled_dot_product_attention expects attention_mask shape to be 203 | # (batch, heads, source_length, target_length) 204 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 205 | 206 | if attn.group_norm is not None: 207 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 208 | 209 | query = attn.to_q(hidden_states) 210 | 211 | if encoder_hidden_states is None: 212 | encoder_hidden_states = hidden_states # B, N, C 213 | else: 214 | encoder_hidden_states = encoder_hidden_states.view(-1,self.id_length+1,sequence_length,channel).reshape(-1,(self.id_length+1) * sequence_length,channel) 215 | 216 | key = attn.to_k(encoder_hidden_states) 217 | value = attn.to_v(encoder_hidden_states) 218 | 219 | inner_dim = key.shape[-1] 220 | head_dim = inner_dim // attn.heads 221 | 222 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 223 | 224 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 225 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 226 | 227 | hidden_states = F.scaled_dot_product_attention( 228 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 229 | ) 230 | 231 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 232 | hidden_states = hidden_states.to(query.dtype) 233 | 234 | # linear proj 235 | hidden_states = attn.to_out[0](hidden_states) 236 | # dropout 237 | hidden_states = attn.to_out[1](hidden_states) 238 | 239 | if input_ndim == 4: 240 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 241 | 242 | if attn.residual_connection: 243 | hidden_states = hidden_states + residual 244 | 245 | hidden_states = hidden_states / attn.rescale_output_factor 246 | 247 | return hidden_states 248 | 249 | def set_attention_processor(unet,id_length): 250 | attn_procs = {} 251 | for name in unet.attn_processors.keys(): 252 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 253 | if name.startswith("mid_block"): 254 | hidden_size = unet.config.block_out_channels[-1] 255 | elif name.startswith("up_blocks"): 256 | block_id = int(name[len("up_blocks.")]) 257 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 258 | elif name.startswith("down_blocks"): 259 | block_id = int(name[len("down_blocks.")]) 260 | hidden_size = unet.config.block_out_channels[block_id] 261 | if cross_attention_dim is None: 262 | if name.startswith("up_blocks") : 263 | attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length) 264 | else: 265 | attn_procs[name] = AttnProcessor() 266 | else: 267 | attn_procs[name] = AttnProcessor() 268 | unet.set_attn_processor(attn_procs) 269 | 270 | 271 | class ComicGeneratorXL: 272 | def __init__( 273 | self, 274 | model_name: str, 275 | id_length: int = 4, 276 | total_length: int = 5, 277 | device: str = "cuda", 278 | torch_dtype: torch.dtype = torch.float16, 279 | scheduler_type: str = "euler", 280 | trigger_word: str = "img", 281 | ): 282 | global _total_count 283 | _total_count = 0 284 | # params 285 | self.model_name = model_name 286 | self.id_length = id_length 287 | self.total_length = total_length 288 | self.device = device 289 | self.torch_dtype = torch_dtype 290 | self.trigger_word = trigger_word 291 | # load pipeline 292 | # self.pipe = StableDiffusionXLPipeline.from_pretrained( 293 | # TODO: add photomaker loader 294 | self.pipe = StoryDiffusionXLPipeline.from_pretrained( 295 | model_name, 296 | torch_dtype=torch_dtype 297 | ).to(device) 298 | # load photomaker for personalization 299 | photomaker_path = os.path.join(model_name, "photomaker", "photomaker-v1.bin") 300 | self.pipe.load_photomaker_adapter( 301 | photomaker_path, 302 | subfolder = "", 303 | weight_name = os.path.basename(photomaker_path), 304 | trigger_word = self.trigger_word 305 | ) 306 | self.pipe.enable_freeu(s1=0.6, s2=0.4, b1=1.1, b2=1.2) 307 | if scheduler_type == "euler": 308 | self.pipe.scheduler = EulerDiscreteScheduler.from_config(self.pipe.scheduler.config) 309 | else: 310 | self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config) 311 | self.pipe.scheduler.set_timesteps(50) 312 | ### Insert PairedAttention 313 | unet = self.pipe.unet 314 | attn_procs = {} 315 | for name in unet.attn_processors.keys(): 316 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim 317 | if name.startswith("mid_block"): 318 | hidden_size = unet.config.block_out_channels[-1] 319 | elif name.startswith("up_blocks"): 320 | block_id = int(name[len("up_blocks.")]) 321 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] 322 | elif name.startswith("down_blocks"): 323 | block_id = int(name[len("down_blocks.")]) 324 | hidden_size = unet.config.block_out_channels[block_id] 325 | if cross_attention_dim is None and (name.startswith("up_blocks") ) : 326 | attn_procs[name] = SpatialAttnProcessor2_0(id_length = id_length) 327 | _total_count +=1 328 | else: 329 | attn_procs[name] = AttnProcessor() 330 | print("successsfully load consistent self-attention") 331 | print(f"number of the processor : {_total_count}") 332 | unet.set_attn_processor(copy.deepcopy(attn_procs)) 333 | 334 | def __call__( 335 | self, 336 | prompts: list, 337 | negative_prompt: str, 338 | width: int = 768, 339 | height: int = 768, 340 | # strength of consistent self-attention: the larger, the stronger 341 | sa32: float = 0.5, 342 | sa64: float = 0.5, 343 | # sdxl params 344 | guidance_scale: float = 5.0, 345 | num_inference_steps: int = 50, 346 | seed: int = 2047, 347 | image_ref: Image.Image = None 348 | ): 349 | global _sa32, _sa64, _height, _width, _write, _mask1024, _mask4096, _cur_step, _attn_count 350 | # strength of consistent self-attention: the larger, the stronger 351 | _sa32 = sa32 352 | _sa64 = sa64 353 | # size 354 | _height = height 355 | _width = width 356 | ### 357 | _write = False 358 | _mask1024, _mask4096 = cal_attn_mask_xl( 359 | self.total_length, 360 | self.id_length, 361 | _sa32, 362 | _sa64, 363 | _height, 364 | _width, 365 | device = self.device, 366 | dtype = self.torch_dtype 367 | ) 368 | # setup seed 369 | setup_seed(seed) 370 | generator = torch.Generator(device=self.device).manual_seed(seed) 371 | # prepare consistent memory 372 | id_prompts = prompts[:self.id_length] 373 | real_prompts = prompts[self.id_length:] 374 | torch.cuda.empty_cache() 375 | _write = True 376 | _cur_step = 0 377 | _attn_count = 0 378 | input_id_images = [image_ref] if image_ref is not None else None 379 | id_images = self.pipe( 380 | id_prompts, 381 | num_inference_steps = num_inference_steps, 382 | guidance_scale = guidance_scale, 383 | height = height, 384 | width = width, 385 | negative_prompt = negative_prompt, 386 | generator = generator, 387 | input_id_images = input_id_images 388 | ).images 389 | _write = False 390 | real_images = [] 391 | for real_prompt in real_prompts: 392 | _cur_step = 0 393 | real_images.append( 394 | self.pipe( 395 | real_prompt, 396 | negative_prompt = negative_prompt, 397 | num_inference_steps = num_inference_steps, 398 | guidance_scale = guidance_scale, 399 | height = height, 400 | width = width, 401 | generator = generator, 402 | input_id_images = input_id_images 403 | ).images[0] 404 | ) 405 | return id_images + real_images 406 | -------------------------------------------------------------------------------- /src/photomaker_id_encoder.py: -------------------------------------------------------------------------------- 1 | # Merge image encoder and fuse module to create an ID Encoder 2 | # send multiple ID images, we can directly obtain the updated text encoder containing a stacked ID embedding 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers.models.clip.modeling_clip import CLIPVisionModelWithProjection 7 | from transformers.models.clip.configuration_clip import CLIPVisionConfig 8 | from transformers import PretrainedConfig 9 | 10 | VISION_CONFIG_DICT = { 11 | "hidden_size": 1024, 12 | "intermediate_size": 4096, 13 | "num_attention_heads": 16, 14 | "num_hidden_layers": 24, 15 | "patch_size": 14, 16 | "projection_dim": 768 17 | } 18 | 19 | class MLP(nn.Module): 20 | def __init__(self, in_dim, out_dim, hidden_dim, use_residual=True): 21 | super().__init__() 22 | if use_residual: 23 | assert in_dim == out_dim 24 | self.layernorm = nn.LayerNorm(in_dim) 25 | self.fc1 = nn.Linear(in_dim, hidden_dim) 26 | self.fc2 = nn.Linear(hidden_dim, out_dim) 27 | self.use_residual = use_residual 28 | self.act_fn = nn.GELU() 29 | 30 | def forward(self, x): 31 | residual = x 32 | x = self.layernorm(x) 33 | x = self.fc1(x) 34 | x = self.act_fn(x) 35 | x = self.fc2(x) 36 | if self.use_residual: 37 | x = x + residual 38 | return x 39 | 40 | 41 | class FuseModule(nn.Module): 42 | def __init__(self, embed_dim): 43 | super().__init__() 44 | self.mlp1 = MLP(embed_dim * 2, embed_dim, embed_dim, use_residual=False) 45 | self.mlp2 = MLP(embed_dim, embed_dim, embed_dim, use_residual=True) 46 | self.layer_norm = nn.LayerNorm(embed_dim) 47 | 48 | def fuse_fn(self, prompt_embeds, id_embeds): 49 | stacked_id_embeds = torch.cat([prompt_embeds, id_embeds], dim=-1) 50 | stacked_id_embeds = self.mlp1(stacked_id_embeds) + prompt_embeds 51 | stacked_id_embeds = self.mlp2(stacked_id_embeds) 52 | stacked_id_embeds = self.layer_norm(stacked_id_embeds) 53 | return stacked_id_embeds 54 | 55 | def forward( 56 | self, 57 | prompt_embeds, 58 | id_embeds, 59 | class_tokens_mask, 60 | ) -> torch.Tensor: 61 | # id_embeds shape: [b, max_num_inputs, 1, 2048] 62 | id_embeds = id_embeds.to(prompt_embeds.dtype) 63 | num_inputs = class_tokens_mask.sum().unsqueeze(0) # TODO: check for training case 64 | batch_size, max_num_inputs = id_embeds.shape[:2] 65 | # seq_length: 77 66 | seq_length = prompt_embeds.shape[1] 67 | # flat_id_embeds shape: [b*max_num_inputs, 1, 2048] 68 | flat_id_embeds = id_embeds.view( 69 | -1, id_embeds.shape[-2], id_embeds.shape[-1] 70 | ) 71 | # valid_id_mask [b*max_num_inputs] 72 | valid_id_mask = ( 73 | torch.arange(max_num_inputs, device=flat_id_embeds.device)[None, :] 74 | < num_inputs[:, None] 75 | ) 76 | valid_id_embeds = flat_id_embeds[valid_id_mask.flatten()] 77 | 78 | prompt_embeds = prompt_embeds.view(-1, prompt_embeds.shape[-1]) 79 | class_tokens_mask = class_tokens_mask.view(-1) 80 | valid_id_embeds = valid_id_embeds.view(-1, valid_id_embeds.shape[-1]) 81 | # slice out the image token embeddings 82 | image_token_embeds = prompt_embeds[class_tokens_mask] 83 | stacked_id_embeds = self.fuse_fn(image_token_embeds, valid_id_embeds) 84 | assert class_tokens_mask.sum() == stacked_id_embeds.shape[0], f"{class_tokens_mask.sum()} != {stacked_id_embeds.shape[0]}" 85 | prompt_embeds.masked_scatter_(class_tokens_mask[:, None], stacked_id_embeds.to(prompt_embeds.dtype)) 86 | updated_prompt_embeds = prompt_embeds.view(batch_size, seq_length, -1) 87 | return updated_prompt_embeds 88 | 89 | class PhotoMakerIDEncoder(CLIPVisionModelWithProjection): 90 | def __init__(self): 91 | super().__init__(CLIPVisionConfig(**VISION_CONFIG_DICT)) 92 | self.visual_projection_2 = nn.Linear(1024, 1280, bias=False) 93 | self.fuse_module = FuseModule(2048) 94 | 95 | def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask): 96 | b, num_inputs, c, h, w = id_pixel_values.shape 97 | id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) 98 | 99 | shared_id_embeds = self.vision_model(id_pixel_values)[1] 100 | id_embeds = self.visual_projection(shared_id_embeds) 101 | id_embeds_2 = self.visual_projection_2(shared_id_embeds) 102 | 103 | id_embeds = id_embeds.view(b, num_inputs, 1, -1) 104 | id_embeds_2 = id_embeds_2.view(b, num_inputs, 1, -1) 105 | 106 | id_embeds = torch.cat((id_embeds, id_embeds_2), dim=-1) 107 | updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) 108 | 109 | return updated_prompt_embeds 110 | 111 | 112 | if __name__ == "__main__": 113 | PhotoMakerIDEncoder() 114 | -------------------------------------------------------------------------------- /src/pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Optional, Union, Tuple 2 | from collections import OrderedDict 3 | import os 4 | import PIL 5 | import numpy as np 6 | 7 | import torch 8 | from torchvision import transforms as T 9 | 10 | from safetensors import safe_open 11 | from huggingface_hub.utils import validate_hf_hub_args 12 | from transformers import CLIPImageProcessor, CLIPTokenizer 13 | from diffusers import StableDiffusionXLPipeline 14 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput 15 | from diffusers.utils import ( 16 | _get_model_file, 17 | is_transformers_available, 18 | logging, 19 | ) 20 | 21 | from photomaker_id_encoder import PhotoMakerIDEncoder 22 | from utils import remove_word 23 | 24 | PipelineImageInput = Union[ 25 | PIL.Image.Image, 26 | torch.FloatTensor, 27 | List[PIL.Image.Image], 28 | List[torch.FloatTensor], 29 | ] 30 | 31 | 32 | class StoryDiffusionXLPipeline(StableDiffusionXLPipeline): 33 | @validate_hf_hub_args 34 | def load_photomaker_adapter( 35 | self, 36 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 37 | weight_name: str, 38 | subfolder: str = '', 39 | trigger_word: str = 'img', 40 | **kwargs, 41 | ): 42 | """ 43 | Parameters: 44 | pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): 45 | Can be either: 46 | 47 | - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on 48 | the Hub. 49 | - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved 50 | with [`ModelMixin.save_pretrained`]. 51 | - A [torch state 52 | dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). 53 | 54 | weight_name (`str`): 55 | The weight name NOT the path to the weight. 56 | 57 | subfolder (`str`, defaults to `""`): 58 | The subfolder location of a model file within a larger model repository on the Hub or locally. 59 | 60 | trigger_word (`str`, *optional*, defaults to `"img"`): 61 | The trigger word is used to identify the position of class word in the text prompt, 62 | and it is recommended not to set it as a common word. 63 | This trigger word must be placed after the class word when used, otherwise, it will affect the performance of the personalized generation. 64 | """ 65 | 66 | # Load the main state dict first. 67 | cache_dir = kwargs.pop("cache_dir", None) 68 | force_download = kwargs.pop("force_download", False) 69 | resume_download = kwargs.pop("resume_download", False) 70 | proxies = kwargs.pop("proxies", None) 71 | local_files_only = kwargs.pop("local_files_only", None) 72 | token = kwargs.pop("token", None) 73 | revision = kwargs.pop("revision", None) 74 | 75 | user_agent = { 76 | "file_type": "attn_procs_weights", 77 | "framework": "pytorch", 78 | } 79 | 80 | if not isinstance(pretrained_model_name_or_path_or_dict, dict): 81 | model_file = _get_model_file( 82 | pretrained_model_name_or_path_or_dict, 83 | weights_name=weight_name, 84 | cache_dir=cache_dir, 85 | force_download=force_download, 86 | resume_download=resume_download, 87 | proxies=proxies, 88 | local_files_only=local_files_only, 89 | token=token, 90 | revision=revision, 91 | subfolder=subfolder, 92 | user_agent=user_agent, 93 | ) 94 | if weight_name.endswith(".safetensors"): 95 | state_dict = {"id_encoder": {}, "lora_weights": {}} 96 | with safe_open(model_file, framework="pt", device="cpu") as f: 97 | for key in f.keys(): 98 | if key.startswith("id_encoder."): 99 | state_dict["id_encoder"][key.replace("id_encoder.", "")] = f.get_tensor(key) 100 | elif key.startswith("lora_weights."): 101 | state_dict["lora_weights"][key.replace("lora_weights.", "")] = f.get_tensor(key) 102 | else: 103 | state_dict = torch.load(model_file, map_location="cpu") 104 | else: 105 | state_dict = pretrained_model_name_or_path_or_dict 106 | 107 | keys = list(state_dict.keys()) 108 | if keys != ["id_encoder", "lora_weights"]: 109 | raise ValueError("Required keys are (`id_encoder` and `lora_weights`) missing from the state dict.") 110 | 111 | self.trigger_word = trigger_word 112 | # load finetuned CLIP image encoder and fuse module here if it has not been registered to the pipeline yet 113 | print(f"Loading PhotoMaker components [1] id_encoder from [{pretrained_model_name_or_path_or_dict}]...") 114 | id_encoder = PhotoMakerIDEncoder() 115 | id_encoder.load_state_dict(state_dict["id_encoder"], strict=True) 116 | id_encoder = id_encoder.to(self.device, dtype=self.unet.dtype) 117 | self.id_encoder = id_encoder 118 | self.id_image_processor = CLIPImageProcessor() 119 | 120 | # load lora into models 121 | print(f"Loading PhotoMaker components [2] lora_weights from [{pretrained_model_name_or_path_or_dict}]") 122 | self.load_lora_weights(state_dict["lora_weights"], adapter_name="photomaker") 123 | 124 | # Add trigger word token 125 | if self.tokenizer is not None: 126 | self.tokenizer.add_tokens([self.trigger_word], special_tokens=True) 127 | 128 | self.tokenizer_2.add_tokens([self.trigger_word], special_tokens=True) 129 | 130 | def encode_prompt_with_trigger_word( 131 | self, 132 | prompt: str, 133 | prompt_2: Optional[str] = None, 134 | num_id_images: int = 1, 135 | device: Optional[torch.device] = None, 136 | prompt_embeds: Optional[torch.FloatTensor] = None, 137 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 138 | class_tokens_mask: Optional[torch.LongTensor] = None, 139 | ): 140 | device = device or self._execution_device 141 | 142 | if prompt is not None and isinstance(prompt, str): 143 | batch_size = 1 144 | elif prompt is not None and isinstance(prompt, list): 145 | batch_size = len(prompt) 146 | else: 147 | batch_size = prompt_embeds.shape[0] 148 | 149 | # Find the token id of the trigger word 150 | image_token_id = self.tokenizer_2.convert_tokens_to_ids(self.trigger_word) 151 | 152 | # Define tokenizers and text encoders 153 | tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] 154 | text_encoders = ( 155 | [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] 156 | ) 157 | 158 | if prompt_embeds is None: 159 | prompt_2 = prompt_2 or prompt 160 | prompt_embeds_list = [] 161 | prompts = [prompt, prompt_2] 162 | for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): 163 | input_ids = tokenizer.encode(prompt) # TODO: batch encode 164 | clean_index = 0 165 | clean_input_ids = [] 166 | class_token_index = [] 167 | # Find out the corresponding class word token based on the newly added trigger word token 168 | for i, token_id in enumerate(input_ids): 169 | if token_id == image_token_id: 170 | class_token_index.append(clean_index - 1) 171 | else: 172 | clean_input_ids.append(token_id) 173 | clean_index += 1 174 | 175 | if len(class_token_index) != 1: 176 | raise ValueError( 177 | f"PhotoMaker currently does not support multiple trigger words in a single prompt.\ 178 | Trigger word: {self.trigger_word}, Prompt: {prompt}." 179 | ) 180 | class_token_index = class_token_index[0] 181 | 182 | # Expand the class word token and corresponding mask 183 | class_token = clean_input_ids[class_token_index] 184 | clean_input_ids = clean_input_ids[:class_token_index] + [class_token] * num_id_images + \ 185 | clean_input_ids[class_token_index+1:] 186 | 187 | # Truncation or padding 188 | max_len = tokenizer.model_max_length 189 | if len(clean_input_ids) > max_len: 190 | clean_input_ids = clean_input_ids[:max_len] 191 | else: 192 | clean_input_ids = clean_input_ids + [tokenizer.pad_token_id] * ( 193 | max_len - len(clean_input_ids) 194 | ) 195 | 196 | class_tokens_mask = [True if class_token_index <= i < class_token_index+num_id_images else False \ 197 | for i in range(len(clean_input_ids))] 198 | 199 | clean_input_ids = torch.tensor(clean_input_ids, dtype=torch.long).unsqueeze(0) 200 | class_tokens_mask = torch.tensor(class_tokens_mask, dtype=torch.bool).unsqueeze(0) 201 | 202 | prompt_embeds = text_encoder( 203 | clean_input_ids.to(device), 204 | output_hidden_states=True, 205 | ) 206 | 207 | # We are only ALWAYS interested in the pooled output of the final text encoder 208 | pooled_prompt_embeds = prompt_embeds[0] 209 | prompt_embeds = prompt_embeds.hidden_states[-2] 210 | prompt_embeds_list.append(prompt_embeds) 211 | 212 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 213 | 214 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) 215 | class_tokens_mask = class_tokens_mask.to(device=device) # TODO: ignoring two-prompt case 216 | 217 | return prompt_embeds, pooled_prompt_embeds, class_tokens_mask 218 | 219 | @property 220 | def interrupt(self): 221 | return self._interrupt 222 | 223 | @torch.no_grad() 224 | def __call__( 225 | self, 226 | prompt: Union[str, List[str]] = None, 227 | prompt_2: Optional[Union[str, List[str]]] = None, 228 | height: Optional[int] = None, 229 | width: Optional[int] = None, 230 | num_inference_steps: int = 50, 231 | denoising_end: Optional[float] = None, 232 | guidance_scale: float = 5.0, 233 | negative_prompt: Optional[Union[str, List[str]]] = None, 234 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 235 | num_images_per_prompt: Optional[int] = 1, 236 | eta: float = 0.0, 237 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 238 | latents: Optional[torch.FloatTensor] = None, 239 | prompt_embeds: Optional[torch.FloatTensor] = None, 240 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 241 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 242 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 243 | output_type: Optional[str] = "pil", 244 | return_dict: bool = True, 245 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 246 | guidance_rescale: float = 0.0, 247 | original_size: Optional[Tuple[int, int]] = None, 248 | crops_coords_top_left: Tuple[int, int] = (0, 0), 249 | target_size: Optional[Tuple[int, int]] = None, 250 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 251 | callback_steps: int = 1, 252 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 253 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 254 | # Added parameters (for PhotoMaker) 255 | input_id_images: PipelineImageInput = None, 256 | start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future 257 | class_tokens_mask: Optional[torch.LongTensor] = None, 258 | prompt_embeds_text_only: Optional[torch.FloatTensor] = None, 259 | pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, 260 | ): 261 | if input_id_images is None: 262 | # remove stopword from prompt 263 | prompt = remove_word(prompt, self.trigger_word) 264 | prompt_2 = remove_word(prompt, self.trigger_word) 265 | negative_prompt = remove_word(negative_prompt, self.trigger_word) 266 | negative_prompt_2 = remove_word(negative_prompt_2, self.trigger_word) 267 | # inference with the original pipeline 268 | return super().__call__( 269 | prompt=prompt, 270 | prompt_2=prompt_2, 271 | height=height, 272 | width=width, 273 | num_inference_steps=num_inference_steps, 274 | denoising_end=denoising_end, 275 | guidance_scale=guidance_scale, 276 | negative_prompt=negative_prompt, 277 | negative_prompt_2=negative_prompt_2, 278 | num_images_per_prompt=num_images_per_prompt, 279 | eta=eta, 280 | generator=generator, 281 | latents=latents, 282 | prompt_embeds=prompt_embeds, 283 | negative_prompt_embeds=negative_prompt_embeds, 284 | pooled_prompt_embeds=pooled_prompt_embeds, 285 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 286 | output_type=output_type, 287 | return_dict=return_dict, 288 | # cross_attention_kwargs=cross_attention_kwargs, 289 | cross_attention_kwargs={"scale": 0.0}, # suppress photomaker adapter 290 | guidance_rescale=guidance_rescale, 291 | original_size=original_size, 292 | crops_coords_top_left=crops_coords_top_left, 293 | target_size=target_size, 294 | callback=callback, 295 | callback_steps=callback_steps, 296 | callback_on_step_end=callback_on_step_end, 297 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs 298 | ) 299 | else: 300 | return self.generate_id( 301 | prompt=prompt, 302 | prompt_2=prompt_2, 303 | height=height, 304 | width=width, 305 | num_inference_steps=num_inference_steps, 306 | denoising_end=denoising_end, 307 | guidance_scale=guidance_scale, 308 | negative_prompt=negative_prompt, 309 | negative_prompt_2=negative_prompt_2, 310 | num_images_per_prompt=num_images_per_prompt, 311 | eta=eta, 312 | generator=generator, 313 | latents=latents, 314 | prompt_embeds=prompt_embeds, 315 | negative_prompt_embeds=negative_prompt_embeds, 316 | pooled_prompt_embeds=pooled_prompt_embeds, 317 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 318 | output_type=output_type, 319 | return_dict=return_dict, 320 | cross_attention_kwargs=cross_attention_kwargs, 321 | guidance_rescale=guidance_rescale, 322 | original_size=original_size, 323 | crops_coords_top_left=crops_coords_top_left, 324 | target_size=target_size, 325 | callback=callback, 326 | callback_steps=callback_steps, 327 | callback_on_step_end=callback_on_step_end, 328 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 329 | input_id_images=input_id_images, 330 | start_merge_step=start_merge_step, # TODO: change to `style_strength_ratio` in the future 331 | class_tokens_mask=class_tokens_mask, 332 | prompt_embeds_text_only=prompt_embeds_text_only, 333 | pooled_prompt_embeds_text_only=pooled_prompt_embeds_text_only 334 | ) 335 | 336 | @torch.no_grad() 337 | def generate_id( 338 | self, 339 | prompt: Union[str, List[str]] = None, 340 | prompt_2: Optional[Union[str, List[str]]] = None, 341 | height: Optional[int] = None, 342 | width: Optional[int] = None, 343 | num_inference_steps: int = 50, 344 | denoising_end: Optional[float] = None, 345 | guidance_scale: float = 5.0, 346 | negative_prompt: Optional[Union[str, List[str]]] = None, 347 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 348 | num_images_per_prompt: Optional[int] = 1, 349 | eta: float = 0.0, 350 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 351 | latents: Optional[torch.FloatTensor] = None, 352 | prompt_embeds: Optional[torch.FloatTensor] = None, 353 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 354 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 355 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 356 | output_type: Optional[str] = "pil", 357 | return_dict: bool = True, 358 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 359 | guidance_rescale: float = 0.0, 360 | original_size: Optional[Tuple[int, int]] = None, 361 | crops_coords_top_left: Tuple[int, int] = (0, 0), 362 | target_size: Optional[Tuple[int, int]] = None, 363 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 364 | callback_steps: int = 1, 365 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 366 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 367 | # Added parameters (for PhotoMaker) 368 | input_id_images: PipelineImageInput = None, 369 | start_merge_step: int = 0, # TODO: change to `style_strength_ratio` in the future 370 | class_tokens_mask: Optional[torch.LongTensor] = None, 371 | prompt_embeds_text_only: Optional[torch.FloatTensor] = None, 372 | pooled_prompt_embeds_text_only: Optional[torch.FloatTensor] = None, 373 | ): 374 | r""" 375 | Function invoked when calling the pipeline for generation. 376 | Only the parameters introduced by PhotoMaker are discussed here. 377 | For explanations of the previous parameters in StableDiffusionXLPipeline, please refer to https://github.com/huggingface/diffusers/blob/v0.25.0/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py 378 | 379 | Args: 380 | input_id_images (`PipelineImageInput`, *optional*): 381 | Input ID Image to work with PhotoMaker. 382 | class_tokens_mask (`torch.LongTensor`, *optional*): 383 | Pre-generated class token. When the `prompt_embeds` parameter is provided in advance, it is necessary to prepare the `class_tokens_mask` beforehand for marking out the position of class word. 384 | prompt_embeds_text_only (`torch.FloatTensor`, *optional*): 385 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 386 | provided, text embeddings will be generated from `prompt` input argument. 387 | pooled_prompt_embeds_text_only (`torch.FloatTensor`, *optional*): 388 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 389 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 390 | 391 | Returns: 392 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: 393 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a 394 | `tuple`. When returning a tuple, the first element is a list with the generated images. 395 | """ 396 | # 0. Default height and width to unet 397 | height = height or self.unet.config.sample_size * self.vae_scale_factor 398 | width = width or self.unet.config.sample_size * self.vae_scale_factor 399 | 400 | original_size = original_size or (height, width) 401 | target_size = target_size or (height, width) 402 | 403 | # 1. Check inputs. Raise error if not correct 404 | self.check_inputs( 405 | prompt, 406 | prompt_2, 407 | height, 408 | width, 409 | callback_steps, 410 | negative_prompt, 411 | negative_prompt_2, 412 | prompt_embeds, 413 | negative_prompt_embeds, 414 | pooled_prompt_embeds, 415 | negative_pooled_prompt_embeds, 416 | callback_on_step_end_tensor_inputs, 417 | ) 418 | 419 | self._interrupt = False 420 | 421 | # 422 | if prompt_embeds is not None and class_tokens_mask is None: 423 | raise ValueError( 424 | "If `prompt_embeds` are provided, `class_tokens_mask` also have to be passed. Make sure to generate `class_tokens_mask` from the same tokenizer that was used to generate `prompt_embeds`." 425 | ) 426 | # check the input id images 427 | if input_id_images is None: 428 | raise ValueError( 429 | "Provide `input_id_images`. Cannot leave `input_id_images` undefined for PhotoMaker pipeline." 430 | ) 431 | if not isinstance(input_id_images, list): 432 | input_id_images = [input_id_images] 433 | 434 | # 2. Define call parameters 435 | if prompt is not None and isinstance(prompt, str): 436 | batch_size = 1 437 | prompt = [prompt] 438 | elif prompt is not None and isinstance(prompt, list): 439 | batch_size = len(prompt) 440 | else: 441 | batch_size = prompt_embeds.shape[0] 442 | 443 | device = self._execution_device 444 | 445 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 446 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 447 | # corresponds to doing no classifier free guidance. 448 | do_classifier_free_guidance = guidance_scale >= 1.0 449 | 450 | assert do_classifier_free_guidance 451 | 452 | # 3. Encode input prompt 453 | num_id_images = len(input_id_images) 454 | if isinstance(prompt, list): 455 | prompt_arr = prompt 456 | negative_prompt_embeds_arr = [] 457 | prompt_embeds_text_only_arr = [] 458 | prompt_embeds_arr = [] 459 | latents_arr = [] 460 | add_time_ids_arr = [] 461 | negative_pooled_prompt_embeds_arr = [] 462 | pooled_prompt_embeds_text_only_arr = [] 463 | pooled_prompt_embeds_arr = [] 464 | for prompt in prompt_arr: 465 | ( 466 | prompt_embeds, 467 | pooled_prompt_embeds, 468 | class_tokens_mask, 469 | ) = self.encode_prompt_with_trigger_word( 470 | prompt=prompt, 471 | prompt_2=prompt_2, 472 | device=device, 473 | num_id_images=num_id_images, 474 | prompt_embeds=prompt_embeds, 475 | pooled_prompt_embeds=pooled_prompt_embeds, 476 | class_tokens_mask=class_tokens_mask, 477 | ) 478 | 479 | # 4. Encode input prompt without the trigger word for delayed conditioning 480 | # encode, remove trigger word token, then decode 481 | tokens_text_only = self.tokenizer.encode(prompt, add_special_tokens=False) 482 | trigger_word_token = self.tokenizer.convert_tokens_to_ids(self.trigger_word) 483 | tokens_text_only.remove(trigger_word_token) 484 | prompt_text_only = self.tokenizer.decode(tokens_text_only, add_special_tokens=False) 485 | ( 486 | prompt_embeds_text_only, 487 | negative_prompt_embeds, 488 | pooled_prompt_embeds_text_only, # TODO: replace the pooled_prompt_embeds with text only prompt 489 | negative_pooled_prompt_embeds, 490 | ) = self.encode_prompt( 491 | prompt=prompt_text_only, 492 | prompt_2=prompt_2, 493 | device=device, 494 | num_images_per_prompt=num_images_per_prompt, 495 | do_classifier_free_guidance=True, 496 | negative_prompt=negative_prompt, 497 | negative_prompt_2=negative_prompt_2, 498 | prompt_embeds=prompt_embeds_text_only, 499 | negative_prompt_embeds=negative_prompt_embeds, 500 | pooled_prompt_embeds=pooled_prompt_embeds_text_only, 501 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 502 | ) 503 | 504 | # 5. Prepare the input ID images 505 | dtype = next(self.id_encoder.parameters()).dtype 506 | if not isinstance(input_id_images[0], torch.Tensor): 507 | id_pixel_values = self.id_image_processor(input_id_images, return_tensors="pt").pixel_values 508 | 509 | id_pixel_values = id_pixel_values.unsqueeze(0).to(device=device, dtype=dtype) # TODO: multiple prompts 510 | 511 | # 6. Get the update text embedding with the stacked ID embedding 512 | prompt_embeds = self.id_encoder(id_pixel_values, prompt_embeds, class_tokens_mask) 513 | 514 | bs_embed, seq_len, _ = prompt_embeds.shape 515 | # duplicate text embeddings for each generation per prompt, using mps friendly method 516 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 517 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 518 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( 519 | bs_embed * num_images_per_prompt, -1 520 | ) 521 | 522 | 523 | negative_prompt_embeds_arr.append(negative_prompt_embeds) 524 | negative_prompt_embeds = None 525 | negative_pooled_prompt_embeds_arr.append(negative_pooled_prompt_embeds) 526 | negative_pooled_prompt_embeds = None 527 | prompt_embeds_text_only_arr.append(prompt_embeds_text_only) 528 | prompt_embeds_text_only = None 529 | prompt_embeds_arr.append(prompt_embeds) 530 | prompt_embeds = None 531 | pooled_prompt_embeds_arr.append(pooled_prompt_embeds) 532 | pooled_prompt_embeds = None 533 | pooled_prompt_embeds_text_only_arr.append(pooled_prompt_embeds_text_only) 534 | pooled_prompt_embeds_text_only = None 535 | # 7. Prepare timesteps 536 | self.scheduler.set_timesteps(num_inference_steps, device=device) 537 | timesteps = self.scheduler.timesteps 538 | 539 | negative_prompt_embeds = torch.cat(negative_prompt_embeds_arr ,dim =0) 540 | prompt_embeds = torch.cat(prompt_embeds_arr ,dim = 0) 541 | 542 | prompt_embeds_text_only = torch.cat(prompt_embeds_text_only_arr ,dim = 0) 543 | pooled_prompt_embeds_text_only = torch.cat(pooled_prompt_embeds_text_only_arr ,dim = 0) 544 | 545 | negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds_arr ,dim = 0) 546 | pooled_prompt_embeds = torch.cat(pooled_prompt_embeds_arr,dim = 0) 547 | # 8. Prepare latent variables 548 | num_channels_latents = self.unet.config.in_channels 549 | latents = self.prepare_latents( 550 | batch_size * num_images_per_prompt, 551 | num_channels_latents, 552 | height, 553 | width, 554 | prompt_embeds.dtype, 555 | device, 556 | generator, 557 | latents, 558 | ) 559 | 560 | # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 561 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 562 | 563 | # 10. Prepare added time ids & embeddings 564 | if self.text_encoder_2 is None: 565 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) 566 | else: 567 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim 568 | 569 | add_time_ids = self._get_add_time_ids( 570 | original_size, 571 | crops_coords_top_left, 572 | target_size, 573 | dtype=prompt_embeds.dtype, 574 | text_encoder_projection_dim=text_encoder_projection_dim, 575 | ) 576 | add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) 577 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 578 | 579 | # 11. Denoising loop 580 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 581 | with self.progress_bar(total=num_inference_steps) as progress_bar: 582 | for i, t in enumerate(timesteps): 583 | if self.interrupt: 584 | continue 585 | 586 | latent_model_input = ( 587 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 588 | ) 589 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 590 | 591 | if i <= start_merge_step: 592 | current_prompt_embeds = torch.cat( 593 | [negative_prompt_embeds, prompt_embeds_text_only], dim=0 594 | ) 595 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds_text_only], dim=0) 596 | else: 597 | current_prompt_embeds = torch.cat( 598 | [negative_prompt_embeds, prompt_embeds], dim=0 599 | ) 600 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0) 601 | # predict the noise residual 602 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 603 | noise_pred = self.unet( 604 | latent_model_input, 605 | t, 606 | encoder_hidden_states=current_prompt_embeds, 607 | cross_attention_kwargs=cross_attention_kwargs, 608 | added_cond_kwargs=added_cond_kwargs, 609 | return_dict=False, 610 | )[0] 611 | # perform guidance 612 | if do_classifier_free_guidance: 613 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 614 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 615 | 616 | if do_classifier_free_guidance and guidance_rescale > 0.0: 617 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 618 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 619 | 620 | # compute the previous noisy sample x_t -> x_t-1 621 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 622 | 623 | if callback_on_step_end is not None: 624 | callback_kwargs = {} 625 | for k in callback_on_step_end_tensor_inputs: 626 | callback_kwargs[k] = locals()[k] 627 | 628 | ck_outputs = callback_on_step_end(self, i, t, callback_kwargs) 629 | 630 | latents = callback_outputs.pop("latents", latents) 631 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 632 | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) 633 | add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds) 634 | 635 | # call the callback, if provided 636 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 637 | progress_bar.update() 638 | if callback is not None and i % callback_steps == 0: 639 | step_idx = i // getattr(self.scheduler, "order", 1) 640 | callback(step_idx, t, latents) 641 | 642 | # make sure the VAE is in float32 mode, as it overflows in float16 643 | if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: 644 | self.upcast_vae() 645 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 646 | 647 | if not output_type == "latent": 648 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 649 | else: 650 | image = latents 651 | return StableDiffusionXLPipelineOutput(images=image) 652 | 653 | image = self.image_processor.postprocess(image, output_type=output_type) 654 | 655 | # Offload all models 656 | self.maybe_free_model_hooks() 657 | 658 | if not return_dict: 659 | return (image,) 660 | 661 | return StableDiffusionXLPipelineOutput(images=image) 662 | -------------------------------------------------------------------------------- /src/rp_handler.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import os 4 | import argparse 5 | # runpod utils 6 | import runpod 7 | from runpod.serverless.utils.rp_validator import validate 8 | from runpod.serverless.utils.rp_upload import upload_in_memory_object 9 | from runpod.serverless.utils import rp_download, rp_cleanup 10 | # predictor 11 | import torch 12 | from comic_generator_xl import ComicGeneratorXL 13 | from rp_schema import INPUT_SCHEMA 14 | # utils 15 | from utils import compress_images_to_zip, download_image 16 | 17 | 18 | # Worker params 19 | model_dir = os.getenv("WORKER_MODEL_DIR", "/model") 20 | id_length = int(os.getenv("WORKER_ID_LENGTH", 4)) 21 | total_length = int(os.getenv("WORKER_TOTAL_LENGTH", 5)) 22 | device = "cuda" if os.getenv("WORKER_USE_CUDA").lower() == "true" else "cpu" 23 | scheduler_type = os.getenv("WORKER_SCHEDULER_TYPE", "euler").lower() 24 | 25 | 26 | def bytesio_to_base64(bytes_io: io.BytesIO) -> str: 27 | """ Convert BytesIO object to base64 string """ 28 | # Extract bytes from BytesIO object 29 | byte_data = bytes_io.getvalue() 30 | # Encode these bytes to a base64 string 31 | base64_encoded = base64.b64encode(byte_data) 32 | # Convert bytes to string 33 | base64_string = base64_encoded.decode('utf-8') 34 | return base64_string 35 | 36 | 37 | def upload_result(result: io.BytesIO, key: str) -> str: 38 | """ Uploads result to S3 bucket if it is available, otherwise returns base64 encoded file. """ 39 | # Upload to S3 40 | if os.environ.get('BUCKET_ENDPOINT_URL', False): 41 | return upload_in_memory_object( 42 | key, 43 | result.getvalue(), 44 | bucket_creds = { 45 | "endpointUrl": os.environ.get('BUCKET_ENDPOINT_URL', None), 46 | "accessId": os.environ.get('BUCKET_ACCESS_KEY_ID', None), 47 | "accessSecret": os.environ.get('BUCKET_SECRET_ACCESS_KEY', None) 48 | } 49 | ) 50 | # Base64 encode 51 | return bytesio_to_base64(result) 52 | 53 | 54 | def run(job): 55 | job_input = job['input'] 56 | 57 | # Input validation 58 | validated_input = validate(job_input, INPUT_SCHEMA) 59 | 60 | if 'errors' in validated_input: 61 | return {"error": validated_input['errors']} 62 | validated_input = validated_input['validated_input'] 63 | 64 | # download image 65 | if validated_input["image_ref"] != "": 66 | image_ref = download_image(validated_input["image_ref"]) 67 | else: 68 | image_ref = None 69 | 70 | # Inference image generator 71 | images = MODEL( 72 | prompts = validated_input["prompts"], 73 | negative_prompt = validated_input.get("negative_prompt", None), 74 | width = validated_input.get("width", 768), 75 | height = validated_input.get("height", 768), 76 | sa32 = validated_input.get("sa32", 0.5), 77 | sa64 = validated_input.get("sa64", 0.5), 78 | guidance_scale = validated_input.get("guidance_scale", 5.0), 79 | num_inference_steps = validated_input.get("num_inference_steps", 25), 80 | seed = validated_input.get("seed", 42), 81 | image_ref = image_ref 82 | ) 83 | 84 | # Upload output object 85 | zip_data = compress_images_to_zip(images) 86 | output_data = upload_result(zip_data, f"{job['id']}.zip") 87 | job_output = { 88 | "output_data": output_data 89 | } 90 | 91 | # Remove downloaded input objects 92 | rp_cleanup.clean(['input_objects']) 93 | 94 | return job_output 95 | 96 | 97 | if __name__ == "__main__": 98 | MODEL = ComicGeneratorXL( 99 | model_name=model_dir, 100 | id_length=id_length, 101 | total_length=total_length, 102 | device=device, 103 | torch_dtype=torch.float16 if device == "cuda" else torch.float32, 104 | scheduler_type=scheduler_type 105 | ) 106 | 107 | runpod.serverless.start({"handler": run}) 108 | -------------------------------------------------------------------------------- /src/rp_schema.py: -------------------------------------------------------------------------------- 1 | INPUT_SCHEMA = { 2 | "prompts": { 3 | "type": list, 4 | "required": True 5 | }, 6 | "negative_prompt": { 7 | "type": str, 8 | "required": True 9 | }, 10 | "width": { 11 | "type": int, 12 | "required": True 13 | }, 14 | "height": { 15 | "type": int, 16 | "required": True 17 | }, 18 | "sa32": { 19 | "type": float, 20 | "required": True 21 | }, 22 | "sa64": { 23 | "type": float, 24 | "required": True, 25 | }, 26 | "guidance_scale": { 27 | "type": float, 28 | "required": True 29 | }, 30 | "num_inference_steps": { 31 | "type": int, 32 | "required": True 33 | }, 34 | "seed": { 35 | "type": int, 36 | "required": True 37 | }, 38 | "image_ref": { 39 | "type": str, 40 | "required": True 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | import requests 4 | import io 5 | import zipfile 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from PIL import Image 11 | 12 | 13 | def download_image(url: str) -> Image.Image: 14 | """ Download an image from a URL 15 | 16 | Args: 17 | url (str): Image URL 18 | 19 | Returns: 20 | Image.Image: Image object 21 | """ 22 | response = requests.get(url) 23 | response.raise_for_status() 24 | img = Image.open(io.BytesIO(response.content)) 25 | return img 26 | 27 | 28 | def compress_images_to_zip(images: list[Image.Image]) -> io.BytesIO: 29 | """ Compress images to a zip file 30 | 31 | Args: 32 | images (list[Image.Image]): List of images 33 | 34 | Returns: 35 | io.BytesIO: Zip file 36 | """ 37 | zip_data = io.BytesIO() 38 | with zipfile.ZipFile(zip_data, 'w', zipfile.ZIP_DEFLATED) as zip_file: 39 | for i, image in enumerate(images): 40 | image_data = io.BytesIO() 41 | image.save(image_data, format="PNG") 42 | image_data.seek(0) 43 | zip_file.writestr(f"{i}.png", image_data.getvalue()) 44 | return zip_data 45 | 46 | 47 | def remove_word(input_string: str, word: str) -> str: 48 | """ Remove a word from a string 49 | 50 | Args: 51 | input_string (str): Input string 52 | word (str): Word to remove 53 | 54 | Returns: 55 | str: Cleaned string 56 | """ 57 | try: 58 | # Pattern to find the stopword with optional spaces around it 59 | pattern = r'\s*\b' + re.escape(word) + r'\b\s*' 60 | # Replace the stopword with a single space to handle potential extra spaces 61 | cleaned_string = re.sub(pattern, ' ', input_string) 62 | # Strip leading/trailing spaces and normalize multiple spaces to one 63 | cleaned_string = re.sub(r'\s+', ' ', cleaned_string).strip() 64 | return cleaned_string 65 | except Exception as e: 66 | return input_string 67 | 68 | 69 | def setup_seed(seed: int): 70 | """ Set random seed for reproducibility 71 | 72 | Args: 73 | seed (int): random seed 74 | """ 75 | torch.manual_seed(seed) 76 | torch.cuda.manual_seed_all(seed) 77 | np.random.seed(seed) 78 | random.seed(seed) 79 | torch.backends.cudnn.deterministic = True 80 | 81 | 82 | def is_torch2_available() -> bool: 83 | """ Check if torch2 is available 84 | 85 | Returns: 86 | bool: True if torch2 is available 87 | """ 88 | return hasattr(F, "scaled_dot_product_attention") 89 | 90 | 91 | def cal_attn_mask_xl( 92 | total_length: int, 93 | id_length: int, 94 | sa32: float, 95 | sa64: float, 96 | height: int, 97 | width: int, 98 | device: str = "cuda", 99 | dtype: torch.dtype = torch.float16 100 | ) -> torch.Tensor: 101 | """ Calculate the attention mask for SDXL 102 | 103 | Args: 104 | total_length (int): Total length 105 | id_length (int): ID length 106 | sa32 (float): Attention mask for 32x32 107 | sa64 (float): Attention mask for 64x64 108 | height (int): Image height 109 | width (int): Image width 110 | device (str): Device (default: "cuda") 111 | dtype (torch.dtype): Data type (default: torch.float16) 112 | 113 | Returns: 114 | torch.Tensor: Attention mask 115 | """ 116 | nums_1024 = (height // 32) * (width // 32) 117 | nums_4096 = (height // 16) * (width // 16) 118 | bool_matrix1024 = torch.rand((1, total_length * nums_1024),device = device,dtype = dtype) < sa32 119 | bool_matrix4096 = torch.rand((1, total_length * nums_4096),device = device,dtype = dtype) < sa64 120 | bool_matrix1024 = bool_matrix1024.repeat(total_length,1) 121 | bool_matrix4096 = bool_matrix4096.repeat(total_length,1) 122 | for i in range(total_length): 123 | bool_matrix1024[i:i+1, id_length*nums_1024:] = False 124 | bool_matrix4096[i:i+1, id_length*nums_4096:] = False 125 | bool_matrix1024[i:i+1, i*nums_1024:(i+1)*nums_1024] = True 126 | bool_matrix4096[i:i+1, i*nums_4096:(i+1)*nums_4096] = True 127 | mask1024 = bool_matrix1024.unsqueeze(1).repeat(1,nums_1024,1).reshape(-1,total_length * nums_1024) 128 | mask4096 = bool_matrix4096.unsqueeze(1).repeat(1,nums_4096,1).reshape(-1,total_length * nums_4096) 129 | return mask1024, mask4096 130 | 131 | 132 | class AttnProcessor(nn.Module): 133 | r""" 134 | Default processor for performing attention-related computations. 135 | """ 136 | def __init__( 137 | self, 138 | hidden_size=None, 139 | cross_attention_dim=None, 140 | ): 141 | super().__init__() 142 | 143 | def __call__( 144 | self, 145 | attn, 146 | hidden_states, 147 | encoder_hidden_states=None, 148 | attention_mask=None, 149 | temb=None, 150 | ): 151 | residual = hidden_states 152 | 153 | if attn.spatial_norm is not None: 154 | hidden_states = attn.spatial_norm(hidden_states, temb) 155 | 156 | input_ndim = hidden_states.ndim 157 | 158 | if input_ndim == 4: 159 | batch_size, channel, height, width = hidden_states.shape 160 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 161 | 162 | batch_size, sequence_length, _ = ( 163 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 164 | ) 165 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 166 | 167 | if attn.group_norm is not None: 168 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 169 | 170 | query = attn.to_q(hidden_states) 171 | 172 | if encoder_hidden_states is None: 173 | encoder_hidden_states = hidden_states 174 | elif attn.norm_cross: 175 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 176 | 177 | key = attn.to_k(encoder_hidden_states) 178 | value = attn.to_v(encoder_hidden_states) 179 | 180 | query = attn.head_to_batch_dim(query) 181 | key = attn.head_to_batch_dim(key) 182 | value = attn.head_to_batch_dim(value) 183 | 184 | attention_probs = attn.get_attention_scores(query, key, attention_mask) 185 | hidden_states = torch.bmm(attention_probs, value) 186 | hidden_states = attn.batch_to_head_dim(hidden_states) 187 | 188 | # linear proj 189 | hidden_states = attn.to_out[0](hidden_states) 190 | # dropout 191 | hidden_states = attn.to_out[1](hidden_states) 192 | 193 | if input_ndim == 4: 194 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 195 | 196 | if attn.residual_connection: 197 | hidden_states = hidden_states + residual 198 | 199 | hidden_states = hidden_states / attn.rescale_output_factor 200 | 201 | return hidden_states 202 | 203 | 204 | class AttnProcessor2_0(torch.nn.Module): 205 | r""" 206 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 207 | """ 208 | def __init__( 209 | self, 210 | hidden_size=None, 211 | cross_attention_dim=None, 212 | ): 213 | super().__init__() 214 | if not hasattr(F, "scaled_dot_product_attention"): 215 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 216 | 217 | def __call__( 218 | self, 219 | attn, 220 | hidden_states, 221 | encoder_hidden_states=None, 222 | attention_mask=None, 223 | temb=None, 224 | ): 225 | residual = hidden_states 226 | 227 | if attn.spatial_norm is not None: 228 | hidden_states = attn.spatial_norm(hidden_states, temb) 229 | 230 | input_ndim = hidden_states.ndim 231 | 232 | if input_ndim == 4: 233 | batch_size, channel, height, width = hidden_states.shape 234 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 235 | 236 | batch_size, sequence_length, _ = ( 237 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 238 | ) 239 | 240 | if attention_mask is not None: 241 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 242 | # scaled_dot_product_attention expects attention_mask shape to be 243 | # (batch, heads, source_length, target_length) 244 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 245 | 246 | if attn.group_norm is not None: 247 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 248 | 249 | query = attn.to_q(hidden_states) 250 | 251 | if encoder_hidden_states is None: 252 | encoder_hidden_states = hidden_states 253 | elif attn.norm_cross: 254 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 255 | 256 | key = attn.to_k(encoder_hidden_states) 257 | value = attn.to_v(encoder_hidden_states) 258 | 259 | inner_dim = key.shape[-1] 260 | head_dim = inner_dim // attn.heads 261 | 262 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 263 | 264 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 265 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 266 | 267 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 268 | # TODO: add support for attn.scale when we move to Torch 2.1 269 | hidden_states = F.scaled_dot_product_attention( 270 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 271 | ) 272 | 273 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 274 | hidden_states = hidden_states.to(query.dtype) 275 | 276 | # linear proj 277 | hidden_states = attn.to_out[0](hidden_states) 278 | # dropout 279 | hidden_states = attn.to_out[1](hidden_states) 280 | 281 | if input_ndim == 4: 282 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 283 | 284 | if attn.residual_connection: 285 | hidden_states = hidden_states + residual 286 | 287 | hidden_states = hidden_states / attn.rescale_output_factor 288 | 289 | return hidden_states 290 | --------------------------------------------------------------------------------