├── .cargo └── config.toml ├── .github └── FUNDING.yml ├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── LICENSE ├── Makefile ├── README.md ├── cake-cli ├── Cargo.toml └── src │ └── main.rs ├── cake-core ├── Cargo.toml └── src │ ├── cake │ ├── api │ │ ├── image.rs │ │ ├── mod.rs │ │ └── text.rs │ ├── client.rs │ ├── master.rs │ ├── mod.rs │ ├── proto │ │ ├── message.rs │ │ └── mod.rs │ ├── topology.rs │ └── worker.rs │ ├── lib.rs │ ├── models │ ├── chat.rs │ ├── llama3 │ │ ├── attention.rs │ │ ├── cache.rs │ │ ├── config.rs │ │ ├── history.rs │ │ ├── llama.rs │ │ ├── mlp.rs │ │ ├── mod.rs │ │ └── transformer.rs │ ├── mod.rs │ └── sd │ │ ├── clip.rs │ │ ├── mod.rs │ │ ├── safe_scheduler.rs │ │ ├── sd.rs │ │ ├── sd_shardable.rs │ │ ├── unet.rs │ │ ├── util.rs │ │ └── vae.rs │ └── utils │ └── mod.rs ├── cake-ios-worker-app ├── Cake Worker.entitlements ├── Cake Worker.xcodeproj │ ├── project.pbxproj │ ├── project.xcworkspace │ │ ├── contents.xcworkspacedata │ │ ├── xcshareddata │ │ │ └── IDEWorkspaceChecks.plist │ │ └── xcuserdata │ │ │ └── evilsocket.xcuserdatad │ │ │ └── UserInterfaceState.xcuserstate │ └── xcuserdata │ │ └── evilsocket.xcuserdatad │ │ └── xcschemes │ │ └── xcschememanagement.plist └── Cake Worker │ ├── Assets.xcassets │ ├── AccentColor.colorset │ │ └── Contents.json │ ├── AppIcon.appiconset │ │ └── Contents.json │ └── Contents.json │ ├── Cake_WorkerApp.swift │ ├── ContentView.swift │ └── Preview Content │ └── Preview Assets.xcassets │ └── Contents.json ├── cake-ios ├── Cargo.toml └── src │ ├── bin │ └── uniffi-bindgen.rs │ └── lib.rs └── cake-split-model ├── Cargo.toml └── src └── main.rs /.cargo/config.toml: -------------------------------------------------------------------------------- 1 | 2 | # https://github.com/huggingface/candle/issues/1625 3 | [target.'cfg(any(target_arch = "arm", target_arch = "aarch64"))'] 4 | rustflags = ["-C", "target-feature=+fp16"] 5 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: evilsocket 2 | patreon: evilsocket -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /cake-data 3 | /cake-ios/bindings 4 | /cake-ios-worker-app/Cake.xcframework 5 | /cake-ios-worker-app/Cake\ Worker/Cake.swift 6 | 7 | .DS_Store 8 | 9 | # Xcode 10 | # 11 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 12 | 13 | ## User settings 14 | xcuserdata/ 15 | 16 | ## Obj-C/Swift specific 17 | *.hmap 18 | 19 | ## App packaging 20 | *.ipa 21 | *.dSYM.zip 22 | *.dSYM 23 | 24 | ## Playgrounds 25 | timeline.xctimeline 26 | playground.xcworkspace 27 | 28 | # Swift Package Manager 29 | # 30 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 31 | # Packages/ 32 | # Package.pins 33 | # Package.resolved 34 | # *.xcodeproj 35 | # 36 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 37 | # hence it is not needed unless you have added a package configuration file to your project 38 | # .swiftpm 39 | 40 | .build/ 41 | 42 | # CocoaPods 43 | # 44 | # We recommend against adding the Pods directory to your .gitignore. However 45 | # you should judge for yourself, the pros and cons are mentioned at: 46 | # https://guides.cocoapods.org/using/using-cocoapods.html#should-i-check-the-pods-directory-into-source-control 47 | # 48 | # Pods/ 49 | # 50 | # Add this line if you want to avoid checking in source code from the Xcode workspace 51 | # *.xcworkspace 52 | 53 | # Carthage 54 | # 55 | # Add this line if you want to avoid checking in source code from Carthage dependencies. 56 | # Carthage/Checkouts 57 | 58 | Carthage/Build/ 59 | 60 | # fastlane 61 | # 62 | # It is recommended to not store the screenshots in the git repo. 63 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 64 | # For more information about the recommended setup visit: 65 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 66 | 67 | fastlane/report.xml 68 | fastlane/Preview.html 69 | fastlane/screenshots/**/*.png 70 | fastlane/test_output 71 | 72 | .idea 73 | /models 74 | /*.sh 75 | .vscode/ 76 | /images 77 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [workspace] 2 | resolver = "2" 3 | members = ["cake-core", "cake-cli", "cake-ios", "cake-split-model"] 4 | 5 | [workspace.package] 6 | version = "0.1.0" 7 | edition = "2021" 8 | description = "Distributed LLM inference for mobile, desktop and server." 9 | repository = "https://github.com/evilsocket/cake" 10 | keywords = ["blas", "tensor", "machine-learning"] 11 | authors = ["Simone Margaritelli "] 12 | license = "GPL-3.0" 13 | readme = "README.md" 14 | categories = ["science"] 15 | 16 | [profile.release] 17 | lto = true # Enable link-time optimization 18 | codegen-units = 1 # Reduce number of codegen units to increase optimizations 19 | panic = 'abort' # Abort on panic 20 | strip = true # Strip symbols from binary* 21 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | clean: 2 | cargo clean 3 | 4 | build: 5 | cargo build 6 | 7 | test: 8 | cargo test 9 | 10 | lint: 11 | cargo clippy --all-targets --all-features -- -D warnings 12 | 13 | build_release: 14 | cargo build --release 15 | 16 | ios_bindings: build 17 | cargo run --bin uniffi-bindgen generate --library ./target/debug/libcake.dylib --language swift --out-dir ./cake-ios/bindings 18 | 19 | ios: ios_bindings 20 | cargo build --release --target=aarch64-apple-ios 21 | mv ./cake-ios/bindings/cakeFFI.modulemap ./cake-ios/bindings/module.modulemap 22 | rm -rf ./cake-ios-worker-app/Cake\ Worker/Cake.swift 23 | mv ./cake-ios/bindings/cake.swift ./cake-ios-worker-app/Cake\ Worker/Cake.swift 24 | rm -rf "./cake-ios-worker-app/Cake.xcframework" 25 | xcodebuild -create-xcframework \ 26 | -library ./target/aarch64-apple-ios/release/libcake.a -headers ./cake-ios/bindings \ 27 | -output "./cake-ios-worker-app/Cake.xcframework" > /dev/null 28 | rm -rf ./cake-ios/bindings 29 | 30 | sync_bahamut: 31 | @echo "@ bahamut sync && build ..." 32 | @rsync -rvzc --exclude=cake-data --exclude=.git --exclude=target . bahamut.local:/home/evilsocket/cake 33 | @rsync -rvzc cake-data/8b-test/bahamut-node bahamut.local:/home/evilsocket/cake-data 34 | 35 | sync_blade: 36 | @echo "@ blade sync && build ..." 37 | @rsync -rvzc --exclude=cake-data --exclude=.git --exclude=target . blade.local:/home/evilsocket/cake 38 | @rsync -rvzc cake-data/8b-test/blade-node blade.local:/home/evilsocket/cake-data 39 | 40 | sync: sync_bahamut sync_blade -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Join the project community on our server! 3 |

4 | 5 | 6 | 7 |

