├── .gitignore ├── LICENSE ├── README.md ├── demo.ipynb ├── requirements.txt └── stable_diffusion_pytorch ├── __init__.py ├── attention.py ├── clip.py ├── decoder.py ├── diffusion.py ├── encoder.py ├── model_loader.py ├── pipeline.py ├── samplers ├── __init__.py ├── k_euler.py ├── k_euler_ancestral.py └── k_lms.py ├── tokenizer.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,python 3 | 4 | ### macOS ### 5 | # General 6 | .DS_Store 7 | .AppleDouble 8 | .LSOverride 9 | 10 | # Icon must end with two \r 11 | Icon 12 | 13 | 14 | # Thumbnails 15 | ._* 16 | 17 | # Files that might appear in the root of a volume 18 | .DocumentRevisions-V100 19 | .fseventsd 20 | .Spotlight-V100 21 | .TemporaryItems 22 | .Trashes 23 | .VolumeIcon.icns 24 | .com.apple.timemachine.donotpresent 25 | 26 | # Directories potentially created on remote AFP share 27 | .AppleDB 28 | .AppleDesktop 29 | Network Trash Folder 30 | Temporary Items 31 | .apdisk 32 | 33 | ### macOS Patch ### 34 | # iCloud generated files 35 | *.icloud 36 | 37 | ### Python ### 38 | # Byte-compiled / optimized / DLL files 39 | __pycache__/ 40 | *.py[cod] 41 | *$py.class 42 | 43 | # C extensions 44 | *.so 45 | 46 | # Distribution / packaging 47 | .Python 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | share/python-wheels/ 61 | *.egg-info/ 62 | .installed.cfg 63 | *.egg 64 | MANIFEST 65 | 66 | # PyInstaller 67 | # Usually these files are written by a python script from a template 68 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 69 | *.manifest 70 | *.spec 71 | 72 | # Installer logs 73 | pip-log.txt 74 | pip-delete-this-directory.txt 75 | 76 | # Unit test / coverage reports 77 | htmlcov/ 78 | .tox/ 79 | .nox/ 80 | .coverage 81 | .coverage.* 82 | .cache 83 | nosetests.xml 84 | coverage.xml 85 | *.cover 86 | *.py,cover 87 | .hypothesis/ 88 | .pytest_cache/ 89 | cover/ 90 | 91 | # Translations 92 | *.mo 93 | *.pot 94 | 95 | # Django stuff: 96 | *.log 97 | local_settings.py 98 | db.sqlite3 99 | db.sqlite3-journal 100 | 101 | # Flask stuff: 102 | instance/ 103 | .webassets-cache 104 | 105 | # Scrapy stuff: 106 | .scrapy 107 | 108 | # Sphinx documentation 109 | docs/_build/ 110 | 111 | # PyBuilder 112 | .pybuilder/ 113 | target/ 114 | 115 | # Jupyter Notebook 116 | .ipynb_checkpoints 117 | 118 | # IPython 119 | profile_default/ 120 | ipython_config.py 121 | 122 | # pyenv 123 | # For a library or package, you might want to ignore these files since the code is 124 | # intended to run in multiple environments; otherwise, check them in: 125 | # .python-version 126 | 127 | # pipenv 128 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 129 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 130 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 131 | # install all needed dependencies. 132 | #Pipfile.lock 133 | 134 | # poetry 135 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 136 | # This is especially recommended for binary packages to ensure reproducibility, and is more 137 | # commonly ignored for libraries. 138 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 139 | #poetry.lock 140 | 141 | # pdm 142 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 143 | #pdm.lock 144 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 145 | # in version control. 146 | # https://pdm.fming.dev/#use-with-ide 147 | .pdm.toml 148 | 149 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 150 | __pypackages__/ 151 | 152 | # Celery stuff 153 | celerybeat-schedule 154 | celerybeat.pid 155 | 156 | # SageMath parsed files 157 | *.sage.py 158 | 159 | # Environments 160 | .env 161 | .venv 162 | env/ 163 | venv/ 164 | ENV/ 165 | env.bak/ 166 | venv.bak/ 167 | 168 | # Spyder project settings 169 | .spyderproject 170 | .spyproject 171 | 172 | # Rope project settings 173 | .ropeproject 174 | 175 | # mkdocs documentation 176 | /site 177 | 178 | # mypy 179 | .mypy_cache/ 180 | .dmypy.json 181 | dmypy.json 182 | 183 | # Pyre type checker 184 | .pyre/ 185 | 186 | # pytype static type analyzer 187 | .pytype/ 188 | 189 | # Cython debug symbols 190 | cython_debug/ 191 | 192 | # PyCharm 193 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 194 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 195 | # and can be added to the global gitignore or merged into this file. For a more nuclear 196 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 197 | #.idea/ 198 | 199 | # End of https://www.toptal.com/developers/gitignore/api/macos,python 200 | 201 | data/ 202 | ddata/ 203 | migrators/ 204 | data.*.tar -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Jinseo Kim 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stable-diffusion-pytorch 2 | 3 | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/kjsman/stable-diffusion-pytorch/blob/main/demo.ipynb) 4 | 5 | Yet another PyTorch implementation of [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release). 6 | 7 | I tried my best to make the codebase minimal, self-contained, consistent, hackable, and easy to read. Features are pruned if not needed in Stable Diffusion (e.g. Attention mask at CLIP tokenizer/encoder). Configs are hard-coded (based on Stable Diffusion v1.x). Loops are unrolled when that shape makes more sense. 8 | 9 | Despite of my efforts, I feel like [I cooked another sphagetti](https://xkcd.com/927/). Well, help yourself! 10 | 11 | Heavily referred to following repositories. Big kudos to them! 12 | 13 | * [divamgupta/stable-diffusion-tensorflow](https://github.com/divamgupta/stable-diffusion-tensorflow) 14 | * [CompVis/stable-diffusion](https://github.com/CompVis/stable-diffusion) 15 | * [huggingface/transformers](https://github.com/huggingface/transformers) 16 | * [crowsonkb/k-diffusion](https://github.com/crowsonkb/k-diffusion) 17 | * [karpathy/minGPT](https://github.com/karpathy/minGPT) 18 | 19 | ## Dependencies 20 | 21 | * PyTorch 22 | * Numpy 23 | * Pillow 24 | * regex 25 | * tqdm 26 | 27 | ## How to Install 28 | 29 | 1. Clone or download this repository. 30 | 2. Install dependencies: Run `pip install torch numpy Pillow regex` or `pip install -r requirements.txt`. 31 | 3. Download `data.v20221029.tar` from [here](https://huggingface.co/jinseokim/stable-diffusion-pytorch-data/resolve/main/data.v20221029.tar) and unpack in the parent folder of `stable_diffusion_pytorch`. Your folders should be like this: 32 | ``` 33 | stable-diffusion-pytorch(-main)/ 34 | ├─ data/ 35 | │ ├─ ckpt/ 36 | │ ├─ ... 37 | ├─ stable_diffusion_pytorch/ 38 | │ ├─ samplers/ 39 | └ ┴─ ... 40 | ``` 41 | *Note that checkpoint files included in `data.zip` [have different license](#license) -- you should agree to the license to use checkpoint files.* 42 | 43 | ## How to Use 44 | 45 | Import `stable_diffusion_pytorch` as submodule. 46 | 47 | Here's some example scripts. You can also read the docstring of `stable_diffusion_pytorch.pipeline.generate`. 48 | 49 | Text-to-image generation: 50 | ```py 51 | from stable_diffusion_pytorch import pipeline 52 | 53 | prompts = ["a photograph of an astronaut riding a horse"] 54 | images = pipeline.generate(prompts) 55 | images[0].save('output.jpg') 56 | ``` 57 | 58 | ...with multiple prompts: 59 | ``` 60 | prompts = [ 61 | "a photograph of an astronaut riding a horse", 62 | ""] 63 | images = pipeline.generate(prompts) 64 | ``` 65 | 66 | ...with unconditional(negative) prompts: 67 | ```py 68 | prompts = ["a photograph of an astronaut riding a horse"] 69 | uncond_prompts = ["low quality"] 70 | images = pipeline.generate(prompts, uncond_prompts) 71 | ``` 72 | 73 | ...with seed: 74 | ```py 75 | prompts = ["a photograph of an astronaut riding a horse"] 76 | images = pipeline.generate(prompts, uncond_prompts, seed=42) 77 | ``` 78 | 79 | Preload models (you will need enough VRAM): 80 | ```py 81 | from stable_diffusion_pytorch import model_loader 82 | models = model_loader.preload_models('cuda') 83 | 84 | prompts = ["a photograph of an astronaut riding a horse"] 85 | images = pipeline.generate(prompts, models=models) 86 | ``` 87 | 88 | If you get OOM with above code but have enough RAM (not VRAM), you can move models to GPU when needed 89 | and move back to CPU when not needed: 90 | ```py 91 | from stable_diffusion_pytorch import model_loader 92 | models = model_loader.preload_models('cpu') 93 | 94 | prompts = ["a photograph of an astronaut riding a horse"] 95 | images = pipeline.generate(prompts, models=models, device='cuda', idle_device='cpu') 96 | ``` 97 | 98 | Image-to-image generation: 99 | ```py 100 | from PIL import Image 101 | 102 | prompts = ["a photograph of an astronaut riding a horse"] 103 | input_images = [Image.open('space.jpg')] 104 | images = pipeline.generate(prompts, input_images=images) 105 | ``` 106 | 107 | ...with custom strength: 108 | ```py 109 | prompts = ["a photograph of an astronaut riding a horse"] 110 | input_images = [Image.open('space.jpg')] 111 | images = pipeline.generate(prompts, input_images=images, strength=0.6) 112 | ``` 113 | 114 | Change [classifier-free guidance](https://arxiv.org/abs/2207.12598) scale: 115 | ```py 116 | prompts = ["a photograph of an astronaut riding a horse"] 117 | images = pipeline.generate(prompts, cfg_scale=11) 118 | ``` 119 | 120 | ...or disable classifier-free guidance: 121 | ```py 122 | prompts = ["a photograph of an astronaut riding a horse"] 123 | images = pipeline.generate(prompts, do_cfg=False) 124 | ``` 125 | 126 | Reduce steps (faster generation, lower quality): 127 | ```py 128 | prompts = ["a photograph of an astronaut riding a horse"] 129 | images = pipeline.generate(prompts, n_inference_steps=28) 130 | ``` 131 | 132 | Use different sampler: 133 | ```py 134 | prompts = ["a photograph of an astronaut riding a horse"] 135 | images = pipeline.generate(prompts, sampler="k_euler") 136 | # "k_lms" (default), "k_euler", or "k_euler_ancestral" is available 137 | ``` 138 | 139 | Generate image with custom size: 140 | ```py 141 | prompts = ["a photograph of an astronaut riding a horse"] 142 | images = pipeline.generate(prompts, height=512, width=768) 143 | ``` 144 | 145 | ## LICENSE 146 | 147 | All codes on this repository are licensed with MIT License. Please see LICENSE file. 148 | 149 | Note that checkpoint files of Stable Diffusion are licensed with [CreativeML Open RAIL-M](https://huggingface.co/spaces/CompVis/stable-diffusion-license) License. It has use-based restriction caluse, so you'd better read it. 150 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "collapsed_sections": [ 8 | "iDI2dKfRWTId" 9 | ] 10 | }, 11 | "kernelspec": { 12 | "name": "python3", 13 | "display_name": "Python 3" 14 | }, 15 | "language_info": { 16 | "name": "python" 17 | }, 18 | "accelerator": "GPU", 19 | "gpuClass": "standard" 20 | }, 21 | "cells": [ 22 | { 23 | "cell_type": "markdown", 24 | "source": [ 25 | "# Demo for stable-diffusion-pytorch" 26 | ], 27 | "metadata": { 28 | "id": "UhG5CzHQWNzr" 29 | } 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "## Install (takes about 1~5 minutes)" 35 | ], 36 | "metadata": { 37 | "id": "iDI2dKfRWTId" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "source": [ 43 | "%cd /content\n", 44 | "!git clone https://github.com/kjsman/stable-diffusion-pytorch" 45 | ], 46 | "metadata": { 47 | "id": "AgkJdPCbVjf6" 48 | }, 49 | "execution_count": null, 50 | "outputs": [] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "source": [ 55 | "# Note that all depencdencies of stable-diffusion-pytorch is pre-installed\n", 56 | "# on Colab environment. This cell basically does nothing on Colab.\n", 57 | "%cd /content/stable-diffusion-pytorch\n", 58 | "%pip install -r requirements.txt" 59 | ], 60 | "metadata": { 61 | "id": "uUsTYf-6BZGs" 62 | }, 63 | "execution_count": null, 64 | "outputs": [] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "source": [ 69 | "%cd /content/stable-diffusion-pytorch\n", 70 | "!wget https://huggingface.co/jinseokim/stable-diffusion-pytorch-data/resolve/main/data.v20221029.tar\n", 71 | "!tar -xf data.v20221029.tar" 72 | ], 73 | "metadata": { 74 | "id": "NXnKKOxcMsin" 75 | }, 76 | "execution_count": null, 77 | "outputs": [] 78 | }, 79 | { 80 | "cell_type": "markdown", 81 | "source": [ 82 | "## Run" 83 | ], 84 | "metadata": { 85 | "id": "UyM8vbLnWVNP" 86 | } 87 | }, 88 | { 89 | "cell_type": "code", 90 | "source": [ 91 | "#@title Preload models (takes about ~20 seconds on default settings)\n", 92 | "\n", 93 | "from stable_diffusion_pytorch import model_loader\n", 94 | "models = model_loader.preload_models('cpu')" 95 | ], 96 | "metadata": { 97 | "cellView": "form", 98 | "id": "fGOopQsDS-7U" 99 | }, 100 | "execution_count": null, 101 | "outputs": [] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "source": [ 106 | "#@title Inference (takes about 30~40 seconds on default settings)\n", 107 | "\n", 108 | "from stable_diffusion_pytorch import pipeline\n", 109 | "\n", 110 | "prompt = \"a photograph of an astronaut riding a horse\" #@param { type: \"string\" }\n", 111 | "prompts = [prompt]\n", 112 | "\n", 113 | "uncond_prompt = \"\" #@param { type: \"string\" }\n", 114 | "uncond_prompts = [uncond_prompt] if uncond_prompt else None\n", 115 | "\n", 116 | "upload_input_image = False #@param { type: \"boolean\" }\n", 117 | "input_images = None\n", 118 | "if upload_input_image:\n", 119 | " from PIL import Image\n", 120 | " from google.colab import files\n", 121 | " print(\"Upload an input image:\")\n", 122 | " path = list(files.upload().keys())[0]\n", 123 | " input_images = [Image.open(path)]\n", 124 | "\n", 125 | "strength = 0.8 #@param { type:\"slider\", min: 0, max: 1, step: 0.01 }\n", 126 | "\n", 127 | "do_cfg = True #@param { type: \"boolean\" }\n", 128 | "cfg_scale = 7.5 #@param { type:\"slider\", min: 1, max: 14, step: 0.5 }\n", 129 | "height = 512 #@param { type: \"integer\" }\n", 130 | "width = 512 #@param { type: \"integer\" }\n", 131 | "sampler = \"k_lms\" #@param [\"k_lms\", \"k_euler\", \"k_euler_ancestral\"]\n", 132 | "n_inference_steps = 50 #@param { type: \"integer\" }\n", 133 | "\n", 134 | "use_seed = False #@param { type: \"boolean\" }\n", 135 | "if use_seed:\n", 136 | " seed = 42 #@param { type: \"integer\" }\n", 137 | "else:\n", 138 | " seed = None\n", 139 | "\n", 140 | "pipeline.generate(prompts=prompts, uncond_prompts=uncond_prompts,\n", 141 | " input_images=input_images, strength=strength,\n", 142 | " do_cfg=do_cfg, cfg_scale=cfg_scale,\n", 143 | " height=height, width=width, sampler=sampler,\n", 144 | " n_inference_steps=n_inference_steps, seed=seed,\n", 145 | " models=models, device='cuda', idle_device='cpu')[0]" 146 | ], 147 | "metadata": { 148 | "cellView": "form", 149 | "id": "x_dhQfFYXoPu" 150 | }, 151 | "execution_count": null, 152 | "outputs": [] 153 | } 154 | ] 155 | } 156 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch >= 1.9 2 | numpy 3 | Pillow 4 | regex 5 | tqdm -------------------------------------------------------------------------------- /stable_diffusion_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizer import Tokenizer 2 | from .clip import CLIP 3 | from .encoder import Encoder 4 | from .decoder import Decoder 5 | from .diffusion import Diffusion 6 | from .samplers import KLMSSampler, KEulerSampler, KEulerAncestralSampler -------------------------------------------------------------------------------- /stable_diffusion_pytorch/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | 7 | class SelfAttention(nn.Module): 8 | def __init__(self, n_heads, d_embed, in_proj_bias=True, out_proj_bias=True): 9 | super().__init__() 10 | self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=in_proj_bias) 11 | self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias) 12 | self.n_heads = n_heads 13 | self.d_head = d_embed // n_heads 14 | 15 | def forward(self, x, causal_mask=False): 16 | input_shape = x.shape 17 | batch_size, sequence_length, d_embed = input_shape 18 | interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head) 19 | 20 | q, k, v = self.in_proj(x).chunk(3, dim=-1) 21 | 22 | q = q.view(interim_shape).transpose(1, 2) 23 | k = k.view(interim_shape).transpose(1, 2) 24 | v = v.view(interim_shape).transpose(1, 2) 25 | 26 | weight = q @ k.transpose(-1, -2) 27 | if causal_mask: 28 | mask = torch.ones_like(weight, dtype=torch.bool).triu(1) 29 | weight.masked_fill_(mask, -torch.inf) 30 | weight /= math.sqrt(self.d_head) 31 | weight = F.softmax(weight, dim=-1) 32 | 33 | output = weight @ v 34 | output = output.transpose(1, 2) 35 | output = output.reshape(input_shape) 36 | output = self.out_proj(output) 37 | return output 38 | 39 | class CrossAttention(nn.Module): 40 | def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True): 41 | super().__init__() 42 | self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias) 43 | self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias) 44 | self.v_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias) 45 | self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias) 46 | self.n_heads = n_heads 47 | self.d_head = d_embed // n_heads 48 | 49 | def forward(self, x, y): 50 | input_shape = x.shape 51 | batch_size, sequence_length, d_embed = input_shape 52 | interim_shape = (batch_size, -1, self.n_heads, self.d_head) 53 | 54 | q = self.q_proj(x) 55 | k = self.k_proj(y) 56 | v = self.v_proj(y) 57 | 58 | q = q.view(interim_shape).transpose(1, 2) 59 | k = k.view(interim_shape).transpose(1, 2) 60 | v = v.view(interim_shape).transpose(1, 2) 61 | 62 | weight = q @ k.transpose(-1, -2) 63 | weight /= math.sqrt(self.d_head) 64 | weight = F.softmax(weight, dim=-1) 65 | 66 | output = weight @ v 67 | output = output.transpose(1, 2).contiguous() 68 | output = output.view(input_shape) 69 | output = self.out_proj(output) 70 | return output -------------------------------------------------------------------------------- /stable_diffusion_pytorch/clip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .attention import SelfAttention 5 | 6 | 7 | class CLIPEmbedding(nn.Module): 8 | def __init__(self, n_vocab: int, n_embd: int, n_token: int): 9 | super().__init__() 10 | self.token_embedding = nn.Embedding(n_vocab, n_embd) 11 | self.position_value = nn.Parameter(torch.zeros((n_token, n_embd))) 12 | 13 | def forward(self, tokens): 14 | x = self.token_embedding(tokens) 15 | x += self.position_value 16 | return x 17 | 18 | class CLIPLayer(nn.Module): 19 | def __init__(self, n_head: int, n_embd: int): 20 | super().__init__() 21 | self.layernorm_1 = nn.LayerNorm(n_embd) 22 | self.attention = SelfAttention(n_head, n_embd) 23 | self.layernorm_2 = nn.LayerNorm(n_embd) 24 | self.linear_1 = nn.Linear(n_embd, 4 * n_embd) 25 | self.linear_2 = nn.Linear(4 * n_embd, n_embd) 26 | 27 | def forward(self, x): 28 | residue = x 29 | x = self.layernorm_1(x) 30 | x = self.attention(x, causal_mask=True) 31 | x += residue 32 | 33 | residue = x 34 | x = self.layernorm_2(x) 35 | x = self.linear_1(x) 36 | x = x * torch.sigmoid(1.702 * x) # QuickGELU activation function 37 | x = self.linear_2(x) 38 | x += residue 39 | 40 | return x 41 | 42 | class CLIP(nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | self.embedding = CLIPEmbedding(49408, 768, 77) 46 | self.layers = nn.ModuleList([ 47 | CLIPLayer(12, 768) for i in range(12) 48 | ]) 49 | self.layernorm = nn.LayerNorm(768) 50 | 51 | def forward(self, tokens: torch.LongTensor) -> torch.FloatTensor: 52 | tokens = tokens.type(torch.long) 53 | 54 | state = self.embedding(tokens) 55 | for layer in self.layers: 56 | state = layer(state) 57 | output = self.layernorm(state) 58 | return output -------------------------------------------------------------------------------- /stable_diffusion_pytorch/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .attention import SelfAttention 5 | 6 | 7 | class AttentionBlock(nn.Module): 8 | def __init__(self, channels): 9 | super().__init__() 10 | self.groupnorm = nn.GroupNorm(32, channels) 11 | self.attention = SelfAttention(1, channels) 12 | 13 | def forward(self, x): 14 | residue = x 15 | x = self.groupnorm(x) 16 | 17 | n, c, h, w = x.shape 18 | x = x.view((n, c, h * w)) 19 | x = x.transpose(-1, -2) 20 | x = self.attention(x) 21 | x = x.transpose(-1, -2) 22 | x = x.view((n, c, h, w)) 23 | 24 | x += residue 25 | return x 26 | 27 | class ResidualBlock(nn.Module): 28 | def __init__(self, in_channels, out_channels): 29 | super().__init__() 30 | self.groupnorm_1 = nn.GroupNorm(32, in_channels) 31 | self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 32 | 33 | self.groupnorm_2 = nn.GroupNorm(32, out_channels) 34 | self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 35 | 36 | if in_channels == out_channels: 37 | self.residual_layer = nn.Identity() 38 | else: 39 | self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 40 | 41 | def forward(self, x): 42 | residue = x 43 | 44 | x = self.groupnorm_1(x) 45 | x = F.silu(x) 46 | x = self.conv_1(x) 47 | 48 | x = self.groupnorm_2(x) 49 | x = F.silu(x) 50 | x = self.conv_2(x) 51 | 52 | return x + self.residual_layer(residue) 53 | 54 | class Decoder(nn.Sequential): 55 | def __init__(self): 56 | super().__init__( 57 | nn.Conv2d(4, 4, kernel_size=1, padding=0), 58 | nn.Conv2d(4, 512, kernel_size=3, padding=1), 59 | ResidualBlock(512, 512), 60 | AttentionBlock(512), 61 | ResidualBlock(512, 512), 62 | ResidualBlock(512, 512), 63 | ResidualBlock(512, 512), 64 | ResidualBlock(512, 512), 65 | nn.Upsample(scale_factor=2), 66 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 67 | ResidualBlock(512, 512), 68 | ResidualBlock(512, 512), 69 | ResidualBlock(512, 512), 70 | nn.Upsample(scale_factor=2), 71 | nn.Conv2d(512, 512, kernel_size=3, padding=1), 72 | ResidualBlock(512, 256), 73 | ResidualBlock(256, 256), 74 | ResidualBlock(256, 256), 75 | nn.Upsample(scale_factor=2), 76 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 77 | ResidualBlock(256, 128), 78 | ResidualBlock(128, 128), 79 | ResidualBlock(128, 128), 80 | nn.GroupNorm(32, 128), 81 | nn.SiLU(), 82 | nn.Conv2d(128, 3, kernel_size=3, padding=1), 83 | ) 84 | 85 | def forward(self, x): 86 | x /= 0.18215 87 | for module in self: 88 | x = module(x) 89 | return x -------------------------------------------------------------------------------- /stable_diffusion_pytorch/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .attention import SelfAttention, CrossAttention 5 | 6 | 7 | class TimeEmbedding(nn.Module): 8 | def __init__(self, n_embd): 9 | super().__init__() 10 | self.linear_1 = nn.Linear(n_embd, 4 * n_embd) 11 | self.linear_2 = nn.Linear(4 * n_embd, 4 * n_embd) 12 | 13 | def forward(self, x): 14 | x = self.linear_1(x) 15 | x = F.silu(x) 16 | x = self.linear_2(x) 17 | return x 18 | 19 | class ResidualBlock(nn.Module): 20 | def __init__(self, in_channels, out_channels, n_time=1280): 21 | super().__init__() 22 | self.groupnorm_feature = nn.GroupNorm(32, in_channels) 23 | self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 24 | self.linear_time = nn.Linear(n_time, out_channels) 25 | 26 | self.groupnorm_merged = nn.GroupNorm(32, out_channels) 27 | self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 28 | 29 | if in_channels == out_channels: 30 | self.residual_layer = nn.Identity() 31 | else: 32 | self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) 33 | 34 | def forward(self, feature, time): 35 | residue = feature 36 | 37 | feature = self.groupnorm_feature(feature) 38 | feature = F.silu(feature) 39 | feature = self.conv_feature(feature) 40 | 41 | time = F.silu(time) 42 | time = self.linear_time(time) 43 | 44 | merged = feature + time.unsqueeze(-1).unsqueeze(-1) 45 | merged = self.groupnorm_merged(merged) 46 | merged = F.silu(merged) 47 | merged = self.conv_merged(merged) 48 | 49 | return merged + self.residual_layer(residue) 50 | 51 | class AttentionBlock(nn.Module): 52 | def __init__(self, n_head: int, n_embd: int, d_context=768): 53 | super().__init__() 54 | channels = n_head * n_embd 55 | 56 | self.groupnorm = nn.GroupNorm(32, channels, eps=1e-6) 57 | self.conv_input = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 58 | 59 | self.layernorm_1 = nn.LayerNorm(channels) 60 | self.attention_1 = SelfAttention(n_head, channels, in_proj_bias=False) 61 | self.layernorm_2 = nn.LayerNorm(channels) 62 | self.attention_2 = CrossAttention(n_head, channels, d_context, in_proj_bias=False) 63 | self.layernorm_3 = nn.LayerNorm(channels) 64 | self.linear_geglu_1 = nn.Linear(channels, 4 * channels * 2) 65 | self.linear_geglu_2 = nn.Linear(4 * channels, channels) 66 | 67 | self.conv_output = nn.Conv2d(channels, channels, kernel_size=1, padding=0) 68 | 69 | def forward(self, x, context): 70 | residue_long = x 71 | 72 | x = self.groupnorm(x) 73 | x = self.conv_input(x) 74 | 75 | n, c, h, w = x.shape 76 | x = x.view((n, c, h * w)) # (n, c, hw) 77 | x = x.transpose(-1, -2) # (n, hw, c) 78 | 79 | residue_short = x 80 | x = self.layernorm_1(x) 81 | x = self.attention_1(x) 82 | x += residue_short 83 | 84 | residue_short = x 85 | x = self.layernorm_2(x) 86 | x = self.attention_2(x, context) 87 | x += residue_short 88 | 89 | residue_short = x 90 | x = self.layernorm_3(x) 91 | x, gate = self.linear_geglu_1(x).chunk(2, dim=-1) 92 | x = x * F.gelu(gate) 93 | x = self.linear_geglu_2(x) 94 | x += residue_short 95 | 96 | x = x.transpose(-1, -2) # (n, c, hw) 97 | x = x.view((n, c, h, w)) # (n, c, h, w) 98 | 99 | return self.conv_output(x) + residue_long 100 | 101 | class Upsample(nn.Module): 102 | def __init__(self, channels): 103 | super().__init__() 104 | self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 105 | 106 | def forward(self, x): 107 | x = F.interpolate(x, scale_factor=2, mode='nearest') 108 | return self.conv(x) 109 | 110 | class SwitchSequential(nn.Sequential): 111 | def forward(self, x, context, time): 112 | for layer in self: 113 | if isinstance(layer, AttentionBlock): 114 | x = layer(x, context) 115 | elif isinstance(layer, ResidualBlock): 116 | x = layer(x, time) 117 | else: 118 | x = layer(x) 119 | return x 120 | 121 | class UNet(nn.Module): 122 | def __init__(self): 123 | super().__init__() 124 | self.encoders = nn.ModuleList([ 125 | SwitchSequential(nn.Conv2d(4, 320, kernel_size=3, padding=1)), 126 | SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)), 127 | SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)), 128 | SwitchSequential(nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)), 129 | SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)), 130 | SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)), 131 | SwitchSequential(nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)), 132 | SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)), 133 | SwitchSequential(ResidualBlock(1280, 1280), AttentionBlock(8, 160)), 134 | SwitchSequential(nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)), 135 | SwitchSequential(ResidualBlock(1280, 1280)), 136 | SwitchSequential(ResidualBlock(1280, 1280)), 137 | ]) 138 | self.bottleneck = SwitchSequential( 139 | ResidualBlock(1280, 1280), 140 | AttentionBlock(8, 160), 141 | ResidualBlock(1280, 1280), 142 | ) 143 | self.decoders = nn.ModuleList([ 144 | SwitchSequential(ResidualBlock(2560, 1280)), 145 | SwitchSequential(ResidualBlock(2560, 1280)), 146 | SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)), 147 | SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), 148 | SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)), 149 | SwitchSequential(ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)), 150 | SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)), 151 | SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)), 152 | SwitchSequential(ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)), 153 | SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)), 154 | SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), 155 | SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)), 156 | ]) 157 | 158 | def forward(self, x, context, time): 159 | skip_connections = [] 160 | for layers in self.encoders: 161 | x = layers(x, context, time) 162 | skip_connections.append(x) 163 | 164 | x = self.bottleneck(x, context, time) 165 | 166 | for layers in self.decoders: 167 | x = torch.cat((x, skip_connections.pop()), dim=1) 168 | x = layers(x, context, time) 169 | 170 | return x 171 | 172 | 173 | class FinalLayer(nn.Module): 174 | def __init__(self, in_channels, out_channels): 175 | super().__init__() 176 | self.groupnorm = nn.GroupNorm(32, in_channels) 177 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 178 | 179 | def forward(self, x): 180 | x = self.groupnorm(x) 181 | x = F.silu(x) 182 | x = self.conv(x) 183 | return x 184 | 185 | class Diffusion(nn.Module): 186 | def __init__(self): 187 | super().__init__() 188 | self.time_embedding = TimeEmbedding(320) 189 | self.unet = UNet() 190 | self.final = FinalLayer(320, 4) 191 | 192 | def forward(self, latent, context, time): 193 | time = self.time_embedding(time) 194 | output = self.unet(latent, context, time) 195 | output = self.final(output) 196 | return output -------------------------------------------------------------------------------- /stable_diffusion_pytorch/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from .decoder import AttentionBlock, ResidualBlock 5 | 6 | 7 | class Encoder(nn.Sequential): 8 | def __init__(self): 9 | super().__init__( 10 | nn.Conv2d(3, 128, kernel_size=3, padding=1), 11 | ResidualBlock(128, 128), 12 | ResidualBlock(128, 128), 13 | nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0), 14 | ResidualBlock(128, 256), 15 | ResidualBlock(256, 256), 16 | nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0), 17 | ResidualBlock(256, 512), 18 | ResidualBlock(512, 512), 19 | nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0), 20 | ResidualBlock(512, 512), 21 | ResidualBlock(512, 512), 22 | ResidualBlock(512, 512), 23 | AttentionBlock(512), 24 | ResidualBlock(512, 512), 25 | nn.GroupNorm(32, 512), 26 | nn.SiLU(), 27 | nn.Conv2d(512, 8, kernel_size=3, padding=1), 28 | nn.Conv2d(8, 8, kernel_size=1, padding=0), 29 | ) 30 | 31 | def forward(self, x, noise): 32 | for module in self: 33 | if getattr(module, 'stride', None) == (2, 2): # Padding at downsampling should be asymmetric (see #8) 34 | x = F.pad(x, (0, 1, 0, 1)) 35 | x = module(x) 36 | 37 | mean, log_variance = torch.chunk(x, 2, dim=1) 38 | log_variance = torch.clamp(log_variance, -30, 20) 39 | variance = log_variance.exp() 40 | stdev = variance.sqrt() 41 | x = mean + stdev * noise 42 | 43 | x *= 0.18215 44 | return x 45 | 46 | -------------------------------------------------------------------------------- /stable_diffusion_pytorch/model_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from . import Tokenizer, CLIP, Encoder, Decoder, Diffusion 3 | from . import util 4 | import warnings 5 | 6 | 7 | def make_compatible(state_dict): 8 | keys = list(state_dict.keys()) 9 | changed = False 10 | for key in keys: 11 | if "causal_attention_mask" in key: 12 | del state_dict[key] 13 | changed = True 14 | elif "_proj_weight" in key: 15 | new_key = key.replace('_proj_weight', '_proj.weight') 16 | state_dict[new_key] = state_dict[key] 17 | del state_dict[key] 18 | changed = True 19 | elif "_proj_bias" in key: 20 | new_key = key.replace('_proj_bias', '_proj.bias') 21 | state_dict[new_key] = state_dict[key] 22 | del state_dict[key] 23 | changed = True 24 | 25 | if changed: 26 | warnings.warn(("Given checkpoint data were modified dynamically by make_compatible" 27 | " function on model_loader.py. Maybe this happened because you're" 28 | " running newer codes with older checkpoint files. This behavior" 29 | " (modify old checkpoints and notify rather than throw an error)" 30 | " will be removed soon, so please download latest checkpoints file.")) 31 | 32 | return state_dict 33 | 34 | def load_clip(device): 35 | state_dict = torch.load(util.get_file_path('ckpt/clip.pt')) 36 | state_dict = make_compatible(state_dict) 37 | 38 | clip = CLIP().to(device) 39 | clip.load_state_dict(state_dict) 40 | return clip 41 | 42 | def load_encoder(device): 43 | state_dict = torch.load(util.get_file_path('ckpt/encoder.pt')) 44 | state_dict = make_compatible(state_dict) 45 | 46 | encoder = Encoder().to(device) 47 | encoder.load_state_dict(state_dict) 48 | return encoder 49 | 50 | def load_decoder(device): 51 | state_dict = torch.load(util.get_file_path('ckpt/decoder.pt')) 52 | state_dict = make_compatible(state_dict) 53 | 54 | decoder = Decoder().to(device) 55 | decoder.load_state_dict(state_dict) 56 | return decoder 57 | 58 | def load_diffusion(device): 59 | state_dict = torch.load(util.get_file_path('ckpt/diffusion.pt')) 60 | state_dict = make_compatible(state_dict) 61 | 62 | diffusion = Diffusion().to(device) 63 | diffusion.load_state_dict(state_dict) 64 | return diffusion 65 | 66 | def preload_models(device): 67 | return { 68 | 'clip': load_clip(device), 69 | 'encoder': load_encoder(device), 70 | 'decoder': load_decoder(device), 71 | 'diffusion': load_diffusion(device), 72 | } -------------------------------------------------------------------------------- /stable_diffusion_pytorch/pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from PIL import Image 4 | from tqdm import tqdm 5 | from . import Tokenizer 6 | from . import KLMSSampler, KEulerSampler, KEulerAncestralSampler 7 | from . import util 8 | from . import model_loader 9 | 10 | 11 | def generate( 12 | prompts, 13 | uncond_prompts=None, 14 | input_images=None, 15 | strength=0.8, 16 | do_cfg=True, 17 | cfg_scale=7.5, 18 | height=512, 19 | width=512, 20 | sampler="k_lms", 21 | n_inference_steps=50, 22 | models={}, 23 | seed=None, 24 | device=None, 25 | idle_device=None 26 | ): 27 | r""" 28 | Function invoked when calling the pipeline for generation. 29 | Args: 30 | prompts (`List[str]`): 31 | The prompts to guide the image generation. 32 | uncond_prompts (`List[str]`, *optional*, defaults to `[""] * len(prompts)`): 33 | The prompts not to guide the image generation. Ignored when not using guidance (i.e. ignored if 34 | `do_cfg` is False). 35 | input_images (List[Union[`PIL.Image.Image`, str]]): 36 | Images which are served as the starting point for the image generation. 37 | strength (`float`, *optional*, defaults to 0.8): 38 | Conceptually, indicates how much to transform the reference `input_images`. Must be between 0 and 1. 39 | `input_images` will be used as a starting point, adding more noise to it the larger the `strength`. 40 | The number of denoising steps depends on the amount of noise initially added. When `strength` is 1, 41 | added noise will be maximum and the denoising process will run for the full number of iterations 42 | specified in `n_inference_steps`. A value of 1, therefore, essentially ignores `input_images`. 43 | do_cfg (`bool`, *optional*, defaults to True): 44 | Enable [classifier-free guidance](https://arxiv.org/abs/2207.12598). 45 | cfg_scale (`float`, *optional*, defaults to 7.5): 46 | Guidance scale of classifier-free guidance. Ignored when it is disabled (i.e. ignored if 47 | `do_cfg` is False). Higher guidance scale encourages to generate images that are closely linked 48 | to the text `prompt`, usually at the expense of lower image quality. 49 | height (`int`, *optional*, defaults to 512): 50 | The height in pixels of the generated image. Ignored when `input_images` are provided. 51 | width (`int`, *optional*, defaults to 512): 52 | The width in pixels of the generated image. Ignored when `input_images` are provided. 53 | sampler (`str`, *optional*, defaults to "k_lms"): 54 | A sampler to be used to denoise the encoded image latents. Can be one of `"k_lms"`, `"k_euler"`, 55 | or `"k_euler_ancestral"`. 56 | n_inference_steps (`int`, *optional*, defaults to 50): 57 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 58 | expense of slower inference. This parameter will be modulated by `strength`. 59 | models (`Dict[str, torch.nn.Module]`, *optional*): 60 | Preloaded models. If some or all models are not provided, they will be loaded dynamically. 61 | seed (`int`, *optional*): 62 | A seed to make generation deterministic. 63 | device (`str` or `torch.device`, *optional*): 64 | PyTorch device which the image generation happens. If not provided, 'cuda' or 'cpu' will be used. 65 | idle_device (`str` or `torch.device`, *optional*): 66 | PyTorch device which the models no longer in use are moved to. 67 | Returns: 68 | `List[PIL.Image.Image]`: 69 | The generated images. 70 | Note: 71 | This docstring is heavily copied from huggingface/diffusers. 72 | """ 73 | with torch.no_grad(): 74 | if not isinstance(prompts, (list, tuple)) or not prompts: 75 | raise ValueError("prompts must be a non-empty list or tuple") 76 | 77 | if uncond_prompts and not isinstance(uncond_prompts, (list, tuple)): 78 | raise ValueError("uncond_prompts must be a non-empty list or tuple if provided") 79 | if uncond_prompts and len(prompts) != len(uncond_prompts): 80 | raise ValueError("length of uncond_prompts must be same as length of prompts") 81 | uncond_prompts = uncond_prompts or [""] * len(prompts) 82 | 83 | if input_images and not isinstance(uncond_prompts, (list, tuple)): 84 | raise ValueError("input_images must be a non-empty list or tuple if provided") 85 | if input_images and len(prompts) != len(input_images): 86 | raise ValueError("length of input_images must be same as length of prompts") 87 | if not 0 < strength < 1: 88 | raise ValueError("strength must be between 0 and 1") 89 | 90 | if height % 8 or width % 8: 91 | raise ValueError("height and width must be a multiple of 8") 92 | 93 | if device is None: 94 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 95 | 96 | if idle_device: 97 | to_idle = lambda x: x.to(idle_device) 98 | else: 99 | to_idle = lambda x: x 100 | 101 | generator = torch.Generator(device=device) 102 | if seed is None: 103 | generator.seed() 104 | else: 105 | generator.manual_seed(seed) 106 | 107 | tokenizer = Tokenizer() 108 | clip = models.get('clip') or model_loader.load_clip(device) 109 | clip.to(device) 110 | 111 | # use the dtype of the model weights as our dtype 112 | dtype = clip.embedding.position_value.dtype 113 | if do_cfg: 114 | cond_tokens = tokenizer.encode_batch(prompts) 115 | cond_tokens = torch.tensor(cond_tokens, dtype=torch.long, device=device) 116 | cond_context = clip(cond_tokens) 117 | uncond_tokens = tokenizer.encode_batch(uncond_prompts) 118 | uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=device) 119 | uncond_context = clip(uncond_tokens) 120 | context = torch.cat([cond_context, uncond_context]) 121 | else: 122 | tokens = tokenizer.encode_batch(prompts) 123 | tokens = torch.tensor(tokens, dtype=torch.long, device=device) 124 | context = clip(tokens) 125 | to_idle(clip) 126 | del tokenizer, clip 127 | 128 | if sampler == "k_lms": 129 | sampler = KLMSSampler(n_inference_steps=n_inference_steps) 130 | elif sampler == "k_euler": 131 | sampler = KEulerSampler(n_inference_steps=n_inference_steps) 132 | elif sampler == "k_euler_ancestral": 133 | sampler = KEulerAncestralSampler(n_inference_steps=n_inference_steps, 134 | generator=generator) 135 | else: 136 | raise ValueError( 137 | "Unknown sampler value %s. " 138 | "Accepted values are {k_lms, k_euler, k_euler_ancestral}" 139 | % sampler 140 | ) 141 | 142 | noise_shape = (len(prompts), 4, height // 8, width // 8) 143 | 144 | if input_images: 145 | encoder = models.get('encoder') or model_loader.load_encoder(device) 146 | encoder.to(device) 147 | processed_input_images = [] 148 | for input_image in input_images: 149 | if type(input_image) is str: 150 | input_image = Image.open(input_image) 151 | 152 | input_image = input_image.resize((width, height)) 153 | input_image = np.array(input_image) 154 | input_image = torch.tensor(input_image, dtype=dtype) 155 | input_image = util.rescale(input_image, (0, 255), (-1, 1)) 156 | processed_input_images.append(input_image) 157 | input_images_tensor = torch.stack(processed_input_images).to(device) 158 | input_images_tensor = util.move_channel(input_images_tensor, to="first") 159 | 160 | _, _, height, width = input_images_tensor.shape 161 | noise_shape = (len(prompts), 4, height // 8, width // 8) 162 | 163 | encoder_noise = torch.randn(noise_shape, generator=generator, device=device, dtype=dtype) 164 | latents = encoder(input_images_tensor, encoder_noise) 165 | 166 | latents_noise = torch.randn(noise_shape, generator=generator, device=device, dtype=dtype) 167 | sampler.set_strength(strength=strength) 168 | latents += latents_noise * sampler.initial_scale 169 | 170 | to_idle(encoder) 171 | del encoder, processed_input_images, input_images_tensor, latents_noise 172 | else: 173 | latents = torch.randn(noise_shape, generator=generator, device=device, dtype=dtype) 174 | latents *= sampler.initial_scale 175 | 176 | diffusion = models.get('diffusion') or model_loader.load_diffusion(device) 177 | diffusion.to(device) 178 | 179 | timesteps = tqdm(sampler.timesteps) 180 | for i, timestep in enumerate(timesteps): 181 | time_embedding = util.get_time_embedding(timestep, dtype).to(device) 182 | 183 | input_latents = latents * sampler.get_input_scale() 184 | if do_cfg: 185 | input_latents = input_latents.repeat(2, 1, 1, 1) 186 | 187 | output = diffusion(input_latents, context, time_embedding) 188 | if do_cfg: 189 | output_cond, output_uncond = output.chunk(2) 190 | output = cfg_scale * (output_cond - output_uncond) + output_uncond 191 | 192 | latents = sampler.step(latents, output) 193 | 194 | to_idle(diffusion) 195 | del diffusion 196 | 197 | decoder = models.get('decoder') or model_loader.load_decoder(device) 198 | decoder.to(device) 199 | images = decoder(latents) 200 | to_idle(decoder) 201 | del decoder 202 | 203 | images = util.rescale(images, (-1, 1), (0, 255), clamp=True) 204 | images = util.move_channel(images, to="last") 205 | images = images.to('cpu', torch.uint8).numpy() 206 | 207 | return [Image.fromarray(image) for image in images] 208 | -------------------------------------------------------------------------------- /stable_diffusion_pytorch/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .k_lms import KLMSSampler 2 | from .k_euler import KEulerSampler 3 | from .k_euler_ancestral import KEulerAncestralSampler -------------------------------------------------------------------------------- /stable_diffusion_pytorch/samplers/k_euler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .. import util 3 | 4 | 5 | class KEulerSampler(): 6 | def __init__(self, n_inference_steps=50, n_training_steps=1000): 7 | timesteps = np.linspace(n_training_steps - 1, 0, n_inference_steps) 8 | 9 | alphas_cumprod = util.get_alphas_cumprod(n_training_steps=n_training_steps) 10 | sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 11 | log_sigmas = np.log(sigmas) 12 | log_sigmas = np.interp(timesteps, range(n_training_steps), log_sigmas) 13 | sigmas = np.exp(log_sigmas) 14 | sigmas = np.append(sigmas, 0) 15 | 16 | self.sigmas = sigmas 17 | self.initial_scale = sigmas.max() 18 | self.timesteps = timesteps 19 | self.n_inference_steps = n_inference_steps 20 | self.n_training_steps = n_training_steps 21 | self.step_count = 0 22 | 23 | def get_input_scale(self, step_count=None): 24 | if step_count is None: 25 | step_count = self.step_count 26 | sigma = self.sigmas[step_count] 27 | return 1 / (sigma ** 2 + 1) ** 0.5 28 | 29 | def set_strength(self, strength=1): 30 | start_step = self.n_inference_steps - int(self.n_inference_steps * strength) 31 | self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps) 32 | self.timesteps = self.timesteps[start_step:] 33 | self.initial_scale = self.sigmas[start_step] 34 | self.step_count = start_step 35 | 36 | def step(self, latents, output): 37 | t = self.step_count 38 | self.step_count += 1 39 | 40 | sigma_from = self.sigmas[t] 41 | sigma_to = self.sigmas[t + 1] 42 | latents += output * (sigma_to - sigma_from) 43 | return latents -------------------------------------------------------------------------------- /stable_diffusion_pytorch/samplers/k_euler_ancestral.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from .. import util 4 | 5 | 6 | class KEulerAncestralSampler(): 7 | def __init__(self, n_inference_steps=50, n_training_steps=1000, generator=None): 8 | timesteps = np.linspace(n_training_steps - 1, 0, n_inference_steps) 9 | 10 | alphas_cumprod = util.get_alphas_cumprod(n_training_steps=n_training_steps) 11 | sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 12 | log_sigmas = np.log(sigmas) 13 | log_sigmas = np.interp(timesteps, range(n_training_steps), log_sigmas) 14 | sigmas = np.exp(log_sigmas) 15 | sigmas = np.append(sigmas, 0) 16 | 17 | self.sigmas = sigmas 18 | self.initial_scale = sigmas.max() 19 | self.timesteps = timesteps 20 | self.n_inference_steps = n_inference_steps 21 | self.n_training_steps = n_training_steps 22 | self.step_count = 0 23 | self.generator = generator 24 | 25 | def get_input_scale(self, step_count=None): 26 | if step_count is None: 27 | step_count = self.step_count 28 | sigma = self.sigmas[step_count] 29 | return 1 / (sigma ** 2 + 1) ** 0.5 30 | 31 | def set_strength(self, strength=1): 32 | start_step = self.n_inference_steps - int(self.n_inference_steps * strength) 33 | self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps) 34 | self.timesteps = self.timesteps[start_step:] 35 | self.initial_scale = self.sigmas[start_step] 36 | self.step_count = start_step 37 | 38 | def step(self, latents, output): 39 | t = self.step_count 40 | self.step_count += 1 41 | 42 | sigma_from = self.sigmas[t] 43 | sigma_to = self.sigmas[t + 1] 44 | sigma_up = sigma_to * (1 - (sigma_to ** 2 / sigma_from ** 2)) ** 0.5 45 | sigma_down = sigma_to ** 2 / sigma_from 46 | latents += output * (sigma_down - sigma_from) 47 | noise = torch.randn( 48 | latents.shape, generator=self.generator, device=latents.device) 49 | latents += noise * sigma_up 50 | return latents -------------------------------------------------------------------------------- /stable_diffusion_pytorch/samplers/k_lms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .. import util 3 | 4 | 5 | class KLMSSampler(): 6 | def __init__(self, n_inference_steps=50, n_training_steps=1000, lms_order=4): 7 | timesteps = np.linspace(n_training_steps - 1, 0, n_inference_steps) 8 | 9 | alphas_cumprod = util.get_alphas_cumprod(n_training_steps=n_training_steps) 10 | sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 11 | log_sigmas = np.log(sigmas) 12 | log_sigmas = np.interp(timesteps, range(n_training_steps), log_sigmas) 13 | sigmas = np.exp(log_sigmas) 14 | sigmas = np.append(sigmas, 0) 15 | 16 | self.sigmas = sigmas 17 | self.initial_scale = sigmas.max() 18 | self.timesteps = timesteps 19 | self.n_inference_steps = n_inference_steps 20 | self.n_training_steps = n_training_steps 21 | self.lms_order = lms_order 22 | self.step_count = 0 23 | self.outputs = [] 24 | 25 | def get_input_scale(self, step_count=None): 26 | if step_count is None: 27 | step_count = self.step_count 28 | sigma = self.sigmas[step_count] 29 | return 1 / (sigma ** 2 + 1) ** 0.5 30 | 31 | def set_strength(self, strength=1): 32 | start_step = self.n_inference_steps - int(self.n_inference_steps * strength) 33 | self.timesteps = np.linspace(self.n_training_steps - 1, 0, self.n_inference_steps) 34 | self.timesteps = self.timesteps[start_step:] 35 | self.initial_scale = self.sigmas[start_step] 36 | self.step_count = start_step 37 | 38 | def step(self, latents, output): 39 | t = self.step_count 40 | self.step_count += 1 41 | 42 | self.outputs = [output] + self.outputs[:self.lms_order - 1] 43 | order = len(self.outputs) 44 | for i, output in enumerate(self.outputs): 45 | # Integrate polynomial by trapezoidal approx. method for 81 points. 46 | x = np.linspace(self.sigmas[t], self.sigmas[t + 1], 81) 47 | y = np.ones(81) 48 | for j in range(order): 49 | if i == j: 50 | continue 51 | y *= x - self.sigmas[t - j] 52 | y /= self.sigmas[t - i] - self.sigmas[t - j] 53 | lms_coeff = np.trapz(y=y, x=x) 54 | latents += lms_coeff * output 55 | return latents -------------------------------------------------------------------------------- /stable_diffusion_pytorch/tokenizer.py: -------------------------------------------------------------------------------- 1 | import unicodedata 2 | import functools 3 | import itertools 4 | import json 5 | from typing import List, Tuple 6 | import regex as re 7 | from . import util 8 | 9 | 10 | def create_bytes_table() -> dict: 11 | table = {} 12 | special_count = 0 13 | for byte in range(256): 14 | category = unicodedata.category(chr(byte)) 15 | if category[0] not in ['C', 'Z']: # ith character is NOT control char or space 16 | table[byte] = chr(byte) 17 | else: # ith character IS control char or space 18 | table[byte] = chr(special_count + 256) 19 | special_count += 1 20 | return table 21 | 22 | def pairwise(seq): 23 | a = iter(seq) 24 | b = iter(seq) 25 | next(b) 26 | return zip(a, b) 27 | 28 | class Tokenizer: 29 | def __init__(self, ): 30 | with open(util.get_file_path('vocab.json'), encoding='utf-8') as f: 31 | self.vocab = json.load(f) 32 | 33 | with open(util.get_file_path('merges.txt'), encoding='utf-8') as f: 34 | lines = f.read().split('\n') 35 | lines = lines[1:-1] 36 | self.merges = {tuple(bigram.split()): i for i, bigram in enumerate(lines)} 37 | 38 | self.bos_token = self.vocab["<|startoftext|>"] 39 | self.eos_token = self.vocab["<|endoftext|>"] 40 | self.pad_token = self.vocab["<|endoftext|>"] 41 | self.max_length = 77 42 | self.bytes_table = create_bytes_table() 43 | self.chunk_pattern = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 44 | 45 | def encode(self, text: str) -> List[int]: 46 | text = unicodedata.normalize('NFC', text) 47 | text = re.sub(r'\s+', ' ', text) 48 | text = text.strip() 49 | text = text.lower() 50 | 51 | tokens = [self.bos_token] 52 | for chunk in re.findall(self.chunk_pattern, text): 53 | chunk = ''.join(self.bytes_table[byte] for byte in chunk.encode('utf-8')) 54 | tokens.extend(self.vocab[word] for word in self.bpe(chunk)) 55 | tokens.append(self.eos_token) 56 | 57 | tokens = tokens[:self.max_length] 58 | token_length = len(tokens) 59 | pad_length = self.max_length - token_length 60 | tokens += [self.pad_token] * pad_length 61 | return tokens 62 | 63 | def encode_batch(self, texts: List[str]) -> List[List[int]]: 64 | return [self.encode(text) for text in texts] 65 | 66 | @functools.lru_cache(maxsize=10000) 67 | def bpe(self, chunk: str) -> Tuple[str]: 68 | words = list(chunk) 69 | words[-1] += "" 70 | 71 | while len(words) > 1: 72 | valid_pairs = [pair for pair in pairwise(words) if pair in self.merges] 73 | if not valid_pairs: 74 | break 75 | 76 | bigram = min(valid_pairs, key=lambda pair: self.merges[pair]) 77 | first, second = bigram 78 | 79 | new_words = [] 80 | for word in words: 81 | if word == second and new_words and new_words[-1] == first: 82 | new_words[-1] = first + second 83 | else: 84 | new_words.append(word) 85 | words = new_words 86 | 87 | return tuple(words) -------------------------------------------------------------------------------- /stable_diffusion_pytorch/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | 5 | 6 | def get_time_embedding(timestep, dtype): 7 | freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=dtype) / 160) 8 | x = torch.tensor([timestep], dtype=dtype)[:, None] * freqs[None] 9 | return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) 10 | 11 | def get_alphas_cumprod(beta_start=0.00085, beta_end=0.0120, n_training_steps=1000): 12 | betas = np.linspace(beta_start ** 0.5, beta_end ** 0.5, n_training_steps, dtype=np.float32) ** 2 13 | alphas = 1.0 - betas 14 | alphas_cumprod = np.cumprod(alphas, axis=0) 15 | return alphas_cumprod 16 | 17 | def get_file_path(filename, url=None): 18 | module_location = os.path.dirname(os.path.abspath(__file__)) 19 | parent_location = os.path.dirname(module_location) 20 | file_location = os.path.join(parent_location, "data", filename) 21 | return file_location 22 | 23 | def move_channel(image, to): 24 | if to == "first": 25 | return image.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) 26 | elif to == "last": 27 | return image.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) 28 | else: 29 | raise ValueError("to must be one of the following: first, last") 30 | 31 | def rescale(x, old_range, new_range, clamp=False): 32 | old_min, old_max = old_range 33 | new_min, new_max = new_range 34 | x -= old_min 35 | x *= (new_max - new_min) / (old_max - old_min) 36 | x += new_min 37 | if clamp: 38 | x = x.clamp(new_min, new_max) 39 | return x 40 | --------------------------------------------------------------------------------