├── .gitignore ├── LICENSE ├── README.md ├── WEIGHTS_LICENSE ├── interactive_text2image.py ├── requirements.txt ├── requirements_m1.txt ├── setup.py ├── stable_diffusion_tf ├── __init__.py ├── autoencoder_kl.py ├── clip_encoder.py ├── clip_tokenizer │ ├── __init__.py │ └── bpe_simple_vocab_16e6.txt.gz ├── constants.py ├── diffusion_model.py ├── layers.py └── stable_diffusion.py └── text2image.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 105 | __pypackages__/ 106 | 107 | # Celery stuff 108 | celerybeat-schedule 109 | celerybeat.pid 110 | 111 | # SageMath parsed files 112 | *.sage.py 113 | 114 | # Environments 115 | .env 116 | .venv 117 | env/ 118 | venv/ 119 | ENV/ 120 | env.bak/ 121 | venv.bak/ 122 | 123 | # Spyder project settings 124 | .spyderproject 125 | .spyproject 126 | 127 | # Rope project settings 128 | .ropeproject 129 | 130 | # mkdocs documentation 131 | /site 132 | 133 | # mypy 134 | .mypy_cache/ 135 | .dmypy.json 136 | dmypy.json 137 | 138 | # Pyre type checker 139 | .pyre/ 140 | 141 | # pytype static type analyzer 142 | .pytype/ 143 | 144 | # Cython debug symbols 145 | cython_debug/ 146 | 147 | # PyCharm 148 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 149 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 150 | # and can be added to the global gitignore or merged into this file. For a more nuclear 151 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 152 | #.idea/ 153 | 154 | 155 | # General 156 | .DS_Store 157 | .AppleDouble 158 | .LSOverride 159 | 160 | # Icon must end with two \r 161 | Icon 162 | 163 | # Thumbnails 164 | ._* 165 | 166 | # Files that might appear in the root of a volume 167 | .DocumentRevisions-V100 168 | .fseventsd 169 | .Spotlight-V100 170 | .TemporaryItems 171 | .Trashes 172 | .VolumeIcon.icns 173 | .com.apple.timemachine.donotpresent 174 | 175 | # Directories potentially created on remote AFP share 176 | .AppleDB 177 | .AppleDesktop 178 | Network Trash Folder 179 | Temporary Items 180 | .apdisk 181 | 182 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The code in this repository is under the Apache 2.0 license (see below). 2 | The Stable Diffusion weights (not included as part of this repository) are under the CreativeML Open RAIL-M license (see WEIGHTS_LICENSE file in the repository). 3 | 4 | 5 | Apache License 6 | Version 2.0, January 2004 7 | http://www.apache.org/licenses/ 8 | 9 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 10 | 11 | 1. Definitions. 12 | 13 | "License" shall mean the terms and conditions for use, reproduction, 14 | and distribution as defined by Sections 1 through 9 of this document. 15 | 16 | "Licensor" shall mean the copyright owner or entity authorized by 17 | the copyright owner that is granting the License. 18 | 19 | "Legal Entity" shall mean the union of the acting entity and all 20 | other entities that control, are controlled by, or are under common 21 | control with that entity. For the purposes of this definition, 22 | "control" means (i) the power, direct or indirect, to cause the 23 | direction or management of such entity, whether by contract or 24 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 25 | outstanding shares, or (iii) beneficial ownership of such entity. 26 | 27 | "You" (or "Your") shall mean an individual or Legal Entity 28 | exercising permissions granted by this License. 29 | 30 | "Source" form shall mean the preferred form for making modifications, 31 | including but not limited to software source code, documentation 32 | source, and configuration files. 33 | 34 | "Object" form shall mean any form resulting from mechanical 35 | transformation or translation of a Source form, including but 36 | not limited to compiled object code, generated documentation, 37 | and conversions to other media types. 38 | 39 | "Work" shall mean the work of authorship, whether in Source or 40 | Object form, made available under the License, as indicated by a 41 | copyright notice that is included in or attached to the work 42 | (an example is provided in the Appendix below). 43 | 44 | "Derivative Works" shall mean any work, whether in Source or Object 45 | form, that is based on (or derived from) the Work and for which the 46 | editorial revisions, annotations, elaborations, or other modifications 47 | represent, as a whole, an original work of authorship. For the purposes 48 | of this License, Derivative Works shall not include works that remain 49 | separable from, or merely link (or bind by name) to the interfaces of, 50 | the Work and Derivative Works thereof. 51 | 52 | "Contribution" shall mean any work of authorship, including 53 | the original version of the Work and any modifications or additions 54 | to that Work or Derivative Works thereof, that is intentionally 55 | submitted to Licensor for inclusion in the Work by the copyright owner 56 | or by an individual or Legal Entity authorized to submit on behalf of 57 | the copyright owner. For the purposes of this definition, "submitted" 58 | means any form of electronic, verbal, or written communication sent 59 | to the Licensor or its representatives, including but not limited to 60 | communication on electronic mailing lists, source code control systems, 61 | and issue tracking systems that are managed by, or on behalf of, the 62 | Licensor for the purpose of discussing and improving the Work, but 63 | excluding communication that is conspicuously marked or otherwise 64 | designated in writing by the copyright owner as "Not a Contribution." 65 | 66 | "Contributor" shall mean Licensor and any individual or Legal Entity 67 | on behalf of whom a Contribution has been received by Licensor and 68 | subsequently incorporated within the Work. 69 | 70 | 2. Grant of Copyright License. Subject to the terms and conditions of 71 | this License, each Contributor hereby grants to You a perpetual, 72 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 73 | copyright license to reproduce, prepare Derivative Works of, 74 | publicly display, publicly perform, sublicense, and distribute the 75 | Work and such Derivative Works in Source or Object form. 76 | 77 | 3. Grant of Patent License. Subject to the terms and conditions of 78 | this License, each Contributor hereby grants to You a perpetual, 79 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 80 | (except as stated in this section) patent license to make, have made, 81 | use, offer to sell, sell, import, and otherwise transfer the Work, 82 | where such license applies only to those patent claims licensable 83 | by such Contributor that are necessarily infringed by their 84 | Contribution(s) alone or by combination of their Contribution(s) 85 | with the Work to which such Contribution(s) was submitted. If You 86 | institute patent litigation against any entity (including a 87 | cross-claim or counterclaim in a lawsuit) alleging that the Work 88 | or a Contribution incorporated within the Work constitutes direct 89 | or contributory patent infringement, then any patent licenses 90 | granted to You under this License for that Work shall terminate 91 | as of the date such litigation is filed. 92 | 93 | 4. Redistribution. You may reproduce and distribute copies of the 94 | Work or Derivative Works thereof in any medium, with or without 95 | modifications, and in Source or Object form, provided that You 96 | meet the following conditions: 97 | 98 | (a) You must give any other recipients of the Work or 99 | Derivative Works a copy of this License; and 100 | 101 | (b) You must cause any modified files to carry prominent notices 102 | stating that You changed the files; and 103 | 104 | (c) You must retain, in the Source form of any Derivative Works 105 | that You distribute, all copyright, patent, trademark, and 106 | attribution notices from the Source form of the Work, 107 | excluding those notices that do not pertain to any part of 108 | the Derivative Works; and 109 | 110 | (d) If the Work includes a "NOTICE" text file as part of its 111 | distribution, then any Derivative Works that You distribute must 112 | include a readable copy of the attribution notices contained 113 | within such NOTICE file, excluding those notices that do not 114 | pertain to any part of the Derivative Works, in at least one 115 | of the following places: within a NOTICE text file distributed 116 | as part of the Derivative Works; within the Source form or 117 | documentation, if provided along with the Derivative Works; or, 118 | within a display generated by the Derivative Works, if and 119 | wherever such third-party notices normally appear. The contents 120 | of the NOTICE file are for informational purposes only and 121 | do not modify the License. You may add Your own attribution 122 | notices within Derivative Works that You distribute, alongside 123 | or as an addendum to the NOTICE text from the Work, provided 124 | that such additional attribution notices cannot be construed 125 | as modifying the License. 126 | 127 | You may add Your own copyright statement to Your modifications and 128 | may provide additional or different license terms and conditions 129 | for use, reproduction, or distribution of Your modifications, or 130 | for any such Derivative Works as a whole, provided Your use, 131 | reproduction, and distribution of the Work otherwise complies with 132 | the conditions stated in this License. 133 | 134 | 5. Submission of Contributions. Unless You explicitly state otherwise, 135 | any Contribution intentionally submitted for inclusion in the Work 136 | by You to the Licensor shall be under the terms and conditions of 137 | this License, without any additional terms or conditions. 138 | Notwithstanding the above, nothing herein shall supersede or modify 139 | the terms of any separate license agreement you may have executed 140 | with Licensor regarding such Contributions. 141 | 142 | 6. Trademarks. This License does not grant permission to use the trade 143 | names, trademarks, service marks, or product names of the Licensor, 144 | except as required for reasonable and customary use in describing the 145 | origin of the Work and reproducing the content of the NOTICE file. 146 | 147 | 7. Disclaimer of Warranty. Unless required by applicable law or 148 | agreed to in writing, Licensor provides the Work (and each 149 | Contributor provides its Contributions) on an "AS IS" BASIS, 150 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 151 | implied, including, without limitation, any warranties or conditions 152 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 153 | PARTICULAR PURPOSE. You are solely responsible for determining the 154 | appropriateness of using or redistributing the Work and assume any 155 | risks associated with Your exercise of permissions under this License. 156 | 157 | 8. Limitation of Liability. In no event and under no legal theory, 158 | whether in tort (including negligence), contract, or otherwise, 159 | unless required by applicable law (such as deliberate and grossly 160 | negligent acts) or agreed to in writing, shall any Contributor be 161 | liable to You for damages, including any direct, indirect, special, 162 | incidental, or consequential damages of any character arising as a 163 | result of this License or out of the use or inability to use the 164 | Work (including but not limited to damages for loss of goodwill, 165 | work stoppage, computer failure or malfunction, or any and all 166 | other commercial damages or losses), even if such Contributor 167 | has been advised of the possibility of such damages. 168 | 169 | 9. Accepting Warranty or Additional Liability. While redistributing 170 | the Work or Derivative Works thereof, You may choose to offer, 171 | and charge a fee for, acceptance of support, warranty, indemnity, 172 | or other liability obligations and/or rights consistent with this 173 | License. However, in accepting such obligations, You may act only 174 | on Your own behalf and on Your sole responsibility, not on behalf 175 | of any other Contributor, and only if You agree to indemnify, 176 | defend, and hold each Contributor harmless for any liability 177 | incurred by, or claims asserted against, such Contributor by reason 178 | of your accepting any such warranty or additional liability. 179 | 180 | END OF TERMS AND CONDITIONS 181 | 182 | APPENDIX: How to apply the Apache License to your work. 183 | 184 | To apply the Apache License to your work, attach the following 185 | boilerplate notice, with the fields enclosed by brackets "[]" 186 | replaced with your own identifying information. (Don't include 187 | the brackets!) The text should be enclosed in the appropriate 188 | comment syntax for the file format. We also recommend that a 189 | file or class name and description of purpose be included on the 190 | same "printed page" as the copyright notice for easier 191 | identification within third-party archives. 192 | 193 | Copyright [yyyy] [name of copyright owner] 194 | 195 | Licensed under the Apache License, Version 2.0 (the "License"); 196 | you may not use this file except in compliance with the License. 197 | You may obtain a copy of the License at 198 | 199 | http://www.apache.org/licenses/LICENSE-2.0 200 | 201 | Unless required by applicable law or agreed to in writing, software 202 | distributed under the License is distributed on an "AS IS" BASIS, 203 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 204 | See the License for the specific language governing permissions and 205 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion in TensorFlow / Keras 2 | 3 | A Keras / Tensorflow implementation of Stable Diffusion. 4 | 5 | This is a fork of [stable-diffusion-tensorflow](https://github.com/divamgupta/stable-diffusion-tensorflow) 6 | created by @divamgupta. The weights were ported from the original implementation. 7 | 8 | 9 | ## Usage 10 | 11 | 1) Try it out with [this GPU Colab](https://colab.research.google.com/drive/1zVTa4mLeM_w44WaFwl7utTaa6JcaH1zK). 12 | 13 | 2) Using the command line : 14 | 15 | ``` 16 | python text2image.py --prompt="An astronaut riding a horse" 17 | ``` 18 | 19 | 3) Using the python interface: 20 | 21 | ``` 22 | pip install git+https://github.com/fchollet/stable-diffusion-tensorflow 23 | ``` 24 | 25 | ```python 26 | from stable_diffusion_tf.stable_diffusion import Text2Image 27 | from PIL import Image 28 | 29 | generator = Text2Image( 30 | img_height=512, 31 | img_width=512, 32 | jit_compile=False, 33 | ) 34 | img = generator.generate( 35 | "An astronaut riding a horse", 36 | num_steps=50, 37 | unconditional_guidance_scale=7.5, 38 | temperature=1, 39 | batch_size=1, 40 | ) 41 | Image.fromarray(img[0]).save("output.png") 42 | ``` 43 | 44 | ## Example outputs 45 | 46 | The following outputs have been generated using the this implementation: 47 | 48 | 1) *A epic and beautiful rococo werewolf drinking coffee, in a burning coffee shop. ultra-detailed. anime, pixiv, uhd 8k cryengine, octane render* 49 | 50 | ![a](https://user-images.githubusercontent.com/1890549/190841598-3d0b9bd1-d679-4c8d-bd5e-b1e24397b5c8.png) 51 | 52 | 53 | 2) *Spider-Gwen Gwen-Stacy Skyscraper Pink White Pink-White Spiderman Photo-realistic 4K* 54 | 55 | ![a](https://user-images.githubusercontent.com/1890549/190841999-689c9c38-ece4-46a0-ad85-f459ec64c5b8.png) 56 | 57 | 58 | 3) *A vision of paradise, Unreal Engine* 59 | 60 | ![a](https://user-images.githubusercontent.com/1890549/190841886-239406ea-72cb-4570-8f4c-fcd074a7ad7f.png) 61 | 62 | 63 | ## References 64 | 65 | 1) https://github.com/CompVis/stable-diffusion 66 | 2) https://github.com/geohot/tinygrad/blob/master/examples/stable_diffusion.py 67 | -------------------------------------------------------------------------------- /WEIGHTS_LICENSE: -------------------------------------------------------------------------------- 1 | The Stable Diffusion weights (not included as part of this repository) are under the following license (CreativeML Open RAIL-M): 2 | 3 | https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE 4 | 5 | The weights content has not been modified as part of this reimplementation (only reformatted). 6 | -------------------------------------------------------------------------------- /interactive_text2image.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from stable_diffusion_tf.stable_diffusion import Text2Image 3 | import argparse 4 | from PIL import Image 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument( 9 | "--H", 10 | type=int, 11 | default=512, 12 | help="Image height, in pixels", 13 | ) 14 | 15 | parser.add_argument( 16 | "--W", 17 | type=int, 18 | default=512, 19 | help="Image width, in pixels", 20 | ) 21 | 22 | parser.add_argument( 23 | "--mp", 24 | default=False, 25 | action="store_true", 26 | help="Enable mixed precision (fp16 computation)", 27 | ) 28 | 29 | parser.add_argument( 30 | "--jit", 31 | default=False, 32 | action="store_true", 33 | help="Enable XLA compilation", 34 | ) 35 | 36 | parser.add_argument( 37 | "--scale", 38 | type=float, 39 | default=7.5, 40 | help="Unconditional guidance scale", 41 | ) 42 | 43 | parser.add_argument("--steps", type=int, default=50, help="Number of diffusion steps") 44 | 45 | parser.add_argument( 46 | "--seed", 47 | type=int, 48 | help="Optionally specify a seed integer for reproducible results", 49 | ) 50 | 51 | parser.add_argument( 52 | "--batch_size", 53 | type=int, 54 | default=1, 55 | help="How many images to generate", 56 | ) 57 | 58 | args = parser.parse_args() 59 | 60 | if args.mp: 61 | print("Using mixed precision.") 62 | keras.mixed_precision.set_global_policy("mixed_float16") 63 | 64 | generator = Text2Image(img_height=args.H, img_width=args.W, jit_compile=args.jit) 65 | 66 | while True: 67 | prompt = input("Enter prompt (or enter 'exit' to exit):") 68 | if prompt == "exit": 69 | break 70 | fname = input("Enter file name (where to save the results):") 71 | 72 | print( 73 | f"Generating {args.batch_size} image{'' if args.batch_size == 1 else 's'} for prompt '{prompt}'" 74 | ) 75 | img = generator.generate( 76 | prompt, 77 | num_steps=args.steps, 78 | unconditional_guidance_scale=args.scale, 79 | temperature=1, 80 | batch_size=args.batch_size, 81 | seed=args.seed, 82 | ) 83 | 84 | if fname.endswith(".png"): 85 | fname = fname[:-4] 86 | if args.batch_size == 1: 87 | Image.fromarray(img[0]).save(f"{fname}.png") 88 | print(f"saved at {fname}.png") 89 | else: 90 | for i in range(args.batch_size): 91 | fname_i = f"{fname}_{i}.png" 92 | Image.fromarray(img[i]).save(fname_i) 93 | print(f"Saved at {fname_i}") 94 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==2.10.0 2 | tensorflow-addons==0.18.0 3 | h5py==3.7.0 4 | Pillow==9.2.0 5 | tqdm==4.64.1 6 | ftfy==6.1.1 7 | regex==2022.9.13 8 | -------------------------------------------------------------------------------- /requirements_m1.txt: -------------------------------------------------------------------------------- 1 | tensorflow-macos==2.10.0 2 | tensorflow-metal==0.6.0 3 | tensorflow_addons==0.17.1 4 | h5py==3.7.0 5 | Pillow==9.2.0 6 | tqdm==4.64.1 7 | protobuf==3.19 8 | ftfy==6.1.1 9 | regex==2022.9.13 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="stable_diffusion_tf", 5 | version="0.1", 6 | description="Stable Diffusion in Tensorflow / Keras", 7 | author="Divam Gupta", 8 | author_email="guptadivam@gmail.com", 9 | platforms=["any"], # or more specific, e.g. "win32", "cygwin", "osx" 10 | url="https://github.com/divamgupta/stable-diffusion-tensorflow", 11 | packages=find_packages(), 12 | ) 13 | -------------------------------------------------------------------------------- /stable_diffusion_tf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fchollet/stable-diffusion-tensorflow/7e9cb598e19518c8d7116c0b7c0f17db39543466/stable_diffusion_tf/__init__.py -------------------------------------------------------------------------------- /stable_diffusion_tf/autoencoder_kl.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import tensorflow_addons as tfa 4 | 5 | from .layers import apply_seq, PaddedConv2D 6 | 7 | 8 | class AttentionBlock(keras.layers.Layer): 9 | def __init__(self, channels): 10 | super().__init__() 11 | self.norm = tfa.layers.GroupNormalization(epsilon=1e-5) 12 | self.q = PaddedConv2D(channels, 1) 13 | self.k = PaddedConv2D(channels, 1) 14 | self.v = PaddedConv2D(channels, 1) 15 | self.proj_out = PaddedConv2D(channels, 1) 16 | 17 | def call(self, x): 18 | h_ = self.norm(x) 19 | q, k, v = self.q(h_), self.k(h_), self.v(h_) 20 | 21 | # Compute attention 22 | b, h, w, c = q.shape 23 | q = tf.reshape(q, (-1, h * w, c)) # b,hw,c 24 | k = keras.layers.Permute((3, 1, 2))(k) 25 | k = tf.reshape(k, (-1, c, h * w)) # b,c,hw 26 | w_ = q @ k 27 | w_ = w_ * (c ** (-0.5)) 28 | w_ = keras.activations.softmax(w_) 29 | 30 | # Attend to values 31 | v = keras.layers.Permute((3, 1, 2))(v) 32 | v = tf.reshape(v, (-1, c, h * w)) 33 | w_ = keras.layers.Permute((2, 1))(w_) 34 | h_ = v @ w_ 35 | h_ = keras.layers.Permute((2, 1))(h_) 36 | h_ = tf.reshape(h_, (-1, h, w, c)) 37 | return x + self.proj_out(h_) 38 | 39 | 40 | class ResnetBlock(keras.layers.Layer): 41 | def __init__(self, in_channels, out_channels): 42 | super().__init__() 43 | self.norm1 = tfa.layers.GroupNormalization(epsilon=1e-5) 44 | self.conv1 = PaddedConv2D(out_channels, 3, padding=1) 45 | self.norm2 = tfa.layers.GroupNormalization(epsilon=1e-5) 46 | self.conv2 = PaddedConv2D(out_channels, 3, padding=1) 47 | self.nin_shortcut = ( 48 | PaddedConv2D(out_channels, 1) 49 | if in_channels != out_channels 50 | else lambda x: x 51 | ) 52 | 53 | def call(self, x): 54 | h = self.conv1(keras.activations.swish(self.norm1(x))) 55 | h = self.conv2(keras.activations.swish(self.norm2(h))) 56 | return self.nin_shortcut(x) + h 57 | 58 | 59 | class Decoder(keras.Sequential): 60 | def __init__(self): 61 | super().__init__( 62 | [ 63 | keras.layers.Lambda(lambda x: 1 / 0.18215 * x), 64 | PaddedConv2D(4, 1), 65 | PaddedConv2D(512, 3, padding=1), 66 | ResnetBlock(512, 512), 67 | AttentionBlock(512), 68 | ResnetBlock(512, 512), 69 | ResnetBlock(512, 512), 70 | ResnetBlock(512, 512), 71 | ResnetBlock(512, 512), 72 | keras.layers.UpSampling2D(size=(2, 2)), 73 | PaddedConv2D(512, 3, padding=1), 74 | ResnetBlock(512, 512), 75 | ResnetBlock(512, 512), 76 | ResnetBlock(512, 512), 77 | keras.layers.UpSampling2D(size=(2, 2)), 78 | PaddedConv2D(512, 3, padding=1), 79 | ResnetBlock(512, 256), 80 | ResnetBlock(256, 256), 81 | ResnetBlock(256, 256), 82 | keras.layers.UpSampling2D(size=(2, 2)), 83 | PaddedConv2D(256, 3, padding=1), 84 | ResnetBlock(256, 128), 85 | ResnetBlock(128, 128), 86 | ResnetBlock(128, 128), 87 | tfa.layers.GroupNormalization(epsilon=1e-5), 88 | keras.layers.Activation("swish"), 89 | PaddedConv2D(3, 3, padding=1), 90 | ] 91 | ) 92 | -------------------------------------------------------------------------------- /stable_diffusion_tf/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import tensorflow_addons as tfa 4 | import numpy as np 5 | 6 | from .layers import quick_gelu 7 | 8 | 9 | class CLIPAttention(keras.layers.Layer): 10 | def __init__(self): 11 | super().__init__() 12 | self.embed_dim = 768 13 | self.num_heads = 12 14 | self.head_dim = self.embed_dim // self.num_heads 15 | self.scale = self.head_dim**-0.5 16 | self.q_proj = keras.layers.Dense(self.embed_dim) 17 | self.k_proj = keras.layers.Dense(self.embed_dim) 18 | self.v_proj = keras.layers.Dense(self.embed_dim) 19 | self.out_proj = keras.layers.Dense(self.embed_dim) 20 | 21 | def _shape(self, tensor, seq_len: int, bsz: int): 22 | a = tf.reshape(tensor, (bsz, seq_len, self.num_heads, self.head_dim)) 23 | return keras.layers.Permute((2, 1, 3))(a) # bs , n_head , seq_len , head_dim 24 | 25 | def call(self, inputs): 26 | hidden_states, causal_attention_mask = inputs 27 | bsz, tgt_len, embed_dim = hidden_states.shape 28 | query_states = self.q_proj(hidden_states) * self.scale 29 | key_states = self._shape(self.k_proj(hidden_states), tgt_len, -1) 30 | value_states = self._shape(self.v_proj(hidden_states), tgt_len, -1) 31 | 32 | proj_shape = (-1, tgt_len, self.head_dim) 33 | query_states = self._shape(query_states, tgt_len, -1) 34 | query_states = tf.reshape(query_states, proj_shape) 35 | key_states = tf.reshape(key_states, proj_shape) 36 | 37 | src_len = tgt_len 38 | value_states = tf.reshape(value_states, proj_shape) 39 | attn_weights = query_states @ keras.layers.Permute((2, 1))(key_states) 40 | 41 | attn_weights = tf.reshape(attn_weights, (-1, self.num_heads, tgt_len, src_len)) 42 | attn_weights = attn_weights + causal_attention_mask 43 | attn_weights = tf.reshape(attn_weights, (-1, tgt_len, src_len)) 44 | 45 | attn_weights = tf.nn.softmax(attn_weights) 46 | attn_output = attn_weights @ value_states 47 | 48 | attn_output = tf.reshape( 49 | attn_output, (-1, self.num_heads, tgt_len, self.head_dim) 50 | ) 51 | attn_output = keras.layers.Permute((2, 1, 3))(attn_output) 52 | attn_output = tf.reshape(attn_output, (-1, tgt_len, embed_dim)) 53 | 54 | return self.out_proj(attn_output) 55 | 56 | 57 | class CLIPEncoderLayer(keras.layers.Layer): 58 | def __init__(self): 59 | super().__init__() 60 | self.layer_norm1 = keras.layers.LayerNormalization(epsilon=1e-5) 61 | self.self_attn = CLIPAttention() 62 | self.layer_norm2 = keras.layers.LayerNormalization(epsilon=1e-5) 63 | self.fc1 = keras.layers.Dense(3072) 64 | self.fc2 = keras.layers.Dense(768) 65 | 66 | def call(self, inputs): 67 | hidden_states, causal_attention_mask = inputs 68 | residual = hidden_states 69 | 70 | hidden_states = self.layer_norm1(hidden_states) 71 | hidden_states = self.self_attn([hidden_states, causal_attention_mask]) 72 | hidden_states = residual + hidden_states 73 | 74 | residual = hidden_states 75 | hidden_states = self.layer_norm2(hidden_states) 76 | 77 | hidden_states = self.fc1(hidden_states) 78 | hidden_states = quick_gelu(hidden_states) 79 | hidden_states = self.fc2(hidden_states) 80 | 81 | return residual + hidden_states 82 | 83 | 84 | class CLIPEncoder(keras.layers.Layer): 85 | def __init__(self): 86 | super().__init__() 87 | self.layers = [CLIPEncoderLayer() for i in range(12)] 88 | 89 | def call(self, inputs): 90 | [hidden_states, causal_attention_mask] = inputs 91 | for l in self.layers: 92 | hidden_states = l([hidden_states, causal_attention_mask]) 93 | return hidden_states 94 | 95 | 96 | class CLIPTextEmbeddings(keras.layers.Layer): 97 | def __init__(self, n_words=77): 98 | super().__init__() 99 | self.token_embedding_layer = keras.layers.Embedding( 100 | 49408, 768, name="token_embedding" 101 | ) 102 | self.position_embedding_layer = keras.layers.Embedding( 103 | n_words, 768, name="position_embedding" 104 | ) 105 | 106 | def call(self, inputs): 107 | input_ids, position_ids = inputs 108 | word_embeddings = self.token_embedding_layer(input_ids) 109 | position_embeddings = self.position_embedding_layer(position_ids) 110 | return word_embeddings + position_embeddings 111 | 112 | 113 | class CLIPTextTransformer(keras.models.Model): 114 | def __init__(self, n_words=77): 115 | super().__init__() 116 | self.embeddings = CLIPTextEmbeddings(n_words=n_words) 117 | self.encoder = CLIPEncoder() 118 | self.final_layer_norm = keras.layers.LayerNormalization(epsilon=1e-5) 119 | self.causal_attention_mask = tf.constant( 120 | np.triu(np.ones((1, 1, 77, 77), dtype="float32") * -np.inf, k=1) 121 | ) 122 | 123 | def call(self, inputs): 124 | input_ids, position_ids = inputs 125 | x = self.embeddings([input_ids, position_ids]) 126 | x = self.encoder([x, self.causal_attention_mask]) 127 | return self.final_layer_norm(x) 128 | -------------------------------------------------------------------------------- /stable_diffusion_tf/clip_tokenizer/__init__.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | import tensorflow as tf 10 | from tensorflow import keras 11 | 12 | 13 | @lru_cache() 14 | def default_bpe(): 15 | p = os.path.join( 16 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 17 | ) 18 | if os.path.exists(p): 19 | return p 20 | else: 21 | return keras.utils.get_file( 22 | "bpe_simple_vocab_16e6.txt.gz", 23 | "https://github.com/openai/CLIP/blob/main/clip/bpe_simple_vocab_16e6.txt.gz?raw=true", 24 | ) 25 | 26 | 27 | @lru_cache() 28 | def bytes_to_unicode(): 29 | """ 30 | Returns list of utf-8 byte and a corresponding list of unicode strings. 31 | The reversible bpe codes work on unicode strings. 32 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 33 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 34 | This is a signficant percentage of your normal, say, 32K bpe vocab. 35 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 36 | And avoids mapping to whitespace/control characters the bpe code barfs on. 37 | """ 38 | bs = ( 39 | list(range(ord("!"), ord("~") + 1)) 40 | + list(range(ord("¡"), ord("¬") + 1)) 41 | + list(range(ord("®"), ord("ÿ") + 1)) 42 | ) 43 | cs = bs[:] 44 | n = 0 45 | for b in range(2**8): 46 | if b not in bs: 47 | bs.append(b) 48 | cs.append(2**8 + n) 49 | n += 1 50 | cs = [chr(n) for n in cs] 51 | return dict(zip(bs, cs)) 52 | 53 | 54 | def get_pairs(word): 55 | """Return set of symbol pairs in a word. 56 | Word is represented as tuple of symbols (symbols being variable-length strings). 57 | """ 58 | pairs = set() 59 | prev_char = word[0] 60 | for char in word[1:]: 61 | pairs.add((prev_char, char)) 62 | prev_char = char 63 | return pairs 64 | 65 | 66 | def basic_clean(text): 67 | text = ftfy.fix_text(text) 68 | text = html.unescape(html.unescape(text)) 69 | return text.strip() 70 | 71 | 72 | def whitespace_clean(text): 73 | text = re.sub(r"\s+", " ", text) 74 | text = text.strip() 75 | return text 76 | 77 | 78 | class SimpleTokenizer(object): 79 | def __init__(self, bpe_path: str = default_bpe()): 80 | self.byte_encoder = bytes_to_unicode() 81 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 82 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 83 | merges = merges[1 : 49152 - 256 - 2 + 1] 84 | merges = [tuple(merge.split()) for merge in merges] 85 | vocab = list(bytes_to_unicode().values()) 86 | vocab = vocab + [v + "" for v in vocab] 87 | for merge in merges: 88 | vocab.append("".join(merge)) 89 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 90 | self.encoder = dict(zip(vocab, range(len(vocab)))) 91 | self.decoder = {v: k for k, v in self.encoder.items()} 92 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 93 | self.cache = { 94 | "<|startoftext|>": "<|startoftext|>", 95 | "<|endoftext|>": "<|endoftext|>", 96 | } 97 | self.pat = re.compile( 98 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 99 | re.IGNORECASE, 100 | ) 101 | 102 | def bpe(self, token): 103 | if token in self.cache: 104 | return self.cache[token] 105 | word = tuple(token[:-1]) + (token[-1] + "",) 106 | pairs = get_pairs(word) 107 | 108 | if not pairs: 109 | return token + "" 110 | 111 | while True: 112 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 113 | if bigram not in self.bpe_ranks: 114 | break 115 | first, second = bigram 116 | new_word = [] 117 | i = 0 118 | while i < len(word): 119 | try: 120 | j = word.index(first, i) 121 | new_word.extend(word[i:j]) 122 | i = j 123 | except: 124 | new_word.extend(word[i:]) 125 | break 126 | 127 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 128 | new_word.append(first + second) 129 | i += 2 130 | else: 131 | new_word.append(word[i]) 132 | i += 1 133 | new_word = tuple(new_word) 134 | word = new_word 135 | if len(word) == 1: 136 | break 137 | else: 138 | pairs = get_pairs(word) 139 | word = " ".join(word) 140 | self.cache[token] = word 141 | return word 142 | 143 | def encode(self, text): 144 | bpe_tokens = [] 145 | text = whitespace_clean(basic_clean(text)).lower() 146 | for token in re.findall(self.pat, text): 147 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 148 | bpe_tokens.extend( 149 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 150 | ) 151 | return [49406] + bpe_tokens + [49407] 152 | 153 | def decode(self, tokens): 154 | text = "".join([self.decoder[token] for token in tokens]) 155 | text = ( 156 | bytearray([self.byte_decoder[c] for c in text]) 157 | .decode("utf-8", errors="replace") 158 | .replace("", " ") 159 | ) 160 | return text 161 | -------------------------------------------------------------------------------- /stable_diffusion_tf/clip_tokenizer/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fchollet/stable-diffusion-tensorflow/7e9cb598e19518c8d7116c0b7c0f17db39543466/stable_diffusion_tf/clip_tokenizer/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /stable_diffusion_tf/constants.py: -------------------------------------------------------------------------------- 1 | _UNCONDITIONAL_TOKENS = [ 2 | 49406, 3 | 49407, 4 | 49407, 5 | 49407, 6 | 49407, 7 | 49407, 8 | 49407, 9 | 49407, 10 | 49407, 11 | 49407, 12 | 49407, 13 | 49407, 14 | 49407, 15 | 49407, 16 | 49407, 17 | 49407, 18 | 49407, 19 | 49407, 20 | 49407, 21 | 49407, 22 | 49407, 23 | 49407, 24 | 49407, 25 | 49407, 26 | 49407, 27 | 49407, 28 | 49407, 29 | 49407, 30 | 49407, 31 | 49407, 32 | 49407, 33 | 49407, 34 | 49407, 35 | 49407, 36 | 49407, 37 | 49407, 38 | 49407, 39 | 49407, 40 | 49407, 41 | 49407, 42 | 49407, 43 | 49407, 44 | 49407, 45 | 49407, 46 | 49407, 47 | 49407, 48 | 49407, 49 | 49407, 50 | 49407, 51 | 49407, 52 | 49407, 53 | 49407, 54 | 49407, 55 | 49407, 56 | 49407, 57 | 49407, 58 | 49407, 59 | 49407, 60 | 49407, 61 | 49407, 62 | 49407, 63 | 49407, 64 | 49407, 65 | 49407, 66 | 49407, 67 | 49407, 68 | 49407, 69 | 49407, 70 | 49407, 71 | 49407, 72 | 49407, 73 | 49407, 74 | 49407, 75 | 49407, 76 | 49407, 77 | 49407, 78 | 49407, 79 | ] 80 | _ALPHAS_CUMPROD = [ 81 | 0.99915, 82 | 0.998296, 83 | 0.9974381, 84 | 0.9965762, 85 | 0.99571025, 86 | 0.9948404, 87 | 0.9939665, 88 | 0.9930887, 89 | 0.9922069, 90 | 0.9913211, 91 | 0.9904313, 92 | 0.98953754, 93 | 0.9886398, 94 | 0.9877381, 95 | 0.9868324, 96 | 0.98592263, 97 | 0.98500896, 98 | 0.9840913, 99 | 0.9831696, 100 | 0.982244, 101 | 0.98131436, 102 | 0.9803808, 103 | 0.97944313, 104 | 0.97850156, 105 | 0.977556, 106 | 0.9766064, 107 | 0.97565293, 108 | 0.9746954, 109 | 0.9737339, 110 | 0.9727684, 111 | 0.97179896, 112 | 0.97082555, 113 | 0.96984816, 114 | 0.96886677, 115 | 0.9678814, 116 | 0.96689206, 117 | 0.96589875, 118 | 0.9649015, 119 | 0.96390027, 120 | 0.9628951, 121 | 0.9618859, 122 | 0.96087277, 123 | 0.95985574, 124 | 0.95883465, 125 | 0.9578097, 126 | 0.95678073, 127 | 0.95574784, 128 | 0.954711, 129 | 0.95367026, 130 | 0.9526256, 131 | 0.9515769, 132 | 0.95052433, 133 | 0.94946784, 134 | 0.94840735, 135 | 0.947343, 136 | 0.94627476, 137 | 0.9452025, 138 | 0.9441264, 139 | 0.9430464, 140 | 0.9419625, 141 | 0.9408747, 142 | 0.939783, 143 | 0.9386874, 144 | 0.93758786, 145 | 0.9364845, 146 | 0.93537724, 147 | 0.9342661, 148 | 0.9331511, 149 | 0.9320323, 150 | 0.9309096, 151 | 0.929783, 152 | 0.9286526, 153 | 0.9275183, 154 | 0.9263802, 155 | 0.92523825, 156 | 0.92409253, 157 | 0.92294294, 158 | 0.9217895, 159 | 0.92063236, 160 | 0.9194713, 161 | 0.9183065, 162 | 0.9171379, 163 | 0.91596556, 164 | 0.9147894, 165 | 0.9136095, 166 | 0.91242576, 167 | 0.9112383, 168 | 0.9100471, 169 | 0.9088522, 170 | 0.9076535, 171 | 0.9064511, 172 | 0.90524495, 173 | 0.9040351, 174 | 0.90282154, 175 | 0.9016043, 176 | 0.90038335, 177 | 0.8991587, 178 | 0.8979304, 179 | 0.8966984, 180 | 0.89546275, 181 | 0.89422345, 182 | 0.8929805, 183 | 0.89173394, 184 | 0.89048374, 185 | 0.88922995, 186 | 0.8879725, 187 | 0.8867115, 188 | 0.88544685, 189 | 0.88417864, 190 | 0.88290685, 191 | 0.8816315, 192 | 0.88035256, 193 | 0.8790701, 194 | 0.87778413, 195 | 0.8764946, 196 | 0.8752016, 197 | 0.873905, 198 | 0.87260497, 199 | 0.8713014, 200 | 0.8699944, 201 | 0.86868393, 202 | 0.86737, 203 | 0.8660526, 204 | 0.8647318, 205 | 0.86340755, 206 | 0.8620799, 207 | 0.8607488, 208 | 0.85941434, 209 | 0.8580765, 210 | 0.8567353, 211 | 0.8553907, 212 | 0.8540428, 213 | 0.85269153, 214 | 0.85133696, 215 | 0.84997904, 216 | 0.84861785, 217 | 0.8472533, 218 | 0.8458856, 219 | 0.8445145, 220 | 0.84314024, 221 | 0.84176266, 222 | 0.8403819, 223 | 0.8389979, 224 | 0.8376107, 225 | 0.8362203, 226 | 0.83482677, 227 | 0.83343, 228 | 0.8320301, 229 | 0.8306271, 230 | 0.8292209, 231 | 0.82781166, 232 | 0.82639927, 233 | 0.8249838, 234 | 0.82356524, 235 | 0.8221436, 236 | 0.82071894, 237 | 0.81929123, 238 | 0.81786054, 239 | 0.8164268, 240 | 0.8149901, 241 | 0.8135504, 242 | 0.81210774, 243 | 0.81066215, 244 | 0.8092136, 245 | 0.8077621, 246 | 0.80630773, 247 | 0.80485046, 248 | 0.8033903, 249 | 0.80192727, 250 | 0.8004614, 251 | 0.79899275, 252 | 0.79752123, 253 | 0.7960469, 254 | 0.7945698, 255 | 0.7930899, 256 | 0.79160726, 257 | 0.7901219, 258 | 0.7886338, 259 | 0.787143, 260 | 0.7856495, 261 | 0.7841533, 262 | 0.78265446, 263 | 0.78115296, 264 | 0.7796488, 265 | 0.77814204, 266 | 0.7766327, 267 | 0.7751208, 268 | 0.7736063, 269 | 0.77208924, 270 | 0.7705697, 271 | 0.7690476, 272 | 0.767523, 273 | 0.7659959, 274 | 0.7644664, 275 | 0.76293445, 276 | 0.7614, 277 | 0.7598632, 278 | 0.75832397, 279 | 0.75678235, 280 | 0.75523835, 281 | 0.75369203, 282 | 0.7521434, 283 | 0.75059247, 284 | 0.7490392, 285 | 0.7474837, 286 | 0.7459259, 287 | 0.7443659, 288 | 0.74280363, 289 | 0.7412392, 290 | 0.7396726, 291 | 0.7381038, 292 | 0.73653287, 293 | 0.7349598, 294 | 0.7333846, 295 | 0.73180735, 296 | 0.730228, 297 | 0.7286466, 298 | 0.7270631, 299 | 0.7254777, 300 | 0.72389024, 301 | 0.72230077, 302 | 0.7207094, 303 | 0.71911603, 304 | 0.7175208, 305 | 0.7159236, 306 | 0.71432453, 307 | 0.7127236, 308 | 0.71112084, 309 | 0.7095162, 310 | 0.7079098, 311 | 0.7063016, 312 | 0.70469165, 313 | 0.70307994, 314 | 0.7014665, 315 | 0.69985133, 316 | 0.6982345, 317 | 0.696616, 318 | 0.6949958, 319 | 0.69337404, 320 | 0.69175065, 321 | 0.69012564, 322 | 0.6884991, 323 | 0.68687093, 324 | 0.6852413, 325 | 0.68361014, 326 | 0.6819775, 327 | 0.6803434, 328 | 0.67870784, 329 | 0.6770708, 330 | 0.6754324, 331 | 0.6737926, 332 | 0.67215145, 333 | 0.670509, 334 | 0.66886514, 335 | 0.66722, 336 | 0.6655736, 337 | 0.66392595, 338 | 0.662277, 339 | 0.6606269, 340 | 0.65897554, 341 | 0.657323, 342 | 0.65566933, 343 | 0.6540145, 344 | 0.6523586, 345 | 0.6507016, 346 | 0.6490435, 347 | 0.64738435, 348 | 0.6457241, 349 | 0.64406294, 350 | 0.6424008, 351 | 0.64073765, 352 | 0.63907355, 353 | 0.63740855, 354 | 0.6357426, 355 | 0.6340758, 356 | 0.6324082, 357 | 0.6307397, 358 | 0.6290704, 359 | 0.6274003, 360 | 0.6257294, 361 | 0.62405777, 362 | 0.6223854, 363 | 0.62071234, 364 | 0.6190386, 365 | 0.61736417, 366 | 0.6156891, 367 | 0.61401343, 368 | 0.6123372, 369 | 0.6106603, 370 | 0.6089829, 371 | 0.607305, 372 | 0.6056265, 373 | 0.6039476, 374 | 0.60226816, 375 | 0.6005883, 376 | 0.598908, 377 | 0.59722733, 378 | 0.5955463, 379 | 0.59386486, 380 | 0.5921831, 381 | 0.59050107, 382 | 0.5888187, 383 | 0.5871361, 384 | 0.5854532, 385 | 0.5837701, 386 | 0.5820868, 387 | 0.5804033, 388 | 0.5787197, 389 | 0.5770359, 390 | 0.575352, 391 | 0.57366806, 392 | 0.571984, 393 | 0.5702999, 394 | 0.5686158, 395 | 0.56693166, 396 | 0.56524754, 397 | 0.5635635, 398 | 0.5618795, 399 | 0.56019557, 400 | 0.5585118, 401 | 0.5568281, 402 | 0.55514455, 403 | 0.5534612, 404 | 0.551778, 405 | 0.5500951, 406 | 0.5484124, 407 | 0.54673, 408 | 0.5450478, 409 | 0.54336596, 410 | 0.54168445, 411 | 0.54000324, 412 | 0.53832245, 413 | 0.5366421, 414 | 0.53496206, 415 | 0.5332825, 416 | 0.53160346, 417 | 0.5299248, 418 | 0.52824676, 419 | 0.5265692, 420 | 0.52489215, 421 | 0.5232157, 422 | 0.5215398, 423 | 0.51986456, 424 | 0.51818997, 425 | 0.51651603, 426 | 0.51484275, 427 | 0.5131702, 428 | 0.5114983, 429 | 0.5098272, 430 | 0.50815684, 431 | 0.5064873, 432 | 0.50481856, 433 | 0.50315064, 434 | 0.50148356, 435 | 0.4998174, 436 | 0.4981521, 437 | 0.49648774, 438 | 0.49482432, 439 | 0.49316183, 440 | 0.49150035, 441 | 0.48983985, 442 | 0.4881804, 443 | 0.486522, 444 | 0.48486462, 445 | 0.4832084, 446 | 0.48155323, 447 | 0.4798992, 448 | 0.47824633, 449 | 0.47659463, 450 | 0.4749441, 451 | 0.47329482, 452 | 0.4716468, 453 | 0.47, 454 | 0.46835446, 455 | 0.46671024, 456 | 0.46506736, 457 | 0.4634258, 458 | 0.46178558, 459 | 0.46014675, 460 | 0.45850933, 461 | 0.45687333, 462 | 0.45523876, 463 | 0.45360568, 464 | 0.45197406, 465 | 0.45034397, 466 | 0.44871536, 467 | 0.44708833, 468 | 0.44546285, 469 | 0.44383895, 470 | 0.44221666, 471 | 0.440596, 472 | 0.43897697, 473 | 0.43735963, 474 | 0.43574396, 475 | 0.43412998, 476 | 0.43251774, 477 | 0.43090722, 478 | 0.4292985, 479 | 0.42769152, 480 | 0.42608637, 481 | 0.42448303, 482 | 0.4228815, 483 | 0.42128187, 484 | 0.4196841, 485 | 0.41808826, 486 | 0.4164943, 487 | 0.4149023, 488 | 0.41331223, 489 | 0.41172415, 490 | 0.41013804, 491 | 0.40855396, 492 | 0.4069719, 493 | 0.4053919, 494 | 0.40381396, 495 | 0.4022381, 496 | 0.40066436, 497 | 0.39909273, 498 | 0.39752322, 499 | 0.3959559, 500 | 0.39439073, 501 | 0.39282778, 502 | 0.39126703, 503 | 0.3897085, 504 | 0.3881522, 505 | 0.3865982, 506 | 0.38504648, 507 | 0.38349706, 508 | 0.38194993, 509 | 0.38040516, 510 | 0.37886274, 511 | 0.37732267, 512 | 0.375785, 513 | 0.37424973, 514 | 0.37271687, 515 | 0.37118647, 516 | 0.36965853, 517 | 0.36813304, 518 | 0.36661002, 519 | 0.36508954, 520 | 0.36357155, 521 | 0.3620561, 522 | 0.36054322, 523 | 0.3590329, 524 | 0.35752517, 525 | 0.35602003, 526 | 0.35451752, 527 | 0.35301763, 528 | 0.3515204, 529 | 0.3500258, 530 | 0.3485339, 531 | 0.3470447, 532 | 0.34555823, 533 | 0.34407446, 534 | 0.34259343, 535 | 0.34111515, 536 | 0.33963963, 537 | 0.33816692, 538 | 0.336697, 539 | 0.3352299, 540 | 0.33376563, 541 | 0.3323042, 542 | 0.33084565, 543 | 0.32938993, 544 | 0.32793713, 545 | 0.3264872, 546 | 0.32504022, 547 | 0.32359615, 548 | 0.32215503, 549 | 0.32071686, 550 | 0.31928164, 551 | 0.31784943, 552 | 0.3164202, 553 | 0.314994, 554 | 0.3135708, 555 | 0.31215066, 556 | 0.31073356, 557 | 0.3093195, 558 | 0.30790854, 559 | 0.30650064, 560 | 0.30509588, 561 | 0.30369422, 562 | 0.30229566, 563 | 0.30090025, 564 | 0.299508, 565 | 0.2981189, 566 | 0.29673296, 567 | 0.29535022, 568 | 0.2939707, 569 | 0.29259437, 570 | 0.29122123, 571 | 0.28985137, 572 | 0.28848472, 573 | 0.28712133, 574 | 0.2857612, 575 | 0.28440437, 576 | 0.2830508, 577 | 0.28170055, 578 | 0.2803536, 579 | 0.27900997, 580 | 0.27766964, 581 | 0.27633268, 582 | 0.27499905, 583 | 0.2736688, 584 | 0.27234194, 585 | 0.27101842, 586 | 0.2696983, 587 | 0.26838157, 588 | 0.26706827, 589 | 0.26575837, 590 | 0.26445192, 591 | 0.26314887, 592 | 0.2618493, 593 | 0.26055318, 594 | 0.2592605, 595 | 0.25797132, 596 | 0.2566856, 597 | 0.2554034, 598 | 0.25412467, 599 | 0.25284946, 600 | 0.25157773, 601 | 0.2503096, 602 | 0.24904492, 603 | 0.24778382, 604 | 0.24652626, 605 | 0.24527225, 606 | 0.2440218, 607 | 0.24277493, 608 | 0.24153163, 609 | 0.24029191, 610 | 0.23905578, 611 | 0.23782326, 612 | 0.23659433, 613 | 0.23536903, 614 | 0.23414734, 615 | 0.23292927, 616 | 0.23171483, 617 | 0.23050404, 618 | 0.22929688, 619 | 0.22809339, 620 | 0.22689353, 621 | 0.22569734, 622 | 0.22450483, 623 | 0.22331597, 624 | 0.2221308, 625 | 0.22094932, 626 | 0.21977153, 627 | 0.21859743, 628 | 0.21742703, 629 | 0.21626033, 630 | 0.21509734, 631 | 0.21393807, 632 | 0.21278252, 633 | 0.21163069, 634 | 0.21048258, 635 | 0.20933822, 636 | 0.20819758, 637 | 0.2070607, 638 | 0.20592754, 639 | 0.20479813, 640 | 0.20367248, 641 | 0.20255059, 642 | 0.20143245, 643 | 0.20031808, 644 | 0.19920748, 645 | 0.19810064, 646 | 0.19699757, 647 | 0.19589828, 648 | 0.19480278, 649 | 0.19371104, 650 | 0.1926231, 651 | 0.19153893, 652 | 0.19045855, 653 | 0.18938197, 654 | 0.18830918, 655 | 0.18724018, 656 | 0.18617497, 657 | 0.18511358, 658 | 0.18405597, 659 | 0.18300217, 660 | 0.18195218, 661 | 0.18090598, 662 | 0.1798636, 663 | 0.17882504, 664 | 0.17779027, 665 | 0.1767593, 666 | 0.17573217, 667 | 0.17470883, 668 | 0.1736893, 669 | 0.1726736, 670 | 0.1716617, 671 | 0.17065361, 672 | 0.16964935, 673 | 0.1686489, 674 | 0.16765225, 675 | 0.16665943, 676 | 0.16567042, 677 | 0.16468522, 678 | 0.16370384, 679 | 0.16272627, 680 | 0.16175252, 681 | 0.16078258, 682 | 0.15981644, 683 | 0.15885411, 684 | 0.1578956, 685 | 0.15694089, 686 | 0.15599, 687 | 0.15504292, 688 | 0.15409963, 689 | 0.15316014, 690 | 0.15222447, 691 | 0.15129258, 692 | 0.1503645, 693 | 0.14944021, 694 | 0.14851972, 695 | 0.14760303, 696 | 0.14669013, 697 | 0.14578101, 698 | 0.14487568, 699 | 0.14397413, 700 | 0.14307636, 701 | 0.14218238, 702 | 0.14129217, 703 | 0.14040573, 704 | 0.13952307, 705 | 0.13864417, 706 | 0.13776903, 707 | 0.13689767, 708 | 0.13603005, 709 | 0.13516618, 710 | 0.13430607, 711 | 0.13344972, 712 | 0.1325971, 713 | 0.13174823, 714 | 0.1309031, 715 | 0.13006169, 716 | 0.12922402, 717 | 0.12839006, 718 | 0.12755983, 719 | 0.12673332, 720 | 0.12591052, 721 | 0.12509143, 722 | 0.12427604, 723 | 0.12346435, 724 | 0.12265636, 725 | 0.121852055, 726 | 0.12105144, 727 | 0.1202545, 728 | 0.11946124, 729 | 0.11867165, 730 | 0.11788572, 731 | 0.11710346, 732 | 0.11632485, 733 | 0.115549885, 734 | 0.11477857, 735 | 0.11401089, 736 | 0.11324684, 737 | 0.11248643, 738 | 0.11172963, 739 | 0.11097645, 740 | 0.110226884, 741 | 0.10948092, 742 | 0.10873855, 743 | 0.10799977, 744 | 0.107264586, 745 | 0.106532976, 746 | 0.105804935, 747 | 0.10508047, 748 | 0.10435956, 749 | 0.1036422, 750 | 0.10292839, 751 | 0.10221813, 752 | 0.1015114, 753 | 0.10080819, 754 | 0.100108504, 755 | 0.09941233, 756 | 0.098719664, 757 | 0.0980305, 758 | 0.09734483, 759 | 0.09666264, 760 | 0.09598393, 761 | 0.095308684, 762 | 0.09463691, 763 | 0.093968585, 764 | 0.09330372, 765 | 0.092642285, 766 | 0.09198428, 767 | 0.09132971, 768 | 0.09067855, 769 | 0.090030804, 770 | 0.089386456, 771 | 0.088745505, 772 | 0.088107936, 773 | 0.08747375, 774 | 0.08684293, 775 | 0.08621547, 776 | 0.085591376, 777 | 0.084970616, 778 | 0.08435319, 779 | 0.0837391, 780 | 0.08312833, 781 | 0.08252087, 782 | 0.08191671, 783 | 0.08131585, 784 | 0.08071827, 785 | 0.080123976, 786 | 0.07953294, 787 | 0.078945175, 788 | 0.078360654, 789 | 0.077779375, 790 | 0.07720133, 791 | 0.07662651, 792 | 0.07605491, 793 | 0.07548651, 794 | 0.07492131, 795 | 0.0743593, 796 | 0.07380046, 797 | 0.073244795, 798 | 0.07269229, 799 | 0.07214294, 800 | 0.07159673, 801 | 0.07105365, 802 | 0.070513695, 803 | 0.06997685, 804 | 0.069443114, 805 | 0.06891247, 806 | 0.06838491, 807 | 0.067860425, 808 | 0.06733901, 809 | 0.066820644, 810 | 0.06630533, 811 | 0.06579305, 812 | 0.0652838, 813 | 0.06477757, 814 | 0.06427433, 815 | 0.0637741, 816 | 0.063276865, 817 | 0.06278259, 818 | 0.062291294, 819 | 0.061802953, 820 | 0.06131756, 821 | 0.0608351, 822 | 0.060355574, 823 | 0.05987896, 824 | 0.059405252, 825 | 0.058934443, 826 | 0.05846652, 827 | 0.058001474, 828 | 0.057539295, 829 | 0.05707997, 830 | 0.056623492, 831 | 0.05616985, 832 | 0.05571903, 833 | 0.055271026, 834 | 0.054825824, 835 | 0.05438342, 836 | 0.053943794, 837 | 0.053506944, 838 | 0.05307286, 839 | 0.052641522, 840 | 0.052212927, 841 | 0.051787063, 842 | 0.051363923, 843 | 0.05094349, 844 | 0.050525755, 845 | 0.05011071, 846 | 0.04969834, 847 | 0.049288645, 848 | 0.0488816, 849 | 0.048477206, 850 | 0.048075445, 851 | 0.04767631, 852 | 0.047279786, 853 | 0.04688587, 854 | 0.046494544, 855 | 0.046105802, 856 | 0.04571963, 857 | 0.04533602, 858 | 0.04495496, 859 | 0.04457644, 860 | 0.044200446, 861 | 0.04382697, 862 | 0.043456003, 863 | 0.043087535, 864 | 0.042721547, 865 | 0.042358037, 866 | 0.04199699, 867 | 0.041638397, 868 | 0.041282244, 869 | 0.040928524, 870 | 0.040577225, 871 | 0.040228333, 872 | 0.039881844, 873 | 0.039537743, 874 | 0.039196018, 875 | 0.038856663, 876 | 0.038519662, 877 | 0.038185004, 878 | 0.037852682, 879 | 0.037522685, 880 | 0.037195, 881 | 0.036869615, 882 | 0.036546525, 883 | 0.036225714, 884 | 0.03590717, 885 | 0.035590887, 886 | 0.035276853, 887 | 0.034965057, 888 | 0.034655485, 889 | 0.03434813, 890 | 0.03404298, 891 | 0.033740025, 892 | 0.033439253, 893 | 0.033140652, 894 | 0.032844216, 895 | 0.03254993, 896 | 0.032257784, 897 | 0.03196777, 898 | 0.031679876, 899 | 0.031394087, 900 | 0.031110398, 901 | 0.030828796, 902 | 0.030549273, 903 | 0.030271813, 904 | 0.02999641, 905 | 0.029723052, 906 | 0.029451728, 907 | 0.029182427, 908 | 0.02891514, 909 | 0.028649855, 910 | 0.028386563, 911 | 0.028125253, 912 | 0.02786591, 913 | 0.027608532, 914 | 0.027353102, 915 | 0.027099613, 916 | 0.026848052, 917 | 0.026598409, 918 | 0.026350675, 919 | 0.02610484, 920 | 0.02586089, 921 | 0.02561882, 922 | 0.025378617, 923 | 0.025140269, 924 | 0.024903767, 925 | 0.0246691, 926 | 0.02443626, 927 | 0.024205236, 928 | 0.023976017, 929 | 0.023748592, 930 | 0.023522953, 931 | 0.023299087, 932 | 0.023076987, 933 | 0.022856642, 934 | 0.02263804, 935 | 0.022421172, 936 | 0.022206029, 937 | 0.0219926, 938 | 0.021780876, 939 | 0.021570845, 940 | 0.021362498, 941 | 0.021155827, 942 | 0.020950818, 943 | 0.020747466, 944 | 0.020545758, 945 | 0.020345684, 946 | 0.020147236, 947 | 0.019950403, 948 | 0.019755175, 949 | 0.019561544, 950 | 0.019369498, 951 | 0.019179028, 952 | 0.018990126, 953 | 0.01880278, 954 | 0.018616982, 955 | 0.018432721, 956 | 0.01824999, 957 | 0.018068777, 958 | 0.017889075, 959 | 0.017710872, 960 | 0.01753416, 961 | 0.017358929, 962 | 0.017185168, 963 | 0.017012872, 964 | 0.016842028, 965 | 0.016672628, 966 | 0.016504662, 967 | 0.016338123, 968 | 0.016173, 969 | 0.016009282, 970 | 0.015846964, 971 | 0.015686033, 972 | 0.015526483, 973 | 0.015368304, 974 | 0.015211486, 975 | 0.0150560215, 976 | 0.014901901, 977 | 0.014749114, 978 | 0.014597654, 979 | 0.014447511, 980 | 0.0142986765, 981 | 0.014151142, 982 | 0.014004898, 983 | 0.013859936, 984 | 0.013716248, 985 | 0.0135738235, 986 | 0.013432656, 987 | 0.013292736, 988 | 0.013154055, 989 | 0.013016605, 990 | 0.012880377, 991 | 0.012745362, 992 | 0.012611552, 993 | 0.012478939, 994 | 0.012347515, 995 | 0.01221727, 996 | 0.012088198, 997 | 0.0119602885, 998 | 0.0118335355, 999 | 0.011707929, 1000 | 0.011583461, 1001 | 0.011460125, 1002 | 0.011337912, 1003 | 0.011216813, 1004 | 0.011096821, 1005 | 0.010977928, 1006 | 0.0108601255, 1007 | 0.010743406, 1008 | 0.010627762, 1009 | 0.0105131855, 1010 | 0.010399668, 1011 | 0.010287202, 1012 | 0.01017578, 1013 | 0.010065395, 1014 | 0.009956039, 1015 | 0.009847702, 1016 | 0.009740381, 1017 | 0.0096340645, 1018 | 0.009528747, 1019 | 0.009424419, 1020 | 0.009321076, 1021 | 0.009218709, 1022 | 0.00911731, 1023 | 0.009016872, 1024 | 0.008917389, 1025 | 0.008818853, 1026 | 0.008721256, 1027 | 0.008624591, 1028 | 0.008528852, 1029 | 0.00843403, 1030 | 0.00834012, 1031 | 0.008247114, 1032 | 0.008155004, 1033 | 0.008063785, 1034 | 0.007973449, 1035 | 0.007883989, 1036 | 0.007795398, 1037 | 0.0077076694, 1038 | 0.0076207966, 1039 | 0.0075347726, 1040 | 0.007449591, 1041 | 0.0073652444, 1042 | 0.007281727, 1043 | 0.0071990318, 1044 | 0.007117152, 1045 | 0.0070360815, 1046 | 0.0069558136, 1047 | 0.0068763415, 1048 | 0.006797659, 1049 | 0.00671976, 1050 | 0.0066426382, 1051 | 0.0065662866, 1052 | 0.006490699, 1053 | 0.0064158696, 1054 | 0.006341792, 1055 | 0.00626846, 1056 | 0.0061958674, 1057 | 0.0061240084, 1058 | 0.0060528764, 1059 | 0.0059824656, 1060 | 0.0059127696, 1061 | 0.0058437833, 1062 | 0.0057755, 1063 | 0.0057079145, 1064 | 0.00564102, 1065 | 0.0055748112, 1066 | 0.0055092825, 1067 | 0.005444428, 1068 | 0.005380241, 1069 | 0.0053167176, 1070 | 0.005253851, 1071 | 0.005191636, 1072 | 0.005130066, 1073 | 0.0050691366, 1074 | 0.0050088423, 1075 | 0.0049491767, 1076 | 0.004890135, 1077 | 0.0048317118, 1078 | 0.004773902, 1079 | 0.004716699, 1080 | 0.0046600983, 1081 | ] 1082 | -------------------------------------------------------------------------------- /stable_diffusion_tf/diffusion_model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import tensorflow_addons as tfa 4 | 5 | from .layers import PaddedConv2D, apply_seq, td_dot, GEGLU 6 | 7 | 8 | class ResBlock(keras.layers.Layer): 9 | def __init__(self, channels, out_channels): 10 | super().__init__() 11 | self.in_layers = [ 12 | tfa.layers.GroupNormalization(epsilon=1e-5), 13 | keras.activations.swish, 14 | PaddedConv2D(out_channels, 3, padding=1), 15 | ] 16 | self.emb_layers = [ 17 | keras.activations.swish, 18 | keras.layers.Dense(out_channels), 19 | ] 20 | self.out_layers = [ 21 | tfa.layers.GroupNormalization(epsilon=1e-5), 22 | keras.activations.swish, 23 | PaddedConv2D(out_channels, 3, padding=1), 24 | ] 25 | self.skip_connection = ( 26 | PaddedConv2D(out_channels, 1) if channels != out_channels else lambda x: x 27 | ) 28 | 29 | def call(self, inputs): 30 | x, emb = inputs 31 | h = apply_seq(x, self.in_layers) 32 | emb_out = apply_seq(emb, self.emb_layers) 33 | h = h + emb_out[:, None, None] 34 | h = apply_seq(h, self.out_layers) 35 | ret = self.skip_connection(x) + h 36 | return ret 37 | 38 | 39 | class CrossAttention(keras.layers.Layer): 40 | def __init__(self, n_heads, d_head): 41 | super().__init__() 42 | self.to_q = keras.layers.Dense(n_heads * d_head, use_bias=False) 43 | self.to_k = keras.layers.Dense(n_heads * d_head, use_bias=False) 44 | self.to_v = keras.layers.Dense(n_heads * d_head, use_bias=False) 45 | self.scale = d_head**-0.5 46 | self.num_heads = n_heads 47 | self.head_size = d_head 48 | self.to_out = [keras.layers.Dense(n_heads * d_head)] 49 | 50 | def call(self, inputs): 51 | assert type(inputs) is list 52 | if len(inputs) == 1: 53 | inputs = inputs + [None] 54 | x, context = inputs 55 | context = x if context is None else context 56 | q, k, v = self.to_q(x), self.to_k(context), self.to_v(context) 57 | assert len(x.shape) == 3 58 | q = tf.reshape(q, (-1, x.shape[1], self.num_heads, self.head_size)) 59 | k = tf.reshape(k, (-1, context.shape[1], self.num_heads, self.head_size)) 60 | v = tf.reshape(v, (-1, context.shape[1], self.num_heads, self.head_size)) 61 | 62 | q = keras.layers.Permute((2, 1, 3))(q) # (bs, num_heads, time, head_size) 63 | k = keras.layers.Permute((2, 3, 1))(k) # (bs, num_heads, head_size, time) 64 | v = keras.layers.Permute((2, 1, 3))(v) # (bs, num_heads, time, head_size) 65 | 66 | score = td_dot(q, k) * self.scale 67 | weights = keras.activations.softmax(score) # (bs, num_heads, time, time) 68 | attention = td_dot(weights, v) 69 | attention = keras.layers.Permute((2, 1, 3))( 70 | attention 71 | ) # (bs, time, num_heads, head_size) 72 | h_ = tf.reshape(attention, (-1, x.shape[1], self.num_heads * self.head_size)) 73 | return apply_seq(h_, self.to_out) 74 | 75 | 76 | class BasicTransformerBlock(keras.layers.Layer): 77 | def __init__(self, dim, n_heads, d_head): 78 | super().__init__() 79 | self.norm1 = keras.layers.LayerNormalization(epsilon=1e-5) 80 | self.attn1 = CrossAttention(n_heads, d_head) 81 | 82 | self.norm2 = keras.layers.LayerNormalization(epsilon=1e-5) 83 | self.attn2 = CrossAttention(n_heads, d_head) 84 | 85 | self.norm3 = keras.layers.LayerNormalization(epsilon=1e-5) 86 | self.geglu = GEGLU(dim * 4) 87 | self.dense = keras.layers.Dense(dim) 88 | 89 | def call(self, inputs): 90 | x, context = inputs 91 | x = self.attn1([self.norm1(x)]) + x 92 | x = self.attn2([self.norm2(x), context]) + x 93 | return self.dense(self.geglu(self.norm3(x))) + x 94 | 95 | 96 | class SpatialTransformer(keras.layers.Layer): 97 | def __init__(self, channels, n_heads, d_head): 98 | super().__init__() 99 | self.norm = tfa.layers.GroupNormalization(epsilon=1e-5) 100 | assert channels == n_heads * d_head 101 | self.proj_in = PaddedConv2D(n_heads * d_head, 1) 102 | self.transformer_blocks = [BasicTransformerBlock(channels, n_heads, d_head)] 103 | self.proj_out = PaddedConv2D(channels, 1) 104 | 105 | def call(self, inputs): 106 | x, context = inputs 107 | b, h, w, c = x.shape 108 | x_in = x 109 | x = self.norm(x) 110 | x = self.proj_in(x) 111 | x = tf.reshape(x, (-1, h * w, c)) 112 | for block in self.transformer_blocks: 113 | x = block([x, context]) 114 | x = tf.reshape(x, (-1, h, w, c)) 115 | return self.proj_out(x) + x_in 116 | 117 | 118 | class Downsample(keras.layers.Layer): 119 | def __init__(self, channels): 120 | super().__init__() 121 | self.op = PaddedConv2D(channels, 3, stride=2, padding=1) 122 | 123 | def call(self, x): 124 | return self.op(x) 125 | 126 | 127 | class Upsample(keras.layers.Layer): 128 | def __init__(self, channels): 129 | super().__init__() 130 | self.ups = keras.layers.UpSampling2D(size=(2, 2)) 131 | self.conv = PaddedConv2D(channels, 3, padding=1) 132 | 133 | def call(self, x): 134 | x = self.ups(x) 135 | return self.conv(x) 136 | 137 | 138 | class UNetModel(keras.models.Model): 139 | def __init__(self): 140 | super().__init__() 141 | self.time_embed = [ 142 | keras.layers.Dense(1280), 143 | keras.activations.swish, 144 | keras.layers.Dense(1280), 145 | ] 146 | self.input_blocks = [ 147 | [PaddedConv2D(320, kernel_size=3, padding=1)], 148 | [ResBlock(320, 320), SpatialTransformer(320, 8, 40)], 149 | [ResBlock(320, 320), SpatialTransformer(320, 8, 40)], 150 | [Downsample(320)], 151 | [ResBlock(320, 640), SpatialTransformer(640, 8, 80)], 152 | [ResBlock(640, 640), SpatialTransformer(640, 8, 80)], 153 | [Downsample(640)], 154 | [ResBlock(640, 1280), SpatialTransformer(1280, 8, 160)], 155 | [ResBlock(1280, 1280), SpatialTransformer(1280, 8, 160)], 156 | [Downsample(1280)], 157 | [ResBlock(1280, 1280)], 158 | [ResBlock(1280, 1280)], 159 | ] 160 | self.middle_block = [ 161 | ResBlock(1280, 1280), 162 | SpatialTransformer(1280, 8, 160), 163 | ResBlock(1280, 1280), 164 | ] 165 | self.output_blocks = [ 166 | [ResBlock(2560, 1280)], 167 | [ResBlock(2560, 1280)], 168 | [ResBlock(2560, 1280), Upsample(1280)], 169 | [ResBlock(2560, 1280), SpatialTransformer(1280, 8, 160)], 170 | [ResBlock(2560, 1280), SpatialTransformer(1280, 8, 160)], 171 | [ 172 | ResBlock(1920, 1280), 173 | SpatialTransformer(1280, 8, 160), 174 | Upsample(1280), 175 | ], 176 | [ResBlock(1920, 640), SpatialTransformer(640, 8, 80)], # 6 177 | [ResBlock(1280, 640), SpatialTransformer(640, 8, 80)], 178 | [ 179 | ResBlock(960, 640), 180 | SpatialTransformer(640, 8, 80), 181 | Upsample(640), 182 | ], 183 | [ResBlock(960, 320), SpatialTransformer(320, 8, 40)], 184 | [ResBlock(640, 320), SpatialTransformer(320, 8, 40)], 185 | [ResBlock(640, 320), SpatialTransformer(320, 8, 40)], 186 | ] 187 | self.out = [ 188 | tfa.layers.GroupNormalization(epsilon=1e-5), 189 | keras.activations.swish, 190 | PaddedConv2D(4, kernel_size=3, padding=1), 191 | ] 192 | 193 | def call(self, inputs): 194 | x, t_emb, context = inputs 195 | emb = apply_seq(t_emb, self.time_embed) 196 | 197 | def apply(x, layer): 198 | if isinstance(layer, ResBlock): 199 | x = layer([x, emb]) 200 | elif isinstance(layer, SpatialTransformer): 201 | x = layer([x, context]) 202 | else: 203 | x = layer(x) 204 | return x 205 | 206 | saved_inputs = [] 207 | for b in self.input_blocks: 208 | for layer in b: 209 | x = apply(x, layer) 210 | saved_inputs.append(x) 211 | 212 | for layer in self.middle_block: 213 | x = apply(x, layer) 214 | 215 | for b in self.output_blocks: 216 | x = tf.concat([x, saved_inputs.pop()], axis=-1) 217 | for layer in b: 218 | x = apply(x, layer) 219 | return apply_seq(x, self.out) 220 | -------------------------------------------------------------------------------- /stable_diffusion_tf/layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | 4 | 5 | class PaddedConv2D(keras.layers.Layer): 6 | def __init__(self, channels, kernel_size, padding=0, stride=1): 7 | super().__init__() 8 | self.padding2d = keras.layers.ZeroPadding2D((padding, padding)) 9 | self.conv2d = keras.layers.Conv2D( 10 | channels, kernel_size, strides=(stride, stride) 11 | ) 12 | 13 | def call(self, x): 14 | x = self.padding2d(x) 15 | return self.conv2d(x) 16 | 17 | 18 | class GEGLU(keras.layers.Layer): 19 | def __init__(self, dim_out): 20 | super().__init__() 21 | self.proj = keras.layers.Dense(dim_out * 2) 22 | self.dim_out = dim_out 23 | 24 | def call(self, x): 25 | xp = self.proj(x) 26 | x, gate = xp[..., : self.dim_out], xp[..., self.dim_out :] 27 | return x * gelu(gate) 28 | 29 | 30 | def gelu(x): 31 | tanh_res = keras.activations.tanh(x * 0.7978845608 * (1 + 0.044715 * (x**2))) 32 | return 0.5 * x * (1 + tanh_res) 33 | 34 | 35 | def quick_gelu(x): 36 | return x * tf.sigmoid(x * 1.702) 37 | 38 | 39 | def apply_seq(x, layers): 40 | for l in layers: 41 | x = l(x) 42 | return x 43 | 44 | 45 | def td_dot(a, b): 46 | aa = tf.reshape(a, (-1, a.shape[2], a.shape[3])) 47 | bb = tf.reshape(b, (-1, b.shape[2], b.shape[3])) 48 | cc = keras.backend.batch_dot(aa, bb) 49 | return tf.reshape(cc, (-1, a.shape[1], cc.shape[1], cc.shape[2])) 50 | -------------------------------------------------------------------------------- /stable_diffusion_tf/stable_diffusion.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | import math 4 | 5 | import tensorflow as tf 6 | from tensorflow import keras 7 | 8 | from .autoencoder_kl import Decoder 9 | from .diffusion_model import UNetModel 10 | from .clip_encoder import CLIPTextTransformer 11 | from .clip_tokenizer import SimpleTokenizer 12 | from .constants import _UNCONDITIONAL_TOKENS, _ALPHAS_CUMPROD 13 | 14 | MAX_TEXT_LEN = 77 15 | 16 | 17 | class Text2Image: 18 | def __init__(self, img_height=1000, img_width=1000, jit_compile=False): 19 | # UNet requires multiples of 2**7 = 128 to prevent dimension mismatch 20 | self.img_height = round(img_height/128) * 128 21 | self.img_width = round(img_width/128) * 128 22 | self.tokenizer = SimpleTokenizer() 23 | 24 | text_encoder, diffusion_model, decoder = get_models(self.img_height, self.img_width) 25 | self.text_encoder = text_encoder 26 | self.diffusion_model = diffusion_model 27 | self.decoder = decoder 28 | if jit_compile: 29 | self.text_encoder.compile(jit_compile=True) 30 | self.diffusion_model.compile(jit_compile=True) 31 | self.decoder.compile(jit_compile=True) 32 | 33 | def generate( 34 | self, 35 | prompt, 36 | batch_size=1, 37 | num_steps=25, 38 | unconditional_guidance_scale=7.5, 39 | temperature=1, 40 | seed=None, 41 | ): 42 | # Tokenize prompt (i.e. starting context) 43 | inputs = self.tokenizer.encode(prompt) 44 | assert len(inputs) < 77, "Prompt is too long (should be < 77 tokens)" 45 | phrase = inputs + [49407] * (77 - len(inputs)) 46 | phrase = np.array(phrase)[None].astype("int32") 47 | phrase = np.repeat(phrase, batch_size, axis=0) 48 | 49 | # Encode prompt tokens (and their positions) into a "context vector" 50 | pos_ids = np.array(list(range(77)))[None].astype("int32") 51 | pos_ids = np.repeat(pos_ids, batch_size, axis=0) 52 | context = self.text_encoder.predict_on_batch([phrase, pos_ids]) 53 | 54 | # Encode unconditional tokens (and their positions into an 55 | # "unconditional context vector" 56 | unconditional_tokens = np.array(_UNCONDITIONAL_TOKENS)[None].astype("int32") 57 | unconditional_tokens = np.repeat(unconditional_tokens, batch_size, axis=0) 58 | self.unconditional_tokens = tf.convert_to_tensor(unconditional_tokens) 59 | unconditional_context = self.text_encoder.predict_on_batch( 60 | [self.unconditional_tokens, pos_ids] 61 | ) 62 | timesteps = np.arange(1, 1000, 1000 // num_steps) 63 | latent, alphas, alphas_prev = self.get_starting_parameters( 64 | timesteps, batch_size, seed 65 | ) 66 | 67 | # Diffusion stage 68 | progbar = tqdm(list(enumerate(timesteps))[::-1]) 69 | for index, timestep in progbar: 70 | progbar.set_description(f"{index:3d} {timestep:3d}") 71 | e_t = self.get_model_output( 72 | latent, 73 | timestep, 74 | context, 75 | unconditional_context, 76 | unconditional_guidance_scale, 77 | batch_size, 78 | ) 79 | a_t, a_prev = alphas[index], alphas_prev[index] 80 | latent, pred_x0 = self.get_x_prev_and_pred_x0( 81 | latent, e_t, index, a_t, a_prev, temperature, seed 82 | ) 83 | 84 | # Decoding stage 85 | decoded = self.decoder.predict_on_batch(latent) 86 | decoded = ((decoded + 1) / 2) * 255 87 | return np.clip(decoded, 0, 255).astype("uint8") 88 | 89 | def timestep_embedding(self, timesteps, dim=320, max_period=10000): 90 | half = dim // 2 91 | freqs = np.exp( 92 | -math.log(max_period) * np.arange(0, half, dtype="float32") / half 93 | ) 94 | args = np.array(timesteps) * freqs 95 | embedding = np.concatenate([np.cos(args), np.sin(args)]) 96 | return tf.convert_to_tensor(embedding.reshape(1, -1)) 97 | 98 | def get_model_output( 99 | self, 100 | latent, 101 | t, 102 | context, 103 | unconditional_context, 104 | unconditional_guidance_scale, 105 | batch_size, 106 | ): 107 | timesteps = np.array([t]) 108 | t_emb = self.timestep_embedding(timesteps) 109 | t_emb = np.repeat(t_emb, batch_size, axis=0) 110 | unconditional_latent = self.diffusion_model.predict_on_batch( 111 | [latent, t_emb, unconditional_context] 112 | ) 113 | latent = self.diffusion_model.predict_on_batch([latent, t_emb, context]) 114 | return unconditional_latent + unconditional_guidance_scale * ( 115 | latent - unconditional_latent 116 | ) 117 | 118 | def get_x_prev_and_pred_x0(self, x, e_t, index, a_t, a_prev, temperature, seed): 119 | sigma_t = 0 120 | sqrt_one_minus_at = math.sqrt(1 - a_t) 121 | pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t) 122 | 123 | # Direction pointing to x_t 124 | dir_xt = math.sqrt(1.0 - a_prev - sigma_t**2) * e_t 125 | noise = sigma_t * tf.random.normal(x.shape, seed=seed) * temperature 126 | x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt 127 | return x_prev, pred_x0 128 | 129 | def get_starting_parameters(self, timesteps, batch_size, seed): 130 | n_h = self.img_height // 8 131 | n_w = self.img_width // 8 132 | alphas = [_ALPHAS_CUMPROD[t] for t in timesteps] 133 | alphas_prev = [1.0] + alphas[:-1] 134 | latent = tf.random.normal((batch_size, n_h, n_w, 4), seed=seed) 135 | return latent, alphas, alphas_prev 136 | 137 | 138 | def get_models(img_height, img_width, download_weights=True): 139 | n_h = img_height // 8 140 | n_w = img_width // 8 141 | 142 | # Create text encoder 143 | input_word_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32") 144 | input_pos_ids = keras.layers.Input(shape=(MAX_TEXT_LEN,), dtype="int32") 145 | embeds = CLIPTextTransformer()([input_word_ids, input_pos_ids]) 146 | text_encoder = keras.models.Model([input_word_ids, input_pos_ids], embeds) 147 | 148 | # Creation diffusion UNet 149 | context = keras.layers.Input((MAX_TEXT_LEN, 768)) 150 | t_emb = keras.layers.Input((320,)) 151 | latent = keras.layers.Input((n_h, n_w, 4)) 152 | unet = UNetModel() 153 | diffusion_model = keras.models.Model( 154 | [latent, t_emb, context], unet([latent, t_emb, context]) 155 | ) 156 | 157 | # Create decoder 158 | latent = keras.layers.Input((n_h, n_w, 4)) 159 | decoder = Decoder() 160 | decoder = keras.models.Model(latent, decoder(latent)) 161 | 162 | text_encoder_weights_fpath = keras.utils.get_file( 163 | origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/text_encoder.h5", 164 | file_hash="d7805118aeb156fc1d39e38a9a082b05501e2af8c8fbdc1753c9cb85212d6619", 165 | ) 166 | diffusion_model_weights_fpath = keras.utils.get_file( 167 | origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/diffusion_model.h5", 168 | file_hash="a5b2eea58365b18b40caee689a2e5d00f4c31dbcb4e1d58a9cf1071f55bbbd3a", 169 | ) 170 | decoder_weights_fpath = keras.utils.get_file( 171 | origin="https://huggingface.co/fchollet/stable-diffusion/resolve/main/decoder.h5", 172 | file_hash="6d3c5ba91d5cc2b134da881aaa157b2d2adc648e5625560e3ed199561d0e39d5", 173 | ) 174 | 175 | text_encoder.load_weights(text_encoder_weights_fpath) 176 | diffusion_model.load_weights(diffusion_model_weights_fpath) 177 | decoder.load_weights(decoder_weights_fpath) 178 | return text_encoder, diffusion_model, decoder 179 | -------------------------------------------------------------------------------- /text2image.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | from stable_diffusion_tf.stable_diffusion import Text2Image 3 | import argparse 4 | from PIL import Image 5 | 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument( 9 | "--prompt", 10 | type=str, 11 | nargs="?", 12 | default="a painting of a virus monster playing guitar", 13 | help="the prompt to render", 14 | ) 15 | 16 | parser.add_argument( 17 | "--output", 18 | type=str, 19 | nargs="?", 20 | default="output", 21 | help="where to save the output image", 22 | ) 23 | 24 | parser.add_argument( 25 | "--H", 26 | type=int, 27 | default=512, 28 | help="image height, in pixels", 29 | ) 30 | 31 | parser.add_argument( 32 | "--W", 33 | type=int, 34 | default=512, 35 | help="image width, in pixels", 36 | ) 37 | 38 | parser.add_argument( 39 | "--scale", 40 | type=float, 41 | default=7.5, 42 | help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))", 43 | ) 44 | 45 | parser.add_argument( 46 | "--steps", type=int, default=50, help="number of ddim sampling steps" 47 | ) 48 | 49 | parser.add_argument( 50 | "--seed", 51 | type=int, 52 | help="optionally specify a seed integer for reproducible results", 53 | ) 54 | 55 | parser.add_argument( 56 | "--batch_size", 57 | type=int, 58 | default=1, 59 | help="how many images to generate", 60 | ) 61 | 62 | parser.add_argument( 63 | "--mp", 64 | default=False, 65 | action="store_true", 66 | help="Enable mixed precision (fp16 computation)", 67 | ) 68 | 69 | args = parser.parse_args() 70 | 71 | if args.mp: 72 | print("Using mixed precision.") 73 | keras.mixed_precision.set_global_policy("mixed_float16") 74 | 75 | generator = Text2Image(img_height=args.H, img_width=args.W, jit_compile=False) 76 | img = generator.generate( 77 | args.prompt, 78 | num_steps=args.steps, 79 | unconditional_guidance_scale=args.scale, 80 | temperature=1, 81 | batch_size=args.batch_size, 82 | seed=args.seed, 83 | ) 84 | 85 | fname = args.output 86 | if fname.endswith(".png"): 87 | fname = fname[:-4] 88 | if args.batch_size == 1: 89 | Image.fromarray(img[0]).save(args.output + ".png") 90 | print(f"saved at {args.output}.png") 91 | else: 92 | for i in range(args.batch_size): 93 | fname_i = f"{fname}_{i}.png" 94 | Image.fromarray(img[i]).save(fname_i) 95 | print(f"saved at {fname_i}") 96 | --------------------------------------------------------------------------------