├── .gitattributes
├── .github
└── workflows
│ └── manual.yml
├── .gitignore
├── .gradio
└── certificate.pem
├── .vscode
└── settings.json
├── Compiler.py
├── HomeImage.png
├── LICENSE
├── README.md
├── _internal
├── ESRGAN
│ └── put_esrgan_and_other_upscale_models_here
├── checkpoints
│ └── put_checkpoints_here
├── clip
│ └── sd1_clip_config.json
├── embeddings
│ └── put_embeddings_or_textual_inversion_concepts_here
├── last_seed.txt
├── loras
│ └── put_loras_here
├── output
│ ├── Adetailer
│ │ └── Adetailer_images_end_up_here
│ ├── Flux
│ │ └── Flux_images_end_up_here
│ ├── HiresFix
│ │ └── HiresFixed_images_end_up_here
│ ├── Img2Img
│ │ └── Upscaled_images_end_up_here
│ └── classic
│ │ └── normal_images_end_up_here
├── prompt.txt
├── sd1_tokenizer
│ ├── merges.txt
│ ├── special_tokens_map.json
│ ├── tokenizer_config.json
│ └── vocab.json
└── yolos
│ └── put_yolo_and_seg_files_here
├── app.py
├── dependency_flow.png
├── modules
├── Attention
│ ├── Attention.py
│ └── AttentionMethods.py
├── AutoDetailer
│ ├── AD_util.py
│ ├── ADetailer.py
│ ├── SAM.py
│ ├── SEGS.py
│ ├── bbox.py
│ ├── mask_util.py
│ └── tensor_util.py
├── AutoEncoders
│ ├── ResBlock.py
│ ├── VariationalAE.py
│ └── taesd.py
├── AutoHDR
│ └── ahdr.py
├── BlackForest
│ └── Flux.py
├── Device
│ └── Device.py
├── FileManaging
│ ├── Downloader.py
│ ├── ImageSaver.py
│ └── Loader.py
├── Model
│ ├── LoRas.py
│ ├── ModelBase.py
│ └── ModelPatcher.py
├── NeuralNetwork
│ ├── transformer.py
│ └── unet.py
├── Quantize
│ └── Quantizer.py
├── SD15
│ ├── SD15.py
│ ├── SDClip.py
│ └── SDToken.py
├── StableFast
│ └── StableFast.py
├── UltimateSDUpscale
│ ├── RDRB.py
│ ├── USDU_upscaler.py
│ ├── USDU_util.py
│ ├── UltimateSDUpscale.py
│ └── image_util.py
├── Utilities
│ ├── Enhancer.py
│ ├── Latent.py
│ ├── upscale.py
│ └── util.py
├── WaveSpeed
│ ├── fbcache_nodes.py
│ ├── first_block_cache.py
│ ├── misc_nodes.py
│ └── utils.py
├── clip
│ ├── CLIPTextModel.py
│ ├── Clip.py
│ ├── FluxClip.py
│ └── clip
│ │ ├── hydit_clip.json
│ │ ├── long_clipl.json
│ │ ├── mt5_config_xl.json
│ │ ├── sd1_clip_config.json
│ │ ├── sd2_clip_config.json
│ │ ├── t5_config_base.json
│ │ ├── t5_config_xxl.json
│ │ ├── t5_pile_config_xl.json
│ │ └── t5_tokenizer
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer.json
│ │ └── tokenizer_config.json
├── cond
│ ├── Activation.py
│ ├── cast.py
│ ├── cond.py
│ └── cond_util.py
├── hidiffusion
│ ├── msw_msa_attention.py
│ └── utils.py
├── sample
│ ├── CFG.py
│ ├── ksampler_util.py
│ ├── samplers.py
│ ├── sampling.py
│ └── sampling_util.py
├── tests
│ └── test.py
└── user
│ ├── GUI.py
│ ├── app_instance.py
│ └── pipeline.py
├── pipeline.bat
├── pipeline.sh
├── requirements.txt
├── run.bat
├── run.sh
├── run_web.bat
├── run_web.sh
└── stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.github/workflows/manual.yml:
--------------------------------------------------------------------------------
1 | name: Manual workflow
2 |
3 | on:
4 | push:
5 | branches: [ main ]
6 | pull_request:
7 | branches: [ main ]
8 |
9 | jobs:
10 | test:
11 | runs-on: self-hosted
12 |
13 | steps:
14 | - uses: actions/checkout@v3
15 |
16 | - name: Set up Python 3.10
17 | uses: actions/setup-python@v4
18 | with:
19 | python-version: '3.10'
20 |
21 | - name: Cache dependencies
22 | uses: actions/cache@v3
23 | with:
24 | path: ~/.cache/pip
25 | key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
26 | restore-keys: |
27 | ${{ runner.os }}-pip-
28 |
29 | - name: Create virtual environment
30 | run: |
31 | python -m venv .venv
32 | if [ "$RUNNER_OS" == "Windows" ]; then
33 | . .venv/Scripts/activate
34 | else
35 | . .venv/bin/activate
36 | fi
37 | shell: bash
38 |
39 | - name: Install dependencies
40 | run: |
41 | if [ "$RUNNER_OS" == "Windows" ]; then
42 | . .venv/Scripts/activate
43 | else
44 | . .venv/bin/activate
45 | fi
46 | python -m pip install --upgrade pip
47 | pip install uv
48 | pip install torch torchvision --index-url https://download.pytorch.org/whl/cu124
49 | pip install "numpy<2.0.0"
50 | if [ -f requirements.txt ]; then
51 | uv pip install -r requirements.txt
52 | fi
53 | shell: bash
54 |
55 | - name: Test pipeline variants
56 | run: |
57 | if [ "$RUNNER_OS" == "Windows" ]; then
58 | . .venv/Scripts/activate
59 | else
60 | . .venv/bin/activate
61 | fi
62 | # Test basic pipeline
63 | python modules/user/pipeline.py "1girl" 512 512 1 1 --hires-fix --adetailer --autohdr --prio-speed
64 | # Test image to image
65 | python modules/user/pipeline.py "./_internal/output/Adetailer/LD-head_00001_.png" 512 512 1 1 --img2img --prio-speed
66 | shell: bash
67 |
68 | - name: Upload test artifacts
69 | if: always()
70 | uses: actions/upload-artifact@v4
71 | with:
72 | name: test-outputs-${{ github.sha }}
73 | path: |
74 | _internal/output/**/*.png
75 | _internal/output/Classic/*.png
76 | _internal/output/Flux/*.png
77 | _internal/output/HF/*.png
78 | retention-days: 5
79 | compression-level: 6
80 | if-no-files-found: warn
81 |
82 | - name: Report status
83 | if: always()
84 | run: |
85 | if [ ${{ job.status }} == 'success' ]; then
86 | echo "All tests passed successfully!"
87 | else
88 | echo "Some tests failed. Check the logs above for details."
89 | exit 1
90 | fi
91 | shell: bash
92 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | *.pyc
3 | *.pth
4 | *.pt
5 | *.safetensors
6 | *.gguf
7 | *.png
8 | /.idea
9 | /htmlcov
10 | .coverage
11 | .toml
12 | __pycache__
13 | .venv
14 | !HomeImage.png
15 | *.txt
16 |
--------------------------------------------------------------------------------
/.gradio/certificate.pem:
--------------------------------------------------------------------------------
1 | -----BEGIN CERTIFICATE-----
2 | MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3 | TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4 | cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5 | WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6 | ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7 | MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8 | h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9 | 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10 | A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11 | T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12 | B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13 | B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14 | KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15 | OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16 | jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17 | qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18 | rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19 | HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20 | hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21 | ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22 | 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23 | NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24 | ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25 | TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26 | jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27 | oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28 | 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29 | mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30 | emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31 | -----END CERTIFICATE-----
32 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.unittestArgs": [
3 | "-v",
4 | "-s",
5 | ".",
6 | "-p",
7 | "*test.py"
8 | ],
9 | "python.testing.pytestEnabled": false,
10 | "python.testing.unittestEnabled": true,
11 | "python.analysis.autoImportCompletions": true,
12 | "python.analysis.typeCheckingMode": "off"
13 | }
--------------------------------------------------------------------------------
/Compiler.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 |
4 | files_ordered = [
5 | "./modules/Utilities/util.py",
6 | "./modules/sample/sampling_util.py",
7 | "./modules/Device/Device.py",
8 | "./modules/cond/cond_util.py",
9 | "./modules/cond/cond.py",
10 | "./modules/sample/ksampler_util.py",
11 | "./modules/cond/cast.py",
12 | "./modules/Attention/AttentionMethods.py",
13 | "./modules/AutoEncoders/taesd.py",
14 | "./modules/cond/cond.py",
15 | "./modules/cond/Activation.py",
16 | "./modules/Attention/Attention.py",
17 | "./modules/sample/samplers.py",
18 | "./modules/sample/CFG.py",
19 | "./modules/NeuralNetwork/transformer.py",
20 | "./modules/sample/sampling.py",
21 | "./modules/clip/CLIPTextModel.py",
22 | "./modules/AutoEncoders/ResBlock.py",
23 | "./modules/AutoDetailer/mask_util.py",
24 | "./modules/NeuralNetwork/unet.py",
25 | "./modules/SD15/SDClip.py",
26 | "./modules/SD15/SDToken.py",
27 | "./modules/UltimateSDUpscale/USDU_util.py",
28 | "./modules/StableFast/SF_util.py",
29 | "./modules/Utilities/Latent.py",
30 | "./modules/AutoDetailer/SEGS.py",
31 | "./modules/AutoDetailer/tensor_util.py",
32 | "./modules/AutoDetailer/AD_util.py",
33 | "./modules/clip/FluxClip.py",
34 | "./modules/Model/ModelPatcher.py",
35 | "./modules/Model/ModelBase.py",
36 | "./modules/UltimateSDUpscale/image_util.py",
37 | "./modules/UltimateSDUpscale/RDRB.py",
38 | "./modules/StableFast/ModuleFactory.py",
39 | "./modules/AutoDetailer/bbox.py",
40 | "./modules/AutoEncoders/VariationalAE.py",
41 | "./modules/clip/Clip.py",
42 | "./modules/Model/LoRas.py",
43 | "./modules/BlackForest/Flux.py",
44 | "./modules/UltimateSDUpscale/USDU_upscaler.py",
45 | "./modules/StableFast/ModuleTracing.py",
46 | "./modules/hidiffusion/utils.py",
47 | "./modules/FileManaging/Downloader.py",
48 | "./modules/AutoDetailer/SAM.py",
49 | "./modules/AutoDetailer/ADetailer.py",
50 | "./modules/Quantize/Quantizer.py",
51 | "./modules/FileManaging/Loader.py",
52 | "./modules/SD15/SD15.py",
53 | "./modules/UltimateSDUpscale/UltimateSDUpscale.py",
54 | "./modules/StableFast/StableFast.py",
55 | "./modules/hidiffusion/msw_msa_attention.py",
56 | "./modules/FileManaging/ImageSaver.py",
57 | "./modules/Utilities/Enhancer.py",
58 | "./modules/Utilities/upscale.py",
59 | "./modules/user/pipeline.py",
60 | ]
61 |
62 | def get_file_patterns():
63 | patterns = []
64 | seen = set()
65 | for path in files_ordered:
66 | filename = os.path.basename(path)
67 | name = os.path.splitext(filename)[0]
68 | if name not in seen:
69 | # Pattern 1: matches module name when not in brackets or after a dot
70 | pattern1 = rf'(?
2 |
3 | # Say hi to LightDiffusion-Next 👋
4 |
5 | [](https://huggingface.co/spaces/Aatricks/LightDiffusion-Next)
6 |
7 | **LightDiffusion-Next** is the fastest AI-powered image generation GUI/CLI, combining speed, precision, and flexibility in one cohesive tool.
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | As a refactored and improved version of the original [LightDiffusion repository](https://github.com/Aatrick/LightDiffusion), this project enhances usability, maintainability, and functionality while introducing a host of new features to streamline your creative workflows.
18 |
19 | ## Motivation:
20 |
21 | **LightDiffusion** was originally meant to be made in Rust, but due to the lack of support for the Rust language in the AI community, it was made in Python with the goal of being the simplest and fastest AI image generation tool.
22 |
23 | That's when the first version of LightDiffusion was born which only counted [3000 lines of code](https://github.com/LightDiffusion/LightDiffusion-original), only using Pytorch. With time, the [project](https://github.com/Aatrick/LightDiffusion) grew and became more complex, and the need for a refactor was evident. This is where **LightDiffusion-Next** comes in, with a more modular and maintainable codebase, and a plethora of new features and optimizations.
24 |
25 | 📚 Learn more in the [official documentation](https://aatrick.github.io/LightDiffusion/).
26 |
27 | ---
28 |
29 | ## 🌟 Highlights
30 |
31 | 
32 |
33 | **LightDiffusion-Next** offers a powerful suite of tools to cater to creators at every level. At its core, it supports **Text-to-Image** (Txt2Img) and **Image-to-Image** (Img2Img) generation, offering a variety of upscale methods and samplers, to make it easier to create stunning images with minimal effort.
34 |
35 | Advanced users can take advantage of features like **attention syntax**, **Hires-Fix** or **ADetailer**. These tools provide better quality and flexibility for generating complex and high-resolution outputs.
36 |
37 | **LightDiffusion-Next** is fine-tuned for **performance**. Features such as **Xformers** acceleration, **BFloat16** precision support, **WaveSpeed** dynamic caching, and **Stable-Fast** model compilation (which offers up to a 70% speed boost) ensure smooth and efficient operation, even on demanding workloads.
38 |
39 | ---
40 |
41 | ## ✨ Feature Showcase
42 |
43 | Here’s what makes LightDiffusion-Next stand out:
44 |
45 | - **Speed and Efficiency**:
46 | Enjoy industry-leading performance with built-in Xformers, Pytorch, Wavespeed and Stable-Fast optimizations, achieving up to 30% faster speeds compared to the rest of the AI image generation backends in SD1.5 and up to 2x for Flux.
47 |
48 | - **Automatic Detailing**:
49 | Effortlessly enhance faces and body details with AI-driven tools based on the [Impact Pack](https://github.com/ltdrdata/ComfyUI-Impact-Pack).
50 |
51 | - **State Preservation**:
52 | Save and resume your progress with saved states, ensuring seamless transitions between sessions.
53 |
54 | - **Advanced GUI, WebUI and CLI**:
55 | Work through a user-friendly graphical interface as GUI or in the browser using Gradio or leverage the streamlined pipeline for CLI-based workflows.
56 |
57 | - **Integration-Ready**:
58 | Collaborate and create directly in Discord with [Boubou](https://github.com/Aatrick/Boubou), or preview images dynamically with the optional **TAESD preview mode**.
59 |
60 | - **Image Previewing**:
61 | Get a real-time preview of your generated images with TAESD, allowing for user-friendly and interactive workflows.
62 |
63 | - **Image Upscaling**:
64 | Enhance your images with advanced upscaling options like UltimateSDUpscaling, ensuring high-quality results every time.
65 |
66 | - **Prompt Refinement**:
67 | Use the Ollama-powered automatic prompt enhancer to refine your prompts and generate more accurate and detailed outputs.
68 |
69 | - **LoRa and Textual Inversion Embeddings**:
70 | Leverage LoRa and textual inversion embeddings for highly customized and nuanced results, adding a new dimension to your creative process.
71 |
72 | - **Low-End Device Support**:
73 | Run LightDiffusion-Next on low-end devices with as little as 2GB of VRAM or even no GPU, ensuring accessibility for all users.
74 |
75 | - **CFG++**:
76 | Uses samplers modified to use CFG++ for better quality results compared to traditional methods.
77 |
78 | ---
79 |
80 | ## ⚡ Performance Benchmarks
81 |
82 | **LightDiffusion-Next** dominates in performance:
83 |
84 | | **Tool** | **Speed (it/s)** |
85 | |------------------------------------|------------------|
86 | | **LightDiffusion with Stable-Fast** | 2.8 |
87 | | **LightDiffusion** | 1.9 |
88 | | **ComfyUI** | 1.4 |
89 | | **SDForge** | 1.3 |
90 | | **SDWebUI** | 0.9 |
91 |
92 | (All benchmarks are based on a 1024x1024 resolution with a batch size of 1 using BFloat16 precision without tweaking installations. Made with a 3060 mobile GPU using SD1.5.)
93 |
94 | With its unmatched speed and efficiency, LightDiffusion-Next sets the benchmark for AI image generation tools.
95 |
96 | ---
97 |
98 | ## 🛠 Installation
99 |
100 | ### Quick Start
101 |
102 | 1. Download a release or clone this repository.
103 | 2. Run `run.bat` in a terminal.
104 | 3. Start creating!
105 |
106 | ### Command-Line Pipeline
107 |
108 | For a GUI-free experience, use the pipeline:
109 | ```bash
110 | pipeline.bat
111 | ```
112 | Use `pipeline.bat -h` for more options.
113 |
114 | ---
115 |
116 | ### Advanced Setup
117 |
118 | - **Install from Source**:
119 | Install dependencies via:
120 | ```bash
121 | pip install -r requirements.txt
122 | ```
123 | Add your SD1/1.5 safetensors model to the `checkpoints` directory, then launch the application.
124 |
125 | - **⚡Stable-Fast Optimization**:
126 | Follow [this guide](https://github.com/chengzeyi/stable-fast?tab=readme-ov-file#installation) to enable Stable-Fast mode for optimal performance.
127 |
128 | - **🦙 Prompt Enhancer**:
129 | Refine your prompts with Ollama:
130 | ```bash
131 | pip install ollama
132 | ollama run deepseek-r1
133 | ```
134 | See the [Ollama guide](https://github.com/ollama/ollama?tab=readme-ov-file) for details.
135 |
136 | - **🤖 Discord Integration**:
137 | Set up the Discord bot by following the [Boubou installation guide](https://github.com/Aatrick/Boubou).
138 |
139 | ---
140 |
141 | 🎨 Enjoy exploring the powerful features of LightDiffusion-Next!
142 |
--------------------------------------------------------------------------------
/_internal/ESRGAN/put_esrgan_and_other_upscale_models_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/ESRGAN/put_esrgan_and_other_upscale_models_here
--------------------------------------------------------------------------------
/_internal/checkpoints/put_checkpoints_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/checkpoints/put_checkpoints_here
--------------------------------------------------------------------------------
/_internal/clip/sd1_clip_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "openai/clip-vit-large-patch14",
3 | "architectures": [
4 | "CLIPTextModel"
5 | ],
6 | "attention_dropout": 0.0,
7 | "bos_token_id": 0,
8 | "dropout": 0.0,
9 | "eos_token_id": 2,
10 | "hidden_act": "quick_gelu",
11 | "hidden_size": 768,
12 | "initializer_factor": 1.0,
13 | "initializer_range": 0.02,
14 | "intermediate_size": 3072,
15 | "layer_norm_eps": 1e-05,
16 | "max_position_embeddings": 77,
17 | "model_type": "clip_text_model",
18 | "num_attention_heads": 12,
19 | "num_hidden_layers": 12,
20 | "pad_token_id": 1,
21 | "projection_dim": 768,
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.24.0",
24 | "vocab_size": 49408
25 | }
26 |
--------------------------------------------------------------------------------
/_internal/embeddings/put_embeddings_or_textual_inversion_concepts_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/embeddings/put_embeddings_or_textual_inversion_concepts_here
--------------------------------------------------------------------------------
/_internal/last_seed.txt:
--------------------------------------------------------------------------------
1 | 7100032452232484160
--------------------------------------------------------------------------------
/_internal/loras/put_loras_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/loras/put_loras_here
--------------------------------------------------------------------------------
/_internal/output/Adetailer/Adetailer_images_end_up_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/Adetailer/Adetailer_images_end_up_here
--------------------------------------------------------------------------------
/_internal/output/Flux/Flux_images_end_up_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/Flux/Flux_images_end_up_here
--------------------------------------------------------------------------------
/_internal/output/HiresFix/HiresFixed_images_end_up_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/HiresFix/HiresFixed_images_end_up_here
--------------------------------------------------------------------------------
/_internal/output/Img2Img/Upscaled_images_end_up_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/Img2Img/Upscaled_images_end_up_here
--------------------------------------------------------------------------------
/_internal/output/classic/normal_images_end_up_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/output/classic/normal_images_end_up_here
--------------------------------------------------------------------------------
/_internal/prompt.txt:
--------------------------------------------------------------------------------
1 | prompt: masterpiece, best quality, (extremely detailed CG unity 8k wallpaper, masterpiece, best quality, ultra-detailed, best shadow), dynamic angle, light particles, high contrast, (best illumination), ((cinematic light)), colorful, hyper detail, dramatic light, intricate details, depth of field, 1girl, cowboy shot, white hair, long hair,
2 | neg: (worst quality, low quality:1.4), (zombie, sketch, interlocked fingers, comic), (embedding:EasyNegative), (embedding:badhandv4)
3 | w: 768
4 | h: 1024
5 | cfg: 8
6 |
--------------------------------------------------------------------------------
/_internal/sd1_tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": {
3 | "content": "<|startoftext|>",
4 | "lstrip": false,
5 | "normalized": true,
6 | "rstrip": false,
7 | "single_word": false
8 | },
9 | "eos_token": {
10 | "content": "<|endoftext|>",
11 | "lstrip": false,
12 | "normalized": true,
13 | "rstrip": false,
14 | "single_word": false
15 | },
16 | "pad_token": "<|endoftext|>",
17 | "unk_token": {
18 | "content": "<|endoftext|>",
19 | "lstrip": false,
20 | "normalized": true,
21 | "rstrip": false,
22 | "single_word": false
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/_internal/sd1_tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "bos_token": {
4 | "__type": "AddedToken",
5 | "content": "<|startoftext|>",
6 | "lstrip": false,
7 | "normalized": true,
8 | "rstrip": false,
9 | "single_word": false
10 | },
11 | "do_lower_case": true,
12 | "eos_token": {
13 | "__type": "AddedToken",
14 | "content": "<|endoftext|>",
15 | "lstrip": false,
16 | "normalized": true,
17 | "rstrip": false,
18 | "single_word": false
19 | },
20 | "errors": "replace",
21 | "model_max_length": 77,
22 | "name_or_path": "openai/clip-vit-large-patch14",
23 | "pad_token": "<|endoftext|>",
24 | "special_tokens_map_file": "./special_tokens_map.json",
25 | "tokenizer_class": "CLIPTokenizer",
26 | "unk_token": {
27 | "__type": "AddedToken",
28 | "content": "<|endoftext|>",
29 | "lstrip": false,
30 | "normalized": true,
31 | "rstrip": false,
32 | "single_word": false
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/_internal/yolos/put_yolo_and_seg_files_here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/_internal/yolos/put_yolo_and_seg_files_here
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import gradio as gr
3 | import sys
4 | import os
5 | from PIL import Image
6 | import numpy as np
7 | import spaces
8 |
9 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))
10 |
11 | from modules.user.pipeline import pipeline
12 | import torch
13 |
14 |
15 | def load_generated_images():
16 | """Load generated images with given prefix from disk"""
17 | image_files = glob.glob("./_internal/output/**/*.png")
18 |
19 | # If there are no image files, return
20 | if not image_files:
21 | return []
22 |
23 | # Sort files by modification time in descending order
24 | image_files.sort(key=os.path.getmtime, reverse=True)
25 |
26 | # Get most recent timestamp
27 | latest_time = os.path.getmtime(image_files[0])
28 |
29 | # Get all images from same batch (within 1 second of most recent)
30 | batch_images = []
31 | for file in image_files:
32 | if abs(os.path.getmtime(file) - latest_time) < 1.0:
33 | try:
34 | img = Image.open(file)
35 | batch_images.append(img)
36 | except:
37 | continue
38 |
39 | if not batch_images:
40 | return []
41 | return batch_images
42 |
43 |
44 | @spaces.GPU
45 | def generate_images(
46 | prompt: str,
47 | width: int = 512,
48 | height: int = 512,
49 | num_images: int = 1,
50 | batch_size: int = 1,
51 | hires_fix: bool = False,
52 | adetailer: bool = False,
53 | enhance_prompt: bool = False,
54 | img2img_enabled: bool = False,
55 | img2img_image: str = None,
56 | stable_fast: bool = False,
57 | reuse_seed: bool = False,
58 | flux_enabled: bool = False,
59 | prio_speed: bool = False,
60 | realistic_model: bool = False,
61 | progress=gr.Progress(),
62 | ):
63 | """Generate images using the LightDiffusion pipeline"""
64 | try:
65 | if img2img_enabled and img2img_image is not None:
66 | # Convert numpy array to PIL Image
67 | if isinstance(img2img_image, np.ndarray):
68 | img_pil = Image.fromarray(img2img_image)
69 | img_pil.save("temp_img2img.png")
70 | prompt = "temp_img2img.png"
71 |
72 | # Run pipeline and capture saved images
73 | with torch.inference_mode():
74 | pipeline(
75 | prompt=prompt,
76 | w=width,
77 | h=height,
78 | number=num_images,
79 | batch=batch_size,
80 | hires_fix=hires_fix,
81 | adetailer=adetailer,
82 | enhance_prompt=enhance_prompt,
83 | img2img=img2img_enabled,
84 | stable_fast=stable_fast,
85 | reuse_seed=reuse_seed,
86 | flux_enabled=flux_enabled,
87 | prio_speed=prio_speed,
88 | autohdr=True,
89 | realistic_model=realistic_model,
90 | )
91 |
92 | # Clean up temporary file if it exists
93 | if os.path.exists("temp_img2img.png"):
94 | os.remove("temp_img2img.png")
95 |
96 | return load_generated_images()
97 |
98 | except Exception:
99 | import traceback
100 |
101 | print(traceback.format_exc())
102 | # Clean up temporary file if it exists
103 | if os.path.exists("temp_img2img.png"):
104 | os.remove("temp_img2img.png")
105 | return [Image.new("RGB", (512, 512), color="black")]
106 |
107 |
108 | # Create Gradio interface
109 | with gr.Blocks(title="LightDiffusion Web UI") as demo:
110 | gr.Markdown("# LightDiffusion Web UI")
111 | gr.Markdown("Generate AI images using LightDiffusion")
112 | gr.Markdown(
113 | "This is the demo for LightDiffusion, the fastest diffusion backend for generating images. https://github.com/LightDiffusion/LightDiffusion-Next"
114 | )
115 |
116 | with gr.Row():
117 | with gr.Column():
118 | # Input components
119 | prompt = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
120 |
121 | with gr.Row():
122 | width = gr.Slider(
123 | minimum=64, maximum=2048, value=512, step=64, label="Width"
124 | )
125 | height = gr.Slider(
126 | minimum=64, maximum=2048, value=512, step=64, label="Height"
127 | )
128 |
129 | with gr.Row():
130 | num_images = gr.Slider(
131 | minimum=1, maximum=10, value=1, step=1, label="Number of Images"
132 | )
133 | batch_size = gr.Slider(
134 | minimum=1, maximum=4, value=1, step=1, label="Batch Size"
135 | )
136 |
137 | with gr.Row():
138 | hires_fix = gr.Checkbox(label="HiRes Fix")
139 | adetailer = gr.Checkbox(label="Auto Face/Body Enhancement")
140 | enhance_prompt = gr.Checkbox(label="Enhance Prompt")
141 | stable_fast = gr.Checkbox(label="Stable Fast Mode")
142 |
143 | with gr.Row():
144 | reuse_seed = gr.Checkbox(label="Reuse Seed")
145 | flux_enabled = gr.Checkbox(label="Flux Mode")
146 | prio_speed = gr.Checkbox(label="Prioritize Speed")
147 | realistic_model = gr.Checkbox(label="Realistic Model")
148 |
149 | with gr.Row():
150 | img2img_enabled = gr.Checkbox(label="Image to Image Mode")
151 | img2img_image = gr.Image(label="Input Image for img2img", visible=False)
152 |
153 | # Make input image visible only when img2img is enabled
154 | img2img_enabled.change(
155 | fn=lambda x: gr.update(visible=x),
156 | inputs=[img2img_enabled],
157 | outputs=[img2img_image],
158 | )
159 |
160 | generate_btn = gr.Button("Generate")
161 |
162 | # Output gallery
163 | gallery = gr.Gallery(
164 | label="Generated Images",
165 | show_label=True,
166 | elem_id="gallery",
167 | columns=[2],
168 | rows=[2],
169 | object_fit="contain",
170 | height="auto",
171 | )
172 |
173 | # Connect generate button to pipeline
174 | generate_btn.click(
175 | fn=generate_images,
176 | inputs=[
177 | prompt,
178 | width,
179 | height,
180 | num_images,
181 | batch_size,
182 | hires_fix,
183 | adetailer,
184 | enhance_prompt,
185 | img2img_enabled,
186 | img2img_image,
187 | stable_fast,
188 | reuse_seed,
189 | flux_enabled,
190 | prio_speed,
191 | realistic_model,
192 | ],
193 | outputs=gallery,
194 | )
195 |
196 |
197 | def is_huggingface_space():
198 | return "SPACE_ID" in os.environ
199 |
200 |
201 | # For local testing
202 | if __name__ == "__main__":
203 | if is_huggingface_space():
204 | demo.launch(
205 | debug=False,
206 | server_name="0.0.0.0",
207 | server_port=7860, # Standard HF Spaces port
208 | )
209 | else:
210 | demo.launch(
211 | server_name="0.0.0.0",
212 | server_port=8000,
213 | auth=None,
214 | share=True, # Only enable sharing locally
215 | debug=True,
216 | )
217 |
--------------------------------------------------------------------------------
/dependency_flow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/dependency_flow.png
--------------------------------------------------------------------------------
/modules/Attention/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import logging
4 |
5 | from modules.Utilities import util
6 | from modules.Attention import AttentionMethods
7 | from modules.Device import Device
8 | from modules.cond import cast
9 |
10 |
11 | def Normalize(
12 | in_channels: int, dtype: torch.dtype = None, device: torch.device = None
13 | ) -> torch.nn.GroupNorm:
14 | """#### Normalize the input channels.
15 |
16 | #### Args:
17 | - `in_channels` (int): The input channels.
18 | - `dtype` (torch.dtype, optional): The data type. Defaults to `None`.
19 | - `device` (torch.device, optional): The device. Defaults to `None`.
20 |
21 | #### Returns:
22 | - `torch.nn.GroupNorm`: The normalized input channels
23 | """
24 | return torch.nn.GroupNorm(
25 | num_groups=32,
26 | num_channels=in_channels,
27 | eps=1e-6,
28 | affine=True,
29 | dtype=dtype,
30 | device=device,
31 | )
32 |
33 |
34 | if Device.xformers_enabled():
35 | logging.info("Using xformers cross attention")
36 | optimized_attention = AttentionMethods.attention_xformers
37 | else:
38 | logging.info("Using pytorch cross attention")
39 | optimized_attention = AttentionMethods.attention_pytorch
40 |
41 | optimized_attention_masked = optimized_attention
42 |
43 |
44 | def optimized_attention_for_device() -> AttentionMethods.attention_pytorch:
45 | """#### Get the optimized attention for a device.
46 |
47 | #### Returns:
48 | - `function`: The optimized attention function.
49 | """
50 | return AttentionMethods.attention_pytorch
51 |
52 |
53 | class CrossAttention(nn.Module):
54 | """#### Cross attention module, which applies attention across the query and context.
55 |
56 | #### Args:
57 | - `query_dim` (int): The query dimension.
58 | - `context_dim` (int, optional): The context dimension. Defaults to `None`.
59 | - `heads` (int, optional): The number of heads. Defaults to `8`.
60 | - `dim_head` (int, optional): The head dimension. Defaults to `64`.
61 | - `dropout` (float, optional): The dropout rate. Defaults to `0.0`.
62 | - `dtype` (torch.dtype, optional): The data type. Defaults to `None`.
63 | - `device` (torch.device, optional): The device. Defaults to `None`.
64 | - `operations` (cast.disable_weight_init, optional): The operations. Defaults to `cast.disable_weight_init`.
65 | """
66 |
67 | def __init__(
68 | self,
69 | query_dim: int,
70 | context_dim: int = None,
71 | heads: int = 8,
72 | dim_head: int = 64,
73 | dropout: float = 0.0,
74 | dtype: torch.dtype = None,
75 | device: torch.device = None,
76 | operations: cast.disable_weight_init = cast.disable_weight_init,
77 | ):
78 | super().__init__()
79 | inner_dim = dim_head * heads
80 | context_dim = util.default(context_dim, query_dim)
81 |
82 | self.heads = heads
83 | self.dim_head = dim_head
84 |
85 | self.to_q = operations.Linear(
86 | query_dim, inner_dim, bias=False, dtype=dtype, device=device
87 | )
88 | self.to_k = operations.Linear(
89 | context_dim, inner_dim, bias=False, dtype=dtype, device=device
90 | )
91 | self.to_v = operations.Linear(
92 | context_dim, inner_dim, bias=False, dtype=dtype, device=device
93 | )
94 |
95 | self.to_out = nn.Sequential(
96 | operations.Linear(inner_dim, query_dim, dtype=dtype, device=device),
97 | nn.Dropout(dropout),
98 | )
99 |
100 | def forward(
101 | self,
102 | x: torch.Tensor,
103 | context: torch.Tensor = None,
104 | value: torch.Tensor = None,
105 | mask: torch.Tensor = None,
106 | ) -> torch.Tensor:
107 | """#### Forward pass of the cross attention module.
108 |
109 | #### Args:
110 | - `x` (torch.Tensor): The input tensor.
111 | - `context` (torch.Tensor, optional): The context tensor. Defaults to `None`.
112 | - `value` (torch.Tensor, optional): The value tensor. Defaults to `None`.
113 | - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.
114 |
115 | #### Returns:
116 | - `torch.Tensor`: The output tensor.
117 | """
118 | q = self.to_q(x)
119 | context = util.default(context, x)
120 | k = self.to_k(context)
121 | v = self.to_v(context)
122 |
123 | out = optimized_attention(q, k, v, self.heads)
124 | return self.to_out(out)
125 |
126 |
127 | class AttnBlock(nn.Module):
128 | """#### Attention block, which applies attention to the input tensor.
129 |
130 | #### Args:
131 | - `in_channels` (int): The input channels.
132 | """
133 |
134 | def __init__(self, in_channels: int):
135 | super().__init__()
136 | self.in_channels = in_channels
137 |
138 | self.norm = Normalize(in_channels)
139 | self.q = cast.disable_weight_init.Conv2d(
140 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
141 | )
142 | self.k = cast.disable_weight_init.Conv2d(
143 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
144 | )
145 | self.v = cast.disable_weight_init.Conv2d(
146 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
147 | )
148 | self.proj_out = cast.disable_weight_init.Conv2d(
149 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
150 | )
151 |
152 | if Device.xformers_enabled_vae():
153 | logging.info("Using xformers attention in VAE")
154 | self.optimized_attention = AttentionMethods.xformers_attention
155 | else:
156 | logging.info("Using pytorch attention in VAE")
157 | self.optimized_attention = AttentionMethods.pytorch_attention
158 |
159 | def forward(self, x: torch.Tensor) -> torch.Tensor:
160 | """#### Forward pass of the attention block.
161 |
162 | #### Args:
163 | - `x` (torch.Tensor): The input tensor.
164 |
165 | #### Returns:
166 | - `torch.Tensor`: The output tensor.
167 | """
168 | h_ = x
169 | h_ = self.norm(h_)
170 | q = self.q(h_)
171 | k = self.k(h_)
172 | v = self.v(h_)
173 |
174 | h_ = self.optimized_attention(q, k, v)
175 |
176 | h_ = self.proj_out(h_)
177 |
178 | return x + h_
179 |
180 |
181 | def make_attn(in_channels: int, attn_type: str = "vanilla") -> AttnBlock:
182 | """#### Make an attention block.
183 |
184 | #### Args:
185 | - `in_channels` (int): The input channels.
186 | - `attn_type` (str, optional): The attention type. Defaults to "vanilla".
187 |
188 | #### Returns:
189 | - `AttnBlock`: A class instance of the attention block.
190 | """
191 | return AttnBlock(in_channels)
192 |
--------------------------------------------------------------------------------
/modules/Attention/AttentionMethods.py:
--------------------------------------------------------------------------------
1 | try :
2 | import xformers
3 | except ImportError:
4 | pass
5 | import torch
6 |
7 | BROKEN_XFORMERS = False
8 | try:
9 | x_vers = xformers.__version__
10 | # XFormers bug confirmed on all versions from 0.0.21 to 0.0.26 (q with bs bigger than 65535 gives CUDA error)
11 | BROKEN_XFORMERS = x_vers.startswith("0.0.2") and not x_vers.startswith("0.0.20")
12 | except:
13 | pass
14 |
15 |
16 | def attention_xformers(
17 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False, flux=False
18 | ) -> torch.Tensor:
19 | """#### Make an attention call using xformers. Fastest attention implementation.
20 |
21 | #### Args:
22 | - `q` (torch.Tensor): The query tensor.
23 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q`.
24 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q`.
25 | - `heads` (int): The number of heads, must be a divisor of the hidden dimension.
26 | - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.
27 |
28 | #### Returns:
29 | - `torch.Tensor`: The output tensor.
30 | """
31 | if not flux:
32 | b, _, dim_head = q.shape
33 | dim_head //= heads
34 |
35 | q, k, v = map(
36 | lambda t: t.unsqueeze(3)
37 | .reshape(b, -1, heads, dim_head)
38 | .permute(0, 2, 1, 3)
39 | .reshape(b * heads, -1, dim_head)
40 | .contiguous(),
41 | (q, k, v),
42 | )
43 |
44 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
45 |
46 | out = (
47 | out.unsqueeze(0)
48 | .reshape(b, heads, -1, dim_head)
49 | .permute(0, 2, 1, 3)
50 | .reshape(b, -1, heads * dim_head)
51 | )
52 | return out
53 | else:
54 | if skip_reshape:
55 | b, _, _, dim_head = q.shape
56 | else:
57 | b, _, dim_head = q.shape
58 | dim_head //= heads
59 |
60 | disabled_xformers = False
61 |
62 | if BROKEN_XFORMERS:
63 | if b * heads > 65535:
64 | disabled_xformers = True
65 |
66 | if not disabled_xformers:
67 | if torch.jit.is_tracing() or torch.jit.is_scripting():
68 | disabled_xformers = True
69 |
70 | if disabled_xformers:
71 | return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
72 |
73 | if skip_reshape:
74 | q, k, v = map(
75 | lambda t: t.reshape(b * heads, -1, dim_head),
76 | (q, k, v),
77 | )
78 | else:
79 | q, k, v = map(
80 | lambda t: t.reshape(b, -1, heads, dim_head),
81 | (q, k, v),
82 | )
83 |
84 | if mask is not None:
85 | pad = 8 - q.shape[1] % 8
86 | mask_out = torch.empty(
87 | [q.shape[0], q.shape[1], q.shape[1] + pad], dtype=q.dtype, device=q.device
88 | )
89 | mask_out[:, :, : mask.shape[-1]] = mask
90 | mask = mask_out[:, :, : mask.shape[-1]]
91 |
92 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask)
93 |
94 | if skip_reshape:
95 | out = (
96 | out.unsqueeze(0)
97 | .reshape(b, heads, -1, dim_head)
98 | .permute(0, 2, 1, 3)
99 | .reshape(b, -1, heads * dim_head)
100 | )
101 | else:
102 | out = out.reshape(b, -1, heads * dim_head)
103 |
104 | return out
105 |
106 |
107 | def attention_pytorch(
108 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, heads: int, mask=None, skip_reshape=False, flux=False
109 | ) -> torch.Tensor:
110 | """#### Make an attention call using PyTorch.
111 |
112 | #### Args:
113 | - `q` (torch.Tensor): The query tensor.
114 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q.
115 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q.
116 | - `heads` (int): The number of heads, must be a divisor of the hidden dimension.
117 | - `mask` (torch.Tensor, optional): The mask tensor. Defaults to `None`.
118 |
119 | #### Returns:
120 | - `torch.Tensor`: The output tensor.
121 | """
122 | if not flux:
123 | b, _, dim_head = q.shape
124 | dim_head //= heads
125 | q, k, v = map(
126 | lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
127 | (q, k, v),
128 | )
129 |
130 | out = torch.nn.functional.scaled_dot_product_attention(
131 | q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
132 | )
133 | out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
134 | return out
135 | else:
136 | if skip_reshape:
137 | b, _, _, dim_head = q.shape
138 | else:
139 | b, _, dim_head = q.shape
140 | dim_head //= heads
141 | q, k, v = map(
142 | lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
143 | (q, k, v),
144 | )
145 |
146 | out = torch.nn.functional.scaled_dot_product_attention(
147 | q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
148 | )
149 | out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
150 | return out
151 |
152 | def xformers_attention(
153 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
154 | ) -> torch.Tensor:
155 | """#### Compute attention using xformers.
156 |
157 | #### Args:
158 | - `q` (torch.Tensor): The query tensor.
159 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q`.
160 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q`.
161 |
162 | Returns:
163 | - `torch.Tensor`: The output tensor.
164 | """
165 | B, C, H, W = q.shape
166 | q, k, v = map(
167 | lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
168 | (q, k, v),
169 | )
170 | out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
171 | out = out.transpose(1, 2).reshape(B, C, H, W)
172 | return out
173 |
174 |
175 | def pytorch_attention(
176 | q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
177 | ) -> torch.Tensor:
178 | """#### Compute attention using PyTorch.
179 |
180 | #### Args:
181 | - `q` (torch.Tensor): The query tensor.
182 | - `k` (torch.Tensor): The key tensor, must have the same shape as `q.
183 | - `v` (torch.Tensor): The value tensor, must have the same shape as `q.
184 |
185 | #### Returns:
186 | - `torch.Tensor`: The output tensor.
187 | """
188 | B, C, H, W = q.shape
189 | q, k, v = map(
190 | lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
191 | (q, k, v),
192 | )
193 | out = torch.nn.functional.scaled_dot_product_attention(
194 | q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False
195 | )
196 | out = out.transpose(2, 3).reshape(B, C, H, W)
197 | return out
198 |
--------------------------------------------------------------------------------
/modules/AutoDetailer/AD_util.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import cv2
3 | import numpy as np
4 | import torch
5 | from ultralytics import YOLO
6 | from PIL import Image
7 |
8 | orig_torch_load = torch.load
9 |
10 | # importing YOLO breaking original torch.load capabilities
11 | torch.load = orig_torch_load
12 |
13 |
14 | def load_yolo(model_path: str) -> YOLO:
15 | """#### Load YOLO model.
16 |
17 | #### Args:
18 | - `model_path` (str): The path to the YOLO model.
19 |
20 | #### Returns:
21 | - `YOLO`: The YOLO model initialized with the specified model path.
22 | """
23 | try:
24 | return YOLO(model_path)
25 | except ModuleNotFoundError:
26 | print("please download yolo model")
27 |
28 |
29 | def inference_bbox(
30 | model: YOLO,
31 | image: Image.Image,
32 | confidence: float = 0.3,
33 | device: str = "",
34 | ) -> List:
35 | """#### Perform inference on an image and return bounding boxes.
36 |
37 | #### Args:
38 | - `model` (YOLO): The YOLO model.
39 | - `image` (Image.Image): The image to perform inference on.
40 | - `confidence` (float): The confidence threshold for the bounding boxes.
41 | - `device` (str): The device to run the model on.
42 |
43 | #### Returns:
44 | - `List[List[str, List[int], np.ndarray, float]]`: The list of bounding boxes.
45 | """
46 | pred = model(image, conf=confidence, device=device)
47 |
48 | bboxes = pred[0].boxes.xyxy.cpu().numpy()
49 | cv2_image = np.array(image)
50 | cv2_image = cv2_image[:, :, ::-1].copy() # Convert RGB to BGR for cv2 processing
51 | cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY)
52 |
53 | segms = []
54 | for x0, y0, x1, y1 in bboxes:
55 | cv2_mask = np.zeros(cv2_gray.shape, np.uint8)
56 | cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1)
57 | cv2_mask_bool = cv2_mask.astype(bool)
58 | segms.append(cv2_mask_bool)
59 |
60 | results = [[], [], [], []]
61 | for i in range(len(bboxes)):
62 | results[0].append(pred[0].names[int(pred[0].boxes[i].cls.item())])
63 | results[1].append(bboxes[i])
64 | results[2].append(segms[i])
65 | results[3].append(pred[0].boxes[i].conf.cpu().numpy())
66 |
67 | return results
68 |
69 |
70 | def create_segmasks(results: List) -> List:
71 | """#### Create segmentation masks from the results of the inference.
72 |
73 | #### Args:
74 | - `results` (List[List[str, List[int], np.ndarray, float]]): The results of the inference.
75 |
76 | #### Returns:
77 | - `List[List[int], np.ndarray, float]`: The list of segmentation masks.
78 | """
79 | bboxs = results[1]
80 | segms = results[2]
81 | confidence = results[3]
82 |
83 | results = []
84 | for i in range(len(segms)):
85 | item = (bboxs[i], segms[i].astype(np.float32), confidence[i])
86 | results.append(item)
87 | return results
88 |
89 |
90 | def dilate_masks(segmasks: List, dilation_factor: int, iter: int = 1) -> List:
91 | """#### Dilate the segmentation masks.
92 |
93 | #### Args:
94 | - `segmasks` (List[List[int], np.ndarray, float]): The segmentation masks.
95 | - `dilation_factor` (int): The dilation factor.
96 | - `iter` (int): The number of iterations.
97 |
98 | #### Returns:
99 | - `List[List[int], np.ndarray, float]`: The dilated segmentation masks.
100 | """
101 | dilated_masks = []
102 | kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8)
103 |
104 | for i in range(len(segmasks)):
105 | cv2_mask = segmasks[i][1]
106 |
107 | dilated_mask = cv2.dilate(cv2_mask, kernel, iter)
108 |
109 | item = (segmasks[i][0], dilated_mask, segmasks[i][2])
110 | dilated_masks.append(item)
111 |
112 | return dilated_masks
113 |
114 |
115 | def normalize_region(limit: int, startp: int, size: int) -> List:
116 | """#### Normalize the region.
117 |
118 | #### Args:
119 | - `limit` (int): The limit.
120 | - `startp` (int): The start point.
121 | - `size` (int): The size.
122 |
123 | #### Returns:
124 | - `List[int]`: The normalized start and end points.
125 | """
126 | if startp < 0:
127 | new_endp = min(limit, size)
128 | new_startp = 0
129 | elif startp + size > limit:
130 | new_startp = max(0, limit - size)
131 | new_endp = limit
132 | else:
133 | new_startp = startp
134 | new_endp = min(limit, startp + size)
135 |
136 | return int(new_startp), int(new_endp)
137 |
138 |
139 | def make_crop_region(w: int, h: int, bbox: List, crop_factor: float) -> List:
140 | """#### Make the crop region.
141 |
142 | #### Args:
143 | - `w` (int): The width.
144 | - `h` (int): The height.
145 | - `bbox` (List[int]): The bounding box.
146 | - `crop_factor` (float): The crop factor.
147 |
148 | #### Returns:
149 | - `List[x1: int, y1: int, x2: int, y2: int]`: The crop region.
150 | """
151 | x1 = bbox[0]
152 | y1 = bbox[1]
153 | x2 = bbox[2]
154 | y2 = bbox[3]
155 |
156 | bbox_w = x2 - x1
157 | bbox_h = y2 - y1
158 |
159 | crop_w = bbox_w * crop_factor
160 | crop_h = bbox_h * crop_factor
161 |
162 | kernel_x = x1 + bbox_w / 2
163 | kernel_y = y1 + bbox_h / 2
164 |
165 | new_x1 = int(kernel_x - crop_w / 2)
166 | new_y1 = int(kernel_y - crop_h / 2)
167 |
168 | # make sure position in (w,h)
169 | new_x1, new_x2 = normalize_region(w, new_x1, crop_w)
170 | new_y1, new_y2 = normalize_region(h, new_y1, crop_h)
171 |
172 | return [new_x1, new_y1, new_x2, new_y2]
173 |
174 |
175 | def crop_ndarray2(npimg: np.ndarray, crop_region: List) -> np.ndarray:
176 | """#### Crop the ndarray in 2 dimensions.
177 |
178 | #### Args:
179 | - `npimg` (np.ndarray): The ndarray to crop.
180 | - `crop_region` (List[int]): The crop region.
181 |
182 | #### Returns:
183 | - `np.ndarray`: The cropped ndarray.
184 | """
185 | x1 = crop_region[0]
186 | y1 = crop_region[1]
187 | x2 = crop_region[2]
188 | y2 = crop_region[3]
189 |
190 | cropped = npimg[y1:y2, x1:x2]
191 |
192 | return cropped
193 |
194 |
195 | def crop_ndarray4(npimg: np.ndarray, crop_region: List) -> np.ndarray:
196 | """#### Crop the ndarray in 4 dimensions.
197 |
198 | #### Args:
199 | - `npimg` (np.ndarray): The ndarray to crop.
200 | - `crop_region` (List[int]): The crop region.
201 |
202 | #### Returns:
203 | - `np.ndarray`: The cropped ndarray.
204 | """
205 | x1 = crop_region[0]
206 | y1 = crop_region[1]
207 | x2 = crop_region[2]
208 | y2 = crop_region[3]
209 |
210 | cropped = npimg[:, y1:y2, x1:x2, :]
211 |
212 | return cropped
213 |
214 |
215 | def crop_image(image: Image.Image, crop_region: List) -> Image.Image:
216 | """#### Crop the image.
217 |
218 | #### Args:
219 | - `image` (Image.Image): The image to crop.
220 | - `crop_region` (List[int]): The crop region.
221 |
222 | #### Returns:
223 | - `Image.Image`: The cropped image.
224 | """
225 | return crop_ndarray4(image, crop_region)
226 |
227 |
228 | def segs_scale_match(segs: List[np.ndarray], target_shape: List) -> List:
229 | """#### Match the scale of the segmentation masks.
230 |
231 | #### Args:
232 | - `segs` (List[np.ndarray]): The segmentation masks.
233 | - `target_shape` (List[int]): The target shape.
234 |
235 | #### Returns:
236 | - `List[np.ndarray]`: The matched segmentation masks.
237 | """
238 | h = segs[0][0]
239 | w = segs[0][1]
240 |
241 | th = target_shape[1]
242 | tw = target_shape[2]
243 |
244 | if (h == th and w == tw) or h == 0 or w == 0:
245 | return segs
246 |
--------------------------------------------------------------------------------
/modules/AutoDetailer/SAM.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from segment_anything import SamPredictor, sam_model_registry
4 | import torch
5 |
6 | from modules.AutoDetailer import mask_util
7 | from modules.Device import Device
8 |
9 |
10 | def sam_predict(
11 | predictor: SamPredictor, points: list, plabs: list, bbox: list, threshold: float
12 | ) -> list:
13 | """#### Predict masks using SAM.
14 |
15 | #### Args:
16 | - `predictor` (SamPredictor): The SAM predictor.
17 | - `points` (list): List of points.
18 | - `plabs` (list): List of point labels.
19 | - `bbox` (list): Bounding box.
20 | - `threshold` (float): Threshold for mask selection.
21 |
22 | #### Returns:
23 | - `list`: List of predicted masks.
24 | """
25 | point_coords = None if not points else np.array(points)
26 | point_labels = None if not plabs else np.array(plabs)
27 |
28 | box = np.array([bbox]) if bbox is not None else None
29 |
30 | cur_masks, scores, _ = predictor.predict(
31 | point_coords=point_coords, point_labels=point_labels, box=box
32 | )
33 |
34 | total_masks = []
35 |
36 | selected = False
37 | max_score = 0
38 | max_mask = None
39 | for idx in range(len(scores)):
40 | if scores[idx] > max_score:
41 | max_score = scores[idx]
42 | max_mask = cur_masks[idx]
43 |
44 | if scores[idx] >= threshold:
45 | selected = True
46 | total_masks.append(cur_masks[idx])
47 | else:
48 | pass
49 |
50 | if not selected and max_mask is not None:
51 | total_masks.append(max_mask)
52 |
53 | return total_masks
54 |
55 |
56 | def is_same_device(a: torch.device, b: torch.device) -> bool:
57 | """#### Check if two devices are the same.
58 |
59 | #### Args:
60 | - `a` (torch.device): The first device.
61 | - `b` (torch.device): The second device.
62 |
63 | #### Returns:
64 | - `bool`: Whether the devices are the same.
65 | """
66 | a_device = torch.device(a) if isinstance(a, str) else a
67 | b_device = torch.device(b) if isinstance(b, str) else b
68 | return a_device.type == b_device.type and a_device.index == b_device.index
69 |
70 |
71 | class SafeToGPU:
72 | """#### Class to safely move objects to GPU."""
73 |
74 | def __init__(self, size: int):
75 | self.size = size
76 |
77 | def to_device(self, obj: torch.nn.Module, device: torch.device) -> None:
78 | """#### Move an object to a device.
79 |
80 | #### Args:
81 | - `obj` (torch.nn.Module): The object to move.
82 | - `device` (torch.device): The target device.
83 | """
84 | if is_same_device(device, "cpu"):
85 | obj.to(device)
86 | else:
87 | if is_same_device(obj.device, "cpu"): # cpu to gpu
88 | Device.free_memory(self.size * 1.3, device)
89 | if Device.get_free_memory(device) > self.size * 1.3:
90 | try:
91 | obj.to(device)
92 | except:
93 | print(
94 | f"WARN: The model is not moved to the '{device}' due to insufficient memory. [1]"
95 | )
96 | else:
97 | print(
98 | f"WARN: The model is not moved to the '{device}' due to insufficient memory. [2]"
99 | )
100 |
101 |
102 | class SAMWrapper:
103 | """#### Wrapper class for SAM model."""
104 |
105 | def __init__(
106 | self, model: torch.nn.Module, is_auto_mode: bool, safe_to_gpu: SafeToGPU = None
107 | ):
108 | self.model = model
109 | self.safe_to_gpu = safe_to_gpu if safe_to_gpu is not None else SafeToGPU()
110 | self.is_auto_mode = is_auto_mode
111 |
112 | def prepare_device(self) -> None:
113 | """#### Prepare the device for the model."""
114 | if self.is_auto_mode:
115 | device = Device.get_torch_device()
116 | self.safe_to_gpu.to_device(self.model, device=device)
117 |
118 | def release_device(self) -> None:
119 | """#### Release the device from the model."""
120 | if self.is_auto_mode:
121 | self.model.to(device="cpu")
122 |
123 | def predict(
124 | self, image: np.ndarray, points: list, plabs: list, bbox: list, threshold: float
125 | ) -> list:
126 | """#### Predict masks using the SAM model.
127 |
128 | #### Args:
129 | - `image` (np.ndarray): The input image.
130 | - `points` (list): List of points.
131 | - `plabs` (list): List of point labels.
132 | - `bbox` (list): Bounding box.
133 | - `threshold` (float): Threshold for mask selection.
134 |
135 | #### Returns:
136 | - `list`: List of predicted masks.
137 | """
138 | predictor = SamPredictor(self.model)
139 | predictor.set_image(image, "RGB")
140 |
141 | return sam_predict(predictor, points, plabs, bbox, threshold)
142 |
143 |
144 | class SAMLoader:
145 | """#### Class to load SAM models."""
146 |
147 | def load_model(self, model_name: str, device_mode: str = "auto") -> tuple:
148 | """#### Load a SAM model.
149 |
150 | #### Args:
151 | - `model_name` (str): The name of the model.
152 | - `device_mode` (str, optional): The device mode. Defaults to "auto".
153 |
154 | #### Returns:
155 | - `tuple`: The loaded SAM model.
156 | """
157 | modelname = "./_internal/yolos/" + model_name
158 |
159 | if "vit_h" in model_name:
160 | model_kind = "vit_h"
161 | elif "vit_l" in model_name:
162 | model_kind = "vit_l"
163 | else:
164 | model_kind = "vit_b"
165 |
166 | sam = sam_model_registry[model_kind](checkpoint=modelname)
167 | size = os.path.getsize(modelname)
168 | safe_to = SafeToGPU(size)
169 |
170 | # Unless user explicitly wants to use CPU, we use GPU
171 | device = Device.get_torch_device() if device_mode == "Prefer GPU" else "CPU"
172 |
173 | if device_mode == "Prefer GPU":
174 | safe_to.to_device(sam, device)
175 |
176 | is_auto_mode = device_mode == "AUTO"
177 |
178 | sam_obj = SAMWrapper(sam, is_auto_mode=is_auto_mode, safe_to_gpu=safe_to)
179 | sam.sam_wrapper = sam_obj
180 |
181 | print(f"Loads SAM model: {modelname} (device:{device_mode})")
182 | return (sam,)
183 |
184 |
185 | def make_sam_mask(
186 | sam: SAMWrapper,
187 | segs: tuple,
188 | image: torch.Tensor,
189 | detection_hint: bool,
190 | dilation: int,
191 | threshold: float,
192 | bbox_expansion: int,
193 | mask_hint_threshold: float,
194 | mask_hint_use_negative: bool,
195 | ) -> torch.Tensor:
196 | """#### Create a SAM mask.
197 |
198 | #### Args:
199 | - `sam` (SAMWrapper): The SAM wrapper.
200 | - `segs` (tuple): Segmentation information.
201 | - `image` (torch.Tensor): The input image.
202 | - `detection_hint` (bool): Whether to use detection hint.
203 | - `dilation` (int): Dilation value.
204 | - `threshold` (float): Threshold for mask selection.
205 | - `bbox_expansion` (int): Bounding box expansion value.
206 | - `mask_hint_threshold` (float): Mask hint threshold.
207 | - `mask_hint_use_negative` (bool): Whether to use negative mask hint.
208 |
209 | #### Returns:
210 | - `torch.Tensor`: The created SAM mask.
211 | """
212 | sam_obj = sam.sam_wrapper
213 | sam_obj.prepare_device()
214 |
215 | try:
216 | image = np.clip(255.0 * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
217 |
218 | total_masks = []
219 | # seg_shape = segs[0]
220 | segs = segs[1]
221 | for i in range(len(segs)):
222 | bbox = segs[i].bbox
223 | center = mask_util.center_of_bbox(bbox)
224 | x1 = max(bbox[0] - bbox_expansion, 0)
225 | y1 = max(bbox[1] - bbox_expansion, 0)
226 | x2 = min(bbox[2] + bbox_expansion, image.shape[1])
227 | y2 = min(bbox[3] + bbox_expansion, image.shape[0])
228 | dilated_bbox = [x1, y1, x2, y2]
229 | points = []
230 | plabs = []
231 | points.append(center)
232 | plabs = [1] # 1 = foreground point, 0 = background point
233 | detected_masks = sam_obj.predict(
234 | image, points, plabs, dilated_bbox, threshold
235 | )
236 | total_masks += detected_masks
237 |
238 | # merge every collected masks
239 | mask = mask_util.combine_masks2(total_masks)
240 |
241 | finally:
242 | sam_obj.release_device()
243 |
244 | if mask is not None:
245 | mask = mask.float()
246 | mask = mask_util.dilate_mask(mask.cpu().numpy(), dilation)
247 | mask = torch.from_numpy(mask)
248 |
249 | mask = mask_util.make_3d_mask(mask)
250 | return mask
251 | else:
252 | return None
253 |
254 |
255 | class SAMDetectorCombined:
256 | """#### Class to combine SAM detection."""
257 |
258 | def doit(
259 | self,
260 | sam_model: SAMWrapper,
261 | segs: tuple,
262 | image: torch.Tensor,
263 | detection_hint: bool,
264 | dilation: int,
265 | threshold: float,
266 | bbox_expansion: int,
267 | mask_hint_threshold: float,
268 | mask_hint_use_negative: bool,
269 | ) -> tuple:
270 | """#### Combine SAM detection.
271 |
272 | #### Args:
273 | - `sam_model` (SAMWrapper): The SAM wrapper.
274 | - `segs` (tuple): Segmentation information.
275 | - `image` (torch.Tensor): The input image.
276 | - `detection_hint` (bool): Whether to use detection hint.
277 | - `dilation` (int): Dilation value.
278 | - `threshold` (float): Threshold for mask selection.
279 | - `bbox_expansion` (int): Bounding box expansion value.
280 | - `mask_hint_threshold` (float): Mask hint threshold.
281 | - `mask_hint_use_negative` (bool): Whether to use negative mask hint.
282 |
283 | #### Returns:
284 | - `tuple`: The combined SAM detection result.
285 | """
286 | sam = make_sam_mask(
287 | sam_model,
288 | segs,
289 | image,
290 | detection_hint,
291 | dilation,
292 | threshold,
293 | bbox_expansion,
294 | mask_hint_threshold,
295 | mask_hint_use_negative,
296 | )
297 | if sam is not None:
298 | return (sam,)
299 | else:
300 | return None
301 |
--------------------------------------------------------------------------------
/modules/AutoDetailer/SEGS.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import numpy as np
3 | import torch
4 | from modules.AutoDetailer import mask_util
5 |
6 | SEG = namedtuple(
7 | "SEG",
8 | [
9 | "cropped_image",
10 | "cropped_mask",
11 | "confidence",
12 | "crop_region",
13 | "bbox",
14 | "label",
15 | "control_net_wrapper",
16 | ],
17 | defaults=[None],
18 | )
19 |
20 |
21 | def segs_bitwise_and_mask(segs: tuple, mask: torch.Tensor) -> tuple:
22 | """#### Apply bitwise AND operation between segmentation masks and a given mask.
23 |
24 | #### Args:
25 | - `segs` (tuple): A tuple containing segmentation information.
26 | - `mask` (torch.Tensor): The mask tensor.
27 |
28 | #### Returns:
29 | - `tuple`: A tuple containing the original segmentation and the updated items.
30 | """
31 | mask = mask_util.make_2d_mask(mask)
32 | items = []
33 |
34 | mask = (mask.cpu().numpy() * 255).astype(np.uint8)
35 |
36 | for seg in segs[1]:
37 | cropped_mask = (seg.cropped_mask * 255).astype(np.uint8)
38 | crop_region = seg.crop_region
39 |
40 | cropped_mask2 = mask[
41 | crop_region[1] : crop_region[3], crop_region[0] : crop_region[2]
42 | ]
43 |
44 | new_mask = np.bitwise_and(cropped_mask.astype(np.uint8), cropped_mask2)
45 | new_mask = new_mask.astype(np.float32) / 255.0
46 |
47 | item = SEG(
48 | seg.cropped_image,
49 | new_mask,
50 | seg.confidence,
51 | seg.crop_region,
52 | seg.bbox,
53 | seg.label,
54 | None,
55 | )
56 | items.append(item)
57 |
58 | return segs[0], items
59 |
60 |
61 | class SegsBitwiseAndMask:
62 | """#### Class to apply bitwise AND operation between segmentation masks and a given mask."""
63 |
64 | def doit(self, segs: tuple, mask: torch.Tensor) -> tuple:
65 | """#### Apply bitwise AND operation between segmentation masks and a given mask.
66 |
67 | #### Args:
68 | - `segs` (tuple): A tuple containing segmentation information.
69 | - `mask` (torch.Tensor): The mask tensor.
70 |
71 | #### Returns:
72 | - `tuple`: A tuple containing the original segmentation and the updated items.
73 | """
74 | return (segs_bitwise_and_mask(segs, mask),)
75 |
76 |
77 | class SEGSLabelFilter:
78 | """#### Class to filter segmentation labels."""
79 |
80 | @staticmethod
81 | def filter(segs: tuple, labels: list) -> tuple:
82 | """#### Filter segmentation labels.
83 |
84 | #### Args:
85 | - `segs` (tuple): A tuple containing segmentation information.
86 | - `labels` (list): A list of labels to filter.
87 |
88 | #### Returns:
89 | - `tuple`: A tuple containing the original segmentation and an empty list.
90 | """
91 | labels = set([label.strip() for label in labels])
92 | return (
93 | segs,
94 | (segs[0], []),
95 | )
96 |
--------------------------------------------------------------------------------
/modules/AutoDetailer/bbox.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from ultralytics import YOLO
3 | from modules.AutoDetailer import SEGS, AD_util, tensor_util
4 | from typing import List, Tuple, Optional
5 |
6 |
7 | class UltraBBoxDetector:
8 | """#### Class to detect bounding boxes using a YOLO model."""
9 |
10 | bbox_model: Optional[YOLO] = None
11 |
12 | def __init__(self, bbox_model: YOLO):
13 | """#### Initialize the UltraBBoxDetector with a YOLO model.
14 |
15 | #### Args:
16 | - `bbox_model` (YOLO): The YOLO model to use for detection.
17 | """
18 | self.bbox_model = bbox_model
19 |
20 | def detect(
21 | self,
22 | image: torch.Tensor,
23 | threshold: float,
24 | dilation: int,
25 | crop_factor: float,
26 | drop_size: int = 1,
27 | detailer_hook: Optional[callable] = None,
28 | ) -> Tuple[Tuple[int, int], List[SEGS.SEG]]:
29 | """#### Detect bounding boxes in an image.
30 |
31 | #### Args:
32 | - `image` (torch.Tensor): The input image tensor.
33 | - `threshold` (float): The detection threshold.
34 | - `dilation` (int): The dilation factor for masks.
35 | - `crop_factor` (float): The crop factor for bounding boxes.
36 | - `drop_size` (int, optional): The minimum size of bounding boxes to keep. Defaults to 1.
37 | - `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None.
38 |
39 | #### Returns:
40 | - `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments.
41 | """
42 | drop_size = max(drop_size, 1)
43 | detected_results = AD_util.inference_bbox(
44 | self.bbox_model, tensor_util.tensor2pil(image), threshold
45 | )
46 | segmasks = AD_util.create_segmasks(detected_results)
47 |
48 | if dilation > 0:
49 | segmasks = AD_util.dilate_masks(segmasks, dilation)
50 |
51 | items = []
52 | h = image.shape[1]
53 | w = image.shape[2]
54 |
55 | for x, label in zip(segmasks, detected_results[0]):
56 | item_bbox = x[0]
57 | item_mask = x[1]
58 |
59 | y1, x1, y2, x2 = item_bbox
60 |
61 | if (
62 | x2 - x1 > drop_size and y2 - y1 > drop_size
63 | ): # minimum dimension must be (2,2) to avoid squeeze issue
64 | crop_region = AD_util.make_crop_region(w, h, item_bbox, crop_factor)
65 |
66 | cropped_image = AD_util.crop_image(image, crop_region)
67 | cropped_mask = AD_util.crop_ndarray2(item_mask, crop_region)
68 | confidence = x[2]
69 |
70 | item = SEGS.SEG(
71 | cropped_image,
72 | cropped_mask,
73 | confidence,
74 | crop_region,
75 | item_bbox,
76 | label,
77 | None,
78 | )
79 |
80 | items.append(item)
81 |
82 | shape = image.shape[1], image.shape[2]
83 | segs = shape, items
84 |
85 | return segs
86 |
87 |
88 | class UltraSegmDetector:
89 | """#### Class to detect segments using a YOLO model."""
90 |
91 | bbox_model: Optional[YOLO] = None
92 |
93 | def __init__(self, bbox_model: YOLO):
94 | """#### Initialize the UltraSegmDetector with a YOLO model.
95 |
96 | #### Args:
97 | - `bbox_model` (YOLO): The YOLO model to use for detection.
98 | """
99 | self.bbox_model = bbox_model
100 |
101 |
102 | class NO_SEGM_DETECTOR:
103 | """#### Placeholder class for no segment detector."""
104 |
105 | pass
106 |
107 |
108 | class UltralyticsDetectorProvider:
109 | """#### Class to provide YOLO models for detection."""
110 |
111 | def doit(self, model_name: str) -> Tuple[UltraBBoxDetector, UltraSegmDetector]:
112 | """#### Load a YOLO model and return detectors.
113 |
114 | #### Args:
115 | - `model_name` (str): The name of the YOLO model to load.
116 |
117 | #### Returns:
118 | - `Tuple[UltraBBoxDetector, UltraSegmDetector]`: The bounding box and segment detectors.
119 | """
120 | model = AD_util.load_yolo("./_internal/yolos/" + model_name)
121 | return UltraBBoxDetector(model), UltraSegmDetector(model)
122 |
123 |
124 | class BboxDetectorForEach:
125 | """#### Class to detect bounding boxes for each segment."""
126 |
127 | def doit(
128 | self,
129 | bbox_detector: UltraBBoxDetector,
130 | image: torch.Tensor,
131 | threshold: float,
132 | dilation: int,
133 | crop_factor: float,
134 | drop_size: int,
135 | labels: Optional[str] = None,
136 | detailer_hook: Optional[callable] = None,
137 | ) -> Tuple[Tuple[int, int], List[SEGS.SEG]]:
138 | """#### Detect bounding boxes for each segment in an image.
139 |
140 | #### Args:
141 | - `bbox_detector` (UltraBBoxDetector): The bounding box detector.
142 | - `image` (torch.Tensor): The input image tensor.
143 | - `threshold` (float): The detection threshold.
144 | - `dilation` (int): The dilation factor for masks.
145 | - `crop_factor` (float): The crop factor for bounding boxes.
146 | - `drop_size` (int): The minimum size of bounding boxes to keep.
147 | - `labels` (str, optional): The labels to filter. Defaults to None.
148 | - `detailer_hook` (callable, optional): A hook function for additional processing. Defaults to None.
149 |
150 | #### Returns:
151 | - `Tuple[Tuple[int, int], List[SEGS.SEG]]`: The shape of the image and a list of detected segments.
152 | """
153 | segs = bbox_detector.detect(
154 | image, threshold, dilation, crop_factor, drop_size, detailer_hook
155 | )
156 |
157 | if labels is not None and labels != "":
158 | labels = labels.split(",")
159 | if len(labels) > 0:
160 | segs, _ = SEGS.SEGSLabelFilter.filter(segs, labels)
161 |
162 | return segs
163 |
164 |
165 | class WildcardChooser:
166 | """#### Class to choose wildcards for segments."""
167 |
168 | def __init__(self, items: List[Tuple[None, str]], randomize_when_exhaust: bool):
169 | """#### Initialize the WildcardChooser.
170 |
171 | #### Args:
172 | - `items` (List[Tuple[None, str]]): The list of items to choose from.
173 | - `randomize_when_exhaust` (bool): Whether to randomize when the list is exhausted.
174 | """
175 | self.i = 0
176 | self.items = items
177 | self.randomize_when_exhaust = randomize_when_exhaust
178 |
179 | def get(self, seg: SEGS.SEG) -> Tuple[None, str]:
180 | """#### Get the next item from the list.
181 |
182 | #### Args:
183 | - `seg` (SEGS.SEG): The segment.
184 |
185 | #### Returns:
186 | - `Tuple[None, str]`: The next item from the list.
187 | """
188 | item = self.items[self.i]
189 | self.i += 1
190 |
191 | return item
192 |
193 |
194 | def process_wildcard_for_segs(wildcard: str) -> Tuple[None, WildcardChooser]:
195 | """#### Process a wildcard for segments.
196 |
197 | #### Args:
198 | - `wildcard` (str): The wildcard.
199 |
200 | #### Returns:
201 | - `Tuple[None, WildcardChooser]`: The processed wildcard and a WildcardChooser.
202 | """
203 | return None, WildcardChooser([(None, wildcard)], False)
204 |
--------------------------------------------------------------------------------
/modules/AutoDetailer/mask_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def center_of_bbox(bbox: list) -> tuple[float, float]:
6 | """#### Calculate the center of a bounding box.
7 |
8 | #### Args:
9 | - `bbox` (list): The bounding box coordinates [x1, y1, x2, y2].
10 |
11 | #### Returns:
12 | - `tuple[float, float]`: The center coordinates (x, y).
13 | """
14 | w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
15 | return bbox[0] + w / 2, bbox[1] + h / 2
16 |
17 |
18 | def make_2d_mask(mask: torch.Tensor) -> torch.Tensor:
19 | """#### Convert a mask to 2D.
20 |
21 | #### Args:
22 | - `mask` (torch.Tensor): The input mask tensor.
23 |
24 | #### Returns:
25 | - `torch.Tensor`: The 2D mask tensor.
26 | """
27 | if len(mask.shape) == 4:
28 | return mask.squeeze(0).squeeze(0)
29 | elif len(mask.shape) == 3:
30 | return mask.squeeze(0)
31 | return mask
32 |
33 |
34 | def combine_masks2(masks: list) -> torch.Tensor | None:
35 | """#### Combine multiple masks into one.
36 |
37 | #### Args:
38 | - `masks` (list): A list of mask tensors.
39 |
40 | #### Returns:
41 | - `torch.Tensor | None`: The combined mask tensor or None if no masks are provided.
42 | """
43 | try:
44 | mask = torch.from_numpy(np.array(masks[0]).astype(np.uint8))
45 | except:
46 | print("No Human Detected")
47 | return None
48 | return mask
49 |
50 |
51 | def dilate_mask(
52 | mask: torch.Tensor, dilation_factor: int, iter: int = 1
53 | ) -> torch.Tensor:
54 | """#### Dilate a mask.
55 |
56 | #### Args:
57 | - `mask` (torch.Tensor): The input mask tensor.
58 | - `dilation_factor` (int): The dilation factor.
59 | - `iter` (int, optional): The number of iterations. Defaults to 1.
60 |
61 | #### Returns:
62 | - `torch.Tensor`: The dilated mask tensor.
63 | """
64 | return make_2d_mask(mask)
65 |
66 |
67 | def make_3d_mask(mask: torch.Tensor) -> torch.Tensor:
68 | """#### Convert a mask to 3D.
69 |
70 | #### Args:
71 | - `mask` (torch.Tensor): The input mask tensor.
72 |
73 | #### Returns:
74 | - `torch.Tensor`: The 3D mask tensor.
75 | """
76 | if len(mask.shape) == 4:
77 | return mask.squeeze(0)
78 | elif len(mask.shape) == 2:
79 | return mask.unsqueeze(0)
80 | return mask
81 |
--------------------------------------------------------------------------------
/modules/AutoDetailer/tensor_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image
4 | import torchvision
5 |
6 | from modules.Device import Device
7 |
8 |
9 | def _tensor_check_image(image: torch.Tensor) -> None:
10 | """#### Check if the input is a valid tensor image.
11 |
12 | #### Args:
13 | - `image` (torch.Tensor): The input tensor image.
14 | """
15 | return
16 |
17 |
18 | def tensor2pil(image: torch.Tensor) -> Image.Image:
19 | """#### Convert a tensor to a PIL image.
20 |
21 | #### Args:
22 | - `image` (torch.Tensor): The input tensor.
23 |
24 | #### Returns:
25 | - `Image.Image`: The converted PIL image.
26 | """
27 | _tensor_check_image(image)
28 | return Image.fromarray(
29 | np.clip(255.0 * image.cpu().numpy().squeeze(0), 0, 255).astype(np.uint8)
30 | )
31 |
32 |
33 | def general_tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor:
34 | """#### Resize a tensor image using bilinear interpolation.
35 |
36 | #### Args:
37 | - `image` (torch.Tensor): The input tensor image.
38 | - `w` (int): The target width.
39 | - `h` (int): The target height.
40 |
41 | #### Returns:
42 | - `torch.Tensor`: The resized tensor image.
43 | """
44 | _tensor_check_image(image)
45 | image = image.permute(0, 3, 1, 2)
46 | image = torch.nn.functional.interpolate(image, size=(h, w), mode="bilinear")
47 | image = image.permute(0, 2, 3, 1)
48 | return image
49 |
50 |
51 | def pil2tensor(image: Image.Image) -> torch.Tensor:
52 | """#### Convert a PIL image to a tensor.
53 |
54 | #### Args:
55 | - `image` (Image.Image): The input PIL image.
56 |
57 | #### Returns:
58 | - `torch.Tensor`: The converted tensor.
59 | """
60 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
61 |
62 |
63 | class TensorBatchBuilder:
64 | """#### Class for building a batch of tensors."""
65 |
66 | def __init__(self):
67 | self.tensor: torch.Tensor | None = None
68 |
69 | def concat(self, new_tensor: torch.Tensor) -> None:
70 | """#### Concatenate a new tensor to the batch.
71 |
72 | #### Args:
73 | - `new_tensor` (torch.Tensor): The new tensor to concatenate.
74 | """
75 | self.tensor = new_tensor
76 |
77 |
78 | LANCZOS = Image.Resampling.LANCZOS if hasattr(Image, "Resampling") else Image.LANCZOS
79 |
80 |
81 | def tensor_resize(image: torch.Tensor, w: int, h: int) -> torch.Tensor:
82 | """#### Resize a tensor image.
83 |
84 | #### Args:
85 | - `image` (torch.Tensor): The input tensor image.
86 | - `w` (int): The target width.
87 | - `h` (int): The target height.
88 |
89 | #### Returns:
90 | - `torch.Tensor`: The resized tensor image.
91 | """
92 | _tensor_check_image(image)
93 | if image.shape[3] >= 3:
94 | scaled_images = TensorBatchBuilder()
95 | for single_image in image:
96 | single_image = single_image.unsqueeze(0)
97 | single_pil = tensor2pil(single_image)
98 | scaled_pil = single_pil.resize((w, h), resample=LANCZOS)
99 |
100 | single_image = pil2tensor(scaled_pil)
101 | scaled_images.concat(single_image)
102 |
103 | return scaled_images.tensor
104 | else:
105 | return general_tensor_resize(image, w, h)
106 |
107 |
108 | def tensor_paste(
109 | image1: torch.Tensor,
110 | image2: torch.Tensor,
111 | left_top: tuple[int, int],
112 | mask: torch.Tensor,
113 | ) -> None:
114 | """#### Paste one tensor image onto another using a mask.
115 |
116 | #### Args:
117 | - `image1` (torch.Tensor): The base tensor image.
118 | - `image2` (torch.Tensor): The tensor image to paste.
119 | - `left_top` (tuple[int, int]): The top-left corner where the image2 will be pasted.
120 | - `mask` (torch.Tensor): The mask tensor.
121 | """
122 | _tensor_check_image(image1)
123 | _tensor_check_image(image2)
124 | _tensor_check_mask(mask)
125 |
126 | x, y = left_top
127 | _, h1, w1, _ = image1.shape
128 | _, h2, w2, _ = image2.shape
129 |
130 | # calculate image patch size
131 | w = min(w1, x + w2) - x
132 | h = min(h1, y + h2) - y
133 |
134 | mask = mask[:, :h, :w, :]
135 | image1[:, y : y + h, x : x + w, :] = (1 - mask) * image1[
136 | :, y : y + h, x : x + w, :
137 | ] + mask * image2[:, :h, :w, :]
138 | return
139 |
140 |
141 | def tensor_convert_rgba(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor:
142 | """#### Convert a tensor image to RGBA format.
143 |
144 | #### Args:
145 | - `image` (torch.Tensor): The input tensor image.
146 | - `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True.
147 |
148 | #### Returns:
149 | - `torch.Tensor`: The converted RGBA tensor image.
150 | """
151 | _tensor_check_image(image)
152 | alpha = torch.ones((*image.shape[:-1], 1))
153 | return torch.cat((image, alpha), axis=-1)
154 |
155 |
156 | def tensor_convert_rgb(image: torch.Tensor, prefer_copy: bool = True) -> torch.Tensor:
157 | """#### Convert a tensor image to RGB format.
158 |
159 | #### Args:
160 | - `image` (torch.Tensor): The input tensor image.
161 | - `prefer_copy` (bool, optional): Whether to prefer copying the tensor. Defaults to True.
162 |
163 | #### Returns:
164 | - `torch.Tensor`: The converted RGB tensor image.
165 | """
166 | _tensor_check_image(image)
167 | return image
168 |
169 |
170 | def tensor_get_size(image: torch.Tensor) -> tuple[int, int]:
171 | """#### Get the size of a tensor image.
172 |
173 | #### Args:
174 | - `image` (torch.Tensor): The input tensor image.
175 |
176 | #### Returns:
177 | - `tuple[int, int]`: The width and height of the tensor image.
178 | """
179 | _tensor_check_image(image)
180 | _, h, w, _ = image.shape
181 | return (w, h)
182 |
183 |
184 | def tensor_putalpha(image: torch.Tensor, mask: torch.Tensor) -> None:
185 | """#### Add an alpha channel to a tensor image using a mask.
186 |
187 | #### Args:
188 | - `image` (torch.Tensor): The input tensor image.
189 | - `mask` (torch.Tensor): The mask tensor.
190 | """
191 | _tensor_check_image(image)
192 | _tensor_check_mask(mask)
193 | image[..., -1] = mask[..., 0]
194 |
195 |
196 | def _tensor_check_mask(mask: torch.Tensor) -> None:
197 | """#### Check if the input is a valid tensor mask.
198 |
199 | #### Args:
200 | - `mask` (torch.Tensor): The input tensor mask.
201 | """
202 | return
203 |
204 |
205 | def tensor_gaussian_blur_mask(
206 | mask: torch.Tensor | np.ndarray, kernel_size: int, sigma: float = 10.0
207 | ) -> torch.Tensor:
208 | """#### Apply Gaussian blur to a tensor mask.
209 |
210 | #### Args:
211 | - `mask` (torch.Tensor | np.ndarray): The input tensor mask.
212 | - `kernel_size` (int): The size of the Gaussian kernel.
213 | - `sigma` (float, optional): The standard deviation of the Gaussian kernel. Defaults to 10.0.
214 |
215 | #### Returns:
216 | - `torch.Tensor`: The blurred tensor mask.
217 | """
218 | if isinstance(mask, np.ndarray):
219 | mask = torch.from_numpy(mask)
220 |
221 | if mask.ndim == 2:
222 | mask = mask[None, ..., None]
223 |
224 | _tensor_check_mask(mask)
225 |
226 | kernel_size = kernel_size * 2 + 1
227 |
228 | prev_device = mask.device
229 | device = Device.get_torch_device()
230 | mask.to(device)
231 |
232 | # apply gaussian blur
233 | mask = mask[:, None, ..., 0]
234 | blurred_mask = torchvision.transforms.GaussianBlur(
235 | kernel_size=kernel_size, sigma=sigma
236 | )(mask)
237 | blurred_mask = blurred_mask[:, 0, ..., None]
238 |
239 | blurred_mask.to(prev_device)
240 |
241 | return blurred_mask
242 |
243 |
244 | def to_tensor(image: np.ndarray) -> torch.Tensor:
245 | """#### Convert a numpy array to a tensor.
246 |
247 | #### Args:
248 | - `image` (np.ndarray): The input numpy array.
249 |
250 | #### Returns:
251 | - `torch.Tensor`: The converted tensor.
252 | """
253 | return torch.from_numpy(image)
254 |
--------------------------------------------------------------------------------
/modules/AutoHDR/ahdr.py:
--------------------------------------------------------------------------------
1 | # Taken and adapted from https://github.com/SuperBeastsAI/ComfyUI-SuperBeasts
2 |
3 | import numpy as np
4 | from PIL import Image, ImageOps, ImageDraw, ImageFilter, ImageEnhance, ImageCms
5 | from PIL.PngImagePlugin import PngInfo
6 | import torch
7 | import torch.nn.functional as F
8 | import json
9 | import random
10 |
11 |
12 | sRGB_profile = ImageCms.createProfile("sRGB")
13 | Lab_profile = ImageCms.createProfile("LAB")
14 |
15 | # Tensor to PIL
16 | def tensor2pil(image):
17 | return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
18 |
19 | # PIL to Tensor
20 | def pil2tensor(image):
21 | return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
22 |
23 | def adjust_shadows_non_linear(luminance, shadow_intensity, max_shadow_adjustment=1.5):
24 | lum_array = np.array(luminance, dtype=np.float32) / 255.0 # Normalize
25 | # Apply a non-linear darkening effect based on shadow_intensity
26 | shadows = lum_array ** (1 / (1 + shadow_intensity * max_shadow_adjustment))
27 | return np.clip(shadows * 255, 0, 255).astype(np.uint8) # Re-scale to [0, 255]
28 |
29 | def adjust_highlights_non_linear(luminance, highlight_intensity, max_highlight_adjustment=1.5):
30 | lum_array = np.array(luminance, dtype=np.float32) / 255.0 # Normalize
31 | # Brighten highlights more aggressively based on highlight_intensity
32 | highlights = 1 - (1 - lum_array) ** (1 + highlight_intensity * max_highlight_adjustment)
33 | return np.clip(highlights * 255, 0, 255).astype(np.uint8) # Re-scale to [0, 255]
34 |
35 | def merge_adjustments_with_blend_modes(luminance, shadows, highlights, hdr_intensity, shadow_intensity, highlight_intensity):
36 | # Ensure the data is in the correct format for processing
37 | base = np.array(luminance, dtype=np.float32)
38 |
39 | # Scale the adjustments based on hdr_intensity
40 | scaled_shadow_intensity = shadow_intensity ** 2 * hdr_intensity
41 | scaled_highlight_intensity = highlight_intensity ** 2 * hdr_intensity
42 |
43 | # Create luminance-based masks for shadows and highlights
44 | shadow_mask = np.clip((1 - (base / 255)) ** 2, 0, 1)
45 | highlight_mask = np.clip((base / 255) ** 2, 0, 1)
46 |
47 | # Apply the adjustments using the masks
48 | adjusted_shadows = np.clip(base * (1 - shadow_mask * scaled_shadow_intensity), 0, 255)
49 | adjusted_highlights = np.clip(base + (255 - base) * highlight_mask * scaled_highlight_intensity, 0, 255)
50 |
51 | # Combine the adjusted shadows and highlights
52 | adjusted_luminance = np.clip(adjusted_shadows + adjusted_highlights - base, 0, 255)
53 |
54 | # Blend the adjusted luminance with the original luminance based on hdr_intensity
55 | final_luminance = np.clip(base * (1 - hdr_intensity) + adjusted_luminance * hdr_intensity, 0, 255).astype(np.uint8)
56 |
57 | return Image.fromarray(final_luminance)
58 |
59 | def apply_gamma_correction(lum_array, gamma):
60 | """
61 | Apply gamma correction to the luminance array.
62 | :param lum_array: Luminance channel as a NumPy array.
63 | :param gamma: Gamma value for correction.
64 | """
65 | if gamma == 0:
66 | return np.clip(lum_array, 0, 255).astype(np.uint8)
67 |
68 | epsilon = 1e-7 # Small value to avoid dividing by zero
69 | gamma_corrected = 1 / (1.1 - gamma)
70 | adjusted = 255 * ((lum_array / 255) ** gamma_corrected)
71 | return np.clip(adjusted, 0, 255).astype(np.uint8)
72 |
73 | # create a wrapper function that can apply a function to multiple images in a batch while passing all other arguments to the function
74 | def apply_to_batch(func):
75 | def wrapper(self, image, *args, **kwargs):
76 | images = []
77 | for img in image:
78 | images.append(func(self, img, *args, **kwargs))
79 | batch_tensor = torch.cat(images, dim=0)
80 | return (batch_tensor, )
81 | return wrapper
82 |
83 | class HDREffects:
84 | @apply_to_batch
85 | def apply_hdr2(self, image, hdr_intensity=0.75, shadow_intensity=0.25, highlight_intensity=0.5, gamma_intensity=0.25, contrast=0.1, enhance_color=0.25):
86 | # Load the image
87 | img = tensor2pil(image)
88 |
89 | # Step 1: Convert RGB to LAB for better color preservation
90 | img_lab = ImageCms.profileToProfile(img, sRGB_profile, Lab_profile, outputMode='LAB')
91 |
92 | # Extract L, A, and B channels
93 | luminance, a, b = img_lab.split()
94 |
95 | # Convert luminance to a NumPy array for processing
96 | lum_array = np.array(luminance, dtype=np.float32)
97 |
98 | # Preparing adjustment layers (shadows, midtones, highlights)
99 | # This example assumes you have methods to extract or calculate these adjustments
100 | shadows_adjusted = adjust_shadows_non_linear(luminance, shadow_intensity)
101 | highlights_adjusted = adjust_highlights_non_linear(luminance, highlight_intensity)
102 |
103 |
104 | merged_adjustments = merge_adjustments_with_blend_modes(lum_array, shadows_adjusted, highlights_adjusted, hdr_intensity, shadow_intensity, highlight_intensity)
105 |
106 | # Apply gamma correction with a base_gamma value (define based on desired effect)
107 | gamma_corrected = apply_gamma_correction(np.array(merged_adjustments), gamma_intensity)
108 | gamma_corrected = Image.fromarray(gamma_corrected).resize(a.size)
109 |
110 |
111 | # Merge L channel back with original A and B channels
112 | adjusted_lab = Image.merge('LAB', (gamma_corrected, a, b))
113 |
114 | # Step 3: Convert LAB back to RGB
115 | img_adjusted = ImageCms.profileToProfile(adjusted_lab, Lab_profile, sRGB_profile, outputMode='RGB')
116 |
117 |
118 | # Enhance contrast
119 | enhancer = ImageEnhance.Contrast(img_adjusted)
120 | contrast_adjusted = enhancer.enhance(1 + contrast)
121 |
122 |
123 | # Enhance color saturation
124 | enhancer = ImageEnhance.Color(contrast_adjusted)
125 | color_adjusted = enhancer.enhance(1 + enhance_color * 0.2)
126 |
127 | return pil2tensor(color_adjusted)
--------------------------------------------------------------------------------
/modules/FileManaging/Downloader.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from huggingface_hub import hf_hub_download
3 |
4 |
5 | def CheckAndDownload():
6 | """#### Check and download all the necessary safetensors and checkpoints models"""
7 | if glob.glob("./_internal/checkpoints/*.safetensors") == []:
8 |
9 | hf_hub_download(
10 | repo_id="Meina/MeinaMix",
11 | filename="Meina V10 - baked VAE.safetensors",
12 | local_dir="./_internal/checkpoints/",
13 | )
14 | hf_hub_download(
15 | repo_id="Lykon/DreamShaper",
16 | filename="DreamShaper_8_pruned.safetensors",
17 | local_dir="./_internal/checkpoints/",
18 | )
19 | if glob.glob("./_internal/yolos/*.pt") == []:
20 |
21 | hf_hub_download(
22 | repo_id="Bingsu/adetailer",
23 | filename="hand_yolov9c.pt",
24 | local_dir="./_internal/yolos/",
25 | )
26 | hf_hub_download(
27 | repo_id="Bingsu/adetailer",
28 | filename="face_yolov9c.pt",
29 | local_dir="./_internal/yolos/",
30 | )
31 | hf_hub_download(
32 | repo_id="Bingsu/adetailer",
33 | filename="person_yolov8m-seg.pt",
34 | local_dir="./_internal/yolos/",
35 | )
36 | hf_hub_download(
37 | repo_id="segments-arnaud/sam_vit_b",
38 | filename="sam_vit_b_01ec64.pth",
39 | local_dir="./_internal/yolos/",
40 | )
41 | if glob.glob("./_internal/ESRGAN/*.pth") == []:
42 |
43 | hf_hub_download(
44 | repo_id="lllyasviel/Annotators",
45 | filename="RealESRGAN_x4plus.pth",
46 | local_dir="./_internal/ESRGAN/",
47 | )
48 | if glob.glob("./_internal/loras/*.safetensors") == []:
49 |
50 | hf_hub_download(
51 | repo_id="EvilEngine/add_detail",
52 | filename="add_detail.safetensors",
53 | local_dir="./_internal/loras/",
54 | )
55 | if glob.glob("./_internal/embeddings/*.pt") == []:
56 |
57 | hf_hub_download(
58 | repo_id="EvilEngine/badhandv4",
59 | filename="badhandv4.pt",
60 | local_dir="./_internal/embeddings/",
61 | )
62 | # hf_hub_download(
63 | # repo_id="segments-arnaud/sam_vit_b",
64 | # filename="EasyNegative.safetensors",
65 | # local_dir="./_internal/embeddings/",
66 | # )
67 | if glob.glob("./_internal/vae_approx/*.pth") == []:
68 |
69 | hf_hub_download(
70 | repo_id="madebyollin/taesd",
71 | filename="taesd_decoder.safetensors",
72 | local_dir="./_internal/vae_approx/",
73 | )
74 |
75 | def CheckAndDownloadFlux():
76 | """#### Check and download all the necessary safetensors and checkpoints models for FLUX"""
77 | if glob.glob("./_internal/embeddings/*.pt") == []:
78 | hf_hub_download(
79 | repo_id="EvilEngine/badhandv4",
80 | filename="badhandv4.pt",
81 | local_dir="./_internal/embeddings",
82 | )
83 | if glob.glob("./_internal/unet/*.gguf") == []:
84 |
85 | hf_hub_download(
86 | repo_id="city96/FLUX.1-dev-gguf",
87 | filename="flux1-dev-Q8_0.gguf",
88 | local_dir="./_internal/unet",
89 | )
90 | if glob.glob("./_internal/clip/*.gguf") == []:
91 |
92 | hf_hub_download(
93 | repo_id="city96/t5-v1_1-xxl-encoder-gguf",
94 | filename="t5-v1_1-xxl-encoder-Q8_0.gguf",
95 | local_dir="./_internal/clip",
96 | )
97 | hf_hub_download(
98 | repo_id="comfyanonymous/flux_text_encoders",
99 | filename="clip_l.safetensors",
100 | local_dir="./_internal/clip",
101 | )
102 | if glob.glob("./_internal/vae/*.safetensors") == []:
103 |
104 | hf_hub_download(
105 | repo_id="black-forest-labs/FLUX.1-schnell",
106 | filename="ae.safetensors",
107 | local_dir="./_internal/vae",
108 | )
109 |
110 | if glob.glob("./_internal/vae_approx/*.pth") == []:
111 |
112 | hf_hub_download(
113 | repo_id="madebyollin/taef1",
114 | filename="diffusion_pytorch_model.safetensors",
115 | local_dir="./_internal/vae_approx/",
116 | )
117 |
--------------------------------------------------------------------------------
/modules/FileManaging/ImageSaver.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from PIL import Image
4 |
5 | output_directory = "./_internal/output"
6 |
7 |
8 | def get_output_directory() -> str:
9 | """#### Get the output directory.
10 |
11 | #### Returns:
12 | - `str`: The output directory.
13 | """
14 | global output_directory
15 | return output_directory
16 |
17 |
18 | def get_save_image_path(
19 | filename_prefix: str, output_dir: str, image_width: int = 0, image_height: int = 0
20 | ) -> tuple:
21 | """#### Get the save image path.
22 |
23 | #### Args:
24 | - `filename_prefix` (str): The filename prefix.
25 | - `output_dir` (str): The output directory.
26 | - `image_width` (int, optional): The image width. Defaults to 0.
27 | - `image_height` (int, optional): The image height. Defaults to 0.
28 |
29 | #### Returns:
30 | - `tuple`: The full output folder, filename, counter, subfolder, and filename prefix.
31 | """
32 |
33 | def map_filename(filename: str) -> tuple:
34 | prefix_len = len(os.path.basename(filename_prefix))
35 | prefix = filename[: prefix_len + 1]
36 | try:
37 | digits = int(filename[prefix_len + 1 :].split("_")[0])
38 | except:
39 | digits = 0
40 | return (digits, prefix)
41 |
42 | def compute_vars(input: str, image_width: int, image_height: int) -> str:
43 | input = input.replace("%width%", str(image_width))
44 | input = input.replace("%height%", str(image_height))
45 | return input
46 |
47 | filename_prefix = compute_vars(filename_prefix, image_width, image_height)
48 |
49 | subfolder = os.path.dirname(os.path.normpath(filename_prefix))
50 | filename = os.path.basename(os.path.normpath(filename_prefix))
51 |
52 | full_output_folder = os.path.join(output_dir, subfolder)
53 | subfolder_paths = [
54 | os.path.join(full_output_folder, x)
55 | for x in ["Classic", "HiresFix", "Img2Img", "Flux", "Adetailer"]
56 | ]
57 | for path in subfolder_paths:
58 | os.makedirs(path, exist_ok=True)
59 | # Find highest counter across all subfolders
60 | counter = 1
61 | for path in subfolder_paths:
62 | if os.path.exists(path):
63 | files = os.listdir(path)
64 | if files:
65 | numbers = [
66 | map_filename(f)[0]
67 | for f in files
68 | if f.startswith(filename) and f.endswith(".png")
69 | ]
70 | if numbers:
71 | counter = max(max(numbers) + 1, counter)
72 |
73 | return full_output_folder, filename, counter, subfolder, filename_prefix
74 |
75 |
76 | MAX_RESOLUTION = 16384
77 |
78 |
79 | class SaveImage:
80 | """#### Class for saving images."""
81 |
82 | def __init__(self):
83 | """#### Initialize the SaveImage class."""
84 | self.output_dir = get_output_directory()
85 | self.type = "output"
86 | self.prefix_append = ""
87 | self.compress_level = 4
88 |
89 | def save_images(
90 | self,
91 | images: list,
92 | filename_prefix: str = "LD",
93 | prompt: str = None,
94 | extra_pnginfo: dict = None,
95 | ) -> dict:
96 | """#### Save images to the output directory.
97 |
98 | #### Args:
99 | - `images` (list): The list of images.
100 | - `filename_prefix` (str, optional): The filename prefix. Defaults to "LD".
101 | - `prompt` (str, optional): The prompt. Defaults to None.
102 | - `extra_pnginfo` (dict, optional): Additional PNG info. Defaults to None.
103 |
104 | #### Returns:
105 | - `dict`: The saved images information.
106 | """
107 | filename_prefix += self.prefix_append
108 | full_output_folder, filename, counter, subfolder, filename_prefix = (
109 | get_save_image_path(
110 | filename_prefix, self.output_dir, images[0].shape[-2], images[0].shape[-1]
111 | )
112 | )
113 | results = list()
114 | for batch_number, image in enumerate(images):
115 | # Ensure correct shape by squeezing extra dimensions
116 | i = 255.0 * image.cpu().numpy()
117 | i = np.squeeze(i) # Remove extra dimensions
118 |
119 | # Ensure we have a valid 3D array (height, width, channels)
120 | if i.ndim == 4:
121 | i = i.reshape(-1, i.shape[-2], i.shape[-1])
122 |
123 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
124 | metadata = None
125 |
126 | filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
127 | file = f"{filename_with_batch_num}_{counter:05}_.png"
128 | if filename_prefix == "LD-HF":
129 | full_output_folder = os.path.join(full_output_folder, "HiresFix")
130 | elif filename_prefix == "LD-I2I":
131 | full_output_folder = os.path.join(full_output_folder, "Img2Img")
132 | elif filename_prefix == "LD-Flux":
133 | full_output_folder = os.path.join(full_output_folder, "Flux")
134 | elif filename_prefix == "LD-head" or filename_prefix == "LD-body":
135 | full_output_folder = os.path.join(full_output_folder, "Adetailer")
136 | else:
137 | full_output_folder = os.path.join(full_output_folder, "Classic")
138 | img.save(
139 | os.path.join(full_output_folder, file),
140 | pnginfo=metadata,
141 | compress_level=self.compress_level,
142 | )
143 | results.append(
144 | {"filename": file, "subfolder": subfolder, "type": self.type}
145 | )
146 | counter += 1
147 |
148 | return {"ui": {"images": results}}
--------------------------------------------------------------------------------
/modules/FileManaging/Loader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from modules.Utilities import util
4 | from modules.AutoEncoders import VariationalAE
5 | from modules.Device import Device
6 | from modules.Model import ModelPatcher
7 | from modules.NeuralNetwork import unet
8 | from modules.clip import Clip
9 |
10 |
11 | def load_checkpoint_guess_config(
12 | ckpt_path: str,
13 | output_vae: bool = True,
14 | output_clip: bool = True,
15 | output_clipvision: bool = False,
16 | embedding_directory: str = None,
17 | output_model: bool = True,
18 | ) -> tuple:
19 | """#### Load a checkpoint and guess the configuration.
20 |
21 | #### Args:
22 | - `ckpt_path` (str): The path to the checkpoint file.
23 | - `output_vae` (bool, optional): Whether to output the VAE. Defaults to True.
24 | - `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True.
25 | - `output_clipvision` (bool, optional): Whether to output the CLIP vision. Defaults to False.
26 | - `embedding_directory` (str, optional): The embedding directory. Defaults to None.
27 | - `output_model` (bool, optional): Whether to output the model. Defaults to True.
28 |
29 | #### Returns:
30 | - `tuple`: The model patcher, CLIP, VAE, and CLIP vision.
31 | """
32 | sd = util.load_torch_file(ckpt_path)
33 | sd.keys()
34 | clip = None
35 | clipvision = None
36 | vae = None
37 | model = None
38 | model_patcher = None
39 | clip_target = None
40 |
41 | parameters = util.calculate_parameters(sd, "model.diffusion_model.")
42 | load_device = Device.get_torch_device()
43 |
44 | model_config = unet.model_config_from_unet(sd, "model.diffusion_model.")
45 | unet_dtype = unet.unet_dtype1(
46 | model_params=parameters,
47 | supported_dtypes=model_config.supported_inference_dtypes,
48 | )
49 | manual_cast_dtype = Device.unet_manual_cast(
50 | unet_dtype, load_device, model_config.supported_inference_dtypes
51 | )
52 | model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
53 |
54 | if output_model:
55 | inital_load_device = Device.unet_inital_load_device(parameters, unet_dtype)
56 | Device.unet_offload_device()
57 | model = model_config.get_model(
58 | sd, "model.diffusion_model.", device=inital_load_device
59 | )
60 | model.load_model_weights(sd, "model.diffusion_model.")
61 |
62 | if output_vae:
63 | vae_sd = util.state_dict_prefix_replace(
64 | sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True
65 | )
66 | vae_sd = model_config.process_vae_state_dict(vae_sd)
67 | vae = VariationalAE.VAE(sd=vae_sd)
68 |
69 | if output_clip:
70 | clip_target = model_config.clip_target()
71 | if clip_target is not None:
72 | clip_sd = model_config.process_clip_state_dict(sd)
73 | if len(clip_sd) > 0:
74 | clip = Clip.CLIP(clip_target, embedding_directory=embedding_directory)
75 | m, u = clip.load_sd(clip_sd, full_model=True)
76 | if len(m) > 0:
77 | m_filter = list(
78 | filter(
79 | lambda a: ".logit_scale" not in a
80 | and ".transformer.text_projection.weight" not in a,
81 | m,
82 | )
83 | )
84 | if len(m_filter) > 0:
85 | logging.warning("clip missing: {}".format(m))
86 | else:
87 | logging.debug("clip missing: {}".format(m))
88 |
89 | if len(u) > 0:
90 | logging.debug("clip unexpected {}:".format(u))
91 | else:
92 | logging.warning(
93 | "no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded."
94 | )
95 |
96 | left_over = sd.keys()
97 | if len(left_over) > 0:
98 | logging.debug("left over keys: {}".format(left_over))
99 |
100 | if output_model:
101 | model_patcher = ModelPatcher.ModelPatcher(
102 | model,
103 | load_device=load_device,
104 | offload_device=Device.unet_offload_device(),
105 | current_device=inital_load_device,
106 | )
107 | if inital_load_device != torch.device("cpu"):
108 | logging.info("loaded straight to GPU")
109 | Device.load_model_gpu(model_patcher)
110 |
111 | return (model_patcher, clip, vae, clipvision)
112 |
113 |
114 | class CheckpointLoaderSimple:
115 | """#### Class for loading checkpoints."""
116 |
117 | def load_checkpoint(
118 | self, ckpt_name: str, output_vae: bool = True, output_clip: bool = True
119 | ) -> tuple:
120 | """#### Load a checkpoint.
121 |
122 | #### Args:
123 | - `ckpt_name` (str): The name of the checkpoint.
124 | - `output_vae` (bool, optional): Whether to output the VAE. Defaults to True.
125 | - `output_clip` (bool, optional): Whether to output the CLIP. Defaults to True.
126 |
127 | #### Returns:
128 | - `tuple`: The model patcher, CLIP, and VAE.
129 | """
130 | ckpt_path = f"{ckpt_name}"
131 | out = load_checkpoint_guess_config(
132 | ckpt_path,
133 | output_vae=output_vae,
134 | output_clip=output_clip,
135 | embedding_directory="./_internal/embeddings/",
136 | )
137 | print("loading", ckpt_path)
138 | return out[:3]
139 |
--------------------------------------------------------------------------------
/modules/Model/LoRas.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules.Utilities import util
3 | from modules.NeuralNetwork import unet
4 |
5 | LORA_CLIP_MAP = {
6 | "mlp.fc1": "mlp_fc1",
7 | "mlp.fc2": "mlp_fc2",
8 | "self_attn.k_proj": "self_attn_k_proj",
9 | "self_attn.q_proj": "self_attn_q_proj",
10 | "self_attn.v_proj": "self_attn_v_proj",
11 | "self_attn.out_proj": "self_attn_out_proj",
12 | }
13 |
14 |
15 | def load_lora(lora: dict, to_load: dict) -> dict:
16 | """#### Load a LoRA model.
17 |
18 | #### Args:
19 | - `lora` (dict): The LoRA model state dictionary.
20 | - `to_load` (dict): The keys to load from the LoRA model.
21 |
22 | #### Returns:
23 | - `dict`: The loaded LoRA model.
24 | """
25 | patch_dict = {}
26 | loaded_keys = set()
27 | for x in to_load:
28 | alpha_name = "{}.alpha".format(x)
29 | alpha = None
30 | if alpha_name in lora.keys():
31 | alpha = lora[alpha_name].item()
32 | loaded_keys.add(alpha_name)
33 |
34 | "{}.dora_scale".format(x)
35 | dora_scale = None
36 |
37 | regular_lora = "{}.lora_up.weight".format(x)
38 | "{}_lora.up.weight".format(x)
39 | "{}.lora_linear_layer.up.weight".format(x)
40 | A_name = None
41 |
42 | if regular_lora in lora.keys():
43 | A_name = regular_lora
44 | B_name = "{}.lora_down.weight".format(x)
45 | "{}.lora_mid.weight".format(x)
46 |
47 | if A_name is not None:
48 | mid = None
49 | patch_dict[to_load[x]] = (
50 | "lora",
51 | (lora[A_name], lora[B_name], alpha, mid, dora_scale),
52 | )
53 | loaded_keys.add(A_name)
54 | loaded_keys.add(B_name)
55 | return patch_dict
56 |
57 |
58 | def model_lora_keys_clip(model: torch.nn.Module, key_map: dict = {}) -> dict:
59 | """#### Get the keys for a LoRA model's CLIP component.
60 |
61 | #### Args:
62 | - `model` (torch.nn.Module): The LoRA model.
63 | - `key_map` (dict, optional): The key map. Defaults to {}.
64 |
65 | #### Returns:
66 | - `dict`: The keys for the CLIP component.
67 | """
68 | sdk = model.state_dict().keys()
69 |
70 | text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
71 | for b in range(32):
72 | for c in LORA_CLIP_MAP:
73 | k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
74 | if k in sdk:
75 | lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
76 | key_map[lora_key] = k
77 | lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(
78 | b, LORA_CLIP_MAP[c]
79 | ) # SDXL base
80 | key_map[lora_key] = k
81 | lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(
82 | b, c
83 | ) # diffusers lora
84 | key_map[lora_key] = k
85 | return key_map
86 |
87 |
88 | def model_lora_keys_unet(model: torch.nn.Module, key_map: dict = {}) -> dict:
89 | """#### Get the keys for a LoRA model's UNet component.
90 |
91 | #### Args:
92 | - `model` (torch.nn.Module): The LoRA model.
93 | - `key_map` (dict, optional): The key map. Defaults to {}.
94 |
95 | #### Returns:
96 | - `dict`: The keys for the UNet component.
97 | """
98 | sdk = model.state_dict().keys()
99 |
100 | for k in sdk:
101 | if k.startswith("diffusion_model.") and k.endswith(".weight"):
102 | key_lora = k[len("diffusion_model.") : -len(".weight")].replace(".", "_")
103 | key_map["lora_unet_{}".format(key_lora)] = k
104 | key_map["lora_prior_unet_{}".format(key_lora)] = k # cascade lora:
105 |
106 | diffusers_keys = unet.unet_to_diffusers(model.model_config.unet_config)
107 | for k in diffusers_keys:
108 | if k.endswith(".weight"):
109 | unet_key = "diffusion_model.{}".format(diffusers_keys[k])
110 | key_lora = k[: -len(".weight")].replace(".", "_")
111 | key_map["lora_unet_{}".format(key_lora)] = unet_key
112 |
113 | diffusers_lora_prefix = ["", "unet."]
114 | for p in diffusers_lora_prefix:
115 | diffusers_lora_key = "{}{}".format(
116 | p, k[: -len(".weight")].replace(".to_", ".processor.to_")
117 | )
118 | if diffusers_lora_key.endswith(".to_out.0"):
119 | diffusers_lora_key = diffusers_lora_key[:-2]
120 | key_map[diffusers_lora_key] = unet_key
121 | return key_map
122 |
123 |
124 | def load_lora_for_models(
125 | model: object, clip: object, lora: dict, strength_model: float, strength_clip: float
126 | ) -> tuple:
127 | """#### Load a LoRA model for the given models.
128 |
129 | #### Args:
130 | - `model` (object): The model.
131 | - `clip` (object): The CLIP model.
132 | - `lora` (dict): The LoRA model state dictionary.
133 | - `strength_model` (float): The strength of the model.
134 | - `strength_clip` (float): The strength of the CLIP model.
135 |
136 | #### Returns:
137 | - `tuple`: The new model patcher and CLIP model.
138 | """
139 | key_map = {}
140 | if model is not None:
141 | key_map = model_lora_keys_unet(model.model, key_map)
142 | if clip is not None:
143 | key_map = model_lora_keys_clip(clip.cond_stage_model, key_map)
144 |
145 | loaded = load_lora(lora, key_map)
146 | new_modelpatcher = model.clone()
147 | k = new_modelpatcher.add_patches(loaded, strength_model)
148 |
149 | new_clip = clip.clone()
150 | k1 = new_clip.add_patches(loaded, strength_clip)
151 | k = set(k)
152 | k1 = set(k1)
153 |
154 | return (new_modelpatcher, new_clip)
155 |
156 |
157 | class LoraLoader:
158 | """#### Class for loading LoRA models."""
159 |
160 | def __init__(self):
161 | """#### Initialize the LoraLoader class."""
162 | self.loaded_lora = None
163 |
164 | def load_lora(
165 | self,
166 | model: object,
167 | clip: object,
168 | lora_name: str,
169 | strength_model: float,
170 | strength_clip: float,
171 | ) -> tuple:
172 | """#### Load a LoRA model.
173 |
174 | #### Args:
175 | - `model` (object): The model.
176 | - `clip` (object): The CLIP model.
177 | - `lora_name` (str): The name of the LoRA model.
178 | - `strength_model` (float): The strength of the model.
179 | - `strength_clip` (float): The strength of the CLIP model.
180 |
181 | #### Returns:
182 | - `tuple`: The new model patcher and CLIP model.
183 | """
184 | lora_path = util.get_full_path("loras", lora_name)
185 | lora = None
186 | if lora is None:
187 | lora = util.load_torch_file(lora_path, safe_load=True)
188 | self.loaded_lora = (lora_path, lora)
189 |
190 | model_lora, clip_lora = load_lora_for_models(
191 | model, clip, lora, strength_model, strength_clip
192 | )
193 | return (model_lora, clip_lora)
194 |
--------------------------------------------------------------------------------
/modules/SD15/SD15.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from modules.BlackForest import Flux
3 | from modules.Utilities import util
4 | from modules.Model import ModelBase
5 | from modules.SD15 import SDClip, SDToken
6 | from modules.Utilities import Latent
7 | from modules.clip import Clip
8 |
9 |
10 | class sm_SD15(ModelBase.BASE):
11 | """#### Class representing the SD15 model.
12 |
13 | #### Args:
14 | - `ModelBase.BASE` (ModelBase.BASE): The base model class.
15 | """
16 |
17 | unet_config: dict = {
18 | "context_dim": 768,
19 | "model_channels": 320,
20 | "use_linear_in_transformer": False,
21 | "adm_in_channels": None,
22 | "use_temporal_attention": False,
23 | }
24 |
25 | unet_extra_config: dict = {
26 | "num_heads": 8,
27 | "num_head_channels": -1,
28 | }
29 |
30 | latent_format: Latent.SD15 = Latent.SD15
31 |
32 | def process_clip_state_dict(self, state_dict: dict) -> dict:
33 | """#### Process the state dictionary for the CLIP model.
34 |
35 | #### Args:
36 | - `state_dict` (dict): The state dictionary.
37 |
38 | #### Returns:
39 | - `dict`: The processed state dictionary.
40 | """
41 | k = list(state_dict.keys())
42 | for x in k:
43 | if x.startswith("cond_stage_model.transformer.") and not x.startswith(
44 | "cond_stage_model.transformer.text_model."
45 | ):
46 | y = x.replace(
47 | "cond_stage_model.transformer.",
48 | "cond_stage_model.transformer.text_model.",
49 | )
50 | state_dict[y] = state_dict.pop(x)
51 |
52 | if (
53 | "cond_stage_model.transformer.text_model.embeddings.position_ids"
54 | in state_dict
55 | ):
56 | ids = state_dict[
57 | "cond_stage_model.transformer.text_model.embeddings.position_ids"
58 | ]
59 | if ids.dtype == torch.float32:
60 | state_dict[
61 | "cond_stage_model.transformer.text_model.embeddings.position_ids"
62 | ] = ids.round()
63 |
64 | replace_prefix = {}
65 | replace_prefix["cond_stage_model."] = "clip_l."
66 | state_dict = util.state_dict_prefix_replace(
67 | state_dict, replace_prefix, filter_keys=True
68 | )
69 | return state_dict
70 |
71 | def clip_target(self) -> Clip.ClipTarget:
72 | """#### Get the target CLIP model.
73 |
74 | #### Returns:
75 | - `Clip.ClipTarget`: The target CLIP model.
76 | """
77 | return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel)
78 |
79 | models = [
80 | sm_SD15, Flux.Flux
81 | ]
--------------------------------------------------------------------------------
/modules/StableFast/StableFast.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import functools
3 | import logging
4 | from dataclasses import dataclass
5 |
6 | import torch
7 |
8 | try:
9 | from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig
10 | from sfast.compilers.diffusion_pipeline_compiler import (
11 | _enable_xformers,
12 | _modify_model,
13 | )
14 | from sfast.cuda.graphs import make_dynamic_graphed_callable
15 | from sfast.jit import utils as jit_utils
16 | from sfast.jit.trace_helper import trace_with_kwargs
17 | except:
18 | pass
19 |
20 |
21 | def hash_arg(arg):
22 | # micro optimization: bool obj is an instance of int
23 | if isinstance(arg, (str, int, float, bytes)):
24 | return arg
25 | if isinstance(arg, (tuple, list)):
26 | return tuple(map(hash_arg, arg))
27 | if isinstance(arg, dict):
28 | return tuple(
29 | sorted(
30 | ((hash_arg(k), hash_arg(v)) for k, v in arg.items()), key=lambda x: x[0]
31 | )
32 | )
33 | return type(arg)
34 |
35 |
36 | class ModuleFactory:
37 | def get_converted_kwargs(self):
38 | return self.converted_kwargs
39 |
40 |
41 | import torch as th
42 | import torch.nn as nn
43 | import copy
44 |
45 |
46 | class BaseModelApplyModelModule(torch.nn.Module):
47 | def __init__(self, func, module):
48 | super().__init__()
49 | self.func = func
50 | self.module = module
51 |
52 | def forward(
53 | self,
54 | input_x,
55 | timestep,
56 | c_concat=None,
57 | c_crossattn=None,
58 | y=None,
59 | control=None,
60 | transformer_options={},
61 | ):
62 | kwargs = {"y": y}
63 |
64 | new_transformer_options = {}
65 |
66 | return self.func(
67 | input_x,
68 | timestep,
69 | c_concat=c_concat,
70 | c_crossattn=c_crossattn,
71 | control=control,
72 | transformer_options=new_transformer_options,
73 | **kwargs,
74 | )
75 |
76 |
77 | class BaseModelApplyModelModuleFactory(ModuleFactory):
78 | kwargs_name = (
79 | "input_x",
80 | "timestep",
81 | "c_concat",
82 | "c_crossattn",
83 | "y",
84 | "control",
85 | )
86 |
87 | def __init__(self, callable, kwargs) -> None:
88 | self.callable = callable
89 | self.unet_config = callable.__self__.model_config.unet_config
90 | self.kwargs = kwargs
91 | self.patch_module = {}
92 | self.patch_module_parameter = {}
93 | self.converted_kwargs = self.gen_converted_kwargs()
94 |
95 | def gen_converted_kwargs(self):
96 | converted_kwargs = {}
97 | for arg_name, arg in self.kwargs.items():
98 | if arg_name in self.kwargs_name:
99 | converted_kwargs[arg_name] = arg
100 |
101 | transformer_options = self.kwargs.get("transformer_options", {})
102 | patches = transformer_options.get("patches", {})
103 |
104 | patch_module = {}
105 | patch_module_parameter = {}
106 |
107 | new_transformer_options = {}
108 | new_transformer_options["patches"] = patch_module_parameter
109 |
110 | self.patch_module = patch_module
111 | self.patch_module_parameter = patch_module_parameter
112 | return converted_kwargs
113 |
114 | def gen_cache_key(self):
115 | key_kwargs = {}
116 | for k, v in self.converted_kwargs.items():
117 | key_kwargs[k] = v
118 |
119 | patch_module_cache_key = {}
120 | return (
121 | self.callable.__class__.__qualname__,
122 | hash_arg(self.unet_config),
123 | hash_arg(key_kwargs),
124 | hash_arg(patch_module_cache_key),
125 | )
126 |
127 | @contextlib.contextmanager
128 | def converted_module_context(self):
129 | module = BaseModelApplyModelModule(self.callable, self.callable.__self__)
130 | yield (module, self.converted_kwargs)
131 |
132 |
133 | logger = logging.getLogger()
134 |
135 |
136 | @dataclass
137 | class TracedModuleCacheItem:
138 | module: object
139 | patch_id: int
140 | device: str
141 |
142 |
143 | class LazyTraceModule:
144 | traced_modules = {}
145 |
146 | def __init__(self, config=None, patch_id=None, **kwargs_) -> None:
147 | self.config = config
148 | self.patch_id = patch_id
149 | self.kwargs_ = kwargs_
150 | self.modify_model = functools.partial(
151 | _modify_model,
152 | enable_cnn_optimization=config.enable_cnn_optimization,
153 | prefer_lowp_gemm=config.prefer_lowp_gemm,
154 | enable_triton=config.enable_triton,
155 | enable_triton_reshape=config.enable_triton,
156 | memory_format=config.memory_format,
157 | )
158 | self.cuda_graph_modules = {}
159 |
160 | def ts_compiler(
161 | self,
162 | m,
163 | ):
164 | with torch.jit.optimized_execution(True):
165 | if self.config.enable_jit_freeze:
166 | # raw freeze causes Tensor reference leak
167 | # because the constant Tensors in the GraphFunction of
168 | # the compilation unit are never freed.
169 | m.eval()
170 | m = jit_utils.better_freeze(m)
171 | self.modify_model(m)
172 |
173 | if self.config.enable_cuda_graph:
174 | m = make_dynamic_graphed_callable(m)
175 | return m
176 |
177 | def __call__(self, model_function, /, **kwargs):
178 | module_factory = BaseModelApplyModelModuleFactory(model_function, kwargs)
179 | kwargs = module_factory.get_converted_kwargs()
180 | key = module_factory.gen_cache_key()
181 |
182 | traced_module = self.cuda_graph_modules.get(key)
183 | if traced_module is None:
184 | with module_factory.converted_module_context() as (m_model, m_kwargs):
185 | logger.info(
186 | f'Tracing {getattr(m_model, "__name__", m_model.__class__.__name__)}'
187 | )
188 | traced_m, call_helper = trace_with_kwargs(
189 | m_model, None, m_kwargs, **self.kwargs_
190 | )
191 |
192 | traced_m = self.ts_compiler(traced_m)
193 | traced_module = call_helper(traced_m)
194 | self.cuda_graph_modules[key] = traced_module
195 |
196 | return traced_module(**kwargs)
197 |
198 |
199 | def build_lazy_trace_module(config, device, patch_id):
200 | config.enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda"
201 |
202 | if config.enable_xformers:
203 | _enable_xformers(None)
204 |
205 | return LazyTraceModule(
206 | config=config,
207 | patch_id=patch_id,
208 | check_trace=True,
209 | strict=True,
210 | )
211 |
212 |
213 | def gen_stable_fast_config():
214 | config = CompilationConfig.Default()
215 | try:
216 | import xformers
217 |
218 | config.enable_xformers = True
219 | except ImportError:
220 | print("xformers not installed, skip")
221 |
222 | # CUDA Graph is suggested for small batch sizes.
223 | # After capturing, the model only accepts one fixed image size.
224 | # If you want the model to be dynamic, don't enable it.
225 | config.enable_cuda_graph = False
226 | # config.enable_jit_freeze = False
227 | return config
228 |
229 |
230 | class StableFastPatch:
231 | def __init__(self, model, config):
232 | self.model = model
233 | self.config = config
234 | self.stable_fast_model = None
235 |
236 | def __call__(self, model_function, params):
237 | input_x = params.get("input")
238 | timestep_ = params.get("timestep")
239 | c = params.get("c")
240 |
241 | if self.stable_fast_model is None:
242 | self.stable_fast_model = build_lazy_trace_module(
243 | self.config,
244 | input_x.device,
245 | id(self),
246 | )
247 |
248 | return self.stable_fast_model(
249 | model_function, input_x=input_x, timestep=timestep_, **c
250 | )
251 |
252 | def to(self, device):
253 | if type(device) == torch.device:
254 | if self.config.enable_cuda_graph or self.config.enable_jit_freeze:
255 | if device.type == "cpu":
256 | del self.stable_fast_model
257 | self.stable_fast_model = None
258 | print(
259 | "\33[93mWarning: Your graphics card doesn't have enough video memory to keep the model. If you experience a noticeable delay every time you start sampling, please consider disable enable_cuda_graph.\33[0m"
260 | )
261 | return self
262 |
263 |
264 | class ApplyStableFastUnet:
265 | def apply_stable_fast(self, model, enable_cuda_graph):
266 | config = gen_stable_fast_config()
267 |
268 | if config.memory_format is not None:
269 | model.model.to(memory_format=config.memory_format)
270 |
271 | patch = StableFastPatch(model, config)
272 | model_stable_fast = model.clone()
273 | model_stable_fast.set_model_unet_function_wrapper(patch)
274 | return (model_stable_fast,)
--------------------------------------------------------------------------------
/modules/UltimateSDUpscale/USDU_upscaler.py:
--------------------------------------------------------------------------------
1 | import logging as logger
2 | import torch
3 | from PIL import Image
4 |
5 | from modules.Device import Device
6 | from modules.UltimateSDUpscale import RDRB
7 | from modules.UltimateSDUpscale import image_util
8 | from modules.Utilities import util
9 |
10 |
11 | def load_state_dict(state_dict: dict) -> RDRB.PyTorchModel:
12 | """#### Load a state dictionary into a PyTorch model.
13 |
14 | #### Args:
15 | - `state_dict` (dict): The state dictionary.
16 |
17 | #### Returns:
18 | - `RDRB.PyTorchModel`: The loaded PyTorch model.
19 | """
20 | logger.debug("Loading state dict into pytorch model arch")
21 | state_dict_keys = list(state_dict.keys())
22 | if "params_ema" in state_dict_keys:
23 | state_dict = state_dict["params_ema"]
24 | model = RDRB.RRDBNet(state_dict)
25 | return model
26 |
27 |
28 | class UpscaleModelLoader:
29 | """#### Class for loading upscale models."""
30 |
31 | def load_model(self, model_name: str) -> tuple:
32 | """#### Load an upscale model.
33 |
34 | #### Args:
35 | - `model_name` (str): The name of the model.
36 |
37 | #### Returns:
38 | - `tuple`: The loaded model.
39 | """
40 | model_path = f"./_internal/ESRGAN/{model_name}"
41 | sd = util.load_torch_file(model_path, safe_load=True)
42 | if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
43 | sd = util.state_dict_prefix_replace(sd, {"module.": ""})
44 | out = load_state_dict(sd).eval()
45 | return (out,)
46 |
47 |
48 | class ImageUpscaleWithModel:
49 | """#### Class for upscaling images with a model."""
50 |
51 | def upscale(self, upscale_model: torch.nn.Module, image: torch.Tensor) -> tuple:
52 | """#### Upscale an image using a model.
53 |
54 | #### Args:
55 | - `upscale_model` (torch.nn.Module): The upscale model.
56 | - `image` (torch.Tensor): The input image tensor.
57 |
58 | #### Returns:
59 | - `tuple`: The upscaled image tensor.
60 | """
61 | if torch.cuda.is_available():
62 | device = torch.device(torch.cuda.current_device())
63 | else:
64 | device = torch.device("cpu")
65 | upscale_model.to(device)
66 | in_img = image.movedim(-1, -3).to(device)
67 | Device.get_free_memory(device)
68 |
69 | tile = 512
70 | overlap = 32
71 |
72 | oom = True
73 | while oom:
74 | steps = in_img.shape[0] * image_util.get_tiled_scale_steps(
75 | in_img.shape[3],
76 | in_img.shape[2],
77 | tile_x=tile,
78 | tile_y=tile,
79 | overlap=overlap,
80 | )
81 | pbar = util.ProgressBar(steps)
82 | s = image_util.tiled_scale(
83 | in_img,
84 | lambda a: upscale_model(a),
85 | tile_x=tile,
86 | tile_y=tile,
87 | overlap=overlap,
88 | upscale_amount=upscale_model.scale,
89 | pbar=pbar,
90 | )
91 | oom = False
92 |
93 | upscale_model.cpu()
94 | s = torch.clamp(s.movedim(-3, -1), min=0, max=1.0)
95 | return (s,)
96 |
97 |
98 | def torch_gc() -> None:
99 | """#### Perform garbage collection for PyTorch."""
100 | pass
101 |
102 |
103 | class Script:
104 | """#### Class representing a script."""
105 | pass
106 |
107 |
108 | class Options:
109 | """#### Class representing options."""
110 |
111 | img2img_background_color: str = "#ffffff" # Set to white for now
112 |
113 |
114 | class State:
115 | """#### Class representing the state."""
116 |
117 | interrupted: bool = False
118 |
119 | def begin(self) -> None:
120 | """#### Begin the state."""
121 | pass
122 |
123 | def end(self) -> None:
124 | """#### End the state."""
125 | pass
126 |
127 |
128 | opts = Options()
129 | state = State()
130 |
131 | # Will only ever hold 1 upscaler
132 | sd_upscalers = [None]
133 | actual_upscaler = None
134 |
135 | # Batch of images to upscale
136 | batch = None
137 |
138 |
139 | if not hasattr(Image, "Resampling"): # For older versions of Pillow
140 | Image.Resampling = Image
141 |
142 |
143 | class Upscaler:
144 | """#### Class for upscaling images."""
145 |
146 | def _upscale(self, img: Image.Image, scale: float) -> Image.Image:
147 | """#### Upscale an image.
148 |
149 | #### Args:
150 | - `img` (Image.Image): The input image.
151 | - `scale` (float): The scale factor.
152 |
153 | #### Returns:
154 | - `Image.Image`: The upscaled image.
155 | """
156 | global actual_upscaler
157 | tensor = image_util.pil_to_tensor(img)
158 | image_upscale_node = ImageUpscaleWithModel()
159 | (upscaled,) = image_upscale_node.upscale(actual_upscaler, tensor)
160 | return image_util.tensor_to_pil(upscaled)
161 |
162 | def upscale(self, img: Image.Image, scale: float, selected_model: str = None) -> Image.Image:
163 | """#### Upscale an image with a selected model.
164 |
165 | #### Args:
166 | - `img` (Image.Image): The input image.
167 | - `scale` (float): The scale factor.
168 | - `selected_model` (str, optional): The selected model. Defaults to None.
169 |
170 | #### Returns:
171 | - `Image.Image`: The upscaled image.
172 | """
173 | global batch
174 | batch = [self._upscale(img, scale) for img in batch]
175 | return batch[0]
176 |
177 |
178 | class UpscalerData:
179 | """#### Class for storing upscaler data."""
180 |
181 | name: str = ""
182 | data_path: str = ""
183 |
184 | def __init__(self):
185 | self.scaler = Upscaler()
--------------------------------------------------------------------------------
/modules/UltimateSDUpscale/USDU_util.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 | import torch
3 | import torch.nn as nn
4 |
5 | ConvMode = Literal["CNA", "NAC", "CNAC"]
6 |
7 | def act(act_type: str, inplace: bool = True, neg_slope: float = 0.2, n_prelu: int = 1) -> nn.Module:
8 | """#### Get the activation layer.
9 |
10 | #### Args:
11 | - `act_type` (str): The type of activation.
12 | - `inplace` (bool, optional): Whether to perform the operation in-place. Defaults to True.
13 | - `neg_slope` (float, optional): The negative slope for LeakyReLU. Defaults to 0.2.
14 | - `n_prelu` (int, optional): The number of PReLU parameters. Defaults to 1.
15 |
16 | #### Returns:
17 | - `nn.Module`: The activation layer.
18 | """
19 | act_type = act_type.lower()
20 | layer = nn.LeakyReLU(neg_slope, inplace)
21 | return layer
22 |
23 | def get_valid_padding(kernel_size: int, dilation: int) -> int:
24 | """#### Get the valid padding for a convolutional layer.
25 |
26 | #### Args:
27 | - `kernel_size` (int): The size of the kernel.
28 | - `dilation` (int): The dilation rate.
29 |
30 | #### Returns:
31 | - `int`: The valid padding.
32 | """
33 | kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
34 | padding = (kernel_size - 1) // 2
35 | return padding
36 |
37 | def sequential(*args: nn.Module) -> nn.Sequential:
38 | """#### Create a sequential container.
39 |
40 | #### Args:
41 | - `*args` (nn.Module): The modules to include in the sequential container.
42 |
43 | #### Returns:
44 | - `nn.Sequential`: The sequential container.
45 | """
46 | modules = []
47 | for module in args:
48 | if isinstance(module, nn.Sequential):
49 | for submodule in module.children():
50 | modules.append(submodule)
51 | elif isinstance(module, nn.Module):
52 | modules.append(module)
53 | return nn.Sequential(*modules)
54 |
55 | def conv_block(
56 | in_nc: int,
57 | out_nc: int,
58 | kernel_size: int,
59 | stride: int = 1,
60 | dilation: int = 1,
61 | groups: int = 1,
62 | bias: bool = True,
63 | pad_type: str = "zero",
64 | norm_type: str | None = None,
65 | act_type: str | None = "relu",
66 | mode: ConvMode = "CNA",
67 | c2x2: bool = False,
68 | ) -> nn.Sequential:
69 | """#### Create a convolutional block.
70 |
71 | #### Args:
72 | - `in_nc` (int): The number of input channels.
73 | - `out_nc` (int): The number of output channels.
74 | - `kernel_size` (int): The size of the kernel.
75 | - `stride` (int, optional): The stride of the convolution. Defaults to 1.
76 | - `dilation` (int, optional): The dilation rate. Defaults to 1.
77 | - `groups` (int, optional): The number of groups. Defaults to 1.
78 | - `bias` (bool, optional): Whether to include a bias term. Defaults to True.
79 | - `pad_type` (str, optional): The type of padding. Defaults to "zero".
80 | - `norm_type` (str | None, optional): The type of normalization. Defaults to None.
81 | - `act_type` (str | None, optional): The type of activation. Defaults to "relu".
82 | - `mode` (ConvMode, optional): The mode of the convolution. Defaults to "CNA".
83 | - `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
84 |
85 | #### Returns:
86 | - `nn.Sequential`: The convolutional block.
87 | """
88 | assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
89 | padding = get_valid_padding(kernel_size, dilation)
90 | padding = padding if pad_type == "zero" else 0
91 |
92 | c = nn.Conv2d(
93 | in_nc,
94 | out_nc,
95 | kernel_size=kernel_size,
96 | stride=stride,
97 | padding=padding,
98 | dilation=dilation,
99 | bias=bias,
100 | groups=groups,
101 | )
102 | a = act(act_type) if act_type else None
103 | if mode in ("CNA", "CNAC"):
104 | return sequential(None, c, None, a)
105 |
106 | def upconv_block(
107 | in_nc: int,
108 | out_nc: int,
109 | upscale_factor: int = 2,
110 | kernel_size: int = 3,
111 | stride: int = 1,
112 | bias: bool = True,
113 | pad_type: str = "zero",
114 | norm_type: str | None = None,
115 | act_type: str = "relu",
116 | mode: str = "nearest",
117 | c2x2: bool = False,
118 | ) -> nn.Sequential:
119 | """#### Create an upsampling convolutional block.
120 |
121 | #### Args:
122 | - `in_nc` (int): The number of input channels.
123 | - `out_nc` (int): The number of output channels.
124 | - `upscale_factor` (int, optional): The upscale factor. Defaults to 2.
125 | - `kernel_size` (int, optional): The size of the kernel. Defaults to 3.
126 | - `stride` (int, optional): The stride of the convolution. Defaults to 1.
127 | - `bias` (bool, optional): Whether to include a bias term. Defaults to True.
128 | - `pad_type` (str, optional): The type of padding. Defaults to "zero".
129 | - `norm_type` (str | None, optional): The type of normalization. Defaults to None.
130 | - `act_type` (str, optional): The type of activation. Defaults to "relu".
131 | - `mode` (str, optional): The mode of upsampling. Defaults to "nearest".
132 | - `c2x2` (bool, optional): Whether to use 2x2 convolutions. Defaults to False.
133 |
134 | #### Returns:
135 | - `nn.Sequential`: The upsampling convolutional block.
136 | """
137 | upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
138 | conv = conv_block(
139 | in_nc,
140 | out_nc,
141 | kernel_size,
142 | stride,
143 | bias=bias,
144 | pad_type=pad_type,
145 | norm_type=norm_type,
146 | act_type=act_type,
147 | c2x2=c2x2,
148 | )
149 | return sequential(upsample, conv)
150 |
151 | class ShortcutBlock(nn.Module):
152 | """#### Elementwise sum the output of a submodule to its input."""
153 |
154 | def __init__(self, submodule: nn.Module):
155 | """#### Initialize the ShortcutBlock.
156 |
157 | #### Args:
158 | - `submodule` (nn.Module): The submodule to apply.
159 | """
160 | super(ShortcutBlock, self).__init__()
161 | self.sub = submodule
162 |
163 | def forward(self, x: torch.Tensor) -> torch.Tensor:
164 | """#### Forward pass.
165 |
166 | #### Args:
167 | - `x` (torch.Tensor): The input tensor.
168 |
169 | #### Returns:
170 | - `torch.Tensor`: The output tensor.
171 | """
172 | output = x + self.sub(x)
173 | return output
--------------------------------------------------------------------------------
/modules/UltimateSDUpscale/image_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import torch
4 | from PIL import Image
5 |
6 |
7 | def get_tiled_scale_steps(width: int, height: int, tile_x: int, tile_y: int, overlap: int) -> int:
8 | """#### Calculate the number of steps required for tiled scaling.
9 |
10 | #### Args:
11 | - `width` (int): The width of the image.
12 | - `height` (int): The height of the image.
13 | - `tile_x` (int): The width of each tile.
14 | - `tile_y` (int): The height of each tile.
15 | - `overlap` (int): The overlap between tiles.
16 |
17 | #### Returns:
18 | - `int`: The number of steps required for tiled scaling.
19 | """
20 | return math.ceil((height / (tile_y - overlap))) * math.ceil(
21 | (width / (tile_x - overlap))
22 | )
23 |
24 |
25 | @torch.inference_mode()
26 | def tiled_scale(
27 | samples: torch.Tensor,
28 | function: callable,
29 | tile_x: int = 64,
30 | tile_y: int = 64,
31 | overlap: int = 8,
32 | upscale_amount: float = 4,
33 | out_channels: int = 3,
34 | pbar: any = None,
35 | ) -> torch.Tensor:
36 | """#### Perform tiled scaling on a batch of samples.
37 |
38 | #### Args:
39 | - `samples` (torch.Tensor): The input samples.
40 | - `function` (callable): The function to apply to each tile.
41 | - `tile_x` (int, optional): The width of each tile. Defaults to 64.
42 | - `tile_y` (int, optional): The height of each tile. Defaults to 64.
43 | - `overlap` (int, optional): The overlap between tiles. Defaults to 8.
44 | - `upscale_amount` (float, optional): The upscale amount. Defaults to 4.
45 | - `out_channels` (int, optional): The number of output channels. Defaults to 3.
46 | - `pbar` (any, optional): The progress bar. Defaults to None.
47 |
48 | #### Returns:
49 | - `torch.Tensor`: The scaled output tensor.
50 | """
51 | output = torch.empty(
52 | (
53 | samples.shape[0],
54 | out_channels,
55 | round(samples.shape[2] * upscale_amount),
56 | round(samples.shape[3] * upscale_amount),
57 | ),
58 | device="cpu",
59 | )
60 | for b in range(samples.shape[0]):
61 | s = samples[b : b + 1]
62 | out = torch.zeros(
63 | (
64 | s.shape[0],
65 | out_channels,
66 | round(s.shape[2] * upscale_amount),
67 | round(s.shape[3] * upscale_amount),
68 | ),
69 | device="cpu",
70 | )
71 | out_div = torch.zeros(
72 | (
73 | s.shape[0],
74 | out_channels,
75 | round(s.shape[2] * upscale_amount),
76 | round(s.shape[3] * upscale_amount),
77 | ),
78 | device="cpu",
79 | )
80 | for y in range(0, s.shape[2], tile_y - overlap):
81 | for x in range(0, s.shape[3], tile_x - overlap):
82 | s_in = s[:, :, y : y + tile_y, x : x + tile_x]
83 |
84 | ps = function(s_in).cpu()
85 | mask = torch.ones_like(ps)
86 | feather = round(overlap * upscale_amount)
87 | for t in range(feather):
88 | mask[:, :, t : 1 + t, :] *= (1.0 / feather) * (t + 1)
89 | mask[:, :, mask.shape[2] - 1 - t : mask.shape[2] - t, :] *= (
90 | 1.0 / feather
91 | ) * (t + 1)
92 | mask[:, :, :, t : 1 + t] *= (1.0 / feather) * (t + 1)
93 | mask[:, :, :, mask.shape[3] - 1 - t : mask.shape[3] - t] *= (
94 | 1.0 / feather
95 | ) * (t + 1)
96 | out[
97 | :,
98 | :,
99 | round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
100 | round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
101 | ] += ps * mask
102 | out_div[
103 | :,
104 | :,
105 | round(y * upscale_amount) : round((y + tile_y) * upscale_amount),
106 | round(x * upscale_amount) : round((x + tile_x) * upscale_amount),
107 | ] += mask
108 |
109 | output[b : b + 1] = out / out_div
110 | return output
111 |
112 |
113 | def flatten(img: Image.Image, bgcolor: str) -> Image.Image:
114 | """#### Replace transparency with a background color.
115 |
116 | #### Args:
117 | - `img` (Image.Image): The input image.
118 | - `bgcolor` (str): The background color.
119 |
120 | #### Returns:
121 | - `Image.Image`: The image with transparency replaced by the background color.
122 | """
123 | if img.mode in ("RGB"):
124 | return img
125 | return Image.alpha_composite(Image.new("RGBA", img.size, bgcolor), img).convert(
126 | "RGB"
127 | )
128 |
129 |
130 | BLUR_KERNEL_SIZE = 15
131 |
132 |
133 | def tensor_to_pil(img_tensor: torch.Tensor, batch_index: int = 0) -> Image.Image:
134 | """#### Convert a tensor to a PIL image.
135 |
136 | #### Args:
137 | - `img_tensor` (torch.Tensor): The input tensor.
138 | - `batch_index` (int, optional): The batch index. Defaults to 0.
139 |
140 | #### Returns:
141 | - `Image.Image`: The converted PIL image.
142 | """
143 | img_tensor = img_tensor[batch_index].unsqueeze(0)
144 | i = 255.0 * img_tensor.cpu().numpy()
145 | img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8).squeeze())
146 | return img
147 |
148 |
149 | def pil_to_tensor(image: Image.Image) -> torch.Tensor:
150 | """#### Convert a PIL image to a tensor.
151 |
152 | #### Args:
153 | - `image` (Image.Image): The input PIL image.
154 |
155 | #### Returns:
156 | - `torch.Tensor`: The converted tensor.
157 | """
158 | image = np.array(image).astype(np.float32) / 255.0
159 | image = torch.from_numpy(image).unsqueeze(0)
160 | return image
161 |
162 |
163 | def get_crop_region(mask: Image.Image, pad: int = 0) -> tuple:
164 | """#### Get the coordinates of the white rectangular mask region.
165 |
166 | #### Args:
167 | - `mask` (Image.Image): The input mask image in 'L' mode.
168 | - `pad` (int, optional): The padding to apply. Defaults to 0.
169 |
170 | #### Returns:
171 | - `tuple`: The coordinates of the crop region.
172 | """
173 | coordinates = mask.getbbox()
174 | if coordinates is not None:
175 | x1, y1, x2, y2 = coordinates
176 | else:
177 | x1, y1, x2, y2 = mask.width, mask.height, 0, 0
178 | # Apply padding
179 | x1 = max(x1 - pad, 0)
180 | y1 = max(y1 - pad, 0)
181 | x2 = min(x2 + pad, mask.width)
182 | y2 = min(y2 + pad, mask.height)
183 | return fix_crop_region((x1, y1, x2, y2), (mask.width, mask.height))
184 |
185 |
186 | def fix_crop_region(region: tuple, image_size: tuple) -> tuple:
187 | """#### Remove the extra pixel added by the get_crop_region function.
188 |
189 | #### Args:
190 | - `region` (tuple): The crop region coordinates.
191 | - `image_size` (tuple): The size of the image.
192 |
193 | #### Returns:
194 | - `tuple`: The fixed crop region coordinates.
195 | """
196 | image_width, image_height = image_size
197 | x1, y1, x2, y2 = region
198 | if x2 < image_width:
199 | x2 -= 1
200 | if y2 < image_height:
201 | y2 -= 1
202 | return x1, y1, x2, y2
203 |
204 |
205 | def expand_crop(region: tuple, width: int, height: int, target_width: int, target_height: int) -> tuple:
206 | """#### Expand a crop region to a specified target size.
207 |
208 | #### Args:
209 | - `region` (tuple): The crop region coordinates.
210 | - `width` (int): The width of the image.
211 | - `height` (int): The height of the image.
212 | - `target_width` (int): The desired width of the crop region.
213 | - `target_height` (int): The desired height of the crop region.
214 |
215 | #### Returns:
216 | - `tuple`: The expanded crop region coordinates and the target size.
217 | """
218 | x1, y1, x2, y2 = region
219 | actual_width = x2 - x1
220 | actual_height = y2 - y1
221 |
222 | # Try to expand region to the right of half the difference
223 | width_diff = target_width - actual_width
224 | x2 = min(x2 + width_diff // 2, width)
225 | # Expand region to the left of the difference including the pixels that could not be expanded to the right
226 | width_diff = target_width - (x2 - x1)
227 | x1 = max(x1 - width_diff, 0)
228 | # Try the right again
229 | width_diff = target_width - (x2 - x1)
230 | x2 = min(x2 + width_diff, width)
231 |
232 | # Try to expand region to the bottom of half the difference
233 | height_diff = target_height - actual_height
234 | y2 = min(y2 + height_diff // 2, height)
235 | # Expand region to the top of the difference including the pixels that could not be expanded to the bottom
236 | height_diff = target_height - (y2 - y1)
237 | y1 = max(y1 - height_diff, 0)
238 | # Try the bottom again
239 | height_diff = target_height - (y2 - y1)
240 | y2 = min(y2 + height_diff, height)
241 |
242 | return (x1, y1, x2, y2), (target_width, target_height)
243 |
244 |
245 | def crop_cond(cond: list, region: tuple, init_size: tuple, canvas_size: tuple, tile_size: tuple, w_pad: int = 0, h_pad: int = 0) -> list:
246 | """#### Crop conditioning data to match a specific region.
247 |
248 | #### Args:
249 | - `cond` (list): The conditioning data.
250 | - `region` (tuple): The crop region coordinates.
251 | - `init_size` (tuple): The initial size of the image.
252 | - `canvas_size` (tuple): The size of the canvas.
253 | - `tile_size` (tuple): The size of the tile.
254 | - `w_pad` (int, optional): The width padding. Defaults to 0.
255 | - `h_pad` (int, optional): The height padding. Defaults to 0.
256 |
257 | #### Returns:
258 | - `list`: The cropped conditioning data.
259 | """
260 | cropped = []
261 | for emb, x in cond:
262 | cond_dict = x.copy()
263 | n = [emb, cond_dict]
264 | cropped.append(n)
265 | return cropped
--------------------------------------------------------------------------------
/modules/Utilities/Enhancer.py:
--------------------------------------------------------------------------------
1 | import ollama
2 | import os
3 |
4 | from modules.Utilities import util
5 |
6 |
7 | def enhance_prompt(p: str) -> str:
8 | """#### Enhance a text-to-image prompt using Ollama.
9 |
10 | #### Args:
11 | - `p` (str, optional): The prompt. Defaults to `None`.
12 |
13 | #### Returns:
14 | - `str`: The enhanced prompt
15 | """
16 |
17 | # Load the prompt from the file
18 | prompt = util.load_parameters_from_file()[0]
19 | if p is None:
20 | pass
21 | else:
22 | prompt = p
23 | print(prompt)
24 | response = ollama.chat(
25 | model="deepseek-r1",
26 | messages=[
27 | {
28 | "role": "user",
29 | "content": f"""Your goal is to generate a text-to-image prompt based on a user's input, detailing their desired final outcome for an image. The user will provide specific details about the characteristics, features, or elements they want the image to include. The prompt should guide the generation of an image that aligns with the user's desired outcome.
30 |
31 | Generate a text-to-image prompt by arranging the following blocks in a single string, separated by commas:
32 |
33 | Image Type: [Specify desired image type]
34 |
35 | Aesthetic or Mood: [Describe desired aesthetic or mood]
36 |
37 | Lighting Conditions: [Specify desired lighting conditions]
38 |
39 | Composition or Framing: [Provide details about desired composition or framing]
40 |
41 | Background: [Specify desired background elements or setting]
42 |
43 | Colors: [Mention any specific colors or color palette]
44 |
45 | Objects or Elements: [List specific objects or features]
46 |
47 | Style or Artistic Influence: [Mention desired artistic style or influence]
48 |
49 | Subject's Appearance: [Describe appearance of main subject]
50 |
51 | Ensure the blocks are arranged in order of visual importance, from the most significant to the least significant, to effectively guide image generation, a block can be surrounded by parentheses to gain additionnal significance.
52 |
53 | This is an example of a user's input: "a beautiful blonde lady in lingerie sitting in seiza in a seducing way with a focus on her assets"
54 |
55 | And this is an example of a desired output: "portrait| serene and mysterious| soft, diffused lighting| close-up shot, emphasizing facial features| simple and blurred background| earthy tones with a hint of warm highlights| renaissance painting| a beautiful lady with freckles and dark makeup"
56 |
57 | Here is the user's input: {prompt}
58 |
59 | Write the prompt in the same style as the example above, in a single line , with absolutely no additional information, words or symbols other than the enhanced prompt.
60 |
61 | Output:""",
62 | },
63 | ],
64 | )
65 | content = response["message"]["content"]
66 | print("here's the enhanced prompt :", content)
67 |
68 | if "" in content and "" in content:
69 | # Get everything after
70 | enhanced = content.split("")[-1].strip()
71 | else:
72 | enhanced = content.strip()
73 | print("here's the enhanced prompt:", enhanced)
74 | os.system("ollama stop deepseek-r1")
75 | return "masterpiece, best quality, (extremely detailed CG unity 8k wallpaper, masterpiece, best quality, ultra-detailed, best shadow), high contrast, (best illumination), ((cinematic light)), hyper detail, dramatic light, depth of field," + enhanced
76 |
--------------------------------------------------------------------------------
/modules/Utilities/Latent.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple
2 | import torch
3 | from modules.Device import Device
4 | from modules.Utilities import util
5 |
6 | class LatentFormat:
7 | """#### Base class for latent formats.
8 |
9 | #### Attributes:
10 | - `scale_factor` (float): The scale factor for the latent format.
11 |
12 | #### Returns:
13 | - `LatentFormat`: A latent format object.
14 | """
15 |
16 | scale_factor: float = 1.0
17 | latent_channels: int = 4
18 |
19 | def process_in(self, latent: torch.Tensor) -> torch.Tensor:
20 | """#### Process the latent input, by multiplying it by the scale factor.
21 |
22 | #### Args:
23 | - `latent` (torch.Tensor): The latent tensor.
24 |
25 | #### Returns:
26 | - `torch.Tensor`: The processed latent tensor.
27 | """
28 | return latent * self.scale_factor
29 |
30 | def process_out(self, latent: torch.Tensor) -> torch.Tensor:
31 | """#### Process the latent output, by dividing it by the scale factor.
32 |
33 | #### Args:
34 | - `latent` (torch.Tensor): The latent tensor.
35 |
36 | #### Returns:
37 | - `torch.Tensor`: The processed latent tensor.
38 | """
39 | return latent / self.scale_factor
40 |
41 | class SD15(LatentFormat):
42 | """#### SD15 latent format.
43 |
44 | #### Args:
45 | - `LatentFormat` (LatentFormat): The base latent format class.
46 | """
47 | latent_channels: int = 4
48 | def __init__(self, scale_factor: float = 0.18215):
49 | """#### Initialize the SD15 latent format.
50 |
51 | #### Args:
52 | - `scale_factor` (float, optional): The scale factor. Defaults to 0.18215.
53 | """
54 | self.scale_factor = scale_factor
55 | self.latent_rgb_factors = [
56 | # R G B
57 | [0.3512, 0.2297, 0.3227],
58 | [0.3250, 0.4974, 0.2350],
59 | [-0.2829, 0.1762, 0.2721],
60 | [-0.2120, -0.2616, -0.7177],
61 | ]
62 | self.taesd_decoder_name = "taesd_decoder"
63 |
64 | class SD3(LatentFormat):
65 | latent_channels = 16
66 |
67 | def __init__(self):
68 | """#### Initialize the SD3 latent format."""
69 | self.scale_factor = 1.5305
70 | self.shift_factor = 0.0609
71 | self.latent_rgb_factors = [
72 | [-0.0645, 0.0177, 0.1052],
73 | [0.0028, 0.0312, 0.0650],
74 | [0.1848, 0.0762, 0.0360],
75 | [0.0944, 0.0360, 0.0889],
76 | [0.0897, 0.0506, -0.0364],
77 | [-0.0020, 0.1203, 0.0284],
78 | [0.0855, 0.0118, 0.0283],
79 | [-0.0539, 0.0658, 0.1047],
80 | [-0.0057, 0.0116, 0.0700],
81 | [-0.0412, 0.0281, -0.0039],
82 | [0.1106, 0.1171, 0.1220],
83 | [-0.0248, 0.0682, -0.0481],
84 | [0.0815, 0.0846, 0.1207],
85 | [-0.0120, -0.0055, -0.0867],
86 | [-0.0749, -0.0634, -0.0456],
87 | [-0.1418, -0.1457, -0.1259],
88 | ]
89 | self.taesd_decoder_name = "taesd3_decoder"
90 |
91 | def process_in(self, latent: torch.Tensor) -> torch.Tensor:
92 | """#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor.
93 |
94 | #### Args:
95 | - `latent` (torch.Tensor): The latent tensor.
96 |
97 | #### Returns:
98 | - `torch.Tensor`: The processed latent tensor.
99 | """
100 | return (latent - self.shift_factor) * self.scale_factor
101 |
102 | def process_out(self, latent: torch.Tensor) -> torch.Tensor:
103 | """#### Process the latent output, by dividing it by the scale factor and adding the shift factor.
104 |
105 | #### Args:
106 | - `latent` (torch.Tensor): The latent tensor.
107 |
108 | #### Returns:
109 | - `torch.Tensor`: The processed latent tensor.
110 | """
111 | return (latent / self.scale_factor) + self.shift_factor
112 |
113 |
114 | class Flux1(SD3):
115 | latent_channels = 16
116 |
117 | def __init__(self):
118 | """#### Initialize the Flux1 latent format."""
119 | self.scale_factor = 0.3611
120 | self.shift_factor = 0.1159
121 | self.latent_rgb_factors = [
122 | [-0.0404, 0.0159, 0.0609],
123 | [0.0043, 0.0298, 0.0850],
124 | [0.0328, -0.0749, -0.0503],
125 | [-0.0245, 0.0085, 0.0549],
126 | [0.0966, 0.0894, 0.0530],
127 | [0.0035, 0.0399, 0.0123],
128 | [0.0583, 0.1184, 0.1262],
129 | [-0.0191, -0.0206, -0.0306],
130 | [-0.0324, 0.0055, 0.1001],
131 | [0.0955, 0.0659, -0.0545],
132 | [-0.0504, 0.0231, -0.0013],
133 | [0.0500, -0.0008, -0.0088],
134 | [0.0982, 0.0941, 0.0976],
135 | [-0.1233, -0.0280, -0.0897],
136 | [-0.0005, -0.0530, -0.0020],
137 | [-0.1273, -0.0932, -0.0680],
138 | ]
139 | self.taesd_decoder_name = "taef1_decoder"
140 |
141 | def process_in(self, latent: torch.Tensor) -> torch.Tensor:
142 | """#### Process the latent input, by multiplying it by the scale factor and subtracting the shift factor.
143 |
144 | #### Args:
145 | - `latent` (torch.Tensor): The latent tensor.
146 |
147 | #### Returns:
148 | - `torch.Tensor`: The processed latent tensor.
149 | """
150 | return (latent - self.shift_factor) * self.scale_factor
151 |
152 | def process_out(self, latent: torch.Tensor) -> torch.Tensor:
153 | """#### Process the latent output, by dividing it by the scale factor and adding the shift factor.
154 |
155 | #### Args:
156 | - `latent` (torch.Tensor): The latent tensor.
157 |
158 | #### Returns:
159 | - `torch.Tensor`: The processed latent tensor.
160 | """
161 | return (latent / self.scale_factor) + self.shift_factor
162 |
163 | class EmptyLatentImage:
164 | """#### A class to generate an empty latent image.
165 |
166 | #### Args:
167 | - `Device` (Device): The device to use for the latent image.
168 | """
169 |
170 | def __init__(self):
171 | """#### Initialize the EmptyLatentImage class."""
172 | self.device = Device.intermediate_device()
173 |
174 | def generate(
175 | self, width: int, height: int, batch_size: int = 1
176 | ) -> Tuple[Dict[str, torch.Tensor]]:
177 | """#### Generate an empty latent image
178 |
179 | #### Args:
180 | - `width` (int): The width of the latent image.
181 | - `height` (int): The height of the latent image.
182 | - `batch_size` (int, optional): The batch size. Defaults to 1.
183 |
184 | #### Returns:
185 | - `Tuple[Dict[str, torch.Tensor]]`: The generated latent image.
186 | """
187 | latent = torch.zeros(
188 | [batch_size, 4, height // 8, width // 8], device=self.device
189 | )
190 | return ({"samples": latent},)
191 |
192 | def fix_empty_latent_channels(model, latent_image):
193 | """#### Fix the empty latent image channels.
194 |
195 | #### Args:
196 | - `model` (Model): The model object.
197 | - `latent_image` (torch.Tensor): The latent image.
198 |
199 | #### Returns:
200 | - `torch.Tensor`: The fixed latent image.
201 | """
202 | latent_channels = model.get_model_object(
203 | "latent_format"
204 | ).latent_channels # Resize the empty latent image so it has the right number of channels
205 | if (
206 | latent_channels != latent_image.shape[1]
207 | and torch.count_nonzero(latent_image) == 0
208 | ):
209 | latent_image = util.repeat_to_batch_size(latent_image, latent_channels, dim=1)
210 | return latent_image
--------------------------------------------------------------------------------
/modules/Utilities/upscale.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 |
4 |
5 | def bislerp(samples: torch.Tensor, width: int, height: int) -> torch.Tensor:
6 | """#### Perform bilinear interpolation on samples.
7 |
8 | #### Args:
9 | - `samples` (torch.Tensor): The input samples.
10 | - `width` (int): The target width.
11 | - `height` (int): The target height.
12 |
13 | #### Returns:
14 | - `torch.Tensor`: The interpolated samples.
15 | """
16 |
17 | def slerp(b1: torch.Tensor, b2: torch.Tensor, r: torch.Tensor) -> torch.Tensor:
18 | """#### Perform spherical linear interpolation between two vectors.
19 |
20 | #### Args:
21 | - `b1` (torch.Tensor): The first vector.
22 | - `b2` (torch.Tensor): The second vector.
23 | - `r` (torch.Tensor): The interpolation ratio.
24 |
25 | #### Returns:
26 | - `torch.Tensor`: The interpolated vector.
27 | """
28 |
29 | c = b1.shape[-1]
30 |
31 | # norms
32 | b1_norms = torch.norm(b1, dim=-1, keepdim=True)
33 | b2_norms = torch.norm(b2, dim=-1, keepdim=True)
34 |
35 | # normalize
36 | b1_normalized = b1 / b1_norms
37 | b2_normalized = b2 / b2_norms
38 |
39 | # zero when norms are zero
40 | b1_normalized[b1_norms.expand(-1, c) == 0.0] = 0.0
41 | b2_normalized[b2_norms.expand(-1, c) == 0.0] = 0.0
42 |
43 | # slerp
44 | dot = (b1_normalized * b2_normalized).sum(1)
45 | omega = torch.acos(dot)
46 | so = torch.sin(omega)
47 |
48 | # technically not mathematically correct, but more pleasing?
49 | res = (torch.sin((1.0 - r.squeeze(1)) * omega) / so).unsqueeze(
50 | 1
51 | ) * b1_normalized + (torch.sin(r.squeeze(1) * omega) / so).unsqueeze(
52 | 1
53 | ) * b2_normalized
54 | res *= (b1_norms * (1.0 - r) + b2_norms * r).expand(-1, c)
55 |
56 | # edge cases for same or polar opposites
57 | res[dot > 1 - 1e-5] = b1[dot > 1 - 1e-5]
58 | res[dot < 1e-5 - 1] = (b1 * (1.0 - r) + b2 * r)[dot < 1e-5 - 1]
59 | return res
60 |
61 | def generate_bilinear_data(
62 | length_old: int, length_new: int, device: torch.device
63 | ) -> List[torch.Tensor]:
64 | """#### Generate bilinear data for interpolation.
65 |
66 | #### Args:
67 | - `length_old` (int): The old length.
68 | - `length_new` (int): The new length.
69 | - `device` (torch.device): The device to use.
70 |
71 | #### Returns:
72 | - `torch.Tensor`: The ratios.
73 | - `torch.Tensor`: The first coordinates.
74 | - `torch.Tensor`: The second coordinates.
75 | """
76 | coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape(
77 | (1, 1, 1, -1)
78 | )
79 | coords_1 = torch.nn.functional.interpolate(
80 | coords_1, size=(1, length_new), mode="bilinear"
81 | )
82 | ratios = coords_1 - coords_1.floor()
83 | coords_1 = coords_1.to(torch.int64)
84 |
85 | coords_2 = (
86 | torch.arange(length_old, dtype=torch.float32, device=device).reshape(
87 | (1, 1, 1, -1)
88 | )
89 | + 1
90 | )
91 | coords_2[:, :, :, -1] -= 1
92 | coords_2 = torch.nn.functional.interpolate(
93 | coords_2, size=(1, length_new), mode="bilinear"
94 | )
95 | coords_2 = coords_2.to(torch.int64)
96 | return ratios, coords_1, coords_2
97 |
98 | orig_dtype = samples.dtype
99 | samples = samples.float()
100 | n, c, h, w = samples.shape
101 | h_new, w_new = (height, width)
102 |
103 | # linear w
104 | ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
105 | coords_1 = coords_1.expand((n, c, h, -1))
106 | coords_2 = coords_2.expand((n, c, h, -1))
107 | ratios = ratios.expand((n, 1, h, -1))
108 |
109 | pass_1 = samples.gather(-1, coords_1).movedim(1, -1).reshape((-1, c))
110 | pass_2 = samples.gather(-1, coords_2).movedim(1, -1).reshape((-1, c))
111 | ratios = ratios.movedim(1, -1).reshape((-1, 1))
112 |
113 | result = slerp(pass_1, pass_2, ratios)
114 | result = result.reshape(n, h, w_new, c).movedim(-1, 1)
115 |
116 | # linear h
117 | ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
118 | coords_1 = coords_1.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
119 | coords_2 = coords_2.reshape((1, 1, -1, 1)).expand((n, c, -1, w_new))
120 | ratios = ratios.reshape((1, 1, -1, 1)).expand((n, 1, -1, w_new))
121 |
122 | pass_1 = result.gather(-2, coords_1).movedim(1, -1).reshape((-1, c))
123 | pass_2 = result.gather(-2, coords_2).movedim(1, -1).reshape((-1, c))
124 | ratios = ratios.movedim(1, -1).reshape((-1, 1))
125 |
126 | result = slerp(pass_1, pass_2, ratios)
127 | result = result.reshape(n, h_new, w_new, c).movedim(-1, 1)
128 | return result.to(orig_dtype)
129 |
130 |
131 | def common_upscale(samples: List, width: int, height: int) -> torch.Tensor:
132 | """#### Upscales the given samples to the specified width and height using the specified method and crop settings.
133 | #### Args:
134 | - `samples` (list): The list of samples to be upscaled.
135 | - `width` (int): The target width for the upscaled samples.
136 | - `height` (int): The target height for the upscaled samples.
137 | #### Returns:
138 | - `torch.Tensor`: The upscaled samples.
139 | """
140 | s = samples
141 | return bislerp(s, width, height)
142 |
143 |
144 | class LatentUpscale:
145 | """#### A class to upscale latent codes."""
146 |
147 | def upscale(self, samples: dict, width: int, height: int) -> tuple:
148 | """#### Upscales the given latent codes.
149 |
150 | #### Args:
151 | - `samples` (dict): The latent codes to be upscaled.
152 | - `width` (int): The target width for the upscaled samples.
153 | - `height` (int): The target height for the upscaled samples.
154 |
155 | #### Returns:
156 | - `tuple`: The upscaled samples.
157 | """
158 | if width == 0 and height == 0:
159 | s = samples
160 | else:
161 | s = samples.copy()
162 | width = max(64, width)
163 | height = max(64, height)
164 |
165 | s["samples"] = common_upscale(samples["samples"], width // 8, height // 8)
166 | return (s,)
167 |
--------------------------------------------------------------------------------
/modules/WaveSpeed/fbcache_nodes.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import unittest
3 | import torch
4 |
5 | from . import first_block_cache
6 |
7 |
8 | class ApplyFBCacheOnModel:
9 |
10 | def patch(
11 | self,
12 | model,
13 | object_to_patch,
14 | residual_diff_threshold,
15 | max_consecutive_cache_hits=-1,
16 | start=0.0,
17 | end=1.0,
18 | ):
19 | if residual_diff_threshold <= 0.0 or max_consecutive_cache_hits == 0:
20 | return (model, )
21 |
22 | # first_block_cache.patch_get_output_data()
23 |
24 | using_validation = max_consecutive_cache_hits >= 0 or start > 0 or end < 1
25 | if using_validation:
26 | model_sampling = model.get_model_object("model_sampling")
27 | start_sigma, end_sigma = (float(
28 | model_sampling.percent_to_sigma(pct)) for pct in (start, end))
29 | del model_sampling
30 |
31 | @torch.compiler.disable()
32 | def validate_use_cache(use_cached):
33 | nonlocal consecutive_cache_hits
34 | use_cached = use_cached and end_sigma <= current_timestep <= start_sigma
35 | use_cached = use_cached and (max_consecutive_cache_hits < 0
36 | or consecutive_cache_hits
37 | < max_consecutive_cache_hits)
38 | consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0
39 | return use_cached
40 | else:
41 | validate_use_cache = None
42 |
43 | prev_timestep = None
44 | prev_input_state = None
45 | current_timestep = None
46 | consecutive_cache_hits = 0
47 |
48 | def reset_cache_state():
49 | # Resets the cache state and hits/time tracking variables.
50 | nonlocal prev_input_state, prev_timestep, consecutive_cache_hits
51 | prev_input_state = prev_timestep = None
52 | consecutive_cache_hits = 0
53 | first_block_cache.set_current_cache_context(
54 | first_block_cache.create_cache_context())
55 |
56 | def ensure_cache_state(model_input: torch.Tensor, timestep: float):
57 | # Validates the current cache state and hits/time tracking variables
58 | # and triggers a reset if necessary. Also updates current_timestep.
59 | nonlocal current_timestep
60 | input_state = (model_input.shape, model_input.dtype, model_input.device)
61 | need_reset = (
62 | prev_timestep is None or
63 | prev_input_state != input_state or
64 | first_block_cache.get_current_cache_context() is None or
65 | timestep >= prev_timestep
66 | )
67 | if need_reset:
68 | reset_cache_state()
69 | current_timestep = timestep
70 |
71 | def update_cache_state(model_input: torch.Tensor, timestep: float):
72 | # Updates the previous timestep and input state validation variables.
73 | nonlocal prev_timestep, prev_input_state
74 | prev_timestep = timestep
75 | prev_input_state = (model_input.shape, model_input.dtype, model_input.device)
76 |
77 | model = model[0].clone()
78 | diffusion_model = model.get_model_object(object_to_patch)
79 |
80 | if diffusion_model.__class__.__name__ in ("UNetModel", "Flux"):
81 |
82 | if diffusion_model.__class__.__name__ == "UNetModel":
83 | create_patch_function = first_block_cache.create_patch_unet_model__forward
84 | elif diffusion_model.__class__.__name__ == "Flux":
85 | create_patch_function = first_block_cache.create_patch_flux_forward_orig
86 | else:
87 | raise ValueError(
88 | f"Unsupported model {diffusion_model.__class__.__name__}")
89 |
90 | patch_forward = create_patch_function(
91 | diffusion_model,
92 | residual_diff_threshold=residual_diff_threshold,
93 | validate_can_use_cache_function=validate_use_cache,
94 | )
95 |
96 | def model_unet_function_wrapper(model_function, kwargs):
97 | try:
98 | input = kwargs["input"]
99 | timestep = kwargs["timestep"]
100 | c = kwargs["c"]
101 | t = timestep[0].item()
102 |
103 | ensure_cache_state(input, t)
104 |
105 | with patch_forward():
106 | result = model_function(input, timestep, **c)
107 | update_cache_state(input, t)
108 | return result
109 | except Exception as exc:
110 | reset_cache_state()
111 | raise exc from None
112 | else:
113 | is_non_native_ltxv = False
114 | if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
115 | is_non_native_ltxv = True
116 | diffusion_model = diffusion_model.transformer
117 |
118 | double_blocks_name = None
119 | single_blocks_name = None
120 | if hasattr(diffusion_model, "transformer_blocks"):
121 | double_blocks_name = "transformer_blocks"
122 | elif hasattr(diffusion_model, "double_blocks"):
123 | double_blocks_name = "double_blocks"
124 | elif hasattr(diffusion_model, "joint_blocks"):
125 | double_blocks_name = "joint_blocks"
126 | else:
127 | raise ValueError(
128 | f"No double blocks found for {diffusion_model.__class__.__name__}"
129 | )
130 |
131 | if hasattr(diffusion_model, "single_blocks"):
132 | single_blocks_name = "single_blocks"
133 |
134 | if is_non_native_ltxv:
135 | original_create_skip_layer_mask = getattr(
136 | diffusion_model, "create_skip_layer_mask", None)
137 | if original_create_skip_layer_mask is not None:
138 | # original_double_blocks = getattr(diffusion_model,
139 | # double_blocks_name)
140 |
141 | def new_create_skip_layer_mask(self, *args, **kwargs):
142 | # with unittest.mock.patch.object(self, double_blocks_name,
143 | # original_double_blocks):
144 | # return original_create_skip_layer_mask(*args, **kwargs)
145 | # return original_create_skip_layer_mask(*args, **kwargs)
146 | raise RuntimeError(
147 | "STG is not supported with FBCache yet")
148 |
149 | diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__(
150 | diffusion_model)
151 |
152 | cached_transformer_blocks = torch.nn.ModuleList([
153 | first_block_cache.CachedTransformerBlocks(
154 | None if double_blocks_name is None else getattr(
155 | diffusion_model, double_blocks_name),
156 | None if single_blocks_name is None else getattr(
157 | diffusion_model, single_blocks_name),
158 | residual_diff_threshold=residual_diff_threshold,
159 | validate_can_use_cache_function=validate_use_cache,
160 | cat_hidden_states_first=diffusion_model.__class__.__name__
161 | == "HunyuanVideo",
162 | return_hidden_states_only=diffusion_model.__class__.
163 | __name__ == "LTXVModel" or is_non_native_ltxv,
164 | clone_original_hidden_states=diffusion_model.__class__.
165 | __name__ == "LTXVModel",
166 | return_hidden_states_first=diffusion_model.__class__.
167 | __name__ != "OpenAISignatureMMDITWrapper",
168 | accept_hidden_states_first=diffusion_model.__class__.
169 | __name__ != "OpenAISignatureMMDITWrapper",
170 | )
171 | ])
172 | dummy_single_transformer_blocks = torch.nn.ModuleList()
173 |
174 | def model_unet_function_wrapper(model_function, kwargs):
175 | try:
176 | input = kwargs["input"]
177 | timestep = kwargs["timestep"]
178 | c = kwargs["c"]
179 | t = timestep[0].item()
180 |
181 | ensure_cache_state(input, t)
182 |
183 | with unittest.mock.patch.object(
184 | diffusion_model,
185 | double_blocks_name,
186 | cached_transformer_blocks,
187 | ), unittest.mock.patch.object(
188 | diffusion_model,
189 | single_blocks_name,
190 | dummy_single_transformer_blocks,
191 | ) if single_blocks_name is not None else contextlib.nullcontext(
192 | ):
193 | result = model_function(input, timestep, **c)
194 | update_cache_state(input, t)
195 | return result
196 | except Exception as exc:
197 | reset_cache_state()
198 | raise exc from None
199 |
200 | model.set_model_unet_function_wrapper(model_unet_function_wrapper)
201 | return (model, )
202 |
--------------------------------------------------------------------------------
/modules/WaveSpeed/misc_nodes.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import json
3 |
4 | from . import utils
5 |
6 |
7 | class EnhancedCompileModel:
8 |
9 | def patch(
10 | self,
11 | model,
12 | is_patcher,
13 | object_to_patch,
14 | compiler,
15 | fullgraph,
16 | dynamic,
17 | mode,
18 | options,
19 | disable,
20 | backend,
21 | ):
22 | utils.patch_optimized_module()
23 | utils.patch_same_meta()
24 |
25 | import_path, function_name = compiler.rsplit(".", 1)
26 | module = importlib.import_module(import_path)
27 | compile_function = getattr(module, function_name)
28 |
29 | mode = mode if mode else None
30 | options = json.loads(options) if options else None
31 |
32 | if compiler == "torch.compile" and backend == "inductor" and dynamic:
33 | # TODO: Fix this
34 | # File "pytorch/torch/_inductor/fx_passes/post_grad.py", line 643, in same_meta
35 | # and statically_known_true(sym_eq(val1.size(), val2.size()))
36 | # AttributeError: 'SymInt' object has no attribute 'size'
37 | pass
38 |
39 | if is_patcher:
40 | patcher = model[0].clone()
41 | else:
42 | patcher = model.patcher
43 | patcher = patcher.clone()
44 |
45 | patcher.add_object_patch(
46 | object_to_patch,
47 | compile_function(
48 | patcher.get_model_object(object_to_patch),
49 | fullgraph=fullgraph,
50 | dynamic=dynamic,
51 | mode=mode,
52 | options=options,
53 | disable=disable,
54 | backend=backend,
55 | ),
56 | )
57 |
58 | if is_patcher:
59 | return (patcher,)
60 | else:
61 | model.patcher = patcher
62 | return (model,)
63 |
--------------------------------------------------------------------------------
/modules/WaveSpeed/utils.py:
--------------------------------------------------------------------------------
1 | import contextlib
2 | import unittest
3 |
4 | import torch
5 |
6 |
7 |
8 |
9 | # wildcard trick is taken from pythongossss's
10 | class AnyType(str):
11 |
12 | def __ne__(self, __value: object) -> bool:
13 | return False
14 |
15 |
16 | any_typ = AnyType("*")
17 |
18 |
19 | def get_weight_dtype_inputs():
20 | return {
21 | "weight_dtype": (
22 | [
23 | "default",
24 | "float32",
25 | "float64",
26 | "bfloat16",
27 | "float16",
28 | "fp8_e4m3fn",
29 | "fp8_e4m3fn_fast",
30 | "fp8_e5m2",
31 | ],
32 | ),
33 | }
34 |
35 |
36 | def parse_weight_dtype(model_options, weight_dtype):
37 | dtype = {
38 | "float32": torch.float32,
39 | "float64": torch.float64,
40 | "bfloat16": torch.bfloat16,
41 | "float16": torch.float16,
42 | "fp8_e4m3fn": torch.float8_e4m3fn,
43 | "fp8_e4m3fn_fast": torch.float8_e4m3fn,
44 | "fp8_e5m2": torch.float8_e5m2,
45 | }.get(weight_dtype, None)
46 | if dtype is not None:
47 | model_options["dtype"] = dtype
48 | if weight_dtype == "fp8_e4m3fn_fast":
49 | model_options["fp8_optimizations"] = True
50 | return model_options
51 |
52 |
53 | @contextlib.contextmanager
54 | def disable_load_models_gpu():
55 | def foo(*args, **kwargs):
56 | pass
57 | from modules.Device import Device
58 | with unittest.mock.patch.object(Device, "load_models_gpu", foo):
59 | yield
60 |
61 |
62 | def patch_optimized_module():
63 | try:
64 | from torch._dynamo.eval_frame import OptimizedModule
65 | except ImportError:
66 | return
67 |
68 | if getattr(OptimizedModule, "_patched", False):
69 | return
70 |
71 | def __getattribute__(self, name):
72 | if name == "_orig_mod":
73 | return object.__getattribute__(self, "_modules")[name]
74 | if name in (
75 | "__class__",
76 | "_modules",
77 | "state_dict",
78 | "load_state_dict",
79 | "parameters",
80 | "named_parameters",
81 | "buffers",
82 | "named_buffers",
83 | "children",
84 | "named_children",
85 | "modules",
86 | "named_modules",
87 | ):
88 | return getattr(object.__getattribute__(self, "_orig_mod"), name)
89 | return object.__getattribute__(self, name)
90 |
91 | def __delattr__(self, name):
92 | # unload_lora_weights() wants to del peft_config
93 | return delattr(self._orig_mod, name)
94 |
95 | @classmethod
96 | def __instancecheck__(cls, instance):
97 | return isinstance(instance, OptimizedModule) or issubclass(
98 | object.__getattribute__(instance, "__class__"), cls
99 | )
100 |
101 | OptimizedModule.__getattribute__ = __getattribute__
102 | OptimizedModule.__delattr__ = __delattr__
103 | OptimizedModule.__instancecheck__ = __instancecheck__
104 | OptimizedModule._patched = True
105 |
106 |
107 | def patch_same_meta():
108 | try:
109 | from torch._inductor.fx_passes import post_grad
110 | except ImportError:
111 | return
112 |
113 | same_meta = getattr(post_grad, "same_meta", None)
114 | if same_meta is None:
115 | return
116 |
117 | if getattr(same_meta, "_patched", False):
118 | return
119 |
120 | def new_same_meta(a, b):
121 | try:
122 | return same_meta(a, b)
123 | except Exception:
124 | return False
125 |
126 | post_grad.same_meta = new_same_meta
127 | new_same_meta._patched = True
128 |
--------------------------------------------------------------------------------
/modules/clip/CLIPTextModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | class CLIPTextModel_(torch.nn.Module):
4 | """#### The CLIPTextModel_ module."""
5 | def __init__(
6 | self,
7 | config_dict: dict,
8 | dtype: torch.dtype,
9 | device: torch.device,
10 | operations: object,
11 | ):
12 | """#### Initialize the CLIPTextModel_ module.
13 |
14 | #### Args:
15 | - `config_dict` (dict): The configuration dictionary.
16 | - `dtype` (torch.dtype): The data type.
17 | - `device` (torch.device): The device to use.
18 | - `operations` (object): The operations object.
19 | """
20 | num_layers = config_dict["num_hidden_layers"]
21 | embed_dim = config_dict["hidden_size"]
22 | heads = config_dict["num_attention_heads"]
23 | intermediate_size = config_dict["intermediate_size"]
24 | intermediate_activation = config_dict["hidden_act"]
25 | num_positions = config_dict["max_position_embeddings"]
26 | self.eos_token_id = config_dict["eos_token_id"]
27 |
28 | super().__init__()
29 | from modules.clip.Clip import CLIPEmbeddings, CLIPEncoder
30 | self.embeddings = CLIPEmbeddings(
31 | embed_dim,
32 | num_positions=num_positions,
33 | dtype=dtype,
34 | device=device,
35 | operations=operations,
36 | )
37 | self.encoder = CLIPEncoder(
38 | num_layers,
39 | embed_dim,
40 | heads,
41 | intermediate_size,
42 | intermediate_activation,
43 | dtype,
44 | device,
45 | operations,
46 | )
47 | self.final_layer_norm = operations.LayerNorm(
48 | embed_dim, dtype=dtype, device=device
49 | )
50 |
51 | def forward(
52 | self,
53 | input_tokens: torch.Tensor,
54 | attention_mask: torch.Tensor = None,
55 | intermediate_output: int = None,
56 | final_layer_norm_intermediate: bool = True,
57 | dtype: torch.dtype = torch.float32,
58 | ) -> tuple:
59 | """#### Forward pass for the CLIPTextModel_ module.
60 |
61 | #### Args:
62 | - `input_tokens` (torch.Tensor): The input tokens.
63 | - `attention_mask` (torch.Tensor, optional): The attention mask. Defaults to None.
64 | - `intermediate_output` (int, optional): The intermediate output layer. Defaults to None.
65 | - `final_layer_norm_intermediate` (bool, optional): Whether to apply final layer normalization to the intermediate output. Defaults to True.
66 |
67 | #### Returns:
68 | - `tuple`: The output tensor, the intermediate output tensor, and the pooled output tensor.
69 | """
70 | x = self.embeddings(input_tokens, dtype=dtype)
71 | mask = None
72 | if attention_mask is not None:
73 | mask = 1.0 - attention_mask.to(x.dtype).reshape(
74 | (attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
75 | ).expand(
76 | attention_mask.shape[0],
77 | 1,
78 | attention_mask.shape[-1],
79 | attention_mask.shape[-1],
80 | )
81 | mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
82 |
83 | causal_mask = (
84 | torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device)
85 | .fill_(float("-inf"))
86 | .triu_(1)
87 | )
88 | if mask is not None:
89 | mask += causal_mask
90 | else:
91 | mask = causal_mask
92 |
93 | x, i = self.encoder(x, mask=mask, intermediate_output=intermediate_output)
94 | x = self.final_layer_norm(x)
95 | if i is not None and final_layer_norm_intermediate:
96 | i = self.final_layer_norm(i)
97 |
98 | pooled_output = x[
99 | torch.arange(x.shape[0], device=x.device),
100 | (
101 | torch.round(input_tokens).to(dtype=torch.int, device=x.device)
102 | == self.eos_token_id
103 | )
104 | .int()
105 | .argmax(dim=-1),
106 | ]
107 | return x, i, pooled_output
108 |
109 | class CLIPTextModel(torch.nn.Module):
110 | """#### The CLIPTextModel module."""
111 | def __init__(
112 | self,
113 | config_dict: dict,
114 | dtype: torch.dtype,
115 | device: torch.device,
116 | operations: object,
117 | ):
118 | """#### Initialize the CLIPTextModel module.
119 |
120 | #### Args:
121 | - `config_dict` (dict): The configuration dictionary.
122 | - `dtype` (torch.dtype): The data type.
123 | - `device` (torch.device): The device to use.
124 | - `operations` (object): The operations object.
125 | """
126 | super().__init__()
127 | self.num_layers = config_dict["num_hidden_layers"]
128 | self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
129 | embed_dim = config_dict["hidden_size"]
130 | self.text_projection = operations.Linear(
131 | embed_dim, embed_dim, bias=False, dtype=dtype, device=device
132 | )
133 | self.dtype = dtype
134 |
135 | def get_input_embeddings(self) -> torch.nn.Embedding:
136 | """#### Get the input embeddings.
137 |
138 | #### Returns:
139 | - `torch.nn.Embedding`: The input embeddings.
140 | """
141 | return self.text_model.embeddings.token_embedding
142 |
143 | def set_input_embeddings(self, embeddings: torch.nn.Embedding) -> None:
144 | """#### Set the input embeddings.
145 |
146 | #### Args:
147 | - `embeddings` (torch.nn.Embedding): The input embeddings.
148 | """
149 | self.text_model.embeddings.token_embedding = embeddings
150 |
151 | def forward(self, *args, **kwargs) -> tuple:
152 | """#### Forward pass for the CLIPTextModel module.
153 |
154 | #### Args:
155 | - `*args`: Variable length argument list.
156 | - `**kwargs`: Arbitrary keyword arguments.
157 |
158 | #### Returns:
159 | - `tuple`: The output tensors.
160 | """
161 | x = self.text_model(*args, **kwargs)
162 | out = self.text_projection(x[2])
163 | return (x[0], x[1], out, x[2])
--------------------------------------------------------------------------------
/modules/clip/clip/hydit_clip.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "hfl/chinese-roberta-wwm-ext-large",
3 | "architectures": [
4 | "BertModel"
5 | ],
6 | "attention_probs_dropout_prob": 0.1,
7 | "bos_token_id": 0,
8 | "classifier_dropout": null,
9 | "directionality": "bidi",
10 | "eos_token_id": 2,
11 | "hidden_act": "gelu",
12 | "hidden_dropout_prob": 0.1,
13 | "hidden_size": 1024,
14 | "initializer_range": 0.02,
15 | "intermediate_size": 4096,
16 | "layer_norm_eps": 1e-12,
17 | "max_position_embeddings": 512,
18 | "model_type": "bert",
19 | "num_attention_heads": 16,
20 | "num_hidden_layers": 24,
21 | "output_past": true,
22 | "pad_token_id": 0,
23 | "pooler_fc_size": 768,
24 | "pooler_num_attention_heads": 12,
25 | "pooler_num_fc_layers": 3,
26 | "pooler_size_per_head": 128,
27 | "pooler_type": "first_token_transform",
28 | "position_embedding_type": "absolute",
29 | "torch_dtype": "float32",
30 | "transformers_version": "4.22.1",
31 | "type_vocab_size": 2,
32 | "use_cache": true,
33 | "vocab_size": 47020
34 | }
35 |
36 |
--------------------------------------------------------------------------------
/modules/clip/clip/long_clipl.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "openai/clip-vit-large-patch14",
3 | "architectures": [
4 | "CLIPTextModel"
5 | ],
6 | "attention_dropout": 0.0,
7 | "bos_token_id": 0,
8 | "dropout": 0.0,
9 | "eos_token_id": 49407,
10 | "hidden_act": "quick_gelu",
11 | "hidden_size": 768,
12 | "initializer_factor": 1.0,
13 | "initializer_range": 0.02,
14 | "intermediate_size": 3072,
15 | "layer_norm_eps": 1e-05,
16 | "max_position_embeddings": 248,
17 | "model_type": "clip_text_model",
18 | "num_attention_heads": 12,
19 | "num_hidden_layers": 12,
20 | "pad_token_id": 1,
21 | "projection_dim": 768,
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.24.0",
24 | "vocab_size": 49408
25 | }
26 |
--------------------------------------------------------------------------------
/modules/clip/clip/mt5_config_xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "d_ff": 5120,
3 | "d_kv": 64,
4 | "d_model": 2048,
5 | "decoder_start_token_id": 0,
6 | "dropout_rate": 0.1,
7 | "eos_token_id": 1,
8 | "dense_act_fn": "gelu_pytorch_tanh",
9 | "initializer_factor": 1.0,
10 | "is_encoder_decoder": true,
11 | "is_gated_act": true,
12 | "layer_norm_epsilon": 1e-06,
13 | "model_type": "mt5",
14 | "num_decoder_layers": 24,
15 | "num_heads": 32,
16 | "num_layers": 24,
17 | "output_past": true,
18 | "pad_token_id": 0,
19 | "relative_attention_num_buckets": 32,
20 | "tie_word_embeddings": false,
21 | "vocab_size": 250112
22 | }
23 |
--------------------------------------------------------------------------------
/modules/clip/clip/sd1_clip_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "openai/clip-vit-large-patch14",
3 | "architectures": [
4 | "CLIPTextModel"
5 | ],
6 | "attention_dropout": 0.0,
7 | "bos_token_id": 0,
8 | "dropout": 0.0,
9 | "eos_token_id": 49407,
10 | "hidden_act": "quick_gelu",
11 | "hidden_size": 768,
12 | "initializer_factor": 1.0,
13 | "initializer_range": 0.02,
14 | "intermediate_size": 3072,
15 | "layer_norm_eps": 1e-05,
16 | "max_position_embeddings": 77,
17 | "model_type": "clip_text_model",
18 | "num_attention_heads": 12,
19 | "num_hidden_layers": 12,
20 | "pad_token_id": 1,
21 | "projection_dim": 768,
22 | "torch_dtype": "float32",
23 | "transformers_version": "4.24.0",
24 | "vocab_size": 49408
25 | }
26 |
--------------------------------------------------------------------------------
/modules/clip/clip/sd2_clip_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "architectures": [
3 | "CLIPTextModel"
4 | ],
5 | "attention_dropout": 0.0,
6 | "bos_token_id": 0,
7 | "dropout": 0.0,
8 | "eos_token_id": 49407,
9 | "hidden_act": "gelu",
10 | "hidden_size": 1024,
11 | "initializer_factor": 1.0,
12 | "initializer_range": 0.02,
13 | "intermediate_size": 4096,
14 | "layer_norm_eps": 1e-05,
15 | "max_position_embeddings": 77,
16 | "model_type": "clip_text_model",
17 | "num_attention_heads": 16,
18 | "num_hidden_layers": 24,
19 | "pad_token_id": 1,
20 | "projection_dim": 1024,
21 | "torch_dtype": "float32",
22 | "vocab_size": 49408
23 | }
24 |
--------------------------------------------------------------------------------
/modules/clip/clip/t5_config_base.json:
--------------------------------------------------------------------------------
1 | {
2 | "d_ff": 3072,
3 | "d_kv": 64,
4 | "d_model": 768,
5 | "decoder_start_token_id": 0,
6 | "dropout_rate": 0.1,
7 | "eos_token_id": 1,
8 | "dense_act_fn": "relu",
9 | "initializer_factor": 1.0,
10 | "is_encoder_decoder": true,
11 | "is_gated_act": false,
12 | "layer_norm_epsilon": 1e-06,
13 | "model_type": "t5",
14 | "num_decoder_layers": 12,
15 | "num_heads": 12,
16 | "num_layers": 12,
17 | "output_past": true,
18 | "pad_token_id": 0,
19 | "relative_attention_num_buckets": 32,
20 | "tie_word_embeddings": false,
21 | "vocab_size": 32128
22 | }
23 |
--------------------------------------------------------------------------------
/modules/clip/clip/t5_config_xxl.json:
--------------------------------------------------------------------------------
1 | {
2 | "d_ff": 10240,
3 | "d_kv": 64,
4 | "d_model": 4096,
5 | "decoder_start_token_id": 0,
6 | "dropout_rate": 0.1,
7 | "eos_token_id": 1,
8 | "dense_act_fn": "gelu_pytorch_tanh",
9 | "initializer_factor": 1.0,
10 | "is_encoder_decoder": true,
11 | "is_gated_act": true,
12 | "layer_norm_epsilon": 1e-06,
13 | "model_type": "t5",
14 | "num_decoder_layers": 24,
15 | "num_heads": 64,
16 | "num_layers": 24,
17 | "output_past": true,
18 | "pad_token_id": 0,
19 | "relative_attention_num_buckets": 32,
20 | "tie_word_embeddings": false,
21 | "vocab_size": 32128
22 | }
23 |
--------------------------------------------------------------------------------
/modules/clip/clip/t5_pile_config_xl.json:
--------------------------------------------------------------------------------
1 | {
2 | "d_ff": 5120,
3 | "d_kv": 64,
4 | "d_model": 2048,
5 | "decoder_start_token_id": 0,
6 | "dropout_rate": 0.1,
7 | "eos_token_id": 2,
8 | "dense_act_fn": "gelu_pytorch_tanh",
9 | "initializer_factor": 1.0,
10 | "is_encoder_decoder": true,
11 | "is_gated_act": true,
12 | "layer_norm_epsilon": 1e-06,
13 | "model_type": "umt5",
14 | "num_decoder_layers": 24,
15 | "num_heads": 32,
16 | "num_layers": 24,
17 | "output_past": true,
18 | "pad_token_id": 1,
19 | "relative_attention_num_buckets": 32,
20 | "tie_word_embeddings": false,
21 | "vocab_size": 32128
22 | }
23 |
--------------------------------------------------------------------------------
/modules/clip/clip/t5_tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "additional_special_tokens": [
3 | "",
4 | "",
5 | "",
6 | "",
7 | "",
8 | "",
9 | "",
10 | "",
11 | "",
12 | "",
13 | "",
14 | "",
15 | "",
16 | "",
17 | "",
18 | "",
19 | "",
20 | "",
21 | "",
22 | "",
23 | "",
24 | "",
25 | "",
26 | "",
27 | "",
28 | "",
29 | "",
30 | "",
31 | "",
32 | "",
33 | "",
34 | "",
35 | "",
36 | "",
37 | "",
38 | "",
39 | "",
40 | "",
41 | "",
42 | "",
43 | "",
44 | "",
45 | "",
46 | "",
47 | "",
48 | "",
49 | "",
50 | "",
51 | "",
52 | "",
53 | "",
54 | "",
55 | "",
56 | "",
57 | "",
58 | "",
59 | "",
60 | "",
61 | "",
62 | "",
63 | "",
64 | "",
65 | "",
66 | "",
67 | "",
68 | "",
69 | "",
70 | "",
71 | "",
72 | "",
73 | "",
74 | "",
75 | "",
76 | "",
77 | "",
78 | "",
79 | "",
80 | "",
81 | "",
82 | "",
83 | "",
84 | "",
85 | "",
86 | "",
87 | "",
88 | "",
89 | "",
90 | "",
91 | "",
92 | "",
93 | "",
94 | "",
95 | "",
96 | "",
97 | "",
98 | "",
99 | "",
100 | "",
101 | "",
102 | ""
103 | ],
104 | "eos_token": {
105 | "content": "",
106 | "lstrip": false,
107 | "normalized": false,
108 | "rstrip": false,
109 | "single_word": false
110 | },
111 | "pad_token": {
112 | "content": "",
113 | "lstrip": false,
114 | "normalized": false,
115 | "rstrip": false,
116 | "single_word": false
117 | },
118 | "unk_token": {
119 | "content": "",
120 | "lstrip": false,
121 | "normalized": false,
122 | "rstrip": false,
123 | "single_word": false
124 | }
125 | }
126 |
--------------------------------------------------------------------------------
/modules/cond/Activation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from modules.cond import cast
4 |
5 |
6 | class GEGLU(nn.Module):
7 | """#### Class representing the GEGLU activation function.
8 |
9 | GEGLU is a gated activation function that is a combination of GELU and ReLU,
10 | used to fire the neurons in the network.
11 |
12 | #### Args:
13 | - `dim_in` (int): The input dimension.
14 | - `dim_out` (int): The output dimension.
15 | """
16 |
17 | def __init__(self, dim_in: int, dim_out: int):
18 | super().__init__()
19 | self.proj = cast.manual_cast.Linear(dim_in, dim_out * 2)
20 |
21 | def forward(self, x: torch.Tensor) -> torch.Tensor:
22 | """#### Forward pass for the GEGLU activation function.
23 |
24 | #### Args:
25 | - `x` (torch.Tensor): The input tensor.
26 |
27 | #### Returns:
28 | - `torch.Tensor`: The output tensor.
29 | """
30 | x, gate = self.proj(x).chunk(2, dim=-1)
31 | return x * torch.nn.functional.gelu(gate)
32 |
--------------------------------------------------------------------------------
/modules/cond/cond_util.py:
--------------------------------------------------------------------------------
1 | from modules.Device import Device
2 | import torch
3 | from typing import List, Tuple, Any
4 |
5 |
6 | def get_models_from_cond(cond: dict, model_type: str) -> List[object]:
7 | """#### Get models from a condition.
8 |
9 | #### Args:
10 | - `cond` (dict): The condition.
11 | - `model_type` (str): The model type.
12 |
13 | #### Returns:
14 | - `List[object]`: The list of models.
15 | """
16 | models = []
17 | for c in cond:
18 | if model_type in c:
19 | models += [c[model_type]]
20 | return models
21 |
22 |
23 | def get_additional_models(conds: dict, dtype: torch.dtype) -> Tuple[List[object], int]:
24 | """#### Load additional models in conditioning.
25 |
26 | #### Args:
27 | - `conds` (dict): The conditions.
28 | - `dtype` (torch.dtype): The data type.
29 |
30 | #### Returns:
31 | - `Tuple[List[object], int]`: The list of models and the inference memory.
32 | """
33 | cnets = []
34 | gligen = []
35 |
36 | for k in conds:
37 | cnets += get_models_from_cond(conds[k], "control")
38 | gligen += get_models_from_cond(conds[k], "gligen")
39 |
40 | control_nets = set(cnets)
41 |
42 | inference_memory = 0
43 | control_models = []
44 | for m in control_nets:
45 | control_models += m.get_models()
46 | inference_memory += m.inference_memory_requirements(dtype)
47 |
48 | gligen = [x[1] for x in gligen]
49 | models = control_models + gligen
50 | return models, inference_memory
51 |
52 |
53 | def prepare_sampling(
54 | model: object, noise_shape: Tuple[int], conds: dict, flux_enabled: bool = False
55 | ) -> Tuple[object, dict, List[object]]:
56 | """#### Prepare the model for sampling.
57 |
58 | #### Args:
59 | - `model` (object): The model.
60 | - `noise_shape` (Tuple[int]): The shape of the noise.
61 | - `conds` (dict): The conditions.
62 | - `flux_enabled` (bool, optional): Whether flux is enabled. Defaults to False.
63 |
64 | #### Returns:
65 | - `Tuple[object, dict, List[object]]`: The prepared model, conditions, and additional models.
66 | """
67 | real_model = None
68 | models, inference_memory = get_additional_models(conds, model.model_dtype())
69 | memory_required = (
70 | model.memory_required([noise_shape[0] * 2] + list(noise_shape[1:]))
71 | + inference_memory
72 | )
73 | minimum_memory_required = (
74 | model.memory_required([noise_shape[0]] + list(noise_shape[1:]))
75 | + inference_memory
76 | )
77 | Device.load_models_gpu(
78 | [model] + models,
79 | memory_required=memory_required,
80 | minimum_memory_required=minimum_memory_required,
81 | flux_enabled=flux_enabled,
82 | )
83 | real_model = model.model
84 |
85 | return real_model, conds, models
86 |
87 | def cleanup_additional_models(models: List[object]) -> None:
88 | """#### Clean up additional models.
89 |
90 | #### Args:
91 | - `models` (List[object]): The list of models.
92 | """
93 | for m in models:
94 | if hasattr(m, "cleanup"):
95 | m.cleanup()
96 |
97 | def cleanup_models(conds: dict, models: List[object]) -> None:
98 | """#### Clean up the models after sampling.
99 |
100 | #### Args:
101 | - `conds` (dict): The conditions.
102 | - `models` (List[object]): The list of models.
103 | """
104 | cleanup_additional_models(models)
105 |
106 | control_cleanup = []
107 | for k in conds:
108 | control_cleanup += get_models_from_cond(conds[k], "control")
109 |
110 | cleanup_additional_models(set(control_cleanup))
111 |
112 |
113 | def cond_equal_size(c1: Any, c2: Any) -> bool:
114 | """#### Check if two conditions have equal size.
115 |
116 | #### Args:
117 | - `c1` (Any): The first condition.
118 | - `c2` (Any): The second condition.
119 |
120 | #### Returns:
121 | - `bool`: Whether the conditions have equal size.
122 | """
123 | if c1 is c2:
124 | return True
125 | if c1.keys() != c2.keys():
126 | return False
127 | return True
128 |
129 |
130 | def can_concat_cond(c1: Any, c2: Any) -> bool:
131 | """#### Check if two conditions can be concatenated.
132 |
133 | #### Args:
134 | - `c1` (Any): The first condition.
135 | - `c2` (Any): The second condition.
136 |
137 | #### Returns:
138 | - `bool`: Whether the conditions can be concatenated.
139 | """
140 | if c1.input_x.shape != c2.input_x.shape:
141 | return False
142 |
143 | def objects_concatable(obj1, obj2):
144 | """#### Check if two objects can be concatenated."""
145 | if (obj1 is None) != (obj2 is None):
146 | return False
147 | if obj1 is not None:
148 | if obj1 is not obj2:
149 | return False
150 | return True
151 |
152 | if not objects_concatable(c1.control, c2.control):
153 | return False
154 |
155 | if not objects_concatable(c1.patches, c2.patches):
156 | return False
157 |
158 | return cond_equal_size(c1.conditioning, c2.conditioning)
159 |
160 |
161 | def cond_cat(c_list: List[dict]) -> dict:
162 | """#### Concatenate a list of conditions.
163 |
164 | #### Args:
165 | - `c_list` (List[dict]): The list of conditions.
166 |
167 | #### Returns:
168 | - `dict`: The concatenated conditions.
169 | """
170 | temp = {}
171 | for x in c_list:
172 | for k in x:
173 | cur = temp.get(k, [])
174 | cur.append(x[k])
175 | temp[k] = cur
176 |
177 | out = {}
178 | for k in temp:
179 | conds = temp[k]
180 | out[k] = conds[0].concat(conds[1:])
181 |
182 | return out
183 |
184 |
185 | def create_cond_with_same_area_if_none(conds: List[dict], c: dict) -> None:
186 | """#### Create a condition with the same area if none exists.
187 |
188 | #### Args:
189 | - `conds` (List[dict]): The list of conditions.
190 | - `c` (dict): The condition.
191 | """
192 | if "area" not in c:
193 | return
194 |
195 | c_area = c["area"]
196 | smallest = None
197 | for x in conds:
198 | if "area" in x:
199 | a = x["area"]
200 | if c_area[2] >= a[2] and c_area[3] >= a[3]:
201 | if a[0] + a[2] >= c_area[0] + c_area[2]:
202 | if a[1] + a[3] >= c_area[1] + c_area[3]:
203 | if smallest is None:
204 | smallest = x
205 | elif "area" not in smallest:
206 | smallest = x
207 | else:
208 | if smallest["area"][0] * smallest["area"][1] > a[0] * a[1]:
209 | smallest = x
210 | else:
211 | if smallest is None:
212 | smallest = x
213 | if smallest is None:
214 | return
215 | if "area" in smallest:
216 | if smallest["area"] == c_area:
217 | return
218 |
219 | out = c.copy()
220 | out["model_conds"] = smallest[
221 | "model_conds"
222 | ].copy()
223 | conds += [out]
224 |
--------------------------------------------------------------------------------
/modules/tests/test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | from modules.user.pipeline import pipeline
4 |
5 | class TestPipeline(unittest.TestCase):
6 | def setUp(self):
7 | self.test_prompt = "a cute cat, high quality, detailed"
8 | self.test_img_path = "../_internal/Flux_00001.png" # Make sure this test image exists
9 |
10 | def test_basic_generation_small(self):
11 | pipeline(self.test_prompt, 128, 128, number=1)
12 | # Check if output files exist
13 |
14 | def test_basic_generation_medium(self):
15 | pipeline(self.test_prompt, 512, 512, number=1)
16 |
17 | def test_basic_generation_large(self):
18 | pipeline(self.test_prompt, 1024, 1024, number=1)
19 |
20 | def test_hires_fix(self):
21 | pipeline(self.test_prompt, 512, 512, number=1, hires_fix=True)
22 |
23 | def test_adetailer(self):
24 | pipeline(
25 | "a portrait of a person, high quality",
26 | 512,
27 | 512,
28 | number=1,
29 | adetailer=True
30 | )
31 |
32 | def test_enhance_prompt(self):
33 | pipeline(
34 | self.test_prompt,
35 | 512,
36 | 512,
37 | number=1,
38 | enhance_prompt=True
39 | )
40 |
41 | def test_img2img(self):
42 | # Skip if test image doesn't exist
43 | if not os.path.exists(self.test_img_path):
44 | self.skipTest("Test image not found")
45 |
46 | pipeline(
47 | self.test_img_path,
48 | 512,
49 | 512,
50 | number=1,
51 | img2img=True
52 | )
53 |
54 | def test_stable_fast(self):
55 | resolutions = [(128, 128), (512, 512), (1024, 1024)]
56 | for w, h in resolutions:
57 | pipeline(
58 | self.test_prompt,
59 | w,
60 | h,
61 | number=1,
62 | stable_fast=True
63 | )
64 |
65 | def test_reuse_seed(self):
66 | pipeline(
67 | self.test_prompt,
68 | 512,
69 | 512,
70 | number=2,
71 | reuse_seed=True
72 | )
73 |
74 | def test_flux_mode(self):
75 | resolutions = [(128, 128), (512, 512), (1024, 1024)]
76 | for w, h in resolutions:
77 | pipeline(
78 | self.test_prompt,
79 | w,
80 | h,
81 | number=1,
82 | flux_enabled=True
83 | )
84 |
85 | if __name__ == '__main__':
86 | unittest.main()
--------------------------------------------------------------------------------
/modules/user/app_instance.py:
--------------------------------------------------------------------------------
1 | from modules.user.GUI import App
2 |
3 | app = App()
4 |
--------------------------------------------------------------------------------
/pipeline.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | SET VENV_DIR=.venv
3 |
4 | REM Check if .venv exists
5 | IF NOT EXIST %VENV_DIR% (
6 | echo Creating virtual environment...
7 | python -m venv %VENV_DIR%
8 | )
9 |
10 | REM Activate the virtual environment
11 | CALL %VENV_DIR%\Scripts\activate
12 |
13 | REM Upgrade pip
14 | echo Upgrading pip...
15 | python -m pip install --upgrade pip
16 |
17 | REM Install specific packages
18 | echo Installing required packages...
19 | pip install uv
20 |
21 | REM Check for NVIDIA GPU
22 | FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO (
23 | SET GPU_CHECK=%%i
24 | )
25 | IF NOT ERRORLEVEL 1 (
26 | echo NVIDIA GPU detected, installing GPU dependencies...
27 | uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126
28 | ) ELSE (
29 | echo No NVIDIA GPU detected, installing CPU dependencies...
30 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
31 | )
32 |
33 | uv pip install "numpy>=1.24.3"
34 |
35 | REM Install additional requirements
36 | IF EXIST requirements.txt (
37 | echo Installing additional requirements...
38 | uv pip install -r requirements.txt
39 | ) ELSE (
40 | echo requirements.txt not found, skipping...
41 | )
42 |
43 | REM Check for enhance-prompt argument
44 | echo Checking for enhance-prompt argument...
45 | echo %* | findstr /i /c:"--enhance-prompt" >nul
46 | IF %ERRORLEVEL% EQU 0 (
47 | echo Installing ollama with winget...
48 | winget install --id ollama.ollama
49 | ollama pull deepseek-r1
50 | )
51 |
52 | REM Launch the script
53 | echo Launching LightDiffusion...
54 | python .\modules\user\pipeline.py %*
55 |
56 | REM Deactivate the virtual environment
57 | deactivate
58 |
--------------------------------------------------------------------------------
/pipeline.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | VENV_DIR=.venv
4 | # for WSL2 Ubuntu install
5 | # sudo apt install software-properties-common
6 | # sudo add-apt-repository ppa:deadsnakes/ppa
7 | sudo apt-get install python3.10 python3.10-venv python3.10-full python3-pip
8 |
9 | # Check if .venv exists
10 | if [ ! -d "$VENV_DIR" ]; then
11 | echo "Creating virtual environment..."
12 | python3.10 -m venv $VENV_DIR
13 | fi
14 |
15 | # Activate the virtual environment
16 | source $VENV_DIR/bin/activate
17 |
18 | # Upgrade pip
19 | echo "Upgrading pip..."
20 | pip install --upgrade pip
21 | pip3 install uv
22 |
23 | # Check GPU type
24 | TORCH_URL="https://download.pytorch.org/whl/cpu"
25 | if command -v nvidia-smi &> /dev/null; then
26 | echo "NVIDIA GPU detected"
27 | TORCH_URL="https://download.pytorch.org/whl/cu121"
28 | uv pip install --index-url $TORCH_URL \
29 | torch==2.2.2 torchvision "xformers>=0.0.22" "triton>=2.1.0" \
30 | stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl
31 | elif command -v rocminfo &> /dev/null; then
32 | echo "AMD GPU detected"
33 | TORCH_URL="https://download.pytorch.org/whl/rocm5.7"
34 | uv pip install --index-url $TORCH_URL \
35 | torch==2.2.2 torchvision "triton>=2.1.0"
36 | else
37 | echo "No compatible GPU detected, using CPU"
38 | uv pip install --index-url $TORCH_URL \
39 | torch==2.2.2+cpu torchvision
40 | fi
41 |
42 | uv pip install "numpy<2.0.0"
43 |
44 | # Install tkinter
45 | echo "Installing tkinter..."
46 | sudo apt-get install python3.10-tk
47 |
48 | # Install additional requirements
49 | if [ -f requirements.txt ]; then
50 | echo "Installing additional requirements..."
51 | uv pip install -r requirements.txt
52 | else
53 | echo "requirements.txt not found, skipping..."
54 | fi
55 |
56 | REM Check for enhance-prompt argument
57 | echo Checking for enhance-prompt argument...
58 | if [[ " $* " == *" --enhance-prompt "* ]]; then
59 | echo "Installing ollama..."
60 | curl -fsSL https://ollama.com/install.sh | sh
61 | ollama pull deepseek-r1
62 | fi
63 |
64 | # Launch the script
65 | echo "Launching LightDiffusion..."
66 | python3.10 "./modules/user/pipeline.py" "$@"
67 |
68 | # Deactivate the virtual environment
69 | deactivate
70 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/requirements.txt
--------------------------------------------------------------------------------
/run.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | SET VENV_DIR=.venv
3 |
4 | REM Check if .venv exists
5 | IF NOT EXIST %VENV_DIR% (
6 | echo Creating virtual environment...
7 | python -m venv %VENV_DIR%
8 | )
9 |
10 | REM Activate the virtual environment
11 | CALL %VENV_DIR%\Scripts\activate
12 |
13 | REM Upgrade pip
14 | echo Upgrading pip...
15 | python -m pip install --upgrade pip
16 |
17 | REM Install specific packages
18 | echo Installing required packages...
19 | pip install uv
20 |
21 | REM Check for NVIDIA GPU
22 | FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO (
23 | SET GPU_CHECK=%%i
24 | )
25 | IF NOT ERRORLEVEL 1 (
26 | echo NVIDIA GPU detected, installing GPU dependencies...
27 | uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126
28 | ) ELSE (
29 | echo No NVIDIA GPU detected, installing CPU dependencies...
30 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
31 | )
32 |
33 | uv pip install "numpy>=1.24.3"
34 |
35 | REM Install additional requirements
36 | IF EXIST requirements.txt (
37 | echo Installing additional requirements...
38 | uv pip install -r requirements.txt
39 | ) ELSE (
40 | echo requirements.txt not found, skipping...
41 | )
42 |
43 | REM Launch the script
44 | echo Launching LightDiffusion...
45 | python .\modules\user\GUI.py
46 |
47 | REM Deactivate the virtual environment
48 | deactivate
49 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | VENV_DIR=.venv
4 | # for WSL2 Ubuntu install
5 | # sudo apt install software-properties-common
6 | # sudo add-apt-repository ppa:deadsnakes/ppa
7 | sudo apt-get install python3.10 python3.10-venv python3.10-full python3-pip
8 |
9 | # Check if .venv exists
10 | if [ ! -d "$VENV_DIR" ]; then
11 | echo "Creating virtual environment..."
12 | python3.10 -m venv $VENV_DIR
13 | fi
14 |
15 | # Activate the virtual environment
16 | source $VENV_DIR/bin/activate
17 |
18 | # Upgrade pip
19 | echo "Upgrading pip..."
20 | pip install --upgrade pip
21 | pip3 install uv
22 |
23 | # Check GPU type
24 | TORCH_URL="https://download.pytorch.org/whl/cpu"
25 | if command -v nvidia-smi &> /dev/null; then
26 | echo "NVIDIA GPU detected"
27 | TORCH_URL="https://download.pytorch.org/whl/cu121"
28 | uv pip install --index-url $TORCH_URL \
29 | torch==2.2.2 torchvision "xformers>=0.0.22" "triton>=2.1.0" \
30 | stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl
31 | elif command -v rocminfo &> /dev/null; then
32 | echo "AMD GPU detected"
33 | TORCH_URL="https://download.pytorch.org/whl/rocm5.7"
34 | uv pip install --index-url $TORCH_URL \
35 | torch==2.2.2 torchvision "triton>=2.1.0"
36 | else
37 | echo "No compatible GPU detected, using CPU"
38 | uv pip install --index-url $TORCH_URL \
39 | torch==2.2.2+cpu torchvision
40 | fi
41 |
42 | uv pip install "numpy<2.0.0"
43 |
44 | # Install tkinter
45 | echo "Installing tkinter..."
46 | sudo apt-get install python3.10-tk
47 |
48 | # Install additional requirements
49 | if [ -f requirements.txt ]; then
50 | echo "Installing additional requirements..."
51 | uv pip install -r requirements.txt
52 | else
53 | echo "requirements.txt not found, skipping..."
54 | fi
55 |
56 | # Launch the script
57 | echo "Launching LightDiffusion..."
58 | python3.10 ./modules/user/GUI.py
59 |
60 | # Deactivate the virtual environment
61 | deactivate
62 |
--------------------------------------------------------------------------------
/run_web.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | SET VENV_DIR=.venv
3 |
4 | REM Check if .venv exists
5 | IF NOT EXIST %VENV_DIR% (
6 | echo Creating virtual environment...
7 | python -m venv %VENV_DIR%
8 | )
9 |
10 | REM Activate the virtual environment
11 | CALL %VENV_DIR%\Scripts\activate
12 |
13 | REM Upgrade pip
14 | echo Upgrading pip...
15 | python -m pip install --upgrade pip
16 |
17 | REM Install specific packages
18 | echo Installing required packages...
19 | pip install uv
20 |
21 | REM Check for NVIDIA GPU
22 | FOR /F "delims=" %%i IN ('nvidia-smi 2^>^&1') DO (
23 | SET GPU_CHECK=%%i
24 | )
25 | IF NOT ERRORLEVEL 1 (
26 | echo NVIDIA GPU detected, installing GPU dependencies...
27 | uv pip install xformers torch torchvision --index-url https://download.pytorch.org/whl/cu126
28 | ) ELSE (
29 | echo No NVIDIA GPU detected, installing CPU dependencies...
30 | uv pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
31 | )
32 |
33 | uv pip install "numpy>=1.24.3"
34 |
35 | REM Install additional requirements
36 | IF EXIST requirements.txt (
37 | echo Installing additional requirements...
38 | uv pip install -r requirements.txt
39 | ) ELSE (
40 | echo requirements.txt not found, skipping...
41 | )
42 |
43 | REM Launch the script
44 | echo Launching LightDiffusion...
45 | python app.py
46 |
47 | REM Deactivate the virtual environment
48 | deactivate
49 |
--------------------------------------------------------------------------------
/run_web.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | VENV_DIR=.venv
4 | # for WSL2 Ubuntu install
5 | # sudo apt install software-properties-common
6 | # sudo add-apt-repository ppa:deadsnakes/ppa
7 | sudo apt-get install python3.10 python3.10-venv python3.10-full python3-pip
8 |
9 | # Check if .venv exists
10 | if [ ! -d "$VENV_DIR" ]; then
11 | echo "Creating virtual environment..."
12 | python3.10 -m venv $VENV_DIR
13 | fi
14 |
15 | # Activate the virtual environment
16 | source $VENV_DIR/bin/activate
17 |
18 | # Upgrade pip
19 | echo "Upgrading pip..."
20 | pip install --upgrade pip
21 | pip3 install uv
22 |
23 | # Check GPU type
24 | TORCH_URL="https://download.pytorch.org/whl/cpu"
25 | if command -v nvidia-smi &> /dev/null; then
26 | echo "NVIDIA GPU detected"
27 | TORCH_URL="https://download.pytorch.org/whl/cu121"
28 | uv pip install --index-url $TORCH_URL \
29 | torch==2.2.2 torchvision "xformers>=0.0.22" "triton>=2.1.0" \
30 | stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl
31 | elif command -v rocminfo &> /dev/null; then
32 | echo "AMD GPU detected"
33 | TORCH_URL="https://download.pytorch.org/whl/rocm5.7"
34 | uv pip install --index-url $TORCH_URL \
35 | torch==2.2.2 torchvision "triton>=2.1.0"
36 | else
37 | echo "No compatible GPU detected, using CPU"
38 | uv pip install --index-url $TORCH_URL \
39 | torch==2.2.2+cpu torchvision
40 | fi
41 |
42 | uv pip install "numpy<2.0.0"
43 |
44 | # Install tkinter
45 | echo "Installing tkinter..."
46 | sudo apt-get install python3.10-tk
47 |
48 | # Install additional requirements
49 | if [ -f requirements.txt ]; then
50 | echo "Installing additional requirements..."
51 | uv pip install -r requirements.txt
52 | else
53 | echo "requirements.txt not found, skipping..."
54 | fi
55 |
56 | # Launch the script
57 | echo "Launching LightDiffusion..."
58 | python3.10 app.py
59 |
60 | # Deactivate the virtual environment
61 | deactivate
62 |
--------------------------------------------------------------------------------
/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LightDiffusion/LightDiffusion-Next/705f61bfc56300559fa1d08b5936eb721074fe12/stable_fast-1.0.5+torch222cu121-cp310-cp310-manylinux2014_x86_64.whl
--------------------------------------------------------------------------------