├── .gitignore ├── BuildConfiguratons ├── Debug.xcconfig ├── Release.xcconfig ├── Sign-Debug-template.xcconfig └── Sign-Release-template.xcconfig ├── CoreML ├── pipeline │ ├── DPMSolverMultistepScheduler.swift │ ├── Decoder.swift │ ├── ManagedMLModel.swift │ ├── Random.swift │ ├── ResourceManaging.swift │ ├── SafetyChecker.swift │ ├── Scheduler.swift │ ├── StableDiffusionPipeline+Resources.swift │ ├── StableDiffusionPipeline.swift │ ├── TextEncoder.swift │ └── Unet.swift └── tokenizer │ ├── BPETokenizer+Reading.swift │ └── BPETokenizer.swift ├── Diffusion-macOS ├── Capabilities.swift ├── ContentView.swift ├── ControlsView.swift ├── Diffusion_macOS.entitlements ├── Diffusion_macOSApp.swift ├── GeneratedImageView.swift ├── HelpContent.swift ├── Info.plist ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json └── StatusView.swift ├── Diffusion.xcodeproj ├── project.pbxproj ├── project.xcworkspace │ ├── contents.xcworkspacedata │ ├── xcshareddata │ │ └── swiftpm │ │ │ └── Package.resolved │ └── xcuserdata │ │ └── cyril.xcuserdatad │ │ └── UserInterfaceState.xcuserstate ├── xcshareddata │ └── xcschemes │ │ └── Diffusion.xcscheme └── xcuserdata │ └── cyril.xcuserdatad │ └── xcschemes │ └── xcschememanagement.plist ├── Diffusion ├── Assets.xcassets │ ├── AccentColor.colorset │ │ └── Contents.json │ ├── AppIcon.appiconset │ │ ├── .DS_Store │ │ ├── 256x256@2x.png │ │ ├── 512x512@2x.png │ │ ├── Contents.json │ │ ├── diffusers_on_white_1024.png │ │ ├── icon_128x128.png │ │ ├── icon_128x128@2x.png │ │ ├── icon_16x16.png │ │ ├── icon_16x16@2x.png │ │ ├── icon_256x256.png │ │ ├── icon_256x256@2x.png │ │ ├── icon_32x32.png │ │ ├── icon_32x32@2x.png │ │ ├── icon_512x512.png │ │ └── icon_512x512@2x.png │ ├── Contents.json │ └── placeholder.imageset │ │ ├── -cell-blank.png │ │ └── Contents.json ├── Common │ ├── Downloader.swift │ ├── ModelInfo.swift │ ├── Pipeline │ │ ├── Pipeline.swift │ │ └── PipelineLoader.swift │ ├── State.swift │ ├── Utils.swift │ └── Views │ │ └── PromptTextField.swift ├── Diffusion.entitlements ├── DiffusionApp.swift ├── Info.plist ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json ├── Support │ ├── AppState.swift │ ├── Extensions.swift │ ├── Functions.swift │ └── SDImage.swift └── Views │ ├── ErrorBanner.swift │ ├── Loading.swift │ ├── MainAppView.swift │ ├── PreviewView.swift │ └── TextToImage.swift ├── DiffusionTests └── DiffusionTests.swift ├── DiffusionUITests ├── DiffusionUITests.swift └── DiffusionUITestsLaunchTests.swift ├── LICENSE ├── README.md ├── RELEASE.md ├── assets └── screenshot.jpg ├── config ├── common.xcconfig └── debug.xcconfig └── screenshot.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | xcuserdata/ 3 | -------------------------------------------------------------------------------- /BuildConfiguratons/Debug.xcconfig: -------------------------------------------------------------------------------- 1 | // 2 | // Debug.xcconfig 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 15/12/2022. 6 | // 7 | 8 | // Detailed explanation of how to set this up can be found here: https://ajpagente.github.io/mobile/using-xcconfig/ 9 | // Note: This file is the one referenced from the Xcode project and this file can be checked into the Git repo. But this references the file with your personal info (Sign-Debug.xccconfig) and that file should not be checked into the repo. 10 | // Configuration settings file format documentation can be found at: 11 | // https://help.apple.com/xcode/#/dev745c5c974 12 | 13 | #include "Sign-Debug.xcconfig" 14 | -------------------------------------------------------------------------------- /BuildConfiguratons/Release.xcconfig: -------------------------------------------------------------------------------- 1 | // 2 | // Release.xcconfig 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 15/12/2022. 6 | // 7 | 8 | // Detailed explanation of how to set this up can be found here: https://ajpagente.github.io/mobile/using-xcconfig/ 9 | // Note: This file is the one referenced from the Xcode project and this file can be checked into the Git repo. But this references the file with your personal info (Sign-Debug.xccconfig) and that file should not be checked into the repo. 10 | // Configuration settings file format documentation can be found at: 11 | // https://help.apple.com/xcode/#/dev745c5c974 12 | 13 | #include "Sign-Release.xcconfig" 14 | -------------------------------------------------------------------------------- /BuildConfiguratons/Sign-Debug-template.xcconfig: -------------------------------------------------------------------------------- 1 | // 2 | // Sign-Debug.xcconfig 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 15/12/2022. 6 | // 7 | 8 | // Detailed explanation of how to set this up can be found here: https://ajpagente.github.io/mobile/using-xcconfig/ 9 | // Note: This is your personal signing details. *Do not* check this into the repo. The Debug.xcconfig file includes this one so that you can modify this without impacting the project. 10 | // Configuration settings file format documentation can be found at: 11 | // https://help.apple.com/xcode/#/dev745c5c974 12 | 13 | // See the first link above for details on how to get the following values 14 | PRODUCT_BUNDLE_IDENTIFIER = 15 | DEVELOPMENT_TEAM = <10 character Team ID> 16 | CODE_SIGN_IDENTITY[sdk=iphoneos*] = <40 character SHA1 Hash from provisionin profile for iOS> 17 | PROVISIONING_PROFILE_SPECIFIER[sdk=iphoneos*] = <36 character UUID from provisioning profile for iOS> 18 | 19 | CODE_SIGN_IDENTITY[sdk=macos*] = <40 character SHA1 Hash from provisionin profile for macOS> 20 | PROVISIONING_PROFILE_SPECIFIER[sdk=macos*] = <36 character UUID from provisioning profile for macOS> 21 | -------------------------------------------------------------------------------- /BuildConfiguratons/Sign-Release-template.xcconfig: -------------------------------------------------------------------------------- 1 | // 2 | // Sign-Release-template.xcconfig 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 15/12/2022. 6 | // 7 | 8 | // Detailed explanation of how to set this up can be found here: https://ajpagente.github.io/mobile/using-xcconfig/ 9 | // Note: This is your personal signing details. *Do not* check this into the repo. The Debug.xcconfig file includes this one so that you can modify this without impacting the project. 10 | // Configuration settings file format documentation can be found at: 11 | // https://help.apple.com/xcode/#/dev745c5c974 12 | 13 | // See the first link above for details on how to get the following values 14 | PRODUCT_BUNDLE_IDENTIFIER = 15 | DEVELOPMENT_TEAM = <10 character Team ID> 16 | CODE_SIGN_IDENTITY[sdk=iphoneos*] = <40 character SHA1 Hash from provisionin profile for iOS> 17 | PROVISIONING_PROFILE_SPECIFIER[sdk=iphoneos*] = <36 character UUID from provisioning profile for iOS> 18 | 19 | CODE_SIGN_IDENTITY[sdk=macos*] = <40 character SHA1 Hash from provisionin profile for macOS> 20 | PROVISIONING_PROFILE_SPECIFIER[sdk=macos*] = <36 character UUID from provisioning profile for macOS> 21 | -------------------------------------------------------------------------------- /CoreML/pipeline/DPMSolverMultistepScheduler.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. and The HuggingFace Team. All Rights Reserved. 3 | 4 | import Accelerate 5 | import CoreML 6 | 7 | /// A scheduler used to compute a de-noised image 8 | /// 9 | /// This implementation matches: 10 | /// [Hugging Face Diffusers DPMSolverMultistepScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py) 11 | /// 12 | /// It uses the DPM-Solver++ algorithm: [code](https://github.com/LuChengTHU/dpm-solver) [paper](https://arxiv.org/abs/2211.01095). 13 | /// Limitations: 14 | /// - Only implemented for DPM-Solver++ algorithm (not DPM-Solver). 15 | /// - Second order only. 16 | /// - Assumes the model predicts epsilon. 17 | /// - No dynamic thresholding. 18 | /// - `midpoint` solver algorithm. 19 | @available(iOS 16.2, macOS 13.1, *) 20 | public final class DPMSolverMultistepScheduler: Scheduler { 21 | public let trainStepCount: Int 22 | public let inferenceStepCount: Int 23 | public let betas: [Float] 24 | public let alphas: [Float] 25 | public let alphasCumProd: [Float] 26 | public let timeSteps: [Int] 27 | 28 | public let alpha_t: [Float] 29 | public let sigma_t: [Float] 30 | public let lambda_t: [Float] 31 | 32 | public let solverOrder = 2 33 | private(set) var lowerOrderStepped = 0 34 | 35 | /// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps. 36 | /// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps. 37 | public let useLowerOrderFinal = true 38 | 39 | // Stores solverOrder (2) items 40 | private(set) var modelOutputs: [MLShapedArray] = [] 41 | 42 | /// Create a scheduler that uses a second order DPM-Solver++ algorithm. 43 | /// 44 | /// - Parameters: 45 | /// - stepCount: Number of inference steps to schedule 46 | /// - trainStepCount: Number of training diffusion steps 47 | /// - betaSchedule: Method to schedule betas from betaStart to betaEnd 48 | /// - betaStart: The starting value of beta for inference 49 | /// - betaEnd: The end value for beta for inference 50 | /// - Returns: A scheduler ready for its first step 51 | public init( 52 | stepCount: Int = 50, 53 | trainStepCount: Int = 1000, 54 | betaSchedule: BetaSchedule = .scaledLinear, 55 | betaStart: Float = 0.00085, 56 | betaEnd: Float = 0.012 57 | ) { 58 | self.trainStepCount = trainStepCount 59 | self.inferenceStepCount = stepCount 60 | 61 | switch betaSchedule { 62 | case .linear: 63 | self.betas = linspace(betaStart, betaEnd, trainStepCount) 64 | case .scaledLinear: 65 | self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) 66 | } 67 | 68 | self.alphas = betas.map({ 1.0 - $0 }) 69 | var alphasCumProd = self.alphas 70 | for i in 1.., timestep: Int, sample: MLShapedArray) -> MLShapedArray { 86 | assert(modelOutput.scalars.count == sample.scalars.count) 87 | let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep]) 88 | 89 | // This could be optimized with a Metal kernel if we find we need to 90 | let x0_scalars = zip(modelOutput.scalars, sample.scalars).map { m, s in 91 | (s - m * sigma_t) / alpha_t 92 | } 93 | return MLShapedArray(scalars: x0_scalars, shape: modelOutput.shape) 94 | } 95 | 96 | /// One step for the first-order DPM-Solver (equivalent to DDIM). 97 | /// See https://arxiv.org/abs/2206.00927 for the detailed derivation. 98 | /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py 99 | func firstOrderUpdate( 100 | modelOutput: MLShapedArray, 101 | timestep: Int, 102 | prevTimestep: Int, 103 | sample: MLShapedArray 104 | ) -> MLShapedArray { 105 | let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep])) 106 | let p_alpha_t = Double(alpha_t[prevTimestep]) 107 | let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep])) 108 | let h = p_lambda_t - lambda_s 109 | // x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output 110 | let x_t = weightedSum( 111 | [p_sigma_t / sigma_s, -p_alpha_t * (exp(-h) - 1)], 112 | [sample, modelOutput] 113 | ) 114 | return x_t 115 | } 116 | 117 | /// One step for the second-order multistep DPM-Solver++ algorithm, using the midpoint method. 118 | /// var names and code structure mostly follow https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py 119 | func secondOrderUpdate( 120 | modelOutputs: [MLShapedArray], 121 | timesteps: [Int], 122 | prevTimestep t: Int, 123 | sample: MLShapedArray 124 | ) -> MLShapedArray { 125 | let (s0, s1) = (timesteps[back: 1], timesteps[back: 2]) 126 | let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2]) 127 | let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1])) 128 | let p_alpha_t = Double(alpha_t[t]) 129 | let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0])) 130 | let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1) 131 | let r0 = h_0 / h 132 | let D0 = m0 133 | 134 | // D1 = (1.0 / r0) * (m0 - m1) 135 | let D1 = weightedSum( 136 | [1/r0, -1/r0], 137 | [m0, m1] 138 | ) 139 | 140 | // See https://arxiv.org/abs/2211.01095 for detailed derivations 141 | // x_t = ( 142 | // (sigma_t / sigma_s0) * sample 143 | // - (alpha_t * (torch.exp(-h) - 1.0)) * D0 144 | // - 0.5 * (alpha_t * (torch.exp(-h) - 1.0)) * D1 145 | // ) 146 | let x_t = weightedSum( 147 | [p_sigma_t/sigma_s0, -p_alpha_t * (exp(-h) - 1), -0.5 * p_alpha_t * (exp(-h) - 1)], 148 | [sample, D0, D1] 149 | ) 150 | return x_t 151 | } 152 | 153 | public func step(output: MLShapedArray, timeStep t: Int, sample: MLShapedArray) -> MLShapedArray { 154 | let stepIndex = timeSteps.firstIndex(of: t) ?? timeSteps.count - 1 155 | let prevTimestep = stepIndex == timeSteps.count - 1 ? 0 : timeSteps[stepIndex + 1] 156 | 157 | let lowerOrderFinal = useLowerOrderFinal && stepIndex == timeSteps.count - 1 && timeSteps.count < 15 158 | let lowerOrderSecond = useLowerOrderFinal && stepIndex == timeSteps.count - 2 && timeSteps.count < 15 159 | let lowerOrder = lowerOrderStepped < 1 || lowerOrderFinal || lowerOrderSecond 160 | 161 | let modelOutput = convertModelOutput(modelOutput: output, timestep: t, sample: sample) 162 | if modelOutputs.count == solverOrder { modelOutputs.removeFirst() } 163 | modelOutputs.append(modelOutput) 164 | 165 | let prevSample: MLShapedArray 166 | if lowerOrder { 167 | prevSample = firstOrderUpdate(modelOutput: modelOutput, timestep: t, prevTimestep: prevTimestep, sample: sample) 168 | } else { 169 | prevSample = secondOrderUpdate( 170 | modelOutputs: modelOutputs, 171 | timesteps: [timeSteps[stepIndex - 1], t], 172 | prevTimestep: prevTimestep, 173 | sample: sample 174 | ) 175 | } 176 | if lowerOrderStepped < solverOrder { 177 | lowerOrderStepped += 1 178 | } 179 | 180 | return prevSample 181 | } 182 | } 183 | -------------------------------------------------------------------------------- /CoreML/pipeline/Decoder.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | import Accelerate 7 | 8 | /// A decoder model which produces RGB images from latent samples 9 | @available(iOS 16.2, macOS 13.1, *) 10 | public struct Decoder: ResourceManaging { 11 | 12 | /// VAE decoder model 13 | var model: ManagedMLModel 14 | 15 | /// Create decoder from Core ML model 16 | /// 17 | /// - Parameters: 18 | /// - url: Location of compiled VAE decoder Core ML model 19 | /// - configuration: configuration to be used when the model is loaded 20 | /// - Returns: A decoder that will lazily load its required resources when needed or requested 21 | public init(modelAt url: URL, configuration: MLModelConfiguration) { 22 | self.model = ManagedMLModel(modelAt: url, configuration: configuration) 23 | } 24 | 25 | /// Ensure the model has been loaded into memory 26 | public func loadResources() throws { 27 | try model.loadResources() 28 | } 29 | 30 | /// Unload the underlying model to free up memory 31 | public func unloadResources() { 32 | model.unloadResources() 33 | } 34 | 35 | /// Batch decode latent samples into images 36 | /// 37 | /// - Parameters: 38 | /// - latents: Batch of latent samples to decode 39 | /// - Returns: decoded images 40 | public func decode(_ latents: [MLShapedArray]) throws -> [CGImage] { 41 | 42 | // Form batch inputs for model 43 | let inputs: [MLFeatureProvider] = try latents.map { sample in 44 | // Reference pipeline scales the latent samples before decoding 45 | let sampleScaled = MLShapedArray( 46 | scalars: sample.scalars.map { $0 / 0.18215 }, 47 | shape: sample.shape) 48 | 49 | let dict = [inputName: MLMultiArray(sampleScaled)] 50 | return try MLDictionaryFeatureProvider(dictionary: dict) 51 | } 52 | let batch = MLArrayBatchProvider(array: inputs) 53 | 54 | // Batch predict with model 55 | let results = try model.perform { model in 56 | try model.predictions(fromBatch: batch) 57 | } 58 | 59 | // Transform the outputs to CGImages 60 | let images: [CGImage] = (0..(output)) 66 | } 67 | 68 | return images 69 | } 70 | 71 | var inputName: String { 72 | try! model.perform { model in 73 | model.modelDescription.inputDescriptionsByName.first!.key 74 | } 75 | } 76 | 77 | typealias PixelBufferPFx1 = vImage.PixelBuffer 78 | typealias PixelBufferP8x3 = vImage.PixelBuffer 79 | typealias PixelBufferIFx3 = vImage.PixelBuffer 80 | typealias PixelBufferI8x3 = vImage.PixelBuffer 81 | 82 | func toRGBCGImage(_ array: MLShapedArray) -> CGImage { 83 | 84 | // array is [N,C,H,W], where C==3 85 | let channelCount = array.shape[1] 86 | assert(channelCount == 3, 87 | "Decoding model output has \(channelCount) channels, expected 3") 88 | let height = array.shape[2] 89 | let width = array.shape[3] 90 | 91 | // Normalize each channel into a float between 0 and 1.0 92 | let floatChannels = (0.. [0.0 1.0] 103 | cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut) 104 | } 105 | return cOut 106 | } 107 | 108 | // Convert to interleaved and then to UInt8 109 | let floatImage = PixelBufferIFx3(planarBuffers: floatChannels) 110 | let uint8Image = PixelBufferI8x3(width: width, height: height) 111 | floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips 112 | 113 | // Convert to uint8x3 to RGB CGImage (no alpha) 114 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) 115 | let cgImage = uint8Image.makeCGImage(cgImageFormat: 116 | .init(bitsPerComponent: 8, 117 | bitsPerPixel: 3*8, 118 | colorSpace: CGColorSpaceCreateDeviceRGB(), 119 | bitmapInfo: bitmapInfo)!)! 120 | 121 | return cgImage 122 | } 123 | } 124 | -------------------------------------------------------------------------------- /CoreML/pipeline/ManagedMLModel.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import CoreML 5 | 6 | /// A class to manage and gate access to a Core ML model 7 | /// 8 | /// It will automatically load a model into memory when needed or requested 9 | /// It allows one to request to unload the model from memory 10 | @available(iOS 16.2, macOS 13.1, *) 11 | public final class ManagedMLModel: ResourceManaging { 12 | 13 | /// The location of the model 14 | var modelURL: URL 15 | 16 | /// The configuration to be used when the model is loaded 17 | var configuration: MLModelConfiguration 18 | 19 | /// The loaded model (when loaded) 20 | var loadedModel: MLModel? 21 | 22 | /// Queue to protect access to loaded model 23 | var queue: DispatchQueue 24 | 25 | /// Create a managed model given its location and desired loaded configuration 26 | /// 27 | /// - Parameters: 28 | /// - url: The location of the model 29 | /// - configuration: The configuration to be used when the model is loaded/used 30 | /// - Returns: A managed model that has not been loaded 31 | public init(modelAt url: URL, configuration: MLModelConfiguration) { 32 | self.modelURL = url 33 | self.configuration = configuration 34 | self.loadedModel = nil 35 | self.queue = DispatchQueue(label: "managed.\(url.lastPathComponent)") 36 | } 37 | 38 | /// Instantiation and load model into memory 39 | public func loadResources() throws { 40 | try queue.sync { 41 | try loadModel() 42 | } 43 | } 44 | 45 | /// Unload the model if it was loaded 46 | public func unloadResources() { 47 | queue.sync { 48 | loadedModel = nil 49 | } 50 | } 51 | 52 | /// Perform an operation with the managed model via a supplied closure. 53 | /// The model will be loaded and supplied to the closure and should only be 54 | /// used within the closure to ensure all resource management is synchronized 55 | /// 56 | /// - Parameters: 57 | /// - body: Closure which performs and action on a loaded model 58 | /// - Returns: The result of the closure 59 | /// - Throws: An error if the model cannot be loaded or if the closure throws 60 | public func perform(_ body: (MLModel) throws -> R) throws -> R { 61 | return try queue.sync { 62 | try autoreleasepool { 63 | try loadModel() 64 | return try body(loadedModel!) 65 | } 66 | } 67 | } 68 | 69 | private func loadModel() throws { 70 | if loadedModel == nil { 71 | loadedModel = try MLModel(contentsOf: modelURL, 72 | configuration: configuration) 73 | } 74 | } 75 | 76 | 77 | } 78 | -------------------------------------------------------------------------------- /CoreML/pipeline/Random.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | 7 | /// A random source consistent with NumPy 8 | /// 9 | /// This implementation matches: 10 | /// [NumPy's older randomkit.c](https://github.com/numpy/numpy/blob/v1.0/numpy/random/mtrand/randomkit.c) 11 | /// 12 | @available(iOS 16.2, macOS 13.1, *) 13 | struct NumPyRandomSource: RandomNumberGenerator { 14 | 15 | struct State { 16 | var key = [UInt32](repeating: 0, count: 624) 17 | var pos: Int = 0 18 | var nextGauss: Double? = nil 19 | } 20 | 21 | var state: State 22 | 23 | /// Initialize with a random seed 24 | /// 25 | /// - Parameters 26 | /// - seed: Seed for underlying Mersenne Twister 19937 generator 27 | /// - Returns random source 28 | init(seed: UInt32) { 29 | state = .init() 30 | var s = seed & 0xffffffff 31 | for i in 0 ..< state.key.count { 32 | state.key[i] = s 33 | s = UInt32((UInt64(1812433253) * UInt64(s ^ (s >> 30)) + UInt64(i) + 1) & 0xffffffff) 34 | } 35 | state.pos = state.key.count 36 | state.nextGauss = nil 37 | } 38 | 39 | /// Generate next UInt32 using fast 32bit Mersenne Twister 40 | mutating func nextUInt32() -> UInt32 { 41 | let n = 624 42 | let m = 397 43 | let matrixA: UInt64 = 0x9908b0df 44 | let upperMask: UInt32 = 0x80000000 45 | let lowerMask: UInt32 = 0x7fffffff 46 | 47 | var y: UInt32 48 | if state.pos == state.key.count { 49 | for i in 0 ..< (n - m) { 50 | y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask) 51 | state.key[i] = state.key[i + m] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA) 52 | } 53 | for i in (n - m) ..< (n - 1) { 54 | y = (state.key[i] & upperMask) | (state.key[i + 1] & lowerMask) 55 | state.key[i] = state.key[i + (m - n)] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA) 56 | } 57 | y = (state.key[n - 1] & upperMask) | (state.key[0] & lowerMask) 58 | state.key[n - 1] = state.key[m - 1] ^ (y >> 1) ^ UInt32((UInt64(~(y & 1)) + 1) & matrixA) 59 | state.pos = 0 60 | } 61 | y = state.key[state.pos] 62 | state.pos += 1 63 | 64 | y ^= (y >> 11) 65 | y ^= (y << 7) & 0x9d2c5680 66 | y ^= (y << 15) & 0xefc60000 67 | y ^= (y >> 18) 68 | 69 | return y 70 | } 71 | 72 | mutating func next() -> UInt64 { 73 | let low = nextUInt32() 74 | let high = nextUInt32() 75 | return (UInt64(high) << 32) | UInt64(low) 76 | } 77 | 78 | /// Generate next random double value 79 | mutating func nextDouble() -> Double { 80 | let a = Double(nextUInt32() >> 5) 81 | let b = Double(nextUInt32() >> 6) 82 | return (a * 67108864.0 + b) / 9007199254740992.0 83 | } 84 | 85 | /// Generate next random value from a standard normal 86 | mutating func nextGauss() -> Double { 87 | if let nextGauss = state.nextGauss { 88 | state.nextGauss = nil 89 | return nextGauss 90 | } 91 | var x1, x2, r2: Double 92 | repeat { 93 | x1 = 2.0 * nextDouble() - 1.0 94 | x2 = 2.0 * nextDouble() - 1.0 95 | r2 = x1 * x1 + x2 * x2 96 | } while r2 >= 1.0 || r2 == 0.0 97 | 98 | // Box-Muller transform 99 | let f = sqrt(-2.0 * log(r2) / r2) 100 | state.nextGauss = f * x1 101 | return f * x2 102 | } 103 | 104 | /// Generates a random value from a normal distribution with given mean and standard deviation. 105 | mutating func nextNormal(mean: Double = 0.0, stdev: Double = 1.0) -> Double { 106 | nextGauss() * stdev + mean 107 | } 108 | 109 | /// Generates an array of random values from a normal distribution with given mean and standard deviation. 110 | mutating func normalArray(count: Int, mean: Double = 0.0, stdev: Double = 1.0) -> [Double] { 111 | (0 ..< count).map { _ in nextNormal(mean: mean, stdev: stdev) } 112 | } 113 | 114 | /// Generate a shaped array with scalars from a normal distribution with given mean and standard deviation. 115 | mutating func normalShapedArray(_ shape: [Int], mean: Double = 0.0, stdev: Double = 1.0) -> MLShapedArray { 116 | let count = shape.reduce(1, *) 117 | return .init(scalars: normalArray(count: count, mean: mean, stdev: stdev), shape: shape) 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /CoreML/pipeline/ResourceManaging.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | /// Protocol for managing internal resources 5 | public protocol ResourceManaging { 6 | 7 | /// Request resources to be loaded and ready if possible 8 | func loadResources() throws 9 | 10 | /// Request resources are unloaded / remove from memory if possible 11 | func unloadResources() 12 | } 13 | 14 | extension ResourceManaging { 15 | /// Request resources are pre-warmed by loading and unloading 16 | func prewarmResources() throws { 17 | try loadResources() 18 | unloadResources() 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /CoreML/pipeline/SafetyChecker.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | import Accelerate 7 | 8 | /// Image safety checking model 9 | @available(iOS 16.2, macOS 13.1, *) 10 | public struct SafetyChecker: ResourceManaging { 11 | 12 | /// Safety checking Core ML model 13 | var model: ManagedMLModel 14 | 15 | /// Creates safety checker 16 | /// 17 | /// - Parameters: 18 | /// - url: Location of compiled safety checking Core ML model 19 | /// - configuration: configuration to be used when the model is loaded 20 | /// - Returns: A safety cherker that will lazily load its required resources when needed or requested 21 | public init(modelAt url: URL, configuration: MLModelConfiguration) { 22 | self.model = ManagedMLModel(modelAt: url, configuration: configuration) 23 | } 24 | 25 | /// Ensure the model has been loaded into memory 26 | public func loadResources() throws { 27 | try model.loadResources() 28 | } 29 | 30 | /// Unload the underlying model to free up memory 31 | public func unloadResources() { 32 | model.unloadResources() 33 | } 34 | 35 | typealias PixelBufferPFx1 = vImage.PixelBuffer 36 | typealias PixelBufferP8x1 = vImage.PixelBuffer 37 | typealias PixelBufferPFx3 = vImage.PixelBuffer 38 | typealias PixelBufferP8x3 = vImage.PixelBuffer 39 | typealias PixelBufferIFx3 = vImage.PixelBuffer 40 | typealias PixelBufferI8x3 = vImage.PixelBuffer 41 | typealias PixelBufferI8x4 = vImage.PixelBuffer 42 | 43 | enum SafetyCheckError: Error { 44 | case imageResizeFailure 45 | case imageToFloatFailure 46 | case modelInputFailure 47 | case unexpectedModelOutput 48 | } 49 | 50 | /// Check if image is safe 51 | /// 52 | /// - Parameters: 53 | /// - image: Image to check 54 | /// - Returns: Whether the model considers the image to be safe 55 | public func isSafe(_ image: CGImage) throws -> Bool { 56 | 57 | let inputName = "clip_input" 58 | let adjustmentName = "adjustment" 59 | let imagesNames = "images" 60 | 61 | let inputInfo = try model.perform { model in 62 | model.modelDescription.inputDescriptionsByName 63 | } 64 | let inputShape = inputInfo[inputName]!.multiArrayConstraint!.shape 65 | 66 | let width = inputShape[2].intValue 67 | let height = inputShape[3].intValue 68 | 69 | let resizedImage = try resizeToRGBA(image, width: width, height: height) 70 | 71 | let bufferP8x3 = try getRGBPlanes(of: resizedImage) 72 | 73 | let arrayPFx3 = normalizeToFloatShapedArray(bufferP8x3) 74 | 75 | guard let input = try? MLDictionaryFeatureProvider( 76 | dictionary:[ 77 | // Input that is analyzed for safety 78 | inputName : MLMultiArray(arrayPFx3), 79 | // No adjustment, use default threshold 80 | adjustmentName : MLMultiArray(MLShapedArray(scalars: [0], shape: [1])), 81 | // Supplying dummy images to be filtered (will be ignored) 82 | imagesNames : MLMultiArray(shape:[1, 512, 512, 3], dataType: .float16) 83 | ] 84 | ) else { 85 | throw SafetyCheckError.modelInputFailure 86 | } 87 | 88 | let result = try model.perform { model in 89 | try model.prediction(from: input) 90 | } 91 | 92 | let output = result.featureValue(for: "has_nsfw_concepts") 93 | 94 | guard let unsafe = output?.multiArrayValue?[0].boolValue else { 95 | throw SafetyCheckError.unexpectedModelOutput 96 | } 97 | 98 | return !unsafe 99 | } 100 | 101 | func resizeToRGBA(_ image: CGImage, 102 | width: Int, height: Int) throws -> CGImage { 103 | 104 | guard let context = CGContext( 105 | data: nil, 106 | width: width, 107 | height: height, 108 | bitsPerComponent: 8, 109 | bytesPerRow: width*4, 110 | space: CGColorSpaceCreateDeviceRGB(), 111 | bitmapInfo: CGImageAlphaInfo.noneSkipLast.rawValue) else { 112 | throw SafetyCheckError.imageResizeFailure 113 | } 114 | 115 | context.interpolationQuality = .high 116 | context.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) 117 | guard let resizedImage = context.makeImage() else { 118 | throw SafetyCheckError.imageResizeFailure 119 | } 120 | 121 | return resizedImage 122 | } 123 | 124 | func getRGBPlanes(of rgbaImage: CGImage) throws -> PixelBufferP8x3 { 125 | // Reference as interleaved 8 bit vImage PixelBuffer 126 | var emptyFormat = vImage_CGImageFormat() 127 | guard let bufferI8x4 = try? PixelBufferI8x4( 128 | cgImage: rgbaImage, 129 | cgImageFormat:&emptyFormat) else { 130 | throw SafetyCheckError.imageToFloatFailure 131 | } 132 | 133 | // Drop the alpha channel, keeping RGB 134 | let bufferI8x3 = PixelBufferI8x3(width: rgbaImage.width, height:rgbaImage.height) 135 | bufferI8x4.convert(to: bufferI8x3, channelOrdering: .RGBA) 136 | 137 | // De-interleave into 8-bit planes 138 | return PixelBufferP8x3(interleavedBuffer: bufferI8x3) 139 | } 140 | 141 | func normalizeToFloatShapedArray(_ bufferP8x3: PixelBufferP8x3) -> MLShapedArray { 142 | let width = bufferP8x3.width 143 | let height = bufferP8x3.height 144 | 145 | let means = [0.485, 0.456, 0.406] as [Float] 146 | let stds = [0.229, 0.224, 0.225] as [Float] 147 | 148 | // Convert to normalized float 1x3xWxH input (plannar) 149 | let arrayPFx3 = MLShapedArray(repeating: 0.0, shape: [1, 3, width, height]) 150 | for c in 0..<3 { 151 | arrayPFx3[0][c].withUnsafeShapedBufferPointer { ptr, _, strides in 152 | let floatChannel = PixelBufferPFx1(data: .init(mutating: ptr.baseAddress!), 153 | width: width, height: height, 154 | byteCountPerRow: strides[0]*4) 155 | 156 | bufferP8x3.withUnsafePixelBuffer(at: c) { uint8Channel in 157 | uint8Channel.convert(to: floatChannel) // maps [0 255] -> [0 1] 158 | floatChannel.multiply(by: 1.0/stds[c], 159 | preBias: -means[c], 160 | postBias: 0.0, 161 | destination: floatChannel) 162 | } 163 | } 164 | } 165 | return arrayPFx3 166 | } 167 | } 168 | -------------------------------------------------------------------------------- /CoreML/pipeline/Scheduler.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import CoreML 5 | 6 | @available(iOS 16.2, macOS 13.1, *) 7 | public protocol Scheduler { 8 | /// Number of diffusion steps performed during training 9 | var trainStepCount: Int { get } 10 | 11 | /// Number of inference steps to be performed 12 | var inferenceStepCount: Int { get } 13 | 14 | /// Training diffusion time steps index by inference time step 15 | var timeSteps: [Int] { get } 16 | 17 | /// Schedule of betas which controls the amount of noise added at each timestep 18 | var betas: [Float] { get } 19 | 20 | /// 1 - betas 21 | var alphas: [Float] { get } 22 | 23 | /// Cached cumulative product of alphas 24 | var alphasCumProd: [Float] { get } 25 | 26 | /// Standard deviation of the initial noise distribution 27 | var initNoiseSigma: Float { get } 28 | 29 | /// Compute a de-noised image sample and step scheduler state 30 | /// 31 | /// - Parameters: 32 | /// - output: The predicted residual noise output of learned diffusion model 33 | /// - timeStep: The current time step in the diffusion chain 34 | /// - sample: The current input sample to the diffusion model 35 | /// - Returns: Predicted de-noised sample at the previous time step 36 | /// - Postcondition: The scheduler state is updated. 37 | /// The state holds the current sample and history of model output noise residuals 38 | func step( 39 | output: MLShapedArray, 40 | timeStep t: Int, 41 | sample s: MLShapedArray 42 | ) -> MLShapedArray 43 | } 44 | 45 | @available(iOS 16.2, macOS 13.1, *) 46 | public extension Scheduler { 47 | var initNoiseSigma: Float { 1 } 48 | } 49 | 50 | @available(iOS 16.2, macOS 13.1, *) 51 | public extension Scheduler { 52 | /// Compute weighted sum of shaped arrays of equal shapes 53 | /// 54 | /// - Parameters: 55 | /// - weights: The weights each array is multiplied by 56 | /// - values: The arrays to be weighted and summed 57 | /// - Returns: sum_i weights[i]*values[i] 58 | func weightedSum(_ weights: [Double], _ values: [MLShapedArray]) -> MLShapedArray { 59 | assert(weights.count > 1 && values.count == weights.count) 60 | assert(values.allSatisfy({ $0.scalarCount == values.first!.scalarCount })) 61 | var w = Float(weights.first!) 62 | var scalars = values.first!.scalars.map({ $0 * w }) 63 | for next in 1 ..< values.count { 64 | w = Float(weights[next]) 65 | let nextScalars = values[next].scalars 66 | for i in 0 ..< scalars.count { 67 | scalars[i] += w * nextScalars[i] 68 | } 69 | } 70 | return MLShapedArray(scalars: scalars, shape: values.first!.shape) 71 | } 72 | } 73 | 74 | /// How to map a beta range to a sequence of betas to step over 75 | @available(iOS 16.2, macOS 13.1, *) 76 | public enum BetaSchedule { 77 | /// Linear stepping between start and end 78 | case linear 79 | /// Steps using linspace(sqrt(start),sqrt(end))^2 80 | case scaledLinear 81 | } 82 | 83 | 84 | /// A scheduler used to compute a de-noised image 85 | /// 86 | /// This implementation matches: 87 | /// [Hugging Face Diffusers PNDMScheduler](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_pndm.py) 88 | /// 89 | /// This scheduler uses the pseudo linear multi-step (PLMS) method only, skipping pseudo Runge-Kutta (PRK) steps 90 | @available(iOS 16.2, macOS 13.1, *) 91 | public final class PNDMScheduler: Scheduler { 92 | public let trainStepCount: Int 93 | public let inferenceStepCount: Int 94 | public let betas: [Float] 95 | public let alphas: [Float] 96 | public let alphasCumProd: [Float] 97 | public let timeSteps: [Int] 98 | 99 | // Internal state 100 | var counter: Int 101 | var ets: [MLShapedArray] 102 | var currentSample: MLShapedArray? 103 | 104 | /// Create a scheduler that uses a pseudo linear multi-step (PLMS) method 105 | /// 106 | /// - Parameters: 107 | /// - stepCount: Number of inference steps to schedule 108 | /// - trainStepCount: Number of training diffusion steps 109 | /// - betaSchedule: Method to schedule betas from betaStart to betaEnd 110 | /// - betaStart: The starting value of beta for inference 111 | /// - betaEnd: The end value for beta for inference 112 | /// - Returns: A scheduler ready for its first step 113 | public init( 114 | stepCount: Int = 50, 115 | trainStepCount: Int = 1000, 116 | betaSchedule: BetaSchedule = .scaledLinear, 117 | betaStart: Float = 0.00085, 118 | betaEnd: Float = 0.012 119 | ) { 120 | self.trainStepCount = trainStepCount 121 | self.inferenceStepCount = stepCount 122 | 123 | switch betaSchedule { 124 | case .linear: 125 | self.betas = linspace(betaStart, betaEnd, trainStepCount) 126 | case .scaledLinear: 127 | self.betas = linspace(pow(betaStart, 0.5), pow(betaEnd, 0.5), trainStepCount).map({ $0 * $0 }) 128 | } 129 | self.alphas = betas.map({ 1.0 - $0 }) 130 | var alphasCumProd = self.alphas 131 | for i in 1.., 164 | timeStep t: Int, 165 | sample s: MLShapedArray 166 | ) -> MLShapedArray { 167 | 168 | var timeStep = t 169 | let stepInc = (trainStepCount / inferenceStepCount) 170 | var prevStep = timeStep - stepInc 171 | var modelOutput = output 172 | var sample = s 173 | 174 | if counter != 1 { 175 | if ets.count > 3 { 176 | ets = Array(ets[(ets.count - 3).., 226 | _ timeStep: Int, 227 | _ prevStep: Int, 228 | _ modelOutput: MLShapedArray 229 | ) -> MLShapedArray { 230 | 231 | // Compute x_(t−δ) using formula (9) from 232 | // "Pseudo Numerical Methods for Diffusion Models on Manifolds", 233 | // Luping Liu, Yi Ren, Zhijie Lin & Zhou Zhao. 234 | // ICLR 2022 235 | // 236 | // Notation: 237 | // 238 | // alphaProdt α_t 239 | // alphaProdtPrev α_(t−δ) 240 | // betaProdt (1 - α_t) 241 | // betaProdtPrev (1 - α_(t−δ)) 242 | let alphaProdt = alphasCumProd[timeStep] 243 | let alphaProdtPrev = alphasCumProd[max(0,prevStep)] 244 | let betaProdt = 1 - alphaProdt 245 | let betaProdtPrev = 1 - alphaProdtPrev 246 | 247 | // sampleCoeff = (α_(t−δ) - α_t) divided by 248 | // denominator of x_t in formula (9) and plus 1 249 | // Note: (α_(t−δ) - α_t) / (sqrt(α_t) * (sqrt(α_(t−δ)) + sqr(α_t))) = 250 | // sqrt(α_(t−δ)) / sqrt(α_t)) 251 | let sampleCoeff = sqrt(alphaProdtPrev / alphaProdt) 252 | 253 | // Denominator of e_θ(x_t, t) in formula (9) 254 | let modelOutputDenomCoeff = alphaProdt * sqrt(betaProdtPrev) 255 | + sqrt(alphaProdt * betaProdt * alphaProdtPrev) 256 | 257 | // full formula (9) 258 | let modelCoeff = -(alphaProdtPrev - alphaProdt)/modelOutputDenomCoeff 259 | let prevSample = weightedSum( 260 | [Double(sampleCoeff), Double(modelCoeff)], 261 | [sample, modelOutput] 262 | ) 263 | 264 | return prevSample 265 | } 266 | } 267 | 268 | /// Evenly spaced floats between specified interval 269 | /// 270 | /// - Parameters: 271 | /// - start: Start of the interval 272 | /// - end: End of the interval 273 | /// - count: The number of floats to return between [*start*, *end*] 274 | /// - Returns: Float array with *count* elements evenly spaced between at *start* and *end* 275 | func linspace(_ start: Float, _ end: Float, _ count: Int) -> [Float] { 276 | let scale = (end - start) / Float(count - 1) 277 | return (0.. Element { 283 | return self[index(endIndex, offsetBy: -i)] 284 | } 285 | } 286 | -------------------------------------------------------------------------------- /CoreML/pipeline/StableDiffusionPipeline+Resources.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | 7 | @available(iOS 16.2, macOS 13.1, *) 8 | public extension StableDiffusionPipeline { 9 | 10 | struct ResourceURLs { 11 | 12 | public let textEncoderURL: URL 13 | public let unetURL: URL 14 | public let unetChunk1URL: URL 15 | public let unetChunk2URL: URL 16 | public let decoderURL: URL 17 | public let safetyCheckerURL: URL 18 | public let vocabURL: URL 19 | public let mergesURL: URL 20 | 21 | public init(resourcesAt baseURL: URL) { 22 | textEncoderURL = baseURL.appending(path: "TextEncoder.mlmodelc") 23 | unetURL = baseURL.appending(path: "Unet.mlmodelc") 24 | unetChunk1URL = baseURL.appending(path: "UnetChunk1.mlmodelc") 25 | unetChunk2URL = baseURL.appending(path: "UnetChunk2.mlmodelc") 26 | decoderURL = baseURL.appending(path: "VAEDecoder.mlmodelc") 27 | safetyCheckerURL = baseURL.appending(path: "SafetyChecker.mlmodelc") 28 | vocabURL = baseURL.appending(path: "vocab.json") 29 | mergesURL = baseURL.appending(path: "merges.txt") 30 | } 31 | } 32 | 33 | /// Create stable diffusion pipeline using model resources at a 34 | /// specified URL 35 | /// 36 | /// - Parameters: 37 | /// - baseURL: URL pointing to directory holding all model 38 | /// and tokenization resources 39 | /// - configuration: The configuration to load model resources with 40 | /// - disableSafety: Load time disable of safety to save memory 41 | /// - reduceMemory: Setup pipeline in reduced memory mode 42 | /// - Returns: 43 | /// Pipeline ready for image generation if all necessary resources loaded 44 | init(resourcesAt baseURL: URL, 45 | configuration config: MLModelConfiguration = .init(), 46 | disableSafety: Bool = false, 47 | reduceMemory: Bool = false) throws { 48 | 49 | /// Expect URL of each resource 50 | let urls = ResourceURLs(resourcesAt: baseURL) 51 | 52 | // Text tokenizer and encoder 53 | let tokenizer = try BPETokenizer(mergesAt: urls.mergesURL, vocabularyAt: urls.vocabURL) 54 | let textEncoder = TextEncoder(tokenizer: tokenizer, 55 | modelAt: urls.textEncoderURL, 56 | configuration: config) 57 | 58 | // Unet model 59 | let unet: Unet 60 | if FileManager.default.fileExists(atPath: urls.unetChunk1URL.path) && 61 | FileManager.default.fileExists(atPath: urls.unetChunk2URL.path) { 62 | unet = Unet(chunksAt: [urls.unetChunk1URL, urls.unetChunk2URL], 63 | configuration: config) 64 | } else { 65 | unet = Unet(modelAt: urls.unetURL, configuration: config) 66 | } 67 | 68 | // Image Decoder 69 | let decoder = Decoder(modelAt: urls.decoderURL, configuration: config) 70 | 71 | // Optional safety checker 72 | var safetyChecker: SafetyChecker? = nil 73 | if !disableSafety && 74 | FileManager.default.fileExists(atPath: urls.safetyCheckerURL.path) { 75 | safetyChecker = SafetyChecker(modelAt: urls.safetyCheckerURL, configuration: config) 76 | } 77 | 78 | // Construct pipeline 79 | self.init(textEncoder: textEncoder, 80 | unet: unet, 81 | decoder: decoder, 82 | safetyChecker: safetyChecker, 83 | reduceMemory: reduceMemory) 84 | } 85 | } 86 | -------------------------------------------------------------------------------- /CoreML/pipeline/StableDiffusionPipeline.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | import Accelerate 7 | import CoreGraphics 8 | 9 | /// Schedulers compatible with StableDiffusionPipeline 10 | public enum StableDiffusionScheduler: String, CaseIterable { 11 | /// Scheduler that uses a pseudo-linear multi-step (PLMS) method 12 | case pndm = "PNDM" 13 | /// Scheduler that uses a second order DPM-Solver++ algorithm 14 | case dpmpp = "DPMPP" 15 | } 16 | 17 | /// A pipeline used to generate image samples from text input using stable diffusion 18 | /// 19 | /// This implementation matches: 20 | /// [Hugging Face Diffusers Pipeline](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py) 21 | @available(iOS 16.2, macOS 13.1, *) 22 | public struct StableDiffusionPipeline: ResourceManaging { 23 | /// Model to generate embeddings for tokenized input text 24 | var textEncoder: TextEncoder 25 | 26 | /// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding 27 | var unet: Unet 28 | 29 | /// Model used to generate final image from latent diffusion process 30 | var decoder: Decoder 31 | 32 | /// Optional model for checking safety of generated image 33 | var safetyChecker: SafetyChecker? = nil 34 | 35 | /// Reports whether this pipeline can perform safety checks 36 | public var canSafetyCheck: Bool { 37 | safetyChecker != nil 38 | } 39 | 40 | /// Option to reduce memory during image generation 41 | /// 42 | /// If true, the pipeline will lazily load TextEncoder, Unet, Decoder, and SafetyChecker 43 | /// when needed and aggressively unload their resources after 44 | /// 45 | /// This will increase latency in favor of reducing memory 46 | var reduceMemory: Bool = false 47 | 48 | /// Creates a pipeline using the specified models and tokenizer 49 | /// 50 | /// - Parameters: 51 | /// - textEncoder: Model for encoding tokenized text 52 | /// - unet: Model for noise prediction on latent samples 53 | /// - decoder: Model for decoding latent sample to image 54 | /// - safetyChecker: Optional model for checking safety of generated images 55 | /// - guidanceScale: Influence of the text prompt on generation process (0=random images) 56 | /// - reduceMemory: Option to enable reduced memory mode 57 | /// - Returns: Pipeline ready for image generation 58 | public init(textEncoder: TextEncoder, unet: Unet, decoder: Decoder, safetyChecker: SafetyChecker? = nil, reduceMemory: Bool = false) { 59 | self.textEncoder = textEncoder 60 | self.unet = unet 61 | self.decoder = decoder 62 | self.safetyChecker = safetyChecker 63 | self.reduceMemory = reduceMemory 64 | } 65 | 66 | /// Load required resources for this pipeline 67 | /// 68 | /// If reducedMemory is true this will instead call prewarmResources instead 69 | /// and let the pipeline lazily load resources as needed 70 | public func loadResources() throws { 71 | if reduceMemory { 72 | try prewarmResources() 73 | } else { 74 | try textEncoder.loadResources() 75 | try unet.loadResources() 76 | try decoder.loadResources() 77 | try safetyChecker?.loadResources() 78 | } 79 | } 80 | 81 | /// Unload the underlying resources to free up memory 82 | public func unloadResources() { 83 | textEncoder.unloadResources() 84 | unet.unloadResources() 85 | decoder.unloadResources() 86 | safetyChecker?.unloadResources() 87 | } 88 | 89 | // Prewarm resources one at a time 90 | public func prewarmResources() throws { 91 | try textEncoder.prewarmResources() 92 | try unet.prewarmResources() 93 | try decoder.prewarmResources() 94 | try safetyChecker?.prewarmResources() 95 | } 96 | 97 | /// Text to image generation using stable diffusion 98 | /// 99 | /// - Parameters: 100 | /// - prompt: Text prompt to guide sampling 101 | /// - stepCount: Number of inference steps to perform 102 | /// - imageCount: Number of samples/images to generate for the input prompt 103 | /// - seed: Random seed which allows us to re-generate the same image for the same prompt by re-using the seed. If the seed is -1, then a random seed is generated and returned at the end. 104 | /// - guidanceScale: For classifier guidance 105 | /// - disableSafety: Safety checks are only performed if `self.canSafetyCheck && !disableSafety` 106 | /// - progressHandler: Callback to perform after each step, stops on receiving false response 107 | /// - Returns: A tuple containing an array of `imageCount` optional images and an `Int` for the random seed used. The images will be nil if safety checks were performed and found the result to be un-safe 108 | public func generateImages(prompt: String, negativePrompt: String = "", imageCount: Int = 1, stepCount: Int = 50, seed: Int = -1, guidanceScale: Float = 7.5, disableSafety: Bool = false, scheduler: StableDiffusionScheduler = .pndm, progressHandler: (Progress) -> Bool = { _ in true }) throws -> ([CGImage?], Int) { 109 | // Should we generate a random seed? 110 | let iseed = seed == -1 ? Int.random(in: 0...Int.max) : seed 111 | // Encode the input prompt as well as a blank unconditioned input 112 | let promptEmbedding = try textEncoder.encode(prompt) 113 | let negativePromptEmbedding = try textEncoder.encode(negativePrompt) 114 | // let blankEmbedding = try textEncoder.encode("") 115 | if reduceMemory { 116 | textEncoder.unloadResources() 117 | } 118 | // Concatenate the prompt and negative prompt embeddings 119 | let concatEmbedding = MLShapedArray(concatenating: [negativePromptEmbedding, promptEmbedding], alongAxis: 0) 120 | let hiddenStates = toHiddenStates(concatEmbedding) 121 | /// Setup schedulers 122 | let scheduler: [Scheduler] = (0..(concatenating: [$0, $0], alongAxis: 0) 136 | } 137 | // Predict noise residuals from latent samples and current time step conditioned on hidden states 138 | var noise = try unet.predictNoise(latents: latentUnetInput, timeStep: t, hiddenStates: hiddenStates) 139 | noise = performGuidance(noise: noise, guidance: guidanceScale) 140 | // Have the scheduler compute the previous (t-1) latent sample given the predicted noise and current sample 141 | for i in 0.. [MLShapedArray] { 160 | var sampleShape = unet.latentSampleShape 161 | sampleShape[0] = 1 162 | 163 | var random = NumPyRandomSource(seed: UInt32(truncatingIfNeeded: seed)) 164 | let samples = (0..( 166 | converting: random.normalShapedArray(sampleShape, mean: 0.0, stdev: Double(stdev))) 167 | } 168 | return samples 169 | } 170 | 171 | func toHiddenStates(_ embedding: MLShapedArray) -> MLShapedArray { 172 | // Unoptimized manual transpose [0, 2, None, 1] 173 | // e.g. From [2, 77, 768] to [2, 768, 1, 77] 174 | let fromShape = embedding.shape 175 | let stateShape = [fromShape[0],fromShape[2], 1, fromShape[1]] 176 | var states = MLShapedArray(repeating: 0.0, shape: stateShape) 177 | for i0 in 0..], guidance: Float) -> [MLShapedArray] { 188 | noise.map { performGuidance(noise: $0, guidance: guidance) } 189 | } 190 | 191 | func performGuidance(noise: MLShapedArray, guidance: Float) -> MLShapedArray { 192 | let blankNoiseScalars = noise[0].scalars 193 | let textNoiseScalars = noise[1].scalars 194 | var resultScalars = blankNoiseScalars 195 | for i in 0..(scalars: resultScalars, shape: shape) 202 | } 203 | 204 | func decodeToImages(_ latents: [MLShapedArray], disableSafety: Bool) throws -> [CGImage?] { 205 | let images = try decoder.decode(latents) 206 | if reduceMemory { 207 | decoder.unloadResources() 208 | } 209 | // If safety is disabled return what was decoded 210 | if disableSafety { 211 | return images 212 | } 213 | // If there is no safety checker return what was decoded 214 | guard let safetyChecker = safetyChecker else { 215 | return images 216 | } 217 | // Otherwise change images which are not safe to nil 218 | let safeImages = try images.map { image in 219 | try safetyChecker.isSafe(image) ? image : nil 220 | } 221 | if reduceMemory { 222 | safetyChecker.unloadResources() 223 | } 224 | return safeImages 225 | } 226 | 227 | } 228 | 229 | @available(iOS 16.2, macOS 13.1, *) 230 | extension StableDiffusionPipeline { 231 | /// Sampling progress details 232 | public struct Progress { 233 | public let pipeline: StableDiffusionPipeline 234 | public let prompt: String 235 | public let step: Int 236 | public let stepCount: Int 237 | public let currentLatentSamples: [MLShapedArray] 238 | public let isSafetyEnabled: Bool 239 | public var currentImages: [CGImage?] { 240 | try! pipeline.decodeToImages( 241 | currentLatentSamples, 242 | disableSafety: !isSafetyEnabled) 243 | } 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /CoreML/pipeline/TextEncoder.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | 7 | /// A model for encoding text 8 | @available(iOS 16.2, macOS 13.1, *) 9 | public struct TextEncoder: ResourceManaging { 10 | 11 | /// Text tokenizer 12 | var tokenizer: BPETokenizer 13 | 14 | /// Embedding model 15 | var model: ManagedMLModel 16 | 17 | /// Creates text encoder which embeds a tokenized string 18 | /// 19 | /// - Parameters: 20 | /// - tokenizer: Tokenizer for input text 21 | /// - url: Location of compiled text encoding Core ML model 22 | /// - configuration: configuration to be used when the model is loaded 23 | /// - Returns: A text encoder that will lazily load its required resources when needed or requested 24 | public init(tokenizer: BPETokenizer, 25 | modelAt url: URL, 26 | configuration: MLModelConfiguration) { 27 | self.tokenizer = tokenizer 28 | self.model = ManagedMLModel(modelAt: url, configuration: configuration) 29 | } 30 | 31 | /// Ensure the model has been loaded into memory 32 | public func loadResources() throws { 33 | try model.loadResources() 34 | } 35 | 36 | /// Unload the underlying model to free up memory 37 | public func unloadResources() { 38 | model.unloadResources() 39 | } 40 | 41 | /// Encode input text/string 42 | /// 43 | /// - Parameters: 44 | /// - text: Input text to be tokenized and then embedded 45 | /// - Returns: Embedding representing the input text 46 | public func encode(_ text: String) throws -> MLShapedArray { 47 | 48 | // Get models expected input length 49 | let inputLength = inputShape.last! 50 | 51 | // Tokenize, padding to the expected length 52 | var (tokens, ids) = tokenizer.tokenize(input: text, minCount: inputLength) 53 | 54 | // Truncate if necessary 55 | if ids.count > inputLength { 56 | tokens = tokens.dropLast(tokens.count - inputLength) 57 | ids = ids.dropLast(ids.count - inputLength) 58 | let truncated = tokenizer.decode(tokens: tokens) 59 | NSLog("Needed to truncate input '\(text)' to '\(truncated)'") 60 | } 61 | 62 | // Use the model to generate the embedding 63 | return try encode(ids: ids) 64 | } 65 | 66 | /// Prediction queue 67 | let queue = DispatchQueue(label: "textencoder.predict") 68 | 69 | func encode(ids: [Int]) throws -> MLShapedArray { 70 | let inputName = inputDescription.name 71 | let inputShape = inputShape 72 | 73 | let floatIds = ids.map { Float32($0) } 74 | let inputArray = MLShapedArray(scalars: floatIds, shape: inputShape) 75 | let inputFeatures = try! MLDictionaryFeatureProvider( 76 | dictionary: [inputName: MLMultiArray(inputArray)]) 77 | 78 | let result = try model.perform { model in 79 | try model.prediction(from: inputFeatures) 80 | } 81 | 82 | let embeddingFeature = result.featureValue(for: "last_hidden_state") 83 | return MLShapedArray(converting: embeddingFeature!.multiArrayValue!) 84 | } 85 | 86 | var inputDescription: MLFeatureDescription { 87 | try! model.perform { model in 88 | model.modelDescription.inputDescriptionsByName.first!.value 89 | } 90 | } 91 | 92 | var inputShape: [Int] { 93 | inputDescription.multiArrayConstraint!.shape.map { $0.intValue } 94 | } 95 | 96 | } 97 | -------------------------------------------------------------------------------- /CoreML/pipeline/Unet.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import CoreML 6 | 7 | /// U-Net noise prediction model for stable diffusion 8 | @available(iOS 16.2, macOS 13.1, *) 9 | public struct Unet: ResourceManaging { 10 | 11 | /// Model used to predict noise residuals given an input, diffusion time step, and conditional embedding 12 | /// 13 | /// It can be in the form of a single model or multiple stages 14 | var models: [ManagedMLModel] 15 | 16 | /// Creates a U-Net noise prediction model 17 | /// 18 | /// - Parameters: 19 | /// - url: Location of single U-Net compiled Core ML model 20 | /// - configuration: Configuration to be used when the model is loaded 21 | /// - Returns: U-net model that will lazily load its required resources when needed or requested 22 | public init(modelAt url: URL, 23 | configuration: MLModelConfiguration) { 24 | self.models = [ManagedMLModel(modelAt: url, configuration: configuration)] 25 | } 26 | 27 | /// Creates a U-Net noise prediction model 28 | /// 29 | /// - Parameters: 30 | /// - urls: Location of chunked U-Net via urls to each compiled chunk 31 | /// - configuration: Configuration to be used when the model is loaded 32 | /// - Returns: U-net model that will lazily load its required resources when needed or requested 33 | public init(chunksAt urls: [URL], 34 | configuration: MLModelConfiguration) { 35 | self.models = urls.map { ManagedMLModel(modelAt: $0, configuration: configuration) } 36 | } 37 | 38 | /// Load resources. 39 | public func loadResources() throws { 40 | for model in models { 41 | try model.loadResources() 42 | } 43 | } 44 | 45 | /// Unload the underlying model to free up memory 46 | public func unloadResources() { 47 | for model in models { 48 | model.unloadResources() 49 | } 50 | } 51 | 52 | /// Pre-warm resources 53 | public func prewarmResources() throws { 54 | // Override default to pre-warm each model 55 | for model in models { 56 | try model.loadResources() 57 | model.unloadResources() 58 | } 59 | } 60 | 61 | var latentSampleDescription: MLFeatureDescription { 62 | try! models.first!.perform { model in 63 | model.modelDescription.inputDescriptionsByName["sample"]! 64 | } 65 | } 66 | 67 | /// The expected shape of the models latent sample input 68 | public var latentSampleShape: [Int] { 69 | latentSampleDescription.multiArrayConstraint!.shape.map { $0.intValue } 70 | } 71 | 72 | /// Batch prediction noise from latent samples 73 | /// 74 | /// - Parameters: 75 | /// - latents: Batch of latent samples in an array 76 | /// - timeStep: Current diffusion timestep 77 | /// - hiddenStates: Hidden state to condition on 78 | /// - Returns: Array of predicted noise residuals 79 | func predictNoise( 80 | latents: [MLShapedArray], 81 | timeStep: Int, 82 | hiddenStates: MLShapedArray 83 | ) throws -> [MLShapedArray] { 84 | 85 | // Match time step batch dimension to the model / latent samples 86 | let t = MLShapedArray(scalars:[Float(timeStep), Float(timeStep)],shape:[2]) 87 | 88 | // Form batch input to model 89 | let inputs = try latents.map { 90 | let dict: [String: Any] = [ 91 | "sample" : MLMultiArray($0), 92 | "timestep" : MLMultiArray(t), 93 | "encoder_hidden_states": MLMultiArray(hiddenStates) 94 | ] 95 | return try MLDictionaryFeatureProvider(dictionary: dict) 96 | } 97 | let batch = MLArrayBatchProvider(array: inputs) 98 | 99 | // Make predictions 100 | let results = try predictions(from: batch) 101 | 102 | // Pull out the results in Float32 format 103 | let noise = (0..(fp32Noise) 119 | } 120 | 121 | return noise 122 | } 123 | 124 | func predictions(from batch: MLBatchProvider) throws -> MLBatchProvider { 125 | 126 | var results = try models.first!.perform { model in 127 | try model.predictions(fromBatch: batch) 128 | } 129 | 130 | if models.count == 1 { 131 | return results 132 | } 133 | 134 | // Manual pipeline batch prediction 135 | let inputs = batch.arrayOfFeatureValueDictionaries 136 | for stage in models.dropFirst() { 137 | 138 | // Combine the original inputs with the outputs of the last stage 139 | let next = try results.arrayOfFeatureValueDictionaries 140 | .enumerated().map { (index, dict) in 141 | let nextDict = dict.merging(inputs[index]) { (out, _) in out } 142 | return try MLDictionaryFeatureProvider(dictionary: nextDict) 143 | } 144 | let nextBatch = MLArrayBatchProvider(array: next) 145 | 146 | // Predict 147 | results = try stage.perform { model in 148 | try model.predictions(fromBatch: nextBatch) 149 | } 150 | } 151 | 152 | return results 153 | } 154 | } 155 | 156 | extension MLFeatureProvider { 157 | var featureValueDictionary: [String : MLFeatureValue] { 158 | self.featureNames.reduce(into: [String : MLFeatureValue]()) { result, name in 159 | result[name] = self.featureValue(for: name) 160 | } 161 | } 162 | } 163 | 164 | extension MLBatchProvider { 165 | var arrayOfFeatureValueDictionaries: [[String : MLFeatureValue]] { 166 | (0.. [String: Int] { 13 | let content = try Data(contentsOf: url) 14 | return try JSONDecoder().decode([String: Int].self, from: content) 15 | } 16 | 17 | /// Read merges.txt file at URL into a dictionary mapping bigrams to the line number/rank/priority 18 | static func readMerges(url: URL) throws -> [TokenPair: Int] { 19 | let content = try String(contentsOf: url) 20 | let lines = content.split(separator: "\n") 21 | 22 | let merges: [(TokenPair, Int)] = try lines.enumerated().compactMap { (index, line) in 23 | if line.hasPrefix("#") { 24 | return nil 25 | } 26 | let pair = line.split(separator: " ") 27 | if pair.count != 2 { 28 | throw FileReadError.invalidMergeFileLine(index+1) 29 | } 30 | return (TokenPair(String(pair[0]), String(pair[1])),index) 31 | } 32 | return [TokenPair : Int](uniqueKeysWithValues: merges) 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /CoreML/tokenizer/BPETokenizer.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | 6 | /// A tokenizer based on byte pair encoding. 7 | public struct BPETokenizer { 8 | /// A dictionary that maps pairs of tokens to the rank/order of the merge. 9 | let merges: [TokenPair : Int] 10 | 11 | /// A dictionary from of tokens to identifiers. 12 | let vocabulary: [String: Int] 13 | 14 | /// The start token. 15 | let startToken: String = "<|startoftext|>" 16 | 17 | /// The end token. 18 | let endToken: String = "<|endoftext|>" 19 | 20 | /// The token used for padding 21 | let padToken: String = "<|endoftext|>" 22 | 23 | /// The unknown token. 24 | let unknownToken: String = "<|endoftext|>" 25 | 26 | var unknownTokenID: Int { 27 | vocabulary[unknownToken, default: 0] 28 | } 29 | 30 | /// Creates a tokenizer. 31 | /// 32 | /// - Parameters: 33 | /// - merges: A dictionary that maps pairs of tokens to the rank/order of the merge. 34 | /// - vocabulary: A dictionary from of tokens to identifiers. 35 | public init(merges: [TokenPair: Int], vocabulary: [String: Int]) { 36 | self.merges = merges 37 | self.vocabulary = vocabulary 38 | } 39 | 40 | /// Creates a tokenizer by loading merges and vocabulary from URLs. 41 | /// 42 | /// - Parameters: 43 | /// - mergesURL: The URL of a text file containing merges. 44 | /// - vocabularyURL: The URL of a JSON file containing the vocabulary. 45 | public init(mergesAt mergesURL: URL, vocabularyAt vocabularyURL: URL) throws { 46 | self.merges = try Self.readMerges(url: mergesURL) 47 | self.vocabulary = try! Self.readVocabulary(url: vocabularyURL) 48 | } 49 | 50 | /// Tokenizes an input string. 51 | /// 52 | /// - Parameters: 53 | /// - input: A string. 54 | /// - minCount: The minimum number of tokens to return. 55 | /// - Returns: An array of tokens and an array of token identifiers. 56 | public func tokenize(input: String, minCount: Int? = nil) -> (tokens: [String], tokenIDs: [Int]) { 57 | var tokens: [String] = [] 58 | 59 | tokens.append(startToken) 60 | tokens.append(contentsOf: encode(input: input)) 61 | tokens.append(endToken) 62 | 63 | // Pad if there was a min length specified 64 | if let minLen = minCount, minLen > tokens.count { 65 | tokens.append(contentsOf: repeatElement(padToken, count: minLen - tokens.count)) 66 | } 67 | 68 | let ids = tokens.map({ vocabulary[$0, default: unknownTokenID] }) 69 | return (tokens: tokens, tokenIDs: ids) 70 | } 71 | 72 | /// Returns the token identifier for a token. 73 | public func tokenID(for token: String) -> Int? { 74 | vocabulary[token] 75 | } 76 | 77 | /// Returns the token for a token identifier. 78 | public func token(id: Int) -> String? { 79 | vocabulary.first(where: { $0.value == id })?.key 80 | } 81 | 82 | /// Decodes a sequence of tokens into a fully formed string 83 | public func decode(tokens: [String]) -> String { 84 | String(tokens.joined()) 85 | .replacingOccurrences(of: "", with: " ") 86 | .replacingOccurrences(of: startToken, with: "") 87 | .replacingOccurrences(of: endToken, with: "") 88 | } 89 | 90 | /// Encode an input string to a sequence of tokens 91 | func encode(input: String) -> [String] { 92 | let normalized = input.trimmingCharacters(in: .whitespacesAndNewlines).lowercased() 93 | let words = normalized.split(separator: " ") 94 | return words.flatMap({ encode(word: $0) }) 95 | } 96 | 97 | /// Encode a single word into a sequence of tokens 98 | func encode(word: Substring) -> [String] { 99 | var tokens = word.map { String($0) } 100 | if let last = tokens.indices.last { 101 | tokens[last] = tokens[last] + "" 102 | } 103 | 104 | while true { 105 | let pairs = pairs(for: tokens) 106 | let canMerge = pairs.filter { merges[$0] != nil } 107 | 108 | if canMerge.isEmpty { 109 | break 110 | } 111 | 112 | // If multiple merges are found, use the one with the lowest rank 113 | let shouldMerge = canMerge.min { merges[$0]! < merges[$1]! }! 114 | tokens = update(tokens, merging: shouldMerge) 115 | } 116 | return tokens 117 | } 118 | 119 | /// Get the set of adjacent pairs / bigrams from a sequence of tokens 120 | func pairs(for tokens: [String]) -> Set { 121 | guard tokens.count > 1 else { 122 | return Set() 123 | } 124 | 125 | var pairs = Set(minimumCapacity: tokens.count - 1) 126 | var prev = tokens.first! 127 | for current in tokens.dropFirst() { 128 | pairs.insert(TokenPair(prev, current)) 129 | prev = current 130 | } 131 | return pairs 132 | } 133 | 134 | /// Update the sequence of tokens by greedily merging instance of a specific bigram 135 | func update(_ tokens: [String], merging bigram: TokenPair) -> [String] { 136 | guard tokens.count > 1 else { 137 | return [] 138 | } 139 | 140 | var newTokens = [String]() 141 | newTokens.reserveCapacity(tokens.count - 1) 142 | 143 | var index = 0 144 | while index < tokens.count { 145 | let remainingTokens = tokens[index...] 146 | if let startMatchIndex = remainingTokens.firstIndex(of: bigram.first) { 147 | // Found a possible match, append everything before it 148 | newTokens.append(contentsOf: tokens[index...size 39 | 40 | // In M1/M2 perflevel0 refers to the performance cores and perflevel1 are the efficiency cores 41 | // In Intel there's only one performance level 42 | let result = sysctlbyname("hw.perflevel0.physicalcpu", &ncores, &bytes, nil, 0) 43 | guard result == 0 else { return 0 } 44 | return Int(ncores) 45 | }() 46 | } 47 | -------------------------------------------------------------------------------- /Diffusion-macOS/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ContentView.swift 3 | // Diffusion-macOS 4 | // 5 | // Created by Cyril Zakka on 1/12/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | import ImageIO 11 | 12 | 13 | // AppKit version that uses NSImage, NSSavePanel 14 | struct ShareButtons: View { 15 | var image: CGImage 16 | var name: String 17 | 18 | var filename: String { 19 | name.replacingOccurrences(of: " ", with: "_") 20 | } 21 | 22 | func showSavePanel() -> URL? { 23 | let savePanel = NSSavePanel() 24 | savePanel.allowedContentTypes = [.png] 25 | savePanel.canCreateDirectories = true 26 | savePanel.isExtensionHidden = false 27 | savePanel.title = "Save your image" 28 | savePanel.message = "Choose a folder and a name to store the image." 29 | savePanel.nameFieldLabel = "File name:" 30 | savePanel.nameFieldStringValue = filename 31 | 32 | let response = savePanel.runModal() 33 | return response == .OK ? savePanel.url : nil 34 | } 35 | 36 | func savePNG(cgImage: CGImage, path: URL) { 37 | let image = NSImage(cgImage: cgImage, size: .zero) 38 | let imageRepresentation = NSBitmapImageRep(data: image.tiffRepresentation!) 39 | guard let pngData = imageRepresentation?.representation(using: .png, properties: [:]) else { 40 | print("Error generating PNG data") 41 | return 42 | } 43 | do { 44 | try pngData.write(to: path) 45 | } catch { 46 | print("Error saving: \(error)") 47 | } 48 | } 49 | 50 | var body: some View { 51 | let imageView = Image(image, scale: 1, label: Text(name)) 52 | HStack { 53 | ShareLink(item: imageView, preview: SharePreview(name, image: imageView)) 54 | Button() { 55 | if let url = showSavePanel() { 56 | savePNG(cgImage: image, path: url) 57 | } 58 | } label: { 59 | Label("Save…", systemImage: "square.and.arrow.down") 60 | } 61 | } 62 | } 63 | } 64 | 65 | struct ContentView: View { 66 | @StateObject var generation = GenerationContext() 67 | 68 | func toolbar() -> any View { 69 | if case .complete(let prompt, let cgImage, _, _) = generation.state, let cgImage = cgImage { 70 | // TODO: share seed too 71 | return ShareButtons(image: cgImage, name: prompt) 72 | } else { 73 | let prompt = DEFAULT_PROMPT 74 | let cgImage = NSImage(imageLiteralResourceName: "placeholder").cgImage(forProposedRect: nil, context: nil, hints: nil)! 75 | return ShareButtons(image: cgImage, name: prompt) 76 | } 77 | } 78 | 79 | var body: some View { 80 | NavigationSplitView { 81 | ControlsView() 82 | .navigationSplitViewColumnWidth(min: 250, ideal: 300) 83 | } detail: { 84 | GeneratedImageView() 85 | .aspectRatio(contentMode: .fit) 86 | .frame(width: 512, height: 512) 87 | .cornerRadius(15) 88 | .toolbar { 89 | AnyView(toolbar()) 90 | } 91 | 92 | } 93 | .environmentObject(generation) 94 | } 95 | } 96 | 97 | struct ContentView_Previews: PreviewProvider { 98 | static var previews: some View { 99 | ContentView() 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /Diffusion-macOS/Diffusion_macOS.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.app-sandbox 6 | 7 | com.apple.security.files.user-selected.read-write 8 | 9 | com.apple.security.network.client 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /Diffusion-macOS/Diffusion_macOSApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Diffusion_macOSApp.swift 3 | // Diffusion-macOS 4 | // 5 | // Created by Cyril Zakka on 1/12/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | 11 | @main 12 | struct Diffusion_macOSApp: App { 13 | var body: some Scene { 14 | WindowGroup { 15 | ContentView() 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /Diffusion-macOS/GeneratedImageView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GeneratedImageView.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on 18/1/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | 11 | struct GeneratedImageView: View { 12 | @EnvironmentObject var generation: GenerationContext 13 | 14 | var body: some View { 15 | switch generation.state { 16 | case .startup: return AnyView(Image("placeholder").resizable()) 17 | case .running(let progress): 18 | guard let progress = progress, progress.stepCount > 0 else { 19 | // The first time it takes a little bit before generation starts 20 | return AnyView(ProgressView()) 21 | } 22 | 23 | let step = Int(progress.step) + 1 24 | let fraction = Double(step) / Double(progress.stepCount) 25 | let label = "Step \(step) of \(progress.stepCount)" 26 | 27 | return AnyView(VStack { 28 | Group { 29 | if let safeImage = generation.previewImage { 30 | Image(safeImage, scale: 1, label: Text("generated")) 31 | .resizable() 32 | .clipShape(RoundedRectangle(cornerRadius: 20)) 33 | } 34 | } 35 | HStack { 36 | ProgressView(label, value: fraction, total: 1).padding() 37 | Button { 38 | generation.cancelGeneration() 39 | } label: { 40 | Image(systemName: "x.circle.fill").foregroundColor(.gray) 41 | } 42 | .buttonStyle(.plain) 43 | } 44 | }) 45 | case .complete(_, let image, _, _): 46 | guard let theImage = image else { 47 | return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) 48 | } 49 | 50 | return AnyView( 51 | Image(theImage, scale: 1, label: Text("generated")) 52 | .resizable() 53 | .clipShape(RoundedRectangle(cornerRadius: 20)) 54 | .contextMenu { 55 | Button { 56 | NSPasteboard.general.clearContents() 57 | let nsimage = NSImage(cgImage: theImage, size: NSSize(width: theImage.width, height: theImage.height)) 58 | NSPasteboard.general.writeObjects([nsimage]) 59 | } label: { 60 | Text("Copy Photo") 61 | } 62 | } 63 | ) 64 | case .failed(_): 65 | return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) 66 | case .userCanceled: 67 | return AnyView(Text("Generation canceled")) 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /Diffusion-macOS/HelpContent.swift: -------------------------------------------------------------------------------- 1 | // 2 | // HelpContent.swift 3 | // Diffusion-macOS 4 | // 5 | // Created by Pedro Cuenca on 7/2/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | 11 | func helpContent(title: String, description: Text, showing: Binding, width: Double = 400) -> some View { 12 | VStack { 13 | Text(title) 14 | .font(.title3) 15 | .padding(.top, 10) 16 | .padding(.all, 5) 17 | description 18 | .lineLimit(nil) 19 | .padding(.bottom, 5) 20 | .padding([.leading, .trailing], 15) 21 | Button { 22 | showing.wrappedValue.toggle() 23 | } label: { 24 | Text("Dismiss").frame(maxWidth: 200) 25 | } 26 | .padding(.bottom) 27 | } 28 | .frame(minWidth: width, idealWidth: width, maxWidth: width) 29 | } 30 | 31 | func helpContent(title: String, description: String, showing: Binding, width: Double = 400) -> some View { 32 | helpContent(title: title, description: Text(description), showing: showing) 33 | } 34 | 35 | func helpContent(title: String, description: AttributedString, showing: Binding, width: Double = 400) -> some View { 36 | helpContent(title: title, description: Text(description), showing: showing) 37 | } 38 | 39 | 40 | func modelsHelp(_ showing: Binding) -> some View { 41 | let description = try! AttributedString(markdown: 42 | """ 43 | Diffusers launches with a set of 5 models that can be downloaded from the Hugging Face Hub: 44 | 45 | **[Stable Diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4)** 46 | 47 | This is the original Stable Diffusion model that changed the landscape of AI image generation. For more details, visit the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) or click on the title above. 48 | 49 | **[Stable Diffusion 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5)** 50 | 51 | Same architecture as 1.4, but trained on additional images with a focus on aesthetics. 52 | 53 | **[Stable Diffusion 2](https://huggingface.co/StabilityAI/stable-diffusion-2-base)** 54 | 55 | Improved model, heavily retrained on millions of additional images. This version corresponds to the [`stable-diffusion-2-base`](https://huggingface.co/StabilityAI/stable-diffusion-2-base) version of the model (trained on 512 x 512 images). 56 | 57 | **[Stable Diffusion 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base)** 58 | 59 | The last reference in the Stable Diffusion family. Works great with _negative prompts_. 60 | 61 | **[OFA small v0](https://huggingface.co/OFA-Sys/small-stable-diffusion-v0)** 62 | 63 | This is a special so-called _distilled_ model, half the size of the others. It runs faster and requires less RAM, try it out if you find generation slow! 64 | 65 | """, options: AttributedString.MarkdownParsingOptions(interpretedSyntax: .inlineOnlyPreservingWhitespace)) 66 | return helpContent(title: "Available Models", description: description, showing: showing, width: 600) 67 | } 68 | 69 | func promptsHelp(_ showing: Binding) -> some View { 70 | let description = try! AttributedString(markdown: 71 | """ 72 | **Prompt** is the description of what you want, and **negative prompt** is what you _don't want_. 73 | 74 | Use the negative prompt to tweak a previous generation (by removing unwanted items), or to provide hints for the model. 75 | 76 | Many people like to use negative prompts such as "ugly, bad quality" to make the model try harder. \ 77 | Or consider excluding terms like "3d" or "realistic" if you're after particular drawing styles. 78 | 79 | """, options: AttributedString.MarkdownParsingOptions(interpretedSyntax: .inlineOnlyPreservingWhitespace)) 80 | return helpContent(title: "Prompt and Negative Prompt", description: description, showing: showing, width: 600) 81 | } 82 | 83 | func guidanceHelp(_ showing: Binding) -> some View { 84 | let description = 85 | """ 86 | Indicates how much the image should resemble the prompt. 87 | 88 | Low values produce more varied results, while excessively high ones \ 89 | may result in image artifacts such as posterization. 90 | 91 | Values between 7 and 10 are usually good choices, but they affect \ 92 | differently to different models. 93 | 94 | Feel free to experiment! 95 | """ 96 | return helpContent(title: "Guidance Scale", description: description, showing: showing) 97 | } 98 | 99 | func stepsHelp(_ showing: Binding) -> some View { 100 | let description = 101 | """ 102 | How many times to go through the diffusion process. 103 | 104 | Quality increases the more steps you choose, but marginal improvements \ 105 | get increasingly smaller. 106 | 107 | 🧨 Diffusers currently uses the super efficient DPM Solver scheduler, \ 108 | which produces great results in just 20 or 25 steps 🤯 109 | """ 110 | return helpContent(title: "Inference Steps", description: description, showing: showing) 111 | } 112 | 113 | func previewHelp(_ showing: Binding) -> some View { 114 | let description = 115 | """ 116 | This number controls how many previews to display throughout the image generation process. 117 | 118 | Using more previews can be useful if you want more visibility into how \ 119 | generation is progressing. 120 | 121 | However, computing each preview takes some time and can slow down \ 122 | generation. If the process is too slow you can reduce the preview count, \ 123 | which will result in less visibility of intermediate steps during generation. 124 | 125 | You can try different values to see what works best for your hardware. 126 | 127 | For the absolute fastest generation times, use 0 previews. 128 | """ 129 | return helpContent(title: "Preview Count", description: description, showing: showing) 130 | } 131 | 132 | func seedHelp(_ showing: Binding) -> some View { 133 | let description = 134 | """ 135 | This is a number that allows you to reproduce a previous generation. 136 | 137 | Use it like this: select a seed and write a prompt, then generate an image. \ 138 | Next, maybe add a negative prompt or tweak the prompt slightly, and see how the result changes. \ 139 | Rinse and repeat until you are satisfied, or select a new seed to start over. 140 | 141 | Set the value to 0 for a random seed to be chosen for you. 142 | """ 143 | return helpContent(title: "Generation Seed", description: description, showing: showing) 144 | } 145 | 146 | func advancedHelp(_ showing: Binding) -> some View { 147 | let description = 148 | """ 149 | This section allows you to try different optimization settings. 150 | 151 | Diffusers will try to select the best configuration for you, but it may not always be optimal \ 152 | for your computer. You can experiment with these settings to verify the combination that works faster \ 153 | in your system. 154 | 155 | Please, note that these settings may trigger downloads of additional model variants. 156 | """ 157 | return helpContent(title: "Advanced Model Settings", description: description, showing: showing) 158 | } 159 | -------------------------------------------------------------------------------- /Diffusion-macOS/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | ITSAppUsesNonExemptEncryption 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /Diffusion-macOS/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Diffusion-macOS/StatusView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // StatusView.swift 3 | // Diffusion-macOS 4 | // 5 | // Created by Cyril Zakka on 1/12/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | 11 | struct StatusView: View { 12 | @EnvironmentObject var generation: GenerationContext 13 | var pipelineState: Binding 14 | 15 | @State private var showErrorPopover = false 16 | 17 | func submit() { 18 | if case .running = generation.state { return } 19 | Task { 20 | generation.state = .running(nil) 21 | do { 22 | let result = try await generation.generate() 23 | if result.userCanceled { 24 | generation.state = .userCanceled 25 | } else { 26 | generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) 27 | } 28 | } catch { 29 | generation.state = .failed(error) 30 | } 31 | } 32 | } 33 | 34 | func errorWithDetails(_ message: String, error: Error) -> any View { 35 | HStack { 36 | Text(message) 37 | Spacer() 38 | Button { 39 | showErrorPopover.toggle() 40 | } label: { 41 | Image(systemName: "info.circle") 42 | }.buttonStyle(.plain) 43 | .popover(isPresented: $showErrorPopover) { 44 | VStack { 45 | Text(verbatim: "\(error)") 46 | .lineLimit(nil) 47 | .padding(.all, 5) 48 | Button { 49 | showErrorPopover.toggle() 50 | } label: { 51 | Text("Dismiss").frame(maxWidth: 200) 52 | } 53 | .padding(.bottom) 54 | } 55 | .frame(minWidth: 400, idealWidth: 400, maxWidth: 400) 56 | .fixedSize() 57 | } 58 | } 59 | } 60 | 61 | func generationStatusView() -> any View { 62 | switch generation.state { 63 | case .startup: return EmptyView() 64 | case .running(let progress): 65 | guard let progress = progress, progress.stepCount > 0 else { 66 | // The first time it takes a little bit before generation starts 67 | return HStack { 68 | Text("Preparing model…") 69 | Spacer() 70 | } 71 | } 72 | let step = Int(progress.step) + 1 73 | let fraction = Double(step) / Double(progress.stepCount) 74 | return HStack { 75 | Text("Generating \(Int(round(100*fraction)))%") 76 | Spacer() 77 | } 78 | case .complete(_, let image, let lastSeed, let interval): 79 | guard let _ = image else { 80 | return HStack { 81 | Text("Safety checker triggered, please try a different prompt or seed.") 82 | Spacer() 83 | } 84 | } 85 | 86 | return HStack { 87 | let intervalString = String(format: "Time: %.1fs", interval ?? 0) 88 | Text(intervalString) 89 | Spacer() 90 | if generation.seed != lastSeed { 91 | 92 | Text(String("Seed: \(formatLargeNumber(lastSeed))")) 93 | Button("Set") { 94 | generation.seed = lastSeed 95 | } 96 | } 97 | }.frame(maxHeight: 25) 98 | case .failed(let error): 99 | return errorWithDetails("Generation error", error: error) 100 | case .userCanceled: 101 | return HStack { 102 | Text("Generation canceled.") 103 | Spacer() 104 | } 105 | } 106 | } 107 | 108 | var body: some View { 109 | switch pipelineState.wrappedValue { 110 | case .downloading(let progress): 111 | ProgressView("Downloading…", value: progress*100, total: 110).padding() 112 | case .uncompressing: 113 | ProgressView("Uncompressing…", value: 100, total: 110).padding() 114 | case .loading: 115 | ProgressView("Loading…", value: 105, total: 110).padding() 116 | case .ready: 117 | VStack { 118 | Button { 119 | submit() 120 | } label: { 121 | Text("Generate") 122 | .frame(maxWidth: .infinity) 123 | .frame(height: 50) 124 | } 125 | .buttonStyle(.borderedProminent) 126 | 127 | AnyView(generationStatusView()) 128 | } 129 | case .failed(let error): 130 | AnyView(errorWithDetails("Pipeline loading error", error: error)) 131 | } 132 | } 133 | } 134 | 135 | struct StatusView_Previews: PreviewProvider { 136 | static var previews: some View { 137 | StatusView(pipelineState: .constant(.downloading(0.2))) 138 | } 139 | } 140 | -------------------------------------------------------------------------------- /Diffusion.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /Diffusion.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "pins" : [ 3 | { 4 | "identity" : "compactslider", 5 | "kind" : "remoteSourceControl", 6 | "location" : "https://github.com/buh/CompactSlider.git", 7 | "state" : { 8 | "branch" : "main", 9 | "revision" : "6d591a76caecd583ad69fbcd06f2fb83135318c0" 10 | } 11 | }, 12 | { 13 | "identity" : "ml-stable-diffusion", 14 | "kind" : "remoteSourceControl", 15 | "location" : "https://github.com/apple/ml-stable-diffusion", 16 | "state" : { 17 | "branch" : "main", 18 | "revision" : "8cf34376f9faf87fc6fe63159e5fae6cbbb71de6" 19 | } 20 | }, 21 | { 22 | "identity" : "swift-argument-parser", 23 | "kind" : "remoteSourceControl", 24 | "location" : "https://github.com/apple/swift-argument-parser.git", 25 | "state" : { 26 | "revision" : "fddd1c00396eed152c45a46bea9f47b98e59301d", 27 | "version" : "1.2.0" 28 | } 29 | }, 30 | { 31 | "identity" : "zipfoundation", 32 | "kind" : "remoteSourceControl", 33 | "location" : "https://github.com/weichsel/ZIPFoundation.git", 34 | "state" : { 35 | "revision" : "43ec568034b3731101dbf7670765d671c30f54f3", 36 | "version" : "0.9.16" 37 | } 38 | } 39 | ], 40 | "version" : 2 41 | } 42 | -------------------------------------------------------------------------------- /Diffusion.xcodeproj/project.xcworkspace/xcuserdata/cyril.xcuserdatad/UserInterfaceState.xcuserstate: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion.xcodeproj/project.xcworkspace/xcuserdata/cyril.xcuserdatad/UserInterfaceState.xcuserstate -------------------------------------------------------------------------------- /Diffusion.xcodeproj/xcshareddata/xcschemes/Diffusion.xcscheme: -------------------------------------------------------------------------------- 1 | 2 | 5 | 8 | 9 | 15 | 21 | 22 | 23 | 24 | 25 | 30 | 31 | 34 | 40 | 41 | 42 | 45 | 51 | 52 | 53 | 54 | 55 | 65 | 67 | 73 | 74 | 75 | 76 | 82 | 84 | 90 | 91 | 92 | 93 | 95 | 96 | 99 | 100 | 101 | -------------------------------------------------------------------------------- /Diffusion.xcodeproj/xcuserdata/cyril.xcuserdatad/xcschemes/xcschememanagement.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SchemeUserState 6 | 7 | Diffusion-macOS.xcscheme_^#shared#^_ 8 | 9 | orderHint 10 | 1 11 | 12 | Diffusion.xcscheme_^#shared#^_ 13 | 14 | orderHint 15 | 0 16 | 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/.DS_Store -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/256x256@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/256x256@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/512x512@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/512x512@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "diffusers_on_white_1024.png", 5 | "idiom" : "universal", 6 | "platform" : "ios", 7 | "size" : "1024x1024" 8 | }, 9 | { 10 | "idiom" : "mac", 11 | "scale" : "1x", 12 | "size" : "16x16" 13 | }, 14 | { 15 | "idiom" : "mac", 16 | "scale" : "2x", 17 | "size" : "16x16" 18 | }, 19 | { 20 | "idiom" : "mac", 21 | "scale" : "1x", 22 | "size" : "32x32" 23 | }, 24 | { 25 | "idiom" : "mac", 26 | "scale" : "2x", 27 | "size" : "32x32" 28 | }, 29 | { 30 | "idiom" : "mac", 31 | "scale" : "1x", 32 | "size" : "128x128" 33 | }, 34 | { 35 | "idiom" : "mac", 36 | "scale" : "2x", 37 | "size" : "128x128" 38 | }, 39 | { 40 | "idiom" : "mac", 41 | "scale" : "1x", 42 | "size" : "256x256" 43 | }, 44 | { 45 | "filename" : "256x256@2x.png", 46 | "idiom" : "mac", 47 | "scale" : "2x", 48 | "size" : "256x256" 49 | }, 50 | { 51 | "filename" : "256x256@2x.png", 52 | "idiom" : "mac", 53 | "scale" : "1x", 54 | "size" : "512x512" 55 | }, 56 | { 57 | "filename" : "512x512@2x.png", 58 | "idiom" : "mac", 59 | "scale" : "2x", 60 | "size" : "512x512" 61 | } 62 | ], 63 | "info" : { 64 | "author" : "xcode", 65 | "version" : 1 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/diffusers_on_white_1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/diffusers_on_white_1024.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_128x128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_128x128.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_128x128@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_128x128@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_16x16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_16x16.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_16x16@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_16x16@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_256x256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_256x256.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_256x256@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_256x256@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_32x32.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_32x32@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_32x32@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_512x512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_512x512.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/AppIcon.appiconset/icon_512x512@2x.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/AppIcon.appiconset/icon_512x512@2x.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/placeholder.imageset/-cell-blank.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/Diffusion/Assets.xcassets/placeholder.imageset/-cell-blank.png -------------------------------------------------------------------------------- /Diffusion/Assets.xcassets/placeholder.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "scale" : "1x" 6 | }, 7 | { 8 | "filename" : "-cell-blank.png", 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /Diffusion/Common/Downloader.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Downloader.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import Foundation 10 | import Combine 11 | 12 | class Downloader: NSObject, ObservableObject { 13 | private(set) var destination: URL 14 | 15 | enum DownloadState { 16 | case notStarted 17 | case downloading(Double) 18 | case completed(URL) 19 | case failed(Error) 20 | } 21 | 22 | private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted) 23 | private var stateSubscriber: Cancellable? 24 | 25 | private var urlSession: URLSession? = nil 26 | 27 | init(from url: URL, to destination: URL, using authToken: String? = nil) { 28 | self.destination = destination 29 | super.init() 30 | 31 | // .background allows downloads to proceed in the background 32 | let config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download") 33 | urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue()) 34 | downloadState.value = .downloading(0) 35 | urlSession?.getAllTasks { tasks in 36 | // If there's an existing pending background task with the same URL, let it proceed. 37 | guard tasks.filter({ $0.originalRequest?.url == url }).isEmpty else { 38 | print("Already downloading \(url)") 39 | return 40 | } 41 | print("Starting download of \(url)") 42 | 43 | var request = URLRequest(url: url) 44 | if let authToken = authToken { 45 | request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization") 46 | } 47 | 48 | self.urlSession?.downloadTask(with: request).resume() 49 | } 50 | } 51 | 52 | @discardableResult 53 | func waitUntilDone() throws -> URL { 54 | // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky) 55 | let semaphore = DispatchSemaphore(value: 0) 56 | stateSubscriber = downloadState.sink { state in 57 | switch state { 58 | case .completed: semaphore.signal() 59 | case .failed: semaphore.signal() 60 | default: break 61 | } 62 | } 63 | semaphore.wait() 64 | 65 | switch downloadState.value { 66 | case .completed(let url): return url 67 | case .failed(let error): throw error 68 | default: throw("Should never happen, lol") 69 | } 70 | } 71 | 72 | func cancel() { 73 | urlSession?.invalidateAndCancel() 74 | } 75 | } 76 | 77 | extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate { 78 | func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten _: Int64, totalBytesExpectedToWrite _: Int64) { 79 | downloadState.value = .downloading(downloadTask.progress.fractionCompleted) 80 | } 81 | 82 | func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { 83 | guard FileManager.default.fileExists(atPath: location.path) else { 84 | downloadState.value = .failed("Invalid download location received: \(location)") 85 | return 86 | } 87 | do { 88 | try FileManager.default.moveItem(at: location, to: destination) 89 | downloadState.value = .completed(destination) 90 | } catch { 91 | downloadState.value = .failed(error) 92 | } 93 | } 94 | 95 | func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { 96 | if let error = error { 97 | downloadState.value = .failed(error) 98 | } else if let response = task.response as? HTTPURLResponse { 99 | print("HTTP response status code: \(response.statusCode)") 100 | // let headers = response.allHeaderFields 101 | // print("HTTP response headers: \(headers)") 102 | } 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /Diffusion/Common/ModelInfo.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ModelInfo.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on 29/12/22. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import CoreML 10 | 11 | enum AttentionVariant: String { 12 | case original 13 | case splitEinsum 14 | case splitEinsumV2 15 | } 16 | 17 | extension AttentionVariant { 18 | var defaultComputeUnits: MLComputeUnits { self == .original ? .cpuAndGPU : .cpuAndNeuralEngine } 19 | } 20 | 21 | struct ModelInfo { 22 | /// Hugging Face model Id that contains .zip archives with compiled Core ML models 23 | let modelId: String 24 | 25 | /// Arbitrary string for presentation purposes. Something like "2.1-base" 26 | let modelVersion: String 27 | 28 | /// Suffix of the archive containing the ORIGINAL attention variant. Usually something like "original_compiled" 29 | let originalAttentionSuffix: String 30 | 31 | /// Suffix of the archive containing the SPLIT_EINSUM attention variant. Usually something like "split_einsum_compiled" 32 | let splitAttentionSuffix: String 33 | 34 | /// Suffix of the archive containing the SPLIT_EINSUM_V2 attention variant. Usually something like "split_einsum_v2_compiled" 35 | let splitAttentionV2Suffix: String 36 | 37 | /// Whether the archive contains the VAE Encoder (for image to image tasks). Not yet in use. 38 | let supportsEncoder: Bool 39 | 40 | /// Is attention v2 supported? (Ideally, we should know by looking at the repo contents) 41 | let supportsAttentionV2: Bool 42 | 43 | /// Are weights quantized? This is only used to decide whether to use `reduceMemory` 44 | let quantized: Bool 45 | 46 | /// Whether this is a Stable Diffusion XL model 47 | // TODO: retrieve from remote config 48 | let isXL: Bool 49 | 50 | //TODO: refactor all these properties 51 | init(modelId: String, modelVersion: String, 52 | originalAttentionSuffix: String = "original_compiled", 53 | splitAttentionSuffix: String = "split_einsum_compiled", 54 | splitAttentionV2Suffix: String = "split_einsum_v2_compiled", 55 | supportsEncoder: Bool = false, 56 | supportsAttentionV2: Bool = false, 57 | quantized: Bool = false, 58 | isXL: Bool = false) { 59 | self.modelId = modelId 60 | self.modelVersion = modelVersion 61 | self.originalAttentionSuffix = originalAttentionSuffix 62 | self.splitAttentionSuffix = splitAttentionSuffix 63 | self.splitAttentionV2Suffix = splitAttentionV2Suffix 64 | self.supportsEncoder = supportsEncoder 65 | self.supportsAttentionV2 = supportsAttentionV2 66 | self.quantized = quantized 67 | self.isXL = isXL 68 | } 69 | } 70 | 71 | extension ModelInfo { 72 | //TODO: set compute units instead and derive variant from it 73 | static var defaultAttention: AttentionVariant { 74 | guard runningOnMac else { return .splitEinsum } 75 | #if os(macOS) 76 | guard Capabilities.hasANE else { return .original } 77 | return Capabilities.performanceCores >= 8 ? .original : .splitEinsum 78 | #else 79 | return .splitEinsum 80 | #endif 81 | } 82 | 83 | static var defaultComputeUnits: MLComputeUnits { defaultAttention.defaultComputeUnits } 84 | 85 | var bestAttention: AttentionVariant { 86 | if !runningOnMac && supportsAttentionV2 { return .splitEinsumV2 } 87 | return ModelInfo.defaultAttention 88 | } 89 | var defaultComputeUnits: MLComputeUnits { bestAttention.defaultComputeUnits } 90 | 91 | func modelURL(for variant: AttentionVariant) -> URL { 92 | // Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip 93 | let suffix: String 94 | switch variant { 95 | case .original: suffix = originalAttentionSuffix 96 | case .splitEinsum: suffix = splitAttentionSuffix 97 | case .splitEinsumV2: suffix = splitAttentionV2Suffix 98 | } 99 | let repo = modelId.split(separator: "/").last! 100 | return URL(string: "https://huggingface.co/\(modelId)/resolve/main/\(repo)_\(suffix).zip")! 101 | } 102 | 103 | /// Best variant for the current platform. 104 | /// Currently using `split_einsum` for iOS and simple performance heuristics for macOS. 105 | var bestURL: URL { modelURL(for: bestAttention) } 106 | 107 | var reduceMemory: Bool { 108 | // Enable on iOS devices, except when using quantization 109 | if runningOnMac { return false } 110 | return !(quantized && deviceHas6GBOrMore) 111 | } 112 | } 113 | 114 | extension ModelInfo { 115 | static let v14Base = ModelInfo( 116 | modelId: "pcuenq/coreml-stable-diffusion-1-4", 117 | modelVersion: "CompVis SD 1.4" 118 | ) 119 | 120 | static let v14Palettized = ModelInfo( 121 | modelId: "apple/coreml-stable-diffusion-1-4-palettized", 122 | modelVersion: "CompVis SD 1.4 [6 bit]", 123 | supportsEncoder: true, 124 | supportsAttentionV2: true, 125 | quantized: true 126 | ) 127 | 128 | static let v15Base = ModelInfo( 129 | modelId: "pcuenq/coreml-stable-diffusion-v1-5", 130 | modelVersion: "RunwayML SD 1.5" 131 | ) 132 | 133 | static let v15Palettized = ModelInfo( 134 | modelId: "apple/coreml-stable-diffusion-v1-5-palettized", 135 | modelVersion: "RunwayML SD 1.5 [6 bit]", 136 | supportsEncoder: true, 137 | supportsAttentionV2: true, 138 | quantized: true 139 | ) 140 | 141 | static let v2Base = ModelInfo( 142 | modelId: "pcuenq/coreml-stable-diffusion-2-base", 143 | modelVersion: "StabilityAI SD 2.0", 144 | supportsEncoder: true 145 | ) 146 | 147 | static let v2Palettized = ModelInfo( 148 | modelId: "apple/coreml-stable-diffusion-2-base-palettized", 149 | modelVersion: "StabilityAI SD 2.0 [6 bit]", 150 | supportsEncoder: true, 151 | supportsAttentionV2: true, 152 | quantized: true 153 | ) 154 | 155 | static let v21Base = ModelInfo( 156 | modelId: "pcuenq/coreml-stable-diffusion-2-1-base", 157 | modelVersion: "StabilityAI SD 2.1", 158 | supportsEncoder: true 159 | ) 160 | 161 | static let v21Palettized = ModelInfo( 162 | modelId: "apple/coreml-stable-diffusion-2-1-base-palettized", 163 | modelVersion: "StabilityAI SD 2.1 [6 bit]", 164 | supportsEncoder: true, 165 | supportsAttentionV2: true, 166 | quantized: true 167 | ) 168 | 169 | static let ofaSmall = ModelInfo( 170 | modelId: "pcuenq/coreml-small-stable-diffusion-v0", 171 | modelVersion: "OFA-Sys/small-stable-diffusion-v0" 172 | ) 173 | 174 | static let xl = ModelInfo( 175 | modelId: "apple/coreml-stable-diffusion-xl-base", 176 | modelVersion: "Stable Diffusion XL base", 177 | supportsEncoder: true, 178 | isXL: true 179 | ) 180 | 181 | static let xlmbp = ModelInfo( 182 | modelId: "apple/coreml-stable-diffusion-mixed-bit-palettization", 183 | modelVersion: "Stable Diffusion XL base [4.5 bit]", 184 | supportsEncoder: true, 185 | quantized: true, 186 | isXL: true 187 | ) 188 | 189 | static let MODELS: [ModelInfo] = { 190 | if deviceSupportsQuantization { 191 | return [ 192 | ModelInfo.v14Base, 193 | ModelInfo.v14Palettized, 194 | ModelInfo.v15Base, 195 | ModelInfo.v15Palettized, 196 | ModelInfo.v2Base, 197 | ModelInfo.v2Palettized, 198 | ModelInfo.v21Base, 199 | ModelInfo.v21Palettized, 200 | ModelInfo.xl, 201 | ModelInfo.xlmbp 202 | ] 203 | } else { 204 | return [ 205 | ModelInfo.v14Base, 206 | ModelInfo.v15Base, 207 | ModelInfo.v2Base, 208 | ModelInfo.v21Base, 209 | ] 210 | } 211 | }() 212 | 213 | static func from(modelVersion: String) -> ModelInfo? { 214 | ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion}) 215 | } 216 | 217 | static func from(modelId: String) -> ModelInfo? { 218 | ModelInfo.MODELS.first(where: {$0.modelId == modelId}) 219 | } 220 | } 221 | 222 | extension ModelInfo : Equatable { 223 | static func ==(lhs: ModelInfo, rhs: ModelInfo) -> Bool { lhs.modelId == rhs.modelId } 224 | } 225 | -------------------------------------------------------------------------------- /Diffusion/Common/Pipeline/Pipeline.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Pipeline.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import Foundation 10 | import CoreML 11 | import Combine 12 | 13 | import StableDiffusion 14 | 15 | struct StableDiffusionProgress { 16 | var progress: StableDiffusionPipeline.Progress 17 | 18 | var step: Int { progress.step } 19 | var stepCount: Int { progress.stepCount } 20 | 21 | var currentImages: [CGImage?] 22 | 23 | init(progress: StableDiffusionPipeline.Progress, previewIndices: [Bool]) { 24 | self.progress = progress 25 | self.currentImages = [nil] 26 | 27 | // Since currentImages is a computed property, only access the preview image if necessary 28 | if progress.step < previewIndices.count, previewIndices[progress.step] { 29 | self.currentImages = progress.currentImages 30 | } 31 | } 32 | } 33 | 34 | struct GenerationResult { 35 | var image: CGImage? 36 | var lastSeed: UInt32 37 | var interval: TimeInterval? 38 | var userCanceled: Bool 39 | var itsPerSecond: Double? 40 | } 41 | 42 | class Pipeline { 43 | let pipeline: StableDiffusionPipelineProtocol 44 | let maxSeed: UInt32 45 | 46 | var progress: StableDiffusionProgress? = nil { 47 | didSet { 48 | progressPublisher.value = progress 49 | } 50 | } 51 | lazy private(set) var progressPublisher: CurrentValueSubject = CurrentValueSubject(progress) 52 | 53 | private var canceled = false 54 | 55 | init(_ pipeline: StableDiffusionPipelineProtocol, maxSeed: UInt32 = UInt32.max) { 56 | self.pipeline = pipeline 57 | self.maxSeed = maxSeed 58 | } 59 | 60 | func generate( 61 | prompt: String, 62 | negativePrompt: String = "", 63 | scheduler: StableDiffusionScheduler, 64 | numInferenceSteps stepCount: Int = 50, 65 | seed: UInt32 = 0, 66 | numPreviews previewCount: Int = 5, 67 | guidanceScale: Float = 7.5, 68 | disableSafety: Bool = false 69 | ) throws -> GenerationResult { 70 | let beginDate = Date() 71 | canceled = false 72 | let theSeed = seed > 0 ? seed : UInt32.random(in: 1...maxSeed) 73 | let sampleTimer = SampleTimer() 74 | sampleTimer.start() 75 | 76 | var config = StableDiffusionPipeline.Configuration(prompt: prompt) 77 | config.negativePrompt = negativePrompt 78 | config.stepCount = stepCount 79 | config.seed = theSeed 80 | config.guidanceScale = guidanceScale 81 | config.disableSafety = disableSafety 82 | config.schedulerType = scheduler 83 | config.useDenoisedIntermediates = true 84 | 85 | // Evenly distribute previews based on inference steps 86 | let previewIndices = previewIndices(stepCount, previewCount) 87 | 88 | let images = try pipeline.generateImages(configuration: config) { progress in 89 | sampleTimer.stop() 90 | handleProgress(StableDiffusionProgress(progress: progress, 91 | previewIndices: previewIndices), 92 | sampleTimer: sampleTimer) 93 | if progress.stepCount != progress.step { 94 | sampleTimer.start() 95 | } 96 | return !canceled 97 | } 98 | let interval = Date().timeIntervalSince(beginDate) 99 | print("Got images: \(images) in \(interval)") 100 | 101 | // Unwrap the 1 image we asked for, nil means safety checker triggered 102 | let image = images.compactMap({ $0 }).first 103 | return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled, itsPerSecond: 1.0/sampleTimer.median) 104 | } 105 | 106 | func handleProgress(_ progress: StableDiffusionProgress, sampleTimer: SampleTimer) { 107 | self.progress = progress 108 | } 109 | 110 | func setCancelled() { 111 | canceled = true 112 | } 113 | } 114 | -------------------------------------------------------------------------------- /Diffusion/Common/Pipeline/PipelineLoader.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PipelineLoader.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | 10 | import CoreML 11 | import Combine 12 | 13 | import ZIPFoundation 14 | import StableDiffusion 15 | 16 | class PipelineLoader { 17 | static let models = Settings.shared.applicationSupportURL().appendingPathComponent("hf-diffusion-models") 18 | let model: ModelInfo 19 | let computeUnits: ComputeUnits 20 | let maxSeed: UInt32 21 | 22 | private var downloadSubscriber: Cancellable? 23 | 24 | init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) { 25 | self.model = model 26 | self.computeUnits = computeUnits ?? model.defaultComputeUnits 27 | self.maxSeed = maxSeed 28 | state = .undetermined 29 | setInitialState() 30 | } 31 | 32 | enum PipelinePreparationPhase { 33 | case undetermined 34 | case waitingToDownload 35 | case downloading(Double) 36 | case downloaded 37 | case uncompressing 38 | case readyOnDisk 39 | case loaded 40 | case failed(Error) 41 | } 42 | 43 | var state: PipelinePreparationPhase { 44 | didSet { 45 | statePublisher.value = state 46 | } 47 | } 48 | private(set) lazy var statePublisher: CurrentValueSubject = CurrentValueSubject(state) 49 | private(set) var downloader: Downloader? = nil 50 | 51 | func setInitialState() { 52 | if ready { 53 | state = .readyOnDisk 54 | return 55 | } 56 | if downloaded { 57 | state = .downloaded 58 | return 59 | } 60 | state = .waitingToDownload 61 | } 62 | } 63 | 64 | extension PipelineLoader { 65 | // Unused. Kept for debugging purposes. --pcuenca 66 | static func removeAll() { 67 | // Delete the parent models folder as it will be recreated when it's needed again 68 | do { 69 | try FileManager.default.removeItem(at: models) 70 | } catch { 71 | print("Failed to delete: \(models), error: \(error.localizedDescription)") 72 | } 73 | } 74 | } 75 | 76 | 77 | extension PipelineLoader { 78 | func cancel() { downloader?.cancel() } 79 | } 80 | 81 | extension PipelineLoader { 82 | var url: URL { 83 | return model.modelURL(for: variant) 84 | } 85 | 86 | var filename: String { 87 | return url.lastPathComponent 88 | } 89 | 90 | var downloadedURL: URL { PipelineLoader.models.appendingPathComponent(filename) } 91 | 92 | var uncompressURL: URL { PipelineLoader.models } 93 | 94 | var packagesFilename: String { (filename as NSString).deletingPathExtension } 95 | 96 | var compiledURL: URL { downloadedURL.deletingLastPathComponent().appendingPathComponent(packagesFilename) } 97 | 98 | var downloaded: Bool { 99 | return FileManager.default.fileExists(atPath: downloadedURL.path) 100 | } 101 | 102 | var ready: Bool { 103 | return FileManager.default.fileExists(atPath: compiledURL.path) 104 | } 105 | 106 | var variant: AttentionVariant { 107 | switch computeUnits { 108 | case .cpuOnly : return .original // Not supported yet 109 | case .cpuAndGPU : return .original 110 | case .cpuAndNeuralEngine: return model.supportsAttentionV2 ? .splitEinsumV2 : .splitEinsum 111 | case .all : return .splitEinsum 112 | @unknown default: 113 | fatalError("Unknown MLComputeUnits") 114 | } 115 | } 116 | 117 | func prepare() async throws -> Pipeline { 118 | do { 119 | do { 120 | try FileManager.default.createDirectory(atPath: PipelineLoader.models.path, withIntermediateDirectories: true, attributes: nil) 121 | } catch { 122 | print("Error creating PipelineLoader.models path: \(error)") 123 | } 124 | 125 | try await download() 126 | try await unzip() 127 | let pipeline = try await load(url: compiledURL) 128 | return Pipeline(pipeline, maxSeed: maxSeed) 129 | } catch { 130 | state = .failed(error) 131 | throw error 132 | } 133 | } 134 | 135 | @discardableResult 136 | func download() async throws -> URL { 137 | if ready || downloaded { return downloadedURL } 138 | 139 | let downloader = Downloader(from: url, to: downloadedURL) 140 | self.downloader = downloader 141 | downloadSubscriber = downloader.downloadState.sink { state in 142 | if case .downloading(let progress) = state { 143 | self.state = .downloading(progress) 144 | } 145 | } 146 | try downloader.waitUntilDone() 147 | return downloadedURL 148 | } 149 | 150 | func unzip() async throws { 151 | guard downloaded else { return } 152 | state = .uncompressing 153 | do { 154 | try FileManager().unzipItem(at: downloadedURL, to: uncompressURL) 155 | } catch { 156 | // Cleanup if error occurs while unzipping 157 | try FileManager.default.removeItem(at: uncompressURL) 158 | throw error 159 | } 160 | try FileManager.default.removeItem(at: downloadedURL) 161 | state = .readyOnDisk 162 | } 163 | 164 | func load(url: URL) async throws -> StableDiffusionPipelineProtocol { 165 | let beginDate = Date() 166 | let configuration = MLModelConfiguration() 167 | configuration.computeUnits = computeUnits 168 | let pipeline: StableDiffusionPipelineProtocol 169 | if model.isXL { 170 | if #available(macOS 14.0, iOS 17.0, *) { 171 | pipeline = try StableDiffusionXLPipeline(resourcesAt: url, 172 | configuration: configuration, 173 | reduceMemory: model.reduceMemory) 174 | } else { 175 | throw "Stable Diffusion XL requires macOS 14" 176 | } 177 | } else { 178 | pipeline = try StableDiffusionPipeline(resourcesAt: url, 179 | controlNet: [], 180 | configuration: configuration, 181 | disableSafety: false, 182 | reduceMemory: model.reduceMemory) 183 | } 184 | try pipeline.loadResources() 185 | print("Pipeline loaded in \(Date().timeIntervalSince(beginDate))") 186 | state = .loaded 187 | return pipeline 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /Diffusion/Common/State.swift: -------------------------------------------------------------------------------- 1 | // 2 | // State.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on 17/1/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import Combine 10 | import SwiftUI 11 | import StableDiffusion 12 | import CoreML 13 | 14 | let DEFAULT_MODEL = ModelInfo.v2Base 15 | let DEFAULT_PROMPT = "Labrador in the style of Vermeer" 16 | 17 | enum GenerationState { 18 | case startup 19 | case running(StableDiffusionProgress?) 20 | case complete(String, CGImage?, UInt32, TimeInterval?) 21 | case userCanceled 22 | case failed(Error) 23 | } 24 | 25 | typealias ComputeUnits = MLComputeUnits 26 | 27 | class GenerationContext: ObservableObject { 28 | let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler 29 | 30 | @Published var pipeline: Pipeline? = nil { 31 | didSet { 32 | if let pipeline = pipeline { 33 | progressSubscriber = pipeline 34 | .progressPublisher 35 | .receive(on: DispatchQueue.main) 36 | .sink { progress in 37 | guard let progress = progress else { return } 38 | self.updatePreviewIfNeeded(progress) 39 | self.state = .running(progress) 40 | } 41 | } 42 | } 43 | } 44 | @Published var state: GenerationState = .startup 45 | 46 | @Published var positivePrompt = DEFAULT_PROMPT 47 | @Published var negativePrompt = "" 48 | 49 | // FIXME: Double to support the slider component 50 | @Published var steps = 25.0 51 | @Published var numImages = 1.0 52 | @Published var seed: UInt32 = 0 53 | @Published var guidanceScale = 7.5 54 | @Published var previews = 5.0 55 | @Published var disableSafety = false 56 | @Published var previewImage: CGImage? = nil 57 | 58 | @Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits 59 | 60 | private var progressSubscriber: Cancellable? 61 | 62 | private func updatePreviewIfNeeded(_ progress: StableDiffusionProgress) { 63 | if previews == 0 || progress.step == 0 { 64 | previewImage = nil 65 | } 66 | 67 | if previews > 0, let newImage = progress.currentImages.first, newImage != nil { 68 | previewImage = newImage 69 | } 70 | } 71 | 72 | func generate() async throws -> GenerationResult { 73 | guard let pipeline = pipeline else { throw "No pipeline" } 74 | return try pipeline.generate( 75 | prompt: positivePrompt, 76 | negativePrompt: negativePrompt, 77 | scheduler: scheduler, 78 | numInferenceSteps: Int(steps), 79 | seed: seed, 80 | numPreviews: Int(previews), 81 | guidanceScale: Float(guidanceScale), 82 | disableSafety: disableSafety 83 | ) 84 | } 85 | 86 | func cancelGeneration() { 87 | pipeline?.setCancelled() 88 | } 89 | } 90 | 91 | class Settings { 92 | static let shared = Settings() 93 | 94 | let defaults = UserDefaults.standard 95 | 96 | enum Keys: String { 97 | case model 98 | case safetyCheckerDisclaimer 99 | case computeUnits 100 | } 101 | 102 | private init() { 103 | defaults.register(defaults: [ 104 | Keys.model.rawValue: ModelInfo.v2Base.modelId, 105 | Keys.safetyCheckerDisclaimer.rawValue: false, 106 | Keys.computeUnits.rawValue: -1 // Use default 107 | ]) 108 | } 109 | 110 | var currentModel: ModelInfo { 111 | set { 112 | defaults.set(newValue.modelId, forKey: Keys.model.rawValue) 113 | } 114 | get { 115 | guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL } 116 | return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL 117 | } 118 | } 119 | 120 | var safetyCheckerDisclaimerShown: Bool { 121 | set { 122 | defaults.set(newValue, forKey: Keys.safetyCheckerDisclaimer.rawValue) 123 | } 124 | get { 125 | return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue) 126 | } 127 | } 128 | 129 | /// Returns the option selected by the user, if overridden 130 | /// `nil` means: guess best 131 | var userSelectedComputeUnits: ComputeUnits? { 132 | set { 133 | // Any value other than the supported ones would cause `get` to return `nil` 134 | defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue) 135 | } 136 | get { 137 | let current = defaults.integer(forKey: Keys.computeUnits.rawValue) 138 | guard current != -1 else { return nil } 139 | return ComputeUnits(rawValue: current) 140 | } 141 | } 142 | 143 | public func applicationSupportURL() -> URL { 144 | let fileManager = FileManager.default 145 | guard let appDirectoryURL = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first else { 146 | // To ensure we don't return an optional - if the user domain application support cannot be accessed use the top level application support directory 147 | return URL.applicationSupportDirectory 148 | } 149 | 150 | do { 151 | // Create the application support directory if it doesn't exist 152 | try fileManager.createDirectory(at: appDirectoryURL, withIntermediateDirectories: true, attributes: nil) 153 | return appDirectoryURL 154 | } catch { 155 | print("Error creating application support directory: \(error)") 156 | return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! 157 | } 158 | } 159 | 160 | } 161 | -------------------------------------------------------------------------------- /Diffusion/Common/Utils.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Utils.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on 14/1/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import Foundation 10 | 11 | extension String: Error {} 12 | 13 | extension Double { 14 | func formatted(_ format: String) -> String { 15 | return String(format: "\(format)", self) 16 | } 17 | } 18 | 19 | /// Returns an array of booleans that indicates at which steps a preview should be generated. 20 | /// 21 | /// - Parameters: 22 | /// - numInferenceSteps: The total number of inference steps. 23 | /// - numPreviews: The desired number of previews. 24 | /// 25 | /// - Returns: An array of booleans of size `numInferenceSteps`, where `true` values represent steps at which a preview should be made. 26 | func previewIndices(_ numInferenceSteps: Int, _ numPreviews: Int) -> [Bool] { 27 | // Ensure valid parameters 28 | guard numInferenceSteps > 0, numPreviews > 0 else { 29 | return [Bool](repeating: false, count: numInferenceSteps) 30 | } 31 | 32 | // Compute the ideal (floating-point) step size, which represents the average number of steps between previews 33 | let idealStep = Double(numInferenceSteps) / Double(numPreviews) 34 | 35 | // Compute the actual steps at which previews should be made. For each preview, we multiply the ideal step size by the preview number, and round to the nearest integer. 36 | // The result is converted to a `Set` for fast membership tests. 37 | let previewIndices: Set = Set((0.. Double { 49 | let multiplier = pow(10, Double(places)) 50 | let newDecimal = multiplier * self // move the decimal right 51 | let truncated = Double(Int(newDecimal)) // drop the fraction 52 | let originalDecimal = truncated / multiplier // move the decimal back 53 | return originalDecimal 54 | } 55 | } 56 | 57 | func formatLargeNumber(_ n: UInt32) -> String { 58 | let num = abs(Double(n)) 59 | 60 | switch num { 61 | case 1_000_000_000...: 62 | var formatted = num / 1_000_000_000 63 | formatted = formatted.reduceScale(to: 3) 64 | return "\(formatted)B" 65 | 66 | case 1_000_000...: 67 | var formatted = num / 1_000_000 68 | formatted = formatted.reduceScale(to: 3) 69 | return "\(formatted)M" 70 | 71 | case 1_000...: 72 | return "\(n)" 73 | 74 | case 0...: 75 | return "\(n)" 76 | 77 | default: 78 | return "\(n)" 79 | } 80 | } 81 | -------------------------------------------------------------------------------- /Diffusion/Common/Views/PromptTextField.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PromptTextField.swift 3 | // Diffusion-macOS 4 | // 5 | // Created by Dolmere on 22/06/2023. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | import Combine 11 | import StableDiffusion 12 | 13 | struct PromptTextField: View { 14 | @State private var output: String = "" 15 | @State private var input: String = "" 16 | @State private var typing = false 17 | @State private var tokenCount: Int = 0 18 | @State var isPositivePrompt: Bool = true 19 | @State private var tokenizer: BPETokenizer? 20 | @State private var currentModelVersion: String = "" 21 | 22 | @Binding var textBinding: String 23 | @Binding var model: String // the model version as it's stored in Settings 24 | 25 | private let maxTokenCount = 77 26 | 27 | private var modelInfo: ModelInfo? { 28 | ModelInfo.from(modelVersion: $model.wrappedValue) 29 | } 30 | 31 | private var filename: String? { 32 | let variant = modelInfo?.bestAttention ?? .original 33 | return modelInfo?.modelURL(for: variant).lastPathComponent 34 | } 35 | 36 | private var downloadedURL: URL? { 37 | if let filename = filename { 38 | return PipelineLoader.models.appendingPathComponent(filename) 39 | } 40 | return nil 41 | } 42 | 43 | private var packagesFilename: String? { 44 | (filename as NSString?)?.deletingPathExtension 45 | } 46 | 47 | private var compiledURL: URL? { 48 | if let packagesFilename = packagesFilename { 49 | return downloadedURL?.deletingLastPathComponent().appendingPathComponent(packagesFilename) 50 | } 51 | return nil 52 | } 53 | 54 | private var textColor: Color { 55 | switch tokenCount { 56 | case 0...65: 57 | return .green 58 | case 66...75: 59 | return .orange 60 | default: 61 | return .red 62 | } 63 | } 64 | 65 | // macOS initializer 66 | init(text: Binding, isPositivePrompt: Bool, model: Binding) { 67 | _textBinding = text 68 | self.isPositivePrompt = isPositivePrompt 69 | _model = model 70 | } 71 | 72 | // iOS initializer 73 | init(text: Binding, isPositivePrompt: Bool, model: String) { 74 | _textBinding = text 75 | self.isPositivePrompt = isPositivePrompt 76 | _model = .constant(model) 77 | } 78 | 79 | var body: some View { 80 | VStack { 81 | #if os(macOS) 82 | TextField(isPositivePrompt ? "Positive prompt" : "Negative Prompt", text: $textBinding, 83 | axis: .vertical) 84 | .lineLimit(20) 85 | .textFieldStyle(.squareBorder) 86 | .listRowInsets(EdgeInsets(top: 0, leading: -20, bottom: 0, trailing: 20)) 87 | .foregroundColor(textColor == .green ? .primary : textColor) 88 | .frame(minHeight: 30) 89 | if modelInfo != nil && tokenizer != nil { 90 | HStack { 91 | Spacer() 92 | if !textBinding.isEmpty { 93 | Text("\(tokenCount)") 94 | .foregroundColor(textColor) 95 | Text(" / \(maxTokenCount)") 96 | } 97 | } 98 | .onReceive(Just(textBinding)) { text in 99 | updateTokenCount(newText: text) 100 | } 101 | .font(.caption) 102 | } 103 | #else 104 | TextField("Prompt", text: $textBinding, axis: .vertical) 105 | .lineLimit(20) 106 | .listRowInsets(EdgeInsets(top: 0, leading: -20, bottom: 0, trailing: 20)) 107 | .foregroundColor(textColor == .green ? .primary : textColor) 108 | .frame(minHeight: 30) 109 | HStack { 110 | if !textBinding.isEmpty { 111 | Text("\(tokenCount)") 112 | .foregroundColor(textColor) 113 | Text(" / \(maxTokenCount)") 114 | } 115 | Spacer() 116 | } 117 | .onReceive(Just(textBinding)) { text in 118 | updateTokenCount(newText: text) 119 | } 120 | .font(.caption) 121 | #endif 122 | } 123 | .onChange(of: model) { model in 124 | updateTokenCount(newText: textBinding) 125 | } 126 | .onAppear { 127 | updateTokenCount(newText: textBinding) 128 | } 129 | } 130 | 131 | private func updateTokenCount(newText: String) { 132 | // ensure that the compiled URL exists 133 | guard let compiledURL = compiledURL else { return } 134 | // Initialize the tokenizer only when it's not created yet or the model changes 135 | // Check if the model version has changed 136 | let modelVersion = $model.wrappedValue 137 | if modelVersion != currentModelVersion { 138 | do { 139 | tokenizer = try BPETokenizer( 140 | mergesAt: compiledURL.appendingPathComponent("merges.txt"), 141 | vocabularyAt: compiledURL.appendingPathComponent("vocab.json") 142 | ) 143 | currentModelVersion = modelVersion 144 | } catch { 145 | print("Failed to create tokenizer: \(error)") 146 | return 147 | } 148 | } 149 | let (tokens, _) = tokenizer?.tokenize(input: newText) ?? ([], []) 150 | 151 | DispatchQueue.main.async { 152 | self.tokenCount = tokens.count 153 | } 154 | } 155 | } 156 | -------------------------------------------------------------------------------- /Diffusion/Diffusion.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.extended-virtual-addressing 6 | 7 | com.apple.developer.kernel.increased-memory-limit 8 | 9 | com.apple.security.app-sandbox 10 | 11 | com.apple.security.files.user-selected.read-write 12 | 13 | com.apple.security.network.client 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /Diffusion/DiffusionApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DiffusionApp.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | 11 | @main 12 | struct DiffusionApp: App { 13 | var body: some Scene { 14 | WindowGroup { 15 | LoadingView() 16 | } 17 | } 18 | } 19 | 20 | let runningOnMac = ProcessInfo.processInfo.isMacCatalystApp 21 | let deviceHas6GBOrMore = ProcessInfo.processInfo.physicalMemory > 5924000000 // Different devices report different amounts, so approximate 22 | 23 | let deviceSupportsQuantization = { 24 | if #available(iOS 17, *) { 25 | true 26 | } else { 27 | false 28 | } 29 | }() 30 | -------------------------------------------------------------------------------- /Diffusion/Info.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | NSPhotoLibraryAddUsageDescription 6 | To be able to save generated images to your Photo Library, you’ll need to allow this. 7 | CFBundleDisplayName 8 | Diffusers 9 | 10 | 11 | -------------------------------------------------------------------------------- /Diffusion/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /Diffusion/Support/AppState.swift: -------------------------------------------------------------------------------- 1 | // 2 | // AppState.swift 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 17/12/2022. 6 | // 7 | 8 | import Foundation 9 | import CoreML 10 | import Combine 11 | 12 | class AppState: ObservableObject { 13 | static let shared = AppState() 14 | 15 | @Published var pipeline: Pipeline? = nil 16 | @Published var modelDir = URL(string: "http://google.com")! 17 | @Published var models = [String]() 18 | 19 | private let def = UserDefaults.standard 20 | private(set) lazy var statePublisher: CurrentValueSubject = CurrentValueSubject(state) 21 | 22 | var state: MainViewState = .loading { 23 | didSet { 24 | statePublisher.value = state 25 | } 26 | } 27 | 28 | var currentModel: String = "" { 29 | didSet { 30 | NSLog("*** Model set") 31 | Task { 32 | NSLog("*** Loading model") 33 | await load(model: currentModel) 34 | model = currentModel 35 | } 36 | } 37 | } 38 | 39 | var prompt: String { 40 | set { 41 | def.set(newValue, forKey: "SD_Prompt") 42 | } 43 | get { 44 | return def.value(forKey: "SD_Prompt") as? String ?? "discworld the truth, Highly detailed, Artstation, Colorful" 45 | } 46 | } 47 | 48 | var negPrompt: String { 49 | set { 50 | def.set(newValue, forKey: "SD_NegPrompt") 51 | } 52 | get { 53 | return def.value(forKey: "SD_NegPrompt") as? String ?? "ugly, boring, bad anatomy" 54 | } 55 | } 56 | 57 | var model: String { 58 | set { 59 | def.set(newValue, forKey: "SD_Model") 60 | } 61 | get { 62 | return def.value(forKey: "SD_Model") as? String ?? models.first ?? "" 63 | } 64 | } 65 | 66 | var scheduler: StableDiffusionScheduler { 67 | set { 68 | def.set(newValue.rawValue, forKey: "SD_Scheduler") 69 | } 70 | get { 71 | if let key = def.value(forKey: "SD_Scheduler") as? String { 72 | return StableDiffusionScheduler(rawValue: key)! 73 | } 74 | return StableDiffusionScheduler.dpmpp 75 | } 76 | } 77 | 78 | var guidance: Double { 79 | set { 80 | def.set(newValue, forKey: "SD_Guidance") 81 | } 82 | get { 83 | return def.value(forKey: "SD_Guidance") as? Double ?? 7.5 84 | } 85 | } 86 | 87 | var steps: Double { 88 | set { 89 | def.set(newValue, forKey: "SD_Steps") 90 | } 91 | get { 92 | return def.value(forKey: "SD_Steps") as? Double ?? 25 93 | } 94 | } 95 | 96 | var numImages: Double { 97 | set { 98 | def.set(newValue, forKey: "SD_NumImages") 99 | } 100 | get { 101 | return def.value(forKey: "SD_NumImages") as? Double ?? 1 102 | } 103 | } 104 | 105 | private init() { 106 | NSLog("*** AppState initialized") 107 | // Does the model path exist? 108 | guard var dir = docDir else { 109 | state = .error("Could not get user document directory") 110 | return 111 | } 112 | dir.append(path: "Diffusion/models", directoryHint: .isDirectory) 113 | let fm = FileManager.default 114 | if !fm.fileExists(atPath: dir.path) { 115 | NSLog("Models directory does not exist at: \(dir.path). Creating ...") 116 | try? fm.createDirectory(at: dir, withIntermediateDirectories: true) 117 | } 118 | modelDir = dir 119 | // Find models in model dir 120 | do { 121 | let subs = try dir.subDirectories() 122 | subs.forEach {sub in 123 | models.append(sub.lastPathComponent) 124 | } 125 | } catch { 126 | state = .error("Could not get sub-folders under model directory: \(dir.path)") 127 | return 128 | } 129 | NSLog("*** Setting model") 130 | self.currentModel = model 131 | // On start, didSet does not appear to fire 132 | Task { 133 | await load(model: currentModel) 134 | } 135 | } 136 | 137 | func load(model: String) async { 138 | NSLog("*** Loading model: \(model)") 139 | let dir = modelDir.appending(component: model, directoryHint: .isDirectory) 140 | let fm = FileManager.default 141 | if !fm.fileExists(atPath: dir.path) { 142 | let msg = "Model directory: \(model) does not exist at: \(dir.path). Cannot proceed." 143 | NSLog(msg) 144 | state = .error(msg) 145 | return 146 | } 147 | let beginDate = Date() 148 | let configuration = MLModelConfiguration() 149 | // .all works for v1.4, but not for v1.5 150 | configuration.computeUnits = .cpuAndGPU 151 | // TODO: measure performance on different devices 152 | do { 153 | let pipeline = try StableDiffusionPipeline(resourcesAt: dir, configuration: configuration, disableSafety: true) 154 | NSLog("Pipeline loaded in \(Date().timeIntervalSince(beginDate))") 155 | DispatchQueue.main.async { 156 | self.pipeline = Pipeline(pipeline) 157 | self.state = .ready("Ready") 158 | } 159 | } catch { 160 | NSLog("Error loading model: \(error)") 161 | DispatchQueue.main.async { 162 | self.state = .error(error.localizedDescription) 163 | } 164 | } 165 | } 166 | } 167 | -------------------------------------------------------------------------------- /Diffusion/Support/Extensions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Extensions.swift 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 17/12/2022. 6 | // 7 | 8 | import Foundation 9 | 10 | extension String: Error {} 11 | 12 | extension URL { 13 | func subDirectories() throws -> [URL] { 14 | guard hasDirectoryPath else { return [] } 15 | return try FileManager.default.contentsOfDirectory(at: self, includingPropertiesForKeys: nil, options: [.skipsHiddenFiles]).filter(\.hasDirectoryPath) 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /Diffusion/Support/Functions.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Functions.swift 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 17/12/2022. 6 | // 7 | 8 | import Foundation 9 | 10 | var docDir: URL? { 11 | return FileManager.default.urls(for: .documentDirectory, in: .userDomainMask).first 12 | } 13 | -------------------------------------------------------------------------------- /Diffusion/Support/SDImage.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SDImage.swift 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 18/12/2022. 6 | // 7 | 8 | import Foundation 9 | import CoreGraphics 10 | import UniformTypeIdentifiers 11 | //import CoreGraphics 12 | //import ImageIO 13 | #if os(macOS) 14 | import AppKit 15 | #endif 16 | 17 | struct SDImage { 18 | var image: CGImage? 19 | var prompt = "" 20 | var negPrompt = "" 21 | var model = "" 22 | var scheduler = "" 23 | var seed = -1 24 | var numSteps = 25 25 | var guidance = 7.5 26 | var imageIndex = 0 27 | 28 | var seedStr: String { 29 | return String(seed) 30 | } 31 | 32 | // Save image with metadata 33 | func save() { 34 | guard let img = image else { 35 | NSLog("*** Image was not valid!") 36 | return 37 | } 38 | #if os(macOS) 39 | let panel = NSSavePanel() 40 | panel.allowedContentTypes = [.jpeg] 41 | panel.canCreateDirectories = true 42 | panel.isExtensionHidden = false 43 | panel.title = "Save your image" 44 | panel.message = "Choose a folder and a name to store the image." 45 | panel.nameFieldLabel = "Image file name:" 46 | let resp = panel.runModal() 47 | if resp != .OK { 48 | return 49 | } 50 | guard let url = panel.url else { return } 51 | guard let data = CFDataCreateMutable(nil, 0) else { return } 52 | guard let destination = CGImageDestinationCreateWithData(data, UTType.jpeg.identifier as CFString, 1, nil) else { return } 53 | let iptc = [kCGImagePropertyIPTCOriginatingProgram: meta(), kCGImagePropertyIPTCCaptionAbstract: title(), kCGImagePropertyIPTCProgramVersion: "\(seed)"] 54 | let meta = [kCGImagePropertyIPTCDictionary: iptc] 55 | CGImageDestinationAddImage(destination, img, meta as CFDictionary) 56 | guard CGImageDestinationFinalize(destination) else { return } 57 | // Save image that now has metadata 58 | do { 59 | try (data as Data).write(to: url) 60 | } catch { 61 | NSLog("*** Error saving image file: \(error)") 62 | } 63 | #endif 64 | } 65 | 66 | private func meta() -> String { 67 | return title() + " Seed: \(seed), Model: \(model), Scheduler: \(scheduler), Seed: \(seed), Steps: \(numSteps), Guidance: \(guidance), Index: \(imageIndex)" 68 | } 69 | 70 | private func title() -> String { 71 | return "Prompt: \(prompt) + Negative: \(negPrompt)" 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /Diffusion/Views/ErrorBanner.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ErrorPopover.swift 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 17/12/2022. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ErrorBanner: View { 11 | var errorMessage: String 12 | 13 | var body: some View { 14 | Text(errorMessage) 15 | .frame(maxWidth: .infinity) 16 | .font(.headline) 17 | .padding(8) 18 | .foregroundColor(.red) 19 | .background(Color.white) 20 | .cornerRadius(8) 21 | .shadow(color: Color.black.opacity(0.2), radius: 8, x: 0, y: 4) 22 | } 23 | } 24 | 25 | struct ErrorBanner_Previews: PreviewProvider { 26 | static var previews: some View { 27 | ErrorBanner(errorMessage: "This is an error!") 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /Diffusion/Views/Loading.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Loading.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | import Combine 11 | 12 | let model = deviceSupportsQuantization ? ModelInfo.v21Palettized : ModelInfo.v21Base 13 | 14 | struct LoadingView: View { 15 | @StateObject var generation = GenerationContext() 16 | 17 | @State private var preparationPhase = "Downloading…" 18 | @State private var downloadProgress: Double = 0 19 | 20 | enum CurrentView { 21 | case loading 22 | case textToImage 23 | case error(String) 24 | } 25 | @State private var currentView: CurrentView = .loading 26 | 27 | @State private var stateSubscriber: Cancellable? 28 | 29 | var body: some View { 30 | VStack { 31 | switch currentView { 32 | case .textToImage: TextToImage().transition(.opacity) 33 | case .error(let message): ErrorPopover(errorMessage: message).transition(.move(edge: .top)) 34 | case .loading: 35 | // TODO: Don't present progress view if the pipeline is cached 36 | ProgressView(preparationPhase, value: downloadProgress, total: 1).padding() 37 | } 38 | } 39 | .animation(.easeIn, value: currentView) 40 | .environmentObject(generation) 41 | .onAppear { 42 | Task.init { 43 | let loader = PipelineLoader(model: model) 44 | stateSubscriber = loader.statePublisher.sink { state in 45 | DispatchQueue.main.async { 46 | switch state { 47 | case .downloading(let progress): 48 | preparationPhase = "Downloading" 49 | downloadProgress = progress 50 | case .uncompressing: 51 | preparationPhase = "Uncompressing" 52 | downloadProgress = 1 53 | case .readyOnDisk: 54 | preparationPhase = "Loading" 55 | downloadProgress = 1 56 | default: 57 | break 58 | } 59 | } 60 | } 61 | do { 62 | generation.pipeline = try await loader.prepare() 63 | self.currentView = .textToImage 64 | } catch { 65 | self.currentView = .error("Could not load model, error: \(error)") 66 | } 67 | } 68 | } 69 | } 70 | } 71 | 72 | // Required by .animation 73 | extension LoadingView.CurrentView: Equatable {} 74 | 75 | struct ErrorPopover: View { 76 | var errorMessage: String 77 | 78 | var body: some View { 79 | Text(errorMessage) 80 | .font(.headline) 81 | .padding() 82 | .foregroundColor(.red) 83 | .background(Color.white) 84 | .cornerRadius(8) 85 | .shadow(color: Color.black.opacity(0.2), radius: 8, x: 0, y: 4) 86 | } 87 | } 88 | 89 | struct LoadingView_Previews: PreviewProvider { 90 | static var previews: some View { 91 | LoadingView() 92 | } 93 | } 94 | -------------------------------------------------------------------------------- /Diffusion/Views/MainAppView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // TextToImageView.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | import Combine 11 | 12 | enum MainViewState { 13 | case loading 14 | case idle 15 | case ready(String) 16 | case error(String) 17 | case running(StableDiffusionProgress?) 18 | } 19 | 20 | struct MainAppView: View { 21 | @StateObject var cfg = AppState.shared 22 | 23 | @State private var image: SDImage? = nil 24 | @State private var state: MainViewState = .loading 25 | @State private var prompt = "" 26 | @State private var negPrompt = "" 27 | @State private var scheduler = StableDiffusionScheduler.dpmpp 28 | @State private var guidance = 7.5 29 | @State private var width = 512.0 30 | @State private var height = 512.0 31 | @State private var steps = 25.0 32 | @State private var numImages = 1.0 33 | @State private var seed = -1 34 | @State private var safetyOn: Bool = true 35 | @State private var images = [SDImage]() 36 | 37 | @State private var stateSubscriber: Cancellable? 38 | @State private var progressSubscriber: Cancellable? 39 | @State private var progressSubs: Cancellable? 40 | 41 | var isBusy: Bool { 42 | if case .loading = state { 43 | return true 44 | } 45 | if case .running = state { 46 | return true 47 | } 48 | return false 49 | } 50 | 51 | var body: some View { 52 | VStack(alignment: .leading) { 53 | getBannerView() 54 | getTopView() 55 | Spacer().frame(height: 16) 56 | 57 | HSplitView { 58 | getSidebarView().frame(minWidth: 200, maxWidth: 400) 59 | 60 | getPreviewPane() 61 | } 62 | } 63 | .padding() 64 | .onAppear { 65 | // Set saved values 66 | prompt = cfg.prompt 67 | negPrompt = cfg.negPrompt 68 | scheduler = cfg.scheduler 69 | guidance = cfg.guidance 70 | steps = cfg.steps 71 | numImages = cfg.numImages 72 | // AppState state subscriber 73 | stateSubscriber = cfg.statePublisher.sink { state in 74 | DispatchQueue.main.async { 75 | self.state = state 76 | } 77 | } 78 | // Pipeline progress subscriber 79 | progressSubscriber = cfg.pipeline?.progressPublisher.sink { progress in 80 | guard let progress = progress else { return } 81 | state = .running(progress) 82 | } 83 | } 84 | } 85 | 86 | private func getProgressView(progress: StableDiffusionProgress?) -> AnyView { 87 | if let progress = progress, progress.stepCount > 0 { 88 | let step = Int(progress.step) + 1 89 | let fraction = Double(step) / Double(progress.stepCount) 90 | let label = "Step \(step) of \(progress.stepCount)" 91 | return AnyView(ProgressView(label, value: fraction, total: 1).padding()) 92 | } 93 | // The first time it takes a little bit before generation starts 94 | return AnyView(ProgressView(label: {Text("Loading ...")}).progressViewStyle(.linear).padding()) 95 | } 96 | 97 | private func getBannerView() -> AnyView? { 98 | if case .loading = state { 99 | return AnyView(ErrorBanner(errorMessage: "Loading ...")) 100 | } else if case let .error(msg) = state { 101 | return AnyView(ErrorBanner(errorMessage: msg)) 102 | } else if case let .running(progress) = state { 103 | return getProgressView(progress: progress) 104 | } 105 | return nil 106 | } 107 | 108 | private func getTopView() -> AnyView { 109 | let vw = HStack { 110 | VStack { 111 | TextField("Prompt", text: $prompt) 112 | .textFieldStyle(.roundedBorder) 113 | TextField("Negative Prompt", text: $negPrompt) 114 | .textFieldStyle(.roundedBorder) 115 | } 116 | Button("Generate") { 117 | submit() 118 | } 119 | .padding() 120 | .buttonStyle(.borderedProminent) 121 | .disabled(isBusy) 122 | } 123 | return AnyView(vw) 124 | } 125 | 126 | private func getSidebarView() -> AnyView { 127 | let vw = VStack(alignment: .leading) { 128 | Group { 129 | Picker("Model", selection: $cfg.currentModel) { 130 | ForEach(cfg.models, id: \.self) { s in 131 | Text(s).tag(s) 132 | } 133 | } 134 | 135 | Spacer().frame(height: 16) 136 | 137 | Picker("Scheduler", selection: $scheduler) { 138 | ForEach(StableDiffusionScheduler.allCases, id: \.self) { s in 139 | Text(s.rawValue).tag(s) 140 | } 141 | } 142 | 143 | Spacer().frame(height: 16) 144 | 145 | Text("Guidance Scale: \(String(format: "%.1f", guidance))") 146 | Slider(value: $guidance, in: 0...15, step: 0.1, label: {}, 147 | minimumValueLabel: {Text("0")}, 148 | maximumValueLabel: {Text("15")}) 149 | 150 | Spacer().frame(height: 16) 151 | } 152 | Group { 153 | Text("Number of Inference Steps: \(String(format: "%.0f", steps))") 154 | Slider(value: $steps, in: 1...300, step: 1, label: {}, 155 | minimumValueLabel: {Text("1")}, 156 | maximumValueLabel: {Text("300")}) 157 | 158 | Spacer().frame(height: 16) 159 | 160 | Text("Number of Images: \(String(format: "%.0f", numImages))") 161 | Slider(value: $numImages, in: 1...8, step: 1, label: {}, 162 | minimumValueLabel: {Text("1")}, 163 | maximumValueLabel: {Text("8")}) 164 | 165 | Spacer().frame(height: 16) 166 | } 167 | Group { 168 | Text("Safety Check On?") 169 | Toggle("", isOn: $safetyOn) 170 | 171 | Spacer().frame(height: 16) 172 | 173 | Text("Seed") 174 | TextField("", value: $seed, format: .number) 175 | } 176 | // Group { 177 | // Text("Image Width") 178 | // Slider(value: $width, in: 64...2048, step: 8, label: {}, 179 | // minimumValueLabel: {Text("64")}, 180 | // maximumValueLabel: {Text("2048")}) 181 | // Text("Image Height") 182 | // Slider(value: $height, in: 64...2048, step: 8, label: {}, 183 | // minimumValueLabel: {Text("64")}, 184 | // maximumValueLabel: {Text("2048")}) 185 | // } 186 | Spacer() 187 | } 188 | .padding() 189 | return AnyView(vw) 190 | } 191 | 192 | private func getPreviewPane() -> AnyView { 193 | let vw = VStack { 194 | PreviewView(image: $image) 195 | .scaledToFit() 196 | 197 | Divider() 198 | 199 | if images.count > 0 { 200 | ScrollView { 201 | HStack { 202 | ForEach(Array(images.enumerated()), id: \.offset) { i, img in 203 | Image(img.image!, scale: 5, label: Text("")) 204 | .onTapGesture { 205 | selectImage(index: i) 206 | } 207 | Divider() 208 | } 209 | } 210 | } 211 | .frame(height: 103) 212 | } 213 | } 214 | .padding() 215 | return AnyView(vw) 216 | } 217 | 218 | private func submit() { 219 | if case .running = state { return } 220 | guard let pipeline = cfg.pipeline else { 221 | state = .error("No pipeline available!") 222 | return 223 | } 224 | state = .running(nil) 225 | // Save current config 226 | cfg.prompt = prompt 227 | cfg.negPrompt = negPrompt 228 | cfg.scheduler = scheduler 229 | cfg.guidance = guidance 230 | cfg.steps = steps 231 | cfg.numImages = numImages 232 | // Pipeline progress subscriber 233 | progressSubs = pipeline.progressPublisher.sink { progress in 234 | guard let progress = progress else { return } 235 | DispatchQueue.main.async { 236 | state = .running(progress) 237 | } 238 | } 239 | DispatchQueue.global(qos: .background).async { 240 | do { 241 | // Generate 242 | let (imgs, seed) = try pipeline.generate(prompt: prompt, negPrompt: negPrompt, scheduler: scheduler, numInferenceSteps: Int(steps), imageCount: Int(numImages), safetyOn: safetyOn, seed: seed) 243 | progressSubs?.cancel() 244 | // Create array of SDImage instances from images 245 | var simgs = [SDImage]() 246 | for (ndx, img) in imgs.enumerated() { 247 | var s = SDImage() 248 | s.image = img 249 | s.prompt = prompt 250 | s.negPrompt = prompt 251 | s.model = cfg.currentModel 252 | s.scheduler = scheduler.rawValue 253 | s.seed = seed 254 | s.numSteps = Int(steps) 255 | s.guidance = guidance 256 | s.imageIndex = ndx 257 | simgs.append(s) 258 | } 259 | DispatchQueue.main.async { 260 | image = simgs.first 261 | images.append(contentsOf: simgs) 262 | state = .ready("Image generation complete") 263 | } 264 | } catch { 265 | let msg = "Error generating images: \(error)" 266 | NSLog(msg) 267 | DispatchQueue.main.async { 268 | state = .error(msg) 269 | } 270 | } 271 | } 272 | } 273 | 274 | private func selectImage(index: Int) { 275 | image = images[index] 276 | } 277 | } 278 | 279 | struct MainAppView_Previews: PreviewProvider { 280 | static var previews: some View { 281 | MainAppView().previewLayout(.sizeThatFits) 282 | } 283 | } 284 | -------------------------------------------------------------------------------- /Diffusion/Views/PreviewView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PreviewView.swift 3 | // Diffusion 4 | // 5 | // Created by Fahim Farook on 15/12/2022. 6 | // 7 | 8 | import SwiftUI 9 | import UniformTypeIdentifiers 10 | #if os(macOS) 11 | import AppKit 12 | #else 13 | import UIKit 14 | #endif 15 | 16 | struct PreviewView: View { 17 | var image: Binding 18 | 19 | var body: some View { 20 | if let sdi = image.wrappedValue, let img = sdi.image { 21 | let imageView = Image(img, scale: 1, label: Text("generated")) 22 | return AnyView( 23 | VStack { 24 | imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20)) 25 | HStack { 26 | Text("Seed: \(sdi.seedStr)") 27 | .help("The seed for this image. Tap to copy to clipboard.") 28 | .onTapGesture { 29 | #if os(macOS) 30 | let pb = NSPasteboard.general 31 | pb.declareTypes([.string], owner: nil) 32 | pb.setString(sdi.seedStr, forType: .string) 33 | #else 34 | UIPasteboard.general.setValue(sdi.seedStr, forPasteboardType: UTType.plainText.identifier) 35 | #endif 36 | } 37 | Spacer() 38 | ShareLink(item: imageView, preview: SharePreview(sdi.prompt, image: imageView)) 39 | Button("Save", action: { 40 | sdi.save() 41 | }) 42 | } 43 | }) 44 | } 45 | return AnyView(Image("placeholder").resizable()) 46 | } 47 | } 48 | 49 | 50 | struct PreviewView_Previews: PreviewProvider { 51 | static var previews: some View { 52 | var sd = SDImage() 53 | sd.prompt = "Test Prompt" 54 | return PreviewView(image: .constant(sd)) 55 | } 56 | } 57 | -------------------------------------------------------------------------------- /Diffusion/Views/TextToImage.swift: -------------------------------------------------------------------------------- 1 | // 2 | // TextToImage.swift 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import SwiftUI 10 | import Combine 11 | import StableDiffusion 12 | 13 | 14 | /// Presents "Share" + "Save" buttons on Mac; just "Share" on iOS/iPadOS. 15 | /// This is because I didn't find a way for "Share" to show a Save option when running on macOS. 16 | struct ShareButtons: View { 17 | var image: CGImage 18 | var name: String 19 | 20 | var filename: String { 21 | name.replacingOccurrences(of: " ", with: "_") 22 | } 23 | 24 | var body: some View { 25 | let imageView = Image(image, scale: 1, label: Text(name)) 26 | 27 | if runningOnMac { 28 | HStack { 29 | ShareLink(item: imageView, preview: SharePreview(name, image: imageView)) 30 | Button() { 31 | guard let imageData = UIImage(cgImage: image).pngData() else { 32 | return 33 | } 34 | do { 35 | let fileURL = FileManager.default.temporaryDirectory.appendingPathComponent("\(filename).png") 36 | try imageData.write(to: fileURL) 37 | let controller = UIDocumentPickerViewController(forExporting: [fileURL]) 38 | 39 | let scene = UIApplication.shared.connectedScenes.first as! UIWindowScene 40 | scene.windows.first!.rootViewController!.present(controller, animated: true) 41 | } catch { 42 | print("Error creating file") 43 | } 44 | } label: { 45 | Label("Save…", systemImage: "square.and.arrow.down") 46 | } 47 | } 48 | } else { 49 | ShareLink(item: imageView, preview: SharePreview(name, image: imageView)) 50 | } 51 | } 52 | } 53 | 54 | struct ImageWithPlaceholder: View { 55 | @EnvironmentObject var generation: GenerationContext 56 | var state: Binding 57 | 58 | var body: some View { 59 | switch state.wrappedValue { 60 | case .startup: return AnyView(Image("placeholder").resizable()) 61 | case .running(let progress): 62 | guard let progress = progress, progress.stepCount > 0 else { 63 | // The first time it takes a little bit before generation starts 64 | return AnyView(ProgressView()) 65 | } 66 | 67 | let step = Int(progress.step) + 1 68 | let fraction = Double(step) / Double(progress.stepCount) 69 | let label = "Step \(step) of \(progress.stepCount)" 70 | return AnyView(VStack { 71 | Group { 72 | if let safeImage = generation.previewImage { 73 | Image(safeImage, scale: 1, label: Text("generated")) 74 | .resizable() 75 | .clipShape(RoundedRectangle(cornerRadius: 20)) 76 | } 77 | } 78 | ProgressView(label, value: fraction, total: 1).padding() 79 | }) 80 | case .complete(let lastPrompt, let image, _, let interval): 81 | guard let theImage = image else { 82 | return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) 83 | } 84 | 85 | let imageView = Image(theImage, scale: 1, label: Text("generated")) 86 | return AnyView( 87 | VStack { 88 | imageView.resizable().clipShape(RoundedRectangle(cornerRadius: 20)) 89 | HStack { 90 | let intervalString = String(format: "Time: %.1fs", interval ?? 0) 91 | Rectangle().fill(.clear).overlay(Text(intervalString).frame(maxWidth: .infinity, alignment: .leading).padding(.leading)) 92 | Rectangle().fill(.clear).overlay( 93 | HStack { 94 | Spacer() 95 | ShareButtons(image: theImage, name: lastPrompt).padding(.trailing) 96 | } 97 | ) 98 | }.frame(maxHeight: 25) 99 | }) 100 | case .failed(_): 101 | return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) 102 | case .userCanceled: 103 | return AnyView(Text("Generation canceled")) 104 | } 105 | } 106 | } 107 | 108 | struct TextToImage: View { 109 | @EnvironmentObject var generation: GenerationContext 110 | 111 | func submit() { 112 | if case .running = generation.state { return } 113 | Task { 114 | generation.state = .running(nil) 115 | do { 116 | let result = try await generation.generate() 117 | generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) 118 | } catch { 119 | generation.state = .failed(error) 120 | } 121 | } 122 | } 123 | 124 | var body: some View { 125 | VStack { 126 | HStack { 127 | PromptTextField(text: $generation.positivePrompt, isPositivePrompt: true, model: deviceSupportsQuantization ? ModelInfo.v21Palettized.modelVersion : ModelInfo.v21Base.modelVersion) 128 | Button("Generate") { 129 | submit() 130 | } 131 | .padding() 132 | .buttonStyle(.borderedProminent) 133 | } 134 | ImageWithPlaceholder(state: $generation.state) 135 | .scaledToFit() 136 | Spacer() 137 | } 138 | .padding() 139 | .environmentObject(generation) 140 | } 141 | } 142 | -------------------------------------------------------------------------------- /DiffusionTests/DiffusionTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DiffusionTests.swift 3 | // DiffusionTests 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import XCTest 10 | 11 | final class DiffusionTests: XCTestCase { 12 | 13 | override func setUpWithError() throws { 14 | // Put setup code here. This method is called before the invocation of each test method in the class. 15 | } 16 | 17 | override func tearDownWithError() throws { 18 | // Put teardown code here. This method is called after the invocation of each test method in the class. 19 | } 20 | 21 | func testExample() throws { 22 | // This is an example of a functional test case. 23 | // Use XCTAssert and related functions to verify your tests produce the correct results. 24 | // Any test you write for XCTest can be annotated as throws and async. 25 | // Mark your test throws to produce an unexpected failure when your test encounters an uncaught error. 26 | // Mark your test async to allow awaiting for asynchronous code to complete. Check the results with assertions afterwards. 27 | } 28 | 29 | func testPerformanceExample() throws { 30 | // This is an example of a performance test case. 31 | measure { 32 | // Put the code you want to measure the time of here. 33 | } 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /DiffusionUITests/DiffusionUITests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DiffusionUITests.swift 3 | // DiffusionUITests 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import XCTest 10 | 11 | final class DiffusionUITests: XCTestCase { 12 | 13 | override func setUpWithError() throws { 14 | // Put setup code here. This method is called before the invocation of each test method in the class. 15 | 16 | // In UI tests it is usually best to stop immediately when a failure occurs. 17 | continueAfterFailure = false 18 | 19 | // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this. 20 | } 21 | 22 | override func tearDownWithError() throws { 23 | // Put teardown code here. This method is called after the invocation of each test method in the class. 24 | } 25 | 26 | func testExample() throws { 27 | // UI tests must launch the application that they test. 28 | let app = XCUIApplication() 29 | app.launch() 30 | 31 | // Use XCTAssert and related functions to verify your tests produce the correct results. 32 | } 33 | 34 | func testLaunchPerformance() throws { 35 | if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) { 36 | // This measures how long it takes to launch your application. 37 | measure(metrics: [XCTApplicationLaunchMetric()]) { 38 | XCUIApplication().launch() 39 | } 40 | } 41 | } 42 | } 43 | -------------------------------------------------------------------------------- /DiffusionUITests/DiffusionUITestsLaunchTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DiffusionUITestsLaunchTests.swift 3 | // DiffusionUITests 4 | // 5 | // Created by Pedro Cuenca on December 2022. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | import XCTest 10 | 11 | final class DiffusionUITestsLaunchTests: XCTestCase { 12 | 13 | override class var runsForEachTargetApplicationUIConfiguration: Bool { 14 | true 15 | } 16 | 17 | override func setUpWithError() throws { 18 | continueAfterFailure = false 19 | } 20 | 21 | func testLaunch() throws { 22 | let app = XCUIApplication() 23 | app.launch() 24 | 25 | // Insert steps here to perform after app launch but before taking a screenshot, 26 | // such as logging into a test account or navigating somewhere in the app 27 | 28 | let attachment = XCTAttachment(screenshot: app.screenshot()) 29 | attachment.name = "Launch Screen" 30 | attachment.lifetime = .keepAlways 31 | add(attachment) 32 | } 33 | } 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Hugging Face SAS. 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Swift Core ML Diffusers 🧨 2 | 3 | This is a native app that shows how to integrate Apple's [Core ML Stable Diffusion implementation](https://github.com/apple/ml-stable-diffusion) in a native Swift UI application. The Core ML port is a simplification of the Stable Diffusion implementation from the [diffusers library](https://github.com/huggingface/diffusers). This application can be used for faster iteration, or as sample code for any use cases. 4 | 5 | This is what the app looks like on macOS: 6 | ![App Screenshot](screenshot.jpg) 7 | 8 | On first launch, the application downloads a zipped archive with a Core ML version of Stability AI's Stable Diffusion v2 base, from [this location in the Hugging Face Hub](https://huggingface.co/pcuenq/coreml-stable-diffusion-2-base/tree/main). This process takes a while, as several GB of data have to be downloaded and unarchived. 9 | 10 | For faster inference, we use a very fast scheduler: [DPM-Solver++](https://github.com/LuChengTHU/dpm-solver), that we ported to Swift from our [diffusers DPMSolverMultistepScheduler implementation](https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py). 11 | 12 | The app supports models quantized with `coremltools` version 7 or better. This requires macOS 14 or iOS/iPadOS 17. 13 | 14 | ## Compatibility and Performance 15 | 16 | - macOS Ventura 13.1, iOS/iPadOS 16.2, Xcode 14.2. 17 | - Performance (after the initial generation, which is slower) 18 | * ~8s in macOS on MacBook Pro M1 Max (64 GB). Model: Stable Diffusion v2-base, ORIGINAL attention implementation, running on CPU + GPU. 19 | * 23 ~ 30s on iPhone 13 Pro. Model: Stable Diffusion v2-base, SPLIT_EINSUM attention, CPU + Neural Engine, memory reduction enabled. 20 | 21 | See [this post](https://huggingface.co/blog/fast-mac-diffusers) and [this issue](https://github.com/huggingface/swift-coreml-diffusers/issues/31) for additional performance figures. 22 | 23 | Quantized models run faster, but they require macOS Ventura 14, or iOS/iPadOS 17. 24 | 25 | The application will try to guess the best hardware to run models on. You can override this setting using the `Advanced` section in the controls sidebar. 26 | 27 | ## How to Run 28 | 29 | The easiest way to test the app on macOS is by [downloading it from the Mac App Store](https://apps.apple.com/app/diffusers/id1666309574). 30 | 31 | ## How to Build 32 | 33 | You need [Xcode](https://developer.apple.com/xcode/) to build the app. When you clone the repo, please update `common.xcconfig` with your development team identifier. Code signing is required to run on iOS, but it's currently disabled for macOS. 34 | 35 | ## Known Issues 36 | 37 | Performance on iPhone is somewhat erratic, sometimes it's ~20x slower and the phone heats up. This happens because the model could not be scheduled to run on the Neural Engine and everything happens in the CPU. We have not been able to determine the reasons for this problem. If you observe the same, here are some recommendations: 38 | - Detach from Xcode 39 | - Kill apps you are not using. 40 | - Let the iPhone cool down before repeating the test. 41 | - Reboot your device. 42 | 43 | ## Next Steps 44 | 45 | - Allow additional models to be downloaded from the Hub. 46 | -------------------------------------------------------------------------------- /RELEASE.md: -------------------------------------------------------------------------------- 1 | # Release Notes 2 | 3 | ### 1.2 — WIP 4 | 5 | * Prompt history — pick from the last 20 prompts and negative prompts 6 | * 7 | 8 | ### 1.1 — 19 Dec 2022 9 | 10 | * Custom file naming on save 11 | * Changed initial placeholder image 12 | * Minor code changes 13 | 14 | ### 1.0 — 18 Dec 2022 15 | 16 | * Support for negative prompts. 17 | * Ability to switch between different models — tested on CoreML Stable Diffusion 1.5 and 2.0-base. 18 | * Ability to switch between different schedulers — currently supports PNDM and DPMPP. 19 | * Supports different generation configurations — guidance scale and number of inference steps. 20 | * Ability to generate multiple images in one run. 21 | * Display history of generated images and allows you to browse generated images via a gallery. 22 | * Show random seed used for each image so that you can use the seed to re-generate the same image. 23 | * Saves your current settings including prompt and negative prompt so that it is resotred the next time around. -------------------------------------------------------------------------------- /assets/screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/assets/screenshot.jpg -------------------------------------------------------------------------------- /config/common.xcconfig: -------------------------------------------------------------------------------- 1 | // 2 | // common.xcconfig 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on 202212. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | // Configuration settings file format documentation can be found at: 10 | // https://help.apple.com/xcode/#/dev745c5c974 11 | 12 | PRODUCT_NAME = Diffusers 13 | CURRENT_PROJECT_VERSION = 1.1.0 14 | MARKETING_VERSION = 1.1 15 | 16 | // Update if you fork this repo 17 | DEVELOPMENT_TEAM = 2EADP68M95 18 | PRODUCT_BUNDLE_IDENTIFIER = com.huggingface.Diffusers 19 | -------------------------------------------------------------------------------- /config/debug.xcconfig: -------------------------------------------------------------------------------- 1 | // 2 | // debug.xcconfig 3 | // Diffusion 4 | // 5 | // Created by Pedro Cuenca on 17/1/23. 6 | // See LICENSE at https://github.com/huggingface/swift-coreml-diffusers/LICENSE 7 | // 8 | 9 | #include "common.xcconfig" 10 | 11 | // Disable code-signing for macOS 12 | CODE_SIGN_IDENTITY[sdk=macos*] = 13 | -------------------------------------------------------------------------------- /screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FahimF/swift-coreml-diffusers/1b41dd48158f59b96f6c597104d47c7e6e6df264/screenshot.jpg --------------------------------------------------------------------------------