├── rust ├── .gitignore ├── src │ ├── error.rs │ ├── stable_diffusion_interface.rs │ └── lib.rs ├── Cargo.toml └── Cargo.lock ├── example ├── .gitignore ├── Cargo.toml ├── Cargo.lock └── src │ └── main.rs ├── assets ├── output.png ├── output2.png ├── output_lora_img2img.png └── output_lora_txt2img.png ├── .github └── workflows │ └── rust.yml └── README.md /rust/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /example/.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | sd-v* 3 | output* 4 | *.clpt 5 | *.gguf -------------------------------------------------------------------------------- /assets/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output.png -------------------------------------------------------------------------------- /assets/output2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output2.png -------------------------------------------------------------------------------- /assets/output_lora_img2img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output_lora_img2img.png -------------------------------------------------------------------------------- /assets/output_lora_txt2img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output_lora_txt2img.png -------------------------------------------------------------------------------- /example/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wasmedge_stable_diffusion_example" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | wasmedge_stable_diffusion = {path="../rust"} 8 | clap = { version = "4.4.6", features = ["cargo"] } 9 | rand = "0.8" -------------------------------------------------------------------------------- /rust/src/error.rs: -------------------------------------------------------------------------------- 1 | use thiserror::Error; 2 | 3 | /// Error types for the Llama Core library. 4 | #[derive(Error, Debug)] 5 | pub enum SDError { 6 | /// Invalid path error. 7 | #[error("{0}")] 8 | InvalidPath(String), 9 | /// Errors in General operation. 10 | #[error("{0}")] 11 | Operation(String), 12 | } 13 | -------------------------------------------------------------------------------- /rust/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "wasmedge_stable_diffusion" 3 | version = "0.3.2" 4 | edition = "2021" 5 | readme = "../README.md" 6 | repository = "https://github.com/WasmEdge/wasmedge-stable-diffusion" 7 | license = "Apache-2.0" 8 | categories = ["wasm", "science"] 9 | description = "A Rust library for using stable diffusion functions when the Wasi is being executed on WasmEdge." 10 | 11 | 12 | [dependencies] 13 | thiserror = "1" 14 | -------------------------------------------------------------------------------- /.github/workflows/rust.yml: -------------------------------------------------------------------------------- 1 | name: Build Rust Crate 2 | 3 | on: 4 | push: 5 | branches: 6 | - dev 7 | - main 8 | - release-* 9 | - feat-* 10 | - ci-* 11 | - refactor-* 12 | - fix-* 13 | - test-* 14 | pull_request: 15 | branches: 16 | - dev 17 | - main 18 | - release-* 19 | - feat-* 20 | - ci-* 21 | - refactor-* 22 | - fix-* 23 | - test-* 24 | 25 | env: 26 | CARGO_TERM_COLOR: always 27 | 28 | jobs: 29 | build: 30 | runs-on: ubuntu-latest 31 | 32 | steps: 33 | - uses: actions/checkout@v4 34 | - uses: dtolnay/rust-toolchain@stable 35 | 36 | - name: Enable wasm32-wasip1 target 37 | run: rustup target add wasm32-wasip1 38 | 39 | - name: Run rustfmt 40 | run: | 41 | cd rust 42 | cargo fmt -- --check 43 | 44 | - name: Run clippy 45 | run: | 46 | cd rust 47 | cargo clippy -- -D warnings 48 | 49 | - name: Build 50 | run: | 51 | cd rust 52 | cargo build --target=wasm32-wasip1 --release 53 | -------------------------------------------------------------------------------- /rust/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "proc-macro2" 7 | version = "1.0.86" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" 10 | dependencies = [ 11 | "unicode-ident", 12 | ] 13 | 14 | [[package]] 15 | name = "quote" 16 | version = "1.0.37" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" 19 | dependencies = [ 20 | "proc-macro2", 21 | ] 22 | 23 | [[package]] 24 | name = "syn" 25 | version = "2.0.77" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" 28 | dependencies = [ 29 | "proc-macro2", 30 | "quote", 31 | "unicode-ident", 32 | ] 33 | 34 | [[package]] 35 | name = "thiserror" 36 | version = "1.0.63" 37 | source = "registry+https://github.com/rust-lang/crates.io-index" 38 | checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" 39 | dependencies = [ 40 | "thiserror-impl", 41 | ] 42 | 43 | [[package]] 44 | name = "thiserror-impl" 45 | version = "1.0.63" 46 | source = "registry+https://github.com/rust-lang/crates.io-index" 47 | checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" 48 | dependencies = [ 49 | "proc-macro2", 50 | "quote", 51 | "syn", 52 | ] 53 | 54 | [[package]] 55 | name = "unicode-ident" 56 | version = "1.0.12" 57 | source = "registry+https://github.com/rust-lang/crates.io-index" 58 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 59 | 60 | [[package]] 61 | name = "wasmedge_stable_diffusion" 62 | version = "0.3.2" 63 | dependencies = [ 64 | "thiserror", 65 | ] 66 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wasmedge-stable-diffusion 2 | A Rust library for using stable diffusion functions when the Wasi is being executed on WasmEdge. 3 | ## Set up WasmEdge 4 | 5 | ```Bash 6 | git clone https://github.com/WasmEdge/WasmEdge.git 7 | cd WasmEdge 8 | cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TESTS=OFF -DWASMEDGE_PLUGIN_STABLEDIFFUSION=On -DWASMEDGE_USE_LLVM=OFF 9 | cmake --build build 10 | sudo cmake --install build 11 | ``` 12 | 13 | ## Download Model 14 | Download the weights or quantized model from the following command. 15 | You also can use our example to quantize the weights by yourself. 16 | 17 | stable-diffusion v1.4: [second-state/stable-diffusion-v-1-4-GGUF](https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF) 18 | stable-diffusion v1.5: [second-state/stable-diffusion-v1-5-GGUF](https://huggingface.co/second-state/stable-diffusion-v1-5-GGUF) 19 | stable-diffusion v2.1: [second-state/stable-diffusion-2-1-GGUF](https://huggingface.co/second-state/stable-diffusion-2-1-GGUF) 20 | 21 | ``` 22 | curl -L -O https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF/resolve/main/sd-v1-4.ckpt 23 | curl -L -O https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF/resolve/main/stable-diffusion-v1-4-Q8_0.gguf 24 | ``` 25 | 26 | ## Compile example file 27 | The compiled `.wasm` file located at `./target/wasm32-wasi/release/`, and named `wasmedge_stable_diffusion_example.wasm` 28 | ```Bash 29 | cargo build --target wasm32-wasi --release 30 | ``` 31 | 32 | ## Run 33 | It supports two mode: txt2img and img2img. 34 | 35 | ### txt2img 36 | Assume that the model `stable-diffusion-v-1-4-GGUF` is located in the models folder of the same directory as this project. 37 | ```Bash 38 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm -m ../../models/stable-diffusion-v1-4-Q8_0.gguf -p "a lovely cat" 39 | ``` 40 |

41 | 42 |

43 | 44 | ### img2img 45 | - `./output.png` is the image generated from the above txt2img pipeline 46 | ```Bash 47 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm --mode img2img -m ../../models/stable-diffusion-v1-4-Q8_0.gguf -p "cat with red eyes" -i ./output.png -o ./img2img_output.png 48 | ``` 49 |

50 | 51 |

52 | 53 | ### Convert 54 | - Stable Diffusion model: [sd-v1-4.ckpt](), which type is Q8_0. 55 | ```Bash 56 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm --mode convert -o stable-diffusion-v1-4-Q8_0_test.gguf -m ../../models/sd-v1-4.ckpt --type q8_0 57 | ``` 58 | If you want to use the converted model, please use `--type` to asign the type `Q8_0`. 59 | 60 | ## More Guides - LoRA 61 | ### Get weights 62 | - LoRA model: [v1-5-pruned-emaonly.safetensors](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors) 63 | ```Bash 64 | curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors 65 | ``` 66 | ### txt2img 67 | You can specify the directory where the lora weights are stored via `--lora-model-dir`. 68 | If not specified, the default is the current working directory. 69 | ```Bash 70 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm \ 71 | --lora-model-dir ../../lora \ 72 | --model ../../lora/v1-5-pruned-emaonly.safetensors \ 73 | -p "a lovely cat" \ 74 | -o output_lora_txt2img.png 75 | ``` 76 | The lora model `../../lora/sd_xl_base_1.0.safetensors` and vae model `../../lora/sdxl_vae.safetensors` will be applied to the model. 77 |

78 | 79 |

80 | 81 | ### img2img 82 | ```Bash 83 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm \ 84 | -p "with blue eyes" \ 85 | --lora-model-dir ../../lora \ 86 | --model ../../lora/v1-5-pruned-emaonly.safetensors \ 87 | -i output_lora_txt2img.png \ 88 | -o output_lora_img2img.png 89 | ``` 90 |

91 | 92 |

