├── .gitignore ├── .style.yapf ├── COMPOSITION.md ├── LICENSE ├── README.md ├── __init__.py ├── __main__.py ├── composition ├── embeds.py ├── guide.py └── schema.py ├── encode └── clip.py ├── experiments ├── city_photo.png ├── city_photo_forest_blend.png ├── creepy_tree.png ├── creepy_tree_mapped.png ├── deer_base.png ├── deer_img2img_base.png ├── deer_img2img_defaults.png ├── deer_mod.webp ├── deer_modded_defaults.png ├── deer_tuned.png ├── deer_tuned_linear_only.png ├── settings.png ├── turtle_base.png ├── turtle_mod.webp ├── turtle_modded_defaults.png ├── turtle_to_rock_monkey.png ├── turtle_tuned.png ├── ui.png ├── zeus.webp ├── zeus_anime.png ├── zeus_mod.webp ├── zeus_modded_defaults.png └── zeus_tuned.png ├── guidance.py ├── interface ├── composer.py └── sandbox.py ├── pipeline ├── flex.py └── guide.py ├── requirements.txt ├── ui.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | old/ 3 | outputs/ 4 | dev_notes.md -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | based_on_style = google 3 | column_limit = 80 4 | indent_width = 4 5 | spaces_before_comment = 1 6 | blank_line_before_nested_class_or_def = false 7 | split_before_arithmetic_operator = true 8 | split_before_bitwise_operator = true 9 | split_before_logical_operator = true 10 | -------------------------------------------------------------------------------- /COMPOSITION.md: -------------------------------------------------------------------------------- 1 | # Composition Experimentation 2 | 3 | Composition is accessble through the second tab in the UI. 4 | Right now it functions as a poor proof of concept, but shows 5 | how guidance can be applied to a certain area. 6 | 7 | ## Why? 8 | 9 | A compositional meta language or structure is essential for repeatable template driven asset generation... 10 | Say you want to generate unique characters portraits for an entire game.. You would want to define a general artistic style, a base prompt, and then a modifier prompt or prompts that further tune aspects of the character.. a simple example could be face. 11 | 12 | ## What? 13 | 14 | The current implmentation here takes multiple prompts, and then blends their noise guidance within a certain space. 15 | What is interesting is that the "background" or main prompt, will adapt the space and noise around it to match the targeted guidance space. 16 | A simple example is to place a bear in the middle of the image, and provide no background prompt. You will see that it guides the background around the bear, to match the heavily guided bear area in the middle. 17 | 18 | ## How? 19 | 20 | See: `composition/guide.py` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tim Speed 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # "FlexDiffuse" 2 | ## *An adaptation of Stable Diffusion with Image Guidance.* 3 | 4 | ## TLDR 5 | 6 | You can find the code for all of this in `guidance.py` feel free to copy and 7 | reference as you desire, I provide no license or warranty. You can probably 8 | `pip install` this repo as a module and reference the file that way too if 9 | you'd like, I plan to make some additional improvements. 10 | 11 | ### What this is 12 | 13 | - Methods to map styles of an image using CLIP image embeddings to the CLIP text 14 | embeddings after encoding the prompt to allow reaching new spaces in prompt 15 | guidance. 16 | - A hopefully easy to use webui for generating images 17 | 18 | ### What this isn't 19 | 20 | - **Img2Img**: The methods shown in this repo apply to image guidance, not the 21 | orgin, and in this demo I show how Img2Img can be used in combination. 22 | - **Textual Inversion**: Requires training / tuning a model, this just modifies 23 | embeddings, and is only good at some style transfer not reconstructing subjects, 24 | as Textual Inversion has been demonstrated to do. I would like to integrate 25 | Textual Inversion into this demo at somepoint. 26 | 27 | ### To Run the demo 28 | 29 | 0. Ensure you have at least **python 3.10** installed and activated in your env. 30 | 1. `pip install -r requirements.txt` 31 | 2. `python ui.py --dl` 32 | 3. The --dl param can be omitted after first run, it's used to connect to 33 | huggingface to download the models, you'll need to setup your account and 34 | all that, see https://huggingface.co and Stable Diffusion there for more 35 | details on how to get access. 36 | 37 | ### Image guidance method parameters 38 | 39 | - **Clustered** - This increases or decreases the association of the 40 | CLIP embeddings with their strongest related features and the features 41 | around them with a linear reduction in weight. 42 | - **Linear** - This increases or decreases the association of the trailing 43 | embedding features. In StableDiffusion this means more stylistic features, 44 | where as the front-most features tend to represent the more major concepts 45 | of the image. 46 | - **Threshold** - This increases or decreases the similarity 47 | of associated features above or below a relative threshold defined by 48 | similarity.. average similarity ( across all features ) by default. 49 | 50 | ### "Help, my image looks like garbage" 51 | 52 | Could be that CLIP doesn't have a good interpretation of it or that it doesn't 53 | align well with the prompt at all. 54 | You can set **Threshold** and **Clustered** guidance to 0.0-0.1 and try to 55 | increment from there, but if there's no alignment then it's probably not a 56 | good fit at all.. **Linear** Style guidance should be fine to use though in these 57 | cases, though you may still get mixed results, they wont be as bad as with the 58 | other options. 59 | 60 | In general I've found these methods work well to transfer styles from portraits 61 | ( photos, or artificial ) and more abstract / background scenes. 62 | CLIP alone does not have a great understanding of the physical world, so you 63 | wont be able to transfer actions or scene composition. 64 | 65 | !['UI Scrennshot'](experiments/ui.png) 66 | 67 | ## Abstract 68 | 69 | Stable Diffusion is great, text prompts are alright... While powerful, they 70 | suck at expressing nuance and complexity; even in our regular speech we often 71 | underexpress meaning due to the encoding cost ( time, energy, etc. ) in our 72 | words and instead we often rely on environmental or situational context to give 73 | meaning. 74 | CLIP text embeddings suffer from this same problem, by being a sort of average 75 | mapping to image features. Stable Diffusion by their use has contraints on its 76 | expressive ability, concepts are effectively "lost in translation", as it's 77 | constrained by the limited meaning derived from text. 78 | In this repo I demostrate several methods that can be used to augment text 79 | prompts to better guide / tune image generation; each with their own use case. 80 | This differs from Img2Img techniques I've currently seen, which transform 81 | ( generate from ) an existing image. I also show how this traditional 82 | Img2Img technique can be combined with these image guidance methods. 83 | 84 | ## Disclaimer 85 | 86 | ### Language 87 | 88 | I use terms like *embedding*, *features*, *latents* somewhat loosely and 89 | interchangeably, so take that as you will. If I mention *token* it could also 90 | mean the above, just of a text *token* from CLIP. If I mention *mappings* it 91 | probably relates somewhat to the concepts of *attention*, as *attention* is a 92 | core component of these transformer models, and this entire work is based on 93 | mapping to an *attention* pattern or *rythm* of text *embeddings* that Stable 94 | Diffusion will recognize, so that image *features* can be leveraged in image 95 | generation guidance, not just as source material. 96 | 97 | ### About Me 98 | 99 | I am by no means an expert on this, I've only been doing ML for a couple years 100 | and have a lot to learn, though I do have a lot of translateable skills after 101 | 15+ years of writing code and other things... 102 | I may make assumptions in this work that are incorrect, and I'm always looking 103 | to improve my knowledge if you'll humor me. 104 | 105 | ## Process 106 | 107 | ### The Problem 108 | 109 | Though Stable Diffusion was trained with images and descriptions of certain 110 | characters and concepts, it is not always capable of representing them through 111 | text. 112 | 113 | For example just try and generate an image of Toad ( The Mario Character), 114 | the closest I've gotten is with `toad nintendo character 3d cute anthropomorphic 115 | mushroom wearing a blue vest, brown shoes and white pants` and while this gets 116 | things that kind of look like him, I have yet to generate one. 117 | Toad should exist somewhere within the Stable Diffusion model, but I can't guide 118 | it to create him. 119 | 120 | Why... Well in this case `Mario` is such a strong signal putting his name in any 121 | prompt will hijack it, even `Nintendo` itself is such a strong term compared to 122 | `Toad` himself... and `Toad` alone is overloaded, because it's used to represent 123 | and amphibian... 124 | 125 | Beyond hidden characters there are plenty of scenes or concepts that Stable 126 | Diffusion can't generate well because of prompts, and sure if you cycle through 127 | hundreds or thousands of seeds you might eventually land on the right one, 128 | but what if there was a way, beyond prompts alone to guide the diffusion 129 | process... 130 | 131 | *CLIP ( The model used by Stable Diffusion to encode text ), was desinged to 132 | encode text and images into a similar embedding space. This allows text and 133 | images to be linked, we can say if these embeddings look x% similar to these 134 | other embeddings, they're likely to be the same thing.* 135 | 136 | So why not use CLIP image embeddings to guide image generation as well? 137 | - Turns out that the embeddings for images are in a different structure than 138 | the text ones. 139 | - The structure of the text embeddings provides meaning to Stable Diffusion 140 | through what I'm assuming to be the attention mechanism of transformers. 141 | - Giving unstructured or sparse embeddings will generate noise, and textures. 142 | 143 | The easiest way I could give the image embeddings structure, was to map them to 144 | text, I had done this before when using CLIP directly, so I thought I'd try it 145 | here, and best fit the image embeddings to the text embeddings encoded from the 146 | text prompt. 147 | 148 | At first I saw noise and got a little discouraged, but then thought maybe I 149 | should amplify based on similarity because I'm kind of in the same boat as 150 | before as not all image features mapped well to the text embeddings, there were 151 | usually just a handful out of all **77**... 152 | 153 | ### Linear integration 154 | 155 | The first method I tried was a sort of linear exchange between image embeddings 156 | and text embeddings ( after they were mapped with best fit ). I knew Stable 157 | Diffusion gives the most attention to the front of the prompt, and was able 158 | to validate that by seeing my image get scrambled when I amplified the best 159 | fit image embeddings at the front vs the back. 160 | 161 | The back seems to be more of a "Style" end, where all the embeddings contribute 162 | to more image aesthetic than anything.. to the point where they could probably 163 | be zeroed and you'd still get nice images. 164 | 165 | So I settled on just doing a reverse linspace from 0-1 from the back of the 166 | embedding tensor all the way to the front, and this actually had some really 167 | nice results.. it would translate photo aesthetics pretty well, and even some 168 | information about background. So this is what I kept as the "Linear" parameter 169 | you'll see in the code. 170 | 171 | ### Clustered integration 172 | 173 | After my linear experiments I thought maybe I should just try mapping to the 174 | features that there was a strong connection to, and just modifying the features 175 | around it a bit. This started as mutating from either the peaks or the valleys 176 | of the mapped features ( mapped by similarity ) 177 | 178 | I found that mutating the valleys messed things up, as the features didn't have 179 | a strong connection, only subtle changes were acceptable, so I decided to remove 180 | the valley portion and just focused on maximizing the translation at the peaks. 181 | 182 | Continued experiments find that this method is not great, but it is kind of a 183 | neat way of doing some residual tweening between prompt and image like linear. 184 | 185 | ### Threshold integration 186 | 187 | The idea of this is that by setting a global threshold for mapped features you 188 | could tune just the embeddings that passed the threshold, this is somewhat how 189 | the clustered integration works, but not quite; this would be a little more 190 | direct and I think would give really good results, especially in a multi-image 191 | guided case. 192 | 193 | ### Prompt Mapping 194 | 195 | The idea here was to selectivly map image concept based on another prompt. 196 | This is done by relating a "Mapping Concepts" prompt to the image, and then 197 | relating that prompt to the base prompt with a high alignment threshold. 198 | This method can work quite well, and it gives much more control vs the direct 199 | threshold mapping method. 200 | 201 | ### The header embedding and direct image guidance 202 | 203 | I've learned that some of my earlier experiments mapping to the front were most 204 | likely ruined by what I now refer to as the dreaded header embedding... 205 | This thing would ideally not exist, and should probably just be hardcoded as 206 | so far, it doesn't really seem to change... 207 | It is the first embedding, and it seems to be consistent among a few prompts I 208 | tested, altering it causes a lot of noise and really messes with guidance. 209 | My guess "not being well versed in transformers" is that it is used as an anchor 210 | for the attention mechanism allowing the guidance of the pattern or rythm in 211 | the space of numbers after it... 212 | If I am right about it.. it would be nice to have it removed from the model 213 | or hardcoded in forward functions. 214 | 215 | So anyway, by skipping this token, and not modifying it, modification of the 216 | earlier embeddings from index 1 onwards actually works well, and I have included 217 | a demonstration of pure image prompt guidance, and shifted all embedding 218 | modification to ignore index 0.. 219 | 220 | Direct image guidance currently functions off the "pooling" embedding as the 221 | primary embedding, which is probably not ideal, the higher index "hidden" 222 | embeddings will probably lead to more image features when placed in series, 223 | as the pooling embedding might have too much compressed meaning, as it seems 224 | to generate very random things.. 225 | 226 | This may also just be more of an issue with the pattern or rythm of the CLIP 227 | embeddings from image being nowhere near text.. and some kind of sequencer 228 | is needed, like maybe BLIP.. 229 | 230 | ### Actually, we can modify the header embedding a bit... 231 | 232 | So actually it seems the attention mechanism actually supports a little tuning 233 | of the header, but not too much. So we can adjust this with a new setting!! 234 | I've added a cap, default = 0.15 for modification of the header. 235 | Seems to work well. If you go to far.. like around 0.35 or so you'll start to 236 | get weird noise patterns. 237 | 238 | ### Text2Text Blending 239 | 240 | Seems to work pretty easily since both are text and in a compatible format for 241 | SD's attention mechanism, tweening can be done with any of the provided methods 242 | so I just slapped a prompt in a tab, so you can toggle between one or the other. 243 | The end goal is to make this blends layerable with a better UI, if you're 244 | familiar with GIMP or Photoshop: just think of the layers and blend settings 245 | on them.. except in this case blend settings would only exist between layers. 246 | 247 | ## Experiments 248 | 249 | Everything here I've sourced from myself or https://lexica.art/ so there 250 | shouldn't be any copyright issues... 251 | 252 | All experiments were run with the following settings and **seed** 1337 unless 253 | otherwise mentioned: 254 | - Diffusion Stength = 0.6 255 | - Steps = 30 256 | - Batches = 4 257 | - Guidance Scale = 8 258 | - Init Height & Width = 512 259 | - Threshold Match Guidance = 0.25 260 | - Threshold Floor = 0.75 261 | - Clustered Match Guidance = 0.25 262 | - Linear Style Guidance Start = 0.0 263 | - Linear Style Guidance End = 0.5 264 | - Max Image Guidance = 0.35 265 | - Max Image Header Mult = 0.0 266 | - Optimal Fit Mapping with reused Latents 267 | 268 | ![Default Settings](experiments/settings.png) 269 | 270 | ### "Deer colorful, fantasy, intricate, highly detailed, digital painting, hq, trending on artstation, illustration, lovecraftian dark ominous eldritch" 271 | 272 | Inspired by: https://lexica.art/?prompt=bf174f36-2be8-475b-8719-76fbf08f6149 273 | 274 | #### Base Generation 275 | 276 | ![Generated Base Images](experiments/deer_base.png) 277 | 278 | #### Img2Img 279 | 280 | This time we'll modify a generated image using the same prompt 281 | 282 | ![Selected Image for Img2Img](experiments/deer_img2img_base.png) 283 | 284 | After Img2Img ( Without Image Guidance ) 285 | 286 | ![Deer again after Img2Img](experiments/deer_img2img_defaults.png) 287 | 288 | #### Modifier 289 | 290 | https://lexica.art/prompt/225c52aa-fd4b-4d17-822e-3721b7eb9c05 291 | ![Guidance Image to apply to Prompt](experiments/deer_mod.webp) 292 | 293 | #### Applied with Defaults 294 | 295 | Except you actually need this due to it's introduction: 296 | - Threshold Floor = 0.2 297 | 298 | ![Generated with image defaults](experiments/deer_modded_defaults.png) 299 | 300 | The differences here are extremely subtle, see if you can spot them, this 301 | indicates high alignment with the prompt. Let's amp up the settings a bit. 302 | 303 | #### Tuned Settings 304 | 305 | - Threshold Mult = 1.0 306 | - Threshold Floor = 0.2 307 | - Clustered = 0.0 ( Turned off cause it gave the deer a warped tree neck ) 308 | - Linear End = 1.0 309 | - Max Image Guidance = 1.0 310 | 311 | ![Generated with tuned settings](experiments/deer_tuned.png) 312 | 313 | Much spooky! 314 | 315 | Threshold generally works better than clustered, but only if there's good 316 | alignment between prompt and image... I need to add a setting to adjust the 317 | actual threshold and not just the intensity. Right now the threshold is 318 | computed as an average of mapped latent similarities, instead it should be 319 | a specific value between, probably around 0.5 .. 320 | 321 | Now watch what happens if we just use linear, setting `threshold = 0` 322 | 323 | ![Generated with linear only](experiments/deer_tuned_linear_only.png) 324 | 325 | All that spookyness went away, and we only see a slight variation from the 326 | base img2img generation, showing how little the trailing latents have an effect 327 | on the guidance of theses images, which is not always the case. 328 | 329 | ### Rock Monkey ??? Custom threshold experiment 330 | 331 | Using a newly added slider to adjust the threshold, and the modifier image from 332 | the deer experiment and applying it to the turtle prompt and adjusting settings 333 | to: 334 | - Threshold Mult = 1.0 335 | - Threshold Floor = 0.1 336 | - Clustered = 0.0 337 | - Linear End = 1.0 338 | - Max Image Guidance = 1.0 339 | 340 | ![Rock Monkey?](experiments/turtle_to_rock_monkey.png) 341 | 342 | We get whatever this is... showing that sometimes embeddings with a low 343 | similarity of 0.1 - 0.2 ( it was still a turtle at 0.2 ) can still map in and 344 | generate something.. this isn't always the case. As with the original turtle 345 | modifier image, the wavy noise pattern maps in below a 0.75 threshold and takes 346 | us far away from anything resembling an animal. 347 | 348 | To me this points towards a potential mapping of core types of images to 349 | different settings. I could probably judge overall prompt similarity with the 350 | modifier image, as well as test different core classes of images "portrait", 351 | "photo", "painting", "group" .. and with that infer ideal defaults. 352 | 353 | 354 | ### "a creepy tree creature, 8k dslr photo" - Concept Mapping 355 | 356 | I added a feature which allows direct mapping matched features through a second 357 | prompt. The prompt does its mapping on the image, which can then be mapped to 358 | the first prompt. 359 | 360 | ![Creepy Tree](experiments/creepy_tree.png) 361 | 362 | #### After applying modifier 363 | 364 | ![Creepy Tree Mapped](experiments/creepy_tree_mapped.png) 365 | 366 | By specifying "creepy tree" in the second prompt and using the modifier image 367 | from the deer experiment, I was able to specifically target and map features 368 | of the guidance image to the base prompt. 369 | 370 | For this experiment I used all default settings, but turned the default image 371 | guidance features all to 0: 372 | - Threshold Mult = 0 373 | - Clustered = 0 374 | - Linear End = 0 375 | - Max Image Guidance = 0 376 | 377 | This concept mapping is currently implemented after all these others as a sort 378 | of override. You do not need to turn those features off, but I did it to 379 | demonstrate. 380 | 381 | ### Text2Text tweening "an urban landscape, city, dslr photo" 382 | 383 | Using the following settings modifications: 384 | - Linear Start = 0.1 385 | 386 | Technically that's default at the time of writing this, but other experiments 387 | don't use it. 388 | 389 | #### Base Image 390 | 391 | ![City Photos](experiments/city_photo.png) 392 | 393 | #### With Guide Text: "a painting of the deep woods, forest" 394 | Make sure to clear out any guide image if replicating... 395 | 396 | ![Trees in City Painting](experiments/city_photo_forest_blend.png) 397 | 398 | Pretty neat! 399 | 400 | ## Archived Experiments - pay close attention to replicate 401 | 402 | They are broken by the introduction of threshold floor or removal of 403 | modification to the "header" embedding. 404 | Header embedding modification has been readded, so it is possible to generate 405 | these again, I just need to go through and find the ideal settings as its now 406 | an adjustable cap. 407 | 408 | 409 | ### Img2Img Zeus portrait to Anime? 410 | 411 | The introduction of threshold floor has broken the exact replication of this 412 | generation, as the auto calculated similarity or alignment used for this is 413 | not enabled as a setting, I believe it is 0.1663 ... 414 | 415 | Initial Image: 416 | https://lexica.art/prompt/6aed4d25-a58e-4461-bc7e-883f90acb6ed 417 | 418 | ![Zeus](experiments/zeus.webp) 419 | 420 | #### Img2Img: "anime portrait of a strong, masculine old man with a curly white beard and blue eyes, anime drawing" 421 | 422 | Cranking *Diffusion Strength up to 0.7* here to get more of a transformation 423 | and *Steps to 40* 424 | 425 | ![Anime Zeus](experiments/zeus_anime.png) 426 | 427 | #### Now lets try and get a little more style guidance from an image 428 | 429 | https://lexica.art/prompt/82bfb12d-86c7-412f-ab27-d3a08d9017af 430 | 431 | ![Mod Image for Zeus](experiments/zeus_mod.webp) 432 | 433 | #### Zeus with default Image Guidance 434 | 435 | ![Zeus modded with defaults from above](experiments/zeus_modded_defaults.png) 436 | 437 | #### Tuned Settings 438 | 439 | - Threshold = 1.0 440 | - Clustered = -0.5 441 | - Linear End = 1.0 442 | - Max Image Guidance = 1.0 443 | 444 | ![Zeus modded and tuned](experiments/zeus_tuned.png) 445 | 446 | Here we're able to take some of the finer aesthetics and amplify some matching 447 | features. 448 | 449 | ### "a photo of a turtle, hd 8k, dlsr photo" 450 | The introduction of threshold floor has broken the exact replication of this 451 | generation, as the auto calculated similarity or alignment used for this is 452 | not enabled as a setting, I believe it is 11.37% 453 | 454 | #### Base Generation 455 | 456 | ![Generated Base Images](experiments/turtle_base.png) 457 | 458 | #### Modifier 459 | 460 | https://lexica.art/prompt/e1fdbf56-a71c-43eb-ac4b-347bacf7c496 461 | ![Guidance Image to apply to Prompt](experiments/turtle_mod.webp) 462 | 463 | #### Applied with Defaults 464 | 465 | ![Generated with image defaults](experiments/turtle_modded_defaults.png) 466 | 467 | #### Tuned Settings 468 | 469 | - Threshold = -0.25 470 | - Clustered = 0.0 471 | - Linear End = 1.0 472 | - Max Image Guidance = 1.0 473 | 474 | ![Generated with tuned settings](experiments/turtle_tuned.png) 475 | 476 | You can see we've guided the generation from this seed in a new direction 477 | while keeping true to the prompt. 478 | 479 | Explanation ( Guess ): 480 | - Negative Threshold setting moves us away from matched concepts between prompt 481 | and image in linear space. 482 | - High Linear setting moves us towards minor stylistic details, fidelity, 483 | texture, color... 484 | 485 | ### Composition, text?? 486 | 487 | Goal: Text based composition research... eventually implement something that 488 | can replace the need for drawing tool integrations for those comfortable 489 | with a language ( natural or otherwise ) guided approach.. the ideal end 490 | state of this may be similar to a programming language: 491 | 492 | #### Research 493 | - From my friend Dom: 494 | - https://primer.ought.org/ 495 | - https://arxiv.org/abs/1909.05858 496 | - "Transformer-XL and retreival augmented transformers have mechanisms 497 | for packing more information into embeddings by retrieving context" 498 | - "You can also take the pyramid approach that is popular in retreival 499 | and summarization where you have multiple seperate embeddings and 500 | have a “fusion” component that is trained to compose them" 501 | - From StableDiffusion Discord: 502 | - https://openreview.net/forum?id=PUIqjT4rzq7 503 | 504 | #### Determination 505 | 506 | After reviewing the work titled: **"Training-Free Structured Diffusion 507 | Guidance for Compositional Text-to-Image Synthesis"** ... 508 | 509 | I have come to the conclusion that the model architecture of Stable 510 | Diffusion most likely cannot support complex composition through embedding 511 | instructions without sacrificing other components in the model.. 512 | 513 | Case being: many artists have unrealistic styles that ignore the rules of 514 | stationary real world physics ( 3d positioning, scale, rotation .. ) the 515 | model's ability to transfer these styles between foreground and background 516 | concepts is one of its strengths and directly conflicts with its ability to 517 | hold an accurate representation. 518 | 519 | While it may be possible to train or modify Stable Diffusion to do both I 520 | believe it's current architecture is likely not "deep" or complex enough 521 | to do both well; as composition in a physical sense requires it's own style 522 | transfer in how the physics are perceived. 523 | 524 | I believe this may be better if a model were able to have a sort of 3d 525 | understanding of an image, and then be able to transfer style between 526 | different interpretations of 3d physics... 527 | 528 | But for now I have Stable Diffusion, and one way to treat it, is as a brush, 529 | A brush that can handle single subject and background environmental concepts 530 | pretty well, and I believe I should use that brush on a canvas. 531 | Many people have been using the image editor plugins released for Krita, 532 | Gimp, Phototshop ... but these plugins are designed for "on the fly" use 533 | during a drawing session, the magic of writing some words and generating an 534 | image is somewhat lost, though they are a very powerful addition to the 535 | digital artists toolset. 536 | 537 | I intend to develop a simple meta language, for composing images, and 538 | leveraging the prompt / embedding tweening things I've built. The purpose 539 | is to built repeatable, cloneable and easy to modify instruction sets that 540 | generate high quality images leveraging stable diffusion. 541 | 542 | It may be that Stable Diffusion improves in composition tasks soon, but this 543 | proposed meta language should be capable of then stitching together those 544 | compositions into even more complex ones, or to compose complex 3d scenes, 545 | worlds, etc using the 3d diffusion models that will arive.. all of these 546 | will be powered by prompts, embeddings, latents.. and will require 547 | composition in a similar way. 548 | 549 | #### How 550 | 551 | CLIP latent space can be used as a 2d canvas.. the space is 1/8th the output 552 | in pixels, so it is highly compressed, but I believe we can build a larger 553 | canvas and target specific points on it with different prompts in different 554 | stages, and using the same technique for image2image, we can reintroduce 555 | noise in specific areas and regenerate concepts, as a blended sort of 556 | inpainting. 557 | 558 | We can arrange concepts through an ordered language: 559 | ``` 560 | "A forest scene" 561 | . "A bear facing left" 562 | > "A hunter pointing a gun to the left" 563 | ``` 564 | 565 | ## Future Work 566 | 567 | - Cleanup UI 568 | - Intersecting features of multiple images on prompt embeddings. 569 | - Support for negative prompts 570 | - Integrate Textual Inversion in an easy to use way 571 | - Dissect BLIP and StableDiffusion training, and build an algorithm or model 572 | that can order CLIP image embeddings to be usable by Stable Diffusion. 573 | - Make mapping and parameter decisions by identifying major image classes. 574 | Investigating potential for a more automated solution where the user can 575 | just set a single guidance multiplier. This would be like applying BLIP or 576 | similar to generate a mapping prompt. 577 | - Someway to find or best fit a noise pattern to an image generation goal... 578 | ( model or algorithm ) as seeds seem to have a huge impact on image 579 | generation. 580 | 581 | 582 | ## Hopes, Dreams and Rambling... 583 | 584 | - Very excited for the LAION and Stability AI work being done to improve CLIP, 585 | A CLIP that understands more images and their features will perform better 586 | in conjunction with Stable Diffusion. Looking forward to Stable Diffusion 587 | supporting this. 588 | - I hope to see Stable Diffusion cross trained to support direct Image 589 | Embeddings ( sparsely distributed embeddings ) as potential inputs for 590 | guidance versus just the current sequential rythmic attention stuff I'm 591 | seeing right now. 592 | - Interested to see if future versions of Stable Diffusion will encorperate 593 | something like NERFs .. I think if diffusion models had a better 594 | understanding of 3 dimensions, they'd be much better at recreating complex 595 | scenese vs just the typical background foreground single subject stuff 596 | they're good at today. 597 | - Even further out I think intersecting diffusion, NERFs and time scale 598 | normalized video frames ( 5+ ) could be really cool and lead to high quality 599 | video generation, with potentially an understanding of physics; seeing some 600 | work out there from NVIDIA and Google that points towards this. 601 | - I look forward to seeing larger variations of Stable Diffusion with more 602 | support for nuance and hopefully better scene composition. With the NVIDIA 603 | 4090 we'll see 48 GB consumer GPU ram soon... 604 | 605 | 606 | ## Thanks 607 | 608 | Everyone behind Stable Diffusion, CLIP and DALL-E for inspiring me creating the 609 | foundation for this work ( Stability AI, and OpenAI ). Huggingface for code 610 | and infrastructure; Lexica for excellent prompt search and image repo; GitHub, 611 | Python, Gradio, Numpy, Pillow and all other libraries and code indirectly used. 612 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | '''Module level exports''' 2 | import encode.clip as encode 3 | import guidance 4 | import pipeline.flex as flex 5 | import utils 6 | 7 | CLIPEncoder = encode.CLIPEncoder 8 | GUIDE_ORDER_TEXT = guidance.GUIDE_ORDER_TEXT 9 | GUIDE_ORDER_ALIGN = guidance.GUIDE_ORDER_ALIGN 10 | Guide = guidance.Guide 11 | preprocess = encode.preprocess 12 | FlexPipeline = flex.FlexPipeline 13 | image_grid = utils.image_grid 14 | Runner = utils.Runner 15 | -------------------------------------------------------------------------------- /__main__.py: -------------------------------------------------------------------------------- 1 | '''Main module for running the dir / module directly''' 2 | import ui 3 | 4 | if __name__ == '__main__': 5 | ui.launch() -------------------------------------------------------------------------------- /composition/embeds.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import List, Sequence, Tuple 3 | 4 | import torch 5 | 6 | from composition.schema import EntitySchema, Schema 7 | from encode.clip import CLIPEncoder 8 | 9 | 10 | @dataclass 11 | class EntityEmbeds(): 12 | embed: torch.Tensor 13 | offset_blocks: Tuple[int, int] 14 | size_blocks: Tuple[int, int] 15 | blend: float 16 | # TODO: Shape / mask 17 | 18 | 19 | @dataclass 20 | class Embeds(): 21 | background_embed: torch.Tensor 22 | style_start_embed: torch.Tensor 23 | style_end_embed: torch.Tensor 24 | style_blend: Tuple[float, float] 25 | entities: List[EntityEmbeds] 26 | 27 | 28 | def px_to_block(px_shape: Sequence[int]) -> Tuple[int, ...]: 29 | return tuple(px // 8 for px in px_shape) 30 | 31 | 32 | def encode_entity(e: EntitySchema, encode: CLIPEncoder) -> EntityEmbeds: 33 | return EntityEmbeds(embed=encode.prompt(e.prompt), 34 | offset_blocks=px_to_block(e.offset), 35 | size_blocks=px_to_block(e.size), 36 | blend=e.blend) 37 | 38 | 39 | def encode_schema(s: Schema, encode: CLIPEncoder) -> Embeds: 40 | return Embeds(background_embed=encode.prompt(s.background_prompt), 41 | style_start_embed=encode.prompt(s.style_start_prompt), 42 | style_end_embed=encode.prompt(s.style_end_prompt), 43 | style_blend=s.style_blend, 44 | entities=[encode_entity(e, encode) for e in s.entities]) 45 | -------------------------------------------------------------------------------- /composition/guide.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import Tuple 3 | import torch 4 | from torch.nn.functional import interpolate 5 | from diffusers.models import UNet2DConditionModel 6 | 7 | from encode.clip import CLIPEncoder 8 | from pipeline.guide import GuideBase 9 | from composition.embeds import encode_schema 10 | from composition.schema import Schema 11 | 12 | MIN_DIM = 64 # 64 * 8 = 512, the number of pixels that SD generates best images 13 | 14 | 15 | def _upscale(t: torch.Tensor) -> torch.Tensor: 16 | low_dim = t.shape[-2] 17 | if t.shape[-1] < low_dim: # Width < Height 18 | low_dim = t.shape[-1] 19 | if low_dim >= MIN_DIM: 20 | return t 21 | # Scale up evenly to MIN_DIM 22 | scale_factor = MIN_DIM / low_dim 23 | shape = (ceil(t.shape[-2] * scale_factor), ceil(t.shape[-1] * scale_factor)) 24 | return _scale(t, shape) 25 | 26 | 27 | def _scale(t: torch.Tensor, shape: Tuple[int, int]): 28 | return interpolate(t, size=shape, mode='bicubic', 29 | antialias=True) # type: ignore - bad annotation in torch 30 | 31 | 32 | class CompositeGuide(GuideBase): 33 | def __init__(self, 34 | encoder: CLIPEncoder, 35 | unet: UNet2DConditionModel, 36 | guidance: float, 37 | schema: Schema, 38 | steps: int, 39 | batch_size: int = 1): 40 | GuideBase.__init__(self, encoder, unet, guidance, steps) 41 | self.schema = schema 42 | self.embeds = encode_schema(schema, encoder) 43 | self.batch_size = batch_size 44 | self.classifier_free_guidance = self.guidance > 1.0 45 | if self.classifier_free_guidance: 46 | # Combine uncond then clip embeds into one stack 47 | self.embed_tensor = torch.cat( 48 | ([self.uncond_embeds] * self.batch_size) + [ 49 | self.embeds.background_embed, 50 | *[e.embed for e in self.embeds.entities] 51 | ]) 52 | else: 53 | self.embed_tensor = torch.cat([ 54 | self.embeds.background_embed, 55 | *[e.embed for e in self.embeds.entities] 56 | ]) 57 | 58 | def _guide_latents(self, latents: torch.Tensor, embeds: torch.Tensor, 59 | step: int) -> torch.FloatTensor: 60 | in_latents = torch.cat([latents] * self.embed_tensor.shape[0]) 61 | # run unet on all uncond and clip embeds 62 | noise_pred = self.unet(in_latents, 63 | step, 64 | encoder_hidden_states=self.embed_tensor).sample 65 | # Combine all in order of declaration 66 | if self.classifier_free_guidance: 67 | noise_stack = noise_pred[self.batch_size:] # First batch is uncod 68 | else: 69 | noise_stack = noise_pred 70 | # Combine 71 | batch_noise = noise_stack[:self.batch_size] 72 | entity_noise = noise_stack[self.batch_size:] 73 | for i in range(self.batch_size): 74 | bi_noise_pred = batch_noise[i:i + 1, :, :, :] 75 | # Blend entity noise onto each item in batch noise 76 | for ei, e in enumerate(self.embeds.entities): 77 | # latents is shape (batch, channel, height, width) 78 | ow, oh = e.offset_blocks # Offset 79 | sw, sh = e.size_blocks # Size 80 | bw, bh = (ow + sw, oh + sh) # Box 81 | # Select noise shape 82 | ens = entity_noise[ei:ei + 1, :, oh:bh, ow:bw] 83 | # Blend 84 | bgs = bi_noise_pred[:, :, oh:bh, ow:bw] 85 | bi_noise_pred[:, :, oh:bh, ow:bw] = bgs + e.blend * (ens - bgs) 86 | # Set back in batch noise 87 | batch_noise[i:i + 1, :, :, :] = bi_noise_pred 88 | 89 | if self.classifier_free_guidance: 90 | # Combine the unconditioned guidance with the clip guidance for bg 91 | noise_pred_uncond = noise_pred[:self.batch_size] 92 | batch_noise = noise_pred_uncond + self.guidance * ( 93 | batch_noise - noise_pred_uncond) 94 | 95 | return batch_noise 96 | 97 | def noise_pred(self, latents: torch.Tensor, step: int) -> torch.FloatTensor: 98 | # noise_pred.shape == latents.shape after chunking 99 | # before noise_pred.shape == latent_model_input.shape 100 | # 1, 4, 64, 64 after ; 2, 4, 64, 64 before ; if using 512x512 101 | # TODO: Composition 102 | # for compositions we blend in the clip guidance here for each 103 | # individual entity at it's position, on top of the background 104 | # embedding 105 | # This means we have to also trim the latent_model_input based 106 | # on entity dimensions and run it through the unet to get a 107 | # noise pred per entity: 108 | # - uncond_pred for whole canvas 109 | # - background_pred for whole canvas 110 | # - entity_pred for each entity in entity block space 111 | # run unet on all 112 | bg_embeds = self.embeds.background_embed 113 | # blend styles 114 | progress = self.steps / step 115 | style_base = self.embeds.style_blend[0] 116 | style_d = self.embeds.style_blend[1] - style_base 117 | style = self.embeds.style_start_embed + ( 118 | self.embeds.style_end_embed 119 | - self.embeds.style_start_embed) * (style_base + 120 | (progress * style_d)) 121 | # TODO: Handle style blending 122 | bg_noise_pred = self._guide_latents(latents, bg_embeds, step) 123 | 124 | # # Run and map all entities 125 | # for e in self.embeds.entities: 126 | # # latents is shape (batch, channel, height, width) 127 | # ow, oh = e.offset_blocks # Offset 128 | # sw, sh = e.size_blocks # Size 129 | # bw, bh = (ow + sw, oh + sh) # Box 130 | # l = latents[:, :, oh:bh, ow:bw] 131 | # # Upscale before guiding latents as SD works at 512x512 132 | # ul = _upscale(l) 133 | # enp = self._guide_latents(ul, e.embed, step) 134 | # # Downscale back and blend 135 | # denp = _scale(enp, (sh, sw)) 136 | # bgs = bg_noise_pred[:, :, oh:bh, ow:bw] 137 | # bg_noise_pred[:, :, oh:bh, ow:bw] = bgs + e.blend * (denp - bgs) 138 | 139 | return bg_noise_pred -------------------------------------------------------------------------------- /composition/schema.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import json 3 | from typing import List, Tuple 4 | 5 | 6 | @dataclass 7 | class EntitySchema(): 8 | prompt: str 9 | offset: Tuple[int, int] 10 | size: Tuple[int, int] 11 | blend: float = 0.8 12 | # TODO: Shape / mask 13 | 14 | 15 | @dataclass 16 | class Schema(): 17 | background_prompt: str 18 | style_start_prompt: str 19 | style_end_prompt: str 20 | style_blend: Tuple[float, float] 21 | entities: List[EntitySchema] 22 | 23 | def json(self) -> str: 24 | s = self.__dict__.copy() 25 | s['entities'] = [e.__dict__ for e in self.entities] 26 | return json.dumps(s) -------------------------------------------------------------------------------- /encode/clip.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import numpy as np 4 | from PIL.Image import Image, LANCZOS 5 | import torch 6 | from torchvision.transforms.functional import (center_crop, normalize, resize, 7 | InterpolationMode) 8 | from transformers.models.clip.modeling_clip import CLIPModel 9 | from transformers.models.clip.tokenization_clip import CLIPTokenizer 10 | 11 | CLIP_IMAGE_SIZE = 224 12 | MAX_SINGLE_DIM = 512 # for stable diffusion image 13 | 14 | 15 | def preprocess(image: Image | Any) -> torch.Tensor: 16 | '''Preprocess image for encoding 17 | 18 | Args: 19 | image (Image | Any): PIL Image 20 | 21 | Returns: 22 | torch.Tensor: Tensor of image data for encoding 23 | ''' 24 | w, h = image.size 25 | if h > w: 26 | w = (int(w / (h / MAX_SINGLE_DIM)) // 64) * 64 27 | h = MAX_SINGLE_DIM 28 | elif w > h: 29 | h = (int(h / (w / MAX_SINGLE_DIM)) // 64) * 64 30 | w = MAX_SINGLE_DIM 31 | else: 32 | h = MAX_SINGLE_DIM 33 | w = MAX_SINGLE_DIM 34 | image = image.resize((w, h), resample=LANCZOS) 35 | image = image.convert('RGB') 36 | image = np.array(image).astype(np.float32) / 255.0 37 | image = image[None].transpose(0, 3, 1, 2) 38 | image = torch.from_numpy(image) 39 | return 2.0 * image - 1.0 40 | 41 | 42 | class CLIPEncoder(): 43 | def __init__(self, clip: CLIPModel, token: CLIPTokenizer) -> None: 44 | self.clip = clip 45 | self.token = token 46 | 47 | def prompt(self, prompt: str | List[str]) -> torch.Tensor: 48 | '''Encode text embeddings for provided text 49 | 50 | Args: 51 | text (str | List[str]): Text or array of texts to encode. 52 | 53 | Returns: 54 | torch.Tensor: Encoded text embeddings. 55 | ''' 56 | # get prompt text embeddings 57 | text_input = self.token( 58 | prompt, 59 | padding='max_length', 60 | max_length=self.token.model_max_length, 61 | truncation=True, 62 | return_tensors='pt', 63 | ) 64 | return self.clip.text_model(text_input.input_ids.to( 65 | self.clip.device))[0] 66 | 67 | def image(self, image: Image) -> torch.Tensor: 68 | '''Encode image embeddings for provided image 69 | 70 | Args: 71 | image (Image): The PIL.Image to encode. 72 | 73 | Returns: 74 | torch.Tensor: Encoded image embeddings. 75 | ''' 76 | guide_tensor = preprocess(image) 77 | crop_size = min(guide_tensor.shape[-2:]) 78 | guide_tensor = center_crop(guide_tensor, [crop_size, crop_size]) 79 | guide_tensor = resize(guide_tensor, [CLIP_IMAGE_SIZE, CLIP_IMAGE_SIZE], 80 | interpolation=InterpolationMode.BICUBIC, 81 | antialias=True) 82 | guide_tensor = normalize( 83 | guide_tensor, [0.48145466, 0.4578275, 0.40821073], 84 | [0.26862954, 0.26130258, 0.27577711]).to(self.clip.device) 85 | 86 | hidden_states = self.clip.vision_model.embeddings(guide_tensor) 87 | hidden_states = self.clip.vision_model.pre_layrnorm(hidden_states) 88 | 89 | encoder_outputs = self.clip.vision_model.encoder( 90 | inputs_embeds=hidden_states, 91 | output_attentions=False, 92 | output_hidden_states=False, 93 | return_dict=True, 94 | ) 95 | 96 | last_hidden_state = encoder_outputs[0] 97 | pooled_output = last_hidden_state[:, :, :] 98 | pooled_output = self.clip.vision_model.post_layernorm(pooled_output) 99 | 100 | return self.clip.visual_projection(pooled_output) -------------------------------------------------------------------------------- /experiments/city_photo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/city_photo.png -------------------------------------------------------------------------------- /experiments/city_photo_forest_blend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/city_photo_forest_blend.png -------------------------------------------------------------------------------- /experiments/creepy_tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/creepy_tree.png -------------------------------------------------------------------------------- /experiments/creepy_tree_mapped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/creepy_tree_mapped.png -------------------------------------------------------------------------------- /experiments/deer_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_base.png -------------------------------------------------------------------------------- /experiments/deer_img2img_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_img2img_base.png -------------------------------------------------------------------------------- /experiments/deer_img2img_defaults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_img2img_defaults.png -------------------------------------------------------------------------------- /experiments/deer_mod.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_mod.webp -------------------------------------------------------------------------------- /experiments/deer_modded_defaults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_modded_defaults.png -------------------------------------------------------------------------------- /experiments/deer_tuned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_tuned.png -------------------------------------------------------------------------------- /experiments/deer_tuned_linear_only.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/deer_tuned_linear_only.png -------------------------------------------------------------------------------- /experiments/settings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/settings.png -------------------------------------------------------------------------------- /experiments/turtle_base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/turtle_base.png -------------------------------------------------------------------------------- /experiments/turtle_mod.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/turtle_mod.webp -------------------------------------------------------------------------------- /experiments/turtle_modded_defaults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/turtle_modded_defaults.png -------------------------------------------------------------------------------- /experiments/turtle_to_rock_monkey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/turtle_to_rock_monkey.png -------------------------------------------------------------------------------- /experiments/turtle_tuned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/turtle_tuned.png -------------------------------------------------------------------------------- /experiments/ui.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/ui.png -------------------------------------------------------------------------------- /experiments/zeus.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/zeus.webp -------------------------------------------------------------------------------- /experiments/zeus_anime.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/zeus_anime.png -------------------------------------------------------------------------------- /experiments/zeus_mod.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/zeus_mod.webp -------------------------------------------------------------------------------- /experiments/zeus_modded_defaults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/zeus_modded_defaults.png -------------------------------------------------------------------------------- /experiments/zeus_tuned.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tim-speed/flexdiffuse/12726691975b4ffe01c714cc66f9fe23170deb9a/experiments/zeus_tuned.png -------------------------------------------------------------------------------- /guidance.py: -------------------------------------------------------------------------------- 1 | '''Utilities for building prompt or image guided embeddings and tweening the 2 | space of their numbers.''' 3 | from itertools import pairwise 4 | import math 5 | from typing import List, Set, Tuple 6 | 7 | import numpy as np 8 | from PIL.Image import Image 9 | import torch 10 | from transformers.models.clip.modeling_clip import CLIPModel 11 | from transformers.models.clip.tokenization_clip import CLIPTokenizer 12 | 13 | from encode.clip import CLIPEncoder 14 | 15 | CLIP_IMAGE_SIZE = 224 16 | MAX_SINGLE_DIM = 512 # for stable diffusion image 17 | 18 | GUIDE_ORDER_TEXT = 0 19 | GUIDE_ORDER_ALIGN = 1 20 | GUIDE_ORDER_DIRECT = 2 21 | 22 | 23 | def _map_emb(alt_emb: torch.Tensor, 24 | txt_emb: torch.Tensor, 25 | alt_emb_reuse: bool = True, 26 | guide_order: int = GUIDE_ORDER_ALIGN) -> np.ndarray: 27 | '''Map the provided alternate embeddings with the provided text embeddings,\ 28 | according to the provided params to their highest alignment match. 29 | 30 | Args: 31 | alt_emb (torch.Tensor): The alternate embeddings. 32 | txt_emb (torch.Tensor): The text embeddings. 33 | alt_emb_reuse (bool, optional): Alternate embedding reuse, True allows\ 34 | embeddings to be allocated to multiple text embeddings, otherwise\ 35 | they will be used once, based on `guide_order`. Defaults to True. 36 | guide_order (int, optional): Guide order, prioritize text order or \ 37 | best aligment? Defaults to GUIDE_ORDER_ALIGN. 38 | 39 | Returns: 40 | np.ndarray: The mappings of shape(MAX_TOKENS-1, 2): (alt_embed_index,\ 41 | Alignment) 42 | ''' 43 | altft = alt_emb / alt_emb.norm(dim=-1, keepdim=True) 44 | txtft = txt_emb / txt_emb.norm(dim=-1, keepdim=True) 45 | # All token matches is 256 * 77: imf_i, tf_i alignment 46 | all_matches: List[Tuple[int, int, float]] = [] 47 | # TODO-OPT: Can probably vectorize this 48 | for i, ialtft in enumerate(altft[0, :]): 49 | ialtft = ialtft.unsqueeze(0) 50 | similarity = (100.0 * (ialtft @ txtft.mT)).softmax(dim=-1) 51 | # Enumerate matches and similarity, we remove the first index here 52 | # so our range goes from 0:MAX_TOKENS -> 1:MAX_TOKENS-1 53 | # to ignore the header token. 54 | all_matches += [ 55 | (i, ii, v.item()) for ii, v in enumerate(similarity[0, 0, 1:]) 56 | ] 57 | if guide_order == GUIDE_ORDER_TEXT: 58 | # sort: asc text feature, desc alignment, asc image feature 59 | all_matches.sort(key=lambda t: (t[1], -t[2], t[0])) 60 | elif guide_order == GUIDE_ORDER_DIRECT: 61 | # Sort only by token ids 62 | # TODO-OPT: this can be optimized by preselecting pairs before running 63 | # similarity.. no extra comparisons.. no sort needed.. 64 | all_matches.sort(key=lambda t: (t[1], t[0])) 65 | mapped_tokens = np.zeros((txt_emb.shape[1], 2)) 66 | for alt_i, txt_i, s in all_matches: 67 | if alt_i == txt_i: 68 | mapped_tokens[txt_i] = (alt_i, s) 69 | return mapped_tokens 70 | else: 71 | # sort: desc alignment, asc text feature, asc image feature 72 | all_matches.sort(key=lambda t: (-t[2], t[1], t[0])) 73 | # Now map the alt token per text token, without reusing tokens 74 | # and skipping the first token, as it is a header not meant to 75 | # be changed. 76 | mapped_tokens = np.zeros((txt_emb.shape[1], 2)) 77 | alt_toks_used: Set[int] = set() 78 | # TODO-OPT: Optimize 79 | for alt_i, txt_i, s in all_matches: 80 | if mapped_tokens[txt_i, 1] > 0 or alt_i in alt_toks_used: 81 | continue 82 | mapped_tokens[txt_i] = (alt_i, s) 83 | if not alt_emb_reuse: 84 | alt_toks_used.add(alt_i) 85 | return mapped_tokens 86 | 87 | 88 | def _traverse_a_to_b(al: List[int], bl: List[int], weights: torch.Tensor, 89 | slope: float) -> torch.Tensor: 90 | '''Utility for applying linear slope adjustment on weights between points\ 91 | al and bl. 92 | 93 | Args: 94 | al (List[int]): The points to traverse from. 95 | bl (List[int]): The points to traverse to. 96 | weights (torch.Tensor): The weight of each point. 97 | slope (float): Slope to apply. 98 | 99 | Returns: 100 | torch.Tensor: The weight tensor modified in place. 101 | ''' 102 | bi = 0 103 | 104 | def traverse_left(a: int, b: int): 105 | d = a - b 106 | gslope = slope / d 107 | for i in range(1, d): 108 | weights[a - i] -= gslope * i 109 | 110 | def traverse_right(a: int, b: int): 111 | d = b - a 112 | gslope = slope / d 113 | for i in range(1, d + 1): 114 | weights[a + i] -= gslope * i 115 | 116 | if bl[0] == 0: 117 | # Apply full slope on left most point as our algo is right focused 118 | weights[0] -= slope 119 | for a in al: 120 | # Left 121 | b = bl[bi] 122 | if b < a: 123 | traverse_left(a, b) 124 | bi += 1 125 | # Right 126 | if bi >= len(bl): 127 | # Peak at end 128 | break 129 | b = bl[bi] 130 | traverse_right(a, b) 131 | 132 | return weights 133 | 134 | 135 | def _clustered_guidance(mapped_tokens: np.ndarray, threshold: float, 136 | guidance: float) -> None | torch.Tensor: 137 | '''Cluster by indentifying all peaks that are over avg_similarity and then\ 138 | and then traversing downward into the valleys as style. 139 | Similarity Peaks == Subject, Valleys == Style 140 | Slope will be calculated between peaks and valleys 141 | 142 | Args: 143 | mapped_tokens (np.ndarray): The mapped tokens to traverse 144 | threshold (float): The mapping aligment threshold for potential peaks. 145 | guidance (float): The amount of guidance to apply ( multiplier ). 146 | 147 | Returns: 148 | None | torch.Tensor: Clustered embedding weights 149 | ''' 150 | token_len = mapped_tokens.shape[0] 151 | clustered_weights = None 152 | peaks: List[int] = [] 153 | for txt_i, (_, s) in enumerate(mapped_tokens[1:-1], 1): 154 | if s < threshold: 155 | continue 156 | if (mapped_tokens[txt_i - 1, 1] <= s >= mapped_tokens[txt_i + 1, 1]): 157 | peaks.append(txt_i) 158 | if peaks: 159 | valleys: List[int] = [] 160 | if peaks[0] != 0: 161 | valleys.append(0) 162 | for p1, p2 in pairwise(peaks): 163 | d = p2 - p1 164 | if d > 0: 165 | valleys.append(p1 + math.ceil(d / 2)) 166 | if peaks[-1] != token_len - 1: 167 | valleys.append(token_len - 1) 168 | # Peaks to Valleys 169 | clustered_weights = _traverse_a_to_b(peaks, valleys, 170 | torch.ones( 171 | (token_len,)), 1.0) * guidance 172 | return clustered_weights 173 | 174 | 175 | def _blend_weights(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: 176 | '''Blend the weights described in two tensors into one. 177 | 178 | Args: 179 | a (torch.Tensor): First tensor to blend 180 | b (torch.Tensor): Second tensor to blend 181 | 182 | Returns: 183 | torch.Tensor: A tensor blended of the two weights 184 | ''' 185 | assert a.shape == b.shape, f'Tensor shapes a={a.shape} != b={b.shape}' 186 | if a.max() >= 0: 187 | if b.max() >= 0: 188 | return torch.maximum(a, b) 189 | # Fighting eachother 190 | # TODO: might be a better way? 191 | return a + b 192 | # Both negative or zero 193 | return torch.minimum(a, b) 194 | 195 | 196 | class Tweener(): 197 | def __init__(self, 198 | threshold: Tuple[float, float] = (0.5, 0.5), 199 | linear: Tuple[float, float] = (0.0, 0.5), 200 | clustered: float = 0.5, 201 | max_guidance: float = 0.5, 202 | header_max: float = 0.15, 203 | align_mode: int = GUIDE_ORDER_ALIGN, 204 | mapping_reuse: bool = True) -> None: 205 | self.threshold_floor = threshold[0] 206 | self.threshold_mult = threshold[1] 207 | self.linear_start = linear[0] 208 | self.linear_end = linear[1] 209 | self.clustered = clustered 210 | self.max_guidance = max_guidance 211 | self.header_max = header_max 212 | self.align_mode = align_mode 213 | self.mapping_reuse = mapping_reuse 214 | 215 | def tween(self, base_emb: torch.Tensor, 216 | alt_emb: torch.Tensor) -> torch.Tensor: 217 | mapped_tokens = _map_emb(alt_emb, base_emb, self.mapping_reuse, 218 | self.align_mode) 219 | avg_similarity = mapped_tokens[:, 1].mean() 220 | print(f'Tweening with, Avg Similarity: {avg_similarity:.2%}, ' 221 | f'Threshold: {self.threshold_floor:.2%}, ' 222 | f'Threshold Multiplier: {self.threshold_mult:.2%}, ' 223 | f'Clustered: {self.clustered:.2%}, ' 224 | f'Linear: {self.linear_start:.2%}' 225 | f'-{self.linear_end:.2%}, ' 226 | f'Guidance Max: {self.max_guidance:.2%}') 227 | # TODO: Guidance slope param to make either linear or grouped 228 | # slopes quadratic .. AKA nice an smooth instead of sharp 229 | # Init img weights from linear slope, front to back, to amplify 230 | # backend / style features 231 | alt_weights = torch.linspace(self.linear_start, 232 | self.linear_end, 233 | steps=base_emb.shape[1]) 234 | if self.clustered != 0: 235 | clustered_weights = _clustered_guidance(mapped_tokens, 236 | avg_similarity, 237 | self.clustered) 238 | if clustered_weights is not None: 239 | alt_weights = _blend_weights(alt_weights, clustered_weights) 240 | 241 | if self.threshold_mult != 0: 242 | th_weights = torch.ones_like(alt_weights) * self.threshold_mult 243 | for txt_i, (_, s) in enumerate(mapped_tokens): 244 | if s < self.threshold_floor: 245 | th_weights[txt_i] = 0 246 | alt_weights = _blend_weights(alt_weights, th_weights) 247 | 248 | # Cap the header token 249 | if self.header_max < 1.0: 250 | hw = alt_weights[0].item() 251 | if hw >= 0: 252 | alt_weights[0] = min(hw, self.header_max) 253 | else: 254 | alt_weights[0] = max(hw, -self.header_max) 255 | print('Alt Embed Blend Weights:', alt_weights.shape, ':', alt_weights) 256 | # tween text and image embeddings 257 | # TODO-OPT: Vectorize, probably don't need if conditions 258 | clip_embeddings = torch.zeros_like(base_emb) 259 | for txt_i, (img_i, s) in enumerate(mapped_tokens): 260 | sd = 1.0 - s 261 | iw = min(alt_weights[txt_i].item(), self.max_guidance) 262 | if iw == 0: 263 | # Leave as is 264 | clip_embeddings[0, txt_i] = base_emb[0, txt_i] 265 | elif abs(iw) >= sd: 266 | # We cap at taking all the image 267 | clip_embeddings[0, txt_i] = alt_emb[0, int(img_i)] 268 | else: 269 | # Text towards image 270 | d_emb = alt_emb[0, int(img_i)] - base_emb[0, txt_i] 271 | clip_embeddings[0, txt_i] = base_emb[0, txt_i] + (d_emb * iw) 272 | return clip_embeddings 273 | 274 | 275 | class ConceptMapper(): 276 | def __init__(self, guide_embeddings: torch.Tensor, 277 | concept_embeddings: torch.Tensor) -> None: 278 | self.guide_embeddings = guide_embeddings 279 | self.concept_embeddings = concept_embeddings 280 | self.concept_mappings = _map_emb(guide_embeddings, concept_embeddings, 281 | False, GUIDE_ORDER_TEXT) 282 | # Print the result 283 | print(f'Image Feature and Concept alignment:') 284 | for txt_i, (img_i, s) in enumerate(self.concept_mappings, 1): 285 | print(f'ConceptTok {txt_i:>02d} ImgTok ' 286 | f'{int(img_i):>02d} {100 * s:.2f}%') 287 | 288 | def map(self, 289 | base_embeddings: torch.Tensor, 290 | output_embeddings: torch.Tensor | None = None) -> torch.Tensor: 291 | if output_embeddings is None: 292 | output_embeddings = base_embeddings.clone() 293 | concept_text = _map_emb(self.concept_embeddings, base_embeddings, True, 294 | GUIDE_ORDER_ALIGN) 295 | # Print the result 296 | print(f'Concept Feature and Token alignment:') 297 | for txt_i, (concept_i, s) in enumerate(concept_text, 1): 298 | concept_i = int(concept_i) 299 | cmi = int(concept_i) - 1 # because mappings start from 1 300 | if cmi < 0: 301 | # skip header token, it'd be weird for it to map though 302 | continue 303 | concept_image_i, concept_image_s = self.concept_mappings[cmi] 304 | concept_image_i = int(concept_image_i) 305 | if s > 0.9: 306 | output_embeddings[0, txt_i] = self.guide_embeddings[ 307 | 0, concept_image_i] 308 | print(f'TxtTok {txt_i:>02d} ConceptTok ' 309 | f'{concept_i:>02d} {s:.2%} ImageTok ' 310 | f'{concept_image_i:>03d} {concept_image_s:.2%}' 311 | + (' MAPPED' if s > 0.9 else '')) 312 | return output_embeddings 313 | 314 | 315 | class Guide(): 316 | def __init__(self, 317 | clip: CLIPModel, 318 | tokenizer: CLIPTokenizer, 319 | device: str = 'cuda') -> None: 320 | '''Init context for generating prompt or image embeddings and tweening\ 321 | the space of their numbers. 322 | 323 | Args: 324 | clip (CLIPModel): The CLIP model to use 325 | tokenizer (CLIPTokenizer): The tokenizer used by the CLIP model. 326 | device (str, optional): Only tested with cuda. Defaults to 'cuda'. 327 | ''' 328 | # TODO: Support regular CLIP not just huggingface 329 | self.clip = clip 330 | self.tokenizer = tokenizer 331 | self.device = device 332 | self.encoder = CLIPEncoder(clip, tokenizer) 333 | # Placeholder embed is used for its header token in direct image 334 | # guidance 335 | self.placeholder_embed = self.encoder.prompt('{}') 336 | 337 | def embeds(self, 338 | prompt: str | List[str] = '', 339 | guide: Image | str | None = None, 340 | mapping_concepts: str = '', 341 | guide_threshold_mult: float = 0.5, 342 | guide_threshold_floor: float = 0.5, 343 | guide_clustered: float = 0.5, 344 | guide_linear: Tuple[float, float] = (0.0, 0.5), 345 | guide_max_guidance: float = 0.5, 346 | guide_header_max: float = 0.15, 347 | guide_mode: int = GUIDE_ORDER_ALIGN, 348 | guide_reuse: bool = True) -> torch.Tensor: 349 | '''Generate embeddings from text or image or tween their space of\ 350 | numbers. 351 | 352 | Args: 353 | prompt (str | List[str], optional): The guidance prompt. Defaults\ 354 | to ''. 355 | guide (Image | str | None, optional): The guidance image or text to\ 356 | blend with the prompt. Defaults to None. 357 | mapping_concepts (str, optional): Concepts to fully map from image.\ 358 | Defaults to ''. 359 | guide_threshold_mult (float, optional): Multiplier for\ 360 | alignment threshold tweening. Defaults to 0.5. 361 | guide_threshold_floor (float, optional): Floor to accept\ 362 | embeddings for threshold tweening based on their alignment.\ 363 | Defaults to 0.5. 364 | guide_clustered (float, optional): A clustered match guidance\ 365 | approach, not as good as threshold but can make some minor\ 366 | adjustments. Defaults to 0.5. 367 | guide_linear (Tuple[float, float], optional): Linear style\ 368 | blending first value mapped to the start of the prompt, the\ 369 | second is mapped to the end. Can be used to push the prompt\ 370 | towards the image at the front or back, or away from it.\ 371 | Defaults to (0.0, 0.5). 372 | guide_max_guidance (float, optional): Cap on the overall\ 373 | tweening, regardless of multiplier, does not affect mapping\ 374 | concepts. Defaults to 0.5. 375 | guide_header_max (float, optional): Caps the manipulation of\ 376 | the leading header token. Defaults to 0.15. 377 | guide_mode (int, optional): Image guidance mode, to priorize\ 378 | based on aligment first or text embedding order, effects most\ 379 | noticiable when playing with the `reuse` parameter. Defaults\ 380 | to GUIDE_ORDER_ALIGN. 381 | guide_reuse (bool, optional): Allow re-mapping already mapped\ 382 | image concepts, True will result in a best fit between image\ 383 | and text, but less "uniqueness". Defaults to True. 384 | 385 | Raises: 386 | ValueError: If you don't supply a proper prompt or guide image. 387 | 388 | Returns: 389 | torch.Tensor: CLIP embeddings for use in StableDiffusion. 390 | ''' 391 | 392 | if isinstance(prompt, str): 393 | prompt = prompt.strip() 394 | elif isinstance(prompt, list): 395 | prompt = [ss for ss in (s.strip() for s in prompt) if ss] 396 | else: 397 | raise ValueError(f'`prompt` has to be of type `str` ' 398 | f'or `list` but is {type(prompt)}') 399 | 400 | if not prompt and guide is None: 401 | raise ValueError('No prompt, or guide image provided.') 402 | 403 | # Get embeddings and setup tweening and mapping 404 | text_embeddings: torch.Tensor | None = None 405 | guide_embeddings: torch.Tensor | None = None 406 | concept_mapper: ConceptMapper | None = None 407 | if prompt: 408 | # get prompt text embeddings 409 | text_embeddings = self.encoder.prompt(prompt) 410 | if guide is not None: 411 | if isinstance(guide, str): 412 | guide = guide.strip() 413 | if guide: 414 | guide_embeddings = self.encoder.prompt(guide) 415 | else: 416 | guide_embeddings = self.encoder.image(guide) 417 | if mapping_concepts: 418 | concept_mapper = ConceptMapper( 419 | guide_embeddings, self.encoder.prompt(mapping_concepts)) 420 | tweener = Tweener((guide_threshold_floor, guide_threshold_mult), 421 | guide_linear, guide_clustered, guide_max_guidance, 422 | guide_header_max, guide_mode, guide_reuse) 423 | 424 | def _tween(img_emb: torch.Tensor, 425 | txt_emb: torch.Tensor) -> torch.Tensor: 426 | clip_embeddings = tweener.tween(txt_emb, img_emb) 427 | if concept_mapper is not None: 428 | # TODO: Add more customization for this feature, combine with 429 | # max guidance etc.. 430 | clip_embeddings = concept_mapper.map(txt_emb, clip_embeddings) 431 | print('Tweened text and image embeddings:', img_emb.shape, 432 | ' text shape:', txt_emb.shape, ' embed shape:', 433 | clip_embeddings.shape) 434 | return clip_embeddings 435 | 436 | # Perform possible tweening based on available embeddings 437 | if text_embeddings is not None: 438 | if guide_embeddings is not None: 439 | if text_embeddings.shape[0] > 1: 440 | # Batch 441 | clip_embeddings = text_embeddings.clone() 442 | # TODO-OPT: Vectorize?? 443 | for i, txt_emb in enumerate(text_embeddings): 444 | clip_embeddings[i] = _tween(guide_embeddings, txt_emb) 445 | else: 446 | # Solo 447 | clip_embeddings = _tween(guide_embeddings, text_embeddings) 448 | else: 449 | clip_embeddings: torch.Tensor = text_embeddings 450 | else: 451 | assert guide_embeddings is not None 452 | # Select the first self.tokenizer.model_max_length image embedding 453 | # tokens only. 454 | # NOTE: This is not good guidance, StableDiffusion wasn't trained 455 | # for this. We need to map to prompts. 456 | # TODO: Build a model that can resequence image embeddings to 457 | # a similar structure as text, SEE: BLIP ?? No need for text tho. 458 | 459 | if isinstance(guide, str): 460 | print('Warning: using the guide like prompt.. just use prompt.') 461 | clip_embeddings = guide_embeddings 462 | else: 463 | print('Warning: trying to guide purely from image, ' 464 | 'this will generate weird stuff, enjoy :)\n' 465 | 'If you\'re bored try an image of yourself ' 466 | 'and see what the model thinks.') 467 | clip_embeddings = guide_embeddings[:, :self.tokenizer. 468 | model_max_length, :] 469 | d_emb = self.placeholder_embed[:, 0, :] - clip_embeddings[:, 470 | 0, :] 471 | # Move 85% towards the text header 472 | clip_embeddings[:, 0, :] += d_emb * 0.85 473 | 474 | return clip_embeddings 475 | -------------------------------------------------------------------------------- /interface/composer.py: -------------------------------------------------------------------------------- 1 | '''Sandbox UI for testing all features''' 2 | from typing import Any, Callable, Iterable, List 3 | import gradio as gr 4 | 5 | import utils 6 | from composition.schema import EntitySchema, Schema 7 | 8 | DEFAULT_SCHEMA = Schema('A forest with a bear and a deer', 'Photo', 'Painting', 9 | (0.0, 1.0), [ 10 | EntitySchema('A bear in the forest', (0, 256), 11 | (256, 256)), 12 | EntitySchema('A deer in the forest', (256, 256), 13 | (256, 256)) 14 | ]) 15 | 16 | 17 | def unpack(e: object) -> List[Any]: 18 | nout = [] 19 | for v in e.__dict__.values(): 20 | if not isinstance(v, str) and isinstance(v, Iterable): 21 | nout.extend(v) 22 | else: 23 | nout.append(v) 24 | return nout 25 | 26 | 27 | def block(runner: Callable[[], utils.Runner]): 28 | def run(bg_prompt, entities_df, start_style, end_style, 29 | style_blend_linear_start, style_blend_linear_end, init_image, 30 | samples, strength, steps, guidance_scale, height, width, seed, 31 | debug): 32 | if debug and samples * steps > 100: 33 | samples = 100 // steps 34 | print(f'Debug detected, forcing samples to {samples}' 35 | f', to avoid too much output... ( <= 100 imgs )') 36 | imgs, grid = runner().compose( 37 | bg_prompt, entities_df, start_style, end_style, 38 | (style_blend_linear_start, style_blend_linear_end), init_image, 39 | samples, strength, steps, guidance_scale, (height, width), seed, 40 | debug) 41 | return imgs 42 | 43 | with gr.Group(): 44 | with gr.Box(): 45 | with gr.Row(): 46 | bg_prompt = gr.TextArea( 47 | label='Background / Main Prompt', 48 | value=DEFAULT_SCHEMA.background_prompt, 49 | max_lines=1, 50 | placeholder=( 51 | 'Enter text to define the main scene / background.'), 52 | ) 53 | with gr.Row(): 54 | entities_df = gr.Dataframe( 55 | label='Entities ( Ordered )', 56 | value=[unpack(e) for e in DEFAULT_SCHEMA.entities], 57 | headers=[ 58 | 'Prompt', 'Left', 'Top', 'Width', 'Height', 'Strength' 59 | ], 60 | datatype=[ 61 | 'str', 'number', 'number', 'number', 'number', 'number' 62 | ], 63 | col_count=(6, 'fixed'), 64 | interactive=True) 65 | with gr.Row(): 66 | start_style = gr.TextArea( 67 | label='Starting Style Prompt', 68 | value=DEFAULT_SCHEMA.style_start_prompt, 69 | max_lines=1, 70 | placeholder=( 71 | 'Enter text to define the style early on in generation.' 72 | ), 73 | ) 74 | 75 | with gr.Row(): 76 | end_style = gr.TextArea( 77 | label='Ending Style Prompt', 78 | value=DEFAULT_SCHEMA.style_end_prompt, 79 | max_lines=1, 80 | placeholder=( 81 | 'Enter text to define the final style in generation.'), 82 | ) 83 | 84 | with gr.Row(): 85 | style_blend_linear_start = gr.Slider( 86 | label='Linear Style Blend Start', 87 | minimum=-1, 88 | maximum=1, 89 | value=DEFAULT_SCHEMA.style_blend[0], 90 | step=0.01) 91 | style_blend_linear_end = gr.Slider( 92 | label='Linear Style Blend End', 93 | minimum=-1, 94 | maximum=1, 95 | value=DEFAULT_SCHEMA.style_blend[1], 96 | step=0.01) 97 | 98 | with gr.Row(): 99 | init_image = gr.Image(label='Initial image', 100 | source='upload', 101 | interactive=True, 102 | type='pil') 103 | with gr.Row(): 104 | samples = gr.Slider(label='Batches ( Images )', 105 | minimum=1, 106 | maximum=16, 107 | value=4, 108 | step=1) 109 | 110 | strength = gr.Slider(label='Diffusion Strength ( For Img2Img )', 111 | minimum=0, 112 | maximum=1, 113 | value=0.6, 114 | step=0.01) 115 | 116 | with gr.Row(): 117 | steps = gr.Slider(label='Steps', 118 | minimum=8, 119 | maximum=100, 120 | value=30, 121 | step=2) 122 | guidance_scale = gr.Slider(label='Guidance Scale ( Overall )', 123 | minimum=0, 124 | maximum=20, 125 | value=8, 126 | step=0.5) 127 | 128 | with gr.Row(): 129 | height = gr.Slider(minimum=64, 130 | maximum=2048, 131 | step=64, 132 | label="Init Height", 133 | value=512) 134 | width = gr.Slider(minimum=64, 135 | maximum=2048, 136 | step=64, 137 | label="Init Width", 138 | value=512) 139 | 140 | with gr.Row(): 141 | seed = gr.Number(label='Seed', 142 | precision=0, 143 | value=1337, 144 | interactive=True) 145 | debug = gr.Checkbox(label='Export Debug Images', value=False) 146 | generate: gr.Button = gr.Button(value='Generate image', 147 | variant='primary') 148 | 149 | gallery = gr.Gallery(label='Generated images', 150 | show_label=False, 151 | elem_id='gallery').style(grid=2, height='auto') 152 | 153 | bg_prompt.submit(run, 154 | inputs=[ 155 | bg_prompt, entities_df, start_style, end_style, 156 | style_blend_linear_start, style_blend_linear_end, 157 | init_image, samples, strength, steps, 158 | guidance_scale, height, width, seed, debug 159 | ], 160 | outputs=[gallery]) 161 | generate.click(run, 162 | inputs=[ 163 | bg_prompt, entities_df, start_style, end_style, 164 | style_blend_linear_start, style_blend_linear_end, 165 | init_image, samples, strength, steps, guidance_scale, 166 | height, width, seed, debug 167 | ], 168 | outputs=[gallery]) 169 | -------------------------------------------------------------------------------- /interface/sandbox.py: -------------------------------------------------------------------------------- 1 | '''Sandbox UI for testing all features''' 2 | from typing import Callable 3 | import gradio as gr 4 | 5 | import utils 6 | 7 | 8 | def block(runner: Callable[[], utils.Runner]): 9 | def run(prompt, init_image, guide_image, guide_text, height, width, 10 | mapping_concepts, guide_image_threshold_mult, 11 | guide_image_threshold_floor, guide_image_clustered, 12 | guide_image_linear_start, guide_image_linear_end, 13 | guide_image_max_guidance, guide_image_header, guide_image_mode, 14 | guide_image_reuse, strength, steps, guidance_scale, samples, seed, 15 | debug): 16 | if debug and samples * steps > 100: 17 | samples = 100 // steps 18 | print(f'Debug detected, forcing samples to {samples}' 19 | f', to avoid too much output... ( <= 100 imgs )') 20 | if not guide_image and guide_text: 21 | guide = guide_text 22 | else: 23 | guide = guide_image 24 | imgs, grid = runner().gen( 25 | prompt, init_image, guide, (height, width), mapping_concepts, 26 | guide_image_threshold_mult, guide_image_threshold_floor, 27 | guide_image_clustered, 28 | (guide_image_linear_start, guide_image_linear_end), 29 | guide_image_max_guidance, guide_image_header, guide_image_mode, 30 | guide_image_reuse, strength, steps, guidance_scale, samples, seed, 31 | debug) 32 | return imgs 33 | 34 | with gr.Group(): 35 | with gr.Box(): 36 | with gr.Row().style(equal_height=True): 37 | prompt = gr.TextArea( 38 | label='Enter your prompt', 39 | show_label=False, 40 | max_lines=1, 41 | placeholder='Enter your prompt', 42 | ).style( 43 | border=(True, False, True, True), 44 | rounded=(True, False, False, True), 45 | ) 46 | generate: gr.Button = gr.Button(value='Generate image', 47 | variant='primary').style( 48 | margin=False, 49 | rounded=(False, True, True, 50 | False), 51 | ) # type: ignore 52 | with gr.Row().style(equal_height=True): 53 | init_image = gr.Image(label='Initial image', 54 | source='upload', 55 | interactive=True, 56 | type='pil') 57 | with gr.Tab('Guide Image'): 58 | guide_image = gr.Image(label='Guidance image', 59 | source='upload', 60 | interactive=True, 61 | type='pil') 62 | with gr.Tab('Guide Text'): 63 | guide_text = gr.TextArea( 64 | label='Text Guide ( only usable if no image is set )', 65 | max_lines=1, 66 | placeholder=('Enter text to blend with the prompt ' 67 | 'using the same techniques used for ' 68 | 'image mapping'), 69 | ) 70 | with gr.Row(equal_height=True): 71 | strength = gr.Slider(label='Diffusion Strength ( For Img2Img )', 72 | minimum=0, 73 | maximum=1, 74 | value=0.6, 75 | step=0.01) 76 | mapping_concepts = gr.TextArea( 77 | label='Image Guidance Mapping Concepts', 78 | max_lines=1, 79 | placeholder= 80 | 'Enter items in image you would like to map directly', 81 | ) 82 | 83 | with gr.Row(equal_height=True): 84 | with gr.Column(scale=2, variant='panel'): 85 | steps = gr.Slider(label='Steps', 86 | minimum=8, 87 | maximum=100, 88 | value=30, 89 | step=2) 90 | with gr.Column(scale=1, variant='panel'): 91 | guide_image_threshold_mult = gr.Slider( 92 | label='Threshold "Match" Guidance Multiplier ( Image )', 93 | minimum=-1, 94 | maximum=1, 95 | value=0.25, 96 | step=0.01) 97 | with gr.Column(scale=1, variant='panel'): 98 | guide_image_threshold_floor = gr.Slider( 99 | label='Threshold "Match" Guidance Floor ( Image )', 100 | minimum=0, 101 | maximum=1, 102 | value=0.75, 103 | step=0.01) 104 | 105 | with gr.Row(equal_height=True): 106 | with gr.Column(scale=2, variant='panel'): 107 | samples = gr.Slider(label='Batches ( Images )', 108 | minimum=1, 109 | maximum=16, 110 | value=4, 111 | step=1) 112 | with gr.Column(scale=1, variant='panel'): 113 | guide_image_linear_start = gr.Slider( 114 | label='Linear Guidance Start ( Image )', 115 | minimum=-1, 116 | maximum=1, 117 | value=0.1, 118 | step=0.01) 119 | with gr.Column(scale=1, variant='panel'): 120 | guide_image_linear_end = gr.Slider( 121 | label='Linear Guidance End ( Image )', 122 | minimum=-1, 123 | maximum=1, 124 | value=0.5, 125 | step=0.01) 126 | 127 | with gr.Row(equal_height=True): 128 | guidance_scale = gr.Slider(label='Guidance Scale ( Overall )', 129 | minimum=0, 130 | maximum=20, 131 | value=8, 132 | step=0.5) 133 | guide_image_clustered = gr.Slider( 134 | label='Clustered "Match" Guidance ( Image )', 135 | minimum=-0.5, 136 | maximum=0.5, 137 | value=0.15, 138 | step=0.01) 139 | 140 | with gr.Row(equal_height=True): 141 | with gr.Column(scale=2, variant='panel'): 142 | seed = gr.Number(label='Seed', 143 | precision=0, 144 | value=0, 145 | interactive=True) 146 | with gr.Column(scale=1, variant='panel'): 147 | guide_image_max = gr.Slider(label='Max Image Guidance', 148 | minimum=0, 149 | maximum=1, 150 | value=0.35, 151 | step=0.01, 152 | interactive=True) 153 | with gr.Column(scale=1, variant='panel'): 154 | guide_image_header = gr.Slider(label='Max Image Header Mult', 155 | minimum=0, 156 | maximum=1, 157 | value=0, 158 | step=0.01, 159 | interactive=True) 160 | 161 | with gr.Row(equal_height=True): 162 | height = gr.Slider(minimum=64, 163 | maximum=2048, 164 | step=64, 165 | label="Init Height", 166 | value=512) 167 | width = gr.Slider(minimum=64, 168 | maximum=2048, 169 | step=64, 170 | label="Init Width", 171 | value=512) 172 | guide_image_mode = gr.Radio( 173 | label='Mapping Priority', 174 | choices=['Text Order', 'Best Fit', 'Direct'], 175 | value='Best Fit', 176 | type='index') 177 | with gr.Group(elem_id='cbgroup'): 178 | guide_image_reuse = gr.Checkbox(label='Reuse Latents', 179 | value=True) 180 | debug = gr.Checkbox(label='Export Debug Images', value=False) 181 | 182 | gallery = gr.Gallery(label='Generated images', 183 | show_label=False, 184 | elem_id='gallery').style(grid=2, height='auto') 185 | 186 | prompt.submit(run, 187 | inputs=[ 188 | prompt, init_image, guide_image, guide_text, height, 189 | width, mapping_concepts, guide_image_threshold_mult, 190 | guide_image_threshold_floor, guide_image_clustered, 191 | guide_image_linear_start, guide_image_linear_end, 192 | guide_image_max, guide_image_header, guide_image_mode, 193 | guide_image_reuse, strength, steps, guidance_scale, 194 | samples, seed, debug 195 | ], 196 | outputs=[gallery]) 197 | generate.click(run, 198 | inputs=[ 199 | prompt, init_image, guide_image, guide_text, height, 200 | width, mapping_concepts, guide_image_threshold_mult, 201 | guide_image_threshold_floor, guide_image_clustered, 202 | guide_image_linear_start, guide_image_linear_end, 203 | guide_image_max, guide_image_header, 204 | guide_image_mode, guide_image_reuse, strength, steps, 205 | guidance_scale, samples, seed, debug 206 | ], 207 | outputs=[gallery]) 208 | -------------------------------------------------------------------------------- /pipeline/flex.py: -------------------------------------------------------------------------------- 1 | '''Pipeline, a modification of Img2Img from huggingface/diffusers''' 2 | import inspect 3 | import warnings 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import numpy as np 7 | import torch 8 | 9 | import PIL 10 | import PIL.Image 11 | 12 | from transformers.models.clip.modeling_clip import CLIPModel 13 | from transformers.models.clip.tokenization_clip import CLIPTokenizer 14 | 15 | from diffusers.configuration_utils import FrozenDict 16 | from diffusers.models import AutoencoderKL, UNet2DConditionModel 17 | from diffusers.pipeline_utils import DiffusionPipeline 18 | from diffusers.schedulers import (DDIMScheduler, LMSDiscreteScheduler, 19 | PNDMScheduler) 20 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput 21 | 22 | from encode.clip import preprocess 23 | from pipeline.guide import GuideBase 24 | 25 | 26 | class FlexPipeline(DiffusionPipeline): 27 | r''' 28 | Pipeline for text-guided image to image generation using Stable Diffusion. 29 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 30 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 31 | Args: 32 | vae ([`AutoencoderKL`]): 33 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 34 | clip ([`CLIPModel`]): 35 | Frozen CLIPModel. Stable Diffusion traditionally uses the text portion but we'll allow both of: 36 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel), specifically 37 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 38 | tokenizer (`CLIPTokenizer`): 39 | Tokenizer of class 40 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 41 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 42 | scheduler ([`SchedulerMixin`]): 43 | A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of 44 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 45 | ''' 46 | def __init__( 47 | self, 48 | vae: AutoencoderKL, 49 | clip: CLIPModel, 50 | tokenizer: CLIPTokenizer, 51 | unet: UNet2DConditionModel, 52 | scheduler: Union[DDIMScheduler, PNDMScheduler], 53 | ): 54 | super().__init__() 55 | scheduler = scheduler.set_format('pt') 56 | 57 | if hasattr(scheduler.config, 58 | 'steps_offset') and scheduler.config['steps_offset'] != 1: 59 | warnings.warn( 60 | f'The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`' 61 | f' should be set to 1 instead of {scheduler.config["steps_offset"]}. Please make sure ' 62 | 'to update the config accordingly as leaving `steps_offset` might led to incorrect results' 63 | ' in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,' 64 | ' it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`' 65 | ' file', 66 | DeprecationWarning, 67 | ) 68 | new_config = dict(scheduler.config) 69 | new_config['steps_offset'] = 1 70 | setattr(scheduler, '_internal_dict', FrozenDict(new_config)) 71 | 72 | self.vae = vae 73 | self.clip = clip 74 | self.tokenizer = tokenizer 75 | self.unet = unet 76 | self.scheduler = scheduler 77 | self.register_modules( 78 | vae=vae, 79 | clip=clip, 80 | tokenizer=tokenizer, 81 | unet=unet, 82 | scheduler=scheduler, 83 | ) 84 | 85 | def enable_attention_slicing(self, 86 | slice_size: Optional[Union[str, 87 | int]] = 'auto'): 88 | r''' 89 | Enable sliced attention computation. 90 | When this option is enabled, the attention module will split the input tensor in slices, to compute attention 91 | in several steps. This is useful to save some memory in exchange for a small speed decrease. 92 | Args: 93 | slice_size (`str` or `int`, *optional*, defaults to `'auto'`): 94 | When `'auto'`, halves the input to the attention heads, so attention will be computed in two steps. If 95 | a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, 96 | `attention_head_dim` must be a multiple of `slice_size`. 97 | ''' 98 | if slice_size == 'auto': 99 | # half the attention head size is usually a good trade-off between 100 | # speed and memory 101 | slice_size = self.unet.config['attention_head_dim'] // 2 102 | self.unet.set_attention_slice(slice_size) 103 | 104 | def disable_attention_slicing(self): 105 | r''' 106 | Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go 107 | back to computing attention in one step. 108 | ''' 109 | # set slice_size = `None` to disable `set_attention_slice` 110 | self.enable_attention_slicing(None) 111 | 112 | def _latents_to_image( 113 | self, 114 | latents: torch.Tensor, 115 | pil: bool = True) -> Union[np.ndarray, List[PIL.Image.Image]]: 116 | # scale and decode the image latents with vae 117 | latents = 1 / 0.18215 * latents 118 | image = self.vae.decode(latents).sample # type: ignore 119 | 120 | image = (image / 2 + 0.5).clamp(0, 1) 121 | image = image.cpu().permute(0, 2, 3, 1).numpy() 122 | if pil: 123 | return self.numpy_to_pil(image) 124 | return image 125 | 126 | @torch.no_grad() 127 | def __call__(self, 128 | guide: GuideBase, 129 | init_image: Optional[Union[torch.Tensor, torch.FloatTensor, 130 | PIL.Image.Image]] = None, 131 | init_size: Tuple[int, int] = (512, 512), 132 | strength: float = 0.6, 133 | eta: float = 0.0, 134 | generator: Optional[torch.Generator] = None, 135 | output_type: str = 'pil', 136 | return_dict: bool = True, 137 | debug: bool = False): 138 | r''' 139 | Function invoked when calling the pipeline for generation. 140 | Args: 141 | init_image (`torch.FloatTensor` or `PIL.Image.Image`): 142 | `Image`, or tensor representing an image batch, that will be used as the starting point for the 143 | process. 144 | init_size (`Tuple[int, int]`): 145 | height and width for the output image initial latents if init_image not provided 146 | strength (`float`, *optional*, defaults to 0.8): 147 | Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1. 148 | `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The 149 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added 150 | noise will be maximum and the denoising process will run for the full number of iterations specified in 151 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. 152 | eta (`float`, *optional*, defaults to 0.0): 153 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 154 | [`schedulers.DDIMScheduler`], will be ignored for others. 155 | generator (`torch.Generator`, *optional*): 156 | A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation 157 | deterministic. 158 | output_type (`str`, *optional*, defaults to `'pil'`): 159 | The output format of the generate image. Choose between 160 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 161 | return_dict (`bool`, *optional*, defaults to `True`): 162 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 163 | plain tuple. 164 | Returns: 165 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 166 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 167 | When returning a tuple, the first element is a list with the generated images, and the second element is 168 | False. 169 | ''' 170 | if strength < 0 or strength > 1: 171 | raise ValueError( 172 | f'The value of strength should in [0.0, 1.0] but is {strength}') 173 | 174 | batch_size = guide.batch_size 175 | 176 | # set timesteps 177 | self.scheduler.set_timesteps(guide.steps) 178 | assert self.scheduler.timesteps is not None 179 | 180 | # Handle VAE latent initialization 181 | if init_image: 182 | if isinstance(init_image, PIL.Image.Image): 183 | init_image = preprocess(init_image) 184 | # TODO: Confirm correct tensor shape expectation 185 | width, height = init_image.shape[-2:] 186 | init_image = init_image.to(self.device) 187 | 188 | # encode the init image into latents and scale the latents 189 | init_latent_dist = self.vae.encode(init_image # type: ignore 190 | ).latent_dist 191 | init_latents = init_latent_dist.sample(generator=generator) 192 | init_latents = 0.18215 * init_latents 193 | # expand init_latents for batch_size 194 | init_latents = torch.cat([init_latents] * batch_size) 195 | 196 | # get the original timestep using init_timestep 197 | offset = self.scheduler.config.get('steps_offset', 0) 198 | init_timestep = int(guide.steps * strength) + offset 199 | init_timestep = min(init_timestep, guide.steps) 200 | if isinstance(self.scheduler, LMSDiscreteScheduler): 201 | timesteps = torch.tensor([guide.steps - init_timestep] 202 | * batch_size, 203 | dtype=torch.long, 204 | device=self.device) 205 | else: 206 | timesteps = self.scheduler.timesteps[-init_timestep] 207 | timesteps = torch.tensor([timesteps] * batch_size, 208 | dtype=torch.long, 209 | device=self.device) 210 | 211 | # add noise to latents using the timesteps 212 | noise = torch.randn(init_latents.shape, 213 | generator=generator, 214 | device=self.device) 215 | init_latents = self.scheduler.add_noise( 216 | init_latents, # type: ignore 217 | noise, # type: ignore 218 | timesteps) # type: ignore 219 | 220 | # Calc start timestep 221 | t_start = max(guide.steps - init_timestep + offset, 0) 222 | else: 223 | height, width = init_size 224 | channels: int = self.unet.in_channels # type: ignore 225 | # Random init 226 | init_latents = torch.randn( 227 | (batch_size, channels, height // 8, width // 8), 228 | generator=generator, 229 | device=self.device, 230 | ) 231 | 232 | # set timesteps 233 | self.scheduler.set_timesteps(guide.steps) 234 | 235 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 236 | if isinstance(self.scheduler, LMSDiscreteScheduler): 237 | init_latents: torch.Tensor = ( 238 | init_latents * self.scheduler.sigmas[0]) # type: ignore 239 | 240 | # No init image so we start from 0 241 | t_start = 0 242 | 243 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 244 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 245 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 246 | # and should be between [0, 1] 247 | accepts_eta = 'eta' in set( 248 | inspect.signature(self.scheduler.step).parameters.keys()) 249 | extra_step_kwargs = {} 250 | if accepts_eta: 251 | extra_step_kwargs['eta'] = eta 252 | 253 | latents = init_latents 254 | all_latents = None 255 | if debug: 256 | all_latents = [init_latents] 257 | # TODO: Training-Free Structured Diffusion: 258 | # https://openreview.net/pdf?id=PUIqjT4rzq7 259 | # Method that modifies the cross attention layers of the model to 260 | # improve "compositionality". It does this through an algorithmic 261 | # process... 262 | for i, t in enumerate( 263 | self.progress_bar(self.scheduler.timesteps[t_start:])): 264 | t_index = t 265 | 266 | # expand the latents if we are doing classifier free guidance 267 | latent_model_input = latents 268 | 269 | # if we use LMSDiscreteScheduler, let's make sure latents are multiplied by sigmas 270 | if isinstance(self.scheduler, LMSDiscreteScheduler): 271 | t_index = t_start + i 272 | sigma = self.scheduler.sigmas[t_index] # type: ignore 273 | # the model input needs to be scaled to match the continuous ODE formulation in K-LMS 274 | latent_model_input = latent_model_input / ((sigma**2 + 1)**0.5) 275 | 276 | # predict the noise residual 277 | noise_pred = guide.noise_pred(latent_model_input, t) 278 | 279 | # compute the previous noisy sample x_t -> x_t-1 280 | step = self.scheduler.step( 281 | noise_pred, 282 | t_index, 283 | latents, # type: ignore 284 | **extra_step_kwargs) 285 | latents = step.prev_sample # type: ignore 286 | if all_latents: 287 | all_latents.append(latents) 288 | 289 | if all_latents: 290 | image_batches = [ 291 | self._latents_to_image(l, output_type == 'pil') 292 | for l in all_latents 293 | ] 294 | if isinstance(image_batches[0], list): 295 | batch_images = [] 296 | for ib in image_batches: 297 | batch_images.extend(ib) 298 | else: 299 | batch_images = np.concatenate( 300 | image_batches, # type: ignore 301 | axis=0) 302 | else: 303 | batch_images = self._latents_to_image(latents, output_type == 'pil') 304 | 305 | if not return_dict: 306 | return (batch_images, False) 307 | 308 | return StableDiffusionPipelineOutput( 309 | images=batch_images, 310 | nsfw_content_detected=[False for _ in batch_images]) 311 | -------------------------------------------------------------------------------- /pipeline/guide.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | from diffusers.models import UNet2DConditionModel 4 | 5 | from encode.clip import CLIPEncoder 6 | 7 | 8 | class GuideBase(): 9 | def __init__(self, encoder: CLIPEncoder, unet: UNet2DConditionModel, 10 | guidance: float, steps: int) -> None: 11 | '''Initialize a guide for guiding noise into images. 12 | 13 | Args: 14 | encoder (CLIPEncoder): An encoder for encoding text to CLIP latents. 15 | unet (UNet2DConditionModel): Model for predicting the next\ 16 | stage of noise from provided latents and clip_embeddings. 17 | guidance (float): Guidance scale as defined in [Classifier-Free\ 18 | Diffusion Guidance](https://arxiv.org/abs/2207.12598).\ 19 | `guidance_scale` is defined as `w` of equation 2. of [Imagen\ 20 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is\ 21 | enabled by setting `guidance_scale > 1`. Higher guidance scale\ 22 | encourages to generate images that are closely linked to the\ 23 | text `prompt`, usually at the expense of lower image quality. 24 | steps (int): Number of steps to generate the image from noise over.\ 25 | More denoising steps usually lead to a higher quality image at\ 26 | the expense of slower inference. 27 | ''' 28 | self.encoder = encoder 29 | self.unet = unet 30 | self.uncond_embeds = encoder.prompt('') 31 | self.batch_size = 1 32 | self.guidance = guidance 33 | self.steps = steps 34 | 35 | def noise_pred(self, latents: torch.Tensor, step: int) -> torch.FloatTensor: 36 | raise NotImplementedError('noise_pred must be implemented.') 37 | 38 | 39 | class SimpleGuide(GuideBase): 40 | def __init__(self, encoder: CLIPEncoder, unet: UNet2DConditionModel, 41 | guidance: float, steps: int, clip_embeds: torch.Tensor): 42 | GuideBase.__init__(self, encoder, unet, guidance, steps) 43 | self.embeds = clip_embeds 44 | self.batch_size = self.embeds.shape[0] 45 | 46 | def noise_pred(self, latents: torch.Tensor, step: int) -> torch.FloatTensor: 47 | classifier_free_guidance = self.guidance > 1.0 48 | # run unet on all 49 | clip_embeddings = self.embeds 50 | if classifier_free_guidance: 51 | # Combine uncond then clip embeds into one stack 52 | clip_embeddings = torch.cat(([self.uncond_embeds] * self.batch_size) 53 | + [clip_embeddings]) 54 | latents = torch.cat([latents] * 2) 55 | # run unet on all uncond and clip embeds 56 | noise_pred = self.unet(latents, 57 | step, 58 | encoder_hidden_states=clip_embeddings).sample 59 | if classifier_free_guidance: 60 | # Combine the unconditioned guidance with the clip guidance 61 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 62 | noise_pred = noise_pred_uncond + self.guidance * ( 63 | noise_pred_text - noise_pred_uncond) 64 | return noise_pred 65 | 66 | 67 | class PromptGuide(SimpleGuide): 68 | def __init__(self, encoder: CLIPEncoder, unet: UNet2DConditionModel, 69 | guidance: float, steps: int, prompt: str | List[str]): 70 | SimpleGuide.__init__(self, encoder, unet, guidance, steps, 71 | encoder.prompt(prompt)) 72 | self.prompt = prompt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | diffusers==0.3.0 2 | gradio==3.4.0 3 | numpy==1.23.1 4 | Pillow==9.2.0 5 | torch==1.11.0 6 | torchvision==0.12.0 7 | transformers==4.21.1 8 | -------------------------------------------------------------------------------- /ui.py: -------------------------------------------------------------------------------- 1 | '''Simple Test UI''' 2 | import sys 3 | import gradio as gr 4 | 5 | import utils 6 | import interface.composer as composer 7 | import interface.sandbox as sandbox 8 | 9 | runner = None 10 | pargs = [a.strip().lower() for a in sys.argv[1:]] 11 | 12 | 13 | def _has_arg_like(*args: str) -> bool: 14 | return bool([pa for pa in pargs for a in args if a in pa]) 15 | 16 | 17 | def get_runner() -> utils.Runner: 18 | global runner 19 | if runner is None: 20 | runner = utils.Runner(not _has_arg_like('dl', 'download')) 21 | return runner 22 | 23 | 24 | def launch(): 25 | css = ''' 26 | textarea { 27 | max-height: 60px; 28 | } 29 | div.gr-block button.gr-button { 30 | max-width: 200px; 31 | } 32 | #gallery>div>.h-full { 33 | min-height: 20rem; 34 | } 35 | div.row, div.row>div.col { 36 | gap: 0; 37 | padding: 0; 38 | } 39 | div.row>div.col>div, div.row>div.col>div>div, div.row>div.col fieldset,\ 40 | #cbgroup, #cbgroup>div { 41 | min-height: 100%; 42 | } 43 | div#cbgroup { 44 | max-width: 25% 45 | } 46 | 47 | ''' 48 | block = gr.Blocks(css=css) 49 | 50 | with block: 51 | with gr.Tab('Sandbox'): 52 | sandbox.block(get_runner) 53 | with gr.Tab('Compose'): 54 | composer.block(get_runner) 55 | 56 | block.launch(server_name=('0.0.0.0' if _has_arg_like('lan') else None), 57 | debug=True) 58 | 59 | 60 | if __name__ == '__main__': 61 | launch() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | '''Utilities for running the pipeline''' 2 | import math 3 | import os 4 | from time import time 5 | from typing import Any, List, Sequence, Tuple 6 | 7 | import torch 8 | from torch.autocast_mode import autocast 9 | from PIL import Image 10 | 11 | from transformers.models.clip.modeling_clip import CLIPModel 12 | from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline 13 | from diffusers.schedulers import DDIMScheduler 14 | 15 | from guidance import Guide 16 | from composition.guide import CompositeGuide 17 | from composition.schema import EntitySchema, Schema 18 | from encode.clip import CLIPEncoder 19 | from pipeline.flex import FlexPipeline 20 | from pipeline.guide import GuideBase, SimpleGuide 21 | 22 | MAX_SEED = 2147483647 23 | 24 | sd_model = 'CompVis/stable-diffusion-v1-4' 25 | clip_model = 'openai/clip-vit-large-patch14' 26 | output_dir = './outputs' 27 | grid_dir = f'{output_dir}/grids' 28 | 29 | os.makedirs(grid_dir, exist_ok=True) 30 | 31 | 32 | def _i100(f: float) -> int: 33 | return int(f * 100) 34 | 35 | 36 | def image_grid(imgs: Sequence[Image.Image]): 37 | '''Builds an image that is a grid arrangement of all provided images''' 38 | num = len(imgs) 39 | # TODO: account for image height and width to compute ideal grid arrangement 40 | # TODO: support inputting a desired best fit aspect ration 41 | # We'll default to width first over height in terms of images not size 42 | cols = math.ceil(num**(1 / 2)) # sqrt 43 | rows = num // cols 44 | 45 | w, h = imgs[0].size 46 | grid = Image.new('RGB', size=(cols * w, rows * h)) 47 | 48 | for i, img in enumerate(imgs): 49 | grid.paste(img, box=((i % cols) * w, (i // cols) * h)) 50 | return grid 51 | 52 | 53 | class Runner(): 54 | def __init__(self, local: bool = True, device: str = 'cuda') -> None: 55 | if local: 56 | print('Running local mode only, to download models add --dl') 57 | else: 58 | print('Connecting to huggingface to check models...') 59 | assert FlexPipeline.from_pretrained is not None 60 | 61 | clip = CLIPModel.from_pretrained(clip_model, 62 | local_files_only=local, 63 | use_auth_token=not local) 64 | assert isinstance(clip, CLIPModel) 65 | sd: Any = StableDiffusionPipeline.from_pretrained( 66 | sd_model, local_files_only=local, use_auth_token=not local) 67 | # TODO: Tweak betas?? can we load?? 68 | # self.scheduler = scheduler = DDIMScheduler( 69 | # beta_start=0.00085, beta_end=0.012, beta_schedule='scaled_linear') 70 | self.pipe = FlexPipeline(sd.vae, clip, sd.tokenizer, sd.unet, 71 | sd.scheduler).to(device) 72 | self.eta = 0 73 | self.encoder = CLIPEncoder(clip, self.pipe.tokenizer) 74 | self.guide = Guide(clip, self.pipe.tokenizer, device=device) 75 | self.device = device 76 | self.generator = torch.Generator(device=device) 77 | 78 | def _set_seed(self, seed: int | None): 79 | if not seed: 80 | seed = int(torch.randint(0, MAX_SEED, (1,))[0]) 81 | else: 82 | seed = min(max(seed, 0), MAX_SEED) 83 | self.generator.manual_seed(seed) 84 | 85 | def _run(self, batches: int, guide: GuideBase, 86 | init_image: Image.Image | None, init_size: Tuple[int, int], 87 | strength: float, debug: bool, 88 | fp: str) -> Tuple[List[Image.Image], Image.Image]: 89 | all_images = [] 90 | for _ in range(batches): 91 | with autocast(self.device): 92 | stime = time() 93 | ms_time = int(stime * 1000) 94 | with torch.no_grad(): 95 | output = self.pipe(guide=guide, 96 | init_image=init_image, 97 | init_size=init_size, 98 | strength=strength, 99 | generator=self.generator, 100 | eta=self.eta, 101 | debug=debug) 102 | images = output['sample'] # type: ignore 103 | self.eta = time() - stime 104 | for i, img in enumerate(images): 105 | img.save(f'{output_dir}/{ms_time:>013d}_{i:>02d}_{fp}.png', 106 | format='png') 107 | all_images.extend(images) 108 | 109 | ms_time = int(time() * 1000) 110 | grid = image_grid(all_images) 111 | grid.save(f'{grid_dir}/{ms_time:>013d}_{fp}.png', format='png') 112 | return all_images, grid 113 | 114 | def gen(self, 115 | prompt: str | List[str] = '', 116 | init_image: Image.Image | None = None, 117 | guide: Image.Image | str | None = None, 118 | init_size: Tuple[int, int] = (512, 512), 119 | mapping_concepts: str = '', 120 | guide_threshold_mult: float = 0.5, 121 | guide_threshold_floor: float = 0.5, 122 | guide_clustered: float = 0.5, 123 | guide_linear: Tuple = (0.0, 0.5), 124 | guide_max_guidance: float = 0.5, 125 | guide_header_max: float = 0.15, 126 | guide_mode: int = 0, 127 | guide_reuse: bool = True, 128 | strength: float = 0.6, 129 | steps: int = 10, 130 | guidance_scale: float = 8, 131 | samples: int = 1, 132 | seed: int | None = None, 133 | debug: bool = False): 134 | 135 | fp = f'i2i_ds{int(strength * 100)}' if init_image else 't2i' 136 | if guide: 137 | fp += (f'_itm{_i100(guide_threshold_mult)}' 138 | f'_itf{_i100(guide_threshold_floor)}' 139 | f'_ic{_i100(guide_clustered)}' 140 | f'_il{_i100(guide_linear[0])}' 141 | f'-{_i100(guide_linear[1])}' 142 | f'_mg{_i100(guide_max_guidance)}' 143 | f'_hm{_i100(guide_header_max)}' 144 | f'_im{guide_mode:d}') 145 | fp += f'_st{steps}_gs{int(guidance_scale)}' 146 | if seed: 147 | fp += f'_se{seed}' 148 | 149 | self._set_seed(seed) 150 | 151 | guide_embeds = self.guide.embeds( 152 | prompt=prompt, 153 | guide=guide, 154 | mapping_concepts=mapping_concepts, 155 | guide_threshold_mult=guide_threshold_mult, 156 | guide_threshold_floor=guide_threshold_floor, 157 | guide_clustered=guide_clustered, 158 | guide_linear=guide_linear, 159 | guide_max_guidance=guide_max_guidance, 160 | guide_header_max=guide_header_max, 161 | guide_mode=guide_mode, 162 | guide_reuse=guide_reuse) 163 | pipeline_guide = SimpleGuide(self.encoder, self.pipe.unet, 164 | guidance_scale, steps, guide_embeds) 165 | return self._run(samples, pipeline_guide, init_image, init_size, 166 | strength, debug, fp) 167 | 168 | def compose(self, 169 | bg_prompt: str = '', 170 | entities_df: List[List[Any]] = [], 171 | start_style: str = '', 172 | end_style: str = '', 173 | style_blend: Tuple[float, float] = (0.0, 1.0), 174 | init_image: Image.Image | None = None, 175 | batches: int = 4, 176 | strength: float = 0.7, 177 | steps: int = 30, 178 | guidance_scale: float = 8.0, 179 | init_size: Tuple[int, int] = (512, 512), 180 | seed: int | None = None, 181 | debug: bool = False): 182 | 183 | fp = f'ci2i_ds{int(strength * 100)}' if init_image else 'ct2i' 184 | fp += f'_st{steps}_gs{int(guidance_scale)}' 185 | if seed: 186 | fp += f'_se{seed}' 187 | 188 | self._set_seed(seed) 189 | 190 | def _row_to_ent(row: List[Any]) -> EntitySchema | None: 191 | try: 192 | return EntitySchema( 193 | str(row[0]).strip(), (int(row[1]), int(row[2])), 194 | (int(row[3]), int(row[4])), float(row[5])) 195 | except Exception as ex: 196 | print('Failed to build EntitySchema:', ex) 197 | return None 198 | 199 | if hasattr(entities_df, '_values'): 200 | entities_df = entities_df._values # type: ignore 201 | rows = [_row_to_ent(r) for r in entities_df] 202 | rows = [r for r in rows if r and r.prompt] 203 | schema = Schema(bg_prompt, start_style, end_style, style_blend, rows) 204 | pipeline_guide = CompositeGuide(self.encoder, self.pipe.unet, 205 | guidance_scale, schema, steps) 206 | return self._run(batches, pipeline_guide, init_image, init_size, 207 | strength, debug, fp) 208 | --------------------------------------------------------------------------------