├── .github
└── workflows
│ ├── publish.yml
│ └── python-tests.yml
├── .gitignore
├── LICENSE
├── README.md
├── environment.yml
├── examples
├── IIC.png
└── linearForest.json
├── generate_grid.py
├── generate_grid_from_json.py
├── mixdiff
├── __init__.py
├── canvas.py
├── extrasmixin.py
├── imgtools.py
└── tiling.py
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
└── canvas_test.py
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish package on PyPI
2 |
3 | on: [push]
4 |
5 | jobs:
6 | test-publish-package:
7 | runs-on: ubuntu-latest
8 | name: publish package on test-PyPI
9 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
10 |
11 | steps:
12 | - uses: actions/checkout@v3
13 | - name: Set up Python
14 | uses: actions/setup-python@v4
15 | with:
16 | python-version: '3.x'
17 | - name: Install dependencies
18 | run: |
19 | python -m pip install --upgrade pip
20 | pip install build
21 | - name: Get release tag
22 | id: tag
23 | uses: dawidd6/action-get-tag@v1
24 | - name: Build package
25 | run: GITHUB_TAG=${{steps.tag.outputs.tag}} python -m build
26 | - name: Publish package
27 | uses: pypa/gh-action-pypi-publish@release/v1
28 | with:
29 | repository_url: https://test.pypi.org/legacy/
30 | user: __token__
31 | password: ${{ secrets.TESTPYPI_API_TOKEN }}
32 |
33 | publish-package:
34 | runs-on: ubuntu-latest
35 | name: publish package on PyPI
36 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
37 | needs: test-publish-package
38 |
39 | steps:
40 | - uses: actions/checkout@v3
41 | - name: Set up Python
42 | uses: actions/setup-python@v4
43 | with:
44 | python-version: '3.x'
45 | - name: Install dependencies
46 | run: |
47 | python -m pip install --upgrade pip
48 | pip install build
49 | - name: Get release tag
50 | id: tag
51 | uses: dawidd6/action-get-tag@v1
52 | - name: Build package
53 | run: GITHUB_TAG=${{steps.tag.outputs.tag}} python -m build
54 | - name: Publish package
55 | uses: pypa/gh-action-pypi-publish@release/v1
56 | with:
57 | user: __token__
58 | password: ${{ secrets.PYPI_API_TOKEN }}
59 |
60 | test-package-download:
61 | runs-on: ubuntu-latest
62 | strategy:
63 | matrix:
64 | python-version: ["3.8", "3.9", "3.10"]
65 | name: test package works by installing from test-PyPI
66 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags')
67 | needs: publish-package
68 |
69 | steps:
70 | - uses: actions/checkout@v3
71 | - name: Erase source code
72 | run: rm -rf mixdiff
73 | - name: Set up Python
74 | uses: actions/setup-python@v4
75 | with:
76 | python-version: ${{ matrix.python-version }}
77 | - name: Get release tag
78 | id: tag
79 | uses: dawidd6/action-get-tag@v1
80 | - name: Install package
81 | run: python -m pip install mixdiff==${{steps.tag.outputs.tag}}
82 | - name: Test with pytest
83 | env:
84 | HUGGING_FACE_HUB_TOKEN: ${{ secrets.huggingface_token }}
85 | run: |
86 | pip install pytest
87 | pytest
88 |
--------------------------------------------------------------------------------
/.github/workflows/python-tests.yml:
--------------------------------------------------------------------------------
1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2 | # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3 |
4 | name: Unit tests
5 |
6 | on:
7 | push:
8 | pull_request:
9 | branches: [ "master" ]
10 |
11 | jobs:
12 | test-local-environment:
13 | runs-on: ubuntu-latest
14 | defaults:
15 | run:
16 | shell: bash -l {0}
17 | strategy:
18 | fail-fast: false
19 | matrix:
20 | python-version: ["3.8", "3.9", "3.10"]
21 |
22 | steps:
23 | - uses: actions/checkout@v3
24 | - name: Setup Miniconda
25 | uses: conda-incubator/setup-miniconda@v2.2.0
26 | with:
27 | python-version: ${{ matrix.python-version }}
28 | auto-activate-base: false
29 | activate-environment: mixture_of_diffusers
30 | environment-file: environment.yml
31 | - name: Test with pytest
32 | env:
33 | HUGGING_FACE_HUB_TOKEN: ${{ secrets.huggingface_token }}
34 | run: |
35 | conda install pytest && pytest
36 |
37 | test-package:
38 | runs-on: ubuntu-latest
39 | strategy:
40 | fail-fast: false
41 | matrix:
42 | python-version: ["3.8", "3.9", "3.10"]
43 | steps:
44 | - uses: actions/checkout@v3
45 | - name: Set up Python
46 | uses: actions/setup-python@v4
47 | with:
48 | python-version: ${{ matrix.python-version }}
49 | - name: Install dependencies
50 | run: |
51 | python -m pip install --upgrade pip
52 | pip install build
53 | - name: Build package
54 | run: python -m build
55 | - name: Install package
56 | run: |
57 | pip install .
58 | - name: Test with pytest
59 | env:
60 | HUGGING_FACE_HUB_TOKEN: ${{ secrets.huggingface_token }}
61 | run: |
62 | pip install pytest
63 | rm -rf mixdiff # Delete source to test installed package
64 | pytest
65 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .ipynb_checkpoints
2 | **/__pycache__
3 | nohup.out
4 | outputs
5 | venv
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Álvaro Barbero Jiménez
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Mixture of Diffusers
2 |
3 | 
4 |
5 | [](https://github.com/albarji/mixture-of-diffusers/actions/workflows/python-tests.yml)
6 |
7 | [](https://huggingface.co/spaces/albarji/mixture-of-diffusers)
8 |
9 | This repository holds various scripts and tools implementing a method for integrating a mixture of different diffusion processes collaborating to generate a single image. Each diffuser focuses on a particular region on the image, taking into account boundary effects to promote a smooth blending.
10 |
11 | If you prefer a more user friendly graphical interface to use this algorithm, I recommend trying the [Tiled Diffusion & VAE](https://github.com/pkuliyi2015/multidiffusion-upscaler-for-automatic1111) plugin developed by pkuliyi2015 for [AUTOMATIC1111's stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui).
12 |
13 | ## Motivation
14 |
15 | Current image generation methods, such as Stable Diffusion, struggle to position objects at specific locations. While the content of the generated image (somewhat) reflects the objects present in the prompt, it is difficult to frame the prompt in a way that creates an specific composition. For instance, take a prompt expressing a complex composition such as
16 |
17 | > A charming house in the countryside on the left,
18 | > in the center a dirt road in the countryside crossing pastures,
19 | > on the right an old and rusty giant robot lying on a dirt road,
20 | > by jakub rozalski,
21 | > sunset lighting on the left and center, dark sunset lighting on the right
22 | > elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece
23 |
24 | Out of a sample of 20 Stable Diffusion generations with different seeds, the generated images that align best with the prompt are the following:
25 |
26 |
27 |
28 |  |
29 |  |
30 |  |
31 |
32 |
33 |
34 | The method proposed here strives to provide a better tool for image composition by using several diffusion processes in parallel, each configured with a specific prompt and settings, and focused on a particular region of the image. For example, the following are three outputs from this method, using the following prompts from left to right:
35 |
36 | * "**A charming house in the countryside, by jakub rozalski, sunset lighting**, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
37 | * "**A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting**, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
38 | * "**An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting**, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
39 |
40 | 
41 | 
42 | 
43 |
44 | The mixture of diffusion processes is done in a way that harmonizes the generation process, preventing "seam" effects in the generated image.
45 |
46 | Using several diffusion processes in parallel has also practical advantages when generating very large images, as the GPU memory requirements are similar to that of generating an image of the size of a single tile.
47 |
48 | ## Usage
49 |
50 | This repository provides two new pipelines, `StableDiffusionTilingPipeline` and `StableDiffusionCanvasPipeline`, that extend the standard Stable Diffusion pipeline from [Diffusers](https://github.com/huggingface/diffusers). They feature new options that allow defining the mixture of diffusers, which are distributed as a number of "diffusion regions" over the image to be generated. `StableDiffusionTilingPipeline` is simpler to use and arranges the diffusion regions as a grid over the canvas, while `StableDiffusionCanvasPipeline` allows a more flexible placement and also features image2image capabilities.
51 |
52 | ### Prerequisites
53 |
54 | Since this work is based on Stable Diffusion models, you will need to [request access and accept the usage terms of Stable Diffusion](https://huggingface.co/CompVis/stable-diffusion#model-access). You will also need to [configure your Hugging Face User Access Token](https://huggingface.co/docs/hub/security-tokens) in your running environment.
55 |
56 | ### StableDiffusionTilingPipeline
57 |
58 | The header image in this repo can be generated as follows
59 |
60 | ```python
61 | from diffusers import LMSDiscreteScheduler
62 | from mixdiff import StableDiffusionTilingPipeline
63 |
64 | # Creater scheduler and model (similar to StableDiffusionPipeline)
65 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
66 | pipeline = StableDiffusionTilingPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True).to("cuda:0")
67 |
68 | # Mixture of Diffusers generation
69 | image = pipeline(
70 | prompt=[[
71 | "A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
72 | "A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece",
73 | "An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
74 | ]],
75 | tile_height=640,
76 | tile_width=640,
77 | tile_row_overlap=0,
78 | tile_col_overlap=256,
79 | guidance_scale=8,
80 | seed=7178915308,
81 | num_inference_steps=50,
82 | )["sample"][0]
83 | ```
84 |
85 | The prompts must be provided as a list of lists, where each list represents a row of diffusion regions. The geometry of the canvas is inferred from these lists, e.g. in the example above we are creating a grid of 1x3 diffusion regions (1 row and 3 columns). The rest of parameters provide information on the size of these regions, and how much they overlap with their neighbors.
86 |
87 | Alternatively, it is possible to specify the grid parameters through a JSON configuration file. In the following example a grid of 10x1 tiles is configured to generate a forest in changing styles:
88 |
89 | 
90 |
91 | A `StableDiffusionTilingPipeline` is configured to use 10 prompts with changing styles. Each tile takes a shape of 768x512 pixels, and tiles overlap 256 pixels to avoid seam effects. All the details are specified in a configuration file:
92 |
93 | ```json
94 | {
95 | "cpu_vae": true,
96 | "gc": 8,
97 | "gc_tiles": null,
98 | "prompt": [
99 | [
100 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
101 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
102 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
103 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
104 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
105 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
106 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
107 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
108 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
109 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors"
110 | ]
111 | ],
112 | "scheduler": "lms",
113 | "seed": 639688656,
114 | "steps": 50,
115 | "tile_col_overlap": 256,
116 | "tile_height": 768,
117 | "tile_row_overlap": 256,
118 | "tile_width": 512
119 | }
120 | ```
121 |
122 | You can try generating this image using this configuration file by running
123 |
124 | ```bash
125 | python generate_grid_from_json.py examples/linearForest.json
126 | ```
127 |
128 | The full list of arguments to a `StableDiffusionTilingPipeline` is:
129 |
130 | > `prompt`: either a single string (no tiling) or a list of lists with all the prompts to use (one list for each row of tiles). This will also define the tiling structure.
131 | >
132 | > `num_inference_steps`: number of diffusions steps.
133 | >
134 | > `guidance_scale`: classifier-free guidance.
135 | >
136 | > `seed`: general random seed to initialize latents.
137 | >
138 | > `tile_height`: height in pixels of each grid tile.
139 | >
140 | > `tile_width`: width in pixels of each grid tile.
141 | >
142 | > `tile_row_overlap`: number of overlap pixels between tiles in consecutive rows.
143 | >
144 | > `tile_col_overlap`: number of overlap pixels between tiles in consecutive columns.
145 | >
146 | > `guidance_scale_tiles`: specific weights for classifier-free guidance in each tile.
147 | >
148 | > `guidance_scale_tiles`: specific weights for classifier-free guidance in each tile. If `None`, the value provided in `guidance_scale` will be used.
149 | >
150 | > `seed_tiles`: specific seeds for the initialization latents in each tile. These will override the latents generated for the whole canvas using the standard `seed` parameter.
151 | >
152 | > `seed_tiles_mode`: either `"full"` `"exclusive"`. If `"full"`, all the latents affected by the tile be overriden. If `"exclusive"`, only the latents that are affected exclusively by this tile (and no other tiles) will be overrriden.
153 | >
154 | > `seed_reroll_regions`: a list of tuples in the form (start row, end row, start column, end column, seed) defining regions in pixel space for which the latents will be overriden using the given seed. Takes priority over `seed_tiles`.
155 | >
156 | > `cpu_vae`: the decoder from latent space to pixel space can require too mucho GPU RAM for large images. If you find out of memory errors at the end of the generation process, try setting this parameter to `True` to run the decoder in CPU. Slower, but should run without memory issues.
157 |
158 | A script showing a more advanced use of this pipeline is available as [generate_grid.py](./generate_grid.py).
159 |
160 | ### StableDiffusionCanvasPipeline
161 |
162 | The `StableDiffusionCanvasPipeline` works by defining a list of `Text2ImageRegion` objects that detail the region of influence of each diffuser. As an illustrative example, the heading image at this repo can be generated with the following code:
163 |
164 | ```python
165 | from diffusers import LMSDiscreteScheduler
166 | from mixdiff import StableDiffusionCanvasPipeline, Text2ImageRegion
167 |
168 | # Creater scheduler and model (similar to StableDiffusionPipeline)
169 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
170 | pipeline = StableDiffusionCanvasPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True).to("cuda:0")
171 |
172 | # Mixture of Diffusers generation
173 | image = pipeline(
174 | canvas_height=640,
175 | canvas_width=1408,
176 | regions=[
177 | Text2ImageRegion(0, 640, 0, 640, guidance_scale=8,
178 | prompt=f"A charming house in the countryside, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"),
179 | Text2ImageRegion(0, 640, 384, 1024, guidance_scale=8,
180 | prompt=f"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"),
181 | Text2ImageRegion(0, 640, 768, 1408, guidance_scale=8,
182 | prompt=f"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"),
183 | ],
184 | num_inference_steps=50,
185 | seed=7178915308,
186 | )["sample"][0]
187 | ```
188 |
189 | `Image2Image` regions can also be added at any position, to use a particular image as guidance. In the following example we create a Christmas postcard by taking a photo of a building (available at this repo) and using it as a guidance in a region of the canvas.
190 |
191 | ```python
192 | from PIL import Image
193 | from diffusers import LMSDiscreteScheduler
194 | from mixdiff import StableDiffusionCanvasPipeline, Text2ImageRegion, Image2ImageRegion, preprocess_image
195 |
196 | # Creater scheduler and model (similar to StableDiffusionPipeline)
197 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
198 | pipeline = StableDiffusionCanvasPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=scheduler, use_auth_token=True).to("cuda:0")
199 |
200 | # Load and preprocess guide image
201 | iic_image = preprocess_image(Image.open("examples/IIC.png").convert("RGB"))
202 |
203 | # Mixture of Diffusers generation
204 | image = pipeline(
205 | canvas_height=800,
206 | canvas_width=352,
207 | regions=[
208 | Text2ImageRegion(0, 800, 0, 352, guidance_scale=8,
209 | prompt=f"Christmas postcard, a charming house in the countryside surrounded by snow, a giant christmas tree, under a starry night sky, by jakub rozalski and alayna danner and guweiz, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"),
210 | Image2ImageRegion(800-352, 800, 0, 352, reference_image=iic_image, strength=0.8),
211 | ],
212 | num_inference_steps=57,
213 | seed=5525475061,
214 | )["sample"][0]
215 | ```
216 |
217 | 
218 |
219 | The full list of arguments to a `StableDiffusionCanvasPipeline` is:
220 |
221 | > `canvas_height`: height in pixels of the image to generate. Must be a multiple of 8.
222 | >
223 | > `canvas_width`: width in pixels of the image to generate. Must be a multiple of 8.
224 | >
225 | > `regions`: list of `Text2Image` or `Image2Image` diffusion regions (see below).
226 | >
227 | > `num_inference_steps`: number of diffusions steps.
228 | >
229 | > `seed`: general random seed to initialize latents.
230 | >
231 | > `reroll_regions`: list of `RerollRegion` regions in which to reroll latents (see below). Useful if you like the overall aspect of the generated image, but want to regenerate a specific region using a different random seed.
232 | >
233 | > `cpu_vae`: whether to perform encoder-decoder operations in CPU, even if the diffusion process runs in GPU. Use `cpu_vae=True` if you run out of GPU memory at the end of the generation process for large canvas dimensions, or if you create large `Image2Image` regions.
234 | >
235 | > `decode_steps`: if `True` the result will include not only the final image, but also all the intermediate steps in the generation. Note: this will greatly increase running times.
236 |
237 | All regions are configured with the following parameters:
238 |
239 | > `row_init`: starting row in pixel space (included). Must be a multiple of 8.
240 | >
241 | > `row_end`: end row in pixel space (not included). Must be a multiple of 8.
242 | >
243 | > `col_init`: starting column in pixel space (included). Must be a multiple of 8.
244 | >
245 | > `col_end`: end column in pixel space (not included). Must be a multiple of 8.
246 | >
247 | > `region_seed`: seed for random operations in this region
248 | >
249 | > `noise_eps`: deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents
250 |
251 | Additionally, `Text2Image` regions use the following arguments:
252 |
253 | > `prompt`: text prompt guiding the diffuser in this region
254 | >
255 | > `guidance_scale`: guidance scale of the diffuser in this region. If None, randomize.
256 | >
257 | > `mask_type`: kind of weight mask applied to this region, must be one of `["constant", gaussian", quartic"]`.
258 | >
259 | > `mask_weight`: global weights multiplier of the mask.
260 |
261 | `Image2Image` regions are configured with the basic region parameters plus ther following:
262 |
263 | > `reference_image`: image to use as guidance. Must be loaded as a PIL image and pre-processed using the `preprocess_image` function (see example above). It will be automatically rescaled to the shape of the region.
264 | >
265 | > `strength`: strength of the image guidance, must lie in the range `[0.0, 1.0]` (from no guidance to absolute priority of the original image).
266 |
267 | Finally, `RerollRegions` accept the basic arguments plus the following:
268 |
269 | > `reroll_mode`: kind of reroll to perform, either `reset` (completely reset latents with new ones) or `epsilon` (alter slightly the latents in the region).
270 |
271 | ## Citing and full technical details
272 |
273 | If you find this repository useful, please be so kind to cite the corresponding paper, which also contains the full details about this method:
274 |
275 | > Álvaro Barbero Jiménez. Mixture of Diffusers for scene composition and high resolution image generation. https://arxiv.org/abs/2302.02412
276 |
277 | ## Responsible use
278 |
279 | The same recommendations as in Stable Diffusion apply, so please check the corresponding [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4).
280 |
281 | More broadly speaking, always bear this in mind: YOU are responsible for the content you create using this tool. Do not fully blame, credit, or place the responsibility on the software.
282 |
283 | ## Gallery
284 |
285 | Here are some relevant illustrations I have created using this software (and putting quite a few hours into them!).
286 |
287 | ### Darkness Dawning
288 |
289 | 
290 |
291 | ### Yog-Sothoth
292 |
293 | 
294 |
295 | ### Looking through the eyes of giants
296 |
297 | 
298 |
299 | [Follow me on DeviantArt for more!](https://www.deviantart.com/albarji)
300 |
301 | ## Acknowledgements
302 |
303 | First and foremost, my most sincere appreciation for the [Stable Diffusion team](https://stability.ai/blog/stable-diffusion-public-release) for releasing such an awesome model, and for letting me take part of the closed beta. Kudos also to the Hugging Face community and developers for implementing the [Diffusers library](https://github.com/huggingface/diffusers).
304 |
305 | Thanks to Instituto de Ingeniería del Conocimiento and Grupo de Aprendizaje Automático (Universidad Autónoma de Madrid) for providing GPU resources for testing and experimenting this library.
306 |
307 | Thanks also to the vibrant communities of the Stable Diffusion discord channel and [Lexica](https://lexica.art/), where I have learned about many amazing artists and styles. And to my friend Abril for sharing many tips on cool artists!
308 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: mixture_of_diffusers
2 | dependencies:
3 | - pip=22.1.*
4 | - pip:
5 | - -r requirements.txt
6 |
--------------------------------------------------------------------------------
/examples/IIC.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/albarji/mixture-of-diffusers/af42292d0a8cb414f6da2eeac79be4c60afbbe48/examples/IIC.png
--------------------------------------------------------------------------------
/examples/linearForest.json:
--------------------------------------------------------------------------------
1 | {
2 | "cpu_vae": true,
3 | "gc": 8,
4 | "gc_tiles": null,
5 | "prompt": [
6 | [
7 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
8 | "a forest, ukiyo-e, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
9 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
10 | "a forest, by velazquez, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
11 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
12 | "a forest, impressionist style by van gogh, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
13 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
14 | "a forest, cubist style by Pablo Picasso intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
15 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors",
16 | "a forest, 80s synthwave style, intricate, elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece, impressive colors"
17 | ]
18 | ],
19 | "scheduler": "lms",
20 | "seed": 639688656,
21 | "steps": 50,
22 | "tile_col_overlap": 256,
23 | "tile_height": 768,
24 | "tile_row_overlap": 256,
25 | "tile_width": 512
26 | }
--------------------------------------------------------------------------------
/generate_grid.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import datetime
3 | from diffusers import LMSDiscreteScheduler, DDIMScheduler
4 | import git
5 | import json
6 | import numpy as np
7 | from pathlib import Path
8 |
9 | from mixdiff import StableDiffusionTilingPipeline
10 |
11 | ### CONFIG START
12 | n = 3
13 | sche = "lms"
14 | gc = 8
15 | seed = None
16 | steps = 50
17 |
18 | suffix = "elegant, highly detailed, smooth, sharp focus, artstation, stunning masterpiece"
19 | prompt = [
20 | [
21 | f"A charming house in the countryside, by jakub rozalski, sunset lighting, {suffix}",
22 | f"A dirt road in the countryside crossing pastures, by jakub rozalski, sunset lighting, {suffix}",
23 | f"An old and rusty giant robot lying on a dirt road, by jakub rozalski, dark sunset lighting, {suffix}"
24 | ]
25 | ]
26 |
27 | gc_tiles = [
28 | [None, None, None],
29 | ]
30 |
31 | seed_tiles = [
32 | [None, None, None],
33 | ]
34 | seed_tiles_mode = StableDiffusionTilingPipeline.SeedTilesMode.FULL.value
35 |
36 | seed_reroll_regions = []
37 |
38 | tile_height = 640
39 | tile_width = 640
40 | tile_row_overlap = 256
41 | tile_col_overlap = 256
42 | cpu_vae = False
43 |
44 | ### CONFIG END
45 |
46 | # Prepared scheduler
47 | if sche == "ddim":
48 | scheduler = DDIMScheduler()
49 | elif sche == "lms":
50 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
51 | else:
52 | raise ValueError(f"Unrecognized scheduler {sche}")
53 |
54 | # Load model
55 | model_id = "CompVis/stable-diffusion-v1-4"
56 | pipelinekind = StableDiffusionTilingPipeline
57 | pipe = pipelinekind.from_pretrained(model_id, scheduler=scheduler, use_auth_token=True).to("cuda:0")
58 |
59 | for _ in range(n):
60 | # Prepare parameters
61 | gc_image = gc if gc is not None else np.random.randint(5, 30)
62 | steps_image = steps if steps is not None else np.random.randint(50, 150)
63 | seed_image = seed if seed is not None else np.random.randint(9999999999)
64 | seed_tiles_image = deepcopy(seed_tiles)
65 | for row in range(len(seed_tiles)):
66 | for col in range(len(seed_tiles[0])):
67 | if seed_tiles[row][col] == "RNG":
68 | seed_tiles_image[row][col] = np.random.randint(9999999999)
69 | seed_reroll_regions_image = deepcopy(seed_reroll_regions)
70 | for i in range(len(seed_reroll_regions)):
71 | row_init, row_end, col_init, col_end, seed_reroll = seed_reroll_regions[i]
72 | if seed_reroll == "RNG":
73 | seed_reroll = np.random.randint(9999999999)
74 | seed_reroll_regions_image[i] = (row_init, row_end, col_init, col_end, seed_reroll)
75 | pipeargs = {
76 | "guidance_scale": gc_image,
77 | "num_inference_steps": steps_image,
78 | "seed": seed_image,
79 | "prompt": prompt,
80 | "tile_height": tile_height,
81 | "tile_width": tile_width,
82 | "tile_row_overlap": tile_row_overlap,
83 | "tile_col_overlap": tile_col_overlap,
84 | "guidance_scale_tiles": gc_tiles,
85 | "seed_tiles": seed_tiles_image,
86 | "seed_tiles_mode": seed_tiles_mode,
87 | "seed_reroll_regions": seed_reroll_regions_image,
88 | "cpu_vae": cpu_vae,
89 | }
90 | image = pipe(**pipeargs)["sample"][0]
91 | ct = datetime.datetime.now()
92 | outname = f"{ct}_{prompt[0][0][0:100]}_{tile_height}x{tile_width}_sche{sche}_seed{seed_image}_gc{gc_image}_steps{steps_image}"
93 | outpath = "./outputs"
94 | Path(outpath).mkdir(parents=True, exist_ok=True)
95 | image.save(f"{outpath}/{outname}.png")
96 | logspath = "./logs"
97 | Path(logspath).mkdir(parents=True, exist_ok=True)
98 | with open(f"{logspath}/{outname}.json", "w") as f:
99 | json.dump(
100 | {
101 | "prompt": prompt,
102 | "tile_height": tile_height,
103 | "tile_width": tile_width,
104 | "tile_row_overlap": tile_row_overlap,
105 | "tile_col_overlap": tile_col_overlap,
106 | "scheduler": sche,
107 | "seed": seed_image,
108 | "gc": gc_image,
109 | "gc_tiles": gc_tiles,
110 | "steps": steps_image,
111 | "seed_tiles": seed_tiles_image,
112 | "seed_tiles_mode": seed_tiles_mode,
113 | "seed_reroll_regions": seed_reroll_regions_image,
114 | "cpu_vae": cpu_vae,
115 | "git_commit": git.Repo(search_parent_directories=True).head.object.hexsha,
116 | },
117 | f,
118 | sort_keys=True,
119 | indent=4
120 | )
121 |
--------------------------------------------------------------------------------
/generate_grid_from_json.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | from diffusers import LMSDiscreteScheduler, DDIMScheduler
4 | import json
5 | from pathlib import Path
6 | import torch
7 |
8 | from mixdiff.tiling import StableDiffusionTilingPipeline
9 |
10 | def generate_grid(generation_arguments):
11 | model_id = "CompVis/stable-diffusion-v1-4"
12 | # Prepared scheduler
13 | if generation_arguments["scheduler"] == "ddim":
14 | scheduler = DDIMScheduler()
15 | elif generation_arguments["scheduler"] == "lms":
16 | scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
17 | else:
18 | raise ValueError(f"Unrecognized scheduler {generation_arguments['scheduler']}")
19 | pipe = StableDiffusionTilingPipeline.from_pretrained(model_id, scheduler=scheduler, use_auth_token=True).to("cuda:0")
20 |
21 | pipeargs = {
22 | "guidance_scale": generation_arguments["gc"],
23 | "num_inference_steps": generation_arguments["steps"],
24 | "seed": generation_arguments["seed"],
25 | "prompt": generation_arguments["prompt"],
26 | "tile_height": generation_arguments["tile_height"],
27 | "tile_width": generation_arguments["tile_width"],
28 | "tile_row_overlap": generation_arguments["tile_row_overlap"],
29 | "tile_col_overlap": generation_arguments["tile_col_overlap"],
30 | "guidance_scale_tiles": generation_arguments["gc_tiles"],
31 | "cpu_vae": generation_arguments["cpu_vae"] if "cpu_vae" in generation_arguments else False,
32 | }
33 | if "seed_tiles" in generation_arguments: pipeargs = {**pipeargs, "seed_tiles": generation_arguments["seed_tiles"]}
34 | if "seed_tiles_mode" in generation_arguments: pipeargs = {**pipeargs, "seed_tiles_mode": generation_arguments["seed_tiles_mode"]}
35 | if "seed_reroll_regions" in generation_arguments: pipeargs = {**pipeargs, "seed_reroll_regions": generation_arguments["seed_reroll_regions"]}
36 | image = pipe(**pipeargs)["sample"][0]
37 | outname = "output"
38 | outpath = "./outputs"
39 | Path(outpath).mkdir(parents=True, exist_ok=True)
40 | image.save(f"{outpath}/{outname}.png")
41 |
42 | if __name__ == "__main__":
43 | parser = argparse.ArgumentParser(description='Generate a stable diffusion grid using a JSON file with all configuration parameters.')
44 | parser.add_argument('config', type=str, help='Path to configuration file')
45 | args = parser.parse_args()
46 | with open(args.config, "r") as f:
47 | generation_arguments = json.load(f)
48 | generate_grid(generation_arguments)
49 |
--------------------------------------------------------------------------------
/mixdiff/__init__.py:
--------------------------------------------------------------------------------
1 | from .canvas import Image2ImageRegion, RerollRegion, StableDiffusionCanvasPipeline, Text2ImageRegion
2 | from .imgtools import preprocess_image
3 | from .tiling import StableDiffusionTilingPipeline
--------------------------------------------------------------------------------
/mixdiff/canvas.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from dataclasses import asdict, dataclass
3 | from enum import Enum
4 | import numpy as np
5 | from numpy import pi, exp, sqrt
6 | import re
7 | import torch
8 | from torchvision.transforms.functional import resize
9 | from tqdm.auto import tqdm
10 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
11 | from typing import List, Optional, Tuple, Union
12 |
13 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
14 | from diffusers.pipeline_utils import DiffusionPipeline
15 | from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
16 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
17 |
18 |
19 | class MaskModes(Enum):
20 | """Modes in which the influence of diffuser is masked"""
21 | CONSTANT = "constant"
22 | GAUSSIAN = "gaussian"
23 | QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics)
24 |
25 |
26 | class RerollModes(Enum):
27 | """Modes in which the reroll regions operate"""
28 | RESET = "reset" # Completely reset the random noise in the region
29 | EPSILON = "epsilon" # Alter slightly the latents in the region
30 |
31 |
32 | @dataclass
33 | class CanvasRegion:
34 | """Class defining a rectangular region in the canvas"""
35 | row_init: int # Region starting row in pixel space (included)
36 | row_end: int # Region end row in pixel space (not included)
37 | col_init: int # Region starting column in pixel space (included)
38 | col_end: int # Region end column in pixel space (not included)
39 | region_seed: int = None # Seed for random operations in this region
40 | noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents
41 |
42 | def __post_init__(self):
43 | # Initialize arguments if not specified
44 | if self.region_seed is None:
45 | self.region_seed = np.random.randint(9999999999)
46 | # Check coordinates are non-negative
47 | for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
48 | if coord < 0:
49 | raise ValueError(f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})")
50 | # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space
51 | for coord in [self.row_init, self.row_end, self.col_init, self.col_end]:
52 | if coord // 8 != coord / 8:
53 | raise ValueError(f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})")
54 | # Check noise eps is non-negative
55 | if self.noise_eps < 0:
56 | raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}")
57 | # Compute coordinates for this region in latent space
58 | self.latent_row_init = self.row_init // 8
59 | self.latent_row_end = self.row_end // 8
60 | self.latent_col_init = self.col_init // 8
61 | self.latent_col_end = self.col_end // 8
62 |
63 | @property
64 | def width(self):
65 | return self.col_end - self.col_init
66 |
67 | @property
68 | def height(self):
69 | return self.row_end - self.row_init
70 |
71 | def get_region_generator(self, device="cpu"):
72 | """Creates a torch.Generator based on the random seed of this region"""
73 | # Initialize region generator
74 | return torch.Generator(device).manual_seed(self.region_seed)
75 |
76 | @property
77 | def __dict__(self):
78 | return asdict(self)
79 |
80 |
81 | @dataclass
82 | class DiffusionRegion(CanvasRegion):
83 | """Abstract class defining a region where some class of diffusion process is acting"""
84 | pass
85 |
86 |
87 | @dataclass
88 | class RerollRegion(CanvasRegion):
89 | """Class defining a rectangular canvas region in which initial latent noise will be rerolled"""
90 | reroll_mode: RerollModes = RerollModes.RESET.value
91 |
92 |
93 | @dataclass
94 | class Text2ImageRegion(DiffusionRegion):
95 | """Class defining a region where a text guided diffusion process is acting"""
96 | prompt: str = "" # Text prompt guiding the diffuser in this region
97 | guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize
98 | mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region
99 | mask_weight: float = 1.0 # Global weights multiplier of the mask
100 | tokenized_prompt = None # Tokenized prompt
101 | encoded_prompt = None # Encoded prompt
102 |
103 | def __post_init__(self):
104 | super().__post_init__()
105 | # Mask weight cannot be negative
106 | if self.mask_weight < 0:
107 | raise ValueError(f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}")
108 | # Mask type must be an actual known mask
109 | if self.mask_type not in [e.value for e in MaskModes]:
110 | raise ValueError(f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})")
111 | # Randomize arguments if given as None
112 | if self.guidance_scale is None:
113 | self.guidance_scale = np.random.randint(5, 30)
114 | # Clean prompt
115 | self.prompt = re.sub(' +', ' ', self.prompt).replace("\n", " ")
116 |
117 | def tokenize_prompt(self, tokenizer):
118 | """Tokenizes the prompt for this diffusion region using a given tokenizer"""
119 | self.tokenized_prompt = tokenizer(self.prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
120 |
121 | def encode_prompt(self, text_encoder, device):
122 | """Encodes the previously tokenized prompt for this diffusion region using a given encoder"""
123 | assert self.tokenized_prompt is not None, ValueError("Prompt in diffusion region must be tokenized before encoding")
124 | self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0]
125 |
126 |
127 | @dataclass
128 | class Image2ImageRegion(DiffusionRegion):
129 | """Class defining a region where an image guided diffusion process is acting"""
130 | reference_image: torch.FloatTensor = None
131 | strength: float = 0.8 # Strength of the image
132 |
133 | def __post_init__(self):
134 | super().__post_init__()
135 | if self.reference_image is None:
136 | raise ValueError("Must provide a reference image when creating an Image2ImageRegion")
137 | if self.strength < 0 or self.strength > 1:
138 | raise ValueError(f'The value of strength should in [0.0, 1.0] but is {self.strength}')
139 | # Rescale image to region shape
140 | self.reference_image = resize(self.reference_image, size=[self.height, self.width])
141 |
142 | def encode_reference_image(self, encoder, device, generator, cpu_vae=False):
143 | """Encodes the reference image for this Image2Image region into the latent space"""
144 | # Place encoder in CPU or not following the parameter cpu_vae
145 | if cpu_vae:
146 | # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome
147 | self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device)
148 | else:
149 | self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample(generator=generator)
150 | self.reference_latents = 0.18215 * self.reference_latents
151 |
152 | @property
153 | def __dict__(self):
154 | # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON
155 |
156 | # Get all basic fields from parent class
157 | super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()}
158 | # Pack other fields
159 | return {
160 | **super_fields,
161 | "reference_image": self.reference_image.cpu().tolist(),
162 | "strength": self.strength
163 | }
164 |
165 |
166 | @dataclass
167 | class MaskWeightsBuilder:
168 | """Auxiliary class to compute a tensor of weights for a given diffusion region"""
169 | latent_space_dim: int # Size of the U-net latent space
170 | nbatch: int = 1 # Batch size in the U-net
171 |
172 | def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor:
173 | """Computes a tensor of weights for a given diffusion region"""
174 | MASK_BUILDERS = {
175 | MaskModes.CONSTANT.value: self._constant_weights,
176 | MaskModes.GAUSSIAN.value: self._gaussian_weights,
177 | MaskModes.QUARTIC.value: self._quartic_weights,
178 | }
179 | return MASK_BUILDERS[region.mask_type](region)
180 |
181 | def _constant_weights(self, region: DiffusionRegion) -> torch.tensor:
182 | """Computes a tensor of constant for a given diffusion region"""
183 | latent_width = region.latent_col_end - region.latent_col_init
184 | latent_height = region.latent_row_end - region.latent_row_init
185 | return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight
186 |
187 | def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor:
188 | """Generates a gaussian mask of weights for tile contributions"""
189 | latent_width = region.latent_col_end - region.latent_col_init
190 | latent_height = region.latent_row_end - region.latent_row_init
191 |
192 | var = 0.01
193 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
194 | x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
195 | midpoint = (latent_height -1) / 2
196 | y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
197 |
198 | weights = np.outer(y_probs, x_probs) * region.mask_weight
199 | return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
200 |
201 | def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor:
202 | """Generates a quartic mask of weights for tile contributions
203 |
204 | The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits.
205 | """
206 | quartic_constant = 15. / 16.
207 |
208 | support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / (region.latent_col_end - region.latent_col_init - 1) * 1.99 - (1.99 / 2.)
209 | x_probs = quartic_constant * np.square(1 - np.square(support))
210 | support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / (region.latent_row_end - region.latent_row_init - 1) * 1.99 - (1.99 / 2.)
211 | y_probs = quartic_constant * np.square(1 - np.square(support))
212 |
213 | weights = np.outer(y_probs, x_probs) * region.mask_weight
214 | return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1))
215 |
216 |
217 | class StableDiffusionCanvasPipeline(DiffusionPipeline):
218 | """Stable Diffusion pipeline that mixes several diffusers in the same canvas"""
219 | def __init__(
220 | self,
221 | vae: AutoencoderKL,
222 | text_encoder: CLIPTextModel,
223 | tokenizer: CLIPTokenizer,
224 | unet: UNet2DConditionModel,
225 | scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler],
226 | safety_checker: StableDiffusionSafetyChecker,
227 | feature_extractor: CLIPFeatureExtractor,
228 | ):
229 | super().__init__()
230 | self.register_modules(
231 | vae=vae,
232 | text_encoder=text_encoder,
233 | tokenizer=tokenizer,
234 | unet=unet,
235 | scheduler=scheduler,
236 | safety_checker=safety_checker,
237 | feature_extractor=feature_extractor,
238 | )
239 |
240 | def decode_latents(self, latents, cpu_vae=False):
241 | """Decodes a given array of latents into pixel space"""
242 | # scale and decode the image latents with vae
243 | if cpu_vae:
244 | lat = deepcopy(latents).cpu()
245 | vae = deepcopy(self.vae).cpu()
246 | else:
247 | lat = latents
248 | vae = self.vae
249 |
250 | lat = 1 / 0.18215 * lat
251 | image = vae.decode(lat).sample
252 |
253 | image = (image / 2 + 0.5).clamp(0, 1)
254 | image = image.cpu().permute(0, 2, 3, 1).numpy()
255 |
256 | return self.numpy_to_pil(image)
257 |
258 | def get_latest_timestep_img2img(self, num_inference_steps, strength):
259 | """Finds the latest timesteps where an img2img strength does not impose latents anymore"""
260 | # get the original timestep using init_timestep
261 | offset = self.scheduler.config.get("steps_offset", 0)
262 | init_timestep = int(num_inference_steps * (1 - strength)) + offset
263 | init_timestep = min(init_timestep, num_inference_steps)
264 |
265 | t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps-1)
266 | latest_timestep = self.scheduler.timesteps[t_start]
267 |
268 | return latest_timestep
269 |
270 | @torch.no_grad()
271 | def __call__(
272 | self,
273 | canvas_height: int,
274 | canvas_width: int,
275 | regions: List[DiffusionRegion],
276 | num_inference_steps: Optional[int] = 50,
277 | seed: Optional[int] = 12345,
278 | reroll_regions: Optional[List[RerollRegion]] = None,
279 | cpu_vae: Optional[bool] = False,
280 | decode_steps: Optional[bool] = False
281 | ):
282 | if reroll_regions is None:
283 | reroll_regions = []
284 | batch_size = 1
285 |
286 | if decode_steps:
287 | steps_images = []
288 |
289 | # Prepare scheduler
290 | self.scheduler.set_timesteps(num_inference_steps, device=self.device)
291 |
292 | # Split diffusion regions by their kind
293 | text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)]
294 | image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)]
295 |
296 | # Prepare text embeddings
297 | for region in text2image_regions:
298 | region.tokenize_prompt(self.tokenizer)
299 | region.encode_prompt(self.text_encoder, self.device)
300 |
301 | # Create original noisy latents using the timesteps
302 | latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8)
303 | generator = torch.Generator(self.device).manual_seed(seed)
304 | init_noise = torch.randn(latents_shape, generator=generator, device=self.device)
305 |
306 | # Reset latents in seed reroll regions, if requested
307 | for region in reroll_regions:
308 | if region.reroll_mode == RerollModes.RESET.value:
309 | region_shape = (latents_shape[0], latents_shape[1], region.latent_row_end - region.latent_row_init, region.latent_col_end - region.latent_col_init)
310 | init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device)
311 |
312 | # Apply epsilon noise to regions: first diffusion regions, then reroll regions
313 | all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value]
314 | for region in all_eps_rerolls:
315 | if region.noise_eps > 0:
316 | region_noise = init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end]
317 | eps_noise = torch.randn(region_noise.shape, generator=region.get_region_generator(self.device), device=self.device) * region.noise_eps
318 | init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] += eps_noise
319 |
320 | # scale the initial noise by the standard deviation required by the scheduler
321 | latents = init_noise * self.scheduler.init_noise_sigma
322 |
323 | # Get unconditional embeddings for classifier free guidance in text2image regions
324 | for region in text2image_regions:
325 | max_length = region.tokenized_prompt.input_ids.shape[-1]
326 | uncond_input = self.tokenizer(
327 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
328 | )
329 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
330 |
331 | # For classifier free guidance, we need to do two forward passes.
332 | # Here we concatenate the unconditional and text embeddings into a single batch
333 | # to avoid doing two forward passes
334 | region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt])
335 |
336 | # Prepare image latents
337 | for region in image2image_regions:
338 | region.encode_reference_image(self.vae, device=self.device, generator=generator)
339 |
340 | # Prepare mask of weights for each region
341 | mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size)
342 | mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions]
343 |
344 | # Diffusion timesteps
345 | for i, t in tqdm(enumerate(self.scheduler.timesteps)):
346 | # Diffuse each region
347 | noise_preds_regions = []
348 |
349 | # text2image regions
350 | for region in text2image_regions:
351 | region_latents = latents[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end]
352 | # expand the latents if we are doing classifier free guidance
353 | latent_model_input = torch.cat([region_latents] * 2)
354 | # scale model input following scheduler rules
355 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
356 | # predict the noise residual
357 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"]
358 | # perform guidance
359 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
360 | noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond)
361 | noise_preds_regions.append(noise_pred_region)
362 |
363 | # Merge noise predictions for all tiles
364 | noise_pred = torch.zeros(latents.shape, device=self.device)
365 | contributors = torch.zeros(latents.shape, device=self.device)
366 | # Add each tile contribution to overall latents
367 | for region, noise_pred_region, mask_weights_region in zip(text2image_regions, noise_preds_regions, mask_weights):
368 | noise_pred[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] += noise_pred_region * mask_weights_region
369 | contributors[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] += mask_weights_region
370 | # Average overlapping areas with more than 1 contributor
371 | noise_pred /= contributors
372 | noise_pred = torch.nan_to_num(noise_pred) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion
373 |
374 | # compute the previous noisy sample x_t -> x_t-1
375 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample
376 |
377 | # Image2Image regions: override latents generated by the scheduler
378 | for region in image2image_regions:
379 | influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength)
380 | # Only override in the timesteps before the last influence step of the image (given by its strength)
381 | if t > influence_step:
382 | timestep = t.repeat(batch_size)
383 | region_init_noise = init_noise[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end]
384 | region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep)
385 | latents[:, :, region.latent_row_init:region.latent_row_end, region.latent_col_init:region.latent_col_end] = region_latents
386 |
387 | if decode_steps:
388 | steps_images.append(self.decode_latents(latents, cpu_vae))
389 |
390 | # scale and decode the image latents with vae
391 | image = self.decode_latents(latents, cpu_vae)
392 |
393 | output = {"sample": image}
394 | if decode_steps:
395 | output = {**output, "steps_images": steps_images}
396 | return output
397 |
--------------------------------------------------------------------------------
/mixdiff/extrasmixin.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | # TODO: remove after adaptaing tiling to use canvas
4 | class StableDiffusionExtrasMixin:
5 | """Mixin providing additional convenience method to Stable Diffusion pipelines"""
6 |
7 | def decode_latents(self, latents, cpu_vae=False):
8 | """Decodes a given array of latents into pixel space"""
9 | # scale and decode the image latents with vae
10 | if cpu_vae:
11 | lat = deepcopy(latents).cpu()
12 | vae = deepcopy(self.vae).cpu()
13 | else:
14 | lat = latents
15 | vae = self.vae
16 |
17 | lat = 1 / 0.18215 * lat
18 | image = vae.decode(lat).sample
19 |
20 | image = (image / 2 + 0.5).clamp(0, 1)
21 | image = image.cpu().permute(0, 2, 3, 1).numpy()
22 |
23 | return self.numpy_to_pil(image)
24 |
--------------------------------------------------------------------------------
/mixdiff/imgtools.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image, ImageFilter
4 |
5 |
6 | def preprocess_image(image):
7 | """Preprocess an input image
8 |
9 | Same as https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44
10 | """
11 | w, h = image.size
12 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
13 | image = image.resize((w, h), resample=Image.LANCZOS)
14 | image = np.array(image).astype(np.float32) / 255.0
15 | image = image[None].transpose(0, 3, 1, 2)
16 | image = torch.from_numpy(image)
17 | return 2.0 * image - 1.0
18 |
19 |
20 | def preprocess_mask(mask, smoothing=None):
21 | """Preprocess an inpainting mask"""
22 | mask = mask.convert("L")
23 | if smoothing is not None:
24 | smoothed = mask.filter(ImageFilter.GaussianBlur(smoothing))
25 | mask = Image.composite(mask, smoothed, mask) # Original mask values kept as 1, out of mask get smoothed
26 | mask.save("outputs/smoothed_mask.png") # FIXME
27 | w, h = mask.size
28 | w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
29 | mask = mask.resize((w // 8, h // 8), resample=Image.NEAREST)
30 | mask = np.array(mask).astype(np.float32) / 255.0
31 | mask = np.tile(mask, (4, 1, 1))
32 | mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
33 | mask = 1 - mask # repaint white, keep black
34 | mask = torch.from_numpy(mask)
35 | return mask
36 |
--------------------------------------------------------------------------------
/mixdiff/tiling.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | import inspect
3 | from ligo.segments import segment
4 | from typing import List, Optional, Tuple, Union
5 |
6 | import torch
7 |
8 | from tqdm.auto import tqdm
9 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
10 |
11 | from diffusers.models import AutoencoderKL, UNet2DConditionModel
12 | from diffusers.pipeline_utils import DiffusionPipeline
13 | from diffusers.schedulers import DDIMScheduler, PNDMScheduler
14 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
15 | from diffusers.schedulers import LMSDiscreteScheduler
16 |
17 | from .extrasmixin import StableDiffusionExtrasMixin
18 |
19 |
20 | class StableDiffusionTilingPipeline(DiffusionPipeline, StableDiffusionExtrasMixin):
21 | def __init__(
22 | self,
23 | vae: AutoencoderKL,
24 | text_encoder: CLIPTextModel,
25 | tokenizer: CLIPTokenizer,
26 | unet: UNet2DConditionModel,
27 | scheduler: Union[DDIMScheduler, PNDMScheduler],
28 | safety_checker: StableDiffusionSafetyChecker,
29 | feature_extractor: CLIPFeatureExtractor,
30 | ):
31 | super().__init__()
32 | self.register_modules(
33 | vae=vae,
34 | text_encoder=text_encoder,
35 | tokenizer=tokenizer,
36 | unet=unet,
37 | scheduler=scheduler,
38 | safety_checker=safety_checker,
39 | feature_extractor=feature_extractor,
40 | )
41 |
42 | class SeedTilesMode(Enum):
43 | """Modes in which the latents of a particular tile can be re-seeded"""
44 | FULL = "full"
45 | EXCLUSIVE = "exclusive"
46 |
47 | @torch.no_grad()
48 | def __call__(
49 | self,
50 | prompt: Union[str, List[List[str]]],
51 | num_inference_steps: Optional[int] = 50,
52 | guidance_scale: Optional[float] = 7.5,
53 | eta: Optional[float] = 0.0,
54 | seed: Optional[int] = None,
55 | tile_height: Optional[int] = 512,
56 | tile_width: Optional[int] = 512,
57 | tile_row_overlap: Optional[int] = 256,
58 | tile_col_overlap: Optional[int] = 256,
59 | guidance_scale_tiles: Optional[List[List[float]]] = None,
60 | seed_tiles: Optional[List[List[int]]] = None,
61 | seed_tiles_mode: Optional[Union[str, List[List[str]]]] = "full",
62 | seed_reroll_regions: Optional[List[Tuple[int, int, int, int, int]]] = None,
63 | cpu_vae: Optional[bool] = False,
64 | ):
65 |
66 | if not isinstance(prompt, list) or not all(isinstance(row, list) for row in prompt):
67 | raise ValueError(f"`prompt` has to be a list of lists but is {type(prompt)}")
68 | grid_rows = len(prompt)
69 | grid_cols = len(prompt[0])
70 | if not all(len(row) == grid_cols for row in prompt):
71 | raise ValueError(f"All prompt rows must have the same number of prompt columns")
72 | if not isinstance(seed_tiles_mode, str) and (not isinstance(seed_tiles_mode, list) or not all(isinstance(row, list) for row in seed_tiles_mode)):
73 | raise ValueError(f"`seed_tiles_mode` has to be a string or list of lists but is {type(prompt)}")
74 | if isinstance(seed_tiles_mode, str):
75 | seed_tiles_mode = [[seed_tiles_mode for _ in range(len(row))] for row in prompt]
76 | if any(mode not in (modes := [mode.value for mode in self.SeedTilesMode]) for row in seed_tiles_mode for mode in row):
77 | raise ValueError(f"Seed tiles mode must be one of {modes}")
78 | if seed_reroll_regions is None:
79 | seed_reroll_regions = []
80 | batch_size = 1
81 |
82 | # create original noisy latents using the timesteps
83 | height = tile_height + (grid_rows - 1) * (tile_height - tile_row_overlap)
84 | width = tile_width + (grid_cols - 1) * (tile_width - tile_col_overlap)
85 | latents_shape = (batch_size, self.unet.config.in_channels, height // 8, width // 8)
86 | generator = torch.Generator("cuda").manual_seed(seed)
87 | latents = torch.randn(latents_shape, generator=generator, device=self.device)
88 |
89 | # overwrite latents for specific tiles if provided
90 | if seed_tiles is not None:
91 | for row in range(grid_rows):
92 | for col in range(grid_cols):
93 | if (seed_tile := seed_tiles[row][col]) is not None:
94 | mode = seed_tiles_mode[row][col]
95 | if mode == self.SeedTilesMode.FULL.value:
96 | row_init, row_end, col_init, col_end = _tile2latent_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap)
97 | else:
98 | row_init, row_end, col_init, col_end = _tile2latent_exclusive_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, grid_rows, grid_cols)
99 | tile_generator = torch.Generator("cuda").manual_seed(seed_tile)
100 | tile_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
101 | latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(tile_shape, generator=tile_generator, device=self.device)
102 |
103 | # overwrite again for seed reroll regions
104 | for row_init, row_end, col_init, col_end, seed_reroll in seed_reroll_regions:
105 | row_init, row_end, col_init, col_end = _pixel2latent_indices(row_init, row_end, col_init, col_end) # to latent space coordinates
106 | reroll_generator = torch.Generator("cuda").manual_seed(seed_reroll)
107 | region_shape = (latents_shape[0], latents_shape[1], row_end - row_init, col_end - col_init)
108 | latents[:, :, row_init:row_end, col_init:col_end] = torch.randn(region_shape, generator=reroll_generator, device=self.device)
109 |
110 | # Prepare scheduler
111 | accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
112 | extra_set_kwargs = {}
113 | if accepts_offset:
114 | extra_set_kwargs["offset"] = 1
115 | self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
116 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas
117 | if isinstance(self.scheduler, LMSDiscreteScheduler):
118 | latents = latents * self.scheduler.sigmas[0]
119 |
120 | # get prompts text embeddings
121 | text_input = [
122 | [
123 | self.tokenizer(
124 | col,
125 | padding="max_length",
126 | max_length=self.tokenizer.model_max_length,
127 | truncation=True,
128 | return_tensors="pt",
129 | )
130 | for col in row
131 | ]
132 | for row in prompt
133 | ]
134 | text_embeddings = [
135 | [
136 | self.text_encoder(col.input_ids.to(self.device))[0]
137 | for col in row
138 | ]
139 | for row in text_input
140 | ]
141 |
142 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
143 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
144 | # corresponds to doing no classifier free guidance.
145 | do_classifier_free_guidance = guidance_scale > 1.0 # TODO: also active if any tile has guidance scale
146 | # get unconditional embeddings for classifier free guidance
147 | if do_classifier_free_guidance:
148 | for i in range(grid_rows):
149 | for j in range(grid_cols):
150 | max_length = text_input[i][j].input_ids.shape[-1]
151 | uncond_input = self.tokenizer(
152 | [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
153 | )
154 | uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
155 |
156 | # For classifier free guidance, we need to do two forward passes.
157 | # Here we concatenate the unconditional and text embeddings into a single batch
158 | # to avoid doing two forward passes
159 | text_embeddings[i][j] = torch.cat([uncond_embeddings, text_embeddings[i][j]])
160 |
161 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
162 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
163 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
164 | # and should be between [0, 1]
165 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
166 | extra_step_kwargs = {}
167 | if accepts_eta:
168 | extra_step_kwargs["eta"] = eta
169 |
170 | # Mask for tile weights strenght
171 | tile_weights = self._gaussian_weights(tile_width, tile_height, batch_size)
172 |
173 | # Diffusion timesteps
174 | for i, t in tqdm(enumerate(self.scheduler.timesteps)):
175 | # Diffuse each tile
176 | noise_preds = []
177 | for row in range(grid_rows):
178 | noise_preds_row = []
179 | for col in range(grid_cols):
180 | px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap)
181 | tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
182 | # expand the latents if we are doing classifier free guidance
183 | latent_model_input = torch.cat([tile_latents] * 2) if do_classifier_free_guidance else tile_latents
184 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
185 | # predict the noise residual
186 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings[row][col])["sample"]
187 | # perform guidance
188 | if do_classifier_free_guidance:
189 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
190 | guidance = guidance_scale if guidance_scale_tiles is None or guidance_scale_tiles[row][col] is None else guidance_scale_tiles[row][col]
191 | noise_pred_tile = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
192 | noise_preds_row.append(noise_pred_tile)
193 | noise_preds.append(noise_preds_row)
194 | # Stitch noise predictions for all tiles
195 | noise_pred = torch.zeros(latents.shape, device=self.device)
196 | contributors = torch.zeros(latents.shape, device=self.device)
197 | # Add each tile contribution to overall latents
198 | for row in range(grid_rows):
199 | for col in range(grid_cols):
200 | px_row_init, px_row_end, px_col_init, px_col_end = _tile2latent_indices(row, col, tile_width, tile_height, tile_row_overlap, tile_col_overlap)
201 | noise_pred[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += noise_preds[row][col] * tile_weights
202 | contributors[:, :, px_row_init:px_row_end, px_col_init:px_col_end] += tile_weights
203 | # Average overlapping areas with more than 1 contributor
204 | noise_pred /= contributors
205 |
206 | # compute the previous noisy sample x_t -> x_t-1
207 | latents = self.scheduler.step(noise_pred, t, latents).prev_sample
208 |
209 | # scale and decode the image latents with vae
210 | image = self.decode_latents(latents, cpu_vae)
211 |
212 | return {"sample": image}
213 |
214 | def _gaussian_weights(self, tile_width, tile_height, nbatches):
215 | """Generates a gaussian mask of weights for tile contributions"""
216 | from numpy import pi, exp, sqrt
217 | import numpy as np
218 |
219 | latent_width = tile_width // 8
220 | latent_height = tile_height // 8
221 |
222 | var = 0.01
223 | midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1
224 | x_probs = [exp(-(x-midpoint)*(x-midpoint)/(latent_width*latent_width)/(2*var)) / sqrt(2*pi*var) for x in range(latent_width)]
225 | midpoint = latent_height / 2
226 | y_probs = [exp(-(y-midpoint)*(y-midpoint)/(latent_height*latent_height)/(2*var)) / sqrt(2*pi*var) for y in range(latent_height)]
227 |
228 | weights = np.outer(y_probs, x_probs)
229 | return torch.tile(torch.tensor(weights, device=self.device), (nbatches, self.unet.config.in_channels, 1, 1))
230 |
231 |
232 |
233 | def _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
234 | """Given a tile row and column numbers returns the range of pixels affected by that tiles in the overall image
235 |
236 | Returns a tuple with:
237 | - Starting coordinates of rows in pixel space
238 | - Ending coordinates of rows in pixel space
239 | - Starting coordinates of columns in pixel space
240 | - Ending coordinates of columns in pixel space
241 | """
242 | px_row_init = 0 if tile_row == 0 else tile_row * (tile_height - tile_row_overlap)
243 | px_row_end = px_row_init + tile_height
244 | px_col_init = 0 if tile_col == 0 else tile_col * (tile_width - tile_col_overlap)
245 | px_col_end = px_col_init + tile_width
246 | return px_row_init, px_row_end, px_col_init, px_col_end
247 |
248 |
249 | def _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end):
250 | """Translates coordinates in pixel space to coordinates in latent space"""
251 | return px_row_init // 8, px_row_end // 8, px_col_init // 8, px_col_end // 8
252 |
253 |
254 | def _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap):
255 | """Given a tile row and column numbers returns the range of latents affected by that tiles in the overall image
256 |
257 | Returns a tuple with:
258 | - Starting coordinates of rows in latent space
259 | - Ending coordinates of rows in latent space
260 | - Starting coordinates of columns in latent space
261 | - Ending coordinates of columns in latent space
262 | """
263 | px_row_init, px_row_end, px_col_init, px_col_end = _tile2pixel_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap)
264 | return _pixel2latent_indices(px_row_init, px_row_end, px_col_init, px_col_end)
265 |
266 |
267 | def _tile2latent_exclusive_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap, rows, columns):
268 | """Given a tile row and column numbers returns the range of latents affected only by that tile in the overall image
269 |
270 | Returns a tuple with:
271 | - Starting coordinates of rows in latent space
272 | - Ending coordinates of rows in latent space
273 | - Starting coordinates of columns in latent space
274 | - Ending coordinates of columns in latent space
275 | """
276 | row_init, row_end, col_init, col_end = _tile2latent_indices(tile_row, tile_col, tile_width, tile_height, tile_row_overlap, tile_col_overlap)
277 | row_segment = segment(row_init, row_end)
278 | col_segment = segment(col_init, col_end)
279 | # Iterate over the rest of tiles, clipping the region for the current tile
280 | for row in range(rows):
281 | for column in range(columns):
282 | if row != tile_row and column != tile_col:
283 | clip_row_init, clip_row_end, clip_col_init, clip_col_end = _tile2latent_indices(row, column, tile_width, tile_height, tile_row_overlap, tile_col_overlap)
284 | row_segment = row_segment - segment(clip_row_init, clip_row_end)
285 | col_segment = col_segment - segment(clip_col_init, clip_col_end)
286 | #return row_init, row_end, col_init, col_end
287 | return row_segment[0], row_segment[1], col_segment[0], col_segment[1]
288 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.10.*
2 | diffusers[torch]==0.16.*
3 | ftfy==6.1.*
4 | gitpython==3.1.*
5 | ligo-segments==1.4.*
6 | torch==2.0.*
7 | torchvision==0.15.*
8 | transformers==4.29.*
9 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | from setuptools import setup
3 |
4 |
5 | # Read long description from readme
6 | with open("README.md", "r", encoding="utf-8") as fh:
7 | LONG_DESCRIPTION = fh.read()
8 |
9 |
10 | # Get tag from Github environment variables
11 | TAG = os.environ['GITHUB_TAG'] if 'GITHUB_TAG' in os.environ else "0.0.0"
12 |
13 |
14 | setup(
15 | name="mixdiff",
16 | version=TAG,
17 | description="Mixture of Diffusers for scene composition and high resolution image generation .",
18 | long_description=LONG_DESCRIPTION,
19 | long_description_content_type="text/markdown",
20 | packages=['mixdiff'],
21 | install_requires=[
22 | 'numpy>=1.19,<2',
23 | 'tqdm>=4.62,<5',
24 | 'scipy==1.10.*',
25 | 'diffusers[torch]==0.16.*',
26 | 'ftfy==6.1.*',
27 | 'gitpython==3.1.*',
28 | 'ligo-segments==1.4.*',
29 | 'torch==2.0.*',
30 | 'torchvision==0.15.*',
31 | 'transformers==4.29.*'
32 | ],
33 | author="Alvaro Barbero",
34 | url='https://github.com/albarji/mixture-of-diffusers',
35 | license='MIT',
36 | classifiers=[
37 | 'Development Status :: 1 - Planning',
38 | 'Environment :: GPU :: NVIDIA CUDA',
39 | 'Intended Audience :: Science/Research',
40 | 'License :: OSI Approved :: MIT License',
41 | 'Operating System :: Unix',
42 | 'Programming Language :: Python :: 3.8',
43 | 'Programming Language :: Python :: 3.9',
44 | 'Programming Language :: Python :: 3.10',
45 | 'Topic :: Artistic Software',
46 | 'Topic :: Multimedia :: Graphics'
47 | ],
48 | keywords='artificial-intelligence, deep-learning, diffusion-models',
49 | test_suite="pytest",
50 | )
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/albarji/mixture-of-diffusers/af42292d0a8cb414f6da2eeac79be4c60afbbe48/tests/__init__.py
--------------------------------------------------------------------------------
/tests/canvas_test.py:
--------------------------------------------------------------------------------
1 | from diffusers import LMSDiscreteScheduler
2 | from PIL import Image
3 | import pytest
4 |
5 | from mixdiff import Image2ImageRegion, StableDiffusionCanvasPipeline, Text2ImageRegion, preprocess_image
6 | from mixdiff.canvas import CanvasRegion
7 |
8 | ### CanvasRegion tests
9 |
10 | @pytest.mark.parametrize("region_params", [
11 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512},
12 | {"row_init": 0, "row_end": 256, "col_init": 0, "col_end": 512},
13 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 256},
14 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "region_seed": 12345},
15 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "noise_eps": 0.1}
16 | ])
17 | def test_create_canvas_region_correct(region_params):
18 | """Creating a correct canvas region with basic parameters works"""
19 | region = CanvasRegion(**region_params)
20 | assert region.row_init == region_params["row_init"]
21 | assert region.row_end == region_params["row_end"]
22 | assert region.col_init == region_params["col_init"]
23 | assert region.col_end == region_params["col_end"]
24 | assert region.height == region.row_end - region.row_init
25 | assert region.width == region.col_end - region.col_init
26 |
27 | def test_create_canvas_region_eps():
28 | """Creating a correct canvas region works"""
29 | CanvasRegion(0, 512, 0, 512)
30 |
31 | def test_create_canvas_region_non_multiple_size():
32 | """Creating a canvas region with sizes that are not a multiple of 8 fails"""
33 | with pytest.raises(ValueError):
34 | CanvasRegion(0, 17, 0, 15)
35 |
36 | def test_create_canvas_region_negative_indices():
37 | """Creating a canvas region with negative indices fails"""
38 | with pytest.raises(ValueError):
39 | CanvasRegion(-512, 0, -256, 0)
40 |
41 | def test_create_canvas_region_negative_eps():
42 | """Creating a canvas region with negative epsilon noise fails"""
43 | with pytest.raises(ValueError):
44 | CanvasRegion(0, 512, 0, 512, noise_eps=-3)
45 |
46 | ### Text2ImageRegion tests
47 |
48 | @pytest.mark.parametrize("region_params", [
49 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512},
50 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers"},
51 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15.},
52 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15., "mask_type": "constant", "mask_weight": 1.0},
53 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15., "mask_type": "gaussian", "mask_weight": 0.75},
54 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "prompt": "Pikachu unit-testing Mixture of Diffusers", "guidance_scale": 15., "mask_type": "quartic", "mask_weight": 1.5},
55 | ])
56 | def test_create_text2image_region_correct(region_params):
57 | """Creating a Text2Image region with correct parameters works"""
58 | region = Text2ImageRegion(**region_params)
59 | if "prompt" in region_params: assert region.prompt == region_params["prompt"]
60 | if "guidance_scale" in region_params: assert region.guidance_scale == region_params["guidance_scale"]
61 | if "mask_type" in region_params: assert region.mask_type == region_params["mask_type"]
62 | if "mask_weight" in region_params: assert region.mask_weight == region_params["mask_weight"]
63 |
64 | def test_create_text2image_region_negative_weight():
65 | """We can't specify a Text2Image region with mask weight"""
66 | with pytest.raises(ValueError):
67 | Text2ImageRegion(0, 512, 0, 512, prompt="Pikachu unit-testing Mixture of Diffusers", mask_type="gaussian", mask_weight=-0.1)
68 |
69 | def test_create_text2image_region_unknown_mask():
70 | """We can't specify a Text2Image region with mask not in the recognized masks list"""
71 | with pytest.raises(ValueError):
72 | Text2ImageRegion(0, 512, 0, 512, prompt="Link unit-testing Mixture of Diffusers", mask_type="majora", mask_weight=1.0)
73 |
74 | ### Image2ImageRegion tests
75 |
76 | @pytest.fixture(scope="session")
77 | def base_image():
78 | return preprocess_image(Image.open("examples/IIC.png").convert("RGB"))
79 |
80 | @pytest.mark.parametrize("region_params", [
81 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512},
82 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "strength": 0.1},
83 | {"row_init": 0, "row_end": 512, "col_init": 0, "col_end": 512, "strength": 0.9},
84 | ])
85 | def test_create_image2image_region_correct(region_params, base_image):
86 | """Creating a Image2Image region with correct parameters works"""
87 | region = Image2ImageRegion(**region_params, reference_image=base_image)
88 | assert region.reference_image.shape == (1, 3, region.height, region.width)
89 |
90 | @pytest.mark.parametrize("region_params", [
91 | {"row_init": 0, "row_end": 256, "col_init": 0, "col_end": 256, "strength": -0.3},
92 | {"row_init": 0, "row_end": 256, "col_init": 0, "col_end": 256, "strength": 1.1}
93 | ])
94 | def test_create_image2image_region_negative_strength(region_params, base_image):
95 | """We can't specify an Image2Image region with strength values outside of [0, 1]"""
96 | with pytest.raises(ValueError):
97 | Image2ImageRegion(**region_params, reference_image=base_image)
98 |
99 | ### StableDiffusionCanvasPipeline tests
100 |
101 | @pytest.fixture(scope="session")
102 | def canvas_pipeline():
103 | return StableDiffusionCanvasPipeline.from_pretrained(
104 | "CompVis/stable-diffusion-v1-4",
105 | scheduler=LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000),
106 | use_auth_token=True
107 | )
108 |
109 | @pytest.fixture()
110 | def basic_canvas_params():
111 | return {
112 | "canvas_height": 64,
113 | "canvas_width": 64,
114 | "regions": [
115 | Text2ImageRegion(0, 48, 0, 48, mask_type="gaussian", prompt="Something"),
116 | Text2ImageRegion(16, 64, 0, 48, mask_type="gaussian", prompt="Something else"),
117 | Text2ImageRegion(0, 48, 16, 64, mask_type="gaussian", prompt="Something more"),
118 | Text2ImageRegion(16, 64, 16, 64, mask_type="gaussian", prompt="One last thing"),
119 | ]
120 | }
121 |
122 | @pytest.mark.parametrize("extra_canvas_params", [
123 | {"num_inference_steps": 1},
124 | {"num_inference_steps": 1, "cpu_vae": True},
125 | ])
126 | def test_stable_diffusion_canvas_pipeline_correct(canvas_pipeline, basic_canvas_params, extra_canvas_params):
127 | """The StableDiffusionCanvasPipeline works for some correct configurations"""
128 | image = canvas_pipeline(**basic_canvas_params, **extra_canvas_params)["sample"][0]
129 | assert image.size == (64, 64)
130 |
131 | @pytest.mark.parametrize("extra_canvas_params", [
132 | {"num_inference_steps": 3},
133 | {"num_inference_steps": 3, "cpu_vae": True},
134 | ])
135 | def test_stable_diffusion_canvas_pipeline_image2image_correct(canvas_pipeline, basic_canvas_params, base_image, extra_canvas_params):
136 | """The StableDiffusionCanvasPipeline works for some correct configurations when including a Text2ImageRegion"""
137 | all_canvas_params = {**basic_canvas_params, **extra_canvas_params}
138 | all_canvas_params["regions"] += [Image2ImageRegion(16, 64, 0, 48, reference_image=base_image, strength=0.5)]
139 |
140 | image = canvas_pipeline(**all_canvas_params)["sample"][0]
141 | assert image.size == (64, 64)
142 |
--------------------------------------------------------------------------------