├── .github └── workflows │ ├── publish.yml │ └── python-tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── environment.yml ├── examples ├── IIC.png └── linearForest.json ├── generate_grid.py ├── generate_grid_from_json.py ├── mixdiff ├── __init__.py ├── canvas.py ├── extrasmixin.py ├── imgtools.py └── tiling.py ├── requirements.txt ├── setup.py └── tests ├── __init__.py └── canvas_test.py /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish package on PyPI 2 | 3 | on: [push] 4 | 5 | jobs: 6 | test-publish-package: 7 | runs-on: ubuntu-latest 8 | name: publish package on test-PyPI 9 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 10 | 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python 14 | uses: actions/setup-python@v4 15 | with: 16 | python-version: '3.x' 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install build 21 | - name: Get release tag 22 | id: tag 23 | uses: dawidd6/action-get-tag@v1 24 | - name: Build package 25 | run: GITHUB_TAG=${{steps.tag.outputs.tag}} python -m build 26 | - name: Publish package 27 | uses: pypa/gh-action-pypi-publish@release/v1 28 | with: 29 | repository_url: https://test.pypi.org/legacy/ 30 | user: __token__ 31 | password: ${{ secrets.TESTPYPI_API_TOKEN }} 32 | 33 | publish-package: 34 | runs-on: ubuntu-latest 35 | name: publish package on PyPI 36 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 37 | needs: test-publish-package 38 | 39 | steps: 40 | - uses: actions/checkout@v3 41 | - name: Set up Python 42 | uses: actions/setup-python@v4 43 | with: 44 | python-version: '3.x' 45 | - name: Install dependencies 46 | run: | 47 | python -m pip install --upgrade pip 48 | pip install build 49 | - name: Get release tag 50 | id: tag 51 | uses: dawidd6/action-get-tag@v1 52 | - name: Build package 53 | run: GITHUB_TAG=${{steps.tag.outputs.tag}} python -m build 54 | - name: Publish package 55 | uses: pypa/gh-action-pypi-publish@release/v1 56 | with: 57 | user: __token__ 58 | password: ${{ secrets.PYPI_API_TOKEN }} 59 | 60 | test-package-download: 61 | runs-on: ubuntu-latest 62 | strategy: 63 | matrix: 64 | python-version: ["3.8", "3.9", "3.10"] 65 | name: test package works by installing from test-PyPI 66 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags') 67 | needs: publish-package 68 | 69 | steps: 70 | - uses: actions/checkout@v3 71 | - name: Erase source code 72 | run: rm -rf mixdiff 73 | - name: Set up Python 74 | uses: actions/setup-python@v4 75 | with: 76 | python-version: ${{ matrix.python-version }} 77 | - name: Get release tag 78 | id: tag 79 | uses: dawidd6/action-get-tag@v1 80 | - name: Install package 81 | run: python -m pip install mixdiff==${{steps.tag.outputs.tag}} 82 | - name: Test with pytest 83 | env: 84 | HUGGING_FACE_HUB_TOKEN: ${{ secrets.huggingface_token }} 85 | run: | 86 | pip install pytest 87 | pytest 88 | -------------------------------------------------------------------------------- /.github/workflows/python-tests.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python 3 | 4 | name: Unit tests 5 | 6 | on: 7 | push: 8 | pull_request: 9 | branches: [ "master" ] 10 | 11 | jobs: 12 | test-local-environment: 13 | runs-on: ubuntu-latest 14 | defaults: 15 | run: 16 | shell: bash -l {0} 17 | strategy: 18 | fail-fast: false 19 | matrix: 20 | python-version: ["3.8", "3.9", "3.10"] 21 | 22 | steps: 23 | - uses: actions/checkout@v3 24 | - name: Setup Miniconda 25 | uses: conda-incubator/setup-miniconda@v2.2.0 26 | with: 27 | python-version: ${{ matrix.python-version }} 28 | auto-activate-base: false 29 | activate-environment: mixture_of_diffusers 30 | environment-file: environment.yml 31 | - name: Test with pytest 32 | env: 33 | HUGGING_FACE_HUB_TOKEN: ${{ secrets.huggingface_token }} 34 | run: | 35 | conda install pytest && pytest 36 | 37 | test-package: 38 | runs-on: ubuntu-latest 39 | strategy: 40 | fail-fast: false 41 | matrix: 42 | python-version: ["3.8", "3.9", "3.10"] 43 | steps: 44 | - uses: actions/checkout@v3 45 | - name: Set up Python 46 | uses: actions/setup-python@v4 47 | with: 48 | python-version: ${{ matrix.python-version }} 49 | - name: Install dependencies 50 | run: | 51 | python -m pip install --upgrade pip 52 | pip install build 53 | - name: Build package 54 | run: python -m build 55 | - name: Install package 56 | run: | 57 | pip install . 58 | - name: Test with pytest 59 | env: 60 | HUGGING_FACE_HUB_TOKEN: ${{ secrets.huggingface_token }} 61 | run: | 62 | pip install pytest 63 | rm -rf mixdiff # Delete source to test installed package 64 | pytest 65 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | **/__pycache__ 3 | nohup.out 4 | outputs 5 | venv 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Álvaro Barbero Jiménez 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixture of Diffusers 2 | 3 | ![2022-10-12 15_35_27 305133_A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, s_640x640_schelms_seed7178915308_gc8_steps50](https://user-images.githubusercontent.com/9654655/195362341-bc7766c2-f5c6-40f2-b457-59277aa11027.png) 4 | 5 | [![Unit tests](https://github.com/albarji/mixture-of-diffusers/actions/workflows/python-tests.yml/badge.svg)](https://github.com/albarji/mixture-of-diffusers/actions/workflows/python-tests.yml) 6 | 7 | [![huggingface space](https://camo.githubusercontent.com/00380c35e60d6b04be65d3d94a58332be5cc93779f630bcdfc18ab9a3a7d3388/68747470733a2f2f696d672e736869656c64732e696f2f62616467652f25463025394625413425393725323048756767696e67253230466163652d5370616365732d626c7565)](https://huggingface.co/spaces/albarji/mixture-of-diffusers) 8 | 9 | This repository holds various scripts and tools implementing a method for integrating a mixture of different diffusion processes collaborating to generate a single image. Each diffuser focuses on a particular region on the image, taking into account boundary effects to promote a smooth blending. 10 | 11 | If you prefer a more user friendly graphical interface to use this algorithm, I recommend trying the [Tiled Diffusion & VAE](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111) plugin developed by pkuliyi2015 for [AUTOMATIC1111's stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui). 12 | 13 | ## Motivation 14 | 15 | Current image generation methods, such as Stable Diffusion, struggle to position objects at specific locations. While the content of the generated image (somewhat) reflects the objects present in the prompt, it is difficult to frame the prompt in a way that creates an specific composition. For instance, take a prompt expressing a complex composition such as 16 | 17 | > A charming house in the countryside on the left, 18 | > in the center a dirt road in the countryside crossing pastures, 19 | > on the right an old and rusty giant robot lying on a dirt road, 20 | > by jakub rozalski, 21 | > sunset lighting on the left and center, dark sunset lighting on the right 22 | > elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece 23 | 24 | Out of a sample of 20 Stable Diffusion generations with different seeds, the generated images that align best with the prompt are the following: 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 |
33 | 34 | The method proposed here strives to provide a better tool for image composition by using several diffusion processes in parallel, each configured with a specific prompt and settings, and focused on a particular region of the image. For example, the following are three outputs from this method, using the following prompts from left to right: 35 | 36 | * "**A charming house in the countryside, by jakub rozalski, sunset lighting**, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" 37 | * "**A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting**, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" 38 | * "**An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting**, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" 39 | 40 | ![2022-10-12 15_25_40 021063_A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, s_640x640_schelms_seed9764851938_gc8_steps50](https://user-images.githubusercontent.com/9654655/195362152-6f3af44d-cf8a-494b-8cf8-36acd8f86871.png) 41 | ![2022-10-12 15_32_11 563087_A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, s_640x640_schelms_seed2096547054_gc8_steps50](https://user-images.githubusercontent.com/9654655/195362315-8c2d01a8-62f2-4d96-90ca-9ad22f69398e.png) 42 | ![2022-10-12 15_35_27 305133_A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, s_640x640_schelms_seed7178915308_gc8_steps50](https://user-images.githubusercontent.com/9654655/195362341-bc7766c2-f5c6-40f2-b457-59277aa11027.png) 43 | 44 | The mixture of diffusion processes is done in a way that harmonizes the generation process, preventing "seam" effects in the generated image. 45 | 46 | Using several diffusion processes in parallel has also practical advantages when generating very large images, as the GPU memory requirements are similar to that of generating an image of the size of a single tile. 47 | 48 | ## Usage 49 | 50 | This repository provides two new pipelines, `StableDiffusionTilingPipeline` and `StableDiffusionCanvasPipeline`, that extend the standard Stable Diffusion pipeline from [Diffusers](https://github.com/huggingface/diffusers). They feature new options that allow defining the mixture of diffusers, which are distributed as a number of "diffusion regions" over the image to be generated. `StableDiffusionTilingPipeline` is simpler to use and arranges the diffusion regions as a grid over the canvas, while `StableDiffusionCanvasPipeline` allows a more flexible placement and also features image2image capabilities. 51 | 52 | ### Prerequisites 53 | 54 | Since this work is based on Stable Diffusion models, you will need to [request access and accept the usage terms of Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion#model-access). You will also need to [configure your Hugging Face User Access Token](https://huggingface.co/docs/hub/security-tokens) in your running environment. 55 | 56 | ### StableDiffusionTilingPipeline 57 | 58 | The header image in this repo can be generated as follows 59 | 60 | ```python 61 | from diffusers import LMSDiscreteScheduler 62 | from mixdiff import StableDiffusionTilingPipeline 63 | 64 | # Creater scheduler and model (similar to StableDiffusionPipeline) 65 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 66 | pipeline = StableDiffusionTilingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True).to("cuda:0") 67 | 68 | # Mixture of Diffusers generation 69 | image = pipeline( 70 | prompt=[[ 71 | "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", 72 | "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece", 73 | "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" 74 | ]], 75 | tile_height=640, 76 | tile_width=640, 77 | tile_row_overlap=0, 78 | tile_col_overlap=256, 79 | guidance_scale=8, 80 | seed=7178915308, 81 | num_inference_steps=50, 82 | )["sample"][0] 83 | ``` 84 | 85 | The prompts must be provided as a list of lists, where each list represents a row of diffusion regions. The geometry of the canvas is inferred from these lists, e.g. in the example above we are creating a grid of 1x3 diffusion regions (1 row and 3 columns). The rest of parameters provide information on the size of these regions, and how much they overlap with their neighbors. 86 | 87 | Alternatively, it is possible to specify the grid parameters through a JSON configuration file. In the following example a grid of 10x1 tiles is configured to generate a forest in changing styles: 88 | 89 | ![gridExampleLabeled](https://user-images.githubusercontent.com/9654655/195371664-54d8a599-25d8-46ba-b823-3c7726ecb6ff.png) 90 | 91 | A `StableDiffusionTilingPipeline` is configured to use 10 prompts with changing styles. Each tile takes a shape of 768x512 pixels, and tiles overlap 256 pixels to avoid seam effects. All the details are specified in a configuration file: 92 | 93 | ```json 94 | { 95 | "cpu_vae": true, 96 | "gc": 8, 97 | "gc_tiles": null, 98 | "prompt": [ 99 | [ 100 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 101 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 102 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 103 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 104 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 105 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 106 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 107 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 108 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 109 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors" 110 | ] 111 | ], 112 | "scheduler": "lms", 113 | "seed": 639688656, 114 | "steps": 50, 115 | "tile_col_overlap": 256, 116 | "tile_height": 768, 117 | "tile_row_overlap": 256, 118 | "tile_width": 512 119 | } 120 | ``` 121 | 122 | You can try generating this image using this configuration file by running 123 | 124 | ```bash 125 | python generate_grid_from_json.py examples/linearForest.json 126 | ``` 127 | 128 | The full list of arguments to a `StableDiffusionTilingPipeline` is: 129 | 130 | > `prompt`: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure. 131 | > 132 | > `num_inference_steps`: number of diffusions steps. 133 | > 134 | > `guidance_scale`: classifier-free guidance. 135 | > 136 | > `seed`: general random seed to initialize latents. 137 | > 138 | > `tile_height`: height in pixels of each grid tile. 139 | > 140 | > `tile_width`: width in pixels of each grid tile. 141 | > 142 | > `tile_row_overlap`: number of overlap pixels between tiles in consecutive rows. 143 | > 144 | > `tile_col_overlap`: number of overlap pixels between tiles in consecutive columns. 145 | > 146 | > `guidance_scale_tiles`: specific weights for classifier-free guidance in each tile. 147 | > 148 | > `guidance_scale_tiles`: specific weights for classifier-free guidance in each tile. If `None`, the value provided in `guidance_scale` will be used. 149 | > 150 | > `seed_tiles`: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard `seed` parameter. 151 | > 152 | > `seed_tiles_mode`: either `"full"` `"exclusive"`. If `"full"`, all the latents affected by the tile be overriden. If `"exclusive"`, only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden. 153 | > 154 | > `seed_reroll_regions`: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over `seed_tiles`. 155 | > 156 | > `cpu_vae`: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to `True` to run the decoder in CPU. Slower, but should run without memory issues. 157 | 158 | A script showing a more advanced use of this pipeline is available as [generate_grid.py](./generate_grid.py). 159 | 160 | ### StableDiffusionCanvasPipeline 161 | 162 | The `StableDiffusionCanvasPipeline` works by defining a list of `Text2ImageRegion` objects that detail the region of influence of each diffuser. As an illustrative example, the heading image at this repo can be generated with the following code: 163 | 164 | ```python 165 | from diffusers import LMSDiscreteScheduler 166 | from mixdiff import StableDiffusionCanvasPipeline, Text2ImageRegion 167 | 168 | # Creater scheduler and model (similar to StableDiffusionPipeline) 169 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 170 | pipeline = StableDiffusionCanvasPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True).to("cuda:0") 171 | 172 | # Mixture of Diffusers generation 173 | image = pipeline( 174 | canvas_height=640, 175 | canvas_width=1408, 176 | regions=[ 177 | Text2ImageRegion(0, 640, 0, 640, guidance_scale=8, 178 | prompt=f"A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"), 179 | Text2ImageRegion(0, 640, 384, 1024, guidance_scale=8, 180 | prompt=f"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"), 181 | Text2ImageRegion(0, 640, 768, 1408, guidance_scale=8, 182 | prompt=f"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"), 183 | ], 184 | num_inference_steps=50, 185 | seed=7178915308, 186 | )["sample"][0] 187 | ``` 188 | 189 | `Image2Image` regions can also be added at any position, to use a particular image as guidance. In the following example we create a Christmas postcard by taking a photo of a building (available at this repo) and using it as a guidance in a region of the canvas. 190 | 191 | ```python 192 | from PIL import Image 193 | from diffusers import LMSDiscreteScheduler 194 | from mixdiff import StableDiffusionCanvasPipeline, Text2ImageRegion, Image2ImageRegion, preprocess_image 195 | 196 | # Creater scheduler and model (similar to StableDiffusionPipeline) 197 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 198 | pipeline = StableDiffusionCanvasPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True).to("cuda:0") 199 | 200 | # Load and preprocess guide image 201 | iic_image = preprocess_image(Image.open("examples/IIC.png").convert("RGB")) 202 | 203 | # Mixture of Diffusers generation 204 | image = pipeline( 205 | canvas_height=800, 206 | canvas_width=352, 207 | regions=[ 208 | Text2ImageRegion(0, 800, 0, 352, guidance_scale=8, 209 | prompt=f"Christmas postcard, a charming house in the countryside surrounded by snow, a giant christmas tree, under a starry night sky, by jakub rozalski and alayna danner and guweiz, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"), 210 | Image2ImageRegion(800-352, 800, 0, 352, reference_image=iic_image, strength=0.8), 211 | ], 212 | num_inference_steps=57, 213 | seed=5525475061, 214 | )["sample"][0] 215 | ``` 216 | 217 | ![githubIIC](https://user-images.githubusercontent.com/9654655/218306373-fbae1381-178a-454c-89bf-0c299af4fb96.png) 218 | 219 | The full list of arguments to a `StableDiffusionCanvasPipeline` is: 220 | 221 | > `canvas_height`: height in pixels of the image to generate. Must be a multiple of 8. 222 | > 223 | > `canvas_width`: width in pixels of the image to generate. Must be a multiple of 8. 224 | > 225 | > `regions`: list of `Text2Image` or `Image2Image` diffusion regions (see below). 226 | > 227 | > `num_inference_steps`: number of diffusions steps. 228 | > 229 | > `seed`: general random seed to initialize latents. 230 | > 231 | > `reroll_regions`: list of `RerollRegion` regions in which to reroll latents (see below). Useful if you like the overall aspect of the generated image, but want to regenerate a specific region using a different random seed. 232 | > 233 | > `cpu_vae`: whether to perform encoder-decoder operations in CPU, even if the diffusion process runs in GPU. Use `cpu_vae=True` if you run out of GPU memory at the end of the generation process for large canvas dimensions, or if you create large `Image2Image` regions. 234 | > 235 | > `decode_steps`: if `True` the result will include not only the final image, but also all the intermediate steps in the generation. Note: this will greatly increase running times. 236 | 237 | All regions are configured with the following parameters: 238 | 239 | > `row_init`: starting row in pixel space (included). Must be a multiple of 8. 240 | > 241 | > `row_end`: end row in pixel space (not included). Must be a multiple of 8. 242 | > 243 | > `col_init`: starting column in pixel space (included). Must be a multiple of 8. 244 | > 245 | > `col_end`: end column in pixel space (not included). Must be a multiple of 8. 246 | > 247 | > `region_seed`: seed for random operations in this region 248 | > 249 | > `noise_eps`: deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents 250 | 251 | Additionally, `Text2Image` regions use the following arguments: 252 | 253 | > `prompt`: text prompt guiding the diffuser in this region 254 | > 255 | > `guidance_scale`: guidance scale of the diffuser in this region. If None, randomize. 256 | > 257 | > `mask_type`: kind of weight mask applied to this region, must be one of `["constant", gaussian", quartic"]`. 258 | > 259 | > `mask_weight`: global weights multiplier of the mask. 260 | 261 | `Image2Image` regions are configured with the basic region parameters plus ther following: 262 | 263 | > `reference_image`: image to use as guidance. Must be loaded as a PIL image and pre-processed using the `preprocess_image` function (see example above). It will be automatically rescaled to the shape of the region. 264 | > 265 | > `strength`: strength of the image guidance, must lie in the range `[0.0, 1.0]` (from no guidance to absolute priority of the original image). 266 | 267 | Finally, `RerollRegions` accept the basic arguments plus the following: 268 | 269 | > `reroll_mode`: kind of reroll to perform, either `reset` (completely reset latents with new ones) or `epsilon` (alter slightly the latents in the region). 270 | 271 | ## Citing and full technical details 272 | 273 | If you find this repository useful, please be so kind to cite the corresponding paper, which also contains the full details about this method: 274 | 275 | > Álvaro Barbero Jiménez. Mixture of Diffusers for scene composition and high resolution image generation. https://arxiv.org/abs/2302.02412 276 | 277 | ## Responsible use 278 | 279 | The same recommendations as in Stable Diffusion apply, so please check the corresponding [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4). 280 | 281 | More broadly speaking, always bear this in mind: YOU are responsible for the content you create using this tool. Do not fully blame, credit, or place the responsibility on the software. 282 | 283 | ## Gallery 284 | 285 | Here are some relevant illustrations I have created using this software (and putting quite a few hours into them!). 286 | 287 | ### Darkness Dawning 288 | 289 | ![Darkness Dawning](https://images-wixmp-ed30a86b8c4ca887773594c2.wixmp.com/f/cd1358aa-80d5-4c59-b95b-cdfde5dcc4f5/dfidq8n-6da9a886-9f1c-40ae-8341-d77af9552395.png?token=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1cm46YXBwOjdlMGQxODg5ODIyNjQzNzNhNWYwZDQxNWVhMGQyNmUwIiwiaXNzIjoidXJuOmFwcDo3ZTBkMTg4OTgyMjY0MzczYTVmMGQ0MTVlYTBkMjZlMCIsIm9iaiI6W1t7InBhdGgiOiJcL2ZcL2NkMTM1OGFhLTgwZDUtNGM1OS1iOTViLWNkZmRlNWRjYzRmNVwvZGZpZHE4bi02ZGE5YTg4Ni05ZjFjLTQwYWUtODM0MS1kNzdhZjk1NTIzOTUucG5nIn1dXSwiYXVkIjpbInVybjpzZXJ2aWNlOmZpbGUuZG93bmxvYWQiXX0.ff6XoVBPdUbcTLcuHUpQMPrD2TaXBM_s6HfRhsARDw0) 290 | 291 | ### Yog-Sothoth 292 | 293 | ![Yog-Sothoth](https://images-wixmp-ed30a86b8c4ca887773594c2.wixmp.com/f/cd1358aa-80d5-4c59-b95b-cdfde5dcc4f5/dfidsq4-174dd428-2c5a-48f6-a78f-9441fb3cffea.png?token=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJ1cm46YXBwOjdlMGQxODg5ODIyNjQzNzNhNWYwZDQxNWVhMGQyNmUwIiwiaXNzIjoidXJuOmFwcDo3ZTBkMTg4OTgyMjY0MzczYTVmMGQ0MTVlYTBkMjZlMCIsIm9iaiI6W1t7InBhdGgiOiJcL2ZcL2NkMTM1OGFhLTgwZDUtNGM1OS1iOTViLWNkZmRlNWRjYzRmNVwvZGZpZHNxNC0xNzRkZDQyOC0yYzVhLTQ4ZjYtYTc4Zi05NDQxZmIzY2ZmZWEucG5nIn1dXSwiYXVkIjpbInVybjpzZXJ2aWNlOmZpbGUuZG93bmxvYWQiXX0.X42zWgsk3lYnYwuEgkifRFRH2km-npHvrdleDN3m6bA) 294 | 295 | ### Looking through the eyes of giants 296 | 297 | ![Looking through the eyes of giants](https://user-images.githubusercontent.com/9654655/218307148-95ce88b6-b2a3-458d-b469-daf5bd56e3a7.jpg) 298 | 299 | [Follow me on DeviantArt for more!](https://www.deviantart.com/albarji) 300 | 301 | ## Acknowledgements 302 | 303 | First and foremost, my most sincere appreciation for the [Stable Diffusion team](https://stability.ai/blog/stable-diffusion-public-release) for releasing such an awesome model, and for letting me take part of the closed beta. Kudos also to the Hugging Face community and developers for implementing the [Diffusers library](https://github.com/huggingface/diffusers). 304 | 305 | Thanks to Instituto de Ingeniería del Conocimiento and Grupo de Aprendizaje Automático (Universidad Autónoma de Madrid) for providing GPU resources for testing and experimenting this library. 306 | 307 | Thanks also to the vibrant communities of the Stable Diffusion discord channel and [Lexica](https://lexica.art/), where I have learned about many amazing artists and styles. And to my friend Abril for sharing many tips on cool artists! 308 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: mixture_of_diffusers 2 | dependencies: 3 | - pip=22.1.* 4 | - pip: 5 | - -r requirements.txt 6 | -------------------------------------------------------------------------------- /examples/IIC.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albarji/mixture-of-diffusers/af42292d0a8cb414f6da2eeac79be4c60afbbe48/examples/IIC.png -------------------------------------------------------------------------------- /examples/linearForest.json: -------------------------------------------------------------------------------- 1 | { 2 | "cpu_vae": true, 3 | "gc": 8, 4 | "gc_tiles": null, 5 | "prompt": [ 6 | [ 7 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 8 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 9 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 10 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 11 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 12 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 13 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 14 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 15 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors", 16 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors" 17 | ] 18 | ], 19 | "scheduler": "lms", 20 | "seed": 639688656, 21 | "steps": 50, 22 | "tile_col_overlap": 256, 23 | "tile_height": 768, 24 | "tile_row_overlap": 256, 25 | "tile_width": 512 26 | } -------------------------------------------------------------------------------- /generate_grid.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import datetime 3 | from diffusers import LMSDiscreteScheduler, DDIMScheduler 4 | import git 5 | import json 6 | import numpy as np 7 | from pathlib import Path 8 | 9 | from mixdiff import StableDiffusionTilingPipeline 10 | 11 | ### CONFIG START 12 | n = 3 13 | sche = "lms" 14 | gc = 8 15 | seed = None 16 | steps = 50 17 | 18 | suffix = "elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece" 19 | prompt = [ 20 | [ 21 | f"A charming house in the countryside, by jakub rozalski, sunset lighting, {suffix}", 22 | f"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, {suffix}", 23 | f"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, {suffix}" 24 | ] 25 | ] 26 | 27 | gc_tiles = [ 28 | [None, None, None], 29 | ] 30 | 31 | seed_tiles = [ 32 | [None, None, None], 33 | ] 34 | seed_tiles_mode = StableDiffusionTilingPipeline.SeedTilesMode.FULL.value 35 | 36 | seed_reroll_regions = [] 37 | 38 | tile_height = 640 39 | tile_width = 640 40 | tile_row_overlap = 256 41 | tile_col_overlap = 256 42 | cpu_vae = False 43 | 44 | ### CONFIG END 45 | 46 | # Prepared scheduler 47 | if sche == "ddim": 48 | scheduler = DDIMScheduler() 49 | elif sche == "lms": 50 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 51 | else: 52 | raise ValueError(f"Unrecognized scheduler {sche}") 53 | 54 | # Load model 55 | model_id = "CompVis/stable-diffusion-v1-4" 56 | pipelinekind = StableDiffusionTilingPipeline 57 | pipe = pipelinekind.from_pretrained(model_id, scheduler=scheduler, use_auth_token=True).to("cuda:0") 58 | 59 | for _ in range(n): 60 | # Prepare parameters 61 | gc_image = gc if gc is not None else np.random.randint(5, 30) 62 | steps_image = steps if steps is not None else np.random.randint(50, 150) 63 | seed_image = seed if seed is not None else np.random.randint(9999999999) 64 | seed_tiles_image = deepcopy(seed_tiles) 65 | for row in range(len(seed_tiles)): 66 | for col in range(len(seed_tiles[0])): 67 | if seed_tiles[row][col] == "RNG": 68 | seed_tiles_image[row][col] = np.random.randint(9999999999) 69 | seed_reroll_regions_image = deepcopy(seed_reroll_regions) 70 | for i in range(len(seed_reroll_regions)): 71 | row_init, row_end, col_init, col_end, seed_reroll = seed_reroll_regions[i] 72 | if seed_reroll == "RNG": 73 | seed_reroll = np.random.randint(9999999999) 74 | seed_reroll_regions_image[i] = (row_init, row_end, col_init, col_end, seed_reroll) 75 | pipeargs = { 76 | "guidance_scale": gc_image, 77 | "num_inference_steps": steps_image, 78 | "seed": seed_image, 79 | "prompt": prompt, 80 | "tile_height": tile_height, 81 | "tile_width": tile_width, 82 | "tile_row_overlap": tile_row_overlap, 83 | "tile_col_overlap": tile_col_overlap, 84 | "guidance_scale_tiles": gc_tiles, 85 | "seed_tiles": seed_tiles_image, 86 | "seed_tiles_mode": seed_tiles_mode, 87 | "seed_reroll_regions": seed_reroll_regions_image, 88 | "cpu_vae": cpu_vae, 89 | } 90 | image = pipe(**pipeargs)["sample"][0] 91 | ct = datetime.datetime.now() 92 | outname = f"{ct}_{prompt[0][0][0:100]}_{tile_height}x{tile_width}_sche{sche}_seed{seed_image}_gc{gc_image}_steps{steps_image}" 93 | outpath = "./outputs" 94 | Path(outpath).mkdir(parents=True, exist_ok=True) 95 | image.save(f"{outpath}/{outname}.png") 96 | logspath = "./logs" 97 | Path(logspath).mkdir(parents=True, exist_ok=True) 98 | with open(f"{logspath}/{outname}.json", "w") as f: 99 | json.dump( 100 | { 101 | "prompt": prompt, 102 | "tile_height": tile_height, 103 | "tile_width": tile_width, 104 | "tile_row_overlap": tile_row_overlap, 105 | "tile_col_overlap": tile_col_overlap, 106 | "scheduler": sche, 107 | "seed": seed_image, 108 | "gc": gc_image, 109 | "gc_tiles": gc_tiles, 110 | "steps": steps_image, 111 | "seed_tiles": seed_tiles_image, 112 | "seed_tiles_mode": seed_tiles_mode, 113 | "seed_reroll_regions": seed_reroll_regions_image, 114 | "cpu_vae": cpu_vae, 115 | "git_commit": git.Repo(search_parent_directories=True).head.object.hexsha, 116 | }, 117 | f, 118 | sort_keys=True, 119 | indent=4 120 | ) 121 | -------------------------------------------------------------------------------- /generate_grid_from_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | from diffusers import LMSDiscreteScheduler, DDIMScheduler 4 | import json 5 | from pathlib import Path 6 | import torch 7 | 8 | from mixdiff.tiling import StableDiffusionTilingPipeline 9 | 10 | def generate_grid(generation_arguments): 11 | model_id = "CompVis/stable-diffusion-v1-4" 12 | # Prepared scheduler 13 | if generation_arguments["scheduler"] == "ddim": 14 | scheduler = DDIMScheduler() 15 | elif generation_arguments["scheduler"] == "lms": 16 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000) 17 | else: 18 | raise ValueError(f"Unrecognized scheduler {generation_arguments['scheduler']}") 19 | pipe = StableDiffusionTilingPipeline.from_pretrained(model_id, scheduler=scheduler, use_auth_token=True).to("cuda:0") 20 | 21 | pipeargs = { 22 | "guidance_scale": generation_arguments["gc"], 23 | "num_inference_steps": generation_arguments["steps"], 24 | "seed": generation_arguments["seed"], 25 | "prompt": generation_arguments["prompt"], 26 | "tile_height": generation_arguments["tile_height"], 27 | "tile_width": generation_arguments["tile_width"], 28 | "tile_row_overlap": generation_arguments["tile_row_overlap"], 29 | "tile_col_overlap": generation_arguments["tile_col_overlap"], 30 | "guidance_scale_tiles": generation_arguments["gc_tiles"], 31 | "cpu_vae": generation_arguments["cpu_vae"] if "cpu_vae" in generation_arguments else False, 32 | } 33 | if "seed_tiles" in generation_arguments: pipeargs = {**pipeargs, "seed_tiles": generation_arguments["seed_tiles"]} 34 | if "seed_tiles_mode" in generation_arguments: pipeargs = {**pipeargs, "seed_tiles_mode": generation_arguments["seed_tiles_mode"]} 35 | if "seed_reroll_regions" in generation_arguments: pipeargs = {**pipeargs, "seed_reroll_regions": generation_arguments["seed_reroll_regions"]} 36 | image = pipe(**pipeargs)["sample"][0] 37 | outname = "output" 38 | outpath = "./outputs" 39 | Path(outpath).mkdir(parents=True, exist_ok=True) 40 | image.save(f"{outpath}/{outname}.png") 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser(description='Generate a stable diffusion grid using a JSON file with all configuration parameters.') 44 | parser.add_argument('config', type=str, help='Path to configuration file') 45 | args = parser.parse_args() 46 | with open(args.config, "r") as f: 47 | generation_arguments = json.load(f) 48 | generate_grid(generation_arguments) 49 | -------------------------------------------------------------------------------- /mixdiff/__init__.py: -------------------------------------------------------------------------------- 1 | from .canvas import Image2ImageRegion, RerollRegion, StableDiffusionCanvasPipeline, Text2ImageRegion 2 | from .imgtools import preprocess_image 3 | from .tiling import StableDiffusionTilingPipeline -------------------------------------------------------------------------------- /mixdiff/canvas.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import asdict, dataclass 3 | from enum import Enum 4 | import numpy as np 5 | from numpy import pi, exp, sqrt 6 | import re 7 | import torch 8 | from torchvision.transforms.functional import resize 9 | from tqdm.auto import tqdm 10 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 11 | from typing import List, Optional, Tuple, Union 12 | 13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 14 | from diffusers.pipeline_utils import DiffusionPipeline 15 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler 16 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 17 | 18 | 19 | class MaskModes(Enum): 20 | """Modes in which the influence of diffuser is masked""" 21 | CONSTANT = "constant" 22 | GAUSSIAN = "gaussian" 23 | QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics) 24 | 25 | 26 | class RerollModes(Enum): 27 | """Modes in which the reroll regions operate""" 28 | RESET = "reset" # Completely reset the random noise in the region 29 | EPSILON = "epsilon" # Alter slightly the latents in the region 30 | 31 | 32 | @dataclass 33 | class CanvasRegion: 34 | """Class defining a rectangular region in the canvas""" 35 | row_init: int # Region starting row in pixel space (included) 36 | row_end: int # Region end row in pixel space (not included) 37 | col_init: int # Region starting column in pixel space (included) 38 | col_end: int # Region end column in pixel space (not included) 39 | region_seed: int = None # Seed for random operations in this region 40 | noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents 41 | 42 | def __post_init__(self): 43 | # Initialize arguments if not specified 44 | if self.region_seed is None: 45 | self.region_seed = np.random.randint(9999999999) 46 | # Check coordinates are non-negative 47 | for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: 48 | if coord < 0: 49 | raise ValueError(f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})") 50 | # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space 51 | for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: 52 | if coord // 8 != coord / 8: 53 | raise ValueError(f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})") 54 | # Check noise eps is non-negative 55 | if self.noise_eps < 0: 56 | raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}") 57 | # Compute coordinates for this region in latent space 58 | self.latent_row_init = self.row_init // 8 59 | self.latent_row_end = self.row_end // 8 60 | self.latent_col_init = self.col_init // 8 61 | self.latent_col_end = self.col_end // 8 62 | 63 | @property 64 | def width(self): 65 | return self.col_end - self.col_init 66 | 67 | @property 68 | def height(self): 69 | return self.row_end - self.row_init 70 | 71 | def get_region_generator(self, device="cpu"): 72 | """Creates a torch.Generator based on the random seed of this region""" 73 | # Initialize region generator 74 | return torch.Generator(device).manual_seed(self.region_seed) 75 | 76 | @property 77 | def __dict__(self): 78 | return asdict(self) 79 | 80 | 81 | @dataclass 82 | class DiffusionRegion(CanvasRegion): 83 | """Abstract class defining a region where some class of diffusion process is acting""" 84 | pass 85 | 86 | 87 | @dataclass 88 | class RerollRegion(CanvasRegion): 89 | """Class defining a rectangular canvas region in which initial latent noise will be rerolled""" 90 | reroll_mode: RerollModes = RerollModes.RESET.value 91 | 92 | 93 | @dataclass 94 | class Text2ImageRegion(DiffusionRegion): 95 | """Class defining a region where a text guided diffusion process is acting""" 96 | prompt: str = "" # Text prompt guiding the diffuser in this region 97 | guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize 98 | mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region 99 | mask_weight: float = 1.0 # Global weights multiplier of the mask 100 | tokenized_prompt = None # Tokenized prompt 101 | encoded_prompt = None # Encoded prompt 102 | 103 | def __post_init__(self): 104 | super().__post_init__() 105 | # Mask weight cannot be negative 106 | if self.mask_weight < 0: 107 | raise ValueError(f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}") 108 | # Mask type must be an actual known mask 109 | if self.mask_type not in [e.value for e in MaskModes]: 110 | raise ValueError(f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})") 111 | # Randomize arguments if given as None 112 | if self.guidance_scale is None: 113 | self.guidance_scale = np.random.randint(5, 30) 114 | # Clean prompt 115 | self.prompt = re.sub(' +', ' ', self.prompt).replace("\n", " ") 116 | 117 | def tokenize_prompt(self, tokenizer): 118 | """Tokenizes the prompt for this diffusion region using a given tokenizer""" 119 | self.tokenized_prompt = tokenizer(self.prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt") 120 | 121 | def encode_prompt(self, text_encoder, device): 122 | """Encodes the previously tokenized prompt for this diffusion region using a given encoder""" 123 | assert self.tokenized_prompt is not None, ValueError("Prompt in diffusion region must be tokenized before encoding") 124 | self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0] 125 | 126 | 127 | @dataclass 128 | class Image2ImageRegion(DiffusionRegion): 129 | """Class defining a region where an image guided diffusion process is acting""" 130 | reference_image: torch.FloatTensor = None 131 | strength: float = 0.8 # Strength of the image 132 | 133 | def __post_init__(self): 134 | super().__post_init__() 135 | if self.reference_image is None: 136 | raise ValueError("Must provide a reference image when creating an Image2ImageRegion") 137 | if self.strength < 0 or self.strength > 1: 138 | raise ValueError(f'The value of strength should in [0.0, 1.0] but is {self.strength}') 139 | # Rescale image to region shape 140 | self.reference_image = resize(self.reference_image, size=[self.height, self.width]) 141 | 142 | def encode_reference_image(self, encoder, device, generator, cpu_vae=False): 143 | """Encodes the reference image for this Image2Image region into the latent space""" 144 | # Place encoder in CPU or not following the parameter cpu_vae 145 | if cpu_vae: 146 | # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome 147 | self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device) 148 | else: 149 | self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(generator=generator) 150 | self.reference_latents = 0.18215 * self.reference_latents 151 | 152 | @property 153 | def __dict__(self): 154 | # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON 155 | 156 | # Get all basic fields from parent class 157 | super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()} 158 | # Pack other fields 159 | return { 160 | **super_fields, 161 | "reference_image": self.reference_image.cpu().tolist(), 162 | "strength": self.strength 163 | } 164 | 165 | 166 | @dataclass 167 | class MaskWeightsBuilder: 168 | """Auxiliary class to compute a tensor of weights for a given diffusion region""" 169 | latent_space_dim: int # Size of the U-net latent space 170 | nbatch: int = 1 # Batch size in the U-net 171 | 172 | def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor: 173 | """Computes a tensor of weights for a given diffusion region""" 174 | MASK_BUILDERS = { 175 | MaskModes.CONSTANT.value: self._constant_weights, 176 | MaskModes.GAUSSIAN.value: self._gaussian_weights, 177 | MaskModes.QUARTIC.value: self._quartic_weights, 178 | } 179 | return MASK_BUILDERS[region.mask_type](region) 180 | 181 | def _constant_weights(self, region: DiffusionRegion) -> torch.tensor: 182 | """Computes a tensor of constant for a given diffusion region""" 183 | latent_width = region.latent_col_end - region.latent_col_init 184 | latent_height = region.latent_row_end - region.latent_row_init 185 | return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight 186 | 187 | def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor: 188 | """Generates a gaussian mask of weights for tile contributions""" 189 | latent_width = region.latent_col_end - region.latent_col_init 190 | latent_height = region.latent_row_end - region.latent_row_init 191 | 192 | var = 0.01 193 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 194 | x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)] 195 | midpoint = (latent_height -1) / 2 196 | y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)] 197 | 198 | weights = np.outer(y_probs, x_probs) * region.mask_weight 199 | return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) 200 | 201 | def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor: 202 | """Generates a quartic mask of weights for tile contributions 203 | 204 | The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits. 205 | """ 206 | quartic_constant = 15. / 16. 207 | 208 | support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (region.latent_col_end - region.latent_col_init - 1) * 1.99 - (1.99 / 2.) 209 | x_probs = quartic_constant * np.square(1 - np.square(support)) 210 | support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (region.latent_row_end - region.latent_row_init - 1) * 1.99 - (1.99 / 2.) 211 | y_probs = quartic_constant * np.square(1 - np.square(support)) 212 | 213 | weights = np.outer(y_probs, x_probs) * region.mask_weight 214 | return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) 215 | 216 | 217 | class StableDiffusionCanvasPipeline(DiffusionPipeline): 218 | """Stable Diffusion pipeline that mixes several diffusers in the same canvas""" 219 | def __init__( 220 | self, 221 | vae: AutoencoderKL, 222 | text_encoder: CLIPTextModel, 223 | tokenizer: CLIPTokenizer, 224 | unet: UNet2DConditionModel, 225 | scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler], 226 | safety_checker: StableDiffusionSafetyChecker, 227 | feature_extractor: CLIPFeatureExtractor, 228 | ): 229 | super().__init__() 230 | self.register_modules( 231 | vae=vae, 232 | text_encoder=text_encoder, 233 | tokenizer=tokenizer, 234 | unet=unet, 235 | scheduler=scheduler, 236 | safety_checker=safety_checker, 237 | feature_extractor=feature_extractor, 238 | ) 239 | 240 | def decode_latents(self, latents, cpu_vae=False): 241 | """Decodes a given array of latents into pixel space""" 242 | # scale and decode the image latents with vae 243 | if cpu_vae: 244 | lat = deepcopy(latents).cpu() 245 | vae = deepcopy(self.vae).cpu() 246 | else: 247 | lat = latents 248 | vae = self.vae 249 | 250 | lat = 1 / 0.18215 * lat 251 | image = vae.decode(lat).sample 252 | 253 | image = (image / 2 + 0.5).clamp(0, 1) 254 | image = image.cpu().permute(0, 2, 3, 1).numpy() 255 | 256 | return self.numpy_to_pil(image) 257 | 258 | def get_latest_timestep_img2img(self, num_inference_steps, strength): 259 | """Finds the latest timesteps where an img2img strength does not impose latents anymore""" 260 | # get the original timestep using init_timestep 261 | offset = self.scheduler.config.get("steps_offset", 0) 262 | init_timestep = int(num_inference_steps * (1 - strength)) + offset 263 | init_timestep = min(init_timestep, num_inference_steps) 264 | 265 | t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps-1) 266 | latest_timestep = self.scheduler.timesteps[t_start] 267 | 268 | return latest_timestep 269 | 270 | @torch.no_grad() 271 | def __call__( 272 | self, 273 | canvas_height: int, 274 | canvas_width: int, 275 | regions: List[DiffusionRegion], 276 | num_inference_steps: Optional[int] = 50, 277 | seed: Optional[int] = 12345, 278 | reroll_regions: Optional[List[RerollRegion]] = None, 279 | cpu_vae: Optional[bool] = False, 280 | decode_steps: Optional[bool] = False 281 | ): 282 | if reroll_regions is None: 283 | reroll_regions = [] 284 | batch_size = 1 285 | 286 | if decode_steps: 287 | steps_images = [] 288 | 289 | # Prepare scheduler 290 | self.scheduler.set_timesteps(num_inference_steps, device=self.device) 291 | 292 | # Split diffusion regions by their kind 293 | text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)] 294 | image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)] 295 | 296 | # Prepare text embeddings 297 | for region in text2image_regions: 298 | region.tokenize_prompt(self.tokenizer) 299 | region.encode_prompt(self.text_encoder, self.device) 300 | 301 | # Create original noisy latents using the timesteps 302 | latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8) 303 | generator = torch.Generator(self.device).manual_seed(seed) 304 | init_noise = torch.randn(latents_shape, generator=generator, device=self.device) 305 | 306 | # Reset latents in seed reroll regions, if requested 307 | for region in reroll_regions: 308 | if region.reroll_mode == RerollModes.RESET.value: 309 | region_shape = (latents_shape[0], latents_shape[1], region.latent_row_end - region.latent_row_init, region.latent_col_end - region.latent_col_init) 310 | init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device) 311 | 312 | # Apply epsilon noise to regions: first diffusion regions, then reroll regions 313 | all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value] 314 | for region in all_eps_rerolls: 315 | if region.noise_eps > 0: 316 | region_noise = init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] 317 | eps_noise = torch.randn(region_noise.shape, generator=region.get_region_generator(self.device), device=self.device) * region.noise_eps 318 | init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] += eps_noise 319 | 320 | # scale the initial noise by the standard deviation required by the scheduler 321 | latents = init_noise * self.scheduler.init_noise_sigma 322 | 323 | # Get unconditional embeddings for classifier free guidance in text2image regions 324 | for region in text2image_regions: 325 | max_length = region.tokenized_prompt.input_ids.shape[-1] 326 | uncond_input = self.tokenizer( 327 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 328 | ) 329 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 330 | 331 | # For classifier free guidance, we need to do two forward passes. 332 | # Here we concatenate the unconditional and text embeddings into a single batch 333 | # to avoid doing two forward passes 334 | region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt]) 335 | 336 | # Prepare image latents 337 | for region in image2image_regions: 338 | region.encode_reference_image(self.vae, device=self.device, generator=generator) 339 | 340 | # Prepare mask of weights for each region 341 | mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size) 342 | mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions] 343 | 344 | # Diffusion timesteps 345 | for i, t in tqdm(enumerate(self.scheduler.timesteps)): 346 | # Diffuse each region 347 | noise_preds_regions = [] 348 | 349 | # text2image regions 350 | for region in text2image_regions: 351 | region_latents = latents[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] 352 | # expand the latents if we are doing classifier free guidance 353 | latent_model_input = torch.cat([region_latents] * 2) 354 | # scale model input following scheduler rules 355 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 356 | # predict the noise residual 357 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"] 358 | # perform guidance 359 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 360 | noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond) 361 | noise_preds_regions.append(noise_pred_region) 362 | 363 | # Merge noise predictions for all tiles 364 | noise_pred = torch.zeros(latents.shape, device=self.device) 365 | contributors = torch.zeros(latents.shape, device=self.device) 366 | # Add each tile contribution to overall latents 367 | for region, noise_pred_region, mask_weights_region in zip(text2image_regions, noise_preds_regions, mask_weights): 368 | noise_pred[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] += noise_pred_region * mask_weights_region 369 | contributors[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] += mask_weights_region 370 | # Average overlapping areas with more than 1 contributor 371 | noise_pred /= contributors 372 | noise_pred = torch.nan_to_num(noise_pred) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion 373 | 374 | # compute the previous noisy sample x_t -> x_t-1 375 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 376 | 377 | # Image2Image regions: override latents generated by the scheduler 378 | for region in image2image_regions: 379 | influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength) 380 | # Only override in the timesteps before the last influence step of the image (given by its strength) 381 | if t > influence_step: 382 | timestep = t.repeat(batch_size) 383 | region_init_noise = init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] 384 | region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep) 385 | latents[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] = region_latents 386 | 387 | if decode_steps: 388 | steps_images.append(self.decode_latents(latents, cpu_vae)) 389 | 390 | # scale and decode the image latents with vae 391 | image = self.decode_latents(latents, cpu_vae) 392 | 393 | output = {"sample": image} 394 | if decode_steps: 395 | output = {**output, "steps_images": steps_images} 396 | return output 397 | -------------------------------------------------------------------------------- /mixdiff/extrasmixin.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | # TODO: remove after adaptaing tiling to use canvas 4 | class StableDiffusionExtrasMixin: 5 | """Mixin providing additional convenience method to Stable Diffusion pipelines""" 6 | 7 | def decode_latents(self, latents, cpu_vae=False): 8 | """Decodes a given array of latents into pixel space""" 9 | # scale and decode the image latents with vae 10 | if cpu_vae: 11 | lat = deepcopy(latents).cpu() 12 | vae = deepcopy(self.vae).cpu() 13 | else: 14 | lat = latents 15 | vae = self.vae 16 | 17 | lat = 1 / 0.18215 * lat 18 | image = vae.decode(lat).sample 19 | 20 | image = (image / 2 + 0.5).clamp(0, 1) 21 | image = image.cpu().permute(0, 2, 3, 1).numpy() 22 | 23 | return self.numpy_to_pil(image) 24 | -------------------------------------------------------------------------------- /mixdiff/imgtools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image, ImageFilter 4 | 5 | 6 | def preprocess_image(image): 7 | """Preprocess an input image 8 | 9 | Same as https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44 10 | """ 11 | w, h = image.size 12 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 13 | image = image.resize((w, h), resample=Image.LANCZOS) 14 | image = np.array(image).astype(np.float32) / 255.0 15 | image = image[None].transpose(0, 3, 1, 2) 16 | image = torch.from_numpy(image) 17 | return 2.0 * image - 1.0 18 | 19 | 20 | def preprocess_mask(mask, smoothing=None): 21 | """Preprocess an inpainting mask""" 22 | mask = mask.convert("L") 23 | if smoothing is not None: 24 | smoothed = mask.filter(ImageFilter.GaussianBlur(smoothing)) 25 | mask = Image.composite(mask, smoothed, mask) # Original mask values kept as 1, out of mask get smoothed 26 | mask.save("outputs/smoothed_mask.png") # FIXME 27 | w, h = mask.size 28 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32 29 | mask = mask.resize((w // 8, h // 8), resample=Image.NEAREST) 30 | mask = np.array(mask).astype(np.float32) / 255.0 31 | mask = np.tile(mask, (4, 1, 1)) 32 | mask = mask[None].transpose(0, 1, 2, 3) # what does this step do? 33 | mask = 1 - mask # repaint white, keep black 34 | mask = torch.from_numpy(mask) 35 | return mask 36 | -------------------------------------------------------------------------------- /mixdiff/tiling.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | import inspect 3 | from ligo.segments import segment 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import torch 7 | 8 | from tqdm.auto import tqdm 9 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 10 | 11 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 12 | from diffusers.pipeline_utils import DiffusionPipeline 13 | from diffusers.schedulers import DDIMScheduler, PNDMScheduler 14 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker 15 | from diffusers.schedulers import LMSDiscreteScheduler 16 | 17 | from .extrasmixin import StableDiffusionExtrasMixin 18 | 19 | 20 | class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin): 21 | def __init__( 22 | self, 23 | vae: AutoencoderKL, 24 | text_encoder: CLIPTextModel, 25 | tokenizer: CLIPTokenizer, 26 | unet: UNet2DConditionModel, 27 | scheduler: Union[DDIMScheduler, PNDMScheduler], 28 | safety_checker: StableDiffusionSafetyChecker, 29 | feature_extractor: CLIPFeatureExtractor, 30 | ): 31 | super().__init__() 32 | self.register_modules( 33 | vae=vae, 34 | text_encoder=text_encoder, 35 | tokenizer=tokenizer, 36 | unet=unet, 37 | scheduler=scheduler, 38 | safety_checker=safety_checker, 39 | feature_extractor=feature_extractor, 40 | ) 41 | 42 | class SeedTilesMode(Enum): 43 | """Modes in which the latents of a particular tile can be re-seeded""" 44 | FULL = "full" 45 | EXCLUSIVE = "exclusive" 46 | 47 | @torch.no_grad() 48 | def __call__( 49 | self, 50 | prompt: Union[str, List[List[str]]], 51 | num_inference_steps: Optional[int] = 50, 52 | guidance_scale: Optional[float] = 7.5, 53 | eta: Optional[float] = 0.0, 54 | seed: Optional[int] = None, 55 | tile_height: Optional[int] = 512, 56 | tile_width: Optional[int] = 512, 57 | tile_row_overlap: Optional[int] = 256, 58 | tile_col_overlap: Optional[int] = 256, 59 | guidance_scale_tiles: Optional[List[List[float]]] = None, 60 | seed_tiles: Optional[List[List[int]]] = None, 61 | seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full", 62 | seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None, 63 | cpu_vae: Optional[bool] = False, 64 | ): 65 | 66 | if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt): 67 | raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}") 68 | grid_rows = len(prompt) 69 | grid_cols = len(prompt[0]) 70 | if not all(len(row) == grid_cols for row in prompt): 71 | raise ValueError(f"All prompt rows must have the same number of prompt columns") 72 | if not isinstance(seed_tiles_mode, str) and (not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)): 73 | raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}") 74 | if isinstance(seed_tiles_mode, str): 75 | seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt] 76 | if any(mode not in (modes := [mode.value for mode in self.SeedTilesMode]) for row in seed_tiles_mode for mode in row): 77 | raise ValueError(f"Seed tiles mode must be one of {modes}") 78 | if seed_reroll_regions is None: 79 | seed_reroll_regions = [] 80 | batch_size = 1 81 | 82 | # create original noisy latents using the timesteps 83 | height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap) 84 | width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap) 85 | latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8) 86 | generator = torch.Generator("cuda").manual_seed(seed) 87 | latents = torch.randn(latents_shape, generator=generator, device=self.device) 88 | 89 | # overwrite latents for specific tiles if provided 90 | if seed_tiles is not None: 91 | for row in range(grid_rows): 92 | for col in range(grid_cols): 93 | if (seed_tile := seed_tiles[row][col]) is not None: 94 | mode = seed_tiles_mode[row][col] 95 | if mode == self.SeedTilesMode.FULL.value: 96 | row_init, row_end, col_init, col_end = _tile2latent_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap) 97 | else: 98 | row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, grid_rows, grid_cols) 99 | tile_generator = torch.Generator("cuda").manual_seed(seed_tile) 100 | tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) 101 | latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(tile_shape, generator=tile_generator, device=self.device) 102 | 103 | # overwrite again for seed reroll regions 104 | for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions: 105 | row_init, row_end, col_init, col_end = _pixel2latent_indices(row_init, row_end, col_init, col_end) # to latent space coordinates 106 | reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll) 107 | region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init) 108 | latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(region_shape, generator=reroll_generator, device=self.device) 109 | 110 | # Prepare scheduler 111 | accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys()) 112 | extra_set_kwargs = {} 113 | if accepts_offset: 114 | extra_set_kwargs["offset"] = 1 115 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs) 116 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 117 | if isinstance(self.scheduler, LMSDiscreteScheduler): 118 | latents = latents * self.scheduler.sigmas[0] 119 | 120 | # get prompts text embeddings 121 | text_input = [ 122 | [ 123 | self.tokenizer( 124 | col, 125 | padding="max_length", 126 | max_length=self.tokenizer.model_max_length, 127 | truncation=True, 128 | return_tensors="pt", 129 | ) 130 | for col in row 131 | ] 132 | for row in prompt 133 | ] 134 | text_embeddings = [ 135 | [ 136 | self.text_encoder(col.input_ids.to(self.device))[0] 137 | for col in row 138 | ] 139 | for row in text_input 140 | ] 141 | 142 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 143 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 144 | # corresponds to doing no classifier free guidance. 145 | do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale 146 | # get unconditional embeddings for classifier free guidance 147 | if do_classifier_free_guidance: 148 | for i in range(grid_rows): 149 | for j in range(grid_cols): 150 | max_length = text_input[i][j].input_ids.shape[-1] 151 | uncond_input = self.tokenizer( 152 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" 153 | ) 154 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 155 | 156 | # For classifier free guidance, we need to do two forward passes. 157 | # Here we concatenate the unconditional and text embeddings into a single batch 158 | # to avoid doing two forward passes 159 | text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]]) 160 | 161 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 162 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 163 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 164 | # and should be between [0, 1] 165 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 166 | extra_step_kwargs = {} 167 | if accepts_eta: 168 | extra_step_kwargs["eta"] = eta 169 | 170 | # Mask for tile weights strenght 171 | tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size) 172 | 173 | # Diffusion timesteps 174 | for i, t in tqdm(enumerate(self.scheduler.timesteps)): 175 | # Diffuse each tile 176 | noise_preds = [] 177 | for row in range(grid_rows): 178 | noise_preds_row = [] 179 | for col in range(grid_cols): 180 | px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap) 181 | tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end] 182 | # expand the latents if we are doing classifier free guidance 183 | latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents 184 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 185 | # predict the noise residual 186 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])["sample"] 187 | # perform guidance 188 | if do_classifier_free_guidance: 189 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 190 | guidance = guidance_scale if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None else guidance_scale_tiles[row][col] 191 | noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) 192 | noise_preds_row.append(noise_pred_tile) 193 | noise_preds.append(noise_preds_row) 194 | # Stitch noise predictions for all tiles 195 | noise_pred = torch.zeros(latents.shape, device=self.device) 196 | contributors = torch.zeros(latents.shape, device=self.device) 197 | # Add each tile contribution to overall latents 198 | for row in range(grid_rows): 199 | for col in range(grid_cols): 200 | px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap) 201 | noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += noise_preds[row][col] * tile_weights 202 | contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights 203 | # Average overlapping areas with more than 1 contributor 204 | noise_pred /= contributors 205 | 206 | # compute the previous noisy sample x_t -> x_t-1 207 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample 208 | 209 | # scale and decode the image latents with vae 210 | image = self.decode_latents(latents, cpu_vae) 211 | 212 | return {"sample": image} 213 | 214 | def _gaussian_weights(self, tile_width, tile_height, nbatches): 215 | """Generates a gaussian mask of weights for tile contributions""" 216 | from numpy import pi, exp, sqrt 217 | import numpy as np 218 | 219 | latent_width = tile_width // 8 220 | latent_height = tile_height // 8 221 | 222 | var = 0.01 223 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 224 | x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)] 225 | midpoint = latent_height / 2 226 | y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)] 227 | 228 | weights = np.outer(y_probs, x_probs) 229 | return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1)) 230 | 231 | 232 | 233 | def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): 234 | """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image 235 | 236 | Returns a tuple with: 237 | - Starting coordinates of rows in pixel space 238 | - Ending coordinates of rows in pixel space 239 | - Starting coordinates of columns in pixel space 240 | - Ending coordinates of columns in pixel space 241 | """ 242 | px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap) 243 | px_row_end = px_row_init + tile_height 244 | px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap) 245 | px_col_end = px_col_init + tile_width 246 | return px_row_init, px_row_end, px_col_init, px_col_end 247 | 248 | 249 | def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end): 250 | """Translates coordinates in pixel space to coordinates in latent space""" 251 | return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8 252 | 253 | 254 | def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap): 255 | """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image 256 | 257 | Returns a tuple with: 258 | - Starting coordinates of rows in latent space 259 | - Ending coordinates of rows in latent space 260 | - Starting coordinates of columns in latent space 261 | - Ending coordinates of columns in latent space 262 | """ 263 | px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap) 264 | return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end) 265 | 266 | 267 | def _tile2latent_exclusive_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns): 268 | """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image 269 | 270 | Returns a tuple with: 271 | - Starting coordinates of rows in latent space 272 | - Ending coordinates of rows in latent space 273 | - Starting coordinates of columns in latent space 274 | - Ending coordinates of columns in latent space 275 | """ 276 | row_init, row_end, col_init, col_end = _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap) 277 | row_segment = segment(row_init, row_end) 278 | col_segment = segment(col_init, col_end) 279 | # Iterate over the rest of tiles, clipping the region for the current tile 280 | for row in range(rows): 281 | for column in range(columns): 282 | if row != tile_row and column != tile_col: 283 | clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap) 284 | row_segment = row_segment - segment(clip_row_init, clip_row_end) 285 | col_segment = col_segment - segment(clip_col_init, clip_col_end) 286 | #return row_init, row_end, col_init, col_end 287 | return row_segment[0], row_segment[1], col_segment[0], col_segment[1] 288 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.10.* 2 | diffusers[torch]==0.16.* 3 | ftfy==6.1.* 4 | gitpython==3.1.* 5 | ligo-segments==1.4.* 6 | torch==2.0.* 7 | torchvision==0.15.* 8 | transformers==4.29.* 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | 5 | # Read long description from readme 6 | with open("README.md", "r", encoding="utf-8") as fh: 7 | LONG_DESCRIPTION = fh.read() 8 | 9 | 10 | # Get tag from Github environment variables 11 | TAG = os.environ['GITHUB_TAG'] if 'GITHUB_TAG' in os.environ else "0.0.0" 12 | 13 | 14 | setup( 15 | name="mixdiff", 16 | version=TAG, 17 | description="Mixture of Diffusers for scene composition and high resolution image generation .", 18 | long_description=LONG_DESCRIPTION, 19 | long_description_content_type="text/markdown", 20 | packages=['mixdiff'], 21 | install_requires=[ 22 | 'numpy>=1.19,<2', 23 | 'tqdm>=4.62,<5', 24 | 'scipy==1.10.*', 25 | 'diffusers[torch]==0.16.*', 26 | 'ftfy==6.1.*', 27 | 'gitpython==3.1.*', 28 | 'ligo-segments==1.4.*', 29 | 'torch==2.0.*', 30 | 'torchvision==0.15.*', 31 | 'transformers==4.29.*' 32 | ], 33 | author="Alvaro Barbero", 34 | url='https://github.com/albarji/mixture-of-diffusers', 35 | license='MIT', 36 | classifiers=[ 37 | 'Development Status :: 1 - Planning', 38 | 'Environment :: GPU :: NVIDIA CUDA', 39 | 'Intended Audience :: Science/Research', 40 | 'License :: OSI Approved :: MIT License', 41 | 'Operating System :: Unix', 42 | 'Programming Language :: Python :: 3.8', 43 | 'Programming Language :: Python :: 3.9', 44 | 'Programming Language :: Python :: 3.10', 45 | 'Topic :: Artistic Software', 46 | 'Topic :: Multimedia :: Graphics' 47 | ], 48 | keywords='artificial-intelligence, deep-learning, diffusion-models', 49 | test_suite="pytest", 50 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/albarji/mixture-of-diffusers/af42292d0a8cb414f6da2eeac79be4c60afbbe48/tests/__init__.py -------------------------------------------------------------------------------- /tests/canvas_test.py: -------------------------------------------------------------------------------- 1 | from diffusers import LMSDiscreteScheduler 2 | from PIL import Image 3 | import pytest 4 | 5 | from mixdiff import Image2ImageRegion, StableDiffusionCanvasPipeline, Text2ImageRegion, preprocess_image 6 | from mixdiff.canvas import CanvasRegion 7 | 8 | ### CanvasRegion tests 9 | 10 | @pytest.mark.parametrize("region_params", [ 11 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512}, 12 | {"row_init": 0, "row_end": 256, "col_init": 0, "col_end": 512}, 13 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 256}, 14 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "region_seed": 12345}, 15 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "noise_eps": 0.1} 16 | ]) 17 | def test_create_canvas_region_correct(region_params): 18 | """Creating a correct canvas region with basic parameters works""" 19 | region = CanvasRegion(**region_params) 20 | assert region.row_init == region_params["row_init"] 21 | assert region.row_end == region_params["row_end"] 22 | assert region.col_init == region_params["col_init"] 23 | assert region.col_end == region_params["col_end"] 24 | assert region.height == region.row_end - region.row_init 25 | assert region.width == region.col_end - region.col_init 26 | 27 | def test_create_canvas_region_eps(): 28 | """Creating a correct canvas region works""" 29 | CanvasRegion(0, 512, 0, 512) 30 | 31 | def test_create_canvas_region_non_multiple_size(): 32 | """Creating a canvas region with sizes that are not a multiple of 8 fails""" 33 | with pytest.raises(ValueError): 34 | CanvasRegion(0, 17, 0, 15) 35 | 36 | def test_create_canvas_region_negative_indices(): 37 | """Creating a canvas region with negative indices fails""" 38 | with pytest.raises(ValueError): 39 | CanvasRegion(-512, 0, -256, 0) 40 | 41 | def test_create_canvas_region_negative_eps(): 42 | """Creating a canvas region with negative epsilon noise fails""" 43 | with pytest.raises(ValueError): 44 | CanvasRegion(0, 512, 0, 512, noise_eps=-3) 45 | 46 | ### Text2ImageRegion tests 47 | 48 | @pytest.mark.parametrize("region_params", [ 49 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512}, 50 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers"}, 51 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15.}, 52 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15., "mask_type": "constant", "mask_weight": 1.0}, 53 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15., "mask_type": "gaussian", "mask_weight": 0.75}, 54 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15., "mask_type": "quartic", "mask_weight": 1.5}, 55 | ]) 56 | def test_create_text2image_region_correct(region_params): 57 | """Creating a Text2Image region with correct parameters works""" 58 | region = Text2ImageRegion(**region_params) 59 | if "prompt" in region_params: assert region.prompt == region_params["prompt"] 60 | if "guidance_scale" in region_params: assert region.guidance_scale == region_params["guidance_scale"] 61 | if "mask_type" in region_params: assert region.mask_type == region_params["mask_type"] 62 | if "mask_weight" in region_params: assert region.mask_weight == region_params["mask_weight"] 63 | 64 | def test_create_text2image_region_negative_weight(): 65 | """We can't specify a Text2Image region with mask weight""" 66 | with pytest.raises(ValueError): 67 | Text2ImageRegion(0, 512, 0, 512, prompt="Pikachu unit-testing Mixture of Diffusers", mask_type="gaussian", mask_weight=-0.1) 68 | 69 | def test_create_text2image_region_unknown_mask(): 70 | """We can't specify a Text2Image region with mask not in the recognized masks list""" 71 | with pytest.raises(ValueError): 72 | Text2ImageRegion(0, 512, 0, 512, prompt="Link unit-testing Mixture of Diffusers", mask_type="majora", mask_weight=1.0) 73 | 74 | ### Image2ImageRegion tests 75 | 76 | @pytest.fixture(scope="session") 77 | def base_image(): 78 | return preprocess_image(Image.open("examples/IIC.png").convert("RGB")) 79 | 80 | @pytest.mark.parametrize("region_params", [ 81 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512}, 82 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "strength": 0.1}, 83 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "strength": 0.9}, 84 | ]) 85 | def test_create_image2image_region_correct(region_params, base_image): 86 | """Creating a Image2Image region with correct parameters works""" 87 | region = Image2ImageRegion(**region_params, reference_image=base_image) 88 | assert region.reference_image.shape == (1, 3, region.height, region.width) 89 | 90 | @pytest.mark.parametrize("region_params", [ 91 | {"row_init": 0, "row_end": 256, "col_init": 0, "col_end": 256, "strength": -0.3}, 92 | {"row_init": 0, "row_end": 256, "col_init": 0, "col_end": 256, "strength": 1.1} 93 | ]) 94 | def test_create_image2image_region_negative_strength(region_params, base_image): 95 | """We can't specify an Image2Image region with strength values outside of [0, 1]""" 96 | with pytest.raises(ValueError): 97 | Image2ImageRegion(**region_params, reference_image=base_image) 98 | 99 | ### StableDiffusionCanvasPipeline tests 100 | 101 | @pytest.fixture(scope="session") 102 | def canvas_pipeline(): 103 | return StableDiffusionCanvasPipeline.from_pretrained( 104 | "CompVis/stable-diffusion-v1-4", 105 | scheduler=LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000), 106 | use_auth_token=True 107 | ) 108 | 109 | @pytest.fixture() 110 | def basic_canvas_params(): 111 | return { 112 | "canvas_height": 64, 113 | "canvas_width": 64, 114 | "regions": [ 115 | Text2ImageRegion(0, 48, 0, 48, mask_type="gaussian", prompt="Something"), 116 | Text2ImageRegion(16, 64, 0, 48, mask_type="gaussian", prompt="Something else"), 117 | Text2ImageRegion(0, 48, 16, 64, mask_type="gaussian", prompt="Something more"), 118 | Text2ImageRegion(16, 64, 16, 64, mask_type="gaussian", prompt="One last thing"), 119 | ] 120 | } 121 | 122 | @pytest.mark.parametrize("extra_canvas_params", [ 123 | {"num_inference_steps": 1}, 124 | {"num_inference_steps": 1, "cpu_vae": True}, 125 | ]) 126 | def test_stable_diffusion_canvas_pipeline_correct(canvas_pipeline, basic_canvas_params, extra_canvas_params): 127 | """The StableDiffusionCanvasPipeline works for some correct configurations""" 128 | image = canvas_pipeline(**basic_canvas_params, **extra_canvas_params)["sample"][0] 129 | assert image.size == (64, 64) 130 | 131 | @pytest.mark.parametrize("extra_canvas_params", [ 132 | {"num_inference_steps": 3}, 133 | {"num_inference_steps": 3, "cpu_vae": True}, 134 | ]) 135 | def test_stable_diffusion_canvas_pipeline_image2image_correct(canvas_pipeline, basic_canvas_params, base_image, extra_canvas_params): 136 | """The StableDiffusionCanvasPipeline works for some correct configurations when including a Text2ImageRegion""" 137 | all_canvas_params = {**basic_canvas_params, **extra_canvas_params} 138 | all_canvas_params["regions"] += [Image2ImageRegion(16, 64, 0, 48, reference_image=base_image, strength=0.5)] 139 | 140 | image = canvas_pipeline(**all_canvas_params)["sample"][0] 141 | assert image.size == (64, 64) 142 | --------------------------------------------------------------------------------