├── .gitattributes ├── .github └── workflows │ └── manual.yml ├── .gitignore ├── .gradio └── certificate.pem ├── .vscode └── settings.json ├── Compiler.py ├── HomeImage.png ├── LICENSE ├── README.md ├── _internal ├── ESRGAN │ └── put_esrgan_and_other_upscale_models_here ├── checkpoints │ └── put_checkpoints_here ├── clip │ └── sd1_clip_config.json ├── embeddings │ └── put_embeddings_or_textual_inversion_concepts_here ├── last_seed.txt ├── loras │ └── put_loras_here ├── output │ ├── Adetailer │ │ └── Adetailer_images_end_up_here │ ├── Flux │ │ └── Flux_images_end_up_here │ ├── HiresFix │ │ └── HiresFixed_images_end_up_here │ ├── Img2Img │ │ └── Upscaled_images_end_up_here │ └── classic │ │ └── normal_images_end_up_here ├── prompt.txt ├── sd1_tokenizer │ ├── merges.txt │ ├── special_tokens_map.json │ ├── tokenizer_config.json │ └── vocab.json └── yolos │ └── put_yolo_and_seg_files_here ├── app.py ├── dependency_flow.png ├── modules ├── Attention │ ├── Attention.py │ └── AttentionMethods.py ├── AutoDetailer │ ├── AD_util.py │ ├── ADetailer.py │ ├── SAM.py │ ├── SEGS.py │ ├── bbox.py │ ├── mask_util.py │ └── tensor_util.py ├── AutoEncoders │ ├── ResBlock.py │ ├── VariationalAE.py │ └── taesd.py ├── AutoHDR │ └── ahdr.py ├── BlackForest │ └── Flux.py ├── Device │ └── Device.py ├── FileManaging │ ├── Downloader.py │ ├── ImageSaver.py │ └── Loader.py ├── Model │ ├── LoRas.py │ ├── ModelBase.py │ └── ModelPatcher.py ├── NeuralNetwork │ ├── transformer.py │ └── unet.py ├── Quantize │ └── Quantizer.py ├── SD15 │ ├── SD15.py │ ├── SDClip.py │ └── SDToken.py ├── StableFast │ └── StableFast.py ├── UltimateSDUpscale │ ├── RDRB.py │ ├── USDU_upscaler.py │ ├── USDU_util.py │ ├── UltimateSDUpscale.py │ └── image_util.py ├── Utilities │ ├── Enhancer.py │ ├── Latent.py │ ├── upscale.py │ └── util.py ├── WaveSpeed │ ├── fbcache_nodes.py │ ├── first_block_cache.py │ ├── misc_nodes.py │ └── utils.py ├── clip │ ├── CLIPTextModel.py │ ├── Clip.py │ ├── FluxClip.py │ └── clip │ │ ├── hydit_clip.json │ │ ├── long_clipl.json │ │ ├── mt5_config_xl.json │ │ ├── sd1_clip_config.json │ │ ├── sd2_clip_config.json │ │ ├── t5_config_base.json │ │ ├── t5_config_xxl.json │ │ ├── t5_pile_config_xl.json │ │ └── t5_tokenizer │ │ ├── special_tokens_map.json │ │ ├── tokenizer.json │ │ └── tokenizer_config.json ├── cond │ ├── Activation.py │ ├── cast.py │ ├── cond.py │ └── cond_util.py ├── hidiffusion │ ├── msw_msa_attention.py │ └── utils.py ├── sample │ ├── CFG.py │ ├── ksampler_util.py │ ├── samplers.py │ ├── sampling.py │ └── sampling_util.py ├── tests │ └── test.py └── user │ ├── GUI.py │ ├── app_instance.py │ └── pipeline.py ├── pipeline.bat ├── pipeline.sh ├── requirements.txt ├── run.bat ├── run.sh ├── run_web.bat ├── run_web.sh └── stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.github/workflows/manual.yml: -------------------------------------------------------------------------------- 1 | name: Manual workflow 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | test: 11 | runs-on: self-hosted 12 | 13 | steps: 14 | - uses: actions/checkout@v3 15 | 16 | - name: Set up Python 3.10 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: '3.10' 20 | 21 | - name: Cache dependencies 22 | uses: actions/cache@v3 23 | with: 24 | path: ~/.cache/pip 25 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} 26 | restore-keys: | 27 | ${{ runner.os }}-pip- 28 | 29 | - name: Create virtual environment 30 | run: | 31 | python -m venv .venv 32 | if [ "$RUNNER_OS" == "Windows" ]; then 33 | . .venv/Scripts/activate 34 | else 35 | . .venv/bin/activate 36 | fi 37 | shell: bash 38 | 39 | - name: Install dependencies 40 | run: | 41 | if [ "$RUNNER_OS" == "Windows" ]; then 42 | . .venv/Scripts/activate 43 | else 44 | . .venv/bin/activate 45 | fi 46 | python -m pip install --upgrade pip 47 | pip install uv 48 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124 49 | pip install "numpy<2.0.0" 50 | if [ -f requirements.txt ]; then 51 | uv pip install -r requirements.txt 52 | fi 53 | shell: bash 54 | 55 | - name: Test pipeline variants 56 | run: | 57 | if [ "$RUNNER_OS" == "Windows" ]; then 58 | . .venv/Scripts/activate 59 | else 60 | . .venv/bin/activate 61 | fi 62 | # Test basic pipeline 63 | python modules/user/pipeline.py "1girl" 512 512 1 1 --hires-fix --adetailer --autohdr --prio-speed 64 | # Test image to image 65 | python modules/user/pipeline.py "./_internal/output/Adetailer/LD-head_00001_.png" 512 512 1 1 --img2img --prio-speed 66 | shell: bash 67 | 68 | - name: Upload test artifacts 69 | if: always() 70 | uses: actions/upload-artifact@v4 71 | with: 72 | name: test-outputs-${{ github.sha }} 73 | path: | 74 | _internal/output/**/*.png 75 | _internal/output/Classic/*.png 76 | _internal/output/Flux/*.png 77 | _internal/output/HF/*.png 78 | retention-days: 5 79 | compression-level: 6 80 | if-no-files-found: warn 81 | 82 | - name: Report status 83 | if: always() 84 | run: | 85 | if [ ${{ job.status }} == 'success' ]; then 86 | echo "All tests passed successfully!" 87 | else 88 | echo "Some tests failed. Check the logs above for details." 89 | exit 1 90 | fi 91 | shell: bash 92 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | *.pyc 3 | *.pth 4 | *.pt 5 | *.safetensors 6 | *.gguf 7 | *.png 8 | /.idea 9 | /htmlcov 10 | .coverage 11 | .toml 12 | __pycache__ 13 | .venv 14 | !HomeImage.png 15 | *.txt 16 | -------------------------------------------------------------------------------- /.gradio/certificate.pem: -------------------------------------------------------------------------------- 1 | -----BEGIN CERTIFICATE----- 2 | MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw 3 | TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh 4 | cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4 5 | WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu 6 | ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY 7 | MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc 8 | h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+ 9 | 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U 10 | A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW 11 | T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH 12 | B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC 13 | B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv 14 | KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn 15 | OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn 16 | jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw 17 | qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI 18 | rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV 19 | HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq 20 | hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL 21 | ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ 22 | 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK 23 | NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5 24 | ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur 25 | TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC 26 | jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc 27 | oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq 28 | 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA 29 | mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d 30 | emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc= 31 | -----END CERTIFICATE----- 32 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.unittestArgs": [ 3 | "-v", 4 | "-s", 5 | ".", 6 | "-p", 7 | "*test.py" 8 | ], 9 | "python.testing.pytestEnabled": false, 10 | "python.testing.unittestEnabled": true, 11 | "python.analysis.autoImportCompletions": true, 12 | "python.analysis.typeCheckingMode": "off" 13 | } -------------------------------------------------------------------------------- /Compiler.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | 4 | files_ordered = [ 5 | "./modules/Utilities/util.py", 6 | "./modules/sample/sampling_util.py", 7 | "./modules/Device/Device.py", 8 | "./modules/cond/cond_util.py", 9 | "./modules/cond/cond.py", 10 | "./modules/sample/ksampler_util.py", 11 | "./modules/cond/cast.py", 12 | "./modules/Attention/AttentionMethods.py", 13 | "./modules/AutoEncoders/taesd.py", 14 | "./modules/cond/cond.py", 15 | "./modules/cond/Activation.py", 16 | "./modules/Attention/Attention.py", 17 | "./modules/sample/samplers.py", 18 | "./modules/sample/CFG.py", 19 | "./modules/NeuralNetwork/transformer.py", 20 | "./modules/sample/sampling.py", 21 | "./modules/clip/CLIPTextModel.py", 22 | "./modules/AutoEncoders/ResBlock.py", 23 | "./modules/AutoDetailer/mask_util.py", 24 | "./modules/NeuralNetwork/unet.py", 25 | "./modules/SD15/SDClip.py", 26 | "./modules/SD15/SDToken.py", 27 | "./modules/UltimateSDUpscale/USDU_util.py", 28 | "./modules/StableFast/SF_util.py", 29 | "./modules/Utilities/Latent.py", 30 | "./modules/AutoDetailer/SEGS.py", 31 | "./modules/AutoDetailer/tensor_util.py", 32 | "./modules/AutoDetailer/AD_util.py", 33 | "./modules/clip/FluxClip.py", 34 | "./modules/Model/ModelPatcher.py", 35 | "./modules/Model/ModelBase.py", 36 | "./modules/UltimateSDUpscale/image_util.py", 37 | "./modules/UltimateSDUpscale/RDRB.py", 38 | "./modules/StableFast/ModuleFactory.py", 39 | "./modules/AutoDetailer/bbox.py", 40 | "./modules/AutoEncoders/VariationalAE.py", 41 | "./modules/clip/Clip.py", 42 | "./modules/Model/LoRas.py", 43 | "./modules/BlackForest/Flux.py", 44 | "./modules/UltimateSDUpscale/USDU_upscaler.py", 45 | "./modules/StableFast/ModuleTracing.py", 46 | "./modules/hidiffusion/utils.py", 47 | "./modules/FileManaging/Downloader.py", 48 | "./modules/AutoDetailer/SAM.py", 49 | "./modules/AutoDetailer/ADetailer.py", 50 | "./modules/Quantize/Quantizer.py", 51 | "./modules/FileManaging/Loader.py", 52 | "./modules/SD15/SD15.py", 53 | "./modules/UltimateSDUpscale/UltimateSDUpscale.py", 54 | "./modules/StableFast/StableFast.py", 55 | "./modules/hidiffusion/msw_msa_attention.py", 56 | "./modules/FileManaging/ImageSaver.py", 57 | "./modules/Utilities/Enhancer.py", 58 | "./modules/Utilities/upscale.py", 59 | "./modules/user/pipeline.py", 60 | ] 61 | 62 | def get_file_patterns(): 63 | patterns = [] 64 | seen = set() 65 | for path in files_ordered: 66 | filename = os.path.basename(path) 67 | name = os.path.splitext(filename)[0] 68 | if name not in seen: 69 | # Pattern 1: matches module name when not in brackets or after a dot 70 | pattern1 = rf'(? 2 | 3 | # Say hi to LightDiffusion-Next 👋 4 | 5 | [![demo platform](https://img.shields.io/badge/Play%20with%20LightDiffusion%21-LightDiffusion%20demo%20platform-lightblue)](https://huggingface.co/spaces/Aatricks/LightDiffusion-Next)  6 | 7 | **LightDiffusion-Next** is the fastest AI-powered image generation GUI/CLI, combining speed, precision, and flexibility in one cohesive tool. 8 |
9 |
10 | 11 | Logo 12 | 13 | 14 |
15 | 16 | 17 | As a refactored and improved version of the original [LightDiffusion repository](https://github.com/Aatrick/LightDiffusion), this project enhances usability, maintainability, and functionality while introducing a host of new features to streamline your creative workflows. 18 | 19 | ## Motivation: 20 | 21 | **LightDiffusion** was originally meant to be made in Rust, but due to the lack of support for the Rust language in the AI community, it was made in Python with the goal of being the simplest and fastest AI image generation tool. 22 | 23 | That's when the first version of LightDiffusion was born which only counted [3000 lines of code](https://github.com/LightDiffusion/LightDiffusion-original), only using Pytorch. With time, the [project](https://github.com/Aatrick/LightDiffusion) grew and became more complex, and the need for a refactor was evident. This is where **LightDiffusion-Next** comes in, with a more modular and maintainable codebase, and a plethora of new features and optimizations. 24 | 25 | 📚 Learn more in the [official documentation](https://aatrick.github.io/LightDiffusion/). 26 | 27 | --- 28 | 29 | ## 🌟 Highlights 30 | 31 | ![image](https://github.com/user-attachments/assets/b994fe0d-3a2e-44ff-93a4-46919cf865e3) 32 | 33 | **LightDiffusion-Next** offers a powerful suite of tools to cater to creators at every level. At its core, it supports **Text-to-Image** (Txt2Img) and **Image-to-Image** (Img2Img) generation, offering a variety of upscale methods and samplers, to make it easier to create stunning images with minimal effort. 34 | 35 | Advanced users can take advantage of features like **attention syntax**, **Hires-Fix** or **ADetailer**. These tools provide better quality and flexibility for generating complex and high-resolution outputs. 36 | 37 | **LightDiffusion-Next** is fine-tuned for **performance**. Features such as **Xformers** acceleration, **BFloat16** precision support, **WaveSpeed** dynamic caching, and **Stable-Fast** model compilation (which offers up to a 70% speed boost) ensure smooth and efficient operation, even on demanding workloads. 38 | 39 | --- 40 | 41 | ## ✨ Feature Showcase 42 | 43 | Here’s what makes LightDiffusion-Next stand out: 44 | 45 | - **Speed and Efficiency**: 46 | Enjoy industry-leading performance with built-in Xformers, Pytorch, Wavespeed and Stable-Fast optimizations, achieving up to 30% faster speeds compared to the rest of the AI image generation backends in SD1.5 and up to 2x for Flux. 47 | 48 | - **Automatic Detailing**: 49 | Effortlessly enhance faces and body details with AI-driven tools based on the [Impact Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack). 50 | 51 | - **State Preservation**: 52 | Save and resume your progress with saved states, ensuring seamless transitions between sessions. 53 | 54 | - **Advanced GUI, WebUI and CLI**: 55 | Work through a user-friendly graphical interface as GUI or in the browser using Gradio or leverage the streamlined pipeline for CLI-based workflows. 56 | 57 | - **Integration-Ready**: 58 | Collaborate and create directly in Discord with [Boubou](https://github.com/Aatrick/Boubou), or preview images dynamically with the optional **TAESD preview mode**. 59 | 60 | - **Image Previewing**: 61 | Get a real-time preview of your generated images with TAESD, allowing for user-friendly and interactive workflows. 62 | 63 | - **Image Upscaling**: 64 | Enhance your images with advanced upscaling options like UltimateSDUpscaling, ensuring high-quality results every time. 65 | 66 | - **Prompt Refinement**: 67 | Use the Ollama-powered automatic prompt enhancer to refine your prompts and generate more accurate and detailed outputs. 68 | 69 | - **LoRa and Textual Inversion Embeddings**: 70 | Leverage LoRa and textual inversion embeddings for highly customized and nuanced results, adding a new dimension to your creative process. 71 | 72 | - **Low-End Device Support**: 73 | Run LightDiffusion-Next on low-end devices with as little as 2GB of VRAM or even no GPU, ensuring accessibility for all users. 74 | 75 | - **CFG++**: 76 | Uses samplers modified to use CFG++ for better quality results compared to traditional methods. 77 | 78 | --- 79 | 80 | ## ⚡ Performance Benchmarks 81 | 82 | **LightDiffusion-Next** dominates in performance: 83 | 84 | | **Tool** | **Speed (it/s)** | 85 | |------------------------------------|------------------| 86 | | **LightDiffusion with Stable-Fast** | 2.8 | 87 | | **LightDiffusion** | 1.9 | 88 | | **ComfyUI** | 1.4 | 89 | | **SDForge** | 1.3 | 90 | | **SDWebUI** | 0.9 | 91 | 92 | (All benchmarks are based on a 1024x1024 resolution with a batch size of 1 using BFloat16 precision without tweaking installations. Made with a 3060 mobile GPU using SD1.5.) 93 | 94 | With its unmatched speed and efficiency, LightDiffusion-Next sets the benchmark for AI image generation tools. 95 | 96 | --- 97 | 98 | ## 🛠 Installation 99 | 100 | ### Quick Start 101 | 102 | 1. Download a release or clone this repository. 103 | 2. Run `run.bat` in a terminal. 104 | 3. Start creating! 105 | 106 | ### Command-Line Pipeline 107 | 108 | For a GUI-free experience, use the pipeline: 109 | ```bash 110 | pipeline.bat 111 | ``` 112 | Use `pipeline.bat -h` for more options. 113 | 114 | --- 115 | 116 | ### Advanced Setup 117 | 118 | - **Install from Source**: 119 | Install dependencies via: 120 | ```bash 121 | pip install -r requirements.txt 122 | ``` 123 | Add your SD1/1.5 safetensors model to the `checkpoints` directory, then launch the application. 124 | 125 | - **⚡Stable-Fast Optimization**: 126 | Follow [this guide](https://github.com/chengzeyi/stable-fast?tab=readme-ov-file#installation) to enable Stable-Fast mode for optimal performance. 127 | 128 | - **🦙 Prompt Enhancer**: 129 | Refine your prompts with Ollama: 130 | ```bash 131 | pip install ollama 132 | ollama run deepseek-r1 133 | ``` 134 | See the [Ollama guide](https://github.com/ollama/ollama?tab=readme-ov-file) for details. 135 | 136 | - **🤖 Discord Integration**: 137 | Set up the Discord bot by following the [Boubou installation guide](https://github.com/Aatrick/Boubou). 138 | 139 | --- 140 | 141 | 🎨 Enjoy exploring the powerful features of LightDiffusion-Next! 142 | -------------------------------------------------------------------------------- /_internal/ESRGAN/put_esrgan_and_other_upscale_models_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/ESRGAN/put_esrgan_and_other_upscale_models_here -------------------------------------------------------------------------------- /_internal/checkpoints/put_checkpoints_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/checkpoints/put_checkpoints_here -------------------------------------------------------------------------------- /_internal/clip/sd1_clip_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 2, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.24.0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /_internal/embeddings/put_embeddings_or_textual_inversion_concepts_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/embeddings/put_embeddings_or_textual_inversion_concepts_here -------------------------------------------------------------------------------- /_internal/last_seed.txt: -------------------------------------------------------------------------------- 1 | 7100032452232484160 -------------------------------------------------------------------------------- /_internal/loras/put_loras_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/loras/put_loras_here -------------------------------------------------------------------------------- /_internal/output/Adetailer/Adetailer_images_end_up_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/Adetailer/Adetailer_images_end_up_here -------------------------------------------------------------------------------- /_internal/output/Flux/Flux_images_end_up_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/Flux/Flux_images_end_up_here -------------------------------------------------------------------------------- /_internal/output/HiresFix/HiresFixed_images_end_up_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/HiresFix/HiresFixed_images_end_up_here -------------------------------------------------------------------------------- /_internal/output/Img2Img/Upscaled_images_end_up_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/Img2Img/Upscaled_images_end_up_here -------------------------------------------------------------------------------- /_internal/output/classic/normal_images_end_up_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/classic/normal_images_end_up_here -------------------------------------------------------------------------------- /_internal/prompt.txt: -------------------------------------------------------------------------------- 1 | prompt: masterpiece, best quality, (extremely detailed CG unity 8k wallpaper, masterpiece, best quality, ultra-detailed, best shadow), dynamic angle, light particles, high contrast, (best illumination), ((cinematic light)), colorful, hyper detail, dramatic light, intricate details, depth of field, 1girl, cowboy shot, white hair, long hair, 2 | neg: (worst quality, low quality:1.4), (zombie, sketch, interlocked fingers, comic), (embedding:EasyNegative), (embedding:badhandv4) 3 | w: 768 4 | h: 1024 5 | cfg: 8 6 | -------------------------------------------------------------------------------- /_internal/sd1_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /_internal/sd1_tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /_internal/yolos/put_yolo_and_seg_files_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/yolos/put_yolo_and_seg_files_here -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import gradio as gr 3 | import sys 4 | import os 5 | from PIL import Image 6 | import numpy as np 7 | import spaces 8 | 9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))) 10 | 11 | from modules.user.pipeline import pipeline 12 | import torch 13 | 14 | 15 | def load_generated_images(): 16 | """Load generated images with given prefix from disk""" 17 | image_files = glob.glob("./_internal/output/**/*.png") 18 | 19 | # If there are no image files, return 20 | if not image_files: 21 | return [] 22 | 23 | # Sort files by modification time in descending order 24 | image_files.sort(key=os.path.getmtime, reverse=True) 25 | 26 | # Get most recent timestamp 27 | latest_time = os.path.getmtime(image_files[0]) 28 | 29 | # Get all images from same batch (within 1 second of most recent) 30 | batch_images = [] 31 | for file in image_files: 32 | if abs(os.path.getmtime(file) - latest_time) < 1.0: 33 | try: 34 | img = Image.open(file) 35 | batch_images.append(img) 36 | except: 37 | continue 38 | 39 | if not batch_images: 40 | return [] 41 | return batch_images 42 | 43 | 44 | @spaces.GPU 45 | def generate_images( 46 | prompt: str, 47 | width: int = 512, 48 | height: int = 512, 49 | num_images: int = 1, 50 | batch_size: int = 1, 51 | hires_fix: bool = False, 52 | adetailer: bool = False, 53 | enhance_prompt: bool = False, 54 | img2img_enabled: bool = False, 55 | img2img_image: str = None, 56 | stable_fast: bool = False, 57 | reuse_seed: bool = False, 58 | flux_enabled: bool = False, 59 | prio_speed: bool = False, 60 | realistic_model: bool = False, 61 | progress=gr.Progress(), 62 | ): 63 | """Generate images using the LightDiffusion pipeline""" 64 | try: 65 | if img2img_enabled and img2img_image is not None: 66 | # Convert numpy array to PIL Image 67 | if isinstance(img2img_image, np.ndarray): 68 | img_pil = Image.fromarray(img2img_image) 69 | img_pil.save("temp_img2img.png") 70 | prompt = "temp_img2img.png" 71 | 72 | # Run pipeline and capture saved images 73 | with torch.inference_mode(): 74 | pipeline( 75 | prompt=prompt, 76 | w=width, 77 | h=height, 78 | number=num_images, 79 | batch=batch_size, 80 | hires_fix=hires_fix, 81 | adetailer=adetailer, 82 | enhance_prompt=enhance_prompt, 83 | img2img=img2img_enabled, 84 | stable_fast=stable_fast, 85 | reuse_seed=reuse_seed, 86 | flux_enabled=flux_enabled, 87 | prio_speed=prio_speed, 88 | autohdr=True, 89 | realistic_model=realistic_model, 90 | ) 91 | 92 | # Clean up temporary file if it exists 93 | if os.path.exists("temp_img2img.png"): 94 | os.remove("temp_img2img.png") 95 | 96 | return load_generated_images() 97 | 98 | except Exception: 99 | import traceback 100 | 101 | print(traceback.format_exc()) 102 | # Clean up temporary file if it exists 103 | if os.path.exists("temp_img2img.png"): 104 | os.remove("temp_img2img.png") 105 | return [Image.new("RGB", (512, 512), color="black")] 106 | 107 | 108 | # Create Gradio interface 109 | with gr.Blocks(title="LightDiffusion Web UI") as demo: 110 | gr.Markdown("# LightDiffusion Web UI") 111 | gr.Markdown("Generate AI images using LightDiffusion") 112 | gr.Markdown( 113 | "This is the demo for LightDiffusion, the fastest diffusion backend for generating images. https://github.com/LightDiffusion/LightDiffusion-Next" 114 | ) 115 | 116 | with gr.Row(): 117 | with gr.Column(): 118 | # Input components 119 | prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...") 120 | 121 | with gr.Row(): 122 | width = gr.Slider( 123 | minimum=64, maximum=2048, value=512, step=64, label="Width" 124 | ) 125 | height = gr.Slider( 126 | minimum=64, maximum=2048, value=512, step=64, label="Height" 127 | ) 128 | 129 | with gr.Row(): 130 | num_images = gr.Slider( 131 | minimum=1, maximum=10, value=1, step=1, label="Number of Images" 132 | ) 133 | batch_size = gr.Slider( 134 | minimum=1, maximum=4, value=1, step=1, label="Batch Size" 135 | ) 136 | 137 | with gr.Row(): 138 | hires_fix = gr.Checkbox(label="HiRes Fix") 139 | adetailer = gr.Checkbox(label="Auto Face/Body Enhancement") 140 | enhance_prompt = gr.Checkbox(label="Enhance Prompt") 141 | stable_fast = gr.Checkbox(label="Stable Fast Mode") 142 | 143 | with gr.Row(): 144 | reuse_seed = gr.Checkbox(label="Reuse Seed") 145 | flux_enabled = gr.Checkbox(label="Flux Mode") 146 | prio_speed = gr.Checkbox(label="Prioritize Speed") 147 | realistic_model = gr.Checkbox(label="Realistic Model") 148 | 149 | with gr.Row(): 150 | img2img_enabled = gr.Checkbox(label="Image to Image Mode") 151 | img2img_image = gr.Image(label="Input Image for img2img", visible=False) 152 | 153 | # Make input image visible only when img2img is enabled 154 | img2img_enabled.change( 155 | fn=lambda x: gr.update(visible=x), 156 | inputs=[img2img_enabled], 157 | outputs=[img2img_image], 158 | ) 159 | 160 | generate_btn = gr.Button("Generate") 161 | 162 | # Output gallery 163 | gallery = gr.Gallery( 164 | label="Generated Images", 165 | show_label=True, 166 | elem_id="gallery", 167 | columns=[2], 168 | rows=[2], 169 | object_fit="contain", 170 | height="auto", 171 | ) 172 | 173 | # Connect generate button to pipeline 174 | generate_btn.click( 175 | fn=generate_images, 176 | inputs=[ 177 | prompt, 178 | width, 179 | height, 180 | num_images, 181 | batch_size, 182 | hires_fix, 183 | adetailer, 184 | enhance_prompt, 185 | img2img_enabled, 186 | img2img_image, 187 | stable_fast, 188 | reuse_seed, 189 | flux_enabled, 190 | prio_speed, 191 | realistic_model, 192 | ], 193 | outputs=gallery, 194 | ) 195 | 196 | 197 | def is_huggingface_space(): 198 | return "SPACE_ID" in os.environ 199 | 200 | 201 | # For local testing 202 | if __name__ == "__main__": 203 | if is_huggingface_space(): 204 | demo.launch( 205 | debug=False, 206 | server_name="0.0.0.0", 207 | server_port=7860, # Standard HF Spaces port 208 | ) 209 | else: 210 | demo.launch( 211 | server_name="0.0.0.0", 212 | server_port=8000, 213 | auth=None, 214 | share=True, # Only enable sharing locally 215 | debug=True, 216 | ) 217 | -------------------------------------------------------------------------------- /dependency_flow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/dependency_flow.png -------------------------------------------------------------------------------- /modules/Attention/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | 5 | from modules.Utilities import util 6 | from modules.Attention import AttentionMethods 7 | from modules.Device import Device 8 | from modules.cond import cast 9 | 10 | 11 | def Normalize( 12 | in_channels: int, dtype: torch.dtype = None, device: torch.device = None 13 | ) -> torch.nn.GroupNorm: 14 | """#### Normalize the input channels. 15 | 16 | #### Args: 17 | - `in_channels` (int): The input channels. 18 | - `dtype` (torch.dtype, optional): The data type. Defaults to `None`. 19 | - `device` (torch.device, optional): The device. Defaults to `None`. 20 | 21 | #### Returns: 22 | - `torch.nn.GroupNorm`: The normalized input channels 23 | """ 24 | return torch.nn.GroupNorm( 25 | num_groups=32, 26 | num_channels=in_channels, 27 | eps=1e-6, 28 | affine=True, 29 | dtype=dtype, 30 | device=device, 31 | ) 32 | 33 | 34 | if Device.xformers_enabled(): 35 | logging.info("Using xformers cross attention") 36 | optimized_attention = AttentionMethods.attention_xformers 37 | else: 38 | logging.info("Using pytorch cross attention") 39 | optimized_attention = AttentionMethods.attention_pytorch 40 | 41 | optimized_attention_masked = optimized_attention 42 | 43 | 44 | def optimized_attention_for_device() -> AttentionMethods.attention_pytorch: 45 | """#### Get the optimized attention for a device. 46 | 47 | #### Returns: 48 | - `function`: The optimized attention function. 49 | """ 50 | return AttentionMethods.attention_pytorch 51 | 52 | 53 | class CrossAttention(nn.Module): 54 | """#### Cross attention module, which applies attention across the query and context. 55 | 56 | #### Args: 57 | - `query_dim` (int): The query dimension. 58 | - `context_dim` (int, optional): The context dimension. Defaults to `None`. 59 | - `heads` (int, optional): The number of heads. Defaults to `8`. 60 | - `dim_head` (int, optional): The head dimension. Defaults to `64`. 61 | - `dropout` (float, optional): The dropout rate. Defaults to `0.0`. 62 | - `dtype` (torch.dtype, optional): The data type. Defaults to `None`. 63 | - `device` (torch.device, optional): The device. Defaults to `None`. 64 | - `operations` (cast.disable_weight_init, optional): The operations. Defaults to `cast.disable_weight_init`. 65 | """ 66 | 67 | def __init__( 68 | self, 69 | query_dim: int, 70 | context_dim: int = None, 71 | heads: int = 8, 72 | dim_head: int = 64, 73 | dropout: float = 0.0, 74 | dtype: torch.dtype = None, 75 | device: torch.device = None, 76 | operations: cast.disable_weight_init = cast.disable_weight_init, 77 | ): 78 | super().__init__() 79 | inner_dim = dim_head * heads 80 | context_dim = util.default(context_dim, query_dim) 81 | 82 | self.heads = heads 83 | self.dim_head = dim_head 84 | 85 | self.to_q = operations.Linear( 86 | query_dim, inner_dim, bias=False, dtype=dtype, device=device 87 | ) 88 | self.to_k = operations.Linear( 89 | context_dim, inner_dim, bias=False, dtype=dtype, device=device 90 | ) 91 | self.to_v = operations.Linear( 92 | context_dim, inner_dim, bias=False, dtype=dtype, device=device 93 | ) 94 | 95 | self.to_out = nn.Sequential( 96 | operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), 97 | nn.Dropout(dropout), 98 | ) 99 | 100 | def forward( 101 | self, 102 | x: torch.Tensor, 103 | context: torch.Tensor = None, 104 | value: torch.Tensor = None, 105 | mask: torch.Tensor = None, 106 | ) -> torch.Tensor: 107 | """#### Forward pass of the cross attention module. 108 | 109 | #### Args: 110 | - `x` (torch.Tensor): The input tensor. 111 | - `context` (torch.Tensor, optional): The context tensor. Defaults to `None`. 112 | - `value` (torch.Tensor, optional): The value tensor. Defaults to `None`. 113 | - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`. 114 | 115 | #### Returns: 116 | - `torch.Tensor`: The output tensor. 117 | """ 118 | q = self.to_q(x) 119 | context = util.default(context, x) 120 | k = self.to_k(context) 121 | v = self.to_v(context) 122 | 123 | out = optimized_attention(q, k, v, self.heads) 124 | return self.to_out(out) 125 | 126 | 127 | class AttnBlock(nn.Module): 128 | """#### Attention block, which applies attention to the input tensor. 129 | 130 | #### Args: 131 | - `in_channels` (int): The input channels. 132 | """ 133 | 134 | def __init__(self, in_channels: int): 135 | super().__init__() 136 | self.in_channels = in_channels 137 | 138 | self.norm = Normalize(in_channels) 139 | self.q = cast.disable_weight_init.Conv2d( 140 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 141 | ) 142 | self.k = cast.disable_weight_init.Conv2d( 143 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 144 | ) 145 | self.v = cast.disable_weight_init.Conv2d( 146 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 147 | ) 148 | self.proj_out = cast.disable_weight_init.Conv2d( 149 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 150 | ) 151 | 152 | if Device.xformers_enabled_vae(): 153 | logging.info("Using xformers attention in VAE") 154 | self.optimized_attention = AttentionMethods.xformers_attention 155 | else: 156 | logging.info("Using pytorch attention in VAE") 157 | self.optimized_attention = AttentionMethods.pytorch_attention 158 | 159 | def forward(self, x: torch.Tensor) -> torch.Tensor: 160 | """#### Forward pass of the attention block. 161 | 162 | #### Args: 163 | - `x` (torch.Tensor): The input tensor. 164 | 165 | #### Returns: 166 | - `torch.Tensor`: The output tensor. 167 | """ 168 | h_ = x 169 | h_ = self.norm(h_) 170 | q = self.q(h_) 171 | k = self.k(h_) 172 | v = self.v(h_) 173 | 174 | h_ = self.optimized_attention(q, k, v) 175 | 176 | h_ = self.proj_out(h_) 177 | 178 | return x + h_ 179 | 180 | 181 | def make_attn(in_channels: int, attn_type: str = "vanilla") -> AttnBlock: 182 | """#### Make an attention block. 183 | 184 | #### Args: 185 | - `in_channels` (int): The input channels. 186 | - `attn_type` (str, optional): The attention type. Defaults to "vanilla". 187 | 188 | #### Returns: 189 | - `AttnBlock`: A class instance of the attention block. 190 | """ 191 | return AttnBlock(in_channels) 192 | -------------------------------------------------------------------------------- /modules/Attention/AttentionMethods.py: -------------------------------------------------------------------------------- 1 | try : 2 | import xformers 3 | except ImportError: 4 | pass 5 | import torch 6 | 7 | BROKEN_XFORMERS = False 8 | try: 9 | x_vers = xformers.__version__ 10 | # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error) 11 | BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20") 12 | except: 13 | pass 14 | 15 | 16 | def attention_xformers( 17 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False, flux=False 18 | ) -> torch.Tensor: 19 | """#### Make an attention call using xformers. Fastest attention implementation. 20 | 21 | #### Args: 22 | - `q` (torch.Tensor): The query tensor. 23 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q`. 24 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q`. 25 | - `heads` (int): The number of heads, must be a divisor of the hidden dimension. 26 | - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`. 27 | 28 | #### Returns: 29 | - `torch.Tensor`: The output tensor. 30 | """ 31 | if not flux: 32 | b, _, dim_head = q.shape 33 | dim_head //= heads 34 | 35 | q, k, v = map( 36 | lambda t: t.unsqueeze(3) 37 | .reshape(b, -1, heads, dim_head) 38 | .permute(0, 2, 1, 3) 39 | .reshape(b * heads, -1, dim_head) 40 | .contiguous(), 41 | (q, k, v), 42 | ) 43 | 44 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) 45 | 46 | out = ( 47 | out.unsqueeze(0) 48 | .reshape(b, heads, -1, dim_head) 49 | .permute(0, 2, 1, 3) 50 | .reshape(b, -1, heads * dim_head) 51 | ) 52 | return out 53 | else: 54 | if skip_reshape: 55 | b, _, _, dim_head = q.shape 56 | else: 57 | b, _, dim_head = q.shape 58 | dim_head //= heads 59 | 60 | disabled_xformers = False 61 | 62 | if BROKEN_XFORMERS: 63 | if b * heads > 65535: 64 | disabled_xformers = True 65 | 66 | if not disabled_xformers: 67 | if torch.jit.is_tracing() or torch.jit.is_scripting(): 68 | disabled_xformers = True 69 | 70 | if disabled_xformers: 71 | return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape) 72 | 73 | if skip_reshape: 74 | q, k, v = map( 75 | lambda t: t.reshape(b * heads, -1, dim_head), 76 | (q, k, v), 77 | ) 78 | else: 79 | q, k, v = map( 80 | lambda t: t.reshape(b, -1, heads, dim_head), 81 | (q, k, v), 82 | ) 83 | 84 | if mask is not None: 85 | pad = 8 - q.shape[1] % 8 86 | mask_out = torch.empty( 87 | [q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device 88 | ) 89 | mask_out[:, :, : mask.shape[-1]] = mask 90 | mask = mask_out[:, :, : mask.shape[-1]] 91 | 92 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) 93 | 94 | if skip_reshape: 95 | out = ( 96 | out.unsqueeze(0) 97 | .reshape(b, heads, -1, dim_head) 98 | .permute(0, 2, 1, 3) 99 | .reshape(b, -1, heads * dim_head) 100 | ) 101 | else: 102 | out = out.reshape(b, -1, heads * dim_head) 103 | 104 | return out 105 | 106 | 107 | def attention_pytorch( 108 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False, flux=False 109 | ) -> torch.Tensor: 110 | """#### Make an attention call using PyTorch. 111 | 112 | #### Args: 113 | - `q` (torch.Tensor): The query tensor. 114 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q. 115 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q. 116 | - `heads` (int): The number of heads, must be a divisor of the hidden dimension. 117 | - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`. 118 | 119 | #### Returns: 120 | - `torch.Tensor`: The output tensor. 121 | """ 122 | if not flux: 123 | b, _, dim_head = q.shape 124 | dim_head //= heads 125 | q, k, v = map( 126 | lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), 127 | (q, k, v), 128 | ) 129 | 130 | out = torch.nn.functional.scaled_dot_product_attention( 131 | q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False 132 | ) 133 | out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) 134 | return out 135 | else: 136 | if skip_reshape: 137 | b, _, _, dim_head = q.shape 138 | else: 139 | b, _, dim_head = q.shape 140 | dim_head //= heads 141 | q, k, v = map( 142 | lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), 143 | (q, k, v), 144 | ) 145 | 146 | out = torch.nn.functional.scaled_dot_product_attention( 147 | q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False 148 | ) 149 | out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) 150 | return out 151 | 152 | def xformers_attention( 153 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor 154 | ) -> torch.Tensor: 155 | """#### Compute attention using xformers. 156 | 157 | #### Args: 158 | - `q` (torch.Tensor): The query tensor. 159 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q`. 160 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q`. 161 | 162 | Returns: 163 | - `torch.Tensor`: The output tensor. 164 | """ 165 | B, C, H, W = q.shape 166 | q, k, v = map( 167 | lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(), 168 | (q, k, v), 169 | ) 170 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) 171 | out = out.transpose(1, 2).reshape(B, C, H, W) 172 | return out 173 | 174 | 175 | def pytorch_attention( 176 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor 177 | ) -> torch.Tensor: 178 | """#### Compute attention using PyTorch. 179 | 180 | #### Args: 181 | - `q` (torch.Tensor): The query tensor. 182 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q. 183 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q. 184 | 185 | #### Returns: 186 | - `torch.Tensor`: The output tensor. 187 | """ 188 | B, C, H, W = q.shape 189 | q, k, v = map( 190 | lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), 191 | (q, k, v), 192 | ) 193 | out = torch.nn.functional.scaled_dot_product_attention( 194 | q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False 195 | ) 196 | out = out.transpose(2, 3).reshape(B, C, H, W) 197 | return out 198 | -------------------------------------------------------------------------------- /modules/AutoDetailer/AD_util.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from ultralytics import YOLO 6 | from PIL import Image 7 | 8 | orig_torch_load = torch.load 9 | 10 | # importing YOLO breaking original torch.load capabilities 11 | torch.load = orig_torch_load 12 | 13 | 14 | def load_yolo(model_path: str) -> YOLO: 15 | """#### Load YOLO model. 16 | 17 | #### Args: 18 | - `model_path` (str): The path to the YOLO model. 19 | 20 | #### Returns: 21 | - `YOLO`: The YOLO model initialized with the specified model path. 22 | """ 23 | try: 24 | return YOLO(model_path) 25 | except ModuleNotFoundError: 26 | print("please download yolo model") 27 | 28 | 29 | def inference_bbox( 30 | model: YOLO, 31 | image: Image.Image, 32 | confidence: float = 0.3, 33 | device: str = "", 34 | ) -> List: 35 | """#### Perform inference on an image and return bounding boxes. 36 | 37 | #### Args: 38 | - `model` (YOLO): The YOLO model. 39 | - `image` (Image.Image): The image to perform inference on. 40 | - `confidence` (float): The confidence threshold for the bounding boxes. 41 | - `device` (str): The device to run the model on. 42 | 43 | #### Returns: 44 | - `List[List[str, List[int], np.ndarray, float]]`: The list of bounding boxes. 45 | """ 46 | pred = model(image, conf=confidence, device=device) 47 | 48 | bboxes = pred[0].boxes.xyxy.cpu().numpy() 49 | cv2_image = np.array(image) 50 | cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing 51 | cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) 52 | 53 | segms = [] 54 | for x0, y0, x1, y1 in bboxes: 55 | cv2_mask = np.zeros(cv2_gray.shape, np.uint8) 56 | cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) 57 | cv2_mask_bool = cv2_mask.astype(bool) 58 | segms.append(cv2_mask_bool) 59 | 60 | results = [[], [], [], []] 61 | for i in range(len(bboxes)): 62 | results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())]) 63 | results[1].append(bboxes[i]) 64 | results[2].append(segms[i]) 65 | results[3].append(pred[0].boxes[i].conf.cpu().numpy()) 66 | 67 | return results 68 | 69 | 70 | def create_segmasks(results: List) -> List: 71 | """#### Create segmentation masks from the results of the inference. 72 | 73 | #### Args: 74 | - `results` (List[List[str, List[int], np.ndarray, float]]): The results of the inference. 75 | 76 | #### Returns: 77 | - `List[List[int], np.ndarray, float]`: The list of segmentation masks. 78 | """ 79 | bboxs = results[1] 80 | segms = results[2] 81 | confidence = results[3] 82 | 83 | results = [] 84 | for i in range(len(segms)): 85 | item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) 86 | results.append(item) 87 | return results 88 | 89 | 90 | def dilate_masks(segmasks: List, dilation_factor: int, iter: int = 1) -> List: 91 | """#### Dilate the segmentation masks. 92 | 93 | #### Args: 94 | - `segmasks` (List[List[int], np.ndarray, float]): The segmentation masks. 95 | - `dilation_factor` (int): The dilation factor. 96 | - `iter` (int): The number of iterations. 97 | 98 | #### Returns: 99 | - `List[List[int], np.ndarray, float]`: The dilated segmentation masks. 100 | """ 101 | dilated_masks = [] 102 | kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) 103 | 104 | for i in range(len(segmasks)): 105 | cv2_mask = segmasks[i][1] 106 | 107 | dilated_mask = cv2.dilate(cv2_mask, kernel, iter) 108 | 109 | item = (segmasks[i][0], dilated_mask, segmasks[i][2]) 110 | dilated_masks.append(item) 111 | 112 | return dilated_masks 113 | 114 | 115 | def normalize_region(limit: int, startp: int, size: int) -> List: 116 | """#### Normalize the region. 117 | 118 | #### Args: 119 | - `limit` (int): The limit. 120 | - `startp` (int): The start point. 121 | - `size` (int): The size. 122 | 123 | #### Returns: 124 | - `List[int]`: The normalized start and end points. 125 | """ 126 | if startp < 0: 127 | new_endp = min(limit, size) 128 | new_startp = 0 129 | elif startp + size > limit: 130 | new_startp = max(0, limit - size) 131 | new_endp = limit 132 | else: 133 | new_startp = startp 134 | new_endp = min(limit, startp + size) 135 | 136 | return int(new_startp), int(new_endp) 137 | 138 | 139 | def make_crop_region(w: int, h: int, bbox: List, crop_factor: float) -> List: 140 | """#### Make the crop region. 141 | 142 | #### Args: 143 | - `w` (int): The width. 144 | - `h` (int): The height. 145 | - `bbox` (List[int]): The bounding box. 146 | - `crop_factor` (float): The crop factor. 147 | 148 | #### Returns: 149 | - `List[x1: int, y1: int, x2: int, y2: int]`: The crop region. 150 | """ 151 | x1 = bbox[0] 152 | y1 = bbox[1] 153 | x2 = bbox[2] 154 | y2 = bbox[3] 155 | 156 | bbox_w = x2 - x1 157 | bbox_h = y2 - y1 158 | 159 | crop_w = bbox_w * crop_factor 160 | crop_h = bbox_h * crop_factor 161 | 162 | kernel_x = x1 + bbox_w / 2 163 | kernel_y = y1 + bbox_h / 2 164 | 165 | new_x1 = int(kernel_x - crop_w / 2) 166 | new_y1 = int(kernel_y - crop_h / 2) 167 | 168 | # make sure position in (w,h) 169 | new_x1, new_x2 = normalize_region(w, new_x1, crop_w) 170 | new_y1, new_y2 = normalize_region(h, new_y1, crop_h) 171 | 172 | return [new_x1, new_y1, new_x2, new_y2] 173 | 174 | 175 | def crop_ndarray2(npimg: np.ndarray, crop_region: List) -> np.ndarray: 176 | """#### Crop the ndarray in 2 dimensions. 177 | 178 | #### Args: 179 | - `npimg` (np.ndarray): The ndarray to crop. 180 | - `crop_region` (List[int]): The crop region. 181 | 182 | #### Returns: 183 | - `np.ndarray`: The cropped ndarray. 184 | """ 185 | x1 = crop_region[0] 186 | y1 = crop_region[1] 187 | x2 = crop_region[2] 188 | y2 = crop_region[3] 189 | 190 | cropped = npimg[y1:y2, x1:x2] 191 | 192 | return cropped 193 | 194 | 195 | def crop_ndarray4(npimg: np.ndarray, crop_region: List) -> np.ndarray: 196 | """#### Crop the ndarray in 4 dimensions. 197 | 198 | #### Args: 199 | - `npimg` (np.ndarray): The ndarray to crop. 200 | - `crop_region` (List[int]): The crop region. 201 | 202 | #### Returns: 203 | - `np.ndarray`: The cropped ndarray. 204 | """ 205 | x1 = crop_region[0] 206 | y1 = crop_region[1] 207 | x2 = crop_region[2] 208 | y2 = crop_region[3] 209 | 210 | cropped = npimg[:, y1:y2, x1:x2, :] 211 | 212 | return cropped 213 | 214 | 215 | def crop_image(image: Image.Image, crop_region: List) -> Image.Image: 216 | """#### Crop the image. 217 | 218 | #### Args: 219 | - `image` (Image.Image): The image to crop. 220 | - `crop_region` (List[int]): The crop region. 221 | 222 | #### Returns: 223 | - `Image.Image`: The cropped image. 224 | """ 225 | return crop_ndarray4(image, crop_region) 226 | 227 | 228 | def segs_scale_match(segs: List[np.ndarray], target_shape: List) -> List: 229 | """#### Match the scale of the segmentation masks. 230 | 231 | #### Args: 232 | - `segs` (List[np.ndarray]): The segmentation masks. 233 | - `target_shape` (List[int]): The target shape. 234 | 235 | #### Returns: 236 | - `List[np.ndarray]`: The matched segmentation masks. 237 | """ 238 | h = segs[0][0] 239 | w = segs[0][1] 240 | 241 | th = target_shape[1] 242 | tw = target_shape[2] 243 | 244 | if (h == th and w == tw) or h == 0 or w == 0: 245 | return segs 246 | -------------------------------------------------------------------------------- /modules/AutoDetailer/SAM.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from segment_anything import SamPredictor, sam_model_registry 4 | import torch 5 | 6 | from modules.AutoDetailer import mask_util 7 | from modules.Device import Device 8 | 9 | 10 | def sam_predict( 11 | predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float 12 | ) -> list: 13 | """#### Predict masks using SAM. 14 | 15 | #### Args: 16 | - `predictor` (SamPredictor): The SAM predictor. 17 | - `points` (list): List of points. 18 | - `plabs` (list): List of point labels. 19 | - `bbox` (list): Bounding box. 20 | - `threshold` (float): Threshold for mask selection. 21 | 22 | #### Returns: 23 | - `list`: List of predicted masks. 24 | """ 25 | point_coords = None if not points else np.array(points) 26 | point_labels = None if not plabs else np.array(plabs) 27 | 28 | box = np.array([bbox]) if bbox is not None else None 29 | 30 | cur_masks, scores, _ = predictor.predict( 31 | point_coords=point_coords, point_labels=point_labels, box=box 32 | ) 33 | 34 | total_masks = [] 35 | 36 | selected = False 37 | max_score = 0 38 | max_mask = None 39 | for idx in range(len(scores)): 40 | if scores[idx] > max_score: 41 | max_score = scores[idx] 42 | max_mask = cur_masks[idx] 43 | 44 | if scores[idx] >= threshold: 45 | selected = True 46 | total_masks.append(cur_masks[idx]) 47 | else: 48 | pass 49 | 50 | if not selected and max_mask is not None: 51 | total_masks.append(max_mask) 52 | 53 | return total_masks 54 | 55 | 56 | def is_same_device(a: torch.device, b: torch.device) -> bool: 57 | """#### Check if two devices are the same. 58 | 59 | #### Args: 60 | - `a` (torch.device): The first device. 61 | - `b` (torch.device): The second device. 62 | 63 | #### Returns: 64 | - `bool`: Whether the devices are the same. 65 | """ 66 | a_device = torch.device(a) if isinstance(a, str) else a 67 | b_device = torch.device(b) if isinstance(b, str) else b 68 | return a_device.type == b_device.type and a_device.index == b_device.index 69 | 70 | 71 | class SafeToGPU: 72 | """#### Class to safely move objects to GPU.""" 73 | 74 | def __init__(self, size: int): 75 | self.size = size 76 | 77 | def to_device(self, obj: torch.nn.Module, device: torch.device) -> None: 78 | """#### Move an object to a device. 79 | 80 | #### Args: 81 | - `obj` (torch.nn.Module): The object to move. 82 | - `device` (torch.device): The target device. 83 | """ 84 | if is_same_device(device, "cpu"): 85 | obj.to(device) 86 | else: 87 | if is_same_device(obj.device, "cpu"): # cpu to gpu 88 | Device.free_memory(self.size * 1.3, device) 89 | if Device.get_free_memory(device) > self.size * 1.3: 90 | try: 91 | obj.to(device) 92 | except: 93 | print( 94 | f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]" 95 | ) 96 | else: 97 | print( 98 | f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]" 99 | ) 100 | 101 | 102 | class SAMWrapper: 103 | """#### Wrapper class for SAM model.""" 104 | 105 | def __init__( 106 | self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None 107 | ): 108 | self.model = model 109 | self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU() 110 | self.is_auto_mode = is_auto_mode 111 | 112 | def prepare_device(self) -> None: 113 | """#### Prepare the device for the model.""" 114 | if self.is_auto_mode: 115 | device = Device.get_torch_device() 116 | self.safe_to_gpu.to_device(self.model, device=device) 117 | 118 | def release_device(self) -> None: 119 | """#### Release the device from the model.""" 120 | if self.is_auto_mode: 121 | self.model.to(device="cpu") 122 | 123 | def predict( 124 | self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float 125 | ) -> list: 126 | """#### Predict masks using the SAM model. 127 | 128 | #### Args: 129 | - `image` (np.ndarray): The input image. 130 | - `points` (list): List of points. 131 | - `plabs` (list): List of point labels. 132 | - `bbox` (list): Bounding box. 133 | - `threshold` (float): Threshold for mask selection. 134 | 135 | #### Returns: 136 | - `list`: List of predicted masks. 137 | """ 138 | predictor = SamPredictor(self.model) 139 | predictor.set_image(image, "RGB") 140 | 141 | return sam_predict(predictor, points, plabs, bbox, threshold) 142 | 143 | 144 | class SAMLoader: 145 | """#### Class to load SAM models.""" 146 | 147 | def load_model(self, model_name: str, device_mode: str = "auto") -> tuple: 148 | """#### Load a SAM model. 149 | 150 | #### Args: 151 | - `model_name` (str): The name of the model. 152 | - `device_mode` (str, optional): The device mode. Defaults to "auto". 153 | 154 | #### Returns: 155 | - `tuple`: The loaded SAM model. 156 | """ 157 | modelname = "./_internal/yolos/" + model_name 158 | 159 | if "vit_h" in model_name: 160 | model_kind = "vit_h" 161 | elif "vit_l" in model_name: 162 | model_kind = "vit_l" 163 | else: 164 | model_kind = "vit_b" 165 | 166 | sam = sam_model_registry[model_kind](checkpoint=modelname) 167 | size = os.path.getsize(modelname) 168 | safe_to = SafeToGPU(size) 169 | 170 | # Unless user explicitly wants to use CPU, we use GPU 171 | device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU" 172 | 173 | if device_mode == "Prefer GPU": 174 | safe_to.to_device(sam, device) 175 | 176 | is_auto_mode = device_mode == "AUTO" 177 | 178 | sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to) 179 | sam.sam_wrapper = sam_obj 180 | 181 | print(f"Loads SAM model: {modelname} (device:{device_mode})") 182 | return (sam,) 183 | 184 | 185 | def make_sam_mask( 186 | sam: SAMWrapper, 187 | segs: tuple, 188 | image: torch.Tensor, 189 | detection_hint: bool, 190 | dilation: int, 191 | threshold: float, 192 | bbox_expansion: int, 193 | mask_hint_threshold: float, 194 | mask_hint_use_negative: bool, 195 | ) -> torch.Tensor: 196 | """#### Create a SAM mask. 197 | 198 | #### Args: 199 | - `sam` (SAMWrapper): The SAM wrapper. 200 | - `segs` (tuple): Segmentation information. 201 | - `image` (torch.Tensor): The input image. 202 | - `detection_hint` (bool): Whether to use detection hint. 203 | - `dilation` (int): Dilation value. 204 | - `threshold` (float): Threshold for mask selection. 205 | - `bbox_expansion` (int): Bounding box expansion value. 206 | - `mask_hint_threshold` (float): Mask hint threshold. 207 | - `mask_hint_use_negative` (bool): Whether to use negative mask hint. 208 | 209 | #### Returns: 210 | - `torch.Tensor`: The created SAM mask. 211 | """ 212 | sam_obj = sam.sam_wrapper 213 | sam_obj.prepare_device() 214 | 215 | try: 216 | image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) 217 | 218 | total_masks = [] 219 | # seg_shape = segs[0] 220 | segs = segs[1] 221 | for i in range(len(segs)): 222 | bbox = segs[i].bbox 223 | center = mask_util.center_of_bbox(bbox) 224 | x1 = max(bbox[0] - bbox_expansion, 0) 225 | y1 = max(bbox[1] - bbox_expansion, 0) 226 | x2 = min(bbox[2] + bbox_expansion, image.shape[1]) 227 | y2 = min(bbox[3] + bbox_expansion, image.shape[0]) 228 | dilated_bbox = [x1, y1, x2, y2] 229 | points = [] 230 | plabs = [] 231 | points.append(center) 232 | plabs = [1] # 1 = foreground point, 0 = background point 233 | detected_masks = sam_obj.predict( 234 | image, points, plabs, dilated_bbox, threshold 235 | ) 236 | total_masks += detected_masks 237 | 238 | # merge every collected masks 239 | mask = mask_util.combine_masks2(total_masks) 240 | 241 | finally: 242 | sam_obj.release_device() 243 | 244 | if mask is not None: 245 | mask = mask.float() 246 | mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation) 247 | mask = torch.from_numpy(mask) 248 | 249 | mask = mask_util.make_3d_mask(mask) 250 | return mask 251 | else: 252 | return None 253 | 254 | 255 | class SAMDetectorCombined: 256 | """#### Class to combine SAM detection.""" 257 | 258 | def doit( 259 | self, 260 | sam_model: SAMWrapper, 261 | segs: tuple, 262 | image: torch.Tensor, 263 | detection_hint: bool, 264 | dilation: int, 265 | threshold: float, 266 | bbox_expansion: int, 267 | mask_hint_threshold: float, 268 | mask_hint_use_negative: bool, 269 | ) -> tuple: 270 | """#### Combine SAM detection. 271 | 272 | #### Args: 273 | - `sam_model` (SAMWrapper): The SAM wrapper. 274 | - `segs` (tuple): Segmentation information. 275 | - `image` (torch.Tensor): The input image. 276 | - `detection_hint` (bool): Whether to use detection hint. 277 | - `dilation` (int): Dilation value. 278 | - `threshold` (float): Threshold for mask selection. 279 | - `bbox_expansion` (int): Bounding box expansion value. 280 | - `mask_hint_threshold` (float): Mask hint threshold. 281 | - `mask_hint_use_negative` (bool): Whether to use negative mask hint. 282 | 283 | #### Returns: 284 | - `tuple`: The combined SAM detection result. 285 | """ 286 | sam = make_sam_mask( 287 | sam_model, 288 | segs, 289 | image, 290 | detection_hint, 291 | dilation, 292 | threshold, 293 | bbox_expansion, 294 | mask_hint_threshold, 295 | mask_hint_use_negative, 296 | ) 297 | if sam is not None: 298 | return (sam,) 299 | else: 300 | return None 301 | -------------------------------------------------------------------------------- /modules/AutoDetailer/SEGS.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import numpy as np 3 | import torch 4 | from modules.AutoDetailer import mask_util 5 | 6 | SEG = namedtuple( 7 | "SEG", 8 | [ 9 | "cropped_image", 10 | "cropped_mask", 11 | "confidence", 12 | "crop_region", 13 | "bbox", 14 | "label", 15 | "control_net_wrapper", 16 | ], 17 | defaults=[None], 18 | ) 19 | 20 | 21 | def segs_bitwise_and_mask(segs: tuple, mask: torch.Tensor) -> tuple: 22 | """#### Apply bitwise AND operation between segmentation masks and a given mask. 23 | 24 | #### Args: 25 | - `segs` (tuple): A tuple containing segmentation information. 26 | - `mask` (torch.Tensor): The mask tensor. 27 | 28 | #### Returns: 29 | - `tuple`: A tuple containing the original segmentation and the updated items. 30 | """ 31 | mask = mask_util.make_2d_mask(mask) 32 | items = [] 33 | 34 | mask = (mask.cpu().numpy() * 255).astype(np.uint8) 35 | 36 | for seg in segs[1]: 37 | cropped_mask = (seg.cropped_mask * 255).astype(np.uint8) 38 | crop_region = seg.crop_region 39 | 40 | cropped_mask2 = mask[ 41 | crop_region[1] : crop_region[3], crop_region[0] : crop_region[2] 42 | ] 43 | 44 | new_mask = np.bitwise_and(cropped_mask.astype(np.uint8), cropped_mask2) 45 | new_mask = new_mask.astype(np.float32) / 255.0 46 | 47 | item = SEG( 48 | seg.cropped_image, 49 | new_mask, 50 | seg.confidence, 51 | seg.crop_region, 52 | seg.bbox, 53 | seg.label, 54 | None, 55 | ) 56 | items.append(item) 57 | 58 | return segs[0], items 59 | 60 | 61 | class SegsBitwiseAndMask: 62 | """#### Class to apply bitwise AND operation between segmentation masks and a given mask.""" 63 | 64 | def doit(self, segs: tuple, mask: torch.Tensor) -> tuple: 65 | """#### Apply bitwise AND operation between segmentation masks and a given mask. 66 | 67 | #### Args: 68 | - `segs` (tuple): A tuple containing segmentation information. 69 | - `mask` (torch.Tensor): The mask tensor. 70 | 71 | #### Returns: 72 | - `tuple`: A tuple containing the original segmentation and the updated items. 73 | """ 74 | return (segs_bitwise_and_mask(segs, mask),) 75 | 76 | 77 | class SEGSLabelFilter: 78 | """#### Class to filter segmentation labels.""" 79 | 80 | @staticmethod 81 | def filter(segs: tuple, labels: list) -> tuple: 82 | """#### Filter segmentation labels. 83 | 84 | #### Args: 85 | - `segs` (tuple): A tuple containing segmentation information. 86 | - `labels` (list): A list of labels to filter. 87 | 88 | #### Returns: 89 | - `tuple`: A tuple containing the original segmentation and an empty list. 90 | """ 91 | labels = set([label.strip() for label in labels]) 92 | return ( 93 | segs, 94 | (segs[0], []), 95 | ) 96 | -------------------------------------------------------------------------------- /modules/AutoDetailer/bbox.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from ultralytics import YOLO 3 | from modules.AutoDetailer import SEGS, AD_util, tensor_util 4 | from typing import List, Tuple, Optional 5 | 6 | 7 | class UltraBBoxDetector: 8 | """#### Class to detect bounding boxes using a YOLO model.""" 9 | 10 | bbox_model: Optional[YOLO] = None 11 | 12 | def __init__(self, bbox_model: YOLO): 13 | """#### Initialize the UltraBBoxDetector with a YOLO model. 14 | 15 | #### Args: 16 | - `bbox_model` (YOLO): The YOLO model to use for detection. 17 | """ 18 | self.bbox_model = bbox_model 19 | 20 | def detect( 21 | self, 22 | image: torch.Tensor, 23 | threshold: float, 24 | dilation: int, 25 | crop_factor: float, 26 | drop_size: int = 1, 27 | detailer_hook: Optional[callable] = None, 28 | ) -> Tuple[Tuple[int, int], List[SEGS.SEG]]: 29 | """#### Detect bounding boxes in an image. 30 | 31 | #### Args: 32 | - `image` (torch.Tensor): The input image tensor. 33 | - `threshold` (float): The detection threshold. 34 | - `dilation` (int): The dilation factor for masks. 35 | - `crop_factor` (float): The crop factor for bounding boxes. 36 | - `drop_size` (int, optional): The minimum size of bounding boxes to keep. Defaults to 1. 37 | - `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None. 38 | 39 | #### Returns: 40 | - `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments. 41 | """ 42 | drop_size = max(drop_size, 1) 43 | detected_results = AD_util.inference_bbox( 44 | self.bbox_model, tensor_util.tensor2pil(image), threshold 45 | ) 46 | segmasks = AD_util.create_segmasks(detected_results) 47 | 48 | if dilation > 0: 49 | segmasks = AD_util.dilate_masks(segmasks, dilation) 50 | 51 | items = [] 52 | h = image.shape[1] 53 | w = image.shape[2] 54 | 55 | for x, label in zip(segmasks, detected_results[0]): 56 | item_bbox = x[0] 57 | item_mask = x[1] 58 | 59 | y1, x1, y2, x2 = item_bbox 60 | 61 | if ( 62 | x2 - x1 > drop_size and y2 - y1 > drop_size 63 | ): # minimum dimension must be (2,2) to avoid squeeze issue 64 | crop_region = AD_util.make_crop_region(w, h, item_bbox, crop_factor) 65 | 66 | cropped_image = AD_util.crop_image(image, crop_region) 67 | cropped_mask = AD_util.crop_ndarray2(item_mask, crop_region) 68 | confidence = x[2] 69 | 70 | item = SEGS.SEG( 71 | cropped_image, 72 | cropped_mask, 73 | confidence, 74 | crop_region, 75 | item_bbox, 76 | label, 77 | None, 78 | ) 79 | 80 | items.append(item) 81 | 82 | shape = image.shape[1], image.shape[2] 83 | segs = shape, items 84 | 85 | return segs 86 | 87 | 88 | class UltraSegmDetector: 89 | """#### Class to detect segments using a YOLO model.""" 90 | 91 | bbox_model: Optional[YOLO] = None 92 | 93 | def __init__(self, bbox_model: YOLO): 94 | """#### Initialize the UltraSegmDetector with a YOLO model. 95 | 96 | #### Args: 97 | - `bbox_model` (YOLO): The YOLO model to use for detection. 98 | """ 99 | self.bbox_model = bbox_model 100 | 101 | 102 | class NO_SEGM_DETECTOR: 103 | """#### Placeholder class for no segment detector.""" 104 | 105 | pass 106 | 107 | 108 | class UltralyticsDetectorProvider: 109 | """#### Class to provide YOLO models for detection.""" 110 | 111 | def doit(self, model_name: str) -> Tuple[UltraBBoxDetector, UltraSegmDetector]: 112 | """#### Load a YOLO model and return detectors. 113 | 114 | #### Args: 115 | - `model_name` (str): The name of the YOLO model to load. 116 | 117 | #### Returns: 118 | - `Tuple[UltraBBoxDetector, UltraSegmDetector]`: The bounding box and segment detectors. 119 | """ 120 | model = AD_util.load_yolo("./_internal/yolos/" + model_name) 121 | return UltraBBoxDetector(model), UltraSegmDetector(model) 122 | 123 | 124 | class BboxDetectorForEach: 125 | """#### Class to detect bounding boxes for each segment.""" 126 | 127 | def doit( 128 | self, 129 | bbox_detector: UltraBBoxDetector, 130 | image: torch.Tensor, 131 | threshold: float, 132 | dilation: int, 133 | crop_factor: float, 134 | drop_size: int, 135 | labels: Optional[str] = None, 136 | detailer_hook: Optional[callable] = None, 137 | ) -> Tuple[Tuple[int, int], List[SEGS.SEG]]: 138 | """#### Detect bounding boxes for each segment in an image. 139 | 140 | #### Args: 141 | - `bbox_detector` (UltraBBoxDetector): The bounding box detector. 142 | - `image` (torch.Tensor): The input image tensor. 143 | - `threshold` (float): The detection threshold. 144 | - `dilation` (int): The dilation factor for masks. 145 | - `crop_factor` (float): The crop factor for bounding boxes. 146 | - `drop_size` (int): The minimum size of bounding boxes to keep. 147 | - `labels` (str, optional): The labels to filter. Defaults to None. 148 | - `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None. 149 | 150 | #### Returns: 151 | - `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments. 152 | """ 153 | segs = bbox_detector.detect( 154 | image, threshold, dilation, crop_factor, drop_size, detailer_hook 155 | ) 156 | 157 | if labels is not None and labels != "": 158 | labels = labels.split(",") 159 | if len(labels) > 0: 160 | segs, _ = SEGS.SEGSLabelFilter.filter(segs, labels) 161 | 162 | return segs 163 | 164 | 165 | class WildcardChooser: 166 | """#### Class to choose wildcards for segments.""" 167 | 168 | def __init__(self, items: List[Tuple[None, str]], randomize_when_exhaust: bool): 169 | """#### Initialize the WildcardChooser. 170 | 171 | #### Args: 172 | - `items` (List[Tuple[None, str]]): The list of items to choose from. 173 | - `randomize_when_exhaust` (bool): Whether to randomize when the list is exhausted. 174 | """ 175 | self.i = 0 176 | self.items = items 177 | self.randomize_when_exhaust = randomize_when_exhaust 178 | 179 | def get(self, seg: SEGS.SEG) -> Tuple[None, str]: 180 | """#### Get the next item from the list. 181 | 182 | #### Args: 183 | - `seg` (SEGS.SEG): The segment. 184 | 185 | #### Returns: 186 | - `Tuple[None, str]`: The next item from the list. 187 | """ 188 | item = self.items[self.i] 189 | self.i += 1 190 | 191 | return item 192 | 193 | 194 | def process_wildcard_for_segs(wildcard: str) -> Tuple[None, WildcardChooser]: 195 | """#### Process a wildcard for segments. 196 | 197 | #### Args: 198 | - `wildcard` (str): The wildcard. 199 | 200 | #### Returns: 201 | - `Tuple[None, WildcardChooser]`: The processed wildcard and a WildcardChooser. 202 | """ 203 | return None, WildcardChooser([(None, wildcard)], False) 204 | -------------------------------------------------------------------------------- /modules/AutoDetailer/mask_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def center_of_bbox(bbox: list) -> tuple[float, float]: 6 | """#### Calculate the center of a bounding box. 7 | 8 | #### Args: 9 | - `bbox` (list): The bounding box coordinates [x1, y1, x2, y2]. 10 | 11 | #### Returns: 12 | - `tuple[float, float]`: The center coordinates (x, y). 13 | """ 14 | w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] 15 | return bbox[0] + w / 2, bbox[1] + h / 2 16 | 17 | 18 | def make_2d_mask(mask: torch.Tensor) -> torch.Tensor: 19 | """#### Convert a mask to 2D. 20 | 21 | #### Args: 22 | - `mask` (torch.Tensor): The input mask tensor. 23 | 24 | #### Returns: 25 | - `torch.Tensor`: The 2D mask tensor. 26 | """ 27 | if len(mask.shape) == 4: 28 | return mask.squeeze(0).squeeze(0) 29 | elif len(mask.shape) == 3: 30 | return mask.squeeze(0) 31 | return mask 32 | 33 | 34 | def combine_masks2(masks: list) -> torch.Tensor | None: 35 | """#### Combine multiple masks into one. 36 | 37 | #### Args: 38 | - `masks` (list): A list of mask tensors. 39 | 40 | #### Returns: 41 | - `torch.Tensor | None`: The combined mask tensor or None if no masks are provided. 42 | """ 43 | try: 44 | mask = torch.from_numpy(np.array(masks[0]).astype(np.uint8)) 45 | except: 46 | print("No Human Detected") 47 | return None 48 | return mask 49 | 50 | 51 | def dilate_mask( 52 | mask: torch.Tensor, dilation_factor: int, iter: int = 1 53 | ) -> torch.Tensor: 54 | """#### Dilate a mask. 55 | 56 | #### Args: 57 | - `mask` (torch.Tensor): The input mask tensor. 58 | - `dilation_factor` (int): The dilation factor. 59 | - `iter` (int, optional): The number of iterations. Defaults to 1. 60 | 61 | #### Returns: 62 | - `torch.Tensor`: The dilated mask tensor. 63 | """ 64 | return make_2d_mask(mask) 65 | 66 | 67 | def make_3d_mask(mask: torch.Tensor) -> torch.Tensor: 68 | """#### Convert a mask to 3D. 69 | 70 | #### Args: 71 | - `mask` (torch.Tensor): The input mask tensor. 72 | 73 | #### Returns: 74 | - `torch.Tensor`: The 3D mask tensor. 75 | """ 76 | if len(mask.shape) == 4: 77 | return mask.squeeze(0) 78 | elif len(mask.shape) == 2: 79 | return mask.unsqueeze(0) 80 | return mask 81 | -------------------------------------------------------------------------------- /modules/AutoDetailer/tensor_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | import torchvision 5 | 6 | from modules.Device import Device 7 | 8 | 9 | def _tensor_check_image(image: torch.Tensor) -> None: 10 | """#### Check if the input is a valid tensor image. 11 | 12 | #### Args: 13 | - `image` (torch.Tensor): The input tensor image. 14 | """ 15 | return 16 | 17 | 18 | def tensor2pil(image: torch.Tensor) -> Image.Image: 19 | """#### Convert a tensor to a PIL image. 20 | 21 | #### Args: 22 | - `image` (torch.Tensor): The input tensor. 23 | 24 | #### Returns: 25 | - `Image.Image`: The converted PIL image. 26 | """ 27 | _tensor_check_image(image) 28 | return Image.fromarray( 29 | np.clip(255.0 * image.cpu().numpy().squeeze(0), 0, 255).astype(np.uint8) 30 | ) 31 | 32 | 33 | def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: 34 | """#### Resize a tensor image using bilinear interpolation. 35 | 36 | #### Args: 37 | - `image` (torch.Tensor): The input tensor image. 38 | - `w` (int): The target width. 39 | - `h` (int): The target height. 40 | 41 | #### Returns: 42 | - `torch.Tensor`: The resized tensor image. 43 | """ 44 | _tensor_check_image(image) 45 | image = image.permute(0, 3, 1, 2) 46 | image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear") 47 | image = image.permute(0, 2, 3, 1) 48 | return image 49 | 50 | 51 | def pil2tensor(image: Image.Image) -> torch.Tensor: 52 | """#### Convert a PIL image to a tensor. 53 | 54 | #### Args: 55 | - `image` (Image.Image): The input PIL image. 56 | 57 | #### Returns: 58 | - `torch.Tensor`: The converted tensor. 59 | """ 60 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 61 | 62 | 63 | class TensorBatchBuilder: 64 | """#### Class for building a batch of tensors.""" 65 | 66 | def __init__(self): 67 | self.tensor: torch.Tensor | None = None 68 | 69 | def concat(self, new_tensor: torch.Tensor) -> None: 70 | """#### Concatenate a new tensor to the batch. 71 | 72 | #### Args: 73 | - `new_tensor` (torch.Tensor): The new tensor to concatenate. 74 | """ 75 | self.tensor = new_tensor 76 | 77 | 78 | LANCZOS = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS 79 | 80 | 81 | def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor: 82 | """#### Resize a tensor image. 83 | 84 | #### Args: 85 | - `image` (torch.Tensor): The input tensor image. 86 | - `w` (int): The target width. 87 | - `h` (int): The target height. 88 | 89 | #### Returns: 90 | - `torch.Tensor`: The resized tensor image. 91 | """ 92 | _tensor_check_image(image) 93 | if image.shape[3] >= 3: 94 | scaled_images = TensorBatchBuilder() 95 | for single_image in image: 96 | single_image = single_image.unsqueeze(0) 97 | single_pil = tensor2pil(single_image) 98 | scaled_pil = single_pil.resize((w, h), resample=LANCZOS) 99 | 100 | single_image = pil2tensor(scaled_pil) 101 | scaled_images.concat(single_image) 102 | 103 | return scaled_images.tensor 104 | else: 105 | return general_tensor_resize(image, w, h) 106 | 107 | 108 | def tensor_paste( 109 | image1: torch.Tensor, 110 | image2: torch.Tensor, 111 | left_top: tuple[int, int], 112 | mask: torch.Tensor, 113 | ) -> None: 114 | """#### Paste one tensor image onto another using a mask. 115 | 116 | #### Args: 117 | - `image1` (torch.Tensor): The base tensor image. 118 | - `image2` (torch.Tensor): The tensor image to paste. 119 | - `left_top` (tuple[int, int]): The top-left corner where the image2 will be pasted. 120 | - `mask` (torch.Tensor): The mask tensor. 121 | """ 122 | _tensor_check_image(image1) 123 | _tensor_check_image(image2) 124 | _tensor_check_mask(mask) 125 | 126 | x, y = left_top 127 | _, h1, w1, _ = image1.shape 128 | _, h2, w2, _ = image2.shape 129 | 130 | # calculate image patch size 131 | w = min(w1, x + w2) - x 132 | h = min(h1, y + h2) - y 133 | 134 | mask = mask[:, :h, :w, :] 135 | image1[:, y : y + h, x : x + w, :] = (1 - mask) * image1[ 136 | :, y : y + h, x : x + w, : 137 | ] + mask * image2[:, :h, :w, :] 138 | return 139 | 140 | 141 | def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: 142 | """#### Convert a tensor image to RGBA format. 143 | 144 | #### Args: 145 | - `image` (torch.Tensor): The input tensor image. 146 | - `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True. 147 | 148 | #### Returns: 149 | - `torch.Tensor`: The converted RGBA tensor image. 150 | """ 151 | _tensor_check_image(image) 152 | alpha = torch.ones((*image.shape[:-1], 1)) 153 | return torch.cat((image, alpha), axis=-1) 154 | 155 | 156 | def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor: 157 | """#### Convert a tensor image to RGB format. 158 | 159 | #### Args: 160 | - `image` (torch.Tensor): The input tensor image. 161 | - `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True. 162 | 163 | #### Returns: 164 | - `torch.Tensor`: The converted RGB tensor image. 165 | """ 166 | _tensor_check_image(image) 167 | return image 168 | 169 | 170 | def tensor_get_size(image: torch.Tensor) -> tuple[int, int]: 171 | """#### Get the size of a tensor image. 172 | 173 | #### Args: 174 | - `image` (torch.Tensor): The input tensor image. 175 | 176 | #### Returns: 177 | - `tuple[int, int]`: The width and height of the tensor image. 178 | """ 179 | _tensor_check_image(image) 180 | _, h, w, _ = image.shape 181 | return (w, h) 182 | 183 | 184 | def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None: 185 | """#### Add an alpha channel to a tensor image using a mask. 186 | 187 | #### Args: 188 | - `image` (torch.Tensor): The input tensor image. 189 | - `mask` (torch.Tensor): The mask tensor. 190 | """ 191 | _tensor_check_image(image) 192 | _tensor_check_mask(mask) 193 | image[..., -1] = mask[..., 0] 194 | 195 | 196 | def _tensor_check_mask(mask: torch.Tensor) -> None: 197 | """#### Check if the input is a valid tensor mask. 198 | 199 | #### Args: 200 | - `mask` (torch.Tensor): The input tensor mask. 201 | """ 202 | return 203 | 204 | 205 | def tensor_gaussian_blur_mask( 206 | mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0 207 | ) -> torch.Tensor: 208 | """#### Apply Gaussian blur to a tensor mask. 209 | 210 | #### Args: 211 | - `mask` (torch.Tensor | np.ndarray): The input tensor mask. 212 | - `kernel_size` (int): The size of the Gaussian kernel. 213 | - `sigma` (float, optional): The standard deviation of the Gaussian kernel. Defaults to 10.0. 214 | 215 | #### Returns: 216 | - `torch.Tensor`: The blurred tensor mask. 217 | """ 218 | if isinstance(mask, np.ndarray): 219 | mask = torch.from_numpy(mask) 220 | 221 | if mask.ndim == 2: 222 | mask = mask[None, ..., None] 223 | 224 | _tensor_check_mask(mask) 225 | 226 | kernel_size = kernel_size * 2 + 1 227 | 228 | prev_device = mask.device 229 | device = Device.get_torch_device() 230 | mask.to(device) 231 | 232 | # apply gaussian blur 233 | mask = mask[:, None, ..., 0] 234 | blurred_mask = torchvision.transforms.GaussianBlur( 235 | kernel_size=kernel_size, sigma=sigma 236 | )(mask) 237 | blurred_mask = blurred_mask[:, 0, ..., None] 238 | 239 | blurred_mask.to(prev_device) 240 | 241 | return blurred_mask 242 | 243 | 244 | def to_tensor(image: np.ndarray) -> torch.Tensor: 245 | """#### Convert a numpy array to a tensor. 246 | 247 | #### Args: 248 | - `image` (np.ndarray): The input numpy array. 249 | 250 | #### Returns: 251 | - `torch.Tensor`: The converted tensor. 252 | """ 253 | return torch.from_numpy(image) 254 | -------------------------------------------------------------------------------- /modules/AutoHDR/ahdr.py: -------------------------------------------------------------------------------- 1 | # Taken and adapted from https://github.com/SuperBeastsAI/ComfyUI-SuperBeasts 2 | 3 | import numpy as np 4 | from PIL import Image, ImageOps, ImageDraw, ImageFilter, ImageEnhance, ImageCms 5 | from PIL.PngImagePlugin import PngInfo 6 | import torch 7 | import torch.nn.functional as F 8 | import json 9 | import random 10 | 11 | 12 | sRGB_profile = ImageCms.createProfile("sRGB") 13 | Lab_profile = ImageCms.createProfile("LAB") 14 | 15 | # Tensor to PIL 16 | def tensor2pil(image): 17 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) 18 | 19 | # PIL to Tensor 20 | def pil2tensor(image): 21 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) 22 | 23 | def adjust_shadows_non_linear(luminance, shadow_intensity, max_shadow_adjustment=1.5): 24 | lum_array = np.array(luminance, dtype=np.float32) / 255.0 # Normalize 25 | # Apply a non-linear darkening effect based on shadow_intensity 26 | shadows = lum_array ** (1 / (1 + shadow_intensity * max_shadow_adjustment)) 27 | return np.clip(shadows * 255, 0, 255).astype(np.uint8) # Re-scale to [0, 255] 28 | 29 | def adjust_highlights_non_linear(luminance, highlight_intensity, max_highlight_adjustment=1.5): 30 | lum_array = np.array(luminance, dtype=np.float32) / 255.0 # Normalize 31 | # Brighten highlights more aggressively based on highlight_intensity 32 | highlights = 1 - (1 - lum_array) ** (1 + highlight_intensity * max_highlight_adjustment) 33 | return np.clip(highlights * 255, 0, 255).astype(np.uint8) # Re-scale to [0, 255] 34 | 35 | def merge_adjustments_with_blend_modes(luminance, shadows, highlights, hdr_intensity, shadow_intensity, highlight_intensity): 36 | # Ensure the data is in the correct format for processing 37 | base = np.array(luminance, dtype=np.float32) 38 | 39 | # Scale the adjustments based on hdr_intensity 40 | scaled_shadow_intensity = shadow_intensity ** 2 * hdr_intensity 41 | scaled_highlight_intensity = highlight_intensity ** 2 * hdr_intensity 42 | 43 | # Create luminance-based masks for shadows and highlights 44 | shadow_mask = np.clip((1 - (base / 255)) ** 2, 0, 1) 45 | highlight_mask = np.clip((base / 255) ** 2, 0, 1) 46 | 47 | # Apply the adjustments using the masks 48 | adjusted_shadows = np.clip(base * (1 - shadow_mask * scaled_shadow_intensity), 0, 255) 49 | adjusted_highlights = np.clip(base + (255 - base) * highlight_mask * scaled_highlight_intensity, 0, 255) 50 | 51 | # Combine the adjusted shadows and highlights 52 | adjusted_luminance = np.clip(adjusted_shadows + adjusted_highlights - base, 0, 255) 53 | 54 | # Blend the adjusted luminance with the original luminance based on hdr_intensity 55 | final_luminance = np.clip(base * (1 - hdr_intensity) + adjusted_luminance * hdr_intensity, 0, 255).astype(np.uint8) 56 | 57 | return Image.fromarray(final_luminance) 58 | 59 | def apply_gamma_correction(lum_array, gamma): 60 | """ 61 | Apply gamma correction to the luminance array. 62 | :param lum_array: Luminance channel as a NumPy array. 63 | :param gamma: Gamma value for correction. 64 | """ 65 | if gamma == 0: 66 | return np.clip(lum_array, 0, 255).astype(np.uint8) 67 | 68 | epsilon = 1e-7 # Small value to avoid dividing by zero 69 | gamma_corrected = 1 / (1.1 - gamma) 70 | adjusted = 255 * ((lum_array / 255) ** gamma_corrected) 71 | return np.clip(adjusted, 0, 255).astype(np.uint8) 72 | 73 | # create a wrapper function that can apply a function to multiple images in a batch while passing all other arguments to the function 74 | def apply_to_batch(func): 75 | def wrapper(self, image, *args, **kwargs): 76 | images = [] 77 | for img in image: 78 | images.append(func(self, img, *args, **kwargs)) 79 | batch_tensor = torch.cat(images, dim=0) 80 | return (batch_tensor, ) 81 | return wrapper 82 | 83 | class HDREffects: 84 | @apply_to_batch 85 | def apply_hdr2(self, image, hdr_intensity=0.75, shadow_intensity=0.25, highlight_intensity=0.5, gamma_intensity=0.25, contrast=0.1, enhance_color=0.25): 86 | # Load the image 87 | img = tensor2pil(image) 88 | 89 | # Step 1: Convert RGB to LAB for better color preservation 90 | img_lab = ImageCms.profileToProfile(img, sRGB_profile, Lab_profile, outputMode='LAB') 91 | 92 | # Extract L, A, and B channels 93 | luminance, a, b = img_lab.split() 94 | 95 | # Convert luminance to a NumPy array for processing 96 | lum_array = np.array(luminance, dtype=np.float32) 97 | 98 | # Preparing adjustment layers (shadows, midtones, highlights) 99 | # This example assumes you have methods to extract or calculate these adjustments 100 | shadows_adjusted = adjust_shadows_non_linear(luminance, shadow_intensity) 101 | highlights_adjusted = adjust_highlights_non_linear(luminance, highlight_intensity) 102 | 103 | 104 | merged_adjustments = merge_adjustments_with_blend_modes(lum_array, shadows_adjusted, highlights_adjusted, hdr_intensity, shadow_intensity, highlight_intensity) 105 | 106 | # Apply gamma correction with a base_gamma value (define based on desired effect) 107 | gamma_corrected = apply_gamma_correction(np.array(merged_adjustments), gamma_intensity) 108 | gamma_corrected = Image.fromarray(gamma_corrected).resize(a.size) 109 | 110 | 111 | # Merge L channel back with original A and B channels 112 | adjusted_lab = Image.merge('LAB', (gamma_corrected, a, b)) 113 | 114 | # Step 3: Convert LAB back to RGB 115 | img_adjusted = ImageCms.profileToProfile(adjusted_lab, Lab_profile, sRGB_profile, outputMode='RGB') 116 | 117 | 118 | # Enhance contrast 119 | enhancer = ImageEnhance.Contrast(img_adjusted) 120 | contrast_adjusted = enhancer.enhance(1 + contrast) 121 | 122 | 123 | # Enhance color saturation 124 | enhancer = ImageEnhance.Color(contrast_adjusted) 125 | color_adjusted = enhancer.enhance(1 + enhance_color * 0.2) 126 | 127 | return pil2tensor(color_adjusted) -------------------------------------------------------------------------------- /modules/FileManaging/Downloader.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from huggingface_hub import hf_hub_download 3 | 4 | 5 | def CheckAndDownload(): 6 | """#### Check and download all the necessary safetensors and checkpoints models""" 7 | if glob.glob("./_internal/checkpoints/*.safetensors") == []: 8 | 9 | hf_hub_download( 10 | repo_id="Meina/MeinaMix", 11 | filename="Meina V10 - baked VAE.safetensors", 12 | local_dir="./_internal/checkpoints/", 13 | ) 14 | hf_hub_download( 15 | repo_id="Lykon/DreamShaper", 16 | filename="DreamShaper_8_pruned.safetensors", 17 | local_dir="./_internal/checkpoints/", 18 | ) 19 | if glob.glob("./_internal/yolos/*.pt") == []: 20 | 21 | hf_hub_download( 22 | repo_id="Bingsu/adetailer", 23 | filename="hand_yolov9c.pt", 24 | local_dir="./_internal/yolos/", 25 | ) 26 | hf_hub_download( 27 | repo_id="Bingsu/adetailer", 28 | filename="face_yolov9c.pt", 29 | local_dir="./_internal/yolos/", 30 | ) 31 | hf_hub_download( 32 | repo_id="Bingsu/adetailer", 33 | filename="person_yolov8m-seg.pt", 34 | local_dir="./_internal/yolos/", 35 | ) 36 | hf_hub_download( 37 | repo_id="segments-arnaud/sam_vit_b", 38 | filename="sam_vit_b_01ec64.pth", 39 | local_dir="./_internal/yolos/", 40 | ) 41 | if glob.glob("./_internal/ESRGAN/*.pth") == []: 42 | 43 | hf_hub_download( 44 | repo_id="lllyasviel/Annotators", 45 | filename="RealESRGAN_x4plus.pth", 46 | local_dir="./_internal/ESRGAN/", 47 | ) 48 | if glob.glob("./_internal/loras/*.safetensors") == []: 49 | 50 | hf_hub_download( 51 | repo_id="EvilEngine/add_detail", 52 | filename="add_detail.safetensors", 53 | local_dir="./_internal/loras/", 54 | ) 55 | if glob.glob("./_internal/embeddings/*.pt") == []: 56 | 57 | hf_hub_download( 58 | repo_id="EvilEngine/badhandv4", 59 | filename="badhandv4.pt", 60 | local_dir="./_internal/embeddings/", 61 | ) 62 | # hf_hub_download( 63 | # repo_id="segments-arnaud/sam_vit_b", 64 | # filename="EasyNegative.safetensors", 65 | # local_dir="./_internal/embeddings/", 66 | # ) 67 | if glob.glob("./_internal/vae_approx/*.pth") == []: 68 | 69 | hf_hub_download( 70 | repo_id="madebyollin/taesd", 71 | filename="taesd_decoder.safetensors", 72 | local_dir="./_internal/vae_approx/", 73 | ) 74 | 75 | def CheckAndDownloadFlux(): 76 | """#### Check and download all the necessary safetensors and checkpoints models for FLUX""" 77 | if glob.glob("./_internal/embeddings/*.pt") == []: 78 | hf_hub_download( 79 | repo_id="EvilEngine/badhandv4", 80 | filename="badhandv4.pt", 81 | local_dir="./_internal/embeddings", 82 | ) 83 | if glob.glob("./_internal/unet/*.gguf") == []: 84 | 85 | hf_hub_download( 86 | repo_id="city96/FLUX.1-dev-gguf", 87 | filename="flux1-dev-Q8_0.gguf", 88 | local_dir="./_internal/unet", 89 | ) 90 | if glob.glob("./_internal/clip/*.gguf") == []: 91 | 92 | hf_hub_download( 93 | repo_id="city96/t5-v1_1-xxl-encoder-gguf", 94 | filename="t5-v1_1-xxl-encoder-Q8_0.gguf", 95 | local_dir="./_internal/clip", 96 | ) 97 | hf_hub_download( 98 | repo_id="comfyanonymous/flux_text_encoders", 99 | filename="clip_l.safetensors", 100 | local_dir="./_internal/clip", 101 | ) 102 | if glob.glob("./_internal/vae/*.safetensors") == []: 103 | 104 | hf_hub_download( 105 | repo_id="black-forest-labs/FLUX.1-schnell", 106 | filename="ae.safetensors", 107 | local_dir="./_internal/vae", 108 | ) 109 | 110 | if glob.glob("./_internal/vae_approx/*.pth") == []: 111 | 112 | hf_hub_download( 113 | repo_id="madebyollin/taef1", 114 | filename="diffusion_pytorch_model.safetensors", 115 | local_dir="./_internal/vae_approx/", 116 | ) 117 | -------------------------------------------------------------------------------- /modules/FileManaging/ImageSaver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | 5 | output_directory = "./_internal/output" 6 | 7 | 8 | def get_output_directory() -> str: 9 | """#### Get the output directory. 10 | 11 | #### Returns: 12 | - `str`: The output directory. 13 | """ 14 | global output_directory 15 | return output_directory 16 | 17 | 18 | def get_save_image_path( 19 | filename_prefix: str, output_dir: str, image_width: int = 0, image_height: int = 0 20 | ) -> tuple: 21 | """#### Get the save image path. 22 | 23 | #### Args: 24 | - `filename_prefix` (str): The filename prefix. 25 | - `output_dir` (str): The output directory. 26 | - `image_width` (int, optional): The image width. Defaults to 0. 27 | - `image_height` (int, optional): The image height. Defaults to 0. 28 | 29 | #### Returns: 30 | - `tuple`: The full output folder, filename, counter, subfolder, and filename prefix. 31 | """ 32 | 33 | def map_filename(filename: str) -> tuple: 34 | prefix_len = len(os.path.basename(filename_prefix)) 35 | prefix = filename[: prefix_len + 1] 36 | try: 37 | digits = int(filename[prefix_len + 1 :].split("_")[0]) 38 | except: 39 | digits = 0 40 | return (digits, prefix) 41 | 42 | def compute_vars(input: str, image_width: int, image_height: int) -> str: 43 | input = input.replace("%width%", str(image_width)) 44 | input = input.replace("%height%", str(image_height)) 45 | return input 46 | 47 | filename_prefix = compute_vars(filename_prefix, image_width, image_height) 48 | 49 | subfolder = os.path.dirname(os.path.normpath(filename_prefix)) 50 | filename = os.path.basename(os.path.normpath(filename_prefix)) 51 | 52 | full_output_folder = os.path.join(output_dir, subfolder) 53 | subfolder_paths = [ 54 | os.path.join(full_output_folder, x) 55 | for x in ["Classic", "HiresFix", "Img2Img", "Flux", "Adetailer"] 56 | ] 57 | for path in subfolder_paths: 58 | os.makedirs(path, exist_ok=True) 59 | # Find highest counter across all subfolders 60 | counter = 1 61 | for path in subfolder_paths: 62 | if os.path.exists(path): 63 | files = os.listdir(path) 64 | if files: 65 | numbers = [ 66 | map_filename(f)[0] 67 | for f in files 68 | if f.startswith(filename) and f.endswith(".png") 69 | ] 70 | if numbers: 71 | counter = max(max(numbers) + 1, counter) 72 | 73 | return full_output_folder, filename, counter, subfolder, filename_prefix 74 | 75 | 76 | MAX_RESOLUTION = 16384 77 | 78 | 79 | class SaveImage: 80 | """#### Class for saving images.""" 81 | 82 | def __init__(self): 83 | """#### Initialize the SaveImage class.""" 84 | self.output_dir = get_output_directory() 85 | self.type = "output" 86 | self.prefix_append = "" 87 | self.compress_level = 4 88 | 89 | def save_images( 90 | self, 91 | images: list, 92 | filename_prefix: str = "LD", 93 | prompt: str = None, 94 | extra_pnginfo: dict = None, 95 | ) -> dict: 96 | """#### Save images to the output directory. 97 | 98 | #### Args: 99 | - `images` (list): The list of images. 100 | - `filename_prefix` (str, optional): The filename prefix. Defaults to "LD". 101 | - `prompt` (str, optional): The prompt. Defaults to None. 102 | - `extra_pnginfo` (dict, optional): Additional PNG info. Defaults to None. 103 | 104 | #### Returns: 105 | - `dict`: The saved images information. 106 | """ 107 | filename_prefix += self.prefix_append 108 | full_output_folder, filename, counter, subfolder, filename_prefix = ( 109 | get_save_image_path( 110 | filename_prefix, self.output_dir, images[0].shape[-2], images[0].shape[-1] 111 | ) 112 | ) 113 | results = list() 114 | for batch_number, image in enumerate(images): 115 | # Ensure correct shape by squeezing extra dimensions 116 | i = 255.0 * image.cpu().numpy() 117 | i = np.squeeze(i) # Remove extra dimensions 118 | 119 | # Ensure we have a valid 3D array (height, width, channels) 120 | if i.ndim == 4: 121 | i = i.reshape(-1, i.shape[-2], i.shape[-1]) 122 | 123 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) 124 | metadata = None 125 | 126 | filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) 127 | file = f"{filename_with_batch_num}_{counter:05}_.png" 128 | if filename_prefix == "LD-HF": 129 | full_output_folder = os.path.join(full_output_folder, "HiresFix") 130 | elif filename_prefix == "LD-I2I": 131 | full_output_folder = os.path.join(full_output_folder, "Img2Img") 132 | elif filename_prefix == "LD-Flux": 133 | full_output_folder = os.path.join(full_output_folder, "Flux") 134 | elif filename_prefix == "LD-head" or filename_prefix == "LD-body": 135 | full_output_folder = os.path.join(full_output_folder, "Adetailer") 136 | else: 137 | full_output_folder = os.path.join(full_output_folder, "Classic") 138 | img.save( 139 | os.path.join(full_output_folder, file), 140 | pnginfo=metadata, 141 | compress_level=self.compress_level, 142 | ) 143 | results.append( 144 | {"filename": file, "subfolder": subfolder, "type": self.type} 145 | ) 146 | counter += 1 147 | 148 | return {"ui": {"images": results}} -------------------------------------------------------------------------------- /modules/FileManaging/Loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from modules.Utilities import util 4 | from modules.AutoEncoders import VariationalAE 5 | from modules.Device import Device 6 | from modules.Model import ModelPatcher 7 | from modules.NeuralNetwork import unet 8 | from modules.clip import Clip 9 | 10 | 11 | def load_checkpoint_guess_config( 12 | ckpt_path: str, 13 | output_vae: bool = True, 14 | output_clip: bool = True, 15 | output_clipvision: bool = False, 16 | embedding_directory: str = None, 17 | output_model: bool = True, 18 | ) -> tuple: 19 | """#### Load a checkpoint and guess the configuration. 20 | 21 | #### Args: 22 | - `ckpt_path` (str): The path to the checkpoint file. 23 | - `output_vae` (bool, optional): Whether to output the VAE. Defaults to True. 24 | - `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True. 25 | - `output_clipvision` (bool, optional): Whether to output the CLIP vision. Defaults to False. 26 | - `embedding_directory` (str, optional): The embedding directory. Defaults to None. 27 | - `output_model` (bool, optional): Whether to output the model. Defaults to True. 28 | 29 | #### Returns: 30 | - `tuple`: The model patcher, CLIP, VAE, and CLIP vision. 31 | """ 32 | sd = util.load_torch_file(ckpt_path) 33 | sd.keys() 34 | clip = None 35 | clipvision = None 36 | vae = None 37 | model = None 38 | model_patcher = None 39 | clip_target = None 40 | 41 | parameters = util.calculate_parameters(sd, "model.diffusion_model.") 42 | load_device = Device.get_torch_device() 43 | 44 | model_config = unet.model_config_from_unet(sd, "model.diffusion_model.") 45 | unet_dtype = unet.unet_dtype1( 46 | model_params=parameters, 47 | supported_dtypes=model_config.supported_inference_dtypes, 48 | ) 49 | manual_cast_dtype = Device.unet_manual_cast( 50 | unet_dtype, load_device, model_config.supported_inference_dtypes 51 | ) 52 | model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) 53 | 54 | if output_model: 55 | inital_load_device = Device.unet_inital_load_device(parameters, unet_dtype) 56 | Device.unet_offload_device() 57 | model = model_config.get_model( 58 | sd, "model.diffusion_model.", device=inital_load_device 59 | ) 60 | model.load_model_weights(sd, "model.diffusion_model.") 61 | 62 | if output_vae: 63 | vae_sd = util.state_dict_prefix_replace( 64 | sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True 65 | ) 66 | vae_sd = model_config.process_vae_state_dict(vae_sd) 67 | vae = VariationalAE.VAE(sd=vae_sd) 68 | 69 | if output_clip: 70 | clip_target = model_config.clip_target() 71 | if clip_target is not None: 72 | clip_sd = model_config.process_clip_state_dict(sd) 73 | if len(clip_sd) > 0: 74 | clip = Clip.CLIP(clip_target, embedding_directory=embedding_directory) 75 | m, u = clip.load_sd(clip_sd, full_model=True) 76 | if len(m) > 0: 77 | m_filter = list( 78 | filter( 79 | lambda a: ".logit_scale" not in a 80 | and ".transformer.text_projection.weight" not in a, 81 | m, 82 | ) 83 | ) 84 | if len(m_filter) > 0: 85 | logging.warning("clip missing: {}".format(m)) 86 | else: 87 | logging.debug("clip missing: {}".format(m)) 88 | 89 | if len(u) > 0: 90 | logging.debug("clip unexpected {}:".format(u)) 91 | else: 92 | logging.warning( 93 | "no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded." 94 | ) 95 | 96 | left_over = sd.keys() 97 | if len(left_over) > 0: 98 | logging.debug("left over keys: {}".format(left_over)) 99 | 100 | if output_model: 101 | model_patcher = ModelPatcher.ModelPatcher( 102 | model, 103 | load_device=load_device, 104 | offload_device=Device.unet_offload_device(), 105 | current_device=inital_load_device, 106 | ) 107 | if inital_load_device != torch.device("cpu"): 108 | logging.info("loaded straight to GPU") 109 | Device.load_model_gpu(model_patcher) 110 | 111 | return (model_patcher, clip, vae, clipvision) 112 | 113 | 114 | class CheckpointLoaderSimple: 115 | """#### Class for loading checkpoints.""" 116 | 117 | def load_checkpoint( 118 | self, ckpt_name: str, output_vae: bool = True, output_clip: bool = True 119 | ) -> tuple: 120 | """#### Load a checkpoint. 121 | 122 | #### Args: 123 | - `ckpt_name` (str): The name of the checkpoint. 124 | - `output_vae` (bool, optional): Whether to output the VAE. Defaults to True. 125 | - `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True. 126 | 127 | #### Returns: 128 | - `tuple`: The model patcher, CLIP, and VAE. 129 | """ 130 | ckpt_path = f"{ckpt_name}" 131 | out = load_checkpoint_guess_config( 132 | ckpt_path, 133 | output_vae=output_vae, 134 | output_clip=output_clip, 135 | embedding_directory="./_internal/embeddings/", 136 | ) 137 | print("loading", ckpt_path) 138 | return out[:3] 139 | -------------------------------------------------------------------------------- /modules/Model/LoRas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.Utilities import util 3 | from modules.NeuralNetwork import unet 4 | 5 | LORA_CLIP_MAP = { 6 | "mlp.fc1": "mlp_fc1", 7 | "mlp.fc2": "mlp_fc2", 8 | "self_attn.k_proj": "self_attn_k_proj", 9 | "self_attn.q_proj": "self_attn_q_proj", 10 | "self_attn.v_proj": "self_attn_v_proj", 11 | "self_attn.out_proj": "self_attn_out_proj", 12 | } 13 | 14 | 15 | def load_lora(lora: dict, to_load: dict) -> dict: 16 | """#### Load a LoRA model. 17 | 18 | #### Args: 19 | - `lora` (dict): The LoRA model state dictionary. 20 | - `to_load` (dict): The keys to load from the LoRA model. 21 | 22 | #### Returns: 23 | - `dict`: The loaded LoRA model. 24 | """ 25 | patch_dict = {} 26 | loaded_keys = set() 27 | for x in to_load: 28 | alpha_name = "{}.alpha".format(x) 29 | alpha = None 30 | if alpha_name in lora.keys(): 31 | alpha = lora[alpha_name].item() 32 | loaded_keys.add(alpha_name) 33 | 34 | "{}.dora_scale".format(x) 35 | dora_scale = None 36 | 37 | regular_lora = "{}.lora_up.weight".format(x) 38 | "{}_lora.up.weight".format(x) 39 | "{}.lora_linear_layer.up.weight".format(x) 40 | A_name = None 41 | 42 | if regular_lora in lora.keys(): 43 | A_name = regular_lora 44 | B_name = "{}.lora_down.weight".format(x) 45 | "{}.lora_mid.weight".format(x) 46 | 47 | if A_name is not None: 48 | mid = None 49 | patch_dict[to_load[x]] = ( 50 | "lora", 51 | (lora[A_name], lora[B_name], alpha, mid, dora_scale), 52 | ) 53 | loaded_keys.add(A_name) 54 | loaded_keys.add(B_name) 55 | return patch_dict 56 | 57 | 58 | def model_lora_keys_clip(model: torch.nn.Module, key_map: dict = {}) -> dict: 59 | """#### Get the keys for a LoRA model's CLIP component. 60 | 61 | #### Args: 62 | - `model` (torch.nn.Module): The LoRA model. 63 | - `key_map` (dict, optional): The key map. Defaults to {}. 64 | 65 | #### Returns: 66 | - `dict`: The keys for the CLIP component. 67 | """ 68 | sdk = model.state_dict().keys() 69 | 70 | text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" 71 | for b in range(32): 72 | for c in LORA_CLIP_MAP: 73 | k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c) 74 | if k in sdk: 75 | lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c]) 76 | key_map[lora_key] = k 77 | lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format( 78 | b, LORA_CLIP_MAP[c] 79 | ) # SDXL base 80 | key_map[lora_key] = k 81 | lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format( 82 | b, c 83 | ) # diffusers lora 84 | key_map[lora_key] = k 85 | return key_map 86 | 87 | 88 | def model_lora_keys_unet(model: torch.nn.Module, key_map: dict = {}) -> dict: 89 | """#### Get the keys for a LoRA model's UNet component. 90 | 91 | #### Args: 92 | - `model` (torch.nn.Module): The LoRA model. 93 | - `key_map` (dict, optional): The key map. Defaults to {}. 94 | 95 | #### Returns: 96 | - `dict`: The keys for the UNet component. 97 | """ 98 | sdk = model.state_dict().keys() 99 | 100 | for k in sdk: 101 | if k.startswith("diffusion_model.") and k.endswith(".weight"): 102 | key_lora = k[len("diffusion_model.") : -len(".weight")].replace(".", "_") 103 | key_map["lora_unet_{}".format(key_lora)] = k 104 | key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora: 105 | 106 | diffusers_keys = unet.unet_to_diffusers(model.model_config.unet_config) 107 | for k in diffusers_keys: 108 | if k.endswith(".weight"): 109 | unet_key = "diffusion_model.{}".format(diffusers_keys[k]) 110 | key_lora = k[: -len(".weight")].replace(".", "_") 111 | key_map["lora_unet_{}".format(key_lora)] = unet_key 112 | 113 | diffusers_lora_prefix = ["", "unet."] 114 | for p in diffusers_lora_prefix: 115 | diffusers_lora_key = "{}{}".format( 116 | p, k[: -len(".weight")].replace(".to_", ".processor.to_") 117 | ) 118 | if diffusers_lora_key.endswith(".to_out.0"): 119 | diffusers_lora_key = diffusers_lora_key[:-2] 120 | key_map[diffusers_lora_key] = unet_key 121 | return key_map 122 | 123 | 124 | def load_lora_for_models( 125 | model: object, clip: object, lora: dict, strength_model: float, strength_clip: float 126 | ) -> tuple: 127 | """#### Load a LoRA model for the given models. 128 | 129 | #### Args: 130 | - `model` (object): The model. 131 | - `clip` (object): The CLIP model. 132 | - `lora` (dict): The LoRA model state dictionary. 133 | - `strength_model` (float): The strength of the model. 134 | - `strength_clip` (float): The strength of the CLIP model. 135 | 136 | #### Returns: 137 | - `tuple`: The new model patcher and CLIP model. 138 | """ 139 | key_map = {} 140 | if model is not None: 141 | key_map = model_lora_keys_unet(model.model, key_map) 142 | if clip is not None: 143 | key_map = model_lora_keys_clip(clip.cond_stage_model, key_map) 144 | 145 | loaded = load_lora(lora, key_map) 146 | new_modelpatcher = model.clone() 147 | k = new_modelpatcher.add_patches(loaded, strength_model) 148 | 149 | new_clip = clip.clone() 150 | k1 = new_clip.add_patches(loaded, strength_clip) 151 | k = set(k) 152 | k1 = set(k1) 153 | 154 | return (new_modelpatcher, new_clip) 155 | 156 | 157 | class LoraLoader: 158 | """#### Class for loading LoRA models.""" 159 | 160 | def __init__(self): 161 | """#### Initialize the LoraLoader class.""" 162 | self.loaded_lora = None 163 | 164 | def load_lora( 165 | self, 166 | model: object, 167 | clip: object, 168 | lora_name: str, 169 | strength_model: float, 170 | strength_clip: float, 171 | ) -> tuple: 172 | """#### Load a LoRA model. 173 | 174 | #### Args: 175 | - `model` (object): The model. 176 | - `clip` (object): The CLIP model. 177 | - `lora_name` (str): The name of the LoRA model. 178 | - `strength_model` (float): The strength of the model. 179 | - `strength_clip` (float): The strength of the CLIP model. 180 | 181 | #### Returns: 182 | - `tuple`: The new model patcher and CLIP model. 183 | """ 184 | lora_path = util.get_full_path("loras", lora_name) 185 | lora = None 186 | if lora is None: 187 | lora = util.load_torch_file(lora_path, safe_load=True) 188 | self.loaded_lora = (lora_path, lora) 189 | 190 | model_lora, clip_lora = load_lora_for_models( 191 | model, clip, lora, strength_model, strength_clip 192 | ) 193 | return (model_lora, clip_lora) 194 | -------------------------------------------------------------------------------- /modules/SD15/SD15.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from modules.BlackForest import Flux 3 | from modules.Utilities import util 4 | from modules.Model import ModelBase 5 | from modules.SD15 import SDClip, SDToken 6 | from modules.Utilities import Latent 7 | from modules.clip import Clip 8 | 9 | 10 | class sm_SD15(ModelBase.BASE): 11 | """#### Class representing the SD15 model. 12 | 13 | #### Args: 14 | - `ModelBase.BASE` (ModelBase.BASE): The base model class. 15 | """ 16 | 17 | unet_config: dict = { 18 | "context_dim": 768, 19 | "model_channels": 320, 20 | "use_linear_in_transformer": False, 21 | "adm_in_channels": None, 22 | "use_temporal_attention": False, 23 | } 24 | 25 | unet_extra_config: dict = { 26 | "num_heads": 8, 27 | "num_head_channels": -1, 28 | } 29 | 30 | latent_format: Latent.SD15 = Latent.SD15 31 | 32 | def process_clip_state_dict(self, state_dict: dict) -> dict: 33 | """#### Process the state dictionary for the CLIP model. 34 | 35 | #### Args: 36 | - `state_dict` (dict): The state dictionary. 37 | 38 | #### Returns: 39 | - `dict`: The processed state dictionary. 40 | """ 41 | k = list(state_dict.keys()) 42 | for x in k: 43 | if x.startswith("cond_stage_model.transformer.") and not x.startswith( 44 | "cond_stage_model.transformer.text_model." 45 | ): 46 | y = x.replace( 47 | "cond_stage_model.transformer.", 48 | "cond_stage_model.transformer.text_model.", 49 | ) 50 | state_dict[y] = state_dict.pop(x) 51 | 52 | if ( 53 | "cond_stage_model.transformer.text_model.embeddings.position_ids" 54 | in state_dict 55 | ): 56 | ids = state_dict[ 57 | "cond_stage_model.transformer.text_model.embeddings.position_ids" 58 | ] 59 | if ids.dtype == torch.float32: 60 | state_dict[ 61 | "cond_stage_model.transformer.text_model.embeddings.position_ids" 62 | ] = ids.round() 63 | 64 | replace_prefix = {} 65 | replace_prefix["cond_stage_model."] = "clip_l." 66 | state_dict = util.state_dict_prefix_replace( 67 | state_dict, replace_prefix, filter_keys=True 68 | ) 69 | return state_dict 70 | 71 | def clip_target(self) -> Clip.ClipTarget: 72 | """#### Get the target CLIP model. 73 | 74 | #### Returns: 75 | - `Clip.ClipTarget`: The target CLIP model. 76 | """ 77 | return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel) 78 | 79 | models = [ 80 | sm_SD15, Flux.Flux 81 | ] -------------------------------------------------------------------------------- /modules/StableFast/StableFast.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import functools 3 | import logging 4 | from dataclasses import dataclass 5 | 6 | import torch 7 | 8 | try: 9 | from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig 10 | from sfast.compilers.diffusion_pipeline_compiler import ( 11 | _enable_xformers, 12 | _modify_model, 13 | ) 14 | from sfast.cuda.graphs import make_dynamic_graphed_callable 15 | from sfast.jit import utils as jit_utils 16 | from sfast.jit.trace_helper import trace_with_kwargs 17 | except: 18 | pass 19 | 20 | 21 | def hash_arg(arg): 22 | # micro optimization: bool obj is an instance of int 23 | if isinstance(arg, (str, int, float, bytes)): 24 | return arg 25 | if isinstance(arg, (tuple, list)): 26 | return tuple(map(hash_arg, arg)) 27 | if isinstance(arg, dict): 28 | return tuple( 29 | sorted( 30 | ((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0] 31 | ) 32 | ) 33 | return type(arg) 34 | 35 | 36 | class ModuleFactory: 37 | def get_converted_kwargs(self): 38 | return self.converted_kwargs 39 | 40 | 41 | import torch as th 42 | import torch.nn as nn 43 | import copy 44 | 45 | 46 | class BaseModelApplyModelModule(torch.nn.Module): 47 | def __init__(self, func, module): 48 | super().__init__() 49 | self.func = func 50 | self.module = module 51 | 52 | def forward( 53 | self, 54 | input_x, 55 | timestep, 56 | c_concat=None, 57 | c_crossattn=None, 58 | y=None, 59 | control=None, 60 | transformer_options={}, 61 | ): 62 | kwargs = {"y": y} 63 | 64 | new_transformer_options = {} 65 | 66 | return self.func( 67 | input_x, 68 | timestep, 69 | c_concat=c_concat, 70 | c_crossattn=c_crossattn, 71 | control=control, 72 | transformer_options=new_transformer_options, 73 | **kwargs, 74 | ) 75 | 76 | 77 | class BaseModelApplyModelModuleFactory(ModuleFactory): 78 | kwargs_name = ( 79 | "input_x", 80 | "timestep", 81 | "c_concat", 82 | "c_crossattn", 83 | "y", 84 | "control", 85 | ) 86 | 87 | def __init__(self, callable, kwargs) -> None: 88 | self.callable = callable 89 | self.unet_config = callable.__self__.model_config.unet_config 90 | self.kwargs = kwargs 91 | self.patch_module = {} 92 | self.patch_module_parameter = {} 93 | self.converted_kwargs = self.gen_converted_kwargs() 94 | 95 | def gen_converted_kwargs(self): 96 | converted_kwargs = {} 97 | for arg_name, arg in self.kwargs.items(): 98 | if arg_name in self.kwargs_name: 99 | converted_kwargs[arg_name] = arg 100 | 101 | transformer_options = self.kwargs.get("transformer_options", {}) 102 | patches = transformer_options.get("patches", {}) 103 | 104 | patch_module = {} 105 | patch_module_parameter = {} 106 | 107 | new_transformer_options = {} 108 | new_transformer_options["patches"] = patch_module_parameter 109 | 110 | self.patch_module = patch_module 111 | self.patch_module_parameter = patch_module_parameter 112 | return converted_kwargs 113 | 114 | def gen_cache_key(self): 115 | key_kwargs = {} 116 | for k, v in self.converted_kwargs.items(): 117 | key_kwargs[k] = v 118 | 119 | patch_module_cache_key = {} 120 | return ( 121 | self.callable.__class__.__qualname__, 122 | hash_arg(self.unet_config), 123 | hash_arg(key_kwargs), 124 | hash_arg(patch_module_cache_key), 125 | ) 126 | 127 | @contextlib.contextmanager 128 | def converted_module_context(self): 129 | module = BaseModelApplyModelModule(self.callable, self.callable.__self__) 130 | yield (module, self.converted_kwargs) 131 | 132 | 133 | logger = logging.getLogger() 134 | 135 | 136 | @dataclass 137 | class TracedModuleCacheItem: 138 | module: object 139 | patch_id: int 140 | device: str 141 | 142 | 143 | class LazyTraceModule: 144 | traced_modules = {} 145 | 146 | def __init__(self, config=None, patch_id=None, **kwargs_) -> None: 147 | self.config = config 148 | self.patch_id = patch_id 149 | self.kwargs_ = kwargs_ 150 | self.modify_model = functools.partial( 151 | _modify_model, 152 | enable_cnn_optimization=config.enable_cnn_optimization, 153 | prefer_lowp_gemm=config.prefer_lowp_gemm, 154 | enable_triton=config.enable_triton, 155 | enable_triton_reshape=config.enable_triton, 156 | memory_format=config.memory_format, 157 | ) 158 | self.cuda_graph_modules = {} 159 | 160 | def ts_compiler( 161 | self, 162 | m, 163 | ): 164 | with torch.jit.optimized_execution(True): 165 | if self.config.enable_jit_freeze: 166 | # raw freeze causes Tensor reference leak 167 | # because the constant Tensors in the GraphFunction of 168 | # the compilation unit are never freed. 169 | m.eval() 170 | m = jit_utils.better_freeze(m) 171 | self.modify_model(m) 172 | 173 | if self.config.enable_cuda_graph: 174 | m = make_dynamic_graphed_callable(m) 175 | return m 176 | 177 | def __call__(self, model_function, /, **kwargs): 178 | module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs) 179 | kwargs = module_factory.get_converted_kwargs() 180 | key = module_factory.gen_cache_key() 181 | 182 | traced_module = self.cuda_graph_modules.get(key) 183 | if traced_module is None: 184 | with module_factory.converted_module_context() as (m_model, m_kwargs): 185 | logger.info( 186 | f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}' 187 | ) 188 | traced_m, call_helper = trace_with_kwargs( 189 | m_model, None, m_kwargs, **self.kwargs_ 190 | ) 191 | 192 | traced_m = self.ts_compiler(traced_m) 193 | traced_module = call_helper(traced_m) 194 | self.cuda_graph_modules[key] = traced_module 195 | 196 | return traced_module(**kwargs) 197 | 198 | 199 | def build_lazy_trace_module(config, device, patch_id): 200 | config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda" 201 | 202 | if config.enable_xformers: 203 | _enable_xformers(None) 204 | 205 | return LazyTraceModule( 206 | config=config, 207 | patch_id=patch_id, 208 | check_trace=True, 209 | strict=True, 210 | ) 211 | 212 | 213 | def gen_stable_fast_config(): 214 | config = CompilationConfig.Default() 215 | try: 216 | import xformers 217 | 218 | config.enable_xformers = True 219 | except ImportError: 220 | print("xformers not installed, skip") 221 | 222 | # CUDA Graph is suggested for small batch sizes. 223 | # After capturing, the model only accepts one fixed image size. 224 | # If you want the model to be dynamic, don't enable it. 225 | config.enable_cuda_graph = False 226 | # config.enable_jit_freeze = False 227 | return config 228 | 229 | 230 | class StableFastPatch: 231 | def __init__(self, model, config): 232 | self.model = model 233 | self.config = config 234 | self.stable_fast_model = None 235 | 236 | def __call__(self, model_function, params): 237 | input_x = params.get("input") 238 | timestep_ = params.get("timestep") 239 | c = params.get("c") 240 | 241 | if self.stable_fast_model is None: 242 | self.stable_fast_model = build_lazy_trace_module( 243 | self.config, 244 | input_x.device, 245 | id(self), 246 | ) 247 | 248 | return self.stable_fast_model( 249 | model_function, input_x=input_x, timestep=timestep_, **c 250 | ) 251 | 252 | def to(self, device): 253 | if type(device) == torch.device: 254 | if self.config.enable_cuda_graph or self.config.enable_jit_freeze: 255 | if device.type == "cpu": 256 | del self.stable_fast_model 257 | self.stable_fast_model = None 258 | print( 259 | "\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m" 260 | ) 261 | return self 262 | 263 | 264 | class ApplyStableFastUnet: 265 | def apply_stable_fast(self, model, enable_cuda_graph): 266 | config = gen_stable_fast_config() 267 | 268 | if config.memory_format is not None: 269 | model.model.to(memory_format=config.memory_format) 270 | 271 | patch = StableFastPatch(model, config) 272 | model_stable_fast = model.clone() 273 | model_stable_fast.set_model_unet_function_wrapper(patch) 274 | return (model_stable_fast,) -------------------------------------------------------------------------------- /modules/UltimateSDUpscale/USDU_upscaler.py: -------------------------------------------------------------------------------- 1 | import logging as logger 2 | import torch 3 | from PIL import Image 4 | 5 | from modules.Device import Device 6 | from modules.UltimateSDUpscale import RDRB 7 | from modules.UltimateSDUpscale import image_util 8 | from modules.Utilities import util 9 | 10 | 11 | def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel: 12 | """#### Load a state dictionary into a PyTorch model. 13 | 14 | #### Args: 15 | - `state_dict` (dict): The state dictionary. 16 | 17 | #### Returns: 18 | - `RDRB.PyTorchModel`: The loaded PyTorch model. 19 | """ 20 | logger.debug("Loading state dict into pytorch model arch") 21 | state_dict_keys = list(state_dict.keys()) 22 | if "params_ema" in state_dict_keys: 23 | state_dict = state_dict["params_ema"] 24 | model = RDRB.RRDBNet(state_dict) 25 | return model 26 | 27 | 28 | class UpscaleModelLoader: 29 | """#### Class for loading upscale models.""" 30 | 31 | def load_model(self, model_name: str) -> tuple: 32 | """#### Load an upscale model. 33 | 34 | #### Args: 35 | - `model_name` (str): The name of the model. 36 | 37 | #### Returns: 38 | - `tuple`: The loaded model. 39 | """ 40 | model_path = f"./_internal/ESRGAN/{model_name}" 41 | sd = util.load_torch_file(model_path, safe_load=True) 42 | if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: 43 | sd = util.state_dict_prefix_replace(sd, {"module.": ""}) 44 | out = load_state_dict(sd).eval() 45 | return (out,) 46 | 47 | 48 | class ImageUpscaleWithModel: 49 | """#### Class for upscaling images with a model.""" 50 | 51 | def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple: 52 | """#### Upscale an image using a model. 53 | 54 | #### Args: 55 | - `upscale_model` (torch.nn.Module): The upscale model. 56 | - `image` (torch.Tensor): The input image tensor. 57 | 58 | #### Returns: 59 | - `tuple`: The upscaled image tensor. 60 | """ 61 | if torch.cuda.is_available(): 62 | device = torch.device(torch.cuda.current_device()) 63 | else: 64 | device = torch.device("cpu") 65 | upscale_model.to(device) 66 | in_img = image.movedim(-1, -3).to(device) 67 | Device.get_free_memory(device) 68 | 69 | tile = 512 70 | overlap = 32 71 | 72 | oom = True 73 | while oom: 74 | steps = in_img.shape[0] * image_util.get_tiled_scale_steps( 75 | in_img.shape[3], 76 | in_img.shape[2], 77 | tile_x=tile, 78 | tile_y=tile, 79 | overlap=overlap, 80 | ) 81 | pbar = util.ProgressBar(steps) 82 | s = image_util.tiled_scale( 83 | in_img, 84 | lambda a: upscale_model(a), 85 | tile_x=tile, 86 | tile_y=tile, 87 | overlap=overlap, 88 | upscale_amount=upscale_model.scale, 89 | pbar=pbar, 90 | ) 91 | oom = False 92 | 93 | upscale_model.cpu() 94 | s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0) 95 | return (s,) 96 | 97 | 98 | def torch_gc() -> None: 99 | """#### Perform garbage collection for PyTorch.""" 100 | pass 101 | 102 | 103 | class Script: 104 | """#### Class representing a script.""" 105 | pass 106 | 107 | 108 | class Options: 109 | """#### Class representing options.""" 110 | 111 | img2img_background_color: str = "#ffffff" # Set to white for now 112 | 113 | 114 | class State: 115 | """#### Class representing the state.""" 116 | 117 | interrupted: bool = False 118 | 119 | def begin(self) -> None: 120 | """#### Begin the state.""" 121 | pass 122 | 123 | def end(self) -> None: 124 | """#### End the state.""" 125 | pass 126 | 127 | 128 | opts = Options() 129 | state = State() 130 | 131 | # Will only ever hold 1 upscaler 132 | sd_upscalers = [None] 133 | actual_upscaler = None 134 | 135 | # Batch of images to upscale 136 | batch = None 137 | 138 | 139 | if not hasattr(Image, "Resampling"): # For older versions of Pillow 140 | Image.Resampling = Image 141 | 142 | 143 | class Upscaler: 144 | """#### Class for upscaling images.""" 145 | 146 | def _upscale(self, img: Image.Image, scale: float) -> Image.Image: 147 | """#### Upscale an image. 148 | 149 | #### Args: 150 | - `img` (Image.Image): The input image. 151 | - `scale` (float): The scale factor. 152 | 153 | #### Returns: 154 | - `Image.Image`: The upscaled image. 155 | """ 156 | global actual_upscaler 157 | tensor = image_util.pil_to_tensor(img) 158 | image_upscale_node = ImageUpscaleWithModel() 159 | (upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor) 160 | return image_util.tensor_to_pil(upscaled) 161 | 162 | def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image: 163 | """#### Upscale an image with a selected model. 164 | 165 | #### Args: 166 | - `img` (Image.Image): The input image. 167 | - `scale` (float): The scale factor. 168 | - `selected_model` (str, optional): The selected model. Defaults to None. 169 | 170 | #### Returns: 171 | - `Image.Image`: The upscaled image. 172 | """ 173 | global batch 174 | batch = [self._upscale(img, scale) for img in batch] 175 | return batch[0] 176 | 177 | 178 | class UpscalerData: 179 | """#### Class for storing upscaler data.""" 180 | 181 | name: str = "" 182 | data_path: str = "" 183 | 184 | def __init__(self): 185 | self.scaler = Upscaler() -------------------------------------------------------------------------------- /modules/UltimateSDUpscale/USDU_util.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | import torch 3 | import torch.nn as nn 4 | 5 | ConvMode = Literal["CNA", "NAC", "CNAC"] 6 | 7 | def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module: 8 | """#### Get the activation layer. 9 | 10 | #### Args: 11 | - `act_type` (str): The type of activation. 12 | - `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True. 13 | - `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2. 14 | - `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1. 15 | 16 | #### Returns: 17 | - `nn.Module`: The activation layer. 18 | """ 19 | act_type = act_type.lower() 20 | layer = nn.LeakyReLU(neg_slope, inplace) 21 | return layer 22 | 23 | def get_valid_padding(kernel_size: int, dilation: int) -> int: 24 | """#### Get the valid padding for a convolutional layer. 25 | 26 | #### Args: 27 | - `kernel_size` (int): The size of the kernel. 28 | - `dilation` (int): The dilation rate. 29 | 30 | #### Returns: 31 | - `int`: The valid padding. 32 | """ 33 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1) 34 | padding = (kernel_size - 1) // 2 35 | return padding 36 | 37 | def sequential(*args: nn.Module) -> nn.Sequential: 38 | """#### Create a sequential container. 39 | 40 | #### Args: 41 | - `*args` (nn.Module): The modules to include in the sequential container. 42 | 43 | #### Returns: 44 | - `nn.Sequential`: The sequential container. 45 | """ 46 | modules = [] 47 | for module in args: 48 | if isinstance(module, nn.Sequential): 49 | for submodule in module.children(): 50 | modules.append(submodule) 51 | elif isinstance(module, nn.Module): 52 | modules.append(module) 53 | return nn.Sequential(*modules) 54 | 55 | def conv_block( 56 | in_nc: int, 57 | out_nc: int, 58 | kernel_size: int, 59 | stride: int = 1, 60 | dilation: int = 1, 61 | groups: int = 1, 62 | bias: bool = True, 63 | pad_type: str = "zero", 64 | norm_type: str | None = None, 65 | act_type: str | None = "relu", 66 | mode: ConvMode = "CNA", 67 | c2x2: bool = False, 68 | ) -> nn.Sequential: 69 | """#### Create a convolutional block. 70 | 71 | #### Args: 72 | - `in_nc` (int): The number of input channels. 73 | - `out_nc` (int): The number of output channels. 74 | - `kernel_size` (int): The size of the kernel. 75 | - `stride` (int, optional): The stride of the convolution. Defaults to 1. 76 | - `dilation` (int, optional): The dilation rate. Defaults to 1. 77 | - `groups` (int, optional): The number of groups. Defaults to 1. 78 | - `bias` (bool, optional): Whether to include a bias term. Defaults to True. 79 | - `pad_type` (str, optional): The type of padding. Defaults to "zero". 80 | - `norm_type` (str | None, optional): The type of normalization. Defaults to None. 81 | - `act_type` (str | None, optional): The type of activation. Defaults to "relu". 82 | - `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA". 83 | - `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False. 84 | 85 | #### Returns: 86 | - `nn.Sequential`: The convolutional block. 87 | """ 88 | assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode) 89 | padding = get_valid_padding(kernel_size, dilation) 90 | padding = padding if pad_type == "zero" else 0 91 | 92 | c = nn.Conv2d( 93 | in_nc, 94 | out_nc, 95 | kernel_size=kernel_size, 96 | stride=stride, 97 | padding=padding, 98 | dilation=dilation, 99 | bias=bias, 100 | groups=groups, 101 | ) 102 | a = act(act_type) if act_type else None 103 | if mode in ("CNA", "CNAC"): 104 | return sequential(None, c, None, a) 105 | 106 | def upconv_block( 107 | in_nc: int, 108 | out_nc: int, 109 | upscale_factor: int = 2, 110 | kernel_size: int = 3, 111 | stride: int = 1, 112 | bias: bool = True, 113 | pad_type: str = "zero", 114 | norm_type: str | None = None, 115 | act_type: str = "relu", 116 | mode: str = "nearest", 117 | c2x2: bool = False, 118 | ) -> nn.Sequential: 119 | """#### Create an upsampling convolutional block. 120 | 121 | #### Args: 122 | - `in_nc` (int): The number of input channels. 123 | - `out_nc` (int): The number of output channels. 124 | - `upscale_factor` (int, optional): The upscale factor. Defaults to 2. 125 | - `kernel_size` (int, optional): The size of the kernel. Defaults to 3. 126 | - `stride` (int, optional): The stride of the convolution. Defaults to 1. 127 | - `bias` (bool, optional): Whether to include a bias term. Defaults to True. 128 | - `pad_type` (str, optional): The type of padding. Defaults to "zero". 129 | - `norm_type` (str | None, optional): The type of normalization. Defaults to None. 130 | - `act_type` (str, optional): The type of activation. Defaults to "relu". 131 | - `mode` (str, optional): The mode of upsampling. Defaults to "nearest". 132 | - `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False. 133 | 134 | #### Returns: 135 | - `nn.Sequential`: The upsampling convolutional block. 136 | """ 137 | upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode) 138 | conv = conv_block( 139 | in_nc, 140 | out_nc, 141 | kernel_size, 142 | stride, 143 | bias=bias, 144 | pad_type=pad_type, 145 | norm_type=norm_type, 146 | act_type=act_type, 147 | c2x2=c2x2, 148 | ) 149 | return sequential(upsample, conv) 150 | 151 | class ShortcutBlock(nn.Module): 152 | """#### Elementwise sum the output of a submodule to its input.""" 153 | 154 | def __init__(self, submodule: nn.Module): 155 | """#### Initialize the ShortcutBlock. 156 | 157 | #### Args: 158 | - `submodule` (nn.Module): The submodule to apply. 159 | """ 160 | super(ShortcutBlock, self).__init__() 161 | self.sub = submodule 162 | 163 | def forward(self, x: torch.Tensor) -> torch.Tensor: 164 | """#### Forward pass. 165 | 166 | #### Args: 167 | - `x` (torch.Tensor): The input tensor. 168 | 169 | #### Returns: 170 | - `torch.Tensor`: The output tensor. 171 | """ 172 | output = x + self.sub(x) 173 | return output -------------------------------------------------------------------------------- /modules/UltimateSDUpscale/image_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import torch 4 | from PIL import Image 5 | 6 | 7 | def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int: 8 | """#### Calculate the number of steps required for tiled scaling. 9 | 10 | #### Args: 11 | - `width` (int): The width of the image. 12 | - `height` (int): The height of the image. 13 | - `tile_x` (int): The width of each tile. 14 | - `tile_y` (int): The height of each tile. 15 | - `overlap` (int): The overlap between tiles. 16 | 17 | #### Returns: 18 | - `int`: The number of steps required for tiled scaling. 19 | """ 20 | return math.ceil((height / (tile_y - overlap))) * math.ceil( 21 | (width / (tile_x - overlap)) 22 | ) 23 | 24 | 25 | @torch.inference_mode() 26 | def tiled_scale( 27 | samples: torch.Tensor, 28 | function: callable, 29 | tile_x: int = 64, 30 | tile_y: int = 64, 31 | overlap: int = 8, 32 | upscale_amount: float = 4, 33 | out_channels: int = 3, 34 | pbar: any = None, 35 | ) -> torch.Tensor: 36 | """#### Perform tiled scaling on a batch of samples. 37 | 38 | #### Args: 39 | - `samples` (torch.Tensor): The input samples. 40 | - `function` (callable): The function to apply to each tile. 41 | - `tile_x` (int, optional): The width of each tile. Defaults to 64. 42 | - `tile_y` (int, optional): The height of each tile. Defaults to 64. 43 | - `overlap` (int, optional): The overlap between tiles. Defaults to 8. 44 | - `upscale_amount` (float, optional): The upscale amount. Defaults to 4. 45 | - `out_channels` (int, optional): The number of output channels. Defaults to 3. 46 | - `pbar` (any, optional): The progress bar. Defaults to None. 47 | 48 | #### Returns: 49 | - `torch.Tensor`: The scaled output tensor. 50 | """ 51 | output = torch.empty( 52 | ( 53 | samples.shape[0], 54 | out_channels, 55 | round(samples.shape[2] * upscale_amount), 56 | round(samples.shape[3] * upscale_amount), 57 | ), 58 | device="cpu", 59 | ) 60 | for b in range(samples.shape[0]): 61 | s = samples[b : b + 1] 62 | out = torch.zeros( 63 | ( 64 | s.shape[0], 65 | out_channels, 66 | round(s.shape[2] * upscale_amount), 67 | round(s.shape[3] * upscale_amount), 68 | ), 69 | device="cpu", 70 | ) 71 | out_div = torch.zeros( 72 | ( 73 | s.shape[0], 74 | out_channels, 75 | round(s.shape[2] * upscale_amount), 76 | round(s.shape[3] * upscale_amount), 77 | ), 78 | device="cpu", 79 | ) 80 | for y in range(0, s.shape[2], tile_y - overlap): 81 | for x in range(0, s.shape[3], tile_x - overlap): 82 | s_in = s[:, :, y : y + tile_y, x : x + tile_x] 83 | 84 | ps = function(s_in).cpu() 85 | mask = torch.ones_like(ps) 86 | feather = round(overlap * upscale_amount) 87 | for t in range(feather): 88 | mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1) 89 | mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= ( 90 | 1.0 / feather 91 | ) * (t + 1) 92 | mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1) 93 | mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= ( 94 | 1.0 / feather 95 | ) * (t + 1) 96 | out[ 97 | :, 98 | :, 99 | round(y * upscale_amount) : round((y + tile_y) * upscale_amount), 100 | round(x * upscale_amount) : round((x + tile_x) * upscale_amount), 101 | ] += ps * mask 102 | out_div[ 103 | :, 104 | :, 105 | round(y * upscale_amount) : round((y + tile_y) * upscale_amount), 106 | round(x * upscale_amount) : round((x + tile_x) * upscale_amount), 107 | ] += mask 108 | 109 | output[b : b + 1] = out / out_div 110 | return output 111 | 112 | 113 | def flatten(img: Image.Image, bgcolor: str) -> Image.Image: 114 | """#### Replace transparency with a background color. 115 | 116 | #### Args: 117 | - `img` (Image.Image): The input image. 118 | - `bgcolor` (str): The background color. 119 | 120 | #### Returns: 121 | - `Image.Image`: The image with transparency replaced by the background color. 122 | """ 123 | if img.mode in ("RGB"): 124 | return img 125 | return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert( 126 | "RGB" 127 | ) 128 | 129 | 130 | BLUR_KERNEL_SIZE = 15 131 | 132 | 133 | def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image: 134 | """#### Convert a tensor to a PIL image. 135 | 136 | #### Args: 137 | - `img_tensor` (torch.Tensor): The input tensor. 138 | - `batch_index` (int, optional): The batch index. Defaults to 0. 139 | 140 | #### Returns: 141 | - `Image.Image`: The converted PIL image. 142 | """ 143 | img_tensor = img_tensor[batch_index].unsqueeze(0) 144 | i = 255.0 * img_tensor.cpu().numpy() 145 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze()) 146 | return img 147 | 148 | 149 | def pil_to_tensor(image: Image.Image) -> torch.Tensor: 150 | """#### Convert a PIL image to a tensor. 151 | 152 | #### Args: 153 | - `image` (Image.Image): The input PIL image. 154 | 155 | #### Returns: 156 | - `torch.Tensor`: The converted tensor. 157 | """ 158 | image = np.array(image).astype(np.float32) / 255.0 159 | image = torch.from_numpy(image).unsqueeze(0) 160 | return image 161 | 162 | 163 | def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple: 164 | """#### Get the coordinates of the white rectangular mask region. 165 | 166 | #### Args: 167 | - `mask` (Image.Image): The input mask image in 'L' mode. 168 | - `pad` (int, optional): The padding to apply. Defaults to 0. 169 | 170 | #### Returns: 171 | - `tuple`: The coordinates of the crop region. 172 | """ 173 | coordinates = mask.getbbox() 174 | if coordinates is not None: 175 | x1, y1, x2, y2 = coordinates 176 | else: 177 | x1, y1, x2, y2 = mask.width, mask.height, 0, 0 178 | # Apply padding 179 | x1 = max(x1 - pad, 0) 180 | y1 = max(y1 - pad, 0) 181 | x2 = min(x2 + pad, mask.width) 182 | y2 = min(y2 + pad, mask.height) 183 | return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height)) 184 | 185 | 186 | def fix_crop_region(region: tuple, image_size: tuple) -> tuple: 187 | """#### Remove the extra pixel added by the get_crop_region function. 188 | 189 | #### Args: 190 | - `region` (tuple): The crop region coordinates. 191 | - `image_size` (tuple): The size of the image. 192 | 193 | #### Returns: 194 | - `tuple`: The fixed crop region coordinates. 195 | """ 196 | image_width, image_height = image_size 197 | x1, y1, x2, y2 = region 198 | if x2 < image_width: 199 | x2 -= 1 200 | if y2 < image_height: 201 | y2 -= 1 202 | return x1, y1, x2, y2 203 | 204 | 205 | def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple: 206 | """#### Expand a crop region to a specified target size. 207 | 208 | #### Args: 209 | - `region` (tuple): The crop region coordinates. 210 | - `width` (int): The width of the image. 211 | - `height` (int): The height of the image. 212 | - `target_width` (int): The desired width of the crop region. 213 | - `target_height` (int): The desired height of the crop region. 214 | 215 | #### Returns: 216 | - `tuple`: The expanded crop region coordinates and the target size. 217 | """ 218 | x1, y1, x2, y2 = region 219 | actual_width = x2 - x1 220 | actual_height = y2 - y1 221 | 222 | # Try to expand region to the right of half the difference 223 | width_diff = target_width - actual_width 224 | x2 = min(x2 + width_diff // 2, width) 225 | # Expand region to the left of the difference including the pixels that could not be expanded to the right 226 | width_diff = target_width - (x2 - x1) 227 | x1 = max(x1 - width_diff, 0) 228 | # Try the right again 229 | width_diff = target_width - (x2 - x1) 230 | x2 = min(x2 + width_diff, width) 231 | 232 | # Try to expand region to the bottom of half the difference 233 | height_diff = target_height - actual_height 234 | y2 = min(y2 + height_diff // 2, height) 235 | # Expand region to the top of the difference including the pixels that could not be expanded to the bottom 236 | height_diff = target_height - (y2 - y1) 237 | y1 = max(y1 - height_diff, 0) 238 | # Try the bottom again 239 | height_diff = target_height - (y2 - y1) 240 | y2 = min(y2 + height_diff, height) 241 | 242 | return (x1, y1, x2, y2), (target_width, target_height) 243 | 244 | 245 | def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list: 246 | """#### Crop conditioning data to match a specific region. 247 | 248 | #### Args: 249 | - `cond` (list): The conditioning data. 250 | - `region` (tuple): The crop region coordinates. 251 | - `init_size` (tuple): The initial size of the image. 252 | - `canvas_size` (tuple): The size of the canvas. 253 | - `tile_size` (tuple): The size of the tile. 254 | - `w_pad` (int, optional): The width padding. Defaults to 0. 255 | - `h_pad` (int, optional): The height padding. Defaults to 0. 256 | 257 | #### Returns: 258 | - `list`: The cropped conditioning data. 259 | """ 260 | cropped = [] 261 | for emb, x in cond: 262 | cond_dict = x.copy() 263 | n = [emb, cond_dict] 264 | cropped.append(n) 265 | return cropped -------------------------------------------------------------------------------- /modules/Utilities/Enhancer.py: -------------------------------------------------------------------------------- 1 | import ollama 2 | import os 3 | 4 | from modules.Utilities import util 5 | 6 | 7 | def enhance_prompt(p: str) -> str: 8 | """#### Enhance a text-to-image prompt using Ollama. 9 | 10 | #### Args: 11 | - `p` (str, optional): The prompt. Defaults to `None`. 12 | 13 | #### Returns: 14 | - `str`: The enhanced prompt 15 | """ 16 | 17 | # Load the prompt from the file 18 | prompt = util.load_parameters_from_file()[0] 19 | if p is None: 20 | pass 21 | else: 22 | prompt = p 23 | print(prompt) 24 | response = ollama.chat( 25 | model="deepseek-r1", 26 | messages=[ 27 | { 28 | "role": "user", 29 | "content": f"""Your goal is to generate a text-to-image prompt based on a user's input, detailing their desired final outcome for an image. The user will provide specific details about the characteristics, features, or elements they want the image to include. The prompt should guide the generation of an image that aligns with the user's desired outcome. 30 | 31 | Generate a text-to-image prompt by arranging the following blocks in a single string, separated by commas: 32 | 33 | Image Type: [Specify desired image type] 34 | 35 | Aesthetic or Mood: [Describe desired aesthetic or mood] 36 | 37 | Lighting Conditions: [Specify desired lighting conditions] 38 | 39 | Composition or Framing: [Provide details about desired composition or framing] 40 | 41 | Background: [Specify desired background elements or setting] 42 | 43 | Colors: [Mention any specific colors or color palette] 44 | 45 | Objects or Elements: [List specific objects or features] 46 | 47 | Style or Artistic Influence: [Mention desired artistic style or influence] 48 | 49 | Subject's Appearance: [Describe appearance of main subject] 50 | 51 | Ensure the blocks are arranged in order of visual importance, from the most significant to the least significant, to effectively guide image generation, a block can be surrounded by parentheses to gain additionnal significance. 52 | 53 | This is an example of a user's input: "a beautiful blonde lady in lingerie sitting in seiza in a seducing way with a focus on her assets" 54 | 55 | And this is an example of a desired output: "portrait| serene and mysterious| soft, diffused lighting| close-up shot, emphasizing facial features| simple and blurred background| earthy tones with a hint of warm highlights| renaissance painting| a beautiful lady with freckles and dark makeup" 56 | 57 | Here is the user's input: {prompt} 58 | 59 | Write the prompt in the same style as the example above, in a single line , with absolutely no additional information, words or symbols other than the enhanced prompt. 60 | 61 | Output:""", 62 | }, 63 | ], 64 | ) 65 | content = response["message"]["content"] 66 | print("here's the enhanced prompt :", content) 67 | 68 | if "" in content and "" in content: 69 | # Get everything after 70 | enhanced = content.split("")[-1].strip() 71 | else: 72 | enhanced = content.strip() 73 | print("here's the enhanced prompt:", enhanced) 74 | os.system("ollama stop deepseek-r1") 75 | return "masterpiece, best quality, (extremely detailed CG unity 8k wallpaper, masterpiece, best quality, ultra-detailed, best shadow), high contrast, (best illumination), ((cinematic light)), hyper detail, dramatic light, depth of field," + enhanced 76 | -------------------------------------------------------------------------------- /modules/Utilities/Latent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple 2 | import torch 3 | from modules.Device import Device 4 | from modules.Utilities import util 5 | 6 | class LatentFormat: 7 | """#### Base class for latent formats. 8 | 9 | #### Attributes: 10 | - `scale_factor` (float): The scale factor for the latent format. 11 | 12 | #### Returns: 13 | - `LatentFormat`: A latent format object. 14 | """ 15 | 16 | scale_factor: float = 1.0 17 | latent_channels: int = 4 18 | 19 | def process_in(self, latent: torch.Tensor) -> torch.Tensor: 20 | """#### Process the latent input, by multiplying it by the scale factor. 21 | 22 | #### Args: 23 | - `latent` (torch.Tensor): The latent tensor. 24 | 25 | #### Returns: 26 | - `torch.Tensor`: The processed latent tensor. 27 | """ 28 | return latent * self.scale_factor 29 | 30 | def process_out(self, latent: torch.Tensor) -> torch.Tensor: 31 | """#### Process the latent output, by dividing it by the scale factor. 32 | 33 | #### Args: 34 | - `latent` (torch.Tensor): The latent tensor. 35 | 36 | #### Returns: 37 | - `torch.Tensor`: The processed latent tensor. 38 | """ 39 | return latent / self.scale_factor 40 | 41 | class SD15(LatentFormat): 42 | """#### SD15 latent format. 43 | 44 | #### Args: 45 | - `LatentFormat` (LatentFormat): The base latent format class. 46 | """ 47 | latent_channels: int = 4 48 | def __init__(self, scale_factor: float = 0.18215): 49 | """#### Initialize the SD15 latent format. 50 | 51 | #### Args: 52 | - `scale_factor` (float, optional): The scale factor. Defaults to 0.18215. 53 | """ 54 | self.scale_factor = scale_factor 55 | self.latent_rgb_factors = [ 56 | # R G B 57 | [0.3512, 0.2297, 0.3227], 58 | [0.3250, 0.4974, 0.2350], 59 | [-0.2829, 0.1762, 0.2721], 60 | [-0.2120, -0.2616, -0.7177], 61 | ] 62 | self.taesd_decoder_name = "taesd_decoder" 63 | 64 | class SD3(LatentFormat): 65 | latent_channels = 16 66 | 67 | def __init__(self): 68 | """#### Initialize the SD3 latent format.""" 69 | self.scale_factor = 1.5305 70 | self.shift_factor = 0.0609 71 | self.latent_rgb_factors = [ 72 | [-0.0645, 0.0177, 0.1052], 73 | [0.0028, 0.0312, 0.0650], 74 | [0.1848, 0.0762, 0.0360], 75 | [0.0944, 0.0360, 0.0889], 76 | [0.0897, 0.0506, -0.0364], 77 | [-0.0020, 0.1203, 0.0284], 78 | [0.0855, 0.0118, 0.0283], 79 | [-0.0539, 0.0658, 0.1047], 80 | [-0.0057, 0.0116, 0.0700], 81 | [-0.0412, 0.0281, -0.0039], 82 | [0.1106, 0.1171, 0.1220], 83 | [-0.0248, 0.0682, -0.0481], 84 | [0.0815, 0.0846, 0.1207], 85 | [-0.0120, -0.0055, -0.0867], 86 | [-0.0749, -0.0634, -0.0456], 87 | [-0.1418, -0.1457, -0.1259], 88 | ] 89 | self.taesd_decoder_name = "taesd3_decoder" 90 | 91 | def process_in(self, latent: torch.Tensor) -> torch.Tensor: 92 | """#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor. 93 | 94 | #### Args: 95 | - `latent` (torch.Tensor): The latent tensor. 96 | 97 | #### Returns: 98 | - `torch.Tensor`: The processed latent tensor. 99 | """ 100 | return (latent - self.shift_factor) * self.scale_factor 101 | 102 | def process_out(self, latent: torch.Tensor) -> torch.Tensor: 103 | """#### Process the latent output, by dividing it by the scale factor and adding the shift factor. 104 | 105 | #### Args: 106 | - `latent` (torch.Tensor): The latent tensor. 107 | 108 | #### Returns: 109 | - `torch.Tensor`: The processed latent tensor. 110 | """ 111 | return (latent / self.scale_factor) + self.shift_factor 112 | 113 | 114 | class Flux1(SD3): 115 | latent_channels = 16 116 | 117 | def __init__(self): 118 | """#### Initialize the Flux1 latent format.""" 119 | self.scale_factor = 0.3611 120 | self.shift_factor = 0.1159 121 | self.latent_rgb_factors = [ 122 | [-0.0404, 0.0159, 0.0609], 123 | [0.0043, 0.0298, 0.0850], 124 | [0.0328, -0.0749, -0.0503], 125 | [-0.0245, 0.0085, 0.0549], 126 | [0.0966, 0.0894, 0.0530], 127 | [0.0035, 0.0399, 0.0123], 128 | [0.0583, 0.1184, 0.1262], 129 | [-0.0191, -0.0206, -0.0306], 130 | [-0.0324, 0.0055, 0.1001], 131 | [0.0955, 0.0659, -0.0545], 132 | [-0.0504, 0.0231, -0.0013], 133 | [0.0500, -0.0008, -0.0088], 134 | [0.0982, 0.0941, 0.0976], 135 | [-0.1233, -0.0280, -0.0897], 136 | [-0.0005, -0.0530, -0.0020], 137 | [-0.1273, -0.0932, -0.0680], 138 | ] 139 | self.taesd_decoder_name = "taef1_decoder" 140 | 141 | def process_in(self, latent: torch.Tensor) -> torch.Tensor: 142 | """#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor. 143 | 144 | #### Args: 145 | - `latent` (torch.Tensor): The latent tensor. 146 | 147 | #### Returns: 148 | - `torch.Tensor`: The processed latent tensor. 149 | """ 150 | return (latent - self.shift_factor) * self.scale_factor 151 | 152 | def process_out(self, latent: torch.Tensor) -> torch.Tensor: 153 | """#### Process the latent output, by dividing it by the scale factor and adding the shift factor. 154 | 155 | #### Args: 156 | - `latent` (torch.Tensor): The latent tensor. 157 | 158 | #### Returns: 159 | - `torch.Tensor`: The processed latent tensor. 160 | """ 161 | return (latent / self.scale_factor) + self.shift_factor 162 | 163 | class EmptyLatentImage: 164 | """#### A class to generate an empty latent image. 165 | 166 | #### Args: 167 | - `Device` (Device): The device to use for the latent image. 168 | """ 169 | 170 | def __init__(self): 171 | """#### Initialize the EmptyLatentImage class.""" 172 | self.device = Device.intermediate_device() 173 | 174 | def generate( 175 | self, width: int, height: int, batch_size: int = 1 176 | ) -> Tuple[Dict[str, torch.Tensor]]: 177 | """#### Generate an empty latent image 178 | 179 | #### Args: 180 | - `width` (int): The width of the latent image. 181 | - `height` (int): The height of the latent image. 182 | - `batch_size` (int, optional): The batch size. Defaults to 1. 183 | 184 | #### Returns: 185 | - `Tuple[Dict[str, torch.Tensor]]`: The generated latent image. 186 | """ 187 | latent = torch.zeros( 188 | [batch_size, 4, height // 8, width // 8], device=self.device 189 | ) 190 | return ({"samples": latent},) 191 | 192 | def fix_empty_latent_channels(model, latent_image): 193 | """#### Fix the empty latent image channels. 194 | 195 | #### Args: 196 | - `model` (Model): The model object. 197 | - `latent_image` (torch.Tensor): The latent image. 198 | 199 | #### Returns: 200 | - `torch.Tensor`: The fixed latent image. 201 | """ 202 | latent_channels = model.get_model_object( 203 | "latent_format" 204 | ).latent_channels # Resize the empty latent image so it has the right number of channels 205 | if ( 206 | latent_channels != latent_image.shape[1] 207 | and torch.count_nonzero(latent_image) == 0 208 | ): 209 | latent_image = util.repeat_to_batch_size(latent_image, latent_channels, dim=1) 210 | return latent_image -------------------------------------------------------------------------------- /modules/Utilities/upscale.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | 4 | 5 | def bislerp(samples: torch.Tensor, width: int, height: int) -> torch.Tensor: 6 | """#### Perform bilinear interpolation on samples. 7 | 8 | #### Args: 9 | - `samples` (torch.Tensor): The input samples. 10 | - `width` (int): The target width. 11 | - `height` (int): The target height. 12 | 13 | #### Returns: 14 | - `torch.Tensor`: The interpolated samples. 15 | """ 16 | 17 | def slerp(b1: torch.Tensor, b2: torch.Tensor, r: torch.Tensor) -> torch.Tensor: 18 | """#### Perform spherical linear interpolation between two vectors. 19 | 20 | #### Args: 21 | - `b1` (torch.Tensor): The first vector. 22 | - `b2` (torch.Tensor): The second vector. 23 | - `r` (torch.Tensor): The interpolation ratio. 24 | 25 | #### Returns: 26 | - `torch.Tensor`: The interpolated vector. 27 | """ 28 | 29 | c = b1.shape[-1] 30 | 31 | # norms 32 | b1_norms = torch.norm(b1, dim=-1, keepdim=True) 33 | b2_norms = torch.norm(b2, dim=-1, keepdim=True) 34 | 35 | # normalize 36 | b1_normalized = b1 / b1_norms 37 | b2_normalized = b2 / b2_norms 38 | 39 | # zero when norms are zero 40 | b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0 41 | b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0 42 | 43 | # slerp 44 | dot = (b1_normalized * b2_normalized).sum(1) 45 | omega = torch.acos(dot) 46 | so = torch.sin(omega) 47 | 48 | # technically not mathematically correct, but more pleasing? 49 | res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze( 50 | 1 51 | ) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze( 52 | 1 53 | ) * b2_normalized 54 | res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c) 55 | 56 | # edge cases for same or polar opposites 57 | res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5] 58 | res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1] 59 | return res 60 | 61 | def generate_bilinear_data( 62 | length_old: int, length_new: int, device: torch.device 63 | ) -> List[torch.Tensor]: 64 | """#### Generate bilinear data for interpolation. 65 | 66 | #### Args: 67 | - `length_old` (int): The old length. 68 | - `length_new` (int): The new length. 69 | - `device` (torch.device): The device to use. 70 | 71 | #### Returns: 72 | - `torch.Tensor`: The ratios. 73 | - `torch.Tensor`: The first coordinates. 74 | - `torch.Tensor`: The second coordinates. 75 | """ 76 | coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape( 77 | (1, 1, 1, -1) 78 | ) 79 | coords_1 = torch.nn.functional.interpolate( 80 | coords_1, size=(1, length_new), mode="bilinear" 81 | ) 82 | ratios = coords_1 - coords_1.floor() 83 | coords_1 = coords_1.to(torch.int64) 84 | 85 | coords_2 = ( 86 | torch.arange(length_old, dtype=torch.float32, device=device).reshape( 87 | (1, 1, 1, -1) 88 | ) 89 | + 1 90 | ) 91 | coords_2[:, :, :, -1] -= 1 92 | coords_2 = torch.nn.functional.interpolate( 93 | coords_2, size=(1, length_new), mode="bilinear" 94 | ) 95 | coords_2 = coords_2.to(torch.int64) 96 | return ratios, coords_1, coords_2 97 | 98 | orig_dtype = samples.dtype 99 | samples = samples.float() 100 | n, c, h, w = samples.shape 101 | h_new, w_new = (height, width) 102 | 103 | # linear w 104 | ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) 105 | coords_1 = coords_1.expand((n, c, h, -1)) 106 | coords_2 = coords_2.expand((n, c, h, -1)) 107 | ratios = ratios.expand((n, 1, h, -1)) 108 | 109 | pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c)) 110 | pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c)) 111 | ratios = ratios.movedim(1, -1).reshape((-1, 1)) 112 | 113 | result = slerp(pass_1, pass_2, ratios) 114 | result = result.reshape(n, h, w_new, c).movedim(-1, 1) 115 | 116 | # linear h 117 | ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) 118 | coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) 119 | coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new)) 120 | ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new)) 121 | 122 | pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c)) 123 | pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c)) 124 | ratios = ratios.movedim(1, -1).reshape((-1, 1)) 125 | 126 | result = slerp(pass_1, pass_2, ratios) 127 | result = result.reshape(n, h_new, w_new, c).movedim(-1, 1) 128 | return result.to(orig_dtype) 129 | 130 | 131 | def common_upscale(samples: List, width: int, height: int) -> torch.Tensor: 132 | """#### Upscales the given samples to the specified width and height using the specified method and crop settings. 133 | #### Args: 134 | - `samples` (list): The list of samples to be upscaled. 135 | - `width` (int): The target width for the upscaled samples. 136 | - `height` (int): The target height for the upscaled samples. 137 | #### Returns: 138 | - `torch.Tensor`: The upscaled samples. 139 | """ 140 | s = samples 141 | return bislerp(s, width, height) 142 | 143 | 144 | class LatentUpscale: 145 | """#### A class to upscale latent codes.""" 146 | 147 | def upscale(self, samples: dict, width: int, height: int) -> tuple: 148 | """#### Upscales the given latent codes. 149 | 150 | #### Args: 151 | - `samples` (dict): The latent codes to be upscaled. 152 | - `width` (int): The target width for the upscaled samples. 153 | - `height` (int): The target height for the upscaled samples. 154 | 155 | #### Returns: 156 | - `tuple`: The upscaled samples. 157 | """ 158 | if width == 0 and height == 0: 159 | s = samples 160 | else: 161 | s = samples.copy() 162 | width = max(64, width) 163 | height = max(64, height) 164 | 165 | s["samples"] = common_upscale(samples["samples"], width // 8, height // 8) 166 | return (s,) 167 | -------------------------------------------------------------------------------- /modules/WaveSpeed/fbcache_nodes.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import unittest 3 | import torch 4 | 5 | from . import first_block_cache 6 | 7 | 8 | class ApplyFBCacheOnModel: 9 | 10 | def patch( 11 | self, 12 | model, 13 | object_to_patch, 14 | residual_diff_threshold, 15 | max_consecutive_cache_hits=-1, 16 | start=0.0, 17 | end=1.0, 18 | ): 19 | if residual_diff_threshold <= 0.0 or max_consecutive_cache_hits == 0: 20 | return (model, ) 21 | 22 | # first_block_cache.patch_get_output_data() 23 | 24 | using_validation = max_consecutive_cache_hits >= 0 or start > 0 or end < 1 25 | if using_validation: 26 | model_sampling = model.get_model_object("model_sampling") 27 | start_sigma, end_sigma = (float( 28 | model_sampling.percent_to_sigma(pct)) for pct in (start, end)) 29 | del model_sampling 30 | 31 | @torch.compiler.disable() 32 | def validate_use_cache(use_cached): 33 | nonlocal consecutive_cache_hits 34 | use_cached = use_cached and end_sigma <= current_timestep <= start_sigma 35 | use_cached = use_cached and (max_consecutive_cache_hits < 0 36 | or consecutive_cache_hits 37 | < max_consecutive_cache_hits) 38 | consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0 39 | return use_cached 40 | else: 41 | validate_use_cache = None 42 | 43 | prev_timestep = None 44 | prev_input_state = None 45 | current_timestep = None 46 | consecutive_cache_hits = 0 47 | 48 | def reset_cache_state(): 49 | # Resets the cache state and hits/time tracking variables. 50 | nonlocal prev_input_state, prev_timestep, consecutive_cache_hits 51 | prev_input_state = prev_timestep = None 52 | consecutive_cache_hits = 0 53 | first_block_cache.set_current_cache_context( 54 | first_block_cache.create_cache_context()) 55 | 56 | def ensure_cache_state(model_input: torch.Tensor, timestep: float): 57 | # Validates the current cache state and hits/time tracking variables 58 | # and triggers a reset if necessary. Also updates current_timestep. 59 | nonlocal current_timestep 60 | input_state = (model_input.shape, model_input.dtype, model_input.device) 61 | need_reset = ( 62 | prev_timestep is None or 63 | prev_input_state != input_state or 64 | first_block_cache.get_current_cache_context() is None or 65 | timestep >= prev_timestep 66 | ) 67 | if need_reset: 68 | reset_cache_state() 69 | current_timestep = timestep 70 | 71 | def update_cache_state(model_input: torch.Tensor, timestep: float): 72 | # Updates the previous timestep and input state validation variables. 73 | nonlocal prev_timestep, prev_input_state 74 | prev_timestep = timestep 75 | prev_input_state = (model_input.shape, model_input.dtype, model_input.device) 76 | 77 | model = model[0].clone() 78 | diffusion_model = model.get_model_object(object_to_patch) 79 | 80 | if diffusion_model.__class__.__name__ in ("UNetModel", "Flux"): 81 | 82 | if diffusion_model.__class__.__name__ == "UNetModel": 83 | create_patch_function = first_block_cache.create_patch_unet_model__forward 84 | elif diffusion_model.__class__.__name__ == "Flux": 85 | create_patch_function = first_block_cache.create_patch_flux_forward_orig 86 | else: 87 | raise ValueError( 88 | f"Unsupported model {diffusion_model.__class__.__name__}") 89 | 90 | patch_forward = create_patch_function( 91 | diffusion_model, 92 | residual_diff_threshold=residual_diff_threshold, 93 | validate_can_use_cache_function=validate_use_cache, 94 | ) 95 | 96 | def model_unet_function_wrapper(model_function, kwargs): 97 | try: 98 | input = kwargs["input"] 99 | timestep = kwargs["timestep"] 100 | c = kwargs["c"] 101 | t = timestep[0].item() 102 | 103 | ensure_cache_state(input, t) 104 | 105 | with patch_forward(): 106 | result = model_function(input, timestep, **c) 107 | update_cache_state(input, t) 108 | return result 109 | except Exception as exc: 110 | reset_cache_state() 111 | raise exc from None 112 | else: 113 | is_non_native_ltxv = False 114 | if diffusion_model.__class__.__name__ == "LTXVTransformer3D": 115 | is_non_native_ltxv = True 116 | diffusion_model = diffusion_model.transformer 117 | 118 | double_blocks_name = None 119 | single_blocks_name = None 120 | if hasattr(diffusion_model, "transformer_blocks"): 121 | double_blocks_name = "transformer_blocks" 122 | elif hasattr(diffusion_model, "double_blocks"): 123 | double_blocks_name = "double_blocks" 124 | elif hasattr(diffusion_model, "joint_blocks"): 125 | double_blocks_name = "joint_blocks" 126 | else: 127 | raise ValueError( 128 | f"No double blocks found for {diffusion_model.__class__.__name__}" 129 | ) 130 | 131 | if hasattr(diffusion_model, "single_blocks"): 132 | single_blocks_name = "single_blocks" 133 | 134 | if is_non_native_ltxv: 135 | original_create_skip_layer_mask = getattr( 136 | diffusion_model, "create_skip_layer_mask", None) 137 | if original_create_skip_layer_mask is not None: 138 | # original_double_blocks = getattr(diffusion_model, 139 | # double_blocks_name) 140 | 141 | def new_create_skip_layer_mask(self, *args, **kwargs): 142 | # with unittest.mock.patch.object(self, double_blocks_name, 143 | # original_double_blocks): 144 | # return original_create_skip_layer_mask(*args, **kwargs) 145 | # return original_create_skip_layer_mask(*args, **kwargs) 146 | raise RuntimeError( 147 | "STG is not supported with FBCache yet") 148 | 149 | diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__( 150 | diffusion_model) 151 | 152 | cached_transformer_blocks = torch.nn.ModuleList([ 153 | first_block_cache.CachedTransformerBlocks( 154 | None if double_blocks_name is None else getattr( 155 | diffusion_model, double_blocks_name), 156 | None if single_blocks_name is None else getattr( 157 | diffusion_model, single_blocks_name), 158 | residual_diff_threshold=residual_diff_threshold, 159 | validate_can_use_cache_function=validate_use_cache, 160 | cat_hidden_states_first=diffusion_model.__class__.__name__ 161 | == "HunyuanVideo", 162 | return_hidden_states_only=diffusion_model.__class__. 163 | __name__ == "LTXVModel" or is_non_native_ltxv, 164 | clone_original_hidden_states=diffusion_model.__class__. 165 | __name__ == "LTXVModel", 166 | return_hidden_states_first=diffusion_model.__class__. 167 | __name__ != "OpenAISignatureMMDITWrapper", 168 | accept_hidden_states_first=diffusion_model.__class__. 169 | __name__ != "OpenAISignatureMMDITWrapper", 170 | ) 171 | ]) 172 | dummy_single_transformer_blocks = torch.nn.ModuleList() 173 | 174 | def model_unet_function_wrapper(model_function, kwargs): 175 | try: 176 | input = kwargs["input"] 177 | timestep = kwargs["timestep"] 178 | c = kwargs["c"] 179 | t = timestep[0].item() 180 | 181 | ensure_cache_state(input, t) 182 | 183 | with unittest.mock.patch.object( 184 | diffusion_model, 185 | double_blocks_name, 186 | cached_transformer_blocks, 187 | ), unittest.mock.patch.object( 188 | diffusion_model, 189 | single_blocks_name, 190 | dummy_single_transformer_blocks, 191 | ) if single_blocks_name is not None else contextlib.nullcontext( 192 | ): 193 | result = model_function(input, timestep, **c) 194 | update_cache_state(input, t) 195 | return result 196 | except Exception as exc: 197 | reset_cache_state() 198 | raise exc from None 199 | 200 | model.set_model_unet_function_wrapper(model_unet_function_wrapper) 201 | return (model, ) 202 | -------------------------------------------------------------------------------- /modules/WaveSpeed/misc_nodes.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import json 3 | 4 | from . import utils 5 | 6 | 7 | class EnhancedCompileModel: 8 | 9 | def patch( 10 | self, 11 | model, 12 | is_patcher, 13 | object_to_patch, 14 | compiler, 15 | fullgraph, 16 | dynamic, 17 | mode, 18 | options, 19 | disable, 20 | backend, 21 | ): 22 | utils.patch_optimized_module() 23 | utils.patch_same_meta() 24 | 25 | import_path, function_name = compiler.rsplit(".", 1) 26 | module = importlib.import_module(import_path) 27 | compile_function = getattr(module, function_name) 28 | 29 | mode = mode if mode else None 30 | options = json.loads(options) if options else None 31 | 32 | if compiler == "torch.compile" and backend == "inductor" and dynamic: 33 | # TODO: Fix this 34 | # File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta 35 | # and statically_known_true(sym_eq(val1.size(), val2.size())) 36 | # AttributeError: 'SymInt' object has no attribute 'size' 37 | pass 38 | 39 | if is_patcher: 40 | patcher = model[0].clone() 41 | else: 42 | patcher = model.patcher 43 | patcher = patcher.clone() 44 | 45 | patcher.add_object_patch( 46 | object_to_patch, 47 | compile_function( 48 | patcher.get_model_object(object_to_patch), 49 | fullgraph=fullgraph, 50 | dynamic=dynamic, 51 | mode=mode, 52 | options=options, 53 | disable=disable, 54 | backend=backend, 55 | ), 56 | ) 57 | 58 | if is_patcher: 59 | return (patcher,) 60 | else: 61 | model.patcher = patcher 62 | return (model,) 63 | -------------------------------------------------------------------------------- /modules/WaveSpeed/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import unittest 3 | 4 | import torch 5 | 6 | 7 | 8 | 9 | # wildcard trick is taken from pythongossss's 10 | class AnyType(str): 11 | 12 | def __ne__(self, __value: object) -> bool: 13 | return False 14 | 15 | 16 | any_typ = AnyType("*") 17 | 18 | 19 | def get_weight_dtype_inputs(): 20 | return { 21 | "weight_dtype": ( 22 | [ 23 | "default", 24 | "float32", 25 | "float64", 26 | "bfloat16", 27 | "float16", 28 | "fp8_e4m3fn", 29 | "fp8_e4m3fn_fast", 30 | "fp8_e5m2", 31 | ], 32 | ), 33 | } 34 | 35 | 36 | def parse_weight_dtype(model_options, weight_dtype): 37 | dtype = { 38 | "float32": torch.float32, 39 | "float64": torch.float64, 40 | "bfloat16": torch.bfloat16, 41 | "float16": torch.float16, 42 | "fp8_e4m3fn": torch.float8_e4m3fn, 43 | "fp8_e4m3fn_fast": torch.float8_e4m3fn, 44 | "fp8_e5m2": torch.float8_e5m2, 45 | }.get(weight_dtype, None) 46 | if dtype is not None: 47 | model_options["dtype"] = dtype 48 | if weight_dtype == "fp8_e4m3fn_fast": 49 | model_options["fp8_optimizations"] = True 50 | return model_options 51 | 52 | 53 | @contextlib.contextmanager 54 | def disable_load_models_gpu(): 55 | def foo(*args, **kwargs): 56 | pass 57 | from modules.Device import Device 58 | with unittest.mock.patch.object(Device, "load_models_gpu", foo): 59 | yield 60 | 61 | 62 | def patch_optimized_module(): 63 | try: 64 | from torch._dynamo.eval_frame import OptimizedModule 65 | except ImportError: 66 | return 67 | 68 | if getattr(OptimizedModule, "_patched", False): 69 | return 70 | 71 | def __getattribute__(self, name): 72 | if name == "_orig_mod": 73 | return object.__getattribute__(self, "_modules")[name] 74 | if name in ( 75 | "__class__", 76 | "_modules", 77 | "state_dict", 78 | "load_state_dict", 79 | "parameters", 80 | "named_parameters", 81 | "buffers", 82 | "named_buffers", 83 | "children", 84 | "named_children", 85 | "modules", 86 | "named_modules", 87 | ): 88 | return getattr(object.__getattribute__(self, "_orig_mod"), name) 89 | return object.__getattribute__(self, name) 90 | 91 | def __delattr__(self, name): 92 | # unload_lora_weights() wants to del peft_config 93 | return delattr(self._orig_mod, name) 94 | 95 | @classmethod 96 | def __instancecheck__(cls, instance): 97 | return isinstance(instance, OptimizedModule) or issubclass( 98 | object.__getattribute__(instance, "__class__"), cls 99 | ) 100 | 101 | OptimizedModule.__getattribute__ = __getattribute__ 102 | OptimizedModule.__delattr__ = __delattr__ 103 | OptimizedModule.__instancecheck__ = __instancecheck__ 104 | OptimizedModule._patched = True 105 | 106 | 107 | def patch_same_meta(): 108 | try: 109 | from torch._inductor.fx_passes import post_grad 110 | except ImportError: 111 | return 112 | 113 | same_meta = getattr(post_grad, "same_meta", None) 114 | if same_meta is None: 115 | return 116 | 117 | if getattr(same_meta, "_patched", False): 118 | return 119 | 120 | def new_same_meta(a, b): 121 | try: 122 | return same_meta(a, b) 123 | except Exception: 124 | return False 125 | 126 | post_grad.same_meta = new_same_meta 127 | new_same_meta._patched = True 128 | -------------------------------------------------------------------------------- /modules/clip/CLIPTextModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class CLIPTextModel_(torch.nn.Module): 4 | """#### The CLIPTextModel_ module.""" 5 | def __init__( 6 | self, 7 | config_dict: dict, 8 | dtype: torch.dtype, 9 | device: torch.device, 10 | operations: object, 11 | ): 12 | """#### Initialize the CLIPTextModel_ module. 13 | 14 | #### Args: 15 | - `config_dict` (dict): The configuration dictionary. 16 | - `dtype` (torch.dtype): The data type. 17 | - `device` (torch.device): The device to use. 18 | - `operations` (object): The operations object. 19 | """ 20 | num_layers = config_dict["num_hidden_layers"] 21 | embed_dim = config_dict["hidden_size"] 22 | heads = config_dict["num_attention_heads"] 23 | intermediate_size = config_dict["intermediate_size"] 24 | intermediate_activation = config_dict["hidden_act"] 25 | num_positions = config_dict["max_position_embeddings"] 26 | self.eos_token_id = config_dict["eos_token_id"] 27 | 28 | super().__init__() 29 | from modules.clip.Clip import CLIPEmbeddings, CLIPEncoder 30 | self.embeddings = CLIPEmbeddings( 31 | embed_dim, 32 | num_positions=num_positions, 33 | dtype=dtype, 34 | device=device, 35 | operations=operations, 36 | ) 37 | self.encoder = CLIPEncoder( 38 | num_layers, 39 | embed_dim, 40 | heads, 41 | intermediate_size, 42 | intermediate_activation, 43 | dtype, 44 | device, 45 | operations, 46 | ) 47 | self.final_layer_norm = operations.LayerNorm( 48 | embed_dim, dtype=dtype, device=device 49 | ) 50 | 51 | def forward( 52 | self, 53 | input_tokens: torch.Tensor, 54 | attention_mask: torch.Tensor = None, 55 | intermediate_output: int = None, 56 | final_layer_norm_intermediate: bool = True, 57 | dtype: torch.dtype = torch.float32, 58 | ) -> tuple: 59 | """#### Forward pass for the CLIPTextModel_ module. 60 | 61 | #### Args: 62 | - `input_tokens` (torch.Tensor): The input tokens. 63 | - `attention_mask` (torch.Tensor, optional): The attention mask. Defaults to None. 64 | - `intermediate_output` (int, optional): The intermediate output layer. Defaults to None. 65 | - `final_layer_norm_intermediate` (bool, optional): Whether to apply final layer normalization to the intermediate output. Defaults to True. 66 | 67 | #### Returns: 68 | - `tuple`: The output tensor, the intermediate output tensor, and the pooled output tensor. 69 | """ 70 | x = self.embeddings(input_tokens, dtype=dtype) 71 | mask = None 72 | if attention_mask is not None: 73 | mask = 1.0 - attention_mask.to(x.dtype).reshape( 74 | (attention_mask.shape[0], 1, -1, attention_mask.shape[-1]) 75 | ).expand( 76 | attention_mask.shape[0], 77 | 1, 78 | attention_mask.shape[-1], 79 | attention_mask.shape[-1], 80 | ) 81 | mask = mask.masked_fill(mask.to(torch.bool), float("-inf")) 82 | 83 | causal_mask = ( 84 | torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device) 85 | .fill_(float("-inf")) 86 | .triu_(1) 87 | ) 88 | if mask is not None: 89 | mask += causal_mask 90 | else: 91 | mask = causal_mask 92 | 93 | x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output) 94 | x = self.final_layer_norm(x) 95 | if i is not None and final_layer_norm_intermediate: 96 | i = self.final_layer_norm(i) 97 | 98 | pooled_output = x[ 99 | torch.arange(x.shape[0], device=x.device), 100 | ( 101 | torch.round(input_tokens).to(dtype=torch.int, device=x.device) 102 | == self.eos_token_id 103 | ) 104 | .int() 105 | .argmax(dim=-1), 106 | ] 107 | return x, i, pooled_output 108 | 109 | class CLIPTextModel(torch.nn.Module): 110 | """#### The CLIPTextModel module.""" 111 | def __init__( 112 | self, 113 | config_dict: dict, 114 | dtype: torch.dtype, 115 | device: torch.device, 116 | operations: object, 117 | ): 118 | """#### Initialize the CLIPTextModel module. 119 | 120 | #### Args: 121 | - `config_dict` (dict): The configuration dictionary. 122 | - `dtype` (torch.dtype): The data type. 123 | - `device` (torch.device): The device to use. 124 | - `operations` (object): The operations object. 125 | """ 126 | super().__init__() 127 | self.num_layers = config_dict["num_hidden_layers"] 128 | self.text_model = CLIPTextModel_(config_dict, dtype, device, operations) 129 | embed_dim = config_dict["hidden_size"] 130 | self.text_projection = operations.Linear( 131 | embed_dim, embed_dim, bias=False, dtype=dtype, device=device 132 | ) 133 | self.dtype = dtype 134 | 135 | def get_input_embeddings(self) -> torch.nn.Embedding: 136 | """#### Get the input embeddings. 137 | 138 | #### Returns: 139 | - `torch.nn.Embedding`: The input embeddings. 140 | """ 141 | return self.text_model.embeddings.token_embedding 142 | 143 | def set_input_embeddings(self, embeddings: torch.nn.Embedding) -> None: 144 | """#### Set the input embeddings. 145 | 146 | #### Args: 147 | - `embeddings` (torch.nn.Embedding): The input embeddings. 148 | """ 149 | self.text_model.embeddings.token_embedding = embeddings 150 | 151 | def forward(self, *args, **kwargs) -> tuple: 152 | """#### Forward pass for the CLIPTextModel module. 153 | 154 | #### Args: 155 | - `*args`: Variable length argument list. 156 | - `**kwargs`: Arbitrary keyword arguments. 157 | 158 | #### Returns: 159 | - `tuple`: The output tensors. 160 | """ 161 | x = self.text_model(*args, **kwargs) 162 | out = self.text_projection(x[2]) 163 | return (x[0], x[1], out, x[2]) -------------------------------------------------------------------------------- /modules/clip/clip/hydit_clip.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "hfl/chinese-roberta-wwm-ext-large", 3 | "architectures": [ 4 | "BertModel" 5 | ], 6 | "attention_probs_dropout_prob": 0.1, 7 | "bos_token_id": 0, 8 | "classifier_dropout": null, 9 | "directionality": "bidi", 10 | "eos_token_id": 2, 11 | "hidden_act": "gelu", 12 | "hidden_dropout_prob": 0.1, 13 | "hidden_size": 1024, 14 | "initializer_range": 0.02, 15 | "intermediate_size": 4096, 16 | "layer_norm_eps": 1e-12, 17 | "max_position_embeddings": 512, 18 | "model_type": "bert", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "output_past": true, 22 | "pad_token_id": 0, 23 | "pooler_fc_size": 768, 24 | "pooler_num_attention_heads": 12, 25 | "pooler_num_fc_layers": 3, 26 | "pooler_size_per_head": 128, 27 | "pooler_type": "first_token_transform", 28 | "position_embedding_type": "absolute", 29 | "torch_dtype": "float32", 30 | "transformers_version": "4.22.1", 31 | "type_vocab_size": 2, 32 | "use_cache": true, 33 | "vocab_size": 47020 34 | } 35 | 36 | -------------------------------------------------------------------------------- /modules/clip/clip/long_clipl.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 49407, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 248, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.24.0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /modules/clip/clip/mt5_config_xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_ff": 5120, 3 | "d_kv": 64, 4 | "d_model": 2048, 5 | "decoder_start_token_id": 0, 6 | "dropout_rate": 0.1, 7 | "eos_token_id": 1, 8 | "dense_act_fn": "gelu_pytorch_tanh", 9 | "initializer_factor": 1.0, 10 | "is_encoder_decoder": true, 11 | "is_gated_act": true, 12 | "layer_norm_epsilon": 1e-06, 13 | "model_type": "mt5", 14 | "num_decoder_layers": 24, 15 | "num_heads": 32, 16 | "num_layers": 24, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "relative_attention_num_buckets": 32, 20 | "tie_word_embeddings": false, 21 | "vocab_size": 250112 22 | } 23 | -------------------------------------------------------------------------------- /modules/clip/clip/sd1_clip_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "openai/clip-vit-large-patch14", 3 | "architectures": [ 4 | "CLIPTextModel" 5 | ], 6 | "attention_dropout": 0.0, 7 | "bos_token_id": 0, 8 | "dropout": 0.0, 9 | "eos_token_id": 49407, 10 | "hidden_act": "quick_gelu", 11 | "hidden_size": 768, 12 | "initializer_factor": 1.0, 13 | "initializer_range": 0.02, 14 | "intermediate_size": 3072, 15 | "layer_norm_eps": 1e-05, 16 | "max_position_embeddings": 77, 17 | "model_type": "clip_text_model", 18 | "num_attention_heads": 12, 19 | "num_hidden_layers": 12, 20 | "pad_token_id": 1, 21 | "projection_dim": 768, 22 | "torch_dtype": "float32", 23 | "transformers_version": "4.24.0", 24 | "vocab_size": 49408 25 | } 26 | -------------------------------------------------------------------------------- /modules/clip/clip/sd2_clip_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "CLIPTextModel" 4 | ], 5 | "attention_dropout": 0.0, 6 | "bos_token_id": 0, 7 | "dropout": 0.0, 8 | "eos_token_id": 49407, 9 | "hidden_act": "gelu", 10 | "hidden_size": 1024, 11 | "initializer_factor": 1.0, 12 | "initializer_range": 0.02, 13 | "intermediate_size": 4096, 14 | "layer_norm_eps": 1e-05, 15 | "max_position_embeddings": 77, 16 | "model_type": "clip_text_model", 17 | "num_attention_heads": 16, 18 | "num_hidden_layers": 24, 19 | "pad_token_id": 1, 20 | "projection_dim": 1024, 21 | "torch_dtype": "float32", 22 | "vocab_size": 49408 23 | } 24 | -------------------------------------------------------------------------------- /modules/clip/clip/t5_config_base.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_ff": 3072, 3 | "d_kv": 64, 4 | "d_model": 768, 5 | "decoder_start_token_id": 0, 6 | "dropout_rate": 0.1, 7 | "eos_token_id": 1, 8 | "dense_act_fn": "relu", 9 | "initializer_factor": 1.0, 10 | "is_encoder_decoder": true, 11 | "is_gated_act": false, 12 | "layer_norm_epsilon": 1e-06, 13 | "model_type": "t5", 14 | "num_decoder_layers": 12, 15 | "num_heads": 12, 16 | "num_layers": 12, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "relative_attention_num_buckets": 32, 20 | "tie_word_embeddings": false, 21 | "vocab_size": 32128 22 | } 23 | -------------------------------------------------------------------------------- /modules/clip/clip/t5_config_xxl.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_ff": 10240, 3 | "d_kv": 64, 4 | "d_model": 4096, 5 | "decoder_start_token_id": 0, 6 | "dropout_rate": 0.1, 7 | "eos_token_id": 1, 8 | "dense_act_fn": "gelu_pytorch_tanh", 9 | "initializer_factor": 1.0, 10 | "is_encoder_decoder": true, 11 | "is_gated_act": true, 12 | "layer_norm_epsilon": 1e-06, 13 | "model_type": "t5", 14 | "num_decoder_layers": 24, 15 | "num_heads": 64, 16 | "num_layers": 24, 17 | "output_past": true, 18 | "pad_token_id": 0, 19 | "relative_attention_num_buckets": 32, 20 | "tie_word_embeddings": false, 21 | "vocab_size": 32128 22 | } 23 | -------------------------------------------------------------------------------- /modules/clip/clip/t5_pile_config_xl.json: -------------------------------------------------------------------------------- 1 | { 2 | "d_ff": 5120, 3 | "d_kv": 64, 4 | "d_model": 2048, 5 | "decoder_start_token_id": 0, 6 | "dropout_rate": 0.1, 7 | "eos_token_id": 2, 8 | "dense_act_fn": "gelu_pytorch_tanh", 9 | "initializer_factor": 1.0, 10 | "is_encoder_decoder": true, 11 | "is_gated_act": true, 12 | "layer_norm_epsilon": 1e-06, 13 | "model_type": "umt5", 14 | "num_decoder_layers": 24, 15 | "num_heads": 32, 16 | "num_layers": 24, 17 | "output_past": true, 18 | "pad_token_id": 1, 19 | "relative_attention_num_buckets": 32, 20 | "tie_word_embeddings": false, 21 | "vocab_size": 32128 22 | } 23 | -------------------------------------------------------------------------------- /modules/clip/clip/t5_tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "additional_special_tokens": [ 3 | "", 4 | "", 5 | "", 6 | "", 7 | "", 8 | "", 9 | "", 10 | "", 11 | "", 12 | "", 13 | "", 14 | "", 15 | "", 16 | "", 17 | "", 18 | "", 19 | "", 20 | "", 21 | "", 22 | "", 23 | "", 24 | "", 25 | "", 26 | "", 27 | "", 28 | "", 29 | "", 30 | "", 31 | "", 32 | "", 33 | "", 34 | "", 35 | "", 36 | "", 37 | "", 38 | "", 39 | "", 40 | "", 41 | "", 42 | "", 43 | "", 44 | "", 45 | "", 46 | "", 47 | "", 48 | "", 49 | "", 50 | "", 51 | "", 52 | "", 53 | "", 54 | "", 55 | "", 56 | "", 57 | "", 58 | "", 59 | "", 60 | "", 61 | "", 62 | "", 63 | "", 64 | "", 65 | "", 66 | "", 67 | "", 68 | "", 69 | "", 70 | "", 71 | "", 72 | "", 73 | "", 74 | "", 75 | "", 76 | "", 77 | "", 78 | "", 79 | "", 80 | "", 81 | "", 82 | "", 83 | "", 84 | "", 85 | "", 86 | "", 87 | "", 88 | "", 89 | "", 90 | "", 91 | "", 92 | "", 93 | "", 94 | "", 95 | "", 96 | "", 97 | "", 98 | "", 99 | "", 100 | "", 101 | "", 102 | "" 103 | ], 104 | "eos_token": { 105 | "content": "", 106 | "lstrip": false, 107 | "normalized": false, 108 | "rstrip": false, 109 | "single_word": false 110 | }, 111 | "pad_token": { 112 | "content": "", 113 | "lstrip": false, 114 | "normalized": false, 115 | "rstrip": false, 116 | "single_word": false 117 | }, 118 | "unk_token": { 119 | "content": "", 120 | "lstrip": false, 121 | "normalized": false, 122 | "rstrip": false, 123 | "single_word": false 124 | } 125 | } 126 | -------------------------------------------------------------------------------- /modules/cond/Activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from modules.cond import cast 4 | 5 | 6 | class GEGLU(nn.Module): 7 | """#### Class representing the GEGLU activation function. 8 | 9 | GEGLU is a gated activation function that is a combination of GELU and ReLU, 10 | used to fire the neurons in the network. 11 | 12 | #### Args: 13 | - `dim_in` (int): The input dimension. 14 | - `dim_out` (int): The output dimension. 15 | """ 16 | 17 | def __init__(self, dim_in: int, dim_out: int): 18 | super().__init__() 19 | self.proj = cast.manual_cast.Linear(dim_in, dim_out * 2) 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | """#### Forward pass for the GEGLU activation function. 23 | 24 | #### Args: 25 | - `x` (torch.Tensor): The input tensor. 26 | 27 | #### Returns: 28 | - `torch.Tensor`: The output tensor. 29 | """ 30 | x, gate = self.proj(x).chunk(2, dim=-1) 31 | return x * torch.nn.functional.gelu(gate) 32 | -------------------------------------------------------------------------------- /modules/cond/cond_util.py: -------------------------------------------------------------------------------- 1 | from modules.Device import Device 2 | import torch 3 | from typing import List, Tuple, Any 4 | 5 | 6 | def get_models_from_cond(cond: dict, model_type: str) -> List[object]: 7 | """#### Get models from a condition. 8 | 9 | #### Args: 10 | - `cond` (dict): The condition. 11 | - `model_type` (str): The model type. 12 | 13 | #### Returns: 14 | - `List[object]`: The list of models. 15 | """ 16 | models = [] 17 | for c in cond: 18 | if model_type in c: 19 | models += [c[model_type]] 20 | return models 21 | 22 | 23 | def get_additional_models(conds: dict, dtype: torch.dtype) -> Tuple[List[object], int]: 24 | """#### Load additional models in conditioning. 25 | 26 | #### Args: 27 | - `conds` (dict): The conditions. 28 | - `dtype` (torch.dtype): The data type. 29 | 30 | #### Returns: 31 | - `Tuple[List[object], int]`: The list of models and the inference memory. 32 | """ 33 | cnets = [] 34 | gligen = [] 35 | 36 | for k in conds: 37 | cnets += get_models_from_cond(conds[k], "control") 38 | gligen += get_models_from_cond(conds[k], "gligen") 39 | 40 | control_nets = set(cnets) 41 | 42 | inference_memory = 0 43 | control_models = [] 44 | for m in control_nets: 45 | control_models += m.get_models() 46 | inference_memory += m.inference_memory_requirements(dtype) 47 | 48 | gligen = [x[1] for x in gligen] 49 | models = control_models + gligen 50 | return models, inference_memory 51 | 52 | 53 | def prepare_sampling( 54 | model: object, noise_shape: Tuple[int], conds: dict, flux_enabled: bool = False 55 | ) -> Tuple[object, dict, List[object]]: 56 | """#### Prepare the model for sampling. 57 | 58 | #### Args: 59 | - `model` (object): The model. 60 | - `noise_shape` (Tuple[int]): The shape of the noise. 61 | - `conds` (dict): The conditions. 62 | - `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False. 63 | 64 | #### Returns: 65 | - `Tuple[object, dict, List[object]]`: The prepared model, conditions, and additional models. 66 | """ 67 | real_model = None 68 | models, inference_memory = get_additional_models(conds, model.model_dtype()) 69 | memory_required = ( 70 | model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:])) 71 | + inference_memory 72 | ) 73 | minimum_memory_required = ( 74 | model.memory_required([noise_shape[0]] + list(noise_shape[1:])) 75 | + inference_memory 76 | ) 77 | Device.load_models_gpu( 78 | [model] + models, 79 | memory_required=memory_required, 80 | minimum_memory_required=minimum_memory_required, 81 | flux_enabled=flux_enabled, 82 | ) 83 | real_model = model.model 84 | 85 | return real_model, conds, models 86 | 87 | def cleanup_additional_models(models: List[object]) -> None: 88 | """#### Clean up additional models. 89 | 90 | #### Args: 91 | - `models` (List[object]): The list of models. 92 | """ 93 | for m in models: 94 | if hasattr(m, "cleanup"): 95 | m.cleanup() 96 | 97 | def cleanup_models(conds: dict, models: List[object]) -> None: 98 | """#### Clean up the models after sampling. 99 | 100 | #### Args: 101 | - `conds` (dict): The conditions. 102 | - `models` (List[object]): The list of models. 103 | """ 104 | cleanup_additional_models(models) 105 | 106 | control_cleanup = [] 107 | for k in conds: 108 | control_cleanup += get_models_from_cond(conds[k], "control") 109 | 110 | cleanup_additional_models(set(control_cleanup)) 111 | 112 | 113 | def cond_equal_size(c1: Any, c2: Any) -> bool: 114 | """#### Check if two conditions have equal size. 115 | 116 | #### Args: 117 | - `c1` (Any): The first condition. 118 | - `c2` (Any): The second condition. 119 | 120 | #### Returns: 121 | - `bool`: Whether the conditions have equal size. 122 | """ 123 | if c1 is c2: 124 | return True 125 | if c1.keys() != c2.keys(): 126 | return False 127 | return True 128 | 129 | 130 | def can_concat_cond(c1: Any, c2: Any) -> bool: 131 | """#### Check if two conditions can be concatenated. 132 | 133 | #### Args: 134 | - `c1` (Any): The first condition. 135 | - `c2` (Any): The second condition. 136 | 137 | #### Returns: 138 | - `bool`: Whether the conditions can be concatenated. 139 | """ 140 | if c1.input_x.shape != c2.input_x.shape: 141 | return False 142 | 143 | def objects_concatable(obj1, obj2): 144 | """#### Check if two objects can be concatenated.""" 145 | if (obj1 is None) != (obj2 is None): 146 | return False 147 | if obj1 is not None: 148 | if obj1 is not obj2: 149 | return False 150 | return True 151 | 152 | if not objects_concatable(c1.control, c2.control): 153 | return False 154 | 155 | if not objects_concatable(c1.patches, c2.patches): 156 | return False 157 | 158 | return cond_equal_size(c1.conditioning, c2.conditioning) 159 | 160 | 161 | def cond_cat(c_list: List[dict]) -> dict: 162 | """#### Concatenate a list of conditions. 163 | 164 | #### Args: 165 | - `c_list` (List[dict]): The list of conditions. 166 | 167 | #### Returns: 168 | - `dict`: The concatenated conditions. 169 | """ 170 | temp = {} 171 | for x in c_list: 172 | for k in x: 173 | cur = temp.get(k, []) 174 | cur.append(x[k]) 175 | temp[k] = cur 176 | 177 | out = {} 178 | for k in temp: 179 | conds = temp[k] 180 | out[k] = conds[0].concat(conds[1:]) 181 | 182 | return out 183 | 184 | 185 | def create_cond_with_same_area_if_none(conds: List[dict], c: dict) -> None: 186 | """#### Create a condition with the same area if none exists. 187 | 188 | #### Args: 189 | - `conds` (List[dict]): The list of conditions. 190 | - `c` (dict): The condition. 191 | """ 192 | if "area" not in c: 193 | return 194 | 195 | c_area = c["area"] 196 | smallest = None 197 | for x in conds: 198 | if "area" in x: 199 | a = x["area"] 200 | if c_area[2] >= a[2] and c_area[3] >= a[3]: 201 | if a[0] + a[2] >= c_area[0] + c_area[2]: 202 | if a[1] + a[3] >= c_area[1] + c_area[3]: 203 | if smallest is None: 204 | smallest = x 205 | elif "area" not in smallest: 206 | smallest = x 207 | else: 208 | if smallest["area"][0] * smallest["area"][1] > a[0] * a[1]: 209 | smallest = x 210 | else: 211 | if smallest is None: 212 | smallest = x 213 | if smallest is None: 214 | return 215 | if "area" in smallest: 216 | if smallest["area"] == c_area: 217 | return 218 | 219 | out = c.copy() 220 | out["model_conds"] = smallest[ 221 | "model_conds" 222 | ].copy() 223 | conds += [out] 224 | -------------------------------------------------------------------------------- /modules/tests/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from modules.user.pipeline import pipeline 4 | 5 | class TestPipeline(unittest.TestCase): 6 | def setUp(self): 7 | self.test_prompt = "a cute cat, high quality, detailed" 8 | self.test_img_path = "../_internal/Flux_00001.png" # Make sure this test image exists 9 | 10 | def test_basic_generation_small(self): 11 | pipeline(self.test_prompt, 128, 128, number=1) 12 | # Check if output files exist 13 | 14 | def test_basic_generation_medium(self): 15 | pipeline(self.test_prompt, 512, 512, number=1) 16 | 17 | def test_basic_generation_large(self): 18 | pipeline(self.test_prompt, 1024, 1024, number=1) 19 | 20 | def test_hires_fix(self): 21 | pipeline(self.test_prompt, 512, 512, number=1, hires_fix=True) 22 | 23 | def test_adetailer(self): 24 | pipeline( 25 | "a portrait of a person, high quality", 26 | 512, 27 | 512, 28 | number=1, 29 | adetailer=True 30 | ) 31 | 32 | def test_enhance_prompt(self): 33 | pipeline( 34 | self.test_prompt, 35 | 512, 36 | 512, 37 | number=1, 38 | enhance_prompt=True 39 | ) 40 | 41 | def test_img2img(self): 42 | # Skip if test image doesn't exist 43 | if not os.path.exists(self.test_img_path): 44 | self.skipTest("Test image not found") 45 | 46 | pipeline( 47 | self.test_img_path, 48 | 512, 49 | 512, 50 | number=1, 51 | img2img=True 52 | ) 53 | 54 | def test_stable_fast(self): 55 | resolutions = [(128, 128), (512, 512), (1024, 1024)] 56 | for w, h in resolutions: 57 | pipeline( 58 | self.test_prompt, 59 | w, 60 | h, 61 | number=1, 62 | stable_fast=True 63 | ) 64 | 65 | def test_reuse_seed(self): 66 | pipeline( 67 | self.test_prompt, 68 | 512, 69 | 512, 70 | number=2, 71 | reuse_seed=True 72 | ) 73 | 74 | def test_flux_mode(self): 75 | resolutions = [(128, 128), (512, 512), (1024, 1024)] 76 | for w, h in resolutions: 77 | pipeline( 78 | self.test_prompt, 79 | w, 80 | h, 81 | number=1, 82 | flux_enabled=True 83 | ) 84 | 85 | if __name__ == '__main__': 86 | unittest.main() -------------------------------------------------------------------------------- /modules/user/app_instance.py: -------------------------------------------------------------------------------- 1 | from modules.user.GUI import App 2 | 3 | app = App() 4 | -------------------------------------------------------------------------------- /pipeline.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | SET VENV_DIR=.venv 3 | 4 | REM Check if .venv exists 5 | IF NOT EXIST %VENV_DIR% ( 6 | echo Creating virtual environment... 7 | python -m venv %VENV_DIR% 8 | ) 9 | 10 | REM Activate the virtual environment 11 | CALL %VENV_DIR%\Scripts\activate 12 | 13 | REM Upgrade pip 14 | echo Upgrading pip... 15 | python -m pip install --upgrade pip 16 | 17 | REM Install specific packages 18 | echo Installing required packages... 19 | pip install uv 20 | 21 | REM Check for NVIDIA GPU 22 | FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO ( 23 | SET GPU_CHECK=%%i 24 | ) 25 | IF NOT ERRORLEVEL 1 ( 26 | echo NVIDIA GPU detected, installing GPU dependencies... 27 | uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126 28 | ) ELSE ( 29 | echo No NVIDIA GPU detected, installing CPU dependencies... 30 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 31 | ) 32 | 33 | uv pip install "numpy>=1.24.3" 34 | 35 | REM Install additional requirements 36 | IF EXIST requirements.txt ( 37 | echo Installing additional requirements... 38 | uv pip install -r requirements.txt 39 | ) ELSE ( 40 | echo requirements.txt not found, skipping... 41 | ) 42 | 43 | REM Check for enhance-prompt argument 44 | echo Checking for enhance-prompt argument... 45 | echo %* | findstr /i /c:"--enhance-prompt" >nul 46 | IF %ERRORLEVEL% EQU 0 ( 47 | echo Installing ollama with winget... 48 | winget install --id ollama.ollama 49 | ollama pull deepseek-r1 50 | ) 51 | 52 | REM Launch the script 53 | echo Launching LightDiffusion... 54 | python .\modules\user\pipeline.py %* 55 | 56 | REM Deactivate the virtual environment 57 | deactivate 58 | -------------------------------------------------------------------------------- /pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VENV_DIR=.venv 4 | # for WSL2 Ubuntu install 5 | # sudo apt install software-properties-common 6 | # sudo add-apt-repository ppa:deadsnakes/ppa 7 | sudo apt-get install python3.10 python3.10-venv python3.10-full python3-pip 8 | 9 | # Check if .venv exists 10 | if [ ! -d "$VENV_DIR" ]; then 11 | echo "Creating virtual environment..." 12 | python3.10 -m venv $VENV_DIR 13 | fi 14 | 15 | # Activate the virtual environment 16 | source $VENV_DIR/bin/activate 17 | 18 | # Upgrade pip 19 | echo "Upgrading pip..." 20 | pip install --upgrade pip 21 | pip3 install uv 22 | 23 | # Check GPU type 24 | TORCH_URL="https://download.pytorch.org/whl/cpu" 25 | if command -v nvidia-smi &> /dev/null; then 26 | echo "NVIDIA GPU detected" 27 | TORCH_URL="https://download.pytorch.org/whl/cu121" 28 | uv pip install --index-url $TORCH_URL \ 29 | torch==2.2.2 torchvision "xformers>=0.0.22" "triton>=2.1.0" \ 30 | stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl 31 | elif command -v rocminfo &> /dev/null; then 32 | echo "AMD GPU detected" 33 | TORCH_URL="https://download.pytorch.org/whl/rocm5.7" 34 | uv pip install --index-url $TORCH_URL \ 35 | torch==2.2.2 torchvision "triton>=2.1.0" 36 | else 37 | echo "No compatible GPU detected, using CPU" 38 | uv pip install --index-url $TORCH_URL \ 39 | torch==2.2.2+cpu torchvision 40 | fi 41 | 42 | uv pip install "numpy<2.0.0" 43 | 44 | # Install tkinter 45 | echo "Installing tkinter..." 46 | sudo apt-get install python3.10-tk 47 | 48 | # Install additional requirements 49 | if [ -f requirements.txt ]; then 50 | echo "Installing additional requirements..." 51 | uv pip install -r requirements.txt 52 | else 53 | echo "requirements.txt not found, skipping..." 54 | fi 55 | 56 | REM Check for enhance-prompt argument 57 | echo Checking for enhance-prompt argument... 58 | if [[ " $* " == *" --enhance-prompt "* ]]; then 59 | echo "Installing ollama..." 60 | curl -fsSL https://ollama.com/install.sh | sh 61 | ollama pull deepseek-r1 62 | fi 63 | 64 | # Launch the script 65 | echo "Launching LightDiffusion..." 66 | python3.10 "./modules/user/pipeline.py" "$@" 67 | 68 | # Deactivate the virtual environment 69 | deactivate 70 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/requirements.txt -------------------------------------------------------------------------------- /run.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | SET VENV_DIR=.venv 3 | 4 | REM Check if .venv exists 5 | IF NOT EXIST %VENV_DIR% ( 6 | echo Creating virtual environment... 7 | python -m venv %VENV_DIR% 8 | ) 9 | 10 | REM Activate the virtual environment 11 | CALL %VENV_DIR%\Scripts\activate 12 | 13 | REM Upgrade pip 14 | echo Upgrading pip... 15 | python -m pip install --upgrade pip 16 | 17 | REM Install specific packages 18 | echo Installing required packages... 19 | pip install uv 20 | 21 | REM Check for NVIDIA GPU 22 | FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO ( 23 | SET GPU_CHECK=%%i 24 | ) 25 | IF NOT ERRORLEVEL 1 ( 26 | echo NVIDIA GPU detected, installing GPU dependencies... 27 | uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126 28 | ) ELSE ( 29 | echo No NVIDIA GPU detected, installing CPU dependencies... 30 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 31 | ) 32 | 33 | uv pip install "numpy>=1.24.3" 34 | 35 | REM Install additional requirements 36 | IF EXIST requirements.txt ( 37 | echo Installing additional requirements... 38 | uv pip install -r requirements.txt 39 | ) ELSE ( 40 | echo requirements.txt not found, skipping... 41 | ) 42 | 43 | REM Launch the script 44 | echo Launching LightDiffusion... 45 | python .\modules\user\GUI.py 46 | 47 | REM Deactivate the virtual environment 48 | deactivate 49 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VENV_DIR=.venv 4 | # for WSL2 Ubuntu install 5 | # sudo apt install software-properties-common 6 | # sudo add-apt-repository ppa:deadsnakes/ppa 7 | sudo apt-get install python3.10 python3.10-venv python3.10-full python3-pip 8 | 9 | # Check if .venv exists 10 | if [ ! -d "$VENV_DIR" ]; then 11 | echo "Creating virtual environment..." 12 | python3.10 -m venv $VENV_DIR 13 | fi 14 | 15 | # Activate the virtual environment 16 | source $VENV_DIR/bin/activate 17 | 18 | # Upgrade pip 19 | echo "Upgrading pip..." 20 | pip install --upgrade pip 21 | pip3 install uv 22 | 23 | # Check GPU type 24 | TORCH_URL="https://download.pytorch.org/whl/cpu" 25 | if command -v nvidia-smi &> /dev/null; then 26 | echo "NVIDIA GPU detected" 27 | TORCH_URL="https://download.pytorch.org/whl/cu121" 28 | uv pip install --index-url $TORCH_URL \ 29 | torch==2.2.2 torchvision "xformers>=0.0.22" "triton>=2.1.0" \ 30 | stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl 31 | elif command -v rocminfo &> /dev/null; then 32 | echo "AMD GPU detected" 33 | TORCH_URL="https://download.pytorch.org/whl/rocm5.7" 34 | uv pip install --index-url $TORCH_URL \ 35 | torch==2.2.2 torchvision "triton>=2.1.0" 36 | else 37 | echo "No compatible GPU detected, using CPU" 38 | uv pip install --index-url $TORCH_URL \ 39 | torch==2.2.2+cpu torchvision 40 | fi 41 | 42 | uv pip install "numpy<2.0.0" 43 | 44 | # Install tkinter 45 | echo "Installing tkinter..." 46 | sudo apt-get install python3.10-tk 47 | 48 | # Install additional requirements 49 | if [ -f requirements.txt ]; then 50 | echo "Installing additional requirements..." 51 | uv pip install -r requirements.txt 52 | else 53 | echo "requirements.txt not found, skipping..." 54 | fi 55 | 56 | # Launch the script 57 | echo "Launching LightDiffusion..." 58 | python3.10 ./modules/user/GUI.py 59 | 60 | # Deactivate the virtual environment 61 | deactivate 62 | -------------------------------------------------------------------------------- /run_web.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | SET VENV_DIR=.venv 3 | 4 | REM Check if .venv exists 5 | IF NOT EXIST %VENV_DIR% ( 6 | echo Creating virtual environment... 7 | python -m venv %VENV_DIR% 8 | ) 9 | 10 | REM Activate the virtual environment 11 | CALL %VENV_DIR%\Scripts\activate 12 | 13 | REM Upgrade pip 14 | echo Upgrading pip... 15 | python -m pip install --upgrade pip 16 | 17 | REM Install specific packages 18 | echo Installing required packages... 19 | pip install uv 20 | 21 | REM Check for NVIDIA GPU 22 | FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO ( 23 | SET GPU_CHECK=%%i 24 | ) 25 | IF NOT ERRORLEVEL 1 ( 26 | echo NVIDIA GPU detected, installing GPU dependencies... 27 | uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126 28 | ) ELSE ( 29 | echo No NVIDIA GPU detected, installing CPU dependencies... 30 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu 31 | ) 32 | 33 | uv pip install "numpy>=1.24.3" 34 | 35 | REM Install additional requirements 36 | IF EXIST requirements.txt ( 37 | echo Installing additional requirements... 38 | uv pip install -r requirements.txt 39 | ) ELSE ( 40 | echo requirements.txt not found, skipping... 41 | ) 42 | 43 | REM Launch the script 44 | echo Launching LightDiffusion... 45 | python app.py 46 | 47 | REM Deactivate the virtual environment 48 | deactivate 49 | -------------------------------------------------------------------------------- /run_web.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | VENV_DIR=.venv 4 | # for WSL2 Ubuntu install 5 | # sudo apt install software-properties-common 6 | # sudo add-apt-repository ppa:deadsnakes/ppa 7 | sudo apt-get install python3.10 python3.10-venv python3.10-full python3-pip 8 | 9 | # Check if .venv exists 10 | if [ ! -d "$VENV_DIR" ]; then 11 | echo "Creating virtual environment..." 12 | python3.10 -m venv $VENV_DIR 13 | fi 14 | 15 | # Activate the virtual environment 16 | source $VENV_DIR/bin/activate 17 | 18 | # Upgrade pip 19 | echo "Upgrading pip..." 20 | pip install --upgrade pip 21 | pip3 install uv 22 | 23 | # Check GPU type 24 | TORCH_URL="https://download.pytorch.org/whl/cpu" 25 | if command -v nvidia-smi &> /dev/null; then 26 | echo "NVIDIA GPU detected" 27 | TORCH_URL="https://download.pytorch.org/whl/cu121" 28 | uv pip install --index-url $TORCH_URL \ 29 | torch==2.2.2 torchvision "xformers>=0.0.22" "triton>=2.1.0" \ 30 | stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl 31 | elif command -v rocminfo &> /dev/null; then 32 | echo "AMD GPU detected" 33 | TORCH_URL="https://download.pytorch.org/whl/rocm5.7" 34 | uv pip install --index-url $TORCH_URL \ 35 | torch==2.2.2 torchvision "triton>=2.1.0" 36 | else 37 | echo "No compatible GPU detected, using CPU" 38 | uv pip install --index-url $TORCH_URL \ 39 | torch==2.2.2+cpu torchvision 40 | fi 41 | 42 | uv pip install "numpy<2.0.0" 43 | 44 | # Install tkinter 45 | echo "Installing tkinter..." 46 | sudo apt-get install python3.10-tk 47 | 48 | # Install additional requirements 49 | if [ -f requirements.txt ]; then 50 | echo "Installing additional requirements..." 51 | uv pip install -r requirements.txt 52 | else 53 | echo "requirements.txt not found, skipping..." 54 | fi 55 | 56 | # Launch the script 57 | echo "Launching LightDiffusion..." 58 | python3.10 app.py 59 | 60 | # Deactivate the virtual environment 61 | deactivate 62 | -------------------------------------------------------------------------------- /stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl --------------------------------------------------------------------------------