├── .gitignore ├── LICENSE ├── README.md ├── demonstration.jpg ├── maple-convert.py ├── maple-diffusion.xcodeproj ├── project.pbxproj └── project.xcworkspace │ ├── contents.xcworkspacedata │ └── xcshareddata │ └── IDEWorkspaceChecks.plist ├── maple-diffusion ├── Assets.xcassets │ ├── AccentColor.colorset │ │ └── Contents.json │ ├── AppIcon.appiconset │ │ └── Contents.json │ └── Contents.json ├── ContentView.swift ├── MapleDiffusion.swift ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json ├── bins │ ├── .gitkeep │ └── alphas_cumprod.bin ├── maple_diffusion.entitlements └── maple_diffusionApp.swift ├── requirements.txt └── screenshot.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | maple-diffusion/bins/*.txt 2 | maple-diffusion/bins/*.bin 3 | xcuserdata/ 4 | .DS_Store 5 | # dreambooth ckpts don't have alphas 6 | !maple-diffusion/bins/alphas_cumprod.bin 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ollin Boer Bohan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🍁 Maple Diffusion 2 | 3 | Maple Diffusion runs Stable Diffusion models **locally** on macOS / iOS devices, in Swift, using the MPSGraph framework (not Python). 4 | 5 | ![](demonstration.jpg) 6 | 7 | Maple Diffusion should be capable of generating a reasonable image [in a minute or two](https://twitter.com/madebyollin/status/1579213789823893504) on a recent iPhone (I get around ~2.3s / step on an iPhone 13 Pro). 8 | 9 | To attain usable performance without tripping over iOS's 4GB memory limit, Maple Diffusion relies internally on FP16 (NHWC) tensors, operator fusion from MPSGraph, and a truly pitiable degree of swapping models to device storage. 10 | 11 | On macOS, Maple Diffusion uses slightly more memory (~6GB), to reach <1s / step. 12 | 13 | ![](screenshot.jpg) 14 | 15 | # Related Projects 16 | 17 | * **Core ML Stable Diffusion** ([repo](https://github.com/apple/ml-stable-diffusion)) is Apple's recommended way of running Stable Diffusion in Swift, using CoreML instead of MPSGraph. CoreML was originally much slower than MPSGraph ([I tried it back in August](https://gist.github.com/madebyollin/86b9596ffa4ab0fa7674a16ca2aeab3d)), but Apple has improved CoreML performance a lot on recent macOS / iOS versions. 18 | * **Native Diffusion** ([repo](https://github.com/mortenjust/native-diffusion/)) is a Swift Package-ified version of this codebase with several improvements (including image-to-image) 19 | * **Waifu Art AI** ([announcement](https://twitter.com/dgspitzer/status/1596652212964712449), [App Store link](https://apps.apple.com/us/app/waifu-art-ai-local-generator/id6444585505)) is an iOS / macOS app for (anime-style) Stable Diffusion based on this codebase 20 | * **Draw Things** ([announcement](https://liuliu.me/eyes/stretch-iphone-to-its-limit-a-2gib-model-that-can-draw-everything-in-your-pocket/), [App Store link](https://apps.apple.com/us/app/draw-things-ai-generation/id6444050820)) is an iOS app for Stable Diffusion (using an independent codebase with similar MPSGraph-based approach) 21 | 22 | # Device Requirements 23 | 24 | Maple Diffusion should run on any Apple Silicon Mac (M1, M2, etc.). Intel Macs should also work now thanks to [this PR](https://github.com/madebyollin/maple-diffusion/pull/14#issuecomment-1282166802). 25 | 26 | Maple Diffusion should run on any iOS device with [sufficient RAM](https://blakespot.com/ios_device_specifications_grid.html) (≥6144MB RAM definitely works; 4096MB [doesn't](https://github.com/madebyollin/maple-diffusion/issues/25)). That means recent iPads should work out of the box, and recent iPhones should work if you can get the `Increase Memory Limit` capability working (to unlock 4GB of app-usable RAM). iPhone 14 variants reportedly didn't work until [iOS 16.1 stable](https://github.com/madebyollin/maple-diffusion/issues/5#issuecomment-1304410263). 27 | 28 | Maple Diffusion currently expects **Xcode 14** and **iOS 16**; other versions may require changing build settings or just not work. iOS 16.1 (beta) was reportedly [broken](https://github.com/madebyollin/maple-diffusion/issues/8) and always generating a gray image, but I think that's fixed 29 | 30 | # Usage 31 | 32 | To build and run Maple Diffusion: 33 | 34 | 1. Download a Stable Diffusion PyTorch model checkpoint ([`sd-v1-4.ckpt`](https://huggingface.co/CompVis/stable-diffusion-v-1-4-original), or some derivation thereof) 35 | 36 | 2. Download this repo 37 | 38 | ```bash 39 | git clone https://github.com/madebyollin/maple-diffusion.git && cd maple-diffusion 40 | ``` 41 | 42 | 3. Setup & install Python with PyTorch, if you haven't already. 43 | 44 | ```bash 45 | # may need to install conda first https://github.com/conda-forge/miniforge#homebrew 46 | conda deactivate 47 | conda remove -n maple-diffusion --all 48 | conda create -n maple-diffusion python=3.10 49 | conda activate maple-diffusion 50 | pip install torch typing_extensions numpy Pillow requests pytorch_lightning 51 | ``` 52 | 53 | 4. Convert the PyTorch model checkpoint into a bunch of fp16 binary blobs. 54 | 55 | ```bash 56 | ./maple-convert.py ~/Downloads/sd-v1-4.ckpt 57 | ``` 58 | 59 | 5. Open the `maple-diffusion` Xcode project. Select the device you want to run on from the `Product > Destination` menu. 60 | 61 | 6. [Manually add](https://github.com/madebyollin/maple-diffusion/issues/5#issuecomment-1279111878) the `Increased Memory Limit` capability to the `maple-diffusion` target (this step might not be needed on iPads, but it's definitely needed on iPhones - the default limit is 3GB). 62 | 63 | 7. Build & run the project on your device with the `Product > Run` menu. 64 | -------------------------------------------------------------------------------- /demonstration.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madebyollin/maple-diffusion/6304d68a066b8a3d9a2d5faded29be271ea5a55a/demonstration.jpg -------------------------------------------------------------------------------- /maple-convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import sys 3 | if len(sys.argv) < 2: raise ValueError(f"Usage: {sys.argv[0]} path_to_ckpt") 4 | 5 | from pathlib import Path 6 | import torch as th 7 | import numpy as np 8 | 9 | ckpt = th.load(sys.argv[1], map_location="cpu") 10 | outpath = Path("maple-diffusion/bins") 11 | outpath.mkdir(exist_ok=True) 12 | 13 | # vocab for clip 14 | vocab_url = "https://openaipublic.blob.core.windows.net/clip/bpe_simple_vocab_16e6.txt" 15 | vocab_dest = outpath / vocab_url.split("/")[-1] 16 | if not vocab_dest.exists(): 17 | print("downloading clip vocab") 18 | import requests 19 | with requests.get(vocab_url, stream=True) as r: 20 | assert r.status_code == 200, f"{vocab_url} failed to download. please copy it to {vocab_dest} manually." 21 | with vocab_dest.open('wb') as vf: 22 | for c in r.iter_content(chunk_size=8192): 23 | vf.write(c) 24 | print("downloaded clip vocab") 25 | 26 | # model weights 27 | for k, v in ckpt["state_dict"].items(): 28 | if "first_stage_model.encoder" in k: continue 29 | if not hasattr(v, "numpy"): continue 30 | v.numpy().astype('float16').tofile(outpath / (k + ".bin")) 31 | print("exporting state_dict", k, end="\r") 32 | print("\nexporting other stuff...") 33 | 34 | # other stuff 35 | th.exp(-th.log(th.tensor([10000])) * th.arange(0, 160) / 160).numpy().tofile(outpath / "temb_coefficients_fp32.bin") 36 | np.triu(np.ones((1,1,77,77), dtype=np.float16) * -65500.0, k=1).astype(np.float16).tofile(outpath / "causal_mask.bin") 37 | np.array([0.14013671875, 0.0711669921875, -0.03271484375, -0.11407470703125, 0.126220703125, 0.10101318359375, 0.034515380859375, -0.1383056640625, 0.126220703125, 0.07733154296875, 0.042633056640625, -0.177978515625]).astype(np.float16).tofile(outpath / "aux_output_conv.weight.bin") 38 | np.array([0.423828125, 0.471923828125, 0.473876953125]).astype(np.float16).tofile(outpath / "aux_output_conv.bias.bin") 39 | print(f"Done!") 40 | -------------------------------------------------------------------------------- /maple-diffusion.xcodeproj/project.pbxproj: -------------------------------------------------------------------------------- 1 | // !$*UTF8*$! 2 | { 3 | archiveVersion = 1; 4 | classes = { 5 | }; 6 | objectVersion = 56; 7 | objects = { 8 | 9 | /* Begin PBXBuildFile section */ 10 | 97B171AF28F23B7700B97242 /* maple_diffusionApp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 97B171AE28F23B7700B97242 /* maple_diffusionApp.swift */; }; 11 | 97B171B128F23B7700B97242 /* ContentView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 97B171B028F23B7700B97242 /* ContentView.swift */; }; 12 | 97B171B328F23B7800B97242 /* Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97B171B228F23B7800B97242 /* Assets.xcassets */; }; 13 | 97B171B728F23B7800B97242 /* Preview Assets.xcassets in Resources */ = {isa = PBXBuildFile; fileRef = 97B171B628F23B7800B97242 /* Preview Assets.xcassets */; }; 14 | 97B171BE28F23BBC00B97242 /* MapleDiffusion.swift in Sources */ = {isa = PBXBuildFile; fileRef = 97B171BD28F23BBC00B97242 /* MapleDiffusion.swift */; }; 15 | 97B171C028F23C7D00B97242 /* bins in Resources */ = {isa = PBXBuildFile; fileRef = 97B171BF28F23C7D00B97242 /* bins */; }; 16 | /* End PBXBuildFile section */ 17 | 18 | /* Begin PBXFileReference section */ 19 | 97B171AB28F23B7700B97242 /* maple-diffusion.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = "maple-diffusion.app"; sourceTree = BUILT_PRODUCTS_DIR; }; 20 | 97B171AE28F23B7700B97242 /* maple_diffusionApp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = maple_diffusionApp.swift; sourceTree = ""; }; 21 | 97B171B028F23B7700B97242 /* ContentView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ContentView.swift; sourceTree = ""; }; 22 | 97B171B228F23B7800B97242 /* Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = Assets.xcassets; sourceTree = ""; }; 23 | 97B171B428F23B7800B97242 /* maple_diffusion.entitlements */ = {isa = PBXFileReference; lastKnownFileType = text.plist.entitlements; path = maple_diffusion.entitlements; sourceTree = ""; }; 24 | 97B171B628F23B7800B97242 /* Preview Assets.xcassets */ = {isa = PBXFileReference; lastKnownFileType = folder.assetcatalog; path = "Preview Assets.xcassets"; sourceTree = ""; }; 25 | 97B171BD28F23BBC00B97242 /* MapleDiffusion.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = MapleDiffusion.swift; sourceTree = ""; }; 26 | 97B171BF28F23C7D00B97242 /* bins */ = {isa = PBXFileReference; lastKnownFileType = folder; path = bins; sourceTree = ""; }; 27 | /* End PBXFileReference section */ 28 | 29 | /* Begin PBXFrameworksBuildPhase section */ 30 | 97B171A828F23B7700B97242 /* Frameworks */ = { 31 | isa = PBXFrameworksBuildPhase; 32 | buildActionMask = 2147483647; 33 | files = ( 34 | ); 35 | runOnlyForDeploymentPostprocessing = 0; 36 | }; 37 | /* End PBXFrameworksBuildPhase section */ 38 | 39 | /* Begin PBXGroup section */ 40 | 97B171A228F23B7700B97242 = { 41 | isa = PBXGroup; 42 | children = ( 43 | 97B171AD28F23B7700B97242 /* maple-diffusion */, 44 | 97B171AC28F23B7700B97242 /* Products */, 45 | ); 46 | sourceTree = ""; 47 | }; 48 | 97B171AC28F23B7700B97242 /* Products */ = { 49 | isa = PBXGroup; 50 | children = ( 51 | 97B171AB28F23B7700B97242 /* maple-diffusion.app */, 52 | ); 53 | name = Products; 54 | sourceTree = ""; 55 | }; 56 | 97B171AD28F23B7700B97242 /* maple-diffusion */ = { 57 | isa = PBXGroup; 58 | children = ( 59 | 97B171AE28F23B7700B97242 /* maple_diffusionApp.swift */, 60 | 97B171B028F23B7700B97242 /* ContentView.swift */, 61 | 97B171BD28F23BBC00B97242 /* MapleDiffusion.swift */, 62 | 97B171BF28F23C7D00B97242 /* bins */, 63 | 97B171B228F23B7800B97242 /* Assets.xcassets */, 64 | 97B171B428F23B7800B97242 /* maple_diffusion.entitlements */, 65 | 97B171B528F23B7800B97242 /* Preview Content */, 66 | ); 67 | path = "maple-diffusion"; 68 | sourceTree = ""; 69 | }; 70 | 97B171B528F23B7800B97242 /* Preview Content */ = { 71 | isa = PBXGroup; 72 | children = ( 73 | 97B171B628F23B7800B97242 /* Preview Assets.xcassets */, 74 | ); 75 | path = "Preview Content"; 76 | sourceTree = ""; 77 | }; 78 | /* End PBXGroup section */ 79 | 80 | /* Begin PBXNativeTarget section */ 81 | 97B171AA28F23B7700B97242 /* maple-diffusion */ = { 82 | isa = PBXNativeTarget; 83 | buildConfigurationList = 97B171BA28F23B7800B97242 /* Build configuration list for PBXNativeTarget "maple-diffusion" */; 84 | buildPhases = ( 85 | 97B171A728F23B7700B97242 /* Sources */, 86 | 97B171A828F23B7700B97242 /* Frameworks */, 87 | 97B171A928F23B7700B97242 /* Resources */, 88 | ); 89 | buildRules = ( 90 | ); 91 | dependencies = ( 92 | ); 93 | name = "maple-diffusion"; 94 | productName = "maple-diffusion"; 95 | productReference = 97B171AB28F23B7700B97242 /* maple-diffusion.app */; 96 | productType = "com.apple.product-type.application"; 97 | }; 98 | /* End PBXNativeTarget section */ 99 | 100 | /* Begin PBXProject section */ 101 | 97B171A328F23B7700B97242 /* Project object */ = { 102 | isa = PBXProject; 103 | attributes = { 104 | BuildIndependentTargetsInParallel = 1; 105 | LastSwiftUpdateCheck = 1400; 106 | LastUpgradeCheck = 1400; 107 | TargetAttributes = { 108 | 97B171AA28F23B7700B97242 = { 109 | CreatedOnToolsVersion = 14.0.1; 110 | }; 111 | }; 112 | }; 113 | buildConfigurationList = 97B171A628F23B7700B97242 /* Build configuration list for PBXProject "maple-diffusion" */; 114 | compatibilityVersion = "Xcode 14.0"; 115 | developmentRegion = en; 116 | hasScannedForEncodings = 0; 117 | knownRegions = ( 118 | en, 119 | Base, 120 | ); 121 | mainGroup = 97B171A228F23B7700B97242; 122 | productRefGroup = 97B171AC28F23B7700B97242 /* Products */; 123 | projectDirPath = ""; 124 | projectRoot = ""; 125 | targets = ( 126 | 97B171AA28F23B7700B97242 /* maple-diffusion */, 127 | ); 128 | }; 129 | /* End PBXProject section */ 130 | 131 | /* Begin PBXResourcesBuildPhase section */ 132 | 97B171A928F23B7700B97242 /* Resources */ = { 133 | isa = PBXResourcesBuildPhase; 134 | buildActionMask = 2147483647; 135 | files = ( 136 | 97B171C028F23C7D00B97242 /* bins in Resources */, 137 | 97B171B728F23B7800B97242 /* Preview Assets.xcassets in Resources */, 138 | 97B171B328F23B7800B97242 /* Assets.xcassets in Resources */, 139 | ); 140 | runOnlyForDeploymentPostprocessing = 0; 141 | }; 142 | /* End PBXResourcesBuildPhase section */ 143 | 144 | /* Begin PBXSourcesBuildPhase section */ 145 | 97B171A728F23B7700B97242 /* Sources */ = { 146 | isa = PBXSourcesBuildPhase; 147 | buildActionMask = 2147483647; 148 | files = ( 149 | 97B171B128F23B7700B97242 /* ContentView.swift in Sources */, 150 | 97B171BE28F23BBC00B97242 /* MapleDiffusion.swift in Sources */, 151 | 97B171AF28F23B7700B97242 /* maple_diffusionApp.swift in Sources */, 152 | ); 153 | runOnlyForDeploymentPostprocessing = 0; 154 | }; 155 | /* End PBXSourcesBuildPhase section */ 156 | 157 | /* Begin XCBuildConfiguration section */ 158 | 97B171B828F23B7800B97242 /* Debug */ = { 159 | isa = XCBuildConfiguration; 160 | buildSettings = { 161 | ALWAYS_SEARCH_USER_PATHS = NO; 162 | CLANG_ANALYZER_NONNULL = YES; 163 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 164 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; 165 | CLANG_ENABLE_MODULES = YES; 166 | CLANG_ENABLE_OBJC_ARC = YES; 167 | CLANG_ENABLE_OBJC_WEAK = YES; 168 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 169 | CLANG_WARN_BOOL_CONVERSION = YES; 170 | CLANG_WARN_COMMA = YES; 171 | CLANG_WARN_CONSTANT_CONVERSION = YES; 172 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 173 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 174 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 175 | CLANG_WARN_EMPTY_BODY = YES; 176 | CLANG_WARN_ENUM_CONVERSION = YES; 177 | CLANG_WARN_INFINITE_RECURSION = YES; 178 | CLANG_WARN_INT_CONVERSION = YES; 179 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 180 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 181 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 182 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 183 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 184 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 185 | CLANG_WARN_STRICT_PROTOTYPES = YES; 186 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 187 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 188 | CLANG_WARN_UNREACHABLE_CODE = YES; 189 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 190 | COPY_PHASE_STRIP = NO; 191 | DEBUG_INFORMATION_FORMAT = dwarf; 192 | ENABLE_STRICT_OBJC_MSGSEND = YES; 193 | ENABLE_TESTABILITY = YES; 194 | GCC_C_LANGUAGE_STANDARD = gnu11; 195 | GCC_DYNAMIC_NO_PIC = NO; 196 | GCC_NO_COMMON_BLOCKS = YES; 197 | GCC_OPTIMIZATION_LEVEL = 0; 198 | GCC_PREPROCESSOR_DEFINITIONS = ( 199 | "DEBUG=1", 200 | "$(inherited)", 201 | ); 202 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 203 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 204 | GCC_WARN_UNDECLARED_SELECTOR = YES; 205 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 206 | GCC_WARN_UNUSED_FUNCTION = YES; 207 | GCC_WARN_UNUSED_VARIABLE = YES; 208 | MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; 209 | MTL_FAST_MATH = YES; 210 | ONLY_ACTIVE_ARCH = YES; 211 | SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG; 212 | SWIFT_OPTIMIZATION_LEVEL = "-Onone"; 213 | }; 214 | name = Debug; 215 | }; 216 | 97B171B928F23B7800B97242 /* Release */ = { 217 | isa = XCBuildConfiguration; 218 | buildSettings = { 219 | ALWAYS_SEARCH_USER_PATHS = NO; 220 | CLANG_ANALYZER_NONNULL = YES; 221 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 222 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; 223 | CLANG_ENABLE_MODULES = YES; 224 | CLANG_ENABLE_OBJC_ARC = YES; 225 | CLANG_ENABLE_OBJC_WEAK = YES; 226 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 227 | CLANG_WARN_BOOL_CONVERSION = YES; 228 | CLANG_WARN_COMMA = YES; 229 | CLANG_WARN_CONSTANT_CONVERSION = YES; 230 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 231 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 232 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 233 | CLANG_WARN_EMPTY_BODY = YES; 234 | CLANG_WARN_ENUM_CONVERSION = YES; 235 | CLANG_WARN_INFINITE_RECURSION = YES; 236 | CLANG_WARN_INT_CONVERSION = YES; 237 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 238 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 239 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 240 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 241 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 242 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 243 | CLANG_WARN_STRICT_PROTOTYPES = YES; 244 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 245 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 246 | CLANG_WARN_UNREACHABLE_CODE = YES; 247 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 248 | COPY_PHASE_STRIP = NO; 249 | DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; 250 | ENABLE_NS_ASSERTIONS = NO; 251 | ENABLE_STRICT_OBJC_MSGSEND = YES; 252 | GCC_C_LANGUAGE_STANDARD = gnu11; 253 | GCC_NO_COMMON_BLOCKS = YES; 254 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 255 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 256 | GCC_WARN_UNDECLARED_SELECTOR = YES; 257 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 258 | GCC_WARN_UNUSED_FUNCTION = YES; 259 | GCC_WARN_UNUSED_VARIABLE = YES; 260 | MTL_ENABLE_DEBUG_INFO = NO; 261 | MTL_FAST_MATH = YES; 262 | SWIFT_COMPILATION_MODE = wholemodule; 263 | SWIFT_OPTIMIZATION_LEVEL = "-O"; 264 | }; 265 | name = Release; 266 | }; 267 | 97B171BB28F23B7800B97242 /* Debug */ = { 268 | isa = XCBuildConfiguration; 269 | buildSettings = { 270 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 271 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 272 | CODE_SIGN_ENTITLEMENTS = "maple-diffusion/maple_diffusion.entitlements"; 273 | "CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development"; 274 | CODE_SIGN_STYLE = Automatic; 275 | CURRENT_PROJECT_VERSION = 1; 276 | DEVELOPMENT_ASSET_PATHS = "\"maple-diffusion/Preview Content\""; 277 | DEVELOPMENT_TEAM = MGZW3M7DL4; 278 | ENABLE_PREVIEWS = YES; 279 | GENERATE_INFOPLIST_FILE = YES; 280 | INFOPLIST_KEY_CFBundleDisplayName = "Maple Diffusion"; 281 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; 282 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; 283 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; 284 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; 285 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; 286 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; 287 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; 288 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; 289 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 290 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 291 | IPHONEOS_DEPLOYMENT_TARGET = 16.0; 292 | LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; 293 | "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; 294 | MACOSX_DEPLOYMENT_TARGET = 12.3; 295 | MARKETING_VERSION = 1.0; 296 | PRODUCT_BUNDLE_IDENTIFIER = "com.madebyollin.maple-diffusion"; 297 | PRODUCT_NAME = "$(TARGET_NAME)"; 298 | SDKROOT = auto; 299 | SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; 300 | SWIFT_EMIT_LOC_STRINGS = YES; 301 | SWIFT_VERSION = 5.0; 302 | TARGETED_DEVICE_FAMILY = "1,2"; 303 | }; 304 | name = Debug; 305 | }; 306 | 97B171BC28F23B7800B97242 /* Release */ = { 307 | isa = XCBuildConfiguration; 308 | buildSettings = { 309 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 310 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 311 | CODE_SIGN_ENTITLEMENTS = "maple-diffusion/maple_diffusion.entitlements"; 312 | "CODE_SIGN_IDENTITY[sdk=macosx*]" = "Apple Development"; 313 | CODE_SIGN_STYLE = Automatic; 314 | CURRENT_PROJECT_VERSION = 1; 315 | DEVELOPMENT_ASSET_PATHS = "\"maple-diffusion/Preview Content\""; 316 | DEVELOPMENT_TEAM = MGZW3M7DL4; 317 | ENABLE_PREVIEWS = YES; 318 | GENERATE_INFOPLIST_FILE = YES; 319 | INFOPLIST_KEY_CFBundleDisplayName = "Maple Diffusion"; 320 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; 321 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; 322 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; 323 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; 324 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; 325 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; 326 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; 327 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; 328 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 329 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 330 | IPHONEOS_DEPLOYMENT_TARGET = 16.0; 331 | LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; 332 | "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; 333 | MACOSX_DEPLOYMENT_TARGET = 12.3; 334 | MARKETING_VERSION = 1.0; 335 | PRODUCT_BUNDLE_IDENTIFIER = "com.madebyollin.maple-diffusion"; 336 | PRODUCT_NAME = "$(TARGET_NAME)"; 337 | SDKROOT = auto; 338 | SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx"; 339 | SWIFT_EMIT_LOC_STRINGS = YES; 340 | SWIFT_VERSION = 5.0; 341 | TARGETED_DEVICE_FAMILY = "1,2"; 342 | }; 343 | name = Release; 344 | }; 345 | /* End XCBuildConfiguration section */ 346 | 347 | /* Begin XCConfigurationList section */ 348 | 97B171A628F23B7700B97242 /* Build configuration list for PBXProject "maple-diffusion" */ = { 349 | isa = XCConfigurationList; 350 | buildConfigurations = ( 351 | 97B171B828F23B7800B97242 /* Debug */, 352 | 97B171B928F23B7800B97242 /* Release */, 353 | ); 354 | defaultConfigurationIsVisible = 0; 355 | defaultConfigurationName = Release; 356 | }; 357 | 97B171BA28F23B7800B97242 /* Build configuration list for PBXNativeTarget "maple-diffusion" */ = { 358 | isa = XCConfigurationList; 359 | buildConfigurations = ( 360 | 97B171BB28F23B7800B97242 /* Debug */, 361 | 97B171BC28F23B7800B97242 /* Release */, 362 | ); 363 | defaultConfigurationIsVisible = 0; 364 | defaultConfigurationName = Release; 365 | }; 366 | /* End XCConfigurationList section */ 367 | }; 368 | rootObject = 97B171A328F23B7700B97242 /* Project object */; 369 | } 370 | -------------------------------------------------------------------------------- /maple-diffusion.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /maple-diffusion.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /maple-diffusion/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /maple-diffusion/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "idiom" : "universal", 5 | "platform" : "ios", 6 | "size" : "1024x1024" 7 | }, 8 | { 9 | "idiom" : "mac", 10 | "scale" : "1x", 11 | "size" : "16x16" 12 | }, 13 | { 14 | "idiom" : "mac", 15 | "scale" : "2x", 16 | "size" : "16x16" 17 | }, 18 | { 19 | "idiom" : "mac", 20 | "scale" : "1x", 21 | "size" : "32x32" 22 | }, 23 | { 24 | "idiom" : "mac", 25 | "scale" : "2x", 26 | "size" : "32x32" 27 | }, 28 | { 29 | "idiom" : "mac", 30 | "scale" : "1x", 31 | "size" : "128x128" 32 | }, 33 | { 34 | "idiom" : "mac", 35 | "scale" : "2x", 36 | "size" : "128x128" 37 | }, 38 | { 39 | "idiom" : "mac", 40 | "scale" : "1x", 41 | "size" : "256x256" 42 | }, 43 | { 44 | "idiom" : "mac", 45 | "scale" : "2x", 46 | "size" : "256x256" 47 | }, 48 | { 49 | "idiom" : "mac", 50 | "scale" : "1x", 51 | "size" : "512x512" 52 | }, 53 | { 54 | "idiom" : "mac", 55 | "scale" : "2x", 56 | "size" : "512x512" 57 | } 58 | ], 59 | "info" : { 60 | "author" : "xcode", 61 | "version" : 1 62 | } 63 | } 64 | -------------------------------------------------------------------------------- /maple-diffusion/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /maple-diffusion/ContentView.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | 3 | struct ContentView: View { 4 | #if os(iOS) 5 | let mapleDiffusion = MapleDiffusion(saveMemoryButBeSlower: true) 6 | #else 7 | let mapleDiffusion = MapleDiffusion(saveMemoryButBeSlower: false) 8 | #endif 9 | let dispatchQueue = DispatchQueue(label: "Generation") 10 | @State var steps: Float = 20 11 | @State var image: Image? 12 | @State var prompt: String = "" 13 | @State var negativePrompt: String = "" 14 | @State var guidanceScale: Float = 7.5 15 | @State var running: Bool = false 16 | @State var progressProp: Float = 1 17 | @State var progressStage: String = "Ready" 18 | 19 | func loadModels() { 20 | dispatchQueue.async { 21 | running = true 22 | mapleDiffusion.initModels() { (p, s) -> () in 23 | progressProp = p 24 | progressStage = s 25 | } 26 | running = false 27 | } 28 | } 29 | 30 | func generate() { 31 | dispatchQueue.async { 32 | running = true 33 | progressStage = "" 34 | progressProp = 0 35 | mapleDiffusion.generate(prompt: prompt, negativePrompt: negativePrompt, seed: Int.random(in: 1.. () in 36 | if (cgim != nil) { 37 | image = Image(cgim!, scale: 1.0, label: Text("Generated image")) 38 | } 39 | progressProp = p 40 | progressStage = s 41 | } 42 | running = false 43 | } 44 | } 45 | var body: some View { 46 | VStack { 47 | #if os(iOS) 48 | Text("🍁 Maple Diffusion").foregroundColor(.orange).bold().frame(alignment: Alignment.center) 49 | #endif 50 | if (image == nil) { 51 | Rectangle().fill(.gray).aspectRatio(1.0, contentMode: .fit).frame(idealWidth: mapleDiffusion.width as? CGFloat, idealHeight: mapleDiffusion.height as? CGFloat) 52 | } else { 53 | #if os(iOS) 54 | ShareLink(item: image!, preview: SharePreview(prompt, image: image!)) { 55 | image!.resizable().aspectRatio(contentMode: .fit).frame(idealWidth: mapleDiffusion.width as? CGFloat, idealHeight: mapleDiffusion.height as? CGFloat) 56 | } 57 | #else 58 | image!.resizable().aspectRatio(contentMode: .fit).frame(idealWidth: mapleDiffusion.width as? CGFloat, idealHeight: mapleDiffusion.height as? CGFloat) 59 | #endif 60 | } 61 | HStack { 62 | Text("Prompt").bold() 63 | TextField("What you want", text: $prompt) 64 | } 65 | HStack { 66 | Text("Negative Prompt").bold() 67 | TextField("What you don't want", text: $negativePrompt) 68 | } 69 | HStack { 70 | HStack { 71 | Text("Scale").bold() 72 | Text(String(format: "%.1f", guidanceScale)).foregroundColor(.secondary) 73 | }.frame(width: 96, alignment: .leading) 74 | Slider(value: $guidanceScale, in: 1...20) 75 | } 76 | HStack { 77 | HStack { 78 | Text("Steps").bold() 79 | Text("\(Int(steps))").foregroundColor(.secondary) 80 | }.frame(width: 96, alignment: .leading) 81 | Slider(value: $steps, in: 5...150) 82 | } 83 | ProgressView(progressStage, value: progressProp, total: 1).opacity(running ? 1 : 0).foregroundColor(.secondary) 84 | Spacer(minLength: 8) 85 | Button(action: generate) { 86 | Text("Generate Image") 87 | .frame(minWidth: 100, maxWidth: .infinity, minHeight: 64, alignment: .center) 88 | .background(running ? .gray : .blue) 89 | .foregroundColor(.white) 90 | .font(Font.title) 91 | .cornerRadius(32) 92 | }.buttonStyle(.borderless).disabled(running) 93 | }.padding(16).onAppear(perform: loadModels) 94 | } 95 | } 96 | 97 | struct ContentView_Previews: PreviewProvider { 98 | static var previews: some View { 99 | ContentView() 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /maple-diffusion/MapleDiffusion.swift: -------------------------------------------------------------------------------- 1 | import MetalPerformanceShadersGraph 2 | import Foundation 3 | 4 | // Maple Diffusion implements stable diffusion (original v1.4 model) 5 | // inference via MPSGraph. iOS has a hard memory limit of 4GB (with 6 | // a special entitlement), so this implementation trades off latency 7 | // for memory usage in many places (tagged with MEM-HACK) in order to 8 | // stay under the limit and minimize probability of oom. 9 | 10 | func makeGraph(synchonize: Bool) -> MPSGraph { 11 | let graph = MPSGraph() 12 | graph.options = synchonize ? MPSGraphOptions.synchronizeResults : .none 13 | return graph 14 | } 15 | 16 | func loadConstant(graph: MPSGraph, name: String, shape: [NSNumber], fp32: Bool = false) -> MPSGraphTensor { 17 | let numels = shape.map({$0.intValue}).reduce(1, *) 18 | let fileUrl: URL = Bundle.main.url(forResource: "bins/" + name + (fp32 ? "_fp32" : ""), withExtension: ".bin")! 19 | let data: Data = try! Data(contentsOf: fileUrl, options: Data.ReadingOptions.alwaysMapped) 20 | let expectedCount = numels * (fp32 ? 4 : 2) 21 | assert(data.count == expectedCount, "Mismatch between byte count of data \(data.count) and expected size \(expectedCount) for \(numels) els in \(fileUrl)") 22 | return graph.constant(data, shape: shape, dataType: fp32 ? MPSDataType.float32 : MPSDataType.float16) 23 | } 24 | 25 | func makeConv(graph: MPSGraph, xIn: MPSGraphTensor, name: String, outChannels: NSNumber, khw: NSNumber, stride: Int = 1, bias: Bool = true) -> MPSGraphTensor { 26 | let w = loadConstant(graph: graph, name: name + ".weight", shape: [outChannels, xIn.shape![3], khw, khw]) 27 | let p: Int = khw.intValue / 2; 28 | let convDesc = MPSGraphConvolution2DOpDescriptor(strideInX: stride, strideInY: stride, dilationRateInX: 1, dilationRateInY: 1, groups: 1, paddingLeft: p, paddingRight: p, paddingTop: p, paddingBottom: p, paddingStyle: MPSGraphPaddingStyle.explicit, dataLayout: MPSGraphTensorNamedDataLayout.NHWC, weightsLayout: MPSGraphTensorNamedDataLayout.OIHW)! 29 | let conv = graph.convolution2D(xIn, weights: w, descriptor: convDesc, name: nil) 30 | if (bias) { 31 | let b = loadConstant(graph: graph, name: name + ".bias", shape: [1, 1, 1, outChannels]) 32 | return graph.addition(conv, b, name: nil) 33 | } 34 | return conv 35 | } 36 | 37 | func makeUpsampleNearest(graph: MPSGraph, xIn: MPSGraphTensor, scaleFactor: Int=2) -> MPSGraphTensor { 38 | return graph.resize(xIn, size: [NSNumber(value:xIn.shape![1].intValue * scaleFactor), NSNumber(value:xIn.shape![2].intValue * scaleFactor)], mode: MPSGraphResizeMode.nearest, centerResult: true, alignCorners: false, layout: MPSGraphTensorNamedDataLayout.NHWC, name: nil) 39 | } 40 | 41 | func makeGroupNorm(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 42 | var x = xIn 43 | if (xIn.shape!.count == 3) { 44 | x = graph.expandDims(x, axes: [1], name: nil) 45 | } 46 | let shape = x.shape! 47 | let nGroups: NSNumber = 32 48 | let nGrouped: NSNumber = shape[3].floatValue / nGroups.floatValue as NSNumber 49 | let gamma = loadConstant(graph: graph, name: name + ".weight", shape: [1, 1, 1, nGroups, nGrouped]) 50 | let beta = loadConstant(graph: graph, name: name + ".bias", shape: [1, 1, 1, nGroups, nGrouped]) 51 | x = graph.reshape(x, shape: [shape[0], shape[1], shape[2], nGroups, nGrouped], name: nil) 52 | let mean = graph.mean(of: x, axes: [1, 2, 4], name: nil) 53 | let variance = graph.variance(of: x, axes: [1, 2, 4], name: nil) 54 | x = graph.normalize(x, mean: mean, variance: variance, gamma: gamma, beta: beta, epsilon: 1e-5, name: nil) 55 | return graph.reshape(x, shape: xIn.shape!, name: nil) 56 | } 57 | 58 | func makeSwish(graph: MPSGraph, xIn: MPSGraphTensor) -> MPSGraphTensor { 59 | return graph.multiplication(xIn, graph.sigmoid(with: xIn, name: nil), name: nil) 60 | } 61 | 62 | func makeGroupNormSwish(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 63 | return makeSwish(graph: graph, xIn: makeGroupNorm(graph: graph, xIn: xIn, name: name)) 64 | } 65 | 66 | func makeDecoderResBlock(graph: MPSGraph, xIn: MPSGraphTensor, name: String, outChannels: NSNumber) -> MPSGraphTensor { 67 | var x = xIn 68 | x = makeGroupNormSwish(graph: graph, xIn: x, name: name + ".norm1") 69 | x = makeConv(graph: graph, xIn: x, name: name + ".conv1", outChannels: outChannels, khw: 3) 70 | x = makeGroupNormSwish(graph: graph, xIn: x, name: name + ".norm2") 71 | x = makeConv(graph: graph, xIn: x, name: name + ".conv2", outChannels: outChannels, khw: 3) 72 | if (xIn.shape![3] != outChannels) { 73 | let ninShortcut = makeConv(graph: graph, xIn: xIn, name: name + ".nin_shortcut", outChannels: outChannels, khw: 1) 74 | return graph.addition(x, ninShortcut, name: "skip") 75 | } 76 | return graph.addition(x, xIn, name: "skip") 77 | } 78 | 79 | func makeDecoderAttention(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 80 | var x = makeGroupNorm(graph: graph, xIn: xIn, name: name + ".norm") 81 | let c = x.shape![3] 82 | x = graph.reshape(x, shape: [x.shape![0], NSNumber(value:x.shape![1].intValue * x.shape![2].intValue), c], name: nil) 83 | let q = makeLinear(graph: graph, xIn: x, name: name + ".q", outChannels: c, bias: false) 84 | var k = makeLinear(graph: graph, xIn: x, name: name + ".k", outChannels: c, bias: false) 85 | k = graph.multiplication(k, graph.constant(1.0 / sqrt(c.doubleValue), dataType: MPSDataType.float16), name: nil) 86 | k = graph.transposeTensor(k, dimension: 1, withDimension: 2, name: nil) 87 | let v = makeLinear(graph: graph, xIn: x, name: name + ".v", outChannels: c, bias: false) 88 | var att = graph.matrixMultiplication(primary: q, secondary: k, name: nil) 89 | att = graph.softMax(with: att, axis: 2, name: nil) 90 | att = graph.matrixMultiplication(primary: att, secondary: v, name: nil) 91 | x = makeLinear(graph: graph, xIn: att, name: name + ".proj_out", outChannels: c) 92 | x = graph.reshape(x, shape: xIn.shape!, name: nil) 93 | return graph.addition(x, xIn, name: nil) 94 | } 95 | 96 | func makeByteConverter(graph: MPSGraph, xIn: MPSGraphTensor) -> MPSGraphTensor { 97 | var x = xIn 98 | x = graph.clamp(x, min: graph.constant(0, shape: [1], dataType: MPSDataType.float16), max: graph.constant(1.0, shape: [1], dataType: MPSDataType.float16), name: nil) 99 | x = graph.multiplication(x, graph.constant(255, shape: [1], dataType: MPSDataType.float16), name: nil) 100 | x = graph.round(with: x, name: nil) 101 | x = graph.cast(x, to: MPSDataType.uInt8, name: "cast to uint8 rgba") 102 | let alpha = graph.constant(255, shape: [1, x.shape![1], x.shape![2], 1], dataType: MPSDataType.uInt8) 103 | return graph.concatTensors([x, alpha], dimension: 3, name: nil) 104 | } 105 | 106 | func makeDecoder(graph: MPSGraph, xIn: MPSGraphTensor) -> MPSGraphTensor { 107 | var x = xIn 108 | let name = "first_stage_model.decoder" 109 | x = graph.multiplication(x, graph.constant(1 / 0.18215, dataType: MPSDataType.float16), name: "rescale") 110 | x = makeConv(graph: graph, xIn: x, name: "first_stage_model.post_quant_conv", outChannels: 4, khw: 1) 111 | x = makeConv(graph: graph, xIn: x, name: name + ".conv_in", outChannels: 512, khw: 3) 112 | 113 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".mid.block_1", outChannels: 512) 114 | x = makeDecoderAttention(graph: graph, xIn: x, name: name + ".mid.attn_1") 115 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".mid.block_2", outChannels: 512) 116 | 117 | // block 3 118 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.3.block.0", outChannels: 512) 119 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.3.block.1", outChannels: 512) 120 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.3.block.2", outChannels: 512) 121 | x = makeUpsampleNearest(graph: graph, xIn: x) 122 | x = makeConv(graph: graph, xIn: x, name: name + ".up.3.upsample.conv", outChannels: 512, khw: 3) 123 | 124 | // block 2 125 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.2.block.0", outChannels: 512) 126 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.2.block.1", outChannels: 512) 127 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.2.block.2", outChannels: 512) 128 | x = makeUpsampleNearest(graph: graph, xIn: x) 129 | x = makeConv(graph: graph, xIn: x, name: name + ".up.2.upsample.conv", outChannels: 512, khw: 3) 130 | 131 | // block 1 132 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.1.block.0", outChannels: 256) 133 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.1.block.1", outChannels: 256) 134 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.1.block.2", outChannels: 256) 135 | x = makeUpsampleNearest(graph: graph, xIn: x) 136 | x = makeConv(graph: graph, xIn: x, name: name + ".up.1.upsample.conv", outChannels: 256, khw: 3) 137 | // block 0 138 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.0.block.0", outChannels: 128) 139 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.0.block.1", outChannels: 128) 140 | x = makeDecoderResBlock(graph: graph, xIn: x, name: name + ".up.0.block.2", outChannels: 128) 141 | 142 | x = makeGroupNormSwish(graph: graph, xIn: x, name: name + ".norm_out") 143 | x = makeConv(graph: graph, xIn: x, name: name + ".conv_out", outChannels: 3, khw: 3) 144 | x = graph.addition(x, graph.constant(1.0, dataType: MPSDataType.float16), name: nil) 145 | x = graph.multiplication(x, graph.constant(0.5, dataType: MPSDataType.float16), name: nil) 146 | return makeByteConverter(graph: graph, xIn: x) 147 | } 148 | 149 | func makeLayerNorm(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 150 | assert(xIn.shape!.count == 3, "layernorm requires NTC") 151 | let gamma = loadConstant(graph: graph, name: name + ".weight", shape: [1, 1, xIn.shape![2]]) 152 | let beta = loadConstant(graph: graph, name: name + ".bias", shape: [1, 1, xIn.shape![2]]) 153 | let mean = graph.mean(of: xIn, axes: [2], name: nil) 154 | let variance = graph.variance(of: xIn, axes: [2], name: nil) 155 | let x = graph.normalize(xIn, mean: mean, variance: variance, gamma: gamma, beta: beta, epsilon: 1e-5, name: nil) 156 | return graph.reshape(x, shape: xIn.shape!, name: nil) 157 | } 158 | 159 | func makeLinear(graph: MPSGraph, xIn: MPSGraphTensor, name: String, outChannels: NSNumber, bias: Bool = true) -> MPSGraphTensor { 160 | if (xIn.shape!.count == 2) { 161 | var x = graph.reshape(xIn, shape: [xIn.shape![0], 1, 1, xIn.shape![1]], name: nil) 162 | x = makeConv(graph: graph, xIn: x, name: name, outChannels: outChannels, khw: 1, bias: bias) 163 | return graph.reshape(x, shape: [xIn.shape![0], outChannels], name: nil) 164 | } 165 | var x = graph.reshape(xIn, shape: [xIn.shape![0], 1, xIn.shape![1], xIn.shape![2]], name: nil) 166 | x = makeConv(graph: graph, xIn: x, name: name, outChannels: outChannels, khw: 1, bias: bias) 167 | return graph.reshape(x, shape: [xIn.shape![0], xIn.shape![1], outChannels], name: nil) 168 | } 169 | 170 | func makeTimeEmbed(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 171 | var x = xIn 172 | x = makeLinear(graph: graph, xIn: x, name: name + ".0", outChannels: 1280) 173 | x = makeSwish(graph: graph, xIn: x) 174 | return makeLinear(graph: graph, xIn: x, name: name + ".2", outChannels: 1280) 175 | } 176 | 177 | func makeUNetResBlock(graph: MPSGraph, xIn: MPSGraphTensor, embIn: MPSGraphTensor, name: String, inChannels: NSNumber, outChannels: NSNumber) -> MPSGraphTensor { 178 | var x = xIn 179 | x = makeGroupNormSwish(graph: graph, xIn: x, name: name + ".in_layers.0") 180 | x = makeConv(graph: graph, xIn: x, name: name + ".in_layers.2", outChannels: outChannels, khw: 3) 181 | var emb = embIn 182 | emb = makeSwish(graph: graph, xIn: emb) 183 | emb = makeLinear(graph: graph, xIn: emb, name: name + ".emb_layers.1", outChannels: outChannels) 184 | emb = graph.expandDims(emb, axes: [1, 2], name: nil) 185 | x = graph.addition(x, emb, name: nil) 186 | x = makeGroupNormSwish(graph: graph, xIn: x, name: name + ".out_layers.0") 187 | x = makeConv(graph: graph, xIn: x, name: name + ".out_layers.3", outChannels: outChannels, khw: 3) 188 | 189 | var skip = xIn 190 | if (inChannels != outChannels) { 191 | skip = makeConv(graph: graph, xIn: xIn, name: name + ".skip_connection", outChannels: outChannels, khw: 1) 192 | } 193 | return graph.addition(x, skip, name: nil) 194 | } 195 | 196 | func makeCrossAttention(graph: MPSGraph, xIn: MPSGraphTensor, name: String, context: MPSGraphTensor?, saveMemory: Bool) -> MPSGraphTensor { 197 | let c = xIn.shape![2] 198 | let (nHeads, dHead) = (NSNumber(8), NSNumber(value: c.intValue / 8)) 199 | var q = makeLinear(graph: graph, xIn: xIn, name: name + ".to_q", outChannels: c, bias: false) 200 | let context = context ?? xIn 201 | var k = makeLinear(graph: graph, xIn: context, name: name + ".to_k", outChannels: c, bias: false) 202 | var v = makeLinear(graph: graph, xIn: context, name: name + ".to_v", outChannels: c, bias: false) 203 | let n = xIn.shape![0] 204 | let hw = xIn.shape![1] 205 | let t = context.shape![1] 206 | q = graph.reshape(q, shape: [n, hw, nHeads, dHead], name: nil) 207 | k = graph.reshape(k, shape: [n, t, nHeads, dHead], name: nil) 208 | v = graph.reshape(v, shape: [n, t, nHeads, dHead], name: nil) 209 | 210 | q = graph.transposeTensor(q, dimension: 1, withDimension: 2, name: nil) 211 | k = graph.transposeTensor(k, dimension: 1, withDimension: 2, name: nil) 212 | k = graph.transposeTensor(k, dimension: 2, withDimension: 3, name: nil) 213 | k = graph.multiplication(k, graph.constant(1.0 / sqrt(dHead.doubleValue), dataType: MPSDataType.float16), name: nil) 214 | v = graph.transposeTensor(v, dimension: 1, withDimension: 2, name: nil) 215 | 216 | var att: MPSGraphTensor 217 | if (saveMemory) { 218 | // MEM-HACK - silly graph seems to use less peak memory 219 | var attRes = [MPSGraphTensor]() 220 | let sliceSize = 1 221 | for i in 0.. MPSGraphTensor { 243 | var x = xIn 244 | x = graph.multiplication(x, graph.constant(1/sqrt(2), dataType: MPSDataType.float16), name: nil) 245 | x = graph.erf(with: x, name: nil) 246 | x = graph.addition(x, graph.constant(1, dataType: MPSDataType.float16), name: nil) 247 | x = graph.multiplication(x, graph.constant(0.5, dataType: MPSDataType.float16), name: nil) 248 | return graph.multiplication(xIn, x, name: nil) 249 | } 250 | 251 | func makeFeedForward(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 252 | assert(xIn.shape!.count == 3) 253 | let dim = xIn.shape![2] 254 | let dimMult = dim.intValue * 4 255 | let dimProj = NSNumber(value: dimMult * 2) 256 | let proj = makeLinear(graph: graph, xIn: xIn, name: name + ".0.proj", outChannels: dimProj) 257 | var x = graph.sliceTensor(proj, dimension: 2, start: 0, length: dimMult, name: nil) 258 | var gate = graph.sliceTensor(proj, dimension: 2, start: dimMult, length: dimMult, name: nil) 259 | gate = makeGelu(graph: graph, xIn: gate) 260 | x = graph.multiplication(x, gate, name: nil) 261 | return makeLinear(graph: graph, xIn: x, name: name + ".2", outChannels: dim) 262 | } 263 | 264 | func makeBasicTransformerBlock(graph: MPSGraph, xIn: MPSGraphTensor, name: String, contextIn: MPSGraphTensor, saveMemory: Bool) -> MPSGraphTensor { 265 | var x = xIn 266 | var attn1 = makeLayerNorm(graph: graph, xIn: x, name: name + ".norm1") 267 | attn1 = makeCrossAttention(graph: graph, xIn: attn1, name: name + ".attn1", context: nil, saveMemory: saveMemory) 268 | x = graph.addition(attn1, x, name: nil) 269 | var attn2 = makeLayerNorm(graph: graph, xIn: x, name: name + ".norm2") 270 | attn2 = makeCrossAttention(graph: graph, xIn: attn2, name: name + ".attn2", context: contextIn, saveMemory: saveMemory) 271 | x = graph.addition(attn2, x, name: nil) 272 | var ff = makeLayerNorm(graph: graph, xIn: x, name: name + ".norm3") 273 | ff = makeFeedForward(graph: graph, xIn: ff, name: name + ".ff.net") 274 | return graph.addition(ff, x, name: nil) 275 | } 276 | 277 | func makeSpatialTransformerBlock(graph: MPSGraph, xIn: MPSGraphTensor, name: String, contextIn: MPSGraphTensor, saveMemory: Bool) -> MPSGraphTensor { 278 | let n, h, w, c: NSNumber 279 | (n, h, w, c) = (xIn.shape![0], xIn.shape![1], xIn.shape![2], xIn.shape![3]) 280 | var x = xIn 281 | x = makeGroupNorm(graph: graph, xIn: x, name: name + ".norm") 282 | x = makeConv(graph: graph, xIn: x, name: name + ".proj_in", outChannels: c, khw: 1) 283 | x = graph.reshape(x, shape: [n, (h.intValue * w.intValue) as NSNumber, c], name: nil) 284 | x = makeBasicTransformerBlock(graph: graph, xIn: x, name: name + ".transformer_blocks.0", contextIn: contextIn, saveMemory: saveMemory) 285 | x = graph.reshape(x, shape: [n, h, w, c], name: nil) 286 | x = makeConv(graph: graph, xIn: x, name: name + ".proj_out", outChannels: c, khw: 1) 287 | return graph.addition(x, xIn, name: nil) 288 | } 289 | 290 | func makeOutputBlock(graph: MPSGraph, xIn: MPSGraphTensor, embIn: MPSGraphTensor, condIn: MPSGraphTensor, inChannels: NSNumber, outChannels: NSNumber, dHead: NSNumber, name: String, saveMemory: Bool, spatialTransformer: Bool = true, upsample: Bool = false) -> MPSGraphTensor { 291 | var x = xIn 292 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: embIn, name: name + ".0", inChannels: inChannels, outChannels: outChannels) 293 | if (spatialTransformer) { 294 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".1", contextIn: condIn, saveMemory: saveMemory) 295 | } 296 | if (upsample) { 297 | x = makeUpsampleNearest(graph: graph, xIn: x) 298 | x = makeConv(graph: graph, xIn: x, name: name + (spatialTransformer ? ".2" : ".1") + ".conv", outChannels: outChannels, khw: 3) 299 | } 300 | return x 301 | } 302 | 303 | 304 | func makeUNetAnUnexpectedJourney(graph: MPSGraph, xIn: MPSGraphTensor, tembIn: MPSGraphTensor, condIn: MPSGraphTensor, name: String, saveMemory: Bool = true) -> [MPSGraphTensor] { 305 | let emb = makeTimeEmbed(graph: graph, xIn: tembIn, name: name + ".time_embed") 306 | 307 | var savedInputs = [MPSGraphTensor]() 308 | var x = xIn 309 | 310 | if (!saveMemory) { 311 | // need to explicitly batch to avoid shape errors later iirc 312 | // TODO: did we actually need this 313 | x = graph.broadcast(x, shape: [condIn.shape![0], x.shape![1], x.shape![2], x.shape![3]], name: nil) 314 | } 315 | 316 | // input blocks 317 | x = makeConv(graph: graph, xIn: x, name: name + ".input_blocks.0.0", outChannels: 320, khw: 3) 318 | savedInputs.append(x) 319 | 320 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.1.0", inChannels: 320, outChannels: 320) 321 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".input_blocks.1.1", contextIn: condIn, saveMemory: saveMemory) 322 | savedInputs.append(x) 323 | 324 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.2.0", inChannels: 320, outChannels: 320) 325 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".input_blocks.2.1", contextIn: condIn, saveMemory: saveMemory) 326 | savedInputs.append(x) 327 | 328 | // downsample 329 | x = makeConv(graph: graph, xIn: x, name: name + ".input_blocks.3.0.op", outChannels: 320, khw: 3, stride: 2) 330 | savedInputs.append(x) 331 | 332 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.4.0", inChannels: 320, outChannels: 640) 333 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".input_blocks.4.1", contextIn: condIn, saveMemory: saveMemory) 334 | savedInputs.append(x) 335 | 336 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.5.0", inChannels: 640, outChannels: 640) 337 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".input_blocks.5.1", contextIn: condIn, saveMemory: saveMemory) 338 | savedInputs.append(x) 339 | 340 | // downsample 341 | x = makeConv(graph: graph, xIn: x, name: name + ".input_blocks.6.0.op", outChannels: 640, khw: 3, stride: 2) 342 | savedInputs.append(x) 343 | 344 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.7.0", inChannels: 640, outChannels: 1280) 345 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".input_blocks.7.1", contextIn: condIn, saveMemory: saveMemory) 346 | savedInputs.append(x) 347 | 348 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.8.0", inChannels: 1280, outChannels: 1280) 349 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".input_blocks.8.1", contextIn: condIn, saveMemory: saveMemory) 350 | savedInputs.append(x) 351 | 352 | // downsample 353 | x = makeConv(graph: graph, xIn: x, name: name + ".input_blocks.9.0.op", outChannels: 1280, khw: 3, stride: 2) 354 | savedInputs.append(x) 355 | 356 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.10.0", inChannels: 1280, outChannels: 1280) 357 | savedInputs.append(x) 358 | 359 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".input_blocks.11.0", inChannels: 1280, outChannels: 1280) 360 | savedInputs.append(x) 361 | 362 | // middle blocks 363 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".middle_block.0", inChannels: 1280, outChannels: 1280) 364 | x = makeSpatialTransformerBlock(graph: graph, xIn: x, name: name + ".middle_block.1", contextIn: condIn, saveMemory: saveMemory) 365 | x = makeUNetResBlock(graph: graph, xIn: x, embIn: emb, name: name + ".middle_block.2", inChannels: 1280, outChannels: 1280) 366 | 367 | return savedInputs + [emb] + [x] 368 | } 369 | 370 | func makeUNetTheDesolationOfSmaug(graph: MPSGraph, savedInputsIn: [MPSGraphTensor], name: String, saveMemory: Bool = true) -> [MPSGraphTensor] { 371 | var savedInputs = savedInputsIn 372 | let condIn = savedInputs.popLast()! 373 | var x = savedInputs.popLast()! 374 | let emb = savedInputs.popLast()! 375 | // output blocks 376 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 377 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 2560, outChannels: 1280, dHead: 160, name: name + ".output_blocks.0", saveMemory: saveMemory, spatialTransformer: false, upsample: false) 378 | 379 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 380 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 2560, outChannels: 1280, dHead: 160, name: name + ".output_blocks.1", saveMemory: saveMemory, spatialTransformer: false, upsample: false) 381 | 382 | // upsample 383 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 384 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 2560, outChannels: 1280, dHead: 160, name: name + ".output_blocks.2", saveMemory: saveMemory, spatialTransformer: false, upsample: true) 385 | 386 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 387 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 2560, outChannels: 1280, dHead: 160, name: name + ".output_blocks.3", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 388 | 389 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 390 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 2560, outChannels: 1280, dHead: 160, name: name + ".output_blocks.4", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 391 | 392 | return savedInputs + [emb] + [x] 393 | } 394 | 395 | func makeUNetTheBattleOfTheFiveArmies(graph: MPSGraph, savedInputsIn: [MPSGraphTensor], name: String, saveMemory: Bool = true) -> MPSGraphTensor { 396 | var savedInputs = savedInputsIn 397 | let condIn = savedInputs.popLast()! 398 | var x = savedInputs.popLast()! 399 | let emb = savedInputs.popLast()! 400 | // upsample 401 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 402 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 1920, outChannels: 1280, dHead: 160, name: name + ".output_blocks.5", saveMemory: saveMemory, spatialTransformer: true, upsample: true) 403 | 404 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 405 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 1920, outChannels: 640, dHead: 80, name: name + ".output_blocks.6", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 406 | 407 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 408 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 1280, outChannels: 640, dHead: 80, name: name + ".output_blocks.7", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 409 | 410 | // upsample 411 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 412 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 960, outChannels: 640, dHead: 80, name: name + ".output_blocks.8", saveMemory: saveMemory, spatialTransformer: true, upsample: true) 413 | 414 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 415 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 960, outChannels: 320, dHead: 40, name: name + ".output_blocks.9", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 416 | 417 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 418 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 640, outChannels: 320, dHead: 40, name: name + ".output_blocks.10", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 419 | 420 | x = graph.concatTensors([x, savedInputs.popLast()!], dimension: 3, name: nil) 421 | x = makeOutputBlock(graph: graph, xIn: x, embIn: emb, condIn: condIn, inChannels: 640, outChannels: 320, dHead: 40, name: name + ".output_blocks.11", saveMemory: saveMemory, spatialTransformer: true, upsample: false) 422 | 423 | // out 424 | x = makeGroupNormSwish(graph: graph, xIn: x, name: "model.diffusion_model.out.0") 425 | return makeConv(graph: graph, xIn: x, name: "model.diffusion_model.out.2", outChannels: 4, khw: 3) 426 | } 427 | 428 | func makeTimeFeatures(graph: MPSGraph, tIn: MPSGraphTensor) -> MPSGraphTensor { 429 | var temb = graph.cast(tIn, to: MPSDataType.float32, name: "temb") 430 | var coeffs = loadConstant(graph: graph, name: "temb_coefficients", shape: [160], fp32: true) 431 | coeffs = graph.cast(coeffs, to: MPSDataType.float32, name: "coeffs") 432 | temb = graph.multiplication(temb, coeffs, name: nil) 433 | temb = graph.concatTensors([graph.cos(with: temb, name: nil), graph.sin(with: temb, name: nil)], dimension: 0, name: nil) 434 | temb = graph.reshape(temb, shape: [1, 320], name: nil) 435 | return graph.cast(temb, to: MPSDataType.float16, name: "temb fp16") 436 | } 437 | 438 | func makeSqrtOneMinus(graph: MPSGraph, xIn: MPSGraphTensor) -> MPSGraphTensor { 439 | return graph.squareRoot(with: graph.subtraction(graph.constant(1.0, dataType: MPSDataType.float16), xIn, name: nil), name: nil) 440 | } 441 | 442 | func makeDiffusionStep(graph: MPSGraph, xIn: MPSGraphTensor, etaUncondIn: MPSGraphTensor, etaCondIn: MPSGraphTensor, tIn: MPSGraphTensor, tPrevIn: MPSGraphTensor, guidanceScaleIn: MPSGraphTensor) -> MPSGraphTensor { 443 | 444 | // superconditioning 445 | var deltaCond = graph.multiplication(graph.subtraction(etaCondIn, etaUncondIn, name: nil), guidanceScaleIn, name: nil) 446 | deltaCond = graph.tanh(with: deltaCond, name: nil) // NOTE: normal SD doesn't clamp here iirc 447 | let eta = graph.addition(etaUncondIn, deltaCond, name: nil) 448 | 449 | // scheduler conditioning 450 | let alphasCumprod = loadConstant(graph: graph, name: "alphas_cumprod", shape: [1000]) 451 | let alphaIn = graph.gatherAlongAxis(0, updates: alphasCumprod, indices: tIn, name: nil) 452 | let alphasCumprodPrev = graph.concatTensors([graph.constant(1, dataType: MPSDataType.float16), alphasCumprod], dimension: 0, name: nil) 453 | let tPrevInOffset = graph.reLU(with: graph.addition(tPrevIn, graph.constant(1, dataType: MPSDataType.int32), name: nil), name: nil) 454 | let alphaPrevIn = graph.gatherAlongAxis(0, updates: alphasCumprodPrev, indices: tPrevInOffset, name: nil) 455 | 456 | // scheduler step 457 | let deltaX0 = graph.multiplication(makeSqrtOneMinus(graph: graph, xIn: alphaIn), eta, name: nil) 458 | let predX0Unscaled = graph.subtraction(xIn, deltaX0, name: nil) 459 | let predX0 = graph.division(predX0Unscaled, graph.squareRoot(with: alphaIn, name: nil), name: nil) 460 | let dirX = graph.multiplication(makeSqrtOneMinus(graph: graph, xIn: alphaPrevIn), eta, name: nil) 461 | let xPrevBase = graph.multiplication(graph.squareRoot(with: alphaPrevIn, name: nil), predX0, name:nil) 462 | return graph.addition(xPrevBase, dirX, name: nil) 463 | } 464 | 465 | class BPETokenizer { 466 | // why didn't they just byte-encode 467 | func whitespaceClean(s: String) -> String { return s.components(separatedBy: .whitespacesAndNewlines).filter { !$0.isEmpty }.joined(separator: " ").trimmingCharacters(in: .whitespacesAndNewlines) } 468 | 469 | func getPairs(s: [String]) -> Set { return Set((1.."}) 487 | let vocabFile = try! String(contentsOf: Bundle.main.url(forResource: "bins/bpe_simple_vocab_16e6", withExtension: "txt")!) 488 | for (i, m) in vocabFile.split(separator: "\n")[1..<48_895].enumerated() { 489 | ranks[String(m)] = i 490 | vocabList.append(m.split(separator: " ").joined(separator: "")) 491 | } 492 | vocab = vocabList.enumerated().reduce(into: [:], {$0[$1.element] = $1.offset}) 493 | } 494 | 495 | func encodeToken(s: String) -> [Int] { 496 | let token = String(s.utf8.map{bytesToUnicode[Int($0)]!}) 497 | var word = token[.."] 498 | var pairs = getPairs(s: Array(word)) 499 | var mergedWordTokens = [token + ""] 500 | var count = 0 501 | if (!pairs.isEmpty) { 502 | while (true) { 503 | count += 1 504 | assert(count < 8192, "encodeToken is trapped in a token factory for input \(s)") 505 | let highestRankedBigram = pairs.min(by: {ranks[$0, default: Int.max] < ranks[$1, default: Int.max]})! 506 | if (ranks[highestRankedBigram] == nil) { break } 507 | let fs = highestRankedBigram.split(separator: " ") 508 | let (first, second) = (String(fs[0]), String(fs[1])) 509 | var (newWord, i) = ([String](), 0) 510 | while (i < word.count) { 511 | let j = word[i.. [Int] { 540 | let ns = NSString(string: whitespaceClean(s: s.lowercased())) 541 | var bpe: [Int] = [] 542 | for match in pat.matches(in: String(ns), range: NSRange(location: 0, length: ns.length)) { 543 | bpe.append(contentsOf: encodeToken(s: ns.substring(with: match.range))) 544 | } 545 | if (bpe.count > 75) { 546 | print("Prompt of \(bpe.count) bpe tokens will be truncated: \(s)") 547 | } 548 | return [49406] + bpe[.. MPSGraphTensor { 553 | let nHeads: NSNumber = 12 554 | let dHead: NSNumber = 64 555 | let c: NSNumber = 768 556 | var q = makeLinear(graph: graph, xIn: xIn, name: name + ".q_proj", outChannels: c) 557 | var k = makeLinear(graph: graph, xIn: xIn, name: name + ".k_proj", outChannels: c) 558 | var v = makeLinear(graph: graph, xIn: xIn, name: name + ".v_proj", outChannels: c) 559 | 560 | let n = xIn.shape![0] 561 | let t = xIn.shape![1] 562 | q = graph.reshape(q, shape: [n, t, nHeads, dHead], name: nil) 563 | k = graph.reshape(k, shape: [n, t, nHeads, dHead], name: nil) 564 | v = graph.reshape(v, shape: [n, t, nHeads, dHead], name: nil) 565 | 566 | q = graph.transposeTensor(q, dimension: 1, withDimension: 2, name: nil) 567 | k = graph.transposeTensor(k, dimension: 1, withDimension: 2, name: nil) 568 | v = graph.transposeTensor(v, dimension: 1, withDimension: 2, name: nil) 569 | 570 | var att = graph.matrixMultiplication(primary: q, secondary: graph.transposeTensor(k, dimension: 2, withDimension: 3, name: nil), name: nil) 571 | att = graph.multiplication(att, graph.constant(1.0 / sqrt(dHead.doubleValue), dataType: MPSDataType.float16), name: nil) 572 | att = graph.addition(att, loadConstant(graph: graph, name: "causal_mask", shape: [1, 1, 77, 77]), name: nil) 573 | att = graph.softMax(with: att, axis: 3, name: nil) 574 | att = graph.matrixMultiplication(primary: att, secondary: v, name: nil) 575 | att = graph.transposeTensor(att, dimension: 1, withDimension: 2, name: nil) 576 | att = graph.reshape(att, shape: [n, t, c], name: nil) 577 | return makeLinear(graph: graph, xIn: att, name: name + ".out_proj", outChannels: c) 578 | } 579 | 580 | func makeTextEncoderLayer(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 581 | var x = xIn 582 | x = makeLayerNorm(graph: graph, xIn: x, name: name + ".layer_norm1") 583 | x = makeTextAttention(graph: graph, xIn: x, name: name + ".self_attn") 584 | x = graph.addition(x, xIn, name: nil) 585 | let skip = x 586 | x = makeLayerNorm(graph: graph, xIn: x, name: name + ".layer_norm2") 587 | x = makeLinear(graph: graph, xIn: x, name: name + ".mlp.fc1", outChannels: 3072) 588 | x = makeGelu(graph: graph, xIn: x) 589 | x = makeLinear(graph: graph, xIn: x, name: name + ".mlp.fc2", outChannels: 768) 590 | return graph.addition(x, skip, name: nil) 591 | } 592 | 593 | func makeTextEncoder(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 594 | var x = xIn 595 | for i in 0..<12 { 596 | x = makeTextEncoderLayer(graph: graph, xIn: x, name: name + ".layers.\(i)") 597 | } 598 | return x 599 | } 600 | 601 | func makeTextEmbeddings(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 602 | var tokenEmbeddings = loadConstant(graph: graph, name: name + ".token_embedding.weight", shape: [1, 49408, 768]) 603 | tokenEmbeddings = graph.broadcast(tokenEmbeddings, shape: [2, 49408, 768], name: nil) 604 | let positionEmbeddings = loadConstant(graph: graph, name: name + ".position_embedding.weight", shape: [1, 77, 768]) 605 | var embeddings = graph.broadcast(graph.expandDims(xIn, axes: [2], name: nil), shape: [2, 77, 768], name: nil) 606 | embeddings = graph.gatherAlongAxis(1, updates: tokenEmbeddings, indices: embeddings, name: nil) 607 | return graph.addition(embeddings, positionEmbeddings, name: nil) 608 | } 609 | 610 | func makeTextGuidance(graph: MPSGraph, xIn: MPSGraphTensor, name: String) -> MPSGraphTensor { 611 | var x = makeTextEmbeddings(graph: graph, xIn: xIn, name: name + ".embeddings") 612 | x = makeTextEncoder(graph: graph, xIn: x, name: name + ".encoder") 613 | return makeLayerNorm(graph: graph, xIn: x, name: name + ".final_layer_norm") 614 | } 615 | 616 | func makeAuxUpsampler(graph: MPSGraph, xIn: MPSGraphTensor) -> MPSGraphTensor { 617 | var x = xIn 618 | x = makeConv(graph: graph, xIn: xIn, name: "aux_output_conv", outChannels: 3, khw: 1) 619 | x = makeUpsampleNearest(graph: graph, xIn: x, scaleFactor: 8) 620 | return makeByteConverter(graph: graph, xIn: x) 621 | } 622 | 623 | class MapleDiffusion { 624 | let device: MTLDevice 625 | let graphDevice: MPSGraphDevice 626 | let commandQueue: MTLCommandQueue 627 | let saveMemory: Bool 628 | let shouldSynchronize: Bool 629 | 630 | // text tokenization 631 | let tokenizer: BPETokenizer 632 | 633 | // text guidance 634 | var textGuidanceExecutable: MPSGraphExecutable? 635 | 636 | // time embedding 637 | let tembGraph: MPSGraph 638 | let tembTIn: MPSGraphTensor 639 | let tembOut: MPSGraphTensor 640 | 641 | // diffusion 642 | let diffGraph: MPSGraph 643 | let diffGuidanceScaleIn: MPSGraphTensor 644 | let diffXIn: MPSGraphTensor 645 | let diffEtaUncondIn: MPSGraphTensor 646 | let diffEtaCondIn: MPSGraphTensor 647 | let diffTIn: MPSGraphTensor 648 | let diffTPrevIn: MPSGraphTensor 649 | let diffOut: MPSGraphTensor 650 | let diffAuxOut: MPSGraphTensor 651 | 652 | // unet 653 | // MEM-HACK: split into subgraphs 654 | var unetAnUnexpectedJourneyExecutable: MPSGraphExecutable? 655 | var anUnexpectedJourneyShapes = [[NSNumber]]() 656 | var unetTheDesolationOfSmaugExecutable: MPSGraphExecutable? 657 | var theDesolationOfSmaugShapes = [[NSNumber]]() 658 | var theDesolationOfSmaugIndices = [MPSGraphTensor: Int]() 659 | var unetTheBattleOfTheFiveArmiesExecutable: MPSGraphExecutable? 660 | var theBattleOfTheFiveArmiesIndices = [MPSGraphTensor: Int]() 661 | 662 | var width: NSNumber = 64 663 | var height: NSNumber = 64 664 | 665 | public init(saveMemoryButBeSlower: Bool = true) { 666 | saveMemory = saveMemoryButBeSlower 667 | device = MTLCreateSystemDefaultDevice()! 668 | graphDevice = MPSGraphDevice(mtlDevice: device) 669 | commandQueue = device.makeCommandQueue()! 670 | shouldSynchronize = !device.hasUnifiedMemory 671 | 672 | // text tokenization 673 | tokenizer = BPETokenizer() 674 | 675 | // time embedding 676 | tembGraph = makeGraph(synchonize: shouldSynchronize) 677 | tembTIn = tembGraph.placeholder(shape: [1], dataType: MPSDataType.int32, name: nil) 678 | tembOut = makeTimeFeatures(graph: tembGraph, tIn: tembTIn) 679 | 680 | // diffusion 681 | diffGraph = makeGraph(synchonize: shouldSynchronize) 682 | diffXIn = diffGraph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil) 683 | diffEtaUncondIn = diffGraph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil) 684 | diffEtaCondIn = diffGraph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil) 685 | diffTIn = diffGraph.placeholder(shape: [1], dataType: MPSDataType.int32, name: nil) 686 | diffTPrevIn = diffGraph.placeholder(shape: [1], dataType: MPSDataType.int32, name: nil) 687 | diffGuidanceScaleIn = diffGraph.placeholder(shape: [1], dataType: MPSDataType.float32, name: nil) 688 | diffOut = makeDiffusionStep(graph: diffGraph, xIn: diffXIn, etaUncondIn: diffEtaUncondIn, etaCondIn: diffEtaCondIn, tIn: diffTIn, tPrevIn: diffTPrevIn, guidanceScaleIn: diffGraph.cast(diffGuidanceScaleIn, to: MPSDataType.float16, name: "this string must not be the empty string")) 689 | diffAuxOut = makeAuxUpsampler(graph: diffGraph, xIn: diffOut) 690 | } 691 | 692 | public func initModels(completion: (Float, String)->()) { 693 | // text guidance 694 | completion(0, "Loading text guidance...") 695 | initTextGuidance() 696 | 697 | // unet 698 | completion(0.25, "Loading UNet part 1/3...") 699 | initAnUnexpectedJourney() 700 | completion(0.5, "Loading UNet part 2/3...") 701 | initTheDesolationOfSmaug() 702 | completion(0.75, "Loading UNet part 3/3...") 703 | initTheBattleOfTheFiveArmies() 704 | completion(1, "Loaded models") 705 | } 706 | 707 | private func initTextGuidance() { 708 | let graph = makeGraph(synchonize: shouldSynchronize) 709 | let textGuidanceIn = graph.placeholder(shape: [2, 77], dataType: MPSDataType.int32, name: nil) 710 | let textGuidanceOut = makeTextGuidance(graph: graph, xIn: textGuidanceIn, name: "cond_stage_model.transformer.text_model") 711 | let textGuidanceOut0 = graph.sliceTensor(textGuidanceOut, dimension: 0, start: 0, length: 1, name: nil) 712 | let textGuidanceOut1 = graph.sliceTensor(textGuidanceOut, dimension: 0, start: 1, length: 1, name: nil) 713 | textGuidanceExecutable = graph.compile(with: graphDevice, feeds: [textGuidanceIn: MPSGraphShapedType(shape: textGuidanceIn.shape, dataType: MPSDataType.int32)], targetTensors: [textGuidanceOut0, textGuidanceOut1], targetOperations: nil, compilationDescriptor: nil) 714 | } 715 | 716 | private func initAnUnexpectedJourney() { 717 | let graph = makeGraph(synchonize: shouldSynchronize) 718 | let xIn = graph.placeholder(shape: [1, height, width, 4], dataType: MPSDataType.float16, name: nil) 719 | let condIn = graph.placeholder(shape: [saveMemory ? 1 : 2, 77, 768], dataType: MPSDataType.float16, name: nil) 720 | let tembIn = graph.placeholder(shape: [1, 320], dataType: MPSDataType.float16, name: nil) 721 | let unetOuts = makeUNetAnUnexpectedJourney(graph: graph, xIn: xIn, tembIn: tembIn, condIn: condIn, name: "model.diffusion_model", saveMemory: saveMemory) 722 | let unetFeeds = [xIn, condIn, tembIn].reduce(into: [:], {$0[$1] = MPSGraphShapedType(shape: $1.shape!, dataType: $1.dataType)}) 723 | unetAnUnexpectedJourneyExecutable = graph.compile(with: graphDevice, feeds: unetFeeds, targetTensors: unetOuts, targetOperations: nil, compilationDescriptor: nil) 724 | anUnexpectedJourneyShapes = unetOuts.map{$0.shape!} 725 | } 726 | 727 | private func initTheDesolationOfSmaug() { 728 | let graph = makeGraph(synchonize: shouldSynchronize) 729 | let condIn = graph.placeholder(shape: [saveMemory ? 1 : 2, 77, 768], dataType: MPSDataType.float16, name: nil) 730 | let placeholders = anUnexpectedJourneyShapes.map{graph.placeholder(shape: $0, dataType: MPSDataType.float16, name: nil)} + [condIn] 731 | theDesolationOfSmaugIndices.removeAll() 732 | for i in 0.. MPSGraphTensorData { 755 | let graph = makeGraph(synchonize: shouldSynchronize) 756 | let out = graph.randomTensor(withShape: [1, height, width, 4], descriptor: MPSGraphRandomOpDescriptor(distribution: .normal, dataType: .float16)!, seed: seed, name: nil) 757 | return graph.run(with: commandQueue, feeds: [:], targetTensors: [out], targetOperations: nil)[out]! 758 | } 759 | 760 | private func runTextGuidance(baseTokens: [Int], tokens: [Int]) -> (MPSGraphTensorData, MPSGraphTensorData) { 761 | let tokensData = (baseTokens + tokens).map({Int32($0)}).withUnsafeBufferPointer {Data(buffer: $0)} 762 | let tokensMPSData = MPSGraphTensorData(device: graphDevice, data: tokensData, shape: [2, 77], dataType: MPSDataType.int32) 763 | let res = textGuidanceExecutable!.run(with: commandQueue, inputs: [tokensMPSData], results: nil, executionDescriptor: nil) 764 | return (res[0], res[1]) 765 | } 766 | 767 | private func loadDecoderAndGetFinalImage(xIn: MPSGraphTensorData) -> MPSGraphTensorData { 768 | // MEM-HACK: decoder is loaded from disc and deallocated to save memory (at cost of latency) 769 | let x = xIn 770 | let decoderGraph = makeGraph(synchonize: shouldSynchronize) 771 | let decoderIn = decoderGraph.placeholder(shape: x.shape, dataType: MPSDataType.float16, name: nil) 772 | let decoderOut = makeDecoder(graph: decoderGraph, xIn: decoderIn) 773 | return decoderGraph.run(with: commandQueue, feeds: [decoderIn: x], targetTensors: [decoderOut], targetOperations: nil)[decoderOut]! 774 | } 775 | 776 | private func reorderAnUnexpectedJourney(x: [MPSGraphTensorData]) -> [MPSGraphTensorData] { 777 | var out = [MPSGraphTensorData]() 778 | for r in unetAnUnexpectedJourneyExecutable!.feedTensors! { 779 | for i in x { 780 | if (i.shape == r.shape) { 781 | out.append(i) 782 | } 783 | } 784 | } 785 | return out 786 | } 787 | 788 | private func reorderTheDesolationOfSmaug(x: [MPSGraphTensorData]) -> [MPSGraphTensorData] { 789 | var out = [MPSGraphTensorData]() 790 | for r in unetTheDesolationOfSmaugExecutable!.feedTensors! { 791 | out.append(x[theDesolationOfSmaugIndices[r]!]) 792 | } 793 | return out 794 | } 795 | 796 | private func reorderTheBattleOfTheFiveArmies(x: [MPSGraphTensorData]) -> [MPSGraphTensorData] { 797 | var out = [MPSGraphTensorData]() 798 | for r in unetTheBattleOfTheFiveArmiesExecutable!.feedTensors! { 799 | out.append(x[theBattleOfTheFiveArmiesIndices[r]!]) 800 | } 801 | return out 802 | } 803 | 804 | private func runUNet(latent: MPSGraphTensorData, guidance: MPSGraphTensorData, temb: MPSGraphTensorData) -> MPSGraphTensorData { 805 | var x = unetAnUnexpectedJourneyExecutable!.run(with: commandQueue, inputs: reorderAnUnexpectedJourney(x: [latent, guidance, temb]), results: nil, executionDescriptor: nil) 806 | x = unetTheDesolationOfSmaugExecutable!.run(with: commandQueue, inputs: reorderTheDesolationOfSmaug(x: x + [guidance]), results: nil, executionDescriptor: nil) 807 | return unetTheBattleOfTheFiveArmiesExecutable!.run(with: commandQueue, inputs: reorderTheBattleOfTheFiveArmies(x: x + [guidance]), results: nil, executionDescriptor: nil)[0] 808 | } 809 | 810 | private func runBatchedUNet(latent: MPSGraphTensorData, baseGuidance: MPSGraphTensorData, textGuidance: MPSGraphTensorData, temb: MPSGraphTensorData) -> (MPSGraphTensorData, MPSGraphTensorData) { 811 | // concat 812 | var graph = makeGraph(synchonize: shouldSynchronize) 813 | let bg = graph.placeholder(shape: baseGuidance.shape, dataType: MPSDataType.float16, name: nil) 814 | let tg = graph.placeholder(shape: textGuidance.shape, dataType: MPSDataType.float16, name: nil) 815 | let concatGuidance = graph.concatTensors([bg, tg], dimension: 0, name: nil) 816 | let concatGuidanceData = graph.run(with: commandQueue, feeds: [bg : baseGuidance, tg: textGuidance], targetTensors: [concatGuidance], targetOperations: nil)[concatGuidance]! 817 | // run 818 | let concatEtaData = runUNet(latent: latent, guidance: concatGuidanceData, temb: temb) 819 | // split 820 | graph = makeGraph(synchonize: shouldSynchronize) 821 | let etas = graph.placeholder(shape: concatEtaData.shape, dataType: concatEtaData.dataType, name: nil) 822 | let eta0 = graph.sliceTensor(etas, dimension: 0, start: 0, length: 1, name: nil) 823 | let eta1 = graph.sliceTensor(etas, dimension: 0, start: 1, length: 1, name: nil) 824 | let etaRes = graph.run(with: commandQueue, feeds: [etas: concatEtaData], targetTensors: [eta0, eta1], targetOperations: nil) 825 | return (etaRes[eta0]!, etaRes[eta1]!) 826 | } 827 | 828 | private func generateLatent(prompt: String, negativePrompt: String, seed: Int, steps: Int, guidanceScale: Float, completion: @escaping (CGImage?, Float, String)->()) -> MPSGraphTensorData { 829 | completion(nil, 0, "Tokenizing...") 830 | 831 | // 1. String -> Tokens 832 | let baseTokens = tokenizer.encode(s: negativePrompt) 833 | let tokens = tokenizer.encode(s: prompt) 834 | completion(nil, 0.25 * 1 / Float(steps), "Encoding...") 835 | 836 | // 2. Tokens -> Embedding 837 | let (baseGuidance, textGuidance) = runTextGuidance(baseTokens: baseTokens, tokens: tokens) 838 | if (saveMemory) { 839 | // MEM-HACK unload the text guidance to fit the unet 840 | textGuidanceExecutable = nil 841 | } 842 | completion(nil, 0.5 * 1 / Float(steps), "Generating noise...") 843 | 844 | // 3. Noise generation 845 | var latent = randomLatent(seed: seed) 846 | let timesteps = Array(stride(from: 1, to: 1000, by: Int(1000 / steps))) 847 | completion(nil, 0.75 * 1 / Float(steps), "Starting diffusion...") 848 | 849 | // 4. Diffusion 850 | for t in (0.. 0 ? timesteps[t - 1] : timesteps[t] - 1000 / steps 855 | let tData = [Int32(timesteps[t])].withUnsafeBufferPointer {Data(buffer: $0)} 856 | let tMPSData = MPSGraphTensorData(device: graphDevice, data: tData, shape: [1], dataType: MPSDataType.int32) 857 | let tPrevData = [Int32(tsPrev)].withUnsafeBufferPointer {Data(buffer: $0)} 858 | let tPrevMPSData = MPSGraphTensorData(device: graphDevice, data: tPrevData, shape: [1], dataType: MPSDataType.int32) 859 | let guidanceScaleData = [Float32(guidanceScale)].withUnsafeBufferPointer {Data(buffer: $0)} 860 | let guidanceScaleMPSData = MPSGraphTensorData(device: graphDevice, data: guidanceScaleData, shape: [1], dataType: MPSDataType.float32) 861 | let temb = tembGraph.run(with: commandQueue, feeds: [tembTIn: tMPSData], targetTensors: [tembOut], targetOperations: nil)[tembOut]! 862 | let etaUncond: MPSGraphTensorData 863 | let etaCond: MPSGraphTensorData 864 | if (saveMemory) { 865 | // MEM-HACK: un/neg-conditional and text-conditional are run in two separate passes (not batched) to save memory 866 | etaUncond = runUNet(latent: latent, guidance: baseGuidance, temb: temb) 867 | etaCond = runUNet(latent: latent, guidance: textGuidance, temb: temb) 868 | } else { 869 | (etaUncond, etaCond) = runBatchedUNet(latent: latent, baseGuidance: baseGuidance, textGuidance: textGuidance, temb: temb) 870 | } 871 | let res = diffGraph.run(with: commandQueue, feeds: [diffXIn: latent, diffEtaUncondIn: etaUncond, diffEtaCondIn: etaCond, diffTIn: tMPSData, diffTPrevIn: tPrevMPSData, diffGuidanceScaleIn: guidanceScaleMPSData], targetTensors: [diffOut, diffAuxOut], targetOperations: nil) 872 | latent = res[diffOut]! 873 | 874 | // update ui 875 | let tock = CFAbsoluteTimeGetCurrent() 876 | let stepRuntime = String(format:"%.2fs", tock - tick) 877 | let progressDesc = t == 0 ? "Decoding..." : "Step \(timesteps.count - t) / \(timesteps.count) (\(stepRuntime) / step)" 878 | completion(tensorToCGImage(data: res[diffAuxOut]!), Float(timesteps.count - t) / Float(timesteps.count), progressDesc) 879 | } 880 | return latent 881 | } 882 | 883 | public func generate(prompt: String, negativePrompt: String, seed: Int, steps: Int, guidanceScale: Float, completion: @escaping (CGImage?, Float, String)->()) { 884 | let latent = generateLatent(prompt: prompt, negativePrompt: negativePrompt, seed: seed, steps: steps, guidanceScale: guidanceScale, completion: completion) 885 | 886 | if (saveMemory) { 887 | // MEM-HACK: unload the unet to fit the decoder 888 | unetAnUnexpectedJourneyExecutable = nil 889 | unetTheDesolationOfSmaugExecutable = nil 890 | unetTheBattleOfTheFiveArmiesExecutable = nil 891 | } 892 | 893 | // 5. Decoder 894 | let decoderRes = loadDecoderAndGetFinalImage(xIn: latent) 895 | completion(tensorToCGImage(data: decoderRes), 1.0, "Cooling down...") 896 | 897 | if (saveMemory) { 898 | // reload the unet and text guidance 899 | initAnUnexpectedJourney() 900 | initTheDesolationOfSmaug() 901 | initTheBattleOfTheFiveArmies() 902 | initTextGuidance() 903 | } 904 | } 905 | } 906 | 907 | func tensorToCGImage(data: MPSGraphTensorData) -> CGImage { 908 | let shape = data.shape.map{$0.intValue} 909 | var imageArrayCPUBytes = [UInt8](repeating: 0, count: shape.reduce(1, *)) 910 | data.mpsndarray().readBytes(&imageArrayCPUBytes, strideBytes: nil) 911 | return CGImage(width: shape[2], height: shape[1], bitsPerComponent: 8, bitsPerPixel: 32, bytesPerRow: shape[2]*shape[3], space: CGColorSpaceCreateDeviceRGB(), bitmapInfo: CGBitmapInfo(rawValue: CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.noneSkipLast.rawValue), provider: CGDataProvider(data: NSData(bytes: &imageArrayCPUBytes, length: imageArrayCPUBytes.count))!, decode: nil, shouldInterpolate: true, intent: CGColorRenderingIntent.defaultIntent)! 912 | } 913 | 914 | -------------------------------------------------------------------------------- /maple-diffusion/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /maple-diffusion/bins/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madebyollin/maple-diffusion/6304d68a066b8a3d9a2d5faded29be271ea5a55a/maple-diffusion/bins/.gitkeep -------------------------------------------------------------------------------- /maple-diffusion/bins/alphas_cumprod.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madebyollin/maple-diffusion/6304d68a066b8a3d9a2d5faded29be271ea5a55a/maple-diffusion/bins/alphas_cumprod.bin -------------------------------------------------------------------------------- /maple-diffusion/maple_diffusion.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.developer.kernel.increased-memory-limit 6 | 7 | com.apple.security.app-sandbox 8 | 9 | com.apple.security.files.user-selected.read-only 10 | 11 | 12 | 13 | -------------------------------------------------------------------------------- /maple-diffusion/maple_diffusionApp.swift: -------------------------------------------------------------------------------- 1 | import SwiftUI 2 | 3 | @main 4 | struct maple_diffusionApp: App { 5 | var body: some Scene { 6 | WindowGroup { 7 | ContentView().frame(minWidth: 192, minHeight: 192).navigationTitle("🍁 Maple Diffusion") 8 | } 9 | } 10 | } 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.23 2 | pytorch-lightning>=1.7 3 | requests>=2.28 4 | torch>=1.12 5 | -------------------------------------------------------------------------------- /screenshot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/madebyollin/maple-diffusion/6304d68a066b8a3d9a2d5faded29be271ea5a55a/screenshot.jpg --------------------------------------------------------------------------------