93 | 94 | ## Supported parameters 95 | ``` 96 | usage: wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm [arguments] 97 | 98 | arguments: 99 | -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img) 100 | -t, --threads N number of threads to use during computation (default: -1).If threads <= 0, then threads will be set to the number of CPU physical cores 101 | -m, --model [MODEL] path to full model 102 | --diffusion-model path to the standalone diffusion model 103 | --clip_l path to the clip-l text encoder 104 | --t5xxl path to the the t5xxl text encoder 105 | --vae [VAE] path to vae 106 | --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) 107 | --control-net [CONTROL_PATH] path to control net model 108 | --embd-dir [EMBEDDING_PATH] path to embeddings 109 | --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings 110 | --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir 111 | --normalize-input normalize PHOTOMAKER input id images 112 | --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now 113 | --upscale-repeats Run the ESRGAN upscaler this many times (default 1) 114 | --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k) 115 | If not specified, the default is the type of the weight file 116 | --lora-model-dir [DIR] lora model directory 117 | -i, --init-img [IMAGE] path to the input image, required by img2img 118 | --control-image [IMAGE] path to image condition, control net 119 | -o, --output OUTPUT path to write result image to (default: ./output.png) 120 | -p, --prompt [PROMPT] the prompt to render 121 | -n, --negative-prompt PROMPT the negative prompt (default: "") 122 | --cfg-scale SCALE unconditional guidance scale: (default: 7.0) 123 | --strength STRENGTH strength for noising/unnoising (default: 0.75) 124 | --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%) 125 | --control-strength STRENGTH strength to apply Control Net (default: 0.9) 126 | 1.0 corresponds to full destruction of information in init image 127 | 128 | --guidance guidance 129 | -H, --height H image height, in pixel space (default: 512) 130 | -W, --width W image width, in pixel space (default: 512) 131 | --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm} 132 | sampling method (default: "euler_a") 133 | --steps STEPS number of sample steps (default: 20) 134 | --rng {std_default, cuda} RNG (default: cuda) 135 | -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) 136 | -b, --batch-count COUNT number of images to generate 137 | --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete) 138 | --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) 139 | <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x 140 | --vae-tiling process vae in tiles to reduce memory usage 141 | --vae-on-cpu keep vae in cpu (for low vram) 142 | --clip-on-cpu keep clip in cpu (for low vram) 143 | --control-net-cpu keep controlnet in cpu (for low vram) 144 | --canny apply canny preprocessor (edge detection) 145 | ``` 146 | 147 | 148 | -------------------------------------------------------------------------------- /example/Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "anstream" 7 | version = "0.6.15" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526" 10 | dependencies = [ 11 | "anstyle", 12 | "anstyle-parse", 13 | "anstyle-query", 14 | "anstyle-wincon", 15 | "colorchoice", 16 | "is_terminal_polyfill", 17 | "utf8parse", 18 | ] 19 | 20 | [[package]] 21 | name = "anstyle" 22 | version = "1.0.8" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" 25 | 26 | [[package]] 27 | name = "anstyle-parse" 28 | version = "0.2.5" 29 | source = "registry+https://github.com/rust-lang/crates.io-index" 30 | checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb" 31 | dependencies = [ 32 | "utf8parse", 33 | ] 34 | 35 | [[package]] 36 | name = "anstyle-query" 37 | version = "1.1.1" 38 | source = "registry+https://github.com/rust-lang/crates.io-index" 39 | checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" 40 | dependencies = [ 41 | "windows-sys", 42 | ] 43 | 44 | [[package]] 45 | name = "anstyle-wincon" 46 | version = "3.0.4" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8" 49 | dependencies = [ 50 | "anstyle", 51 | "windows-sys", 52 | ] 53 | 54 | [[package]] 55 | name = "byteorder" 56 | version = "1.5.0" 57 | source = "registry+https://github.com/rust-lang/crates.io-index" 58 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" 59 | 60 | [[package]] 61 | name = "cfg-if" 62 | version = "1.0.0" 63 | source = "registry+https://github.com/rust-lang/crates.io-index" 64 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 65 | 66 | [[package]] 67 | name = "clap" 68 | version = "4.5.17" 69 | source = "registry+https://github.com/rust-lang/crates.io-index" 70 | checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" 71 | dependencies = [ 72 | "clap_builder", 73 | ] 74 | 75 | [[package]] 76 | name = "clap_builder" 77 | version = "4.5.17" 78 | source = "registry+https://github.com/rust-lang/crates.io-index" 79 | checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" 80 | dependencies = [ 81 | "anstream", 82 | "anstyle", 83 | "clap_lex", 84 | "strsim", 85 | ] 86 | 87 | [[package]] 88 | name = "clap_lex" 89 | version = "0.7.2" 90 | source = "registry+https://github.com/rust-lang/crates.io-index" 91 | checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" 92 | 93 | [[package]] 94 | name = "colorchoice" 95 | version = "1.0.2" 96 | source = "registry+https://github.com/rust-lang/crates.io-index" 97 | checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0" 98 | 99 | [[package]] 100 | name = "getrandom" 101 | version = "0.2.15" 102 | source = "registry+https://github.com/rust-lang/crates.io-index" 103 | checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" 104 | dependencies = [ 105 | "cfg-if", 106 | "libc", 107 | "wasi", 108 | ] 109 | 110 | [[package]] 111 | name = "is_terminal_polyfill" 112 | version = "1.70.1" 113 | source = "registry+https://github.com/rust-lang/crates.io-index" 114 | checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" 115 | 116 | [[package]] 117 | name = "libc" 118 | version = "0.2.158" 119 | source = "registry+https://github.com/rust-lang/crates.io-index" 120 | checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" 121 | 122 | [[package]] 123 | name = "ppv-lite86" 124 | version = "0.2.20" 125 | source = "registry+https://github.com/rust-lang/crates.io-index" 126 | checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" 127 | dependencies = [ 128 | "zerocopy", 129 | ] 130 | 131 | [[package]] 132 | name = "proc-macro2" 133 | version = "1.0.86" 134 | source = "registry+https://github.com/rust-lang/crates.io-index" 135 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" 136 | dependencies = [ 137 | "unicode-ident", 138 | ] 139 | 140 | [[package]] 141 | name = "quote" 142 | version = "1.0.37" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" 145 | dependencies = [ 146 | "proc-macro2", 147 | ] 148 | 149 | [[package]] 150 | name = "rand" 151 | version = "0.8.5" 152 | source = "registry+https://github.com/rust-lang/crates.io-index" 153 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 154 | dependencies = [ 155 | "libc", 156 | "rand_chacha", 157 | "rand_core", 158 | ] 159 | 160 | [[package]] 161 | name = "rand_chacha" 162 | version = "0.3.1" 163 | source = "registry+https://github.com/rust-lang/crates.io-index" 164 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 165 | dependencies = [ 166 | "ppv-lite86", 167 | "rand_core", 168 | ] 169 | 170 | [[package]] 171 | name = "rand_core" 172 | version = "0.6.4" 173 | source = "registry+https://github.com/rust-lang/crates.io-index" 174 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 175 | dependencies = [ 176 | "getrandom", 177 | ] 178 | 179 | [[package]] 180 | name = "strsim" 181 | version = "0.11.1" 182 | source = "registry+https://github.com/rust-lang/crates.io-index" 183 | checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" 184 | 185 | [[package]] 186 | name = "syn" 187 | version = "2.0.77" 188 | source = "registry+https://github.com/rust-lang/crates.io-index" 189 | checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" 190 | dependencies = [ 191 | "proc-macro2", 192 | "quote", 193 | "unicode-ident", 194 | ] 195 | 196 | [[package]] 197 | name = "thiserror" 198 | version = "1.0.63" 199 | source = "registry+https://github.com/rust-lang/crates.io-index" 200 | checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" 201 | dependencies = [ 202 | "thiserror-impl", 203 | ] 204 | 205 | [[package]] 206 | name = "thiserror-impl" 207 | version = "1.0.63" 208 | source = "registry+https://github.com/rust-lang/crates.io-index" 209 | checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" 210 | dependencies = [ 211 | "proc-macro2", 212 | "quote", 213 | "syn", 214 | ] 215 | 216 | [[package]] 217 | name = "unicode-ident" 218 | version = "1.0.12" 219 | source = "registry+https://github.com/rust-lang/crates.io-index" 220 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" 221 | 222 | [[package]] 223 | name = "utf8parse" 224 | version = "0.2.2" 225 | source = "registry+https://github.com/rust-lang/crates.io-index" 226 | checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" 227 | 228 | [[package]] 229 | name = "wasi" 230 | version = "0.11.0+wasi-snapshot-preview1" 231 | source = "registry+https://github.com/rust-lang/crates.io-index" 232 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 233 | 234 | [[package]] 235 | name = "wasmedge_stable_diffusion" 236 | version = "0.3.2" 237 | dependencies = [ 238 | "thiserror", 239 | ] 240 | 241 | [[package]] 242 | name = "wasmedge_stable_diffusion_example" 243 | version = "0.1.0" 244 | dependencies = [ 245 | "clap", 246 | "rand", 247 | "wasmedge_stable_diffusion", 248 | ] 249 | 250 | [[package]] 251 | name = "windows-sys" 252 | version = "0.52.0" 253 | source = "registry+https://github.com/rust-lang/crates.io-index" 254 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" 255 | dependencies = [ 256 | "windows-targets", 257 | ] 258 | 259 | [[package]] 260 | name = "windows-targets" 261 | version = "0.52.6" 262 | source = "registry+https://github.com/rust-lang/crates.io-index" 263 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" 264 | dependencies = [ 265 | "windows_aarch64_gnullvm", 266 | "windows_aarch64_msvc", 267 | "windows_i686_gnu", 268 | "windows_i686_gnullvm", 269 | "windows_i686_msvc", 270 | "windows_x86_64_gnu", 271 | "windows_x86_64_gnullvm", 272 | "windows_x86_64_msvc", 273 | ] 274 | 275 | [[package]] 276 | name = "windows_aarch64_gnullvm" 277 | version = "0.52.6" 278 | source = "registry+https://github.com/rust-lang/crates.io-index" 279 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" 280 | 281 | [[package]] 282 | name = "windows_aarch64_msvc" 283 | version = "0.52.6" 284 | source = "registry+https://github.com/rust-lang/crates.io-index" 285 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" 286 | 287 | [[package]] 288 | name = "windows_i686_gnu" 289 | version = "0.52.6" 290 | source = "registry+https://github.com/rust-lang/crates.io-index" 291 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" 292 | 293 | [[package]] 294 | name = "windows_i686_gnullvm" 295 | version = "0.52.6" 296 | source = "registry+https://github.com/rust-lang/crates.io-index" 297 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" 298 | 299 | [[package]] 300 | name = "windows_i686_msvc" 301 | version = "0.52.6" 302 | source = "registry+https://github.com/rust-lang/crates.io-index" 303 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" 304 | 305 | [[package]] 306 | name = "windows_x86_64_gnu" 307 | version = "0.52.6" 308 | source = "registry+https://github.com/rust-lang/crates.io-index" 309 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" 310 | 311 | [[package]] 312 | name = "windows_x86_64_gnullvm" 313 | version = "0.52.6" 314 | source = "registry+https://github.com/rust-lang/crates.io-index" 315 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" 316 | 317 | [[package]] 318 | name = "windows_x86_64_msvc" 319 | version = "0.52.6" 320 | source = "registry+https://github.com/rust-lang/crates.io-index" 321 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 322 | 323 | [[package]] 324 | name = "zerocopy" 325 | version = "0.7.35" 326 | source = "registry+https://github.com/rust-lang/crates.io-index" 327 | checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" 328 | dependencies = [ 329 | "byteorder", 330 | "zerocopy-derive", 331 | ] 332 | 333 | [[package]] 334 | name = "zerocopy-derive" 335 | version = "0.7.35" 336 | source = "registry+https://github.com/rust-lang/crates.io-index" 337 | checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" 338 | dependencies = [ 339 | "proc-macro2", 340 | "quote", 341 | "syn", 342 | ] 343 | -------------------------------------------------------------------------------- /rust/src/stable_diffusion_interface.rs: -------------------------------------------------------------------------------- 1 | use core::fmt; 2 | use core::mem::MaybeUninit; 3 | #[repr(transparent)] 4 | #[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)] 5 | pub struct WasmedgeSdErrno(u32); 6 | pub const WASMEDGE_SD_ERRNO_SUCCESS: WasmedgeSdErrno = WasmedgeSdErrno(0); 7 | pub const WASMEDGE_SD_ERRNO_INVALID_ARGUMENT: WasmedgeSdErrno = WasmedgeSdErrno(1); 8 | pub const WASMEDGE_SD_ERRNO_INVALID_ENCODING: WasmedgeSdErrno = WasmedgeSdErrno(2); 9 | pub const WASMEDGE_SD_ERRNO_MISSING_MEMORY: WasmedgeSdErrno = WasmedgeSdErrno(3); 10 | pub const WASMEDGE_SD_ERRNO_BUSY: WasmedgeSdErrno = WasmedgeSdErrno(4); 11 | pub const WASMEDGE_SD_ERRNO_RUNTIME_ERROR: WasmedgeSdErrno = WasmedgeSdErrno(5); 12 | impl WasmedgeSdErrno { 13 | pub const fn raw(&self) -> u32 { 14 | self.0 15 | } 16 | 17 | pub fn name(&self) -> &'static str { 18 | match self.0 { 19 | 0 => "SUCCESS", 20 | 1 => "INVALID_ARGUMENT", 21 | 2 => "INVALID_ENCODING", 22 | 3 => "MISSING_MEMORY", 23 | 4 => "BUSY", 24 | 5 => "RUNTIME_ERROR", 25 | _ => unsafe { core::hint::unreachable_unchecked() }, 26 | } 27 | } 28 | pub fn message(&self) -> &'static str { 29 | match self.0 { 30 | 0 => "", 31 | 1 => "", 32 | 2 => "", 33 | 3 => "", 34 | 4 => "", 35 | 5 => "", 36 | _ => unsafe { core::hint::unreachable_unchecked() }, 37 | } 38 | } 39 | } 40 | impl fmt::Debug for WasmedgeSdErrno { 41 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 42 | f.debug_struct("WasmedgeSdErrno") 43 | .field("code", &self.0) 44 | .field("name", &self.name()) 45 | .field("message", &self.message()) 46 | .finish() 47 | } 48 | } 49 | impl fmt::Display for WasmedgeSdErrno { 50 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 51 | write!(f, "{} (error {})", self.name(), self.0) 52 | } 53 | } 54 | 55 | #[derive(Debug, Copy, Clone)] 56 | pub enum SdTypeT { 57 | SdTypeF32 = 0, 58 | SdTypeF16 = 1, 59 | SdTypeQ4_0 = 2, 60 | SdTypeQ4_1 = 3, 61 | // SdTypeQ4_2 = 4, // support has been removed 62 | // SdTypeQ4_3 = 5, // support has been removed 63 | SdTypeQ5_0 = 6, 64 | SdTypeQ5_1 = 7, 65 | SdTypeQ8_0 = 8, 66 | SdTypeQ8_1 = 9, 67 | SdTypeQ2K = 10, 68 | SdTypeQ3K = 11, 69 | SdTypeQ4K = 12, 70 | SdTypeQ5K = 13, 71 | SdTypeQ6K = 14, 72 | SdTypeQ8K = 15, 73 | SdTypeIq2Xxs = 16, 74 | SdTypeIq2Xs = 17, 75 | SdTypeIq3Xxs = 18, 76 | SdTypeIq1S = 19, 77 | SdTypeIq4Nl = 20, 78 | SdTypeIq3S = 21, 79 | SdTypeIq2S = 22, 80 | SdTypeIq4Xs = 23, 81 | SdTypeI8 = 24, 82 | SdTypeI16 = 25, 83 | SdTypeI32 = 26, 84 | SdTypeI64 = 27, 85 | SdTypeF64 = 28, 86 | SdTypeIq1M = 29, 87 | SdTypeBf16 = 30, 88 | SdTypeQ4044 = 31, 89 | SdTypeQ4048 = 32, 90 | SdTypeQ4088 = 33, 91 | SdTypeCount = 34, 92 | } 93 | #[derive(Debug, Copy, Clone)] 94 | pub enum RngTypeT { 95 | StdDefaultRng = 0, 96 | CUDARng = 1, 97 | } 98 | #[derive(Debug, Copy, Clone)] 99 | pub enum SampleMethodT { 100 | EULERA = 0, 101 | EULER = 1, 102 | HEUN = 2, 103 | DPM2 = 3, 104 | DPMPP2SA = 4, 105 | DPMPP2M = 5, 106 | DPMPP2Mv2 = 6, 107 | IPNDM = 7, 108 | IPNDMV = 8, 109 | LCM = 9, 110 | } 111 | #[derive(Debug, Copy, Clone)] 112 | pub enum ScheduleT { 113 | DEFAULT = 0, 114 | DISCRETE = 1, 115 | KARRAS = 2, 116 | EXPONENTIAL = 3, 117 | AYS = 4, 118 | GITS = 5, 119 | } 120 | #[derive(Debug)] 121 | pub enum ImageType { 122 | Path(String), 123 | } 124 | fn parse_image(image: &ImageType) -> (i32, i32) { 125 | match image { 126 | ImageType::Path(path) => { 127 | if path.is_empty() { 128 | return (0, 0); 129 | } 130 | let path = "path:".to_string() + path; 131 | (path.as_ptr() as i32, path.len() as i32) 132 | } 133 | } 134 | } 135 | 136 | impl SdTypeT { 137 | pub fn from_index(index: usize) -> Result { 138 | match index { 139 | 0 => Ok(SdTypeT::SdTypeF32), 140 | 1 => Ok(SdTypeT::SdTypeF16), 141 | 2 => Ok(SdTypeT::SdTypeQ4_0), 142 | 3 => Ok(SdTypeT::SdTypeQ4_1), 143 | // 4 => Ok(SdTypeT::SdTypeQ4_2),// support has been removed 144 | // 5 => Ok(SdTypeT::SdTypeQ4_3),// support has been removed 145 | 6 => Ok(SdTypeT::SdTypeQ5_0), 146 | 7 => Ok(SdTypeT::SdTypeQ5_1), 147 | 8 => Ok(SdTypeT::SdTypeQ8_0), 148 | 9 => Ok(SdTypeT::SdTypeQ8_1), 149 | 10 => Ok(SdTypeT::SdTypeQ2K), 150 | 11 => Ok(SdTypeT::SdTypeQ3K), 151 | 12 => Ok(SdTypeT::SdTypeQ4K), 152 | 13 => Ok(SdTypeT::SdTypeQ5K), 153 | 14 => Ok(SdTypeT::SdTypeQ6K), 154 | 15 => Ok(SdTypeT::SdTypeQ8K), 155 | 16 => Ok(SdTypeT::SdTypeIq2Xxs), 156 | 17 => Ok(SdTypeT::SdTypeIq2Xs), 157 | 18 => Ok(SdTypeT::SdTypeIq3Xxs), 158 | 19 => Ok(SdTypeT::SdTypeIq1S), 159 | 20 => Ok(SdTypeT::SdTypeIq4Nl), 160 | 21 => Ok(SdTypeT::SdTypeIq3S), 161 | 22 => Ok(SdTypeT::SdTypeIq2S), 162 | 23 => Ok(SdTypeT::SdTypeIq4Xs), 163 | 24 => Ok(SdTypeT::SdTypeI8), 164 | 25 => Ok(SdTypeT::SdTypeI16), 165 | 26 => Ok(SdTypeT::SdTypeI32), 166 | 27 => Ok(SdTypeT::SdTypeI64), 167 | 28 => Ok(SdTypeT::SdTypeF64), 168 | 29 => Ok(SdTypeT::SdTypeIq1M), 169 | 30 => Ok(SdTypeT::SdTypeBf16), 170 | 31 => Ok(SdTypeT::SdTypeQ4044), 171 | 32 => Ok(SdTypeT::SdTypeQ4048), 172 | 33 => Ok(SdTypeT::SdTypeQ4088), 173 | _ => Ok(SdTypeT::SdTypeCount), 174 | } 175 | } 176 | } 177 | impl SampleMethodT { 178 | pub fn from_index(index: usize) -> Result { 179 | match index { 180 | 0 => Ok(SampleMethodT::EULERA), 181 | 1 => Ok(SampleMethodT::EULER), 182 | 2 => Ok(SampleMethodT::HEUN), 183 | 3 => Ok(SampleMethodT::DPM2), 184 | 4 => Ok(SampleMethodT::DPMPP2SA), 185 | 5 => Ok(SampleMethodT::DPMPP2M), 186 | 6 => Ok(SampleMethodT::DPMPP2Mv2), 187 | 7 => Ok(SampleMethodT::IPNDM), 188 | 8 => Ok(SampleMethodT::IPNDMV), 189 | _ => Ok(SampleMethodT::LCM), 190 | } 191 | } 192 | } 193 | impl ScheduleT { 194 | pub fn from_index(index: usize) -> Result { 195 | match index { 196 | 0 => Ok(ScheduleT::DEFAULT), 197 | 1 => Ok(ScheduleT::DISCRETE), 198 | 2 => Ok(ScheduleT::KARRAS), 199 | 3 => Ok(ScheduleT::EXPONENTIAL), 200 | 4 => Ok(ScheduleT::AYS), 201 | _ => Ok(ScheduleT::GITS), 202 | } 203 | } 204 | } 205 | 206 | /// # Safety 207 | /// This will call an external wasm function 208 | pub unsafe fn convert( 209 | model_path: &str, 210 | vae_model_path: &str, 211 | output_path: &str, 212 | wtype: SdTypeT, 213 | ) -> Result<(), WasmedgeSdErrno> { 214 | let model_path_ptr = model_path.as_ptr() as i32; 215 | let model_path_len = model_path.len() as i32; 216 | let vae_model_path_ptr = vae_model_path.as_ptr() as i32; 217 | let vae_model_path_len = vae_model_path.len() as i32; 218 | let output_path_ptr = output_path.as_ptr() as i32; 219 | let output_path_len = output_path.len() as i32; 220 | let result = wasmedge_stablediffusion::convert( 221 | model_path_ptr, 222 | model_path_len, 223 | vae_model_path_ptr, 224 | vae_model_path_len, 225 | output_path_ptr, 226 | output_path_len, 227 | wtype as i32, 228 | ); 229 | if result != 0 { 230 | Err(WasmedgeSdErrno(result as u32)) 231 | } else { 232 | Ok(()) 233 | } 234 | } 235 | #[allow(clippy::too_many_arguments)] 236 | /// # Safety 237 | /// This will call an external wasm function 238 | pub unsafe fn create_context( 239 | model_path: &str, 240 | clip_l_path: &str, 241 | t5xxl_path: &str, 242 | diffusion_model_path: &str, 243 | vae_path: &str, 244 | taesd_path: &str, 245 | control_net_path: &str, 246 | lora_model_dir: &str, 247 | embed_dir: &str, 248 | id_embed_dir: &str, 249 | vae_decode_only: bool, 250 | vae_tiling: bool, 251 | n_threads: i32, 252 | wtype: SdTypeT, 253 | rng_type: RngTypeT, 254 | schedule: ScheduleT, 255 | clip_on_cpu: bool, 256 | control_net_cpu: bool, 257 | vae_on_cpu: bool, 258 | session_id: *mut u32, 259 | ) -> Result<(), WasmedgeSdErrno> { 260 | let model_path_ptr = model_path.as_ptr() as i32; 261 | let model_path_len = model_path.len() as i32; 262 | let clip_l_path_ptr = clip_l_path.as_ptr() as i32; 263 | let clip_l_path_len = clip_l_path.len() as i32; 264 | let t5xxl_path_ptr = t5xxl_path.as_ptr() as i32; 265 | let t5xxl_path_len = t5xxl_path.len() as i32; 266 | let diffusion_model_path_ptr = diffusion_model_path.as_ptr() as i32; 267 | let diffusion_model_path_len = diffusion_model_path.len() as i32; 268 | let vae_path_ptr = vae_path.as_ptr() as i32; 269 | let vae_path_len = vae_path.len() as i32; 270 | let taesd_path_ptr = taesd_path.as_ptr() as i32; 271 | let taesd_path_len = taesd_path.len() as i32; 272 | let control_net_path_ptr = control_net_path.as_ptr() as i32; 273 | let control_net_path_len = control_net_path.len() as i32; 274 | let lora_model_dir_ptr = lora_model_dir.as_ptr() as i32; 275 | let lora_model_dir_len = lora_model_dir.len() as i32; 276 | let embed_dir_ptr = embed_dir.as_ptr() as i32; 277 | let embed_dir_len = embed_dir.len() as i32; 278 | let id_embed_dir_ptr = id_embed_dir.as_ptr() as i32; 279 | let id_embed_dir_len = id_embed_dir.len() as i32; 280 | let vae_decode_only = vae_decode_only as i32; 281 | let vae_tiling = vae_tiling as i32; 282 | let n_threads_ = n_threads; 283 | let wtype = wtype as i32; 284 | let rng_type = rng_type as i32; 285 | let schedule = schedule as i32; 286 | let clip_on_cpu = clip_on_cpu as i32; 287 | let control_net_cpu = control_net_cpu as i32; 288 | let vae_on_cpu = vae_on_cpu as i32; 289 | let session_id_ptr = session_id as i32; 290 | let result = wasmedge_stablediffusion::create_context( 291 | model_path_ptr, 292 | model_path_len, 293 | clip_l_path_ptr, 294 | clip_l_path_len, 295 | t5xxl_path_ptr, 296 | t5xxl_path_len, 297 | diffusion_model_path_ptr, 298 | diffusion_model_path_len, 299 | vae_path_ptr, 300 | vae_path_len, 301 | taesd_path_ptr, 302 | taesd_path_len, 303 | control_net_path_ptr, 304 | control_net_path_len, 305 | lora_model_dir_ptr, 306 | lora_model_dir_len, 307 | embed_dir_ptr, 308 | embed_dir_len, 309 | id_embed_dir_ptr, 310 | id_embed_dir_len, 311 | vae_decode_only, 312 | vae_tiling, 313 | n_threads_, 314 | wtype, 315 | rng_type, 316 | schedule, 317 | clip_on_cpu, 318 | control_net_cpu, 319 | vae_on_cpu, 320 | session_id_ptr, 321 | ); 322 | if result != 0 { 323 | Err(WasmedgeSdErrno(result as u32)) 324 | } else { 325 | Ok(()) 326 | } 327 | } 328 | 329 | #[allow(clippy::too_many_arguments)] 330 | /// # Safety 331 | /// This will call an external wasm function 332 | pub unsafe fn text_to_image( 333 | prompt: &str, 334 | session_id: u32, 335 | control_image: &ImageType, 336 | negative_prompt: &str, 337 | guidance: f32, 338 | width: i32, 339 | height: i32, 340 | clip_skip: i32, 341 | cfg_scale: f32, 342 | sample_method: SampleMethodT, 343 | sample_steps: i32, 344 | seed: i32, 345 | batch_count: i32, 346 | control_strength: f32, 347 | style_ratio: f32, 348 | normalize_input: bool, 349 | input_id_images_dir: &str, 350 | canny_preprocess: bool, 351 | upscale_model: &str, 352 | upscale_repeats: i32, 353 | output_path: &str, 354 | output_buf: *mut u8, 355 | out_buffer_max_size: i32, 356 | ) -> Result { 357 | let prompt_ptr = prompt.as_ptr() as i32; 358 | let prompt_len = prompt.len() as i32; 359 | let session_id = session_id as i32; 360 | let (control_image_ptr, control_image_len) = parse_image(control_image); 361 | let negative_prompt_ptr = negative_prompt.as_ptr() as i32; 362 | let negative_prompt_len = negative_prompt.len() as i32; 363 | let sample_method = sample_method as i32; 364 | let input_id_images_dir_ptr = input_id_images_dir.as_ptr() as i32; 365 | let input_id_images_dir_len = input_id_images_dir.len() as i32; 366 | let normalize_input = normalize_input as i32; 367 | let canny_preprocess = canny_preprocess as i32; 368 | let upscale_model_path_ptr = upscale_model.as_ptr() as i32; 369 | let upscale_model_path_len = upscale_model.len() as i32; 370 | let output_path_ptr = output_path.as_ptr() as i32; 371 | let output_path_len = output_path.len() as i32; 372 | let output_buf_ptr = output_buf as i32; 373 | let out_buffer_max_size_ = out_buffer_max_size; 374 | let mut write_bytes = MaybeUninit::::uninit(); 375 | let result = wasmedge_stablediffusion::text_to_image( 376 | prompt_ptr, 377 | prompt_len, 378 | session_id, 379 | control_image_ptr, 380 | control_image_len, 381 | negative_prompt_ptr, 382 | negative_prompt_len, 383 | guidance, 384 | width, 385 | height, 386 | clip_skip, 387 | cfg_scale, 388 | sample_method, 389 | sample_steps, 390 | seed, 391 | batch_count, 392 | control_strength, 393 | style_ratio, 394 | normalize_input, 395 | input_id_images_dir_ptr, 396 | input_id_images_dir_len, 397 | canny_preprocess, 398 | upscale_model_path_ptr, 399 | upscale_model_path_len, 400 | upscale_repeats, 401 | output_path_ptr, 402 | output_path_len, 403 | output_buf_ptr, 404 | out_buffer_max_size_, 405 | write_bytes.as_mut_ptr() as i32, 406 | ); 407 | if result != 0 { 408 | Err(WasmedgeSdErrno(result as u32)) 409 | } else { 410 | Ok(write_bytes.assume_init()) 411 | } 412 | } 413 | #[allow(clippy::too_many_arguments)] 414 | /// # Safety 415 | /// This will call an external wasm function 416 | pub unsafe fn image_to_image( 417 | image: &ImageType, 418 | session_id: u32, 419 | guidance: f32, 420 | width: i32, 421 | height: i32, 422 | control_image: &ImageType, 423 | prompt: &str, 424 | negative_prompt: &str, 425 | clip_skip: i32, 426 | cfg_scale: f32, 427 | sample_method: SampleMethodT, 428 | sample_steps: i32, 429 | strength: f32, 430 | seed: i32, 431 | batch_count: i32, 432 | control_strength: f32, 433 | style_ratio: f32, 434 | normalize_input: bool, 435 | input_id_images_dir: &str, 436 | canny_preprocess: bool, 437 | upscale_model_path: &str, 438 | upscale_repeats: i32, 439 | output_path: &str, 440 | output_buf: *mut u8, 441 | out_buffer_max_size: i32, 442 | ) -> Result { 443 | let (image_ptr, image_len) = parse_image(image); 444 | let (control_image_ptr, control_image_len) = parse_image(control_image); 445 | let session_id = session_id as i32; 446 | let prompt_ptr = prompt.as_ptr() as i32; 447 | let prompt_len = prompt.len() as i32; 448 | let negative_prompt_ptr = negative_prompt.as_ptr() as i32; 449 | let negative_prompt_len = negative_prompt.len() as i32; 450 | let sample_method = sample_method as i32; 451 | let normalize_input = normalize_input as i32; 452 | let input_id_images_dir_ptr = input_id_images_dir.as_ptr() as i32; 453 | let input_id_images_dir_len = input_id_images_dir.len() as i32; 454 | let canny_preprocess = canny_preprocess as i32; 455 | let upscale_model_path_ptr = upscale_model_path.as_ptr() as i32; 456 | let upscale_model_path_len = upscale_model_path.len() as i32; 457 | let output_path_ptr = output_path.as_ptr() as i32; 458 | let output_path_len = output_path.len() as i32; 459 | let output_buf_ptr = output_buf as i32; 460 | let out_buffer_max_size_ = out_buffer_max_size; 461 | let mut write_bytes = MaybeUninit::::uninit(); 462 | 463 | let result = wasmedge_stablediffusion::image_to_image( 464 | image_ptr, 465 | image_len, 466 | session_id, 467 | guidance, 468 | width, 469 | height, 470 | control_image_ptr, 471 | control_image_len, 472 | prompt_ptr, 473 | prompt_len, 474 | negative_prompt_ptr, 475 | negative_prompt_len, 476 | clip_skip, 477 | cfg_scale, 478 | sample_method, 479 | sample_steps, 480 | strength, 481 | seed, 482 | batch_count, 483 | control_strength, 484 | style_ratio, 485 | normalize_input, 486 | input_id_images_dir_ptr, 487 | input_id_images_dir_len, 488 | canny_preprocess, 489 | upscale_model_path_ptr, 490 | upscale_model_path_len, 491 | upscale_repeats, 492 | output_path_ptr, 493 | output_path_len, 494 | output_buf_ptr, 495 | out_buffer_max_size_, 496 | write_bytes.as_mut_ptr() as i32, 497 | ); 498 | if result != 0 { 499 | Err(WasmedgeSdErrno(result as u32)) 500 | } else { 501 | Ok(write_bytes.assume_init()) 502 | } 503 | } 504 | pub mod wasmedge_stablediffusion { 505 | #[link(wasm_import_module = "wasmedge_stablediffusion")] 506 | extern "C" { 507 | pub fn create_context( 508 | model_path_ptr: i32, 509 | model_path_len: i32, 510 | clip_l_path_ptr: i32, 511 | clip_l_path_len: i32, 512 | t5xxl_path_ptr: i32, 513 | t5xxl_path_len: i32, 514 | diffusion_model_path_ptr: i32, 515 | diffusion_model_path_len: i32, 516 | vae_path_ptr: i32, 517 | vae_path_len: i32, 518 | taesd_path_ptr: i32, 519 | taesd_path_len: i32, 520 | control_net_path_ptr: i32, 521 | control_net_path_len: i32, 522 | lora_model_dir_ptr: i32, 523 | lora_model_dir_len: i32, 524 | embed_dir_ptr: i32, 525 | embed_dir_len: i32, 526 | id_embed_dir_ptr: i32, 527 | id_embed_dir_len: i32, 528 | vae_decode_only: i32, 529 | vae_tiling: i32, 530 | n_threads: i32, 531 | wtype: i32, 532 | rng_type: i32, 533 | schedule: i32, 534 | clip_on_cpu: i32, 535 | control_net_cpu: i32, 536 | vae_on_cpu: i32, 537 | session_id_ptr: i32, 538 | ) -> i32; 539 | 540 | pub fn image_to_image( 541 | image_ptr: i32, 542 | image_len: i32, 543 | session_id: i32, 544 | guidance: f32, 545 | width: i32, 546 | height: i32, 547 | control_image_ptr: i32, 548 | control_image_len: i32, 549 | prompt_ptr: i32, 550 | prompt_len: i32, 551 | negative_prompt_ptr: i32, 552 | negative_prompt_len: i32, 553 | clip_skip: i32, 554 | cfg_scale: f32, 555 | sample_method: i32, 556 | sample_steps: i32, 557 | strength: f32, 558 | seed: i32, 559 | batch_count: i32, 560 | control_strength: f32, 561 | style_ratio: f32, 562 | normalize_input: i32, 563 | input_id_images_dir_ptr: i32, 564 | input_id_images_dir_len: i32, 565 | canny_preprocess: i32, 566 | upscale_model_path_ptr: i32, 567 | upscale_model_path_len: i32, 568 | upscale_repeats: i32, 569 | output_path_ptr: i32, 570 | output_path_len: i32, 571 | out_buffer_ptr: i32, 572 | out_buffer_max_size: i32, 573 | bytes_written_ptr: i32, 574 | ) -> i32; 575 | 576 | pub fn text_to_image( 577 | prompt_ptr: i32, 578 | prompt_len: i32, 579 | session_id: i32, 580 | control_image_ptr: i32, 581 | control_image_len: i32, 582 | negative_prompt_ptr: i32, 583 | negative_prompt_len: i32, 584 | guidance: f32, 585 | width: i32, 586 | height: i32, 587 | clip_skip: i32, 588 | cfg_scale: f32, 589 | sample_method: i32, 590 | sample_steps: i32, 591 | seed: i32, 592 | batch_count: i32, 593 | control_strength: f32, 594 | style_ratio: f32, 595 | normalize_input: i32, 596 | input_id_images_dir_ptr: i32, 597 | input_id_images_dir_len: i32, 598 | canny_preprocess: i32, 599 | upscale_model_path_ptr: i32, 600 | upscale_model_path_len: i32, 601 | upscale_repeats: i32, 602 | output_path_ptr: i32, 603 | output_path_len: i32, 604 | out_buffer_ptr: i32, 605 | out_buffer_max_size: i32, 606 | bytes_written_ptr: i32, 607 | ) -> i32; 608 | 609 | pub fn convert( 610 | model_path_ptr: i32, 611 | model_path_len: i32, 612 | vae_model_path_ptr: i32, 613 | vae_model_path_len: i32, 614 | output_path_ptr: i32, 615 | output_path_len: i32, 616 | wtype: i32, 617 | ) -> i32; 618 | } 619 | } 620 | -------------------------------------------------------------------------------- /rust/src/lib.rs: -------------------------------------------------------------------------------- 1 | pub mod error; 2 | pub mod stable_diffusion_interface; 3 | 4 | use core::mem::MaybeUninit; 5 | use error::SDError; 6 | use stable_diffusion_interface::*; 7 | use std::path::Path; 8 | 9 | const BUF_LEN: i32 = 1000000; 10 | 11 | pub type SDResult = Result; 12 | 13 | /// Represents Quantization task. 14 | pub struct Quantization { 15 | pub model_path: String, 16 | pub vae_model_path: String, 17 | pub output_path: String, 18 | pub wtype: SdTypeT, 19 | } 20 | impl Quantization { 21 | pub fn new( 22 | model_path: &str, 23 | vae_model_path: String, 24 | output_path: &str, 25 | wtype: SdTypeT, 26 | ) -> Quantization { 27 | Quantization { 28 | model_path: model_path.to_string(), 29 | vae_model_path, 30 | output_path: output_path.to_string(), 31 | wtype, 32 | } 33 | } 34 | pub fn convert(&self) -> Result<(), WasmedgeSdErrno> { 35 | unsafe { 36 | stable_diffusion_interface::convert( 37 | &self.model_path, 38 | &self.vae_model_path, 39 | &self.output_path, 40 | self.wtype, 41 | ) 42 | } 43 | } 44 | } 45 | 46 | #[derive(Debug)] 47 | pub enum Task { 48 | TextToImage, 49 | ImageToImage, 50 | Convert, 51 | } 52 | // Parse command line arguments, for --mode 53 | impl std::str::FromStr for Task { 54 | type Err = String; 55 | fn from_str(s: &str) -> Result { 56 | match s { 57 | "txt2img" => Ok(Task::TextToImage), 58 | "img2img" => Ok(Task::ImageToImage), 59 | "convert" => Ok(Task::Convert), 60 | _ => Err(format!("Invalid mode: {}", s)), 61 | } 62 | } 63 | } 64 | 65 | #[derive(Debug)] 66 | pub enum Context { 67 | TextToImage(TextToImage), 68 | ImageToImage(ImageToImage), 69 | } 70 | 71 | #[derive(Debug)] 72 | pub struct BaseContext { 73 | pub session_id: u32, 74 | pub prompt: String, 75 | pub guidance: f32, 76 | pub width: i32, 77 | pub height: i32, 78 | pub control_image: ImageType, 79 | pub negative_prompt: String, 80 | pub clip_skip: i32, 81 | pub cfg_scale: f32, 82 | pub sample_method: SampleMethodT, 83 | pub sample_steps: i32, 84 | pub seed: i32, 85 | pub batch_count: i32, 86 | pub control_strength: f32, 87 | pub style_ratio: f32, 88 | pub normalize_input: bool, 89 | pub input_id_images_dir: String, 90 | pub canny_preprocess: bool, 91 | pub upscale_model: String, 92 | pub upscale_repeats: i32, 93 | pub output_path: String, 94 | } 95 | pub trait BaseFunction { 96 | fn base(&mut self) -> &mut BaseContext; 97 | fn set_prompt(&mut self, prompt: String) -> &mut Self { 98 | { 99 | self.base().prompt = prompt; 100 | } 101 | self 102 | } 103 | fn set_guidance(&mut self, guidance: f32) -> &mut Self { 104 | { 105 | self.base().guidance = guidance; 106 | } 107 | self 108 | } 109 | fn set_width(&mut self, width: i32) -> &mut Self { 110 | { 111 | self.base().width = width; 112 | } 113 | self 114 | } 115 | fn set_height(&mut self, height: i32) -> &mut Self { 116 | { 117 | self.base().height = height; 118 | } 119 | self 120 | } 121 | fn set_control_image(&mut self, control_image: ImageType) -> &mut Self { 122 | { 123 | self.base().control_image = control_image; 124 | } 125 | self 126 | } 127 | fn set_negative_prompt(&mut self, negative_prompt: impl Into) -> &mut Self { 128 | { 129 | self.base().negative_prompt = negative_prompt.into(); 130 | } 131 | self 132 | } 133 | fn set_clip_skip(&mut self, clip_skip: i32) -> &mut Self { 134 | { 135 | self.base().clip_skip = clip_skip; 136 | } 137 | self 138 | } 139 | fn set_cfg_scale(&mut self, cfg_scale: f32) -> &mut Self { 140 | { 141 | self.base().cfg_scale = cfg_scale; 142 | } 143 | self 144 | } 145 | fn set_sample_method(&mut self, sample_method: SampleMethodT) -> &mut Self { 146 | { 147 | self.base().sample_method = sample_method; 148 | } 149 | self 150 | } 151 | fn set_sample_steps(&mut self, sample_steps: i32) -> &mut Self { 152 | { 153 | self.base().sample_steps = sample_steps; 154 | } 155 | self 156 | } 157 | fn set_seed(&mut self, seed: i32) -> &mut Self { 158 | { 159 | self.base().seed = seed; 160 | } 161 | self 162 | } 163 | fn set_batch_count(&mut self, batch_count: i32) -> &mut Self { 164 | { 165 | self.base().batch_count = batch_count; 166 | } 167 | self 168 | } 169 | fn set_control_strength(&mut self, control_strength: f32) -> &mut Self { 170 | { 171 | self.base().control_strength = control_strength; 172 | } 173 | self 174 | } 175 | fn set_style_ratio(&mut self, style_ratio: f32) -> &mut Self { 176 | { 177 | self.base().style_ratio = style_ratio; 178 | } 179 | self 180 | } 181 | fn enable_normalize_input(&mut self, normalize_input: bool) -> &mut Self { 182 | { 183 | self.base().normalize_input = normalize_input; 184 | } 185 | self 186 | } 187 | fn set_input_id_images_dir(&mut self, input_id_images_dir: String) -> &mut Self { 188 | { 189 | self.base().input_id_images_dir = input_id_images_dir; 190 | } 191 | self 192 | } 193 | fn enable_canny_preprocess(&mut self, canny_preprocess: bool) -> &mut Self { 194 | { 195 | self.base().canny_preprocess = canny_preprocess; 196 | } 197 | self 198 | } 199 | fn set_upscale_model(&mut self, upscale_model: String) -> &mut Self { 200 | { 201 | self.base().upscale_model = upscale_model; 202 | } 203 | self 204 | } 205 | fn set_upscale_repeats(&mut self, upscale_repeats: i32) -> &mut Self { 206 | { 207 | self.base().upscale_repeats = upscale_repeats; 208 | } 209 | self 210 | } 211 | fn set_output_path(&mut self, output_path: String) -> &mut Self { 212 | { 213 | self.base().output_path = output_path; 214 | } 215 | self 216 | } 217 | fn generate(&self) -> Result<(), WasmedgeSdErrno>; 218 | } 219 | 220 | /// Represents computation context for text-to-image task 221 | #[derive(Debug)] 222 | pub struct TextToImage { 223 | pub common: BaseContext, 224 | } 225 | impl BaseFunction for TextToImage { 226 | fn base(&mut self) -> &mut BaseContext { 227 | &mut self.common 228 | } 229 | fn generate(&self) -> Result<(), WasmedgeSdErrno> { 230 | if self.common.prompt.is_empty() { 231 | return Err(WASMEDGE_SD_ERRNO_INVALID_ARGUMENT); 232 | } 233 | let mut data: Vec = vec![0; BUF_LEN as usize]; 234 | let result = unsafe { 235 | stable_diffusion_interface::text_to_image( 236 | &self.common.prompt, 237 | self.common.session_id, 238 | &self.common.control_image, 239 | &self.common.negative_prompt, 240 | self.common.guidance, 241 | self.common.width, 242 | self.common.height, 243 | self.common.clip_skip, 244 | self.common.cfg_scale, 245 | self.common.sample_method, 246 | self.common.sample_steps, 247 | self.common.seed, 248 | self.common.batch_count, 249 | self.common.control_strength, 250 | self.common.style_ratio, 251 | self.common.normalize_input, 252 | &self.common.input_id_images_dir, 253 | self.common.canny_preprocess, 254 | &self.common.upscale_model, 255 | self.common.upscale_repeats, 256 | &self.common.output_path, 257 | data.as_mut_ptr(), 258 | BUF_LEN, 259 | ) 260 | }; 261 | result?; 262 | Ok(()) 263 | } 264 | } 265 | 266 | /// Represents computation context for image-to-image task. 267 | #[derive(Debug)] 268 | pub struct ImageToImage { 269 | pub common: BaseContext, 270 | pub image: ImageType, 271 | pub strength: f32, 272 | } 273 | impl BaseFunction for ImageToImage { 274 | fn base(&mut self) -> &mut BaseContext { 275 | &mut self.common 276 | } 277 | fn generate(&self) -> Result<(), WasmedgeSdErrno> { 278 | if self.common.prompt.is_empty() { 279 | return Err(WASMEDGE_SD_ERRNO_INVALID_ARGUMENT); 280 | } 281 | match &self.image { 282 | ImageType::Path(path) => { 283 | if path.is_empty() { 284 | return Err(WASMEDGE_SD_ERRNO_INVALID_ARGUMENT); 285 | } 286 | } 287 | } 288 | let mut data: Vec = vec![0; BUF_LEN as usize]; 289 | let result = unsafe { 290 | stable_diffusion_interface::image_to_image( 291 | &self.image, 292 | self.common.session_id, 293 | self.common.guidance, 294 | self.common.width, 295 | self.common.height, 296 | &self.common.control_image, 297 | &self.common.prompt, 298 | &self.common.negative_prompt, 299 | self.common.clip_skip, 300 | self.common.cfg_scale, 301 | self.common.sample_method, 302 | self.common.sample_steps, 303 | self.strength, 304 | self.common.seed, 305 | self.common.batch_count, 306 | self.common.control_strength, 307 | self.common.style_ratio, 308 | self.common.normalize_input, 309 | &self.common.input_id_images_dir, 310 | self.common.canny_preprocess, 311 | &self.common.upscale_model, 312 | self.common.upscale_repeats, 313 | &self.common.output_path, 314 | data.as_mut_ptr(), 315 | BUF_LEN, 316 | ) 317 | }; 318 | result?; 319 | Ok(()) 320 | } 321 | } 322 | impl ImageToImage { 323 | pub fn set_image(&mut self, image: ImageType) -> &mut Self { 324 | { 325 | self.image = image; 326 | } 327 | self 328 | } 329 | pub fn set_strength(&mut self, strength: f32) -> &mut Self { 330 | { 331 | self.strength = strength; 332 | } 333 | self 334 | } 335 | } 336 | 337 | /// Builder for creating a StableDiffusion instance. 338 | #[derive(Debug)] 339 | pub struct SDBuidler { 340 | sd: StableDiffusion, 341 | } 342 | impl SDBuidler { 343 | pub fn new(task: Task, model_path: impl AsRef) -> SDResult { 344 | let path = model_path 345 | .as_ref() 346 | .to_str() 347 | .ok_or_else(|| SDError::Operation("The model path is not valid unicode.".into()))?; 348 | let sd = StableDiffusion::new(task, path); 349 | Ok(Self { sd }) 350 | } 351 | 352 | /// Create a new builder with a full model path. 353 | pub fn new_with_full_model(task: Task, model_path: impl AsRef) -> SDResult { 354 | let path = model_path 355 | .as_ref() 356 | .to_str() 357 | .ok_or_else(|| SDError::Operation("The model path is not valid unicode.".into()))?; 358 | let sd = StableDiffusion::new(task, path); 359 | Ok(Self { sd }) 360 | } 361 | 362 | /// Create a new builder with a standalone diffusion model. 363 | pub fn new_with_standalone_model( 364 | task: Task, 365 | diffusion_model_path: impl AsRef, 366 | ) -> SDResult { 367 | let path = diffusion_model_path 368 | .as_ref() 369 | .to_str() 370 | .ok_or_else(|| SDError::Operation("The model path is not valid unicode.".into()))?; 371 | let sd = StableDiffusion::new_with_standalone_model(task, path); 372 | Ok(Self { sd }) 373 | } 374 | 375 | pub fn with_vae_path(mut self, path: impl AsRef) -> SDResult { 376 | let path = path.as_ref().to_str().ok_or_else(|| { 377 | SDError::InvalidPath("The path to the vae file is not valid unicode.".into()) 378 | })?; 379 | self.sd.vae_path = path.into(); 380 | Ok(self) 381 | } 382 | 383 | pub fn with_clip_l_path(mut self, path: impl AsRef) -> SDResult { 384 | let path = path.as_ref().to_str().ok_or_else(|| { 385 | SDError::InvalidPath("The path to the clip_l file is not valid unicode.".into()) 386 | })?; 387 | self.sd.clip_l_path = path.into(); 388 | Ok(self) 389 | } 390 | 391 | pub fn with_t5xxl_path(mut self, path: impl AsRef) -> SDResult { 392 | let path = path.as_ref().to_str().ok_or_else(|| { 393 | SDError::InvalidPath("The path to the t5xxl file is not valid unicode.".into()) 394 | })?; 395 | self.sd.t5xxl_path = path.into(); 396 | Ok(self) 397 | } 398 | 399 | pub fn with_taesd_path(mut self, path: impl AsRef) -> SDResult { 400 | let path = path.as_ref().to_str().ok_or_else(|| { 401 | SDError::InvalidPath("The path to the taesd file is not valid unicode.".into()) 402 | })?; 403 | self.sd.taesd_path = path.into(); 404 | Ok(self) 405 | } 406 | 407 | pub fn with_lora_model_dir(mut self, path: impl AsRef) -> SDResult { 408 | let path = path.as_ref().to_str().ok_or_else(|| { 409 | SDError::InvalidPath( 410 | "The path to the lora model directory is not valid unicode.".into(), 411 | ) 412 | })?; 413 | self.sd.lora_model_dir = path.into(); 414 | Ok(self) 415 | } 416 | 417 | pub fn use_control_net(mut self, path: impl AsRef, on_cpu: bool) -> SDResult { 418 | let path = path.as_ref().to_str().ok_or_else(|| { 419 | SDError::InvalidPath("The path to the controlnet file is not valid unicode.".into()) 420 | })?; 421 | self.sd.control_net_path = path.into(); 422 | self.sd.control_net_cpu = on_cpu; 423 | Ok(self) 424 | } 425 | 426 | pub fn with_embeddings_path(mut self, path: impl AsRef) -> SDResult { 427 | let path = path.as_ref().to_str().ok_or_else(|| { 428 | SDError::InvalidPath("The path to the embeddings dir is not valid unicode.".into()) 429 | })?; 430 | self.sd.embed_dir = path.into(); 431 | Ok(self) 432 | } 433 | 434 | pub fn with_stacked_id_embeddings_path(mut self, path: impl AsRef) -> SDResult { 435 | let path = path.as_ref().to_str().ok_or_else(|| { 436 | SDError::InvalidPath( 437 | "The path to the stacked id embeddings dir is not valid unicode.".into(), 438 | ) 439 | })?; 440 | self.sd.id_embed_dir = path.into(); 441 | Ok(self) 442 | } 443 | 444 | pub fn with_n_threads(mut self, n_threads: i32) -> Self { 445 | self.sd.n_threads = n_threads; 446 | self 447 | } 448 | 449 | pub fn with_wtype(mut self, wtype: SdTypeT) -> Self { 450 | self.sd.wtype = wtype; 451 | self 452 | } 453 | 454 | pub fn with_rng_type(mut self, rng_type: RngTypeT) -> Self { 455 | self.sd.rng_type = rng_type; 456 | self 457 | } 458 | 459 | pub fn with_schedule(mut self, schedule: ScheduleT) -> Self { 460 | self.sd.schedule = schedule; 461 | self 462 | } 463 | 464 | pub fn enable_vae_tiling(mut self, enable: bool) -> Self { 465 | self.sd.vae_tiling = enable; 466 | self 467 | } 468 | 469 | pub fn enable_clip_on_cpu(mut self, enable: bool) -> Self { 470 | self.sd.clip_on_cpu = enable; 471 | self 472 | } 473 | 474 | pub fn enable_vae_on_cpu(mut self, enable: bool) -> Self { 475 | self.sd.vae_on_cpu = enable; 476 | self 477 | } 478 | 479 | pub fn build(self) -> StableDiffusion { 480 | self.sd 481 | } 482 | } 483 | 484 | /// Represents a stable diffusion model. 485 | #[derive(Debug)] 486 | pub struct StableDiffusion { 487 | task: Task, 488 | model_path: String, 489 | clip_l_path: String, 490 | t5xxl_path: String, 491 | diffusion_model_path: String, 492 | vae_path: String, 493 | taesd_path: String, 494 | control_net_path: String, 495 | lora_model_dir: String, 496 | embed_dir: String, 497 | id_embed_dir: String, 498 | vae_decode_only: bool, 499 | vae_tiling: bool, 500 | n_threads: i32, 501 | wtype: SdTypeT, 502 | rng_type: RngTypeT, 503 | schedule: ScheduleT, 504 | clip_on_cpu: bool, 505 | control_net_cpu: bool, 506 | vae_on_cpu: bool, 507 | } 508 | impl StableDiffusion { 509 | pub fn new(task: Task, model_path: &str) -> StableDiffusion { 510 | let vae_decode_only = match task { 511 | Task::TextToImage => true, 512 | Task::ImageToImage => false, 513 | Task::Convert => false, 514 | }; 515 | StableDiffusion { 516 | task, 517 | model_path: model_path.to_string(), 518 | clip_l_path: "".to_string(), 519 | t5xxl_path: "".to_string(), 520 | diffusion_model_path: "".to_string(), 521 | vae_path: "".to_string(), 522 | taesd_path: "".to_string(), 523 | control_net_path: "".to_string(), 524 | lora_model_dir: "".to_string(), 525 | embed_dir: "".to_string(), 526 | id_embed_dir: "".to_string(), 527 | vae_decode_only, 528 | vae_tiling: false, 529 | n_threads: -1, 530 | wtype: SdTypeT::SdTypeCount, 531 | rng_type: RngTypeT::StdDefaultRng, 532 | schedule: ScheduleT::DEFAULT, 533 | clip_on_cpu: false, 534 | control_net_cpu: false, 535 | vae_on_cpu: false, 536 | } 537 | } 538 | 539 | pub fn new_with_standalone_model(task: Task, diffusion_model_path: &str) -> StableDiffusion { 540 | let vae_decode_only = match task { 541 | Task::TextToImage => true, 542 | Task::ImageToImage => false, 543 | Task::Convert => false, 544 | }; 545 | StableDiffusion { 546 | task, 547 | model_path: "".to_string(), 548 | clip_l_path: "".to_string(), 549 | t5xxl_path: "".to_string(), 550 | diffusion_model_path: diffusion_model_path.to_string(), 551 | vae_path: "".to_string(), 552 | taesd_path: "".to_string(), 553 | control_net_path: "".to_string(), 554 | lora_model_dir: "".to_string(), 555 | embed_dir: "".to_string(), 556 | id_embed_dir: "".to_string(), 557 | vae_decode_only, 558 | vae_tiling: false, 559 | n_threads: -1, 560 | wtype: SdTypeT::SdTypeCount, 561 | rng_type: RngTypeT::StdDefaultRng, 562 | schedule: ScheduleT::DEFAULT, 563 | clip_on_cpu: false, 564 | control_net_cpu: false, 565 | vae_on_cpu: false, 566 | } 567 | } 568 | 569 | pub fn create_context(&self) -> Result { 570 | let mut session_id = MaybeUninit::::uninit(); 571 | unsafe { 572 | stable_diffusion_interface::create_context( 573 | &self.model_path, 574 | &self.clip_l_path, 575 | &self.t5xxl_path, 576 | &self.diffusion_model_path, 577 | &self.vae_path, 578 | &self.taesd_path, 579 | &self.control_net_path, 580 | &self.lora_model_dir, 581 | &self.embed_dir, 582 | &self.id_embed_dir, 583 | self.vae_decode_only, 584 | self.vae_tiling, 585 | self.n_threads, 586 | self.wtype, 587 | self.rng_type, 588 | self.schedule, 589 | self.clip_on_cpu, 590 | self.control_net_cpu, 591 | self.vae_on_cpu, 592 | session_id.as_mut_ptr(), 593 | )?; 594 | let common = BaseContext { 595 | prompt: "".to_string(), 596 | session_id: session_id.assume_init(), 597 | guidance: 3.5, 598 | width: 512, 599 | height: 512, 600 | control_image: ImageType::Path("".to_string()), 601 | negative_prompt: "".to_string(), 602 | clip_skip: -1, 603 | cfg_scale: 7.0, 604 | sample_method: SampleMethodT::EULERA, 605 | sample_steps: 20, 606 | seed: 42, 607 | batch_count: 1, 608 | control_strength: 0.9, 609 | style_ratio: 20.0, 610 | normalize_input: false, 611 | input_id_images_dir: "".to_string(), 612 | canny_preprocess: false, 613 | upscale_model: "".to_string(), 614 | upscale_repeats: 1, 615 | output_path: "".to_string(), 616 | }; 617 | match self.task { 618 | Task::TextToImage => Ok(Context::TextToImage(TextToImage { common })), 619 | Task::ImageToImage => Ok(Context::ImageToImage(ImageToImage { 620 | common, 621 | image: ImageType::Path("".to_string()), 622 | strength: 0.75, 623 | })), 624 | Task::Convert => todo!(), 625 | } 626 | } 627 | } 628 | } 629 | -------------------------------------------------------------------------------- /example/src/main.rs: -------------------------------------------------------------------------------- 1 | use wasmedge_stable_diffusion::stable_diffusion_interface::{ 2 | ImageType, RngTypeT, SampleMethodT, ScheduleT, SdTypeT, 3 | }; 4 | use wasmedge_stable_diffusion::{BaseFunction, Context, Quantization, SDBuidler, Task}; 5 | 6 | use clap::{crate_version, Arg, ArgAction, Command}; 7 | use rand::Rng; 8 | use std::str::FromStr; 9 | use std::time::{SystemTime, UNIX_EPOCH}; 10 | 11 | const WTYPE_METHODS: [&str; 35] = [ 12 | "f32", "f16", "q4_0", "q4_1", "", "", "q5_0", "q5_1", "q8_0", "q8_1", "q2k", "q3k", "q4k", 13 | "q5k", "q6k", "q8k", "iq2Xxs", "iq2Xs", "iq3Xxs", "iq1S", "iq4N1", "iq3S", "iq2S", "iq4Xs", 14 | "i8", "i16", "i32", "i64", "f64", "iq1M", "bf16", "q4044", "q4048", "q4088", "count", 15 | ]; 16 | //Sampling Methods 17 | const SAMPLE_METHODS: [&str; 10] = [ 18 | "euler_a", 19 | "euler", 20 | "heun", 21 | "dpm2", 22 | "dpm++2s_a", 23 | "dpm++2m", 24 | "dpm++2mv2", 25 | "ipndm", 26 | "ipndm_v", 27 | "lcm", 28 | ]; 29 | // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h 30 | const SCHEDULE_STR: [&str; 6] = [ 31 | "default", 32 | "discrete", 33 | "karras", 34 | "exponential", 35 | "ays", 36 | "gits", 37 | ]; 38 | 39 | fn main() -> Result<(), Box> { 40 | let matches = Command::new("wasmedge-stable-diffusion") 41 | .version(crate_version!()) 42 | .arg( 43 | Arg::new("mode") 44 | .short('M') 45 | .long("mode") 46 | .value_name("MODE") 47 | .value_parser([ 48 | "txt2img", 49 | "img2img", 50 | "convert", 51 | ]) 52 | .help("run mode (txt2img or img2img or convert, default: txt2img)") 53 | .default_value("txt2img"), 54 | ) 55 | .arg( 56 | Arg::new("n_threads") 57 | .short('t') 58 | .long("threads") 59 | .value_parser(clap::value_parser!(i32)) 60 | .value_name("N") 61 | .help("number of threads to use during computation (default: -1).If threads <= 0, then threads will be set to the number of CPU physical cores") 62 | .default_value("-1"), 63 | ) 64 | .arg( 65 | Arg::new("model") 66 | .short('m') 67 | .long("model") 68 | .value_name("MODEL") 69 | .help("path to full model") 70 | .default_value("stable-diffusion-v1-4-Q8_0.gguf"), 71 | ) 72 | .arg( 73 | Arg::new("diffusion_model_path") 74 | .long("diffusion-model") 75 | .value_name("PATH") 76 | .help("path to the standalone diffusion model") 77 | .default_value(""), 78 | ) 79 | .arg( 80 | Arg::new("clip_l_path") 81 | .long("clip_l") 82 | .value_name("PATH") 83 | .help("path to clip_l") 84 | .default_value(""), 85 | ) 86 | .arg( 87 | Arg::new("t5xxl_path") 88 | .long("t5xxl") 89 | .value_name("PATH") 90 | .help("path to t5xxl") 91 | .default_value(""), 92 | ) 93 | .arg( 94 | Arg::new("vae_path") 95 | .long("vae") 96 | .value_name("VAE") 97 | .help("path to vae") 98 | .default_value(""), 99 | ) 100 | .arg( 101 | Arg::new("taesd_path") 102 | .long("taesd") 103 | .value_name("TAESD_PATH") 104 | .help("path to taesd. Using Tiny AutoEncoder for fast decoding (low quality).") 105 | .default_value(""), 106 | ) 107 | .arg( 108 | Arg::new("control_net_path") 109 | .long("control-net") 110 | .value_name("CONTROL_PATH") 111 | .help("path to control net model.") 112 | .default_value(""), 113 | ) 114 | .arg( 115 | Arg::new("embeddings_path") 116 | .long("embd-dir") 117 | .value_name("EMBEDDING_PATH") 118 | .help("path to embeddings.") 119 | .default_value(""), 120 | ) 121 | .arg( 122 | Arg::new("stacked_id_embd_dir") 123 | .long("stacked-id-embd-dir") 124 | .value_name("DIR") 125 | .help("path to PHOTOMAKER stacked id embeddings.") 126 | .default_value(""), 127 | ) 128 | .arg( 129 | Arg::new("input_id_images_dir") 130 | .long("input-id-images-dir") 131 | .value_name("DIR") 132 | .help("path to PHOTOMAKER input id images dir.") 133 | .default_value(""), 134 | ) 135 | .arg( 136 | Arg::new("normalize_input") 137 | .long("normalize-input") 138 | .help("normalize PHOTOMAKER input id images.") 139 | .action(ArgAction::SetTrue), 140 | ) 141 | .arg( 142 | Arg::new("upscale_model") 143 | .long("upscale-model") 144 | .value_name("ESRGAN_PATH") 145 | .help("path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.") 146 | .default_value(""), 147 | ) 148 | .arg( 149 | Arg::new("upscale_repeats") 150 | .long("upscale-repeats") 151 | .value_parser(clap::value_parser!(i32)) 152 | .value_name("UPSCALE_REPEATS") 153 | .help("Run the ESRGAN upscaler this many times (default 1).") 154 | .default_value("1"), 155 | ) 156 | .arg( 157 | Arg::new("type") 158 | .long("type") 159 | .value_name("TYPE") 160 | .value_parser([ 161 | "f32", 162 | "f16", 163 | "q4_0", 164 | "q4_1", 165 | "q5_0", 166 | "q5_1", 167 | "q8_0", 168 | "q8_1", 169 | "q2k", 170 | "q3k", 171 | "q4k", 172 | "q5k", 173 | "q6k", 174 | "q8k", 175 | "iq2Xxs", 176 | "iq2Xs", 177 | "iq3Xxs", 178 | "iq1S", 179 | "iq4N1", 180 | "iq3S", 181 | "iq2S", 182 | "iq4Xs", 183 | "i8", 184 | "i16", 185 | "i32", 186 | "i64", 187 | "f64", 188 | "iq1M", 189 | "bf16", 190 | "q4044", 191 | "q4048", 192 | "q4088", 193 | "count" 194 | ]) 195 | .help("weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)If not specified, the default is the type of the weight file.") 196 | .default_value("count"), 197 | ) 198 | .arg( 199 | Arg::new("lora_model_dir") 200 | .long("lora-model-dir") 201 | .value_name("DIR") 202 | .help("lora model directory.") 203 | .default_value(""), 204 | ) 205 | .arg( 206 | Arg::new("init_img") 207 | .short('i') 208 | .long("init-img") 209 | .value_name("IMAGE") 210 | .help("path to the input image, required by img2img.") 211 | .default_value("./output.png"), 212 | ) 213 | .arg( 214 | Arg::new("control_image") 215 | .long("control-image") 216 | .value_name("IMAGE") 217 | .help("path to image condition, control net.") 218 | .default_value(""), 219 | ) 220 | .arg( 221 | Arg::new("output_path") 222 | .short('o') 223 | .long("output") 224 | .value_name("OUTPUT") 225 | .help("path to write result image to (default: ./output.png).") 226 | .default_value("./output2.png"), 227 | ) 228 | .arg( 229 | Arg::new("prompt") 230 | .short('p') 231 | .long("prompt") 232 | .value_name("PROMPT") 233 | .help("the prompt to render.") 234 | .default_value("cat with blue eyes"), 235 | ) 236 | .arg( 237 | Arg::new("negative_prompt") 238 | .short('n') 239 | .long("negative-prompt") 240 | .value_name("PROMPT") 241 | .help("the negative prompt.(default: '').") 242 | .default_value(""), 243 | ) 244 | .arg( 245 | Arg::new("cfg_scale") 246 | .long("cfg-scale") 247 | .value_parser(clap::value_parser!(f32)) 248 | .value_name("SCALE") 249 | .help("unconditional guidance scale: (default: 7.0)") 250 | .default_value("7.0"), 251 | ) 252 | .arg( 253 | Arg::new("strength") 254 | .long("strength") 255 | .value_parser(clap::value_parser!(f32)) 256 | .value_name("STRENGTH") 257 | .help("strength for noising/unnoising (default: 0.75).") 258 | .default_value("0.75"), 259 | ) 260 | .arg( 261 | Arg::new("style_ratio") 262 | .long("style-ratio") 263 | .value_parser(clap::value_parser!(f32)) 264 | .value_name("STYLE_RATIO") 265 | .help("strength for keeping input identity (default: 20%).") 266 | .default_value("20.0"), 267 | ) 268 | .arg( 269 | Arg::new("control_strength") 270 | .long("control-strength") 271 | .value_parser(clap::value_parser!(f32)) 272 | .value_name("CONTROL-STRENGTH") 273 | .help("strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image.") 274 | .default_value("0.9"), 275 | ) 276 | .arg( 277 | Arg::new("guidance") 278 | .long("guidance") 279 | .value_parser(clap::value_parser!(f32)) 280 | .value_name("GUAIDANCE") 281 | .help("guidance scale") 282 | .default_value("3.5"), 283 | ) 284 | .arg( 285 | Arg::new("height") 286 | .short('H') 287 | .long("height") 288 | .value_parser(clap::value_parser!(i32)) 289 | .value_name("H") 290 | .help("image height, in pixel space (default: 512)") 291 | .default_value("512"), 292 | ) 293 | .arg( 294 | Arg::new("width") 295 | .short('W') 296 | .long("width") 297 | .value_parser(clap::value_parser!(i32)) 298 | .value_name("W") 299 | .help("image width, in pixel space (default: 512)") 300 | .default_value("512"), 301 | ) 302 | .arg( 303 | Arg::new("sampling_method") 304 | .long("sampling-method") 305 | .value_parser([ 306 | "euler_a", 307 | "euler", 308 | "heun", 309 | "dpm2", 310 | "dpm++2s_a", 311 | "dpm++2m", 312 | "dpm++2mv2", 313 | "ipndm", 314 | "ipndm_v", 315 | "lcm", 316 | ]) 317 | .value_name("SAMPLING_METHOD") 318 | .help("the sampling method, include values {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}, sampling method (default: euler_a)") 319 | .default_value("euler_a"), 320 | ) 321 | .arg( 322 | Arg::new("sample_steps") 323 | .long("steps") 324 | .value_parser(clap::value_parser!(i32)) 325 | .value_name("STEPS") 326 | .help("number of sample steps (default: 20).") 327 | .default_value("20"), 328 | ) 329 | .arg( 330 | Arg::new("rng_type") 331 | .long("rng") 332 | .value_name("RNG") 333 | .value_parser([ 334 | "std_default", 335 | "cuda", 336 | ]) 337 | .help("RNG (default: std_default).") 338 | .default_value("std_default"), 339 | ) 340 | .arg( 341 | Arg::new("seed") 342 | .short('s') 343 | .long("seed") 344 | .value_parser(clap::value_parser!(i32)) 345 | .value_name("SEED") 346 | .help("RNG seed (default: 42, use random seed for < 0).") 347 | .default_value("42"), 348 | ) 349 | .arg( 350 | Arg::new("batch_count") 351 | .short('b') 352 | .long("batch-count") 353 | .value_parser(clap::value_parser!(i32)) 354 | .value_name("BATCH_COUNT") 355 | .help("number of images to generate.") 356 | .default_value("1"), 357 | ) 358 | .arg( 359 | Arg::new("schedule") 360 | .long("schedule") 361 | .value_name("SCHEDULE") 362 | .value_parser([ 363 | "default", 364 | "discrete", 365 | "karras", 366 | "exponential", 367 | "ays", 368 | "gits" 369 | ]) 370 | .help("Denoiser sigma schedule") 371 | .default_value("default"), 372 | ) 373 | .arg( 374 | Arg::new("clip_skip") 375 | .long("clip-skip") 376 | .value_parser(clap::value_parser!(i32)) 377 | .value_name("N") 378 | .help("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1), <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x.") 379 | .default_value("-1"), 380 | ) 381 | .arg( 382 | Arg::new("vae_tiling") 383 | .long("vae-tiling") 384 | .help("process vae in tiles to reduce memory usage") 385 | .action(ArgAction::SetTrue), 386 | ) 387 | .arg( 388 | Arg::new("vae_on_cpu") 389 | .long("vae-on-cpu") 390 | .help("keep vae in cpu (for low vram)") 391 | .action(ArgAction::SetTrue), 392 | ) 393 | .arg( 394 | Arg::new("clip_on_cpu") 395 | .long("clip-on-cpu") 396 | .help("keep clip in cpu (for low vram)") 397 | .action(ArgAction::SetTrue), 398 | ) 399 | .arg( 400 | Arg::new("control_net_cpu") 401 | .long("control-net-cpu") 402 | .help("keep controlnet in cpu (for low vram)") 403 | .action(ArgAction::SetTrue), 404 | ) 405 | .arg( 406 | Arg::new("canny") 407 | .long("canny") 408 | .help("apply canny preprocessor (edge detection)") 409 | .action(ArgAction::SetTrue), 410 | ) 411 | .arg( 412 | Arg::new("debug") 413 | .long("debug") 414 | .help("print debug informations") 415 | .action(ArgAction::SetTrue), 416 | ) 417 | .after_help("run at the dir of .wasm, Example:wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm -m ../../models/stable-diffusion-v1-4-Q8_0.gguf -M img2img\n") 418 | .get_matches(); 419 | 420 | //init the paraments-------------------------------------------------------------- 421 | let mut options = Options::default(); 422 | 423 | //mode, include "txt2img","img2img",----------"convert" is not yet-------. 424 | let sd_mode = matches.get_one::("mode").unwrap(); 425 | let task = Task::from_str(sd_mode)?; 426 | options.mode = sd_mode.to_string(); 427 | 428 | //n_threads 429 | let n_threads = matches.get_one::("n_threads").unwrap(); 430 | options.n_threads = *n_threads as i32; 431 | 432 | //model 433 | let sd_model = matches.get_one::("model").unwrap(); 434 | options.model_path = sd_model.to_string(); 435 | 436 | //clip_l_path 437 | let clip_l_path = matches.get_one::("clip_l_path").unwrap(); 438 | options.clip_l_path = clip_l_path.to_string(); 439 | 440 | //t5xxl_path 441 | let t5xxl_path = matches.get_one::("t5xxl_path").unwrap(); 442 | options.t5xxl_path = t5xxl_path.to_string(); 443 | 444 | //diffusion_model_path 445 | let diffusion_model_path = matches.get_one::("diffusion_model_path").unwrap(); 446 | options.diffusion_model_path = diffusion_model_path.to_string(); 447 | 448 | //vae_path 449 | let vae_path = matches.get_one::("vae_path").unwrap(); 450 | options.vae_path = vae_path.to_string(); 451 | 452 | //taesd_path 453 | let taesd_path = matches.get_one::("taesd_path").unwrap(); 454 | options.taesd_path = taesd_path.to_string(); 455 | 456 | //control_net_path 457 | let control_net_path = matches.get_one::("control_net_path").unwrap(); 458 | options.control_net_path = control_net_path.to_string(); 459 | 460 | //embeddings_path 461 | let embeddings_path = matches.get_one::("embeddings_path").unwrap(); 462 | options.embeddings_path = embeddings_path.to_string(); 463 | 464 | //stacked_id_embd_dir 465 | let stacked_id_embd_dir = matches.get_one::("stacked_id_embd_dir").unwrap(); 466 | options.stacked_id_embd_dir = stacked_id_embd_dir.to_string(); 467 | 468 | //input_id_images_dir 469 | let input_id_images_dir = matches.get_one::("input_id_images_dir").unwrap(); 470 | options.input_id_images_dir = input_id_images_dir.to_string(); 471 | 472 | //normalize-input 473 | let normalize_input = matches.get_flag("normalize_input"); 474 | options.normalize_input = normalize_input; 475 | 476 | //upscale_model 477 | let upscale_model = matches.get_one::("upscale_model").unwrap(); 478 | options.upscale_model = upscale_model.to_string(); 479 | 480 | //upscale_repeats 481 | let upscale_repeats = matches.get_one::("upscale_repeats").unwrap(); 482 | if *upscale_repeats < 1 { 483 | return Err("Error: the upscale_repeats must be greater than 0".into()); 484 | } 485 | options.upscale_repeats = *upscale_repeats as i32; 486 | 487 | //type 488 | let wtype_selected = matches.get_one::("type").unwrap(); 489 | let wtype_found = WTYPE_METHODS 490 | .iter() 491 | .position(|&method| method == wtype_selected) 492 | .ok_or(format!("Invalid wtype: {}", wtype_selected))?; 493 | let wtype = SdTypeT::from_index(wtype_found)?; 494 | options.wtype = wtype; 495 | 496 | //lora_model_dir 497 | let lora_model_dir = matches.get_one::("lora_model_dir").unwrap(); 498 | options.lora_model_dir = lora_model_dir.to_string(); 499 | 500 | //init_img, used only for img2img 501 | let img = matches.get_one::("init_img").unwrap(); 502 | if sd_mode == "img2img" { 503 | options.init_img = img.to_string(); 504 | }; 505 | 506 | //control_image 507 | let control_image = matches.get_one::("control_image").unwrap(); 508 | options.control_image = control_image.to_string(); 509 | 510 | //output_path 511 | let output_path = matches.get_one::("output_path").unwrap(); 512 | options.output_path = output_path.to_string(); 513 | 514 | //prompt 515 | let prompt = matches.get_one::("prompt").unwrap(); 516 | options.prompt = prompt.to_string(); 517 | 518 | //negative_prompt 519 | let negative_prompt = matches.get_one::("negative_prompt").unwrap(); 520 | options.negative_prompt = negative_prompt.to_string(); 521 | 522 | //cfg_scale 523 | let cfg_scale = matches.get_one::("cfg_scale").unwrap(); 524 | options.cfg_scale = *cfg_scale as f32; 525 | 526 | //strength 527 | let strength = matches.get_one::("strength").unwrap(); 528 | if *strength < 0.0 || *strength > 1.0 { 529 | return Err("Error: can only work with strength in [0.0, 1.0]".into()); 530 | } 531 | options.strength = *strength as f32; 532 | 533 | //style_ratio 534 | let style_ratio = matches.get_one::("style_ratio").unwrap(); 535 | if *style_ratio > 100.0 { 536 | return Err("Error: can only work with style_ratio in [0.0, 100.0]".into()); 537 | } 538 | options.style_ratio = *style_ratio as f32; 539 | 540 | //control_strength 541 | let control_strength = matches.get_one::("control_strength").unwrap(); 542 | if *control_strength > 1.0 { 543 | return Err("Error: can only work with control_strength in [0.0, 1.0]".into()); 544 | } 545 | options.control_strength = *control_strength as f32; 546 | 547 | //guidance 548 | let guidance = matches.get_one::("guidance").unwrap(); 549 | options.guidance = *guidance as f32; 550 | 551 | //height 552 | let height = matches.get_one::("height").unwrap(); 553 | options.height = *height as i32; 554 | 555 | //width 556 | let width = matches.get_one::("width").unwrap(); 557 | options.width = *width as i32; 558 | 559 | //sampling_method 560 | let sampling_method_selected = matches.get_one::("sampling_method").unwrap(); 561 | let sample_method_found = SAMPLE_METHODS 562 | .iter() 563 | .position(|&method| method == sampling_method_selected) 564 | .ok_or(format!( 565 | "Invalid sampling method: {}", 566 | sampling_method_selected 567 | ))?; 568 | let sample_method = SampleMethodT::from_index(sample_method_found)?; 569 | options.sample_method = sample_method; 570 | 571 | //sample_steps 572 | let sample_steps = matches.get_one::("sample_steps").unwrap(); 573 | if *sample_steps <= 0 { 574 | return Err("Error: the sample_steps must be greater than 0".into()); 575 | } 576 | options.sample_steps = *sample_steps as i32; 577 | 578 | //rng_type 579 | let mut rng_type = RngTypeT::StdDefaultRng; 580 | let rng_type_str = matches.get_one::("rng_type").unwrap(); 581 | if rng_type_str == "cuda" { 582 | rng_type = RngTypeT::CUDARng; 583 | } 584 | options.rng_type = rng_type; 585 | 586 | //seed 587 | let seed_str = matches.get_one::("seed").unwrap(); 588 | let mut seed = *seed_str; 589 | // let mut seed: i32 = seed_str.parse().expect("Failed to parse seed as i32"); 590 | if seed < 0 { 591 | let current_time = SystemTime::now() 592 | .duration_since(UNIX_EPOCH) 593 | .expect("Time went backwards"); 594 | let current_time_secs = current_time.as_secs() as u32; 595 | let mut rng = rand::thread_rng(); 596 | // Limit the result to i32 range 597 | seed = ((rng.gen::() ^ current_time_secs) & i32::MAX as u32) as i32; 598 | } 599 | options.seed = seed; 600 | 601 | //batch_count 602 | let batch_count = matches.get_one::("batch_count").unwrap(); 603 | options.batch_count = *batch_count as i32; 604 | 605 | //schedule 606 | let schedule_selected = matches.get_one::("schedule").unwrap(); 607 | let schedule_found = SCHEDULE_STR 608 | .iter() 609 | .position(|&method| method == schedule_selected) 610 | .ok_or(format!("Invalid sampling method: {}", schedule_selected))?; 611 | // Convert an index to an enumeration value 612 | let schedule = ScheduleT::from_index(schedule_found)?; 613 | options.schedule = schedule; 614 | 615 | //clip_skip 616 | let clip_skip = matches.get_one::("clip_skip").unwrap(); 617 | options.clip_skip = *clip_skip as i32; 618 | 619 | //vae_tiling 620 | let vae_tiling = matches.get_flag("vae_tiling"); 621 | options.vae_tiling = vae_tiling; 622 | 623 | //control_net_cpu 624 | let control_net_cpu = matches.get_flag("control_net_cpu"); 625 | options.control_net_cpu = control_net_cpu; 626 | 627 | //canny 628 | let canny = matches.get_flag("canny"); 629 | options.canny = canny; 630 | 631 | //clip_on_cpu 632 | let clip_on_cpu = matches.get_flag("clip_on_cpu"); 633 | options.clip_on_cpu = clip_on_cpu; 634 | 635 | //vae_on_cpu 636 | let vae_on_cpu = matches.get_flag("vae_on_cpu"); 637 | options.vae_on_cpu = vae_on_cpu; 638 | 639 | //debug 640 | let debug = matches.get_flag("debug"); 641 | 642 | //DEBUG: print options from CL 643 | if debug { 644 | print_params(&mut options); 645 | } 646 | 647 | //------------------------------- run the model ---------------------------------------- 648 | match options.mode.as_str() { 649 | "txt2img" => { 650 | let context = SDBuidler::new(task, &options.model_path)? 651 | .with_clip_l_path(options.clip_l_path)? 652 | .with_t5xxl_path(options.t5xxl_path)? 653 | .with_vae_path(options.vae_path)? 654 | .with_taesd_path(options.taesd_path)? 655 | .with_lora_model_dir(options.lora_model_dir)? 656 | .with_embeddings_path(options.embeddings_path)? 657 | .with_stacked_id_embeddings_path(options.stacked_id_embd_dir)? 658 | .use_control_net(options.control_net_path, options.control_net_cpu)? 659 | .with_n_threads(options.n_threads) 660 | .with_wtype(options.wtype) 661 | .with_rng_type(options.rng_type) 662 | .with_schedule(options.schedule) 663 | .enable_vae_tiling(options.vae_tiling) 664 | .enable_clip_on_cpu(options.clip_on_cpu) 665 | .enable_vae_on_cpu(options.vae_on_cpu) 666 | .build(); 667 | if let Context::TextToImage(mut text_to_image) = context.create_context().unwrap() { 668 | text_to_image 669 | .set_prompt(options.prompt) 670 | .set_guidance(options.guidance) 671 | .set_width(options.width) 672 | .set_height(options.height) 673 | .set_control_image(ImageType::Path(options.control_image)) 674 | .set_negative_prompt(options.negative_prompt) 675 | .set_clip_skip(options.clip_skip) 676 | .set_cfg_scale(options.cfg_scale) 677 | .set_sample_method(options.sample_method) 678 | .set_sample_steps(options.sample_steps) 679 | .set_seed(options.seed) 680 | .set_batch_count(options.batch_count) 681 | .set_control_strength(options.control_strength) 682 | .set_style_ratio(options.style_ratio) 683 | .enable_normalize_input(options.normalize_input) 684 | .set_input_id_images_dir(options.input_id_images_dir) 685 | .enable_canny_preprocess(options.canny) 686 | .set_upscale_model(options.upscale_model) 687 | .set_upscale_repeats(options.upscale_repeats) 688 | .set_output_path(options.output_path) 689 | .generate() 690 | .unwrap(); 691 | } 692 | } 693 | "img2img" => { 694 | let context = SDBuidler::new(task, &options.model_path)? 695 | .with_clip_l_path(options.clip_l_path)? 696 | .with_t5xxl_path(options.t5xxl_path)? 697 | .with_vae_path(options.vae_path)? 698 | .with_taesd_path(options.taesd_path)? 699 | .with_lora_model_dir(options.lora_model_dir)? 700 | .with_embeddings_path(options.embeddings_path)? 701 | .with_stacked_id_embeddings_path(options.stacked_id_embd_dir)? 702 | .use_control_net(options.control_net_path, options.control_net_cpu)? 703 | .with_n_threads(options.n_threads) 704 | .with_wtype(options.wtype) 705 | .with_rng_type(options.rng_type) 706 | .with_schedule(options.schedule) 707 | .enable_vae_tiling(options.vae_tiling) 708 | .enable_clip_on_cpu(options.clip_on_cpu) 709 | .enable_vae_on_cpu(options.vae_on_cpu) 710 | .build(); 711 | if let Context::ImageToImage(mut image_to_image) = context.create_context().unwrap() { 712 | image_to_image 713 | .set_prompt(options.prompt) 714 | .set_guidance(options.guidance) 715 | .set_width(options.width) 716 | .set_height(options.height) 717 | .set_control_image(ImageType::Path(options.control_image)) 718 | .set_negative_prompt(options.negative_prompt) 719 | .set_clip_skip(options.clip_skip) 720 | .set_cfg_scale(options.cfg_scale) 721 | .set_sample_method(options.sample_method) 722 | .set_sample_steps(options.sample_steps) 723 | .set_seed(options.seed) 724 | .set_batch_count(options.batch_count) 725 | .set_control_strength(options.control_strength) 726 | .set_style_ratio(options.style_ratio) 727 | .enable_normalize_input(options.normalize_input) 728 | .set_input_id_images_dir(options.input_id_images_dir) 729 | .enable_canny_preprocess(options.canny) 730 | .set_upscale_model(options.upscale_model) 731 | .set_upscale_repeats(options.upscale_repeats) 732 | .set_output_path(options.output_path) 733 | //addtional options for img2img 734 | .set_image(ImageType::Path(options.init_img)) 735 | .set_strength(options.strength) 736 | .generate() 737 | .unwrap(); 738 | } 739 | } 740 | "convert" => { 741 | let quantization = Quantization::new( 742 | &options.model_path, 743 | options.vae_path, 744 | &options.output_path, 745 | options.wtype, 746 | ); 747 | quantization.convert().unwrap(); 748 | } 749 | _ => { 750 | println!("Error: this mode isn't supported!"); 751 | } 752 | } 753 | return Ok(()); 754 | } 755 | 756 | #[derive(Debug)] 757 | struct Options { 758 | n_threads: i32, 759 | mode: String, 760 | model_path: String, 761 | clip_l_path: String, 762 | t5xxl_path: String, 763 | diffusion_model_path: String, 764 | vae_path: String, 765 | taesd_path: String, 766 | control_net_path: String, 767 | upscale_model: String, 768 | embeddings_path: String, 769 | stacked_id_embd_dir: String, 770 | input_id_images_dir: String, 771 | wtype: SdTypeT, 772 | lora_model_dir: String, 773 | output_path: String, 774 | init_img: String, 775 | control_image: String, 776 | 777 | prompt: String, 778 | negative_prompt: String, 779 | cfg_scale: f32, 780 | guidance: f32, 781 | style_ratio: f32, 782 | clip_skip: i32, 783 | width: i32, 784 | height: i32, 785 | batch_count: i32, 786 | 787 | sample_method: SampleMethodT, 788 | schedule: ScheduleT, 789 | sample_steps: i32, 790 | strength: f32, 791 | control_strength: f32, 792 | rng_type: RngTypeT, 793 | seed: i32, 794 | vae_tiling: bool, 795 | control_net_cpu: bool, 796 | normalize_input: bool, 797 | clip_on_cpu: bool, 798 | vae_on_cpu: bool, 799 | canny: bool, 800 | upscale_repeats: i32, 801 | } 802 | 803 | impl Default for Options { 804 | fn default() -> Self { 805 | Self { 806 | n_threads: -1, 807 | mode: String::from("txt2img"), 808 | model_path: String::from(""), 809 | clip_l_path: String::from(""), 810 | t5xxl_path: String::from(""), 811 | diffusion_model_path: String::from(""), 812 | vae_path: String::from(""), 813 | taesd_path: String::from(""), 814 | control_net_path: String::from(""), 815 | upscale_model: String::from(""), 816 | embeddings_path: String::from(""), 817 | stacked_id_embd_dir: String::from(""), 818 | input_id_images_dir: String::from(""), 819 | wtype: SdTypeT::SdTypeCount, 820 | lora_model_dir: String::from(""), 821 | output_path: String::from(""), 822 | init_img: String::from(""), 823 | control_image: String::from(""), 824 | 825 | prompt: String::from(""), 826 | negative_prompt: String::from(""), 827 | cfg_scale: 7.0, 828 | guidance: 3.5, 829 | style_ratio: 20.0, 830 | clip_skip: -1, 831 | width: 512, 832 | height: 512, 833 | batch_count: 1, 834 | 835 | sample_method: SampleMethodT::EULERA, 836 | schedule: ScheduleT::DEFAULT, 837 | sample_steps: 20, 838 | strength: 0.75, 839 | control_strength: 0.9, 840 | rng_type: RngTypeT::StdDefaultRng, 841 | seed: 42, 842 | vae_tiling: false, 843 | control_net_cpu: false, 844 | normalize_input: false, 845 | clip_on_cpu: false, 846 | vae_on_cpu: false, 847 | canny: false, 848 | upscale_repeats: 1, 849 | } 850 | } 851 | } 852 | 853 | fn print_params(params: &mut Options) { 854 | println!("Option:"); 855 | println!("[INFO] n_threads: {}", params.n_threads); 856 | println!("[INFO] mode: {}", params.mode); 857 | println!("[INFO] model_path: {}", params.model_path); 858 | println!( 859 | "[INFO] diffusion_model_path:{}", 860 | params.diffusion_model_path 861 | ); 862 | println!("[INFO] clip_l_path: {}", params.clip_l_path); 863 | println!("[INFO] t5xxl_path: {}", params.t5xxl_path); 864 | println!("[INFO] vae_path: {}", params.vae_path); 865 | println!("[INFO] taesd_path: {}", params.taesd_path); 866 | println!("[INFO] control_net_path: {}", params.control_net_path); 867 | println!("[INFO] upscale_model: {}", params.upscale_model); 868 | println!("[INFO] embeddings_path: {}", params.embeddings_path); 869 | println!("[INFO] stacked_id_embd: {}", params.stacked_id_embd_dir); 870 | println!("[INFO] input_id_images: {}", params.input_id_images_dir); 871 | println!("[INFO] wtype: {:?}", params.wtype); 872 | println!("[INFO] lora_model_dir: {}", params.lora_model_dir); 873 | println!("[INFO] output_path: {}", params.output_path); 874 | println!("[INFO] init_img: {}", params.init_img); 875 | println!("[INFO] control_image: {}", params.control_image); 876 | println!("[INFO] prompt: {}", params.prompt); 877 | println!("[INFO] negative_prompt: {}", params.negative_prompt); 878 | println!("[INFO] cfg_scale: {}", params.cfg_scale); 879 | println!("[INFO] guidance: {}", params.guidance); 880 | println!("[INFO] style_ratio: {}", params.style_ratio); 881 | println!("[INFO] clip_skip: {}", params.clip_skip); 882 | println!("[INFO] width: {}", params.width); 883 | println!("[INFO] height: {}", params.height); 884 | println!("[INFO] batch_count: {}", params.batch_count); 885 | println!("[INFO] sample_method: {:?}", params.sample_method); 886 | println!("[INFO] schedule: {:?}", params.schedule); 887 | println!("[INFO] sample_steps: {}", params.sample_steps); 888 | println!("[INFO] strength: {}", params.strength); 889 | println!("[INFO] control_strength: {}", params.control_strength); 890 | println!("[INFO] rng_type: {:?}", params.rng_type); 891 | println!("[INFO] seed: {}", params.seed); 892 | println!("[INFO] vae_tiling: {}", params.vae_tiling); 893 | println!("[INFO] control_net_cpu: {}", params.control_net_cpu); 894 | println!("[INFO] normalize_input: {}", params.normalize_input); 895 | println!("[INFO] clip_on_cpu: {}", params.clip_on_cpu); 896 | println!("[INFO] vae_on_cpu: {}", params.vae_on_cpu); 897 | println!("[INFO] canny: {}", params.canny); 898 | println!("[INFO] upscale_repeats: {}", params.upscale_repeats); 899 | } 900 | --------------------------------------------------------------------------------