├── requirements.txt ├── assets ├── SpeedUp.jpg ├── forest1.jpg ├── forest2.jpg ├── forest3.jpg ├── woman.jpg └── Average-Speed.jpg ├── hyper_tile ├── __init__.py ├── utils.py └── hyper_tile.py ├── .gitignore ├── setup.py ├── LICENSE ├── README.md └── playground.ipynb /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | einops -------------------------------------------------------------------------------- /assets/SpeedUp.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/HyperTile/HEAD/assets/SpeedUp.jpg -------------------------------------------------------------------------------- /assets/forest1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/HyperTile/HEAD/assets/forest1.jpg -------------------------------------------------------------------------------- /assets/forest2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/HyperTile/HEAD/assets/forest2.jpg -------------------------------------------------------------------------------- /assets/forest3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/HyperTile/HEAD/assets/forest3.jpg -------------------------------------------------------------------------------- /assets/woman.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/HyperTile/HEAD/assets/woman.jpg -------------------------------------------------------------------------------- /hyper_tile/__init__.py: -------------------------------------------------------------------------------- 1 | from .hyper_tile import split_attention 2 | from .utils import flush 3 | -------------------------------------------------------------------------------- /assets/Average-Speed.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tfernd/HyperTile/HEAD/assets/Average-Speed.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .ipynb_checkpoints 3 | 4 | # Dev 5 | /dev 6 | format.bat 7 | /*.jpg 8 | /*.jpeg 9 | /*.png 10 | 11 | # stabilidty 12 | /ldm 13 | /config 14 | 15 | # notebook 16 | /*.ipynb 17 | !/playground.ipynb 18 | 19 | # build 20 | /build 21 | /dist 22 | /*.egg-info -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="hyper_tile", 5 | version="0.1.5", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "torch", 9 | "einops", 10 | ], 11 | author="Thales Fernandes", 12 | author_email="thalesfdfernandes@gmail.com", 13 | description="Tiled-optimizations for Stable-Diffusion", 14 | url="https://github.com/tfernd/HyperTile", 15 | classifiers=[ 16 | "Programming Language :: Python :: 3", 17 | "Operating System :: OS Independent", 18 | ], 19 | ) 20 | -------------------------------------------------------------------------------- /hyper_tile/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import gc 4 | import torch 5 | import random 6 | 7 | 8 | def flush() -> None: 9 | gc.collect() 10 | if torch.cuda.is_available(): 11 | torch.cuda.empty_cache() 12 | torch.cuda.ipc_collect() 13 | 14 | 15 | def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int: 16 | min_value = min(min_value, value) 17 | 18 | # All big divisors of value (inclusive) 19 | divisors = [i for i in range(min_value, value + 1) if value % i == 0] 20 | 21 | ns = [value // i for i in divisors[:max_options]] # has at least 1 element 22 | 23 | idx = random.randint(0, len(ns) - 1) 24 | 25 | return ns[idx] 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Thales Fernandes 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HyperTile: Tiled-optimizations for Stable-Diffusion 2 | 3 | HyperTile optimizes the self-attention layer within the Stable-Diffusion U-Net and VAE models, resulting in a reduction in computation time ranging from 1 to 4 times, depending on the initial resolution and tile size. The implementation is **exceptionally** straightforward. 4 | 5 | To get started with HyperTile and experiment using the Jupyter notebook, follow these steps: 6 | 7 | 1. Clone the repository: 8 | 9 | ```bash 10 | git clone https://github.com/tfernd/HyperTile 11 | cd HyperTile 12 | ``` 13 | 14 | 2. Open the Jupyter notebook `playground.ipynb` (install _jupyter_ if you don't have it instaled already). 15 | 16 | ```bash 17 | jupyter-notebook playground.ipynb 18 | ``` 19 | 20 | Alternatively, you can install HyperTile using pip: 21 | 22 | ```bash 23 | pip install git+https://github.com/tfernd/HyperTile 24 | ``` 25 | 26 | ## Interested in Integrating It into Your Preferred Web UI? 27 | 28 | You can seamlessly incorporate this functionality with just three lines of code: 29 | 30 | ```python 31 | from hyper_tile import split_attention 32 | 33 | with split_attention(vae, height, width, vae_chunk): 34 | with split_attention(unet, height, width, unet_chunk): 35 | # Continue with the rest of your code, including the diffusion process 36 | ``` 37 | 38 | By adjusting the `vae_chunk` and `unet_chunk` sizes, you can fine-tune your setup according to your specific requirements. For Stable-Diffusion 1.5, it's advisable to keep the chunk size at 256 or 384 for the U-Net, and 128 for VAE. 39 | 40 | ## Examples 41 | 42 | All examples were from images found on the internet or generations of mine. It was upscaled with a loopback=2, and strength between 0.3 and 0.4. 43 | 44 | **Note**: The only reason why I'm using loopback, is because I'm using a naive upscaler from PIL (Lanczos), which make images very blurry. 45 | 46 | Woman in a dress: 512x768 -> 1664x2560 47 | ![woman](assets/woman.jpg) 48 | 49 | Forest 1: 681x503 -> 2816x2048 50 | ![forest1](assets/forest1.jpg) 51 | 52 | Forest 2: 768x384 -> 3072x1536 53 | ![forest2](assets/forest2.jpg) 54 | 55 | Forest 3: 512x768 -> 1664x2560 56 | ![forest3](assets/forest3.jpg) 57 | 58 | ## Performance 59 | 60 | In this performance evaluation, I conducted three image generation experiments, each consisting of 30 steps. I used the diffusers backend in PyTorch 2.0.1, with the assistance of [SDPA](https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html). The images generated are square, and their dimensions vary along the x-axis. The black dots represent speed measurements without tiling, while all other colored dots consist of tiles, with each dot corresponding to a specific ratio of the dimension (size/chunk), maintaining a minimum tile size of 128. 61 | 62 | ![Average Speed](assets/Average-Speed.jpg) 63 | 64 | The subsequent graph illustrates the speed-up achieved for each tile-ratio. As the target image dimension increases, the potential speed-up becomes more substantial. 65 | 66 | ![Speed-Up](assets/SpeedUp.jpg) 67 | 68 | It's important to note that, currently, I have exclusively tested with the diffusers backend due to its superior performance. Additionally, there is currently no LoRA model available for HD resolution that is compatible with diffusers. Consequently, text-to-image generation, whether tiled or non-tiled, may exhibit aberrations. Addressing this issue necessitates the development of a fine-tuned LoRA model specifically tailored for high-resolution images with a Hyper-Tiled enabled. 69 | 70 | ## Limitations 71 | 72 | - Stable-Diffusion's training data is based on 512 x 512 images, limiting its effectiveness for larger images. However, you can use an image-to-image approach with reduced `strength` to add details to larger images, especially beneficial when achieving a 3-4 times speed-up for very large images, typically in the 3-4K resolution range. 73 | 74 | - When working at 4K resolution with 16 GB VRAM, the diffusion process functions properly. However, the VAE implementation within `diffusers` struggles to decode the latents, even with the sliced-vae option enabled, resulting in out-of-memory errors. Further investigation into this issue is warranted. 75 | 76 | - In some cases, you may notice soft tiles in the images, which are not conventional hard tiles. These soft tiles may contain more detail. One potential mitigation strategy is to alternate the tile sizes. For example, use a set of smaller tiles initially and then gradually transition to slightly larger ones. Alternatively, consider using larger tiles at the beginning and smaller ones towards the end of the process. This approach is still under exploration for optimization. 77 | 78 | ## Future 79 | 80 | - Try to tile the second depth of the U-Net, using bigger chunks, or training an LoRA to remove problems. Is it worth it? 81 | 82 | - Identify other areas of the U-Net that can be tiled. 83 | 84 | - Tile Rotation: With each function call, a varying tile size is employed to prevent any overlap-related concerns in some special circunstances. 85 | -------------------------------------------------------------------------------- /hyper_tile/hyper_tile.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from typing import Callable 3 | from typing_extensions import Literal 4 | 5 | import logging 6 | from functools import wraps 7 | from contextlib import contextmanager 8 | 9 | import math 10 | import torch.nn as nn 11 | from einops import rearrange 12 | 13 | from .utils import random_divisor 14 | 15 | 16 | # TODO add SD-XL layers 17 | DEPTH_LAYERS = { 18 | 0: [ 19 | # SD 1.5 U-Net (diffusers) 20 | "down_blocks.0.attentions.0.transformer_blocks.0.attn1", 21 | "down_blocks.0.attentions.1.transformer_blocks.0.attn1", 22 | "up_blocks.3.attentions.0.transformer_blocks.0.attn1", 23 | "up_blocks.3.attentions.1.transformer_blocks.0.attn1", 24 | "up_blocks.3.attentions.2.transformer_blocks.0.attn1", 25 | # SD 1.5 U-Net (ldm) 26 | "input_blocks.1.1.transformer_blocks.0.attn1", 27 | "input_blocks.2.1.transformer_blocks.0.attn1", 28 | "output_blocks.9.1.transformer_blocks.0.attn1", 29 | "output_blocks.10.1.transformer_blocks.0.attn1", 30 | "output_blocks.11.1.transformer_blocks.0.attn1", 31 | # SD 1.5 VAE 32 | "decoder.mid_block.attentions.0", 33 | ], 34 | 1: [ 35 | # SD 1.5 U-Net (diffusers) 36 | "down_blocks.1.attentions.0.transformer_blocks.0.attn1", 37 | "down_blocks.1.attentions.1.transformer_blocks.0.attn1", 38 | "up_blocks.2.attentions.0.transformer_blocks.0.attn1", 39 | "up_blocks.2.attentions.1.transformer_blocks.0.attn1", 40 | "up_blocks.2.attentions.2.transformer_blocks.0.attn1", 41 | # SD 1.5 U-Net (ldm) 42 | "input_blocks.4.1.transformer_blocks.0.attn1", 43 | "input_blocks.5.1.transformer_blocks.0.attn1", 44 | "output_blocks.6.1.transformer_blocks.0.attn1", 45 | "output_blocks.7.1.transformer_blocks.0.attn1", 46 | "output_blocks.8.1.transformer_blocks.0.attn1", 47 | ], 48 | 2: [ 49 | # SD 1.5 U-Net (diffusers) 50 | "down_blocks.2.attentions.0.transformer_blocks.0.attn1", 51 | "down_blocks.2.attentions.1.transformer_blocks.0.attn1", 52 | "up_blocks.1.attentions.0.transformer_blocks.0.attn1", 53 | "up_blocks.1.attentions.1.transformer_blocks.0.attn1", 54 | "up_blocks.1.attentions.2.transformer_blocks.0.attn1", 55 | # SD 1.5 U-Net (ldm) 56 | "input_blocks.7.1.transformer_blocks.0.attn1", 57 | "input_blocks.8.1.transformer_blocks.0.attn1", 58 | "output_blocks.3.1.transformer_blocks.0.attn1", 59 | "output_blocks.4.1.transformer_blocks.0.attn1", 60 | "output_blocks.5.1.transformer_blocks.0.attn1", 61 | ], 62 | 3: [ 63 | # SD 1.5 U-Net (diffusers) 64 | "mid_block.attentions.0.transformer_blocks.0.attn1", 65 | # SD 1.5 U-Net (ldm) 66 | "middle_block.1.transformer_blocks.0.attn1", 67 | ], 68 | } 69 | 70 | 71 | @contextmanager 72 | def split_attention( 73 | layer: nn.Module, 74 | /, 75 | aspect_ratio: float, # width/height 76 | tile_size: int = 256, # 128 for VAE 77 | swap_size: int = 2, # 1 for VAE 78 | *, 79 | disable: bool = False, 80 | max_depth: Literal[0, 1, 2, 3] = 0, # ! Try 0 or 1 81 | scale_depth: bool = False, # scale the tile-size depending on the depth 82 | ): 83 | # Hijacks AttnBlock from ldm and Attention from diffusers 84 | 85 | if disable: 86 | logging.info(f"Attention for {layer.__class__.__qualname__} not splitted") 87 | yield 88 | return 89 | 90 | latent_tile_size = max(32, tile_size) // 8 91 | 92 | def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable: 93 | @wraps(forward) 94 | def wrapper(*args, **kwargs): 95 | x = args[0] 96 | 97 | # VAE 98 | if x.ndim == 4: 99 | b, c, h, w = x.shape 100 | 101 | nh = random_divisor(h, latent_tile_size, swap_size) 102 | nw = random_divisor(w, latent_tile_size, swap_size) 103 | 104 | if nh * nw > 1: 105 | x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw) 106 | 107 | out = forward(x, *args[1:], **kwargs) 108 | 109 | if nh * nw > 1: 110 | out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw) 111 | 112 | # U-Net 113 | else: 114 | hw = x.size(1) 115 | h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio)) 116 | 117 | factor = 2**depth if scale_depth else 1 118 | nh = random_divisor(h, latent_tile_size * factor, swap_size) 119 | nw = random_divisor(w, latent_tile_size * factor, swap_size) 120 | 121 | module._split_sizes.append((nh, nw)) # type: ignore 122 | 123 | if nh * nw > 1: 124 | x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw) 125 | 126 | out = forward(x, *args[1:], **kwargs) 127 | 128 | if nh * nw > 1: 129 | out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw) 130 | out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw) 131 | 132 | return out 133 | 134 | return wrapper 135 | 136 | # Handle hikajing the forward method and recovering afterwards 137 | try: 138 | for depth in range(max_depth + 1): 139 | for layer_name, module in layer.named_modules(): 140 | if any(layer_name.endswith(try_name) for try_name in DEPTH_LAYERS[depth]): 141 | logging.info(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}") 142 | 143 | # save original forward for recovery later 144 | setattr(module, "_original_forward", module.forward) 145 | setattr(module, "forward", self_attn_forward(module.forward, depth, layer_name, module)) 146 | 147 | setattr(module, "_split_sizes", []) 148 | yield 149 | finally: 150 | for layer_name, module in layer.named_modules(): 151 | # remove hijack 152 | if hasattr(module, "_original_forward"): 153 | if module._split_sizes: 154 | logging.debug(f"layer {layer_name} splitted with ({module._split_sizes})") 155 | 156 | setattr(module, "forward", module._original_forward) 157 | del module._original_forward 158 | del module._split_sizes 159 | -------------------------------------------------------------------------------- /playground.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# HyperTile Playground" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "I'll presume you have some familiarity with the `diffusers` library.\n", 15 | "\n", 16 | "I encourage you to experiment with both the `text2img` and `img2img` variations. Keep in mind that due to the absence of an HD-LoRA model for `diffusers`, the `text2img` results may exhibit suboptimal structures. However, it's essential to recognize that these limitations stem from the initial training of SD at 512x512 resolution, and improvements are anticipated in the future.\n", 17 | "\n", 18 | "If you find yourself needing further information, don't hesitate to consult the comprehensive documentation provided by `diffusers`; it offers valuable insights and guidance." 19 | ] 20 | }, 21 | { 22 | "cell_type": "markdown", 23 | "metadata": {}, 24 | "source": [ 25 | "## Introduction" 26 | ] 27 | }, 28 | { 29 | "cell_type": "markdown", 30 | "metadata": {}, 31 | "source": [ 32 | "Initialize the packages we need." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "from PIL import Image\n", 42 | "from tqdm.auto import trange\n", 43 | "import logging\n", 44 | "\n", 45 | "import torch\n", 46 | "\n", 47 | "from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, StableDiffusionXLPipeline\n", 48 | "from diffusers.schedulers import UniPCMultistepScheduler\n", 49 | "\n", 50 | "from diffusers.utils import load_image\n", 51 | "\n", 52 | "from hyper_tile import split_attention, flush\n", 53 | "\n", 54 | "# To log attention-splitting\n", 55 | "logging.basicConfig(level=logging.INFO)" 56 | ] 57 | }, 58 | { 59 | "cell_type": "markdown", 60 | "metadata": {}, 61 | "source": [ 62 | "Select the path to the model you want (*safetensors*)." 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "dtype = torch.float16 # bfloat16 can result of out-of-memory with big images due to interpolation limits, well-document in diffusers library\n", 72 | "device = torch.device('cuda')\n", 73 | "\n", 74 | "model_path = r\"path-to-model.safetensors\"" 75 | ] 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "metadata": {}, 80 | "source": [ 81 | "## Text-to-Image" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "pipe: StableDiffusionPipeline = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=dtype, local_files_only=True, use_safetensors=True, load_safety_checker=False) # type: ignore\n", 91 | "pipe.to(device)\n", 92 | "pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)" 93 | ] 94 | }, 95 | { 96 | "cell_type": "markdown", 97 | "metadata": {}, 98 | "source": [ 99 | "1. Choose your desired `height` and `width`.\n", 100 | "\n", 101 | "2. You have the flexibility to adjust the `tile_size` independently for the VAE and UNet components. For the VAE, a `tile_size` of 128 is optimal without sacrificing performance. However, for the UNet, it's advisable to use a chunk size of 256 or greater. `swap_size` determine how many different tiles per dimension are used, to avoid overlap seams in some cases.\n", 102 | "\n", 103 | "3. Modify the `disable` parameter to either True or False to observe the results with or without HyperTile.\n", 104 | "\n", 105 | "**Note**: For improved chunk division, consider using dimensions that are multiples of 128. This can enhance the effectiveness of the chunking process. (This is enforced)" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "# Try lower value if you dont have 16 Gb VRAM\n", 115 | "\n", 116 | "height, width = 2688, 1792\n", 117 | "\n", 118 | "height = int(height)//128*128 # enforcing multiples of 128\n", 119 | "width = int(width)//128*128\n", 120 | "print(height, width)\n", 121 | "\n", 122 | "with split_attention(pipe.vae, height, width, tile_size=128):\n", 123 | " # ! Change the tile_size and disable to see their effects\n", 124 | " with split_attention(pipe.unet, height, width, tile_size=128, swap_size=2, disable=False):\n", 125 | " flush()\n", 126 | " img = pipe(\n", 127 | " # ! Change the prompt and other parameters\n", 128 | " prompt='forest, path, stone, red trees, detailed, buildings', \n", 129 | " negative_prompt='blurry, low detail',\n", 130 | " num_inference_steps=26, guidance_scale=7.5, \n", 131 | " height=height, width=width,\n", 132 | " ).images[0]\n", 133 | "img" 134 | ] 135 | }, 136 | { 137 | "cell_type": "markdown", 138 | "metadata": {}, 139 | "source": [ 140 | "## Image-to-Image" 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": null, 146 | "metadata": {}, 147 | "outputs": [], 148 | "source": [ 149 | "pipe: StableDiffusionImg2ImgPipeline = StableDiffusionImg2ImgPipeline.from_single_file(model_path, torch_dtype=dtype, local_files_only=True, use_safetensors=True, load_safety_checker=False) # type: ignore\n", 150 | "pipe.to(device)\n", 151 | "pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)" 152 | ] 153 | }, 154 | { 155 | "cell_type": "markdown", 156 | "metadata": {}, 157 | "source": [ 158 | "Load the image that you want." 159 | ] 160 | }, 161 | { 162 | "cell_type": "code", 163 | "execution_count": null, 164 | "metadata": {}, 165 | "outputs": [], 166 | "source": [ 167 | "image = load_image(\"image.png\")\n", 168 | "ar = image.height / image.width" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "1. Define the target `height` and the number of loopbacks required, indicating how many times we perform image-to-image operations. This parameter is essential as we employ simple Lanczos upscaling.\n", 176 | "\n", 177 | "2. Adjust the `strength` and `loopback` settings to achieve the optimal outcome. You can experiment with lower strength values paired with more loopbacks or larger strength values with fewer loopbacks.\n", 178 | "\n", 179 | "3. Customize the `tile_size` separately for the VAE and UNet components. A `tile_size` of 128 is recommended for the VAE without compromising quality. For the UNet, it's advisable to use a `tile_size` size of 256 or greater. `swap_size` determine how many different tiles per dimension are used, to avoid overlap seams in some cases.\n", 180 | "\n", 181 | "4. Toggle the `disable` parameter between True and False to observe the results with or without the use of HyperTile.\n", 182 | "\n", 183 | "**Note**: For improved chunk division, consider using dimensions that are multiples of 128. This practice enhances the efficiency of the chunking process.\n", 184 | "\n", 185 | "**Note**: The inclusion of loopbacks is essential due to the original training of Stable-Diffusion (SD) on 512x512 images. When we upscale these images 3-4 times or more, the use of Lanczos upscaling introduces blurriness. Loopbacks play a crucial role in mitigating this issue, effectively restoring image clarity and preserving details during the upscaling process." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "metadata": {}, 192 | "outputs": [], 193 | "source": [ 194 | "loopback = 2 # Use 1 or 2, depending on how much upscaling we are doing\n", 195 | "# Try lower value if you dont have 16 Gb VRAM\n", 196 | "height = 512*3\n", 197 | "\n", 198 | "width = height/ar\n", 199 | "height = int(height)//128*128 # enforcing multiples of 128\n", 200 | "width = int(width)//128*128\n", 201 | "print(height, width)\n", 202 | "\n", 203 | "# Upscale to the correct resolution\n", 204 | "img = image.resize((width, height), resample=Image.LANCZOS) if image.size != (width, height) else image\n", 205 | "\n", 206 | "with split_attention(pipe.vae, height, width, tile_size=128):\n", 207 | " # ! Change the chunk and disable to see their effects\n", 208 | " with split_attention(pipe.unet, height, width, tile_size=256, swap_size=2, disable=False):\n", 209 | " flush()\n", 210 | " for i in trange(loopback):\n", 211 | " img = pipe(\n", 212 | " prompt='forest, path, stone, red trees, detailed', \n", 213 | " negative_prompt='blurry, low detail',\n", 214 | " num_inference_steps=28, guidance_scale=7.5, \n", 215 | " image=img, strength=0.46, # ! you can also change the strength\n", 216 | " ).images[0]\n", 217 | "img" 218 | ] 219 | } 220 | ], 221 | "metadata": { 222 | "kernelspec": { 223 | "display_name": "torch", 224 | "language": "python", 225 | "name": "python3" 226 | }, 227 | "language_info": { 228 | "codemirror_mode": { 229 | "name": "ipython", 230 | "version": 3 231 | }, 232 | "file_extension": ".py", 233 | "mimetype": "text/x-python", 234 | "name": "python", 235 | "nbconvert_exporter": "python", 236 | "pygments_lexer": "ipython3", 237 | "version": "3.10.11" 238 | }, 239 | "orig_nbformat": 4 240 | }, 241 | "nbformat": 4, 242 | "nbformat_minor": 2 243 | } 244 | --------------------------------------------------------------------------------