├── .gitignore
├── LICENSE
├── Makefile
├── README.md
├── assets
├── bowl.svg
├── chiwawa.svg
├── doggy.svg
├── pizza_example.svg
├── teaser.svg
└── tree.svg
├── notebooks
├── Attention Re-weighting.ipynb
└── Prompt Editing.ipynb
├── ptp_utils.py
├── requirements.txt
├── requirements_dev.txt
├── seq_aligner.py
├── setup.cfg
└── stable_diffusion.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | # Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | # poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | # pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 | ptp_dev/
131 |
132 | # Spyder project settings
133 | .spyderproject
134 | .spyproject
135 |
136 | # Rope project settings
137 | .ropeproject
138 |
139 | # mkdocs documentation
140 | /site
141 |
142 | # mypy
143 | .mypy_cache/
144 | .dmypy.json
145 | dmypy.json
146 |
147 | # Pyre type checker
148 | .pyre/
149 |
150 | # pytype static type analyzer
151 | .pytype/
152 |
153 | # Cython debug symbols
154 | cython_debug/
155 |
156 | # PyCharm
157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159 | # and can be added to the global gitignore or merged into this file. For a more nuclear
160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161 | #.idea/
162 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # .PHONY defines parts of the makefile that are not dependant on any specific file
2 | # This is most often used to store functions
3 | .PHONY = init format format_check format_notebooks lint type_check all
4 |
5 | init:
6 | @echo "Create a development python environment"
7 | python3 -m pip install virtualenv
8 | python3 -m venv ptp_dev
9 | . ptp_dev/bin/activate && cat requirements.txt | xargs -n 1 pip install
10 | . ptp_dev/bin/activate && pip install -r requirements_dev.txt
11 |
12 | format:
13 | @echo "Format code according to isort"
14 | . ptp_dev/bin/activate && isort *.py
15 | @echo "Format code according to black"
16 | . ptp_dev/bin/activate && black *.py
17 |
18 | format_check:
19 | @echo "Check code format according to isort"
20 | . ptp_dev/bin/activate && isort *.py --check
21 | @echo "Check code format according to black"
22 | . ptp_dev/bin/activate && black *.py --check
23 |
24 | format_notebooks:
25 | @echo "Format notebooks according to isort"
26 | . ptp_dev/bin/activate && nbqa isort .
27 | @echo "Format notebooks according to black"
28 | . ptp_dev/bin/activate && nbqa black .
29 |
30 | lint:
31 | @echo "Linter check: Flake8"
32 | . ptp_dev/bin/activate && flake8 .
33 |
34 | type_check:
35 | @echo "Type-test check: mypy"
36 | . ptp_dev/bin/activate && mypy .
37 |
38 | all: format format_notebooks lint type_check
39 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Prompt-to-Prompt: Tensorflow Implementation
2 |
3 | [](https://colab.research.google.com/drive/1AoRRd-6oXtFEfx9Ff85GNuTcwSssb5zz?usp=sharing) [](https://huggingface.co/fchollet/stable-diffusion)
4 |
5 |
6 |
7 | ### Unofficial Implementation of the paper Prompt-to-Prompt Image Editing with Cross Attention Control
8 |
9 | 
10 |
11 | [Link to the paper](https://arxiv.org/abs/2208.01626) | [Official PyTorch implementation](https://github.com/google/prompt-to-prompt/) | [Project page](https://prompt-to-prompt.github.io/)
12 |
13 | This repository contains the Tensorflow/Keras code implementation for the paper "**[Prompt-to-Prompt Image Editing with Cross Attention Control](https://arxiv.org/abs/2208.01626)**".
14 |
15 | # 🚀 Quickstart
16 |
17 | Current state-of-the-art methods require the user to provide a spatial mask to localize the edit which ignores the original structure and content within the masked region.
18 | The paper proposes a novel technique to edit the generated content of large-scale language models such as [DALL·E 2](https://openai.com/dall-e-2/), [Imagen](https://imagen.research.google/) or [Stable Diffusion](https://github.com/CompVis/stable-diffusion), **by only manipulating the text of the original parsed prompt**.
19 |
20 | To achieve this result, the authors present the *Prompt-to-Prompt* framework comprised of two functionalities:
21 |
22 | - **Prompt Editing**: where the key idea to edit the generated images is to inject cross-attention maps during the diffusion process, controlling which pixels attend to which tokens of the prompt text.
23 |
24 | - **Attention Re-weighting**: that amplifies or attenuates the effect of a word in the generated image. This is done by first attributing a weight to each token and later scaling the attention map assigned to the token. It's a nice alternative to **negative prompting** and **multi-prompting**.
25 |
26 | ## :gear: Installation
27 |
28 | Install dependencies using the `requirements.txt`.
29 |
30 | ```bash
31 | pip install -r requirements.txt
32 | ```
33 |
34 | Essentially, you need to have installed [TensorFlow](https://github.com/tensorflow/tensorflow) and [Keras-cv](https://github.com/keras-team/keras-cv/).
35 | ## 📚 Notebooks
36 |
37 | Try it yourself:
38 |
39 | - [**Prompt-to-Prompt: Prompt Editing** - Stable Diffusion](notebooks/Prompt%20Editing.ipynb) [](https://colab.research.google.com/drive/1AoRRd-6oXtFEfx9Ff85GNuTcwSssb5zz?usp=sharing)
40 | Notebook with examples for the *Prompt-to-Prompt* prompt editing approach for Stable Diffusion.
41 |
42 | - [**Prompt-to-Prompt: Attention Re-weighting** - Stable Diffusion](notebooks/Attention%20Re-weighting.ipynb) [](https://colab.research.google.com/drive/1UcIFg2Nd_LVaO3-UPPysCSVCNvbKjO11?usp=sharing)
43 | Notebook with examples for the *Prompt-to-Prompt* attention re-weighting approach for Stable Diffusion.
44 |
45 | # :dart: Prompt-to-Prompt Examples
46 |
47 | To start using the *Prompt-to-Prompt* framework, you first need to set up a Tensorflow [strategy](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy) for running computations across multiple devices (in case you have many).
48 |
49 | For example, you can check the available hardware with:
50 |
51 | ```python
52 | gpus = tf.config.list_physical_devices("GPU")
53 | tpus = tf.config.list_physical_devices("TPU")
54 | print(f"Num GPUs Available: {len(gpus)} | Num TPUs Available: {len(tpus)}")
55 | ```
56 |
57 | And adjust accordingly to your needs:
58 |
59 | ```python
60 | import tensorflow as tf
61 |
62 | # For running on multiple GPUs
63 | strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1", ...])
64 | # To get the default strategy
65 | strategy = tf.distribute.get_strategy()
66 | ...
67 | ```
68 |
69 | ## Prompt Editing
70 |
71 | Once the strategy is set, you can start generating images just like in [Keras-cv](https://github.com/keras-team/keras-cv/):
72 |
73 | ```python
74 | # Imports
75 | import tensorflow as tf
76 | from stable_diffusion import StableDiffusion
77 |
78 | generator = StableDiffusion(
79 | strategy=strategy,
80 | img_height=512,
81 | img_width=512,
82 | jit_compile=False,
83 | )
84 |
85 | # Generate text-to-image
86 | img = generator.text_to_image(
87 | prompt="a photo of a chiwawa with sunglasses and a bandana",
88 | num_steps=50,
89 | unconditional_guidance_scale=8,
90 | seed=5681067,
91 | batch_size=1,
92 | )
93 | # Generate Prompt-to-Prompt
94 | img_edit = generator.text_to_image_ptp(
95 | prompt="a photo of a chiwawa with sunglasses and a bandana",
96 | prompt_edit="a photo of a chiwawa with sunglasses and a pirate bandana",
97 | num_steps=50,
98 | unconditional_guidance_scale=8,
99 | cross_attn2_replace_steps_start=0.0,
100 | cross_attn2_replace_steps_end=1.0,
101 | cross_attn1_replace_steps_start=0.8,
102 | cross_attn1_replace_steps_end=1.0,
103 | seed=5681067,
104 | batch_size=1,
105 | )
106 | ```
107 |
108 | This generates the original and pirate bandana images shown below. You can play around and change the `` and `` attributes and many others!
109 |
110 | 
111 |
112 | Another example of prompt editing where one can control the content of the basket just by replacing a couple of words in the prompt:
113 |
114 | ```python
115 | img_edit = generator.text_to_image_ptp(
116 | prompt="a photo of basket with apples",
117 | prompt_edit="a photo of basket with oranges",
118 | num_steps=50,
119 | unconditional_guidance_scale=8,
120 | cross_attn2_replace_steps_start=0.0,
121 | cross_attn2_replace_steps_end=1.0,
122 | cross_attn1_replace_steps_start=0.0,
123 | cross_attn1_replace_steps_end=1.0,
124 | seed=1597337,
125 | batch_size=1,
126 | )
127 | ```
128 |
129 | The image below showcases examples where only the word `` was replaced with other fruits or animals. Try changing ` to other recipients (e.g. bowl or nest) and see what happens!
130 |
131 | 
132 |
133 | ## Attetion Re-weighting
134 |
135 | To manipulate the relative importance of tokens, we've added an argument to pass in both the `text_to_image` and `text_to_image_ptp` methods. You can create an array of weights using our method `create_prompt_weights`.
136 |
137 | For example, you generated a pizza that doesn't have enough pineapple on it, you can edit the weights of your prompt:
138 |
139 | ```python
140 | prompt = "a photo of a pizza with pineapple"
141 | prompt_weights = generator.create_prompt_weights(prompt, [('pineapple', 2)])
142 | ```
143 |
144 | This will create an array with 1's except on the `pineapple` word position where it will be a 2.
145 |
146 | To generate a pizza with more pineapple (yak!), you just need to pass the variable `prompt_weights` to the `text_to_image` method:
147 |
148 | ```python
149 | img = generator.text_to_image(
150 | prompt="a photo of a pizza with pineapple",
151 | num_steps=50,
152 | unconditional_guidance_scale=8,
153 | prompt_weights=prompt_weights,
154 | seed=1234,
155 | batch_size=1,
156 | )
157 | ```
158 |
159 | 
160 |
161 | Now you want to reduce the amount of blossom in a tree:
162 |
163 | ```python
164 | prompt = "A photo of a blossom tree"
165 | prompt_weights = generator.create_prompt_weights(prompt, [('blossom', -1)])
166 |
167 | img = generator.text_to_image(
168 | prompt="A photo of a blossom tree",
169 | num_steps=50,
170 | unconditional_guidance_scale=8,
171 | prompt_weights=prompt_weights,
172 | seed=1407923,
173 | batch_size=1,
174 | )
175 | ```
176 |
177 | Decreasing the weight associated to `` will generate the following images.
178 |
179 | 
180 |
181 | ## Note about the cross-attention parameters
182 |
183 | For the prompt editing method, implemented in the function `text_to_image_ptp`, varying the parameters that indicate in which phase of the diffusion process the edited cross-attention maps should get injected (e.g. `cross_attn2_replace_steps_start`, `cross_attn1_replace_steps_start`), may output different results (image below).
184 |
185 | The cross-attention and prompt weights hyperparameters should be tuned according to the users' necessities and desired outputs.
186 |
187 | 
188 |
189 | More info in [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl#usage) and the [paper](https://arxiv.org/abs/2208.01626).
190 |
191 | # :ballot_box_with_check: TODO
192 |
193 | - [x] Add tutorials and Google Colabs.
194 | - [x] Add multi-batch support.
195 | - [ ] Add examples for Stable Diffusion 2.x.
196 |
197 | # 👨🎓 References
198 |
199 | - [keras-cv](https://github.com/keras-team/keras-cv/tree/master/keras_cv/models/generative/stable_diffusion) for the TensorFlow implementation of Stable Diffusion.
200 | - [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl) unofficial implementation of the paper, where the method `get_matching_sentence_tokens` and code logic were used.
201 | - [google/prompt-to-prompt](https://github.com/google/prompt-to-prompt) Official implementation of the paper in PyTorch.
202 |
203 | # 🔬 Contributing
204 |
205 | Feel free to open an [issue](https://github.com/miguelcalado/prompt-to-prompt-tensorflow/issues) or create a [Pull Request](https://github.com/miguelcalado/prompt-to-prompt-tensorflow/pulls).
206 |
207 | For PRs, after implementing the changes please run the `Makefile` for formatting and linting the submitted code:
208 |
209 | - `make init`: to create a python environment with all the developer packages (Optional).
210 | - `make format`: to format the code.
211 | - `make lint`: to lint the code.
212 | - `make type_check`: to check for type hints.
213 | - `make all`: to run all the checks.
214 |
215 | # :scroll: License
216 |
217 | Licensed under the Apache License 2.0. See [LICENSE](LICENSE) to read it in full.
218 |
--------------------------------------------------------------------------------
/ptp_utils.py:
--------------------------------------------------------------------------------
1 | """Utility methods used to implement Prompt-to-Prompt paper in TensorFlow.
2 |
3 | References
4 | ----------
5 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control."
6 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or.
7 | https://arxiv.org/abs/2208.01626
8 |
9 | Credits
10 | ----------
11 | - Unofficial implementation of the paper, where the method `get_matching_sentence_tokens`
12 | and code logic were used: [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl).
13 | """
14 |
15 | from typing import Tuple
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 | from keras_cv.models.stable_diffusion.diffusion_model import td_dot
20 | from tensorflow import keras
21 |
22 | import seq_aligner
23 |
24 | MAX_TEXT_LEN = 77
25 |
26 |
27 | def rename_cross_attention_layers(diff_model: tf.keras.Model):
28 | """Add suffix to the cross attention layers.
29 |
30 | This becomes useful when using the prompt editing method to save the
31 | attention maps and manipulate the control variables.
32 |
33 | Parameters
34 | ----------
35 | diff_model : tf.keras.Model
36 | Diffusion model.
37 |
38 | Returns
39 | -------
40 | tf.keras.Model
41 | Diffusion model with renamed crossed attention layers.
42 | """
43 | cross_attention_count = 0
44 | for submodule in diff_model.submodules:
45 | submodule_name = submodule.name
46 | if not cross_attention_count % 2 and "cross_attention" in submodule_name:
47 | submodule._name = f"{submodule_name}_attn1"
48 | cross_attention_count += 1
49 | elif cross_attention_count % 2 and "cross_attention" in submodule_name:
50 | submodule._name = f"{submodule_name}_attn2"
51 | cross_attention_count += 1
52 |
53 |
54 | def update_cross_attn_mode(
55 | diff_model: tf.keras.Model, mode: str, attn_suffix: str = "attn"
56 | ):
57 | """Update the mode control variable.
58 |
59 | Parameters
60 | ----------
61 | diff_model : tf.keras.Model
62 | Diffusion model.
63 | mode : str
64 | The mode parameter can take 3 values:
65 | - save: to save the attention map.
66 | - edit: to calculate the attention map with respect to the edited prompt.
67 | - unconditional: to perform the standard attention computations.
68 | attn_suffix : str, optional
69 | Suffix used to find the attention layer, by default "attn".
70 | """
71 | for submodule in diff_model.submodules:
72 | submodule_name = submodule.name
73 | if (
74 | "cross_attention" in submodule_name
75 | and attn_suffix in submodule_name.split("_")[-1]
76 | ):
77 | submodule.cross_attn_mode.assign(mode)
78 |
79 |
80 | def update_attn_weights_usage(diff_model: tf.keras.Model, use: bool):
81 | """Update the mode control variable.
82 |
83 | Parameters
84 | ----------
85 | diff_model : tf.keras.Model
86 | Diffusion model.
87 | use : bool
88 | Whether to use the prompt weights.
89 | """
90 | for submodule in diff_model.submodules:
91 | submodule_name = submodule.name
92 | if (
93 | "cross_attention" in submodule_name
94 | and "attn2" in submodule_name.split("_")[-1]
95 | ):
96 | submodule.use_prompt_weights.assign(use)
97 |
98 |
99 | def add_attn_weights(diff_model: tf.keras.Model, prompt_weights: np.ndarray):
100 | """Assign the attention weights to the diffusion model's corresponding tf.variable.
101 |
102 | Parameters
103 | ----------
104 | diff_model : tf.keras.Model
105 | Diffusion model.
106 | prompt_weights : List
107 | Weights of the attention tokens.
108 | """
109 | for submodule in diff_model.submodules:
110 | submodule_name = submodule.name
111 | if (
112 | "cross_attention" in submodule_name
113 | and "attn2" in submodule_name.split("_")[-1]
114 | ):
115 | submodule.prompt_weights.assign(prompt_weights)
116 |
117 |
118 | def put_mask_dif_model(
119 | diff_model: tf.keras.Model, mask: np.ndarray, indices: np.ndarray
120 | ):
121 | """Assign the diffusion model's tf.variables with the passed mask and indices.
122 |
123 | Parameters
124 | ----------
125 | diff_model : tf.keras.Model
126 | Diffusion model.
127 | mask : np.ndarray
128 | Mask of the original and edited prompt overlap.
129 | indices : np.ndarray
130 | Indices of the original and edited prompt overlap.
131 | """
132 | for submodule in diff_model.submodules:
133 | submodule_name = submodule.name
134 | if (
135 | "cross_attention" in submodule_name
136 | and "attn2" in submodule_name.split("_")[-1]
137 | ):
138 | submodule.prompt_edit_mask.assign(mask)
139 | submodule.prompt_edit_indices.assign(indices)
140 |
141 |
142 | def get_matching_sentence_tokens(
143 | prompt, prompt_edit, tokenizer
144 | ) -> Tuple[np.ndarray, np.ndarray]:
145 | """Create the mask and indices of the overlap between the tokens of the original \
146 | prompt and the edited one.
147 |
148 | Original code source: https://github.com/bloc97/CrossAttentionControl/
149 |
150 | Parameters
151 | ----------
152 | tokens : np.ndarray
153 | Array of the original prompt tokens.
154 | tokens_edit : np.ndarray
155 | Array of the edit prompt tokens.
156 |
157 | Returns
158 | -------
159 | Tuple[np.ndarray, np.ndarray]
160 | Mask and indices of the overlap between the original token and edit prompts.
161 | """
162 | tokens_conditional = tokenizer.encode(prompt)
163 | tokens_conditional_edit = tokenizer.encode(prompt_edit)
164 | mask, indices = seq_aligner.get_mapper(tokens_conditional, tokens_conditional_edit)
165 | return mask, indices
166 |
167 |
168 | def set_initial_tf_variables(diff_model: tf.keras.Model):
169 | """Create initial control variables to auxiliate the prompt editing method.
170 |
171 | Parameters
172 | ----------
173 | diff_model : tf.keras.Model
174 | Diffusion model.
175 | """
176 | for submodule in diff_model.submodules:
177 | submodule_name = submodule.name
178 | if "cross_attention" in submodule_name:
179 | # Set control variables
180 | submodule.cross_attn_mode = tf.Variable(
181 | "", dtype=tf.string, trainable=False
182 | )
183 | submodule.use_prompt_weights = tf.Variable(
184 | False, dtype=tf.bool, trainable=False
185 | )
186 | # Set array variables
187 | submodule.attn_map = tf.Variable(
188 | [], shape=tf.TensorShape(None), dtype=tf.float32, trainable=False
189 | )
190 | submodule.prompt_edit_mask = tf.Variable(
191 | [], shape=tf.TensorShape(None), dtype=tf.float32, trainable=False
192 | )
193 | submodule.prompt_edit_indices = tf.Variable(
194 | [], shape=tf.TensorShape(None), dtype=tf.int32, trainable=False
195 | )
196 | submodule.prompt_weights = tf.Variable(
197 | [], shape=tf.TensorShape(None), dtype=tf.float32, trainable=False
198 | )
199 |
200 |
201 | def reset_initial_tf_variables(diff_model: tf.keras.Model):
202 | """Reset the control variables to their default values.
203 |
204 | Parameters
205 | ----------
206 | diff_model : tf.keras.Model
207 | Diffusion model.
208 | """
209 | for submodule in diff_model.submodules:
210 | submodule_name = submodule.name
211 | if "cross_attention" in submodule_name:
212 | # Reset control variables
213 | submodule.cross_attn_mode.assign("")
214 | submodule.use_prompt_weights.assign(False)
215 | # Reset array variables
216 | submodule.attn_map.assign([])
217 | submodule.prompt_edit_mask.assign([])
218 | submodule.prompt_edit_indices.assign([])
219 | submodule.prompt_weights.assign([])
220 |
221 |
222 | def overwrite_forward_call(diff_model: tf.keras.Model):
223 | """Update the attention forward pass with a custom call method.
224 |
225 | Parameters
226 | ----------
227 | diff_model : tf.keras.Model
228 | Diffusion model.
229 | """
230 | for submodule in diff_model.submodules:
231 | submodule_name = submodule.name
232 | if "cross_attention" in submodule_name:
233 | # Overwrite forward pass method
234 | submodule.call = call_attn_edit.__get__(submodule)
235 |
236 |
237 | def call_attn_edit(self, inputs):
238 | """Implmentation of the custom attention forward pass used in the paper's method."""
239 | inputs, context = inputs
240 | context = inputs if context is None else context
241 | q, k, v = self.to_q(inputs), self.to_k(context), self.to_v(context)
242 | q = tf.reshape(q, (-1, inputs.shape[1], self.num_heads, self.head_size))
243 | k = tf.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size))
244 | v = tf.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size))
245 |
246 | q = tf.transpose(q, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
247 | k = tf.transpose(k, (0, 2, 3, 1)) # (bs, num_heads, head_size, time)
248 | v = tf.transpose(v, (0, 2, 1, 3)) # (bs, num_heads, time, head_size)
249 |
250 | score = td_dot(q, k) * self.scale
251 | weights = keras.activations.softmax(score) # (bs, num_heads, time, time)
252 |
253 | # Method: Prompt Refinement
254 | if tf.equal(self.cross_attn_mode, "edit") and tf.not_equal(
255 | tf.size(self.prompt_edit_mask), 0
256 | ): # not empty
257 | weights_masked = tf.gather(self.attn_map, self.prompt_edit_indices, axis=-1)
258 | edit_weights = weights_masked * self.prompt_edit_mask + weights * (
259 | 1 - self.prompt_edit_mask
260 | )
261 | weights = tf.reshape(edit_weights, shape=tf.shape(weights))
262 |
263 | # Use the attention from the original prompt (M_t)
264 | if tf.equal(self.cross_attn_mode, "use_last"):
265 | weights = tf.reshape(self.attn_map, shape=tf.shape(weights))
266 |
267 | # Save attention
268 | if tf.equal(self.cross_attn_mode, "save"):
269 | self.attn_map.assign(weights)
270 |
271 | # Method: Attention Re–weighting
272 | if tf.equal(self.use_prompt_weights, True) and tf.not_equal(
273 | tf.size(self.prompt_weights), 0
274 | ):
275 | attn_map_weighted = weights * self.prompt_weights
276 | weights = tf.reshape(attn_map_weighted, shape=tf.shape(weights))
277 |
278 | attn = td_dot(weights, v)
279 | attn = tf.transpose(attn, (0, 2, 1, 3)) # (bs, time, num_heads, head_size)
280 | out = tf.reshape(attn, (-1, inputs.shape[1], self.num_heads * self.head_size))
281 | return self.out_proj(out)
282 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==2.10.0
2 | tensorflow_datasets
3 | h5py==3.7.0
4 | Pillow==9.2.0
5 | tqdm==4.64.1
6 | ftfy==6.1.1
7 | regex==2022.9.13
8 | tensorflow-addons==0.17.1
9 | git+https://github.com/keras-team/keras-cv
--------------------------------------------------------------------------------
/requirements_dev.txt:
--------------------------------------------------------------------------------
1 | black
2 | isort
3 | mypy
4 | flake8
5 | nbqa
--------------------------------------------------------------------------------
/seq_aligner.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class ScoreParams:
5 | def __init__(self, gap, match, mismatch):
6 | self.gap = gap
7 | self.match = match
8 | self.mismatch = mismatch
9 |
10 | def mis_match_char(self, x, y):
11 | if x != y:
12 | return self.mismatch
13 | else:
14 | return self.match
15 |
16 |
17 | def get_mapper(x_seq: str, y_seq: str):
18 | score = ScoreParams(0, 1, -1)
19 | matrix, trace_back = global_align(x_seq, y_seq, score)
20 | mapper_base = get_aligned_sequences(x_seq, y_seq, trace_back)[-1]
21 | alphas = np.ones(77)
22 | alphas[: mapper_base.shape[0]] = (mapper_base[:, 1] >= 0).astype(float)
23 | mapper = np.zeros(77, dtype=int)
24 | mapper[: mapper_base.shape[0]] = mapper_base[:, 1]
25 | mapper[mapper_base.shape[0] :] = len(y_seq) + np.arange(77 - len(y_seq))
26 | return alphas, mapper
27 |
28 |
29 | def global_align(x, y, score):
30 | matrix = get_matrix(len(x), len(y), score.gap)
31 | trace_back = get_traceback_matrix(len(x), len(y))
32 | for i in range(1, len(x) + 1):
33 | for j in range(1, len(y) + 1):
34 | left = matrix[i, j - 1] + score.gap
35 | up = matrix[i - 1, j] + score.gap
36 | diag = matrix[i - 1, j - 1] + score.mis_match_char(x[i - 1], y[j - 1])
37 | matrix[i, j] = max(left, up, diag)
38 | if matrix[i, j] == left:
39 | trace_back[i, j] = 1
40 | elif matrix[i, j] == up:
41 | trace_back[i, j] = 2
42 | else:
43 | trace_back[i, j] = 3
44 | return matrix, trace_back
45 |
46 |
47 | def get_matrix(size_x, size_y, gap):
48 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
49 | matrix[0, 1:] = (np.arange(size_y) + 1) * gap
50 | matrix[1:, 0] = (np.arange(size_x) + 1) * gap
51 | return matrix
52 |
53 |
54 | def get_traceback_matrix(size_x, size_y):
55 | matrix = np.zeros((size_x + 1, size_y + 1), dtype=np.int32)
56 | matrix[0, 1:] = 1
57 | matrix[1:, 0] = 2
58 | matrix[0, 0] = 4
59 | return matrix
60 |
61 |
62 | def get_aligned_sequences(x, y, trace_back):
63 | x_seq = []
64 | y_seq = []
65 | i = len(x)
66 | j = len(y)
67 | mapper_y_to_x = []
68 | while i > 0 or j > 0:
69 | if trace_back[i, j] == 3:
70 | x_seq.append(x[i - 1])
71 | y_seq.append(y[j - 1])
72 | i = i - 1
73 | j = j - 1
74 | mapper_y_to_x.append((j, i))
75 | elif trace_back[i][j] == 1:
76 | x_seq.append("-")
77 | y_seq.append(y[j - 1])
78 | j = j - 1
79 | mapper_y_to_x.append((j, -1))
80 | elif trace_back[i][j] == 2:
81 | x_seq.append(x[i - 1])
82 | y_seq.append("-")
83 | i = i - 1
84 | elif trace_back[i][j] == 4:
85 | break
86 | mapper_y_to_x.reverse()
87 | return x_seq, y_seq, np.array(mapper_y_to_x)
88 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | max-line-length = 125
3 | max-complexity = 18
4 | docstring-convention = numpy
5 | show-source = True
6 | statistics = True
7 | count = True
8 | # https://www.flake8rules.com/
9 | ignore =
10 | # Too many leading ```#``` for a block comment
11 | E266,
12 | # Line break occurred before a binary operator
13 | W503,
14 | # Missing docstring in public module
15 | D100,
16 | # Whitespace before ':'
17 | E203,
18 | extend-exclude =
19 | ptp_dev/
20 |
21 | [isort]
22 | profile=black
23 |
24 | [mypy]
25 | check_untyped_defs = True
26 | warn_unused_configs = True
27 |
28 | [mypy-matplotlib.*]
29 | ignore_missing_imports = True
30 |
31 | [mypy-numpy.*]
32 | ignore_missing_imports = True
33 |
34 | [mypy-tensorflow.*]
35 | ignore_missing_imports = True
36 |
37 | [mypy-tqdm.*]
38 | ignore_missing_imports = True
39 |
40 | [mypy-keras_cv.*]
41 | ignore_missing_imports = True
--------------------------------------------------------------------------------
/stable_diffusion.py:
--------------------------------------------------------------------------------
1 | """TensorFlow/Keras implementation of Stable Diffusion and Prompt-to-Prompt papers.
2 |
3 | References
4 | ----------
5 | - "High-Resolution Image Synthesis With Latent Diffusion Models"
6 | Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bjorn
7 | Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)
8 | https://arxiv.org/abs/2112.10752
9 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control."
10 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or.
11 | https://arxiv.org/abs/2208.01626
12 |
13 | Credits
14 | ----------
15 | - [keras-cv](https://github.com/keras-team/keras-cv/tree/master/keras_cv/models/generative/stable_diffusion) \
16 | for the TensorFlow/Keras implementation of Stable Diffusion.
17 | - [bloc97/CrossAttentionControl](https://github.com/bloc97/CrossAttentionControl) unofficial implementation of \
18 | the paper, where the method `get_matching_sentence_tokens` and code logic were used.
19 | - [google/prompt-to-prompt](https://github.com/google/prompt-to-prompt) official implementation of the paper in PyTorch.
20 | """
21 |
22 | import math
23 | from typing import List, Optional, Tuple, Union
24 |
25 | import numpy as np
26 | import tensorflow as tf
27 | from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
28 | from keras_cv.models.stable_diffusion.constants import (
29 | _ALPHAS_CUMPROD,
30 | _UNCONDITIONAL_TOKENS,
31 | )
32 | from keras_cv.models.stable_diffusion.decoder import Decoder
33 | from keras_cv.models.stable_diffusion.diffusion_model import (
34 | DiffusionModel,
35 | DiffusionModelV2,
36 | )
37 | from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
38 | from keras_cv.models.stable_diffusion.text_encoder import TextEncoder, TextEncoderV2
39 | from tensorflow import keras
40 |
41 | import ptp_utils
42 |
43 | MAX_PROMPT_LENGTH = 77
44 | NUM_TRAIN_TIMESTEPS = 1000
45 |
46 |
47 | class StableDiffusionBase:
48 | """Implementation of Stable Diffusion and Prompt-to-Prompt papers in TensorFlow/Keras.
49 |
50 | Parameters
51 | ----------
52 | strategy : tf.distribute
53 | TensorFlow strategy for running computations across multiple devices.
54 | img_height : int, optional
55 | Image height, by default 512
56 | img_width : int, optional
57 | Image width, by default 512
58 | jit_compile : bool, optional
59 | Flag to compile the models to XLA, by default False.
60 | download_weights : bool, optional
61 | Flag to download the models weights, by default True.
62 |
63 | Examples
64 | --------
65 | >>> import tensorflow as tf
66 | >>> from PIL import Image
67 | >>> from stable_diffusion import StableDiffusion
68 | >>> strategy = tf.distribute.get_strategy() # To use only one GPU
69 | >>> generator = StableDiffusion(
70 | strategy=strategy,
71 | img_height=512,
72 | img_width=512,
73 | jit_compile=False,
74 | )
75 | >>> img = generator.text_to_image(
76 | prompt="teddy bear with sunglasses relaxing in a pool",
77 | num_steps=50,
78 | unconditional_guidance_scale=8,
79 | seed=3345435,
80 | batch_size=1,
81 | )
82 | >>> Image.fromarray(img[0]).save("original_prompt.png")
83 |
84 | Now lets edit the image to customize the teddy bear's sunglasses
85 |
86 | >>> img = generator.text_to_image_ptp(
87 | prompt="teddy bear with sunglasses relaxing in a pool",
88 | prompt_edit="teddy bear with heart-shaped red colored sunglasses relaxing in a pool",
89 | num_steps=50,
90 | unconditional_guidance_scale=8,
91 | cross_attn2_replace_steps_start=0.0,
92 | cross_attn2_replace_steps_end=1.0,
93 | cross_attn1_replace_steps_start=1.0,
94 | cross_attn1_replace_steps_end=1.0,
95 | seed=3345435,
96 | batch_size=1,from keras_cv.models.stable_diffusion.decoder import Decoder
97 | )
98 | >>> Image.fromarray(img[0]).save("edited_prompt.png")
99 | """
100 |
101 | def __init__(
102 | self,
103 | img_height: int = 512,
104 | img_width: int = 512,
105 | jit_compile: bool = False,
106 | ):
107 |
108 | # UNet requires multiples of 2**7 = 128
109 | img_height = round(img_height / 128) * 128
110 | img_width = round(img_width / 128) * 128
111 | self.img_height = img_height
112 | self.img_width = img_width
113 |
114 | # lazy initialize the component models and the tokenizer
115 | self._image_encoder = None
116 | self._text_encoder = None
117 | self._diffusion_model = None
118 | self._diffusion_model_ptp = None
119 | self._decoder = None
120 | self._tokenizer = None
121 |
122 | self.jit_compile = jit_compile
123 |
124 | def text_to_image(
125 | self,
126 | prompt: str,
127 | negative_prompt: Optional[str] = None,
128 | num_steps: int = 50,
129 | unconditional_guidance_scale: float = 7.5,
130 | batch_size: int = 1,
131 | seed: Optional[int] = None,
132 | ) -> np.ndarray:
133 | """Generate an image based on a prompt text.
134 |
135 | Parameters
136 | ----------
137 | prompt : str
138 | Text containing the information for the model to generate.
139 | negative_prompt : str
140 | A string containing information to negatively guide the image
141 | generation (e.g. by removing or altering certain aspects of the
142 | generated image).
143 | num_steps : int, optional
144 | Number of diffusion steps (controls image quality), by default 50.
145 | unconditional_guidance_scale : float, optional
146 | Controls how closely the image should adhere to the prompt, by default 7.5.
147 | batch_size : int, optional
148 | Batch size (number of images to generate), by default 1.
149 | seed : Optional[int], optional
150 | Number to seed the random noise, by default None.
151 |
152 | Returns
153 | -------
154 | np.ndarray
155 | Generated image.
156 | """
157 |
158 | # Tokenize and encode prompt
159 | encoded_text = self.encode_text(prompt)
160 |
161 | conditional_context = self._expand_tensor(encoded_text, batch_size)
162 |
163 | if negative_prompt is None:
164 | unconditional_context = tf.repeat(
165 | self._get_unconditional_context(), batch_size, axis=0
166 | )
167 | else:
168 | unconditional_context = self.encode_text(negative_prompt)
169 | unconditional_context = self._expand_tensor(
170 | unconditional_context, batch_size
171 | )
172 |
173 | # Get initial random noise
174 | latent = self._get_initial_diffusion_noise(batch_size, seed)
175 |
176 | # Scheduler
177 | timesteps = tf.range(1, 1000, 1000 // num_steps)
178 |
179 | # Get Initial parameters
180 | alphas, alphas_prev = self._get_initial_alphas(timesteps)
181 |
182 | progbar = keras.utils.Progbar(len(timesteps))
183 | iteration = 0
184 | # Diffusion stage
185 | for index, timestep in list(enumerate(timesteps))[::-1]:
186 |
187 | t_emb = self._get_timestep_embedding(timestep, batch_size)
188 |
189 | # Predict the unconditional noise residual
190 | unconditional_latent = self.diffusion_model.predict_on_batch(
191 | [latent, t_emb, unconditional_context]
192 | )
193 |
194 | # Predict the conditional noise residual
195 | conditional_latent = self.diffusion_model.predict_on_batch(
196 | [latent, t_emb, conditional_context]
197 | )
198 |
199 | # Perform guidance
200 | e_t = unconditional_latent + unconditional_guidance_scale * (
201 | conditional_latent - unconditional_latent
202 | )
203 |
204 | a_t, a_prev = alphas[index], alphas_prev[index]
205 | latent = self._get_x_prev(latent, e_t, a_t, a_prev)
206 |
207 | iteration += 1
208 | progbar.update(iteration)
209 |
210 | # Decode image
211 | img = self._get_decoding_stage(latent)
212 |
213 | return img
214 |
215 | def encode_text(self, prompt):
216 | """Encodes a prompt into a latent text encoding.
217 | The encoding produced by this method should be used as the
218 | `encoded_text` parameter of `StableDiffusion.generate_image`. Encoding
219 | text separately from generating an image can be used to arbitrarily
220 | modify the text encoding priot to image generation, e.g. for walking
221 | between two prompts.
222 | Args:
223 | prompt: a string to encode, must be 77 tokens or shorter.
224 | Example:
225 | ```python
226 | from keras_cv.models import StableDiffusion
227 | model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
228 | encoded_text = model.encode_text("Tacos at dawn")
229 | img = model.generate_image(encoded_text)
230 | ```
231 | """
232 | # Tokenize prompt (i.e. starting context)
233 | inputs = self.tokenizer.encode(prompt)
234 | if len(inputs) > MAX_PROMPT_LENGTH:
235 | raise ValueError(
236 | f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
237 | )
238 | phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
239 | phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
240 |
241 | context = self.text_encoder.predict_on_batch([phrase, self._get_pos_ids()])
242 |
243 | return context
244 |
245 | def text_to_image_ptp(
246 | self,
247 | prompt: str,
248 | prompt_edit: str,
249 | method: str,
250 | self_attn_steps: Union[float, Tuple[float, float]],
251 | cross_attn_steps: Union[float, Tuple[float, float]],
252 | attn_edit_weights: np.ndarray = np.array([]),
253 | negative_prompt: Optional[str] = None,
254 | num_steps: int = 50,
255 | unconditional_guidance_scale: float = 7.5,
256 | batch_size: int = 1,
257 | seed: Optional[int] = None,
258 | ) -> np.ndarray:
259 | """Generate an image based on the Prompt-to-Prompt editing method.
260 |
261 | Edit a generated image controlled only through text.
262 | Paper: https://arxiv.org/abs/2208.01626
263 |
264 | Parameters
265 | ----------
266 | prompt : str
267 | Text containing the information for the model to generate.
268 | prompt_edit : str
269 | Second prompt used to control the edit of the generated image.
270 | method : str
271 | Prompt-to-Prompt method to chose. Can be ['refine', 'replace', 'reweigh'].
272 | self_attn_steps : Union[float, Tuple[float, float]]
273 | Specifies at which step of the start of the diffusion process it replaces
274 | the self attention maps with the ones produced by the edited prompt.
275 | cross_attn_steps : Union[float, Tuple[float, float]]
276 | Specifies at which step of the start of the diffusion process it replaces
277 | the cross attention maps with the ones produced by the edited prompt.
278 | attn_edit_weights: np.array([]), optional
279 | Set of weights for each edit prompt token.
280 | This is used for manipulating the importance of the edit prompt tokens,
281 | increasing or decreasing the importance assigned to each word.
282 | negative_prompt : Optional[str] = None
283 | A string containing information to negatively guide the image
284 | generation (e.g. by removing or altering certain aspects of the
285 | generated image).
286 | num_steps : int, optional
287 | Number of diffusion steps (controls image quality), by default 50.
288 | unconditional_guidance_scale : float, optional
289 | Controls how closely the image should adhere to the prompt, by default 7.5.
290 | batch_size : int, optional
291 | Batch size (number of images to generate), by default 1.
292 | seed : Optional[int], optional
293 | Number to seed the random noise, by default None.
294 |
295 | Returns
296 | -------
297 | np.ndarray
298 | Generated image with edited prompt method.
299 |
300 | Examples
301 | --------
302 | >>> import tensorflow as tf
303 | >>> from PIL import Image
304 | >>> from stable_diffusion import StableDiffusion
305 | >>> strategy = tf.distribute.get_strategy() # To use only one GPU
306 | >>> generator = StableDiffusion(
307 | strategy=strategy,
308 | img_height=512,
309 | img_width=512,
310 | jit_compile=False,
311 | )
312 |
313 | Edit the original generated image by adding heart-shaped red colored to the sunglasses.
314 |
315 | >>> img = generator.text_to_image_ptp(
316 | prompt="teddy bear with sunglasses relaxing in a pool",
317 | prompt_edit="teddy bear with heart-shaped red colored sunglasses relaxing in a pool",
318 | num_steps=50,
319 | unconditional_guidance_scale=8,
320 | self_attn_steps=0.0,
321 | cross_attn_steps=1.0,
322 | seed=3345435,
323 | batch_size=1,
324 | )
325 | >>> Image.fromarray(img[0]).save("edited_prompt.png")
326 | """
327 |
328 | # Prompt-to-Prompt: check inputs
329 | if isinstance(self_attn_steps, float):
330 | self_attn_steps = (0.0, self_attn_steps)
331 | if isinstance(cross_attn_steps, float):
332 | cross_attn_steps = (0.0, cross_attn_steps)
333 |
334 | # Tokenize and encode prompt
335 | encoded_text = self.encode_text(prompt)
336 | conditional_context = self._expand_tensor(encoded_text, batch_size)
337 |
338 | # Tokenize and encode edit prompt
339 | encoded_text_edit = self.encode_text(prompt_edit)
340 | conditional_context_edit = self._expand_tensor(encoded_text_edit, batch_size)
341 |
342 | if negative_prompt is None:
343 | unconditional_context = tf.repeat(
344 | self._get_unconditional_context(), batch_size, axis=0
345 | )
346 | else:
347 | unconditional_context = self.encode_text(negative_prompt)
348 | unconditional_context = self._expand_tensor(
349 | unconditional_context, batch_size
350 | )
351 |
352 | if method == "refine":
353 | # Get the mask and indices of the difference between the original prompt token's and the edited one
354 | mask, indices = ptp_utils.get_matching_sentence_tokens(
355 | prompt, prompt_edit, self.tokenizer
356 | )
357 | # Add the mask and indices to the diffusion model
358 | ptp_utils.put_mask_dif_model(self.diffusion_model_ptp, mask, indices)
359 |
360 | # Update prompt weights variable
361 | if attn_edit_weights.size:
362 | ptp_utils.add_attn_weights(
363 | diff_model=self.diffusion_model_ptp, prompt_weights=attn_edit_weights
364 | )
365 |
366 | # Get initial random noise
367 | latent = self._get_initial_diffusion_noise(batch_size, seed)
368 |
369 | # Scheduler
370 | timesteps = tf.range(1, 1000, 1000 // num_steps)
371 |
372 | # Get Initial parameters
373 | alphas, alphas_prev = self._get_initial_alphas(timesteps)
374 |
375 | progbar = keras.utils.Progbar(len(timesteps))
376 | iteration = 0
377 | # Diffusion stage
378 | for index, timestep in list(enumerate(timesteps))[::-1]:
379 |
380 | t_emb = self._get_timestep_embedding(timestep, batch_size)
381 |
382 | # Change this!
383 | t_scale = 1 - (timestep / NUM_TRAIN_TIMESTEPS)
384 |
385 | # Update Cross-Attention mode to 'unconditional'
386 | ptp_utils.update_cross_attn_mode(
387 | diff_model=self.diffusion_model_ptp, mode="unconditional"
388 | )
389 |
390 | # Predict the unconditional noise residual
391 | unconditional_latent = self.diffusion_model_ptp.predict_on_batch(
392 | [latent, t_emb, unconditional_context]
393 | )
394 |
395 | # Save last cross attention activations
396 | ptp_utils.update_cross_attn_mode(
397 | diff_model=self.diffusion_model_ptp, mode="save"
398 | )
399 | # Predict the conditional noise residual
400 | _ = self.diffusion_model_ptp.predict_on_batch(
401 | [latent, t_emb, conditional_context]
402 | )
403 |
404 | # Edit the Cross-Attention layer activations
405 | if cross_attn_steps[0] <= t_scale <= cross_attn_steps[1]:
406 | if method == "replace":
407 | # Use cross attention from the original prompt (M_t)
408 | ptp_utils.update_cross_attn_mode(
409 | diff_model=self.diffusion_model_ptp,
410 | mode="use_last",
411 | attn_suffix="attn2",
412 | )
413 | elif method == "refine":
414 | # Use cross attention with function A(J)
415 | ptp_utils.update_cross_attn_mode(
416 | diff_model=self.diffusion_model_ptp,
417 | mode="edit",
418 | attn_suffix="attn2",
419 | )
420 | if method == "reweight" or attn_edit_weights.size:
421 | # Use the parsed weights on the edited prompt
422 | ptp_utils.update_attn_weights_usage(
423 | diff_model=self.diffusion_model_ptp, use=True
424 | )
425 |
426 | else:
427 | # Use cross attention from the edited prompt (M^*_t)
428 | ptp_utils.update_cross_attn_mode(
429 | diff_model=self.diffusion_model_ptp,
430 | mode="injection",
431 | attn_suffix="attn2",
432 | )
433 |
434 | # Edit the self-Attention layer activations
435 | if self_attn_steps[0] <= t_scale <= self_attn_steps[1]:
436 | # Use self attention from the original prompt (M_t)
437 | ptp_utils.update_cross_attn_mode(
438 | diff_model=self.diffusion_model_ptp,
439 | mode="use_last",
440 | attn_suffix="attn1",
441 | )
442 | else:
443 | # Use self attention from the edited prompt (M^*_t)
444 | ptp_utils.update_cross_attn_mode(
445 | diff_model=self.diffusion_model_ptp,
446 | mode="injection",
447 | attn_suffix="attn1",
448 | )
449 |
450 | # Predict the edited conditional noise residual
451 | conditional_latent_edit = self.diffusion_model_ptp.predict_on_batch(
452 | [latent, t_emb, conditional_context_edit],
453 | )
454 |
455 | # Assign usage to False so it doesn't get used in other contexts
456 | if attn_edit_weights.size:
457 | ptp_utils.update_attn_weights_usage(
458 | diff_model=self.diffusion_model_ptp, use=False
459 | )
460 |
461 | # Perform guidance
462 | e_t = unconditional_latent + unconditional_guidance_scale * (
463 | conditional_latent_edit - unconditional_latent
464 | )
465 |
466 | a_t, a_prev = alphas[index], alphas_prev[index]
467 | latent = self._get_x_prev(latent, e_t, a_t, a_prev)
468 |
469 | iteration += 1
470 | progbar.update(iteration)
471 |
472 | # Decode image
473 | img = self._get_decoding_stage(latent)
474 |
475 | # Reset control variables
476 | ptp_utils.reset_initial_tf_variables(self.diffusion_model_ptp)
477 |
478 | return img
479 |
480 | def _get_unconditional_context(self):
481 | unconditional_tokens = tf.convert_to_tensor(
482 | [_UNCONDITIONAL_TOKENS], dtype=tf.int32
483 | )
484 | unconditional_context = self.text_encoder.predict_on_batch(
485 | [unconditional_tokens, self._get_pos_ids()]
486 | )
487 |
488 | return unconditional_context
489 |
490 | def tokenize_prompt(self, prompt: str) -> tf.Tensor:
491 | """Tokenize a phrase prompt.
492 |
493 | Parameters
494 | ----------
495 | prompt : str
496 | The prompt string to tokenize, must be 77 tokens or shorter.
497 | batch_size : int
498 | Batch size.
499 |
500 | Returns
501 | -------
502 | np.ndarray
503 | Array of tokens.
504 | """
505 | inputs = self.tokenizer.encode(prompt)
506 | if len(inputs) > MAX_PROMPT_LENGTH:
507 | raise ValueError(
508 | f"Prompt is too long (should be <= {MAX_PROMPT_LENGTH} tokens)"
509 | )
510 | phrase = inputs + [49407] * (MAX_PROMPT_LENGTH - len(inputs))
511 | phrase = tf.convert_to_tensor([phrase], dtype=tf.int32)
512 | return phrase
513 |
514 | def create_prompt_weights(
515 | self, prompt: str, prompt_weights: List[Tuple[str, float]]
516 | ) -> np.ndarray:
517 | """Create an array of weights for each prompt token.
518 |
519 | This is used for manipulating the importance of the prompt tokens,
520 | increasing or decreasing the importance assigned to each word.
521 |
522 | Parameters
523 | ----------
524 | prompt : str
525 | The prompt string to tokenize, must be 77 tokens or shorter.
526 | prompt_weights : List[Tuple[str, float]]
527 | A list of tuples containing the pair of word and weight to be manipulated.
528 | batch_size : int
529 | Batch size.
530 |
531 | Returns
532 | -------
533 | np.ndarray
534 | Array of weights to control the importance of each prompt token.
535 | """
536 |
537 | # Initialize the weights to 1.
538 | weights = np.ones(MAX_PROMPT_LENGTH)
539 |
540 | # Get the prompt tokens
541 | tokens = self.tokenize_prompt(prompt)
542 |
543 | # Extract the new weights and tokens
544 | edit_weights = [weight for word, weight in prompt_weights]
545 | edit_tokens = [
546 | self.tokenizer.encode(word)[1:-1] for word, weight in prompt_weights
547 | ]
548 |
549 | # Get the indexes of the tokens
550 | index_edit_tokens = np.in1d(tokens, edit_tokens).nonzero()[0]
551 |
552 | # Replace the original weight values
553 | weights[index_edit_tokens] = edit_weights
554 | return weights
555 |
556 | def _expand_tensor(self, text_embedding, batch_size):
557 | """Extends a tensor by repeating it to fit the shape of the given batch size."""
558 | text_embedding = tf.squeeze(text_embedding)
559 | if text_embedding.shape.rank == 2:
560 | text_embedding = tf.repeat(
561 | tf.expand_dims(text_embedding, axis=0), batch_size, axis=0
562 | )
563 | return text_embedding
564 |
565 | def _get_initial_alphas(self, timesteps):
566 |
567 | alphas = [_ALPHAS_CUMPROD[t] for t in timesteps]
568 | alphas_prev = [1.0] + alphas[:-1]
569 |
570 | return alphas, alphas_prev
571 |
572 | def _get_initial_diffusion_noise(self, batch_size: int, seed: Optional[int]):
573 | return tf.random.normal(
574 | (batch_size, self.img_height // 8, self.img_width // 8, 4), seed=seed
575 | )
576 |
577 | def _get_timestep_embedding(self, timestep, batch_size, dim=320, max_period=10000):
578 | half = dim // 2
579 | freqs = tf.math.exp(
580 | -math.log(max_period) * tf.range(0, half, dtype=tf.float32) / half
581 | )
582 | args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
583 | embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
584 | embedding = tf.reshape(embedding, [1, -1])
585 | return tf.repeat(embedding, batch_size, axis=0)
586 |
587 | def _get_decoding_stage(self, latent):
588 | decoded = self.decoder.predict_on_batch(latent)
589 | decoded = ((decoded + 1) / 2) * 255
590 | return np.clip(decoded, 0, 255).astype("uint8")
591 |
592 | def _get_x_prev(self, x, e_t, a_t, a_prev):
593 | sqrt_one_minus_at = math.sqrt(1 - a_t)
594 | pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
595 | # Direction pointing to x_t
596 | dir_xt = math.sqrt(1.0 - a_prev) * e_t
597 | x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt
598 | return x_prev
599 |
600 | @staticmethod
601 | def _get_pos_ids():
602 | return tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
603 |
604 | @property
605 | def image_encoder(self):
606 | """image_encoder returns the VAE Encoder with pretrained weights.
607 | Usage:
608 | ```python
609 | sd = keras_cv.models.StableDiffusion()
610 | my_image = np.ones((512, 512, 3))
611 | latent_representation = sd.image_encoder.predict(my_image)
612 | ```
613 | """
614 | if self._image_encoder is None:
615 | self._image_encoder = ImageEncoder(self.img_height, self.img_width)
616 | if self.jit_compile:
617 | self._image_encoder.compile(jit_compile=True)
618 | return self._image_encoder
619 |
620 | @property
621 | def text_encoder(self):
622 | pass
623 |
624 | @property
625 | def diffusion_model(self):
626 | pass
627 |
628 | @property
629 | def decoder(self):
630 | """decoder returns the diffusion image decoder model with pretrained weights.
631 | Can be overriden for tasks where the decoder needs to be modified.
632 | """
633 | if self._decoder is None:
634 | self._decoder = Decoder(self.img_height, self.img_width)
635 | if self.jit_compile:
636 | self._decoder.compile(jit_compile=True)
637 | return self._decoder
638 |
639 | @property
640 | def tokenizer(self):
641 | """tokenizer returns the tokenizer used for text inputs.
642 | Can be overriden for tasks like textual inversion where the tokenizer needs to be modified.
643 | """
644 | if self._tokenizer is None:
645 | self._tokenizer = SimpleTokenizer()
646 | return self._tokenizer
647 |
648 |
649 | class StableDiffusion(StableDiffusionBase):
650 | """Keras implementation of Stable Diffusion.
651 |
652 | Note that the StableDiffusion API, as well as the APIs of the sub-components
653 | of StableDiffusion (e.g. ImageEncoder, DiffusionModel) should be considered
654 | unstable at this point. We do not guarantee backwards compatability for
655 | future changes to these APIs.
656 | Stable Diffusion is a powerful image generation model that can be used,
657 | among other things, to generate pictures according to a short text description
658 | (called a "prompt").
659 | Arguments:
660 | img_height: Height of the images to generate, in pixel. Note that only
661 | multiples of 128 are supported; the value provided will be rounded
662 | to the nearest valid value. Default: 512.
663 | img_width: Width of the images to generate, in pixel. Note that only
664 | multiples of 128 are supported; the value provided will be rounded
665 | to the nearest valid value. Default: 512.
666 | jit_compile: Whether to compile the underlying models to XLA.
667 | This can lead to a significant speedup on some systems. Default: False.
668 | Example:
669 | ```python
670 | from keras_cv.models import StableDiffusion
671 | from PIL import Image
672 | model = StableDiffusion(img_height=512, img_width=512, jit_compile=True)
673 | img = model.text_to_image(
674 | prompt="A beautiful horse running through a field",
675 | batch_size=1, # How many images to generate at once
676 | num_steps=25, # Number of iterations (controls image quality)
677 | seed=123, # Set this to always get the same image from the same prompt
678 | )
679 | Image.fromarray(img[0]).save("horse.png")
680 | print("saved at horse.png")
681 | ```
682 | References:
683 | - [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
684 | - [Original implementation](https://github.com/CompVis/stable-diffusion)
685 | """
686 |
687 | def __init__(
688 | self,
689 | img_height=512,
690 | img_width=512,
691 | jit_compile=False,
692 | ):
693 | super().__init__(img_height, img_width, jit_compile)
694 | print(
695 | "By using this model checkpoint, you acknowledge that its usage is "
696 | "subject to the terms of the CreativeML Open RAIL-M license at "
697 | "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE"
698 | )
699 |
700 | @property
701 | def text_encoder(self):
702 | """text_encoder returns the text encoder with pretrained weights.
703 | Can be overriden for tasks like textual inversion where the text encoder
704 | needs to be modified.
705 | """
706 | if self._text_encoder is None:
707 | self._text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
708 | if self.jit_compile:
709 | self._text_encoder.compile(jit_compile=True)
710 | return self._text_encoder
711 |
712 | @property
713 | def diffusion_model(self) -> tf.keras.Model:
714 | """diffusion_model returns the diffusion model with pretrained weights.
715 | Can be overriden for tasks where the diffusion model needs to be modified.
716 | """
717 | if self._diffusion_model is None:
718 | self._diffusion_model = DiffusionModel(
719 | self.img_height, self.img_width, MAX_PROMPT_LENGTH
720 | )
721 | if self.jit_compile:
722 | self._diffusion_model.compile(jit_compile=True)
723 | return self._diffusion_model
724 |
725 | @property
726 | def diffusion_model_ptp(self) -> tf.keras.Model:
727 | """diffusion_model_ptp returns the diffusion model with modifications for the Prompt-to-Prompt method.
728 |
729 | References
730 | ----------
731 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control."
732 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or.
733 | https://arxiv.org/abs/2208.01626
734 | """
735 | if self._diffusion_model_ptp is None:
736 | if self._diffusion_model is None:
737 | self._diffusion_model_ptp = self.diffusion_model
738 | else:
739 | # Reset the graph and add/overwrite variables and forward calls
740 | self._diffusion_model.compile(jit_compile=self.jit_compile)
741 | self._diffusion_model_ptp = self._diffusion_model
742 |
743 | # Add extra variables and callbacks
744 | ptp_utils.rename_cross_attention_layers(self._diffusion_model_ptp)
745 | ptp_utils.overwrite_forward_call(self._diffusion_model_ptp)
746 | ptp_utils.set_initial_tf_variables(self._diffusion_model_ptp)
747 |
748 | return self._diffusion_model_ptp
749 |
750 |
751 | class StableDiffusionV2(StableDiffusionBase):
752 | """Keras implementation of Stable Diffusion v2.
753 | Note that the StableDiffusion API, as well as the APIs of the sub-components
754 | of StableDiffusionV2 (e.g. ImageEncoder, DiffusionModelV2) should be considered
755 | unstable at this point. We do not guarantee backwards compatability for
756 | future changes to these APIs.
757 | Stable Diffusion is a powerful image generation model that can be used,
758 | among other things, to generate pictures according to a short text description
759 | (called a "prompt").
760 | Arguments:
761 | img_height: Height of the images to generate, in pixel. Note that only
762 | multiples of 128 are supported; the value provided will be rounded
763 | to the nearest valid value. Default: 512.
764 | img_width: Width of the images to generate, in pixel. Note that only
765 | multiples of 128 are supported; the value provided will be rounded
766 | to the nearest valid value. Default: 512.
767 | jit_compile: Whether to compile the underlying models to XLA.
768 | This can lead to a significant speedup on some systems. Default: False.
769 | Example:
770 | ```python
771 | from keras_cv.models import StableDiffusionV2
772 | from PIL import Image
773 | model = StableDiffusionV2(img_height=512, img_width=512, jit_compile=True)
774 | img = model.text_to_image(
775 | prompt="A beautiful horse running through a field",
776 | batch_size=1, # How many images to generate at once
777 | num_steps=25, # Number of iterations (controls image quality)
778 | seed=123, # Set this to always get the same image from the same prompt
779 | )
780 | Image.fromarray(img[0]).save("horse.png")
781 | print("saved at horse.png")
782 | ```
783 | References:
784 | - [About Stable Diffusion](https://stability.ai/blog/stable-diffusion-announcement)
785 | - [Original implementation](https://github.com/Stability-AI/stablediffusion)
786 | """
787 |
788 | def __init__(
789 | self,
790 | img_height=512,
791 | img_width=512,
792 | jit_compile=False,
793 | ):
794 | super().__init__(img_height, img_width, jit_compile)
795 | print(
796 | "By using this model checkpoint, you acknowledge that its usage is "
797 | "subject to the terms of the CreativeML Open RAIL++-M license at "
798 | "https://github.com/Stability-AI/stablediffusion/main/LICENSE-MODEL"
799 | )
800 |
801 | @property
802 | def text_encoder(self):
803 | """text_encoder returns the text encoder with pretrained weights.
804 | Can be overriden for tasks like textual inversion where the text encoder
805 | needs to be modified.
806 | """
807 | if self._text_encoder is None:
808 | self._text_encoder = TextEncoderV2(MAX_PROMPT_LENGTH)
809 | if self.jit_compile:
810 | self._text_encoder.compile(jit_compile=True)
811 | return self._text_encoder
812 |
813 | @property
814 | def diffusion_model(self) -> tf.keras.Model:
815 | """diffusion_model returns the diffusion model with pretrained weights.
816 | Can be overriden for tasks where the diffusion model needs to be modified.
817 | """
818 | if self._diffusion_model is None:
819 | self._diffusion_model = DiffusionModelV2(
820 | self.img_height, self.img_width, MAX_PROMPT_LENGTH
821 | )
822 | if self.jit_compile:
823 | self._diffusion_model.compile(jit_compile=True)
824 | return self._diffusion_model
825 |
826 | @property
827 | def diffusion_model_ptp(self) -> tf.keras.Model:
828 | """diffusion_model_ptp returns the diffusion model with modifications for the Prompt-to-Prompt method.
829 |
830 | References
831 | ----------
832 | - "Prompt-to-Prompt Image Editing with Cross-Attention Control."
833 | Amir Hertz, Ron Mokady, Jay Tenenbaum, Kfir Aberman, Yael Pritch, Daniel Cohen-Or.
834 | https://arxiv.org/abs/2208.01626
835 | """
836 | if self._diffusion_model_ptp is None:
837 | if self._diffusion_model is None:
838 | self._diffusion_model_ptp = self.diffusion_model()
839 | else:
840 | # Reset the graph - this is to save up memory
841 | self._diffusion_model.compile(jit_compile=self.jit_compile)
842 | self._diffusion_model_ptp = self._diffusion_model
843 |
844 | # Add extra variables and callbacks
845 | ptp_utils.rename_cross_attention_layers(self._diffusion_model_ptp)
846 | ptp_utils.overwrite_forward_call(self._diffusion_model_ptp)
847 | ptp_utils.set_initial_tf_variables(self._diffusion_model_ptp)
848 |
849 | return self._diffusion_model_ptp
850 |
--------------------------------------------------------------------------------