├── .github └── workflows │ └── rust-ci.yml ├── .gitignore ├── Cargo.toml ├── LICENSE-APACHE ├── LICENSE-MIT ├── README.md ├── build.rs ├── clippy.toml ├── examples ├── controlnet │ └── main.rs ├── stable-diffusion-img2img │ └── main.rs ├── stable-diffusion-inpaint │ └── main.rs ├── stable-diffusion │ └── main.rs └── tensor-tools.rs ├── media ├── in_img2img.jpg ├── out_img2img.jpg ├── out_inpaint.jpg ├── robot11.jpg ├── robot13.jpg ├── robot3.jpg ├── robot4.jpg ├── robot7.jpg ├── robot8.jpg ├── vermeer-edges.png ├── vermeer-out1.jpg ├── vermeer-out2.jpg ├── vermeer-out3.jpg ├── vermeer-out4.jpg ├── vermeer-out5.jpg └── vermeer.jpg ├── rustfmt.toml ├── scripts ├── download_weights_1.5.sh ├── download_weights_2.1.sh └── get_weights.py └── src ├── lib.rs ├── models ├── attention.rs ├── controlnet.rs ├── embeddings.rs ├── mod.rs ├── resnet.rs ├── unet_2d.rs ├── unet_2d_blocks.rs └── vae.rs ├── pipelines ├── mod.rs └── stable_diffusion.rs ├── schedulers ├── ddim.rs ├── ddpm.rs ├── dpmsolver_multistep.rs ├── euler_ancestral_discrete.rs ├── euler_discrete.rs ├── heun_discrete.rs ├── integrate.rs ├── k_dpm_2_ancestral_discrete.rs ├── k_dpm_2_discrete.rs ├── lms_discrete.rs ├── mod.rs └── pndm.rs ├── transformers ├── clip.rs └── mod.rs └── utils.rs /.github/workflows/rust-ci.yml: -------------------------------------------------------------------------------- 1 | on: [push, pull_request] 2 | 3 | name: Continuous integration 4 | 5 | jobs: 6 | check: 7 | name: Check 8 | runs-on: ${{ matrix.os }} 9 | strategy: 10 | matrix: 11 | os: [ubuntu-latest, windows-2019, macOS-latest] 12 | rust: [stable, nightly] 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: actions-rs/toolchain@v1 16 | with: 17 | profile: minimal 18 | toolchain: ${{ matrix.rust }} 19 | override: true 20 | - uses: actions-rs/cargo@v1 21 | with: 22 | command: check 23 | 24 | test: 25 | name: Test Suite 26 | runs-on: ${{ matrix.os }} 27 | strategy: 28 | matrix: 29 | os: [ubuntu-latest, windows-2019, macOS-latest] 30 | rust: [stable, nightly] 31 | steps: 32 | - uses: actions/checkout@v2 33 | - uses: actions-rs/toolchain@v1 34 | with: 35 | profile: minimal 36 | toolchain: ${{ matrix.rust }} 37 | override: true 38 | - uses: actions-rs/cargo@v1 39 | with: 40 | command: test 41 | 42 | fmt: 43 | name: Rustfmt 44 | runs-on: ubuntu-latest 45 | steps: 46 | - uses: actions/checkout@v2 47 | - uses: actions-rs/toolchain@v1 48 | with: 49 | profile: minimal 50 | toolchain: stable 51 | override: true 52 | - run: rustup component add rustfmt 53 | - uses: actions-rs/cargo@v1 54 | with: 55 | command: fmt 56 | args: --all -- --check 57 | 58 | clippy: 59 | name: Clippy 60 | runs-on: ubuntu-latest 61 | steps: 62 | - uses: actions/checkout@v2 63 | - uses: actions-rs/toolchain@v1 64 | with: 65 | profile: minimal 66 | toolchain: stable 67 | override: true 68 | - run: rustup component add clippy 69 | - uses: actions-rs/cargo@v1 70 | with: 71 | command: clippy 72 | args: --tests --examples --features=clap -- -D warnings 73 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | target/ 2 | target2/ 3 | _build/ 4 | data/ 5 | data*/ 6 | gen/.merlin 7 | **/*.rs.bk 8 | *.swp 9 | *.swo 10 | Cargo.lock 11 | __pycache__ 12 | *~ 13 | .*~ 14 | sd_*jpg 15 | sd_*png 16 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "diffusers" 3 | version = "0.3.1" 4 | authors = ["Laurent Mazare "] 5 | edition = "2021" 6 | build = "build.rs" 7 | 8 | description = "Rust implementation of the Diffusers library using Torch." 9 | repository = "https://github.com/LaurentMazare/diffusers-rs" 10 | keywords = ["pytorch", "deep-learning", "machine-learning", "diffusion", "transformers"] 11 | categories = ["science"] 12 | license = "MIT/Apache-2.0" 13 | readme = "README.md" 14 | 15 | exclude = [ 16 | "media/*", 17 | ] 18 | 19 | [dependencies] 20 | anyhow = "1" 21 | thiserror = "1" 22 | regex = "1.6.0" 23 | tch = "0.13" 24 | torch-sys = { version = "0.13", features = ["download-libtorch"] } 25 | 26 | clap = { version = "4.0.19", optional = true, features = ["derive"] } 27 | image = { version = "0.24.6", optional = true } 28 | imageproc = { version = "0.23.0", optional = true } 29 | 30 | [[example]] 31 | name = "stable-diffusion" 32 | required-features = ["clap"] 33 | 34 | [[example]] 35 | name = "stable-diffusion-img2img" 36 | required-features = ["clap"] 37 | 38 | [[example]] 39 | name = "stable-diffusion-inpaint" 40 | required-features = ["clap"] 41 | 42 | [[example]] 43 | name = "controlnet" 44 | required-features = ["clap", "imageproc"] 45 | 46 | [features] 47 | doc-only = ["tch/doc-only"] 48 | 49 | [package.metadata.docs.rs] 50 | features = ["doc-only"] 51 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | Permission is hereby granted, free of charge, to any 2 | person obtaining a copy of this software and associated 3 | documentation files (the "Software"), to deal in the 4 | Software without restriction, including without 5 | limitation the rights to use, copy, modify, merge, 6 | publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software 8 | is furnished to do so, subject to the following 9 | conditions: 10 | 11 | The above copyright notice and this permission notice 12 | shall be included in all copies or substantial portions 13 | of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF 16 | ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED 17 | TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A 18 | PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT 19 | SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 20 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION 21 | OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR 22 | IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 23 | DEALINGS IN THE SOFTWARE. 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # diffusers-rs: A Diffusers API in Rust/Torch 2 | 3 | [![Build Status](https://github.com/LaurentMazare/diffusers-rs/workflows/Continuous%20integration/badge.svg)](https://github.com/LaurentMazare/diffusers-rs/actions) 4 | [![Latest version](https://img.shields.io/crates/v/diffusers.svg)](https://crates.io/crates/diffusers) 5 | [![Documentation](https://docs.rs/diffusers/badge.svg)](https://docs.rs/diffusers) 6 | ![License](https://img.shields.io/crates/l/diffusers.svg) 7 | 8 | ![rusty robot holding a torch](media/robot13.jpg) 9 | 10 | _A rusty robot holding a fire torch_, generated by stable diffusion using Rust and libtorch. 11 | 12 | The `diffusers` crate is a Rust equivalent to Huggingface's amazing 13 | [diffusers](https://github.com/huggingface/diffusers) Python library. 14 | It is based on the [tch crate](https://github.com/LaurentMazare/tch-rs/). 15 | The implementation supports running Stable Diffusion v1.5 and v2.1. 16 | 17 | ## Getting the weights 18 | 19 | The weight files can be retrieved from the HuggingFace model repos and should be 20 | moved in the `data/` directory. 21 | - For Stable Diffusion v2.1, get the `bpe_simple_vocab_16e6.txt`, 22 | `clip_v2.1.safetensors`, `unet_v2.1.safetensors`, and `vae_v2.1.safetensors` 23 | files from the 24 | [v2.1 repo](https://huggingface.co/lmz/rust-stable-diffusion-v2-1/tree/main/weights). 25 | - For Stable Diffusion v1.5, get the `bpe_simple_vocab_16e6.txt`, 26 | `pytorch_model.safetensors`, `unet.safetensors`, and `vae.safetensors` 27 | files from this 28 | [v1.5 repo](https://huggingface.co/lmz/rust-stable-diffusion-v1-5/tree/main/weights). 29 | - Alternatively, you can run the following python script. 30 | ```bash 31 | # Add --sd_version 1.5 to get the v1.5 weights rather than the v2.1. 32 | python3 ./scripts/get_weights.py 33 | ``` 34 | 35 | ## Running some example. 36 | 37 | ```bash 38 | cargo run --example stable-diffusion --features clap -- --prompt "A rusty robot holding a fire torch." 39 | ``` 40 | 41 | The final image is named `sd_final.png` by default. 42 | The default scheduler is the Denoising Diffusion Implicit Model scheduler (DDIM). The 43 | original paper and some code can be found in the [associated repo](https://github.com/ermongroup/ddim). 44 | 45 | This generates some images of rusty robots holding some torches! 46 | 47 | 48 | 49 | 50 | 51 | ## Image to Image Pipeline 52 | 53 | The stable diffusion model can also be used to generate an image based on 54 | another image. The following command runs this image to image pipeline: 55 | 56 | ```bash 57 | cargo run --example stable-diffusion-img2img --features clap -- --input-image media/in_img2img.jpg 58 | ``` 59 | 60 | The default prompt is "A fantasy landscape, trending on artstation.", but can 61 | be changed via the `-prompt` flag. 62 | 63 | ![img2img input](media/in_img2img.jpg) 64 | ![img2img output](media/out_img2img.jpg) 65 | 66 | ## Inpainting Pipeline 67 | 68 | Inpainting can be used to modify an existing image based on a prompt and modifying the part of the 69 | initial image specified by a mask. 70 | This requires different unet weights `unet-inpaint.safetensors` that could also be retrieved from this 71 | [repo](https://huggingface.co/lmz/rust-stable-diffusion-v1-5) and should also be 72 | placed in the `data/` directory. 73 | 74 | The following command runs this image to image pipeline: 75 | 76 | ```bash 77 | wget https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png -O sd_input.png 78 | wget https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png -O sd_mask.png 79 | cargo run --example stable-diffusion-inpaint --features clap --input-image sd_input.png --mask-image sd_mask.png 80 | ``` 81 | 82 | The default prompt is "Face of a yellow cat, high resolution, sitting on a park bench.", but can 83 | be changed via the `-prompt` flag. 84 | 85 | 86 | 87 | ![inpaint output](media/out_inpaint.jpg) 88 | 89 | ## ControlNet Pipeline 90 | 91 | The [ControlNet](https://github.com/lllyasviel/ControlNet) architecture can be 92 | used to control how stable diffusion generate images. This is to be used with 93 | the weights for stable diffusion 1.5 (see how to get these above). Additional 94 | weights have to be retrieved from this [HuggingFace 95 | repo](https://huggingface.co/lllyasviel/sd-controlnet-canny/blob/main/diffusion_pytorch_model.safetensors) 96 | and copied in `data/controlnet.safetensors`. 97 | 98 | The ControlNet pipeline takes as input a sample image, in the default mode it 99 | will perform edge detection on this image using the [Canny edge 100 | detector](https://en.wikipedia.org/wiki/Canny_edge_detector) and will use the 101 | resulting edge image as a guide. 102 | 103 | ```bash 104 | cargo run --example controlnet --features clap,image,imageproc -- \ 105 | --prompt "a rusty robot, lit by a fire torch, hd, very detailed" \ 106 | --input-image media/vermeer.jpg 107 | ``` 108 | The `media/vermeer.jpg` image is the well known painting on the left hand side, 109 | this results in the right hand side image after performing edge detection. 110 | 111 | 112 | 113 | Using only the edge detection image, the ControlNet model generate the following 114 | samples. 115 | 116 | 122 | 123 | ## FAQ 124 | 125 | ### Memory Issues 126 | 127 | This requires a GPU with more than 8GB of memory, as a fallback the CPU version can be used 128 | but is slower. 129 | 130 | ```bash 131 | cargo run --example stable-diffusion --features clap -- --prompt "A very rusty robot holding a fire torch." --cpu all 132 | ``` 133 | 134 | For a GPU with 8GB, one can use the [fp16 weights for the UNet](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/fp16/unet) and put only the UNet on the GPU. 135 | 136 | ```bash 137 | PYTORCH_CUDA_ALLOC_CONF=garbage_collection_threshold:0.6,max_split_size_mb:128 RUST_BACKTRACE=1 CARGO_TARGET_DIR=target2 cargo run \ 138 | --example stable-diffusion --features clap -- --cpu vae --cpu clip \ 139 | --unet-weights data/unet-fp16.safetensors 140 | ``` 141 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | let os = std::env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); 3 | match os.as_str() { 4 | "linux" | "windows" => { 5 | if let Some(lib_path) = std::env::var_os("DEP_TCH_LIBTORCH_LIB") { 6 | println!("cargo:rustc-link-arg=-Wl,-rpath={}", lib_path.to_string_lossy()); 7 | } 8 | println!("cargo:rustc-link-arg=-Wl,--no-as-needed"); 9 | println!("cargo:rustc-link-arg=-Wl,--copy-dt-needed-entries"); 10 | println!("cargo:rustc-link-arg=-ltorch"); 11 | } 12 | _ => {} 13 | } 14 | } 15 | -------------------------------------------------------------------------------- /clippy.toml: -------------------------------------------------------------------------------- 1 | too-many-arguments-threshold = 20 2 | -------------------------------------------------------------------------------- /examples/controlnet/main.rs: -------------------------------------------------------------------------------- 1 | // The additional weight files can be found on HuggingFace hub: 2 | // https://huggingface.co/lllyasviel/sd-controlnet-canny/blob/main/diffusion_pytorch_model.safetensors 3 | // This has to be copied in data/controlnet.safetensors 4 | use clap::Parser; 5 | use diffusers::pipelines::stable_diffusion; 6 | use diffusers::transformers::clip; 7 | use tch::{nn, nn::Module, Device, Kind, Tensor}; 8 | 9 | const GUIDANCE_SCALE: f64 = 7.5; 10 | 11 | #[derive(Parser)] 12 | #[command(author, version, about, long_about = None)] 13 | struct Args { 14 | /// The input image. 15 | #[arg(long, value_name = "FILE")] 16 | input_image: String, 17 | 18 | /// The prompt to be used for image generation. 19 | #[arg( 20 | long, 21 | default_value = "A very realistic photo of a rusty robot walking on a sandy beach" 22 | )] 23 | prompt: String, 24 | 25 | /// When set, use the CPU for the listed devices, can be 'all', 'unet', 'clip', etc. 26 | /// Multiple values can be set. 27 | #[arg(long)] 28 | cpu: Vec, 29 | 30 | /// The height in pixels of the generated image. 31 | #[arg(long)] 32 | height: Option, 33 | 34 | /// The width in pixels of the generated image. 35 | #[arg(long)] 36 | width: Option, 37 | 38 | /// The UNet weight file, in .ot or .safetensors format. 39 | #[arg(long, value_name = "FILE", default_value = "data/unet.safetensors")] 40 | unet_weights: String, 41 | 42 | /// The ControlNet weight file, in .ot or .safetensors format. 43 | #[arg(long, value_name = "FILE", default_value = "data/controlnet.safetensors")] 44 | controlnet_weights: String, 45 | 46 | /// The CLIP weight file, in .ot or .safetensors format. 47 | #[arg(long, value_name = "FILE", default_value = "data/pytorch_model.safetensors")] 48 | clip_weights: String, 49 | 50 | /// The VAE weight file, in .ot or .safetensors format. 51 | #[arg(long, value_name = "FILE", default_value = "data/vae.safetensors")] 52 | vae_weights: String, 53 | 54 | #[arg(long, value_name = "FILE", default_value = "data/bpe_simple_vocab_16e6.txt")] 55 | /// The file specifying the vocabulary to used for tokenization. 56 | vocab_file: String, 57 | 58 | /// The size of the sliced attention or 0 for automatic slicing (disabled by default) 59 | #[arg(long)] 60 | sliced_attention_size: Option, 61 | 62 | /// The number of steps to run the diffusion for. 63 | #[arg(long, default_value_t = 30)] 64 | n_steps: usize, 65 | 66 | /// The random seed to be used for the generation. 67 | #[arg(long, default_value_t = 32)] 68 | seed: i64, 69 | 70 | /// The number of samples to generate. 71 | #[arg(long, default_value_t = 1)] 72 | num_samples: i64, 73 | 74 | /// The name of the final image to generate. 75 | #[arg(long, value_name = "FILE", default_value = "sd_final.png")] 76 | final_image: String, 77 | 78 | /// Use autocast (disabled by default as it may use more memory in some cases). 79 | #[arg(long, action)] 80 | autocast: bool, 81 | 82 | /// Generate intermediary images at each step. 83 | #[arg(long, action)] 84 | intermediary_images: bool, 85 | 86 | /// The type of ControlNet model to be used. 87 | #[arg(long, value_enum, default_value = "canny")] 88 | control_type: ControlType, 89 | } 90 | 91 | fn output_filename( 92 | basename: &str, 93 | sample_idx: i64, 94 | num_samples: i64, 95 | timestep_idx: Option, 96 | ) -> String { 97 | let filename = if num_samples > 1 { 98 | match basename.rsplit_once('.') { 99 | None => format!("{basename}.{sample_idx}.png"), 100 | Some((filename_no_extension, extension)) => { 101 | format!("{filename_no_extension}.{sample_idx}.{extension}") 102 | } 103 | } 104 | } else { 105 | basename.to_string() 106 | }; 107 | match timestep_idx { 108 | None => filename, 109 | Some(timestep_idx) => match filename.rsplit_once('.') { 110 | None => format!("{filename}-{timestep_idx}.png"), 111 | Some((filename_no_extension, extension)) => { 112 | format!("{filename_no_extension}-{timestep_idx}.{extension}") 113 | } 114 | }, 115 | } 116 | } 117 | 118 | #[derive(Debug, Clone, Copy, clap::ValueEnum)] 119 | enum ControlType { 120 | Canny, 121 | } 122 | 123 | impl ControlType { 124 | fn image_preprocess>(&self, path: T) -> anyhow::Result { 125 | match self { 126 | Self::Canny => { 127 | // TODO: Use an implementation of the Canny edge detector in PyTorch 128 | // and remove this dependency. 129 | use image::EncodableLayout; 130 | let image = image::open(path)?.to_luma8(); 131 | let edges = imageproc::edges::canny(&image, 50., 100.); 132 | let tensor = Tensor::f_from_data_size( 133 | edges.as_bytes(), 134 | &[1, 1, edges.height() as i64, edges.width() as i64], 135 | Kind::Uint8, 136 | )?; 137 | let tensor = Tensor::f_concat(&[&tensor, &tensor, &tensor], 1)?; 138 | // In order to look at the detected edges, uncomment the following line: 139 | // tch::vision::image::save(&tensor.squeeze(), "/tmp/edges.png").unwrap(); 140 | let tensor = Tensor::f_concat(&[&tensor, &tensor], 0)?; 141 | Ok(tensor.to_kind(Kind::Float) / 255.) 142 | } 143 | } 144 | } 145 | } 146 | 147 | fn run(args: Args) -> anyhow::Result<()> { 148 | let Args { 149 | prompt, 150 | cpu, 151 | height, 152 | width, 153 | n_steps, 154 | seed, 155 | vocab_file, 156 | final_image, 157 | sliced_attention_size, 158 | num_samples, 159 | input_image, 160 | unet_weights, 161 | vae_weights, 162 | clip_weights, 163 | controlnet_weights, 164 | control_type, 165 | .. 166 | } = args; 167 | tch::maybe_init_cuda(); 168 | println!("Cuda available: {}", tch::Cuda::is_available()); 169 | println!("Cudnn available: {}", tch::Cuda::cudnn_is_available()); 170 | println!("MPS available: {}", tch::utils::has_mps()); 171 | 172 | let sd_config = 173 | stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width); 174 | 175 | let image = control_type.image_preprocess(input_image)?; 176 | let device_setup = diffusers::utils::DeviceSetup::new(cpu); 177 | let clip_device = device_setup.get("clip"); 178 | let vae_device = device_setup.get("vae"); 179 | let unet_device = device_setup.get("unet"); 180 | let scheduler = sd_config.build_scheduler(n_steps); 181 | 182 | let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?; 183 | println!("Running with prompt \"{prompt}\"."); 184 | let tokens = tokenizer.encode(&prompt)?; 185 | let tokens: Vec = tokens.into_iter().map(|x| x as i64).collect(); 186 | let tokens = Tensor::from_slice(&tokens).view((1, -1)).to(clip_device); 187 | let uncond_tokens = tokenizer.encode("")?; 188 | let uncond_tokens: Vec = uncond_tokens.into_iter().map(|x| x as i64).collect(); 189 | let uncond_tokens = Tensor::from_slice(&uncond_tokens).view((1, -1)).to(clip_device); 190 | 191 | let no_grad_guard = tch::no_grad_guard(); 192 | 193 | println!("Building the Clip transformer."); 194 | let text_model = sd_config.build_clip_transformer(&clip_weights, clip_device)?; 195 | let text_embeddings = text_model.forward(&tokens); 196 | let uncond_embeddings = text_model.forward(&uncond_tokens); 197 | let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0).to(unet_device); 198 | 199 | println!("Building the autoencoder."); 200 | let vae = sd_config.build_vae(&vae_weights, vae_device)?; 201 | println!("Building the unet."); 202 | let unet = sd_config.build_unet(&unet_weights, unet_device, 4)?; 203 | println!("Building the controlnet."); 204 | let mut vs_controlnet = nn::VarStore::new(unet_device); 205 | let controlnet = 206 | diffusers::models::controlnet::ControlNet::new(vs_controlnet.root(), 4, Default::default()); 207 | vs_controlnet.load(controlnet_weights)?; 208 | 209 | let bsize = 1; 210 | for idx in 0..num_samples { 211 | tch::manual_seed(seed + idx); 212 | let mut latents = Tensor::randn( 213 | [bsize, 4, sd_config.height / 8, sd_config.width / 8], 214 | (Kind::Float, unet_device), 215 | ); 216 | 217 | // scale the initial noise by the standard deviation required by the scheduler 218 | latents *= scheduler.init_noise_sigma(); 219 | 220 | for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { 221 | println!("Timestep {timestep_index}/{n_steps}"); 222 | let latent_model_input = Tensor::cat(&[&latents, &latents], 0); 223 | 224 | let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); 225 | let (down_block_additional_residuals, mid_block_additional_residuals) = controlnet 226 | .forward(&latent_model_input, timestep as f64, &text_embeddings, &image, 1.); 227 | let noise_pred = unet.forward_with_additional_residuals( 228 | &latent_model_input, 229 | timestep as f64, 230 | &text_embeddings, 231 | Some(&down_block_additional_residuals), 232 | Some(&mid_block_additional_residuals), 233 | ); 234 | let noise_pred = noise_pred.chunk(2, 0); 235 | let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); 236 | let noise_pred = 237 | noise_pred_uncond + (noise_pred_text - noise_pred_uncond) * GUIDANCE_SCALE; 238 | latents = scheduler.step(&noise_pred, timestep, &latents); 239 | 240 | if args.intermediary_images { 241 | let latents = latents.to(vae_device); 242 | let image = vae.decode(&(&latents / 0.18215)); 243 | let image = (image / 2 + 0.5).clamp(0., 1.).to_device(Device::Cpu); 244 | let image = (image * 255.).to_kind(Kind::Uint8); 245 | let final_image = 246 | output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); 247 | tch::vision::image::save(&image, final_image)?; 248 | } 249 | } 250 | 251 | println!("Generating the final image for sample {}/{}.", idx + 1, num_samples); 252 | let latents = latents.to(vae_device); 253 | let image = vae.decode(&(&latents / 0.18215)); 254 | let image = (image / 2 + 0.5).clamp(0., 1.).to_device(Device::Cpu); 255 | let image = (image * 255.).to_kind(Kind::Uint8); 256 | let final_image = output_filename(&final_image, idx + 1, num_samples, None); 257 | tch::vision::image::save(&image, final_image)?; 258 | } 259 | 260 | drop(no_grad_guard); 261 | Ok(()) 262 | } 263 | 264 | fn main() -> anyhow::Result<()> { 265 | let args = Args::parse(); 266 | if !args.autocast { 267 | run(args) 268 | } else { 269 | tch::autocast(true, || run(args)) 270 | } 271 | } 272 | -------------------------------------------------------------------------------- /examples/stable-diffusion-img2img/main.rs: -------------------------------------------------------------------------------- 1 | // Stable diffusion image to image pipeline. 2 | // See the main stable-diffusion example for how to get the weights. 3 | // 4 | // This has been mostly adapted from looking at the diff between the sample 5 | // diffusion standard and img2img pipelines in the diffusers library. 6 | // patdiff src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion{,_img2img}.py 7 | // 8 | // Suggestions: 9 | // image: https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg 10 | // prompt = "A fantasy landscape, trending on artstation" 11 | use clap::Parser; 12 | use diffusers::pipelines::stable_diffusion; 13 | use diffusers::transformers::clip; 14 | use tch::{nn::Module, Device, Kind, Tensor}; 15 | 16 | const GUIDANCE_SCALE: f64 = 7.5; 17 | 18 | #[derive(Parser)] 19 | #[command(author, version, about, long_about = None)] 20 | struct Args { 21 | /// The input image. 22 | #[arg(long, value_name = "FILE")] 23 | input_image: String, 24 | 25 | /// The prompt to be used for image generation. 26 | #[arg(long, default_value = "A fantasy landscape, trending on artstation.")] 27 | prompt: String, 28 | 29 | /// When set, use the CPU for the listed devices, can be 'all', 'unet', 'clip', etc. 30 | /// Multiple values can be set. 31 | #[arg(long)] 32 | cpu: Vec, 33 | 34 | /// The UNet weight file, in .ot or .safetensors format. 35 | #[arg(long, value_name = "FILE")] 36 | unet_weights: Option, 37 | 38 | /// The CLIP weight file, in .ot or .safetensors format. 39 | #[arg(long, value_name = "FILE")] 40 | clip_weights: Option, 41 | 42 | /// The VAE weight file, in .ot or .safetensors format. 43 | #[arg(long, value_name = "FILE")] 44 | vae_weights: Option, 45 | 46 | #[arg(long, value_name = "FILE", default_value = "data/bpe_simple_vocab_16e6.txt")] 47 | /// The file specifying the vocabulary to used for tokenization. 48 | vocab_file: String, 49 | 50 | /// The size of the sliced attention or 0 for automatic slicing (disabled by default) 51 | #[arg(long)] 52 | sliced_attention_size: Option, 53 | 54 | /// The number of steps to run the diffusion for. 55 | #[arg(long, default_value_t = 30)] 56 | n_steps: usize, 57 | 58 | /// The strength, indicates how much to transform the initial image. The 59 | /// value must be between 0 and 1, a value of 1 discards the initial image 60 | /// information. 61 | #[arg(long, default_value_t = 0.8)] 62 | strength: f64, 63 | 64 | /// The random seed to be used for the generation. 65 | #[arg(long, default_value_t = 32)] 66 | seed: i64, 67 | 68 | /// The number of samples to generate. 69 | #[arg(long, default_value_t = 1)] 70 | num_samples: i64, 71 | 72 | /// The name of the final image to generate. 73 | #[arg(long, value_name = "FILE", default_value = "sd_final.png")] 74 | final_image: String, 75 | 76 | /// Do not use autocast. 77 | #[arg(long, action)] 78 | no_autocast: bool, 79 | 80 | #[arg(long, value_enum, default_value = "v2-1")] 81 | sd_version: StableDiffusionVersion, 82 | } 83 | 84 | #[derive(Debug, Clone, Copy, clap::ValueEnum)] 85 | enum StableDiffusionVersion { 86 | V1_5, 87 | V2_1, 88 | } 89 | 90 | impl Args { 91 | fn clip_weights(&self) -> String { 92 | match &self.clip_weights { 93 | Some(w) => w.clone(), 94 | None => match self.sd_version { 95 | StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), 96 | StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), 97 | }, 98 | } 99 | } 100 | 101 | fn vae_weights(&self) -> String { 102 | match &self.vae_weights { 103 | Some(w) => w.clone(), 104 | None => match self.sd_version { 105 | StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), 106 | StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), 107 | }, 108 | } 109 | } 110 | 111 | fn unet_weights(&self) -> String { 112 | match &self.unet_weights { 113 | Some(w) => w.clone(), 114 | None => match self.sd_version { 115 | StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(), 116 | StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(), 117 | }, 118 | } 119 | } 120 | } 121 | 122 | fn image_preprocess>(path: T) -> anyhow::Result { 123 | let image = tch::vision::image::load(path)?; 124 | let (_num_channels, height, width) = image.size3()?; 125 | let height = height - height % 32; 126 | let width = width - width % 32; 127 | let image = tch::vision::image::resize(&image, width, height)?; 128 | Ok((image / 255. * 2. - 1.).unsqueeze(0)) 129 | } 130 | 131 | fn run(args: Args) -> anyhow::Result<()> { 132 | let clip_weights = args.clip_weights(); 133 | let vae_weights = args.vae_weights(); 134 | let unet_weights = args.unet_weights(); 135 | let Args { 136 | prompt, 137 | cpu, 138 | n_steps, 139 | seed, 140 | final_image, 141 | sliced_attention_size, 142 | num_samples, 143 | strength, 144 | input_image, 145 | sd_version, 146 | vocab_file, 147 | .. 148 | } = args; 149 | if !(0. ..=1.).contains(&strength) { 150 | anyhow::bail!("strength should be between 0 and 1, got {strength}") 151 | } 152 | tch::maybe_init_cuda(); 153 | println!("Cuda available: {}", tch::Cuda::is_available()); 154 | println!("Cudnn available: {}", tch::Cuda::cudnn_is_available()); 155 | let sd_config = match sd_version { 156 | StableDiffusionVersion::V1_5 => { 157 | stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, None, None) 158 | } 159 | StableDiffusionVersion::V2_1 => { 160 | stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, None, None) 161 | } 162 | }; 163 | 164 | let init_image = image_preprocess(input_image)?; 165 | let device_setup = diffusers::utils::DeviceSetup::new(cpu); 166 | let clip_device = device_setup.get("clip"); 167 | let vae_device = device_setup.get("vae"); 168 | let unet_device = device_setup.get("unet"); 169 | let scheduler = sd_config.build_scheduler(n_steps); 170 | 171 | let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?; 172 | println!("Running with prompt \"{prompt}\"."); 173 | let tokens = tokenizer.encode(&prompt)?; 174 | let tokens: Vec = tokens.into_iter().map(|x| x as i64).collect(); 175 | let tokens = Tensor::from_slice(&tokens).view((1, -1)).to(clip_device); 176 | let uncond_tokens = tokenizer.encode("")?; 177 | let uncond_tokens: Vec = uncond_tokens.into_iter().map(|x| x as i64).collect(); 178 | let uncond_tokens = Tensor::from_slice(&uncond_tokens).view((1, -1)).to(clip_device); 179 | 180 | let no_grad_guard = tch::no_grad_guard(); 181 | 182 | println!("Building the Clip transformer."); 183 | let text_model = sd_config.build_clip_transformer(&clip_weights, clip_device)?; 184 | let text_embeddings = text_model.forward(&tokens); 185 | let uncond_embeddings = text_model.forward(&uncond_tokens); 186 | let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0).to(unet_device); 187 | 188 | println!("Building the autoencoder."); 189 | let vae = sd_config.build_vae(&vae_weights, vae_device)?; 190 | println!("Building the unet."); 191 | let unet = sd_config.build_unet(&unet_weights, unet_device, 4)?; 192 | 193 | println!("Generating the latent from the input image {:?}.", init_image.size()); 194 | let init_image = init_image.to(vae_device); 195 | let init_latent_dist = vae.encode(&init_image); 196 | 197 | let t_start = n_steps - (n_steps as f64 * strength) as usize; 198 | 199 | for idx in 0..num_samples { 200 | tch::manual_seed(seed + idx); 201 | let latents = (init_latent_dist.sample() * 0.18215).to(unet_device); 202 | let timesteps = scheduler.timesteps(); 203 | let noise = latents.randn_like(); 204 | let mut latents = scheduler.add_noise(&latents, noise, timesteps[t_start]); 205 | 206 | for (timestep_index, ×tep) in timesteps.iter().enumerate() { 207 | if timestep_index < t_start { 208 | continue; 209 | } 210 | println!("Timestep {timestep_index}/{n_steps}"); 211 | let latent_model_input = Tensor::cat(&[&latents, &latents], 0); 212 | let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); 213 | 214 | let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings); 215 | let noise_pred = noise_pred.chunk(2, 0); 216 | let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); 217 | let noise_pred = 218 | noise_pred_uncond + (noise_pred_text - noise_pred_uncond) * GUIDANCE_SCALE; 219 | latents = scheduler.step(&noise_pred, timestep, &latents); 220 | } 221 | 222 | println!("Generating the final image for sample {}/{}.", idx + 1, num_samples); 223 | let latents = latents.to(vae_device); 224 | let image = vae.decode(&(&latents / 0.18215)); 225 | let image = (image / 2 + 0.5).clamp(0., 1.).to_device(Device::Cpu); 226 | let image = (image * 255.).to_kind(Kind::Uint8); 227 | let final_image = if num_samples > 1 { 228 | match final_image.rsplit_once('.') { 229 | None => format!("{}.{}.png", final_image, idx + 1), 230 | Some((filename_no_extension, extension)) => { 231 | format!("{}.{}.{}", filename_no_extension, idx + 1, extension) 232 | } 233 | } 234 | } else { 235 | final_image.clone() 236 | }; 237 | tch::vision::image::save(&image, final_image)?; 238 | } 239 | 240 | drop(no_grad_guard); 241 | Ok(()) 242 | } 243 | 244 | fn main() -> anyhow::Result<()> { 245 | let args = Args::parse(); 246 | if args.no_autocast { 247 | run(args) 248 | } else { 249 | tch::autocast(true, || run(args)) 250 | } 251 | } 252 | -------------------------------------------------------------------------------- /examples/stable-diffusion-inpaint/main.rs: -------------------------------------------------------------------------------- 1 | // Stable diffusion inpainting pipeline. 2 | // See the main stable-diffusion example for how to get the weights. 3 | // 4 | // This has been mostly adapted from looking at the diff between the sample 5 | // diffusion standard and inpaint pipelines in the diffusers library. 6 | // patdiff src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion{,_inpaint}.py 7 | // 8 | // The unet weights should be downloaded from: 9 | // https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/unet/diffusion_pytorch_model.bin 10 | // Or for the fp16 version: 11 | // https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/fp16/unet/diffusion_pytorch_model.bin 12 | // 13 | // Sample input image: 14 | // https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png 15 | // Sample mask: 16 | // https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png 17 | use clap::Parser; 18 | use diffusers::pipelines::stable_diffusion; 19 | use diffusers::transformers::clip; 20 | use tch::{nn::Module, Device, Kind, Tensor}; 21 | 22 | const GUIDANCE_SCALE: f64 = 7.5; 23 | 24 | #[derive(Parser)] 25 | #[command(author, version, about, long_about = None)] 26 | struct Args { 27 | /// The input image that will be inpainted. 28 | #[arg(long, value_name = "FILE")] 29 | input_image: String, 30 | 31 | /// The mask image to be used for inpainting, white pixels are repainted whereas black pixels 32 | /// are preserved. 33 | #[arg(long, value_name = "FILE")] 34 | mask_image: String, 35 | 36 | /// The prompt to be used for image generation. 37 | #[arg(long, default_value = "Face of a yellow cat, high resolution, sitting on a park bench")] 38 | prompt: String, 39 | 40 | /// When set, use the CPU for the listed devices, can be 'all', 'unet', 'clip', etc. 41 | /// Multiple values can be set. 42 | #[arg(long)] 43 | cpu: Vec, 44 | 45 | /// The height in pixels of the generated image. 46 | #[arg(long)] 47 | height: Option, 48 | 49 | /// The width in pixels of the generated image. 50 | #[arg(long)] 51 | width: Option, 52 | 53 | #[arg(long, value_name = "FILE", default_value = "data/bpe_simple_vocab_16e6.txt")] 54 | /// The file specifying the vocabulary to used for tokenization. 55 | vocab_file: String, 56 | 57 | /// The UNet weight file, in .ot or .safetensors format. 58 | #[arg(long, value_name = "FILE")] 59 | unet_weights: Option, 60 | 61 | /// The CLIP weight file, in .ot or .safetensors format. 62 | #[arg(long, value_name = "FILE")] 63 | clip_weights: Option, 64 | 65 | /// The VAE weight file, in .ot or .safetensors format. 66 | #[arg(long, value_name = "FILE")] 67 | vae_weights: Option, 68 | 69 | /// The size of the sliced attention or 0 for automatic slicing (disabled by default) 70 | #[arg(long)] 71 | sliced_attention_size: Option, 72 | 73 | /// The number of steps to run the diffusion for. 74 | #[arg(long, default_value_t = 30)] 75 | n_steps: usize, 76 | 77 | /// The random seed to be used for the generation. 78 | #[arg(long, default_value_t = 32)] 79 | seed: i64, 80 | 81 | /// The number of samples to generate. 82 | #[arg(long, default_value_t = 1)] 83 | num_samples: i64, 84 | 85 | /// The name of the final image to generate. 86 | #[arg(long, value_name = "FILE", default_value = "sd_final.png")] 87 | final_image: String, 88 | 89 | #[arg(long, value_enum, default_value = "v1-5")] 90 | sd_version: StableDiffusionVersion, 91 | } 92 | 93 | #[derive(Debug, Clone, Copy, clap::ValueEnum)] 94 | enum StableDiffusionVersion { 95 | V1_5, 96 | V2_1, 97 | } 98 | 99 | impl Args { 100 | fn clip_weights(&self) -> String { 101 | match &self.clip_weights { 102 | Some(w) => w.clone(), 103 | None => match self.sd_version { 104 | StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), 105 | StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), 106 | }, 107 | } 108 | } 109 | 110 | fn vae_weights(&self) -> String { 111 | match &self.vae_weights { 112 | Some(w) => w.clone(), 113 | None => match self.sd_version { 114 | StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), 115 | StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), 116 | }, 117 | } 118 | } 119 | 120 | fn unet_weights(&self) -> String { 121 | match &self.unet_weights { 122 | Some(w) => w.clone(), 123 | None => match self.sd_version { 124 | StableDiffusionVersion::V1_5 => "data/unet-inpaint.safetensors".to_string(), 125 | StableDiffusionVersion::V2_1 => "data/unet-inpaint_v2.1.safetensors".to_string(), 126 | }, 127 | } 128 | } 129 | } 130 | 131 | fn prepare_mask_and_masked_image>( 132 | path_input: T, 133 | path_mask: T, 134 | ) -> anyhow::Result<(Tensor, Tensor)> { 135 | let image = tch::vision::image::load(path_input)?; 136 | let image = image / 255. * 2. - 1.; 137 | 138 | let mask = tch::vision::image::load(path_mask)?; 139 | let mask = mask.mean_dim(Some([0].as_slice()), true, Kind::Float); 140 | let mask = mask.ge(122.5).totype(Kind::Float); 141 | let masked_image: Tensor = image * (1 - &mask); 142 | Ok((mask.unsqueeze(0), masked_image.unsqueeze(0))) 143 | } 144 | 145 | fn run(args: Args) -> anyhow::Result<()> { 146 | let clip_weights = args.clip_weights(); 147 | let vae_weights = args.vae_weights(); 148 | let unet_weights = args.unet_weights(); 149 | let Args { 150 | prompt, 151 | cpu, 152 | height, 153 | width, 154 | n_steps, 155 | seed, 156 | final_image, 157 | sliced_attention_size, 158 | num_samples, 159 | input_image, 160 | mask_image, 161 | vocab_file, 162 | sd_version, 163 | .. 164 | } = args; 165 | tch::maybe_init_cuda(); 166 | println!("Cuda available: {}", tch::Cuda::is_available()); 167 | println!("Cudnn available: {}", tch::Cuda::cudnn_is_available()); 168 | let sd_config = match sd_version { 169 | StableDiffusionVersion::V1_5 => { 170 | stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) 171 | } 172 | StableDiffusionVersion::V2_1 => stable_diffusion::StableDiffusionConfig::v2_1_inpaint( 173 | sliced_attention_size, 174 | height, 175 | width, 176 | ), 177 | }; 178 | let (mask, masked_image) = prepare_mask_and_masked_image(input_image, mask_image)?; 179 | println!("Loaded input image and mask, {:?} {:?}.", masked_image.size(), mask.size()); 180 | let device_setup = diffusers::utils::DeviceSetup::new(cpu); 181 | let clip_device = device_setup.get("clip"); 182 | let vae_device = device_setup.get("vae"); 183 | let unet_device = device_setup.get("unet"); 184 | let scheduler = sd_config.build_scheduler(n_steps); 185 | 186 | let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?; 187 | println!("Running with prompt \"{prompt}\"."); 188 | let tokens = tokenizer.encode(&prompt)?; 189 | let tokens: Vec = tokens.into_iter().map(|x| x as i64).collect(); 190 | let tokens = Tensor::from_slice(&tokens).view((1, -1)).to(clip_device); 191 | let uncond_tokens = tokenizer.encode("")?; 192 | let uncond_tokens: Vec = uncond_tokens.into_iter().map(|x| x as i64).collect(); 193 | let uncond_tokens = Tensor::from_slice(&uncond_tokens).view((1, -1)).to(clip_device); 194 | 195 | let no_grad_guard = tch::no_grad_guard(); 196 | 197 | println!("Building the Clip transformer."); 198 | let text_model = sd_config.build_clip_transformer(&clip_weights, clip_device)?; 199 | let text_embeddings = text_model.forward(&tokens); 200 | let uncond_embeddings = text_model.forward(&uncond_tokens); 201 | let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0).to(unet_device); 202 | 203 | println!("Building the autoencoder."); 204 | let vae = sd_config.build_vae(&vae_weights, vae_device)?; 205 | println!("Building the unet."); 206 | let unet = sd_config.build_unet(&unet_weights, unet_device, 9)?; 207 | 208 | let mask = mask.upsample_nearest2d([sd_config.height / 8, sd_config.width / 8], None, None); 209 | let mask = Tensor::cat(&[&mask, &mask], 0).to_device(unet_device); 210 | let masked_image_dist = vae.encode(&masked_image.to_device(vae_device)); 211 | 212 | let bsize = 1; 213 | for idx in 0..num_samples { 214 | tch::manual_seed(seed + idx); 215 | let masked_image_latents = (masked_image_dist.sample() * 0.18215).to(unet_device); 216 | let masked_image_latents = Tensor::cat(&[&masked_image_latents, &masked_image_latents], 0); 217 | let mut latents = Tensor::randn( 218 | [bsize, 4, sd_config.height / 8, sd_config.width / 8], 219 | (Kind::Float, unet_device), 220 | ); 221 | 222 | // scale the initial noise by the standard deviation required by the scheduler 223 | latents *= scheduler.init_noise_sigma(); 224 | 225 | for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { 226 | println!("Timestep {timestep_index}/{n_steps}"); 227 | let latent_model_input = Tensor::cat(&[&latents, &latents], 0); 228 | 229 | // concat latents, mask, masked_image_latents in the channel dimension 230 | let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); 231 | let latent_model_input = 232 | Tensor::cat(&[&latent_model_input, &mask, &masked_image_latents], 1); 233 | let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings); 234 | let noise_pred = noise_pred.chunk(2, 0); 235 | let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); 236 | let noise_pred = 237 | noise_pred_uncond + (noise_pred_text - noise_pred_uncond) * GUIDANCE_SCALE; 238 | latents = scheduler.step(&noise_pred, timestep, &latents); 239 | } 240 | 241 | println!("Generating the final image for sample {}/{}.", idx + 1, num_samples); 242 | let latents = latents.to(vae_device); 243 | let image = vae.decode(&(&latents / 0.18215)); 244 | let image = (image / 2 + 0.5).clamp(0., 1.).to_device(Device::Cpu); 245 | let image = (image * 255.).to_kind(Kind::Uint8); 246 | let final_image = if num_samples > 1 { 247 | match final_image.rsplit_once('.') { 248 | None => format!("{}.{}.png", final_image, idx + 1), 249 | Some((filename_no_extension, extension)) => { 250 | format!("{}.{}.{}", filename_no_extension, idx + 1, extension) 251 | } 252 | } 253 | } else { 254 | final_image.clone() 255 | }; 256 | tch::vision::image::save(&image, final_image)?; 257 | } 258 | 259 | drop(no_grad_guard); 260 | Ok(()) 261 | } 262 | 263 | fn main() -> anyhow::Result<()> { 264 | let args = Args::parse(); 265 | run(args) 266 | } 267 | -------------------------------------------------------------------------------- /examples/stable-diffusion/main.rs: -------------------------------------------------------------------------------- 1 | // Stable Diffusion implementation inspired: 2 | // - Huggingface's amazing diffuser Python api: https://huggingface.co/blog/annotated-diffusion 3 | // - Huggingface's (also amazing) blog post: https://huggingface.co/blog/annotated-diffusion 4 | // - The "Grokking Stable Diffusion" notebook by Jonathan Whitaker. 5 | // https://colab.research.google.com/drive/1dlgggNa5Mz8sEAGU0wFCHhGLFooW_pf1?usp=sharing 6 | // 7 | // In order to run this, the weights first have to be downloaded and converted by following 8 | // the instructions below. 9 | // 10 | // mkdir -p data && cd data 11 | // wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz 12 | // gunzip bpe_simple_vocab_16e6.txt.gz 13 | // 14 | // Getting the weights then depend on the stable diffusion version (1.5 or 2.1). 15 | // 16 | // # How to get the weights for Stable Diffusion 2.1. 17 | // 18 | // 1. Clip Encoding Weights 19 | // wget https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/fp16/text_encoder/pytorch_model.bin -O clip.bin 20 | // From python, extract the weights and save them as a .npz file. 21 | // import torch 22 | // from safetensors.torch import save_file 23 | // 24 | // model = torch.load("./clip.bin") 25 | // save_file("./clip_v2.1.safetensors", **{k: v.numpy() for k, v in model.items() if "text_model" in k}) 26 | // 27 | // 2. VAE and Unet Weights 28 | // wget https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/fp16/vae/diffusion_pytorch_model.bin -O vae.bin 29 | // wget https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/fp16/unet/diffusion_pytorch_model.bin -O unet.bin 30 | // 31 | // import torch 32 | // from safetensors.torch import save_file 33 | // model = torch.load("./vae.bin") 34 | // save_file(dict(model), './vae.safetensors') 35 | // model = torch.load("./unet.bin") 36 | // save_file(dict(model), './unet.safetensors') 37 | // 38 | // # How to get the weights for Stable Diffusion 1.5. 39 | // 40 | // 1. Clip Encoding Weights 41 | // wget https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin 42 | // From python, extract the weights and save them as a .npz file. 43 | // import torch 44 | // from safetensors.torch import save_file 45 | // 46 | // model = torch.load("./pytorch_model.bin") 47 | // save_file("./pytorch_model.safetensors", **{k: v.numpy() for k, v in model.items() if "text_model" in k}) 48 | // 49 | // 2. VAE and Unet Weights 50 | // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/diffusion_pytorch_model.bin 51 | // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/diffusion_pytorch_model.bin 52 | // 53 | // import torch 54 | // from safetensors.torch import save_file 55 | // model = torch.load("./vae.bin") 56 | // save_file(dict(model), './vae.safetensors') 57 | // model = torch.load("./unet.bin") 58 | // save_file(dict(model), './unet.safetensors') 59 | use clap::Parser; 60 | use diffusers::pipelines::stable_diffusion; 61 | use diffusers::transformers::clip; 62 | use tch::{nn::Module, Device, Kind, Tensor}; 63 | 64 | const GUIDANCE_SCALE: f64 = 7.5; 65 | 66 | #[derive(Parser)] 67 | #[command(author, version, about, long_about = None)] 68 | struct Args { 69 | /// The prompt to be used for image generation. 70 | #[arg( 71 | long, 72 | default_value = "A very realistic photo of a rusty robot walking on a sandy beach" 73 | )] 74 | prompt: String, 75 | 76 | /// When set, use the CPU for the listed devices, can be 'all', 'unet', 'clip', etc. 77 | /// Multiple values can be set. 78 | #[arg(long)] 79 | cpu: Vec, 80 | 81 | /// The height in pixels of the generated image. 82 | #[arg(long)] 83 | height: Option, 84 | 85 | /// The width in pixels of the generated image. 86 | #[arg(long)] 87 | width: Option, 88 | 89 | /// The UNet weight file, in .ot or .safetensors format. 90 | #[arg(long, value_name = "FILE")] 91 | unet_weights: Option, 92 | 93 | /// The CLIP weight file, in .ot or .safetensors format. 94 | #[arg(long, value_name = "FILE")] 95 | clip_weights: Option, 96 | 97 | /// The VAE weight file, in .ot or .safetensors format. 98 | #[arg(long, value_name = "FILE")] 99 | vae_weights: Option, 100 | 101 | #[arg(long, value_name = "FILE", default_value = "data/bpe_simple_vocab_16e6.txt")] 102 | /// The file specifying the vocabulary to used for tokenization. 103 | vocab_file: String, 104 | 105 | /// The size of the sliced attention or 0 for automatic slicing (disabled by default) 106 | #[arg(long)] 107 | sliced_attention_size: Option, 108 | 109 | /// The number of steps to run the diffusion for. 110 | #[arg(long, default_value_t = 30)] 111 | n_steps: usize, 112 | 113 | /// The random seed to be used for the generation. 114 | #[arg(long, default_value_t = 32)] 115 | seed: i64, 116 | 117 | /// The number of samples to generate. 118 | #[arg(long, default_value_t = 1)] 119 | num_samples: i64, 120 | 121 | /// The name of the final image to generate. 122 | #[arg(long, value_name = "FILE", default_value = "sd_final.png")] 123 | final_image: String, 124 | 125 | /// Use autocast (disabled by default as it may use more memory in some cases). 126 | #[arg(long, action)] 127 | autocast: bool, 128 | 129 | #[arg(long, value_enum, default_value = "v2-1")] 130 | sd_version: StableDiffusionVersion, 131 | 132 | /// Generate intermediary images at each step. 133 | #[arg(long, action)] 134 | intermediary_images: bool, 135 | } 136 | 137 | #[derive(Debug, Clone, Copy, clap::ValueEnum)] 138 | enum StableDiffusionVersion { 139 | V1_5, 140 | V2_1, 141 | } 142 | 143 | impl Args { 144 | fn clip_weights(&self) -> String { 145 | match &self.clip_weights { 146 | Some(w) => w.clone(), 147 | None => match self.sd_version { 148 | StableDiffusionVersion::V1_5 => "data/pytorch_model.safetensors".to_string(), 149 | StableDiffusionVersion::V2_1 => "data/clip_v2.1.safetensors".to_string(), 150 | }, 151 | } 152 | } 153 | 154 | fn vae_weights(&self) -> String { 155 | match &self.vae_weights { 156 | Some(w) => w.clone(), 157 | None => match self.sd_version { 158 | StableDiffusionVersion::V1_5 => "data/vae.safetensors".to_string(), 159 | StableDiffusionVersion::V2_1 => "data/vae_v2.1.safetensors".to_string(), 160 | }, 161 | } 162 | } 163 | 164 | fn unet_weights(&self) -> String { 165 | match &self.unet_weights { 166 | Some(w) => w.clone(), 167 | None => match self.sd_version { 168 | StableDiffusionVersion::V1_5 => "data/unet.safetensors".to_string(), 169 | StableDiffusionVersion::V2_1 => "data/unet_v2.1.safetensors".to_string(), 170 | }, 171 | } 172 | } 173 | } 174 | 175 | fn output_filename( 176 | basename: &str, 177 | sample_idx: i64, 178 | num_samples: i64, 179 | timestep_idx: Option, 180 | ) -> String { 181 | let filename = if num_samples > 1 { 182 | match basename.rsplit_once('.') { 183 | None => format!("{basename}.{sample_idx}.png"), 184 | Some((filename_no_extension, extension)) => { 185 | format!("{filename_no_extension}.{sample_idx}.{extension}") 186 | } 187 | } 188 | } else { 189 | basename.to_string() 190 | }; 191 | match timestep_idx { 192 | None => filename, 193 | Some(timestep_idx) => match filename.rsplit_once('.') { 194 | None => format!("{filename}-{timestep_idx}.png"), 195 | Some((filename_no_extension, extension)) => { 196 | format!("{filename_no_extension}-{timestep_idx}.{extension}") 197 | } 198 | }, 199 | } 200 | } 201 | 202 | fn run(args: Args) -> anyhow::Result<()> { 203 | let clip_weights = args.clip_weights(); 204 | let vae_weights = args.vae_weights(); 205 | let unet_weights = args.unet_weights(); 206 | let Args { 207 | prompt, 208 | cpu, 209 | height, 210 | width, 211 | n_steps, 212 | seed, 213 | vocab_file, 214 | final_image, 215 | sliced_attention_size, 216 | num_samples, 217 | sd_version, 218 | .. 219 | } = args; 220 | tch::maybe_init_cuda(); 221 | println!("Cuda available: {}", tch::Cuda::is_available()); 222 | println!("Cudnn available: {}", tch::Cuda::cudnn_is_available()); 223 | println!("MPS available: {}", tch::utils::has_mps()); 224 | 225 | let sd_config = match sd_version { 226 | StableDiffusionVersion::V1_5 => { 227 | stable_diffusion::StableDiffusionConfig::v1_5(sliced_attention_size, height, width) 228 | } 229 | StableDiffusionVersion::V2_1 => { 230 | stable_diffusion::StableDiffusionConfig::v2_1(sliced_attention_size, height, width) 231 | } 232 | }; 233 | 234 | let device_setup = diffusers::utils::DeviceSetup::new(cpu); 235 | let clip_device = device_setup.get("clip"); 236 | let vae_device = device_setup.get("vae"); 237 | let unet_device = device_setup.get("unet"); 238 | let scheduler = sd_config.build_scheduler(n_steps); 239 | 240 | let tokenizer = clip::Tokenizer::create(vocab_file, &sd_config.clip)?; 241 | println!("Running with prompt \"{prompt}\"."); 242 | let tokens = tokenizer.encode(&prompt)?; 243 | let tokens: Vec = tokens.into_iter().map(|x| x as i64).collect(); 244 | let tokens = Tensor::from_slice(&tokens).view((1, -1)).to(clip_device); 245 | let uncond_tokens = tokenizer.encode("")?; 246 | let uncond_tokens: Vec = uncond_tokens.into_iter().map(|x| x as i64).collect(); 247 | let uncond_tokens = Tensor::from_slice(&uncond_tokens).view((1, -1)).to(clip_device); 248 | 249 | let no_grad_guard = tch::no_grad_guard(); 250 | 251 | println!("Building the Clip transformer."); 252 | let text_model = sd_config.build_clip_transformer(&clip_weights, clip_device)?; 253 | let text_embeddings = text_model.forward(&tokens); 254 | let uncond_embeddings = text_model.forward(&uncond_tokens); 255 | let text_embeddings = Tensor::cat(&[uncond_embeddings, text_embeddings], 0).to(unet_device); 256 | 257 | println!("Building the autoencoder."); 258 | let vae = sd_config.build_vae(&vae_weights, vae_device)?; 259 | println!("Building the unet."); 260 | let unet = sd_config.build_unet(&unet_weights, unet_device, 4)?; 261 | 262 | let bsize = 1; 263 | for idx in 0..num_samples { 264 | tch::manual_seed(seed + idx); 265 | let mut latents = Tensor::randn( 266 | [bsize, 4, sd_config.height / 8, sd_config.width / 8], 267 | (Kind::Float, unet_device), 268 | ); 269 | 270 | // scale the initial noise by the standard deviation required by the scheduler 271 | latents *= scheduler.init_noise_sigma(); 272 | 273 | for (timestep_index, ×tep) in scheduler.timesteps().iter().enumerate() { 274 | println!("Timestep {timestep_index}/{n_steps}"); 275 | let latent_model_input = Tensor::cat(&[&latents, &latents], 0); 276 | 277 | let latent_model_input = scheduler.scale_model_input(latent_model_input, timestep); 278 | let noise_pred = unet.forward(&latent_model_input, timestep as f64, &text_embeddings); 279 | let noise_pred = noise_pred.chunk(2, 0); 280 | let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); 281 | let noise_pred = 282 | noise_pred_uncond + (noise_pred_text - noise_pred_uncond) * GUIDANCE_SCALE; 283 | latents = scheduler.step(&noise_pred, timestep, &latents); 284 | 285 | if args.intermediary_images { 286 | let latents = latents.to(vae_device); 287 | let image = vae.decode(&(&latents / 0.18215)); 288 | let image = (image / 2 + 0.5).clamp(0., 1.).to_device(Device::Cpu); 289 | let image = (image * 255.).to_kind(Kind::Uint8); 290 | let final_image = 291 | output_filename(&final_image, idx + 1, num_samples, Some(timestep_index + 1)); 292 | tch::vision::image::save(&image, final_image)?; 293 | } 294 | } 295 | 296 | println!("Generating the final image for sample {}/{}.", idx + 1, num_samples); 297 | let latents = latents.to(vae_device); 298 | let image = vae.decode(&(&latents / 0.18215)); 299 | let image = (image / 2 + 0.5).clamp(0., 1.).to_device(Device::Cpu); 300 | let image = (image * 255.).to_kind(Kind::Uint8); 301 | let final_image = output_filename(&final_image, idx + 1, num_samples, None); 302 | tch::vision::image::save(&image, final_image)?; 303 | } 304 | 305 | drop(no_grad_guard); 306 | Ok(()) 307 | } 308 | 309 | fn main() -> anyhow::Result<()> { 310 | let args = Args::parse(); 311 | if !args.autocast { 312 | run(args) 313 | } else { 314 | tch::autocast(true, || run(args)) 315 | } 316 | } 317 | -------------------------------------------------------------------------------- /examples/tensor-tools.rs: -------------------------------------------------------------------------------- 1 | // A small tensor tool utility. 2 | // 3 | // - List the content of some npy/npz/ot file. 4 | // tensor-tools ls a.npy b.npz c.ot 5 | // 6 | // - Convert a npz file to an ot file. 7 | // tensor-tools cp src.npz dst.ot 8 | // Or the other way around. 9 | // tensor-tools cp src.ot dst.npz 10 | 11 | use anyhow::{bail, ensure, Result}; 12 | 13 | pub fn main() -> Result<()> { 14 | let args: Vec<_> = std::env::args().collect(); 15 | ensure!(args.len() >= 2, "usage: {} (ls|cp) ...", args[0]); 16 | match args[1].as_str() { 17 | "ls" => { 18 | for filename in args.iter().skip(2) { 19 | if filename.ends_with(".npy") { 20 | let tensor = tch::Tensor::read_npy(filename)?; 21 | println!("{filename}: {tensor:?}"); 22 | } else if filename.ends_with(".npz") { 23 | let tensors = tch::Tensor::read_npz(filename)?; 24 | for (name, tensor) in tensors.iter() { 25 | println!("{filename}: {name} {tensor:?}") 26 | } 27 | } else if filename.ends_with(".ot") { 28 | let tensors = tch::Tensor::load_multi(filename)?; 29 | for (name, tensor) in tensors.iter() { 30 | println!("{filename}: {name} {tensor:?}") 31 | } 32 | } else { 33 | bail!("unhandled file {}", filename); 34 | } 35 | } 36 | } 37 | "cp" => { 38 | ensure!(args.len() == 4, "usage: {} cp src.ot dst.npz", args[0]); 39 | let src_filename = &args[2]; 40 | let dst_filename = &args[3]; 41 | let tensors = if src_filename.ends_with(".npz") { 42 | tch::Tensor::read_npz(src_filename)? 43 | } else if src_filename.ends_with(".ot") { 44 | tch::Tensor::load_multi(src_filename)? 45 | } else { 46 | bail!("unhandled file {}", src_filename) 47 | }; 48 | for (name, tensor) in tensors.iter() { 49 | println!("{src_filename}: {name} {tensor:?}") 50 | } 51 | if dst_filename.ends_with(".npz") { 52 | tch::Tensor::write_npz(&tensors, dst_filename)? 53 | } else if dst_filename.ends_with(".ot") { 54 | tch::Tensor::save_multi(&tensors, dst_filename)? 55 | } else { 56 | bail!("unhandled file {}", dst_filename) 57 | }; 58 | } 59 | _ => bail!("usage: {} (ls|cp) ...", args[0]), 60 | } 61 | 62 | Ok(()) 63 | } 64 | -------------------------------------------------------------------------------- /media/in_img2img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/in_img2img.jpg -------------------------------------------------------------------------------- /media/out_img2img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/out_img2img.jpg -------------------------------------------------------------------------------- /media/out_inpaint.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/out_inpaint.jpg -------------------------------------------------------------------------------- /media/robot11.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/robot11.jpg -------------------------------------------------------------------------------- /media/robot13.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/robot13.jpg -------------------------------------------------------------------------------- /media/robot3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/robot3.jpg -------------------------------------------------------------------------------- /media/robot4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/robot4.jpg -------------------------------------------------------------------------------- /media/robot7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/robot7.jpg -------------------------------------------------------------------------------- /media/robot8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/robot8.jpg -------------------------------------------------------------------------------- /media/vermeer-edges.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer-edges.png -------------------------------------------------------------------------------- /media/vermeer-out1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer-out1.jpg -------------------------------------------------------------------------------- /media/vermeer-out2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer-out2.jpg -------------------------------------------------------------------------------- /media/vermeer-out3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer-out3.jpg -------------------------------------------------------------------------------- /media/vermeer-out4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer-out4.jpg -------------------------------------------------------------------------------- /media/vermeer-out5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer-out5.jpg -------------------------------------------------------------------------------- /media/vermeer.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LaurentMazare/diffusers-rs/f19c33f84599eb7dea3a65e5b0810ea55c4c57c3/media/vermeer.jpg -------------------------------------------------------------------------------- /rustfmt.toml: -------------------------------------------------------------------------------- 1 | use_small_heuristics = "Max" 2 | edition = "2018" 3 | -------------------------------------------------------------------------------- /scripts/download_weights_1.5.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | ROOT=$(pwd) 5 | 6 | # This can be either fp16 or main for float32 weights. 7 | BRANCH=fp16 8 | 9 | wget_vocab() { 10 | wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz 11 | gunzip bpe_simple_vocab_16e6.txt.gz 12 | } 13 | 14 | wget_clip_weights() { 15 | wget -c https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin 16 | LD_LIBRARY_PATH= python3 -c " 17 | import torch 18 | from safetensors.torch import save_file 19 | 20 | model = torch.load('./pytorch_model.bin') 21 | tensors = {k: v.clone().detach() for k, v in model.items() if 'text_model' in k} 22 | save_file(tensors, 'pytorch_model.safetensors') 23 | " 24 | } 25 | 26 | wget_vae_unet_weights() { 27 | # download weights for vae 28 | header="Authorization: Bearer $1" 29 | wget --header="$header" https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/$BRANCH/vae/diffusion_pytorch_model.bin -O vae.bin 30 | # download weights for unet 31 | wget --header="$header" https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/$BRANCH/unet/diffusion_pytorch_model.bin -O unet.bin 32 | 33 | # convert to npz 34 | LD_LIBRARY_PATH= python3 -c " 35 | import torch 36 | from safetensors.torch import save_file 37 | 38 | model = torch.load('./vae.bin') 39 | save_file(dict(model), './vae.safetensors') 40 | 41 | model = torch.load('./unet.bin') 42 | save_file(dict(model), './unet.safetensors') 43 | " 44 | } 45 | 46 | if [ $# -ne 1 ]; then 47 | echo 'Usage: ./download_weights_1.5.sh ' >&2 48 | exit 1 49 | fi 50 | 51 | echo "Setting up for diffusers-rs..." 52 | 53 | mkdir -p data 54 | cd data 55 | 56 | echo "Getting the Weights and the Vocab File" 57 | # get the weights 58 | wget_vocab 59 | wget_clip_weights 60 | wget_vae_unet_weights $1 61 | 62 | echo "Cleaning ..." 63 | rm -rf $ROOT/data/*.bin 64 | 65 | echo "Done." 66 | -------------------------------------------------------------------------------- /scripts/download_weights_2.1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euxo pipefail 3 | 4 | ROOT=$(pwd) 5 | 6 | # This can be either fp16 or main for float32 weights. 7 | BRANCH=fp16 8 | 9 | wget_vocab() { 10 | wget https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz 11 | gunzip bpe_simple_vocab_16e6.txt.gz 12 | } 13 | 14 | wget_weights() { 15 | header="Authorization: Bearer $1" 16 | # download weights for clip 17 | wget --header="$header" https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/$BRANCH/text_encoder/pytorch_model.bin -O clip.bin 18 | # download weights for vae 19 | wget --header="$header" https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/$BRANCH/vae/diffusion_pytorch_model.bin -O vae.bin 20 | # download weights for unet 21 | wget --header="$header" https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/$BRANCH/unet/diffusion_pytorch_model.bin -O unet.bin 22 | 23 | # convert to npz 24 | LD_LIBRARY_PATH= python3 -c " 25 | import torch 26 | from safetensors.torch import save_file 27 | 28 | model = torch.load('./clip.bin') 29 | save_file({k: v for k, v in model.items() if 'text_model' in k}, './clip_v2.1.safetensors') 30 | 31 | model = torch.load('./vae.bin') 32 | save_file(dict(model), './vae_v2.1.safetensors') 33 | 34 | model = torch.load('./unet.bin') 35 | save_file(dict(model), './unet_v2.1.safetensors') 36 | " 37 | } 38 | 39 | if [ $# -ne 1 ]; then 40 | echo 'Usage: ./download_weights_2.1.sh ' >&2 41 | exit 1 42 | fi 43 | 44 | echo "Setting up for diffusers-rs..." 45 | 46 | mkdir -p data 47 | cd data 48 | 49 | echo "Getting the Weights and the Vocab File" 50 | # get the weights 51 | wget_vocab 52 | wget_weights $1 53 | 54 | echo "Cleaning ..." 55 | rm -rf $ROOT/data/*.bin 56 | 57 | echo "Done." 58 | -------------------------------------------------------------------------------- /scripts/get_weights.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from safetensors.torch import save_file 3 | import urllib.request 4 | import os 5 | import sys 6 | import argparse 7 | 8 | data_path = os.path.join(os.path.dirname(__file__), "../data/") 9 | vocab_filename = "bpe_simple_vocab_16e6.txt" 10 | 11 | def ensure_data_dir(safetensors): 12 | print("Ensuring empty data directory...") 13 | 14 | if os.path.exists(data_path): 15 | # Fail if conflicting files exist 16 | files = os.listdir(data_path) 17 | newfiles = [x for name in safetensors for x in (f"{name}.bin", f"{name}.safetensors")] 18 | newfiles += [vocab_filename, f"{vocab_filename}.gz"] 19 | conflicts = set(files) & set(newfiles) 20 | if len(conflicts) != 0: 21 | print("Error: please remove the following files from data directory:") 22 | print(conflicts) 23 | sys.exit("Found conflicting files in data directory.") 24 | else: 25 | os.mkdir(data_path) 26 | 27 | print("Found no conflicts!") 28 | 29 | def get_safetensors(safetensors, weight_bits): 30 | for name, url in safetensors.items(): 31 | print(f"Getting {name} {weight_bits} bit tensors...") 32 | 33 | # Download bin file 34 | urllib.request.urlretrieve(url, os.path.join(data_path, f"{name}.bin")) 35 | 36 | # Make safetensors file 37 | model = torch.load(os.path.join(data_path, f"{name}.bin"), map_location=torch.device("cpu")) 38 | tensors = {k: v.clone().detach() for k, v in model.items() if 'text_model' in k} if name in ["clip_v2.1", "pytorch_model"] else dict(model) 39 | save_file(tensors, os.path.join(data_path, f"{name}.safetensors")) 40 | 41 | # Remove bin file 42 | os.remove(os.path.join(data_path, f"{name}.bin")) 43 | 44 | def get_vocab(vocab_url): 45 | print("Getting vocab...") 46 | urllib.request.urlretrieve(vocab_url, os.path.join(data_path, f"{vocab_filename}.gz")) 47 | import gzip 48 | with gzip.open(os.path.join(data_path, f"{vocab_filename}.gz"), 'rb') as g: 49 | with open(os.path.join(data_path, vocab_filename), "xb") as f: 50 | f.write(g.read()) 51 | os.remove(os.path.join(data_path, f"{vocab_filename}.gz")) 52 | 53 | def get_urls(sd_version, weight_bits): 54 | branch = "main" 55 | if weight_bits == "16": 56 | branch = "fp16" # fp16 for float16 weights or main for float32 weights 57 | 58 | safetensors_v1_5 = { 59 | "vae": f"https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/{branch}/vae/diffusion_pytorch_model.bin", 60 | "unet": f"https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/{branch}/unet/diffusion_pytorch_model.bin", 61 | "pytorch_model": f"https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/pytorch_model.bin" 62 | } 63 | safetensors_v2_1 = { 64 | "vae_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/vae/diffusion_pytorch_model.bin", 65 | "unet_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/unet/diffusion_pytorch_model.bin", 66 | "clip_v2.1": f"https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/{branch}/text_encoder/pytorch_model.bin" 67 | } 68 | vocab_url = "https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz" 69 | 70 | return safetensors_v1_5 if sd_version == "1.5" else safetensors_v2_1, vocab_url 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser(description="Download weights for diffusers-rs.") 74 | parser.add_argument("--sd_version", "-v", choices=["2.1", "1.5"], default="2.1") 75 | parser.add_argument("--weight_bits", "-w", choices=["16", "32"], default="16") 76 | args = parser.parse_args() 77 | 78 | print("Setting up model weights for diffusers-rs...") 79 | 80 | safetensors, vocab_url = get_urls(args.sd_version, args.weight_bits) 81 | ensure_data_dir(safetensors) 82 | get_vocab(vocab_url) 83 | get_safetensors(safetensors, args.weight_bits) 84 | 85 | print("Finished!") 86 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | //! # Diffusion pipelines and models 2 | //! 3 | //! This is a Rust port of Hugging Face's [diffusers](https://github.com/huggingface/diffusers) Python api using Torch via the [tch-rs](https://github.com/LaurentMazare/tch-rs). 4 | //! 5 | //! This library includes: 6 | //! - Multiple type of UNet based models, with a ResNet backend. 7 | //! - Training examples including version 1.5 of Stable Diffusion. 8 | //! - Some basic transformers implementation for handling user prompts. 9 | //! 10 | //! The models can used pre-trained weights adapted from the Python 11 | //! implementation. 12 | 13 | pub mod models; 14 | pub mod pipelines; 15 | pub mod schedulers; 16 | pub mod transformers; 17 | pub mod utils; 18 | -------------------------------------------------------------------------------- /src/models/controlnet.rs: -------------------------------------------------------------------------------- 1 | // https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/controlnet.py 2 | use super::unet_2d::{BlockConfig, UNetDownBlock}; 3 | use crate::models::embeddings::{TimestepEmbedding, Timesteps}; 4 | use crate::models::unet_2d_blocks::*; 5 | use tch::{nn, nn::Module, Kind, Tensor}; 6 | 7 | #[derive(Debug)] 8 | pub struct ControlNetConditioningEmbedding { 9 | conv_in: nn::Conv2D, 10 | conv_out: nn::Conv2D, 11 | blocks: Vec<(nn::Conv2D, nn::Conv2D)>, 12 | } 13 | 14 | impl ControlNetConditioningEmbedding { 15 | pub fn new( 16 | vs: nn::Path, 17 | conditioning_embedding_channels: i64, 18 | conditioning_channels: i64, 19 | blocks: &[i64], 20 | ) -> Self { 21 | let b_channels = blocks[0]; 22 | let bl_channels = *blocks.last().unwrap(); 23 | let conv_cfg = nn::ConvConfig { padding: 1, ..Default::default() }; 24 | let conv_cfg2 = nn::ConvConfig { stride: 2, padding: 1, ..Default::default() }; 25 | let conv_in = nn::conv2d(&vs / "conv_in", conditioning_channels, b_channels, 3, conv_cfg); 26 | let conv_out = 27 | nn::conv2d(&vs / "conv_out", bl_channels, conditioning_embedding_channels, 3, conv_cfg); 28 | let vs_b = &vs / "blocks"; 29 | let blocks = (0..(blocks.len() - 1)) 30 | .map(|i| { 31 | let channel_in = blocks[i]; 32 | let channel_out = blocks[i + 1]; 33 | let c1 = nn::conv2d(&vs_b / (2 * i), channel_in, channel_in, 3, conv_cfg); 34 | let c2 = nn::conv2d(&vs_b / (2 * i + 1), channel_in, channel_out, 3, conv_cfg2); 35 | (c1, c2) 36 | }) 37 | .collect(); 38 | Self { conv_in, conv_out, blocks } 39 | } 40 | } 41 | 42 | impl tch::nn::Module for ControlNetConditioningEmbedding { 43 | fn forward(&self, xs: &Tensor) -> Tensor { 44 | let mut xs = xs.apply(&self.conv_in).silu(); 45 | for (c1, c2) in self.blocks.iter() { 46 | xs = xs.apply(c1).silu().apply(c2).silu(); 47 | } 48 | xs.apply(&self.conv_out) 49 | } 50 | } 51 | 52 | pub struct ControlNetConfig { 53 | pub flip_sin_to_cos: bool, 54 | pub freq_shift: f64, 55 | pub blocks: Vec, 56 | pub conditioning_embedding_out_channels: Vec, 57 | pub layers_per_block: i64, 58 | pub downsample_padding: i64, 59 | pub mid_block_scale_factor: f64, 60 | pub norm_num_groups: i64, 61 | pub norm_eps: f64, 62 | pub cross_attention_dim: i64, 63 | pub use_linear_projection: bool, 64 | } 65 | 66 | impl Default for ControlNetConfig { 67 | // https://huggingface.co/lllyasviel/sd-controlnet-canny/blob/main/config.json 68 | fn default() -> Self { 69 | Self { 70 | flip_sin_to_cos: true, 71 | freq_shift: 0., 72 | blocks: vec![ 73 | BlockConfig { out_channels: 320, use_cross_attn: true, attention_head_dim: 8 }, 74 | BlockConfig { out_channels: 640, use_cross_attn: true, attention_head_dim: 8 }, 75 | BlockConfig { out_channels: 1280, use_cross_attn: true, attention_head_dim: 8 }, 76 | BlockConfig { out_channels: 1280, use_cross_attn: false, attention_head_dim: 8 }, 77 | ], 78 | conditioning_embedding_out_channels: vec![16, 32, 96, 256], 79 | layers_per_block: 2, 80 | downsample_padding: 1, 81 | mid_block_scale_factor: 1., 82 | norm_num_groups: 32, 83 | norm_eps: 1e-5, 84 | // The default value for the following is 1280 in diffusers/models/controlnet.py but 85 | // 768 in the actual config file. 86 | cross_attention_dim: 768, 87 | use_linear_projection: false, 88 | } 89 | } 90 | } 91 | 92 | #[allow(dead_code)] 93 | pub struct ControlNet { 94 | conv_in: nn::Conv2D, 95 | controlnet_mid_block: nn::Conv2D, 96 | controlnet_cond_embedding: ControlNetConditioningEmbedding, 97 | time_proj: Timesteps, 98 | time_embedding: TimestepEmbedding, 99 | down_blocks: Vec, 100 | controlnet_down_blocks: Vec, 101 | mid_block: UNetMidBlock2DCrossAttn, 102 | pub config: ControlNetConfig, 103 | } 104 | 105 | impl ControlNet { 106 | pub fn new(vs: nn::Path, in_channels: i64, config: ControlNetConfig) -> Self { 107 | let n_blocks = config.blocks.len(); 108 | let b_channels = config.blocks[0].out_channels; 109 | let bl_channels = config.blocks.last().unwrap().out_channels; 110 | let time_embed_dim = b_channels * 4; 111 | let time_proj = 112 | Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift, vs.device()); 113 | let time_embedding = 114 | TimestepEmbedding::new(&vs / "time_embedding", b_channels, time_embed_dim); 115 | let conv_cfg = nn::ConvConfig { stride: 1, padding: 1, ..Default::default() }; 116 | let conv_in = nn::conv2d(&vs / "conv_in", in_channels, b_channels, 3, conv_cfg); 117 | let controlnet_mid_block = nn::conv2d( 118 | &vs / "controlnet_mid_block", 119 | bl_channels, 120 | bl_channels, 121 | 1, 122 | Default::default(), 123 | ); 124 | let controlnet_cond_embedding = ControlNetConditioningEmbedding::new( 125 | &vs / "controlnet_cond_embedding", 126 | b_channels, 127 | 3, 128 | &config.conditioning_embedding_out_channels, 129 | ); 130 | let vs_db = &vs / "down_blocks"; 131 | let down_blocks = (0..n_blocks) 132 | .map(|i| { 133 | let BlockConfig { out_channels, use_cross_attn, attention_head_dim } = 134 | config.blocks[i]; 135 | 136 | let in_channels = 137 | if i > 0 { config.blocks[i - 1].out_channels } else { b_channels }; 138 | let db_cfg = DownBlock2DConfig { 139 | num_layers: config.layers_per_block, 140 | resnet_eps: config.norm_eps, 141 | resnet_groups: config.norm_num_groups, 142 | add_downsample: i < n_blocks - 1, 143 | downsample_padding: config.downsample_padding, 144 | output_scale_factor: 1., 145 | }; 146 | if use_cross_attn { 147 | let config = CrossAttnDownBlock2DConfig { 148 | downblock: db_cfg, 149 | attn_num_head_channels: attention_head_dim, 150 | cross_attention_dim: config.cross_attention_dim, 151 | sliced_attention_size: None, 152 | use_linear_projection: config.use_linear_projection, 153 | }; 154 | let block = CrossAttnDownBlock2D::new( 155 | &vs_db / i, 156 | in_channels, 157 | out_channels, 158 | Some(time_embed_dim), 159 | config, 160 | ); 161 | UNetDownBlock::CrossAttn(block) 162 | } else { 163 | let block = DownBlock2D::new( 164 | &vs_db / i, 165 | in_channels, 166 | out_channels, 167 | Some(time_embed_dim), 168 | db_cfg, 169 | ); 170 | UNetDownBlock::Basic(block) 171 | } 172 | }) 173 | .collect(); 174 | let bl_channels = config.blocks.last().unwrap().out_channels; 175 | let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; 176 | let mid_cfg = UNetMidBlock2DCrossAttnConfig { 177 | resnet_eps: config.norm_eps, 178 | output_scale_factor: config.mid_block_scale_factor, 179 | cross_attn_dim: config.cross_attention_dim, 180 | attn_num_head_channels: bl_attention_head_dim, 181 | resnet_groups: Some(config.norm_num_groups), 182 | use_linear_projection: config.use_linear_projection, 183 | ..Default::default() 184 | }; 185 | let mid_block = UNetMidBlock2DCrossAttn::new( 186 | &vs / "mid_block", 187 | bl_channels, 188 | Some(time_embed_dim), 189 | mid_cfg, 190 | ); 191 | 192 | let vs_c = &vs / "controlnet_down_blocks"; 193 | let controlnet_block = nn::conv2d(&vs_c / 0, b_channels, b_channels, 1, Default::default()); 194 | let mut controlnet_down_blocks = vec![controlnet_block]; 195 | for (i, block) in config.blocks.iter().enumerate() { 196 | let out_channels = block.out_channels; 197 | for _ in 0..config.layers_per_block { 198 | let conv1 = nn::conv2d( 199 | &vs_c / controlnet_down_blocks.len(), 200 | out_channels, 201 | out_channels, 202 | 1, 203 | Default::default(), 204 | ); 205 | controlnet_down_blocks.push(conv1); 206 | } 207 | if i + 1 != config.blocks.len() { 208 | let conv2 = nn::conv2d( 209 | &vs_c / controlnet_down_blocks.len(), 210 | out_channels, 211 | out_channels, 212 | 1, 213 | Default::default(), 214 | ); 215 | controlnet_down_blocks.push(conv2); 216 | } 217 | } 218 | 219 | Self { 220 | conv_in, 221 | controlnet_mid_block, 222 | controlnet_cond_embedding, 223 | controlnet_down_blocks, 224 | time_proj, 225 | time_embedding, 226 | down_blocks, 227 | mid_block, 228 | config, 229 | } 230 | } 231 | 232 | pub fn forward( 233 | &self, 234 | xs: &Tensor, 235 | timestep: f64, 236 | encoder_hidden_states: &Tensor, 237 | controlnet_cond: &Tensor, 238 | conditioning_scale: f64, 239 | ) -> (Vec, Tensor) { 240 | let (bsize, _channels, _height, _width) = xs.size4().unwrap(); 241 | let device = xs.device(); 242 | // Only support: 243 | // - The default channel order (rgb). 244 | // - No class embedding, class_embed_type and num_class_embeds are both None. 245 | // - No guess mode. 246 | 247 | // 1. Time 248 | let emb = (Tensor::ones([bsize], (Kind::Float, device)) * timestep) 249 | .apply(&self.time_proj) 250 | .apply(&self.time_embedding); 251 | 252 | // 2. Pre-process. 253 | let xs = xs.apply(&self.conv_in); 254 | let controlnet_cond = controlnet_cond.apply(&self.controlnet_cond_embedding); 255 | let xs = xs + controlnet_cond; 256 | 257 | // 3. Down. 258 | let mut down_block_res_xs = vec![xs.shallow_clone()]; 259 | let mut xs = xs; 260 | for down_block in self.down_blocks.iter() { 261 | let (_xs, res_xs) = match down_block { 262 | UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb)), 263 | UNetDownBlock::CrossAttn(b) => { 264 | b.forward(&xs, Some(&emb), Some(encoder_hidden_states)) 265 | } 266 | }; 267 | down_block_res_xs.extend(res_xs); 268 | xs = _xs; 269 | } 270 | 271 | // 4. Mid. 272 | let xs = self.mid_block.forward(&xs, Some(&emb), Some(encoder_hidden_states)); 273 | 274 | // 5. ControlNet blocks. 275 | let controlnet_down_block_res_xs = self 276 | .controlnet_down_blocks 277 | .iter() 278 | .enumerate() 279 | .map(|(i, block)| block.forward(&down_block_res_xs[i]) * conditioning_scale) 280 | .collect::>(); 281 | 282 | let xs = xs.apply(&self.controlnet_mid_block); 283 | (controlnet_down_block_res_xs, xs * conditioning_scale) 284 | } 285 | } 286 | -------------------------------------------------------------------------------- /src/models/embeddings.rs: -------------------------------------------------------------------------------- 1 | use tch::{nn, nn::Module, Device, Kind, Tensor}; 2 | 3 | #[derive(Debug)] 4 | pub struct TimestepEmbedding { 5 | linear_1: nn::Linear, 6 | linear_2: nn::Linear, 7 | } 8 | 9 | impl TimestepEmbedding { 10 | // act_fn: "silu" 11 | pub fn new(vs: nn::Path, channel: i64, time_embed_dim: i64) -> Self { 12 | let linear_cfg = Default::default(); 13 | let linear_1 = nn::linear(&vs / "linear_1", channel, time_embed_dim, linear_cfg); 14 | let linear_2 = nn::linear(&vs / "linear_2", time_embed_dim, time_embed_dim, linear_cfg); 15 | Self { linear_1, linear_2 } 16 | } 17 | } 18 | 19 | impl Module for TimestepEmbedding { 20 | fn forward(&self, xs: &Tensor) -> Tensor { 21 | xs.apply(&self.linear_1).silu().apply(&self.linear_2) 22 | } 23 | } 24 | 25 | #[derive(Debug)] 26 | pub struct Timesteps { 27 | num_channels: i64, 28 | flip_sin_to_cos: bool, 29 | downscale_freq_shift: f64, 30 | device: Device, 31 | } 32 | 33 | impl Timesteps { 34 | pub fn new( 35 | num_channels: i64, 36 | flip_sin_to_cos: bool, 37 | downscale_freq_shift: f64, 38 | device: Device, 39 | ) -> Self { 40 | Self { num_channels, flip_sin_to_cos, downscale_freq_shift, device } 41 | } 42 | } 43 | 44 | impl Module for Timesteps { 45 | fn forward(&self, xs: &Tensor) -> Tensor { 46 | let half_dim = self.num_channels / 2; 47 | let exponent = Tensor::arange(half_dim, (Kind::Float, self.device)) * -f64::ln(10000.); 48 | let exponent = exponent / (half_dim as f64 - self.downscale_freq_shift); 49 | let emb = exponent.exp(); 50 | // emb = timesteps[:, None].float() * emb[None, :] 51 | let emb = xs.unsqueeze(-1) * emb.unsqueeze(0); 52 | let emb = if self.flip_sin_to_cos { 53 | Tensor::cat(&[emb.cos(), emb.sin()], -1) 54 | } else { 55 | Tensor::cat(&[emb.sin(), emb.cos()], -1) 56 | }; 57 | if self.num_channels % 2 == 1 { 58 | emb.pad([0, 1, 0, 0], "constant", None) 59 | } else { 60 | emb 61 | } 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /src/models/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Models 2 | //! 3 | //! A collection of models to be used in a diffusion loop. 4 | 5 | pub mod attention; 6 | pub mod controlnet; 7 | pub mod embeddings; 8 | pub mod resnet; 9 | pub mod unet_2d; 10 | pub mod unet_2d_blocks; 11 | pub mod vae; 12 | -------------------------------------------------------------------------------- /src/models/resnet.rs: -------------------------------------------------------------------------------- 1 | //! ResNet Building Blocks 2 | //! 3 | //! Some Residual Network blocks used in UNet models. 4 | //! 5 | //! Denoising Diffusion Implicit Models, K. He and al, 2015. 6 | //! https://arxiv.org/abs/1512.03385 7 | use tch::{nn, Tensor}; 8 | 9 | /// Configuration for a ResNet block. 10 | #[derive(Debug, Clone, Copy)] 11 | pub struct ResnetBlock2DConfig { 12 | /// The number of output channels, defaults to the number of input channels. 13 | pub out_channels: Option, 14 | pub temb_channels: Option, 15 | /// The number of groups to use in group normalization. 16 | pub groups: i64, 17 | pub groups_out: Option, 18 | /// The epsilon to be used in the group normalization operations. 19 | pub eps: f64, 20 | /// Whether to use a 2D convolution in the skip connection. When using None, 21 | /// such a convolution is used if the number of input channels is different from 22 | /// the number of output channels. 23 | pub use_in_shortcut: Option, 24 | // non_linearity: silu 25 | /// The final output is scaled by dividing by this value. 26 | pub output_scale_factor: f64, 27 | } 28 | 29 | impl Default for ResnetBlock2DConfig { 30 | fn default() -> Self { 31 | Self { 32 | out_channels: None, 33 | temb_channels: Some(512), 34 | groups: 32, 35 | groups_out: None, 36 | eps: 1e-6, 37 | use_in_shortcut: None, 38 | output_scale_factor: 1., 39 | } 40 | } 41 | } 42 | 43 | #[derive(Debug)] 44 | pub struct ResnetBlock2D { 45 | norm1: nn::GroupNorm, 46 | conv1: nn::Conv2D, 47 | norm2: nn::GroupNorm, 48 | conv2: nn::Conv2D, 49 | time_emb_proj: Option, 50 | conv_shortcut: Option, 51 | config: ResnetBlock2DConfig, 52 | } 53 | 54 | impl ResnetBlock2D { 55 | pub fn new(vs: nn::Path, in_channels: i64, config: ResnetBlock2DConfig) -> Self { 56 | let out_channels = config.out_channels.unwrap_or(in_channels); 57 | let conv_cfg = nn::ConvConfig { stride: 1, padding: 1, ..Default::default() }; 58 | let group_cfg = nn::GroupNormConfig { eps: config.eps, affine: true, ..Default::default() }; 59 | let norm1 = nn::group_norm(&vs / "norm1", config.groups, in_channels, group_cfg); 60 | let conv1 = nn::conv2d(&vs / "conv1", in_channels, out_channels, 3, conv_cfg); 61 | let groups_out = config.groups_out.unwrap_or(config.groups); 62 | let norm2 = nn::group_norm(&vs / "norm2", groups_out, out_channels, group_cfg); 63 | let conv2 = nn::conv2d(&vs / "conv2", out_channels, out_channels, 3, conv_cfg); 64 | let use_in_shortcut = config.use_in_shortcut.unwrap_or(in_channels != out_channels); 65 | let conv_shortcut = if use_in_shortcut { 66 | let conv_cfg = nn::ConvConfig { stride: 1, padding: 0, ..Default::default() }; 67 | Some(nn::conv2d(&vs / "conv_shortcut", in_channels, out_channels, 1, conv_cfg)) 68 | } else { 69 | None 70 | }; 71 | let time_emb_proj = config.temb_channels.map(|temb_channels| { 72 | nn::linear(&vs / "time_emb_proj", temb_channels, out_channels, Default::default()) 73 | }); 74 | Self { norm1, conv1, norm2, conv2, time_emb_proj, config, conv_shortcut } 75 | } 76 | 77 | pub fn forward(&self, xs: &Tensor, temb: Option<&Tensor>) -> Tensor { 78 | let shortcut_xs = match &self.conv_shortcut { 79 | Some(conv_shortcut) => xs.apply(conv_shortcut), 80 | None => xs.shallow_clone(), 81 | }; 82 | let xs = xs.apply(&self.norm1).silu().apply(&self.conv1); 83 | let xs = match (temb, &self.time_emb_proj) { 84 | (Some(temb), Some(time_emb_proj)) => { 85 | temb.silu().apply(time_emb_proj).unsqueeze(-1).unsqueeze(-1) + xs 86 | } 87 | _ => xs, 88 | }; 89 | let xs = xs.apply(&self.norm2).silu().apply(&self.conv2); 90 | (shortcut_xs + xs) / self.config.output_scale_factor 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /src/models/unet_2d.rs: -------------------------------------------------------------------------------- 1 | //! 2D UNet Denoising Models 2 | //! 3 | //! The 2D Unet models take as input a noisy sample and the current diffusion 4 | //! timestep and return a denoised version of the input. 5 | use crate::models::embeddings::{TimestepEmbedding, Timesteps}; 6 | use crate::models::unet_2d_blocks::*; 7 | use tch::{nn, Kind, Tensor}; 8 | 9 | #[derive(Debug, Clone, Copy)] 10 | pub struct BlockConfig { 11 | pub out_channels: i64, 12 | pub use_cross_attn: bool, 13 | pub attention_head_dim: i64, 14 | } 15 | 16 | #[derive(Debug, Clone)] 17 | pub struct UNet2DConditionModelConfig { 18 | pub center_input_sample: bool, 19 | pub flip_sin_to_cos: bool, 20 | pub freq_shift: f64, 21 | pub blocks: Vec, 22 | pub layers_per_block: i64, 23 | pub downsample_padding: i64, 24 | pub mid_block_scale_factor: f64, 25 | pub norm_num_groups: i64, 26 | pub norm_eps: f64, 27 | pub cross_attention_dim: i64, 28 | pub sliced_attention_size: Option, 29 | pub use_linear_projection: bool, 30 | } 31 | 32 | impl Default for UNet2DConditionModelConfig { 33 | fn default() -> Self { 34 | Self { 35 | center_input_sample: false, 36 | flip_sin_to_cos: true, 37 | freq_shift: 0., 38 | blocks: vec![ 39 | BlockConfig { out_channels: 320, use_cross_attn: true, attention_head_dim: 8 }, 40 | BlockConfig { out_channels: 640, use_cross_attn: true, attention_head_dim: 8 }, 41 | BlockConfig { out_channels: 1280, use_cross_attn: true, attention_head_dim: 8 }, 42 | BlockConfig { out_channels: 1280, use_cross_attn: false, attention_head_dim: 8 }, 43 | ], 44 | layers_per_block: 2, 45 | downsample_padding: 1, 46 | mid_block_scale_factor: 1., 47 | norm_num_groups: 32, 48 | norm_eps: 1e-5, 49 | cross_attention_dim: 1280, 50 | sliced_attention_size: None, 51 | use_linear_projection: false, 52 | } 53 | } 54 | } 55 | 56 | #[derive(Debug)] 57 | pub(crate) enum UNetDownBlock { 58 | Basic(DownBlock2D), 59 | CrossAttn(CrossAttnDownBlock2D), 60 | } 61 | 62 | #[derive(Debug)] 63 | enum UNetUpBlock { 64 | Basic(UpBlock2D), 65 | CrossAttn(CrossAttnUpBlock2D), 66 | } 67 | 68 | #[derive(Debug)] 69 | pub struct UNet2DConditionModel { 70 | conv_in: nn::Conv2D, 71 | time_proj: Timesteps, 72 | time_embedding: TimestepEmbedding, 73 | down_blocks: Vec, 74 | mid_block: UNetMidBlock2DCrossAttn, 75 | up_blocks: Vec, 76 | conv_norm_out: nn::GroupNorm, 77 | conv_out: nn::Conv2D, 78 | config: UNet2DConditionModelConfig, 79 | } 80 | 81 | impl UNet2DConditionModel { 82 | pub fn new( 83 | vs: nn::Path, 84 | in_channels: i64, 85 | out_channels: i64, 86 | config: UNet2DConditionModelConfig, 87 | ) -> Self { 88 | let n_blocks = config.blocks.len(); 89 | let b_channels = config.blocks[0].out_channels; 90 | let bl_channels = config.blocks.last().unwrap().out_channels; 91 | let bl_attention_head_dim = config.blocks.last().unwrap().attention_head_dim; 92 | let time_embed_dim = b_channels * 4; 93 | let conv_cfg = nn::ConvConfig { stride: 1, padding: 1, ..Default::default() }; 94 | let conv_in = nn::conv2d(&vs / "conv_in", in_channels, b_channels, 3, conv_cfg); 95 | 96 | let time_proj = 97 | Timesteps::new(b_channels, config.flip_sin_to_cos, config.freq_shift, vs.device()); 98 | let time_embedding = 99 | TimestepEmbedding::new(&vs / "time_embedding", b_channels, time_embed_dim); 100 | 101 | let vs_db = &vs / "down_blocks"; 102 | let down_blocks = (0..n_blocks) 103 | .map(|i| { 104 | let BlockConfig { out_channels, use_cross_attn, attention_head_dim } = 105 | config.blocks[i]; 106 | 107 | // Enable automatic attention slicing if the config sliced_attention_size is set to 0. 108 | let sliced_attention_size = match config.sliced_attention_size { 109 | Some(0) => Some(attention_head_dim / 2), 110 | _ => config.sliced_attention_size, 111 | }; 112 | 113 | let in_channels = 114 | if i > 0 { config.blocks[i - 1].out_channels } else { b_channels }; 115 | let db_cfg = DownBlock2DConfig { 116 | num_layers: config.layers_per_block, 117 | resnet_eps: config.norm_eps, 118 | resnet_groups: config.norm_num_groups, 119 | add_downsample: i < n_blocks - 1, 120 | downsample_padding: config.downsample_padding, 121 | ..Default::default() 122 | }; 123 | if use_cross_attn { 124 | let config = CrossAttnDownBlock2DConfig { 125 | downblock: db_cfg, 126 | attn_num_head_channels: attention_head_dim, 127 | cross_attention_dim: config.cross_attention_dim, 128 | sliced_attention_size, 129 | use_linear_projection: config.use_linear_projection, 130 | }; 131 | let block = CrossAttnDownBlock2D::new( 132 | &vs_db / i, 133 | in_channels, 134 | out_channels, 135 | Some(time_embed_dim), 136 | config, 137 | ); 138 | UNetDownBlock::CrossAttn(block) 139 | } else { 140 | let block = DownBlock2D::new( 141 | &vs_db / i, 142 | in_channels, 143 | out_channels, 144 | Some(time_embed_dim), 145 | db_cfg, 146 | ); 147 | UNetDownBlock::Basic(block) 148 | } 149 | }) 150 | .collect(); 151 | 152 | let mid_cfg = UNetMidBlock2DCrossAttnConfig { 153 | resnet_eps: config.norm_eps, 154 | output_scale_factor: config.mid_block_scale_factor, 155 | cross_attn_dim: config.cross_attention_dim, 156 | attn_num_head_channels: bl_attention_head_dim, 157 | resnet_groups: Some(config.norm_num_groups), 158 | use_linear_projection: config.use_linear_projection, 159 | ..Default::default() 160 | }; 161 | let mid_block = UNetMidBlock2DCrossAttn::new( 162 | &vs / "mid_block", 163 | bl_channels, 164 | Some(time_embed_dim), 165 | mid_cfg, 166 | ); 167 | 168 | let vs_ub = &vs / "up_blocks"; 169 | let up_blocks = (0..n_blocks) 170 | .map(|i| { 171 | let BlockConfig { out_channels, use_cross_attn, attention_head_dim } = 172 | config.blocks[n_blocks - 1 - i]; 173 | 174 | // Enable automatic attention slicing if the config sliced_attention_size is set to 0. 175 | let sliced_attention_size = match config.sliced_attention_size { 176 | Some(0) => Some(attention_head_dim / 2), 177 | _ => config.sliced_attention_size, 178 | }; 179 | 180 | let prev_out_channels = 181 | if i > 0 { config.blocks[n_blocks - i].out_channels } else { bl_channels }; 182 | let in_channels = { 183 | let index = if i == n_blocks - 1 { 0 } else { n_blocks - i - 2 }; 184 | config.blocks[index].out_channels 185 | }; 186 | let ub_cfg = UpBlock2DConfig { 187 | num_layers: config.layers_per_block + 1, 188 | resnet_eps: config.norm_eps, 189 | resnet_groups: config.norm_num_groups, 190 | add_upsample: i < n_blocks - 1, 191 | ..Default::default() 192 | }; 193 | if use_cross_attn { 194 | let config = CrossAttnUpBlock2DConfig { 195 | upblock: ub_cfg, 196 | attn_num_head_channels: attention_head_dim, 197 | cross_attention_dim: config.cross_attention_dim, 198 | sliced_attention_size, 199 | use_linear_projection: config.use_linear_projection, 200 | }; 201 | let block = CrossAttnUpBlock2D::new( 202 | &vs_ub / i, 203 | in_channels, 204 | prev_out_channels, 205 | out_channels, 206 | Some(time_embed_dim), 207 | config, 208 | ); 209 | UNetUpBlock::CrossAttn(block) 210 | } else { 211 | let block = UpBlock2D::new( 212 | &vs_ub / i, 213 | in_channels, 214 | prev_out_channels, 215 | out_channels, 216 | Some(time_embed_dim), 217 | ub_cfg, 218 | ); 219 | UNetUpBlock::Basic(block) 220 | } 221 | }) 222 | .collect(); 223 | 224 | let group_cfg = nn::GroupNormConfig { eps: config.norm_eps, ..Default::default() }; 225 | let conv_norm_out = 226 | nn::group_norm(&vs / "conv_norm_out", config.norm_num_groups, b_channels, group_cfg); 227 | let conv_out = nn::conv2d(&vs / "conv_out", b_channels, out_channels, 3, conv_cfg); 228 | Self { 229 | conv_in, 230 | time_proj, 231 | time_embedding, 232 | down_blocks, 233 | mid_block, 234 | up_blocks, 235 | conv_norm_out, 236 | conv_out, 237 | config, 238 | } 239 | } 240 | } 241 | 242 | impl UNet2DConditionModel { 243 | pub fn forward(&self, xs: &Tensor, timestep: f64, encoder_hidden_states: &Tensor) -> Tensor { 244 | self.forward_with_additional_residuals(xs, timestep, encoder_hidden_states, None, None) 245 | } 246 | 247 | pub fn forward_with_additional_residuals( 248 | &self, 249 | xs: &Tensor, 250 | timestep: f64, 251 | encoder_hidden_states: &Tensor, 252 | down_block_additional_residuals: Option<&[Tensor]>, 253 | mid_block_additional_residual: Option<&Tensor>, 254 | ) -> Tensor { 255 | let (bsize, _channels, height, width) = xs.size4().unwrap(); 256 | let device = xs.device(); 257 | let n_blocks = self.config.blocks.len(); 258 | let num_upsamplers = n_blocks - 1; 259 | let default_overall_up_factor = 2i64.pow(num_upsamplers as u32); 260 | let forward_upsample_size = 261 | height % default_overall_up_factor != 0 || width % default_overall_up_factor != 0; 262 | // 0. center input if necessary 263 | let xs = if self.config.center_input_sample { xs * 2.0 - 1.0 } else { xs.shallow_clone() }; 264 | // 1. time 265 | let emb = (Tensor::ones([bsize], (Kind::Float, device)) * timestep) 266 | .apply(&self.time_proj) 267 | .apply(&self.time_embedding); 268 | // 2. pre-process 269 | let xs = xs.apply(&self.conv_in); 270 | // 3. down 271 | let mut down_block_res_xs = vec![xs.shallow_clone()]; 272 | let mut xs = xs; 273 | for down_block in self.down_blocks.iter() { 274 | let (_xs, res_xs) = match down_block { 275 | UNetDownBlock::Basic(b) => b.forward(&xs, Some(&emb)), 276 | UNetDownBlock::CrossAttn(b) => { 277 | b.forward(&xs, Some(&emb), Some(encoder_hidden_states)) 278 | } 279 | }; 280 | down_block_res_xs.extend(res_xs); 281 | xs = _xs; 282 | } 283 | 284 | let new_down_block_res_xs = 285 | if let Some(down_block_additional_residuals) = down_block_additional_residuals { 286 | let mut v = vec![]; 287 | // A previous version of this code had a bug because of the addition being made 288 | // in place via += hence modifying the input of the mid block. 289 | for (i, residuals) in down_block_additional_residuals.iter().enumerate() { 290 | v.push(&down_block_res_xs[i] + residuals) 291 | } 292 | v 293 | } else { 294 | down_block_res_xs 295 | }; 296 | let mut down_block_res_xs = new_down_block_res_xs; 297 | 298 | // 4. mid 299 | let xs = self.mid_block.forward(&xs, Some(&emb), Some(encoder_hidden_states)); 300 | let xs = match mid_block_additional_residual { 301 | None => xs, 302 | Some(m) => m + xs, 303 | }; 304 | // 5. up 305 | let mut xs = xs; 306 | let mut upsample_size = None; 307 | for (i, up_block) in self.up_blocks.iter().enumerate() { 308 | let n_resnets = match up_block { 309 | UNetUpBlock::Basic(b) => b.resnets.len(), 310 | UNetUpBlock::CrossAttn(b) => b.upblock.resnets.len(), 311 | }; 312 | let res_xs = down_block_res_xs.split_off(down_block_res_xs.len() - n_resnets); 313 | if i < n_blocks - 1 && forward_upsample_size { 314 | let (_, _, h, w) = down_block_res_xs.last().unwrap().size4().unwrap(); 315 | upsample_size = Some((h, w)) 316 | } 317 | xs = match up_block { 318 | UNetUpBlock::Basic(b) => b.forward(&xs, &res_xs, Some(&emb), upsample_size), 319 | UNetUpBlock::CrossAttn(b) => { 320 | b.forward(&xs, &res_xs, Some(&emb), upsample_size, Some(encoder_hidden_states)) 321 | } 322 | }; 323 | } 324 | // 6. post-process 325 | xs.apply(&self.conv_norm_out).silu().apply(&self.conv_out) 326 | } 327 | } 328 | -------------------------------------------------------------------------------- /src/models/vae.rs: -------------------------------------------------------------------------------- 1 | //! # Variational Auto-Encoder (VAE) Models. 2 | //! 3 | //! Auto-encoder models compress their input to a usually smaller latent space 4 | //! before expanding it back to its original shape. This results in the latent values 5 | //! compressing the original information. 6 | use crate::models::unet_2d_blocks::{ 7 | DownEncoderBlock2D, DownEncoderBlock2DConfig, UNetMidBlock2D, UNetMidBlock2DConfig, 8 | UpDecoderBlock2D, UpDecoderBlock2DConfig, 9 | }; 10 | use tch::{nn, nn::Module, Tensor}; 11 | 12 | #[derive(Debug, Clone)] 13 | struct EncoderConfig { 14 | // down_block_types: DownEncoderBlock2D 15 | block_out_channels: Vec, 16 | layers_per_block: i64, 17 | norm_num_groups: i64, 18 | double_z: bool, 19 | } 20 | 21 | impl Default for EncoderConfig { 22 | fn default() -> Self { 23 | Self { 24 | block_out_channels: vec![64], 25 | layers_per_block: 2, 26 | norm_num_groups: 32, 27 | double_z: true, 28 | } 29 | } 30 | } 31 | 32 | #[derive(Debug)] 33 | struct Encoder { 34 | conv_in: nn::Conv2D, 35 | down_blocks: Vec, 36 | mid_block: UNetMidBlock2D, 37 | conv_norm_out: nn::GroupNorm, 38 | conv_out: nn::Conv2D, 39 | #[allow(dead_code)] 40 | config: EncoderConfig, 41 | } 42 | 43 | impl Encoder { 44 | fn new(vs: nn::Path, in_channels: i64, out_channels: i64, config: EncoderConfig) -> Self { 45 | let conv_cfg = nn::ConvConfig { stride: 1, padding: 1, ..Default::default() }; 46 | let conv_in = 47 | nn::conv2d(&vs / "conv_in", in_channels, config.block_out_channels[0], 3, conv_cfg); 48 | let mut down_blocks = vec![]; 49 | let vs_down_blocks = &vs / "down_blocks"; 50 | for index in 0..config.block_out_channels.len() { 51 | let out_channels = config.block_out_channels[index]; 52 | let in_channels = if index > 0 { 53 | config.block_out_channels[index - 1] 54 | } else { 55 | config.block_out_channels[0] 56 | }; 57 | let is_final = index + 1 == config.block_out_channels.len(); 58 | let cfg = DownEncoderBlock2DConfig { 59 | num_layers: config.layers_per_block, 60 | resnet_eps: 1e-6, 61 | resnet_groups: config.norm_num_groups, 62 | add_downsample: !is_final, 63 | downsample_padding: 0, 64 | ..Default::default() 65 | }; 66 | let down_block = 67 | DownEncoderBlock2D::new(&vs_down_blocks / index, in_channels, out_channels, cfg); 68 | down_blocks.push(down_block) 69 | } 70 | let last_block_out_channels = *config.block_out_channels.last().unwrap(); 71 | let mid_cfg = UNetMidBlock2DConfig { 72 | resnet_eps: 1e-6, 73 | output_scale_factor: 1., 74 | attn_num_head_channels: None, 75 | resnet_groups: Some(config.norm_num_groups), 76 | ..Default::default() 77 | }; 78 | let mid_block = 79 | UNetMidBlock2D::new(&vs / "mid_block", last_block_out_channels, None, mid_cfg); 80 | let group_cfg = nn::GroupNormConfig { eps: 1e-6, ..Default::default() }; 81 | let conv_norm_out = nn::group_norm( 82 | &vs / "conv_norm_out", 83 | config.norm_num_groups, 84 | last_block_out_channels, 85 | group_cfg, 86 | ); 87 | let conv_out_channels = if config.double_z { 2 * out_channels } else { out_channels }; 88 | let conv_cfg = nn::ConvConfig { padding: 1, ..Default::default() }; 89 | let conv_out = 90 | nn::conv2d(&vs / "conv_out", last_block_out_channels, conv_out_channels, 3, conv_cfg); 91 | Self { conv_in, down_blocks, mid_block, conv_norm_out, conv_out, config } 92 | } 93 | } 94 | 95 | impl Module for Encoder { 96 | fn forward(&self, xs: &Tensor) -> Tensor { 97 | let mut xs = xs.apply(&self.conv_in); 98 | for down_block in self.down_blocks.iter() { 99 | xs = xs.apply(down_block) 100 | } 101 | self.mid_block.forward(&xs, None).apply(&self.conv_norm_out).silu().apply(&self.conv_out) 102 | } 103 | } 104 | 105 | #[derive(Debug, Clone)] 106 | struct DecoderConfig { 107 | // up_block_types: UpDecoderBlock2D 108 | block_out_channels: Vec, 109 | layers_per_block: i64, 110 | norm_num_groups: i64, 111 | } 112 | 113 | impl Default for DecoderConfig { 114 | fn default() -> Self { 115 | Self { block_out_channels: vec![64], layers_per_block: 2, norm_num_groups: 32 } 116 | } 117 | } 118 | 119 | #[derive(Debug)] 120 | struct Decoder { 121 | conv_in: nn::Conv2D, 122 | up_blocks: Vec, 123 | mid_block: UNetMidBlock2D, 124 | conv_norm_out: nn::GroupNorm, 125 | conv_out: nn::Conv2D, 126 | #[allow(dead_code)] 127 | config: DecoderConfig, 128 | } 129 | 130 | impl Decoder { 131 | fn new(vs: nn::Path, in_channels: i64, out_channels: i64, config: DecoderConfig) -> Self { 132 | let n_block_out_channels = config.block_out_channels.len(); 133 | let last_block_out_channels = *config.block_out_channels.last().unwrap(); 134 | let conv_cfg = nn::ConvConfig { stride: 1, padding: 1, ..Default::default() }; 135 | let conv_in = 136 | nn::conv2d(&vs / "conv_in", in_channels, last_block_out_channels, 3, conv_cfg); 137 | let mid_cfg = UNetMidBlock2DConfig { 138 | resnet_eps: 1e-6, 139 | output_scale_factor: 1., 140 | attn_num_head_channels: None, 141 | resnet_groups: Some(config.norm_num_groups), 142 | ..Default::default() 143 | }; 144 | let mid_block = 145 | UNetMidBlock2D::new(&vs / "mid_block", last_block_out_channels, None, mid_cfg); 146 | let mut up_blocks = vec![]; 147 | let vs_up_blocks = &vs / "up_blocks"; 148 | let reversed_block_out_channels: Vec<_> = 149 | config.block_out_channels.iter().copied().rev().collect(); 150 | for index in 0..n_block_out_channels { 151 | let out_channels = reversed_block_out_channels[index]; 152 | let in_channels = if index > 0 { 153 | reversed_block_out_channels[index - 1] 154 | } else { 155 | reversed_block_out_channels[0] 156 | }; 157 | let is_final = index + 1 == n_block_out_channels; 158 | let cfg = UpDecoderBlock2DConfig { 159 | num_layers: config.layers_per_block + 1, 160 | resnet_eps: 1e-6, 161 | resnet_groups: config.norm_num_groups, 162 | add_upsample: !is_final, 163 | ..Default::default() 164 | }; 165 | let up_block = 166 | UpDecoderBlock2D::new(&vs_up_blocks / index, in_channels, out_channels, cfg); 167 | up_blocks.push(up_block) 168 | } 169 | let group_cfg = nn::GroupNormConfig { eps: 1e-6, ..Default::default() }; 170 | let conv_norm_out = nn::group_norm( 171 | &vs / "conv_norm_out", 172 | config.norm_num_groups, 173 | config.block_out_channels[0], 174 | group_cfg, 175 | ); 176 | let conv_cfg = nn::ConvConfig { padding: 1, ..Default::default() }; 177 | let conv_out = 178 | nn::conv2d(&vs / "conv_out", config.block_out_channels[0], out_channels, 3, conv_cfg); 179 | Self { conv_in, up_blocks, mid_block, conv_norm_out, conv_out, config } 180 | } 181 | } 182 | 183 | impl Module for Decoder { 184 | fn forward(&self, xs: &Tensor) -> Tensor { 185 | let mut xs = self.mid_block.forward(&xs.apply(&self.conv_in), None); 186 | for up_block in self.up_blocks.iter() { 187 | xs = xs.apply(up_block) 188 | } 189 | xs.apply(&self.conv_norm_out).silu().apply(&self.conv_out) 190 | } 191 | } 192 | 193 | #[derive(Debug, Clone)] 194 | pub struct AutoEncoderKLConfig { 195 | pub block_out_channels: Vec, 196 | pub layers_per_block: i64, 197 | pub latent_channels: i64, 198 | pub norm_num_groups: i64, 199 | } 200 | 201 | impl Default for AutoEncoderKLConfig { 202 | fn default() -> Self { 203 | Self { 204 | block_out_channels: vec![64], 205 | layers_per_block: 1, 206 | latent_channels: 4, 207 | norm_num_groups: 32, 208 | } 209 | } 210 | } 211 | 212 | pub struct DiagonalGaussianDistribution { 213 | mean: Tensor, 214 | std: Tensor, 215 | device: tch::Device, 216 | } 217 | 218 | impl DiagonalGaussianDistribution { 219 | pub fn new(parameters: &Tensor) -> Self { 220 | let mut parameters = parameters.chunk(2, 1).into_iter(); 221 | let mean = parameters.next().unwrap(); 222 | let logvar = parameters.next().unwrap(); 223 | let std = (logvar * 0.5).exp(); 224 | let device = std.device(); 225 | DiagonalGaussianDistribution { mean, std, device } 226 | } 227 | 228 | pub fn sample(&self) -> Tensor { 229 | let sample = Tensor::randn_like(&self.mean).to(self.device); 230 | &self.mean + &self.std * sample 231 | } 232 | } 233 | 234 | // https://github.com/huggingface/diffusers/blob/970e30606c2944e3286f56e8eb6d3dc6d1eb85f7/src/diffusers/models/vae.py#L485 235 | // This implementation is specific to the config used in stable-diffusion-v1-5 236 | // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json 237 | #[derive(Debug)] 238 | pub struct AutoEncoderKL { 239 | encoder: Encoder, 240 | decoder: Decoder, 241 | quant_conv: nn::Conv2D, 242 | post_quant_conv: nn::Conv2D, 243 | pub config: AutoEncoderKLConfig, 244 | } 245 | 246 | impl AutoEncoderKL { 247 | pub fn new( 248 | vs: nn::Path, 249 | in_channels: i64, 250 | out_channels: i64, 251 | config: AutoEncoderKLConfig, 252 | ) -> Self { 253 | let latent_channels = config.latent_channels; 254 | let encoder_cfg = EncoderConfig { 255 | block_out_channels: config.block_out_channels.clone(), 256 | layers_per_block: config.layers_per_block, 257 | norm_num_groups: config.norm_num_groups, 258 | double_z: true, 259 | }; 260 | let encoder = Encoder::new(&vs / "encoder", in_channels, latent_channels, encoder_cfg); 261 | let decoder_cfg = DecoderConfig { 262 | block_out_channels: config.block_out_channels.clone(), 263 | layers_per_block: config.layers_per_block, 264 | norm_num_groups: config.norm_num_groups, 265 | }; 266 | let decoder = Decoder::new(&vs / "decoder", latent_channels, out_channels, decoder_cfg); 267 | let conv_cfg = Default::default(); 268 | let quant_conv = 269 | nn::conv2d(&vs / "quant_conv", 2 * latent_channels, 2 * latent_channels, 1, conv_cfg); 270 | let post_quant_conv = 271 | nn::conv2d(&vs / "post_quant_conv", latent_channels, latent_channels, 1, conv_cfg); 272 | Self { encoder, decoder, quant_conv, post_quant_conv, config } 273 | } 274 | 275 | /// Returns the distribution in the latent space. 276 | pub fn encode(&self, xs: &Tensor) -> DiagonalGaussianDistribution { 277 | let parameters = xs.apply(&self.encoder).apply(&self.quant_conv); 278 | DiagonalGaussianDistribution::new(¶meters) 279 | } 280 | 281 | /// Takes as input some sampled values. 282 | pub fn decode(&self, xs: &Tensor) -> Tensor { 283 | xs.apply(&self.post_quant_conv).apply(&self.decoder) 284 | } 285 | } 286 | -------------------------------------------------------------------------------- /src/pipelines/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Pipelines 2 | 3 | pub mod stable_diffusion; 4 | -------------------------------------------------------------------------------- /src/pipelines/stable_diffusion.rs: -------------------------------------------------------------------------------- 1 | use crate::models::{unet_2d, vae}; 2 | use crate::schedulers::ddim; 3 | use crate::schedulers::PredictionType; 4 | use crate::transformers::clip; 5 | use tch::{nn, Device}; 6 | 7 | #[derive(Clone, Debug)] 8 | pub struct StableDiffusionConfig { 9 | pub width: i64, 10 | pub height: i64, 11 | pub clip: clip::Config, 12 | autoencoder: vae::AutoEncoderKLConfig, 13 | unet: unet_2d::UNet2DConditionModelConfig, 14 | scheduler: ddim::DDIMSchedulerConfig, 15 | } 16 | 17 | impl StableDiffusionConfig { 18 | pub fn v1_5( 19 | sliced_attention_size: Option, 20 | height: Option, 21 | width: Option, 22 | ) -> Self { 23 | let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { 24 | out_channels, 25 | use_cross_attn, 26 | attention_head_dim, 27 | }; 28 | // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/unet/config.json 29 | let unet = unet_2d::UNet2DConditionModelConfig { 30 | blocks: vec![bc(320, true, 8), bc(640, true, 8), bc(1280, true, 8), bc(1280, false, 8)], 31 | center_input_sample: false, 32 | cross_attention_dim: 768, 33 | downsample_padding: 1, 34 | flip_sin_to_cos: true, 35 | freq_shift: 0., 36 | layers_per_block: 2, 37 | mid_block_scale_factor: 1., 38 | norm_eps: 1e-5, 39 | norm_num_groups: 32, 40 | sliced_attention_size, 41 | use_linear_projection: false, 42 | }; 43 | let autoencoder = vae::AutoEncoderKLConfig { 44 | block_out_channels: vec![128, 256, 512, 512], 45 | layers_per_block: 2, 46 | latent_channels: 4, 47 | norm_num_groups: 32, 48 | }; 49 | let height = if let Some(height) = height { 50 | assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); 51 | height 52 | } else { 53 | 512 54 | }; 55 | 56 | let width = if let Some(width) = width { 57 | assert_eq!(width % 8, 0, "width has to be divisible by 8"); 58 | width 59 | } else { 60 | 512 61 | }; 62 | 63 | Self { 64 | width, 65 | height, 66 | clip: clip::Config::v1_5(), 67 | autoencoder, 68 | scheduler: Default::default(), 69 | unet, 70 | } 71 | } 72 | 73 | fn v2_1_( 74 | sliced_attention_size: Option, 75 | height: Option, 76 | width: Option, 77 | prediction_type: PredictionType, 78 | ) -> Self { 79 | let bc = |out_channels, use_cross_attn, attention_head_dim| unet_2d::BlockConfig { 80 | out_channels, 81 | use_cross_attn, 82 | attention_head_dim, 83 | }; 84 | // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/unet/config.json 85 | let unet = unet_2d::UNet2DConditionModelConfig { 86 | blocks: vec![ 87 | bc(320, true, 5), 88 | bc(640, true, 10), 89 | bc(1280, true, 20), 90 | bc(1280, false, 20), 91 | ], 92 | center_input_sample: false, 93 | cross_attention_dim: 1024, 94 | downsample_padding: 1, 95 | flip_sin_to_cos: true, 96 | freq_shift: 0., 97 | layers_per_block: 2, 98 | mid_block_scale_factor: 1., 99 | norm_eps: 1e-5, 100 | norm_num_groups: 32, 101 | sliced_attention_size, 102 | use_linear_projection: true, 103 | }; 104 | // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/vae/config.json 105 | let autoencoder = vae::AutoEncoderKLConfig { 106 | block_out_channels: vec![128, 256, 512, 512], 107 | layers_per_block: 2, 108 | latent_channels: 4, 109 | norm_num_groups: 32, 110 | }; 111 | let scheduler = ddim::DDIMSchedulerConfig { prediction_type, ..Default::default() }; 112 | 113 | let height = if let Some(height) = height { 114 | assert_eq!(height % 8, 0, "heigh has to be divisible by 8"); 115 | height 116 | } else { 117 | 768 118 | }; 119 | 120 | let width = if let Some(width) = width { 121 | assert_eq!(width % 8, 0, "width has to be divisible by 8"); 122 | width 123 | } else { 124 | 768 125 | }; 126 | 127 | Self { width, height, clip: clip::Config::v2_1(), autoencoder, scheduler, unet } 128 | } 129 | 130 | pub fn v2_1( 131 | sliced_attention_size: Option, 132 | height: Option, 133 | width: Option, 134 | ) -> Self { 135 | // https://huggingface.co/stabilityai/stable-diffusion-2-1/blob/main/scheduler/scheduler_config.json 136 | Self::v2_1_(sliced_attention_size, height, width, PredictionType::VPrediction) 137 | } 138 | 139 | pub fn v2_1_inpaint( 140 | sliced_attention_size: Option, 141 | height: Option, 142 | width: Option, 143 | ) -> Self { 144 | // https://huggingface.co/stabilityai/stable-diffusion-2-inpainting/blob/main/scheduler/scheduler_config.json 145 | // This uses a PNDM scheduler rather than DDIM but the biggest difference is the prediction 146 | // type being "epsilon" by default and not "v_prediction". 147 | Self::v2_1_(sliced_attention_size, height, width, PredictionType::Epsilon) 148 | } 149 | 150 | pub fn build_vae( 151 | &self, 152 | vae_weights: &str, 153 | device: Device, 154 | ) -> anyhow::Result { 155 | let mut vs_ae = nn::VarStore::new(device); 156 | // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json 157 | let autoencoder = vae::AutoEncoderKL::new(vs_ae.root(), 3, 3, self.autoencoder.clone()); 158 | vs_ae.load(vae_weights)?; 159 | Ok(autoencoder) 160 | } 161 | 162 | pub fn build_unet( 163 | &self, 164 | unet_weights: &str, 165 | device: Device, 166 | in_channels: i64, 167 | ) -> anyhow::Result { 168 | let mut vs_unet = nn::VarStore::new(device); 169 | let unet = 170 | unet_2d::UNet2DConditionModel::new(vs_unet.root(), in_channels, 4, self.unet.clone()); 171 | vs_unet.load(unet_weights)?; 172 | Ok(unet) 173 | } 174 | 175 | pub fn build_scheduler(&self, n_steps: usize) -> ddim::DDIMScheduler { 176 | ddim::DDIMScheduler::new(n_steps, self.scheduler) 177 | } 178 | 179 | pub fn build_clip_transformer( 180 | &self, 181 | clip_weights: &str, 182 | device: tch::Device, 183 | ) -> anyhow::Result { 184 | let mut vs = tch::nn::VarStore::new(device); 185 | let text_model = clip::ClipTextTransformer::new(vs.root(), &self.clip); 186 | vs.load(clip_weights)?; 187 | Ok(text_model) 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /src/schedulers/ddim.rs: -------------------------------------------------------------------------------- 1 | //! # Denoising Diffusion Implicit Models 2 | //! 3 | //! The Denoising Diffusion Implicit Models (DDIM) is a simple scheduler 4 | //! similar to Denoising Diffusion Probabilistic Models (DDPM). The DDPM 5 | //! generative process is the reverse of a Markovian process, DDIM generalizes 6 | //! this to non-Markovian guidance. 7 | //! 8 | //! Denoising Diffusion Implicit Models, J. Song et al, 2020. 9 | //! https://arxiv.org/abs/2010.02502 10 | use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; 11 | use tch::{kind, Kind, Tensor}; 12 | 13 | /// The configuration for the DDIM scheduler. 14 | #[derive(Debug, Clone, Copy)] 15 | pub struct DDIMSchedulerConfig { 16 | /// The value of beta at the beginning of training. 17 | pub beta_start: f64, 18 | /// The value of beta at the end of training. 19 | pub beta_end: f64, 20 | /// How beta evolved during training. 21 | pub beta_schedule: BetaSchedule, 22 | /// The amount of noise to be added at each step. 23 | pub eta: f64, 24 | /// Adjust the indexes of the inference schedule by this value. 25 | pub steps_offset: usize, 26 | /// prediction type of the scheduler function, one of `epsilon` (predicting 27 | /// the noise of the diffusion process), `sample` (directly predicting the noisy sample`) 28 | /// or `v_prediction` (see section 2.4 https://imagen.research.google/video/paper.pdf) 29 | pub prediction_type: PredictionType, 30 | /// number of diffusion steps used to train the model 31 | pub train_timesteps: usize, 32 | } 33 | 34 | impl Default for DDIMSchedulerConfig { 35 | fn default() -> Self { 36 | Self { 37 | beta_start: 0.00085f64, 38 | beta_end: 0.012f64, 39 | beta_schedule: BetaSchedule::ScaledLinear, 40 | eta: 0., 41 | steps_offset: 1, 42 | prediction_type: PredictionType::Epsilon, 43 | train_timesteps: 1000, 44 | } 45 | } 46 | } 47 | 48 | /// The DDIM scheduler. 49 | #[derive(Debug, Clone)] 50 | pub struct DDIMScheduler { 51 | timesteps: Vec, 52 | alphas_cumprod: Vec, 53 | step_ratio: usize, 54 | init_noise_sigma: f64, 55 | pub config: DDIMSchedulerConfig, 56 | } 57 | 58 | // clip_sample: False, set_alpha_to_one: False 59 | impl DDIMScheduler { 60 | /// Creates a new DDIM scheduler given the number of steps to be 61 | /// used for inference as well as the number of steps that was used 62 | /// during training. 63 | pub fn new(inference_steps: usize, config: DDIMSchedulerConfig) -> Self { 64 | let step_ratio = config.train_timesteps / inference_steps; 65 | let timesteps: Vec = 66 | (0..(inference_steps)).map(|s| s * step_ratio + config.steps_offset).rev().collect(); 67 | let betas = match config.beta_schedule { 68 | BetaSchedule::ScaledLinear => Tensor::linspace( 69 | config.beta_start.sqrt(), 70 | config.beta_end.sqrt(), 71 | config.train_timesteps as i64, 72 | kind::FLOAT_CPU, 73 | ) 74 | .square(), 75 | BetaSchedule::Linear => Tensor::linspace( 76 | config.beta_start, 77 | config.beta_end, 78 | config.train_timesteps as i64, 79 | kind::FLOAT_CPU, 80 | ), 81 | BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999), 82 | }; 83 | let alphas: Tensor = 1.0 - betas; 84 | let alphas_cumprod = Vec::::try_from(alphas.cumprod(0, Kind::Double)).unwrap(); 85 | Self { alphas_cumprod, timesteps, step_ratio, init_noise_sigma: 1., config } 86 | } 87 | 88 | pub fn timesteps(&self) -> &[usize] { 89 | self.timesteps.as_slice() 90 | } 91 | 92 | /// Ensures interchangeability with schedulers that need to scale the denoising model input 93 | /// depending on the current timestep. 94 | pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { 95 | sample 96 | } 97 | 98 | /// Performs a backward step during inference. 99 | pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { 100 | let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { timestep }; 101 | // https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195 102 | let prev_timestep = if timestep > self.step_ratio { timestep - self.step_ratio } else { 0 }; 103 | 104 | let alpha_prod_t = self.alphas_cumprod[timestep]; 105 | let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep]; 106 | let beta_prod_t = 1. - alpha_prod_t; 107 | let beta_prod_t_prev = 1. - alpha_prod_t_prev; 108 | 109 | let (pred_original_sample, pred_epsilon) = match self.config.prediction_type { 110 | PredictionType::Epsilon => { 111 | let pred_original_sample = 112 | (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt(); 113 | (pred_original_sample, model_output.shallow_clone()) 114 | } 115 | PredictionType::VPrediction => { 116 | let pred_original_sample = 117 | alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output; 118 | let pred_epsilon = alpha_prod_t.sqrt() * model_output + beta_prod_t.sqrt() * sample; 119 | (pred_original_sample, pred_epsilon) 120 | } 121 | PredictionType::Sample => { 122 | let pred_original_sample = model_output.shallow_clone(); 123 | let pred_epsilon = 124 | (sample - alpha_prod_t.sqrt() * &pred_original_sample) / beta_prod_t.sqrt(); 125 | (pred_original_sample, pred_epsilon) 126 | } 127 | }; 128 | 129 | let variance = (beta_prod_t_prev / beta_prod_t) * (1. - alpha_prod_t / alpha_prod_t_prev); 130 | let std_dev_t = self.config.eta * variance.sqrt(); 131 | 132 | let pred_sample_direction = 133 | (1. - alpha_prod_t_prev - std_dev_t * std_dev_t).sqrt() * pred_epsilon; 134 | let prev_sample = alpha_prod_t_prev.sqrt() * pred_original_sample + pred_sample_direction; 135 | if self.config.eta > 0. { 136 | &prev_sample + Tensor::randn_like(&prev_sample) * std_dev_t 137 | } else { 138 | prev_sample 139 | } 140 | } 141 | 142 | pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Tensor { 143 | let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { timestep }; 144 | let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); 145 | let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); 146 | sqrt_alpha_prod * original + sqrt_one_minus_alpha_prod * noise 147 | } 148 | 149 | pub fn init_noise_sigma(&self) -> f64 { 150 | self.init_noise_sigma 151 | } 152 | } 153 | -------------------------------------------------------------------------------- /src/schedulers/ddpm.rs: -------------------------------------------------------------------------------- 1 | use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; 2 | use tch::{kind, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone, PartialEq, Eq)] 5 | pub enum DDPMVarianceType { 6 | FixedSmall, 7 | FixedSmallLog, 8 | FixedLarge, 9 | FixedLargeLog, 10 | Learned, 11 | } 12 | 13 | impl Default for DDPMVarianceType { 14 | fn default() -> Self { 15 | Self::FixedSmall 16 | } 17 | } 18 | 19 | #[derive(Debug, Clone)] 20 | pub struct DDPMSchedulerConfig { 21 | /// The value of beta at the beginning of training. 22 | pub beta_start: f64, 23 | /// The value of beta at the end of training. 24 | pub beta_end: f64, 25 | /// How beta evolved during training. 26 | pub beta_schedule: BetaSchedule, 27 | /// Option to predicted sample between -1 and 1 for numerical stability. 28 | pub clip_sample: bool, 29 | /// Option to clip the variance used when adding noise to the denoised sample. 30 | pub variance_type: DDPMVarianceType, 31 | /// prediction type of the scheduler function 32 | pub prediction_type: PredictionType, 33 | /// number of diffusion steps used to train the model. 34 | pub train_timesteps: usize, 35 | } 36 | 37 | impl Default for DDPMSchedulerConfig { 38 | fn default() -> Self { 39 | Self { 40 | beta_start: 0.00085, 41 | beta_end: 0.012, 42 | beta_schedule: BetaSchedule::ScaledLinear, 43 | clip_sample: false, 44 | variance_type: DDPMVarianceType::FixedSmall, 45 | prediction_type: PredictionType::Epsilon, 46 | train_timesteps: 1000, 47 | } 48 | } 49 | } 50 | 51 | pub struct DDPMScheduler { 52 | alphas_cumprod: Vec, 53 | init_noise_sigma: f64, 54 | timesteps: Vec, 55 | step_ratio: usize, 56 | pub config: DDPMSchedulerConfig, 57 | } 58 | 59 | impl DDPMScheduler { 60 | pub fn new(inference_steps: usize, config: DDPMSchedulerConfig) -> Self { 61 | let betas = match config.beta_schedule { 62 | BetaSchedule::ScaledLinear => Tensor::linspace( 63 | config.beta_start.sqrt(), 64 | config.beta_end.sqrt(), 65 | config.train_timesteps as i64, 66 | kind::FLOAT_CPU, 67 | ) 68 | .square(), 69 | BetaSchedule::Linear => Tensor::linspace( 70 | config.beta_start, 71 | config.beta_end, 72 | config.train_timesteps as i64, 73 | kind::FLOAT_CPU, 74 | ), 75 | 76 | BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999), 77 | }; 78 | 79 | // &betas to avoid moving it 80 | let alphas: Tensor = 1. - betas; 81 | let alphas_cumprod = Vec::::try_from(alphas.cumprod(0, Kind::Double)).unwrap(); 82 | 83 | // min(train_timesteps, inference_steps) 84 | // https://github.com/huggingface/diffusers/blob/8331da46837be40f96fbd24de6a6fb2da28acd11/src/diffusers/schedulers/scheduling_ddpm.py#L187 85 | let inference_steps = inference_steps.min(config.train_timesteps); 86 | // arange the number of the scheduler's timesteps 87 | let step_ratio = config.train_timesteps / inference_steps; 88 | let timesteps: Vec = (0..inference_steps).map(|s| s * step_ratio).rev().collect(); 89 | 90 | Self { alphas_cumprod, init_noise_sigma: 1.0, timesteps, step_ratio, config } 91 | } 92 | 93 | fn get_variance(&self, timestep: usize) -> f64 { 94 | let prev_t = timestep as isize - self.step_ratio as isize; 95 | let alpha_prod_t = self.alphas_cumprod[timestep]; 96 | let alpha_prod_t_prev = 97 | if prev_t >= 0 { self.alphas_cumprod[prev_t as usize] } else { 1.0 }; 98 | let current_beta_t = 1. - alpha_prod_t / alpha_prod_t_prev; 99 | 100 | // For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf) 101 | // and sample from it to get previous sample 102 | // x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample 103 | let variance = (1. - alpha_prod_t_prev) / (1. - alpha_prod_t) * current_beta_t; 104 | 105 | // retrieve variance 106 | match self.config.variance_type { 107 | DDPMVarianceType::FixedSmall => variance.max(1e-20), 108 | // for rl-diffuser https://arxiv.org/abs/2205.09991 109 | DDPMVarianceType::FixedSmallLog => { 110 | let variance = variance.max(1e-20).ln(); 111 | (variance * 0.5).exp() 112 | } 113 | DDPMVarianceType::FixedLarge => current_beta_t, 114 | DDPMVarianceType::FixedLargeLog => current_beta_t.ln(), 115 | DDPMVarianceType::Learned => variance, 116 | } 117 | } 118 | 119 | pub fn timesteps(&self) -> &[usize] { 120 | self.timesteps.as_slice() 121 | } 122 | 123 | /// Ensures interchangeability with schedulers that need to scale the denoising model input 124 | /// depending on the current timestep. 125 | pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { 126 | sample 127 | } 128 | 129 | pub fn step(&self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { 130 | let prev_t = timestep as isize - self.step_ratio as isize; 131 | 132 | // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L272 133 | // 1. compute alphas, betas 134 | let alpha_prod_t = self.alphas_cumprod[timestep]; 135 | let alpha_prod_t_prev = 136 | if prev_t >= 0 { self.alphas_cumprod[prev_t as usize] } else { 1.0 }; 137 | let beta_prod_t = 1. - alpha_prod_t; 138 | let beta_prod_t_prev = 1. - alpha_prod_t_prev; 139 | let current_alpha_t = alpha_prod_t / alpha_prod_t_prev; 140 | let current_beta_t = 1. - current_alpha_t; 141 | 142 | // 2. compute predicted original sample from predicted noise also called "predicted x_0" of formula (15) 143 | let mut pred_original_sample = match self.config.prediction_type { 144 | PredictionType::Epsilon => { 145 | (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt() 146 | } 147 | PredictionType::Sample => model_output.shallow_clone(), 148 | PredictionType::VPrediction => { 149 | alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output 150 | } 151 | }; 152 | 153 | // 3. clip predicted x_0 154 | if self.config.clip_sample { 155 | pred_original_sample = pred_original_sample.clamp(-1., 1.); 156 | } 157 | 158 | // 4. Compute coefficients for pred_original_sample x_0 and current sample x_t 159 | // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 160 | let pred_original_sample_coeff = (alpha_prod_t_prev.sqrt() * current_beta_t) / beta_prod_t; 161 | let current_sample_coeff = current_alpha_t.sqrt() * beta_prod_t_prev / beta_prod_t; 162 | 163 | // 5. Compute predicted previous sample µ_t 164 | // See formula (7) from https://arxiv.org/pdf/2006.11239.pdf 165 | let pred_prev_sample = 166 | pred_original_sample_coeff * &pred_original_sample + current_sample_coeff * sample; 167 | 168 | // https://github.com/huggingface/diffusers/blob/df2b548e893ccb8a888467c2508756680df22821/src/diffusers/schedulers/scheduling_ddpm.py#L305 169 | // 6. Add noise 170 | let mut variance = model_output.zeros_like(); 171 | if timestep > 0 { 172 | let variance_noise = model_output.randn_like(); 173 | if self.config.variance_type == DDPMVarianceType::FixedSmallLog { 174 | variance = self.get_variance(timestep) * variance_noise; 175 | } else { 176 | variance = self.get_variance(timestep).sqrt() * variance_noise; 177 | } 178 | } 179 | &pred_prev_sample + variance 180 | } 181 | 182 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: usize) -> Tensor { 183 | self.alphas_cumprod[timestep].sqrt() * original_samples 184 | + (1. - self.alphas_cumprod[timestep]).sqrt() * noise 185 | } 186 | 187 | pub fn init_noise_sigma(&self) -> f64 { 188 | self.init_noise_sigma 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /src/schedulers/euler_ancestral_discrete.rs: -------------------------------------------------------------------------------- 1 | use super::{interp, BetaSchedule, PredictionType}; 2 | use tch::{kind, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct EulerAncestralDiscreteSchedulerConfig { 6 | /// The value of beta at the beginning of training. 7 | pub beta_start: f64, 8 | /// The value of beta at the end of training. 9 | pub beta_end: f64, 10 | /// How beta evolved during training. 11 | pub beta_schedule: BetaSchedule, 12 | /// number of diffusion steps used to train the model. 13 | pub train_timesteps: usize, 14 | /// prediction type of the scheduler function 15 | pub prediction_type: PredictionType, 16 | } 17 | 18 | impl Default for EulerAncestralDiscreteSchedulerConfig { 19 | fn default() -> Self { 20 | Self { 21 | beta_start: 0.00085, 22 | beta_end: 0.012, 23 | beta_schedule: BetaSchedule::ScaledLinear, 24 | train_timesteps: 1000, 25 | prediction_type: PredictionType::Epsilon, 26 | } 27 | } 28 | } 29 | 30 | /// Ancestral sampling with Euler method steps. 31 | /// Based on the original k-diffusion implementation by Katherine Crowson: 32 | /// 33 | /// https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 34 | #[derive(Clone)] 35 | pub struct EulerAncestralDiscreteScheduler { 36 | timesteps: Vec, 37 | sigmas: Vec, 38 | init_noise_sigma: f64, 39 | pub config: EulerAncestralDiscreteSchedulerConfig, 40 | } 41 | 42 | impl EulerAncestralDiscreteScheduler { 43 | pub fn new(inference_steps: usize, config: EulerAncestralDiscreteSchedulerConfig) -> Self { 44 | let betas = match config.beta_schedule { 45 | BetaSchedule::ScaledLinear => Tensor::linspace( 46 | config.beta_start.sqrt(), 47 | config.beta_end.sqrt(), 48 | config.train_timesteps as i64, 49 | kind::FLOAT_CPU, 50 | ) 51 | .square(), 52 | BetaSchedule::Linear => Tensor::linspace( 53 | config.beta_start, 54 | config.beta_end, 55 | config.train_timesteps as i64, 56 | kind::FLOAT_CPU, 57 | ), 58 | _ => unimplemented!( 59 | "EulerAncestralDiscreteScheduler only implements linear and scaled_linear betas." 60 | ), 61 | }; 62 | 63 | let alphas: Tensor = 1. - betas; 64 | let alphas_cumprod = alphas.cumprod(0, Kind::Double); 65 | 66 | let timesteps = Tensor::linspace( 67 | (config.train_timesteps - 1) as f64, 68 | 0., 69 | inference_steps as i64, 70 | kind::FLOAT_CPU, 71 | ); 72 | 73 | let sigmas = ((1. - &alphas_cumprod) as Tensor / &alphas_cumprod).sqrt(); 74 | let sigmas = interp( 75 | ×teps, // x-coordinates at which to evaluate the interpolated values 76 | Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU), 77 | sigmas, 78 | ); 79 | 80 | let sigmas = Tensor::concat(&[sigmas, Tensor::from_slice(&[0.0])], 0); 81 | 82 | // standard deviation of the initial noise distribution 83 | let init_noise_sigma: f64 = sigmas.max().try_into().unwrap(); 84 | Self { 85 | timesteps: timesteps.try_into().unwrap(), 86 | sigmas: sigmas.try_into().unwrap(), 87 | init_noise_sigma, 88 | config, 89 | } 90 | } 91 | 92 | pub fn timesteps(&self) -> &[f64] { 93 | self.timesteps.as_slice() 94 | } 95 | 96 | pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor { 97 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 98 | let sigma = self.sigmas[step_index]; 99 | 100 | // https://github.com/huggingface/diffusers/blob/aba2a65d6ab47c0d1c12fa47e9b238c1d3e34512/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py#L132 101 | sample / (sigma.powi(2) + 1.).sqrt() 102 | } 103 | 104 | pub fn step(&self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor { 105 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 106 | let sigma = self.sigmas[step_index]; 107 | 108 | // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 109 | let pred_original_sample = match self.config.prediction_type { 110 | PredictionType::Epsilon => sample - sigma * model_output, 111 | PredictionType::VPrediction => { 112 | model_output * (-sigma / (sigma.powi(2) + 1.).sqrt()) 113 | + (sample / (sigma.powi(2) + 1.)) 114 | } 115 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"), 116 | }; 117 | 118 | let sigma_from = self.sigmas[step_index]; 119 | let sigma_to = self.sigmas[step_index + 1]; 120 | let sigma_up = (sigma_to.powi(2) * (sigma_from.powi(2) - sigma_to.powi(2)) 121 | / sigma_from.powi(2)) 122 | .sqrt(); 123 | let sigma_down = (sigma_to.powi(2) - sigma_up.powi(2)).sqrt(); 124 | 125 | // 2. Convert to an ODE derivative 126 | let derivative = (sample - pred_original_sample) / sigma; 127 | let dt = sigma_down - sigma; 128 | 129 | let prev_sample = sample + derivative * dt; 130 | let noise = Tensor::randn_like(model_output); 131 | 132 | prev_sample + noise * sigma_up 133 | } 134 | 135 | pub fn init_noise_sigma(&self) -> f64 { 136 | self.init_noise_sigma 137 | } 138 | 139 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor { 140 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 141 | let sigma = self.sigmas[step_index]; 142 | 143 | // noisy samples 144 | original_samples + noise * sigma 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/schedulers/euler_discrete.rs: -------------------------------------------------------------------------------- 1 | use super::{interp, BetaSchedule, PredictionType}; 2 | use tch::{kind, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct EulerDiscreteSchedulerConfig { 6 | /// The value of beta at the beginning of training. 7 | pub beta_start: f64, 8 | /// The value of beta at the end of training. 9 | pub beta_end: f64, 10 | /// How beta evolved during training. 11 | pub beta_schedule: BetaSchedule, 12 | /// number of diffusion steps used to train the model. 13 | pub train_timesteps: usize, 14 | /// prediction type of the scheduler function 15 | pub prediction_type: PredictionType, 16 | } 17 | 18 | impl Default for EulerDiscreteSchedulerConfig { 19 | fn default() -> Self { 20 | Self { 21 | beta_start: 0.00085, 22 | beta_end: 0.012, 23 | beta_schedule: BetaSchedule::ScaledLinear, 24 | train_timesteps: 1000, 25 | prediction_type: PredictionType::Epsilon, 26 | } 27 | } 28 | } 29 | 30 | /// Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. 31 | /// Based on the original 32 | /// k-diffusion implementation by Katherine Crowson: 33 | /// https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51 34 | #[derive(Clone)] 35 | pub struct EulerDiscreteScheduler { 36 | timesteps: Vec, 37 | sigmas: Vec, 38 | init_noise_sigma: f64, 39 | pub config: EulerDiscreteSchedulerConfig, 40 | } 41 | 42 | impl EulerDiscreteScheduler { 43 | pub fn new(inference_steps: usize, config: EulerDiscreteSchedulerConfig) -> Self { 44 | let betas = match config.beta_schedule { 45 | BetaSchedule::ScaledLinear => Tensor::linspace( 46 | config.beta_start.sqrt(), 47 | config.beta_end.sqrt(), 48 | config.train_timesteps as i64, 49 | kind::FLOAT_CPU, 50 | ) 51 | .square(), 52 | BetaSchedule::Linear => Tensor::linspace( 53 | config.beta_start, 54 | config.beta_end, 55 | config.train_timesteps as i64, 56 | kind::FLOAT_CPU, 57 | ), 58 | _ => unimplemented!( 59 | "EulerDiscreteScheduler only implements linear and scaled_linear betas." 60 | ), 61 | }; 62 | 63 | let alphas: Tensor = 1. - betas; 64 | let alphas_cumprod = alphas.cumprod(0, Kind::Double); 65 | 66 | let timesteps = Tensor::linspace( 67 | (config.train_timesteps - 1) as f64, 68 | 0., 69 | inference_steps as i64, 70 | kind::FLOAT_CPU, 71 | ); 72 | 73 | let sigmas = ((1. - &alphas_cumprod) as Tensor / &alphas_cumprod).sqrt(); 74 | let sigmas = interp( 75 | ×teps, // x-coordinates at which to evaluate the interpolated values 76 | Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU), 77 | sigmas, 78 | ); 79 | let sigmas = Tensor::concat(&[sigmas, Tensor::from_slice(&[0.0])], 0); 80 | 81 | // standard deviation of the initial noise distribution 82 | let init_noise_sigma: f64 = sigmas.max().try_into().unwrap(); 83 | 84 | Self { 85 | timesteps: timesteps.try_into().unwrap(), 86 | sigmas: sigmas.try_into().unwrap(), 87 | init_noise_sigma, 88 | config, 89 | } 90 | } 91 | 92 | pub fn timesteps(&self) -> &[f64] { 93 | self.timesteps.as_slice() 94 | } 95 | 96 | pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor { 97 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 98 | let sigma = self.sigmas[step_index]; 99 | 100 | // https://github.com/huggingface/diffusers/blob/2bd53a940c60d13421d9e8887af96b30a53c1b95/src/diffusers/schedulers/scheduling_euler_discrete.py#L133 101 | sample / (sigma.powi(2) + 1.).sqrt() 102 | } 103 | 104 | pub fn step(&self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor { 105 | let (s_churn, s_tmin, s_tmax, s_noise) = (0.0, 0.0, f64::INFINITY, 1.0); 106 | 107 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 108 | let sigma = self.sigmas[step_index]; 109 | 110 | let gamma = if s_tmin <= sigma && sigma <= s_tmax { 111 | (s_churn / (self.sigmas.len() as f64 - 1.)).min(2.0_f64.sqrt() - 1.) 112 | } else { 113 | 0.0 114 | }; 115 | 116 | let noise = Tensor::randn_like(model_output); 117 | let eps = noise * s_noise; 118 | let sigma_hat = sigma * (gamma + 1.); 119 | 120 | let sample = if gamma > 0.0 { 121 | sample + eps * (sigma_hat.powi(2) - sigma.powi(2)).sqrt() 122 | } else { 123 | sample.shallow_clone() 124 | }; 125 | 126 | // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 127 | let pred_original_sample = match self.config.prediction_type { 128 | PredictionType::Epsilon => &sample - sigma_hat * model_output, 129 | PredictionType::VPrediction => { 130 | model_output * (-sigma / (sigma.powi(2) + 1.).sqrt()) 131 | + (&sample / (sigma.powi(2) + 1.)) 132 | } 133 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"), 134 | }; 135 | 136 | // 2. Convert to an ODE derivative 137 | let derivative = (&sample - pred_original_sample) / sigma_hat; 138 | let dt = self.sigmas[step_index + 1] - sigma_hat; 139 | 140 | sample + derivative * dt 141 | } 142 | 143 | pub fn init_noise_sigma(&self) -> f64 { 144 | self.init_noise_sigma 145 | } 146 | 147 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor { 148 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 149 | let sigma = self.sigmas[step_index]; 150 | 151 | original_samples + noise * sigma 152 | } 153 | } 154 | -------------------------------------------------------------------------------- /src/schedulers/heun_discrete.rs: -------------------------------------------------------------------------------- 1 | use super::{interp, BetaSchedule, PredictionType}; 2 | use tch::{kind, IndexOp, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct HeunDiscreteSchedulerConfig { 6 | /// The value of beta at the beginning of training. 7 | pub beta_start: f64, 8 | /// The value of beta at the end of training. 9 | pub beta_end: f64, 10 | /// How beta evolved during training. 11 | pub beta_schedule: BetaSchedule, 12 | /// number of diffusion steps used to train the model. 13 | pub train_timesteps: usize, 14 | /// prediction type of the scheduler function 15 | pub prediction_type: PredictionType, 16 | } 17 | 18 | impl Default for HeunDiscreteSchedulerConfig { 19 | fn default() -> Self { 20 | Self { 21 | beta_start: 0.00085, // sensible defaults 22 | beta_end: 0.012, 23 | beta_schedule: BetaSchedule::Linear, 24 | train_timesteps: 1000, 25 | prediction_type: PredictionType::Epsilon, 26 | } 27 | } 28 | } 29 | 30 | pub struct HeunDiscreteScheduler { 31 | timesteps: Vec, 32 | sigmas: Vec, 33 | init_noise_sigma: f64, 34 | prev_derivative: Option, 35 | sample: Option, 36 | dt: Option, 37 | pub config: HeunDiscreteSchedulerConfig, 38 | } 39 | 40 | impl HeunDiscreteScheduler { 41 | pub fn new(inference_steps: usize, config: HeunDiscreteSchedulerConfig) -> Self { 42 | let betas = match config.beta_schedule { 43 | BetaSchedule::ScaledLinear => Tensor::linspace( 44 | config.beta_start.sqrt(), 45 | config.beta_end.sqrt(), 46 | config.train_timesteps as i64, 47 | kind::FLOAT_CPU, 48 | ) 49 | .square(), 50 | BetaSchedule::Linear => Tensor::linspace( 51 | config.beta_start, 52 | config.beta_end, 53 | config.train_timesteps as i64, 54 | kind::FLOAT_CPU, 55 | ), 56 | _ => unimplemented!( 57 | "HeunDiscreteScheduler only implements linear and scaled_linear betas." 58 | ), 59 | }; 60 | 61 | let alphas: Tensor = 1. - betas; 62 | let alphas_cumprod = alphas.cumprod(0, Kind::Double); 63 | 64 | let timesteps = Tensor::linspace( 65 | (config.train_timesteps - 1) as f64, 66 | 0., 67 | inference_steps as i64, 68 | kind::FLOAT_CPU, 69 | ); 70 | 71 | let sigmas = ((1. - &alphas_cumprod) as Tensor / &alphas_cumprod).sqrt(); 72 | let sigmas = interp( 73 | ×teps, // x-coordinates at which to evaluate the interpolated values 74 | Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU), 75 | sigmas, 76 | ); 77 | 78 | // https://github.com/huggingface/diffusers/blob/aba2a65d6ab47c0d1c12fa47e9b238c1d3e34512/src/diffusers/schedulers/scheduling_heun_discrete.py#L132-L134 79 | let sigmas = Tensor::cat( 80 | &[ 81 | // sigmas[:1] 82 | sigmas.i(..1), 83 | // sigmas[1:].repeat_interleave(2) 84 | sigmas.i(1..).repeat_interleave_self_int(2, 0, None), 85 | // append 0.0 86 | Tensor::from_slice(&[0.0]), 87 | ], 88 | 0, 89 | ); 90 | 91 | let init_noise_sigma: f64 = sigmas.max().try_into().unwrap(); 92 | 93 | // https://github.com/huggingface/diffusers/blob/aba2a65d6ab47c0d1c12fa47e9b238c1d3e34512/src/diffusers/schedulers/scheduling_heun_discrete.py#L140 94 | let timesteps = Tensor::cat( 95 | &[ 96 | // timesteps[:1] 97 | timesteps.i(..1), 98 | // timesteps[1:].repeat_interleave(2) 99 | timesteps.i(1..).repeat_interleave_self_int(2, 0, None), 100 | ], 101 | 0, 102 | ); 103 | 104 | Self { 105 | timesteps: timesteps.try_into().unwrap(), 106 | sigmas: sigmas.try_into().unwrap(), 107 | prev_derivative: None, 108 | dt: None, 109 | sample: None, 110 | init_noise_sigma, 111 | config, 112 | } 113 | } 114 | 115 | pub fn timesteps(&self) -> &[f64] { 116 | self.timesteps.as_slice() 117 | } 118 | 119 | fn index_for_timestep(&self, timestep: f64) -> usize { 120 | // find all the positions of the timesteps corresponding to timestep 121 | let indices = self 122 | .timesteps 123 | .iter() 124 | .enumerate() 125 | .filter_map(|(idx, &t)| (t == timestep).then_some(idx)) 126 | .collect::>(); 127 | 128 | if self.state_in_first_order() { 129 | *indices.last().unwrap() 130 | } else { 131 | indices[0] 132 | } 133 | } 134 | 135 | /// Ensures interchangeability with schedulers that need to scale the denoising model input depending on the 136 | /// current timestep. 137 | pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor { 138 | let step_index = self.index_for_timestep(timestep); 139 | let sigma = self.sigmas[step_index]; 140 | 141 | // https://github.com/huggingface/diffusers/blob/aba2a65d6ab47c0d1c12fa47e9b238c1d3e34512/src/diffusers/schedulers/scheduling_heun_discrete.py#L106 142 | sample / (sigma.powi(2) + 1.).sqrt() 143 | } 144 | 145 | fn state_in_first_order(&self) -> bool { 146 | self.dt.is_none() 147 | } 148 | 149 | pub fn step(&mut self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor { 150 | let step_index = self.index_for_timestep(timestep); 151 | 152 | let (sigma, sigma_next) = if self.state_in_first_order() { 153 | (self.sigmas[step_index], self.sigmas[step_index + 1]) 154 | } else { 155 | // 2nd order / Heun's method 156 | (self.sigmas[step_index - 1], self.sigmas[step_index]) 157 | }; 158 | 159 | // currently only gamma=0 is supported. This usually works best anyways. 160 | // We can support gamma in the future but then need to scale the timestep before 161 | // passing it to the model which requires a change in API 162 | let gamma = 0.0; 163 | let sigma_hat = sigma * (gamma + 1.); // sigma_hat == sigma for now 164 | 165 | // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 166 | let sigma_input = if self.state_in_first_order() { sigma_hat } else { sigma_next }; 167 | 168 | let pred_original_sample = match self.config.prediction_type { 169 | PredictionType::Epsilon => sample - sigma_input * model_output, 170 | PredictionType::VPrediction => { 171 | model_output * (-sigma_input / (sigma_input.powi(2) + 1.).sqrt()) 172 | + (sample / (sigma_input.powi(2) + 1.)) 173 | } 174 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"), 175 | }; 176 | 177 | let (derivative, dt, sample) = if self.state_in_first_order() { 178 | // 2. Convert to an ODE derivative for 1st order 179 | ( 180 | (sample - pred_original_sample) / sigma_hat, 181 | sigma_next - sigma_hat, 182 | sample.shallow_clone(), 183 | ) 184 | } else { 185 | // 2. 2nd order / Heun's method 186 | let derivative = (sample - &pred_original_sample) / sigma_next; 187 | ( 188 | (self.prev_derivative.as_ref().unwrap() + derivative) / 2., 189 | self.dt.unwrap(), 190 | self.sample.as_ref().unwrap().shallow_clone(), 191 | ) 192 | }; 193 | 194 | if self.state_in_first_order() { 195 | // store for 2nd order step 196 | self.prev_derivative = Some(derivative.shallow_clone()); 197 | self.dt = Some(dt); 198 | self.sample = Some(sample.shallow_clone()); 199 | } else { 200 | // free dt and derivative 201 | // Note, this puts the scheduler in "first order mode" 202 | self.prev_derivative = None; 203 | self.dt = None; 204 | self.sample = None; 205 | } 206 | 207 | sample + derivative * dt 208 | } 209 | 210 | pub fn init_noise_sigma(&self) -> f64 { 211 | self.init_noise_sigma 212 | } 213 | 214 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor { 215 | let step_index = self.index_for_timestep(timestep); 216 | let sigma = self.sigmas[step_index]; 217 | 218 | // noisy samples 219 | original_samples + noise * sigma 220 | } 221 | } 222 | -------------------------------------------------------------------------------- /src/schedulers/k_dpm_2_ancestral_discrete.rs: -------------------------------------------------------------------------------- 1 | use super::{interp, BetaSchedule, PredictionType}; 2 | use tch::{kind, IndexOp, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct KDPM2AncestralDiscreteSchedulerConfig { 6 | /// The value of beta at the beginning of training. 7 | pub beta_start: f64, 8 | /// The value of beta at the end of training. 9 | pub beta_end: f64, 10 | /// How beta evolved during training. 11 | pub beta_schedule: BetaSchedule, 12 | /// number of diffusion steps used to train the model. 13 | pub train_timesteps: usize, 14 | /// prediction type of the scheduler function 15 | pub prediction_type: PredictionType, 16 | } 17 | 18 | impl Default for KDPM2AncestralDiscreteSchedulerConfig { 19 | fn default() -> Self { 20 | Self { 21 | beta_start: 0.00085, // sensible defaults 22 | beta_end: 0.012, 23 | beta_schedule: BetaSchedule::ScaledLinear, 24 | train_timesteps: 1000, 25 | prediction_type: PredictionType::Epsilon, 26 | } 27 | } 28 | } 29 | 30 | /// Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: 31 | /// https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 32 | /// 33 | /// Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). 34 | pub struct KDPM2AncestralDiscreteScheduler { 35 | timesteps: Vec, 36 | sigmas: Vec, 37 | sigmas_interpol: Vec, 38 | sigmas_up: Vec, 39 | sigmas_down: Vec, 40 | init_noise_sigma: f64, 41 | sample: Option, 42 | pub config: KDPM2AncestralDiscreteSchedulerConfig, 43 | } 44 | 45 | impl KDPM2AncestralDiscreteScheduler { 46 | pub fn new(inference_steps: usize, config: KDPM2AncestralDiscreteSchedulerConfig) -> Self { 47 | let betas = match config.beta_schedule { 48 | BetaSchedule::ScaledLinear => Tensor::linspace( 49 | config.beta_start.sqrt(), 50 | config.beta_end.sqrt(), 51 | config.train_timesteps as i64, 52 | kind::FLOAT_CPU, 53 | ) 54 | .square(), 55 | BetaSchedule::Linear => Tensor::linspace( 56 | config.beta_start, 57 | config.beta_end, 58 | config.train_timesteps as i64, 59 | kind::FLOAT_CPU, 60 | ), 61 | _ => unimplemented!( 62 | "KDPM2AncestralDiscreteScheduler only implements linear and scaled_linear betas." 63 | ), 64 | }; 65 | 66 | let alphas: Tensor = 1. - betas; 67 | let alphas_cumprod = alphas.cumprod(0, Kind::Double); 68 | 69 | let timesteps = Tensor::linspace( 70 | (config.train_timesteps - 1) as f64, 71 | 0., 72 | inference_steps as i64, 73 | kind::FLOAT_CPU, 74 | ); 75 | 76 | let sigmas = ((1. - &alphas_cumprod) as Tensor / alphas_cumprod).sqrt(); 77 | let log_sigmas = sigmas.log(); 78 | 79 | let sigmas = interp( 80 | ×teps, // x-coordinates at which to evaluate the interpolated values 81 | Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU), 82 | sigmas, 83 | ); 84 | // append 0.0 85 | let sigmas = Tensor::concat(&[sigmas, [0.0].as_slice().into()], 0); 86 | let sz = sigmas.size1().unwrap(); 87 | 88 | // compute up and down sigmas 89 | let sigmas_next = sigmas.roll([-1], [0]); 90 | // sigmas_next[-1] = 0.0 91 | let sigmas_next = sigmas_next.index_fill(0, &[sz - 1].as_slice().into(), 0.0); 92 | let sigmas_up = (sigmas_next.square() * (sigmas.square() - sigmas_next.square()) 93 | / sigmas.square()) 94 | .sqrt(); 95 | let sigmas_down = (sigmas_next.square() - sigmas_up.square()).sqrt(); 96 | // sigmas_down[-1] = 0.0 97 | let sigmas_down = sigmas_down.index_fill(0, &[sz - 1].as_slice().into(), 0.0); 98 | 99 | // interpolate sigmas 100 | let sigmas_interpol = sigmas.log().lerp(&sigmas_down.log(), 0.5).exp(); 101 | // sigmas_interpol[-2] = 0.0 102 | let sigmas_interpol = 103 | sigmas_interpol.index_fill(0, &[sz - 2, sz - 1].as_slice().into(), 0.0); 104 | 105 | // interpolate timesteps 106 | let timesteps_interpol = Self::sigma_to_t(&sigmas_interpol, log_sigmas); 107 | let interleaved_timesteps = Tensor::stack( 108 | &[ 109 | // timesteps_interpol[:-2, None] 110 | timesteps_interpol.slice(0, 0, -2, 1).unsqueeze(-1), 111 | // timesteps[1:, None] 112 | timesteps.i(1..).unsqueeze(-1), 113 | ], 114 | -1, 115 | ) 116 | .flatten(0, -1); 117 | 118 | // set sigmas 119 | let sigmas = Tensor::cat( 120 | &[ 121 | // sigmas[:1] 122 | sigmas.i(..1), 123 | // sigmas[1:].repeat_interleave(2) 124 | sigmas.i(1..).repeat_interleave_self_int(2, 0, None), 125 | //sigmas[-1:] 126 | sigmas.i(-1..0), 127 | ], 128 | 0, 129 | ); 130 | // https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py#L155-L157 131 | let sigmas_interpol = Tensor::cat( 132 | &[ 133 | // sigmas_interpol[:1] 134 | sigmas_interpol.i(..1), 135 | // sigmas_interpol[1:].repeat_interleave(2) 136 | sigmas_interpol.i(1..).repeat_interleave_self_int(2, 0, None), 137 | //sigmas_interpol[-1:] 138 | sigmas_interpol.i(-1..0), 139 | ], 140 | 0, 141 | ); 142 | 143 | let sigmas_up = Tensor::cat( 144 | &[ 145 | // sigmas_up[:1] 146 | sigmas_up.i(..1), 147 | // sigmas_up[1:].repeat_interleave(2) 148 | sigmas_up.i(1..).repeat_interleave_self_int(2, 0, None), 149 | // sigmas_up[-1:] 150 | sigmas_up.i(-1..0), 151 | ], 152 | 0, 153 | ); 154 | let sigmas_down = Tensor::cat( 155 | &[ 156 | // sigmas_down[:1] 157 | sigmas_down.i(..1), 158 | // sigmas_down[1:].repeat_interleave(2) 159 | sigmas_down.i(1..).repeat_interleave_self_int(2, 0, None), 160 | // sigmas_up[-1:] 161 | sigmas_down.i(-1..0), 162 | ], 163 | 0, 164 | ); 165 | 166 | // https://github.com/huggingface/diffusers/blob/9b37ed33b5fa09e594b38e4e6f7477beff3bd66a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py#L158 167 | let timesteps = Tensor::cat( 168 | &[ 169 | // timesteps[:1] 170 | timesteps.i(..1), 171 | interleaved_timesteps, 172 | ], 173 | 0, 174 | ); 175 | 176 | // standard deviation of the initial noise distribution 177 | let init_noise_sigma: f64 = sigmas.max().try_into().unwrap(); 178 | 179 | Self { 180 | timesteps: timesteps.try_into().unwrap(), 181 | sigmas: sigmas.try_into().unwrap(), 182 | sigmas_interpol: sigmas_interpol.try_into().unwrap(), 183 | sigmas_up: sigmas_up.try_into().unwrap(), 184 | sigmas_down: sigmas_down.try_into().unwrap(), 185 | init_noise_sigma, 186 | sample: None, 187 | config, 188 | } 189 | } 190 | 191 | fn sigma_to_t(sigma: &Tensor, log_sigmas: Tensor) -> Tensor { 192 | // get log sigma 193 | let log_sigma = sigma.log(); 194 | 195 | // get distribution 196 | let dists = &log_sigma - log_sigmas.unsqueeze(-1); 197 | 198 | // get sigmas range 199 | let low_idx = dists 200 | .ge(0) 201 | .cumsum(0, Kind::Int64) 202 | .argmax(0, false) 203 | .clamp_max(log_sigmas.size1().unwrap() - 2); 204 | let high_idx = &low_idx + 1; 205 | 206 | let low = log_sigmas.index_select(0, &low_idx); 207 | let high = log_sigmas.index_select(0, &high_idx); 208 | 209 | // interpolate sigmas 210 | let w = (&low - log_sigma) / (low - high); 211 | let w = w.clamp(0., 1.); 212 | 213 | // transform interpolation to time range 214 | let t: Tensor = (1 - &w) * low_idx + w * high_idx; 215 | 216 | t.view(sigma.size().as_slice()) 217 | } 218 | 219 | pub fn timesteps(&self) -> &[f64] { 220 | self.timesteps.as_slice() 221 | } 222 | 223 | fn index_for_timestep(&self, timestep: f64) -> usize { 224 | // find all the positions of the timesteps corresponding to timestep 225 | let indices = self 226 | .timesteps 227 | .iter() 228 | .enumerate() 229 | .filter_map(|(idx, &t)| (t == timestep).then_some(idx)) 230 | .collect::>(); 231 | 232 | if self.state_in_first_order() { 233 | *indices.last().unwrap() 234 | } else { 235 | indices[0] 236 | } 237 | } 238 | 239 | /// Scales model input by (sigma^2 + 1) ^ .5 240 | pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor { 241 | let step_index = self.index_for_timestep(timestep); 242 | let step_index_minus_one = 243 | if step_index == 0 { self.sigmas.len() - 1 } else { step_index - 1 }; 244 | 245 | let sigma = if self.state_in_first_order() { 246 | self.sigmas[step_index] 247 | } else { 248 | self.sigmas_interpol[step_index_minus_one] 249 | }; 250 | 251 | sample / (sigma.powi(2) + 1.).sqrt() 252 | } 253 | 254 | fn state_in_first_order(&self) -> bool { 255 | self.sample.is_none() 256 | } 257 | 258 | pub fn step(&mut self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor { 259 | let step_index = self.index_for_timestep(timestep); 260 | let step_index_minus_one = 261 | if step_index == 0 { self.sigmas.len() - 1 } else { step_index - 1 }; 262 | 263 | let (sigma, sigma_interpol, sigma_up, sigma_down) = if self.state_in_first_order() { 264 | ( 265 | self.sigmas[step_index], 266 | self.sigmas_interpol[step_index], 267 | self.sigmas_up[step_index], 268 | self.sigmas_down[step_index_minus_one], 269 | ) 270 | } else { 271 | // 2nd order / KDPM2's method 272 | ( 273 | self.sigmas[step_index_minus_one], 274 | self.sigmas_interpol[step_index_minus_one], 275 | self.sigmas_up[step_index_minus_one], 276 | self.sigmas_down[step_index_minus_one], 277 | ) 278 | }; 279 | 280 | // currently only gamma=0 is supported. This usually works best anyways. 281 | // We can support gamma in the future but then need to scale the timestep before 282 | // passing it to the model which requires a change in API 283 | let gamma = 0.0; 284 | let sigma_hat = sigma * (gamma + 1.); // sigma_hat == sigma for now 285 | 286 | let noise = model_output.randn_like(); 287 | 288 | // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 289 | let sigma_input = if self.state_in_first_order() { sigma_hat } else { sigma_interpol }; 290 | let pred_original_sample = match self.config.prediction_type { 291 | PredictionType::Epsilon => sample - sigma_input * model_output, 292 | PredictionType::VPrediction => { 293 | model_output * (-sigma_input / (sigma_input.powi(2) + 1.).sqrt()) 294 | + (sample / (sigma_input.powi(2) + 1.)) 295 | } 296 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"), 297 | }; 298 | 299 | let mut prev_sample; 300 | if self.state_in_first_order() { 301 | // 2. Convert to an ODE derivative for 1st order 302 | let derivative = (sample - pred_original_sample) / sigma_hat; 303 | // 3. delta timestep 304 | let dt = sigma_interpol - sigma_hat; 305 | 306 | // store for 2nd order step 307 | self.sample = Some(sample.shallow_clone()); 308 | prev_sample = sample + derivative * dt; 309 | } else { 310 | // DPM-Solver-2 311 | // 2. Convert to an ODE derivative for 2nd order 312 | let derivative = (sample - pred_original_sample) / sigma_interpol; 313 | // 3. delta timestep 314 | let dt = sigma_down - sigma_hat; 315 | 316 | let sample = self.sample.as_ref().unwrap().shallow_clone(); 317 | self.sample = None; 318 | 319 | prev_sample = sample + derivative * dt; 320 | prev_sample += noise * sigma_up; 321 | } 322 | 323 | prev_sample 324 | } 325 | 326 | pub fn init_noise_sigma(&self) -> f64 { 327 | self.init_noise_sigma 328 | } 329 | 330 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor { 331 | let step_index = self.index_for_timestep(timestep); 332 | let sigma = self.sigmas[step_index]; 333 | 334 | // noisy samples 335 | original_samples + noise * sigma 336 | } 337 | } 338 | -------------------------------------------------------------------------------- /src/schedulers/k_dpm_2_discrete.rs: -------------------------------------------------------------------------------- 1 | use super::{interp, BetaSchedule, PredictionType}; 2 | use tch::{kind, IndexOp, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct KDPM2DiscreteSchedulerConfig { 6 | /// The value of beta at the beginning of training. 7 | pub beta_start: f64, 8 | /// The value of beta at the end of training. 9 | pub beta_end: f64, 10 | /// How beta evolved during training. 11 | pub beta_schedule: BetaSchedule, 12 | /// number of diffusion steps used to train the model. 13 | pub train_timesteps: usize, 14 | /// prediction type of the scheduler function 15 | pub prediction_type: PredictionType, 16 | } 17 | 18 | impl Default for KDPM2DiscreteSchedulerConfig { 19 | fn default() -> Self { 20 | Self { 21 | beta_start: 0.00085, // sensible defaults 22 | beta_end: 0.012, 23 | beta_schedule: BetaSchedule::ScaledLinear, 24 | train_timesteps: 1000, 25 | prediction_type: PredictionType::Epsilon, 26 | } 27 | } 28 | } 29 | 30 | /// Scheduler created by @crowsonkb in [k_diffusion](https://github.com/crowsonkb/k-diffusion), see: 31 | /// https://github.com/crowsonkb/k-diffusion/blob/5b3af030dd83e0297272d861c19477735d0317ec/k_diffusion/sampling.py#L188 32 | /// 33 | /// Scheduler inspired by DPM-Solver-2 and Algorthim 2 from Karras et al. (2022). 34 | pub struct KDPM2DiscreteScheduler { 35 | timesteps: Vec, 36 | sigmas: Vec, 37 | sigmas_interpol: Vec, 38 | init_noise_sigma: f64, 39 | sample: Option, 40 | pub config: KDPM2DiscreteSchedulerConfig, 41 | } 42 | 43 | impl KDPM2DiscreteScheduler { 44 | pub fn new(inference_steps: usize, config: KDPM2DiscreteSchedulerConfig) -> Self { 45 | let betas = match config.beta_schedule { 46 | BetaSchedule::ScaledLinear => Tensor::linspace( 47 | config.beta_start.sqrt(), 48 | config.beta_end.sqrt(), 49 | config.train_timesteps as i64, 50 | kind::FLOAT_CPU, 51 | ) 52 | .square(), 53 | BetaSchedule::Linear => Tensor::linspace( 54 | config.beta_start, 55 | config.beta_end, 56 | config.train_timesteps as i64, 57 | kind::FLOAT_CPU, 58 | ), 59 | _ => unimplemented!( 60 | "KDPM2DiscreteScheduler only implements linear and scaled_linear betas." 61 | ), 62 | }; 63 | 64 | let alphas: Tensor = 1. - betas; 65 | let alphas_cumprod = alphas.cumprod(0, Kind::Double); 66 | 67 | let timesteps = Tensor::linspace( 68 | (config.train_timesteps - 1) as f64, 69 | 0., 70 | inference_steps as i64, 71 | kind::FLOAT_CPU, 72 | ); 73 | 74 | let sigmas = ((1. - &alphas_cumprod) as Tensor / alphas_cumprod).sqrt(); 75 | let log_sigmas = sigmas.log(); 76 | 77 | let sigmas = interp( 78 | ×teps, // x-coordinates at which to evaluate the interpolated values 79 | Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU), 80 | sigmas, 81 | ); 82 | // append 0.0 83 | let sigmas = Tensor::concat(&[sigmas, Tensor::from_slice(&[0.0])], 0); 84 | 85 | // interpolate sigmas 86 | let sigmas_interpol = sigmas.log().lerp(&sigmas.roll([1], [0]).log(), 0.5).exp(); 87 | 88 | // https://github.com/huggingface/diffusers/blob/9b37ed33b5fa09e594b38e4e6f7477beff3bd66a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py#L145 89 | let sigmas = Tensor::cat( 90 | &[ 91 | // sigmas[:1] 92 | sigmas.i(..1), 93 | // sigmas[1:].repeat_interleave(2) 94 | sigmas.i(1..).repeat_interleave_self_int(2, 0, None), 95 | //sigmas[-1:] 96 | sigmas.i(-1..0), 97 | ], 98 | 0, 99 | ); 100 | 101 | let init_noise_sigma: f64 = sigmas.max().try_into().unwrap(); 102 | 103 | // interpolate timesteps 104 | let timesteps_interpol = Self::sigma_to_t(&sigmas_interpol, log_sigmas); 105 | let interleaved_timesteps = Tensor::stack( 106 | &[ 107 | // timesteps_interpol[1:-1, None] 108 | timesteps_interpol.slice(0, 1, -1, 1).unsqueeze(-1), 109 | // timesteps[1:, None] 110 | timesteps.i(1..).unsqueeze(-1), 111 | ], 112 | -1, 113 | ) 114 | .flatten(0, -1); 115 | 116 | // https://github.com/huggingface/diffusers/blob/9b37ed33b5fa09e594b38e4e6f7477beff3bd66a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py#L146-L148 117 | let sigmas_interpol = Tensor::cat( 118 | &[ 119 | // sigmas_interpol[:1] 120 | sigmas_interpol.i(..1), 121 | // sigmas_interpol[1:].repeat_interleave(2) 122 | sigmas_interpol.i(1..).repeat_interleave_self_int(2, 0, None), 123 | //sigmas_interpol[-1:] 124 | sigmas_interpol.i(-1..0), 125 | ], 126 | 0, 127 | ); 128 | // https://github.com/huggingface/diffusers/blob/9b37ed33b5fa09e594b38e4e6f7477beff3bd66a/src/diffusers/schedulers/scheduling_k_dpm_2_discrete.py#L158 129 | let timesteps = Tensor::cat( 130 | &[ 131 | // timesteps[:1] 132 | timesteps.i(..1), 133 | interleaved_timesteps, 134 | ], 135 | 0, 136 | ); 137 | 138 | Self { 139 | timesteps: timesteps.try_into().unwrap(), 140 | sigmas: sigmas.try_into().unwrap(), 141 | sigmas_interpol: sigmas_interpol.try_into().unwrap(), 142 | init_noise_sigma, 143 | sample: None, 144 | config, 145 | } 146 | } 147 | 148 | fn sigma_to_t(sigma: &Tensor, log_sigmas: Tensor) -> Tensor { 149 | // get log sigma 150 | let log_sigma = sigma.log(); 151 | 152 | // get distribution 153 | let dists = &log_sigma - log_sigmas.unsqueeze(-1); 154 | 155 | // get sigmas range 156 | let low_idx = dists 157 | .ge(0) 158 | .cumsum(0, Kind::Int64) 159 | .argmax(0, false) 160 | .clamp_max(log_sigmas.size1().unwrap() - 2); 161 | let high_idx = &low_idx + 1; 162 | 163 | let low = log_sigmas.index_select(0, &low_idx); 164 | let high = log_sigmas.index_select(0, &high_idx); 165 | 166 | // interpolate sigmas 167 | let w = (&low - log_sigma) / (low - high); 168 | let w = w.clamp(0., 1.); 169 | 170 | // transform interpolation to time range 171 | let t: Tensor = (1 - &w) * low_idx + w * high_idx; 172 | 173 | t.view(sigma.size().as_slice()) 174 | } 175 | 176 | pub fn timesteps(&self) -> &[f64] { 177 | self.timesteps.as_slice() 178 | } 179 | 180 | fn index_for_timestep(&self, timestep: f64) -> usize { 181 | // find all the positions of the timesteps corresponding to timestep 182 | let indices = self 183 | .timesteps 184 | .iter() 185 | .enumerate() 186 | .filter_map(|(idx, &t)| (t == timestep).then_some(idx)) 187 | .collect::>(); 188 | 189 | if self.state_in_first_order() { 190 | *indices.last().unwrap() 191 | } else { 192 | indices[0] 193 | } 194 | } 195 | 196 | /// Scales model input by (sigma^2 + 1) ^ .5 197 | pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor { 198 | let step_index = self.index_for_timestep(timestep); 199 | 200 | let sigma = if self.state_in_first_order() { 201 | self.sigmas[step_index] 202 | } else { 203 | self.sigmas_interpol[step_index] 204 | }; 205 | 206 | sample / (sigma.powi(2) + 1.).sqrt() 207 | } 208 | 209 | fn state_in_first_order(&self) -> bool { 210 | self.sample.is_none() 211 | } 212 | 213 | pub fn step(&mut self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor { 214 | let step_index = self.index_for_timestep(timestep); 215 | 216 | let (sigma, sigma_interpol, sigma_next) = if self.state_in_first_order() { 217 | ( 218 | self.sigmas[step_index], 219 | self.sigmas_interpol[step_index + 1], 220 | self.sigmas[step_index + 1], 221 | ) 222 | } else { 223 | // 2nd order / KDPM2's method 224 | ( 225 | self.sigmas[step_index - 1], 226 | self.sigmas_interpol[step_index + 1], 227 | self.sigmas[step_index], 228 | ) 229 | }; 230 | 231 | // currently only gamma=0 is supported. This usually works best anyways. 232 | // We can support gamma in the future but then need to scale the timestep before 233 | // passing it to the model which requires a change in API 234 | let gamma = 0.0; 235 | let sigma_hat = sigma * (gamma + 1.); // sigma_hat == sigma for now 236 | 237 | // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 238 | let sigma_input = if self.state_in_first_order() { sigma_hat } else { sigma_interpol }; 239 | let pred_original_sample = match self.config.prediction_type { 240 | PredictionType::Epsilon => sample - sigma_input * model_output, 241 | PredictionType::VPrediction => { 242 | model_output * (-sigma_input / (sigma_input.powi(2) + 1.).sqrt()) 243 | + (sample / (sigma_input.powi(2) + 1.)) 244 | } 245 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"), 246 | }; 247 | 248 | let (derivative, dt, sample) = if self.state_in_first_order() { 249 | ( 250 | // 2. Convert to an ODE derivative for 1st order 251 | (sample - pred_original_sample) / sigma_hat, 252 | // 3. delta timestep 253 | sigma_interpol - sigma_hat, 254 | sample.shallow_clone(), 255 | ) 256 | } else { 257 | ( 258 | // DPM-Solver-2 259 | // 2. Convert to an ODE derivative for 2nd order 260 | (sample - pred_original_sample) / sigma_interpol, 261 | // 3. delta timestep 262 | sigma_next - sigma_hat, 263 | self.sample.as_ref().unwrap().shallow_clone(), 264 | ) 265 | }; 266 | 267 | if self.state_in_first_order() { 268 | // store for 2nd order step 269 | self.sample = Some(sample.shallow_clone()); 270 | } else { 271 | self.sample = None 272 | }; 273 | 274 | sample + derivative * dt 275 | } 276 | 277 | pub fn init_noise_sigma(&self) -> f64 { 278 | self.init_noise_sigma 279 | } 280 | 281 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor { 282 | let step_index = self.index_for_timestep(timestep); 283 | let sigma = self.sigmas[step_index]; 284 | 285 | // noisy samples 286 | original_samples + noise * sigma 287 | } 288 | } 289 | -------------------------------------------------------------------------------- /src/schedulers/lms_discrete.rs: -------------------------------------------------------------------------------- 1 | use super::integrate::integrate; 2 | use super::{interp, BetaSchedule, PredictionType}; 3 | use tch::{kind, Kind, Tensor}; 4 | 5 | #[derive(Debug, Clone)] 6 | pub struct LMSDiscreteSchedulerConfig { 7 | /// The value of beta at the beginning of training. 8 | pub beta_start: f64, 9 | /// The value of beta at the end of training. 10 | pub beta_end: f64, 11 | /// How beta evolved during training. 12 | pub beta_schedule: BetaSchedule, 13 | /// number of diffusion steps used to train the model. 14 | pub train_timesteps: usize, 15 | /// coefficient for multi-step inference. 16 | /// https://github.com/huggingface/diffusers/blob/9b37ed33b5fa09e594b38e4e6f7477beff3bd66a/src/diffusers/schedulers/scheduling_lms_discrete.py#L189 17 | pub order: usize, 18 | /// prediction type of the scheduler function 19 | pub prediction_type: PredictionType, 20 | } 21 | 22 | impl Default for LMSDiscreteSchedulerConfig { 23 | fn default() -> Self { 24 | Self { 25 | beta_start: 0.00085, 26 | beta_end: 0.012, 27 | beta_schedule: BetaSchedule::ScaledLinear, 28 | train_timesteps: 1000, 29 | order: 4, 30 | prediction_type: PredictionType::Epsilon, 31 | } 32 | } 33 | } 34 | 35 | pub struct LMSDiscreteScheduler { 36 | timesteps: Vec, 37 | sigmas: Vec, 38 | init_noise_sigma: f64, 39 | derivatives: Vec, 40 | pub config: LMSDiscreteSchedulerConfig, 41 | } 42 | 43 | impl LMSDiscreteScheduler { 44 | pub fn new(inference_steps: usize, config: LMSDiscreteSchedulerConfig) -> Self { 45 | let betas = match config.beta_schedule { 46 | BetaSchedule::ScaledLinear => Tensor::linspace( 47 | config.beta_start.sqrt(), 48 | config.beta_end.sqrt(), 49 | config.train_timesteps as i64, 50 | kind::FLOAT_CPU, 51 | ) 52 | .square(), 53 | BetaSchedule::Linear => Tensor::linspace( 54 | config.beta_start, 55 | config.beta_end, 56 | config.train_timesteps as i64, 57 | kind::FLOAT_CPU, 58 | ), 59 | _ => unimplemented!( 60 | "LMSDiscreteScheduler only implements linear and scaled_linear betas." 61 | ), 62 | }; 63 | 64 | let alphas: Tensor = 1. - betas; 65 | let alphas_cumprod = alphas.cumprod(0, Kind::Double); 66 | 67 | let timesteps = Tensor::linspace( 68 | (config.train_timesteps - 1) as f64, 69 | 0., 70 | inference_steps as i64, 71 | kind::FLOAT_CPU, 72 | ); 73 | 74 | let sigmas = ((1. - &alphas_cumprod) as Tensor / &alphas_cumprod).sqrt(); 75 | let sigmas = interp( 76 | ×teps, // x-coordinates at which to evaluate the interpolated values 77 | Tensor::range(0, sigmas.size1().unwrap() - 1, kind::FLOAT_CPU), 78 | sigmas, 79 | ); 80 | let sigmas = Tensor::concat(&[sigmas, Tensor::from_slice(&[0.0])], 0); 81 | 82 | // standard deviation of the initial noise distribution 83 | let init_noise_sigma: f64 = sigmas.max().try_into().unwrap(); 84 | 85 | Self { 86 | timesteps: timesteps.try_into().unwrap(), 87 | sigmas: sigmas.try_into().unwrap(), 88 | init_noise_sigma, 89 | derivatives: vec![], 90 | config, 91 | } 92 | } 93 | 94 | pub fn timesteps(&self) -> &[f64] { 95 | self.timesteps.as_slice() 96 | } 97 | 98 | /// Scales the denoising model input by `(sigma^2 + 1)^0.5` to match the K-LMS algorithm. 99 | pub fn scale_model_input(&self, sample: Tensor, timestep: f64) -> Tensor { 100 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 101 | let sigma = self.sigmas[step_index]; 102 | 103 | // https://github.com/huggingface/diffusers/blob/769f0be8fb41daca9f3cbcffcfd0dbf01cc194b8/src/diffusers/schedulers/scheduling_lms_discrete.py#L132 104 | sample / (sigma.powi(2) + 1.).sqrt() 105 | } 106 | 107 | /// Compute a linear multistep coefficient 108 | fn get_lms_coefficient(&mut self, order: usize, t: usize, current_order: usize) -> f64 { 109 | let lms_derivative = |tau| -> f64 { 110 | let mut prod = 1.0; 111 | for k in 0..order { 112 | if current_order == k { 113 | continue; 114 | } 115 | prod *= (tau - self.sigmas[t - k]) 116 | / (self.sigmas[t - current_order] - self.sigmas[t - k]); 117 | } 118 | prod 119 | }; 120 | 121 | // Integrate `lms_derivative` over two consecutive timesteps. 122 | // Absolute tolerances and limit are taken from 123 | // the defaults of `scipy.integrate.quad` 124 | // https://docs.scipy.org/doc/scipy/reference/generated/scipy.integrate.quad.html 125 | let integration_out = 126 | integrate(lms_derivative, self.sigmas[t], self.sigmas[t + 1], 1.49e-8); 127 | // integrated coeff 128 | integration_out.integral 129 | } 130 | 131 | pub fn step(&mut self, model_output: &Tensor, timestep: f64, sample: &Tensor) -> Tensor { 132 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 133 | let sigma = self.sigmas[step_index]; 134 | 135 | // 1. compute predicted original sample (x_0) from sigma-scaled predicted noise 136 | let pred_original_sample = match self.config.prediction_type { 137 | PredictionType::Epsilon => sample - sigma * model_output, 138 | PredictionType::VPrediction => { 139 | model_output * (-sigma / (sigma.powi(2) + 1.).sqrt()) 140 | + (sample / (sigma.powi(2) + 1.)) 141 | } 142 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction`"), 143 | }; 144 | 145 | // 2. Convert to an ODE derivative 146 | let derivative = (sample - pred_original_sample) / sigma; 147 | self.derivatives.push(derivative); 148 | if self.derivatives.len() > self.config.order { 149 | // remove the first element 150 | self.derivatives.drain(0..1); 151 | } 152 | 153 | // 3. compute linear multistep coefficients 154 | let order = self.config.order.min(step_index + 1); 155 | let lms_coeffs: Vec<_> = 156 | (0..order).map(|o| self.get_lms_coefficient(order, step_index, o)).collect(); 157 | 158 | // 4. compute previous sample based on the derivatives path 159 | // https://github.com/huggingface/diffusers/blob/769f0be8fb41daca9f3cbcffcfd0dbf01cc194b8/src/diffusers/schedulers/scheduling_lms_discrete.py#L243-L245 160 | let deriv_sum: Tensor = lms_coeffs 161 | .iter() 162 | .zip(self.derivatives.iter().rev()) 163 | .map(|(coeff, derivative)| *coeff * derivative) 164 | .sum(); 165 | 166 | sample + deriv_sum 167 | } 168 | 169 | pub fn init_noise_sigma(&self) -> f64 { 170 | self.init_noise_sigma 171 | } 172 | 173 | pub fn add_noise(&self, original_samples: &Tensor, noise: Tensor, timestep: f64) -> Tensor { 174 | let step_index = self.timesteps.iter().position(|&t| t == timestep).unwrap(); 175 | let sigma = self.sigmas[step_index]; 176 | 177 | // noisy samples 178 | original_samples + noise * sigma 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /src/schedulers/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Diffusion pipelines and models 2 | //! 3 | //! Noise schedulers can be used to set the trade-off between 4 | //! inference speed and quality. 5 | 6 | use tch::{IndexOp, Kind, Tensor}; 7 | 8 | pub mod ddim; 9 | pub mod ddpm; 10 | pub mod dpmsolver_multistep; 11 | pub mod euler_ancestral_discrete; 12 | pub mod euler_discrete; 13 | pub mod heun_discrete; 14 | mod integrate; 15 | pub mod k_dpm_2_ancestral_discrete; 16 | pub mod k_dpm_2_discrete; 17 | pub mod lms_discrete; 18 | pub mod pndm; 19 | 20 | /// This represents how beta ranges from its minimum value to the maximum 21 | /// during training. 22 | #[derive(Debug, Clone, Copy)] 23 | pub enum BetaSchedule { 24 | /// Linear interpolation. 25 | Linear, 26 | /// Linear interpolation of the square root of beta. 27 | ScaledLinear, 28 | /// Glide cosine schedule 29 | SquaredcosCapV2, 30 | } 31 | 32 | #[derive(Debug, Clone, Copy)] 33 | pub enum PredictionType { 34 | Epsilon, 35 | VPrediction, 36 | Sample, 37 | } 38 | 39 | /// Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of 40 | /// `(1-beta)` over time from `t = [0,1]`. 41 | /// 42 | /// Contains a function `alpha_bar` that takes an argument `t` and transforms it to the cumulative product of `(1-beta)` 43 | /// up to that part of the diffusion process. 44 | pub(crate) fn betas_for_alpha_bar(num_diffusion_timesteps: usize, max_beta: f64) -> Tensor { 45 | let alpha_bar = |time_step: usize| { 46 | f64::cos((time_step as f64 + 0.008) / 1.008 * std::f64::consts::FRAC_PI_2).powi(2) 47 | }; 48 | let mut betas = Vec::with_capacity(num_diffusion_timesteps); 49 | for i in 0..num_diffusion_timesteps { 50 | let t1 = i / num_diffusion_timesteps; 51 | let t2 = (i + 1) / num_diffusion_timesteps; 52 | betas.push((1.0 - alpha_bar(t2) / alpha_bar(t1)).min(max_beta)); 53 | } 54 | Tensor::from_slice(&betas) 55 | } 56 | 57 | /// One-dimensional linear interpolation for monotonically increasing sample 58 | /// points, mimicking np.interp(). 59 | /// 60 | /// Based on https://github.com/pytorch/pytorch/issues/50334#issuecomment-1000917964 61 | pub fn interp(x: &Tensor, xp: Tensor, yp: Tensor) -> Tensor { 62 | assert_eq!(xp.size(), yp.size()); 63 | let sz = xp.size1().unwrap(); 64 | 65 | // (yp[1:] - yp[:-1]) / (xp[1:] - xp[:-1]) 66 | let m = (yp.i(1..) - yp.i(..sz - 1)) / (xp.i(1..) - xp.i(..sz - 1)); 67 | 68 | // yp[:-1] - (m * xp[:-1]) 69 | let b = yp.i(..sz - 1) - (&m * xp.i(..sz - 1)); 70 | 71 | // torch.sum(torch.ge(x[:, None], xp[None, :]), 1) - 1 72 | let indices = x.unsqueeze(-1).ge_tensor(&xp.unsqueeze(0)); 73 | let indices = indices.sum_dim_intlist(1, false, Kind::Int64) - 1; 74 | // torch.clamp(indices, 0, len(m) - 1) 75 | let indices = indices.clamp(0, m.size1().unwrap() - 1); 76 | 77 | m.take(&indices) * x + b.take(&indices) 78 | } 79 | -------------------------------------------------------------------------------- /src/schedulers/pndm.rs: -------------------------------------------------------------------------------- 1 | use super::{betas_for_alpha_bar, BetaSchedule, PredictionType}; 2 | use tch::{kind, Kind, Tensor}; 3 | 4 | #[derive(Debug, Clone)] 5 | pub struct PNDMSchedulerConfig { 6 | /// The value of beta at the beginning of training. 7 | pub beta_start: f64, 8 | /// The value of beta at the end of training. 9 | pub beta_end: f64, 10 | /// How beta evolved during training. 11 | pub beta_schedule: BetaSchedule, 12 | /// each diffusion step uses the value of alphas product at that step and 13 | /// at the previous one. For the final step there is no previous alpha. 14 | /// When this option is `True` the previous alpha product is fixed to `1`, 15 | /// otherwise it uses the value of alpha at step 0. 16 | pub set_alpha_to_one: bool, 17 | /// prediction type of the scheduler function 18 | pub prediction_type: PredictionType, 19 | /// an offset added to the inference steps. 20 | pub steps_offset: usize, 21 | /// number of diffusion steps used to train the model. 22 | pub train_timesteps: usize, 23 | } 24 | 25 | impl Default for PNDMSchedulerConfig { 26 | fn default() -> Self { 27 | Self { 28 | beta_start: 0.00085, 29 | beta_end: 0.012, 30 | beta_schedule: BetaSchedule::ScaledLinear, 31 | set_alpha_to_one: false, 32 | prediction_type: PredictionType::Epsilon, 33 | steps_offset: 1, 34 | train_timesteps: 1000, 35 | } 36 | } 37 | } 38 | 39 | /// Pseudo numerical methods for diffusion models (PNDM) proposes using more advanced ODE 40 | /// integration techniques, namely Runge-Kutta method and a linear multi-step method. 41 | pub struct PNDMScheduler { 42 | alphas_cumprod: Vec, 43 | final_alpha_cumprod: f64, 44 | step_ratio: usize, 45 | init_noise_sigma: f64, 46 | counter: usize, 47 | cur_sample: Option, 48 | ets: Vec, 49 | timesteps: Vec, 50 | pub config: PNDMSchedulerConfig, 51 | } 52 | 53 | impl PNDMScheduler { 54 | pub fn new(inference_steps: usize, config: PNDMSchedulerConfig) -> Self { 55 | let betas = match config.beta_schedule { 56 | BetaSchedule::ScaledLinear => Tensor::linspace( 57 | config.beta_start.sqrt(), 58 | config.beta_end.sqrt(), 59 | config.train_timesteps as i64, 60 | kind::FLOAT_CPU, 61 | ) 62 | .square(), 63 | BetaSchedule::Linear => Tensor::linspace( 64 | config.beta_start, 65 | config.beta_end, 66 | config.train_timesteps as i64, 67 | kind::FLOAT_CPU, 68 | ), 69 | 70 | BetaSchedule::SquaredcosCapV2 => betas_for_alpha_bar(config.train_timesteps, 0.999), 71 | }; 72 | 73 | // &betas to avoid moving it 74 | let alphas: Tensor = 1. - betas; 75 | let alphas_cumprod = Vec::::try_from(alphas.cumprod(0, Kind::Double)).unwrap(); 76 | 77 | let final_alpha_cumprod = if config.set_alpha_to_one { 1.0 } else { alphas_cumprod[0] }; 78 | // creates integer timesteps by multiplying by ratio 79 | // casting to int to avoid issues when num_inference_step is power of 3 80 | let step_ratio = config.train_timesteps / inference_steps; 81 | let timesteps: Vec = 82 | (0..(inference_steps)).map(|s| s * step_ratio + config.steps_offset).collect(); 83 | 84 | let n_ts = timesteps.len(); 85 | // https://github.com/huggingface/diffusers/blob/8f581591598255eff72cce8858f365eace47481f/src/diffusers/schedulers/scheduling_pndm.py#L173 86 | let plms_timesteps = 87 | [×teps[..n_ts - 2], &[timesteps[n_ts - 2]], ×teps[n_ts - 2..]] 88 | .concat() 89 | .into_iter() 90 | .rev() 91 | .collect(); 92 | 93 | Self { 94 | alphas_cumprod, 95 | final_alpha_cumprod, 96 | step_ratio, 97 | init_noise_sigma: 1.0, 98 | counter: 0, 99 | cur_sample: None, 100 | ets: vec![], 101 | timesteps: plms_timesteps, 102 | config, 103 | } 104 | } 105 | 106 | pub fn timesteps(&self) -> &[usize] { 107 | self.timesteps.as_slice() 108 | } 109 | 110 | /// Ensures interchangeability with schedulers that need to scale the denoising model input 111 | /// depending on the current timestep. 112 | pub fn scale_model_input(&self, sample: Tensor, _timestep: usize) -> Tensor { 113 | sample 114 | } 115 | 116 | pub fn step(&mut self, model_output: &Tensor, timestep: usize, sample: &Tensor) -> Tensor { 117 | self.step_plms(model_output, timestep, sample) 118 | } 119 | 120 | /// Step function propagating the sample with the linear multi-step method. 121 | /// This has one forward pass with multiple times to approximate the solution. 122 | fn step_plms(&mut self, model_output: &Tensor, mut timestep: usize, sample: &Tensor) -> Tensor { 123 | let mut prev_timestep = timestep as isize - self.step_ratio as isize; 124 | 125 | if self.counter != 1 { 126 | // make sure `ets` has at least size 3 before 127 | // taking a slice of the last 3 128 | if self.ets.len() > 3 { 129 | // self.ets = self.ets[-3:] 130 | self.ets.drain(0..self.ets.len() - 3); 131 | } 132 | self.ets.push(model_output.shallow_clone()); 133 | } else { 134 | prev_timestep = timestep as isize; 135 | timestep += self.step_ratio; 136 | } 137 | 138 | let (ets_last, n_ets) = (self.ets.last().unwrap(), self.ets.len()); 139 | let (mut model_output, mut sample) = (model_output.shallow_clone(), sample.shallow_clone()); 140 | 141 | if n_ets == 1 && self.counter == 0 { 142 | self.cur_sample = Some(sample.shallow_clone()); 143 | } else if n_ets == 1 && self.counter == 1 { 144 | sample = self.cur_sample.as_ref().unwrap().shallow_clone(); 145 | self.cur_sample = None; 146 | model_output = (model_output + ets_last) / 2.; 147 | } else if n_ets == 2 { 148 | model_output = (3. * ets_last - &self.ets[n_ets - 2]) / 2.; 149 | } else if n_ets == 3 { 150 | model_output = 151 | (23. * ets_last - 16. * &self.ets[n_ets - 2] + 5. * &self.ets[n_ets - 3]) / 12.; 152 | } else { 153 | model_output = (1. / 24.) 154 | * (55. * ets_last - 59. * &self.ets[n_ets - 2] + 37. * &self.ets[n_ets - 3] 155 | - 9. * &self.ets[n_ets - 4]); 156 | } 157 | 158 | let prev_sample = self.get_prev_sample(sample, timestep, prev_timestep, model_output); 159 | self.counter += 1; 160 | 161 | prev_sample 162 | } 163 | 164 | fn get_prev_sample( 165 | &self, 166 | sample: Tensor, 167 | timestep: usize, 168 | prev_timestep: isize, 169 | model_output: Tensor, 170 | ) -> Tensor { 171 | // See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf 172 | // this function computes x_(t−δ) using the formula of (9) 173 | // Note that x_t needs to be added to both sides of the equation 174 | // 175 | // Notation ( -> 176 | // alpha_prod_t -> α_t 177 | // alpha_prod_t_prev -> α_(t−δ) 178 | // beta_prod_t -> (1 - α_t) 179 | // beta_prod_t_prev -> (1 - α_(t−δ)) 180 | // sample -> x_t 181 | // model_output -> e_θ(x_t, t) 182 | // prev_sample -> x_(t−δ) 183 | let alpha_prod_t = self.alphas_cumprod[timestep]; 184 | let alpha_prod_t_prev = if prev_timestep >= 0 { 185 | self.alphas_cumprod[prev_timestep as usize] 186 | } else { 187 | self.final_alpha_cumprod 188 | }; 189 | 190 | let beta_prod_t = 1. - alpha_prod_t; 191 | let beta_prod_t_prev = 1. - alpha_prod_t_prev; 192 | 193 | let model_output = match self.config.prediction_type { 194 | PredictionType::VPrediction => { 195 | alpha_prod_t.sqrt() * model_output + beta_prod_t.sqrt() * &sample 196 | } 197 | PredictionType::Epsilon => model_output.shallow_clone(), 198 | _ => unimplemented!("Prediction type must be one of `epsilon` or `v_prediction"), 199 | }; 200 | 201 | // corresponds to (α_(t−δ) - α_t) divided by 202 | // denominator of x_t in formula (9) and plus 1 203 | // Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = 204 | // sqrt(α_(t−δ)) / sqrt(α_t)) 205 | let sample_coeff = (alpha_prod_t_prev / alpha_prod_t).sqrt(); 206 | 207 | // corresponds to denominator of e_θ(x_t, t) in formula (9) 208 | let model_output_denom_coeff = alpha_prod_t * beta_prod_t_prev.sqrt() 209 | + (alpha_prod_t * beta_prod_t * alpha_prod_t_prev).sqrt(); 210 | 211 | // full formula (9) 212 | // prev sample 213 | sample_coeff * sample 214 | - (alpha_prod_t_prev - alpha_prod_t) * model_output / model_output_denom_coeff 215 | } 216 | 217 | pub fn add_noise(&self, original: &Tensor, noise: Tensor, timestep: usize) -> Tensor { 218 | let timestep = if timestep >= self.alphas_cumprod.len() { timestep - 1 } else { timestep }; 219 | let sqrt_alpha_prod = self.alphas_cumprod[timestep].sqrt(); 220 | let sqrt_one_minus_alpha_prod = (1.0 - self.alphas_cumprod[timestep]).sqrt(); 221 | // noisy samples 222 | sqrt_alpha_prod * original + sqrt_one_minus_alpha_prod * noise 223 | } 224 | 225 | pub fn init_noise_sigma(&self) -> f64 { 226 | self.init_noise_sigma 227 | } 228 | } 229 | -------------------------------------------------------------------------------- /src/transformers/mod.rs: -------------------------------------------------------------------------------- 1 | //! # Transformers 2 | //! 3 | //! The transformers module contains some basic implementation 4 | //! of transformers based models used to process the user prompt 5 | //! and generate the related embeddings. It also includes some 6 | //! simple tokenization. 7 | 8 | pub mod clip; 9 | -------------------------------------------------------------------------------- /src/utils.rs: -------------------------------------------------------------------------------- 1 | // A simple wrapper around File::open adding details about the 2 | // problematic file. 3 | use std::path::Path; 4 | use tch::Device; 5 | 6 | pub(crate) fn file_open>(path: P) -> anyhow::Result { 7 | std::fs::File::open(path.as_ref()).map_err(|e| { 8 | let context = format!("error opening {:?}", path.as_ref().to_string_lossy()); 9 | anyhow::Error::new(e).context(context) 10 | }) 11 | } 12 | 13 | pub struct DeviceSetup { 14 | accelerator_device: Device, 15 | cpu: Vec, 16 | } 17 | 18 | impl DeviceSetup { 19 | pub fn new(cpu: Vec) -> Self { 20 | let accelerator_device = 21 | if tch::utils::has_mps() { Device::Mps } else { Device::cuda_if_available() }; 22 | Self { accelerator_device, cpu } 23 | } 24 | 25 | pub fn get(&self, name: &str) -> Device { 26 | if self.cpu.iter().any(|c| c == "all" || c == name) { 27 | Device::Cpu 28 | } else { 29 | self.accelerator_device 30 | } 31 | } 32 | } 33 | --------------------------------------------------------------------------------