├── rust
├── .gitignore
├── src
│ ├── error.rs
│ ├── stable_diffusion_interface.rs
│ └── lib.rs
├── Cargo.toml
└── Cargo.lock
├── example
├── .gitignore
├── Cargo.toml
├── Cargo.lock
└── src
│ └── main.rs
├── assets
├── output.png
├── output2.png
├── output_lora_img2img.png
└── output_lora_txt2img.png
├── .github
└── workflows
│ └── rust.yml
└── README.md
/rust/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 |
--------------------------------------------------------------------------------
/example/.gitignore:
--------------------------------------------------------------------------------
1 | /target
2 | sd-v*
3 | output*
4 | *.clpt
5 | *.gguf
--------------------------------------------------------------------------------
/assets/output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output.png
--------------------------------------------------------------------------------
/assets/output2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output2.png
--------------------------------------------------------------------------------
/assets/output_lora_img2img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output_lora_img2img.png
--------------------------------------------------------------------------------
/assets/output_lora_txt2img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/WasmEdge/wasmedge-stable-diffusion/HEAD/assets/output_lora_txt2img.png
--------------------------------------------------------------------------------
/example/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "wasmedge_stable_diffusion_example"
3 | version = "0.1.0"
4 | edition = "2021"
5 |
6 | [dependencies]
7 | wasmedge_stable_diffusion = {path="../rust"}
8 | clap = { version = "4.4.6", features = ["cargo"] }
9 | rand = "0.8"
--------------------------------------------------------------------------------
/rust/src/error.rs:
--------------------------------------------------------------------------------
1 | use thiserror::Error;
2 |
3 | /// Error types for the Llama Core library.
4 | #[derive(Error, Debug)]
5 | pub enum SDError {
6 | /// Invalid path error.
7 | #[error("{0}")]
8 | InvalidPath(String),
9 | /// Errors in General operation.
10 | #[error("{0}")]
11 | Operation(String),
12 | }
13 |
--------------------------------------------------------------------------------
/rust/Cargo.toml:
--------------------------------------------------------------------------------
1 | [package]
2 | name = "wasmedge_stable_diffusion"
3 | version = "0.3.2"
4 | edition = "2021"
5 | readme = "../README.md"
6 | repository = "https://github.com/WasmEdge/wasmedge-stable-diffusion"
7 | license = "Apache-2.0"
8 | categories = ["wasm", "science"]
9 | description = "A Rust library for using stable diffusion functions when the Wasi is being executed on WasmEdge."
10 |
11 |
12 | [dependencies]
13 | thiserror = "1"
14 |
--------------------------------------------------------------------------------
/.github/workflows/rust.yml:
--------------------------------------------------------------------------------
1 | name: Build Rust Crate
2 |
3 | on:
4 | push:
5 | branches:
6 | - dev
7 | - main
8 | - release-*
9 | - feat-*
10 | - ci-*
11 | - refactor-*
12 | - fix-*
13 | - test-*
14 | pull_request:
15 | branches:
16 | - dev
17 | - main
18 | - release-*
19 | - feat-*
20 | - ci-*
21 | - refactor-*
22 | - fix-*
23 | - test-*
24 |
25 | env:
26 | CARGO_TERM_COLOR: always
27 |
28 | jobs:
29 | build:
30 | runs-on: ubuntu-latest
31 |
32 | steps:
33 | - uses: actions/checkout@v4
34 | - uses: dtolnay/rust-toolchain@stable
35 |
36 | - name: Enable wasm32-wasip1 target
37 | run: rustup target add wasm32-wasip1
38 |
39 | - name: Run rustfmt
40 | run: |
41 | cd rust
42 | cargo fmt -- --check
43 |
44 | - name: Run clippy
45 | run: |
46 | cd rust
47 | cargo clippy -- -D warnings
48 |
49 | - name: Build
50 | run: |
51 | cd rust
52 | cargo build --target=wasm32-wasip1 --release
53 |
--------------------------------------------------------------------------------
/rust/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 3
4 |
5 | [[package]]
6 | name = "proc-macro2"
7 | version = "1.0.86"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
10 | dependencies = [
11 | "unicode-ident",
12 | ]
13 |
14 | [[package]]
15 | name = "quote"
16 | version = "1.0.37"
17 | source = "registry+https://github.com/rust-lang/crates.io-index"
18 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
19 | dependencies = [
20 | "proc-macro2",
21 | ]
22 |
23 | [[package]]
24 | name = "syn"
25 | version = "2.0.77"
26 | source = "registry+https://github.com/rust-lang/crates.io-index"
27 | checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
28 | dependencies = [
29 | "proc-macro2",
30 | "quote",
31 | "unicode-ident",
32 | ]
33 |
34 | [[package]]
35 | name = "thiserror"
36 | version = "1.0.63"
37 | source = "registry+https://github.com/rust-lang/crates.io-index"
38 | checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
39 | dependencies = [
40 | "thiserror-impl",
41 | ]
42 |
43 | [[package]]
44 | name = "thiserror-impl"
45 | version = "1.0.63"
46 | source = "registry+https://github.com/rust-lang/crates.io-index"
47 | checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
48 | dependencies = [
49 | "proc-macro2",
50 | "quote",
51 | "syn",
52 | ]
53 |
54 | [[package]]
55 | name = "unicode-ident"
56 | version = "1.0.12"
57 | source = "registry+https://github.com/rust-lang/crates.io-index"
58 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
59 |
60 | [[package]]
61 | name = "wasmedge_stable_diffusion"
62 | version = "0.3.2"
63 | dependencies = [
64 | "thiserror",
65 | ]
66 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # wasmedge-stable-diffusion
2 | A Rust library for using stable diffusion functions when the Wasi is being executed on WasmEdge.
3 | ## Set up WasmEdge
4 |
5 | ```Bash
6 | git clone https://github.com/WasmEdge/WasmEdge.git
7 | cd WasmEdge
8 | cmake -GNinja -Bbuild -DCMAKE_BUILD_TYPE=Release -DWASMEDGE_BUILD_TESTS=OFF -DWASMEDGE_PLUGIN_STABLEDIFFUSION=On -DWASMEDGE_USE_LLVM=OFF
9 | cmake --build build
10 | sudo cmake --install build
11 | ```
12 |
13 | ## Download Model
14 | Download the weights or quantized model from the following command.
15 | You also can use our example to quantize the weights by yourself.
16 |
17 | stable-diffusion v1.4: [second-state/stable-diffusion-v-1-4-GGUF](https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF)
18 | stable-diffusion v1.5: [second-state/stable-diffusion-v1-5-GGUF](https://huggingface.co/second-state/stable-diffusion-v1-5-GGUF)
19 | stable-diffusion v2.1: [second-state/stable-diffusion-2-1-GGUF](https://huggingface.co/second-state/stable-diffusion-2-1-GGUF)
20 |
21 | ```
22 | curl -L -O https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF/resolve/main/sd-v1-4.ckpt
23 | curl -L -O https://huggingface.co/second-state/stable-diffusion-v-1-4-GGUF/resolve/main/stable-diffusion-v1-4-Q8_0.gguf
24 | ```
25 |
26 | ## Compile example file
27 | The compiled `.wasm` file located at `./target/wasm32-wasi/release/`, and named `wasmedge_stable_diffusion_example.wasm`
28 | ```Bash
29 | cargo build --target wasm32-wasi --release
30 | ```
31 |
32 | ## Run
33 | It supports two mode: txt2img and img2img.
34 |
35 | ### txt2img
36 | Assume that the model `stable-diffusion-v-1-4-GGUF` is located in the models folder of the same directory as this project.
37 | ```Bash
38 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm -m ../../models/stable-diffusion-v1-4-Q8_0.gguf -p "a lovely cat"
39 | ```
40 |
41 |
42 |
43 |
44 | ### img2img
45 | - `./output.png` is the image generated from the above txt2img pipeline
46 | ```Bash
47 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm --mode img2img -m ../../models/stable-diffusion-v1-4-Q8_0.gguf -p "cat with red eyes" -i ./output.png -o ./img2img_output.png
48 | ```
49 |
50 |
51 |
52 |
53 | ### Convert
54 | - Stable Diffusion model: [sd-v1-4.ckpt](), which type is Q8_0.
55 | ```Bash
56 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm --mode convert -o stable-diffusion-v1-4-Q8_0_test.gguf -m ../../models/sd-v1-4.ckpt --type q8_0
57 | ```
58 | If you want to use the converted model, please use `--type` to asign the type `Q8_0`.
59 |
60 | ## More Guides - LoRA
61 | ### Get weights
62 | - LoRA model: [v1-5-pruned-emaonly.safetensors](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors)
63 | ```Bash
64 | curl -L -O https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors
65 | ```
66 | ### txt2img
67 | You can specify the directory where the lora weights are stored via `--lora-model-dir`.
68 | If not specified, the default is the current working directory.
69 | ```Bash
70 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm \
71 | --lora-model-dir ../../lora \
72 | --model ../../lora/v1-5-pruned-emaonly.safetensors \
73 | -p "a lovely cat" \
74 | -o output_lora_txt2img.png
75 | ```
76 | The lora model `../../lora/sd_xl_base_1.0.safetensors` and vae model `../../lora/sdxl_vae.safetensors` will be applied to the model.
77 |
78 |
79 |
80 |
81 | ### img2img
82 | ```Bash
83 | wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm \
84 | -p "with blue eyes" \
85 | --lora-model-dir ../../lora \
86 | --model ../../lora/v1-5-pruned-emaonly.safetensors \
87 | -i output_lora_txt2img.png \
88 | -o output_lora_img2img.png
89 | ```
90 |
91 |
92 |
93 |
94 | ## Supported parameters
95 | ```
96 | usage: wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm [arguments]
97 |
98 | arguments:
99 | -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)
100 | -t, --threads N number of threads to use during computation (default: -1).If threads <= 0, then threads will be set to the number of CPU physical cores
101 | -m, --model [MODEL] path to full model
102 | --diffusion-model path to the standalone diffusion model
103 | --clip_l path to the clip-l text encoder
104 | --t5xxl path to the the t5xxl text encoder
105 | --vae [VAE] path to vae
106 | --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)
107 | --control-net [CONTROL_PATH] path to control net model
108 | --embd-dir [EMBEDDING_PATH] path to embeddings
109 | --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings
110 | --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir
111 | --normalize-input normalize PHOTOMAKER input id images
112 | --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now
113 | --upscale-repeats Run the ESRGAN upscaler this many times (default 1)
114 | --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)
115 | If not specified, the default is the type of the weight file
116 | --lora-model-dir [DIR] lora model directory
117 | -i, --init-img [IMAGE] path to the input image, required by img2img
118 | --control-image [IMAGE] path to image condition, control net
119 | -o, --output OUTPUT path to write result image to (default: ./output.png)
120 | -p, --prompt [PROMPT] the prompt to render
121 | -n, --negative-prompt PROMPT the negative prompt (default: "")
122 | --cfg-scale SCALE unconditional guidance scale: (default: 7.0)
123 | --strength STRENGTH strength for noising/unnoising (default: 0.75)
124 | --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%)
125 | --control-strength STRENGTH strength to apply Control Net (default: 0.9)
126 | 1.0 corresponds to full destruction of information in init image
127 |
128 | --guidance guidance
129 | -H, --height H image height, in pixel space (default: 512)
130 | -W, --width W image width, in pixel space (default: 512)
131 | --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}
132 | sampling method (default: "euler_a")
133 | --steps STEPS number of sample steps (default: 20)
134 | --rng {std_default, cuda} RNG (default: cuda)
135 | -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)
136 | -b, --batch-count COUNT number of images to generate
137 | --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)
138 | --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)
139 | <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x
140 | --vae-tiling process vae in tiles to reduce memory usage
141 | --vae-on-cpu keep vae in cpu (for low vram)
142 | --clip-on-cpu keep clip in cpu (for low vram)
143 | --control-net-cpu keep controlnet in cpu (for low vram)
144 | --canny apply canny preprocessor (edge detection)
145 | ```
146 |
147 |
148 |
--------------------------------------------------------------------------------
/example/Cargo.lock:
--------------------------------------------------------------------------------
1 | # This file is automatically @generated by Cargo.
2 | # It is not intended for manual editing.
3 | version = 3
4 |
5 | [[package]]
6 | name = "anstream"
7 | version = "0.6.15"
8 | source = "registry+https://github.com/rust-lang/crates.io-index"
9 | checksum = "64e15c1ab1f89faffbf04a634d5e1962e9074f2741eef6d97f3c4e322426d526"
10 | dependencies = [
11 | "anstyle",
12 | "anstyle-parse",
13 | "anstyle-query",
14 | "anstyle-wincon",
15 | "colorchoice",
16 | "is_terminal_polyfill",
17 | "utf8parse",
18 | ]
19 |
20 | [[package]]
21 | name = "anstyle"
22 | version = "1.0.8"
23 | source = "registry+https://github.com/rust-lang/crates.io-index"
24 | checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1"
25 |
26 | [[package]]
27 | name = "anstyle-parse"
28 | version = "0.2.5"
29 | source = "registry+https://github.com/rust-lang/crates.io-index"
30 | checksum = "eb47de1e80c2b463c735db5b217a0ddc39d612e7ac9e2e96a5aed1f57616c1cb"
31 | dependencies = [
32 | "utf8parse",
33 | ]
34 |
35 | [[package]]
36 | name = "anstyle-query"
37 | version = "1.1.1"
38 | source = "registry+https://github.com/rust-lang/crates.io-index"
39 | checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a"
40 | dependencies = [
41 | "windows-sys",
42 | ]
43 |
44 | [[package]]
45 | name = "anstyle-wincon"
46 | version = "3.0.4"
47 | source = "registry+https://github.com/rust-lang/crates.io-index"
48 | checksum = "5bf74e1b6e971609db8ca7a9ce79fd5768ab6ae46441c572e46cf596f59e57f8"
49 | dependencies = [
50 | "anstyle",
51 | "windows-sys",
52 | ]
53 |
54 | [[package]]
55 | name = "byteorder"
56 | version = "1.5.0"
57 | source = "registry+https://github.com/rust-lang/crates.io-index"
58 | checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
59 |
60 | [[package]]
61 | name = "cfg-if"
62 | version = "1.0.0"
63 | source = "registry+https://github.com/rust-lang/crates.io-index"
64 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
65 |
66 | [[package]]
67 | name = "clap"
68 | version = "4.5.17"
69 | source = "registry+https://github.com/rust-lang/crates.io-index"
70 | checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac"
71 | dependencies = [
72 | "clap_builder",
73 | ]
74 |
75 | [[package]]
76 | name = "clap_builder"
77 | version = "4.5.17"
78 | source = "registry+https://github.com/rust-lang/crates.io-index"
79 | checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73"
80 | dependencies = [
81 | "anstream",
82 | "anstyle",
83 | "clap_lex",
84 | "strsim",
85 | ]
86 |
87 | [[package]]
88 | name = "clap_lex"
89 | version = "0.7.2"
90 | source = "registry+https://github.com/rust-lang/crates.io-index"
91 | checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97"
92 |
93 | [[package]]
94 | name = "colorchoice"
95 | version = "1.0.2"
96 | source = "registry+https://github.com/rust-lang/crates.io-index"
97 | checksum = "d3fd119d74b830634cea2a0f58bbd0d54540518a14397557951e79340abc28c0"
98 |
99 | [[package]]
100 | name = "getrandom"
101 | version = "0.2.15"
102 | source = "registry+https://github.com/rust-lang/crates.io-index"
103 | checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7"
104 | dependencies = [
105 | "cfg-if",
106 | "libc",
107 | "wasi",
108 | ]
109 |
110 | [[package]]
111 | name = "is_terminal_polyfill"
112 | version = "1.70.1"
113 | source = "registry+https://github.com/rust-lang/crates.io-index"
114 | checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
115 |
116 | [[package]]
117 | name = "libc"
118 | version = "0.2.158"
119 | source = "registry+https://github.com/rust-lang/crates.io-index"
120 | checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
121 |
122 | [[package]]
123 | name = "ppv-lite86"
124 | version = "0.2.20"
125 | source = "registry+https://github.com/rust-lang/crates.io-index"
126 | checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04"
127 | dependencies = [
128 | "zerocopy",
129 | ]
130 |
131 | [[package]]
132 | name = "proc-macro2"
133 | version = "1.0.86"
134 | source = "registry+https://github.com/rust-lang/crates.io-index"
135 | checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77"
136 | dependencies = [
137 | "unicode-ident",
138 | ]
139 |
140 | [[package]]
141 | name = "quote"
142 | version = "1.0.37"
143 | source = "registry+https://github.com/rust-lang/crates.io-index"
144 | checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af"
145 | dependencies = [
146 | "proc-macro2",
147 | ]
148 |
149 | [[package]]
150 | name = "rand"
151 | version = "0.8.5"
152 | source = "registry+https://github.com/rust-lang/crates.io-index"
153 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
154 | dependencies = [
155 | "libc",
156 | "rand_chacha",
157 | "rand_core",
158 | ]
159 |
160 | [[package]]
161 | name = "rand_chacha"
162 | version = "0.3.1"
163 | source = "registry+https://github.com/rust-lang/crates.io-index"
164 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
165 | dependencies = [
166 | "ppv-lite86",
167 | "rand_core",
168 | ]
169 |
170 | [[package]]
171 | name = "rand_core"
172 | version = "0.6.4"
173 | source = "registry+https://github.com/rust-lang/crates.io-index"
174 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
175 | dependencies = [
176 | "getrandom",
177 | ]
178 |
179 | [[package]]
180 | name = "strsim"
181 | version = "0.11.1"
182 | source = "registry+https://github.com/rust-lang/crates.io-index"
183 | checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
184 |
185 | [[package]]
186 | name = "syn"
187 | version = "2.0.77"
188 | source = "registry+https://github.com/rust-lang/crates.io-index"
189 | checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed"
190 | dependencies = [
191 | "proc-macro2",
192 | "quote",
193 | "unicode-ident",
194 | ]
195 |
196 | [[package]]
197 | name = "thiserror"
198 | version = "1.0.63"
199 | source = "registry+https://github.com/rust-lang/crates.io-index"
200 | checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724"
201 | dependencies = [
202 | "thiserror-impl",
203 | ]
204 |
205 | [[package]]
206 | name = "thiserror-impl"
207 | version = "1.0.63"
208 | source = "registry+https://github.com/rust-lang/crates.io-index"
209 | checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261"
210 | dependencies = [
211 | "proc-macro2",
212 | "quote",
213 | "syn",
214 | ]
215 |
216 | [[package]]
217 | name = "unicode-ident"
218 | version = "1.0.12"
219 | source = "registry+https://github.com/rust-lang/crates.io-index"
220 | checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
221 |
222 | [[package]]
223 | name = "utf8parse"
224 | version = "0.2.2"
225 | source = "registry+https://github.com/rust-lang/crates.io-index"
226 | checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
227 |
228 | [[package]]
229 | name = "wasi"
230 | version = "0.11.0+wasi-snapshot-preview1"
231 | source = "registry+https://github.com/rust-lang/crates.io-index"
232 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
233 |
234 | [[package]]
235 | name = "wasmedge_stable_diffusion"
236 | version = "0.3.2"
237 | dependencies = [
238 | "thiserror",
239 | ]
240 |
241 | [[package]]
242 | name = "wasmedge_stable_diffusion_example"
243 | version = "0.1.0"
244 | dependencies = [
245 | "clap",
246 | "rand",
247 | "wasmedge_stable_diffusion",
248 | ]
249 |
250 | [[package]]
251 | name = "windows-sys"
252 | version = "0.52.0"
253 | source = "registry+https://github.com/rust-lang/crates.io-index"
254 | checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
255 | dependencies = [
256 | "windows-targets",
257 | ]
258 |
259 | [[package]]
260 | name = "windows-targets"
261 | version = "0.52.6"
262 | source = "registry+https://github.com/rust-lang/crates.io-index"
263 | checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973"
264 | dependencies = [
265 | "windows_aarch64_gnullvm",
266 | "windows_aarch64_msvc",
267 | "windows_i686_gnu",
268 | "windows_i686_gnullvm",
269 | "windows_i686_msvc",
270 | "windows_x86_64_gnu",
271 | "windows_x86_64_gnullvm",
272 | "windows_x86_64_msvc",
273 | ]
274 |
275 | [[package]]
276 | name = "windows_aarch64_gnullvm"
277 | version = "0.52.6"
278 | source = "registry+https://github.com/rust-lang/crates.io-index"
279 | checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
280 |
281 | [[package]]
282 | name = "windows_aarch64_msvc"
283 | version = "0.52.6"
284 | source = "registry+https://github.com/rust-lang/crates.io-index"
285 | checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
286 |
287 | [[package]]
288 | name = "windows_i686_gnu"
289 | version = "0.52.6"
290 | source = "registry+https://github.com/rust-lang/crates.io-index"
291 | checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
292 |
293 | [[package]]
294 | name = "windows_i686_gnullvm"
295 | version = "0.52.6"
296 | source = "registry+https://github.com/rust-lang/crates.io-index"
297 | checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
298 |
299 | [[package]]
300 | name = "windows_i686_msvc"
301 | version = "0.52.6"
302 | source = "registry+https://github.com/rust-lang/crates.io-index"
303 | checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
304 |
305 | [[package]]
306 | name = "windows_x86_64_gnu"
307 | version = "0.52.6"
308 | source = "registry+https://github.com/rust-lang/crates.io-index"
309 | checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
310 |
311 | [[package]]
312 | name = "windows_x86_64_gnullvm"
313 | version = "0.52.6"
314 | source = "registry+https://github.com/rust-lang/crates.io-index"
315 | checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
316 |
317 | [[package]]
318 | name = "windows_x86_64_msvc"
319 | version = "0.52.6"
320 | source = "registry+https://github.com/rust-lang/crates.io-index"
321 | checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
322 |
323 | [[package]]
324 | name = "zerocopy"
325 | version = "0.7.35"
326 | source = "registry+https://github.com/rust-lang/crates.io-index"
327 | checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0"
328 | dependencies = [
329 | "byteorder",
330 | "zerocopy-derive",
331 | ]
332 |
333 | [[package]]
334 | name = "zerocopy-derive"
335 | version = "0.7.35"
336 | source = "registry+https://github.com/rust-lang/crates.io-index"
337 | checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e"
338 | dependencies = [
339 | "proc-macro2",
340 | "quote",
341 | "syn",
342 | ]
343 |
--------------------------------------------------------------------------------
/rust/src/stable_diffusion_interface.rs:
--------------------------------------------------------------------------------
1 | use core::fmt;
2 | use core::mem::MaybeUninit;
3 | #[repr(transparent)]
4 | #[derive(Copy, Clone, Hash, Eq, PartialEq, Ord, PartialOrd)]
5 | pub struct WasmedgeSdErrno(u32);
6 | pub const WASMEDGE_SD_ERRNO_SUCCESS: WasmedgeSdErrno = WasmedgeSdErrno(0);
7 | pub const WASMEDGE_SD_ERRNO_INVALID_ARGUMENT: WasmedgeSdErrno = WasmedgeSdErrno(1);
8 | pub const WASMEDGE_SD_ERRNO_INVALID_ENCODING: WasmedgeSdErrno = WasmedgeSdErrno(2);
9 | pub const WASMEDGE_SD_ERRNO_MISSING_MEMORY: WasmedgeSdErrno = WasmedgeSdErrno(3);
10 | pub const WASMEDGE_SD_ERRNO_BUSY: WasmedgeSdErrno = WasmedgeSdErrno(4);
11 | pub const WASMEDGE_SD_ERRNO_RUNTIME_ERROR: WasmedgeSdErrno = WasmedgeSdErrno(5);
12 | impl WasmedgeSdErrno {
13 | pub const fn raw(&self) -> u32 {
14 | self.0
15 | }
16 |
17 | pub fn name(&self) -> &'static str {
18 | match self.0 {
19 | 0 => "SUCCESS",
20 | 1 => "INVALID_ARGUMENT",
21 | 2 => "INVALID_ENCODING",
22 | 3 => "MISSING_MEMORY",
23 | 4 => "BUSY",
24 | 5 => "RUNTIME_ERROR",
25 | _ => unsafe { core::hint::unreachable_unchecked() },
26 | }
27 | }
28 | pub fn message(&self) -> &'static str {
29 | match self.0 {
30 | 0 => "",
31 | 1 => "",
32 | 2 => "",
33 | 3 => "",
34 | 4 => "",
35 | 5 => "",
36 | _ => unsafe { core::hint::unreachable_unchecked() },
37 | }
38 | }
39 | }
40 | impl fmt::Debug for WasmedgeSdErrno {
41 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 | f.debug_struct("WasmedgeSdErrno")
43 | .field("code", &self.0)
44 | .field("name", &self.name())
45 | .field("message", &self.message())
46 | .finish()
47 | }
48 | }
49 | impl fmt::Display for WasmedgeSdErrno {
50 | fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 | write!(f, "{} (error {})", self.name(), self.0)
52 | }
53 | }
54 |
55 | #[derive(Debug, Copy, Clone)]
56 | pub enum SdTypeT {
57 | SdTypeF32 = 0,
58 | SdTypeF16 = 1,
59 | SdTypeQ4_0 = 2,
60 | SdTypeQ4_1 = 3,
61 | // SdTypeQ4_2 = 4, // support has been removed
62 | // SdTypeQ4_3 = 5, // support has been removed
63 | SdTypeQ5_0 = 6,
64 | SdTypeQ5_1 = 7,
65 | SdTypeQ8_0 = 8,
66 | SdTypeQ8_1 = 9,
67 | SdTypeQ2K = 10,
68 | SdTypeQ3K = 11,
69 | SdTypeQ4K = 12,
70 | SdTypeQ5K = 13,
71 | SdTypeQ6K = 14,
72 | SdTypeQ8K = 15,
73 | SdTypeIq2Xxs = 16,
74 | SdTypeIq2Xs = 17,
75 | SdTypeIq3Xxs = 18,
76 | SdTypeIq1S = 19,
77 | SdTypeIq4Nl = 20,
78 | SdTypeIq3S = 21,
79 | SdTypeIq2S = 22,
80 | SdTypeIq4Xs = 23,
81 | SdTypeI8 = 24,
82 | SdTypeI16 = 25,
83 | SdTypeI32 = 26,
84 | SdTypeI64 = 27,
85 | SdTypeF64 = 28,
86 | SdTypeIq1M = 29,
87 | SdTypeBf16 = 30,
88 | SdTypeQ4044 = 31,
89 | SdTypeQ4048 = 32,
90 | SdTypeQ4088 = 33,
91 | SdTypeCount = 34,
92 | }
93 | #[derive(Debug, Copy, Clone)]
94 | pub enum RngTypeT {
95 | StdDefaultRng = 0,
96 | CUDARng = 1,
97 | }
98 | #[derive(Debug, Copy, Clone)]
99 | pub enum SampleMethodT {
100 | EULERA = 0,
101 | EULER = 1,
102 | HEUN = 2,
103 | DPM2 = 3,
104 | DPMPP2SA = 4,
105 | DPMPP2M = 5,
106 | DPMPP2Mv2 = 6,
107 | IPNDM = 7,
108 | IPNDMV = 8,
109 | LCM = 9,
110 | }
111 | #[derive(Debug, Copy, Clone)]
112 | pub enum ScheduleT {
113 | DEFAULT = 0,
114 | DISCRETE = 1,
115 | KARRAS = 2,
116 | EXPONENTIAL = 3,
117 | AYS = 4,
118 | GITS = 5,
119 | }
120 | #[derive(Debug)]
121 | pub enum ImageType {
122 | Path(String),
123 | }
124 | fn parse_image(image: &ImageType) -> (i32, i32) {
125 | match image {
126 | ImageType::Path(path) => {
127 | if path.is_empty() {
128 | return (0, 0);
129 | }
130 | let path = "path:".to_string() + path;
131 | (path.as_ptr() as i32, path.len() as i32)
132 | }
133 | }
134 | }
135 |
136 | impl SdTypeT {
137 | pub fn from_index(index: usize) -> Result {
138 | match index {
139 | 0 => Ok(SdTypeT::SdTypeF32),
140 | 1 => Ok(SdTypeT::SdTypeF16),
141 | 2 => Ok(SdTypeT::SdTypeQ4_0),
142 | 3 => Ok(SdTypeT::SdTypeQ4_1),
143 | // 4 => Ok(SdTypeT::SdTypeQ4_2),// support has been removed
144 | // 5 => Ok(SdTypeT::SdTypeQ4_3),// support has been removed
145 | 6 => Ok(SdTypeT::SdTypeQ5_0),
146 | 7 => Ok(SdTypeT::SdTypeQ5_1),
147 | 8 => Ok(SdTypeT::SdTypeQ8_0),
148 | 9 => Ok(SdTypeT::SdTypeQ8_1),
149 | 10 => Ok(SdTypeT::SdTypeQ2K),
150 | 11 => Ok(SdTypeT::SdTypeQ3K),
151 | 12 => Ok(SdTypeT::SdTypeQ4K),
152 | 13 => Ok(SdTypeT::SdTypeQ5K),
153 | 14 => Ok(SdTypeT::SdTypeQ6K),
154 | 15 => Ok(SdTypeT::SdTypeQ8K),
155 | 16 => Ok(SdTypeT::SdTypeIq2Xxs),
156 | 17 => Ok(SdTypeT::SdTypeIq2Xs),
157 | 18 => Ok(SdTypeT::SdTypeIq3Xxs),
158 | 19 => Ok(SdTypeT::SdTypeIq1S),
159 | 20 => Ok(SdTypeT::SdTypeIq4Nl),
160 | 21 => Ok(SdTypeT::SdTypeIq3S),
161 | 22 => Ok(SdTypeT::SdTypeIq2S),
162 | 23 => Ok(SdTypeT::SdTypeIq4Xs),
163 | 24 => Ok(SdTypeT::SdTypeI8),
164 | 25 => Ok(SdTypeT::SdTypeI16),
165 | 26 => Ok(SdTypeT::SdTypeI32),
166 | 27 => Ok(SdTypeT::SdTypeI64),
167 | 28 => Ok(SdTypeT::SdTypeF64),
168 | 29 => Ok(SdTypeT::SdTypeIq1M),
169 | 30 => Ok(SdTypeT::SdTypeBf16),
170 | 31 => Ok(SdTypeT::SdTypeQ4044),
171 | 32 => Ok(SdTypeT::SdTypeQ4048),
172 | 33 => Ok(SdTypeT::SdTypeQ4088),
173 | _ => Ok(SdTypeT::SdTypeCount),
174 | }
175 | }
176 | }
177 | impl SampleMethodT {
178 | pub fn from_index(index: usize) -> Result {
179 | match index {
180 | 0 => Ok(SampleMethodT::EULERA),
181 | 1 => Ok(SampleMethodT::EULER),
182 | 2 => Ok(SampleMethodT::HEUN),
183 | 3 => Ok(SampleMethodT::DPM2),
184 | 4 => Ok(SampleMethodT::DPMPP2SA),
185 | 5 => Ok(SampleMethodT::DPMPP2M),
186 | 6 => Ok(SampleMethodT::DPMPP2Mv2),
187 | 7 => Ok(SampleMethodT::IPNDM),
188 | 8 => Ok(SampleMethodT::IPNDMV),
189 | _ => Ok(SampleMethodT::LCM),
190 | }
191 | }
192 | }
193 | impl ScheduleT {
194 | pub fn from_index(index: usize) -> Result {
195 | match index {
196 | 0 => Ok(ScheduleT::DEFAULT),
197 | 1 => Ok(ScheduleT::DISCRETE),
198 | 2 => Ok(ScheduleT::KARRAS),
199 | 3 => Ok(ScheduleT::EXPONENTIAL),
200 | 4 => Ok(ScheduleT::AYS),
201 | _ => Ok(ScheduleT::GITS),
202 | }
203 | }
204 | }
205 |
206 | /// # Safety
207 | /// This will call an external wasm function
208 | pub unsafe fn convert(
209 | model_path: &str,
210 | vae_model_path: &str,
211 | output_path: &str,
212 | wtype: SdTypeT,
213 | ) -> Result<(), WasmedgeSdErrno> {
214 | let model_path_ptr = model_path.as_ptr() as i32;
215 | let model_path_len = model_path.len() as i32;
216 | let vae_model_path_ptr = vae_model_path.as_ptr() as i32;
217 | let vae_model_path_len = vae_model_path.len() as i32;
218 | let output_path_ptr = output_path.as_ptr() as i32;
219 | let output_path_len = output_path.len() as i32;
220 | let result = wasmedge_stablediffusion::convert(
221 | model_path_ptr,
222 | model_path_len,
223 | vae_model_path_ptr,
224 | vae_model_path_len,
225 | output_path_ptr,
226 | output_path_len,
227 | wtype as i32,
228 | );
229 | if result != 0 {
230 | Err(WasmedgeSdErrno(result as u32))
231 | } else {
232 | Ok(())
233 | }
234 | }
235 | #[allow(clippy::too_many_arguments)]
236 | /// # Safety
237 | /// This will call an external wasm function
238 | pub unsafe fn create_context(
239 | model_path: &str,
240 | clip_l_path: &str,
241 | t5xxl_path: &str,
242 | diffusion_model_path: &str,
243 | vae_path: &str,
244 | taesd_path: &str,
245 | control_net_path: &str,
246 | lora_model_dir: &str,
247 | embed_dir: &str,
248 | id_embed_dir: &str,
249 | vae_decode_only: bool,
250 | vae_tiling: bool,
251 | n_threads: i32,
252 | wtype: SdTypeT,
253 | rng_type: RngTypeT,
254 | schedule: ScheduleT,
255 | clip_on_cpu: bool,
256 | control_net_cpu: bool,
257 | vae_on_cpu: bool,
258 | session_id: *mut u32,
259 | ) -> Result<(), WasmedgeSdErrno> {
260 | let model_path_ptr = model_path.as_ptr() as i32;
261 | let model_path_len = model_path.len() as i32;
262 | let clip_l_path_ptr = clip_l_path.as_ptr() as i32;
263 | let clip_l_path_len = clip_l_path.len() as i32;
264 | let t5xxl_path_ptr = t5xxl_path.as_ptr() as i32;
265 | let t5xxl_path_len = t5xxl_path.len() as i32;
266 | let diffusion_model_path_ptr = diffusion_model_path.as_ptr() as i32;
267 | let diffusion_model_path_len = diffusion_model_path.len() as i32;
268 | let vae_path_ptr = vae_path.as_ptr() as i32;
269 | let vae_path_len = vae_path.len() as i32;
270 | let taesd_path_ptr = taesd_path.as_ptr() as i32;
271 | let taesd_path_len = taesd_path.len() as i32;
272 | let control_net_path_ptr = control_net_path.as_ptr() as i32;
273 | let control_net_path_len = control_net_path.len() as i32;
274 | let lora_model_dir_ptr = lora_model_dir.as_ptr() as i32;
275 | let lora_model_dir_len = lora_model_dir.len() as i32;
276 | let embed_dir_ptr = embed_dir.as_ptr() as i32;
277 | let embed_dir_len = embed_dir.len() as i32;
278 | let id_embed_dir_ptr = id_embed_dir.as_ptr() as i32;
279 | let id_embed_dir_len = id_embed_dir.len() as i32;
280 | let vae_decode_only = vae_decode_only as i32;
281 | let vae_tiling = vae_tiling as i32;
282 | let n_threads_ = n_threads;
283 | let wtype = wtype as i32;
284 | let rng_type = rng_type as i32;
285 | let schedule = schedule as i32;
286 | let clip_on_cpu = clip_on_cpu as i32;
287 | let control_net_cpu = control_net_cpu as i32;
288 | let vae_on_cpu = vae_on_cpu as i32;
289 | let session_id_ptr = session_id as i32;
290 | let result = wasmedge_stablediffusion::create_context(
291 | model_path_ptr,
292 | model_path_len,
293 | clip_l_path_ptr,
294 | clip_l_path_len,
295 | t5xxl_path_ptr,
296 | t5xxl_path_len,
297 | diffusion_model_path_ptr,
298 | diffusion_model_path_len,
299 | vae_path_ptr,
300 | vae_path_len,
301 | taesd_path_ptr,
302 | taesd_path_len,
303 | control_net_path_ptr,
304 | control_net_path_len,
305 | lora_model_dir_ptr,
306 | lora_model_dir_len,
307 | embed_dir_ptr,
308 | embed_dir_len,
309 | id_embed_dir_ptr,
310 | id_embed_dir_len,
311 | vae_decode_only,
312 | vae_tiling,
313 | n_threads_,
314 | wtype,
315 | rng_type,
316 | schedule,
317 | clip_on_cpu,
318 | control_net_cpu,
319 | vae_on_cpu,
320 | session_id_ptr,
321 | );
322 | if result != 0 {
323 | Err(WasmedgeSdErrno(result as u32))
324 | } else {
325 | Ok(())
326 | }
327 | }
328 |
329 | #[allow(clippy::too_many_arguments)]
330 | /// # Safety
331 | /// This will call an external wasm function
332 | pub unsafe fn text_to_image(
333 | prompt: &str,
334 | session_id: u32,
335 | control_image: &ImageType,
336 | negative_prompt: &str,
337 | guidance: f32,
338 | width: i32,
339 | height: i32,
340 | clip_skip: i32,
341 | cfg_scale: f32,
342 | sample_method: SampleMethodT,
343 | sample_steps: i32,
344 | seed: i32,
345 | batch_count: i32,
346 | control_strength: f32,
347 | style_ratio: f32,
348 | normalize_input: bool,
349 | input_id_images_dir: &str,
350 | canny_preprocess: bool,
351 | upscale_model: &str,
352 | upscale_repeats: i32,
353 | output_path: &str,
354 | output_buf: *mut u8,
355 | out_buffer_max_size: i32,
356 | ) -> Result {
357 | let prompt_ptr = prompt.as_ptr() as i32;
358 | let prompt_len = prompt.len() as i32;
359 | let session_id = session_id as i32;
360 | let (control_image_ptr, control_image_len) = parse_image(control_image);
361 | let negative_prompt_ptr = negative_prompt.as_ptr() as i32;
362 | let negative_prompt_len = negative_prompt.len() as i32;
363 | let sample_method = sample_method as i32;
364 | let input_id_images_dir_ptr = input_id_images_dir.as_ptr() as i32;
365 | let input_id_images_dir_len = input_id_images_dir.len() as i32;
366 | let normalize_input = normalize_input as i32;
367 | let canny_preprocess = canny_preprocess as i32;
368 | let upscale_model_path_ptr = upscale_model.as_ptr() as i32;
369 | let upscale_model_path_len = upscale_model.len() as i32;
370 | let output_path_ptr = output_path.as_ptr() as i32;
371 | let output_path_len = output_path.len() as i32;
372 | let output_buf_ptr = output_buf as i32;
373 | let out_buffer_max_size_ = out_buffer_max_size;
374 | let mut write_bytes = MaybeUninit::::uninit();
375 | let result = wasmedge_stablediffusion::text_to_image(
376 | prompt_ptr,
377 | prompt_len,
378 | session_id,
379 | control_image_ptr,
380 | control_image_len,
381 | negative_prompt_ptr,
382 | negative_prompt_len,
383 | guidance,
384 | width,
385 | height,
386 | clip_skip,
387 | cfg_scale,
388 | sample_method,
389 | sample_steps,
390 | seed,
391 | batch_count,
392 | control_strength,
393 | style_ratio,
394 | normalize_input,
395 | input_id_images_dir_ptr,
396 | input_id_images_dir_len,
397 | canny_preprocess,
398 | upscale_model_path_ptr,
399 | upscale_model_path_len,
400 | upscale_repeats,
401 | output_path_ptr,
402 | output_path_len,
403 | output_buf_ptr,
404 | out_buffer_max_size_,
405 | write_bytes.as_mut_ptr() as i32,
406 | );
407 | if result != 0 {
408 | Err(WasmedgeSdErrno(result as u32))
409 | } else {
410 | Ok(write_bytes.assume_init())
411 | }
412 | }
413 | #[allow(clippy::too_many_arguments)]
414 | /// # Safety
415 | /// This will call an external wasm function
416 | pub unsafe fn image_to_image(
417 | image: &ImageType,
418 | session_id: u32,
419 | guidance: f32,
420 | width: i32,
421 | height: i32,
422 | control_image: &ImageType,
423 | prompt: &str,
424 | negative_prompt: &str,
425 | clip_skip: i32,
426 | cfg_scale: f32,
427 | sample_method: SampleMethodT,
428 | sample_steps: i32,
429 | strength: f32,
430 | seed: i32,
431 | batch_count: i32,
432 | control_strength: f32,
433 | style_ratio: f32,
434 | normalize_input: bool,
435 | input_id_images_dir: &str,
436 | canny_preprocess: bool,
437 | upscale_model_path: &str,
438 | upscale_repeats: i32,
439 | output_path: &str,
440 | output_buf: *mut u8,
441 | out_buffer_max_size: i32,
442 | ) -> Result {
443 | let (image_ptr, image_len) = parse_image(image);
444 | let (control_image_ptr, control_image_len) = parse_image(control_image);
445 | let session_id = session_id as i32;
446 | let prompt_ptr = prompt.as_ptr() as i32;
447 | let prompt_len = prompt.len() as i32;
448 | let negative_prompt_ptr = negative_prompt.as_ptr() as i32;
449 | let negative_prompt_len = negative_prompt.len() as i32;
450 | let sample_method = sample_method as i32;
451 | let normalize_input = normalize_input as i32;
452 | let input_id_images_dir_ptr = input_id_images_dir.as_ptr() as i32;
453 | let input_id_images_dir_len = input_id_images_dir.len() as i32;
454 | let canny_preprocess = canny_preprocess as i32;
455 | let upscale_model_path_ptr = upscale_model_path.as_ptr() as i32;
456 | let upscale_model_path_len = upscale_model_path.len() as i32;
457 | let output_path_ptr = output_path.as_ptr() as i32;
458 | let output_path_len = output_path.len() as i32;
459 | let output_buf_ptr = output_buf as i32;
460 | let out_buffer_max_size_ = out_buffer_max_size;
461 | let mut write_bytes = MaybeUninit::::uninit();
462 |
463 | let result = wasmedge_stablediffusion::image_to_image(
464 | image_ptr,
465 | image_len,
466 | session_id,
467 | guidance,
468 | width,
469 | height,
470 | control_image_ptr,
471 | control_image_len,
472 | prompt_ptr,
473 | prompt_len,
474 | negative_prompt_ptr,
475 | negative_prompt_len,
476 | clip_skip,
477 | cfg_scale,
478 | sample_method,
479 | sample_steps,
480 | strength,
481 | seed,
482 | batch_count,
483 | control_strength,
484 | style_ratio,
485 | normalize_input,
486 | input_id_images_dir_ptr,
487 | input_id_images_dir_len,
488 | canny_preprocess,
489 | upscale_model_path_ptr,
490 | upscale_model_path_len,
491 | upscale_repeats,
492 | output_path_ptr,
493 | output_path_len,
494 | output_buf_ptr,
495 | out_buffer_max_size_,
496 | write_bytes.as_mut_ptr() as i32,
497 | );
498 | if result != 0 {
499 | Err(WasmedgeSdErrno(result as u32))
500 | } else {
501 | Ok(write_bytes.assume_init())
502 | }
503 | }
504 | pub mod wasmedge_stablediffusion {
505 | #[link(wasm_import_module = "wasmedge_stablediffusion")]
506 | extern "C" {
507 | pub fn create_context(
508 | model_path_ptr: i32,
509 | model_path_len: i32,
510 | clip_l_path_ptr: i32,
511 | clip_l_path_len: i32,
512 | t5xxl_path_ptr: i32,
513 | t5xxl_path_len: i32,
514 | diffusion_model_path_ptr: i32,
515 | diffusion_model_path_len: i32,
516 | vae_path_ptr: i32,
517 | vae_path_len: i32,
518 | taesd_path_ptr: i32,
519 | taesd_path_len: i32,
520 | control_net_path_ptr: i32,
521 | control_net_path_len: i32,
522 | lora_model_dir_ptr: i32,
523 | lora_model_dir_len: i32,
524 | embed_dir_ptr: i32,
525 | embed_dir_len: i32,
526 | id_embed_dir_ptr: i32,
527 | id_embed_dir_len: i32,
528 | vae_decode_only: i32,
529 | vae_tiling: i32,
530 | n_threads: i32,
531 | wtype: i32,
532 | rng_type: i32,
533 | schedule: i32,
534 | clip_on_cpu: i32,
535 | control_net_cpu: i32,
536 | vae_on_cpu: i32,
537 | session_id_ptr: i32,
538 | ) -> i32;
539 |
540 | pub fn image_to_image(
541 | image_ptr: i32,
542 | image_len: i32,
543 | session_id: i32,
544 | guidance: f32,
545 | width: i32,
546 | height: i32,
547 | control_image_ptr: i32,
548 | control_image_len: i32,
549 | prompt_ptr: i32,
550 | prompt_len: i32,
551 | negative_prompt_ptr: i32,
552 | negative_prompt_len: i32,
553 | clip_skip: i32,
554 | cfg_scale: f32,
555 | sample_method: i32,
556 | sample_steps: i32,
557 | strength: f32,
558 | seed: i32,
559 | batch_count: i32,
560 | control_strength: f32,
561 | style_ratio: f32,
562 | normalize_input: i32,
563 | input_id_images_dir_ptr: i32,
564 | input_id_images_dir_len: i32,
565 | canny_preprocess: i32,
566 | upscale_model_path_ptr: i32,
567 | upscale_model_path_len: i32,
568 | upscale_repeats: i32,
569 | output_path_ptr: i32,
570 | output_path_len: i32,
571 | out_buffer_ptr: i32,
572 | out_buffer_max_size: i32,
573 | bytes_written_ptr: i32,
574 | ) -> i32;
575 |
576 | pub fn text_to_image(
577 | prompt_ptr: i32,
578 | prompt_len: i32,
579 | session_id: i32,
580 | control_image_ptr: i32,
581 | control_image_len: i32,
582 | negative_prompt_ptr: i32,
583 | negative_prompt_len: i32,
584 | guidance: f32,
585 | width: i32,
586 | height: i32,
587 | clip_skip: i32,
588 | cfg_scale: f32,
589 | sample_method: i32,
590 | sample_steps: i32,
591 | seed: i32,
592 | batch_count: i32,
593 | control_strength: f32,
594 | style_ratio: f32,
595 | normalize_input: i32,
596 | input_id_images_dir_ptr: i32,
597 | input_id_images_dir_len: i32,
598 | canny_preprocess: i32,
599 | upscale_model_path_ptr: i32,
600 | upscale_model_path_len: i32,
601 | upscale_repeats: i32,
602 | output_path_ptr: i32,
603 | output_path_len: i32,
604 | out_buffer_ptr: i32,
605 | out_buffer_max_size: i32,
606 | bytes_written_ptr: i32,
607 | ) -> i32;
608 |
609 | pub fn convert(
610 | model_path_ptr: i32,
611 | model_path_len: i32,
612 | vae_model_path_ptr: i32,
613 | vae_model_path_len: i32,
614 | output_path_ptr: i32,
615 | output_path_len: i32,
616 | wtype: i32,
617 | ) -> i32;
618 | }
619 | }
620 |
--------------------------------------------------------------------------------
/rust/src/lib.rs:
--------------------------------------------------------------------------------
1 | pub mod error;
2 | pub mod stable_diffusion_interface;
3 |
4 | use core::mem::MaybeUninit;
5 | use error::SDError;
6 | use stable_diffusion_interface::*;
7 | use std::path::Path;
8 |
9 | const BUF_LEN: i32 = 1000000;
10 |
11 | pub type SDResult = Result;
12 |
13 | /// Represents Quantization task.
14 | pub struct Quantization {
15 | pub model_path: String,
16 | pub vae_model_path: String,
17 | pub output_path: String,
18 | pub wtype: SdTypeT,
19 | }
20 | impl Quantization {
21 | pub fn new(
22 | model_path: &str,
23 | vae_model_path: String,
24 | output_path: &str,
25 | wtype: SdTypeT,
26 | ) -> Quantization {
27 | Quantization {
28 | model_path: model_path.to_string(),
29 | vae_model_path,
30 | output_path: output_path.to_string(),
31 | wtype,
32 | }
33 | }
34 | pub fn convert(&self) -> Result<(), WasmedgeSdErrno> {
35 | unsafe {
36 | stable_diffusion_interface::convert(
37 | &self.model_path,
38 | &self.vae_model_path,
39 | &self.output_path,
40 | self.wtype,
41 | )
42 | }
43 | }
44 | }
45 |
46 | #[derive(Debug)]
47 | pub enum Task {
48 | TextToImage,
49 | ImageToImage,
50 | Convert,
51 | }
52 | // Parse command line arguments, for --mode
53 | impl std::str::FromStr for Task {
54 | type Err = String;
55 | fn from_str(s: &str) -> Result {
56 | match s {
57 | "txt2img" => Ok(Task::TextToImage),
58 | "img2img" => Ok(Task::ImageToImage),
59 | "convert" => Ok(Task::Convert),
60 | _ => Err(format!("Invalid mode: {}", s)),
61 | }
62 | }
63 | }
64 |
65 | #[derive(Debug)]
66 | pub enum Context {
67 | TextToImage(TextToImage),
68 | ImageToImage(ImageToImage),
69 | }
70 |
71 | #[derive(Debug)]
72 | pub struct BaseContext {
73 | pub session_id: u32,
74 | pub prompt: String,
75 | pub guidance: f32,
76 | pub width: i32,
77 | pub height: i32,
78 | pub control_image: ImageType,
79 | pub negative_prompt: String,
80 | pub clip_skip: i32,
81 | pub cfg_scale: f32,
82 | pub sample_method: SampleMethodT,
83 | pub sample_steps: i32,
84 | pub seed: i32,
85 | pub batch_count: i32,
86 | pub control_strength: f32,
87 | pub style_ratio: f32,
88 | pub normalize_input: bool,
89 | pub input_id_images_dir: String,
90 | pub canny_preprocess: bool,
91 | pub upscale_model: String,
92 | pub upscale_repeats: i32,
93 | pub output_path: String,
94 | }
95 | pub trait BaseFunction {
96 | fn base(&mut self) -> &mut BaseContext;
97 | fn set_prompt(&mut self, prompt: String) -> &mut Self {
98 | {
99 | self.base().prompt = prompt;
100 | }
101 | self
102 | }
103 | fn set_guidance(&mut self, guidance: f32) -> &mut Self {
104 | {
105 | self.base().guidance = guidance;
106 | }
107 | self
108 | }
109 | fn set_width(&mut self, width: i32) -> &mut Self {
110 | {
111 | self.base().width = width;
112 | }
113 | self
114 | }
115 | fn set_height(&mut self, height: i32) -> &mut Self {
116 | {
117 | self.base().height = height;
118 | }
119 | self
120 | }
121 | fn set_control_image(&mut self, control_image: ImageType) -> &mut Self {
122 | {
123 | self.base().control_image = control_image;
124 | }
125 | self
126 | }
127 | fn set_negative_prompt(&mut self, negative_prompt: impl Into) -> &mut Self {
128 | {
129 | self.base().negative_prompt = negative_prompt.into();
130 | }
131 | self
132 | }
133 | fn set_clip_skip(&mut self, clip_skip: i32) -> &mut Self {
134 | {
135 | self.base().clip_skip = clip_skip;
136 | }
137 | self
138 | }
139 | fn set_cfg_scale(&mut self, cfg_scale: f32) -> &mut Self {
140 | {
141 | self.base().cfg_scale = cfg_scale;
142 | }
143 | self
144 | }
145 | fn set_sample_method(&mut self, sample_method: SampleMethodT) -> &mut Self {
146 | {
147 | self.base().sample_method = sample_method;
148 | }
149 | self
150 | }
151 | fn set_sample_steps(&mut self, sample_steps: i32) -> &mut Self {
152 | {
153 | self.base().sample_steps = sample_steps;
154 | }
155 | self
156 | }
157 | fn set_seed(&mut self, seed: i32) -> &mut Self {
158 | {
159 | self.base().seed = seed;
160 | }
161 | self
162 | }
163 | fn set_batch_count(&mut self, batch_count: i32) -> &mut Self {
164 | {
165 | self.base().batch_count = batch_count;
166 | }
167 | self
168 | }
169 | fn set_control_strength(&mut self, control_strength: f32) -> &mut Self {
170 | {
171 | self.base().control_strength = control_strength;
172 | }
173 | self
174 | }
175 | fn set_style_ratio(&mut self, style_ratio: f32) -> &mut Self {
176 | {
177 | self.base().style_ratio = style_ratio;
178 | }
179 | self
180 | }
181 | fn enable_normalize_input(&mut self, normalize_input: bool) -> &mut Self {
182 | {
183 | self.base().normalize_input = normalize_input;
184 | }
185 | self
186 | }
187 | fn set_input_id_images_dir(&mut self, input_id_images_dir: String) -> &mut Self {
188 | {
189 | self.base().input_id_images_dir = input_id_images_dir;
190 | }
191 | self
192 | }
193 | fn enable_canny_preprocess(&mut self, canny_preprocess: bool) -> &mut Self {
194 | {
195 | self.base().canny_preprocess = canny_preprocess;
196 | }
197 | self
198 | }
199 | fn set_upscale_model(&mut self, upscale_model: String) -> &mut Self {
200 | {
201 | self.base().upscale_model = upscale_model;
202 | }
203 | self
204 | }
205 | fn set_upscale_repeats(&mut self, upscale_repeats: i32) -> &mut Self {
206 | {
207 | self.base().upscale_repeats = upscale_repeats;
208 | }
209 | self
210 | }
211 | fn set_output_path(&mut self, output_path: String) -> &mut Self {
212 | {
213 | self.base().output_path = output_path;
214 | }
215 | self
216 | }
217 | fn generate(&self) -> Result<(), WasmedgeSdErrno>;
218 | }
219 |
220 | /// Represents computation context for text-to-image task
221 | #[derive(Debug)]
222 | pub struct TextToImage {
223 | pub common: BaseContext,
224 | }
225 | impl BaseFunction for TextToImage {
226 | fn base(&mut self) -> &mut BaseContext {
227 | &mut self.common
228 | }
229 | fn generate(&self) -> Result<(), WasmedgeSdErrno> {
230 | if self.common.prompt.is_empty() {
231 | return Err(WASMEDGE_SD_ERRNO_INVALID_ARGUMENT);
232 | }
233 | let mut data: Vec = vec![0; BUF_LEN as usize];
234 | let result = unsafe {
235 | stable_diffusion_interface::text_to_image(
236 | &self.common.prompt,
237 | self.common.session_id,
238 | &self.common.control_image,
239 | &self.common.negative_prompt,
240 | self.common.guidance,
241 | self.common.width,
242 | self.common.height,
243 | self.common.clip_skip,
244 | self.common.cfg_scale,
245 | self.common.sample_method,
246 | self.common.sample_steps,
247 | self.common.seed,
248 | self.common.batch_count,
249 | self.common.control_strength,
250 | self.common.style_ratio,
251 | self.common.normalize_input,
252 | &self.common.input_id_images_dir,
253 | self.common.canny_preprocess,
254 | &self.common.upscale_model,
255 | self.common.upscale_repeats,
256 | &self.common.output_path,
257 | data.as_mut_ptr(),
258 | BUF_LEN,
259 | )
260 | };
261 | result?;
262 | Ok(())
263 | }
264 | }
265 |
266 | /// Represents computation context for image-to-image task.
267 | #[derive(Debug)]
268 | pub struct ImageToImage {
269 | pub common: BaseContext,
270 | pub image: ImageType,
271 | pub strength: f32,
272 | }
273 | impl BaseFunction for ImageToImage {
274 | fn base(&mut self) -> &mut BaseContext {
275 | &mut self.common
276 | }
277 | fn generate(&self) -> Result<(), WasmedgeSdErrno> {
278 | if self.common.prompt.is_empty() {
279 | return Err(WASMEDGE_SD_ERRNO_INVALID_ARGUMENT);
280 | }
281 | match &self.image {
282 | ImageType::Path(path) => {
283 | if path.is_empty() {
284 | return Err(WASMEDGE_SD_ERRNO_INVALID_ARGUMENT);
285 | }
286 | }
287 | }
288 | let mut data: Vec = vec![0; BUF_LEN as usize];
289 | let result = unsafe {
290 | stable_diffusion_interface::image_to_image(
291 | &self.image,
292 | self.common.session_id,
293 | self.common.guidance,
294 | self.common.width,
295 | self.common.height,
296 | &self.common.control_image,
297 | &self.common.prompt,
298 | &self.common.negative_prompt,
299 | self.common.clip_skip,
300 | self.common.cfg_scale,
301 | self.common.sample_method,
302 | self.common.sample_steps,
303 | self.strength,
304 | self.common.seed,
305 | self.common.batch_count,
306 | self.common.control_strength,
307 | self.common.style_ratio,
308 | self.common.normalize_input,
309 | &self.common.input_id_images_dir,
310 | self.common.canny_preprocess,
311 | &self.common.upscale_model,
312 | self.common.upscale_repeats,
313 | &self.common.output_path,
314 | data.as_mut_ptr(),
315 | BUF_LEN,
316 | )
317 | };
318 | result?;
319 | Ok(())
320 | }
321 | }
322 | impl ImageToImage {
323 | pub fn set_image(&mut self, image: ImageType) -> &mut Self {
324 | {
325 | self.image = image;
326 | }
327 | self
328 | }
329 | pub fn set_strength(&mut self, strength: f32) -> &mut Self {
330 | {
331 | self.strength = strength;
332 | }
333 | self
334 | }
335 | }
336 |
337 | /// Builder for creating a StableDiffusion instance.
338 | #[derive(Debug)]
339 | pub struct SDBuidler {
340 | sd: StableDiffusion,
341 | }
342 | impl SDBuidler {
343 | pub fn new(task: Task, model_path: impl AsRef) -> SDResult {
344 | let path = model_path
345 | .as_ref()
346 | .to_str()
347 | .ok_or_else(|| SDError::Operation("The model path is not valid unicode.".into()))?;
348 | let sd = StableDiffusion::new(task, path);
349 | Ok(Self { sd })
350 | }
351 |
352 | /// Create a new builder with a full model path.
353 | pub fn new_with_full_model(task: Task, model_path: impl AsRef) -> SDResult {
354 | let path = model_path
355 | .as_ref()
356 | .to_str()
357 | .ok_or_else(|| SDError::Operation("The model path is not valid unicode.".into()))?;
358 | let sd = StableDiffusion::new(task, path);
359 | Ok(Self { sd })
360 | }
361 |
362 | /// Create a new builder with a standalone diffusion model.
363 | pub fn new_with_standalone_model(
364 | task: Task,
365 | diffusion_model_path: impl AsRef,
366 | ) -> SDResult {
367 | let path = diffusion_model_path
368 | .as_ref()
369 | .to_str()
370 | .ok_or_else(|| SDError::Operation("The model path is not valid unicode.".into()))?;
371 | let sd = StableDiffusion::new_with_standalone_model(task, path);
372 | Ok(Self { sd })
373 | }
374 |
375 | pub fn with_vae_path(mut self, path: impl AsRef) -> SDResult {
376 | let path = path.as_ref().to_str().ok_or_else(|| {
377 | SDError::InvalidPath("The path to the vae file is not valid unicode.".into())
378 | })?;
379 | self.sd.vae_path = path.into();
380 | Ok(self)
381 | }
382 |
383 | pub fn with_clip_l_path(mut self, path: impl AsRef) -> SDResult {
384 | let path = path.as_ref().to_str().ok_or_else(|| {
385 | SDError::InvalidPath("The path to the clip_l file is not valid unicode.".into())
386 | })?;
387 | self.sd.clip_l_path = path.into();
388 | Ok(self)
389 | }
390 |
391 | pub fn with_t5xxl_path(mut self, path: impl AsRef) -> SDResult {
392 | let path = path.as_ref().to_str().ok_or_else(|| {
393 | SDError::InvalidPath("The path to the t5xxl file is not valid unicode.".into())
394 | })?;
395 | self.sd.t5xxl_path = path.into();
396 | Ok(self)
397 | }
398 |
399 | pub fn with_taesd_path(mut self, path: impl AsRef) -> SDResult {
400 | let path = path.as_ref().to_str().ok_or_else(|| {
401 | SDError::InvalidPath("The path to the taesd file is not valid unicode.".into())
402 | })?;
403 | self.sd.taesd_path = path.into();
404 | Ok(self)
405 | }
406 |
407 | pub fn with_lora_model_dir(mut self, path: impl AsRef) -> SDResult {
408 | let path = path.as_ref().to_str().ok_or_else(|| {
409 | SDError::InvalidPath(
410 | "The path to the lora model directory is not valid unicode.".into(),
411 | )
412 | })?;
413 | self.sd.lora_model_dir = path.into();
414 | Ok(self)
415 | }
416 |
417 | pub fn use_control_net(mut self, path: impl AsRef, on_cpu: bool) -> SDResult {
418 | let path = path.as_ref().to_str().ok_or_else(|| {
419 | SDError::InvalidPath("The path to the controlnet file is not valid unicode.".into())
420 | })?;
421 | self.sd.control_net_path = path.into();
422 | self.sd.control_net_cpu = on_cpu;
423 | Ok(self)
424 | }
425 |
426 | pub fn with_embeddings_path(mut self, path: impl AsRef) -> SDResult {
427 | let path = path.as_ref().to_str().ok_or_else(|| {
428 | SDError::InvalidPath("The path to the embeddings dir is not valid unicode.".into())
429 | })?;
430 | self.sd.embed_dir = path.into();
431 | Ok(self)
432 | }
433 |
434 | pub fn with_stacked_id_embeddings_path(mut self, path: impl AsRef) -> SDResult {
435 | let path = path.as_ref().to_str().ok_or_else(|| {
436 | SDError::InvalidPath(
437 | "The path to the stacked id embeddings dir is not valid unicode.".into(),
438 | )
439 | })?;
440 | self.sd.id_embed_dir = path.into();
441 | Ok(self)
442 | }
443 |
444 | pub fn with_n_threads(mut self, n_threads: i32) -> Self {
445 | self.sd.n_threads = n_threads;
446 | self
447 | }
448 |
449 | pub fn with_wtype(mut self, wtype: SdTypeT) -> Self {
450 | self.sd.wtype = wtype;
451 | self
452 | }
453 |
454 | pub fn with_rng_type(mut self, rng_type: RngTypeT) -> Self {
455 | self.sd.rng_type = rng_type;
456 | self
457 | }
458 |
459 | pub fn with_schedule(mut self, schedule: ScheduleT) -> Self {
460 | self.sd.schedule = schedule;
461 | self
462 | }
463 |
464 | pub fn enable_vae_tiling(mut self, enable: bool) -> Self {
465 | self.sd.vae_tiling = enable;
466 | self
467 | }
468 |
469 | pub fn enable_clip_on_cpu(mut self, enable: bool) -> Self {
470 | self.sd.clip_on_cpu = enable;
471 | self
472 | }
473 |
474 | pub fn enable_vae_on_cpu(mut self, enable: bool) -> Self {
475 | self.sd.vae_on_cpu = enable;
476 | self
477 | }
478 |
479 | pub fn build(self) -> StableDiffusion {
480 | self.sd
481 | }
482 | }
483 |
484 | /// Represents a stable diffusion model.
485 | #[derive(Debug)]
486 | pub struct StableDiffusion {
487 | task: Task,
488 | model_path: String,
489 | clip_l_path: String,
490 | t5xxl_path: String,
491 | diffusion_model_path: String,
492 | vae_path: String,
493 | taesd_path: String,
494 | control_net_path: String,
495 | lora_model_dir: String,
496 | embed_dir: String,
497 | id_embed_dir: String,
498 | vae_decode_only: bool,
499 | vae_tiling: bool,
500 | n_threads: i32,
501 | wtype: SdTypeT,
502 | rng_type: RngTypeT,
503 | schedule: ScheduleT,
504 | clip_on_cpu: bool,
505 | control_net_cpu: bool,
506 | vae_on_cpu: bool,
507 | }
508 | impl StableDiffusion {
509 | pub fn new(task: Task, model_path: &str) -> StableDiffusion {
510 | let vae_decode_only = match task {
511 | Task::TextToImage => true,
512 | Task::ImageToImage => false,
513 | Task::Convert => false,
514 | };
515 | StableDiffusion {
516 | task,
517 | model_path: model_path.to_string(),
518 | clip_l_path: "".to_string(),
519 | t5xxl_path: "".to_string(),
520 | diffusion_model_path: "".to_string(),
521 | vae_path: "".to_string(),
522 | taesd_path: "".to_string(),
523 | control_net_path: "".to_string(),
524 | lora_model_dir: "".to_string(),
525 | embed_dir: "".to_string(),
526 | id_embed_dir: "".to_string(),
527 | vae_decode_only,
528 | vae_tiling: false,
529 | n_threads: -1,
530 | wtype: SdTypeT::SdTypeCount,
531 | rng_type: RngTypeT::StdDefaultRng,
532 | schedule: ScheduleT::DEFAULT,
533 | clip_on_cpu: false,
534 | control_net_cpu: false,
535 | vae_on_cpu: false,
536 | }
537 | }
538 |
539 | pub fn new_with_standalone_model(task: Task, diffusion_model_path: &str) -> StableDiffusion {
540 | let vae_decode_only = match task {
541 | Task::TextToImage => true,
542 | Task::ImageToImage => false,
543 | Task::Convert => false,
544 | };
545 | StableDiffusion {
546 | task,
547 | model_path: "".to_string(),
548 | clip_l_path: "".to_string(),
549 | t5xxl_path: "".to_string(),
550 | diffusion_model_path: diffusion_model_path.to_string(),
551 | vae_path: "".to_string(),
552 | taesd_path: "".to_string(),
553 | control_net_path: "".to_string(),
554 | lora_model_dir: "".to_string(),
555 | embed_dir: "".to_string(),
556 | id_embed_dir: "".to_string(),
557 | vae_decode_only,
558 | vae_tiling: false,
559 | n_threads: -1,
560 | wtype: SdTypeT::SdTypeCount,
561 | rng_type: RngTypeT::StdDefaultRng,
562 | schedule: ScheduleT::DEFAULT,
563 | clip_on_cpu: false,
564 | control_net_cpu: false,
565 | vae_on_cpu: false,
566 | }
567 | }
568 |
569 | pub fn create_context(&self) -> Result {
570 | let mut session_id = MaybeUninit::::uninit();
571 | unsafe {
572 | stable_diffusion_interface::create_context(
573 | &self.model_path,
574 | &self.clip_l_path,
575 | &self.t5xxl_path,
576 | &self.diffusion_model_path,
577 | &self.vae_path,
578 | &self.taesd_path,
579 | &self.control_net_path,
580 | &self.lora_model_dir,
581 | &self.embed_dir,
582 | &self.id_embed_dir,
583 | self.vae_decode_only,
584 | self.vae_tiling,
585 | self.n_threads,
586 | self.wtype,
587 | self.rng_type,
588 | self.schedule,
589 | self.clip_on_cpu,
590 | self.control_net_cpu,
591 | self.vae_on_cpu,
592 | session_id.as_mut_ptr(),
593 | )?;
594 | let common = BaseContext {
595 | prompt: "".to_string(),
596 | session_id: session_id.assume_init(),
597 | guidance: 3.5,
598 | width: 512,
599 | height: 512,
600 | control_image: ImageType::Path("".to_string()),
601 | negative_prompt: "".to_string(),
602 | clip_skip: -1,
603 | cfg_scale: 7.0,
604 | sample_method: SampleMethodT::EULERA,
605 | sample_steps: 20,
606 | seed: 42,
607 | batch_count: 1,
608 | control_strength: 0.9,
609 | style_ratio: 20.0,
610 | normalize_input: false,
611 | input_id_images_dir: "".to_string(),
612 | canny_preprocess: false,
613 | upscale_model: "".to_string(),
614 | upscale_repeats: 1,
615 | output_path: "".to_string(),
616 | };
617 | match self.task {
618 | Task::TextToImage => Ok(Context::TextToImage(TextToImage { common })),
619 | Task::ImageToImage => Ok(Context::ImageToImage(ImageToImage {
620 | common,
621 | image: ImageType::Path("".to_string()),
622 | strength: 0.75,
623 | })),
624 | Task::Convert => todo!(),
625 | }
626 | }
627 | }
628 | }
629 |
--------------------------------------------------------------------------------
/example/src/main.rs:
--------------------------------------------------------------------------------
1 | use wasmedge_stable_diffusion::stable_diffusion_interface::{
2 | ImageType, RngTypeT, SampleMethodT, ScheduleT, SdTypeT,
3 | };
4 | use wasmedge_stable_diffusion::{BaseFunction, Context, Quantization, SDBuidler, Task};
5 |
6 | use clap::{crate_version, Arg, ArgAction, Command};
7 | use rand::Rng;
8 | use std::str::FromStr;
9 | use std::time::{SystemTime, UNIX_EPOCH};
10 |
11 | const WTYPE_METHODS: [&str; 35] = [
12 | "f32", "f16", "q4_0", "q4_1", "", "", "q5_0", "q5_1", "q8_0", "q8_1", "q2k", "q3k", "q4k",
13 | "q5k", "q6k", "q8k", "iq2Xxs", "iq2Xs", "iq3Xxs", "iq1S", "iq4N1", "iq3S", "iq2S", "iq4Xs",
14 | "i8", "i16", "i32", "i64", "f64", "iq1M", "bf16", "q4044", "q4048", "q4088", "count",
15 | ];
16 | //Sampling Methods
17 | const SAMPLE_METHODS: [&str; 10] = [
18 | "euler_a",
19 | "euler",
20 | "heun",
21 | "dpm2",
22 | "dpm++2s_a",
23 | "dpm++2m",
24 | "dpm++2mv2",
25 | "ipndm",
26 | "ipndm_v",
27 | "lcm",
28 | ];
29 | // Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h
30 | const SCHEDULE_STR: [&str; 6] = [
31 | "default",
32 | "discrete",
33 | "karras",
34 | "exponential",
35 | "ays",
36 | "gits",
37 | ];
38 |
39 | fn main() -> Result<(), Box> {
40 | let matches = Command::new("wasmedge-stable-diffusion")
41 | .version(crate_version!())
42 | .arg(
43 | Arg::new("mode")
44 | .short('M')
45 | .long("mode")
46 | .value_name("MODE")
47 | .value_parser([
48 | "txt2img",
49 | "img2img",
50 | "convert",
51 | ])
52 | .help("run mode (txt2img or img2img or convert, default: txt2img)")
53 | .default_value("txt2img"),
54 | )
55 | .arg(
56 | Arg::new("n_threads")
57 | .short('t')
58 | .long("threads")
59 | .value_parser(clap::value_parser!(i32))
60 | .value_name("N")
61 | .help("number of threads to use during computation (default: -1).If threads <= 0, then threads will be set to the number of CPU physical cores")
62 | .default_value("-1"),
63 | )
64 | .arg(
65 | Arg::new("model")
66 | .short('m')
67 | .long("model")
68 | .value_name("MODEL")
69 | .help("path to full model")
70 | .default_value("stable-diffusion-v1-4-Q8_0.gguf"),
71 | )
72 | .arg(
73 | Arg::new("diffusion_model_path")
74 | .long("diffusion-model")
75 | .value_name("PATH")
76 | .help("path to the standalone diffusion model")
77 | .default_value(""),
78 | )
79 | .arg(
80 | Arg::new("clip_l_path")
81 | .long("clip_l")
82 | .value_name("PATH")
83 | .help("path to clip_l")
84 | .default_value(""),
85 | )
86 | .arg(
87 | Arg::new("t5xxl_path")
88 | .long("t5xxl")
89 | .value_name("PATH")
90 | .help("path to t5xxl")
91 | .default_value(""),
92 | )
93 | .arg(
94 | Arg::new("vae_path")
95 | .long("vae")
96 | .value_name("VAE")
97 | .help("path to vae")
98 | .default_value(""),
99 | )
100 | .arg(
101 | Arg::new("taesd_path")
102 | .long("taesd")
103 | .value_name("TAESD_PATH")
104 | .help("path to taesd. Using Tiny AutoEncoder for fast decoding (low quality).")
105 | .default_value(""),
106 | )
107 | .arg(
108 | Arg::new("control_net_path")
109 | .long("control-net")
110 | .value_name("CONTROL_PATH")
111 | .help("path to control net model.")
112 | .default_value(""),
113 | )
114 | .arg(
115 | Arg::new("embeddings_path")
116 | .long("embd-dir")
117 | .value_name("EMBEDDING_PATH")
118 | .help("path to embeddings.")
119 | .default_value(""),
120 | )
121 | .arg(
122 | Arg::new("stacked_id_embd_dir")
123 | .long("stacked-id-embd-dir")
124 | .value_name("DIR")
125 | .help("path to PHOTOMAKER stacked id embeddings.")
126 | .default_value(""),
127 | )
128 | .arg(
129 | Arg::new("input_id_images_dir")
130 | .long("input-id-images-dir")
131 | .value_name("DIR")
132 | .help("path to PHOTOMAKER input id images dir.")
133 | .default_value(""),
134 | )
135 | .arg(
136 | Arg::new("normalize_input")
137 | .long("normalize-input")
138 | .help("normalize PHOTOMAKER input id images.")
139 | .action(ArgAction::SetTrue),
140 | )
141 | .arg(
142 | Arg::new("upscale_model")
143 | .long("upscale-model")
144 | .value_name("ESRGAN_PATH")
145 | .help("path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.")
146 | .default_value(""),
147 | )
148 | .arg(
149 | Arg::new("upscale_repeats")
150 | .long("upscale-repeats")
151 | .value_parser(clap::value_parser!(i32))
152 | .value_name("UPSCALE_REPEATS")
153 | .help("Run the ESRGAN upscaler this many times (default 1).")
154 | .default_value("1"),
155 | )
156 | .arg(
157 | Arg::new("type")
158 | .long("type")
159 | .value_name("TYPE")
160 | .value_parser([
161 | "f32",
162 | "f16",
163 | "q4_0",
164 | "q4_1",
165 | "q5_0",
166 | "q5_1",
167 | "q8_0",
168 | "q8_1",
169 | "q2k",
170 | "q3k",
171 | "q4k",
172 | "q5k",
173 | "q6k",
174 | "q8k",
175 | "iq2Xxs",
176 | "iq2Xs",
177 | "iq3Xxs",
178 | "iq1S",
179 | "iq4N1",
180 | "iq3S",
181 | "iq2S",
182 | "iq4Xs",
183 | "i8",
184 | "i16",
185 | "i32",
186 | "i64",
187 | "f64",
188 | "iq1M",
189 | "bf16",
190 | "q4044",
191 | "q4048",
192 | "q4088",
193 | "count"
194 | ])
195 | .help("weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)If not specified, the default is the type of the weight file.")
196 | .default_value("count"),
197 | )
198 | .arg(
199 | Arg::new("lora_model_dir")
200 | .long("lora-model-dir")
201 | .value_name("DIR")
202 | .help("lora model directory.")
203 | .default_value(""),
204 | )
205 | .arg(
206 | Arg::new("init_img")
207 | .short('i')
208 | .long("init-img")
209 | .value_name("IMAGE")
210 | .help("path to the input image, required by img2img.")
211 | .default_value("./output.png"),
212 | )
213 | .arg(
214 | Arg::new("control_image")
215 | .long("control-image")
216 | .value_name("IMAGE")
217 | .help("path to image condition, control net.")
218 | .default_value(""),
219 | )
220 | .arg(
221 | Arg::new("output_path")
222 | .short('o')
223 | .long("output")
224 | .value_name("OUTPUT")
225 | .help("path to write result image to (default: ./output.png).")
226 | .default_value("./output2.png"),
227 | )
228 | .arg(
229 | Arg::new("prompt")
230 | .short('p')
231 | .long("prompt")
232 | .value_name("PROMPT")
233 | .help("the prompt to render.")
234 | .default_value("cat with blue eyes"),
235 | )
236 | .arg(
237 | Arg::new("negative_prompt")
238 | .short('n')
239 | .long("negative-prompt")
240 | .value_name("PROMPT")
241 | .help("the negative prompt.(default: '').")
242 | .default_value(""),
243 | )
244 | .arg(
245 | Arg::new("cfg_scale")
246 | .long("cfg-scale")
247 | .value_parser(clap::value_parser!(f32))
248 | .value_name("SCALE")
249 | .help("unconditional guidance scale: (default: 7.0)")
250 | .default_value("7.0"),
251 | )
252 | .arg(
253 | Arg::new("strength")
254 | .long("strength")
255 | .value_parser(clap::value_parser!(f32))
256 | .value_name("STRENGTH")
257 | .help("strength for noising/unnoising (default: 0.75).")
258 | .default_value("0.75"),
259 | )
260 | .arg(
261 | Arg::new("style_ratio")
262 | .long("style-ratio")
263 | .value_parser(clap::value_parser!(f32))
264 | .value_name("STYLE_RATIO")
265 | .help("strength for keeping input identity (default: 20%).")
266 | .default_value("20.0"),
267 | )
268 | .arg(
269 | Arg::new("control_strength")
270 | .long("control-strength")
271 | .value_parser(clap::value_parser!(f32))
272 | .value_name("CONTROL-STRENGTH")
273 | .help("strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image.")
274 | .default_value("0.9"),
275 | )
276 | .arg(
277 | Arg::new("guidance")
278 | .long("guidance")
279 | .value_parser(clap::value_parser!(f32))
280 | .value_name("GUAIDANCE")
281 | .help("guidance scale")
282 | .default_value("3.5"),
283 | )
284 | .arg(
285 | Arg::new("height")
286 | .short('H')
287 | .long("height")
288 | .value_parser(clap::value_parser!(i32))
289 | .value_name("H")
290 | .help("image height, in pixel space (default: 512)")
291 | .default_value("512"),
292 | )
293 | .arg(
294 | Arg::new("width")
295 | .short('W')
296 | .long("width")
297 | .value_parser(clap::value_parser!(i32))
298 | .value_name("W")
299 | .help("image width, in pixel space (default: 512)")
300 | .default_value("512"),
301 | )
302 | .arg(
303 | Arg::new("sampling_method")
304 | .long("sampling-method")
305 | .value_parser([
306 | "euler_a",
307 | "euler",
308 | "heun",
309 | "dpm2",
310 | "dpm++2s_a",
311 | "dpm++2m",
312 | "dpm++2mv2",
313 | "ipndm",
314 | "ipndm_v",
315 | "lcm",
316 | ])
317 | .value_name("SAMPLING_METHOD")
318 | .help("the sampling method, include values {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, lcm}, sampling method (default: euler_a)")
319 | .default_value("euler_a"),
320 | )
321 | .arg(
322 | Arg::new("sample_steps")
323 | .long("steps")
324 | .value_parser(clap::value_parser!(i32))
325 | .value_name("STEPS")
326 | .help("number of sample steps (default: 20).")
327 | .default_value("20"),
328 | )
329 | .arg(
330 | Arg::new("rng_type")
331 | .long("rng")
332 | .value_name("RNG")
333 | .value_parser([
334 | "std_default",
335 | "cuda",
336 | ])
337 | .help("RNG (default: std_default).")
338 | .default_value("std_default"),
339 | )
340 | .arg(
341 | Arg::new("seed")
342 | .short('s')
343 | .long("seed")
344 | .value_parser(clap::value_parser!(i32))
345 | .value_name("SEED")
346 | .help("RNG seed (default: 42, use random seed for < 0).")
347 | .default_value("42"),
348 | )
349 | .arg(
350 | Arg::new("batch_count")
351 | .short('b')
352 | .long("batch-count")
353 | .value_parser(clap::value_parser!(i32))
354 | .value_name("BATCH_COUNT")
355 | .help("number of images to generate.")
356 | .default_value("1"),
357 | )
358 | .arg(
359 | Arg::new("schedule")
360 | .long("schedule")
361 | .value_name("SCHEDULE")
362 | .value_parser([
363 | "default",
364 | "discrete",
365 | "karras",
366 | "exponential",
367 | "ays",
368 | "gits"
369 | ])
370 | .help("Denoiser sigma schedule")
371 | .default_value("default"),
372 | )
373 | .arg(
374 | Arg::new("clip_skip")
375 | .long("clip-skip")
376 | .value_parser(clap::value_parser!(i32))
377 | .value_name("N")
378 | .help("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1), <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x.")
379 | .default_value("-1"),
380 | )
381 | .arg(
382 | Arg::new("vae_tiling")
383 | .long("vae-tiling")
384 | .help("process vae in tiles to reduce memory usage")
385 | .action(ArgAction::SetTrue),
386 | )
387 | .arg(
388 | Arg::new("vae_on_cpu")
389 | .long("vae-on-cpu")
390 | .help("keep vae in cpu (for low vram)")
391 | .action(ArgAction::SetTrue),
392 | )
393 | .arg(
394 | Arg::new("clip_on_cpu")
395 | .long("clip-on-cpu")
396 | .help("keep clip in cpu (for low vram)")
397 | .action(ArgAction::SetTrue),
398 | )
399 | .arg(
400 | Arg::new("control_net_cpu")
401 | .long("control-net-cpu")
402 | .help("keep controlnet in cpu (for low vram)")
403 | .action(ArgAction::SetTrue),
404 | )
405 | .arg(
406 | Arg::new("canny")
407 | .long("canny")
408 | .help("apply canny preprocessor (edge detection)")
409 | .action(ArgAction::SetTrue),
410 | )
411 | .arg(
412 | Arg::new("debug")
413 | .long("debug")
414 | .help("print debug informations")
415 | .action(ArgAction::SetTrue),
416 | )
417 | .after_help("run at the dir of .wasm, Example:wasmedge --dir .:. ./target/wasm32-wasi/release/wasmedge_stable_diffusion_example.wasm -m ../../models/stable-diffusion-v1-4-Q8_0.gguf -M img2img\n")
418 | .get_matches();
419 |
420 | //init the paraments--------------------------------------------------------------
421 | let mut options = Options::default();
422 |
423 | //mode, include "txt2img","img2img",----------"convert" is not yet-------.
424 | let sd_mode = matches.get_one::("mode").unwrap();
425 | let task = Task::from_str(sd_mode)?;
426 | options.mode = sd_mode.to_string();
427 |
428 | //n_threads
429 | let n_threads = matches.get_one::("n_threads").unwrap();
430 | options.n_threads = *n_threads as i32;
431 |
432 | //model
433 | let sd_model = matches.get_one::("model").unwrap();
434 | options.model_path = sd_model.to_string();
435 |
436 | //clip_l_path
437 | let clip_l_path = matches.get_one::("clip_l_path").unwrap();
438 | options.clip_l_path = clip_l_path.to_string();
439 |
440 | //t5xxl_path
441 | let t5xxl_path = matches.get_one::("t5xxl_path").unwrap();
442 | options.t5xxl_path = t5xxl_path.to_string();
443 |
444 | //diffusion_model_path
445 | let diffusion_model_path = matches.get_one::("diffusion_model_path").unwrap();
446 | options.diffusion_model_path = diffusion_model_path.to_string();
447 |
448 | //vae_path
449 | let vae_path = matches.get_one::("vae_path").unwrap();
450 | options.vae_path = vae_path.to_string();
451 |
452 | //taesd_path
453 | let taesd_path = matches.get_one::("taesd_path").unwrap();
454 | options.taesd_path = taesd_path.to_string();
455 |
456 | //control_net_path
457 | let control_net_path = matches.get_one::("control_net_path").unwrap();
458 | options.control_net_path = control_net_path.to_string();
459 |
460 | //embeddings_path
461 | let embeddings_path = matches.get_one::("embeddings_path").unwrap();
462 | options.embeddings_path = embeddings_path.to_string();
463 |
464 | //stacked_id_embd_dir
465 | let stacked_id_embd_dir = matches.get_one::("stacked_id_embd_dir").unwrap();
466 | options.stacked_id_embd_dir = stacked_id_embd_dir.to_string();
467 |
468 | //input_id_images_dir
469 | let input_id_images_dir = matches.get_one::("input_id_images_dir").unwrap();
470 | options.input_id_images_dir = input_id_images_dir.to_string();
471 |
472 | //normalize-input
473 | let normalize_input = matches.get_flag("normalize_input");
474 | options.normalize_input = normalize_input;
475 |
476 | //upscale_model
477 | let upscale_model = matches.get_one::("upscale_model").unwrap();
478 | options.upscale_model = upscale_model.to_string();
479 |
480 | //upscale_repeats
481 | let upscale_repeats = matches.get_one::("upscale_repeats").unwrap();
482 | if *upscale_repeats < 1 {
483 | return Err("Error: the upscale_repeats must be greater than 0".into());
484 | }
485 | options.upscale_repeats = *upscale_repeats as i32;
486 |
487 | //type
488 | let wtype_selected = matches.get_one::("type").unwrap();
489 | let wtype_found = WTYPE_METHODS
490 | .iter()
491 | .position(|&method| method == wtype_selected)
492 | .ok_or(format!("Invalid wtype: {}", wtype_selected))?;
493 | let wtype = SdTypeT::from_index(wtype_found)?;
494 | options.wtype = wtype;
495 |
496 | //lora_model_dir
497 | let lora_model_dir = matches.get_one::("lora_model_dir").unwrap();
498 | options.lora_model_dir = lora_model_dir.to_string();
499 |
500 | //init_img, used only for img2img
501 | let img = matches.get_one::("init_img").unwrap();
502 | if sd_mode == "img2img" {
503 | options.init_img = img.to_string();
504 | };
505 |
506 | //control_image
507 | let control_image = matches.get_one::("control_image").unwrap();
508 | options.control_image = control_image.to_string();
509 |
510 | //output_path
511 | let output_path = matches.get_one::("output_path").unwrap();
512 | options.output_path = output_path.to_string();
513 |
514 | //prompt
515 | let prompt = matches.get_one::("prompt").unwrap();
516 | options.prompt = prompt.to_string();
517 |
518 | //negative_prompt
519 | let negative_prompt = matches.get_one::("negative_prompt").unwrap();
520 | options.negative_prompt = negative_prompt.to_string();
521 |
522 | //cfg_scale
523 | let cfg_scale = matches.get_one::("cfg_scale").unwrap();
524 | options.cfg_scale = *cfg_scale as f32;
525 |
526 | //strength
527 | let strength = matches.get_one::("strength").unwrap();
528 | if *strength < 0.0 || *strength > 1.0 {
529 | return Err("Error: can only work with strength in [0.0, 1.0]".into());
530 | }
531 | options.strength = *strength as f32;
532 |
533 | //style_ratio
534 | let style_ratio = matches.get_one::("style_ratio").unwrap();
535 | if *style_ratio > 100.0 {
536 | return Err("Error: can only work with style_ratio in [0.0, 100.0]".into());
537 | }
538 | options.style_ratio = *style_ratio as f32;
539 |
540 | //control_strength
541 | let control_strength = matches.get_one::("control_strength").unwrap();
542 | if *control_strength > 1.0 {
543 | return Err("Error: can only work with control_strength in [0.0, 1.0]".into());
544 | }
545 | options.control_strength = *control_strength as f32;
546 |
547 | //guidance
548 | let guidance = matches.get_one::("guidance").unwrap();
549 | options.guidance = *guidance as f32;
550 |
551 | //height
552 | let height = matches.get_one::("height").unwrap();
553 | options.height = *height as i32;
554 |
555 | //width
556 | let width = matches.get_one::("width").unwrap();
557 | options.width = *width as i32;
558 |
559 | //sampling_method
560 | let sampling_method_selected = matches.get_one::("sampling_method").unwrap();
561 | let sample_method_found = SAMPLE_METHODS
562 | .iter()
563 | .position(|&method| method == sampling_method_selected)
564 | .ok_or(format!(
565 | "Invalid sampling method: {}",
566 | sampling_method_selected
567 | ))?;
568 | let sample_method = SampleMethodT::from_index(sample_method_found)?;
569 | options.sample_method = sample_method;
570 |
571 | //sample_steps
572 | let sample_steps = matches.get_one::("sample_steps").unwrap();
573 | if *sample_steps <= 0 {
574 | return Err("Error: the sample_steps must be greater than 0".into());
575 | }
576 | options.sample_steps = *sample_steps as i32;
577 |
578 | //rng_type
579 | let mut rng_type = RngTypeT::StdDefaultRng;
580 | let rng_type_str = matches.get_one::("rng_type").unwrap();
581 | if rng_type_str == "cuda" {
582 | rng_type = RngTypeT::CUDARng;
583 | }
584 | options.rng_type = rng_type;
585 |
586 | //seed
587 | let seed_str = matches.get_one::("seed").unwrap();
588 | let mut seed = *seed_str;
589 | // let mut seed: i32 = seed_str.parse().expect("Failed to parse seed as i32");
590 | if seed < 0 {
591 | let current_time = SystemTime::now()
592 | .duration_since(UNIX_EPOCH)
593 | .expect("Time went backwards");
594 | let current_time_secs = current_time.as_secs() as u32;
595 | let mut rng = rand::thread_rng();
596 | // Limit the result to i32 range
597 | seed = ((rng.gen::() ^ current_time_secs) & i32::MAX as u32) as i32;
598 | }
599 | options.seed = seed;
600 |
601 | //batch_count
602 | let batch_count = matches.get_one::("batch_count").unwrap();
603 | options.batch_count = *batch_count as i32;
604 |
605 | //schedule
606 | let schedule_selected = matches.get_one::("schedule").unwrap();
607 | let schedule_found = SCHEDULE_STR
608 | .iter()
609 | .position(|&method| method == schedule_selected)
610 | .ok_or(format!("Invalid sampling method: {}", schedule_selected))?;
611 | // Convert an index to an enumeration value
612 | let schedule = ScheduleT::from_index(schedule_found)?;
613 | options.schedule = schedule;
614 |
615 | //clip_skip
616 | let clip_skip = matches.get_one::("clip_skip").unwrap();
617 | options.clip_skip = *clip_skip as i32;
618 |
619 | //vae_tiling
620 | let vae_tiling = matches.get_flag("vae_tiling");
621 | options.vae_tiling = vae_tiling;
622 |
623 | //control_net_cpu
624 | let control_net_cpu = matches.get_flag("control_net_cpu");
625 | options.control_net_cpu = control_net_cpu;
626 |
627 | //canny
628 | let canny = matches.get_flag("canny");
629 | options.canny = canny;
630 |
631 | //clip_on_cpu
632 | let clip_on_cpu = matches.get_flag("clip_on_cpu");
633 | options.clip_on_cpu = clip_on_cpu;
634 |
635 | //vae_on_cpu
636 | let vae_on_cpu = matches.get_flag("vae_on_cpu");
637 | options.vae_on_cpu = vae_on_cpu;
638 |
639 | //debug
640 | let debug = matches.get_flag("debug");
641 |
642 | //DEBUG: print options from CL
643 | if debug {
644 | print_params(&mut options);
645 | }
646 |
647 | //------------------------------- run the model ----------------------------------------
648 | match options.mode.as_str() {
649 | "txt2img" => {
650 | let context = SDBuidler::new(task, &options.model_path)?
651 | .with_clip_l_path(options.clip_l_path)?
652 | .with_t5xxl_path(options.t5xxl_path)?
653 | .with_vae_path(options.vae_path)?
654 | .with_taesd_path(options.taesd_path)?
655 | .with_lora_model_dir(options.lora_model_dir)?
656 | .with_embeddings_path(options.embeddings_path)?
657 | .with_stacked_id_embeddings_path(options.stacked_id_embd_dir)?
658 | .use_control_net(options.control_net_path, options.control_net_cpu)?
659 | .with_n_threads(options.n_threads)
660 | .with_wtype(options.wtype)
661 | .with_rng_type(options.rng_type)
662 | .with_schedule(options.schedule)
663 | .enable_vae_tiling(options.vae_tiling)
664 | .enable_clip_on_cpu(options.clip_on_cpu)
665 | .enable_vae_on_cpu(options.vae_on_cpu)
666 | .build();
667 | if let Context::TextToImage(mut text_to_image) = context.create_context().unwrap() {
668 | text_to_image
669 | .set_prompt(options.prompt)
670 | .set_guidance(options.guidance)
671 | .set_width(options.width)
672 | .set_height(options.height)
673 | .set_control_image(ImageType::Path(options.control_image))
674 | .set_negative_prompt(options.negative_prompt)
675 | .set_clip_skip(options.clip_skip)
676 | .set_cfg_scale(options.cfg_scale)
677 | .set_sample_method(options.sample_method)
678 | .set_sample_steps(options.sample_steps)
679 | .set_seed(options.seed)
680 | .set_batch_count(options.batch_count)
681 | .set_control_strength(options.control_strength)
682 | .set_style_ratio(options.style_ratio)
683 | .enable_normalize_input(options.normalize_input)
684 | .set_input_id_images_dir(options.input_id_images_dir)
685 | .enable_canny_preprocess(options.canny)
686 | .set_upscale_model(options.upscale_model)
687 | .set_upscale_repeats(options.upscale_repeats)
688 | .set_output_path(options.output_path)
689 | .generate()
690 | .unwrap();
691 | }
692 | }
693 | "img2img" => {
694 | let context = SDBuidler::new(task, &options.model_path)?
695 | .with_clip_l_path(options.clip_l_path)?
696 | .with_t5xxl_path(options.t5xxl_path)?
697 | .with_vae_path(options.vae_path)?
698 | .with_taesd_path(options.taesd_path)?
699 | .with_lora_model_dir(options.lora_model_dir)?
700 | .with_embeddings_path(options.embeddings_path)?
701 | .with_stacked_id_embeddings_path(options.stacked_id_embd_dir)?
702 | .use_control_net(options.control_net_path, options.control_net_cpu)?
703 | .with_n_threads(options.n_threads)
704 | .with_wtype(options.wtype)
705 | .with_rng_type(options.rng_type)
706 | .with_schedule(options.schedule)
707 | .enable_vae_tiling(options.vae_tiling)
708 | .enable_clip_on_cpu(options.clip_on_cpu)
709 | .enable_vae_on_cpu(options.vae_on_cpu)
710 | .build();
711 | if let Context::ImageToImage(mut image_to_image) = context.create_context().unwrap() {
712 | image_to_image
713 | .set_prompt(options.prompt)
714 | .set_guidance(options.guidance)
715 | .set_width(options.width)
716 | .set_height(options.height)
717 | .set_control_image(ImageType::Path(options.control_image))
718 | .set_negative_prompt(options.negative_prompt)
719 | .set_clip_skip(options.clip_skip)
720 | .set_cfg_scale(options.cfg_scale)
721 | .set_sample_method(options.sample_method)
722 | .set_sample_steps(options.sample_steps)
723 | .set_seed(options.seed)
724 | .set_batch_count(options.batch_count)
725 | .set_control_strength(options.control_strength)
726 | .set_style_ratio(options.style_ratio)
727 | .enable_normalize_input(options.normalize_input)
728 | .set_input_id_images_dir(options.input_id_images_dir)
729 | .enable_canny_preprocess(options.canny)
730 | .set_upscale_model(options.upscale_model)
731 | .set_upscale_repeats(options.upscale_repeats)
732 | .set_output_path(options.output_path)
733 | //addtional options for img2img
734 | .set_image(ImageType::Path(options.init_img))
735 | .set_strength(options.strength)
736 | .generate()
737 | .unwrap();
738 | }
739 | }
740 | "convert" => {
741 | let quantization = Quantization::new(
742 | &options.model_path,
743 | options.vae_path,
744 | &options.output_path,
745 | options.wtype,
746 | );
747 | quantization.convert().unwrap();
748 | }
749 | _ => {
750 | println!("Error: this mode isn't supported!");
751 | }
752 | }
753 | return Ok(());
754 | }
755 |
756 | #[derive(Debug)]
757 | struct Options {
758 | n_threads: i32,
759 | mode: String,
760 | model_path: String,
761 | clip_l_path: String,
762 | t5xxl_path: String,
763 | diffusion_model_path: String,
764 | vae_path: String,
765 | taesd_path: String,
766 | control_net_path: String,
767 | upscale_model: String,
768 | embeddings_path: String,
769 | stacked_id_embd_dir: String,
770 | input_id_images_dir: String,
771 | wtype: SdTypeT,
772 | lora_model_dir: String,
773 | output_path: String,
774 | init_img: String,
775 | control_image: String,
776 |
777 | prompt: String,
778 | negative_prompt: String,
779 | cfg_scale: f32,
780 | guidance: f32,
781 | style_ratio: f32,
782 | clip_skip: i32,
783 | width: i32,
784 | height: i32,
785 | batch_count: i32,
786 |
787 | sample_method: SampleMethodT,
788 | schedule: ScheduleT,
789 | sample_steps: i32,
790 | strength: f32,
791 | control_strength: f32,
792 | rng_type: RngTypeT,
793 | seed: i32,
794 | vae_tiling: bool,
795 | control_net_cpu: bool,
796 | normalize_input: bool,
797 | clip_on_cpu: bool,
798 | vae_on_cpu: bool,
799 | canny: bool,
800 | upscale_repeats: i32,
801 | }
802 |
803 | impl Default for Options {
804 | fn default() -> Self {
805 | Self {
806 | n_threads: -1,
807 | mode: String::from("txt2img"),
808 | model_path: String::from(""),
809 | clip_l_path: String::from(""),
810 | t5xxl_path: String::from(""),
811 | diffusion_model_path: String::from(""),
812 | vae_path: String::from(""),
813 | taesd_path: String::from(""),
814 | control_net_path: String::from(""),
815 | upscale_model: String::from(""),
816 | embeddings_path: String::from(""),
817 | stacked_id_embd_dir: String::from(""),
818 | input_id_images_dir: String::from(""),
819 | wtype: SdTypeT::SdTypeCount,
820 | lora_model_dir: String::from(""),
821 | output_path: String::from(""),
822 | init_img: String::from(""),
823 | control_image: String::from(""),
824 |
825 | prompt: String::from(""),
826 | negative_prompt: String::from(""),
827 | cfg_scale: 7.0,
828 | guidance: 3.5,
829 | style_ratio: 20.0,
830 | clip_skip: -1,
831 | width: 512,
832 | height: 512,
833 | batch_count: 1,
834 |
835 | sample_method: SampleMethodT::EULERA,
836 | schedule: ScheduleT::DEFAULT,
837 | sample_steps: 20,
838 | strength: 0.75,
839 | control_strength: 0.9,
840 | rng_type: RngTypeT::StdDefaultRng,
841 | seed: 42,
842 | vae_tiling: false,
843 | control_net_cpu: false,
844 | normalize_input: false,
845 | clip_on_cpu: false,
846 | vae_on_cpu: false,
847 | canny: false,
848 | upscale_repeats: 1,
849 | }
850 | }
851 | }
852 |
853 | fn print_params(params: &mut Options) {
854 | println!("Option:");
855 | println!("[INFO] n_threads: {}", params.n_threads);
856 | println!("[INFO] mode: {}", params.mode);
857 | println!("[INFO] model_path: {}", params.model_path);
858 | println!(
859 | "[INFO] diffusion_model_path:{}",
860 | params.diffusion_model_path
861 | );
862 | println!("[INFO] clip_l_path: {}", params.clip_l_path);
863 | println!("[INFO] t5xxl_path: {}", params.t5xxl_path);
864 | println!("[INFO] vae_path: {}", params.vae_path);
865 | println!("[INFO] taesd_path: {}", params.taesd_path);
866 | println!("[INFO] control_net_path: {}", params.control_net_path);
867 | println!("[INFO] upscale_model: {}", params.upscale_model);
868 | println!("[INFO] embeddings_path: {}", params.embeddings_path);
869 | println!("[INFO] stacked_id_embd: {}", params.stacked_id_embd_dir);
870 | println!("[INFO] input_id_images: {}", params.input_id_images_dir);
871 | println!("[INFO] wtype: {:?}", params.wtype);
872 | println!("[INFO] lora_model_dir: {}", params.lora_model_dir);
873 | println!("[INFO] output_path: {}", params.output_path);
874 | println!("[INFO] init_img: {}", params.init_img);
875 | println!("[INFO] control_image: {}", params.control_image);
876 | println!("[INFO] prompt: {}", params.prompt);
877 | println!("[INFO] negative_prompt: {}", params.negative_prompt);
878 | println!("[INFO] cfg_scale: {}", params.cfg_scale);
879 | println!("[INFO] guidance: {}", params.guidance);
880 | println!("[INFO] style_ratio: {}", params.style_ratio);
881 | println!("[INFO] clip_skip: {}", params.clip_skip);
882 | println!("[INFO] width: {}", params.width);
883 | println!("[INFO] height: {}", params.height);
884 | println!("[INFO] batch_count: {}", params.batch_count);
885 | println!("[INFO] sample_method: {:?}", params.sample_method);
886 | println!("[INFO] schedule: {:?}", params.schedule);
887 | println!("[INFO] sample_steps: {}", params.sample_steps);
888 | println!("[INFO] strength: {}", params.strength);
889 | println!("[INFO] control_strength: {}", params.control_strength);
890 | println!("[INFO] rng_type: {:?}", params.rng_type);
891 | println!("[INFO] seed: {}", params.seed);
892 | println!("[INFO] vae_tiling: {}", params.vae_tiling);
893 | println!("[INFO] control_net_cpu: {}", params.control_net_cpu);
894 | println!("[INFO] normalize_input: {}", params.normalize_input);
895 | println!("[INFO] clip_on_cpu: {}", params.clip_on_cpu);
896 | println!("[INFO] vae_on_cpu: {}", params.vae_on_cpu);
897 | println!("[INFO] canny: {}", params.canny);
898 | println!("[INFO] upscale_repeats: {}", params.upscale_repeats);
899 | }
900 |
--------------------------------------------------------------------------------