├── .gitignore ├── LICENSE ├── Makefile ├── README.md ├── assets ├── bowl.svg ├── chiwawa.svg ├── doggy.svg ├── pizza_example.svg ├── teaser.svg └── tree.svg ├── notebooks ├── Attention Re-weighting.ipynb └── Prompt Editing.ipynb ├── ptp_utils.py ├── requirements.txt ├── requirements_dev.txt ├── seq_aligner.py ├── setup.cfg └── stable_diffusion.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | # Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | # poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | # pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | ptp_dev/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # .PHONY defines parts of the makefile that are not dependant on any specific file 2 | # This is most often used to store functions 3 | .PHONY = init format format_check format_notebooks lint type_check all 4 | 5 | init: 6 | @echo "Create a development python environment" 7 | python3 -m pip install virtualenv 8 | python3 -m venv ptp_dev 9 | . ptp_dev/bin/activate && cat requirements.txt | xargs -n 1 pip install 10 | . ptp_dev/bin/activate && pip install -r requirements_dev.txt 11 | 12 | format: 13 | @echo "Format code according to isort" 14 | . ptp_dev/bin/activate && isort *.py 15 | @echo "Format code according to black" 16 | . ptp_dev/bin/activate && black *.py 17 | 18 | format_check: 19 | @echo "Check code format according to isort" 20 | . ptp_dev/bin/activate && isort *.py --check 21 | @echo "Check code format according to black" 22 | . ptp_dev/bin/activate && black *.py --check 23 | 24 | format_notebooks: 25 | @echo "Format notebooks according to isort" 26 | . ptp_dev/bin/activate && nbqa isort . 27 | @echo "Format notebooks according to black" 28 | . ptp_dev/bin/activate && nbqa black . 29 | 30 | lint: 31 | @echo "Linter check: Flake8" 32 | . ptp_dev/bin/activate && flake8 . 33 | 34 | type_check: 35 | @echo "Type-test check: mypy" 36 | . ptp_dev/bin/activate && mypy . 37 | 38 | all: format format_notebooks lint type_check 39 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Prompt-to-Prompt: Tensorflow Implementation 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1AoRRd-6oXtFEfx9Ff85GNuTcwSssb5zz?usp=sharing) [![Hugging Face Demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/fchollet/stable-diffusion) 4 | 5 | 6 | 7 | ### Unofficial Implementation of the paper Prompt-to-Prompt Image Editing with Cross Attention Control 8 | 9 | ![teaser](assets/teaser.svg) 10 | 11 | [Link to the paper](https://arxiv.org/abs/2208.01626) | [Official PyTorch implementation](https://github.com/google/prompt-to-prompt/) | [Project page](https://prompt-to-prompt.github.io/) 12 | 13 | This repository contains the Tensorflow/Keras code implementation for the paper "**[Prompt-to-Prompt Image Editing with Cross Attention Control](https://arxiv.org/abs/2208.01626)**". 14 | 15 | # 🚀 Quickstart 16 | 17 | Current state-of-the-art methods require the user to provide a spatial mask to localize the edit which ignores the original structure and content within the masked region. 18 | The paper proposes a novel technique to edit the generated content of large-scale language models such as [DALL·E 2](https://openai.com/dall-e-2/), [Imagen](https://imagen.research.google/) or [Stable Diffusion](https://github.com/CompVis/stable-diffusion), **by only manipulating the text of the original parsed prompt**. 19 | 20 | To achieve this result, the authors present the *Prompt-to-Prompt* framework comprised of two functionalities: 21 | 22 | - **Prompt Editing**: where the key idea to edit the generated images is to inject cross-attention maps during the diffusion process, controlling which pixels attend to which tokens of the prompt text. 23 | 24 | - **Attention Re-weighting**: that amplifies or attenuates the effect of a word in the generated image. This is done by first attributing a weight to each token and later scaling the attention map assigned to the token. It's a nice alternative to **negative prompting** and **multi-prompting**. 25 | 26 | ## :gear: Installation 27 | 28 | Install dependencies using the `requirements.txt`. 29 | 30 | ```bash 31 | pip install -r requirements.txt 32 | ``` 33 | 34 | Essentially, you need to have installed [TensorFlow](https://github.com/tensorflow/tensorflow) and [Keras-cv](https://github.com/keras-team/keras-cv/). 35 | ## 📚 Notebooks 36 | 37 | Try it yourself: 38 | 39 | - [**Prompt-to-Prompt: Prompt Editing** - Stable Diffusion](notebooks/Prompt%20Editing.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1AoRRd-6oXtFEfx9Ff85GNuTcwSssb5zz?usp=sharing)
40 | Notebook with examples for the *Prompt-to-Prompt* prompt editing approach for Stable Diffusion. 41 | 42 | - [**Prompt-to-Prompt: Attention Re-weighting** - Stable Diffusion](notebooks/Attention%20Re-weighting.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1UcIFg2Nd_LVaO3-UPPysCSVCNvbKjO11?usp=sharing)
43 | Notebook with examples for the *Prompt-to-Prompt* attention re-weighting approach for Stable Diffusion. 44 | 45 | # :dart: Prompt-to-Prompt Examples 46 | 47 | To start using the *Prompt-to-Prompt* framework, you first need to set up a Tensorflow [strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy) for running computations across multiple devices (in case you have many). 48 | 49 | For example, you can check the available hardware with: 50 | 51 | ```python 52 | gpus = tf.config.list_physical_devices("GPU") 53 | tpus = tf.config.list_physical_devices("TPU") 54 | print(f"Num GPUs Available: {len(gpus)} | Num TPUs Available: {len(tpus)}") 55 | ``` 56 | 57 | And adjust accordingly to your needs: 58 | 59 | ```python 60 | import tensorflow as tf 61 | 62 | # For running on multiple GPUs 63 | strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1", ...]) 64 | # To get the default strategy 65 | strategy = tf.distribute.get_strategy() 66 | ... 67 | ``` 68 | 69 | ## Prompt Editing 70 | 71 | Once the strategy is set, you can start generating images just like in [Keras-cv](https://github.com/keras-team/keras-cv/): 72 | 73 | ```python 74 | # Imports 75 | import tensorflow as tf 76 | from stable_diffusion import StableDiffusion 77 | 78 | generator = StableDiffusion( 79 | strategy=strategy, 80 | img_height=512, 81 | img_width=512, 82 | jit_compile=False, 83 | ) 84 | 85 | # Generate text-to-image 86 | img = generator.text_to_image( 87 | prompt="a photo of a chiwawa with sunglasses and a bandana", 88 | num_steps=50, 89 | unconditional_guidance_scale=8, 90 | seed=5681067, 91 | batch_size=1, 92 | ) 93 | # Generate Prompt-to-Prompt 94 | img_edit = generator.text_to_image_ptp( 95 | prompt="a photo of a chiwawa with sunglasses and a bandana", 96 | prompt_edit="a photo of a chiwawa with sunglasses and a pirate bandana", 97 | num_steps=50, 98 | unconditional_guidance_scale=8, 99 | cross_attn2_replace_steps_start=0.0, 100 | cross_attn2_replace_steps_end=1.0, 101 | cross_attn1_replace_steps_start=0.8, 102 | cross_attn1_replace_steps_end=1.0, 103 | seed=5681067, 104 | batch_size=1, 105 | ) 106 | ``` 107 | 108 | This generates the original and pirate bandana images shown below. You can play around and change the `` and `` attributes and many others! 109 | 110 | ![teaser](assets/chiwawa.svg) 111 | 112 | Another example of prompt editing where one can control the content of the basket just by replacing a couple of words in the prompt: 113 | 114 | ```python 115 | img_edit = generator.text_to_image_ptp( 116 | prompt="a photo of basket with apples", 117 | prompt_edit="a photo of basket with oranges", 118 | num_steps=50, 119 | unconditional_guidance_scale=8, 120 | cross_attn2_replace_steps_start=0.0, 121 | cross_attn2_replace_steps_end=1.0, 122 | cross_attn1_replace_steps_start=0.0, 123 | cross_attn1_replace_steps_end=1.0, 124 | seed=1597337, 125 | batch_size=1, 126 | ) 127 | ``` 128 | 129 | The image below showcases examples where only the word `` was replaced with other fruits or animals. Try changing ` to other recipients (e.g. bowl or nest) and see what happens! 130 | 131 | ![teaser](assets/bowl.svg) 132 | 133 | ## Attetion Re-weighting 134 | 135 | To manipulate the relative importance of tokens, we've added an argument to pass in both the `text_to_image` and `text_to_image_ptp` methods. You can create an array of weights using our method `create_prompt_weights`. 136 | 137 | For example, you generated a pizza that doesn't have enough pineapple on it, you can edit the weights of your prompt: 138 | 139 | ```python 140 | prompt = "a photo of a pizza with pineapple" 141 | prompt_weights = generator.create_prompt_weights(prompt, [('pineapple', 2)]) 142 | ``` 143 | 144 | This will create an array with 1's except on the `pineapple` word position where it will be a 2. 145 | 146 | To generate a pizza with more pineapple (yak!), you just need to pass the variable `prompt_weights` to the `text_to_image` method: 147 | 148 | ```python 149 | img = generator.text_to_image( 150 | prompt="a photo of a pizza with pineapple", 151 | num_steps=50, 152 | unconditional_guidance_scale=8, 153 | prompt_weights=prompt_weights, 154 | seed=1234, 155 | batch_size=1, 156 | ) 157 | ``` 158 | 159 | ![teaser](assets/pizza_example.svg) 160 | 161 | Now you want to reduce the amount of blossom in a tree: 162 | 163 | ```python 164 | prompt = "A photo of a blossom tree" 165 | prompt_weights = generator.create_prompt_weights(prompt, [('blossom', -1)]) 166 | 167 | img = generator.text_to_image( 168 | prompt="A photo of a blossom tree", 169 | num_steps=50, 170 | unconditional_guidance_scale=8, 171 | prompt_weights=prompt_weights, 172 | seed=1407923, 173 | batch_size=1, 174 | ) 175 | ``` 176 | 177 | Decreasing the weight associated to `` will generate the following images. 178 | 179 | ![teaser](assets/tree.svg) 180 | 181 | ## Note about the cross-attention parameters 182 | 183 | For the prompt editing method, implemented in the function `text_to_image_ptp`, varying the parameters that indicate in which phase of the diffusion process the edited cross-attention maps should get injected (e.g. `cross_attn2_replace_steps_start`, `cross_attn1_replace_steps_start`), may output different results (image below). 184 | 185 | The cross-attention and prompt weights hyperparameters should be tuned according to the users' necessities and desired outputs. 186 | 187 | ![teaser](assets/doggy.svg) 188 | 189 | More info in [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl#usage) and the [paper](https://arxiv.org/abs/2208.01626). 190 | 191 | # :ballot_box_with_check: TODO 192 | 193 | - [x] Add tutorials and Google Colabs. 194 | - [x] Add multi-batch support. 195 | - [ ] Add examples for Stable Diffusion 2.x. 196 | 197 | # 👨‍🎓 References 198 | 199 | - [keras-cv](https://github.com/keras-team/keras-cv/tree/master/keras_cv/models/generative/stable_diffusion) for the TensorFlow implementation of Stable Diffusion. 200 | - [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl) unofficial implementation of the paper, where the method `get_matching_sentence_tokens` and code logic were used. 201 | - [google/prompt-to-prompt](https://github.com/google/prompt-to-prompt) Official implementation of the paper in PyTorch. 202 | 203 | # 🔬 Contributing 204 | 205 | Feel free to open an [issue](https://github.com/miguelcalado/prompt-to-prompt-tensorflow/issues) or create a [Pull Request](https://github.com/miguelcalado/prompt-to-prompt-tensorflow/pulls). 206 | 207 | For PRs, after implementing the changes please run the `Makefile` for formatting and linting the submitted code: 208 | 209 | - `make init`: to create a python environment with all the developer packages (Optional). 210 | - `make format`: to format the code. 211 | - `make lint`: to lint the code. 212 | - `make type_check`: to check for type hints. 213 | - `make all`: to run all the checks. 214 | 215 | # :scroll: License 216 | 217 | Licensed under the Apache License 2.0. See [LICENSE](LICENSE) to read it in full. 218 | -------------------------------------------------------------------------------- /ptp_utils.py: -------------------------------------------------------------------------------- 1 | """Utility methods used to implement Prompt-to-Prompt paper in TensorFlow. 2 | 3 | References 4 | ---------- 5 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control." 6 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or. 7 | https://arxiv.org/abs/2208.01626 8 | 9 | Credits 10 | ---------- 11 | - Unofficial implementation of the paper, where the method `get_matching_sentence_tokens` 12 | and code logic were used: [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl). 13 | """ 14 | 15 | from typing import Tuple 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | from keras_cv.models.stable_diffusion.diffusion_model import td_dot 20 | from tensorflow import keras 21 | 22 | import seq_aligner 23 | 24 | MAX_TEXT_LEN = 77 25 | 26 | 27 | def rename_cross_attention_layers(diff_model: tf.keras.Model): 28 | """Add suffix to the cross attention layers. 29 | 30 | This becomes useful when using the prompt editing method to save the 31 | attention maps and manipulate the control variables. 32 | 33 | Parameters 34 | ---------- 35 | diff_model : tf.keras.Model 36 | Diffusion model. 37 | 38 | Returns 39 | ------- 40 | tf.keras.Model 41 | Diffusion model with renamed crossed attention layers. 42 | """ 43 | cross_attention_count = 0 44 | for submodule in diff_model.submodules: 45 | submodule_name = submodule.name 46 | if not cross_attention_count % 2 and "cross_attention" in submodule_name: 47 | submodule._name = f"{submodule_name}_attn1" 48 | cross_attention_count += 1 49 | elif cross_attention_count % 2 and "cross_attention" in submodule_name: 50 | submodule._name = f"{submodule_name}_attn2" 51 | cross_attention_count += 1 52 | 53 | 54 | def update_cross_attn_mode( 55 | diff_model: tf.keras.Model, mode: str, attn_suffix: str = "attn" 56 | ): 57 | """Update the mode control variable. 58 | 59 | Parameters 60 | ---------- 61 | diff_model : tf.keras.Model 62 | Diffusion model. 63 | mode : str 64 | The mode parameter can take 3 values: 65 | - save: to save the attention map. 66 | - edit: to calculate the attention map with respect to the edited prompt. 67 | - unconditional: to perform the standard attention computations. 68 | attn_suffix : str, optional 69 | Suffix used to find the attention layer, by default "attn". 70 | """ 71 | for submodule in diff_model.submodules: 72 | submodule_name = submodule.name 73 | if ( 74 | "cross_attention" in submodule_name 75 | and attn_suffix in submodule_name.split("_")[-1] 76 | ): 77 | submodule.cross_attn_mode.assign(mode) 78 | 79 | 80 | def update_attn_weights_usage(diff_model: tf.keras.Model, use: bool): 81 | """Update the mode control variable. 82 | 83 | Parameters 84 | ---------- 85 | diff_model : tf.keras.Model 86 | Diffusion model. 87 | use : bool 88 | Whether to use the prompt weights. 89 | """ 90 | for submodule in diff_model.submodules: 91 | submodule_name = submodule.name 92 | if ( 93 | "cross_attention" in submodule_name 94 | and "attn2" in submodule_name.split("_")[-1] 95 | ): 96 | submodule.use_prompt_weights.assign(use) 97 | 98 | 99 | def add_attn_weights(diff_model: tf.keras.Model, prompt_weights: np.ndarray): 100 | """Assign the attention weights to the diffusion model's corresponding tf.variable. 101 | 102 | Parameters 103 | ---------- 104 | diff_model : tf.keras.Model 105 | Diffusion model. 106 | prompt_weights : List 107 | Weights of the attention tokens. 108 | """ 109 | for submodule in diff_model.submodules: 110 | submodule_name = submodule.name 111 | if ( 112 | "cross_attention" in submodule_name 113 | and "attn2" in submodule_name.split("_")[-1] 114 | ): 115 | submodule.prompt_weights.assign(prompt_weights) 116 | 117 | 118 | def put_mask_dif_model( 119 | diff_model: tf.keras.Model, mask: np.ndarray, indices: np.ndarray 120 | ): 121 | """Assign the diffusion model's tf.variables with the passed mask and indices. 122 | 123 | Parameters 124 | ---------- 125 | diff_model : tf.keras.Model 126 | Diffusion model. 127 | mask : np.ndarray 128 | Mask of the original and edited prompt overlap. 129 | indices : np.ndarray 130 | Indices of the original and edited prompt overlap. 131 | """ 132 | for submodule in diff_model.submodules: 133 | submodule_name = submodule.name 134 | if ( 135 | "cross_attention" in submodule_name 136 | and "attn2" in submodule_name.split("_")[-1] 137 | ): 138 | submodule.prompt_edit_mask.assign(mask) 139 | submodule.prompt_edit_indices.assign(indices) 140 | 141 | 142 | def get_matching_sentence_tokens( 143 | prompt, prompt_edit, tokenizer 144 | ) -> Tuple[np.ndarray, np.ndarray]: 145 | """Create the mask and indices of the overlap between the tokens of the original \ 146 | prompt and the edited one. 147 | 148 | Original code source: https://github.com/bloc97/CrossAttentionControl/ 149 | 150 | Parameters 151 | ---------- 152 | tokens : np.ndarray 153 | Array of the original prompt tokens. 154 | tokens_edit : np.ndarray 155 | Array of the edit prompt tokens. 156 | 157 | Returns 158 | ------- 159 | Tuple[np.ndarray, np.ndarray] 160 | Mask and indices of the overlap between the original token and edit prompts. 161 | """ 162 | tokens_conditional = tokenizer.encode(prompt) 163 | tokens_conditional_edit = tokenizer.encode(prompt_edit) 164 | mask, indices = seq_aligner.get_mapper(tokens_conditional, tokens_conditional_edit) 165 | return mask, indices 166 | 167 | 168 | def set_initial_tf_variables(diff_model: tf.keras.Model): 169 | """Create initial control variables to auxiliate the prompt editing method. 170 | 171 | Parameters 172 | ---------- 173 | diff_model : tf.keras.Model 174 | Diffusion model. 175 | """ 176 | for submodule in diff_model.submodules: 177 | submodule_name = submodule.name 178 | if "cross_attention" in submodule_name: 179 | # Set control variables 180 | submodule.cross_attn_mode = tf.Variable( 181 | "", dtype=tf.string, trainable=False 182 | ) 183 | submodule.use_prompt_weights = tf.Variable( 184 | False, dtype=tf.bool, trainable=False 185 | ) 186 | # Set array variables 187 | submodule.attn_map = tf.Variable( 188 | [], shape=tf.TensorShape(None), dtype=tf.float32, trainable=False 189 | ) 190 | submodule.prompt_edit_mask = tf.Variable( 191 | [], shape=tf.TensorShape(None), dtype=tf.float32, trainable=False 192 | ) 193 | submodule.prompt_edit_indices = tf.Variable( 194 | [], shape=tf.TensorShape(None), dtype=tf.int32, trainable=False 195 | ) 196 | submodule.prompt_weights = tf.Variable( 197 | [], shape=tf.TensorShape(None), dtype=tf.float32, trainable=False 198 | ) 199 | 200 | 201 | def reset_initial_tf_variables(diff_model: tf.keras.Model): 202 | """Reset the control variables to their default values. 203 | 204 | Parameters 205 | ---------- 206 | diff_model : tf.keras.Model 207 | Diffusion model. 208 | """ 209 | for submodule in diff_model.submodules: 210 | submodule_name = submodule.name 211 | if "cross_attention" in submodule_name: 212 | # Reset control variables 213 | submodule.cross_attn_mode.assign("") 214 | submodule.use_prompt_weights.assign(False) 215 | # Reset array variables 216 | submodule.attn_map.assign([]) 217 | submodule.prompt_edit_mask.assign([]) 218 | submodule.prompt_edit_indices.assign([]) 219 | submodule.prompt_weights.assign([]) 220 | 221 | 222 | def overwrite_forward_call(diff_model: tf.keras.Model): 223 | """Update the attention forward pass with a custom call method. 224 | 225 | Parameters 226 | ---------- 227 | diff_model : tf.keras.Model 228 | Diffusion model. 229 | """ 230 | for submodule in diff_model.submodules: 231 | submodule_name = submodule.name 232 | if "cross_attention" in submodule_name: 233 | # Overwrite forward pass method 234 | submodule.call = call_attn_edit.__get__(submodule) 235 | 236 | 237 | def call_attn_edit(self, inputs): 238 | """Implmentation of the custom attention forward pass used in the paper's method.""" 239 | inputs, context = inputs 240 | context = inputs if context is None else context 241 | q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context) 242 | q = tf.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size)) 243 | k = tf.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size)) 244 | v = tf.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size)) 245 | 246 | q = tf.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) 247 | k = tf.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time) 248 | v = tf.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size) 249 | 250 | score = td_dot(q, k) * self.scale 251 | weights = keras.activations.softmax(score) # (bs, num_heads, time, time) 252 | 253 | # Method: Prompt Refinement 254 | if tf.equal(self.cross_attn_mode, "edit") and tf.not_equal( 255 | tf.size(self.prompt_edit_mask), 0 256 | ): # not empty 257 | weights_masked = tf.gather(self.attn_map, self.prompt_edit_indices, axis=-1) 258 | edit_weights = weights_masked * self.prompt_edit_mask + weights * ( 259 | 1 - self.prompt_edit_mask 260 | ) 261 | weights = tf.reshape(edit_weights, shape=tf.shape(weights)) 262 | 263 | # Use the attention from the original prompt (M_t) 264 | if tf.equal(self.cross_attn_mode, "use_last"): 265 | weights = tf.reshape(self.attn_map, shape=tf.shape(weights)) 266 | 267 | # Save attention 268 | if tf.equal(self.cross_attn_mode, "save"): 269 | self.attn_map.assign(weights) 270 | 271 | # Method: Attention Re–weighting 272 | if tf.equal(self.use_prompt_weights, True) and tf.not_equal( 273 | tf.size(self.prompt_weights), 0 274 | ): 275 | attn_map_weighted = weights * self.prompt_weights 276 | weights = tf.reshape(attn_map_weighted, shape=tf.shape(weights)) 277 | 278 | attn = td_dot(weights, v) 279 | attn = tf.transpose(attn, (0, 2, 1, 3)) # (bs, time, num_heads, head_size) 280 | out = tf.reshape(attn, (-1, inputs.shape[1], self.num_heads * self.head_size)) 281 | return self.out_proj(out) 282 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.10.0 2 | tensorflow_datasets 3 | h5py==3.7.0 4 | Pillow==9.2.0 5 | tqdm==4.64.1 6 | ftfy==6.1.1 7 | regex==2022.9.13 8 | tensorflow-addons==0.17.1 9 | git+https://github.com/keras-team/keras-cv -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | black 2 | isort 3 | mypy 4 | flake8 5 | nbqa -------------------------------------------------------------------------------- /seq_aligner.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class ScoreParams: 5 | def __init__(self, gap, match, mismatch): 6 | self.gap = gap 7 | self.match = match 8 | self.mismatch = mismatch 9 | 10 | def mis_match_char(self, x, y): 11 | if x != y: 12 | return self.mismatch 13 | else: 14 | return self.match 15 | 16 | 17 | def get_mapper(x_seq: str, y_seq: str): 18 | score = ScoreParams(0, 1, -1) 19 | matrix, trace_back = global_align(x_seq, y_seq, score) 20 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1] 21 | alphas = np.ones(77) 22 | alphas[: mapper_base.shape[0]] = (mapper_base[:, 1] >= 0).astype(float) 23 | mapper = np.zeros(77, dtype=int) 24 | mapper[: mapper_base.shape[0]] = mapper_base[:, 1] 25 | mapper[mapper_base.shape[0] :] = len(y_seq) + np.arange(77 - len(y_seq)) 26 | return alphas, mapper 27 | 28 | 29 | def global_align(x, y, score): 30 | matrix = get_matrix(len(x), len(y), score.gap) 31 | trace_back = get_traceback_matrix(len(x), len(y)) 32 | for i in range(1, len(x) + 1): 33 | for j in range(1, len(y) + 1): 34 | left = matrix[i, j - 1] + score.gap 35 | up = matrix[i - 1, j] + score.gap 36 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1]) 37 | matrix[i, j] = max(left, up, diag) 38 | if matrix[i, j] == left: 39 | trace_back[i, j] = 1 40 | elif matrix[i, j] == up: 41 | trace_back[i, j] = 2 42 | else: 43 | trace_back[i, j] = 3 44 | return matrix, trace_back 45 | 46 | 47 | def get_matrix(size_x, size_y, gap): 48 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 49 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap 50 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap 51 | return matrix 52 | 53 | 54 | def get_traceback_matrix(size_x, size_y): 55 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32) 56 | matrix[0, 1:] = 1 57 | matrix[1:, 0] = 2 58 | matrix[0, 0] = 4 59 | return matrix 60 | 61 | 62 | def get_aligned_sequences(x, y, trace_back): 63 | x_seq = [] 64 | y_seq = [] 65 | i = len(x) 66 | j = len(y) 67 | mapper_y_to_x = [] 68 | while i > 0 or j > 0: 69 | if trace_back[i, j] == 3: 70 | x_seq.append(x[i - 1]) 71 | y_seq.append(y[j - 1]) 72 | i = i - 1 73 | j = j - 1 74 | mapper_y_to_x.append((j, i)) 75 | elif trace_back[i][j] == 1: 76 | x_seq.append("-") 77 | y_seq.append(y[j - 1]) 78 | j = j - 1 79 | mapper_y_to_x.append((j, -1)) 80 | elif trace_back[i][j] == 2: 81 | x_seq.append(x[i - 1]) 82 | y_seq.append("-") 83 | i = i - 1 84 | elif trace_back[i][j] == 4: 85 | break 86 | mapper_y_to_x.reverse() 87 | return x_seq, y_seq, np.array(mapper_y_to_x) 88 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 125 3 | max-complexity = 18 4 | docstring-convention = numpy 5 | show-source = True 6 | statistics = True 7 | count = True 8 | # https://www.flake8rules.com/ 9 | ignore = 10 | # Too many leading ```#``` for a block comment 11 | E266, 12 | # Line break occurred before a binary operator 13 | W503, 14 | # Missing docstring in public module 15 | D100, 16 | # Whitespace before ':' 17 | E203, 18 | extend-exclude = 19 | ptp_dev/ 20 | 21 | [isort] 22 | profile=black 23 | 24 | [mypy] 25 | check_untyped_defs = True 26 | warn_unused_configs = True 27 | 28 | [mypy-matplotlib.*] 29 | ignore_missing_imports = True 30 | 31 | [mypy-numpy.*] 32 | ignore_missing_imports = True 33 | 34 | [mypy-tensorflow.*] 35 | ignore_missing_imports = True 36 | 37 | [mypy-tqdm.*] 38 | ignore_missing_imports = True 39 | 40 | [mypy-keras_cv.*] 41 | ignore_missing_imports = True -------------------------------------------------------------------------------- /stable_diffusion.py: -------------------------------------------------------------------------------- 1 | """TensorFlow/Keras implementation of Stable Diffusion and Prompt-to-Prompt papers. 2 | 3 | References 4 | ---------- 5 | - "High-Resolution Image Synthesis With Latent Diffusion Models" 6 | Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bjorn 7 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 8 | https://arxiv.org/abs/2112.10752 9 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control." 10 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or. 11 | https://arxiv.org/abs/2208.01626 12 | 13 | Credits 14 | ---------- 15 | - [keras-cv](https://github.com/keras-team/keras-cv/tree/master/keras_cv/models/generative/stable_diffusion) \ 16 | for the TensorFlow/Keras implementation of Stable Diffusion. 17 | - [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl) unofficial implementation of \ 18 | the paper, where the method `get_matching_sentence_tokens` and code logic were used. 19 | - [google/prompt-to-prompt](https://github.com/google/prompt-to-prompt) official implementation of the paper in PyTorch. 20 | """ 21 | 22 | import math 23 | from typing import List, Optional, Tuple, Union 24 | 25 | import numpy as np 26 | import tensorflow as tf 27 | from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer 28 | from keras_cv.models.stable_diffusion.constants import ( 29 | _ALPHAS_CUMPROD, 30 | _UNCONDITIONAL_TOKENS, 31 | ) 32 | from keras_cv.models.stable_diffusion.decoder import Decoder 33 | from keras_cv.models.stable_diffusion.diffusion_model import ( 34 | DiffusionModel, 35 | DiffusionModelV2, 36 | ) 37 | from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder 38 | from keras_cv.models.stable_diffusion.text_encoder import TextEncoder, TextEncoderV2 39 | from tensorflow import keras 40 | 41 | import ptp_utils 42 | 43 | MAX_PROMPT_LENGTH = 77 44 | NUM_TRAIN_TIMESTEPS = 1000 45 | 46 | 47 | class StableDiffusionBase: 48 | """Implementation of Stable Diffusion and Prompt-to-Prompt papers in TensorFlow/Keras. 49 | 50 | Parameters 51 | ---------- 52 | strategy : tf.distribute 53 | TensorFlow strategy for running computations across multiple devices. 54 | img_height : int, optional 55 | Image height, by default 512 56 | img_width : int, optional 57 | Image width, by default 512 58 | jit_compile : bool, optional 59 | Flag to compile the models to XLA, by default False. 60 | download_weights : bool, optional 61 | Flag to download the models weights, by default True. 62 | 63 | Examples 64 | -------- 65 | >>> import tensorflow as tf 66 | >>> from PIL import Image 67 | >>> from stable_diffusion import StableDiffusion 68 | >>> strategy = tf.distribute.get_strategy() # To use only one GPU 69 | >>> generator = StableDiffusion( 70 | strategy=strategy, 71 | img_height=512, 72 | img_width=512, 73 | jit_compile=False, 74 | ) 75 | >>> img = generator.text_to_image( 76 | prompt="teddy bear with sunglasses relaxing in a pool", 77 | num_steps=50, 78 | unconditional_guidance_scale=8, 79 | seed=3345435, 80 | batch_size=1, 81 | ) 82 | >>> Image.fromarray(img[0]).save("original_prompt.png") 83 | 84 | Now lets edit the image to customize the teddy bear's sunglasses 85 | 86 | >>> img = generator.text_to_image_ptp( 87 | prompt="teddy bear with sunglasses relaxing in a pool", 88 | prompt_edit="teddy bear with heart-shaped red colored sunglasses relaxing in a pool", 89 | num_steps=50, 90 | unconditional_guidance_scale=8, 91 | cross_attn2_replace_steps_start=0.0, 92 | cross_attn2_replace_steps_end=1.0, 93 | cross_attn1_replace_steps_start=1.0, 94 | cross_attn1_replace_steps_end=1.0, 95 | seed=3345435, 96 | batch_size=1,from keras_cv.models.stable_diffusion.decoder import Decoder 97 | ) 98 | >>> Image.fromarray(img[0]).save("edited_prompt.png") 99 | """ 100 | 101 | def __init__( 102 | self, 103 | img_height: int = 512, 104 | img_width: int = 512, 105 | jit_compile: bool = False, 106 | ): 107 | 108 | # UNet requires multiples of 2**7 = 128 109 | img_height = round(img_height / 128) * 128 110 | img_width = round(img_width / 128) * 128 111 | self.img_height = img_height 112 | self.img_width = img_width 113 | 114 | # lazy initialize the component models and the tokenizer 115 | self._image_encoder = None 116 | self._text_encoder = None 117 | self._diffusion_model = None 118 | self._diffusion_model_ptp = None 119 | self._decoder = None 120 | self._tokenizer = None 121 | 122 | self.jit_compile = jit_compile 123 | 124 | def text_to_image( 125 | self, 126 | prompt: str, 127 | negative_prompt: Optional[str] = None, 128 | num_steps: int = 50, 129 | unconditional_guidance_scale: float = 7.5, 130 | batch_size: int = 1, 131 | seed: Optional[int] = None, 132 | ) -> np.ndarray: 133 | """Generate an image based on a prompt text. 134 | 135 | Parameters 136 | ---------- 137 | prompt : str 138 | Text containing the information for the model to generate. 139 | negative_prompt : str 140 | A string containing information to negatively guide the image 141 | generation (e.g. by removing or altering certain aspects of the 142 | generated image). 143 | num_steps : int, optional 144 | Number of diffusion steps (controls image quality), by default 50. 145 | unconditional_guidance_scale : float, optional 146 | Controls how closely the image should adhere to the prompt, by default 7.5. 147 | batch_size : int, optional 148 | Batch size (number of images to generate), by default 1. 149 | seed : Optional[int], optional 150 | Number to seed the random noise, by default None. 151 | 152 | Returns 153 | ------- 154 | np.ndarray 155 | Generated image. 156 | """ 157 | 158 | # Tokenize and encode prompt 159 | encoded_text = self.encode_text(prompt) 160 | 161 | conditional_context = self._expand_tensor(encoded_text, batch_size) 162 | 163 | if negative_prompt is None: 164 | unconditional_context = tf.repeat( 165 | self._get_unconditional_context(), batch_size, axis=0 166 | ) 167 | else: 168 | unconditional_context = self.encode_text(negative_prompt) 169 | unconditional_context = self._expand_tensor( 170 | unconditional_context, batch_size 171 | ) 172 | 173 | # Get initial random noise 174 | latent = self._get_initial_diffusion_noise(batch_size, seed) 175 | 176 | # Scheduler 177 | timesteps = tf.range(1, 1000, 1000 // num_steps) 178 | 179 | # Get Initial parameters 180 | alphas, alphas_prev = self._get_initial_alphas(timesteps) 181 | 182 | progbar = keras.utils.Progbar(len(timesteps)) 183 | iteration = 0 184 | # Diffusion stage 185 | for index, timestep in list(enumerate(timesteps))[::-1]: 186 | 187 | t_emb = self._get_timestep_embedding(timestep, batch_size) 188 | 189 | # Predict the unconditional noise residual 190 | unconditional_latent = self.diffusion_model.predict_on_batch( 191 | [latent, t_emb, unconditional_context] 192 | ) 193 | 194 | # Predict the conditional noise residual 195 | conditional_latent = self.diffusion_model.predict_on_batch( 196 | [latent, t_emb, conditional_context] 197 | ) 198 | 199 | # Perform guidance 200 | e_t = unconditional_latent + unconditional_guidance_scale * ( 201 | conditional_latent - unconditional_latent 202 | ) 203 | 204 | a_t, a_prev = alphas[index], alphas_prev[index] 205 | latent = self._get_x_prev(latent, e_t, a_t, a_prev) 206 | 207 | iteration += 1 208 | progbar.update(iteration) 209 | 210 | # Decode image 211 | img = self._get_decoding_stage(latent) 212 | 213 | return img 214 | 215 | def encode_text(self, prompt): 216 | """Encodes a prompt into a latent text encoding. 217 | The encoding produced by this method should be used as the 218 | `encoded_text` parameter of `StableDiffusion.generate_image`. Encoding 219 | text separately from generating an image can be used to arbitrarily 220 | modify the text encoding priot to image generation, e.g. for walking 221 | between two prompts. 222 | Args: 223 | prompt: a string to encode, must be 77 tokens or shorter. 224 | Example: 225 | ```python 226 | from keras_cv.models import StableDiffusion 227 | model = StableDiffusion(img_height=512, img_width=512, jit_compile=True) 228 | encoded_text = model.encode_text("Tacos at dawn") 229 | img = model.generate_image(encoded_text) 230 | ``` 231 | """ 232 | # Tokenize prompt (i.e. starting context) 233 | inputs = self.tokenizer.encode(prompt) 234 | if len(inputs) > MAX_PROMPT_LENGTH: 235 | raise ValueError( 236 | f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)" 237 | ) 238 | phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs)) 239 | phrase = tf.convert_to_tensor([phrase], dtype=tf.int32) 240 | 241 | context = self.text_encoder.predict_on_batch([phrase, self._get_pos_ids()]) 242 | 243 | return context 244 | 245 | def text_to_image_ptp( 246 | self, 247 | prompt: str, 248 | prompt_edit: str, 249 | method: str, 250 | self_attn_steps: Union[float, Tuple[float, float]], 251 | cross_attn_steps: Union[float, Tuple[float, float]], 252 | attn_edit_weights: np.ndarray = np.array([]), 253 | negative_prompt: Optional[str] = None, 254 | num_steps: int = 50, 255 | unconditional_guidance_scale: float = 7.5, 256 | batch_size: int = 1, 257 | seed: Optional[int] = None, 258 | ) -> np.ndarray: 259 | """Generate an image based on the Prompt-to-Prompt editing method. 260 | 261 | Edit a generated image controlled only through text. 262 | Paper: https://arxiv.org/abs/2208.01626 263 | 264 | Parameters 265 | ---------- 266 | prompt : str 267 | Text containing the information for the model to generate. 268 | prompt_edit : str 269 | Second prompt used to control the edit of the generated image. 270 | method : str 271 | Prompt-to-Prompt method to chose. Can be ['refine', 'replace', 'reweigh']. 272 | self_attn_steps : Union[float, Tuple[float, float]] 273 | Specifies at which step of the start of the diffusion process it replaces 274 | the self attention maps with the ones produced by the edited prompt. 275 | cross_attn_steps : Union[float, Tuple[float, float]] 276 | Specifies at which step of the start of the diffusion process it replaces 277 | the cross attention maps with the ones produced by the edited prompt. 278 | attn_edit_weights: np.array([]), optional 279 | Set of weights for each edit prompt token. 280 | This is used for manipulating the importance of the edit prompt tokens, 281 | increasing or decreasing the importance assigned to each word. 282 | negative_prompt : Optional[str] = None 283 | A string containing information to negatively guide the image 284 | generation (e.g. by removing or altering certain aspects of the 285 | generated image). 286 | num_steps : int, optional 287 | Number of diffusion steps (controls image quality), by default 50. 288 | unconditional_guidance_scale : float, optional 289 | Controls how closely the image should adhere to the prompt, by default 7.5. 290 | batch_size : int, optional 291 | Batch size (number of images to generate), by default 1. 292 | seed : Optional[int], optional 293 | Number to seed the random noise, by default None. 294 | 295 | Returns 296 | ------- 297 | np.ndarray 298 | Generated image with edited prompt method. 299 | 300 | Examples 301 | -------- 302 | >>> import tensorflow as tf 303 | >>> from PIL import Image 304 | >>> from stable_diffusion import StableDiffusion 305 | >>> strategy = tf.distribute.get_strategy() # To use only one GPU 306 | >>> generator = StableDiffusion( 307 | strategy=strategy, 308 | img_height=512, 309 | img_width=512, 310 | jit_compile=False, 311 | ) 312 | 313 | Edit the original generated image by adding heart-shaped red colored to the sunglasses. 314 | 315 | >>> img = generator.text_to_image_ptp( 316 | prompt="teddy bear with sunglasses relaxing in a pool", 317 | prompt_edit="teddy bear with heart-shaped red colored sunglasses relaxing in a pool", 318 | num_steps=50, 319 | unconditional_guidance_scale=8, 320 | self_attn_steps=0.0, 321 | cross_attn_steps=1.0, 322 | seed=3345435, 323 | batch_size=1, 324 | ) 325 | >>> Image.fromarray(img[0]).save("edited_prompt.png") 326 | """ 327 | 328 | # Prompt-to-Prompt: check inputs 329 | if isinstance(self_attn_steps, float): 330 | self_attn_steps = (0.0, self_attn_steps) 331 | if isinstance(cross_attn_steps, float): 332 | cross_attn_steps = (0.0, cross_attn_steps) 333 | 334 | # Tokenize and encode prompt 335 | encoded_text = self.encode_text(prompt) 336 | conditional_context = self._expand_tensor(encoded_text, batch_size) 337 | 338 | # Tokenize and encode edit prompt 339 | encoded_text_edit = self.encode_text(prompt_edit) 340 | conditional_context_edit = self._expand_tensor(encoded_text_edit, batch_size) 341 | 342 | if negative_prompt is None: 343 | unconditional_context = tf.repeat( 344 | self._get_unconditional_context(), batch_size, axis=0 345 | ) 346 | else: 347 | unconditional_context = self.encode_text(negative_prompt) 348 | unconditional_context = self._expand_tensor( 349 | unconditional_context, batch_size 350 | ) 351 | 352 | if method == "refine": 353 | # Get the mask and indices of the difference between the original prompt token's and the edited one 354 | mask, indices = ptp_utils.get_matching_sentence_tokens( 355 | prompt, prompt_edit, self.tokenizer 356 | ) 357 | # Add the mask and indices to the diffusion model 358 | ptp_utils.put_mask_dif_model(self.diffusion_model_ptp, mask, indices) 359 | 360 | # Update prompt weights variable 361 | if attn_edit_weights.size: 362 | ptp_utils.add_attn_weights( 363 | diff_model=self.diffusion_model_ptp, prompt_weights=attn_edit_weights 364 | ) 365 | 366 | # Get initial random noise 367 | latent = self._get_initial_diffusion_noise(batch_size, seed) 368 | 369 | # Scheduler 370 | timesteps = tf.range(1, 1000, 1000 // num_steps) 371 | 372 | # Get Initial parameters 373 | alphas, alphas_prev = self._get_initial_alphas(timesteps) 374 | 375 | progbar = keras.utils.Progbar(len(timesteps)) 376 | iteration = 0 377 | # Diffusion stage 378 | for index, timestep in list(enumerate(timesteps))[::-1]: 379 | 380 | t_emb = self._get_timestep_embedding(timestep, batch_size) 381 | 382 | # Change this! 383 | t_scale = 1 - (timestep / NUM_TRAIN_TIMESTEPS) 384 | 385 | # Update Cross-Attention mode to 'unconditional' 386 | ptp_utils.update_cross_attn_mode( 387 | diff_model=self.diffusion_model_ptp, mode="unconditional" 388 | ) 389 | 390 | # Predict the unconditional noise residual 391 | unconditional_latent = self.diffusion_model_ptp.predict_on_batch( 392 | [latent, t_emb, unconditional_context] 393 | ) 394 | 395 | # Save last cross attention activations 396 | ptp_utils.update_cross_attn_mode( 397 | diff_model=self.diffusion_model_ptp, mode="save" 398 | ) 399 | # Predict the conditional noise residual 400 | _ = self.diffusion_model_ptp.predict_on_batch( 401 | [latent, t_emb, conditional_context] 402 | ) 403 | 404 | # Edit the Cross-Attention layer activations 405 | if cross_attn_steps[0] <= t_scale <= cross_attn_steps[1]: 406 | if method == "replace": 407 | # Use cross attention from the original prompt (M_t) 408 | ptp_utils.update_cross_attn_mode( 409 | diff_model=self.diffusion_model_ptp, 410 | mode="use_last", 411 | attn_suffix="attn2", 412 | ) 413 | elif method == "refine": 414 | # Use cross attention with function A(J) 415 | ptp_utils.update_cross_attn_mode( 416 | diff_model=self.diffusion_model_ptp, 417 | mode="edit", 418 | attn_suffix="attn2", 419 | ) 420 | if method == "reweight" or attn_edit_weights.size: 421 | # Use the parsed weights on the edited prompt 422 | ptp_utils.update_attn_weights_usage( 423 | diff_model=self.diffusion_model_ptp, use=True 424 | ) 425 | 426 | else: 427 | # Use cross attention from the edited prompt (M^*_t) 428 | ptp_utils.update_cross_attn_mode( 429 | diff_model=self.diffusion_model_ptp, 430 | mode="injection", 431 | attn_suffix="attn2", 432 | ) 433 | 434 | # Edit the self-Attention layer activations 435 | if self_attn_steps[0] <= t_scale <= self_attn_steps[1]: 436 | # Use self attention from the original prompt (M_t) 437 | ptp_utils.update_cross_attn_mode( 438 | diff_model=self.diffusion_model_ptp, 439 | mode="use_last", 440 | attn_suffix="attn1", 441 | ) 442 | else: 443 | # Use self attention from the edited prompt (M^*_t) 444 | ptp_utils.update_cross_attn_mode( 445 | diff_model=self.diffusion_model_ptp, 446 | mode="injection", 447 | attn_suffix="attn1", 448 | ) 449 | 450 | # Predict the edited conditional noise residual 451 | conditional_latent_edit = self.diffusion_model_ptp.predict_on_batch( 452 | [latent, t_emb, conditional_context_edit], 453 | ) 454 | 455 | # Assign usage to False so it doesn't get used in other contexts 456 | if attn_edit_weights.size: 457 | ptp_utils.update_attn_weights_usage( 458 | diff_model=self.diffusion_model_ptp, use=False 459 | ) 460 | 461 | # Perform guidance 462 | e_t = unconditional_latent + unconditional_guidance_scale * ( 463 | conditional_latent_edit - unconditional_latent 464 | ) 465 | 466 | a_t, a_prev = alphas[index], alphas_prev[index] 467 | latent = self._get_x_prev(latent, e_t, a_t, a_prev) 468 | 469 | iteration += 1 470 | progbar.update(iteration) 471 | 472 | # Decode image 473 | img = self._get_decoding_stage(latent) 474 | 475 | # Reset control variables 476 | ptp_utils.reset_initial_tf_variables(self.diffusion_model_ptp) 477 | 478 | return img 479 | 480 | def _get_unconditional_context(self): 481 | unconditional_tokens = tf.convert_to_tensor( 482 | [_UNCONDITIONAL_TOKENS], dtype=tf.int32 483 | ) 484 | unconditional_context = self.text_encoder.predict_on_batch( 485 | [unconditional_tokens, self._get_pos_ids()] 486 | ) 487 | 488 | return unconditional_context 489 | 490 | def tokenize_prompt(self, prompt: str) -> tf.Tensor: 491 | """Tokenize a phrase prompt. 492 | 493 | Parameters 494 | ---------- 495 | prompt : str 496 | The prompt string to tokenize, must be 77 tokens or shorter. 497 | batch_size : int 498 | Batch size. 499 | 500 | Returns 501 | ------- 502 | np.ndarray 503 | Array of tokens. 504 | """ 505 | inputs = self.tokenizer.encode(prompt) 506 | if len(inputs) > MAX_PROMPT_LENGTH: 507 | raise ValueError( 508 | f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)" 509 | ) 510 | phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs)) 511 | phrase = tf.convert_to_tensor([phrase], dtype=tf.int32) 512 | return phrase 513 | 514 | def create_prompt_weights( 515 | self, prompt: str, prompt_weights: List[Tuple[str, float]] 516 | ) -> np.ndarray: 517 | """Create an array of weights for each prompt token. 518 | 519 | This is used for manipulating the importance of the prompt tokens, 520 | increasing or decreasing the importance assigned to each word. 521 | 522 | Parameters 523 | ---------- 524 | prompt : str 525 | The prompt string to tokenize, must be 77 tokens or shorter. 526 | prompt_weights : List[Tuple[str, float]] 527 | A list of tuples containing the pair of word and weight to be manipulated. 528 | batch_size : int 529 | Batch size. 530 | 531 | Returns 532 | ------- 533 | np.ndarray 534 | Array of weights to control the importance of each prompt token. 535 | """ 536 | 537 | # Initialize the weights to 1. 538 | weights = np.ones(MAX_PROMPT_LENGTH) 539 | 540 | # Get the prompt tokens 541 | tokens = self.tokenize_prompt(prompt) 542 | 543 | # Extract the new weights and tokens 544 | edit_weights = [weight for word, weight in prompt_weights] 545 | edit_tokens = [ 546 | self.tokenizer.encode(word)[1:-1] for word, weight in prompt_weights 547 | ] 548 | 549 | # Get the indexes of the tokens 550 | index_edit_tokens = np.in1d(tokens, edit_tokens).nonzero()[0] 551 | 552 | # Replace the original weight values 553 | weights[index_edit_tokens] = edit_weights 554 | return weights 555 | 556 | def _expand_tensor(self, text_embedding, batch_size): 557 | """Extends a tensor by repeating it to fit the shape of the given batch size.""" 558 | text_embedding = tf.squeeze(text_embedding) 559 | if text_embedding.shape.rank == 2: 560 | text_embedding = tf.repeat( 561 | tf.expand_dims(text_embedding, axis=0), batch_size, axis=0 562 | ) 563 | return text_embedding 564 | 565 | def _get_initial_alphas(self, timesteps): 566 | 567 | alphas = [_ALPHAS_CUMPROD[t] for t in timesteps] 568 | alphas_prev = [1.0] + alphas[:-1] 569 | 570 | return alphas, alphas_prev 571 | 572 | def _get_initial_diffusion_noise(self, batch_size: int, seed: Optional[int]): 573 | return tf.random.normal( 574 | (batch_size, self.img_height // 8, self.img_width // 8, 4), seed=seed 575 | ) 576 | 577 | def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000): 578 | half = dim // 2 579 | freqs = tf.math.exp( 580 | -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half 581 | ) 582 | args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs 583 | embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0) 584 | embedding = tf.reshape(embedding, [1, -1]) 585 | return tf.repeat(embedding, batch_size, axis=0) 586 | 587 | def _get_decoding_stage(self, latent): 588 | decoded = self.decoder.predict_on_batch(latent) 589 | decoded = ((decoded + 1) / 2) * 255 590 | return np.clip(decoded, 0, 255).astype("uint8") 591 | 592 | def _get_x_prev(self, x, e_t, a_t, a_prev): 593 | sqrt_one_minus_at = math.sqrt(1 - a_t) 594 | pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t) 595 | # Direction pointing to x_t 596 | dir_xt = math.sqrt(1.0 - a_prev) * e_t 597 | x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt 598 | return x_prev 599 | 600 | @staticmethod 601 | def _get_pos_ids(): 602 | return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32) 603 | 604 | @property 605 | def image_encoder(self): 606 | """image_encoder returns the VAE Encoder with pretrained weights. 607 | Usage: 608 | ```python 609 | sd = keras_cv.models.StableDiffusion() 610 | my_image = np.ones((512, 512, 3)) 611 | latent_representation = sd.image_encoder.predict(my_image) 612 | ``` 613 | """ 614 | if self._image_encoder is None: 615 | self._image_encoder = ImageEncoder(self.img_height, self.img_width) 616 | if self.jit_compile: 617 | self._image_encoder.compile(jit_compile=True) 618 | return self._image_encoder 619 | 620 | @property 621 | def text_encoder(self): 622 | pass 623 | 624 | @property 625 | def diffusion_model(self): 626 | pass 627 | 628 | @property 629 | def decoder(self): 630 | """decoder returns the diffusion image decoder model with pretrained weights. 631 | Can be overriden for tasks where the decoder needs to be modified. 632 | """ 633 | if self._decoder is None: 634 | self._decoder = Decoder(self.img_height, self.img_width) 635 | if self.jit_compile: 636 | self._decoder.compile(jit_compile=True) 637 | return self._decoder 638 | 639 | @property 640 | def tokenizer(self): 641 | """tokenizer returns the tokenizer used for text inputs. 642 | Can be overriden for tasks like textual inversion where the tokenizer needs to be modified. 643 | """ 644 | if self._tokenizer is None: 645 | self._tokenizer = SimpleTokenizer() 646 | return self._tokenizer 647 | 648 | 649 | class StableDiffusion(StableDiffusionBase): 650 | """Keras implementation of Stable Diffusion. 651 | 652 | Note that the StableDiffusion API, as well as the APIs of the sub-components 653 | of StableDiffusion (e.g. ImageEncoder, DiffusionModel) should be considered 654 | unstable at this point. We do not guarantee backwards compatability for 655 | future changes to these APIs. 656 | Stable Diffusion is a powerful image generation model that can be used, 657 | among other things, to generate pictures according to a short text description 658 | (called a "prompt"). 659 | Arguments: 660 | img_height: Height of the images to generate, in pixel. Note that only 661 | multiples of 128 are supported; the value provided will be rounded 662 | to the nearest valid value. Default: 512. 663 | img_width: Width of the images to generate, in pixel. Note that only 664 | multiples of 128 are supported; the value provided will be rounded 665 | to the nearest valid value. Default: 512. 666 | jit_compile: Whether to compile the underlying models to XLA. 667 | This can lead to a significant speedup on some systems. Default: False. 668 | Example: 669 | ```python 670 | from keras_cv.models import StableDiffusion 671 | from PIL import Image 672 | model = StableDiffusion(img_height=512, img_width=512, jit_compile=True) 673 | img = model.text_to_image( 674 | prompt="A beautiful horse running through a field", 675 | batch_size=1, # How many images to generate at once 676 | num_steps=25, # Number of iterations (controls image quality) 677 | seed=123, # Set this to always get the same image from the same prompt 678 | ) 679 | Image.fromarray(img[0]).save("horse.png") 680 | print("saved at horse.png") 681 | ``` 682 | References: 683 | - [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement) 684 | - [Original implementation](https://github.com/CompVis/stable-diffusion) 685 | """ 686 | 687 | def __init__( 688 | self, 689 | img_height=512, 690 | img_width=512, 691 | jit_compile=False, 692 | ): 693 | super().__init__(img_height, img_width, jit_compile) 694 | print( 695 | "By using this model checkpoint, you acknowledge that its usage is " 696 | "subject to the terms of the CreativeML Open RAIL-M license at " 697 | "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE" 698 | ) 699 | 700 | @property 701 | def text_encoder(self): 702 | """text_encoder returns the text encoder with pretrained weights. 703 | Can be overriden for tasks like textual inversion where the text encoder 704 | needs to be modified. 705 | """ 706 | if self._text_encoder is None: 707 | self._text_encoder = TextEncoder(MAX_PROMPT_LENGTH) 708 | if self.jit_compile: 709 | self._text_encoder.compile(jit_compile=True) 710 | return self._text_encoder 711 | 712 | @property 713 | def diffusion_model(self) -> tf.keras.Model: 714 | """diffusion_model returns the diffusion model with pretrained weights. 715 | Can be overriden for tasks where the diffusion model needs to be modified. 716 | """ 717 | if self._diffusion_model is None: 718 | self._diffusion_model = DiffusionModel( 719 | self.img_height, self.img_width, MAX_PROMPT_LENGTH 720 | ) 721 | if self.jit_compile: 722 | self._diffusion_model.compile(jit_compile=True) 723 | return self._diffusion_model 724 | 725 | @property 726 | def diffusion_model_ptp(self) -> tf.keras.Model: 727 | """diffusion_model_ptp returns the diffusion model with modifications for the Prompt-to-Prompt method. 728 | 729 | References 730 | ---------- 731 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control." 732 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or. 733 | https://arxiv.org/abs/2208.01626 734 | """ 735 | if self._diffusion_model_ptp is None: 736 | if self._diffusion_model is None: 737 | self._diffusion_model_ptp = self.diffusion_model 738 | else: 739 | # Reset the graph and add/overwrite variables and forward calls 740 | self._diffusion_model.compile(jit_compile=self.jit_compile) 741 | self._diffusion_model_ptp = self._diffusion_model 742 | 743 | # Add extra variables and callbacks 744 | ptp_utils.rename_cross_attention_layers(self._diffusion_model_ptp) 745 | ptp_utils.overwrite_forward_call(self._diffusion_model_ptp) 746 | ptp_utils.set_initial_tf_variables(self._diffusion_model_ptp) 747 | 748 | return self._diffusion_model_ptp 749 | 750 | 751 | class StableDiffusionV2(StableDiffusionBase): 752 | """Keras implementation of Stable Diffusion v2. 753 | Note that the StableDiffusion API, as well as the APIs of the sub-components 754 | of StableDiffusionV2 (e.g. ImageEncoder, DiffusionModelV2) should be considered 755 | unstable at this point. We do not guarantee backwards compatability for 756 | future changes to these APIs. 757 | Stable Diffusion is a powerful image generation model that can be used, 758 | among other things, to generate pictures according to a short text description 759 | (called a "prompt"). 760 | Arguments: 761 | img_height: Height of the images to generate, in pixel. Note that only 762 | multiples of 128 are supported; the value provided will be rounded 763 | to the nearest valid value. Default: 512. 764 | img_width: Width of the images to generate, in pixel. Note that only 765 | multiples of 128 are supported; the value provided will be rounded 766 | to the nearest valid value. Default: 512. 767 | jit_compile: Whether to compile the underlying models to XLA. 768 | This can lead to a significant speedup on some systems. Default: False. 769 | Example: 770 | ```python 771 | from keras_cv.models import StableDiffusionV2 772 | from PIL import Image 773 | model = StableDiffusionV2(img_height=512, img_width=512, jit_compile=True) 774 | img = model.text_to_image( 775 | prompt="A beautiful horse running through a field", 776 | batch_size=1, # How many images to generate at once 777 | num_steps=25, # Number of iterations (controls image quality) 778 | seed=123, # Set this to always get the same image from the same prompt 779 | ) 780 | Image.fromarray(img[0]).save("horse.png") 781 | print("saved at horse.png") 782 | ``` 783 | References: 784 | - [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement) 785 | - [Original implementation](https://github.com/Stability-AI/stablediffusion) 786 | """ 787 | 788 | def __init__( 789 | self, 790 | img_height=512, 791 | img_width=512, 792 | jit_compile=False, 793 | ): 794 | super().__init__(img_height, img_width, jit_compile) 795 | print( 796 | "By using this model checkpoint, you acknowledge that its usage is " 797 | "subject to the terms of the CreativeML Open RAIL++-M license at " 798 | "https://github.com/Stability-AI/stablediffusion/main/LICENSE-MODEL" 799 | ) 800 | 801 | @property 802 | def text_encoder(self): 803 | """text_encoder returns the text encoder with pretrained weights. 804 | Can be overriden for tasks like textual inversion where the text encoder 805 | needs to be modified. 806 | """ 807 | if self._text_encoder is None: 808 | self._text_encoder = TextEncoderV2(MAX_PROMPT_LENGTH) 809 | if self.jit_compile: 810 | self._text_encoder.compile(jit_compile=True) 811 | return self._text_encoder 812 | 813 | @property 814 | def diffusion_model(self) -> tf.keras.Model: 815 | """diffusion_model returns the diffusion model with pretrained weights. 816 | Can be overriden for tasks where the diffusion model needs to be modified. 817 | """ 818 | if self._diffusion_model is None: 819 | self._diffusion_model = DiffusionModelV2( 820 | self.img_height, self.img_width, MAX_PROMPT_LENGTH 821 | ) 822 | if self.jit_compile: 823 | self._diffusion_model.compile(jit_compile=True) 824 | return self._diffusion_model 825 | 826 | @property 827 | def diffusion_model_ptp(self) -> tf.keras.Model: 828 | """diffusion_model_ptp returns the diffusion model with modifications for the Prompt-to-Prompt method. 829 | 830 | References 831 | ---------- 832 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control." 833 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or. 834 | https://arxiv.org/abs/2208.01626 835 | """ 836 | if self._diffusion_model_ptp is None: 837 | if self._diffusion_model is None: 838 | self._diffusion_model_ptp = self.diffusion_model() 839 | else: 840 | # Reset the graph - this is to save up memory 841 | self._diffusion_model.compile(jit_compile=self.jit_compile) 842 | self._diffusion_model_ptp = self._diffusion_model 843 | 844 | # Add extra variables and callbacks 845 | ptp_utils.rename_cross_attention_layers(self._diffusion_model_ptp) 846 | ptp_utils.overwrite_forward_call(self._diffusion_model_ptp) 847 | ptp_utils.set_initial_tf_variables(self._diffusion_model_ptp) 848 | 849 | return self._diffusion_model_ptp 850 | --------------------------------------------------------------------------------