8 |
9 | 10 | 11 | `Cake` is a Rust framework for distributed inference of large models like [LLama3](https://x.com/evilsocket/status/1812110504531259900) and [Stable Diffusion](https://x.com/crynuxai/status/1822085290543960216) based on [Candle](https://github.com/huggingface/candle). The goal of the project is being able to run big (70B+) models by repurposing consumer hardware into an heterogeneous cluster of iOS, Android, macOS, Linux and Windows devices, effectively leveraging [planned obsolescence](https://en.wikipedia.org/wiki/Planned_obsolescence) as a tool to make AI more accessible and democratic. 12 | 13 |

14 | 15 | ⚠ This is experimental code that's being actively developed and changed very quickly, expect bugs ⚠ 16 | 17 |

18 | 19 | The idea is to shard the transformer blocks to multiple devices in order to be able to run the inference on models that wouldn't normally fit in the GPU memory of a single device. Inferences over contiguous transformer blocks on the same worker are batched in order to minimize latency due to data transfer. 20 | 21 | ## Support 22 | 23 | | OS | Architectures | Acceleration | Status | 24 | |----------------------------------|------------------|------------------|------------------| 25 | | GNU/Linux | arm, arm64, x86_64 | - | ✅ | 26 | | GNU/Linux | arm, arm64, x86_64 | CUDA | ✅ | 27 | | GNU/Linux | arm, arm64, x86_64 | BLAS | ✅ | 28 | | Windows | x86_64 | BLAS | [untested](https://github.com/evilsocket/cake/issues/7) | 29 | | Windows | x86_64 | CUDA | ✅ | 30 | | macOS | x86_64 | - | ✅ | 31 | | macOS | aarch64 | - | ✅ | 32 | | macOS | aarch64 | Metal | ✅ | 33 | | Android | arm, arm64, x86_64 | - | ✅ | 34 | | Android | arm, arm64, x86_64 | CUDA | [untested](https://docs.nvidia.com/gameworks/content/technologies/mobile/cuda_android_main.htm) | 35 | | iOS / iPadOS | aarch64 | - | ✅ | 36 | | iOS / iPadOS | aarch64 | Metal | 🛠️ [90% done, WIP](https://github.com/huggingface/candle/issues/2322) | 37 | | Web | - | WebGPU | [in theory possible, not done](https://onnxruntime.ai/docs/tutorials/web/ep-webgpu.html) | 38 | 39 | CUDA >= 12.2 is required for CUDA accelerated systems. 40 | 41 | ## Compile 42 | 43 | With [Rust installed](https://www.rust-lang.org/tools/install), you can build the core library and the CLI utilities with different accelerations. 44 | 45 | Without acceleration (will use CPU): 46 | 47 | ```sh 48 | cargo build --release 49 | ``` 50 | 51 | With Metal acceleration for Apple Silicon: 52 | 53 | ```sh 54 | cargo build --release --features metal 55 | ``` 56 | 57 | With CUDA acceleration: 58 | 59 | ```sh 60 | cargo build --release --features cuda 61 | ``` 62 | 63 | To generate the iOS bindings in the app that can then be [compiled and deployed via XCode](https://github.com/evilsocket/cake/tree/main/cake-ios-worker-app): 64 | 65 | ```sh 66 | make ios 67 | ``` 68 | 69 | ## Using 70 | 71 | Run a worker node: 72 | 73 | ```sh 74 | cake-cli --model /path/to/Meta-Llama-3-8B \ # model path, read below on how to optimize model size for workers 75 | --mode worker \ # run as worker 76 | --name worker0 \ # worker name in topology file 77 | --topology topology.yml \ # topology 78 | --address 0.0.0.0:10128 # bind address 79 | ``` 80 | 81 | Run a master node with an OpenAI compatible REST API: 82 | 83 | ```sh 84 | cake-cli --model /path/to/Meta-Llama-3-8B \ # model path 85 | --api 0.0.0.0:8080 \ # API bind address 86 | --topology topology.yml # topology file 87 | ``` 88 | 89 | You can also omit the topology file to load the entire model in a single instance of cake: 90 | 91 | ```sh 92 | cake-cli --model /path/to/Meta-Llama-3-8B \ # model path 93 | --api 0.0.0.0:8080 # API bind address 94 | ``` 95 | 96 | ### Topology 97 | 98 | The `topology.yml` determines which layers are served by which worker (you can find a list of all the layers of a model in its [tensor index file](https://huggingface.co/meta-llama/Meta-Llama-3-70B/blob/main/model.safetensors.index.json)): 99 | 100 | ```yaml 101 | linux_server_1: 102 | host: 'linux_server.host:10128' 103 | description: 'NVIDIA Titan X Pascal (12GB)' 104 | layers: 105 | - 'model.layers.0-5' 106 | 107 | linux_server_2: 108 | host: 'linux_server2.host:10128' 109 | description: 'NVIDIA GeForce 3080 (10GB)' 110 | layers: 111 | - 'model.layers.6-16' 112 | 113 | iphone: 114 | host: 'iphone.host:10128' 115 | description: 'iPhone 15 Pro Max' 116 | layers: 117 | - 'model.layers.17' 118 | 119 | ipad: 120 | host: 'ipad.host:10128' 121 | description: 'iPad' 122 | layers: 123 | - 'model.layers.18-19' 124 | 125 | macbook: 126 | host: 'macbook.host:10128' 127 | description: 'M1 Max' 128 | layers: 129 | - 'model.layers.20-31' 130 | ``` 131 | 132 | You can now interact with the cluster by: 133 | 134 | ```sh 135 | curl http://master-ip:8080/api/v1/chat/completions \ ~ 136 | -H "Content-Type: application/json" \ 137 | -d '{ 138 | "messages": [ 139 | { 140 | "role": "system", 141 | "content": "You are a helpful AI assistant." 142 | }, 143 | { 144 | "role": "user", 145 | "content": "Why is the sky blue?" 146 | } 147 | ] 148 | }' 149 | ``` 150 | 151 | ### Splitting the Model 152 | 153 | As a memory and disk space optimization, you might want to give the worker only the data it actually needs from the model instead of the whole folder, in which case you can use the `cake-split-model` utility. For instance to generate a smaller version of the llama3 safetensors, you can: 154 | 155 | ```sh 156 | cake-split-model --model-path path/to/Meta-Llama-3-8B \ # source model to split 157 | --topology path/to/topology.yml \ # topology file 158 | --output output-folder-name # output folder where all the workers data bundles will be saved 159 | ``` 160 | 161 | This will create a smaller folder with only the required layers tensors and the topology file for the specific worker. Remember to also copy other model contents (config.json, tokenizer.json, etc) in the worker bundle before deploying it. 162 | 163 | ### Stable Diffusion Image Generation 164 | 165 | Define the model parts inside `topology.yml`: 166 | 167 | ```yaml 168 | wsl2_on_windows: 169 | host: 192.168.1.2:10128 170 | description: NVIDIA RTX 4090 24GB 171 | layers: 172 | - unet 173 | 174 | macbook: 175 | host: 192.168.1.3:10128 176 | description: Macbook M2 177 | layers: 178 | - clip 179 | - vae 180 | ``` 181 | 182 | Run a worker node: 183 | 184 | ```sh 185 | cake-cli --model /path/to/hf/cache \ # The cache dir for huggingface models 186 | --mode worker \ # run as worker 187 | --name wsl2_on_windows \ # worker name in topology file 188 | --model-type image-model \ # use image-model for SD, text-model or skip for LLM 189 | --topology topology.yml \ # topology 190 | --address 0.0.0.0:10128 # bind address 191 | ``` 192 | 193 | The model could be switched between SD1.5, SD2.1, SDXL and SDXL Turbo by specifying [more command line arguments](./cake-core/src/lib.rs). 194 | 195 | The model files will be downloaded from Huggingface automatically if not found in the local cache directory. 196 | 197 | Run a master node with REST API: 198 | 199 | ```sh 200 | cake-cli --model /path/to/hf/cache \ # The cache dir for huggingface models 201 | --api 0.0.0.0:8080 \ # API bind address 202 | --model-type image-model \ # use image-model for SD, text-model or skip for LLM 203 | --topology topology.yml # topology file 204 | ``` 205 | 206 | Generate images using the cluster: 207 | 208 | ```sh 209 | curl http://master-ip:8080/api/v1/image \ ~ 210 | -H "Content-Type: application/json" \ 211 | -d '{ 212 | "image_args": { 213 | "sd-image-prompt": "An old man sitting on the chair at seaside", 214 | "sd-num-samples": 1, 215 | "sd-image-seed": 2439383 216 | } 217 | }' 218 | ``` 219 | 220 | More control arguments could be found [inside the codes](./cake-core/src/lib.rs). 221 | 222 | 223 | ## License 224 | 225 | Released under the GPL 3 license. To see the licenses of the project dependencies, install cargo license with `cargo install cargo-license` and then run `cargo license`. 226 | -------------------------------------------------------------------------------- /cake-cli/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cake-cli" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | anyhow = "1.0.86" 8 | cake-core = { path = "../cake-core" } 9 | clap = "4.5.8" 10 | env_logger = "0.11.3" 11 | tokio = { version = "1.38.0", features = ["full"] } 12 | -------------------------------------------------------------------------------- /cake-cli/src/main.rs: -------------------------------------------------------------------------------- 1 | //! This is the cake command line utility. 2 | 3 | use cake_core::{ 4 | cake::{Context, Master, Mode, Worker}, 5 | Args, ModelType, 6 | }; 7 | 8 | use anyhow::Result; 9 | use clap::Parser; 10 | 11 | #[tokio::main] 12 | async fn main() -> Result<()> { 13 | // parse command line 14 | let args = Args::parse(); 15 | 16 | // setup logging 17 | if std::env::var_os("RUST_LOG").is_none() { 18 | // set `RUST_LOG=debug` to see debug logs 19 | std::env::set_var("RUST_LOG", "info,tokenizers=error,actix_server=warn"); 20 | } 21 | 22 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) 23 | .format_module_path(false) 24 | .format_target(false) 25 | .init(); 26 | 27 | // setup context 28 | let mut ctx = Context::from_args(args)?; 29 | 30 | // run either in master or worker mode depending on command line 31 | let ret = match ctx.args.mode { 32 | Mode::Master => { 33 | Master::::new(ctx) 34 | .await? 35 | .run() 36 | .await 37 | } 38 | Mode::Worker => match ctx.args.model_type { 39 | ModelType::TextModel => { 40 | Worker::::new(&mut ctx) 41 | .await? 42 | .run() 43 | .await 44 | } 45 | ModelType::ImageModel => { 46 | Worker::::new(&mut ctx) 47 | .await? 48 | .run() 49 | .await 50 | } 51 | }, 52 | }; 53 | 54 | if ret.is_err() { 55 | // we were possibly streaming text, add a newline before reporting the error 56 | println!(); 57 | return ret; 58 | } 59 | 60 | Ok(()) 61 | } 62 | -------------------------------------------------------------------------------- /cake-core/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cake-core" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | anyhow = "1.0.86" 8 | async-trait = "0.1.80" 9 | 10 | clap = { version = "4.5.8", features = ["derive"] } 11 | human_bytes = "0.4.3" 12 | lazy_static = "1.5.0" 13 | log = "0.4.22" 14 | memmap2 = "0.9.4" 15 | memory-stats = "1.2.0" 16 | regex = "1.10.5" 17 | safetensors = "0.4.3" 18 | serde = { version = "1.0.203", features = ["derive"] } 19 | serde_json = "1.0.120" 20 | serde_yaml = "0.9.34" 21 | speedy = "0.8.7" 22 | tokenizers = { version = "0.19.1", features = ["onig"] } 23 | tokio = { version = "1.38.0", features = ["full"] } 24 | yoke = { version = "0.7.4", features = ["derive"] } 25 | 26 | actix-web = { version = "4.8.0", optional = true } 27 | uuid = { version = "1.10.0", optional = true, features = ["v4"] } 28 | 29 | candle-core = { version = "0.7.2" } 30 | candle-nn = { version = "0.7.2" } 31 | candle-transformers = { version = "0.7.2" } 32 | image = "0.25.2" 33 | hf-hub = "0.3.2" 34 | tracing-chrome = "0.7.2" 35 | tracing-subscriber = "0.3.18" 36 | base64 = "0.22.1" 37 | 38 | [features] 39 | default = ["master"] 40 | 41 | metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal"] 42 | cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda"] 43 | 44 | master = ["dep:actix-web", "dep:uuid"] 45 | -------------------------------------------------------------------------------- /cake-core/src/cake/api/image.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::Master; 2 | use crate::models::ImageGenerator; 3 | use crate::models::TextGenerator; 4 | use crate::ImageGenerationArgs; 5 | use actix_web::{web, HttpRequest, HttpResponse, Responder}; 6 | use base64::engine::general_purpose; 7 | use base64::Engine; 8 | use image::{DynamicImage, ImageFormat}; 9 | use serde::{Deserialize, Serialize}; 10 | use std::io::Cursor; 11 | use std::sync::Arc; 12 | use std::sync::Mutex; 13 | use tokio::sync::RwLock; 14 | 15 | #[derive(Deserialize)] 16 | pub struct ImageRequest { 17 | pub image_args: ImageGenerationArgs, 18 | } 19 | 20 | #[derive(Serialize)] 21 | struct ImageResponse { 22 | pub images: Vec, 23 | } 24 | 25 | pub async fn generate_image( 26 | state: web::Data>>>, 27 | req: HttpRequest, 28 | image_request: web::Json, 29 | ) -> impl Responder 30 | where 31 | TG: TextGenerator + Send + Sync + 'static, 32 | IG: ImageGenerator + Send + Sync + 'static, 33 | { 34 | let client = req.peer_addr().unwrap(); 35 | 36 | log::info!("starting generating image for {} ...", &client); 37 | 38 | let mut master = state.write().await; 39 | 40 | let result_images = Arc::new(Mutex::new(Vec::new())); 41 | let result_images_cloned = Arc::clone(&result_images); 42 | 43 | master 44 | .generate_image(image_request.image_args.clone(), move |images| { 45 | let mut base64_images: Vec = images 46 | .iter() 47 | .map(|image| { 48 | let dynamic_image = DynamicImage::ImageRgb8(image.clone()); 49 | let mut png_bytes = Vec::new(); 50 | let mut cursor = Cursor::new(&mut png_bytes); 51 | dynamic_image 52 | .write_to(&mut cursor, ImageFormat::Png) 53 | .unwrap(); 54 | general_purpose::STANDARD.encode(png_bytes) 55 | }) 56 | .collect(); 57 | 58 | let mut locked_result_images = 59 | result_images_cloned.lock().expect("Error acquiring lock"); 60 | locked_result_images.append(&mut base64_images); 61 | }) 62 | .await 63 | .expect("Error generating images"); 64 | 65 | let locked_result_images = result_images.lock().expect("Error acquiring lock"); 66 | let response = ImageResponse { 67 | images: locked_result_images.to_vec(), 68 | }; 69 | 70 | HttpResponse::Ok().json(response) 71 | } 72 | -------------------------------------------------------------------------------- /cake-core/src/cake/api/mod.rs: -------------------------------------------------------------------------------- 1 | mod image; 2 | mod text; 3 | 4 | use std::sync::Arc; 5 | 6 | use actix_web::web; 7 | use actix_web::App; 8 | use actix_web::HttpResponse; 9 | use actix_web::HttpServer; 10 | use tokio::sync::RwLock; 11 | 12 | use crate::models::{ImageGenerator, TextGenerator}; 13 | 14 | use image::*; 15 | use text::*; 16 | 17 | use super::Master; 18 | 19 | async fn not_found() -> actix_web::Result { 20 | Ok(HttpResponse::NotFound().body("nope")) 21 | } 22 | 23 | pub(crate) async fn start(master: Master) -> anyhow::Result<()> 24 | where 25 | TG: TextGenerator + Send + Sync + 'static, 26 | IG: ImageGenerator + Send + Sync + 'static, 27 | { 28 | let address = master.ctx.args.api.as_ref().unwrap().to_string(); 29 | 30 | log::info!("starting api on http://{} ...", &address); 31 | 32 | let state = Arc::new(RwLock::new(master)); 33 | 34 | HttpServer::new( 35 | move || { 36 | App::new() 37 | .app_data(web::Data::new(state.clone())) 38 | .route( 39 | "/api/v1/chat/completions", 40 | web::post().to(generate_text::), 41 | ) 42 | .route("/api/v1/image", web::post().to(generate_image::)) 43 | .default_service(web::route().to(not_found)) 44 | }, //.wrap(actix_web::middleware::Logger::default())) 45 | ) 46 | .bind(&address) 47 | .map_err(|e| anyhow!(e))? 48 | .run() 49 | .await 50 | .map_err(|e| anyhow!(e)) 51 | } 52 | -------------------------------------------------------------------------------- /cake-core/src/cake/api/text.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::Master; 2 | use crate::models::chat::Message; 3 | use crate::models::{ImageGenerator, TextGenerator}; 4 | use actix_web::{web, HttpRequest, HttpResponse, Responder}; 5 | use serde::{Deserialize, Serialize}; 6 | use std::io::Write; 7 | use std::sync::Arc; 8 | use std::time::{SystemTime, UNIX_EPOCH}; 9 | use tokio::sync::RwLock; 10 | 11 | #[derive(Deserialize)] 12 | pub struct ChatRequest { 13 | pub messages: Vec, 14 | } 15 | 16 | #[derive(Serialize)] 17 | struct Choice { 18 | pub index: usize, 19 | pub message: Message, 20 | } 21 | 22 | #[derive(Serialize)] 23 | struct ChatResponse { 24 | pub id: String, 25 | pub object: String, 26 | pub created: u64, 27 | pub model: String, 28 | pub choices: Vec, 29 | } 30 | 31 | impl ChatResponse { 32 | pub fn from_assistant_response(model: String, message: String) -> Self { 33 | let id = uuid::Uuid::new_v4().to_string(); 34 | let object = String::from("chat.completion"); 35 | let created = SystemTime::now() 36 | .duration_since(UNIX_EPOCH) 37 | .unwrap() 38 | .as_secs(); 39 | let choices = vec![Choice { 40 | index: 0, 41 | message: Message::assistant(message), 42 | }]; 43 | 44 | Self { 45 | id, 46 | object, 47 | created, 48 | model, 49 | choices, 50 | } 51 | } 52 | } 53 | 54 | pub async fn generate_text( 55 | state: web::Data>>>, 56 | req: HttpRequest, 57 | messages: web::Json, 58 | ) -> impl Responder 59 | where 60 | TG: TextGenerator + Send + Sync + 'static, 61 | IG: ImageGenerator + Send + Sync + 'static, 62 | { 63 | let client = req.peer_addr().unwrap(); 64 | 65 | log::info!("starting chat for {} ...", &client); 66 | 67 | let mut master = state.write().await; 68 | 69 | master.reset().unwrap(); 70 | 71 | let llm_model = master.llm_model.as_mut().expect("LLM model not found"); 72 | 73 | for message in messages.0.messages { 74 | llm_model.add_message(message).unwrap(); 75 | } 76 | 77 | let mut resp = String::new(); 78 | 79 | // just run one generation to stdout 80 | master 81 | .generate_text(|data| { 82 | resp += data; 83 | if data.is_empty() { 84 | println!(); 85 | } else { 86 | print!("{data}") 87 | } 88 | std::io::stdout().flush().unwrap(); 89 | }) 90 | .await 91 | .unwrap(); 92 | 93 | let response = ChatResponse::from_assistant_response(TG::MODEL_NAME.to_string(), resp); 94 | master.goodbye().await.unwrap(); 95 | 96 | HttpResponse::Ok().json(response) 97 | } 98 | -------------------------------------------------------------------------------- /cake-core/src/cake/client.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use candle_core::{Device, Tensor}; 4 | use tokio::net::TcpStream; 5 | 6 | use super::{Context, Message, WorkerInfo}; 7 | 8 | /// A client object used by the master to connect and orchestrate the workers. 9 | /// From the Cake perspective, each worker is a server and the master uses 10 | /// multiple Client instances to connect to them. 11 | #[derive(Debug)] 12 | pub struct Client { 13 | device: Device, 14 | address: String, 15 | layer_name: String, 16 | stream: TcpStream, 17 | info: WorkerInfo, 18 | } 19 | 20 | impl Client { 21 | /// Connects to the given worker address. 22 | /// NOTE: device and layer_name here are only passed for std::fmt::Display. 23 | pub async fn new(device: Device, address: &str, layer_name: &str) -> Result { 24 | let address = address.to_string(); 25 | let layer_name = layer_name.to_string(); 26 | let stream = TcpStream::connect(&address) 27 | .await 28 | .map_err(|e| anyhow!("can't connect to {address}: {e}"))?; 29 | let worker_info = WorkerInfo::default(); 30 | 31 | let mut client = Self { 32 | address, 33 | device, 34 | stream, 35 | layer_name, 36 | info: worker_info, 37 | }; 38 | 39 | let resp = client.request(Message::Hello).await?; 40 | client.info = if let Message::WorkerInfo(info) = resp { 41 | info 42 | } else { 43 | return Err(anyhow!("unexpected worker info message: {:?}", &resp)); 44 | }; 45 | 46 | Ok(client) 47 | } 48 | 49 | /// Send a Message to the worker and return a response. 50 | async fn request(&mut self, req: Message) -> Result { 51 | req.to_writer(&mut self.stream) 52 | .await 53 | .map_err(|e| anyhow!("error sending message {:?}: {}", req, e))?; 54 | 55 | let (_, msg) = super::Message::from_reader(&mut self.stream) 56 | .await 57 | .map_err(|e| anyhow!("error receiving response for {:?}: {}", req, e))?; 58 | Ok(msg) 59 | } 60 | 61 | async fn forward_request(&mut self, req: Message) -> Result { 62 | let resp = self.request(req).await?; 63 | match resp { 64 | Message::Tensor(raw) => Ok(raw.to_tensor(&self.device)?), 65 | _ => Err(anyhow!("unexpected response {:?}", &resp)), 66 | } 67 | } 68 | } 69 | 70 | impl std::fmt::Display for Client { 71 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 72 | write!( 73 | f, 74 | "{}@{} [{}<{}> {}-{} latency={}ms]", 75 | &self.layer_name, 76 | &self.address, 77 | &self.info.device, 78 | &self.info.device_idx, 79 | &self.info.os, 80 | &self.info.arch, 81 | self.info.latency 82 | ) 83 | } 84 | } 85 | 86 | #[async_trait] 87 | impl super::Forwarder for Client { 88 | fn load(_: String, _: &Context) -> Result> { 89 | Err(anyhow!("load should never be called on cake::Client")) 90 | } 91 | 92 | async fn forward(&self, _: &Tensor, _: usize, _: usize, _: &mut Context) -> Result { 93 | Err(anyhow!( 94 | "immutable forward should never be called on cake::Client" 95 | )) 96 | } 97 | 98 | /// Executes the worker's pipeline for this tensor. 99 | async fn forward_mut( 100 | &mut self, 101 | x: &Tensor, 102 | index_pos: usize, 103 | block_idx: usize, 104 | _: &mut Context, 105 | ) -> Result { 106 | self.forward_request(super::Message::single_op( 107 | &self.layer_name, 108 | x, 109 | index_pos, 110 | block_idx, 111 | )) 112 | .await 113 | } 114 | 115 | /// Executes the worker's pipeline with multiple batched steps for this tensor. 116 | async fn forward_batch( 117 | &mut self, 118 | x: &Tensor, 119 | batch: Vec<(String, usize, usize)>, 120 | _: &mut Context, 121 | ) -> Result { 122 | self.forward_request(super::Message::from_batch(x, batch)) 123 | .await 124 | } 125 | 126 | async fn goodbye(&mut self) -> Result<()> { 127 | self.request(Message::Goodbye).await?; 128 | Ok(()) 129 | } 130 | 131 | fn layer_name(&self) -> &str { 132 | &self.layer_name 133 | } 134 | 135 | fn ident(&self) -> &str { 136 | &self.address 137 | } 138 | } 139 | -------------------------------------------------------------------------------- /cake-core/src/cake/master.rs: -------------------------------------------------------------------------------- 1 | use std::io::Write; 2 | 3 | use crate::models::{chat::Message, ImageGenerator, TextGenerator}; 4 | 5 | use super::{api, Context}; 6 | 7 | use crate::{ImageGenerationArgs, ModelType}; 8 | use anyhow::Result; 9 | use image::{ImageBuffer, Rgb}; 10 | 11 | /// A master connects to, communicates with and orchestrates the workers. 12 | pub struct Master { 13 | pub ctx: Context, 14 | pub llm_model: Option>, 15 | pub sd_model: Option>, 16 | } 17 | 18 | impl 19 | Master 20 | { 21 | /// Create a new instance. 22 | pub async fn new(mut ctx: Context) -> Result { 23 | match ctx.args.model_type { 24 | ModelType::ImageModel => { 25 | let sd_model = IG::load(&mut ctx).await?; 26 | Ok(Self { 27 | ctx, 28 | sd_model, 29 | llm_model: None, 30 | }) 31 | } 32 | ModelType::TextModel => { 33 | let llm_model = TG::load(&mut ctx).await?; 34 | Ok(Self { 35 | ctx, 36 | llm_model, 37 | sd_model: None, 38 | }) 39 | } 40 | } 41 | } 42 | 43 | pub async fn run(mut self) -> Result<()> { 44 | if self.ctx.args.api.is_some() { 45 | // run as REST api 46 | api::start(self).await?; 47 | } else { 48 | // if running in cli mode, pre add system and user prompts 49 | if self.ctx.args.model_type == ModelType::TextModel { 50 | let llm_model = self.llm_model.as_mut().expect("LLM model not found"); 51 | llm_model.add_message(Message::system(self.ctx.args.system_prompt.clone()))?; 52 | llm_model.add_message(Message::user(self.ctx.args.prompt.clone()))?; 53 | 54 | // just run one generation to stdout 55 | self.generate_text(|data| { 56 | if data.is_empty() { 57 | println!(); 58 | } else { 59 | print!("{data}") 60 | } 61 | std::io::stdout().flush().unwrap(); 62 | }) 63 | .await?; 64 | } else { 65 | let mut step_num = 0; 66 | 67 | self.generate_image(self.ctx.args.sd_img_gen_args.clone(), move |images| { 68 | let mut batched_num = 0; 69 | for image in images { 70 | image 71 | .save(format!("images/image_{}_{}.png", batched_num, step_num)) 72 | .expect("Error saving image to disk"); 73 | batched_num += 1; 74 | } 75 | step_num += 1; 76 | }) 77 | .await?; 78 | } 79 | } 80 | 81 | Ok(()) 82 | } 83 | 84 | /// Reset the master state for a new inference. 85 | pub fn reset(&mut self) -> Result<()> { 86 | self.llm_model 87 | .as_mut() 88 | .expect("LLM model not found") 89 | .reset() 90 | } 91 | 92 | /// clear worker kv cache 93 | pub async fn goodbye(&mut self) -> Result<()> { 94 | self.llm_model 95 | .as_mut() 96 | .expect("LLM model not found") 97 | .goodbye() 98 | .await 99 | } 100 | 101 | /// Start the generation loop and call the stream function for every token. 102 | pub async fn generate_text(&mut self, mut stream: S) -> Result<()> 103 | where 104 | S: FnMut(&str), 105 | { 106 | log::info!( 107 | "starting the inference loop (mem={})\n\n", 108 | human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64) 109 | ); 110 | 111 | log::debug!(" ctx.args.sample_len = {}", self.ctx.args.sample_len); 112 | 113 | // stream(&self.ctx.args.prompt); 114 | 115 | let mut start_gen = std::time::Instant::now(); 116 | let llm_model = self.llm_model.as_mut().expect("LLM model not found"); 117 | 118 | for index in 0..self.ctx.args.sample_len { 119 | if index == 1 { 120 | // record start time again since the first token is the warmup 121 | start_gen = std::time::Instant::now() 122 | } 123 | 124 | let token = llm_model.next_token(index).await?; 125 | if token.is_end_of_stream { 126 | break; 127 | } else { 128 | stream(&token.to_string()); 129 | } 130 | } 131 | 132 | // signal end of stream 133 | stream(""); 134 | 135 | let dt = start_gen.elapsed(); 136 | let generated = llm_model.generated_tokens(); 137 | 138 | log::info!( 139 | "{} tokens generated ({} token/s) - mem={}", 140 | generated, 141 | (generated - 1) as f64 / dt.as_secs_f64(), 142 | human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64) 143 | ); 144 | 145 | Ok(()) 146 | } 147 | 148 | pub async fn generate_image(&mut self, args: ImageGenerationArgs, callback: F) -> Result<()> 149 | where 150 | F: FnMut(Vec, Vec>>) + Send + 'static, 151 | { 152 | let sd_model = self.sd_model.as_mut().expect("SD model not found"); 153 | sd_model.generate_image(&args, callback).await 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /cake-core/src/cake/mod.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | fmt::{Debug, Display}, 3 | path::PathBuf, 4 | }; 5 | 6 | use crate::{ 7 | models::llama3::{Cache, Config, LlamaConfig}, 8 | utils, Args, ModelType, 9 | }; 10 | use anyhow::Result; 11 | use async_trait::async_trait; 12 | use candle_core::{DType, Device, Tensor}; 13 | use candle_nn::VarBuilder; 14 | 15 | #[cfg(feature = "master")] 16 | mod api; 17 | #[cfg(feature = "master")] 18 | mod master; 19 | 20 | mod client; 21 | mod proto; 22 | mod topology; 23 | mod worker; 24 | 25 | #[cfg(feature = "master")] 26 | pub use master::*; 27 | 28 | pub use client::*; 29 | pub use proto::*; 30 | pub use topology::*; 31 | pub use worker::*; 32 | 33 | /// Determines if we run in master or worker mode. 34 | #[derive(clap::ValueEnum, Clone, Debug, Default)] 35 | pub enum Mode { 36 | #[default] 37 | Master, 38 | Worker, 39 | } 40 | 41 | /// Main contect object used as a shared state. 42 | #[derive(Clone)] 43 | pub struct Context { 44 | pub args: Args, 45 | pub dtype: DType, 46 | pub topology: Topology, 47 | pub data_path: PathBuf, 48 | pub device: Device, 49 | pub config: Option, // TODO: decouple 50 | pub cache: Option, 51 | pub var_builder: Option>, 52 | } 53 | 54 | impl Context { 55 | /// Create the context from the parsed command line arguments. 56 | pub fn from_args(args: Args) -> Result { 57 | let dtype: DType = match args.dtype.as_deref() { 58 | Some("f16") => DType::F16, 59 | Some("bf16") => DType::BF16, 60 | Some("f32") => DType::F32, 61 | Some(dtype) => bail!("unsupported dtype {dtype}"), 62 | None => DType::F16, 63 | }; 64 | 65 | let device = utils::get_inference_device(args.cpu, args.device) 66 | .map_err(|e| anyhow!("can't attach to device: {:?}", e))?; 67 | 68 | log::info!( 69 | "[{:?}] dtype={:?} device={:?} mem={}", 70 | args.mode, 71 | &dtype, 72 | &device, 73 | human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64) 74 | ); 75 | 76 | let data_path = PathBuf::from(&args.model); 77 | if !data_path.exists() { 78 | bail!("model path does not exist: {}", data_path.display()); 79 | } 80 | 81 | let topology = if let Some(path) = &args.topology { 82 | Topology::from_path(path, &args.model_type)? 83 | } else { 84 | log::warn!("no topology file specified, the entire model will be loaded"); 85 | Topology::new() 86 | }; 87 | 88 | let mut config: Option = None; 89 | let mut cache: Option = None; 90 | let mut var_builder: Option = None; 91 | 92 | if args.model_type == ModelType::TextModel { 93 | let config_filename = data_path.join("config.json"); 94 | let config_internal = LlamaConfig::from_path(&config_filename)?.into_config(); 95 | let model_tensors_index: PathBuf = data_path.join("model.safetensors.index.json"); 96 | var_builder = Some(utils::load_var_builder_from_index( 97 | model_tensors_index, 98 | dtype, 99 | device.clone(), 100 | )?); 101 | cache = Some(Cache::new(true, dtype, &config_internal, &device)?); 102 | config = Some(config_internal); 103 | } 104 | 105 | Ok(Context { 106 | args, 107 | dtype, 108 | topology, 109 | data_path, 110 | device, 111 | config, 112 | cache, 113 | var_builder, 114 | }) 115 | } 116 | } 117 | 118 | /// This is the trait that a shardable object must implement. 119 | #[async_trait] 120 | pub trait Forwarder: Debug + Send + Sync + Display { 121 | /// Create an instance of this object loading the specified layer(s) from a VarBuilder. 122 | fn load(name: String, ctx: &Context) -> Result> 123 | where 124 | Self: Sized; 125 | 126 | /// Applies a forward operation to the input tensor, does not require mutability. 127 | async fn forward( 128 | &self, 129 | x: &Tensor, 130 | index_pos: usize, 131 | block_idx: usize, 132 | ctx: &mut Context, 133 | ) -> Result; 134 | 135 | /// Applies a forward operation to the input tensor, requires mutability. 136 | async fn forward_mut( 137 | &mut self, 138 | x: &Tensor, 139 | index_pos: usize, 140 | block_idx: usize, 141 | ctx: &mut Context, 142 | ) -> Result; 143 | 144 | /// Applies a batch of forward operations to the input tensor. 145 | async fn forward_batch( 146 | &mut self, 147 | _x: &Tensor, 148 | _batch: Vec<(String, usize, usize)>, 149 | _ctx: &mut Context, 150 | ) -> Result { 151 | unimplemented!() 152 | } 153 | 154 | async fn goodbye(&mut self) -> Result<()> { 155 | unimplemented!() 156 | } 157 | 158 | /// Return the layer name. 159 | fn layer_name(&self) -> &str; 160 | 161 | /// Return the unique identity or local. 162 | fn ident(&self) -> &str { 163 | "local" 164 | } 165 | } 166 | -------------------------------------------------------------------------------- /cake-core/src/cake/proto/message.rs: -------------------------------------------------------------------------------- 1 | use std::str::FromStr; 2 | 3 | use anyhow::Result; 4 | use candle_core::{DType, Device, Tensor}; 5 | use safetensors::View; 6 | use speedy::{BigEndian, Readable, Writable}; 7 | use tokio::io::{AsyncReadExt, AsyncWriteExt}; 8 | 9 | /// Represents a tensor in Cake protocol. 10 | #[derive(Debug, Readable, Writable)] 11 | pub struct RawTensor { 12 | /// Tensor data. 13 | pub data: Vec, 14 | /// The data type as string. 15 | pub dtype: String, 16 | /// The tensor shape. 17 | pub shape: Vec, 18 | } 19 | 20 | impl RawTensor { 21 | /// Convert x into a RawTensor. 22 | pub fn from_tensor(x: &Tensor) -> Self { 23 | let data: Vec = x.data().to_vec(); 24 | let dtype = x.dtype().as_str().to_string(); 25 | let shape = x.shape().clone().into_dims(); 26 | Self { data, dtype, shape } 27 | } 28 | 29 | /// Convert the raw tensor in a Tensor allocated on the given device. 30 | pub fn to_tensor(&self, device: &Device) -> Result { 31 | let dtype = DType::from_str(&self.dtype)?; 32 | Ok(Tensor::from_raw_buffer( 33 | &self.data, 34 | dtype, 35 | &self.shape, 36 | device, 37 | )?) 38 | } 39 | } 40 | 41 | /// Diagnostic information about a worker. 42 | #[derive(Debug, Default, Readable, Writable)] 43 | pub struct WorkerInfo { 44 | /// Protocol version. 45 | pub version: String, 46 | /// Tensors data type. 47 | pub dtype: String, 48 | /// Operating system. 49 | pub os: String, 50 | /// Architecture. 51 | pub arch: String, 52 | /// Device. 53 | pub device: String, 54 | /// Device index for multi GPU environments. 55 | pub device_idx: usize, 56 | /// Latency in millisenconds. 57 | pub latency: u128, 58 | } 59 | 60 | /// A Cake protocol message. 61 | #[derive(Debug, Readable, Writable)] 62 | pub enum Message { 63 | /// First message sent. 64 | Hello, 65 | /// Message that the worker sends when a master connects with runtime information. 66 | WorkerInfo(WorkerInfo), 67 | /// Single inference operation for a given layer. 68 | SingleOp { 69 | layer_name: String, 70 | x: RawTensor, 71 | index_pos: usize, 72 | block_idx: usize, 73 | }, 74 | /// Batched inference operations over a Tensor. 75 | Batch { 76 | x: RawTensor, 77 | batch: Vec<(String, usize, usize)>, 78 | }, 79 | /// A message to transmit tensors. 80 | Tensor(RawTensor), 81 | /// Last message sent. 82 | Goodbye, 83 | } 84 | 85 | #[inline] 86 | async fn read_u32be(reader: &mut R) -> Result 87 | where 88 | R: AsyncReadExt + Unpin, 89 | { 90 | Ok(u32::from_be(reader.read_u32().await?)) 91 | } 92 | 93 | #[inline] 94 | async fn write_u32be(writer: &mut W, n: u32) -> Result<()> 95 | where 96 | W: AsyncWriteExt + Unpin, 97 | { 98 | Ok(writer.write_u32(n.to_be()).await?) 99 | } 100 | 101 | impl Message { 102 | /// Create a Message::SingleOp message. 103 | pub fn single_op(layer_name: &str, x: &Tensor, index_pos: usize, block_idx: usize) -> Self { 104 | let layer_name = layer_name.to_owned(); 105 | let x = RawTensor::from_tensor(x); 106 | Self::SingleOp { 107 | layer_name, 108 | x, 109 | index_pos, 110 | block_idx, 111 | } 112 | } 113 | 114 | /// Create a Message::Tensor message. 115 | pub fn from_tensor(x: &Tensor) -> Self { 116 | Self::Tensor(RawTensor::from_tensor(x)) 117 | } 118 | 119 | /// Create a Message::Batch message. 120 | pub fn from_batch(x: &Tensor, batch: Vec<(String, usize, usize)>) -> Self { 121 | Self::Batch { 122 | x: RawTensor::from_tensor(x), 123 | batch, 124 | } 125 | } 126 | 127 | // Yes, I could use GRPC, but this is simpler and faster. 128 | // Check speedy benchmarks ;) 129 | 130 | /// Serializes the message to raw bytes. 131 | fn to_bytes(&self) -> Result> { 132 | Ok(self.write_to_vec_with_ctx(BigEndian::default())?) 133 | } 134 | 135 | /// Deserializes a Message from raw bytes. 136 | fn from_bytes(raw: &[u8]) -> Result { 137 | Ok(Self::read_from_buffer_with_ctx(BigEndian::default(), raw)?) 138 | } 139 | 140 | /// Read a Message with the provided reader. 141 | pub async fn from_reader(reader: &mut R) -> Result<(usize, Self)> 142 | where 143 | R: AsyncReadExt + Unpin, 144 | { 145 | let magic = read_u32be(reader).await?; 146 | if magic != super::PROTO_MAGIC { 147 | return Err(anyhow!("invalid magic value: {magic}")); 148 | } 149 | 150 | let req_size = read_u32be(reader).await?; 151 | if req_size > super::MESSAGE_MAX_SIZE { 152 | return Err(anyhow!("request size {req_size} > MESSAGE_MAX_SIZE")); 153 | } 154 | 155 | let mut req = vec![0_u8; req_size as usize]; 156 | 157 | reader.read_exact(&mut req).await?; 158 | 159 | Ok((req.len(), Self::from_bytes(&req)?)) 160 | } 161 | 162 | /// Write a Message with the provided writer. 163 | pub async fn to_writer(&self, writer: &mut W) -> Result 164 | where 165 | W: AsyncWriteExt + Unpin, 166 | { 167 | let req = self.to_bytes()?; 168 | let req_size = req.len() as u32; 169 | if req_size > super::MESSAGE_MAX_SIZE { 170 | return Err(anyhow!("request size {req_size} > MESSAGE_MAX_SIZE")); 171 | } 172 | 173 | write_u32be(writer, super::PROTO_MAGIC).await?; 174 | write_u32be(writer, req_size).await?; 175 | writer.write_all(&req).await?; 176 | 177 | Ok(8 + req.len()) 178 | } 179 | } 180 | -------------------------------------------------------------------------------- /cake-core/src/cake/proto/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module contains Cake protocol specific objects and constants. 2 | 3 | /// Cake protocol header magic value. 4 | const PROTO_MAGIC: u32 = 0x104F4C7; 5 | 6 | /// Cake protocol message max size. 7 | const MESSAGE_MAX_SIZE: u32 = 512 * 1024 * 1024; 8 | 9 | mod message; 10 | 11 | pub use message::*; 12 | -------------------------------------------------------------------------------- /cake-core/src/cake/topology.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use crate::ModelType; 4 | use anyhow::Result; 5 | use lazy_static::lazy_static; 6 | use regex::Regex; 7 | use serde::{Deserialize, Serialize}; 8 | 9 | lazy_static! { 10 | static ref LAYER_RANGE_PARSER: Regex = Regex::new(r"(?m)^(.+[^\d])(\d+)-(\d+)$").unwrap(); 11 | } 12 | 13 | /// A single node (worker). 14 | #[derive(Clone, Serialize, Deserialize)] 15 | pub struct Node { 16 | /// Address and port of the worker. 17 | pub host: String, 18 | /// Optional descriptioon. 19 | pub description: Option, 20 | pub layers: Vec, 21 | } 22 | 23 | impl Node { 24 | /// Return true if this node hosts the specified layer. 25 | pub fn is_text_model_layer_owner(&self, full_layer_name: &str) -> bool { 26 | for prefix in self.layers.iter() { 27 | if full_layer_name.starts_with(&format!("{}.", prefix)) { 28 | return true; 29 | } 30 | } 31 | 32 | false 33 | } 34 | } 35 | 36 | /// The topology is a worker-name -> worker-info map. 37 | #[derive(Clone, Serialize, Deserialize)] 38 | pub struct Topology(HashMap); 39 | 40 | impl Topology { 41 | /// Create a new empty topology. 42 | pub fn new() -> Self { 43 | Self(HashMap::new()) 44 | } 45 | 46 | /// Load the topology from a yaml file. 47 | pub fn from_path(path: &str, model_type: &ModelType) -> Result { 48 | log::info!("loading topology from {}", path); 49 | 50 | let mut topology: Self = serde_yaml::from_str(&std::fs::read_to_string(path)?) 51 | .map_err(|e| anyhow!("can't read {path}: {e}"))?; 52 | 53 | if *model_type == ModelType::TextModel { 54 | // check for range expressions 55 | for (_worker_name, node) in topology.iter_mut() { 56 | let mut layers = vec![]; 57 | for layer_name in &node.layers { 58 | if let Some(caps) = LAYER_RANGE_PARSER.captures_iter(layer_name).next() { 59 | let base = caps.get(1).unwrap().as_str().to_string(); 60 | let start = caps.get(2).unwrap().as_str().to_string().parse::()?; 61 | let stop = caps.get(3).unwrap().as_str().to_string().parse::()?; 62 | 63 | if stop <= start { 64 | return Err(anyhow!( 65 | "invalid range expression {layer_name}, end must be > start" 66 | )); 67 | } 68 | 69 | for n in start..=stop { 70 | layers.push(format!("{}{}", base, n)); 71 | } 72 | } else { 73 | layers.push(layer_name.to_string()); 74 | } 75 | } 76 | 77 | node.layers = layers; 78 | } 79 | } 80 | 81 | Ok(topology) 82 | } 83 | 84 | /// Return the node serving the specified layer, or None if not found. 85 | pub fn get_node_for_layer(&self, layer_name: &str) -> Option<(&str, &Node)> { 86 | for (node_name, node) in &self.0 { 87 | for node_layer_name in &node.layers { 88 | if layer_name == node_layer_name { 89 | return Some((node_name, node)); 90 | } 91 | } 92 | } 93 | None 94 | } 95 | } 96 | 97 | impl std::ops::Deref for Topology { 98 | type Target = HashMap; 99 | fn deref(&self) -> &HashMap { 100 | &self.0 101 | } 102 | } 103 | 104 | impl std::ops::DerefMut for Topology { 105 | fn deref_mut(&mut self) -> &mut Self::Target { 106 | &mut self.0 107 | } 108 | } 109 | -------------------------------------------------------------------------------- /cake-core/src/cake/worker.rs: -------------------------------------------------------------------------------- 1 | use std::{ 2 | collections::HashMap, 3 | net::SocketAddr, 4 | sync::Arc, 5 | time::{Duration, Instant}, 6 | }; 7 | 8 | use super::{Context, Forwarder, Message, WorkerInfo}; 9 | use crate::{models::Generator, ModelType}; 10 | 11 | use anyhow::Result; 12 | use candle_core::{DType, Device}; 13 | use tokio::{ 14 | io::{AsyncReadExt, AsyncWriteExt}, 15 | net::{TcpListener, TcpStream}, 16 | }; 17 | 18 | /// Determines how often worker statistics are calculated and printed. 19 | const NUM_OPS_TO_STATS: usize = 5; 20 | 21 | /// A single worker state. 22 | #[derive(Clone)] 23 | struct WorkerContext { 24 | device: Device, 25 | device_idx: usize, 26 | dtype: DType, 27 | blocks: Arc>>, 28 | context: Context, 29 | } 30 | 31 | impl WorkerContext { 32 | /// Create a WorkerInfo structure to be sent to the master. 33 | fn to_info(&self, latency: u128) -> WorkerInfo { 34 | WorkerInfo { 35 | version: env!("CARGO_PKG_VERSION").to_string(), 36 | os: std::env::consts::OS.to_string(), 37 | arch: std::env::consts::ARCH.to_string(), 38 | device: if self.device.is_cuda() { 39 | "cuda".to_string() 40 | } else if self.device.is_metal() { 41 | "metal".to_string() 42 | } else { 43 | "cpu".to_string() 44 | }, 45 | device_idx: self.device_idx, 46 | latency, 47 | dtype: format!("{:?}", self.dtype), 48 | } 49 | } 50 | 51 | /// Create a copy of self with new kv-cache. 52 | fn get_client_context(&self) -> Self { 53 | let cache = self.context.cache.as_ref().map(|cache| cache.as_new()); 54 | 55 | let mut cloned_context = self.context.clone(); 56 | cloned_context.cache = cache; 57 | 58 | WorkerContext { 59 | device: self.device.clone(), 60 | device_idx: self.device_idx, 61 | dtype: self.dtype, 62 | blocks: self.blocks.clone(), 63 | // each client loop gets a new cache 64 | context: cloned_context, 65 | } 66 | } 67 | } 68 | 69 | /// Cake worker node. 70 | pub struct Worker { 71 | listener: TcpListener, 72 | context: WorkerContext, 73 | } 74 | 75 | impl Worker { 76 | /// Create a new Worker from the context. 77 | pub async fn new(ctx: &mut Context) -> Result { 78 | let worker_name = if let Some(name) = &ctx.args.name { 79 | name.to_string() 80 | } else { 81 | return Err(anyhow!("no --name provided for worker")); 82 | }; 83 | 84 | let worker_topology = if let Some(node) = ctx.topology.get(&worker_name) { 85 | node 86 | } else if !ctx.topology.is_empty() { 87 | let first = ctx.topology.keys().next().unwrap(); 88 | log::warn!( 89 | "topology for worker name '{}' not found, using '{}'", 90 | &worker_name, 91 | first 92 | ); 93 | ctx.topology.get(first).unwrap() 94 | } else { 95 | return Err(anyhow!( 96 | "could not find topology for {worker_name} and topology file is empty" 97 | )); 98 | }; 99 | 100 | let mut blocks = HashMap::new(); 101 | 102 | let vb = ctx.var_builder.clone(); 103 | 104 | for block_layer_name in &worker_topology.layers { 105 | log::info!("loading {} ...", &block_layer_name); 106 | 107 | if ctx.args.model_type == ModelType::TextModel { 108 | ctx.var_builder = Some( 109 | vb.clone() 110 | .expect("Error retrieving var_builder") 111 | .pp(block_layer_name), 112 | ); 113 | } 114 | 115 | let block = G::Shardable::load(block_layer_name.to_string(), ctx)?; 116 | 117 | blocks.insert(block_layer_name.to_string(), block); 118 | } 119 | 120 | ctx.var_builder = vb; 121 | 122 | let blocks = Arc::new(blocks); 123 | 124 | let listener = TcpListener::bind(&ctx.args.address).await?; 125 | 126 | log::info!( 127 | "listening on {} (mem:{}) ...", 128 | &ctx.args.address, 129 | human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64) 130 | ); 131 | 132 | let device = ctx.device.clone(); 133 | let dtype = ctx.dtype; 134 | let device_idx = ctx.args.device; 135 | 136 | let context = WorkerContext { 137 | device, 138 | device_idx, 139 | dtype, 140 | blocks, 141 | context: ctx.clone(), 142 | }; 143 | 144 | Ok(Self { listener, context }) 145 | } 146 | 147 | /// Read a message from the socket and return elapsed time, message size and message. 148 | async fn read_message_timed(mut socket: R) -> Result<(Duration, usize, Message)> 149 | where 150 | R: AsyncReadExt + Unpin, 151 | { 152 | let start = Instant::now(); 153 | let (size, message) = Message::from_reader(&mut socket).await?; 154 | let latency = start.elapsed(); 155 | 156 | Ok((latency, size, message)) 157 | } 158 | 159 | /// Write a message to the socket and return the elapsed time with written size. 160 | async fn write_message_timed(mut socket: W, message: Message) -> Result<(Duration, usize)> 161 | where 162 | W: AsyncWriteExt + Unpin, 163 | { 164 | let start = Instant::now(); 165 | let size = message.to_writer(&mut socket).await?; 166 | let latency = start.elapsed(); 167 | 168 | Ok((latency, size)) 169 | } 170 | 171 | /// Main loop handling communication with the master. 172 | async fn handle_master_client( 173 | mut socket: TcpStream, 174 | client: SocketAddr, 175 | mut context: WorkerContext, 176 | ) -> Result<()> { 177 | // read and validate Hello 178 | let (latency, _size, hello) = Self::read_message_timed(&mut socket).await?; 179 | if !matches!(hello, Message::Hello) { 180 | return Err(anyhow!( 181 | "[{}] unpexpected message instead of hello: {:?}", 182 | &client, 183 | hello 184 | )); 185 | } 186 | 187 | // send info 188 | if let Err(e) = Self::write_message_timed( 189 | &mut socket, 190 | Message::WorkerInfo(context.to_info(latency.as_millis())), 191 | ) 192 | .await 193 | { 194 | return Err(anyhow!("[{}] could not send worker info: {:?}", &client, e)); 195 | } 196 | 197 | let mut msg_idx = 0; 198 | let mut avg_ops = 0; 199 | let mut avg_write = 0; 200 | let mut avg_read = 0; 201 | 202 | // keep reading messages 203 | while let Ok((read_time, read_size, op_message)) = 204 | Self::read_message_timed(&mut socket).await 205 | { 206 | if matches!(op_message, Message::Goodbye) { 207 | log::info!("[{}] goodbye", &client); 208 | context 209 | .context 210 | .cache 211 | .as_mut() 212 | .expect("No cache specified") 213 | .clear(); 214 | 215 | // send info 216 | if let Err(e) = Self::write_message_timed( 217 | &mut socket, 218 | Message::WorkerInfo(context.to_info(read_time.as_millis())), 219 | ) 220 | .await 221 | { 222 | return Err(anyhow!("[{}] could not send worker info: {:?}", &client, e)); 223 | } 224 | 225 | continue; 226 | } 227 | 228 | let (x, ops) = match op_message { 229 | // single block operation 230 | Message::SingleOp { 231 | layer_name, 232 | x, 233 | index_pos, 234 | block_idx, 235 | } => (x, vec![(layer_name, index_pos, block_idx)]), 236 | // batched 237 | Message::Batch { x, batch } => (x, batch), 238 | _ => { 239 | return Err(anyhow!( 240 | "[{}] unhandled message in loop: {:?}", 241 | &client, 242 | op_message 243 | )); 244 | } 245 | }; 246 | 247 | // load raw tensor to device 248 | let mut x = x.to_tensor(&context.device).unwrap(); 249 | let num_ops = ops.len(); 250 | let start_ops = Instant::now(); 251 | 252 | // for each element in the ops batch 253 | for (layer_name, index_pos, block_idx) in ops { 254 | // get layer block by name 255 | if let Some(block) = context.blocks.get(&layer_name) { 256 | // run forward pass 257 | x = block 258 | .forward(&x, index_pos, block_idx, &mut context.context) 259 | .await 260 | .unwrap(); 261 | } else { 262 | return Err(anyhow!("could not find layer {}", &layer_name)); 263 | } 264 | } 265 | 266 | let elaps_ops = start_ops.elapsed(); 267 | 268 | // send response tensor 269 | match Self::write_message_timed(&mut socket, Message::from_tensor(&x)).await { 270 | Ok((elaps_write, written)) => { 271 | let ops_per_sec = (num_ops as f64 / elaps_ops.as_secs_f64()) as usize; 272 | let write_bytes_per_sec = (written as f64 / elaps_write.as_secs_f64()) as usize; 273 | let read_bytes_per_sec = (read_size as f64 / read_time.as_secs_f64()) as usize; 274 | 275 | avg_ops += ops_per_sec; 276 | avg_write += write_bytes_per_sec; 277 | avg_read += read_bytes_per_sec; 278 | } 279 | Err(e) => { 280 | return Err(anyhow!( 281 | "[{}] could not send response tensor: {:?}", 282 | &client, 283 | e 284 | )); 285 | } 286 | } 287 | 288 | // compute and print stats every NUM_OPS_TO_STATS operations to avoid spamming stdout 289 | if msg_idx % NUM_OPS_TO_STATS == 0 { 290 | log::info!( 291 | "ops={}/s read={}/s write={}/s", 292 | avg_ops / NUM_OPS_TO_STATS, 293 | human_bytes::human_bytes(avg_read as f64 / NUM_OPS_TO_STATS as f64), 294 | human_bytes::human_bytes(avg_write as f64 / NUM_OPS_TO_STATS as f64) 295 | ); 296 | avg_ops = 0; 297 | avg_write = 0; 298 | avg_read = 0; 299 | } 300 | msg_idx += 1; 301 | } 302 | 303 | Ok(()) 304 | } 305 | 306 | /// Run the worker server accept loop. 307 | pub async fn run(&mut self) -> Result<()> { 308 | while let Ok((socket, client)) = self.listener.accept().await { 309 | log::info!("{} connected", &client); 310 | 311 | let context = self.context.get_client_context(); 312 | tokio::spawn(async move { 313 | if let Err(e) = Self::handle_master_client(socket, client, context).await { 314 | log::error!("{}", e); 315 | } 316 | }); 317 | } 318 | 319 | Ok(()) 320 | } 321 | } 322 | -------------------------------------------------------------------------------- /cake-core/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This is the core library where all Cake logic is implemented. 2 | #[macro_use] 3 | extern crate anyhow; 4 | 5 | use cake::Mode; 6 | 7 | use clap::{Parser, ValueEnum}; 8 | use serde::Deserialize; 9 | 10 | pub mod cake; 11 | pub mod models; 12 | pub mod utils; 13 | 14 | #[derive(Copy, Clone, Parser, Default, Debug, Eq, PartialEq, PartialOrd, Ord, ValueEnum)] 15 | pub enum ModelType { 16 | #[default] 17 | TextModel, 18 | ImageModel, 19 | } 20 | 21 | #[derive(Clone, Parser, Default, Debug)] 22 | #[command(author, version, about, long_about = None)] 23 | pub struct Args { 24 | /// GPU device index. 25 | #[arg(long, default_value_t = 0)] 26 | pub device: usize, 27 | /// Mode. 28 | #[arg(long, default_value_t, value_enum)] 29 | pub mode: Mode, 30 | /// Worker name. 31 | #[arg(long)] 32 | pub name: Option, 33 | /// Binding address and port for workers. 34 | #[arg(long, default_value = "127.0.0.1:10128")] 35 | pub address: String, 36 | /// Enable OpenAI compatible chat completion API. 37 | #[arg(long)] 38 | pub api: Option, 39 | /// Llama3 model data path. 40 | #[arg(long, default_value = "./cake-data/Meta-Llama-3-8B/")] 41 | pub model: String, 42 | /// Topology file. 43 | #[arg(long)] 44 | pub topology: Option, 45 | /// The initial prompt. 46 | #[arg(long, default_value = "The sky is blue because ")] 47 | pub prompt: String, 48 | /// The system prompt. 49 | #[arg(long, default_value = "You are a helpful AI assistant.")] 50 | pub system_prompt: String, 51 | /// The seed to use when generating random samples. 52 | #[arg(long, default_value_t = 299792458)] 53 | pub seed: u64, 54 | /// The length of the sample to generate (in tokens). 55 | #[arg(short = 'n', long, default_value_t = 100)] 56 | pub sample_len: usize, 57 | /// The temperature used to generate samples. 58 | #[arg(long, default_value_t = 1.0)] 59 | pub temperature: f64, 60 | /// Nucleus sampling probability cutoff. 61 | #[arg(long)] 62 | pub top_p: Option, 63 | /// Only sample among the top K samples. 64 | #[arg(long)] 65 | pub top_k: Option, 66 | /// Penalty to be applied for repeating tokens, 1. means no penalty. 67 | #[arg(long, default_value_t = 1.1)] 68 | pub repeat_penalty: f32, 69 | /// The context size to consider for the repeat penalty. 70 | #[arg(long, default_value_t = 128)] 71 | pub repeat_last_n: usize, 72 | /// Use different dtype than f16 73 | #[arg(long)] 74 | pub dtype: Option, 75 | 76 | /// Run on CPU rather than on GPU. 77 | #[arg(long, default_value_t = false)] 78 | pub cpu: bool, 79 | 80 | #[arg(long, default_value = "text-model")] 81 | pub model_type: ModelType, 82 | 83 | #[clap(flatten)] 84 | pub sd_args: SDArgs, 85 | 86 | #[clap(flatten)] 87 | pub sd_img_gen_args: ImageGenerationArgs, 88 | } 89 | 90 | #[derive(Clone, Parser, Default, Debug)] 91 | pub struct SDArgs { 92 | #[arg(long = "sd-tokenizer")] 93 | pub tokenizer: Option, 94 | 95 | #[arg(long = "sd-tokenizer-2")] 96 | pub tokenizer_2: Option, 97 | 98 | #[arg(long = "sd-version", value_enum, default_value = "v1-5")] 99 | sd_version: StableDiffusionVersion, 100 | 101 | #[arg(long = "sd-use-f16", default_value_t = true)] 102 | use_f16: bool, 103 | 104 | #[arg(long = "sd-width")] 105 | width: Option, 106 | 107 | #[arg(long = "sd-height")] 108 | height: Option, 109 | 110 | #[arg(long = "sd-sliced-attention-size")] 111 | sliced_attention_size: Option, 112 | 113 | #[arg(long = "sd-clip")] 114 | clip: Option, 115 | 116 | #[arg(long = "sd-clip2")] 117 | clip2: Option, 118 | 119 | #[arg(long = "sd-vae")] 120 | vae: Option, 121 | 122 | #[arg(long = "sd-unet")] 123 | unet: Option, 124 | 125 | #[arg(long = "sd-use-flash-attention", default_value_t = false)] 126 | use_flash_attention: bool, 127 | } 128 | 129 | fn default_prompt() -> String { 130 | "A very realistic photo of a rusty robot walking on a sandy beach".to_string() 131 | } 132 | 133 | fn empty_str() -> String { 134 | "".to_string() 135 | } 136 | 137 | fn usize_one() -> usize { 138 | 1 139 | } 140 | 141 | fn default_img2img_strength() -> f64 { 142 | 0.8 143 | } 144 | 145 | #[derive(Clone, Parser, Default, Debug, Deserialize)] 146 | pub struct ImageGenerationArgs { 147 | /// The prompt to be used for image generation. 148 | #[arg( 149 | long = "sd-image-prompt", 150 | default_value = "A very realistic photo of a rusty robot walking on a sandy beach" 151 | )] 152 | #[serde(rename(deserialize = "sd-image-prompt"), default = "default_prompt")] 153 | image_prompt: String, 154 | 155 | #[arg(long = "sd-uncond-prompt", default_value = "")] 156 | #[serde(rename(deserialize = "sd-uncond-prompt"), default = "empty_str")] 157 | uncond_prompt: String, 158 | 159 | /// Enable tracing (generates a trace-timestamp.json file). 160 | #[arg(long = "sd-tracing", default_value_t = false)] 161 | #[serde(rename(deserialize = "sd-tracing"), default)] 162 | tracing: bool, 163 | 164 | /// The number of steps to run the diffusion for. 165 | #[arg(long = "sd-n-steps")] 166 | #[serde(rename(deserialize = "sd-n-steps"))] 167 | n_steps: Option, 168 | 169 | /// The number of samples to generate iteratively. 170 | #[arg(long = "sd-num-samples", default_value_t = 1)] 171 | #[serde(rename(deserialize = "sd-num-samples"), default = "usize_one")] 172 | num_samples: usize, 173 | 174 | /// The numbers of samples to generate simultaneously. 175 | #[arg(long = "sd-bsize", default_value_t = 1)] 176 | #[serde(rename(deserialize = "sd-bsize"), default = "usize_one")] 177 | bsize: usize, 178 | 179 | /// Generate intermediary images every n steps. 180 | #[arg(long = "sd-intermediary-images", default_value_t = 0, action)] 181 | #[serde(rename(deserialize = "sd-intermediary-images"), default)] 182 | intermediary_images: usize, 183 | 184 | #[arg(long = "sd-guidance-scale")] 185 | #[serde(rename(deserialize = "sd-guidance-scale"))] 186 | guidance_scale: Option, 187 | 188 | #[arg(long = "sd-img2img", value_name = "FILE")] 189 | #[serde(rename(deserialize = "sd-img2img"))] 190 | img2img: Option, 191 | 192 | /// The strength, indicates how much to transform the initial image. The 193 | /// value must be between 0 and 1, a value of 1 discards the initial image 194 | /// information. 195 | #[arg(long = "sd-img2img-strength", default_value_t = 0.8)] 196 | #[serde( 197 | rename(deserialize = "sd-img2img-strength"), 198 | default = "default_img2img_strength" 199 | )] 200 | img2img_strength: f64, 201 | 202 | /// The seed to use when generating random samples. 203 | #[arg(long = "sd-seed")] 204 | #[serde(rename(deserialize = "sd-seed"))] 205 | image_seed: Option, 206 | } 207 | 208 | #[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq, Default)] 209 | pub enum StableDiffusionVersion { 210 | #[default] 211 | V1_5, 212 | V2_1, 213 | Xl, 214 | Turbo, 215 | } 216 | 217 | impl StableDiffusionVersion { 218 | fn repo(&self) -> &'static str { 219 | match self { 220 | Self::Xl => "stabilityai/stable-diffusion-xl-base-1.0", 221 | Self::V2_1 => "stabilityai/stable-diffusion-2-1", 222 | Self::V1_5 => "runwayml/stable-diffusion-v1-5", 223 | Self::Turbo => "stabilityai/sdxl-turbo", 224 | } 225 | } 226 | 227 | fn unet_file(&self, use_f16: bool) -> &'static str { 228 | match self { 229 | Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { 230 | if use_f16 { 231 | "unet/diffusion_pytorch_model.fp16.safetensors" 232 | } else { 233 | "unet/diffusion_pytorch_model.safetensors" 234 | } 235 | } 236 | } 237 | } 238 | 239 | fn vae_file(&self, use_f16: bool) -> &'static str { 240 | match self { 241 | Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { 242 | if use_f16 { 243 | "vae/diffusion_pytorch_model.fp16.safetensors" 244 | } else { 245 | "vae/diffusion_pytorch_model.safetensors" 246 | } 247 | } 248 | } 249 | } 250 | 251 | fn clip_file(&self, use_f16: bool) -> &'static str { 252 | match self { 253 | Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { 254 | if use_f16 { 255 | "text_encoder/model.fp16.safetensors" 256 | } else { 257 | "text_encoder/model.safetensors" 258 | } 259 | } 260 | } 261 | } 262 | 263 | fn clip2_file(&self, use_f16: bool) -> &'static str { 264 | match self { 265 | Self::V1_5 | Self::V2_1 | Self::Xl | Self::Turbo => { 266 | if use_f16 { 267 | "text_encoder_2/model.fp16.safetensors" 268 | } else { 269 | "text_encoder_2/model.safetensors" 270 | } 271 | } 272 | } 273 | } 274 | } 275 | -------------------------------------------------------------------------------- /cake-core/src/models/chat.rs: -------------------------------------------------------------------------------- 1 | use serde::{Deserialize, Serialize}; 2 | 3 | /// The role of a message in a chat. 4 | #[derive(Debug, Serialize, Deserialize)] 5 | pub enum MessageRole { 6 | /// System prompt. 7 | #[serde(alias = "system")] 8 | System, 9 | /// User prompt. 10 | #[serde(alias = "user")] 11 | User, 12 | /// Assistant response. 13 | #[serde(alias = "assistant")] 14 | Assistant, 15 | } 16 | 17 | impl std::fmt::Display for MessageRole { 18 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 19 | write!( 20 | f, 21 | "{}", 22 | match self { 23 | MessageRole::System => "system", 24 | MessageRole::User => "user", 25 | MessageRole::Assistant => "assistant", 26 | } 27 | ) 28 | } 29 | } 30 | 31 | /// A chat message. 32 | #[derive(Debug, Serialize, Deserialize)] 33 | pub struct Message { 34 | /// Message role. 35 | pub role: MessageRole, 36 | /// Messagae content. 37 | pub content: String, 38 | } 39 | 40 | impl Message { 41 | /// Create a system message. 42 | pub fn system(content: String) -> Self { 43 | Self { 44 | role: MessageRole::System, 45 | content, 46 | } 47 | } 48 | 49 | /// Create a user message. 50 | pub fn user(content: String) -> Self { 51 | Self { 52 | role: MessageRole::User, 53 | content, 54 | } 55 | } 56 | 57 | /// Create an assistant message. 58 | pub fn assistant(content: String) -> Self { 59 | Self { 60 | role: MessageRole::Assistant, 61 | content, 62 | } 63 | } 64 | } 65 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/attention.rs: -------------------------------------------------------------------------------- 1 | //! Causal self attention implementation. 2 | use candle_core::{DType, Result, Tensor, D}; 3 | use candle_nn::{linear_no_bias as linear, Linear, Module, VarBuilder}; 4 | 5 | #[derive(Debug, Clone)] 6 | pub struct CausalSelfAttention { 7 | q_proj: Linear, 8 | k_proj: Linear, 9 | v_proj: Linear, 10 | o_proj: Linear, 11 | num_attention_heads: usize, 12 | num_key_value_heads: usize, 13 | head_dim: usize, 14 | } 15 | 16 | #[inline] 17 | fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { 18 | let shape = mask.shape(); 19 | let on_true = Tensor::new(on_true, on_false.device())?.broadcast_as(shape.dims())?; 20 | let m = mask.where_cond(&on_true, on_false)?; 21 | Ok(m) 22 | } 23 | 24 | impl CausalSelfAttention { 25 | fn apply_rotary_emb( 26 | &self, 27 | x: &Tensor, 28 | index_pos: usize, 29 | cache: &super::Cache, 30 | ) -> Result { 31 | let (_batch_size, _, seq_len, _hidden_size) = x.dims4()?; 32 | let cos = cache.cosine(index_pos, seq_len)?; 33 | let sin = cache.sine(index_pos, seq_len)?; 34 | candle_nn::rotary_emb::rope(x, &cos, &sin) 35 | } 36 | 37 | /// Process the input tensor using the given state indexes and cache. 38 | pub fn forward( 39 | &self, 40 | x: &Tensor, 41 | index_pos: usize, 42 | block_idx: usize, 43 | cache: &mut super::Cache, 44 | ) -> anyhow::Result { 45 | let (b_sz, seq_len, hidden_size) = x.dims3().map_err(|e| anyhow!("x.dims3 -> {e}"))?; 46 | 47 | // log::info!("x.dims3 = {:?}", x.dims3().unwrap()); 48 | 49 | let q = self 50 | .q_proj 51 | .forward(x) 52 | .map_err(|e| anyhow!("q.forward -> {e}"))?; 53 | let k = self 54 | .k_proj 55 | .forward(x) 56 | .map_err(|e| anyhow!("k.forward -> {e}"))?; 57 | let v = self 58 | .v_proj 59 | .forward(x) 60 | .map_err(|e| anyhow!("v.forward -> {e}"))?; 61 | 62 | let q = q 63 | .reshape((b_sz, seq_len, self.num_attention_heads, self.head_dim))? 64 | .transpose(1, 2)? 65 | .contiguous() 66 | .map_err(|e| anyhow!("q.reshape -> {e}"))?; 67 | let k = k 68 | .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? 69 | .transpose(1, 2)? 70 | .contiguous() 71 | .map_err(|e| anyhow!("k.reshape -> {e}"))?; 72 | let v = v 73 | .reshape((b_sz, seq_len, self.num_key_value_heads, self.head_dim))? 74 | .transpose(1, 2) 75 | .map_err(|e| anyhow!("v.reshape -> {e}"))?; 76 | 77 | let q = self 78 | .apply_rotary_emb(&q, index_pos, cache) 79 | .map_err(|e| anyhow!("q.apply_rotary_emb -> {e}"))?; 80 | 81 | let k = self 82 | .apply_rotary_emb(&k, index_pos, cache) 83 | .map_err(|e| anyhow!("k.apply_rotary_emb -> {e}"))?; 84 | 85 | let (k, v) = cache 86 | .process_kv(block_idx, k, v) 87 | .map_err(|e| anyhow!("cache.process_kv(block={block_idx}) -> {e}"))?; 88 | 89 | let k = self 90 | .repeat_kv(k) 91 | .map_err(|e| anyhow!("repeat_kv(k) -> {e}"))?; 92 | let v = self 93 | .repeat_kv(v) 94 | .map_err(|e| anyhow!("repeat_kv(v) -> {e}"))?; 95 | 96 | let y = { 97 | let in_dtype = q.dtype(); 98 | let q = q.to_dtype(DType::F32)?; 99 | let k = k.to_dtype(DType::F32)?; 100 | let v = v.to_dtype(DType::F32)?; 101 | let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?; 102 | let att = if seq_len == 1 { 103 | att 104 | } else { 105 | let mask = cache 106 | .mask(seq_len) 107 | .map_err(|e| anyhow!("cache.mask({seq_len}) -> {e}"))? 108 | .broadcast_as(att.shape()) 109 | .map_err(|e| anyhow!("mask.broadcast_as({:?}) -> {e}", att.shape()))?; 110 | 111 | masked_fill(&att, &mask, f32::NEG_INFINITY) 112 | .map_err(|e| anyhow!("masked_fill -> {e}"))? 113 | }; 114 | let att = candle_nn::ops::softmax(&att, D::Minus1)?; 115 | 116 | // Convert to contiguous as matmul doesn't support strided vs for now. 117 | att.matmul(&v.contiguous()?)?.to_dtype(in_dtype)? 118 | }; 119 | let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, hidden_size])?; 120 | let y = self.o_proj.forward(&y)?; 121 | 122 | Ok(y) 123 | } 124 | 125 | fn repeat_kv(&self, x: Tensor) -> Result { 126 | candle_transformers::utils::repeat_kv( 127 | x, 128 | self.num_attention_heads / self.num_key_value_heads, 129 | ) 130 | } 131 | 132 | /// Load an instance of this object from the VarBuilder object with the given configuration. 133 | pub fn load(vb: VarBuilder, cfg: &super::Config) -> Result { 134 | let size_in = cfg.hidden_size; 135 | let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads; 136 | let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads; 137 | let q_proj = linear(size_in, size_q, vb.pp("q_proj"))?; 138 | let k_proj = linear(size_in, size_kv, vb.pp("k_proj"))?; 139 | let v_proj = linear(size_in, size_kv, vb.pp("v_proj"))?; 140 | let o_proj = linear(size_q, size_in, vb.pp("o_proj"))?; 141 | Ok(Self { 142 | q_proj, 143 | k_proj, 144 | v_proj, 145 | o_proj, 146 | num_attention_heads: cfg.num_attention_heads, 147 | num_key_value_heads: cfg.num_key_value_heads, 148 | head_dim: cfg.hidden_size / cfg.num_attention_heads, 149 | }) 150 | } 151 | } 152 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/cache.rs: -------------------------------------------------------------------------------- 1 | use std::collections::HashMap; 2 | 3 | use candle_core::{DType, Device, Result, Tensor, D}; 4 | 5 | use super::{Config, MAX_SEQ_LEN}; 6 | 7 | /// Abstraction over cosine and sine tables, kv-caching and attention masking. 8 | #[derive(Debug, Clone)] 9 | pub struct Cache { 10 | cos: Tensor, 11 | sin: Tensor, 12 | 13 | masks: HashMap, 14 | use_kv_cache: bool, 15 | kvs: Vec>, 16 | 17 | device: Device, 18 | } 19 | 20 | impl Cache { 21 | /// Creates a new cache instance with the provided configuration. 22 | /// Set `use_kv_cache` to false to disable kv-caching. 23 | pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result { 24 | // precompute freqs_cis 25 | let n_elem = config.hidden_size / config.num_attention_heads; 26 | 27 | log::debug!("cache::n_elem = {n_elem}"); 28 | 29 | let theta: Vec<_> = (0..n_elem) 30 | .step_by(2) 31 | .map(|i| 1f32 / config.rope_theta.powf(i as f32 / n_elem as f32)) 32 | .collect(); 33 | 34 | let theta = Tensor::new(theta.as_slice(), device)?; 35 | 36 | log::debug!("cache::theta = {}", &theta); 37 | 38 | let idx_theta = Tensor::arange(0, super::MAX_SEQ_LEN as u32, device)? 39 | .to_dtype(DType::F32)? 40 | .reshape((super::MAX_SEQ_LEN, 1))? 41 | .matmul(&theta.reshape((1, theta.elem_count()))?)?; 42 | 43 | log::debug!("cache::idx_theta = {}", &idx_theta); 44 | 45 | // This is different from the paper, see: 46 | // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112 47 | let cos = idx_theta.cos()?.to_dtype(dtype)?; 48 | let sin = idx_theta.sin()?.to_dtype(dtype)?; 49 | 50 | log::debug!("cache::cos = {}", &cos); 51 | log::debug!("cache::sin = {}", &sin); 52 | 53 | Ok(Self { 54 | masks: HashMap::new(), 55 | use_kv_cache, 56 | kvs: vec![None; config.num_hidden_layers], 57 | device: device.clone(), 58 | cos, 59 | sin, 60 | }) 61 | } 62 | 63 | /// Return true if kv-caching is enabled. 64 | pub fn with_kv_cache(&self) -> bool { 65 | self.use_kv_cache 66 | } 67 | 68 | /// Return the cached cosine value for the given position and sequence length. 69 | pub fn cosine(&self, index_pos: usize, seq_len: usize) -> Result { 70 | self.cos.narrow(0, index_pos, seq_len) 71 | } 72 | 73 | /// Return the cached sine value for the given position and sequence length. 74 | pub fn sine(&self, index_pos: usize, seq_len: usize) -> Result { 75 | self.sin.narrow(0, index_pos, seq_len) 76 | } 77 | 78 | /// Get the attention mask for the given sequence length. 79 | pub fn mask(&mut self, seq_len: usize) -> Result { 80 | if let Some(mask) = self.masks.get(&seq_len) { 81 | Ok(mask.clone()) 82 | } else { 83 | let mask: Vec<_> = (0..seq_len) 84 | .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i))) 85 | .collect(); 86 | let mask = Tensor::from_slice(&mask, (seq_len, seq_len), &self.device)?; 87 | self.masks.insert(seq_len, mask.clone()); 88 | Ok(mask) 89 | } 90 | } 91 | 92 | /// Process the input k and v by either generating their cache entry or applying a previously cached one. 93 | pub fn process_kv( 94 | &mut self, 95 | block_idx: usize, 96 | mut k: Tensor, 97 | mut v: Tensor, 98 | ) -> Result<(Tensor, Tensor)> { 99 | if self.use_kv_cache { 100 | // if this block_idx in cache 101 | if let Some((cache_k, cache_v)) = &self.kvs[block_idx] { 102 | // update cache entry 103 | k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?; 104 | v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?; 105 | let k_seq_len = k.dims()[1]; 106 | if k_seq_len > MAX_SEQ_LEN { 107 | k = k 108 | .narrow(D::Minus1, k_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? 109 | .contiguous()? 110 | } 111 | let v_seq_len = v.dims()[1]; 112 | if v_seq_len > 2 * MAX_SEQ_LEN { 113 | v = v 114 | .narrow(D::Minus1, v_seq_len - MAX_SEQ_LEN, MAX_SEQ_LEN)? 115 | .contiguous()? 116 | } 117 | } 118 | // set entry for this block 119 | self.kvs[block_idx] = Some((k.clone(), v.clone())) 120 | } 121 | Ok((k, v)) 122 | } 123 | 124 | /// Return a copy of this cache with the same state but new kv table. 125 | pub fn as_new(&self) -> Self { 126 | let mut copy = self.clone(); 127 | copy.clear(); 128 | copy 129 | } 130 | 131 | /// Clear the cache. 132 | pub fn clear(&mut self) { 133 | self.masks.clear(); 134 | self.kvs = vec![None; self.kvs.len()]; 135 | } 136 | } 137 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/config.rs: -------------------------------------------------------------------------------- 1 | use std::path::Path; 2 | 3 | use anyhow::Result; 4 | 5 | /// Max supported sequence length. 6 | pub const MAX_SEQ_LEN: usize = 4096; 7 | 8 | fn default_rope() -> f32 { 9 | 10_000.0 10 | } 11 | 12 | /// LLama specific configuration. 13 | #[derive(Debug, Clone, serde::Deserialize)] 14 | pub struct LlamaConfig { 15 | pub hidden_size: usize, 16 | pub intermediate_size: usize, 17 | pub vocab_size: usize, 18 | pub num_hidden_layers: usize, 19 | pub num_attention_heads: usize, 20 | pub num_key_value_heads: Option, 21 | pub rms_norm_eps: f64, 22 | #[serde(default = "default_rope")] 23 | pub rope_theta: f32, 24 | pub bos_token_id: Option, 25 | pub eos_token_id: Option, 26 | } 27 | 28 | impl LlamaConfig { 29 | /// Load the configuration from the given path. 30 | pub fn from_path(path: &Path) -> Result { 31 | log::info!("loading configuration from {}", path.display()); 32 | 33 | let data = 34 | std::fs::read(path).map_err(|e| anyhow!("can't read {}: {:?}", path.display(), e))?; 35 | serde_json::from_slice(&data) 36 | .map_err(|e| anyhow!("can't parse {}: {:?}", path.display(), e)) 37 | } 38 | 39 | /// Return the number of kv heads. 40 | pub fn num_key_value_heads(&self) -> usize { 41 | self.num_key_value_heads.unwrap_or(self.num_attention_heads) 42 | } 43 | 44 | /// Return a generalized Config object. 45 | pub fn into_config(self) -> Config { 46 | Config { 47 | hidden_size: self.hidden_size, 48 | intermediate_size: self.intermediate_size, 49 | vocab_size: self.vocab_size, 50 | num_hidden_layers: self.num_hidden_layers, 51 | num_attention_heads: self.num_attention_heads, 52 | num_key_value_heads: self.num_key_value_heads(), 53 | rms_norm_eps: self.rms_norm_eps, 54 | rope_theta: self.rope_theta, 55 | bos_token_id: self.bos_token_id, 56 | eos_token_id: self.eos_token_id, 57 | } 58 | } 59 | } 60 | 61 | /// Generalized LLama/LLM configuration. 62 | #[derive(Debug, Clone)] 63 | pub struct Config { 64 | pub hidden_size: usize, 65 | pub intermediate_size: usize, 66 | pub vocab_size: usize, 67 | pub num_hidden_layers: usize, 68 | pub num_attention_heads: usize, 69 | pub num_key_value_heads: usize, 70 | pub rms_norm_eps: f64, 71 | pub rope_theta: f32, 72 | pub bos_token_id: Option, 73 | pub eos_token_id: Option, 74 | } 75 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/history.rs: -------------------------------------------------------------------------------- 1 | use crate::models::chat::Message; 2 | 3 | /// Chat history. 4 | pub struct History(Vec); 5 | 6 | // Adapted from https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py#L202 7 | impl History { 8 | fn encode_header(message: &Message) -> String { 9 | format!("<|start_header_id|>{}<|end_header_id|>\n\n", message.role) 10 | } 11 | 12 | fn encode_message(message: &Message) -> String { 13 | Self::encode_header(message) + message.content.trim() + "<|eot_id|>" 14 | } 15 | 16 | /// Create a new instance of this object. 17 | pub fn new() -> Self { 18 | Self(vec![]) 19 | } 20 | 21 | /// Encode the dialog to llama3 prompt format. 22 | pub fn encode_dialog_to_prompt(&self) -> String { 23 | let mut encoded = "<|begin_of_text|>".to_string(); 24 | 25 | for message in self.iter() { 26 | encoded += &Self::encode_message(message); 27 | } 28 | 29 | // Add the start of an assistant message for the model to complete. 30 | encoded += &Self::encode_header(&Message::assistant("".to_string())); 31 | 32 | encoded 33 | } 34 | } 35 | 36 | impl std::ops::Deref for History { 37 | type Target = Vec; 38 | fn deref(&self) -> &Vec { 39 | &self.0 40 | } 41 | } 42 | 43 | impl std::ops::DerefMut for History { 44 | fn deref_mut(&mut self) -> &mut Self::Target { 45 | &mut self.0 46 | } 47 | } 48 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/llama.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use candle_core::{DType, IndexOp, Tensor}; 4 | use candle_nn::{linear_no_bias as linear, Embedding, Linear, Module, RmsNorm}; 5 | use candle_transformers::generation::{LogitsProcessor, Sampling}; 6 | use tokenizers::Tokenizer; 7 | 8 | use super::{transformer::Transformer, History}; 9 | use crate::models::TextGenerator; 10 | use crate::{ 11 | cake::{Context, Forwarder}, 12 | models::{chat::Message, Generator, Token}, 13 | }; 14 | 15 | /// Default end of stream token if not found in configuration. 16 | const DEFAULT_EOS_TOKEN: &str = ""; 17 | 18 | /// Load the tokenizer and return the first tokens from the prompt in context. 19 | fn load_tokenizer(ctx: &Context) -> Result<(Tokenizer, Option)> { 20 | let tokenizer_filename = ctx.data_path.join("tokenizer.json"); 21 | 22 | log::info!("loading tokenizer from {}", tokenizer_filename.display()); 23 | 24 | let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(anyhow::Error::msg)?; 25 | 26 | let eos_token_id = ctx 27 | .config 28 | .as_ref() 29 | .expect("No config specified") 30 | .eos_token_id 31 | .or_else(|| tokenizer.token_to_id(DEFAULT_EOS_TOKEN)); 32 | 33 | Ok((tokenizer, eos_token_id)) 34 | } 35 | 36 | /// Create the logit sampling logic from the context. 37 | fn create_logits_processor(ctx: &Context) -> LogitsProcessor { 38 | let temperature = ctx.args.temperature; 39 | let sampling = if temperature <= 0. { 40 | Sampling::ArgMax 41 | } else { 42 | match (ctx.args.top_k, ctx.args.top_p) { 43 | (None, None) => Sampling::All { temperature }, 44 | (Some(k), None) => Sampling::TopK { k, temperature }, 45 | (None, Some(p)) => Sampling::TopP { p, temperature }, 46 | (Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature }, 47 | } 48 | }; 49 | LogitsProcessor::from_sampling(ctx.args.seed, sampling) 50 | } 51 | 52 | /// LLama main class. 53 | pub struct LLama { 54 | ctx: Context, 55 | 56 | tokenizer: Tokenizer, 57 | embedding: Embedding, 58 | eos_token_id: Option, 59 | index_pos: usize, 60 | generated: usize, 61 | 62 | blocks: Vec>, 63 | 64 | ln_f: RmsNorm, 65 | lm_head: Linear, 66 | 67 | logits_processor: LogitsProcessor, 68 | 69 | history: History, 70 | tokens: Vec, 71 | } 72 | 73 | impl LLama { 74 | async fn forward(&mut self, x: &Tensor, idx: usize) -> Result { 75 | let (_batch_size, seq_len) = x.dims2()?; 76 | let mut x = self.embedding.forward(x)?; 77 | 78 | let num_blocks = self.blocks.len(); 79 | let mut block_idx = 0; 80 | 81 | // log::info!("X = {}", &x); 82 | 83 | while block_idx < num_blocks { 84 | let curr_block_id = self.blocks[block_idx].ident().to_owned(); 85 | if curr_block_id == "local" { 86 | // log::info!("x={:?} idx={idx} block={block_idx}", x.shape()); 87 | 88 | // do not batch local inferences 89 | x = self.blocks[block_idx] 90 | .forward_mut(&x, idx, block_idx, &mut self.ctx) 91 | .await 92 | .map_err(|e| { 93 | anyhow!("error in forward operation of local block {block_idx}: {e}") 94 | })?; 95 | 96 | block_idx += 1; 97 | } else { 98 | // collect all contiguous layers running on the same worker 99 | let mut batch = vec![]; 100 | let first = block_idx; 101 | while block_idx < num_blocks && self.blocks[block_idx].ident() == curr_block_id { 102 | batch.push(( 103 | self.blocks[block_idx].layer_name().to_string(), 104 | idx, 105 | block_idx, 106 | )); 107 | block_idx += 1; 108 | } 109 | 110 | x = self.blocks[first] 111 | .forward_batch(&x, batch, &mut self.ctx) 112 | .await 113 | .map_err(|e| { 114 | anyhow!("error in forward batch operation for block {block_idx}: {e}") 115 | })?; 116 | } 117 | 118 | // log::info!("{}.forward(X) -> {}", &curr_block_id, &x); 119 | } 120 | 121 | let x = self 122 | .ln_f 123 | .forward(&x) 124 | .map_err(|e| anyhow!("error in ln_f.forward: {e}"))?; 125 | 126 | let x = x 127 | .i((.., seq_len - 1, ..)) 128 | .map_err(|e| anyhow!("error in x.i: {e}"))? 129 | .contiguous() 130 | .map_err(|e| anyhow!("error in x.i.contiguous: {e}"))?; 131 | 132 | let logits = self 133 | .lm_head 134 | .forward(&x) 135 | .map_err(|e| anyhow!("error in lm_head.forward: {e}"))?; 136 | 137 | logits 138 | .to_dtype(DType::F32) 139 | .map_err(|e| anyhow!("error converting logits: {e}")) 140 | } 141 | 142 | fn start_dialog_prompt(&mut self) -> Result<()> { 143 | // make sure we start clean 144 | self.tokens.clear(); 145 | self.ctx.cache.as_mut().expect("No cache specified").clear(); 146 | self.index_pos = 0; 147 | 148 | log::debug!("generating history tokens ..."); 149 | 150 | // generate raw from history 151 | let dialog = self.history.encode_dialog_to_prompt(); 152 | 153 | log::debug!("dialog={}", &dialog); 154 | 155 | // tokenize raw 156 | self.tokens = self 157 | .tokenizer 158 | .encode(dialog, false) // do not add special tokens as we already added them 159 | .map_err(anyhow::Error::msg)? 160 | .get_ids() 161 | .to_vec(); 162 | 163 | log::debug!("encoded={:?}", &self.tokens); 164 | 165 | log::debug!("history tokens: {}", self.tokens.len()); 166 | 167 | Ok(()) 168 | } 169 | } 170 | 171 | #[async_trait] 172 | impl Generator for LLama { 173 | type Shardable = Transformer; 174 | const MODEL_NAME: &'static str = "llama3"; 175 | 176 | /// Load this model from the context. 177 | async fn load(ctx: &mut Context) -> Result>> { 178 | let config = ctx.config.as_ref().expect("No config specified"); 179 | let var_builder = ctx.var_builder.as_ref().expect("No var_builder specified"); 180 | 181 | log::info!("loading embeddings ..."); 182 | let embedding: Embedding = candle_nn::embedding( 183 | config.vocab_size, 184 | config.hidden_size, 185 | var_builder.pp("model.embed_tokens"), 186 | )?; 187 | 188 | log::info!("loading lm_head ..."); 189 | let lm_head = linear( 190 | config.hidden_size, 191 | config.vocab_size, 192 | var_builder.pp("lm_head"), 193 | )?; 194 | 195 | log::info!("loading model.norm ..."); 196 | let ln_f = candle_nn::rms_norm( 197 | config.hidden_size, 198 | config.rms_norm_eps, 199 | var_builder.pp("model.norm"), 200 | )?; 201 | 202 | log::info!("loading {} blocks ...", config.num_hidden_layers); 203 | 204 | let mut blocks: Vec> = vec![]; 205 | 206 | for i in 0..config.num_hidden_layers { 207 | let block_layer_name = format!("model.layers.{i}"); 208 | if let Some((node_name, node)) = ctx.topology.get_node_for_layer(&block_layer_name) { 209 | log::debug!("node {node_name} will serve {}", &block_layer_name); 210 | blocks.push(Box::new( 211 | crate::cake::Client::new(ctx.device.clone(), &node.host, &block_layer_name) 212 | .await?, 213 | )); 214 | } else { 215 | log::debug!("{} will be served locally", &block_layer_name); 216 | blocks.push(Transformer::load(block_layer_name.clone(), ctx)?); 217 | } 218 | } 219 | 220 | for block in &blocks { 221 | log::info!(" {}", block) 222 | } 223 | 224 | let (tokenizer, eos_token_id) = load_tokenizer(ctx)?; 225 | let tokens = vec![]; 226 | let history = History::new(); 227 | 228 | let logits_processor = create_logits_processor(ctx); 229 | let index_pos = 0; 230 | 231 | log::info!( 232 | "model loaded - mem={}", 233 | human_bytes::human_bytes(memory_stats::memory_stats().unwrap().physical_mem as f64) 234 | ); 235 | 236 | let generated = 0; 237 | 238 | Ok(Some(Box::new(Self { 239 | tokenizer, 240 | tokens, 241 | generated, 242 | history, 243 | eos_token_id, 244 | index_pos, 245 | ctx: ctx.clone(), 246 | embedding, 247 | blocks, 248 | ln_f, 249 | lm_head, 250 | logits_processor, 251 | }))) 252 | } 253 | } 254 | 255 | #[async_trait] 256 | impl TextGenerator for LLama { 257 | /// Add a message to the chat history. 258 | fn add_message(&mut self, message: Message) -> Result<()> { 259 | self.history.push(message); 260 | Ok(()) 261 | } 262 | 263 | /// Reset the chat pipeline state. 264 | fn reset(&mut self) -> Result<()> { 265 | self.tokens.clear(); 266 | self.history.clear(); 267 | self.ctx.cache.as_mut().expect("No cache specified").clear(); 268 | self.index_pos = 0; 269 | self.generated = 0; 270 | Ok(()) 271 | } 272 | 273 | async fn goodbye(&mut self) -> Result<()> { 274 | let num_blocks = self.blocks.len(); 275 | let mut block_idx = 0; 276 | while block_idx < num_blocks { 277 | self.blocks[block_idx] 278 | .goodbye() 279 | .await 280 | .map_err(|e| anyhow!("error in reset operation for block {block_idx}: {e}"))?; 281 | block_idx += 1; 282 | } 283 | Ok(()) 284 | } 285 | 286 | /// Return the next token. 287 | async fn next_token(&mut self, index: usize) -> Result { 288 | log::trace!("model.next_token({index})"); 289 | 290 | // Prefill tokens with chat history the first time. 291 | if self.generated == 0 { 292 | self.start_dialog_prompt()?; 293 | } 294 | 295 | let num_tokens = self.tokens.len(); 296 | let (context_size, context_index) = if self 297 | .ctx 298 | .cache 299 | .as_ref() 300 | .expect("No cache specified") 301 | .with_kv_cache() 302 | && index > 0 303 | { 304 | (1, self.index_pos) 305 | } else { 306 | (num_tokens, 0) 307 | }; 308 | 309 | let context_offset = num_tokens.saturating_sub(context_size); 310 | let context_tokens = &self.tokens[context_offset..]; 311 | let num_context_tokens = context_tokens.len(); 312 | 313 | let input = Tensor::new(context_tokens, &self.ctx.device)? 314 | .unsqueeze(0) 315 | .map_err(|e| anyhow!("error squeezing context tokens: {e}"))?; 316 | 317 | // log::info!("input={:?} context_index={context_index}", input.shape()); 318 | 319 | let logits = self 320 | .forward(&input, context_index) 321 | .await 322 | .map_err(|e| anyhow!("error in model.forward: {e}"))?; 323 | 324 | let logits = logits 325 | .squeeze(0) 326 | .map_err(|e| anyhow!("error squeezing logits: {e}"))?; 327 | 328 | let logits = if self.ctx.args.repeat_penalty == 1. { 329 | logits 330 | } else { 331 | let start_at = num_tokens.saturating_sub(self.ctx.args.repeat_last_n); 332 | candle_transformers::utils::apply_repeat_penalty( 333 | &logits, 334 | self.ctx.args.repeat_penalty, 335 | &self.tokens[start_at..], 336 | )? 337 | }; 338 | self.index_pos += num_context_tokens; 339 | 340 | let next_token = self 341 | .logits_processor 342 | .sample(&logits) 343 | .map_err(|e| anyhow!("error sampling logits {logits}: {e}"))?; 344 | self.generated += 1; 345 | self.tokens.push(next_token); 346 | 347 | Ok(Token { 348 | id: next_token, 349 | text: match self.tokenizer.decode(&[next_token], false) { 350 | Ok(s) => Some(s), 351 | Err(e) => { 352 | log::error!("could not decode token {next_token}: {e}"); 353 | None 354 | } 355 | }, 356 | is_end_of_stream: Some(next_token) == self.eos_token_id, 357 | }) 358 | } 359 | 360 | /// Return the number of generated tokens so far. 361 | fn generated_tokens(&self) -> usize { 362 | self.generated 363 | } 364 | } 365 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/mlp.rs: -------------------------------------------------------------------------------- 1 | use candle_core::{Result, Tensor}; 2 | use candle_nn::{linear_no_bias as linear, Linear, Module, VarBuilder}; 3 | 4 | /// Multi-perceptron implementation. 5 | #[allow(clippy::upper_case_acronyms)] 6 | #[derive(Debug, Clone)] 7 | pub struct MLP { 8 | gate_proj: Linear, 9 | up_proj: Linear, 10 | down_proj: Linear, 11 | } 12 | 13 | impl MLP { 14 | /// Execute MLP(x). 15 | pub fn forward(&self, x: &Tensor) -> Result { 16 | let x = (candle_nn::ops::silu(&self.gate_proj.forward(x)?)? * self.up_proj.forward(x)?)?; 17 | self.down_proj.forward(&x) 18 | } 19 | 20 | /// Load this block from the VarBuilder given the specific configuration. 21 | pub fn load(vb: VarBuilder, cfg: &super::Config) -> Result { 22 | let h_size = cfg.hidden_size; 23 | let i_size = cfg.intermediate_size; 24 | let gate_proj = linear(h_size, i_size, vb.pp("gate_proj"))?; 25 | let up_proj = linear(h_size, i_size, vb.pp("up_proj"))?; 26 | let down_proj = linear(i_size, h_size, vb.pp("down_proj"))?; 27 | Ok(Self { 28 | gate_proj, 29 | up_proj, 30 | down_proj, 31 | }) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/mod.rs: -------------------------------------------------------------------------------- 1 | //! This module contains model and inference specific code. 2 | mod attention; 3 | mod cache; 4 | mod config; 5 | mod history; 6 | mod llama; 7 | mod mlp; 8 | mod transformer; 9 | 10 | pub use attention::*; 11 | pub use cache::*; 12 | pub use config::*; 13 | pub use history::*; 14 | pub use llama::*; 15 | pub use mlp::*; 16 | pub use transformer::*; 17 | -------------------------------------------------------------------------------- /cake-core/src/models/llama3/transformer.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use candle_core::Tensor; 3 | use candle_nn::{Module, RmsNorm}; 4 | 5 | use crate::cake::{Context, Forwarder}; 6 | use async_trait::async_trait; 7 | 8 | use super::{CausalSelfAttention, MLP}; 9 | 10 | /// Transformer block with causal self attention and several caching strategies. 11 | #[derive(Debug, Clone)] 12 | pub struct Transformer { 13 | name: String, 14 | rms_1: RmsNorm, 15 | attn: CausalSelfAttention, 16 | rms_2: RmsNorm, 17 | mlp: MLP, 18 | } 19 | 20 | impl std::fmt::Display for Transformer { 21 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 22 | write!(f, "{} (local)", &self.name) 23 | } 24 | } 25 | 26 | #[async_trait] 27 | impl Forwarder for Transformer { 28 | fn load(name: String, ctx: &Context) -> Result> { 29 | let vb = ctx 30 | .var_builder 31 | .as_ref() 32 | .expect("No var_builder specified") 33 | .pp(&name); 34 | let cfg = ctx.config.as_ref().expect("No config specified"); 35 | 36 | let attn = super::CausalSelfAttention::load(vb.pp("self_attn"), cfg)?; 37 | let mlp = super::MLP::load(vb.pp("mlp"), cfg)?; 38 | let rms_1 = 39 | candle_nn::rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?; 40 | let rms_2 = candle_nn::rms_norm( 41 | cfg.hidden_size, 42 | cfg.rms_norm_eps, 43 | vb.pp("post_attention_layernorm"), 44 | )?; 45 | Ok(Box::new(Self { 46 | name, 47 | rms_1, 48 | attn, 49 | rms_2, 50 | mlp, 51 | })) 52 | } 53 | 54 | async fn forward( 55 | &self, 56 | x: &Tensor, 57 | index_pos: usize, 58 | block_idx: usize, 59 | ctx: &mut Context, 60 | ) -> Result { 61 | let residual = x; 62 | 63 | let x = self.rms_1.forward(x).map_err(|e| anyhow!("rms_1: {e}"))?; 64 | let x = (self 65 | .attn 66 | .forward( 67 | &x, 68 | index_pos, 69 | block_idx, 70 | ctx.cache.as_mut().expect("No cache specified"), 71 | ) 72 | .map_err(|e| anyhow!("attention: {e}"))? 73 | + residual) 74 | .map_err(|e| anyhow!("residual: {e}"))?; 75 | let residual = &x; 76 | let x = self.rms_2.forward(&x).map_err(|e| anyhow!("rms_2: {e}"))?; 77 | let x = (self.mlp.forward(&x).map_err(|e| anyhow!("mlp: {e}"))? + residual) 78 | .map_err(|e| anyhow!("mlp residual: {e}"))?; 79 | 80 | Ok(x) 81 | } 82 | 83 | async fn forward_mut( 84 | &mut self, 85 | x: &Tensor, 86 | index_pos: usize, 87 | block_idx: usize, 88 | ctx: &mut Context, 89 | ) -> Result { 90 | self.forward(x, index_pos, block_idx, ctx).await 91 | } 92 | 93 | fn layer_name(&self) -> &str { 94 | &self.name 95 | } 96 | } 97 | -------------------------------------------------------------------------------- /cake-core/src/models/mod.rs: -------------------------------------------------------------------------------- 1 | use anyhow::Result; 2 | use async_trait::async_trait; 3 | use image::{ImageBuffer, Rgb}; 4 | 5 | use chat::Message; 6 | 7 | use crate::cake::{Context, Forwarder}; 8 | use crate::ImageGenerationArgs; 9 | 10 | pub mod chat; 11 | pub mod llama3; 12 | pub mod sd; 13 | 14 | /// A token. 15 | pub struct Token { 16 | /// Numerical identifier. 17 | pub id: u32, 18 | /// Resolved text token or None if not present in the tokenizer. 19 | pub text: Option, 20 | /// Set to true if the stream of tokens is over. 21 | pub is_end_of_stream: bool, 22 | } 23 | 24 | impl std::fmt::Display for Token { 25 | fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 26 | write!( 27 | f, 28 | "{}", 29 | if let Some(text) = &self.text { 30 | text.clone() 31 | } else { 32 | format!("", self.id) 33 | } 34 | ) 35 | } 36 | } 37 | 38 | /// A model must implement this trait in order to be usable by the Cake framework. 39 | #[async_trait] 40 | pub trait Generator { 41 | /// This associated type determines which part of the model can be sharded. 42 | type Shardable: Forwarder; 43 | 44 | /// The model name. 45 | const MODEL_NAME: &'static str; 46 | 47 | /// Load the model from the context. 48 | async fn load(context: &mut Context) -> Result>>; 49 | } 50 | 51 | #[async_trait] 52 | pub trait TextGenerator: Generator { 53 | /// Add a message to the chat. 54 | fn add_message(&mut self, message: Message) -> Result<()>; 55 | /// Clear chat history. 56 | fn reset(&mut self) -> Result<()>; 57 | /// clear worker kv cache 58 | async fn goodbye(&mut self) -> Result<()>; 59 | 60 | /// Return the next token. 61 | async fn next_token(&mut self, index: usize) -> Result; 62 | /// Return the number of generated tokens so far. 63 | fn generated_tokens(&self) -> usize; 64 | } 65 | 66 | #[async_trait] 67 | pub trait ImageGenerator: Generator { 68 | async fn generate_image( 69 | &mut self, 70 | args: &ImageGenerationArgs, 71 | mut callback: F, 72 | ) -> Result<(), anyhow::Error> 73 | where 74 | F: FnMut(Vec, Vec>>) + Send + 'static; 75 | } 76 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/clip.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::{Context, Forwarder}; 2 | use crate::models::sd::sd::ModelFile; 3 | use crate::models::sd::util::get_sd_config; 4 | use crate::StableDiffusionVersion; 5 | use async_trait::async_trait; 6 | use candle_core::{DType, Device, Module, Tensor}; 7 | use candle_transformers::models::stable_diffusion; 8 | use candle_transformers::models::stable_diffusion::clip::ClipTextTransformer; 9 | use log::info; 10 | use std::fmt::{Debug, Display, Formatter}; 11 | 12 | #[derive(Debug)] 13 | pub struct Clip { 14 | clip_model: ClipTextTransformer, 15 | layer_name: &'static str, 16 | } 17 | 18 | impl Display for Clip { 19 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 20 | write!(f, "{} (local)", &self.layer_name) 21 | } 22 | } 23 | 24 | #[async_trait] 25 | impl Forwarder for Clip { 26 | fn load(name: String, ctx: &Context) -> anyhow::Result> 27 | where 28 | Self: Sized, 29 | { 30 | let model_file; 31 | let model_filename; 32 | let sd_config = get_sd_config(ctx)?; 33 | let clip_config; 34 | 35 | match name.as_str() { 36 | "clip" => { 37 | model_file = ModelFile::Clip; 38 | model_filename = ctx.args.sd_args.clip.clone(); 39 | clip_config = sd_config.clip; 40 | } 41 | "clip2" => { 42 | model_file = ModelFile::Clip2; 43 | model_filename = ctx.args.sd_args.clip2.clone(); 44 | clip_config = sd_config.clip2.unwrap(); 45 | } 46 | _ => { 47 | anyhow::bail!("name not recognized"); 48 | } 49 | }; 50 | 51 | Self::load_model( 52 | model_file, 53 | model_filename, 54 | ctx.args.sd_args.sd_version, 55 | ctx.args.sd_args.use_f16, 56 | &ctx.device, 57 | ctx.dtype, 58 | ctx.args.model.clone(), 59 | &clip_config, 60 | ) 61 | } 62 | 63 | async fn forward( 64 | &self, 65 | x: &Tensor, 66 | _index_pos: usize, 67 | _block_idx: usize, 68 | _ctx: &mut Context, 69 | ) -> anyhow::Result { 70 | info!("Clip model forwarding"); 71 | Ok(self 72 | .clip_model 73 | .forward(x) 74 | .expect("Error running Clip forward")) 75 | } 76 | 77 | async fn forward_mut( 78 | &mut self, 79 | x: &Tensor, 80 | index_pos: usize, 81 | block_idx: usize, 82 | ctx: &mut Context, 83 | ) -> anyhow::Result { 84 | self.forward(x, index_pos, block_idx, ctx).await 85 | } 86 | 87 | fn layer_name(&self) -> &str { 88 | self.layer_name 89 | } 90 | } 91 | 92 | impl Clip { 93 | pub fn load_model( 94 | model_file: ModelFile, 95 | name: Option, 96 | version: StableDiffusionVersion, 97 | use_f16: bool, 98 | device: &Device, 99 | dtype: DType, 100 | cache_dir: String, 101 | config: &stable_diffusion::clip::Config, 102 | ) -> anyhow::Result> 103 | where 104 | Self: Sized, 105 | { 106 | let clip_weights = model_file.get(name, version, use_f16, cache_dir)?; 107 | let clip_model = 108 | stable_diffusion::build_clip_transformer(config, clip_weights, device, dtype)?; 109 | let layer_name = model_file.name(); 110 | 111 | info!("Loading Clip model: {layer_name}"); 112 | 113 | Ok(Box::new(Self { 114 | clip_model, 115 | layer_name, 116 | })) 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/mod.rs: -------------------------------------------------------------------------------- 1 | mod clip; 2 | mod safe_scheduler; 3 | mod sd; 4 | mod sd_shardable; 5 | mod unet; 6 | mod util; 7 | mod vae; 8 | 9 | pub use sd::*; 10 | pub use util::*; 11 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/safe_scheduler.rs: -------------------------------------------------------------------------------- 1 | pub struct SafeScheduler { 2 | pub(crate) scheduler: T, 3 | } 4 | 5 | unsafe impl Send for SafeScheduler {} 6 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/sd.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::{Context, Forwarder}; 2 | use crate::models::sd::clip::Clip; 3 | use crate::models::sd::safe_scheduler::SafeScheduler; 4 | use crate::models::sd::sd_shardable::SDShardable; 5 | use crate::models::sd::unet::UNet; 6 | use crate::models::sd::vae::VAE; 7 | use crate::models::{Generator, ImageGenerator}; 8 | use crate::{ImageGenerationArgs, SDArgs, StableDiffusionVersion}; 9 | use anyhow::{Error as E, Result}; 10 | use async_trait::async_trait; 11 | use candle_core::{DType, Device, IndexOp, Tensor, D}; 12 | use candle_transformers::models::stable_diffusion::StableDiffusionConfig; 13 | use hf_hub::api::sync::ApiBuilder; 14 | use hf_hub::Cache; 15 | use image::{ImageBuffer, Rgb}; 16 | use log::{debug, info}; 17 | use tokenizers::Tokenizer; 18 | 19 | #[derive(Debug, Clone, Copy, PartialEq, Eq)] 20 | pub enum ModelFile { 21 | Tokenizer, 22 | Tokenizer2, 23 | Clip, 24 | Clip2, 25 | Unet, 26 | Vae, 27 | } 28 | 29 | impl ModelFile { 30 | pub fn get( 31 | &self, 32 | filename: Option, 33 | version: StableDiffusionVersion, 34 | use_f16: bool, 35 | cache_dir: String, 36 | ) -> Result { 37 | match filename { 38 | Some(filename) => Ok(std::path::PathBuf::from(filename)), 39 | None => { 40 | let (repo, path) = match self { 41 | Self::Tokenizer => { 42 | let tokenizer_repo = match version { 43 | StableDiffusionVersion::V1_5 | StableDiffusionVersion::V2_1 => { 44 | "openai/clip-vit-base-patch32" 45 | } 46 | StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo => { 47 | // This seems similar to the patch32 version except some very small 48 | // difference in the split regex. 49 | "openai/clip-vit-large-patch14" 50 | } 51 | }; 52 | (tokenizer_repo, "tokenizer.json") 53 | } 54 | Self::Tokenizer2 => { 55 | ("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", "tokenizer.json") 56 | } 57 | Self::Clip => (version.repo(), version.clip_file(use_f16)), 58 | Self::Clip2 => (version.repo(), version.clip2_file(use_f16)), 59 | Self::Unet => (version.repo(), version.unet_file(use_f16)), 60 | Self::Vae => { 61 | // Override for SDXL when using f16 weights. 62 | // See https://github.com/huggingface/candle/issues/1060 63 | if matches!( 64 | version, 65 | StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo, 66 | ) && use_f16 67 | { 68 | ( 69 | "madebyollin/sdxl-vae-fp16-fix", 70 | "diffusion_pytorch_model.safetensors", 71 | ) 72 | } else { 73 | (version.repo(), version.vae_file(use_f16)) 74 | } 75 | } 76 | }; 77 | let mut cache_path = std::path::PathBuf::from(cache_dir.as_str()); 78 | cache_path.push("hub"); 79 | 80 | debug!("Model cache dir: {:?}", cache_path); 81 | 82 | let cache = Cache::new(cache_path); 83 | let api = ApiBuilder::from_cache(cache).build()?; 84 | 85 | let filename = api.model(repo.to_string()).get(path)?; 86 | Ok(filename) 87 | } 88 | } 89 | } 90 | 91 | pub(crate) fn name(&self) -> &'static str { 92 | match *self { 93 | ModelFile::Tokenizer => "tokenizer", 94 | ModelFile::Tokenizer2 => "tokenizer_2", 95 | ModelFile::Clip => "clip", 96 | ModelFile::Clip2 => "clip2", 97 | ModelFile::Unet => "unet", 98 | ModelFile::Vae => "vae", 99 | } 100 | } 101 | } 102 | 103 | pub struct SD { 104 | tokenizer: Tokenizer, 105 | pad_id: u32, 106 | tokenizer_2: Option, 107 | pad_id_2: Option, 108 | text_model: Box, 109 | text_model_2: Option>, 110 | vae: Box, 111 | unet: Box, 112 | sd_version: StableDiffusionVersion, 113 | sd_config: StableDiffusionConfig, 114 | context: Context, 115 | } 116 | 117 | #[async_trait] 118 | impl Generator for SD { 119 | type Shardable = SDShardable; 120 | const MODEL_NAME: &'static str = "stable-diffusion"; 121 | 122 | async fn load(context: &mut Context) -> Result>> { 123 | let SDArgs { 124 | tokenizer, 125 | tokenizer_2, 126 | sd_version, 127 | use_f16, 128 | width, 129 | height, 130 | sliced_attention_size, 131 | clip, 132 | clip2, 133 | vae, 134 | unet, 135 | use_flash_attention, 136 | .. 137 | } = &context.args.sd_args; 138 | 139 | let sd_config = match *sd_version { 140 | StableDiffusionVersion::V1_5 => { 141 | StableDiffusionConfig::v1_5(*sliced_attention_size, *height, *width) 142 | } 143 | StableDiffusionVersion::V2_1 => { 144 | StableDiffusionConfig::v2_1(*sliced_attention_size, *height, *width) 145 | } 146 | StableDiffusionVersion::Xl => { 147 | StableDiffusionConfig::sdxl(*sliced_attention_size, *height, *width) 148 | } 149 | StableDiffusionVersion::Turbo => { 150 | StableDiffusionConfig::sdxl_turbo(*sliced_attention_size, *height, *width) 151 | } 152 | }; 153 | 154 | // Tokenizer 155 | info!("Loading the Tokenizer..."); 156 | 157 | let tokenizer_file = ModelFile::Tokenizer; 158 | let tokenizer = tokenizer_file.get( 159 | tokenizer.clone(), 160 | *sd_version, 161 | *use_f16, 162 | context.args.model.clone(), 163 | )?; 164 | let tokenizer = Tokenizer::from_file(tokenizer).map_err(E::msg)?; 165 | 166 | let pad_id = match &sd_config.clip.pad_with { 167 | Some(padding) => *tokenizer.get_vocab(true).get(padding.as_str()).unwrap(), 168 | None => *tokenizer.get_vocab(true).get("<|endoftext|>").unwrap(), 169 | }; 170 | 171 | info!("Tokenizer loaded!"); 172 | 173 | // Tokenizer 2 174 | 175 | let mut tokenizer_2_option: Option = None; 176 | let mut pad_id_2: Option = None; 177 | 178 | if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version { 179 | info!("Loading the Tokenizer 2..."); 180 | 181 | let tokenizer_2_file = ModelFile::Tokenizer2; 182 | let tokenizer_2 = tokenizer_2_file.get( 183 | tokenizer_2.clone(), 184 | *sd_version, 185 | *use_f16, 186 | context.args.model.clone(), 187 | )?; 188 | let tokenizer_2 = Tokenizer::from_file(tokenizer_2).map_err(E::msg)?; 189 | 190 | if let Some(clip2) = &sd_config.clip2 { 191 | pad_id_2 = match &clip2.pad_with { 192 | Some(padding) => { 193 | Some(*tokenizer_2.get_vocab(true).get(padding.as_str()).unwrap()) 194 | } 195 | None => Some(*tokenizer_2.get_vocab(true).get("<|endoftext|>").unwrap()), 196 | }; 197 | } 198 | 199 | tokenizer_2_option = Some(tokenizer_2); 200 | 201 | info!("Tokenizer 2 loaded!"); 202 | } 203 | 204 | // Clip 205 | info!("Loading the Clip text model."); 206 | 207 | let text_model: Box; 208 | 209 | if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Clip.name()) 210 | { 211 | info!("node {node_name} will serve Clip"); 212 | text_model = Box::new( 213 | crate::cake::Client::new( 214 | context.device.clone(), 215 | &node.host, 216 | ModelFile::Clip.name(), 217 | ) 218 | .await?, 219 | ); 220 | } else { 221 | info!("Clip will be served locally"); 222 | text_model = Clip::load_model( 223 | ModelFile::Clip, 224 | clip.clone(), 225 | *sd_version, 226 | *use_f16, 227 | &context.device, 228 | context.dtype, 229 | context.args.model.clone(), 230 | &sd_config.clip, 231 | )?; 232 | } 233 | 234 | info!("Clip text model loaded!"); 235 | 236 | // Clip 2 237 | 238 | let mut text_model_2: Option> = None; 239 | if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version { 240 | info!("Loading the Clip 2 text model."); 241 | 242 | if let Some((node_name, node)) = 243 | context.topology.get_node_for_layer(ModelFile::Clip2.name()) 244 | { 245 | info!("node {node_name} will serve clip2"); 246 | text_model_2 = Some(Box::new( 247 | crate::cake::Client::new( 248 | context.device.clone(), 249 | &node.host, 250 | ModelFile::Clip2.name(), 251 | ) 252 | .await?, 253 | )); 254 | } else { 255 | info!("Clip 2 will be served locally"); 256 | text_model_2 = Some(Clip::load_model( 257 | ModelFile::Clip2, 258 | clip2.clone(), 259 | *sd_version, 260 | *use_f16, 261 | &context.device, 262 | context.dtype, 263 | context.args.model.clone(), 264 | sd_config.clip2.as_ref().unwrap(), 265 | )?); 266 | } 267 | 268 | info!("Clip 2 text model loaded!"); 269 | } 270 | 271 | // VAE 272 | info!("Loading the VAE..."); 273 | 274 | let vae_model: Box; 275 | 276 | if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Vae.name()) 277 | { 278 | info!("node {node_name} will serve VAE"); 279 | vae_model = Box::new( 280 | crate::cake::Client::new(context.device.clone(), &node.host, ModelFile::Vae.name()) 281 | .await?, 282 | ); 283 | } else { 284 | info!("VAE will be served locally"); 285 | vae_model = VAE::load_model( 286 | vae.clone(), 287 | *sd_version, 288 | *use_f16, 289 | &context.device, 290 | context.dtype, 291 | context.args.model.clone(), 292 | &sd_config, 293 | )?; 294 | } 295 | 296 | info!("VAE loaded!"); 297 | 298 | // Unet 299 | info!("Loading the UNet."); 300 | 301 | let unet_model: Box; 302 | if let Some((node_name, node)) = context.topology.get_node_for_layer(ModelFile::Unet.name()) 303 | { 304 | info!("node {node_name} will serve UNet"); 305 | unet_model = Box::new( 306 | crate::cake::Client::new( 307 | context.device.clone(), 308 | &node.host, 309 | ModelFile::Unet.name(), 310 | ) 311 | .await?, 312 | ); 313 | } else { 314 | info!("UNet will be served locally"); 315 | unet_model = UNet::load_model( 316 | unet.clone(), 317 | *use_flash_attention, 318 | *sd_version, 319 | *use_f16, 320 | &context.device, 321 | context.dtype, 322 | context.args.model.clone(), 323 | &sd_config, 324 | )?; 325 | } 326 | 327 | info!("UNet loaded!"); 328 | 329 | Ok(Some(Box::new(Self { 330 | tokenizer, 331 | sd_version: *sd_version, 332 | sd_config, 333 | pad_id, 334 | text_model, 335 | tokenizer_2: tokenizer_2_option, 336 | pad_id_2, 337 | text_model_2, 338 | vae: vae_model, 339 | unet: unet_model, 340 | context: context.clone(), 341 | }))) 342 | } 343 | } 344 | 345 | #[async_trait] 346 | impl ImageGenerator for SD { 347 | async fn generate_image( 348 | &mut self, 349 | args: &ImageGenerationArgs, 350 | mut callback: F, 351 | ) -> Result<(), anyhow::Error> 352 | where 353 | F: FnMut(Vec, Vec>>) + Send + 'static, 354 | { 355 | use tracing_chrome::ChromeLayerBuilder; 356 | use tracing_subscriber::prelude::*; 357 | 358 | let ImageGenerationArgs { 359 | image_prompt, 360 | uncond_prompt, 361 | n_steps, 362 | num_samples, 363 | bsize, 364 | tracing, 365 | guidance_scale, 366 | img2img, 367 | img2img_strength, 368 | image_seed, 369 | intermediary_images, 370 | .. 371 | } = args; 372 | 373 | let sd_version = self.sd_version; 374 | 375 | if !(0. ..=1.).contains(img2img_strength) { 376 | anyhow::bail!("img2img-strength should be between 0 and 1, got {img2img_strength}") 377 | } 378 | 379 | let _guard = if *tracing { 380 | let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); 381 | tracing_subscriber::registry().with(chrome_layer).init(); 382 | Some(guard) 383 | } else { 384 | None 385 | }; 386 | 387 | let guidance_scale = match guidance_scale { 388 | Some(guidance_scale) => guidance_scale, 389 | None => &match sd_version { 390 | StableDiffusionVersion::V1_5 391 | | StableDiffusionVersion::V2_1 392 | | StableDiffusionVersion::Xl => 7.5, 393 | StableDiffusionVersion::Turbo => 0., 394 | }, 395 | }; 396 | let n_steps = match n_steps { 397 | Some(n_steps) => n_steps, 398 | None => &match sd_version { 399 | StableDiffusionVersion::V1_5 400 | | StableDiffusionVersion::V2_1 401 | | StableDiffusionVersion::Xl => 30, 402 | StableDiffusionVersion::Turbo => 1, 403 | }, 404 | }; 405 | 406 | if let Some(seed) = image_seed { 407 | self.context.device.set_seed(*seed)?; 408 | } 409 | let use_guide_scale = guidance_scale > &1.0; 410 | 411 | let mut text_embeddings: Vec = Vec::new(); 412 | 413 | let text_embeddings_1 = self 414 | .text_embeddings(image_prompt, uncond_prompt, use_guide_scale, true) 415 | .await?; 416 | 417 | text_embeddings.push(text_embeddings_1); 418 | 419 | if let StableDiffusionVersion::Xl | StableDiffusionVersion::Turbo = sd_version { 420 | let text_embeddings_2 = self 421 | .text_embeddings(image_prompt, uncond_prompt, use_guide_scale, false) 422 | .await?; 423 | 424 | text_embeddings.push(text_embeddings_2); 425 | } 426 | 427 | let text_embeddings = Tensor::cat(&text_embeddings, D::Minus1)?; 428 | let text_embeddings = text_embeddings.repeat((*bsize, 1, 1))?; 429 | debug!("{text_embeddings:?}"); 430 | 431 | let init_latent_dist_sample = match &img2img { 432 | None => None, 433 | Some(image) => { 434 | let image = image_preprocess(image)?.to_device(&self.context.device)?; 435 | Some(VAE::encode(&mut self.vae, image, &mut self.context).await?) 436 | } 437 | }; 438 | 439 | let t_start = if img2img.is_some() { 440 | *n_steps - (*n_steps as f64 * img2img_strength) as usize 441 | } else { 442 | 0 443 | }; 444 | 445 | let vae_scale = match sd_version { 446 | StableDiffusionVersion::V1_5 447 | | StableDiffusionVersion::V2_1 448 | | StableDiffusionVersion::Xl => 0.18215, 449 | StableDiffusionVersion::Turbo => 0.13025, 450 | }; 451 | 452 | let safe_scheduler = SafeScheduler { 453 | scheduler: self.sd_config.build_scheduler(*n_steps)?, 454 | }; 455 | 456 | for idx in 0..(*num_samples) { 457 | let timesteps = safe_scheduler.scheduler.timesteps(); 458 | let latents = match &init_latent_dist_sample { 459 | Some(init_latent_dist) => { 460 | let latents = 461 | (init_latent_dist * vae_scale)?.to_device(&self.context.device)?; 462 | if t_start < timesteps.len() { 463 | let noise = latents.randn_like(0f64, 1f64)?; 464 | safe_scheduler 465 | .scheduler 466 | .add_noise(&latents, noise, timesteps[t_start])? 467 | } else { 468 | latents 469 | } 470 | } 471 | 472 | None => { 473 | let latents = Tensor::randn( 474 | 0f32, 475 | 1f32, 476 | ( 477 | *bsize, 478 | 4, 479 | self.sd_config.height / 8, 480 | self.sd_config.width / 8, 481 | ), 482 | &self.context.device, 483 | )?; 484 | // scale the initial noise by the standard deviation required by the scheduler 485 | (latents * safe_scheduler.scheduler.init_noise_sigma())? 486 | } 487 | }; 488 | 489 | let mut latents = latents.to_dtype(self.context.dtype)?; 490 | 491 | debug!("Starting sampling..."); 492 | 493 | for (timestep_index, ×tep) in timesteps.iter().enumerate() { 494 | if timestep_index < t_start { 495 | continue; 496 | } 497 | let start_time = std::time::Instant::now(); 498 | let latent_model_input = if use_guide_scale { 499 | Tensor::cat(&[&latents, &latents], 0)? 500 | } else { 501 | latents.clone() 502 | }; 503 | 504 | let latent_model_input = safe_scheduler 505 | .scheduler 506 | .scale_model_input(latent_model_input, timestep)?; 507 | 508 | debug!("UNet forwarding..."); 509 | 510 | let noise_pred = UNet::forward_unpacked( 511 | &mut self.unet, 512 | latent_model_input, 513 | text_embeddings.clone(), 514 | timestep, 515 | &mut self.context, 516 | ) 517 | .await?; 518 | 519 | debug!("UNet forwarding completed!"); 520 | 521 | let noise_pred = if use_guide_scale { 522 | debug!("Applying guidance scale..."); 523 | 524 | let noise_pred = noise_pred.chunk(2, 0)?; 525 | let (noise_pred_uncond, noise_pred_text) = (&noise_pred[0], &noise_pred[1]); 526 | 527 | (noise_pred_uncond 528 | + ((noise_pred_text - noise_pred_uncond)? * *guidance_scale)?)? 529 | } else { 530 | noise_pred 531 | }; 532 | 533 | debug!("Scheduler stepping..."); 534 | 535 | latents = safe_scheduler 536 | .scheduler 537 | .step(&noise_pred, timestep, &latents)?; 538 | 539 | let dt = start_time.elapsed().as_secs_f32(); 540 | info!("step {}/{n_steps} done, {:.2}s", timestep_index + 1, dt); 541 | 542 | if *intermediary_images != 0 && timestep_index % *intermediary_images == 0 { 543 | let intermediary_batched_images = 544 | self.split_images(&latents, vae_scale, *bsize).await?; 545 | callback(intermediary_batched_images); 546 | } 547 | } 548 | 549 | debug!( 550 | "Generating the final image for sample {}/{}.", 551 | idx + 1, 552 | num_samples 553 | ); 554 | 555 | let batched_images = self.split_images(&latents, vae_scale, *bsize).await?; 556 | 557 | callback(batched_images); 558 | } 559 | 560 | Ok(()) 561 | } 562 | } 563 | 564 | impl SD { 565 | async fn split_images( 566 | &mut self, 567 | latents: &Tensor, 568 | vae_scale: f64, 569 | bsize: usize, 570 | ) -> Result, Vec>>> { 571 | let mut images_vec = Vec::new(); 572 | 573 | let scaled = (latents / vae_scale)?; 574 | let images = VAE::decode(&mut self.vae, scaled, &mut self.context).await?; 575 | let images = ((images / 2.)? + 0.5)?.to_device(&Device::Cpu)?; 576 | let images = (images.clamp(0f32, 1.)? * 255.)?.to_dtype(DType::U8)?; 577 | for batch in 0..bsize { 578 | let image_tensor = images.i(batch)?; 579 | let (channel, height, width) = image_tensor.dims3()?; 580 | if channel != 3 { 581 | anyhow::bail!("save_image expects an input of shape (3, height, width)") 582 | } 583 | let image_tensor = image_tensor.permute((1, 2, 0))?.flatten_all()?; 584 | let pixels = image_tensor.to_vec1::()?; 585 | 586 | let image: ImageBuffer, Vec> = 587 | match ImageBuffer::from_raw(width as u32, height as u32, pixels) { 588 | Some(image) => image, 589 | None => anyhow::bail!("Error splitting images"), 590 | }; 591 | images_vec.push(image) 592 | } 593 | Ok(images_vec) 594 | } 595 | 596 | async fn text_embeddings( 597 | &mut self, 598 | prompt: &str, 599 | uncond_prompt: &str, 600 | use_guide_scale: bool, 601 | first: bool, 602 | ) -> Result { 603 | let tokenizer; 604 | let text_model; 605 | let pad_id; 606 | let max_token_embeddings; 607 | 608 | if first { 609 | tokenizer = &self.tokenizer; 610 | text_model = &mut self.text_model; 611 | pad_id = self.pad_id; 612 | max_token_embeddings = self.sd_config.clip.max_position_embeddings; 613 | } else { 614 | tokenizer = self.tokenizer_2.as_ref().unwrap(); 615 | text_model = self.text_model_2.as_mut().unwrap(); 616 | pad_id = self.pad_id_2.unwrap(); 617 | max_token_embeddings = self 618 | .sd_config 619 | .clip2 620 | .as_ref() 621 | .unwrap() 622 | .max_position_embeddings; 623 | } 624 | 625 | info!("Running with prompt \"{prompt}\"."); 626 | 627 | let mut tokens = tokenizer 628 | .encode(prompt, true) 629 | .map_err(E::msg)? 630 | .get_ids() 631 | .to_vec(); 632 | 633 | if tokens.len() > max_token_embeddings { 634 | anyhow::bail!( 635 | "the prompt is too long, {} > max-tokens ({})", 636 | tokens.len(), 637 | max_token_embeddings 638 | ) 639 | } 640 | 641 | while tokens.len() < max_token_embeddings { 642 | tokens.push(pad_id) 643 | } 644 | 645 | let tokens = Tensor::new(tokens.as_slice(), &self.context.device)?.unsqueeze(0)?; 646 | 647 | let text_embeddings = text_model 648 | .forward_mut(&tokens, 0, 0, &mut self.context) 649 | .await?; 650 | 651 | let text_embeddings = if use_guide_scale { 652 | let mut uncond_tokens = tokenizer 653 | .encode(uncond_prompt, true) 654 | .map_err(E::msg)? 655 | .get_ids() 656 | .to_vec(); 657 | if uncond_tokens.len() > max_token_embeddings { 658 | anyhow::bail!( 659 | "the negative prompt is too long, {} > max-tokens ({})", 660 | uncond_tokens.len(), 661 | max_token_embeddings 662 | ) 663 | } 664 | while uncond_tokens.len() < max_token_embeddings { 665 | uncond_tokens.push(pad_id) 666 | } 667 | 668 | let uncond_tokens = 669 | Tensor::new(uncond_tokens.as_slice(), &self.context.device)?.unsqueeze(0)?; 670 | 671 | info!("Clip forwarding..."); 672 | let uncond_embeddings = text_model 673 | .forward_mut(&uncond_tokens, 0, 0, &mut self.context) 674 | .await?; 675 | 676 | Tensor::cat(&[uncond_embeddings, text_embeddings], 0)?.to_dtype(self.context.dtype)? 677 | } else { 678 | text_embeddings.to_dtype(self.context.dtype)? 679 | }; 680 | 681 | Ok(text_embeddings) 682 | } 683 | } 684 | 685 | fn image_preprocess>(path: T) -> Result { 686 | let img = image::ImageReader::open(path)?.decode()?; 687 | let (height, width) = (img.height() as usize, img.width() as usize); 688 | let height = height - height % 32; 689 | let width = width - width % 32; 690 | let img = img.resize_to_fill( 691 | width as u32, 692 | height as u32, 693 | image::imageops::FilterType::CatmullRom, 694 | ); 695 | let img = img.to_rgb8(); 696 | let img = img.into_raw(); 697 | let img = Tensor::from_vec(img, (height, width, 3), &Device::Cpu)? 698 | .permute((2, 0, 1))? 699 | .to_dtype(DType::F32)? 700 | .affine(2. / 255., -1.)? 701 | .unsqueeze(0)?; 702 | Ok(img) 703 | } 704 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/sd_shardable.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::{Context, Forwarder}; 2 | use crate::models::sd::clip::Clip; 3 | use crate::models::sd::unet::UNet; 4 | use crate::models::sd::vae::VAE; 5 | use async_trait::async_trait; 6 | use candle_core::Tensor; 7 | use std::fmt::{Debug, Display, Formatter}; 8 | 9 | #[derive(Debug)] 10 | pub struct SDShardable { 11 | forwarder: Box, 12 | layer_name: String, 13 | } 14 | 15 | impl Display for SDShardable { 16 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 17 | write!(f, "{} (local)", &self.layer_name) 18 | } 19 | } 20 | 21 | #[async_trait] 22 | impl Forwarder for SDShardable { 23 | fn load(name: String, ctx: &Context) -> anyhow::Result> 24 | where 25 | Self: Sized, 26 | { 27 | let model: Box; 28 | 29 | match name.as_str() { 30 | "vae" => { 31 | model = VAE::load(name.clone(), ctx)?; 32 | } 33 | "clip" => { 34 | model = Clip::load(name.clone(), ctx)?; 35 | } 36 | "clip2" => { 37 | model = Clip::load(name.clone(), ctx)?; 38 | } 39 | "unet" => { 40 | model = UNet::load(name.clone(), ctx)?; 41 | } 42 | _ => { 43 | anyhow::bail!("Model name not recognized"); 44 | } 45 | } 46 | 47 | Ok(Box::new(Self { 48 | forwarder: model, 49 | layer_name: name, 50 | })) 51 | } 52 | 53 | async fn forward( 54 | &self, 55 | x: &Tensor, 56 | index_pos: usize, 57 | block_idx: usize, 58 | ctx: &mut Context, 59 | ) -> anyhow::Result { 60 | self.forwarder.forward(x, index_pos, block_idx, ctx).await 61 | } 62 | 63 | async fn forward_mut( 64 | &mut self, 65 | x: &Tensor, 66 | index_pos: usize, 67 | block_idx: usize, 68 | ctx: &mut Context, 69 | ) -> anyhow::Result { 70 | self.forwarder 71 | .forward_mut(x, index_pos, block_idx, ctx) 72 | .await 73 | } 74 | 75 | async fn forward_batch( 76 | &mut self, 77 | _x: &Tensor, 78 | _batch: Vec<(String, usize, usize)>, 79 | ctx: &mut Context, 80 | ) -> anyhow::Result { 81 | self.forwarder.forward_batch(_x, _batch, ctx).await 82 | } 83 | 84 | fn layer_name(&self) -> &str { 85 | &self.layer_name 86 | } 87 | 88 | fn ident(&self) -> &str { 89 | &self.layer_name 90 | } 91 | } 92 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/unet.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::{Context, Forwarder}; 2 | use crate::models::sd::util::{get_sd_config, pack_tensors, unpack_tensors}; 3 | use crate::models::sd::ModelFile; 4 | use crate::StableDiffusionVersion; 5 | use async_trait::async_trait; 6 | use candle_core::{DType, Device, Tensor}; 7 | use candle_transformers::models::stable_diffusion::unet_2d::UNet2DConditionModel; 8 | use candle_transformers::models::stable_diffusion::StableDiffusionConfig; 9 | use log::info; 10 | use std::fmt::{Debug, Display, Formatter}; 11 | 12 | #[derive(Debug)] 13 | pub struct UNet { 14 | unet_model: UNet2DConditionModel, 15 | } 16 | 17 | impl Display for UNet { 18 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 19 | write!(f, "UNet (local)") 20 | } 21 | } 22 | 23 | #[async_trait] 24 | impl Forwarder for UNet { 25 | fn load(_name: String, ctx: &Context) -> anyhow::Result> 26 | where 27 | Self: Sized, 28 | { 29 | let sd_config = get_sd_config(ctx)?; 30 | 31 | Self::load_model( 32 | ctx.args.sd_args.unet.clone(), 33 | ctx.args.sd_args.use_flash_attention, 34 | ctx.args.sd_args.sd_version, 35 | ctx.args.sd_args.use_f16, 36 | &ctx.device, 37 | ctx.dtype, 38 | ctx.args.model.clone(), 39 | &sd_config, 40 | ) 41 | } 42 | 43 | async fn forward( 44 | &self, 45 | x: &Tensor, 46 | _index_pos: usize, 47 | _block_idx: usize, 48 | ctx: &mut Context, 49 | ) -> anyhow::Result { 50 | let unpacked_tensors = unpack_tensors(x)?; 51 | let latent_model_input = &unpacked_tensors[0].to_dtype(ctx.dtype)?; 52 | let text_embeddings = &unpacked_tensors[1].to_dtype(ctx.dtype)?; 53 | 54 | let timestep_tensor = &unpacked_tensors[2]; 55 | let timestep_vec = timestep_tensor.to_vec1()?; 56 | let timestep_f32: &f32 = timestep_vec.first().expect("Error retrieving timestep"); 57 | 58 | info!("UNet model forwarding..."); 59 | 60 | Ok(self 61 | .unet_model 62 | .forward(latent_model_input, *timestep_f32 as f64, text_embeddings) 63 | .expect("Error running UNet forward")) 64 | } 65 | 66 | async fn forward_mut( 67 | &mut self, 68 | x: &Tensor, 69 | index_pos: usize, 70 | block_idx: usize, 71 | ctx: &mut Context, 72 | ) -> anyhow::Result { 73 | self.forward(x, index_pos, block_idx, ctx).await 74 | } 75 | 76 | fn layer_name(&self) -> &str { 77 | "unet" 78 | } 79 | } 80 | 81 | impl UNet { 82 | pub fn load_model( 83 | name: Option, 84 | use_flash_attn: bool, 85 | version: StableDiffusionVersion, 86 | use_f16: bool, 87 | device: &Device, 88 | dtype: DType, 89 | cache_dir: String, 90 | config: &StableDiffusionConfig, 91 | ) -> anyhow::Result> 92 | where 93 | Self: Sized, 94 | { 95 | let unet_weights = ModelFile::Unet.get(name, version, use_f16, cache_dir)?; 96 | let unet = config.build_unet(unet_weights, device, 4, use_flash_attn, dtype)?; 97 | 98 | info!("Loading UNet model..."); 99 | 100 | Ok(Box::new(Self { unet_model: unet })) 101 | } 102 | 103 | pub async fn forward_unpacked( 104 | forwarder: &mut Box, 105 | latent_model_input: Tensor, 106 | text_embeddings: Tensor, 107 | timestep: usize, 108 | ctx: &mut Context, 109 | ) -> anyhow::Result { 110 | // Pack the tensors to be sent into one 111 | let timestep_tensor = Tensor::from_slice(&[timestep as f32], 1, &ctx.device)?; 112 | 113 | let tensors = Vec::from([latent_model_input, text_embeddings, timestep_tensor]); 114 | 115 | let combined_tensor = pack_tensors(tensors, &ctx.device)?; 116 | forwarder.forward_mut(&combined_tensor, 0, 0, ctx).await 117 | } 118 | } 119 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/util.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::Context; 2 | use crate::StableDiffusionVersion; 3 | use anyhow::Result; 4 | use candle_core::utils::{cuda_is_available, metal_is_available}; 5 | use candle_core::{Device, Tensor}; 6 | use candle_transformers::models::stable_diffusion::StableDiffusionConfig; 7 | 8 | pub fn pack_tensors(tensors: Vec, device: &Device) -> Result { 9 | let num_tensors = tensors.len(); 10 | let mut prepared_tensors = Vec::from([Tensor::from_slice(&[num_tensors as f32], 1, device)?]); 11 | 12 | for tensor in tensors { 13 | let shape_info = tensor.shape().clone().into_dims(); 14 | 15 | let shape_info_f32 = shape_info.clone().into_iter().map(|x| x as f32).collect(); 16 | 17 | let shape_info_len = shape_info.len(); 18 | 19 | let flattened_tensor = tensor.flatten_all()?.to_dtype(candle_core::DType::F32)?; 20 | 21 | prepared_tensors.push(Tensor::from_slice(&[shape_info_len as f32], 1, device)?); 22 | prepared_tensors.push(Tensor::from_vec(shape_info_f32, shape_info_len, device)?); 23 | prepared_tensors.push(flattened_tensor); 24 | } 25 | 26 | Ok(Tensor::cat(&prepared_tensors, 0)?) 27 | } 28 | 29 | pub fn unpack_tensors(tensor: &Tensor) -> Result> { 30 | let mut unpacked_tensors: Vec = Vec::new(); 31 | 32 | let num_tensors: f32 = tensor.get(0)?.to_scalar()?; 33 | let num_tensors_i32 = num_tensors as i32; 34 | 35 | let mut idx: i32 = 1; 36 | 37 | for _i in 0..num_tensors_i32 { 38 | let shape_info_len: f32 = tensor.get(idx as usize)?.to_scalar()?; 39 | 40 | idx += 1; 41 | 42 | let shape_info: Vec = tensor 43 | .narrow(0, idx as usize, shape_info_len as usize)? 44 | .to_vec1()? 45 | .into_iter() 46 | .map(|x: f32| x as i32) 47 | .collect(); 48 | 49 | idx += shape_info_len as i32; 50 | 51 | let num_elements: i32 = shape_info.iter().product(); 52 | 53 | let shape_info_usize: Vec<_> = shape_info.iter().map(|&x| x as usize).collect(); 54 | 55 | let extracted = tensor 56 | .narrow(0, idx as usize, num_elements as usize)? 57 | .reshape(shape_info_usize)?; 58 | idx += num_elements; 59 | 60 | unpacked_tensors.push(extracted); 61 | } 62 | 63 | Ok(unpacked_tensors) 64 | } 65 | 66 | pub fn get_device(cpu: bool) -> Result { 67 | if cpu { 68 | Ok(Device::Cpu) 69 | } else if cuda_is_available() { 70 | Ok(Device::new_cuda(0)?) 71 | } else if metal_is_available() { 72 | Ok(Device::new_metal(0)?) 73 | } else { 74 | #[cfg(all(target_os = "macos", target_arch = "aarch64"))] 75 | { 76 | println!( 77 | "Running on CPU, to run on GPU(metal), build this example with `--features metal`" 78 | ); 79 | } 80 | #[cfg(not(all(target_os = "macos", target_arch = "aarch64")))] 81 | { 82 | println!("Running on CPU, to run on GPU, build this example with `--features cuda`"); 83 | } 84 | Ok(Device::Cpu) 85 | } 86 | } 87 | 88 | pub fn get_sd_config(ctx: &Context) -> Result { 89 | let height = ctx.args.sd_args.height; 90 | let width = ctx.args.sd_args.width; 91 | let sliced_attention_size = ctx.args.sd_args.sliced_attention_size; 92 | let sd_config = match ctx.args.sd_args.sd_version { 93 | StableDiffusionVersion::V1_5 => { 94 | StableDiffusionConfig::v1_5(sliced_attention_size, height, width) 95 | } 96 | StableDiffusionVersion::V2_1 => { 97 | StableDiffusionConfig::v2_1(sliced_attention_size, height, width) 98 | } 99 | StableDiffusionVersion::Xl => { 100 | StableDiffusionConfig::sdxl(sliced_attention_size, height, width) 101 | } 102 | StableDiffusionVersion::Turbo => StableDiffusionConfig::sdxl_turbo( 103 | ctx.args.sd_args.sliced_attention_size, 104 | ctx.args.sd_args.height, 105 | ctx.args.sd_args.width, 106 | ), 107 | }; 108 | Ok(sd_config) 109 | } 110 | -------------------------------------------------------------------------------- /cake-core/src/models/sd/vae.rs: -------------------------------------------------------------------------------- 1 | use crate::cake::{Context, Forwarder}; 2 | use crate::models::sd::util::{get_sd_config, pack_tensors, unpack_tensors}; 3 | use crate::models::sd::ModelFile; 4 | use crate::StableDiffusionVersion; 5 | use async_trait::async_trait; 6 | use candle_core::{DType, Device, Tensor}; 7 | use candle_transformers::models::stable_diffusion::vae::AutoEncoderKL; 8 | use candle_transformers::models::stable_diffusion::StableDiffusionConfig; 9 | use log::{debug, info}; 10 | use std::fmt::{Debug, Display, Formatter}; 11 | 12 | #[derive(Debug)] 13 | pub struct VAE { 14 | vae_model: AutoEncoderKL, 15 | } 16 | 17 | impl Display for VAE { 18 | fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { 19 | write!(f, "VAE (local)") 20 | } 21 | } 22 | 23 | #[async_trait] 24 | impl Forwarder for VAE { 25 | fn load(_name: String, ctx: &Context) -> anyhow::Result> 26 | where 27 | Self: Sized, 28 | { 29 | let sd_config = get_sd_config(ctx)?; 30 | 31 | Self::load_model( 32 | ctx.args.sd_args.vae.clone(), 33 | ctx.args.sd_args.sd_version, 34 | ctx.args.sd_args.use_f16, 35 | &ctx.device, 36 | ctx.dtype, 37 | ctx.args.model.clone(), 38 | &sd_config, 39 | ) 40 | } 41 | 42 | async fn forward( 43 | &self, 44 | x: &Tensor, 45 | _index_pos: usize, 46 | _block_idx: usize, 47 | ctx: &mut Context, 48 | ) -> anyhow::Result { 49 | info!("VAE model forwarding..."); 50 | 51 | let unpacked_tensors = unpack_tensors(x)?; 52 | 53 | let direction_tensor = &unpacked_tensors[0]; 54 | let direction_vec = direction_tensor.to_vec1()?; 55 | let direction_f32: f32 = *direction_vec 56 | .first() 57 | .expect("Error retrieving direction info"); 58 | 59 | let input = &unpacked_tensors[1].to_dtype(ctx.dtype)?; 60 | 61 | debug!("VAE tensors decoded."); 62 | 63 | if direction_f32 == 1.0 { 64 | let dist = self.vae_model.encode(input)?; 65 | Ok(dist.sample()?) 66 | } else { 67 | Ok(self.vae_model.decode(input)?) 68 | } 69 | } 70 | 71 | async fn forward_mut( 72 | &mut self, 73 | x: &Tensor, 74 | index_pos: usize, 75 | block_idx: usize, 76 | ctx: &mut Context, 77 | ) -> anyhow::Result { 78 | self.forward(x, index_pos, block_idx, ctx).await 79 | } 80 | 81 | fn layer_name(&self) -> &str { 82 | "vae" 83 | } 84 | } 85 | 86 | impl VAE { 87 | pub fn load_model( 88 | name: Option, 89 | version: StableDiffusionVersion, 90 | use_f16: bool, 91 | device: &Device, 92 | dtype: DType, 93 | cache_dir: String, 94 | config: &StableDiffusionConfig, 95 | ) -> anyhow::Result> 96 | where 97 | Self: Sized, 98 | { 99 | let vae_weights = ModelFile::Vae.get(name, version, use_f16, cache_dir)?; 100 | let vae_model = config.build_vae(vae_weights, device, dtype)?; 101 | 102 | info!("Loading VAE model..."); 103 | 104 | Ok(Box::new(Self { vae_model })) 105 | } 106 | 107 | pub async fn encode( 108 | forwarder: &mut Box, 109 | image: Tensor, 110 | ctx: &mut Context, 111 | ) -> anyhow::Result { 112 | let tensors = Vec::from([Tensor::from_slice(&[1f32], 1, &ctx.device)?, image]); 113 | 114 | let combined_tensor = pack_tensors(tensors, &ctx.device)?; 115 | 116 | forwarder.forward_mut(&combined_tensor, 0, 0, ctx).await 117 | } 118 | 119 | pub async fn decode( 120 | forwarder: &mut Box, 121 | latents: Tensor, 122 | ctx: &mut Context, 123 | ) -> anyhow::Result { 124 | let tensors = Vec::from([Tensor::from_slice(&[0f32], 1, &ctx.device)?, latents]); 125 | 126 | let combined_tensor = pack_tensors(tensors, &ctx.device)?; 127 | 128 | let result = forwarder.forward_mut(&combined_tensor, 0, 0, ctx).await?; 129 | Ok(result) 130 | } 131 | } 132 | -------------------------------------------------------------------------------- /cake-core/src/utils/mod.rs: -------------------------------------------------------------------------------- 1 | //! Utility functions and abstractions. 2 | 3 | use std::path::{Path, PathBuf}; 4 | 5 | use candle_core::{ 6 | utils::{cuda_is_available, metal_is_available}, 7 | DType, Device, Tensor, 8 | }; 9 | 10 | use anyhow::{bail, Result}; 11 | 12 | use candle_nn::VarBuilder; 13 | 14 | /// Returns the best available device at `ordinal` index (in case of multiple GPUs), or CPU if `force_cpu` is true. 15 | pub fn get_inference_device(force_cpu: bool, ordinal: usize) -> Result { 16 | if force_cpu { 17 | log::debug!("device is forced cpu"); 18 | Ok(Device::Cpu) 19 | } else if cuda_is_available() { 20 | log::debug!("device is cuda {ordinal}"); 21 | Ok(Device::new_cuda(ordinal)?) 22 | } else if metal_is_available() { 23 | log::debug!("device is metal {ordinal}"); 24 | Ok(Device::new_metal(ordinal)?) 25 | } else { 26 | log::debug!("device is cpu"); 27 | // fallback to cpu if nothing else available 28 | Ok(Device::Cpu) 29 | } 30 | } 31 | 32 | pub fn load_safetensors_from_model(path: &Path) -> Result> { 33 | log::info!("loading tensors from {} ...", "model.safetensors"); 34 | let result = vec![path.join("model.safetensors")]; 35 | Ok(result) 36 | } 37 | 38 | /// Load the safetensors files for a model from the hub based on a json index file. 39 | pub fn load_safetensors_paths_from_index( 40 | tensors_index_json_filename: PathBuf, 41 | ) -> Result> { 42 | log::info!( 43 | "loading tensors from {} ...", 44 | tensors_index_json_filename.display() 45 | ); 46 | 47 | let parent_dir = tensors_index_json_filename.parent().unwrap(); 48 | let json_file = std::fs::File::open(&tensors_index_json_filename).map_err(|e| { 49 | anyhow!( 50 | "can't open {}: {:?}", 51 | tensors_index_json_filename.display(), 52 | e 53 | ) 54 | })?; 55 | let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(|e| { 56 | anyhow!( 57 | "can't parse {}: {:?}", 58 | tensors_index_json_filename.display(), 59 | e 60 | ) 61 | })?; 62 | let weight_map = match json.get("weight_map") { 63 | None => bail!("no weight map in {json_file:?}"), 64 | Some(serde_json::Value::Object(map)) => map, 65 | Some(_) => bail!("weight map in {json_file:?} is not a map"), 66 | }; 67 | let mut safetensors_files = std::collections::HashSet::new(); 68 | for value in weight_map.values() { 69 | if let Some(file) = value.as_str() { 70 | safetensors_files.insert(file.to_string()); 71 | } 72 | } 73 | let safetensors_files = safetensors_files 74 | .iter() 75 | .map(|v| parent_dir.join(v)) 76 | .collect::>(); 77 | 78 | Ok(safetensors_files) 79 | } 80 | 81 | /// Create a VarBuilder with the tensors loaded from the index. 82 | pub fn load_var_builder_from_index<'a>( 83 | tensor_index: PathBuf, 84 | dtype: DType, 85 | device: Device, 86 | ) -> Result> { 87 | let filenames: Vec = if tensor_index.exists() { 88 | load_safetensors_paths_from_index(tensor_index) 89 | .map_err(|e| anyhow!("can't load tensors index: {:?}", e))? 90 | } else { 91 | load_safetensors_from_model(tensor_index.parent().unwrap()) 92 | .map_err(|e| anyhow!("can't load tensors index: {:?}", e))? 93 | }; 94 | 95 | unsafe { 96 | VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device) 97 | .map_err(|e| anyhow!("can't create varbuilder from tensors: {:?}", e)) 98 | } 99 | } 100 | 101 | /// Nasty hack to debug NaN in tensors. 102 | #[allow(dead_code)] 103 | pub(crate) fn panic_on_nan(t: &Tensor, name: &str) { 104 | if t.to_string().contains("NaN") { 105 | panic!("\ntensor '{name}' contains NaN: \n{t}"); 106 | } 107 | } 108 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.developer.kernel.extended-virtual-addressing 8 | 9 | com.apple.security.device.usb 10 | 11 | com.apple.security.files.user-selected.read-only 12 | 13 | com.apple.security.network.server 14 | 15 | com.apple.security.network.client 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker.xcodeproj/project.pbxproj: -------------------------------------------------------------------------------- 1 | // !$*UTF8*$! 2 | { 3 | archiveVersion = 1; 4 | classes = { 5 | }; 6 | objectVersion = 56; 7 | objects = { 8 | 9 | /* Begin PBXBuildFile section */ 10 | 8830A1182C3B019800BA42B5 /* Cake_WorkerApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8830A1172C3B019800BA42B5 /* Cake_WorkerApp.swift */; }; 11 | 8830A11A2C3B019800BA42B5 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8830A1192C3B019800BA42B5 /* ContentView.swift */; }; 12 | 8830A11C2C3B019900BA42B5 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8830A11B2C3B019900BA42B5 /* Assets.xcassets */; }; 13 | 8830A11F2C3B019900BA42B5 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 8830A11E2C3B019900BA42B5 /* Preview Assets.xcassets */; }; 14 | 8876468D2C3B0261007A32E0 /* Cake.swift in Sources */ = {isa = PBXBuildFile; fileRef = 8876468C2C3B0261007A32E0 /* Cake.swift */; }; 15 | 8876468F2C3B05DB007A32E0 /* Cake.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 8876468E2C3B05DB007A32E0 /* Cake.xcframework */; }; 16 | 887646942C3B07A7007A32E0 /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 887646932C3B07A7007A32E0 /* Metal.framework */; }; 17 | 887646972C3B08B3007A32E0 /* MetalPerformanceShadersGraph.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 887646952C3B08B3007A32E0 /* MetalPerformanceShadersGraph.framework */; }; 18 | 887646982C3B08B3007A32E0 /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 887646962C3B08B3007A32E0 /* MetalPerformanceShaders.framework */; }; 19 | /* End PBXBuildFile section */ 20 | 21 | /* Begin PBXFileReference section */ 22 | 8830A1142C3B019800BA42B5 /* Cake Worker.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "Cake Worker.app"; sourceTree = BUILT_PRODUCTS_DIR; }; 23 | 8830A1172C3B019800BA42B5 /* Cake_WorkerApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Cake_WorkerApp.swift; sourceTree = ""; }; 24 | 8830A1192C3B019800BA42B5 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; 25 | 8830A11B2C3B019900BA42B5 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; 26 | 8830A11E2C3B019900BA42B5 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; 27 | 8876468C2C3B0261007A32E0 /* Cake.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = Cake.swift; sourceTree = ""; }; 28 | 8876468E2C3B05DB007A32E0 /* Cake.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; path = Cake.xcframework; sourceTree = SOURCE_ROOT; }; 29 | 887646912C3B078F007A32E0 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; 30 | 887646932C3B07A7007A32E0 /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; }; 31 | 887646952C3B08B3007A32E0 /* MetalPerformanceShadersGraph.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShadersGraph.framework; path = System/Library/Frameworks/MetalPerformanceShadersGraph.framework; sourceTree = SDKROOT; }; 32 | 887646962C3B08B3007A32E0 /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = SDKROOT; }; 33 | 88811B6C2C45F2AF0010390A /* Cake Worker.entitlements */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.entitlements; path = "Cake Worker.entitlements"; sourceTree = SOURCE_ROOT; }; 34 | /* End PBXFileReference section */ 35 | 36 | /* Begin PBXFrameworksBuildPhase section */ 37 | 8830A1112C3B019800BA42B5 /* Frameworks */ = { 38 | isa = PBXFrameworksBuildPhase; 39 | buildActionMask = 2147483647; 40 | files = ( 41 | 887646982C3B08B3007A32E0 /* MetalPerformanceShaders.framework in Frameworks */, 42 | 8876468F2C3B05DB007A32E0 /* Cake.xcframework in Frameworks */, 43 | 887646942C3B07A7007A32E0 /* Metal.framework in Frameworks */, 44 | 887646972C3B08B3007A32E0 /* MetalPerformanceShadersGraph.framework in Frameworks */, 45 | ); 46 | runOnlyForDeploymentPostprocessing = 0; 47 | }; 48 | /* End PBXFrameworksBuildPhase section */ 49 | 50 | /* Begin PBXGroup section */ 51 | 8830A10B2C3B019800BA42B5 = { 52 | isa = PBXGroup; 53 | children = ( 54 | 8830A1162C3B019800BA42B5 /* Cake Worker */, 55 | 8830A1152C3B019800BA42B5 /* Products */, 56 | 887646902C3B078E007A32E0 /* Frameworks */, 57 | ); 58 | sourceTree = ""; 59 | }; 60 | 8830A1152C3B019800BA42B5 /* Products */ = { 61 | isa = PBXGroup; 62 | children = ( 63 | 8830A1142C3B019800BA42B5 /* Cake Worker.app */, 64 | ); 65 | name = Products; 66 | sourceTree = ""; 67 | }; 68 | 8830A1162C3B019800BA42B5 /* Cake Worker */ = { 69 | isa = PBXGroup; 70 | children = ( 71 | 88811B6C2C45F2AF0010390A /* Cake Worker.entitlements */, 72 | 8876468E2C3B05DB007A32E0 /* Cake.xcframework */, 73 | 8830A1172C3B019800BA42B5 /* Cake_WorkerApp.swift */, 74 | 8876468C2C3B0261007A32E0 /* Cake.swift */, 75 | 8830A1192C3B019800BA42B5 /* ContentView.swift */, 76 | 8830A11B2C3B019900BA42B5 /* Assets.xcassets */, 77 | 8830A11D2C3B019900BA42B5 /* Preview Content */, 78 | ); 79 | path = "Cake Worker"; 80 | sourceTree = ""; 81 | }; 82 | 8830A11D2C3B019900BA42B5 /* Preview Content */ = { 83 | isa = PBXGroup; 84 | children = ( 85 | 8830A11E2C3B019900BA42B5 /* Preview Assets.xcassets */, 86 | ); 87 | path = "Preview Content"; 88 | sourceTree = ""; 89 | }; 90 | 887646902C3B078E007A32E0 /* Frameworks */ = { 91 | isa = PBXGroup; 92 | children = ( 93 | 887646962C3B08B3007A32E0 /* MetalPerformanceShaders.framework */, 94 | 887646952C3B08B3007A32E0 /* MetalPerformanceShadersGraph.framework */, 95 | 887646932C3B07A7007A32E0 /* Metal.framework */, 96 | 887646912C3B078F007A32E0 /* Accelerate.framework */, 97 | ); 98 | name = Frameworks; 99 | sourceTree = ""; 100 | }; 101 | /* End PBXGroup section */ 102 | 103 | /* Begin PBXNativeTarget section */ 104 | 8830A1132C3B019800BA42B5 /* Cake Worker */ = { 105 | isa = PBXNativeTarget; 106 | buildConfigurationList = 8830A1222C3B019900BA42B5 /* Build configuration list for PBXNativeTarget "Cake Worker" */; 107 | buildPhases = ( 108 | 8830A1102C3B019800BA42B5 /* Sources */, 109 | 8830A1112C3B019800BA42B5 /* Frameworks */, 110 | 8830A1122C3B019800BA42B5 /* Resources */, 111 | ); 112 | buildRules = ( 113 | ); 114 | dependencies = ( 115 | ); 116 | name = "Cake Worker"; 117 | productName = "Cake Worker"; 118 | productReference = 8830A1142C3B019800BA42B5 /* Cake Worker.app */; 119 | productType = "com.apple.product-type.application"; 120 | }; 121 | /* End PBXNativeTarget section */ 122 | 123 | /* Begin PBXProject section */ 124 | 8830A10C2C3B019800BA42B5 /* Project object */ = { 125 | isa = PBXProject; 126 | attributes = { 127 | BuildIndependentTargetsInParallel = 1; 128 | LastSwiftUpdateCheck = 1540; 129 | LastUpgradeCheck = 1540; 130 | TargetAttributes = { 131 | 8830A1132C3B019800BA42B5 = { 132 | CreatedOnToolsVersion = 15.4; 133 | }; 134 | }; 135 | }; 136 | buildConfigurationList = 8830A10F2C3B019800BA42B5 /* Build configuration list for PBXProject "Cake Worker" */; 137 | compatibilityVersion = "Xcode 14.0"; 138 | developmentRegion = en; 139 | hasScannedForEncodings = 0; 140 | knownRegions = ( 141 | en, 142 | Base, 143 | ); 144 | mainGroup = 8830A10B2C3B019800BA42B5; 145 | productRefGroup = 8830A1152C3B019800BA42B5 /* Products */; 146 | projectDirPath = ""; 147 | projectRoot = ""; 148 | targets = ( 149 | 8830A1132C3B019800BA42B5 /* Cake Worker */, 150 | ); 151 | }; 152 | /* End PBXProject section */ 153 | 154 | /* Begin PBXResourcesBuildPhase section */ 155 | 8830A1122C3B019800BA42B5 /* Resources */ = { 156 | isa = PBXResourcesBuildPhase; 157 | buildActionMask = 2147483647; 158 | files = ( 159 | 8830A11F2C3B019900BA42B5 /* Preview Assets.xcassets in Resources */, 160 | 8830A11C2C3B019900BA42B5 /* Assets.xcassets in Resources */, 161 | ); 162 | runOnlyForDeploymentPostprocessing = 0; 163 | }; 164 | /* End PBXResourcesBuildPhase section */ 165 | 166 | /* Begin PBXSourcesBuildPhase section */ 167 | 8830A1102C3B019800BA42B5 /* Sources */ = { 168 | isa = PBXSourcesBuildPhase; 169 | buildActionMask = 2147483647; 170 | files = ( 171 | 8830A11A2C3B019800BA42B5 /* ContentView.swift in Sources */, 172 | 8876468D2C3B0261007A32E0 /* Cake.swift in Sources */, 173 | 8830A1182C3B019800BA42B5 /* Cake_WorkerApp.swift in Sources */, 174 | ); 175 | runOnlyForDeploymentPostprocessing = 0; 176 | }; 177 | /* End PBXSourcesBuildPhase section */ 178 | 179 | /* Begin XCBuildConfiguration section */ 180 | 8830A1202C3B019900BA42B5 /* Debug */ = { 181 | isa = XCBuildConfiguration; 182 | buildSettings = { 183 | ALWAYS_SEARCH_USER_PATHS = NO; 184 | ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; 185 | CLANG_ANALYZER_NONNULL = YES; 186 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 187 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; 188 | CLANG_ENABLE_MODULES = YES; 189 | CLANG_ENABLE_OBJC_ARC = YES; 190 | CLANG_ENABLE_OBJC_WEAK = YES; 191 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 192 | CLANG_WARN_BOOL_CONVERSION = YES; 193 | CLANG_WARN_COMMA = YES; 194 | CLANG_WARN_CONSTANT_CONVERSION = YES; 195 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 196 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 197 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 198 | CLANG_WARN_EMPTY_BODY = YES; 199 | CLANG_WARN_ENUM_CONVERSION = YES; 200 | CLANG_WARN_INFINITE_RECURSION = YES; 201 | CLANG_WARN_INT_CONVERSION = YES; 202 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 203 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 204 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 205 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 206 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 207 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 208 | CLANG_WARN_STRICT_PROTOTYPES = YES; 209 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 210 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 211 | CLANG_WARN_UNREACHABLE_CODE = YES; 212 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 213 | COPY_PHASE_STRIP = NO; 214 | DEBUG_INFORMATION_FORMAT = dwarf; 215 | ENABLE_STRICT_OBJC_MSGSEND = YES; 216 | ENABLE_TESTABILITY = YES; 217 | ENABLE_USER_SCRIPT_SANDBOXING = YES; 218 | GCC_C_LANGUAGE_STANDARD = gnu17; 219 | GCC_DYNAMIC_NO_PIC = NO; 220 | GCC_NO_COMMON_BLOCKS = YES; 221 | GCC_OPTIMIZATION_LEVEL = 0; 222 | GCC_PREPROCESSOR_DEFINITIONS = ( 223 | "DEBUG=1", 224 | "$(inherited)", 225 | ); 226 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 227 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 228 | GCC_WARN_UNDECLARED_SELECTOR = YES; 229 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 230 | GCC_WARN_UNUSED_FUNCTION = YES; 231 | GCC_WARN_UNUSED_VARIABLE = YES; 232 | IPHONEOS_DEPLOYMENT_TARGET = 17.5; 233 | LOCALIZATION_PREFERS_STRING_CATALOGS = YES; 234 | MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; 235 | MTL_FAST_MATH = YES; 236 | ONLY_ACTIVE_ARCH = YES; 237 | OTHER_LDFLAGS = "-lc++"; 238 | SDKROOT = iphoneos; 239 | SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; 240 | SWIFT_OPTIMIZATION_LEVEL = "-Onone"; 241 | }; 242 | name = Debug; 243 | }; 244 | 8830A1212C3B019900BA42B5 /* Release */ = { 245 | isa = XCBuildConfiguration; 246 | buildSettings = { 247 | ALWAYS_SEARCH_USER_PATHS = NO; 248 | ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; 249 | CLANG_ANALYZER_NONNULL = YES; 250 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 251 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; 252 | CLANG_ENABLE_MODULES = YES; 253 | CLANG_ENABLE_OBJC_ARC = YES; 254 | CLANG_ENABLE_OBJC_WEAK = YES; 255 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 256 | CLANG_WARN_BOOL_CONVERSION = YES; 257 | CLANG_WARN_COMMA = YES; 258 | CLANG_WARN_CONSTANT_CONVERSION = YES; 259 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 260 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 261 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 262 | CLANG_WARN_EMPTY_BODY = YES; 263 | CLANG_WARN_ENUM_CONVERSION = YES; 264 | CLANG_WARN_INFINITE_RECURSION = YES; 265 | CLANG_WARN_INT_CONVERSION = YES; 266 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 267 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 268 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 269 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 270 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 271 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 272 | CLANG_WARN_STRICT_PROTOTYPES = YES; 273 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 274 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 275 | CLANG_WARN_UNREACHABLE_CODE = YES; 276 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 277 | COPY_PHASE_STRIP = NO; 278 | DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; 279 | ENABLE_NS_ASSERTIONS = NO; 280 | ENABLE_STRICT_OBJC_MSGSEND = YES; 281 | ENABLE_USER_SCRIPT_SANDBOXING = YES; 282 | GCC_C_LANGUAGE_STANDARD = gnu17; 283 | GCC_NO_COMMON_BLOCKS = YES; 284 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 285 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 286 | GCC_WARN_UNDECLARED_SELECTOR = YES; 287 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 288 | GCC_WARN_UNUSED_FUNCTION = YES; 289 | GCC_WARN_UNUSED_VARIABLE = YES; 290 | IPHONEOS_DEPLOYMENT_TARGET = 17.5; 291 | LOCALIZATION_PREFERS_STRING_CATALOGS = YES; 292 | MTL_ENABLE_DEBUG_INFO = NO; 293 | MTL_FAST_MATH = YES; 294 | OTHER_LDFLAGS = "-lc++"; 295 | SDKROOT = iphoneos; 296 | SWIFT_COMPILATION_MODE = wholemodule; 297 | VALIDATE_PRODUCT = YES; 298 | }; 299 | name = Release; 300 | }; 301 | 8830A1232C3B019900BA42B5 /* Debug */ = { 302 | isa = XCBuildConfiguration; 303 | buildSettings = { 304 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 305 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 306 | CODE_SIGN_ENTITLEMENTS = "$(SRCROOT)/Cake Worker.entitlements"; 307 | "CODE_SIGN_ENTITLEMENTS[sdk=*]" = ""; 308 | CODE_SIGN_STYLE = Automatic; 309 | CURRENT_PROJECT_VERSION = 1; 310 | DEVELOPMENT_ASSET_PATHS = "\"Cake Worker/Preview Content\""; 311 | DEVELOPMENT_TEAM = XKB2P2J9BH; 312 | ENABLE_PREVIEWS = YES; 313 | GENERATE_INFOPLIST_FILE = YES; 314 | INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; 315 | INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; 316 | INFOPLIST_KEY_UILaunchScreen_Generation = YES; 317 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 318 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 319 | LD_RUNPATH_SEARCH_PATHS = ( 320 | "$(inherited)", 321 | "@executable_path/Frameworks", 322 | ); 323 | MARKETING_VERSION = 1.0; 324 | PRODUCT_BUNDLE_IDENTIFIER = "cake.Cake-Worker"; 325 | PRODUCT_NAME = "$(TARGET_NAME)"; 326 | SWIFT_EMIT_LOC_STRINGS = YES; 327 | SWIFT_VERSION = 5.0; 328 | TARGETED_DEVICE_FAMILY = "1,2"; 329 | }; 330 | name = Debug; 331 | }; 332 | 8830A1242C3B019900BA42B5 /* Release */ = { 333 | isa = XCBuildConfiguration; 334 | buildSettings = { 335 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 336 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 337 | CODE_SIGN_ENTITLEMENTS = "$(SRCROOT)/Cake Worker.entitlements"; 338 | CODE_SIGN_STYLE = Automatic; 339 | CURRENT_PROJECT_VERSION = 1; 340 | DEVELOPMENT_ASSET_PATHS = "\"Cake Worker/Preview Content\""; 341 | DEVELOPMENT_TEAM = XKB2P2J9BH; 342 | ENABLE_PREVIEWS = YES; 343 | GENERATE_INFOPLIST_FILE = YES; 344 | INFOPLIST_KEY_UIApplicationSceneManifest_Generation = YES; 345 | INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents = YES; 346 | INFOPLIST_KEY_UILaunchScreen_Generation = YES; 347 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 348 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 349 | LD_RUNPATH_SEARCH_PATHS = ( 350 | "$(inherited)", 351 | "@executable_path/Frameworks", 352 | ); 353 | MARKETING_VERSION = 1.0; 354 | PRODUCT_BUNDLE_IDENTIFIER = "cake.Cake-Worker"; 355 | PRODUCT_NAME = "$(TARGET_NAME)"; 356 | SWIFT_EMIT_LOC_STRINGS = YES; 357 | SWIFT_VERSION = 5.0; 358 | TARGETED_DEVICE_FAMILY = "1,2"; 359 | }; 360 | name = Release; 361 | }; 362 | /* End XCBuildConfiguration section */ 363 | 364 | /* Begin XCConfigurationList section */ 365 | 8830A10F2C3B019800BA42B5 /* Build configuration list for PBXProject "Cake Worker" */ = { 366 | isa = XCConfigurationList; 367 | buildConfigurations = ( 368 | 8830A1202C3B019900BA42B5 /* Debug */, 369 | 8830A1212C3B019900BA42B5 /* Release */, 370 | ); 371 | defaultConfigurationIsVisible = 0; 372 | defaultConfigurationName = Release; 373 | }; 374 | 8830A1222C3B019900BA42B5 /* Build configuration list for PBXNativeTarget "Cake Worker" */ = { 375 | isa = XCConfigurationList; 376 | buildConfigurations = ( 377 | 8830A1232C3B019900BA42B5 /* Debug */, 378 | 8830A1242C3B019900BA42B5 /* Release */, 379 | ); 380 | defaultConfigurationIsVisible = 0; 381 | defaultConfigurationName = Release; 382 | }; 383 | /* End XCConfigurationList section */ 384 | }; 385 | rootObject = 8830A10C2C3B019800BA42B5 /* Project object */; 386 | } 387 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker.xcodeproj/project.xcworkspace/xcuserdata/evilsocket.xcuserdatad/UserInterfaceState.xcuserstate: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/evilsocket/cake/afb792fd13169f49c2e53eb567fe40b6d0fae20c/cake-ios-worker-app/Cake Worker.xcodeproj/project.xcworkspace/xcuserdata/evilsocket.xcuserdatad/UserInterfaceState.xcuserstate -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker.xcodeproj/xcuserdata/evilsocket.xcuserdatad/xcschemes/xcschememanagement.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SchemeUserState 6 | 7 | Cake Worker.xcscheme_^#shared#^_ 8 | 9 | orderHint 10 | 0 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | } 8 | ], 9 | "info" : { 10 | "author" : "xcode", 11 | "version" : 1 12 | } 13 | } 14 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker/Cake_WorkerApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Cake_WorkerApp.swift 3 | // Cake Worker 4 | // 5 | // Created by Simone Margaritelli on 07/07/24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @main 11 | struct Cake_WorkerApp: App { 12 | var body: some Scene { 13 | WindowGroup { 14 | ContentView() 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ContentView.swift 3 | // Cake Worker 4 | // 5 | // Created by Simone Margaritelli on 07/07/24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ContentView: View { 11 | @State private var showActionSheet = false 12 | @State private var buttonTitle: String = "Run Node" 13 | @State private var selectedModelType: String = "text" 14 | 15 | var body: some View { 16 | VStack { 17 | Image(systemName: "brain") 18 | .imageScale(.large) 19 | .foregroundStyle(.tint) 20 | Picker("Model type", selection: $selectedModelType) { 21 | Text("Text Model").tag("text") 22 | Text("Image Model").tag("image") 23 | }.padding() 24 | Button(buttonTitle) { 25 | showActionSheet = true 26 | buttonTitle = "Running ..." 27 | } 28 | .buttonStyle(.borderless) 29 | .controlSize(.large) 30 | .fileImporter(isPresented: $showActionSheet, allowedContentTypes: [.folder]) { result in 31 | switch result { 32 | case .success(let directory): 33 | // print("using \(directory)"); 34 | 35 | if directory.startAccessingSecurityScopedResource() { 36 | defer { 37 | print("revoking access"); 38 | directory.stopAccessingSecurityScopedResource() 39 | } 40 | 41 | let basePath = directory.path(); 42 | let topologyPath = basePath + "topology.yml"; 43 | let modelPath = basePath + "model"; 44 | 45 | // print(" topologyPath=\(topologyPath)"); 46 | // print(" modelPath=\(modelPath)"); 47 | 48 | print("Model type: \(selectedModelType)"); 49 | 50 | startWorker(name:"iphone", modelPath: modelPath, topologyPath: topologyPath, modelType: selectedModelType) 51 | } else { 52 | print("access denied to \(directory)"); 53 | } 54 | 55 | case .failure(let error): 56 | print(error) 57 | } 58 | } 59 | } 60 | .padding() 61 | } 62 | } 63 | 64 | #Preview { 65 | ContentView() 66 | } 67 | -------------------------------------------------------------------------------- /cake-ios-worker-app/Cake Worker/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /cake-ios/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cake-ios" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [lib] 7 | crate-type = ["cdylib", "staticlib"] 8 | name = "cake" 9 | 10 | [dependencies] 11 | anyhow = "1.0.86" 12 | uniffi = { version = "0.28.0", features = ["cli", "tokio"] } 13 | env_logger = "0.11.3" 14 | log = "0.4.22" 15 | tokio = "1.38.0" 16 | 17 | # we don't need master code here 18 | cake-core = { path = "../cake-core", default-features = false } 19 | -------------------------------------------------------------------------------- /cake-ios/src/bin/uniffi-bindgen.rs: -------------------------------------------------------------------------------- 1 | fn main() { 2 | uniffi::uniffi_bindgen_main() 3 | } 4 | -------------------------------------------------------------------------------- /cake-ios/src/lib.rs: -------------------------------------------------------------------------------- 1 | //! This is a small library that wraps cake-core and exposes it as an API to the Swift side of things on iOS. 2 | uniffi::setup_scaffolding!(); 3 | 4 | use cake_core::{ 5 | cake::{Context, Mode, Worker}, 6 | Args, ModelType, 7 | }; 8 | 9 | #[uniffi::export] 10 | pub fn start_worker(name: String, model_path: String, topology_path: String, model_type: String) { 11 | env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("debug")).init(); 12 | 13 | log::debug!("@ creating context"); 14 | 15 | log::debug!("@ model type: {model_type}"); 16 | 17 | let model_type_arg = match model_type.as_str() { 18 | "text" => ModelType::TextModel, 19 | "image" => ModelType::ImageModel, 20 | _ => panic!("Unrecognized model type"), 21 | }; 22 | 23 | let args = Args { 24 | address: "0.0.0.0:10128".to_string(), 25 | mode: Mode::Worker, 26 | name: Some(name), 27 | model: model_path, 28 | topology: Some(topology_path), 29 | model_type: model_type_arg, 30 | ..Default::default() 31 | }; 32 | 33 | let mut ctx = match Context::from_args(args) { 34 | Ok(ctx) => ctx, 35 | Err(e) => { 36 | log::error!("ERROR: {}", e); 37 | return; 38 | } 39 | }; 40 | 41 | tokio::runtime::Builder::new_multi_thread() 42 | .enable_all() 43 | .build() 44 | .unwrap() 45 | .block_on(async { 46 | log::debug!("@ creating worker"); 47 | 48 | match model_type.as_str() { 49 | "text" => { 50 | let mut worker = 51 | match Worker::::new(&mut ctx).await { 52 | Ok(w) => w, 53 | Err(e) => { 54 | log::error!("ERROR: {}", e); 55 | return; 56 | } 57 | }; 58 | 59 | log::info!("@ running worker for text model..."); 60 | 61 | match worker.run().await { 62 | Ok(_) => log::info!("worker exiting"), 63 | Err(e) => { 64 | log::error!("ERROR: {}", e); 65 | } 66 | } 67 | } 68 | "image" => { 69 | let mut worker = match Worker::::new(&mut ctx).await 70 | { 71 | Ok(w) => w, 72 | Err(e) => { 73 | log::error!("ERROR: {}", e); 74 | return; 75 | } 76 | }; 77 | 78 | log::info!("@ running worker for image model..."); 79 | 80 | match worker.run().await { 81 | Ok(_) => log::info!("worker exiting"), 82 | Err(e) => { 83 | log::error!("ERROR: {}", e); 84 | } 85 | } 86 | } 87 | _ => { 88 | log::error!("ERROR: unrecognized model type"); 89 | } 90 | } 91 | }) 92 | } 93 | -------------------------------------------------------------------------------- /cake-split-model/Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "cake-split-model" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | clap = { version = "4.5.8", features = ["derive"] } 8 | safetensors = "0.4.3" 9 | cake-core = { path = "../cake-core" } 10 | serde = { version = "1.0.204", features = ["derive"] } 11 | serde_json = "1.0.120" 12 | memmap2 = "0.9.4" 13 | candle-core = "0.6.0" 14 | candle-nn = "0.6.0" 15 | anyhow = "1.0.86" 16 | serde_yaml = "0.9.34" 17 | -------------------------------------------------------------------------------- /cake-split-model/src/main.rs: -------------------------------------------------------------------------------- 1 | //! This is a utility to split a single, safetensors based model into parts 2 | //! that are smaller and can be distributed to the workers instead of the entire model. 3 | use std::{ 4 | collections::HashMap, 5 | fs::File, 6 | path::{Path, PathBuf}, 7 | }; 8 | 9 | use anyhow::Result; 10 | use cake_core::{ 11 | cake::{Node, Topology}, 12 | utils, ModelType, 13 | }; 14 | use clap::Parser; 15 | use safetensors::{Dtype, SafeTensors, View}; 16 | use serde::{Deserialize, Serialize}; 17 | 18 | #[derive(Debug, Serialize, Deserialize)] 19 | struct Index { 20 | pub weight_map: HashMap, 21 | } 22 | 23 | impl Index { 24 | pub fn new() -> Self { 25 | let weight_map = HashMap::new(); 26 | Self { weight_map } 27 | } 28 | } 29 | 30 | #[derive(Debug)] 31 | struct TensorStore { 32 | dtype: Dtype, 33 | shape: Vec, 34 | data: Vec, 35 | } 36 | 37 | impl View for TensorStore { 38 | fn dtype(&self) -> Dtype { 39 | self.dtype 40 | } 41 | 42 | fn shape(&self) -> &[usize] { 43 | &self.shape 44 | } 45 | 46 | fn data(&self) -> std::borrow::Cow<[u8]> { 47 | std::borrow::Cow::from(&self.data) 48 | } 49 | 50 | fn data_len(&self) -> usize { 51 | self.data.len() 52 | } 53 | } 54 | 55 | #[derive(Parser, Default, Debug)] 56 | #[command(author, version, about, long_about = None)] 57 | pub struct Args { 58 | /// Input model path. 59 | #[arg(long, default_value = "./cake-data/Meta-Llama-3-8B/")] 60 | pub model_path: String, 61 | /// Topology file. 62 | #[arg(long, default_value = "./cake-data/topology.yml")] 63 | pub topology: String, 64 | /// Worker name or empty for all. 65 | #[arg(long)] 66 | pub worker: Option, 67 | /// Output folder. 68 | #[arg(long)] 69 | pub output: String, 70 | } 71 | 72 | fn load_index(data_path: &Path) -> Result { 73 | let tensors_index_path = data_path.join("model.safetensors.index.json"); 74 | let tensors_index_data = std::fs::read_to_string(tensors_index_path)?; 75 | let tensors_index: Index = serde_json::from_str(&tensors_index_data)?; 76 | 77 | Ok(tensors_index) 78 | } 79 | 80 | fn reduce_for_worker( 81 | index: &Index, 82 | worker: &Node, 83 | ) -> Result<(Index, HashMap>)> { 84 | println!("worker: {}", &worker.host); 85 | 86 | let mut reduced: HashMap> = HashMap::new(); 87 | let mut new_index = Index::new(); 88 | 89 | for (layer_full_name, filename) in &index.weight_map { 90 | if worker.is_text_model_layer_owner(layer_full_name) { 91 | //println!("{} {}", layer_full_name, filename); 92 | if let Some(layers) = reduced.get_mut(filename) { 93 | layers.push(layer_full_name.to_string()); 94 | } else { 95 | reduced.insert(filename.to_string(), vec![layer_full_name.to_string()]); 96 | } 97 | 98 | new_index.weight_map.insert( 99 | layer_full_name.to_string(), 100 | "reduced.safetensors".to_string(), 101 | ); 102 | } 103 | } 104 | 105 | Ok((new_index, reduced)) 106 | } 107 | 108 | fn create_new_metadata( 109 | data_path: &Path, 110 | reduced: &HashMap>, 111 | ) -> Result> { 112 | let mut metadata: HashMap = HashMap::new(); 113 | 114 | for (filename, tensor_names) in reduced { 115 | let filepath = data_path.join(filename); 116 | 117 | println!("loading {} ...", filepath.display()); 118 | 119 | let file = File::open(&filepath).unwrap(); 120 | let buffer = unsafe { memmap2::MmapOptions::new().map(&file).unwrap() }; 121 | let tensors = SafeTensors::deserialize(&buffer).unwrap(); 122 | 123 | println!(" extracting {} tensors", tensor_names.len()); 124 | 125 | for tensor_name in tensor_names { 126 | let tensor = tensors.tensor(tensor_name).unwrap(); 127 | metadata.insert( 128 | tensor_name.to_string(), 129 | TensorStore { 130 | dtype: tensor.dtype(), 131 | shape: tensor.shape().to_vec(), 132 | data: tensor.data().to_vec(), 133 | }, 134 | ); 135 | } 136 | 137 | drop(tensors); 138 | drop(buffer); 139 | } 140 | 141 | Ok(metadata) 142 | } 143 | 144 | fn main() { 145 | let args = Args::parse(); 146 | let data_path = PathBuf::from(&args.model_path); 147 | 148 | let topology = 149 | Topology::from_path(&args.topology, &ModelType::TextModel).expect("can't load topology"); 150 | let index = load_index(&data_path).expect("can't load index"); 151 | 152 | println!("index has {} tensors", index.weight_map.len()); 153 | 154 | let selected_workers = if let Some(name) = &args.worker { 155 | vec![name.to_string()] 156 | } else { 157 | topology.keys().map(|s| s.to_string()).collect() 158 | }; 159 | 160 | println!("processing {} workers", selected_workers.len()); 161 | 162 | for worker_name in &selected_workers { 163 | println!("processing worker {worker_name} ..."); 164 | 165 | let worker_node = topology 166 | .get(worker_name) 167 | .expect("can't find worker topology"); 168 | 169 | let (new_index, reduced) = 170 | reduce_for_worker(&index, worker_node).expect("can't reduce for worker"); 171 | 172 | println!("compacting {} tensors ...", new_index.weight_map.len()); 173 | 174 | let metadata = 175 | create_new_metadata(&data_path, &reduced).expect("can't create metadata for worker"); 176 | 177 | let bundle_name = format!("{worker_name}-node"); 178 | let output_path = PathBuf::from(&args.output).join(bundle_name); 179 | let model_output_path = output_path.join("model"); 180 | if !output_path.exists() { 181 | println!("creating {}", model_output_path.display()); 182 | std::fs::create_dir_all(&model_output_path).unwrap(); 183 | } else { 184 | println!("saving model to {}", model_output_path.display()); 185 | } 186 | 187 | let new_index_path = model_output_path.join("model.safetensors.index.json"); 188 | 189 | println!("saving new index to {} ...", new_index_path.display()); 190 | 191 | let new_index_data = serde_json::to_string_pretty(&new_index).unwrap(); 192 | std::fs::write(&new_index_path, new_index_data).unwrap(); 193 | 194 | let new_tensors_path = model_output_path.join("reduced.safetensors"); 195 | 196 | println!( 197 | "saving reduced tensors to {} ...", 198 | new_tensors_path.display() 199 | ); 200 | 201 | safetensors::serialize_to_file(metadata, &None, &new_tensors_path).unwrap(); 202 | 203 | let loaded = utils::load_safetensors_paths_from_index(new_index_path).unwrap(); 204 | 205 | assert_eq!(loaded.len(), 1); 206 | 207 | let file = File::open(&loaded[0]).unwrap(); 208 | let buffer = unsafe { memmap2::MmapOptions::new().map(&file).unwrap() }; 209 | let _ = SafeTensors::deserialize(&buffer).unwrap(); 210 | 211 | let new_topology_path = output_path.join("topology.yml"); 212 | 213 | println!( 214 | "saving worker topology to {} ...", 215 | new_topology_path.display() 216 | ); 217 | 218 | let mut new_topology: HashMap = HashMap::new(); 219 | 220 | new_topology.insert(worker_name.to_string(), worker_node); 221 | 222 | let new_topology_data = serde_yaml::to_string(&new_topology).unwrap(); 223 | 224 | std::fs::write(&new_topology_path, new_topology_data).unwrap(); 225 | } 226 | } 227 | --------------------------------------------------------------------------------