├── MyAppleIntelligence.xcodeproj ├── project.pbxproj ├── project.xcworkspace │ ├── contents.xcworkspacedata │ └── xcshareddata │ │ └── swiftpm │ │ └── Package.resolved └── xcuserdata │ └── stefanblos.xcuserdatad │ ├── xcdebugger │ └── Breakpoints_v2.xcbkptlist │ └── xcschemes │ └── xcschememanagement.plist ├── MyAppleIntelligence ├── Assets.xcassets │ ├── AccentColor.colorset │ │ └── Contents.json │ ├── AppIcon.appiconset │ │ ├── App Icon DALL·E Nov 12.png │ │ └── Contents.json │ ├── Contents.json │ ├── abstract.imageset │ │ ├── Contents.json │ │ └── abstract.jpg │ ├── ai.imageset │ │ ├── Contents.json │ │ └── ai.png │ ├── emin.imageset │ │ ├── Contents.json │ │ └── emin.png │ ├── haikei.imageset │ │ ├── Contents.json │ │ └── haikei.png │ ├── santa.imageset │ │ ├── Contents.json │ │ └── santa.png │ └── stefan.imageset │ │ ├── Contents.json │ │ └── stefan.png ├── Components │ ├── MeshGradient+Custom.swift │ └── TextInputView.swift ├── Features │ ├── Image Playground │ │ ├── Helpers │ │ │ ├── CustomError.swift │ │ │ ├── Downloader.swift │ │ │ ├── GenerationContext.swift │ │ │ ├── GenerationState.swift │ │ │ ├── Pipeline.swift │ │ │ ├── PipelineLoader.swift │ │ │ ├── Settings.swift │ │ │ ├── State.swift │ │ │ └── Utils.swift │ │ ├── Model │ │ │ ├── AttentionVariant.swift │ │ │ ├── ImageStyle.swift │ │ │ └── ModelInfo.swift │ │ ├── View │ │ │ ├── BlobShape.swift │ │ │ ├── GradientLabelView.swift │ │ │ ├── ImagePlaygroundView.swift │ │ │ ├── ImageTemplateView.swift │ │ │ ├── ImageWithPlaceholder.swift │ │ │ └── SiriIcon.swift │ │ └── ViewModel │ │ │ └── ImagePlaygroundViewModel.swift │ └── Writing tools │ │ ├── Model │ │ ├── OptionCardSize.swift │ │ ├── Shaders │ │ │ ├── Shimmer.metal │ │ │ └── Wave.metal │ │ └── WritingToolOption.swift │ │ ├── View │ │ ├── OptionCard.swift │ │ ├── SheetButton.swift │ │ ├── WritingToolsInputView.swift │ │ └── WritingToolsView.swift │ │ └── ViewModel │ │ └── WritingToolsViewModel.swift ├── Libraries │ ├── Configuration.swift │ ├── Evaluate.swift │ ├── KVCache.swift │ ├── LLMModel.swift │ ├── Llama.swift │ ├── Load.swift │ ├── Lora.swift │ ├── ModelConfiguration.swift │ └── Tokenizer.swift ├── MyAppleIntelligence.entitlements ├── MyAppleIntelligenceApp.swift ├── Navigation │ ├── Model │ │ └── NavigationOption.swift │ ├── View │ │ └── ContentView.swift │ └── ViewModel │ │ └── NavigationModel.swift └── Preview Content │ └── Preview Assets.xcassets │ └── Contents.json └── README.md /MyAppleIntelligence.xcodeproj/project.pbxproj: -------------------------------------------------------------------------------- 1 | // !$*UTF8*$! 2 | { 3 | archiveVersion = 1; 4 | classes = { 5 | }; 6 | objectVersion = 77; 7 | objects = { 8 | 9 | /* Begin PBXBuildFile section */ 10 | ACC22A252CDEA4910033EFD2 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A242CDEA4910033EFD2 /* MLX */; }; 11 | ACC22A272CDEA4910033EFD2 /* MLXFFT in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A262CDEA4910033EFD2 /* MLXFFT */; }; 12 | ACC22A292CDEA4910033EFD2 /* MLXFast in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A282CDEA4910033EFD2 /* MLXFast */; }; 13 | ACC22A2B2CDEA4910033EFD2 /* MLXLinalg in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A2A2CDEA4910033EFD2 /* MLXLinalg */; }; 14 | ACC22A2D2CDEA4910033EFD2 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A2C2CDEA4910033EFD2 /* MLXNN */; }; 15 | ACC22A362CDEA6370033EFD2 /* Transformers in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A352CDEA6370033EFD2 /* Transformers */; }; 16 | ACC22A462CDEA8A30033EFD2 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A452CDEA8A30033EFD2 /* MLXOptimizers */; }; 17 | ACC22A4D2CDEA9D40033EFD2 /* MarkdownUI in Frameworks */ = {isa = PBXBuildFile; productRef = ACC22A4C2CDEA9D40033EFD2 /* MarkdownUI */; }; 18 | ACCD86F72CE103C400A369B1 /* ZIPFoundation in Frameworks */ = {isa = PBXBuildFile; productRef = ACCD86F62CE103C400A369B1 /* ZIPFoundation */; }; 19 | ACCD87092CE109D200A369B1 /* StableDiffusion in Frameworks */ = {isa = PBXBuildFile; productRef = ACCD87082CE109D200A369B1 /* StableDiffusion */; }; 20 | /* End PBXBuildFile section */ 21 | 22 | /* Begin PBXFileReference section */ 23 | ACD01E8B2CDA9A6B005B11E2 /* MyAppleIntelligence.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = MyAppleIntelligence.app; sourceTree = BUILT_PRODUCTS_DIR; }; 24 | /* End PBXFileReference section */ 25 | 26 | /* Begin PBXFileSystemSynchronizedRootGroup section */ 27 | ACD01E8D2CDA9A6B005B11E2 /* MyAppleIntelligence */ = { 28 | isa = PBXFileSystemSynchronizedRootGroup; 29 | path = MyAppleIntelligence; 30 | sourceTree = ""; 31 | }; 32 | /* End PBXFileSystemSynchronizedRootGroup section */ 33 | 34 | /* Begin PBXFrameworksBuildPhase section */ 35 | ACD01E882CDA9A6B005B11E2 /* Frameworks */ = { 36 | isa = PBXFrameworksBuildPhase; 37 | buildActionMask = 2147483647; 38 | files = ( 39 | ACC22A362CDEA6370033EFD2 /* Transformers in Frameworks */, 40 | ACC22A292CDEA4910033EFD2 /* MLXFast in Frameworks */, 41 | ACC22A2B2CDEA4910033EFD2 /* MLXLinalg in Frameworks */, 42 | ACC22A4D2CDEA9D40033EFD2 /* MarkdownUI in Frameworks */, 43 | ACC22A252CDEA4910033EFD2 /* MLX in Frameworks */, 44 | ACC22A462CDEA8A30033EFD2 /* MLXOptimizers in Frameworks */, 45 | ACC22A2D2CDEA4910033EFD2 /* MLXNN in Frameworks */, 46 | ACC22A272CDEA4910033EFD2 /* MLXFFT in Frameworks */, 47 | ACCD86F72CE103C400A369B1 /* ZIPFoundation in Frameworks */, 48 | ACCD87092CE109D200A369B1 /* StableDiffusion in Frameworks */, 49 | ); 50 | runOnlyForDeploymentPostprocessing = 0; 51 | }; 52 | /* End PBXFrameworksBuildPhase section */ 53 | 54 | /* Begin PBXGroup section */ 55 | ACD01E822CDA9A6B005B11E2 = { 56 | isa = PBXGroup; 57 | children = ( 58 | ACD01E8D2CDA9A6B005B11E2 /* MyAppleIntelligence */, 59 | ACD01E8C2CDA9A6B005B11E2 /* Products */, 60 | ); 61 | sourceTree = ""; 62 | }; 63 | ACD01E8C2CDA9A6B005B11E2 /* Products */ = { 64 | isa = PBXGroup; 65 | children = ( 66 | ACD01E8B2CDA9A6B005B11E2 /* MyAppleIntelligence.app */, 67 | ); 68 | name = Products; 69 | sourceTree = ""; 70 | }; 71 | /* End PBXGroup section */ 72 | 73 | /* Begin PBXNativeTarget section */ 74 | ACD01E8A2CDA9A6B005B11E2 /* MyAppleIntelligence */ = { 75 | isa = PBXNativeTarget; 76 | buildConfigurationList = ACD01E9A2CDA9A6D005B11E2 /* Build configuration list for PBXNativeTarget "MyAppleIntelligence" */; 77 | buildPhases = ( 78 | ACD01E872CDA9A6B005B11E2 /* Sources */, 79 | ACD01E882CDA9A6B005B11E2 /* Frameworks */, 80 | ACD01E892CDA9A6B005B11E2 /* Resources */, 81 | ); 82 | buildRules = ( 83 | ); 84 | dependencies = ( 85 | ); 86 | fileSystemSynchronizedGroups = ( 87 | ACD01E8D2CDA9A6B005B11E2 /* MyAppleIntelligence */, 88 | ); 89 | name = MyAppleIntelligence; 90 | packageProductDependencies = ( 91 | ACC22A242CDEA4910033EFD2 /* MLX */, 92 | ACC22A262CDEA4910033EFD2 /* MLXFFT */, 93 | ACC22A282CDEA4910033EFD2 /* MLXFast */, 94 | ACC22A2A2CDEA4910033EFD2 /* MLXLinalg */, 95 | ACC22A2C2CDEA4910033EFD2 /* MLXNN */, 96 | ACC22A352CDEA6370033EFD2 /* Transformers */, 97 | ACC22A452CDEA8A30033EFD2 /* MLXOptimizers */, 98 | ACC22A4C2CDEA9D40033EFD2 /* MarkdownUI */, 99 | ACCD86F62CE103C400A369B1 /* ZIPFoundation */, 100 | ACCD87082CE109D200A369B1 /* StableDiffusion */, 101 | ); 102 | productName = MyAppleIntelligence; 103 | productReference = ACD01E8B2CDA9A6B005B11E2 /* MyAppleIntelligence.app */; 104 | productType = "com.apple.product-type.application"; 105 | }; 106 | /* End PBXNativeTarget section */ 107 | 108 | /* Begin PBXProject section */ 109 | ACD01E832CDA9A6B005B11E2 /* Project object */ = { 110 | isa = PBXProject; 111 | attributes = { 112 | BuildIndependentTargetsInParallel = 1; 113 | LastSwiftUpdateCheck = 1610; 114 | LastUpgradeCheck = 1610; 115 | TargetAttributes = { 116 | ACD01E8A2CDA9A6B005B11E2 = { 117 | CreatedOnToolsVersion = 16.1; 118 | }; 119 | }; 120 | }; 121 | buildConfigurationList = ACD01E862CDA9A6B005B11E2 /* Build configuration list for PBXProject "MyAppleIntelligence" */; 122 | developmentRegion = en; 123 | hasScannedForEncodings = 0; 124 | knownRegions = ( 125 | en, 126 | Base, 127 | ); 128 | mainGroup = ACD01E822CDA9A6B005B11E2; 129 | minimizedProjectReferenceProxies = 1; 130 | packageReferences = ( 131 | ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */, 132 | ACC22A342CDEA6370033EFD2 /* XCRemoteSwiftPackageReference "swift-transformers" */, 133 | ACC22A4B2CDEA9D40033EFD2 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */, 134 | ACCD86F52CE103C400A369B1 /* XCRemoteSwiftPackageReference "ZIPFoundation" */, 135 | ACCD87072CE109D200A369B1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */, 136 | ); 137 | preferredProjectObjectVersion = 77; 138 | productRefGroup = ACD01E8C2CDA9A6B005B11E2 /* Products */; 139 | projectDirPath = ""; 140 | projectRoot = ""; 141 | targets = ( 142 | ACD01E8A2CDA9A6B005B11E2 /* MyAppleIntelligence */, 143 | ); 144 | }; 145 | /* End PBXProject section */ 146 | 147 | /* Begin PBXResourcesBuildPhase section */ 148 | ACD01E892CDA9A6B005B11E2 /* Resources */ = { 149 | isa = PBXResourcesBuildPhase; 150 | buildActionMask = 2147483647; 151 | files = ( 152 | ); 153 | runOnlyForDeploymentPostprocessing = 0; 154 | }; 155 | /* End PBXResourcesBuildPhase section */ 156 | 157 | /* Begin PBXSourcesBuildPhase section */ 158 | ACD01E872CDA9A6B005B11E2 /* Sources */ = { 159 | isa = PBXSourcesBuildPhase; 160 | buildActionMask = 2147483647; 161 | files = ( 162 | ); 163 | runOnlyForDeploymentPostprocessing = 0; 164 | }; 165 | /* End PBXSourcesBuildPhase section */ 166 | 167 | /* Begin XCBuildConfiguration section */ 168 | ACD01E982CDA9A6D005B11E2 /* Debug */ = { 169 | isa = XCBuildConfiguration; 170 | buildSettings = { 171 | ALWAYS_SEARCH_USER_PATHS = NO; 172 | ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; 173 | CLANG_ANALYZER_NONNULL = YES; 174 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 175 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; 176 | CLANG_ENABLE_MODULES = YES; 177 | CLANG_ENABLE_OBJC_ARC = YES; 178 | CLANG_ENABLE_OBJC_WEAK = YES; 179 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 180 | CLANG_WARN_BOOL_CONVERSION = YES; 181 | CLANG_WARN_COMMA = YES; 182 | CLANG_WARN_CONSTANT_CONVERSION = YES; 183 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 184 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 185 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 186 | CLANG_WARN_EMPTY_BODY = YES; 187 | CLANG_WARN_ENUM_CONVERSION = YES; 188 | CLANG_WARN_INFINITE_RECURSION = YES; 189 | CLANG_WARN_INT_CONVERSION = YES; 190 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 191 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 192 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 193 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 194 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 195 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 196 | CLANG_WARN_STRICT_PROTOTYPES = YES; 197 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 198 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 199 | CLANG_WARN_UNREACHABLE_CODE = YES; 200 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 201 | COPY_PHASE_STRIP = NO; 202 | DEBUG_INFORMATION_FORMAT = dwarf; 203 | ENABLE_STRICT_OBJC_MSGSEND = YES; 204 | ENABLE_TESTABILITY = YES; 205 | ENABLE_USER_SCRIPT_SANDBOXING = YES; 206 | GCC_C_LANGUAGE_STANDARD = gnu17; 207 | GCC_DYNAMIC_NO_PIC = NO; 208 | GCC_NO_COMMON_BLOCKS = YES; 209 | GCC_OPTIMIZATION_LEVEL = 0; 210 | GCC_PREPROCESSOR_DEFINITIONS = ( 211 | "DEBUG=1", 212 | "$(inherited)", 213 | ); 214 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 215 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 216 | GCC_WARN_UNDECLARED_SELECTOR = YES; 217 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 218 | GCC_WARN_UNUSED_FUNCTION = YES; 219 | GCC_WARN_UNUSED_VARIABLE = YES; 220 | LOCALIZATION_PREFERS_STRING_CATALOGS = YES; 221 | MTL_ENABLE_DEBUG_INFO = INCLUDE_SOURCE; 222 | MTL_FAST_MATH = YES; 223 | ONLY_ACTIVE_ARCH = YES; 224 | SWIFT_ACTIVE_COMPILATION_CONDITIONS = "DEBUG $(inherited)"; 225 | SWIFT_OPTIMIZATION_LEVEL = "-Onone"; 226 | }; 227 | name = Debug; 228 | }; 229 | ACD01E992CDA9A6D005B11E2 /* Release */ = { 230 | isa = XCBuildConfiguration; 231 | buildSettings = { 232 | ALWAYS_SEARCH_USER_PATHS = NO; 233 | ASSETCATALOG_COMPILER_GENERATE_SWIFT_ASSET_SYMBOL_EXTENSIONS = YES; 234 | CLANG_ANALYZER_NONNULL = YES; 235 | CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE; 236 | CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; 237 | CLANG_ENABLE_MODULES = YES; 238 | CLANG_ENABLE_OBJC_ARC = YES; 239 | CLANG_ENABLE_OBJC_WEAK = YES; 240 | CLANG_WARN_BLOCK_CAPTURE_AUTORELEASING = YES; 241 | CLANG_WARN_BOOL_CONVERSION = YES; 242 | CLANG_WARN_COMMA = YES; 243 | CLANG_WARN_CONSTANT_CONVERSION = YES; 244 | CLANG_WARN_DEPRECATED_OBJC_IMPLEMENTATIONS = YES; 245 | CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR; 246 | CLANG_WARN_DOCUMENTATION_COMMENTS = YES; 247 | CLANG_WARN_EMPTY_BODY = YES; 248 | CLANG_WARN_ENUM_CONVERSION = YES; 249 | CLANG_WARN_INFINITE_RECURSION = YES; 250 | CLANG_WARN_INT_CONVERSION = YES; 251 | CLANG_WARN_NON_LITERAL_NULL_CONVERSION = YES; 252 | CLANG_WARN_OBJC_IMPLICIT_RETAIN_SELF = YES; 253 | CLANG_WARN_OBJC_LITERAL_CONVERSION = YES; 254 | CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR; 255 | CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; 256 | CLANG_WARN_RANGE_LOOP_ANALYSIS = YES; 257 | CLANG_WARN_STRICT_PROTOTYPES = YES; 258 | CLANG_WARN_SUSPICIOUS_MOVE = YES; 259 | CLANG_WARN_UNGUARDED_AVAILABILITY = YES_AGGRESSIVE; 260 | CLANG_WARN_UNREACHABLE_CODE = YES; 261 | CLANG_WARN__DUPLICATE_METHOD_MATCH = YES; 262 | COPY_PHASE_STRIP = NO; 263 | DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym"; 264 | ENABLE_NS_ASSERTIONS = NO; 265 | ENABLE_STRICT_OBJC_MSGSEND = YES; 266 | ENABLE_USER_SCRIPT_SANDBOXING = YES; 267 | GCC_C_LANGUAGE_STANDARD = gnu17; 268 | GCC_NO_COMMON_BLOCKS = YES; 269 | GCC_WARN_64_TO_32_BIT_CONVERSION = YES; 270 | GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR; 271 | GCC_WARN_UNDECLARED_SELECTOR = YES; 272 | GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE; 273 | GCC_WARN_UNUSED_FUNCTION = YES; 274 | GCC_WARN_UNUSED_VARIABLE = YES; 275 | LOCALIZATION_PREFERS_STRING_CATALOGS = YES; 276 | MTL_ENABLE_DEBUG_INFO = NO; 277 | MTL_FAST_MATH = YES; 278 | SWIFT_COMPILATION_MODE = wholemodule; 279 | }; 280 | name = Release; 281 | }; 282 | ACD01E9B2CDA9A6D005B11E2 /* Debug */ = { 283 | isa = XCBuildConfiguration; 284 | buildSettings = { 285 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 286 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 287 | CODE_SIGN_ENTITLEMENTS = MyAppleIntelligence/MyAppleIntelligence.entitlements; 288 | CODE_SIGN_STYLE = Automatic; 289 | CURRENT_PROJECT_VERSION = 1; 290 | DEVELOPMENT_ASSET_PATHS = "\"MyAppleIntelligence/Preview Content\""; 291 | DEVELOPMENT_TEAM = JY49624W8T; 292 | ENABLE_HARDENED_RUNTIME = YES; 293 | ENABLE_PREVIEWS = YES; 294 | GENERATE_INFOPLIST_FILE = YES; 295 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; 296 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; 297 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; 298 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; 299 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; 300 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; 301 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; 302 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; 303 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 304 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 305 | IPHONEOS_DEPLOYMENT_TARGET = 18.1; 306 | LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; 307 | "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; 308 | MACOSX_DEPLOYMENT_TARGET = 15.1; 309 | MARKETING_VERSION = 1.0; 310 | PRODUCT_BUNDLE_IDENTIFIER = com.stefanblos.MyAppleIntelligence; 311 | PRODUCT_NAME = "$(TARGET_NAME)"; 312 | SDKROOT = auto; 313 | SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; 314 | SWIFT_EMIT_LOC_STRINGS = YES; 315 | SWIFT_VERSION = 5.0; 316 | TARGETED_DEVICE_FAMILY = "1,2,7"; 317 | XROS_DEPLOYMENT_TARGET = 2.1; 318 | }; 319 | name = Debug; 320 | }; 321 | ACD01E9C2CDA9A6D005B11E2 /* Release */ = { 322 | isa = XCBuildConfiguration; 323 | buildSettings = { 324 | ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; 325 | ASSETCATALOG_COMPILER_GLOBAL_ACCENT_COLOR_NAME = AccentColor; 326 | CODE_SIGN_ENTITLEMENTS = MyAppleIntelligence/MyAppleIntelligence.entitlements; 327 | CODE_SIGN_STYLE = Automatic; 328 | CURRENT_PROJECT_VERSION = 1; 329 | DEVELOPMENT_ASSET_PATHS = "\"MyAppleIntelligence/Preview Content\""; 330 | DEVELOPMENT_TEAM = JY49624W8T; 331 | ENABLE_HARDENED_RUNTIME = YES; 332 | ENABLE_PREVIEWS = YES; 333 | GENERATE_INFOPLIST_FILE = YES; 334 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphoneos*]" = YES; 335 | "INFOPLIST_KEY_UIApplicationSceneManifest_Generation[sdk=iphonesimulator*]" = YES; 336 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphoneos*]" = YES; 337 | "INFOPLIST_KEY_UIApplicationSupportsIndirectInputEvents[sdk=iphonesimulator*]" = YES; 338 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphoneos*]" = YES; 339 | "INFOPLIST_KEY_UILaunchScreen_Generation[sdk=iphonesimulator*]" = YES; 340 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphoneos*]" = UIStatusBarStyleDefault; 341 | "INFOPLIST_KEY_UIStatusBarStyle[sdk=iphonesimulator*]" = UIStatusBarStyleDefault; 342 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPad = "UIInterfaceOrientationPortrait UIInterfaceOrientationPortraitUpsideDown UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 343 | INFOPLIST_KEY_UISupportedInterfaceOrientations_iPhone = "UIInterfaceOrientationPortrait UIInterfaceOrientationLandscapeLeft UIInterfaceOrientationLandscapeRight"; 344 | IPHONEOS_DEPLOYMENT_TARGET = 18.1; 345 | LD_RUNPATH_SEARCH_PATHS = "@executable_path/Frameworks"; 346 | "LD_RUNPATH_SEARCH_PATHS[sdk=macosx*]" = "@executable_path/../Frameworks"; 347 | MACOSX_DEPLOYMENT_TARGET = 15.1; 348 | MARKETING_VERSION = 1.0; 349 | PRODUCT_BUNDLE_IDENTIFIER = com.stefanblos.MyAppleIntelligence; 350 | PRODUCT_NAME = "$(TARGET_NAME)"; 351 | SDKROOT = auto; 352 | SUPPORTED_PLATFORMS = "iphoneos iphonesimulator macosx xros xrsimulator"; 353 | SWIFT_EMIT_LOC_STRINGS = YES; 354 | SWIFT_VERSION = 5.0; 355 | TARGETED_DEVICE_FAMILY = "1,2,7"; 356 | XROS_DEPLOYMENT_TARGET = 2.1; 357 | }; 358 | name = Release; 359 | }; 360 | /* End XCBuildConfiguration section */ 361 | 362 | /* Begin XCConfigurationList section */ 363 | ACD01E862CDA9A6B005B11E2 /* Build configuration list for PBXProject "MyAppleIntelligence" */ = { 364 | isa = XCConfigurationList; 365 | buildConfigurations = ( 366 | ACD01E982CDA9A6D005B11E2 /* Debug */, 367 | ACD01E992CDA9A6D005B11E2 /* Release */, 368 | ); 369 | defaultConfigurationIsVisible = 0; 370 | defaultConfigurationName = Release; 371 | }; 372 | ACD01E9A2CDA9A6D005B11E2 /* Build configuration list for PBXNativeTarget "MyAppleIntelligence" */ = { 373 | isa = XCConfigurationList; 374 | buildConfigurations = ( 375 | ACD01E9B2CDA9A6D005B11E2 /* Debug */, 376 | ACD01E9C2CDA9A6D005B11E2 /* Release */, 377 | ); 378 | defaultConfigurationIsVisible = 0; 379 | defaultConfigurationName = Release; 380 | }; 381 | /* End XCConfigurationList section */ 382 | 383 | /* Begin XCRemoteSwiftPackageReference section */ 384 | ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */ = { 385 | isa = XCRemoteSwiftPackageReference; 386 | repositoryURL = "https://github.com/ml-explore/mlx-swift"; 387 | requirement = { 388 | kind = upToNextMajorVersion; 389 | minimumVersion = 0.18.1; 390 | }; 391 | }; 392 | ACC22A342CDEA6370033EFD2 /* XCRemoteSwiftPackageReference "swift-transformers" */ = { 393 | isa = XCRemoteSwiftPackageReference; 394 | repositoryURL = "https://github.com/huggingface/swift-transformers"; 395 | requirement = { 396 | kind = upToNextMajorVersion; 397 | minimumVersion = 0.1.14; 398 | }; 399 | }; 400 | ACC22A4B2CDEA9D40033EFD2 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */ = { 401 | isa = XCRemoteSwiftPackageReference; 402 | repositoryURL = "https://github.com/gonzalezreal/swift-markdown-ui"; 403 | requirement = { 404 | kind = upToNextMajorVersion; 405 | minimumVersion = 2.4.1; 406 | }; 407 | }; 408 | ACCD86F52CE103C400A369B1 /* XCRemoteSwiftPackageReference "ZIPFoundation" */ = { 409 | isa = XCRemoteSwiftPackageReference; 410 | repositoryURL = "https://github.com/weichsel/ZIPFoundation"; 411 | requirement = { 412 | kind = upToNextMajorVersion; 413 | minimumVersion = 0.9.19; 414 | }; 415 | }; 416 | ACCD87072CE109D200A369B1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */ = { 417 | isa = XCRemoteSwiftPackageReference; 418 | repositoryURL = "https://github.com/apple/ml-stable-diffusion"; 419 | requirement = { 420 | kind = upToNextMajorVersion; 421 | minimumVersion = 1.1.1; 422 | }; 423 | }; 424 | /* End XCRemoteSwiftPackageReference section */ 425 | 426 | /* Begin XCSwiftPackageProductDependency section */ 427 | ACC22A242CDEA4910033EFD2 /* MLX */ = { 428 | isa = XCSwiftPackageProductDependency; 429 | package = ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */; 430 | productName = MLX; 431 | }; 432 | ACC22A262CDEA4910033EFD2 /* MLXFFT */ = { 433 | isa = XCSwiftPackageProductDependency; 434 | package = ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */; 435 | productName = MLXFFT; 436 | }; 437 | ACC22A282CDEA4910033EFD2 /* MLXFast */ = { 438 | isa = XCSwiftPackageProductDependency; 439 | package = ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */; 440 | productName = MLXFast; 441 | }; 442 | ACC22A2A2CDEA4910033EFD2 /* MLXLinalg */ = { 443 | isa = XCSwiftPackageProductDependency; 444 | package = ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */; 445 | productName = MLXLinalg; 446 | }; 447 | ACC22A2C2CDEA4910033EFD2 /* MLXNN */ = { 448 | isa = XCSwiftPackageProductDependency; 449 | package = ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */; 450 | productName = MLXNN; 451 | }; 452 | ACC22A352CDEA6370033EFD2 /* Transformers */ = { 453 | isa = XCSwiftPackageProductDependency; 454 | package = ACC22A342CDEA6370033EFD2 /* XCRemoteSwiftPackageReference "swift-transformers" */; 455 | productName = Transformers; 456 | }; 457 | ACC22A452CDEA8A30033EFD2 /* MLXOptimizers */ = { 458 | isa = XCSwiftPackageProductDependency; 459 | package = ACC22A232CDEA4910033EFD2 /* XCRemoteSwiftPackageReference "mlx-swift" */; 460 | productName = MLXOptimizers; 461 | }; 462 | ACC22A4C2CDEA9D40033EFD2 /* MarkdownUI */ = { 463 | isa = XCSwiftPackageProductDependency; 464 | package = ACC22A4B2CDEA9D40033EFD2 /* XCRemoteSwiftPackageReference "swift-markdown-ui" */; 465 | productName = MarkdownUI; 466 | }; 467 | ACCD86F62CE103C400A369B1 /* ZIPFoundation */ = { 468 | isa = XCSwiftPackageProductDependency; 469 | package = ACCD86F52CE103C400A369B1 /* XCRemoteSwiftPackageReference "ZIPFoundation" */; 470 | productName = ZIPFoundation; 471 | }; 472 | ACCD87082CE109D200A369B1 /* StableDiffusion */ = { 473 | isa = XCSwiftPackageProductDependency; 474 | package = ACCD87072CE109D200A369B1 /* XCRemoteSwiftPackageReference "ml-stable-diffusion" */; 475 | productName = StableDiffusion; 476 | }; 477 | /* End XCSwiftPackageProductDependency section */ 478 | }; 479 | rootObject = ACD01E832CDA9A6B005B11E2 /* Project object */; 480 | } 481 | -------------------------------------------------------------------------------- /MyAppleIntelligence.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /MyAppleIntelligence.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved: -------------------------------------------------------------------------------- 1 | { 2 | "originHash" : "d1477efafbe24749de44196137834ce7ce64a6a2afa2155e7163eaf4cb83edd9", 3 | "pins" : [ 4 | { 5 | "identity" : "jinja", 6 | "kind" : "remoteSourceControl", 7 | "location" : "https://github.com/maiqingqiang/Jinja", 8 | "state" : { 9 | "revision" : "6dbe4c449469fb586d0f7339f900f0dd4d78b167", 10 | "version" : "1.0.6" 11 | } 12 | }, 13 | { 14 | "identity" : "ml-stable-diffusion", 15 | "kind" : "remoteSourceControl", 16 | "location" : "https://github.com/apple/ml-stable-diffusion", 17 | "state" : { 18 | "revision" : "5a170d29cf38e674b80541d7ce22929c6a11cdde", 19 | "version" : "1.1.1" 20 | } 21 | }, 22 | { 23 | "identity" : "mlx-swift", 24 | "kind" : "remoteSourceControl", 25 | "location" : "https://github.com/ml-explore/mlx-swift", 26 | "state" : { 27 | "revision" : "d649c62b77c487c25012910b0d02b30283d388ca", 28 | "version" : "0.18.1" 29 | } 30 | }, 31 | { 32 | "identity" : "networkimage", 33 | "kind" : "remoteSourceControl", 34 | "location" : "https://github.com/gonzalezreal/NetworkImage", 35 | "state" : { 36 | "revision" : "2849f5323265386e200484b0d0f896e73c3411b9", 37 | "version" : "6.0.1" 38 | } 39 | }, 40 | { 41 | "identity" : "swift-argument-parser", 42 | "kind" : "remoteSourceControl", 43 | "location" : "https://github.com/apple/swift-argument-parser.git", 44 | "state" : { 45 | "revision" : "41982a3656a71c768319979febd796c6fd111d5c", 46 | "version" : "1.5.0" 47 | } 48 | }, 49 | { 50 | "identity" : "swift-cmark", 51 | "kind" : "remoteSourceControl", 52 | "location" : "https://github.com/swiftlang/swift-cmark", 53 | "state" : { 54 | "revision" : "3ccff77b2dc5b96b77db3da0d68d28068593fa53", 55 | "version" : "0.5.0" 56 | } 57 | }, 58 | { 59 | "identity" : "swift-markdown-ui", 60 | "kind" : "remoteSourceControl", 61 | "location" : "https://github.com/gonzalezreal/swift-markdown-ui", 62 | "state" : { 63 | "revision" : "5f613358148239d0292c0cef674a3c2314737f9e", 64 | "version" : "2.4.1" 65 | } 66 | }, 67 | { 68 | "identity" : "swift-numerics", 69 | "kind" : "remoteSourceControl", 70 | "location" : "https://github.com/apple/swift-numerics", 71 | "state" : { 72 | "revision" : "0a5bc04095a675662cf24757cc0640aa2204253b", 73 | "version" : "1.0.2" 74 | } 75 | }, 76 | { 77 | "identity" : "swift-transformers", 78 | "kind" : "remoteSourceControl", 79 | "location" : "https://github.com/huggingface/swift-transformers", 80 | "state" : { 81 | "revision" : "d42fdae473c49ea216671da8caae58e102d28709", 82 | "version" : "0.1.14" 83 | } 84 | }, 85 | { 86 | "identity" : "zipfoundation", 87 | "kind" : "remoteSourceControl", 88 | "location" : "https://github.com/weichsel/ZIPFoundation", 89 | "state" : { 90 | "revision" : "02b6abe5f6eef7e3cbd5f247c5cc24e246efcfe0", 91 | "version" : "0.9.19" 92 | } 93 | } 94 | ], 95 | "version" : 3 96 | } 97 | -------------------------------------------------------------------------------- /MyAppleIntelligence.xcodeproj/xcuserdata/stefanblos.xcuserdatad/xcdebugger/Breakpoints_v2.xcbkptlist: -------------------------------------------------------------------------------- 1 | 2 | 6 | 7 | 9 | 21 | 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /MyAppleIntelligence.xcodeproj/xcuserdata/stefanblos.xcuserdatad/xcschemes/xcschememanagement.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SchemeUserState 6 | 7 | MyAppleIntelligence.xcscheme_^#shared#^_ 8 | 9 | orderHint 10 | 0 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "idiom" : "universal" 5 | } 6 | ], 7 | "info" : { 8 | "author" : "xcode", 9 | "version" : 1 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/AppIcon.appiconset/App Icon DALL·E Nov 12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/AppIcon.appiconset/App Icon DALL·E Nov 12.png -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "App Icon DALL·E Nov 12.png", 5 | "idiom" : "universal", 6 | "platform" : "ios", 7 | "size" : "1024x1024" 8 | }, 9 | { 10 | "appearances" : [ 11 | { 12 | "appearance" : "luminosity", 13 | "value" : "dark" 14 | } 15 | ], 16 | "idiom" : "universal", 17 | "platform" : "ios", 18 | "size" : "1024x1024" 19 | }, 20 | { 21 | "appearances" : [ 22 | { 23 | "appearance" : "luminosity", 24 | "value" : "tinted" 25 | } 26 | ], 27 | "idiom" : "universal", 28 | "platform" : "ios", 29 | "size" : "1024x1024" 30 | }, 31 | { 32 | "idiom" : "mac", 33 | "scale" : "1x", 34 | "size" : "16x16" 35 | }, 36 | { 37 | "idiom" : "mac", 38 | "scale" : "2x", 39 | "size" : "16x16" 40 | }, 41 | { 42 | "idiom" : "mac", 43 | "scale" : "1x", 44 | "size" : "32x32" 45 | }, 46 | { 47 | "idiom" : "mac", 48 | "scale" : "2x", 49 | "size" : "32x32" 50 | }, 51 | { 52 | "idiom" : "mac", 53 | "scale" : "1x", 54 | "size" : "128x128" 55 | }, 56 | { 57 | "idiom" : "mac", 58 | "scale" : "2x", 59 | "size" : "128x128" 60 | }, 61 | { 62 | "idiom" : "mac", 63 | "scale" : "1x", 64 | "size" : "256x256" 65 | }, 66 | { 67 | "idiom" : "mac", 68 | "scale" : "2x", 69 | "size" : "256x256" 70 | }, 71 | { 72 | "idiom" : "mac", 73 | "scale" : "1x", 74 | "size" : "512x512" 75 | }, 76 | { 77 | "idiom" : "mac", 78 | "scale" : "2x", 79 | "size" : "512x512" 80 | } 81 | ], 82 | "info" : { 83 | "author" : "xcode", 84 | "version" : 1 85 | } 86 | } 87 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/abstract.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "abstract.jpg", 5 | "idiom" : "universal", 6 | "scale" : "1x" 7 | }, 8 | { 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/abstract.imageset/abstract.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/abstract.imageset/abstract.jpg -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/ai.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "ai.png", 5 | "idiom" : "universal", 6 | "scale" : "1x" 7 | }, 8 | { 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | }, 21 | "properties" : { 22 | "template-rendering-intent" : "template" 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/ai.imageset/ai.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/ai.imageset/ai.png -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/emin.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "emin.png", 5 | "idiom" : "universal", 6 | "scale" : "1x" 7 | }, 8 | { 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/emin.imageset/emin.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/emin.imageset/emin.png -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/haikei.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "haikei.png", 5 | "idiom" : "universal", 6 | "scale" : "1x" 7 | }, 8 | { 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/haikei.imageset/haikei.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/haikei.imageset/haikei.png -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/santa.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "santa.png", 5 | "idiom" : "universal", 6 | "scale" : "1x" 7 | }, 8 | { 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/santa.imageset/santa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/santa.imageset/santa.png -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/stefan.imageset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "filename" : "stefan.png", 5 | "idiom" : "universal", 6 | "scale" : "1x" 7 | }, 8 | { 9 | "idiom" : "universal", 10 | "scale" : "2x" 11 | }, 12 | { 13 | "idiom" : "universal", 14 | "scale" : "3x" 15 | } 16 | ], 17 | "info" : { 18 | "author" : "xcode", 19 | "version" : 1 20 | } 21 | } 22 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Assets.xcassets/stefan.imageset/stefan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaemonLoki/MyAppleIntelligence/164b025abba0ff70a0bb4c6c9abf846df9938399/MyAppleIntelligence/Assets.xcassets/stefan.imageset/stefan.png -------------------------------------------------------------------------------- /MyAppleIntelligence/Components/MeshGradient+Custom.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MeshGradient+Custom.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 11.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | extension MeshGradient { 11 | 12 | static var custom: MeshGradient { 13 | return MeshGradient( 14 | width: 3, 15 | height: 2, 16 | points: [ 17 | [0, 0], [0.5, 0], [1, 0], 18 | [0, 1], [0.5, 1], [1, 1] 19 | ], 20 | colors: [ 21 | .yellow, .orange, .red, 22 | .blue, .indigo, .purple 23 | 24 | ] 25 | ) 26 | } 27 | 28 | } 29 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Components/TextInputView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // TextInputView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 07.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct TextInputView: View { 11 | 12 | @Binding var text: String 13 | var placeholder: String = "Describe an image" 14 | var showButton: Bool = true 15 | var sendAction: () -> Void 16 | 17 | var body: some View { 18 | HStack(spacing: 12) { 19 | SiriIcon() 20 | .frame(width: 32, height: 32) 21 | 22 | TextField(text: $text, prompt: Text(placeholder)) { 23 | EmptyView() 24 | } 25 | .frame(maxWidth: .infinity) 26 | 27 | if showButton { 28 | Button { 29 | sendAction() 30 | } label: { 31 | Image(systemName: "arrow.up") 32 | 33 | } 34 | .buttonStyle(.bordered) 35 | .buttonBorderShape(.circle) 36 | .tint(.gray) 37 | } 38 | } 39 | .frame(maxWidth: .infinity, maxHeight: 50) 40 | .padding(.vertical, 2) 41 | .padding(.horizontal, 12) 42 | .background(Color(uiColor: .secondarySystemBackground), in: Capsule()) 43 | } 44 | } 45 | 46 | #Preview { 47 | @Previewable @State var text: String = "test" 48 | TextInputView(text: $text) { 49 | print("Send") 50 | } 51 | } 52 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/CustomError.swift: -------------------------------------------------------------------------------- 1 | // 2 | // CustomError.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 12.11.24. 6 | // 7 | 8 | import Foundation 9 | 10 | struct CustomError: Error { 11 | var message: String 12 | } 13 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/Downloader.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Downloader.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import Foundation 9 | import Combine 10 | 11 | class Downloader: NSObject, ObservableObject { 12 | private(set) var destination: URL 13 | 14 | enum DownloadState { 15 | case notStarted 16 | case downloading(Double) 17 | case completed(URL) 18 | case failed(CustomError?) 19 | } 20 | 21 | private(set) lazy var downloadState: CurrentValueSubject = CurrentValueSubject(.notStarted) 22 | private var stateSubscriber: Cancellable? 23 | 24 | private var urlSession: URLSession? = nil 25 | 26 | init(from url: URL, to destination: URL, using authToken: String? = nil) { 27 | self.destination = destination 28 | super.init() 29 | 30 | var config = URLSessionConfiguration.default 31 | #if !os(macOS) 32 | // .background allows downloads to proceed in the background 33 | // helpful for devices that may not keep the app in the foreground for the download duration 34 | config = URLSessionConfiguration.background(withIdentifier: "net.pcuenca.diffusion.download") 35 | config.isDiscretionary = false 36 | config.sessionSendsLaunchEvents = true 37 | #endif 38 | urlSession = URLSession(configuration: config, delegate: self, delegateQueue: OperationQueue()) 39 | downloadState.value = .downloading(0) 40 | urlSession?.getAllTasks { tasks in 41 | // If there's an existing pending background task with the same URL, let it proceed. 42 | guard tasks.filter({ $0.originalRequest?.url == url }).isEmpty else { 43 | print("Already downloading \(url)") 44 | return 45 | } 46 | print("Starting download of \(url)") 47 | 48 | var request = URLRequest(url: url) 49 | if let authToken = authToken { 50 | request.setValue("Bearer \(authToken)", forHTTPHeaderField: "Authorization") 51 | } 52 | 53 | self.urlSession?.downloadTask(with: request).resume() 54 | } 55 | } 56 | 57 | @discardableResult 58 | func waitUntilDone() throws -> URL { 59 | // It's either this, or stream the bytes ourselves (add to a buffer, save to disk, etc; boring and finicky) 60 | let semaphore = DispatchSemaphore(value: 0) 61 | stateSubscriber = downloadState.sink { state in 62 | switch state { 63 | case .completed: semaphore.signal() 64 | case .failed: semaphore.signal() 65 | default: break 66 | } 67 | } 68 | semaphore.wait() 69 | 70 | switch downloadState.value { 71 | case .completed(let url): return url 72 | case .failed(let error): throw error ?? CustomError(message: "Unknown error") 73 | default : throw CustomError(message: "Should never happen, lol") 74 | } 75 | } 76 | 77 | func cancel() { 78 | urlSession?.invalidateAndCancel() 79 | } 80 | } 81 | 82 | extension Downloader: URLSessionDelegate, URLSessionDownloadDelegate { 83 | func urlSession(_: URLSession, downloadTask: URLSessionDownloadTask, didWriteData _: Int64, totalBytesWritten: Int64, totalBytesExpectedToWrite: Int64) { 84 | downloadState.value = .downloading(Double(totalBytesWritten) / Double(totalBytesExpectedToWrite)) 85 | } 86 | 87 | func urlSession(_: URLSession, downloadTask _: URLSessionDownloadTask, didFinishDownloadingTo location: URL) { 88 | guard FileManager.default.fileExists(atPath: location.path) else { 89 | downloadState.value = .failed(CustomError(message: "Invalid download location received: \(location)")) 90 | return 91 | } 92 | do { 93 | try FileManager.default.moveItem(at: location, to: destination) 94 | downloadState.value = .completed(destination) 95 | } catch { 96 | downloadState.value = .failed(error as? CustomError) 97 | } 98 | } 99 | 100 | func urlSession(_ session: URLSession, task: URLSessionTask, didCompleteWithError error: Error?) { 101 | if let error = error { 102 | downloadState.value = .failed(error as? CustomError) 103 | } else if let response = task.response as? HTTPURLResponse { 104 | print("HTTP response status code: \(response.statusCode)") 105 | } 106 | } 107 | } 108 | 109 | 110 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/GenerationContext.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GenerationContext.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import SwiftUI 9 | import Combine 10 | 11 | class GenerationContext: ObservableObject { 12 | let scheduler = StableDiffusionScheduler.dpmSolverMultistepScheduler 13 | 14 | @Published var pipeline: Pipeline? = nil { 15 | didSet { 16 | if let pipeline = pipeline { 17 | progressSubscriber = pipeline 18 | .progressPublisher 19 | .receive(on: DispatchQueue.main) 20 | .sink { progress in 21 | guard let progress = progress else { return } 22 | self.updatePreviewIfNeeded(progress) 23 | self.state = .running(progress) 24 | } 25 | } 26 | } 27 | } 28 | @Published var state: GenerationState = .startup 29 | 30 | @Published var positivePrompt = Settings.shared.prompt 31 | @Published var negativePrompt = Settings.shared.negativePrompt 32 | 33 | // FIXME: Double to support the slider component 34 | @Published var steps: Double = Settings.shared.stepCount 35 | @Published var numImages: Double = 1.0 36 | @Published var seed: UInt32 = Settings.shared.seed 37 | @Published var guidanceScale: Double = Settings.shared.guidanceScale 38 | @Published var previews: Double = 0.0 39 | @Published var disableSafety = false 40 | @Published var previewImage: CGImage? = nil 41 | 42 | @Published var computeUnits: ComputeUnits = Settings.shared.userSelectedComputeUnits ?? ModelInfo.defaultComputeUnits 43 | 44 | private var progressSubscriber: Cancellable? 45 | 46 | private func updatePreviewIfNeeded(_ progress: StableDiffusionProgress) { 47 | if previews == 0 || progress.step == 0 { 48 | previewImage = nil 49 | } 50 | 51 | if previews > 0, let newImage = progress.currentImages.first, newImage != nil { 52 | previewImage = newImage 53 | } 54 | } 55 | 56 | func generate(prompt: String) async throws -> GenerationResult { 57 | guard let pipeline = pipeline else { throw CustomError(message: "No pipeline") } 58 | return try pipeline.generate( 59 | prompt: prompt, 60 | negativePrompt: negativePrompt, 61 | scheduler: scheduler, 62 | numInferenceSteps: Int(steps), 63 | seed: seed, 64 | numPreviews: Int(previews), 65 | guidanceScale: Float(guidanceScale), 66 | disableSafety: disableSafety 67 | ) 68 | } 69 | 70 | func cancelGeneration() { 71 | pipeline?.setCancelled() 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/GenerationState.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GenerationState.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import StableDiffusion 9 | import SwiftUI 10 | 11 | enum GenerationState { 12 | case startup 13 | case running(StableDiffusionProgress?) 14 | case complete(String, CGImage?, UInt32, TimeInterval?) 15 | case userCanceled 16 | case failed(Error) 17 | } 18 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/Pipeline.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Pipeline.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import Foundation 9 | import CoreML 10 | import Combine 11 | 12 | import StableDiffusion 13 | 14 | struct StableDiffusionProgress { 15 | var progress: StableDiffusionPipeline.Progress 16 | 17 | var step: Int { progress.step } 18 | var stepCount: Int { progress.stepCount } 19 | 20 | var currentImages: [CGImage?] 21 | 22 | init(progress: StableDiffusionPipeline.Progress, previewIndices: [Bool]) { 23 | self.progress = progress 24 | self.currentImages = [nil] 25 | 26 | // Since currentImages is a computed property, only access the preview image if necessary 27 | if progress.step < previewIndices.count, previewIndices[progress.step] { 28 | self.currentImages = progress.currentImages 29 | } 30 | } 31 | } 32 | 33 | struct GenerationResult { 34 | var image: CGImage? 35 | var lastSeed: UInt32 36 | var interval: TimeInterval? 37 | var userCanceled: Bool 38 | var itsPerSecond: Double? 39 | } 40 | 41 | class Pipeline { 42 | let pipeline: StableDiffusionPipelineProtocol 43 | let maxSeed: UInt32 44 | 45 | var isXL: Bool { 46 | if #available(macOS 14.0, iOS 17.0, *) { 47 | return (pipeline as? StableDiffusionXLPipeline) != nil 48 | } 49 | return false 50 | } 51 | 52 | var progress: StableDiffusionProgress? = nil { 53 | didSet { 54 | progressPublisher.value = progress 55 | } 56 | } 57 | lazy private(set) var progressPublisher: CurrentValueSubject = CurrentValueSubject(progress) 58 | 59 | private var canceled = false 60 | 61 | init(_ pipeline: StableDiffusionPipelineProtocol, maxSeed: UInt32 = UInt32.max) { 62 | self.pipeline = pipeline 63 | self.maxSeed = maxSeed 64 | } 65 | 66 | func generate( 67 | prompt: String, 68 | negativePrompt: String = "", 69 | scheduler: StableDiffusionScheduler, 70 | numInferenceSteps stepCount: Int = 50, 71 | seed: UInt32 = 0, 72 | numPreviews previewCount: Int = 5, 73 | guidanceScale: Float = 7.5, 74 | disableSafety: Bool = false 75 | ) throws -> GenerationResult { 76 | let beginDate = Date() 77 | canceled = false 78 | let theSeed = seed > 0 ? seed : UInt32.random(in: 1...maxSeed) 79 | let sampleTimer = SampleTimer() 80 | sampleTimer.start() 81 | 82 | var config = StableDiffusionPipeline.Configuration(prompt: prompt) 83 | config.negativePrompt = negativePrompt 84 | config.stepCount = stepCount 85 | config.seed = theSeed 86 | config.guidanceScale = guidanceScale 87 | config.disableSafety = disableSafety 88 | config.schedulerType = scheduler.asStableDiffusionScheduler() 89 | config.useDenoisedIntermediates = true 90 | if isXL { 91 | config.encoderScaleFactor = 0.13025 92 | config.decoderScaleFactor = 0.13025 93 | config.schedulerTimestepSpacing = .karras 94 | } 95 | 96 | // Evenly distribute previews based on inference steps 97 | let previewIndices = previewIndices(stepCount, previewCount) 98 | 99 | let images = try pipeline.generateImages(configuration: config) { progress in 100 | sampleTimer.stop() 101 | handleProgress(StableDiffusionProgress(progress: progress, 102 | previewIndices: previewIndices), 103 | sampleTimer: sampleTimer) 104 | if progress.stepCount != progress.step { 105 | sampleTimer.start() 106 | } 107 | return !canceled 108 | } 109 | let interval = Date().timeIntervalSince(beginDate) 110 | print("Got images: \(images) in \(interval)") 111 | 112 | // Unwrap the 1 image we asked for, nil means safety checker triggered 113 | let image = images.compactMap({ $0 }).first 114 | return GenerationResult(image: image, lastSeed: theSeed, interval: interval, userCanceled: canceled, itsPerSecond: 1.0/sampleTimer.median) 115 | } 116 | 117 | func handleProgress(_ progress: StableDiffusionProgress, sampleTimer: SampleTimer) { 118 | self.progress = progress 119 | } 120 | 121 | func setCancelled() { 122 | canceled = true 123 | } 124 | } 125 | 126 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/PipelineLoader.swift: -------------------------------------------------------------------------------- 1 | // 2 | // PipelineLoader.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import Foundation 9 | 10 | import CoreML 11 | import Combine 12 | 13 | import ZIPFoundation 14 | import StableDiffusion 15 | 16 | class PipelineLoader { 17 | static let models = Settings.shared.applicationSupportURL().appendingPathComponent("hf-diffusion-models") 18 | let model: ModelInfo 19 | let computeUnits: ComputeUnits 20 | let maxSeed: UInt32 21 | 22 | private var downloadSubscriber: Cancellable? 23 | 24 | init(model: ModelInfo, computeUnits: ComputeUnits? = nil, maxSeed: UInt32 = UInt32.max) { 25 | self.model = model 26 | self.computeUnits = computeUnits ?? model.defaultComputeUnits 27 | self.maxSeed = maxSeed 28 | state = .undetermined 29 | setInitialState() 30 | } 31 | 32 | enum PipelinePreparationPhase { 33 | case undetermined 34 | case waitingToDownload 35 | case downloading(Double) 36 | case downloaded 37 | case uncompressing 38 | case readyOnDisk 39 | case loaded 40 | case failed(Error) 41 | } 42 | 43 | var state: PipelinePreparationPhase { 44 | didSet { 45 | statePublisher.value = state 46 | } 47 | } 48 | private(set) lazy var statePublisher: CurrentValueSubject = CurrentValueSubject(state) 49 | private(set) var downloader: Downloader? = nil 50 | 51 | func setInitialState() { 52 | if ready { 53 | state = .readyOnDisk 54 | return 55 | } 56 | if downloaded { 57 | state = .downloaded 58 | return 59 | } 60 | state = .waitingToDownload 61 | } 62 | } 63 | 64 | extension PipelineLoader { 65 | // Unused. Kept for debugging purposes. --pcuenca 66 | static func removeAll() { 67 | // Delete the parent models folder as it will be recreated when it's needed again 68 | do { 69 | try FileManager.default.removeItem(at: models) 70 | } catch { 71 | print("Failed to delete: \(models), error: \(error.localizedDescription)") 72 | } 73 | } 74 | } 75 | 76 | 77 | extension PipelineLoader { 78 | func cancel() { downloader?.cancel() } 79 | } 80 | 81 | extension PipelineLoader { 82 | var url: URL { 83 | return model.modelURL(for: variant) 84 | } 85 | 86 | var filename: String { 87 | return url.lastPathComponent 88 | } 89 | 90 | var downloadedURL: URL { PipelineLoader.models.appendingPathComponent(filename) } 91 | 92 | var uncompressURL: URL { PipelineLoader.models } 93 | 94 | var packagesFilename: String { (filename as NSString).deletingPathExtension } 95 | 96 | var compiledURL: URL { downloadedURL.deletingLastPathComponent().appendingPathComponent(packagesFilename) } 97 | 98 | var downloaded: Bool { 99 | return FileManager.default.fileExists(atPath: downloadedURL.path) 100 | } 101 | 102 | var ready: Bool { 103 | return FileManager.default.fileExists(atPath: compiledURL.path) 104 | } 105 | 106 | var variant: AttentionVariant { 107 | switch computeUnits { 108 | case .cpuOnly : return .original // Not supported yet 109 | case .cpuAndGPU : return .original 110 | case .cpuAndNeuralEngine: return model.supportsAttentionV2 ? .splitEinsumV2 : .splitEinsum 111 | case .all : return model.isSD3 ? .original : .splitEinsum 112 | @unknown default: 113 | fatalError("Unknown MLComputeUnits") 114 | } 115 | } 116 | 117 | func prepare() async throws -> Pipeline { 118 | do { 119 | do { 120 | try FileManager.default.createDirectory(atPath: PipelineLoader.models.path, withIntermediateDirectories: true, attributes: nil) 121 | } catch { 122 | print("Error creating PipelineLoader.models path: \(error)") 123 | } 124 | 125 | try await download() 126 | try await unzip() 127 | let pipeline = try await load(url: compiledURL) 128 | return Pipeline(pipeline, maxSeed: maxSeed) 129 | } catch { 130 | state = .failed(error) 131 | throw error 132 | } 133 | } 134 | 135 | @discardableResult 136 | func download() async throws -> URL { 137 | if ready || downloaded { return downloadedURL } 138 | 139 | let downloader = Downloader(from: url, to: downloadedURL) 140 | self.downloader = downloader 141 | downloadSubscriber = downloader.downloadState.sink { state in 142 | if case .downloading(let progress) = state { 143 | self.state = .downloading(progress) 144 | } 145 | } 146 | try downloader.waitUntilDone() 147 | return downloadedURL 148 | } 149 | 150 | func unzip() async throws { 151 | guard downloaded else { return } 152 | state = .uncompressing 153 | do { 154 | try FileManager().unzipItem(at: downloadedURL, to: uncompressURL) 155 | } catch { 156 | // Cleanup if error occurs while unzipping 157 | try FileManager.default.removeItem(at: uncompressURL) 158 | throw error 159 | } 160 | try FileManager.default.removeItem(at: downloadedURL) 161 | state = .readyOnDisk 162 | } 163 | 164 | func load(url: URL) async throws -> StableDiffusionPipelineProtocol { 165 | let beginDate = Date() 166 | let configuration = MLModelConfiguration() 167 | configuration.computeUnits = computeUnits 168 | let pipeline: StableDiffusionPipelineProtocol 169 | if model.isXL { 170 | if #available(macOS 14.0, iOS 17.0, *) { 171 | pipeline = try StableDiffusionXLPipeline(resourcesAt: url, 172 | configuration: configuration, 173 | reduceMemory: model.reduceMemory) 174 | } else { 175 | throw CustomError(message: "Stable Diffusion XL requires macOS 14") 176 | } 177 | } else { 178 | pipeline = try StableDiffusionPipeline(resourcesAt: url, 179 | controlNet: [], 180 | configuration: configuration, 181 | disableSafety: false, 182 | reduceMemory: model.reduceMemory) 183 | } 184 | try pipeline.loadResources() 185 | print("Pipeline loaded in \(Date().timeIntervalSince(beginDate))") 186 | state = .loaded 187 | return pipeline 188 | } 189 | } 190 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/Settings.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Settings.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import CoreML 9 | 10 | let DEFAULT_MODEL = ModelInfo.v21Palettized 11 | let DEFAULT_PROMPT = "Labrador in the style of Vermeer" 12 | 13 | typealias ComputeUnits = MLComputeUnits 14 | 15 | class Settings { 16 | static let shared = Settings() 17 | 18 | let defaults = UserDefaults.standard 19 | 20 | enum Keys: String { 21 | case model 22 | case safetyCheckerDisclaimer 23 | case computeUnits 24 | case prompt 25 | case negativePrompt 26 | case guidanceScale 27 | case stepCount 28 | case previewCount 29 | case seed 30 | } 31 | 32 | private init() { 33 | defaults.register(defaults: [ 34 | Keys.model.rawValue: ModelInfo.v21Palettized.modelId, 35 | Keys.safetyCheckerDisclaimer.rawValue: false, 36 | Keys.computeUnits.rawValue: -1, // Use default 37 | Keys.prompt.rawValue: DEFAULT_PROMPT, 38 | Keys.negativePrompt.rawValue: "", 39 | Keys.guidanceScale.rawValue: 7.5, 40 | Keys.stepCount.rawValue: 25, 41 | Keys.previewCount.rawValue: 5, 42 | Keys.seed.rawValue: 0 43 | ]) 44 | } 45 | 46 | var currentModel: ModelInfo { 47 | set { 48 | defaults.set(newValue.modelId, forKey: Keys.model.rawValue) 49 | } 50 | get { 51 | guard let modelId = defaults.string(forKey: Keys.model.rawValue) else { return DEFAULT_MODEL } 52 | return ModelInfo.from(modelId: modelId) ?? DEFAULT_MODEL 53 | } 54 | } 55 | 56 | var prompt: String { 57 | set { 58 | defaults.set(newValue, forKey: Keys.prompt.rawValue) 59 | } 60 | get { 61 | return defaults.string(forKey: Keys.prompt.rawValue) ?? DEFAULT_PROMPT 62 | } 63 | } 64 | 65 | var negativePrompt: String { 66 | set { 67 | defaults.set(newValue, forKey: Keys.negativePrompt.rawValue) 68 | } 69 | get { 70 | return defaults.string(forKey: Keys.negativePrompt.rawValue) ?? "" 71 | } 72 | } 73 | 74 | var guidanceScale: Double { 75 | set { 76 | defaults.set(newValue, forKey: Keys.guidanceScale.rawValue) 77 | } 78 | get { 79 | return defaults.double(forKey: Keys.guidanceScale.rawValue) 80 | } 81 | } 82 | 83 | var stepCount: Double { 84 | set { 85 | defaults.set(newValue, forKey: Keys.stepCount.rawValue) 86 | } 87 | get { 88 | return defaults.double(forKey: Keys.stepCount.rawValue) 89 | } 90 | } 91 | 92 | var previewCount: Double { 93 | set { 94 | defaults.set(newValue, forKey: Keys.previewCount.rawValue) 95 | } 96 | get { 97 | return defaults.double(forKey: Keys.previewCount.rawValue) 98 | } 99 | } 100 | 101 | var seed: UInt32 { 102 | set { 103 | defaults.set(String(newValue), forKey: Keys.seed.rawValue) 104 | } 105 | get { 106 | if let seedString = defaults.string(forKey: Keys.seed.rawValue), let seedValue = UInt32(seedString) { 107 | return seedValue 108 | } 109 | return 0 110 | } 111 | } 112 | 113 | var safetyCheckerDisclaimerShown: Bool { 114 | set { 115 | defaults.set(newValue, forKey: Keys.safetyCheckerDisclaimer.rawValue) 116 | } 117 | get { 118 | return defaults.bool(forKey: Keys.safetyCheckerDisclaimer.rawValue) 119 | } 120 | } 121 | 122 | /// Returns the option selected by the user, if overridden 123 | /// `nil` means: guess best 124 | var userSelectedComputeUnits: ComputeUnits? { 125 | set { 126 | // Any value other than the supported ones would cause `get` to return `nil` 127 | defaults.set(newValue?.rawValue ?? -1, forKey: Keys.computeUnits.rawValue) 128 | } 129 | get { 130 | let current = defaults.integer(forKey: Keys.computeUnits.rawValue) 131 | guard current != -1 else { return nil } 132 | return ComputeUnits(rawValue: current) 133 | } 134 | } 135 | 136 | public func applicationSupportURL() -> URL { 137 | let fileManager = FileManager.default 138 | guard let appDirectoryURL = fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first else { 139 | // To ensure we don't return an optional - if the user domain application support cannot be accessed use the top level application support directory 140 | return URL.applicationSupportDirectory 141 | } 142 | 143 | do { 144 | // Create the application support directory if it doesn't exist 145 | try fileManager.createDirectory(at: appDirectoryURL, withIntermediateDirectories: true, attributes: nil) 146 | return appDirectoryURL 147 | } catch { 148 | print("Error creating application support directory: \(error)") 149 | return fileManager.urls(for: .applicationSupportDirectory, in: .userDomainMask).first! 150 | } 151 | } 152 | 153 | func tempStorageURL() -> URL { 154 | 155 | let tmpDir = applicationSupportURL().appendingPathComponent("hf-diffusion-tmp") 156 | 157 | // Create directory if it doesn't exist 158 | if !FileManager.default.fileExists(atPath: tmpDir.path) { 159 | do { 160 | try FileManager.default.createDirectory(at: tmpDir, withIntermediateDirectories: true, attributes: nil) 161 | } catch { 162 | print("Failed to create temporary directory: \(error)") 163 | return FileManager.default.temporaryDirectory 164 | } 165 | } 166 | 167 | return tmpDir 168 | } 169 | 170 | } 171 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/State.swift: -------------------------------------------------------------------------------- 1 | // 2 | // State.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import Combine 9 | import SwiftUI 10 | import StableDiffusion 11 | import CoreML 12 | 13 | /// Schedulers compatible with StableDiffusionPipeline. This is a local implementation of the StableDiffusionScheduler enum as a String represetation to allow for compliance with NSSecureCoding. 14 | public enum StableDiffusionScheduler: String { 15 | /// Scheduler that uses a pseudo-linear multi-step (PLMS) method 16 | case pndmScheduler 17 | /// Scheduler that uses a second order DPM-Solver++ algorithm 18 | case dpmSolverMultistepScheduler 19 | 20 | func asStableDiffusionScheduler() -> StableDiffusion.StableDiffusionScheduler { 21 | switch self { 22 | case .pndmScheduler: return StableDiffusion.StableDiffusionScheduler.pndmScheduler 23 | case .dpmSolverMultistepScheduler: return StableDiffusion.StableDiffusionScheduler.dpmSolverMultistepScheduler 24 | } 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/Helpers/Utils.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Utils.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import Foundation 9 | 10 | //extension String: Error {} 11 | 12 | /// Returns an array of booleans that indicates at which steps a preview should be generated. 13 | /// 14 | /// - Parameters: 15 | /// - numInferenceSteps: The total number of inference steps. 16 | /// - numPreviews: The desired number of previews. 17 | /// 18 | /// - Returns: An array of booleans of size `numInferenceSteps`, where `true` values represent steps at which a preview should be made. 19 | func previewIndices(_ numInferenceSteps: Int, _ numPreviews: Int) -> [Bool] { 20 | // Ensure valid parameters 21 | guard numInferenceSteps > 0, numPreviews > 0 else { 22 | return [Bool](repeating: false, count: numInferenceSteps) 23 | } 24 | 25 | // Compute the ideal (floating-point) step size, which represents the average number of steps between previews 26 | let idealStep = Double(numInferenceSteps) / Double(numPreviews) 27 | 28 | // Compute the actual steps at which previews should be made. For each preview, we multiply the ideal step size by the preview number, and round to the nearest integer. 29 | // The result is converted to a `Set` for fast membership tests. 30 | let previewIndices: Set = Set((0.. URL { 86 | // Pattern: https://huggingface.co/pcuenq/coreml-stable-diffusion/resolve/main/coreml-stable-diffusion-v1-5_original_compiled.zip 87 | let suffix: String 88 | switch variant { 89 | case .original: suffix = originalAttentionSuffix 90 | case .splitEinsum: suffix = splitAttentionSuffix 91 | case .splitEinsumV2: suffix = splitAttentionV2Suffix 92 | } 93 | let repo = modelId.split(separator: "/").last! 94 | return URL(string: "https://huggingface.co/\(modelId)/resolve/main/\(repo)_\(suffix).zip")! 95 | } 96 | 97 | /// Best variant for the current platform. 98 | /// Currently using `split_einsum` for iOS and simple performance heuristics for macOS. 99 | var bestURL: URL { modelURL(for: bestAttention) } 100 | 101 | var reduceMemory: Bool { 102 | // Enable on iOS devices, except when using quantization 103 | if isXL { return !deviceHas8GBOrMore } 104 | return !(quantized && deviceHas6GBOrMore) 105 | } 106 | } 107 | 108 | let deviceHas6GBOrMore = ProcessInfo.processInfo.physicalMemory > 5910000000 // Reported by iOS 17 beta (21A5319a) on iPhone 13 Pro: 5917753344 109 | let deviceHas8GBOrMore = ProcessInfo.processInfo.physicalMemory > 7900000000 // Reported by iOS 17.0.2 on iPhone 15 Pro Max: 8021032960 110 | 111 | extension ModelInfo { 112 | static let v21Base = ModelInfo( 113 | modelId: "pcuenq/coreml-stable-diffusion-2-1-base", 114 | modelVersion: "StabilityAI SD 2.1", 115 | supportsEncoder: true 116 | ) 117 | 118 | static let xlmbpChunked = ModelInfo( 119 | modelId: "apple/coreml-stable-diffusion-xl-base-ios", 120 | modelVersion: "SDXL base (768, iOS) [4 bit]", 121 | supportsEncoder: false, 122 | quantized: true, 123 | isXL: true 124 | ) 125 | 126 | static let v21Palettized = ModelInfo( 127 | modelId: "apple/coreml-stable-diffusion-2-1-base-palettized", 128 | modelVersion: "StabilityAI SD 2.1 [6 bit]", 129 | supportsEncoder: true, 130 | supportsAttentionV2: true, 131 | quantized: true 132 | ) 133 | 134 | static let MODELS: [ModelInfo] = { 135 | var models = [ModelInfo.v21Base, ModelInfo.v21Palettized] 136 | 137 | if deviceSupportsQuantization { 138 | models.append(ModelInfo.xlmbpChunked) 139 | } 140 | 141 | return models 142 | }() 143 | 144 | static func from(modelVersion: String) -> ModelInfo? { 145 | ModelInfo.MODELS.first(where: {$0.modelVersion == modelVersion}) 146 | } 147 | 148 | static func from(modelId: String) -> ModelInfo? { 149 | ModelInfo.MODELS.first(where: {$0.modelId == modelId}) 150 | } 151 | } 152 | 153 | let deviceSupportsQuantization = { 154 | if #available(iOS 17, *) { 155 | true 156 | } else { 157 | false 158 | } 159 | }() 160 | 161 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/View/BlobShape.swift: -------------------------------------------------------------------------------- 1 | // 2 | // BlobShape.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 06.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct BlobShape: Shape { 11 | func path(in rect: CGRect) -> Path { 12 | var path = Path() 13 | let width = rect.size.width 14 | let height = rect.size.height 15 | path.move(to: CGPoint(x: 0.16427*width, y: 0.36126*height)) 16 | path.addCurve(to: CGPoint(x: 0.25589*width, y: 0.18325*height), control1: CGPoint(x: 0.20113*width, y: 0.29298*height), control2: CGPoint(x: 0.17954*width, y: 0.226*height)) 17 | path.addCurve(to: CGPoint(x: 0.62238*width, y: 0.10471*height), control1: CGPoint(x: 0.33224*width, y: 0.14049*height), control2: CGPoint(x: 0.52727*width, y: 0.09991*height)) 18 | path.addCurve(to: CGPoint(x: 0.82657*width, y: 0.21204*height), control1: CGPoint(x: 0.7175*width, y: 0.10951*height), control2: CGPoint(x: 0.78207*width, y: 0.1538*height)) 19 | path.addCurve(to: CGPoint(x: 0.8894*width, y: 0.45419*height), control1: CGPoint(x: 0.87107*width, y: 0.27029*height), control2: CGPoint(x: 0.89136*width, y: 0.39092*height)) 20 | path.addCurve(to: CGPoint(x: 0.81479*width, y: 0.59162*height), control1: CGPoint(x: 0.88743*width, y: 0.51745*height), control2: CGPoint(x: 0.82657*width, y: 0.52029*height)) 21 | path.addCurve(to: CGPoint(x: 0.81872*width, y: 0.8822*height), control1: CGPoint(x: 0.80301*width, y: 0.66296*height), control2: CGPoint(x: 0.87456*width, y: 0.83159*height)) 22 | path.addCurve(to: CGPoint(x: 0.47971*width, y: 0.89529*height), control1: CGPoint(x: 0.76287*width, y: 0.93281*height), control2: CGPoint(x: 0.56719*width, y: 0.91907*height)) 23 | path.addCurve(to: CGPoint(x: 0.29385*width, y: 0.73953*height), control1: CGPoint(x: 0.39223*width, y: 0.87151*height), control2: CGPoint(x: 0.36802*width, y: 0.78992*height)) 24 | path.addCurve(to: CGPoint(x: 0.03469*width, y: 0.59293*height), control1: CGPoint(x: 0.21968*width, y: 0.68914*height), control2: CGPoint(x: 0.05628*width, y: 0.65598*height)) 25 | path.addCurve(to: CGPoint(x: 0.16427*width, y: 0.36126*height), control1: CGPoint(x: 0.01309*width, y: 0.52989*height), control2: CGPoint(x: 0.1274*width, y: 0.42954*height)) 26 | path.addCurve(to: CGPoint(x: 0.25589*width, y: 0.18325*height), control1: CGPoint(x: 0.20113*width, y: 0.29298*height), control2: CGPoint(x: 0.17954*width, y: 0.226*height)) 27 | return path 28 | } 29 | } 30 | 31 | #Preview { 32 | BlobShape() 33 | .frame(width: 300, height: 300) 34 | } 35 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/View/GradientLabelView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GradientLabelView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 11.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct GradientLabelView: View { 11 | 12 | var text: String 13 | 14 | var body: some View { 15 | Text(text) 16 | .foregroundStyle(.secondary) 17 | .font(.subheadline) 18 | .padding(.horizontal, 16) 19 | .padding(.vertical, 8) 20 | .background(.thinMaterial, in: Capsule()) 21 | .background(MeshGradient.custom, in: Capsule()) 22 | } 23 | } 24 | 25 | #Preview { 26 | GradientLabelView(text: "Model loading...") 27 | } 28 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/View/ImagePlaygroundView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ImagePlaygroundView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 05.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ImagePlaygroundView: View { 11 | 12 | @ObservedObject var navigationModel: NavigationModel 13 | @StateObject var viewModel = ImagePlaygroundViewModel() 14 | 15 | @State private var text = "" 16 | 17 | var body: some View { 18 | VStack { 19 | HStack { 20 | Button { 21 | navigationModel.navigationPath.removeLast() 22 | } label: { 23 | Text("Cancel") 24 | } 25 | .buttonStyle(.bordered) 26 | .buttonBorderShape(.capsule) 27 | .foregroundStyle(.primary) 28 | 29 | Spacer() 30 | 31 | Button { 32 | navigationModel.navigationPath.removeLast() 33 | } label: { 34 | Text("Done") 35 | } 36 | .buttonStyle(.borderedProminent) 37 | .buttonBorderShape(.capsule) 38 | .foregroundStyle(.primary) 39 | .tint(.yellow) 40 | } 41 | .padding(.horizontal) 42 | 43 | Spacer() 44 | 45 | ZStack { 46 | if viewModel.generating { 47 | if let generatedImage = viewModel.generatedImage { 48 | TimelineView(.animation) { timeline in 49 | let x = (sin(timeline.date.timeIntervalSince1970) + 1) / 1 50 | 51 | MeshGradient(width: 3, height: 3, points: [ 52 | [0, 0], [0.5, 0], [1, 0], 53 | [0, 0.5], [Float(x), 0.5], [1, 0.5], 54 | [0, 1], [0.5, 1], [1, 1] 55 | ], colors: [ 56 | .blue, .red, .purple, 57 | .indigo, .orange, .pink, 58 | .cyan, .green, .yellow 59 | ]) 60 | .frame(width: 350, height: 350) 61 | .blur(radius: 10) 62 | .rotationEffect(.degrees(x * 100)) 63 | } 64 | .frame(width: 250, height: 250) 65 | .clipShape(.rect(cornerRadius: 20, style: .continuous)) 66 | .blur(radius: 6) 67 | .overlay { 68 | Image(uiImage: UIImage(cgImage: generatedImage)) 69 | .resizable() 70 | .scaledToFill() 71 | .frame(width: 235, height: 235) 72 | .background(Color.gray.opacity(0.4)) 73 | .clipShape(.rect(cornerRadius: 20, style: .continuous)) 74 | .overlay { 75 | RoundedRectangle( 76 | cornerRadius: 20, 77 | style: .continuous 78 | ) 79 | .stroke( 80 | MeshGradient(width: 3, height: 3, points: [ 81 | [0, 0], [0.5, 0], [1, 0], 82 | [0, 0.5], [0.5, 0.5], [1, 0.5], 83 | [0, 1], [0.5, 1], [1, 1] 84 | ], colors: [ 85 | .blue, .red, .purple, 86 | .indigo, .orange, .pink, 87 | .cyan, .green, .yellow 88 | ]), lineWidth: 20) 89 | .frame(width: 250, height: 250) 90 | .blur(radius: 10) 91 | } 92 | } 93 | .transition(.scale) 94 | } else { 95 | ImageTemplateView() 96 | .transition(.scale) 97 | } 98 | } else { 99 | if viewModel.currentView == .ready { 100 | GradientLabelView(text: "Enter text to generate an image.") 101 | .transition(.scale) 102 | } 103 | } 104 | 105 | ZStack { 106 | switch viewModel.currentView { 107 | case .loading: 108 | GradientLabelView(text: "Loading the model...") 109 | .transition(.scale) 110 | case .ready: 111 | EmptyView() 112 | case .error(let string): 113 | GradientLabelView(text: "Error: \(string)") 114 | .transition(.scale) 115 | } 116 | } 117 | 118 | if viewModel.preparationPhase == .Downloading { 119 | GradientLabelView(text: String(format: "Downloading: %.1f", viewModel.downloadProgress * 100)) 120 | } 121 | } 122 | 123 | 124 | Spacer() 125 | 126 | VStack { 127 | HStack { 128 | Text("Suggestions".uppercased()) 129 | .font(.subheadline) 130 | .foregroundStyle(.secondary) 131 | 132 | Spacer() 133 | 134 | Button { 135 | // Show more 136 | } label: { 137 | Text("Show more".uppercased()) 138 | .font(.subheadline) 139 | } 140 | .buttonStyle(.borderless) 141 | .tint(.yellow) 142 | 143 | } 144 | .padding() 145 | 146 | HStack { 147 | ForEach(ImageStyle.allCases, id: \.self) { imageStyle in 148 | Button { 149 | viewModel.generate(prompt: "\(text) with the following style: \(imageStyle.prompt)") 150 | } label: { 151 | VStack { 152 | Image(imageStyle.imageName) 153 | .resizable() 154 | .frame(width: 50, height: 50) 155 | .clipShape(Circle()) 156 | 157 | Text(imageStyle.title) 158 | .font(.subheadline) 159 | .foregroundStyle(.secondary) 160 | } 161 | } 162 | .foregroundStyle(.secondary) 163 | .frame(maxWidth: .infinity, maxHeight: 50) 164 | } 165 | } 166 | .padding() 167 | } 168 | 169 | HStack(spacing: 4) { 170 | HStack(spacing: 12) { 171 | SiriIcon() 172 | .frame(width: 32, height: 32) 173 | 174 | TextField(text: $text, prompt: Text("Describe an image")) { 175 | EmptyView() 176 | } 177 | } 178 | .frame(maxWidth: .infinity, maxHeight: 50) 179 | .padding(.vertical, 2) 180 | .padding(.horizontal, 12) 181 | .background(Color(uiColor: .secondarySystemBackground), in: Capsule()) 182 | 183 | Button { 184 | UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to:nil, from:nil, for:nil) 185 | 186 | viewModel.generate(prompt: text) 187 | } label: { 188 | Image(systemName: "arrow.up") 189 | .padding(12) 190 | } 191 | .foregroundStyle(.primary) 192 | .buttonStyle(.bordered) 193 | .buttonBorderShape(.circle) 194 | } 195 | .padding(.horizontal) 196 | 197 | HStack(alignment: .firstTextBaseline) { 198 | Text("BETA") 199 | .padding(4) 200 | .font(.caption) 201 | .background(.gray, in: Capsule()) 202 | 203 | Text("Image may vary based on description, prettiness of the creator, or president selected.") 204 | .font(.caption) 205 | .foregroundStyle(.secondary) 206 | } 207 | .padding(.vertical, 10) 208 | .padding(.horizontal) 209 | } 210 | .task { 211 | await viewModel.loadModel() 212 | } 213 | .navigationBarHidden(true) 214 | } 215 | } 216 | 217 | #Preview { 218 | ImagePlaygroundView(navigationModel: NavigationModel()) 219 | } 220 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/View/ImageTemplateView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ImageTemplateView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 06.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ImageTemplateView: View { 11 | 12 | @State private var blurRadius: Double = 4 13 | @State private var rotationDegrees: Double = 10 14 | @State private var scaleValue: CGFloat = 1 15 | 16 | let opacityValue: Double = 0.75 17 | 18 | var body: some View { 19 | ZStack { 20 | BlobShape() 21 | .fill( 22 | MeshGradient(width: 3, height: 3, points: [ 23 | [0, 0], [0.5, 0], [1, 0], 24 | [0, 0.5], [0.5, 0.5], [1, 0.5], 25 | [0, 1], [0.5, 1], [1, 1] 26 | ], colors: [ 27 | .blue, .red, .purple, 28 | .indigo, .orange, .pink, 29 | .cyan, .green, .yellow 30 | ]) 31 | ) 32 | .frame(width: 200, height: 200) 33 | .blur(radius: blurRadius) 34 | .rotationEffect(.degrees(rotationDegrees * -1)) 35 | .opacity(opacityValue) 36 | .scaleEffect(CGSize(width: scaleValue, height: scaleValue)) 37 | 38 | BlobShape() 39 | .fill( 40 | MeshGradient(width: 3, height: 3, points: [ 41 | [0, 0], [0.5, 0], [1, 0], 42 | [0, 0.5], [0.5, 0.5], [1, 0.5], 43 | [0, 1], [0.5, 1], [1, 1] 44 | ], colors: [ 45 | .blue, .red, .purple, 46 | .indigo, .orange, .pink, 47 | .cyan, .green, .yellow 48 | ]) 49 | ) 50 | .frame(width: 200, height: 200) 51 | .blur(radius: blurRadius) 52 | .rotationEffect(.degrees(rotationDegrees)) 53 | .opacity(opacityValue) 54 | .scaleEffect(CGSize(width: scaleValue * 0.95, height: scaleValue * 0.95)) 55 | 56 | BlobShape() 57 | .fill( 58 | MeshGradient(width: 3, height: 3, points: [ 59 | [0, 0], [0.5, 0], [1, 0], 60 | [0, 0.5], [0.5, 0.5], [1, 0.5], 61 | [0, 1], [0.5, 1], [1, 1] 62 | ], colors: [ 63 | .blue, .red, .purple, 64 | .indigo, .orange, .pink, 65 | .cyan, .green, .yellow 66 | ]) 67 | ) 68 | .frame(width: 200, height: 200) 69 | .blur(radius: blurRadius) 70 | .rotationEffect(.degrees((rotationDegrees - 130) * 0.6)) 71 | .opacity(opacityValue) 72 | .scaleEffect(CGSize(width: scaleValue * 1.15, height: scaleValue * 1.15)) 73 | 74 | BlobShape() 75 | .fill( 76 | MeshGradient(width: 3, height: 3, points: [ 77 | [0, 0], [0.5, 0], [1, 0], 78 | [0, 0.5], [0.5, 0.5], [1, 0.5], 79 | [0, 1], [0.5, 1], [1, 1] 80 | ], colors: [ 81 | .blue, .red, .purple, 82 | .indigo, .orange, .pink, 83 | .cyan, .green, .yellow 84 | ]) 85 | ) 86 | .frame(width: 200, height: 200) 87 | .blur(radius: blurRadius) 88 | .rotationEffect(.degrees((rotationDegrees + 90) * 0.2)) 89 | .opacity(opacityValue) 90 | .scaleEffect(CGSize(width: scaleValue * 1.05, height: scaleValue * 1.05)) 91 | 92 | BlobShape() 93 | .fill( 94 | MeshGradient(width: 3, height: 3, points: [ 95 | [0, 0], [0.5, 0], [1, 0], 96 | [0, 0.5], [0.5, 0.5], [1, 0.5], 97 | [0, 1], [0.5, 1], [1, 1] 98 | ], colors: [ 99 | .blue, .red, .purple, 100 | .indigo, .orange, .pink, 101 | .cyan, .green, .yellow 102 | ]) 103 | ) 104 | .frame(width: 200, height: 200) 105 | .blur(radius: blurRadius) 106 | .rotationEffect(.degrees((rotationDegrees - 240) * 0.9)) 107 | .opacity(opacityValue) 108 | .scaleEffect(CGSize(width: scaleValue * 0.9, height: scaleValue * 0.95)) 109 | } 110 | .onAppear { 111 | withAnimation(.easeInOut(duration: 3).repeatForever(autoreverses: true)) { 112 | blurRadius = 6 113 | rotationDegrees = 200 114 | scaleValue = 1.2 115 | } 116 | } 117 | } 118 | } 119 | 120 | #Preview { 121 | ImageTemplateView() 122 | } 123 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/View/ImageWithPlaceholder.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ImageWithPlaceholder.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 10.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ImageWithPlaceholder: View { 11 | 12 | @ObservedObject var viewModel: ImagePlaygroundViewModel 13 | 14 | var body: some View { 15 | switch viewModel.generation.state { 16 | case .startup: return AnyView(Image("placeholder").resizable()) 17 | case .running(let progress): 18 | guard let progress = progress, progress.stepCount > 0 else { 19 | // The first time it takes a little bit before generation starts 20 | return AnyView(ProgressView()) 21 | } 22 | 23 | let step = Int(progress.step) + 1 24 | let fraction = Double(step) / Double(progress.stepCount) 25 | let label = "Step \(step) of \(progress.stepCount)" 26 | return AnyView(VStack { 27 | Group { 28 | if let safeImage = viewModel.generation.previewImage { 29 | Image(safeImage, scale: 1, label: Text("generated")) 30 | .resizable() 31 | .clipShape(RoundedRectangle(cornerRadius: 20)) 32 | } 33 | } 34 | ProgressView(label, value: fraction, total: 1).padding() 35 | }) 36 | case .complete(_, let image, _, let interval): 37 | guard let theImage = image else { 38 | return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) 39 | } 40 | 41 | let imageView = Image(theImage, scale: 1, label: Text("generated")) 42 | return AnyView( 43 | VStack { 44 | imageView 45 | .resizable() 46 | .frame(width: 200, height: 200) 47 | .clipShape(RoundedRectangle(cornerRadius: 20)) 48 | HStack { 49 | let intervalString = String(format: "Time: %.1fs", interval ?? 0) 50 | Rectangle().fill(.clear).overlay(Text(intervalString).frame(maxWidth: .infinity, alignment: .leading).padding(.leading)) 51 | }.frame(maxHeight: 25) 52 | }) 53 | case .failed(_): 54 | return AnyView(Image(systemName: "exclamationmark.triangle").resizable()) 55 | case .userCanceled: 56 | return AnyView(Text("Generation canceled")) 57 | } 58 | } 59 | } 60 | 61 | #Preview { 62 | ImageWithPlaceholder(viewModel: ImagePlaygroundViewModel()) 63 | } 64 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/View/SiriIcon.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SiriIcon.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 06.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct SiriIcon: View { 11 | var body: some View { 12 | ZStack { 13 | Image("ai") 14 | .resizable() 15 | .aspectRatio(contentMode: .fit) 16 | .foregroundStyle( 17 | MeshGradient.custom 18 | ) 19 | } 20 | } 21 | } 22 | 23 | #Preview { 24 | SiriIcon() 25 | } 26 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Image Playground/ViewModel/ImagePlaygroundViewModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ImagePlaygroundViewModel.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 06.11.24. 6 | // 7 | 8 | import SwiftUI 9 | import Combine 10 | 11 | enum CurrentView: Equatable { 12 | case loading 13 | case ready 14 | case error(String) 15 | } 16 | 17 | enum PreparationPhase: String { 18 | case Downloading, Uncompressing, Loading 19 | } 20 | 21 | @MainActor 22 | class ImagePlaygroundViewModel: ObservableObject { 23 | var generation = GenerationContext() 24 | 25 | @Published var currentView: CurrentView = .loading 26 | @Published var imageName: String? 27 | @Published var generating = false 28 | 29 | @Published var generatedImage: CGImage? 30 | 31 | @Published var preparationPhase: PreparationPhase = .Downloading 32 | @Published var downloadProgress: Double = 0 33 | 34 | @Published var stateSubscriber: Cancellable? 35 | 36 | func generate(prompt: String) { 37 | self.generatedImage = nil 38 | if case .running = generation.state { return } 39 | 40 | generating = true 41 | generation.positivePrompt = prompt 42 | Task { 43 | generation.state = .running(nil) 44 | do { 45 | let result = try await generation.generate(prompt: prompt) 46 | generation.state = .complete(generation.positivePrompt, result.image, result.lastSeed, result.interval) 47 | 48 | await MainActor.run { 49 | self.generatedImage = result.image 50 | } 51 | } catch { 52 | generation.state = .failed(error) 53 | generating = false 54 | } 55 | } 56 | } 57 | 58 | func loadModel() async { 59 | let loader = PipelineLoader(model: iosModel()) 60 | stateSubscriber = loader.statePublisher.sink { state in 61 | DispatchQueue.main.async { 62 | switch state { 63 | case .downloading(let progress): 64 | self.preparationPhase = .Downloading 65 | self.downloadProgress = progress 66 | case .uncompressing: 67 | self.preparationPhase = .Uncompressing 68 | self.downloadProgress = 1 69 | case .readyOnDisk: 70 | self.preparationPhase = .Loading 71 | self.downloadProgress = 1 72 | default: 73 | break 74 | } 75 | } 76 | } 77 | do { 78 | generation.pipeline = try await loader.prepare() 79 | self.currentView = .ready 80 | } catch { 81 | self.currentView = .error("Could not load model, error: \(error)") 82 | } 83 | } 84 | 85 | func iosModel() -> ModelInfo { 86 | guard deviceSupportsQuantization else { return ModelInfo.v21Base } 87 | if deviceHas6GBOrMore { return ModelInfo.xlmbpChunked } 88 | return ModelInfo.v21Palettized 89 | } 90 | 91 | } 92 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/Model/OptionCardSize.swift: -------------------------------------------------------------------------------- 1 | // 2 | // OptionCardSize.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 07.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | enum OptionCardSize { 11 | case regular, small 12 | 13 | var buttonSize: CGFloat { 14 | switch self { 15 | case .regular: return 24 16 | case .small: return 18 17 | } 18 | } 19 | 20 | var font: Font { 21 | switch self { 22 | case .regular: 23 | return .body 24 | case .small: 25 | return .caption 26 | } 27 | } 28 | 29 | var padding: CGFloat { 30 | switch self { 31 | case .regular: 32 | return 8 33 | case .small: 34 | return 4 35 | } 36 | } 37 | } 38 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/Model/Shaders/Shimmer.metal: -------------------------------------------------------------------------------- 1 | // 2 | // Shimmer.metal 3 | // Inferno 4 | // https://www.github.com/twostraws/Inferno 5 | // See LICENSE for license information. 6 | // 7 | 8 | #include 9 | using namespace metal; 10 | 11 | /// Converts a color from RGB to HSL representation. 12 | /// Reference: https://en.wikipedia.org/wiki/HSL_and_HSV 13 | /// - Parameter rgb: A vector representing a color with components (R, G, B). 14 | /// - Returns: A vector representing a color with components (H, S, L). 15 | half3 rgbToHSL(half3 rgb) { 16 | half min = min3(rgb.r, rgb.g, rgb.b); 17 | half max = max3(rgb.r, rgb.g, rgb.b); 18 | half delta = max - min; 19 | 20 | half3 hsl = half3(0.0h, 0.0h, 0.5h * (max + min)); 21 | 22 | if (delta > 0.0h) { 23 | if (max == rgb.r) { 24 | hsl[0] = fmod((rgb.g - rgb.b) / delta, 6.0h); 25 | } else if (max == rgb.g) { 26 | hsl[0] = (rgb.b - rgb.r) / delta + 2.0h; 27 | } else { 28 | hsl[0] = (rgb.r - rgb.g) / delta + 4.0h; 29 | } 30 | hsl[0] /= 6.0h; 31 | if (hsl[2] > 0.0h && hsl[2] < 1.0h) { 32 | hsl[1] = delta / (1.0h - abs(2.0h * hsl[2] - 1.0h)); 33 | } else { 34 | hsl[1] = 0.0h; 35 | } 36 | } 37 | 38 | return hsl; 39 | } 40 | 41 | /// Converts a color from HSL to RGB representation. 42 | /// Reference: https://en.wikipedia.org/wiki/HSL_and_HSV 43 | /// - Parameter hsl: A vector representing a color with components (H, S, L). 44 | /// - Returns: A vector representing a color with components (R, G, B). 45 | half3 hslToRGB(half3 hsl) { 46 | half c = (1.0h - abs(2.0h * hsl[2] - 1.0h)) * hsl[1]; 47 | half h = hsl[0] * 6.0h; 48 | half x = c * (1.0h - abs(fmod(h, 2.0h) - 1.0h)); 49 | 50 | half3 rgb = half3(0.0h, 0.0h, 0.0h); 51 | 52 | if (h < 1.0h) { 53 | rgb = half3(c, x, 0.0h); 54 | } else if (h < 2.0h) { 55 | rgb = half3(x, c, 0.0h); 56 | } else if (h < 3.0h) { 57 | rgb = half3(0.0h, c, x); 58 | } else if (h < 4.0h) { 59 | rgb = half3(0.0h, x, c); 60 | } else if (h < 5.0h) { 61 | rgb = half3(x, 0.0h, c); 62 | } else { 63 | rgb = half3(c, 0.0h, x); 64 | } 65 | 66 | half m = hsl[2] - 0.5h * c; 67 | return rgb + m; 68 | } 69 | 70 | /// A shader that generates a shimmering effect. 71 | /// 72 | /// This works by creating a gradient that moves horizontally across the view, 73 | /// and then uses that gradient to modulate the lightness of the pixel. 74 | /// 75 | /// - Parameter position: The user-space coordinate of the current pixel. 76 | /// - Parameter color: The current color of the pixel. 77 | /// - Parameter size: The size of the entire view, in user-space. 78 | /// - Parameter time: The number of elapsed seconds since the shader was created. 79 | /// - Parameter animationDuration: The duration of a single loop of the shimmer animation, in seconds. 80 | /// - Parameter gradientWidth: The width of the shimmer gradient in UV space. 81 | /// - Parameter maxLightness: The maximum lightness at the peak of the gradient. 82 | /// - Returns: The new pixel color. 83 | [[ stitchable ]] half4 shimmer(float2 position, half4 color, float2 size, float time, float animationDuration, float gradientWidth, float maxLightness) { 84 | if (color.a == 0.0h) { 85 | return color; 86 | } 87 | 88 | // Calculate the current progress of the shimmer animation loop, from 0 to 1 89 | float loopedProgress = fmod(time, float(animationDuration)); 90 | half progress = loopedProgress / animationDuration; 91 | 92 | // Convert coordinate to UV space, 0 to 1. 93 | half2 uv = half2(position / size); 94 | 95 | // Calculate u beyond the views's edges based on the gradient size 96 | half minU = 0.0h - gradientWidth; 97 | half maxU = 1.0h + gradientWidth; 98 | 99 | // Based on the current progress and v, calculate the starting and ending u of the gradient 100 | half start = minU + maxU * progress + gradientWidth * uv.y; 101 | half end = start + gradientWidth; 102 | 103 | if (uv.x > start && uv.x < end) { 104 | // Determine the pixel's position within the gradient, from 0 to 1 105 | half gradient = smoothstep(start, end, uv.x); 106 | // Determine gradient intensity using a sine wave 107 | half intensity = sin(gradient * M_PI_H); 108 | 109 | // Convert from RGB to HSL 110 | half3 hsl = rgbToHSL(color.rgb); 111 | // Modify the lightness component based on intensity 112 | hsl[2] = hsl[2] + half(maxLightness * (maxLightness > 0.0h ? 1 - hsl[2] : hsl[2])) * intensity; 113 | // Convert back to RGB 114 | color.rgb = hslToRGB(hsl); 115 | } 116 | 117 | return color; 118 | } 119 | 120 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/Model/Shaders/Wave.metal: -------------------------------------------------------------------------------- 1 | // 2 | // Wave.metal 3 | // Inferno 4 | // https://www.github.com/twostraws/Inferno 5 | // See LICENSE for license information. 6 | // 7 | 8 | #include 9 | using namespace metal; 10 | 11 | /// A shader that generates a uniform wave effect. 12 | /// 13 | /// This works by offsetting the Y position of each pixel by some amount of its X 14 | /// position. Using sin() for this generates values between -1 and 1, but this then 15 | /// gets multiplied by 10 to increase the strength. 16 | /// 17 | /// - Parameter position: The user-space coordinate of the current pixel. 18 | /// - Parameter time: The number of elapsed seconds since the shader was created 19 | /// - Parameter speed: How fast to make the waves ripple. Try starting with a value of 5. 20 | /// - Parameter smoothing: How much to smooth out the ripples, where greater values 21 | /// produce a smoother effect. Try starting with a value of 20. 22 | /// - Parameter strength: How pronounced to make the ripple effect. Try starting with a 23 | /// value of 5. 24 | /// - Returns: The new pixel color. 25 | [[ stitchable ]] float2 wave(float2 position, float time, float speed, float smoothing, float strength) { 26 | // Offset our Y value by some amount of our X position. 27 | // Using time * 5 speeds up the wave, and dividing the 28 | // X position by 20 smooths out the wave to avoid jaggies. 29 | position.y += sin(time * speed + position.x / smoothing) * strength; 30 | position.x += cos(time * speed + position.y / smoothing) * strength; 31 | 32 | // Send back the offset position. 33 | return position; 34 | } 35 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/Model/WritingToolOption.swift: -------------------------------------------------------------------------------- 1 | // 2 | // WritingToolOption.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 09.11.24. 6 | // 7 | 8 | import Foundation 9 | 10 | enum WritingToolOption { 11 | case proofread, rewrite, friendly, professional, concise, summary, keypoints, table, list 12 | 13 | // func generatePrompt(with text: String) -> String { 14 | // switch self { 15 | // case .proofread: 16 | // return "Proofread the following text. Correct any spelling, grammar, or punctuation errors without changing the meaning. Make sure the text reads smoothly and is error-free: \(text)" 17 | // case .rewrite: 18 | // return "Rewrite the following text to improve clarity, readability, and flow. Preserve the meaning, but make it more engaging and natural: \(text)" 19 | // case .friendly: 20 | // return "Rewrite the following text in a friendly, approachable tone. Use conversational language and make it sound warm and welcoming while preserving the original message: \(text)" 21 | // case .professional: 22 | // return "Rewrite the following text in a professional and formal tone. Make it sound polished, precise, and suitable for a business or formal setting: \(text)" 23 | // case .concise: 24 | // return "Rewrite the following text to be as concise as possible without losing any important details or meaning. Eliminate any unnecessary words: \(text)" 25 | // case .summary: 26 | // return "Summarize the following text in a few sentences, capturing only the essential points and main ideas: \(text)" 27 | // case .keypoints: 28 | // return "Extract the key points from the following text. Present them as bullet points and cover the main ideas or important details: \(text)" 29 | // case .table: 30 | // return "Convert the information in the following text into a structured table format. Organize it into relevant categories or columns for clarity: \(text)" 31 | // case .list: 32 | // return "Convert the following text into a list of bullet points. Each point should represent a distinct idea or item for easy reading: \(text)" 33 | // } 34 | // } 35 | 36 | } 37 | 38 | 39 | extension WritingToolOption { 40 | func generatePrompt(with text: String) -> String { 41 | switch self { 42 | case .proofread: 43 | return "Proofread the following text. Correct any spelling, grammar, or punctuation errors without changing the meaning. Make sure the text reads smoothly and is error-free: \(text)" 44 | case .rewrite: 45 | return "Rewrite the following text to improve clarity, readability, and flow. Preserve the meaning, but make it more engaging and natural: \(text)" 46 | case .friendly: 47 | return "Rewrite the following text in a friendly, approachable tone. Use conversational language and make it sound warm and welcoming while preserving the original message: \(text)" 48 | case .professional: 49 | return "Rewrite the following text in a professional and formal tone. Make it sound polished, precise, and suitable for a business or formal setting: \(text)" 50 | case .concise: 51 | return "Rewrite the following text to be as concise as possible without losing any important details or meaning. Eliminate any unnecessary words: \(text)" 52 | case .summary: 53 | return "Summarize the following text in a few sentences, capturing only the essential points and main ideas: \(text)" 54 | case .keypoints: 55 | return "Extract the key points from the following text. Present them as bullet points and cover the main ideas or important details: \(text)" 56 | case .table: 57 | return "Convert the information in the following text into a structured table format. Organize it into relevant categories or columns for clarity: \(text)" 58 | case .list: 59 | return "Convert the following text into a list of bullet points. Each point should represent a distinct idea or item for easy reading: \(text)" 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/View/OptionCard.swift: -------------------------------------------------------------------------------- 1 | // 2 | // OptionCard.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 07.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct OptionCard: View { 11 | 12 | var title: String 13 | var systemImageName: String 14 | var action: () -> Void 15 | 16 | var option: OptionCardSize = .regular 17 | 18 | var body: some View { 19 | Button { 20 | action() 21 | } label: { 22 | VStack(spacing: option.padding) { 23 | Image(systemName: systemImageName) 24 | .resizable() 25 | .frame( 26 | width: option.buttonSize, 27 | height: option.buttonSize 28 | ) 29 | 30 | Text(title) 31 | .font(option.font) 32 | } 33 | .padding(option.padding) 34 | .frame(maxWidth: .infinity) 35 | } 36 | .buttonStyle(.bordered) 37 | .buttonBorderShape(.roundedRectangle(radius: 10) 38 | ) 39 | .tint(.secondary) 40 | } 41 | 42 | } 43 | 44 | #Preview { 45 | OptionCard(title: "Proofread", systemImageName: "text.magnifyingglass", action: { 46 | // 47 | }) 48 | } 49 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/View/SheetButton.swift: -------------------------------------------------------------------------------- 1 | // 2 | // SheetButton.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 07.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct SheetButton: View { 11 | 12 | var text: String 13 | var action: () -> Void 14 | 15 | var body: some View { 16 | Button { 17 | action() 18 | } label: { 19 | VStack(spacing: 8) { 20 | RoundedRectangle(cornerRadius: 8, style: .continuous) 21 | .fill(Color(uiColor: .tertiarySystemGroupedBackground)) 22 | .aspectRatio(3 / 4, contentMode: .fit) 23 | .shadow(radius: 8) 24 | 25 | Text(text) 26 | .font(.caption) 27 | .fontWeight(.semibold) 28 | } 29 | .frame(maxWidth: .infinity) 30 | } 31 | .tint(.primary) 32 | } 33 | } 34 | 35 | #Preview { 36 | SheetButton(text: "Summary") {} 37 | } 38 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/View/WritingToolsInputView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // WritingToolsInputView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 09.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct WritingToolsInputView: View { 11 | 12 | @ObservedObject var navigationModel: NavigationModel 13 | @StateObject var viewModel = WritingToolsViewModel() 14 | 15 | var body: some View { 16 | VStack { 17 | HStack(alignment: .center) { 18 | Button { 19 | navigationModel.navigationPath.removeLast() 20 | } label: { 21 | Image(systemName: "xmark") 22 | .resizable() 23 | .frame(width: 12, height: 12) 24 | .padding(8) 25 | } 26 | .buttonStyle(.bordered) 27 | .buttonBorderShape(.circle) 28 | .tint(.gray) 29 | 30 | Text("Input text") 31 | .font(.title3) 32 | .fontWeight(.semibold) 33 | 34 | Spacer() 35 | 36 | 37 | 38 | Button { 39 | viewModel.showingWritingTools.toggle() 40 | } label: { 41 | SiriIcon() 42 | .frame(width: 32, height: 32) 43 | } 44 | } 45 | .padding() 46 | 47 | ZStack { 48 | TextEditor(text: $viewModel.textInput) 49 | .scrollContentBackground(.hidden) 50 | .foregroundStyle(.primary) 51 | .padding() 52 | .background(Color(uiColor: .secondarySystemBackground)) 53 | .clipShape(.rect(cornerRadius: 20, style: .continuous)) 54 | .padding() 55 | .overlay { 56 | if viewModel.analyzingText { 57 | ZStack { 58 | Color(.white) 59 | .opacity(0.6) 60 | .clipShape(.rect(cornerRadius: 20, style: .continuous)) 61 | .padding() 62 | 63 | TimelineView(.animation) { timeline in 64 | let time = viewModel.startDate.distance(to: timeline.date) 65 | 66 | Color.white.opacity(0.5) 67 | .clipShape(.rect(cornerRadius: 20, style: .continuous)) 68 | .padding() 69 | .visualEffect { content, proxy in 70 | content 71 | .colorEffect(ShaderLibrary.shimmer( 72 | .float2(proxy.size), 73 | .float(time), 74 | .float(2.0), 75 | .float(0.9), 76 | .float(0.5))) 77 | } 78 | } 79 | } 80 | } 81 | } 82 | 83 | 84 | } 85 | } 86 | .sheet(isPresented: $viewModel.showingWritingTools) { 87 | WritingToolsView(viewModel: viewModel) 88 | .presentationDetents([.medium]) 89 | .presentationDragIndicator(.hidden) 90 | } 91 | .task { 92 | _ = try? await viewModel.load() 93 | } 94 | .navigationBarHidden(true) 95 | } 96 | } 97 | 98 | #Preview { 99 | WritingToolsInputView(navigationModel: NavigationModel()) 100 | } 101 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/View/WritingToolsView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // WritingToolsView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 07.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct WritingToolsView: View { 11 | 12 | @ObservedObject var viewModel: WritingToolsViewModel 13 | 14 | @State private var text = "" 15 | 16 | var body: some View { 17 | VStack { 18 | HStack { 19 | Text("Writing Tools") 20 | .font(.title2) 21 | .fontWeight(.semibold) 22 | 23 | Spacer() 24 | 25 | Button { 26 | // Close 27 | } label: { 28 | Image(systemName: "xmark") 29 | .padding(4) 30 | } 31 | .buttonStyle(.bordered) 32 | .buttonBorderShape(.circle) 33 | .tint(.gray) 34 | } 35 | 36 | TextInputView(text: $text, placeholder: "Describe your change", showButton: true) { 37 | // Execute command 38 | } 39 | 40 | Divider() 41 | 42 | HStack { 43 | OptionCard(title: "Proofread", systemImageName: "text.magnifyingglass") { 44 | // Proofread 45 | viewModel.executeLLM(with: .proofread) 46 | } 47 | 48 | OptionCard(title: "Rewrite", systemImageName: "arrow.trianglehead.2.counterclockwise.rotate.90") { 49 | // Rewrite 50 | viewModel.executeLLM(with: .rewrite) 51 | } 52 | } 53 | 54 | HStack { 55 | OptionCard(title: "Friendly", systemImageName: "face.smiling", action: { 56 | // Friendly tone 57 | viewModel.executeLLM(with: .friendly) 58 | }, option: .small) 59 | 60 | OptionCard(title: "Professional", systemImageName: "briefcase", action: { 61 | // Professional tone 62 | viewModel.executeLLM(with: .professional) 63 | }, option: .small) 64 | 65 | OptionCard(title: "Concise", systemImageName: "arrow.down.and.line.horizontal.and.arrow.up", action: { 66 | // Concise tone 67 | viewModel.executeLLM(with: .concise) 68 | }, option: .small) 69 | } 70 | 71 | HStack(spacing: 30) { 72 | SheetButton(text: "Summary") { 73 | // Create summary 74 | viewModel.executeLLM(with: .summary) 75 | } 76 | 77 | SheetButton(text: "Key Points") { 78 | // Extract key points 79 | viewModel.executeLLM(with: .keypoints) 80 | } 81 | 82 | Divider() 83 | .frame(maxHeight: 40) 84 | 85 | SheetButton(text: "Table") { 86 | // Create table 87 | viewModel.executeLLM(with: .table) 88 | } 89 | 90 | SheetButton(text: "List") { 91 | // List 92 | viewModel.executeLLM(with: .list) 93 | } 94 | } 95 | .frame(maxHeight: 140) 96 | } 97 | .padding() 98 | } 99 | } 100 | 101 | #Preview { 102 | WritingToolsView(viewModel: WritingToolsViewModel()) 103 | } 104 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Features/Writing tools/ViewModel/WritingToolsViewModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // WritingToolsViewModel.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 09.11.24. 6 | // 7 | 8 | import MLX 9 | import MLXRandom 10 | import MarkdownUI 11 | import Metal 12 | import SwiftUI 13 | import Tokenizers 14 | 15 | @MainActor 16 | class WritingToolsViewModel: ObservableObject { 17 | 18 | @Published var textInput = "The bestest TV show I've ever had wathed was The Game of Thrones seres. It ws s gdd!" 19 | @Published var analyzingText = false 20 | @Published var showingWritingTools = false 21 | @Published var startDate: Date = Date() 22 | 23 | var running = false 24 | 25 | var output = "" 26 | var modelInfo = "" 27 | 28 | let systemPrompt = "You are a professional writing assistant designed to improve and transform written text according to user needs. Your goal is to make text clearer, more concise, and appropriate in tone or format based on the user's instructions. You handle text with precision, ensuring accurate grammar, a clear message, and the specified tone or structure. Respond only with the improved or transformed text unless otherwise requested." 29 | 30 | /// This controls which model loads. `llama3_2_3B_4bit` is one of the smaller ones, so this will fit on 31 | /// more devices. 32 | let modelConfiguration = ModelConfiguration.llama3_2_3B_4bit 33 | 34 | /// parameters controlling the output 35 | let generateParameters = GenerateParameters(temperature: 0.6) 36 | let maxTokens = 240 37 | 38 | /// update the display every N tokens -- 4 looks like it updates continuously 39 | /// and is low overhead. observed ~15% reduction in tokens/s when updating 40 | /// on every token 41 | let displayEveryNTokens = 4 42 | 43 | enum LoadState { 44 | case idle 45 | case loaded(ModelContainer) 46 | } 47 | 48 | var loadState = LoadState.idle 49 | 50 | /// load and return the model -- can be called multiple times, subsequent calls will 51 | /// just return the loaded model 52 | func load() async throws -> ModelContainer { 53 | switch loadState { 54 | case .idle: 55 | // limit the buffer cache 56 | MLX.GPU.set(cacheLimit: 20 * 1024 * 1024) 57 | 58 | let modelContainer = try await loadModelContainer(configuration: modelConfiguration) 59 | { 60 | [modelConfiguration] progress in 61 | Task { @MainActor in 62 | self.modelInfo = 63 | "Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%" 64 | 65 | print("Downloading \(modelConfiguration.name): \(Int(progress.fractionCompleted * 100))%") 66 | } 67 | } 68 | let numParams = await modelContainer.perform { 69 | [] model, _ in 70 | return model.numParameters() 71 | } 72 | 73 | self.modelInfo = 74 | "Loaded \(modelConfiguration.id). Weights: \(numParams / (1024*1024))M" 75 | loadState = .loaded(modelContainer) 76 | return modelContainer 77 | 78 | case .loaded(let modelContainer): 79 | return modelContainer 80 | } 81 | } 82 | 83 | func executeLLM(with option: WritingToolOption) { 84 | Task { 85 | await generateText(text: textInput, option: option) 86 | } 87 | } 88 | 89 | private func generateText(text: String, option: WritingToolOption) async { 90 | guard !running else { return } 91 | 92 | self.startDate = Date() 93 | self.showingWritingTools = false 94 | 95 | withAnimation { 96 | analyzingText = true 97 | } 98 | 99 | try? await Task.sleep(for: .seconds(2)) 100 | 101 | withAnimation { 102 | analyzingText = false 103 | } 104 | 105 | 106 | running = true 107 | textInput = "" 108 | 109 | do { 110 | let modelContainer = try await load() 111 | 112 | let messages = [ 113 | ["role": "system", "content": systemPrompt], 114 | ["role": "user", "content": option.generatePrompt(with: text)] 115 | ] 116 | let promptTokens = try await modelContainer.perform { _, tokenizer in 117 | try tokenizer.applyChatTemplate(messages: messages) 118 | } 119 | 120 | // each time you generate you will get something new 121 | MLXRandom.seed(UInt64(Date.timeIntervalSinceReferenceDate * 1000)) 122 | 123 | let result = await modelContainer.perform { model, tokenizer in 124 | MyAppleIntelligence.generate( 125 | promptTokens: promptTokens, parameters: generateParameters, model: model, 126 | tokenizer: tokenizer, extraEOSTokens: modelConfiguration.extraEOSTokens 127 | ) { tokens in 128 | // update the output -- this will make the view show the text as it generates 129 | if tokens.count % displayEveryNTokens == 0 { 130 | let text = tokenizer.decode(tokens: tokens) 131 | Task { @MainActor in 132 | self.textInput = text 133 | } 134 | } 135 | 136 | if tokens.count >= maxTokens { 137 | return .stop 138 | } else { 139 | return .more 140 | } 141 | } 142 | } 143 | 144 | // update the text if needed, e.g. we haven't displayed because of displayEveryNTokens 145 | if result.output != self.output { 146 | self.output = result.output 147 | } 148 | 149 | print("Tokens/second: \(String(format: "%.3f", result.tokensPerSecond))") 150 | 151 | } catch { 152 | textInput = "Failed: \(error)" 153 | } 154 | 155 | running = false 156 | } 157 | } 158 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/Configuration.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Configuration.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | 10 | public enum StringOrNumber: Codable, Equatable, Sendable { 11 | case string(String) 12 | case float(Float) 13 | 14 | public init(from decoder: Decoder) throws { 15 | let values = try decoder.singleValueContainer() 16 | 17 | if let v = try? values.decode(Float.self) { 18 | self = .float(v) 19 | } else { 20 | let v = try values.decode(String.self) 21 | self = .string(v) 22 | } 23 | } 24 | 25 | public func encode(to encoder: Encoder) throws { 26 | var container = encoder.singleValueContainer() 27 | switch self { 28 | case .string(let v): try container.encode(v) 29 | case .float(let v): try container.encode(v) 30 | } 31 | } 32 | } 33 | 34 | private class ModelTypeRegistry: @unchecked Sendable { 35 | 36 | // Note: using NSLock as we have very small (just dictionary get/set) 37 | // critical sections and expect no contention. this allows the methods 38 | // to remain synchronous. 39 | private let lock = NSLock() 40 | 41 | @Sendable 42 | private static func createLlamaModel(url: URL) throws -> LLMModel { 43 | let configuration = try JSONDecoder().decode( 44 | LlamaConfiguration.self, from: Data(contentsOf: url)) 45 | return LlamaModel(configuration) 46 | } 47 | 48 | private var creators: [String: @Sendable (URL) throws -> LLMModel] = [ 49 | "llama": createLlamaModel, 50 | ] 51 | 52 | public func registerModelType( 53 | _ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel 54 | ) { 55 | lock.withLock { 56 | creators[type] = creator 57 | } 58 | } 59 | 60 | public func createModel(configuration: URL, rawValue: String) throws -> LLMModel { 61 | let creator = lock.withLock { 62 | creators[rawValue] 63 | } 64 | guard let creator else { 65 | throw LLMError(message: "Unsupported model type.") 66 | } 67 | return try creator(configuration) 68 | } 69 | 70 | } 71 | 72 | private let modelTypeRegistry = ModelTypeRegistry() 73 | 74 | public struct ModelType: RawRepresentable, Codable, Sendable { 75 | public let rawValue: String 76 | 77 | public init(rawValue: String) { 78 | self.rawValue = rawValue 79 | } 80 | 81 | public static func registerModelType( 82 | _ type: String, creator: @Sendable @escaping (URL) throws -> LLMModel 83 | ) { 84 | modelTypeRegistry.registerModelType(type, creator: creator) 85 | } 86 | 87 | public func createModel(configuration: URL) throws -> LLMModel { 88 | try modelTypeRegistry.createModel(configuration: configuration, rawValue: rawValue) 89 | } 90 | } 91 | 92 | public struct BaseConfiguration: Codable, Sendable { 93 | public let modelType: ModelType 94 | 95 | public struct Quantization: Codable, Sendable { 96 | public init(groupSize: Int, bits: Int) { 97 | self.groupSize = groupSize 98 | self.bits = bits 99 | } 100 | 101 | let groupSize: Int 102 | let bits: Int 103 | 104 | enum CodingKeys: String, CodingKey { 105 | case groupSize = "group_size" 106 | case bits = "bits" 107 | } 108 | } 109 | 110 | public var quantization: Quantization? 111 | 112 | enum CodingKeys: String, CodingKey { 113 | case modelType = "model_type" 114 | case quantization 115 | } 116 | } 117 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/Evaluate.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Evaluate.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | import MLX 10 | import MLXRandom 11 | import Tokenizers 12 | 13 | /// Parameters for text generation, see ``TokenIterator`` 14 | public struct GenerateParameters: Sendable { 15 | 16 | /// Step size for processing the prompt 17 | public var prefillStepSize = 512 18 | 19 | /// sampling temperature 20 | public var temperature: Float = 0.6 21 | 22 | /// top p sampling 23 | public var topP: Float = 1.0 24 | 25 | /// penalty factor for repeating tokens 26 | public var repetitionPenalty: Float? 27 | 28 | /// number of tokens to consider for repetition penalty 29 | public var repetitionContextSize: Int = 20 30 | 31 | public init( 32 | temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil, 33 | repetitionContextSize: Int = 20 34 | ) { 35 | self.temperature = temperature 36 | self.topP = topP 37 | self.repetitionPenalty = repetitionPenalty 38 | self.repetitionContextSize = repetitionContextSize 39 | } 40 | } 41 | 42 | struct SampleContext { 43 | 44 | let temp: MLXArray 45 | let topP: MLXArray 46 | let useTopP: Bool 47 | let useArgMax: Bool 48 | 49 | init(parameters: GenerateParameters) { 50 | self.temp = MLXArray(parameters.temperature) 51 | self.topP = MLXArray(parameters.topP) 52 | self.useTopP = parameters.topP > 0 && parameters.topP < 1 53 | self.useArgMax = parameters.temperature == 0 54 | } 55 | 56 | private let compiledTopPSampling: (MLXArray, MLXArray, MLXArray) -> MLXArray = { 57 | compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { 58 | logits, topP, temp in 59 | let probs = softmax(logits / temp, axis: -1) 60 | let sortedIndices = argSort(probs, axis: -1) 61 | 62 | // probs shape is [B,V] and after take it will be [1, B, V], so we squeeze it back to [B, V] 63 | let sortedProbs = take(probs, sortedIndices, axis: -1).squeezed(axis: 0) 64 | 65 | let cumulativeProbs = cumsum(sortedProbs, axis: -1) 66 | 67 | let topProbs = MLX.where( 68 | cumulativeProbs .> (1 - topP), sortedProbs, zeros(like: sortedProbs)) 69 | 70 | let sortedToken = categorical(log(topProbs)) 71 | return sortedIndices.squeezed(axis: 0)[sortedToken] 72 | } 73 | }() 74 | 75 | private let compiledCategorical: (MLXArray, MLXArray) -> MLXArray = { 76 | compile(inputs: [MLXRandom.globalState], outputs: [MLXRandom.globalState]) { logits, temp in 77 | categorical(logits * (1 / temp)) 78 | } 79 | }() 80 | 81 | private func topPSampling(logits: MLXArray) -> MLXArray { 82 | var logits = logits 83 | if logits.dtype == .bfloat16 { 84 | logits = logits.asType(.float32) 85 | } 86 | 87 | return compiledTopPSampling(logits, topP, temp) 88 | } 89 | 90 | func sample(logits: MLXArray) -> MLXArray { 91 | if useArgMax { 92 | return argMax(logits, axis: -1) 93 | } else { 94 | if useTopP { 95 | return topPSampling(logits: logits) 96 | } else { 97 | return compiledCategorical(logits, temp) 98 | } 99 | } 100 | } 101 | } 102 | 103 | /// Encapsulaton of the repetitionPenalty 104 | struct RepetitionContext: Sendable { 105 | /// tokens in the repetition context sliding window 106 | var tokens: [Int] 107 | 108 | /// current write into into the tokens circular array 109 | var index = 0 110 | 111 | /// penalty factor for repeating tokens 112 | let repetitionPenalty: Float? 113 | 114 | /// number of tokens to consider for repetition penalty 115 | let repetitionContextSize: Int 116 | 117 | init(prompt: MLXArray, parameters: GenerateParameters) { 118 | self.repetitionPenalty = parameters.repetitionPenalty 119 | self.repetitionContextSize = parameters.repetitionContextSize 120 | 121 | if repetitionPenalty != nil && repetitionContextSize > 1 { 122 | if prompt.shape[0] <= repetitionContextSize { 123 | self.tokens = prompt.asArray(Int.self) 124 | } else { 125 | self.tokens = prompt[(-repetitionContextSize)...].asArray(Int.self) 126 | } 127 | } else { 128 | self.tokens = [] 129 | } 130 | } 131 | 132 | func applyRepetitionPenalty(logits: MLXArray) -> MLXArray { 133 | if let penalty = repetitionPenalty, tokens.count > 0 { 134 | let indices = MLXArray(tokens.map { UInt32($0) }) 135 | var selectedLogits = logits[0..., indices] 136 | 137 | selectedLogits = MLX.where( 138 | selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty) 139 | 140 | logits[0..., indices] = selectedLogits 141 | return logits 142 | } 143 | 144 | return logits 145 | } 146 | 147 | mutating func append(token: MLXArray) { 148 | if repetitionPenalty != nil { 149 | if tokens.count >= repetitionContextSize { 150 | tokens[index] = token.item(Int.self) 151 | index = (index + 1) % repetitionContextSize 152 | } else { 153 | tokens.append(token.item(Int.self)) 154 | } 155 | } 156 | } 157 | } 158 | 159 | /// Synchronous generator of tokens. 160 | /// 161 | /// Tokens are integers that can be passed through a `Tokenizer` or ``StreamingDetokenizer`` to produce Strings. 162 | /// 163 | /// Port of `generate_step()` from https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/utils.py 164 | /// 165 | /// Note: this uses `asyncEval()` and there may be an async evaluation running after a call to `next()`. 166 | public struct TokenIterator: Sequence, IteratorProtocol { 167 | let model: LLMModel 168 | let parameters: GenerateParameters 169 | 170 | var y: MLXArray 171 | var cache: [KVCache] 172 | var repetitionContext: RepetitionContext 173 | let sampleContext: SampleContext 174 | 175 | public init(prompt: MLXArray, model: LLMModel, parameters: GenerateParameters) { 176 | self.model = model 177 | self.parameters = parameters 178 | self.y = prompt 179 | self.cache = model.newCache(parameters: parameters) 180 | 181 | self.repetitionContext = RepetitionContext(prompt: prompt, parameters: parameters) 182 | self.sampleContext = SampleContext(parameters: parameters) 183 | 184 | // prepare the prompt in chunks if larger than the prefill size 185 | while y.size > parameters.prefillStepSize { 186 | _ = model( 187 | y[.newAxis, .. MLXArray { 201 | var logits: MLXArray 202 | logits = model(previous[.newAxis], cache: cache.isEmpty ? nil : cache) 203 | 204 | logits = logits[0..., -1, 0...] 205 | logits = repetitionContext.applyRepetitionPenalty(logits: logits) 206 | 207 | let y = sampleContext.sample(logits: logits) 208 | 209 | repetitionContext.append(token: y) 210 | 211 | return y 212 | } 213 | 214 | mutating public func next() -> Int? { 215 | // save current value -- this will be returned 216 | let previousY = y 217 | 218 | // compute the next state and async eval the next token 219 | y = step(previous: previousY) 220 | asyncEval(y) 221 | 222 | return previousY.item(Int.self) 223 | } 224 | } 225 | 226 | public struct GenerateResult: Sendable { 227 | /// input tokens 228 | public let promptTokens: [Int] 229 | 230 | /// output tokens 231 | public let tokens: [Int] 232 | 233 | /// output text 234 | public let output: String 235 | 236 | /// time to process the prompt / generate the first token 237 | public let promptTime: TimeInterval 238 | 239 | /// time to generate the remaining tokens 240 | public let generateTime: TimeInterval 241 | 242 | public var promptTokensPerSecond: Double { 243 | Double(promptTokens.count) / promptTime 244 | } 245 | 246 | public var tokensPerSecond: Double { 247 | Double(tokens.count) / generateTime 248 | } 249 | 250 | public func summary() -> String { 251 | """ 252 | Prompt: \(promptTokens.count) tokens, \(promptTokensPerSecond.formatted()) tokens/s 253 | Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s 254 | """ 255 | } 256 | } 257 | 258 | public enum GenerateDisposition: Sendable { 259 | case more 260 | case stop 261 | } 262 | 263 | /// Given prompt tokens generate text using the given model and parameters. 264 | /// 265 | /// - Parameters: 266 | /// - promptTokens: tokenized prompt 267 | /// - parameters: generation parameters 268 | /// - model: model to evaluate 269 | /// - tokenizer: tokenizer to convert tokens back into strings and recognizer special tokens 270 | /// - configuration: the model configuration 271 | /// - didGenerate: visitor for the tokens as they are generated 272 | public func generate( 273 | promptTokens: [Int], parameters: GenerateParameters, model: LLMModel, tokenizer: Tokenizer, 274 | extraEOSTokens: Set? = nil, 275 | didGenerate: ([Int]) -> GenerateDisposition 276 | ) -> GenerateResult { 277 | var start = Date.timeIntervalSinceReferenceDate 278 | var promptTime: TimeInterval = 0 279 | 280 | let additionalEOSTokenIds = Set( 281 | (extraEOSTokens ?? []) 282 | .compactMap { 283 | tokenizer.convertTokenToId($0) 284 | }) 285 | 286 | var tokens = [Int]() 287 | 288 | for token in TokenIterator( 289 | prompt: MLXArray(promptTokens), model: model, parameters: parameters) 290 | { 291 | // compute the timing for the prompt 292 | if tokens.isEmpty { 293 | let now = Date.timeIntervalSinceReferenceDate 294 | promptTime = now - start 295 | start = now 296 | } 297 | 298 | if token == tokenizer.unknownTokenId || token == tokenizer.eosTokenId 299 | || additionalEOSTokenIds.contains(token) 300 | { 301 | break 302 | } 303 | tokens.append(token) 304 | 305 | if didGenerate(tokens) == .stop { 306 | break 307 | } 308 | } 309 | 310 | let now = Date.timeIntervalSinceReferenceDate 311 | let generateTime = now - start 312 | 313 | // TokenIterator uses `asyncEval()` to keep the pipeline full. If the caller 314 | // exits the program right away, those tasks will still be executing and will 315 | // hit assertions as the mlx scheduler is torn down. Synchronize with the stream 316 | // to make sure it is complete. 317 | Stream().synchronize() 318 | 319 | return GenerateResult( 320 | promptTokens: promptTokens, tokens: tokens, 321 | output: tokenizer.decode(tokens: tokens), 322 | promptTime: promptTime, generateTime: generateTime) 323 | } 324 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/KVCache.swift: -------------------------------------------------------------------------------- 1 | // 2 | // KVCache.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | import MLX 10 | 11 | /// Interface for Key/Value cache for LLMs. 12 | /// 13 | /// See ``LLMModel/newCache(parameters:)-47tyu`` 14 | public protocol KVCache: Evaluatable { 15 | 16 | /// get the current offset 17 | var offset: Int { get } 18 | 19 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) 20 | } 21 | 22 | func createAdditiveCausalMask(n: Int, offset: Int) -> MLXArray { 23 | let rinds = MLXArray(Int32(0) ..< Int32(offset + n)) 24 | let linds = offset != 0 ? MLXArray(Int32(offset) ..< Int32(offset + n)) : rinds 25 | let mask = linds[0..., .newAxis] .< rinds[.newAxis] 26 | return mask * Float32(-1e9) 27 | } 28 | 29 | /// create an attention mask using the parameters from the KVCache. 30 | /// 31 | /// See also ``MultiHeadAttention/createAdditiveCausalMask(_:dtype:)`` -- same idea 32 | /// but doesn't honor the cache offset. 33 | public func createAttentionMask(h: MLXArray, cache: [KVCache]?) -> MLXArray? { 34 | let t = h.dim(1) 35 | if t > 1 { 36 | var offset = 0 37 | if let c = cache?.first { 38 | offset = c.offset 39 | } 40 | return createAdditiveCausalMask(n: t, offset: offset) 41 | .asType(h.dtype) 42 | } 43 | return nil 44 | } 45 | 46 | /// See https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/base.py#L11 47 | class KVCacheSimple: KVCache, Evaluatable { 48 | let kHeadDim: Int 49 | let vHeadDim: Int 50 | let kvHeads: Int 51 | 52 | var keys: MLXArray? 53 | var values: MLXArray? 54 | 55 | var offset = 0 56 | var step = 256 57 | 58 | init(headDim: IntOrPair, kvHeads: Int) { 59 | self.kHeadDim = headDim.first 60 | self.vHeadDim = headDim.second 61 | self.kvHeads = kvHeads 62 | } 63 | 64 | public func innerState() -> [MLXArray] { 65 | [self.keys, self.values].compactMap { $0 } 66 | } 67 | 68 | func update(keys: MLXArray, values: MLXArray) -> (MLXArray, MLXArray) { 69 | let previous = self.offset 70 | 71 | let reset = 72 | if let currentKeys = self.keys, (previous + keys.dim(2)) > currentKeys.dim(2) { 73 | true 74 | } else { 75 | self.keys == nil 76 | } 77 | if reset { 78 | let B = keys.dim(0) 79 | let nSteps = (step + keys.dim(2) - 1) / step 80 | let kShape = [B, kvHeads, nSteps * step, kHeadDim] 81 | let vShape = [B, kvHeads, nSteps * step, vHeadDim] 82 | let newK = MLXArray.zeros(kShape, dtype: keys.dtype) 83 | let newV = MLXArray.zeros(vShape, dtype: values.dtype) 84 | 85 | if var currentKeys = self.keys, var currentValues = self.values { 86 | if previous % step != 0 { 87 | currentKeys = currentKeys[.ellipsis, ..(_ action: @Sendable (LLMModel, Tokenizer) throws -> R) rethrows -> R { 62 | try action(model, tokenizer) 63 | } 64 | } 65 | 66 | extension Module { 67 | 68 | /// Compute the number of parameters in a possibly quantized model 69 | public func numParameters() -> Int { 70 | return leafModules().flattenedValues().map { 71 | mod -> Int in 72 | if let qlin = mod as? QuantizedLinear { 73 | return qlin.scales.size * qlin.groupSize 74 | } else if let qemb = mod as? QuantizedEmbedding { 75 | return qemb.scales.size * qemb.groupSize 76 | } else { 77 | return mod.parameters().flattenedValues().reduce( 78 | 0, 79 | { 80 | $0 + $1.size 81 | }) 82 | } 83 | }.reduce(0, +) 84 | } 85 | } 86 | 87 | /// Interface for all LLM Models 88 | public protocol LLMModel: Module { 89 | 90 | var vocabularySize: Int { get } 91 | 92 | func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray 93 | 94 | /// create a new array of ``KVCache`` -- automatic implementation if self 95 | /// implements ``KVCacheDimensionProvider`` 96 | func newCache(parameters: GenerateParameters) -> [KVCache] 97 | 98 | /// Optionally preprocess the weights and modify / remove values as needed. 99 | func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] 100 | } 101 | 102 | /// Optional protocol that can be implemented by ``LLMModel`` and will 103 | /// provide an automatic implementation of ``LLMModel/newCache(parameters:)`` 104 | public protocol KVCacheDimensionProvider { 105 | var kvHeads: [Int] { get } 106 | var headDim: IntOrPair { get } 107 | } 108 | 109 | extension LLMModel { 110 | public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { 111 | weights 112 | } 113 | } 114 | 115 | extension LLMModel where Self: KVCacheDimensionProvider { 116 | public func newCache(parameters: GenerateParameters) -> [KVCache] { 117 | kvHeads.map { n in 118 | KVCacheSimple(headDim: headDim, kvHeads: n) 119 | } 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/Llama.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Llama.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | import MLX 10 | import MLXFast 11 | import MLXNN 12 | 13 | // port of https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/llama.py 14 | 15 | func computeBaseFrequency( 16 | base: Float, dims: Int, ropeType: String, ropeScaling: [String: StringOrNumber]? 17 | ) 18 | -> Float 19 | { 20 | if ropeType != "llama3" { 21 | return base 22 | } 23 | 24 | guard let ropeScaling = ropeScaling else { 25 | return base 26 | } 27 | 28 | guard case .float(let factor) = ropeScaling["factor"], 29 | case .float(let lowFreqFactor) = ropeScaling["low_freq_factor"] ?? .float(1.0), 30 | case .float(let highFreqFactor) = ropeScaling["high_freq_factor"] ?? .float(4.0), 31 | case .float(let oldContextLen) = ropeScaling["original_max_position_embeddings"] 32 | ?? .float(8192) 33 | else { 34 | return base 35 | } 36 | 37 | let lowFreqWavelen = oldContextLen / lowFreqFactor 38 | let highFreqWavelen = oldContextLen / highFreqFactor 39 | 40 | let freqs = (0 ..< dims).compactMap { index -> Float? in 41 | if index % 2 == 0 { 42 | return pow(base, Float(index) / Float(dims)) 43 | } 44 | return nil 45 | } 46 | 47 | let newBaseFreqs = freqs.map { freq -> Float in 48 | let wavelen = 2 * .pi / freq 49 | let smooth = max( 50 | 0, min(1, (wavelen - highFreqWavelen) / (lowFreqWavelen - highFreqWavelen))) 51 | return freq * ((1 - smooth) * factor + smooth) 52 | } 53 | 54 | return newBaseFreqs.reduce(0, +) / Float(newBaseFreqs.count) 55 | } 56 | 57 | private class DynamicNTKScalingRoPE: Module { 58 | let dims: Int 59 | let maxPositionEmbeddings: Int? 60 | let traditional: Bool 61 | let base: Float 62 | var scale: Float 63 | let ropeType: String 64 | let ropeScaling: [String: StringOrNumber]? 65 | 66 | init( 67 | dims: Int, maxPositionEmbeddings: Int?, traditional: Bool = false, 68 | base: Float = 10000, scale: Float = 1.0, ropeType: String = "default", 69 | ropeScaling: [String: StringOrNumber]? = nil 70 | ) { 71 | self.dims = dims 72 | self.maxPositionEmbeddings = maxPositionEmbeddings 73 | self.traditional = traditional 74 | self.base = computeBaseFrequency( 75 | base: base, dims: dims, ropeType: ropeType, ropeScaling: ropeScaling) 76 | self.scale = scale 77 | self.ropeType = ropeType 78 | self.ropeScaling = ropeScaling 79 | } 80 | 81 | func callAsFunction(_ x: MLXArray, offset: Int = 0) -> MLXArray { 82 | let seqLen = x.dim(1) + offset 83 | var base = self.base 84 | if let maxPositionEmbeddings, seqLen > maxPositionEmbeddings { 85 | let factorAdjustment = Float(seqLen) / Float(maxPositionEmbeddings) - 1 86 | let dimensionRatio = Float(dims) / Float(Float(dims) - 2) 87 | let adjustedScale = scale * pow(1 + factorAdjustment, dimensionRatio) 88 | base *= adjustedScale 89 | } 90 | return MLXFast.RoPE( 91 | x, dimensions: dims, traditional: traditional, base: base, scale: scale, offset: offset) 92 | } 93 | } 94 | 95 | private class Attention: Module { 96 | 97 | let args: LlamaConfiguration 98 | let scale: Float 99 | 100 | @ModuleInfo(key: "q_proj") var wq: Linear 101 | @ModuleInfo(key: "k_proj") var wk: Linear 102 | @ModuleInfo(key: "v_proj") var wv: Linear 103 | @ModuleInfo(key: "o_proj") var wo: Linear 104 | 105 | let rope: DynamicNTKScalingRoPE 106 | 107 | init(_ args: LlamaConfiguration) { 108 | self.args = args 109 | 110 | let dim = args.hiddenSize 111 | let heads = args.attentionHeads 112 | let kvHeads = args.kvHeads 113 | 114 | let headDim = args.resolvedHeadDimensions 115 | self.scale = pow(Float(headDim), -0.5) 116 | 117 | self._wq.wrappedValue = Linear(dim, heads * headDim, bias: args.attentionBias) 118 | self._wk.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) 119 | self._wv.wrappedValue = Linear(dim, kvHeads * headDim, bias: args.attentionBias) 120 | self._wo.wrappedValue = Linear(heads * headDim, dim, bias: args.attentionBias) 121 | 122 | self.rope = DynamicNTKScalingRoPE( 123 | dims: headDim, 124 | maxPositionEmbeddings: args.maxPositionEmbeddings, 125 | traditional: args.ropeTraditional, 126 | base: args.ropeTheta, 127 | scale: 1.0, 128 | ropeType: { 129 | if case .string(let value) = args.ropeScaling?["type"] { 130 | return value 131 | } else { 132 | return "default" 133 | } 134 | }(), 135 | ropeScaling: args.ropeScaling) 136 | } 137 | 138 | func callAsFunction( 139 | _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? 140 | ) -> MLXArray { 141 | let (B, L) = (x.dim(0), x.dim(1)) 142 | 143 | var queries = wq(x) 144 | var keys = wk(x) 145 | var values = wv(x) 146 | 147 | // Prepare the queries, keys and values for the attention computation 148 | queries = queries.reshaped(B, L, args.attentionHeads, -1).transposed(0, 2, 1, 3) 149 | keys = keys.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) 150 | values = values.reshaped(B, L, args.kvHeads, -1).transposed(0, 2, 1, 3) 151 | 152 | if let cache { 153 | queries = rope(queries, offset: cache.offset) 154 | keys = rope(keys, offset: cache.offset) 155 | (keys, values) = cache.update(keys: keys, values: values) 156 | } else { 157 | queries = rope(queries) 158 | keys = rope(keys) 159 | } 160 | 161 | let output = MLXFast.scaledDotProductAttention( 162 | queries: queries, keys: keys, values: values, scale: scale, mask: mask 163 | ) 164 | .transposed(0, 2, 1, 3) 165 | .reshaped(B, L, -1) 166 | 167 | return wo(output) 168 | } 169 | } 170 | 171 | private class MLP: Module, UnaryLayer { 172 | 173 | @ModuleInfo(key: "gate_proj") var gate: Linear 174 | @ModuleInfo(key: "down_proj") var down: Linear 175 | @ModuleInfo(key: "up_proj") var up: Linear 176 | 177 | init(_ args: LlamaConfiguration) { 178 | self._gate.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) 179 | self._down.wrappedValue = Linear(args.intermediateSize, args.hiddenSize, bias: args.mlpBias) 180 | self._up.wrappedValue = Linear(args.hiddenSize, args.intermediateSize, bias: args.mlpBias) 181 | } 182 | 183 | func callAsFunction(_ x: MLXArray) -> MLXArray { 184 | let activation = silu(gate(x)) 185 | return down(activation * up(x)) 186 | } 187 | } 188 | 189 | private class TransformerBlock: Module { 190 | @ModuleInfo(key: "self_attn") var attention: Attention 191 | @ModuleInfo(key: "mlp") var mlp: MLP 192 | 193 | @ModuleInfo(key: "input_layernorm") var inputLayerNorm: RMSNorm 194 | @ModuleInfo(key: "post_attention_layernorm") var postAttentionLayerNorm: RMSNorm 195 | 196 | init(_ args: LlamaConfiguration) { 197 | self._attention.wrappedValue = Attention(args) 198 | self._mlp.wrappedValue = MLP(args) 199 | self._inputLayerNorm.wrappedValue = RMSNorm( 200 | dimensions: args.hiddenSize, eps: args.rmsNormEps) 201 | self._postAttentionLayerNorm.wrappedValue = RMSNorm( 202 | dimensions: args.hiddenSize, eps: args.rmsNormEps) 203 | } 204 | 205 | func callAsFunction( 206 | _ x: MLXArray, mask: MLXArray? = nil, cache: KVCache? 207 | ) -> MLXArray { 208 | var r = attention(inputLayerNorm(x), mask: mask, cache: cache) 209 | let h = x + r 210 | r = mlp(postAttentionLayerNorm(h)) 211 | let out = h + r 212 | return out 213 | } 214 | } 215 | 216 | private class LlamaModelInner: Module { 217 | 218 | @ModuleInfo(key: "embed_tokens") var embedTokens: Embedding 219 | 220 | let layers: [TransformerBlock] 221 | let norm: RMSNorm 222 | 223 | init(_ args: LlamaConfiguration) { 224 | precondition(args.vocabularySize > 0) 225 | 226 | self._embedTokens.wrappedValue = Embedding( 227 | embeddingCount: args.vocabularySize, dimensions: args.hiddenSize) 228 | 229 | self.layers = (0 ..< args.hiddenLayers).map { _ in TransformerBlock(args) } 230 | self.norm = RMSNorm(dimensions: args.hiddenSize, eps: args.rmsNormEps) 231 | } 232 | 233 | func callAsFunction(_ inputs: MLXArray, cache: [KVCache]? = nil) -> MLXArray { 234 | var h = embedTokens(inputs) 235 | 236 | let mask: MLXArray? = createAttentionMask(h: h, cache: cache) 237 | 238 | for (i, layer) in layers.enumerated() { 239 | h = layer(h, mask: mask, cache: cache?[i]) 240 | } 241 | 242 | return norm(h) 243 | } 244 | } 245 | 246 | public class LlamaModel: Module, LLMModel, KVCacheDimensionProvider { 247 | 248 | public let vocabularySize: Int 249 | public let kvHeads: [Int] 250 | public let headDim: IntOrPair 251 | 252 | fileprivate let model: LlamaModelInner 253 | 254 | @ModuleInfo(key: "lm_head") var lmHead: Linear? 255 | 256 | public init(_ args: LlamaConfiguration) { 257 | self.vocabularySize = args.vocabularySize 258 | self.kvHeads = (0 ..< args.hiddenLayers).map { _ in args.kvHeads } 259 | self.headDim = .init(args.resolvedHeadDimensions) 260 | self.model = LlamaModelInner(args) 261 | if !args.tieWordEmbeddings { 262 | self._lmHead.wrappedValue = Linear(args.hiddenSize, args.vocabularySize, bias: false) 263 | } 264 | } 265 | 266 | public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray { 267 | let out = model(inputs, cache: cache) 268 | if let lmHead { 269 | return lmHead(out) 270 | } else { 271 | return model.embedTokens.asLinear(out) 272 | } 273 | } 274 | 275 | public func sanitize(weights: [String: MLXArray]) -> [String: MLXArray] { 276 | // Remove unused precomputed rotary frequencies 277 | weights.filter { 278 | !$0.key.contains("self_attn.rotary_emb.inv_freq") 279 | } 280 | } 281 | } 282 | 283 | public struct LlamaConfiguration: Codable, Sendable { 284 | 285 | var hiddenSize: Int 286 | var hiddenLayers: Int 287 | var intermediateSize: Int 288 | var attentionHeads: Int 289 | var headDimensions: Int? 290 | var rmsNormEps: Float 291 | var vocabularySize: Int 292 | var kvHeads: Int 293 | var maxPositionEmbeddings: Int? 294 | var ropeTheta: Float = 10_000 295 | var ropeTraditional: Bool = false 296 | var ropeScaling: [String: StringOrNumber]? 297 | var tieWordEmbeddings: Bool = true 298 | var attentionBias: Bool = false 299 | var mlpBias: Bool = false 300 | 301 | var resolvedHeadDimensions: Int { 302 | headDimensions ?? (hiddenSize / attentionHeads) 303 | } 304 | 305 | enum CodingKeys: String, CodingKey { 306 | case hiddenSize = "hidden_size" 307 | case hiddenLayers = "num_hidden_layers" 308 | case intermediateSize = "intermediate_size" 309 | case attentionHeads = "num_attention_heads" 310 | case headDimensions = "head_dim" 311 | case rmsNormEps = "rms_norm_eps" 312 | case vocabularySize = "vocab_size" 313 | case kvHeads = "num_key_value_heads" 314 | case maxPositionEmbeddings = "max_position_embeddings" 315 | case ropeTheta = "rope_theta" 316 | case ropeTraditional = "rope_traditional" 317 | case ropeScaling = "rope_scaling" 318 | case tieWordEmbeddings = "tie_word_embeddings" 319 | case attentionBias = "attention_bias" 320 | case mlpBias = "mlp_bias" 321 | } 322 | 323 | public init(from decoder: Decoder) throws { 324 | let container = try decoder.container(keyedBy: CodingKeys.self) 325 | 326 | hiddenSize = try container.decode(Int.self, forKey: .hiddenSize) 327 | hiddenLayers = try container.decode(Int.self, forKey: .hiddenLayers) 328 | intermediateSize = try container.decode(Int.self, forKey: .intermediateSize) 329 | attentionHeads = try container.decode(Int.self, forKey: .attentionHeads) 330 | headDimensions = try container.decodeIfPresent(Int.self, forKey: .headDimensions) 331 | rmsNormEps = try container.decode(Float.self, forKey: .rmsNormEps) 332 | vocabularySize = try container.decode(Int.self, forKey: .vocabularySize) 333 | kvHeads = try container.decodeIfPresent(Int.self, forKey: .kvHeads) ?? attentionHeads 334 | maxPositionEmbeddings = try container.decodeIfPresent( 335 | Int.self, forKey: .maxPositionEmbeddings) 336 | if let ropeTheta = try container.decodeIfPresent(Float.self, forKey: .ropeTheta) { 337 | self.ropeTheta = ropeTheta 338 | } 339 | if let ropeTraditional = try container.decodeIfPresent(Bool.self, forKey: .ropeTraditional) 340 | { 341 | self.ropeTraditional = ropeTraditional 342 | } 343 | ropeScaling = try container.decodeIfPresent( 344 | [String: StringOrNumber].self, forKey: .ropeScaling) 345 | if let tieWordEmbeddings = try container.decodeIfPresent( 346 | Bool.self, forKey: .tieWordEmbeddings) 347 | { 348 | self.tieWordEmbeddings = tieWordEmbeddings 349 | } 350 | if let attentionBias = try container.decodeIfPresent(Bool.self, forKey: .attentionBias) { 351 | self.attentionBias = attentionBias 352 | } 353 | if let mlpBias = try container.decodeIfPresent(Bool.self, forKey: .mlpBias) { 354 | self.mlpBias = mlpBias 355 | } 356 | 357 | if let ropeScaling { 358 | if ropeScaling["factor"] == nil { 359 | throw DecodingError.dataCorruptedError( 360 | forKey: .ropeScaling, in: container, 361 | debugDescription: "rope_scaling must contain 'factor'") 362 | } 363 | if let ropeType = ropeScaling["type"] ?? ropeScaling["rope_type"] { 364 | if case .string = ropeType { 365 | let options = [ 366 | StringOrNumber.string("linear"), StringOrNumber.string("dynamic"), 367 | StringOrNumber.string("llama3"), 368 | ] 369 | if !options.contains(ropeType) { 370 | throw DecodingError.dataCorruptedError( 371 | forKey: .ropeScaling, in: container, 372 | debugDescription: 373 | "rope_scaling 'type' currently only supports 'linear', 'dynamic', or 'llama3'" 374 | ) 375 | } 376 | } 377 | } else { 378 | throw DecodingError.dataCorruptedError( 379 | forKey: .ropeScaling, in: container, 380 | debugDescription: "rope_scaling must contain either 'type' or 'rope_type'") 381 | } 382 | } 383 | } 384 | } 385 | 386 | // MARK: - LoRA 387 | 388 | extension LlamaModel: LoRAModel { 389 | public func loraLinearLayers() -> LoRALinearLayers { 390 | model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } 391 | } 392 | } 393 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/Load.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Load.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | @preconcurrency import Hub 10 | import MLX 11 | import MLXNN 12 | import MLXRandom 13 | import Tokenizers 14 | 15 | struct LLMError: Error { 16 | let message: String 17 | } 18 | 19 | func prepareModelDirectory( 20 | hub: HubApi, configuration: ModelConfiguration, 21 | progressHandler: @Sendable @escaping (Progress) -> Void 22 | ) async throws -> URL { 23 | do { 24 | switch configuration.id { 25 | case .id(let id): 26 | // download the model weights 27 | let repo = Hub.Repo(id: id) 28 | let modelFiles = ["*.safetensors", "config.json"] 29 | return try await hub.snapshot( 30 | from: repo, matching: modelFiles, progressHandler: progressHandler) 31 | 32 | case .directory(let directory): 33 | return directory 34 | } 35 | } catch Hub.HubClientError.authorizationRequired { 36 | // an authorizationRequired means (typically) that the named repo doesn't exist on 37 | // on the server so retry with local only configuration 38 | return configuration.modelDirectory(hub: hub) 39 | } catch { 40 | let nserror = error as NSError 41 | if nserror.domain == NSURLErrorDomain && nserror.code == NSURLErrorNotConnectedToInternet { 42 | // Error Domain=NSURLErrorDomain Code=-1009 "The Internet connection appears to be offline." 43 | // fall back to the local directory 44 | return configuration.modelDirectory(hub: hub) 45 | } else { 46 | throw error 47 | } 48 | } 49 | } 50 | 51 | /// Load and return the model and tokenizer 52 | public func load( 53 | hub: HubApi = HubApi(), configuration: ModelConfiguration, 54 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 55 | ) async throws -> (LLMModel, Tokenizer) { 56 | let modelDirectory = try await prepareModelDirectory( 57 | hub: hub, configuration: configuration, progressHandler: progressHandler) 58 | let model = try loadSynchronous(modelDirectory: modelDirectory) 59 | let tokenizer = try await loadTokenizer(configuration: configuration, hub: hub) 60 | 61 | return (model, tokenizer) 62 | } 63 | 64 | func loadSynchronous(modelDirectory: URL) throws -> LLMModel { 65 | // create the model (no weights loaded) 66 | let configurationURL = modelDirectory.appending(component: "config.json") 67 | let baseConfig = try JSONDecoder().decode( 68 | BaseConfiguration.self, from: Data(contentsOf: configurationURL)) 69 | 70 | let model = try baseConfig.modelType.createModel(configuration: configurationURL) 71 | 72 | // load the weights 73 | var weights = [String: MLXArray]() 74 | let enumerator = FileManager.default.enumerator( 75 | at: modelDirectory, includingPropertiesForKeys: nil)! 76 | for case let url as URL in enumerator { 77 | if url.pathExtension == "safetensors" { 78 | let w = try loadArrays(url: url) 79 | for (key, value) in w { 80 | weights[key] = value 81 | } 82 | } 83 | } 84 | 85 | // per-model cleanup 86 | weights = model.sanitize(weights: weights) 87 | 88 | // quantize if needed 89 | if let quantization = baseConfig.quantization { 90 | quantize(model: model, groupSize: quantization.groupSize, bits: quantization.bits) { 91 | path, module in 92 | weights["\(path).scales"] != nil 93 | } 94 | } 95 | 96 | // apply the loaded weights 97 | let parameters = ModuleParameters.unflattened(weights) 98 | try model.update(parameters: parameters, verify: [.all]) 99 | 100 | eval(model) 101 | 102 | return model 103 | } 104 | 105 | /// Load and return the model and tokenizer wrapped in a ``ModelContainer`` (provides 106 | /// thread safe access). 107 | public func loadModelContainer( 108 | hub: HubApi = HubApi(), 109 | configuration: ModelConfiguration, 110 | progressHandler: @Sendable @escaping (Progress) -> Void = { _ in } 111 | ) async throws -> ModelContainer { 112 | let modelDirectory = try await prepareModelDirectory( 113 | hub: hub, 114 | configuration: configuration, 115 | progressHandler: progressHandler 116 | ) 117 | 118 | return try await ModelContainer( 119 | hub: hub, 120 | modelDirectory: modelDirectory, 121 | configuration: configuration 122 | ) 123 | } 124 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/Lora.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Lora.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | import MLX 10 | import MLXNN 11 | import MLXOptimizers 12 | import MLXRandom 13 | import Tokenizers 14 | 15 | /// Layers to apply LoRA adapters to. 16 | /// 17 | /// This is the value returned by ``LoRAModel/loraLinearLayers()``. 18 | public typealias LoRALinearLayers = [(Module, [String])] 19 | 20 | public protocol LoRAModel { 21 | /// Return the layers and keys to apply LoRA adapters to. 22 | /// 23 | /// For example this might apply the adapters to the `q` an `v` projections in the 24 | /// Attention layers: 25 | /// 26 | /// ```swift 27 | /// model.layers.map { ($0.attention, ["q_proj", "v_proj"]) } 28 | /// ``` 29 | /// 30 | /// It is not required that a model implement this protocol to have LoRA adapters applied, but 31 | /// the command line driver example uses this to produce the ``LoRALinearLayers``. 32 | /// 33 | /// ### See Also 34 | /// - ``LoRATrain/convert(model:layers:)`` 35 | func loraLinearLayers() -> LoRALinearLayers 36 | } 37 | 38 | /// Protocol for LoRA implementations that provides a method for converting back to a `Linear` 39 | /// (or subtype). 40 | /// 41 | /// This is normally called via ``LoRATrain/fuse(model:layers:deQuantize:)`` 42 | public protocol LoRAConvertToLinear { 43 | func toLinear(deQuantize: Bool) -> Linear 44 | } 45 | 46 | /// Implementation of LoRA `Linear` replacement layer. 47 | /// 48 | /// This layer implements the LoRA capabilities for `Linear` layers, specifically: 49 | /// 50 | /// - converting `Linear` or `QuantizedLinear` layers to ``LoRALinear`` / ``QLoRALinear`` 51 | /// - converting ``LoRALinear`` back to `Linear` or `QuantizedLinear` (``LoRAConvertToLinear``) 52 | /// - implementing the LoRA evaluation 53 | /// 54 | /// ``QLoRALinear`` is the equivalent class for `QuantizedLinear`. 55 | /// 56 | /// This is not typically used directly -- ``LoRATrain/convert(model:layers:)`` is used to 57 | /// add the adapter layers to a given model. 58 | /// 59 | /// ### See Also 60 | /// - [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) 61 | /// - [QLoRA: Efficient Finetuning of Quantized LLMs](https://arxiv.org/abs/2305.14314) 62 | /// - ``QLoRALinear`` 63 | /// - ``LoRATrain/convert(model:layers:)`` 64 | /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` 65 | public class LoRALinear: Linear, LoRAConvertToLinear { 66 | 67 | let scale: Float 68 | 69 | @ParameterInfo(key: "lora_a") var loraA: MLXArray 70 | @ParameterInfo(key: "lora_b") var loraB: MLXArray 71 | 72 | required public init( 73 | _ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false, 74 | scale: Float = 20.0, linear: Linear 75 | ) { 76 | // Scale for low-rank update 77 | self.scale = scale 78 | 79 | // Low rank lora weights 80 | let loraScale = 1 / sqrt(Float(inputDimensions)) 81 | self._loraA.wrappedValue = MLXRandom.uniform( 82 | low: -loraScale, high: loraScale, [inputDimensions, rank]) 83 | self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) 84 | 85 | super.init(weight: linear.weight, bias: linear.bias) 86 | 87 | freeze() 88 | } 89 | 90 | /// Freeze all parameters except the lora parameters 91 | public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) 92 | throws 93 | { 94 | // realize the keys and omit the lora parameters 95 | let keys = 96 | (keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 }) 97 | .filter { 98 | $0 != "lora_a" && $0 != "lora_b" 99 | } 100 | try super.freeze(recursive: recursive, keys: keys, strict: strict) 101 | } 102 | 103 | /// Convert a `Linear` or `QuantizedLinear` layer into a new `Linear` layer 104 | /// that implements the `LoRA` adapter. 105 | /// 106 | /// This is typically called via ``LoRATrain/convert(model:layers:)``. 107 | /// 108 | /// ### See Also 109 | /// - ``LoRATrain/convert(model:layers:)`` 110 | /// - ``QLoRALinear/from(linear:rank:)`` 111 | public static func from(linear: Linear, rank: Int = 8) -> Linear { 112 | if let linear = linear as? QuantizedLinear { 113 | return QLoRALinear.from(linear: linear, rank: rank) 114 | } 115 | let (outputDimensions, inputDimensions) = linear.shape 116 | return LoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear) 117 | } 118 | 119 | /// Convert back into a fused `Linear` layer. 120 | /// 121 | /// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``. 122 | /// 123 | /// ### See Also 124 | /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` 125 | /// - ``LoRAConvertToLinear`` 126 | /// - ``QLoRALinear/toLinear(deQuantize:)`` 127 | public func toLinear(deQuantize: Bool = false) -> Linear { 128 | let dtype = weight.dtype 129 | let loraB = (scale * loraB.T).asType(dtype) 130 | let loraA = loraA.T.asType(dtype) 131 | return Linear(weight: weight + matmul(loraB, loraA), bias: bias) 132 | } 133 | 134 | public override func callAsFunction(_ x: MLXArray) -> MLXArray { 135 | let y = super.callAsFunction(x.asType(weight.dtype)) 136 | let z = matmul(matmul(x, self.loraA), self.loraB) 137 | return y + scale * z 138 | } 139 | } 140 | 141 | /// Implementation of LoRA `QuantizedLinear` replacement layer. 142 | /// 143 | /// See ``LoRALinear`` (equivalent class for `Linear` layers) for more information. 144 | public class QLoRALinear: QuantizedLinear, LoRAConvertToLinear { 145 | 146 | let scale: Float 147 | 148 | @ParameterInfo(key: "lora_a") var loraA: MLXArray 149 | @ParameterInfo(key: "lora_b") var loraB: MLXArray 150 | 151 | required public init( 152 | _ inputDimensions: Int, _ outputDimensions: Int, rank: Int = 8, bias: Bool = false, 153 | scale: Float = 20.0, linear: QuantizedLinear 154 | ) { 155 | 156 | // Scale for low-rank update 157 | self.scale = scale 158 | 159 | // Low rank lora weights 160 | let loraScale = 1 / sqrt(Float(inputDimensions)) 161 | self._loraA.wrappedValue = MLXRandom.uniform( 162 | low: -loraScale, high: loraScale, [inputDimensions, rank]) 163 | self._loraB.wrappedValue = MLXArray.zeros([rank, outputDimensions]) 164 | 165 | super.init( 166 | weight: linear.weight, bias: linear.bias, scales: linear.scales, biases: linear.biases, 167 | groupSize: linear.groupSize, bits: linear.bits) 168 | 169 | // start frozen except for the lora keys 170 | freeze() 171 | } 172 | 173 | /// Freeze all parameters except the lora parameters 174 | public override func freeze(recursive: Bool = true, keys: [String]? = nil, strict: Bool = false) 175 | throws 176 | { 177 | // realize the keys and omit the lora parameters 178 | let keys = 179 | (keys ?? self.filterMap(filter: Self.filterLocalParameters).flattened().map { $0.0 }) 180 | .filter { 181 | $0 != "lora_a" && $0 != "lora_b" 182 | } 183 | try super.freeze(recursive: recursive, keys: keys, strict: strict) 184 | } 185 | 186 | /// Convert a `QuantizedLinear` layer into a new `Linear` layer 187 | /// that implements the `LoRA` adapter. 188 | /// 189 | /// This is typically called via ``LoRATrain/convert(model:layers:)``. 190 | /// 191 | /// ### See Also 192 | /// - ``LoRATrain/convert(model:layers:)`` 193 | /// - ``LoRALinear/from(linear:rank:)`` 194 | public static func from(linear: QuantizedLinear, rank: Int = 8) -> Linear { 195 | var (outputDimensions, inputDimensions) = linear.shape 196 | inputDimensions = inputDimensions * 32 / linear.bits 197 | return QLoRALinear(inputDimensions, outputDimensions, rank: rank, linear: linear) 198 | } 199 | 200 | /// Convert back into a fused `QuantizedLinear` layer. 201 | /// 202 | /// This is typically called via ``LoRATrain/fuse(model:layers:deQuantize:)``. 203 | /// 204 | /// ### See Also 205 | /// - ``LoRATrain/fuse(model:layers:deQuantize:)`` 206 | public func toLinear(deQuantize: Bool = false) -> Linear { 207 | // convert back into full weights 208 | let weight = dequantized( 209 | weight, scales: scales, biases: biases, groupSize: groupSize, bits: bits) 210 | 211 | let loraB = (scale * loraB.T).asType(.float16) 212 | let loraA = loraA.T.asType(.float16) 213 | 214 | // convert back into quantized 215 | return QuantizedLinear( 216 | weight: weight + matmul(loraB, loraA), bias: bias, groupSize: groupSize, bits: bits) 217 | } 218 | 219 | public override func callAsFunction(_ x: MLXArray) -> MLXArray { 220 | let y = super.callAsFunction(x.asType(scales.dtype)) 221 | let z = matmul(matmul(x, self.loraA), self.loraB) 222 | return y + scale * z 223 | } 224 | } 225 | 226 | /// Equivalent to `lora.py/iterate_batches()`. Used internally by ``LoRATrain``. 227 | struct LoRABatchIterator: Sequence, IteratorProtocol { 228 | 229 | let dataset: [String] 230 | let batchSize: Int 231 | let tokenizer: Tokenizer 232 | 233 | let train: Bool 234 | 235 | var indices: [Int] 236 | var index = 0 237 | 238 | public init(dataset: [String], tokenizer: Tokenizer, batchSize: Int, train: Bool) { 239 | self.dataset = dataset 240 | self.batchSize = batchSize 241 | self.tokenizer = tokenizer 242 | self.train = train 243 | 244 | self.indices = Array(0 ..< dataset.count) 245 | if train { 246 | indices.shuffle() 247 | } 248 | } 249 | 250 | mutating public func next() -> (MLXArray, MLXArray, MLXArray)? { 251 | if index >= indices.count { 252 | if !train { 253 | return nil 254 | } 255 | 256 | indices.shuffle() 257 | index = 0 258 | } 259 | 260 | let endIndex = Swift.min(index + batchSize, indices.count) 261 | 262 | let batch = (index ..< endIndex) 263 | .map { tokenizer.encode(text: dataset[indices[$0]]) } 264 | let lengths = batch.map { $0.count } 265 | let maxLength = lengths.max() ?? 0 266 | 267 | if maxLength > 2048 { 268 | print( 269 | """ 270 | [WARNING] Some sequences are longer than 2048 tokens. 271 | Consider pre-splitting your data to save memory. 272 | """) 273 | } 274 | 275 | // pad to the max length 276 | let batchArray = MLXArray.zeros([lengths.count, maxLength], type: Int32.self) 277 | for (j, (b, l)) in zip(batch, lengths).enumerated() { 278 | batchArray[j, 0 ..< l] = MLXArray(b) 279 | } 280 | 281 | index = endIndex 282 | 283 | return (batchArray[0..., .stride(to: -1)], batchArray[0..., 1...], MLXArray(lengths)) 284 | } 285 | 286 | } 287 | 288 | /// Collection of functions for adding LoRA adapters to an LLM model, training, fusing and saving/loading weights. 289 | /// 290 | /// The typical flow for training is: 291 | /// 292 | /// ```swift 293 | /// // load the base model and tokenizer 294 | /// let (model, tokenizer) = try await LLM.load(configuration: ModelConfiguration.mistral7B4bit) 295 | /// 296 | /// // add LoRALinear adapter layers 297 | /// LoRATrain.convert(model: model, layers: Array(model.loraLinearLayers().suffix(4))) 298 | /// 299 | /// // optionally load LoRA weights 300 | /// try LoRATrain.loadLoRAWeights(model: model, url: ...) 301 | /// 302 | /// // load the train/validation data 303 | /// let train = try loadLoRAData(directory: data, name: "train") 304 | /// let valid = try loadLoRAData(directory: data, name: "valid") 305 | /// 306 | /// // train 307 | /// let optimizer = Adam(learningRate: 1e-5) 308 | /// try await LoRATrain.train( 309 | /// model: model, train: train, validate: valid, optimizer: optimizer, tokenizer: tokenizer, 310 | /// parameters: LoRATrain.Parameters() 311 | /// ) { progress in 312 | /// print(progress) 313 | /// return .more 314 | /// } 315 | /// ``` 316 | /// 317 | /// At this point the model will be trained and you could do one of the following: 318 | /// 319 | /// - ``saveLoRAWeights(model:url:)`` -- write the LoRA weights to a file 320 | /// - ``fuse(model:layers:deQuantize:)`` -- fuse the LoRA weights and convert back into the original model 321 | /// architecture. These weights can be saved and reloaded with normal model handling code. 322 | /// - ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)``-- compute the test loss 323 | /// againts a test dataset 324 | /// - use the in memory model as a normal `LLMModel` and evaluate a prompt 325 | /// 326 | public enum LoRATrain { 327 | 328 | public typealias LoraLossFunction = (Module, MLXArray, MLXArray, MLXArray) -> ( 329 | MLXArray, MLXArray 330 | ) 331 | 332 | /// LoRA training parameters 333 | public struct Parameters: Sendable { 334 | /// number of prompts to evaluate per iteration 335 | public var batchSize = 4 336 | 337 | /// number of iterations to train for 338 | public var iterations = 1000 339 | 340 | /// number of training steps between loss reporting 341 | public var stepsPerReport = 10 342 | 343 | /// number of steps between validations 344 | public var stepsPerEval = 100 345 | 346 | /// number of validations batches, `0` uses the entire validation set 347 | public var validationBatches = 10 348 | 349 | /// save the model every N iterations 350 | public var saveEvery = 100 351 | 352 | /// save path for the adapter `.safetensors` 353 | public var adapterURL: URL? 354 | 355 | public init( 356 | batchSize: Int = 4, iterations: Int = 1000, stepsPerReport: Int = 10, 357 | stepsPerEval: Int = 100, validationBatches: Int = 10, saveEvery: Int = 100, 358 | adapterURL: URL? = nil 359 | ) { 360 | self.batchSize = batchSize 361 | self.iterations = iterations 362 | self.stepsPerReport = stepsPerReport 363 | self.stepsPerEval = stepsPerEval 364 | self.validationBatches = validationBatches 365 | self.saveEvery = saveEvery 366 | self.adapterURL = adapterURL 367 | } 368 | } 369 | 370 | /// Freeze the model layers and replace the indicated modules (Linear) that should be 371 | /// converted to ``LoRALinear`` and remain trainable. 372 | /// 373 | /// Once a model has had the LoRA adapters applied, adapter weights can be loaded 374 | /// (if available): 375 | /// 376 | /// ```swift 377 | /// try LoRATrain.loadLoRAWeights(model: model, url: args.adapter) 378 | /// ``` 379 | /// 380 | /// At this point the model is ready for one or more of the following: 381 | /// 382 | /// - training with ``train(model:train:validate:optimizer:loss:tokenizer:parameters:progress:)`` 383 | /// - loss evaluation with ``evaluate(model:dataset:loss:tokenizer:batchSize:batchCount:)`` 384 | /// - fusing with ``fuse(model:layers:deQuantize:)`` 385 | /// - text generation with ``generate(promptTokens:parameters:model:tokenizer:additionalEOSTokens:didGenerate:)`` 386 | /// - note that this is just using normal model text generation 387 | /// 388 | /// - Parameters: 389 | /// - model: model to convert 390 | /// - layers: number of suffix layers to convert 391 | public static func convert(model: Module, layers: LoRALinearLayers) { 392 | model.freeze() 393 | 394 | for (layer, keys) in layers { 395 | var update = ModuleChildren() 396 | let children = layer.children() 397 | for key in keys { 398 | if let item = children[key], case .value(let child) = item { 399 | if let linear = child as? Linear { 400 | update[key] = .value(LoRALinear.from(linear: linear)) 401 | } else { 402 | print("\(key) on \(layer) is not Linear") 403 | } 404 | } else { 405 | print("failed to find key \(key) on \(layer)") 406 | } 407 | } 408 | layer.update(modules: update) 409 | } 410 | } 411 | 412 | /// Fuses the LoRA adapters back into the model weights. 413 | /// 414 | /// This produces a model in the original format with `Linear` or `QuantizedLinear` layer 415 | /// weights that incorporate the LoRA adapter. 416 | /// 417 | /// - Parameters: 418 | /// - model: model to convert 419 | /// - deQuantize: if `true` will convert `QuantizedLinear` back into `Linear` 420 | public static func fuse(model: Module, layers: LoRALinearLayers, deQuantize: Bool = false) { 421 | for (layer, keys) in layers { 422 | var update = ModuleChildren() 423 | let children = layer.children() 424 | for key in keys { 425 | if let item = children[key], case .value(let child) = item { 426 | if let lora = child as? LoRAConvertToLinear { 427 | update[key] = .value(lora.toLinear(deQuantize: deQuantize)) 428 | } 429 | } 430 | } 431 | if !update.isEmpty { 432 | layer.update(modules: update) 433 | } 434 | } 435 | } 436 | 437 | public static func loss(model: Module, inputs: MLXArray, targets: MLXArray, lengths: MLXArray) 438 | -> ( 439 | MLXArray, MLXArray 440 | ) 441 | { 442 | // def loss(model, inputs, targets, lengths): 443 | 444 | // run model on inputs 445 | let model = model as! LLMModel 446 | let logits = model(inputs, cache: nil).asType(.float32) 447 | 448 | // mask padding tokens 449 | let lengthMask = MLXArray(0 ..< inputs.dim(1))[.newAxis, 0...] .< lengths[0..., .newAxis] 450 | 451 | // calculate the loss 452 | let ntoks = lengthMask.sum() 453 | let ce = (crossEntropy(logits: logits, targets: targets) * lengthMask).sum() / ntoks 454 | return (ce, ntoks) 455 | } 456 | 457 | /// Evaluate the model and dataset and return the loss over the entire dataset. 458 | /// 459 | /// - Parameters: 460 | /// - model: the model to evaluate 461 | /// - dataset: the dataset 462 | /// - loss: loss function 463 | /// - tokenizer: tokenizer 464 | /// - batchSize: number of items from the dataset to evaluate at once 465 | /// - batchCount: number of batch elements to evaluate, 0 for all 466 | /// - Returns: the loss over the enumerate data 467 | /// 468 | /// ### See Also 469 | /// - ``loadLoRAData(directory:name:)`` 470 | public static func evaluate( 471 | model: Module, dataset: [String], loss: LoraLossFunction = loss, tokenizer: Tokenizer, 472 | batchSize: Int, batchCount: Int 473 | ) -> Float { 474 | var allLosses = [Float]() 475 | var tokenCount = 0 476 | 477 | for (iteration, (inputs, targets, lengths)) in LoRABatchIterator( 478 | dataset: dataset, tokenizer: tokenizer, batchSize: batchSize, train: false 479 | ).enumerated() { 480 | let (losses, tokens) = loss(model, inputs, targets, lengths) 481 | allLosses.append((losses * tokens).item(Float.self)) 482 | tokenCount += tokens.item(Int.self) 483 | 484 | if batchCount != 0 && iteration + 1 >= batchCount { 485 | break 486 | } 487 | } 488 | 489 | return (sum(MLXArray(allLosses), stream: .cpu) / tokenCount).item(Float.self) 490 | } 491 | 492 | /// Given a model with LoRA adaptors applied, load adapter weights from a `.safetensors` file. 493 | /// 494 | /// ### See Also 495 | /// - ``convert(model:layers:)`` 496 | /// - ``saveLoRAWeights(model:url:)`` 497 | public static func loadLoRAWeights(model: Module, url: URL) throws { 498 | let weights = try ModuleParameters.unflattened(loadArrays(url: url)) 499 | try model.update(parameters: weights, verify: .noUnusedKeys) 500 | eval(model) 501 | } 502 | 503 | /// Given a model with LoRA adaptors applied, write adapter weights to a `.safetensors` file. 504 | /// 505 | /// ### See Also 506 | /// - ``convert(model:layers:)`` 507 | /// - ``loadLoRAWeights(model:url:)`` 508 | public static func saveLoRAWeights(model: Module, url: URL) throws { 509 | let parameters = Dictionary( 510 | uniqueKeysWithValues: model.trainableParameters().flattened()) 511 | try save(arrays: parameters, url: url) 512 | } 513 | 514 | public enum Progress: CustomStringConvertible, Sendable { 515 | case train( 516 | iteration: Int, trainingLoss: Float, iterationsPerSecond: Double, 517 | tokensPerSecond: Double) 518 | case validation(iteration: Int, validationLoss: Float, validationTime: Double) 519 | case save(iteration: Int, url: URL) 520 | 521 | public var description: String { 522 | switch self { 523 | case .train( 524 | let iteration, let trainingLoss, let iterationsPerSecond, let tokensPerSecond): 525 | "Iteration \(iteration + 1): training loss \(trainingLoss.formatted()), " 526 | + "iterations/sec \(iterationsPerSecond.formatted()), " 527 | + "Tokens/sec \(tokensPerSecond.formatted())" 528 | case .validation(let iteration, let validationLoss, let validationTime): 529 | "Iteration \(iteration + 1): " 530 | + "validation loss \(validationLoss.formatted()), " 531 | + "validation time \(validationTime.formatted())s" 532 | case .save(let iteration, let url): 533 | "Iteration \(iteration + 1): saved weights to \(url.path())" 534 | } 535 | } 536 | } 537 | 538 | public enum ProgressDisposition: Sendable { 539 | case stop 540 | case more 541 | } 542 | 543 | /// Train (or continue training) LoRA weights. 544 | /// 545 | /// - Parameters: 546 | /// - model: model to train 547 | /// - train: training dataset 548 | /// - validate: validate dataset 549 | /// - optimizer: optimizer used in training 550 | /// - loss: loss function 551 | /// - tokenizer: tokenizer 552 | /// - parameters: training parameters 553 | /// - progress: progress callback 554 | public static func train( 555 | model: Module, train: [String], validate: [String], optimizer: Optimizer, 556 | loss: @escaping LoraLossFunction = loss, tokenizer: Tokenizer, parameters: Parameters, 557 | progress: (Progress) -> ProgressDisposition 558 | ) throws { 559 | // def train(model, train_set, val_set, optimizer, loss, tokenizer, args) 560 | 561 | let lossValueGrad = valueAndGrad(model: model) { model, arrays in 562 | let (ce, ntoks) = loss(model, arrays[0], arrays[1], arrays[2]) 563 | return [ce, ntoks] 564 | } 565 | 566 | var losses = [Float]() 567 | var tokenCount = 0 568 | 569 | var start = Date.timeIntervalSinceReferenceDate 570 | 571 | for (iteration, (inputs, targets, lengths)) in LoRABatchIterator( 572 | dataset: train, tokenizer: tokenizer, batchSize: parameters.batchSize, train: true 573 | ).enumerated() { 574 | // forward and backward pass 575 | let (resultArray, grad) = lossValueGrad(model, [inputs, targets, lengths]) 576 | let lvalue = resultArray[0] 577 | let tokens = resultArray[1] 578 | 579 | // model update 580 | optimizer.update(model: model, gradients: grad) 581 | eval(model, optimizer, lvalue) 582 | 583 | // record loss 584 | losses.append(lvalue.item(Float.self)) 585 | tokenCount += tokens.item(Int.self) 586 | 587 | // report training loss 588 | if (iteration + 1) % parameters.stepsPerReport == 0 { 589 | let trainingLoss = MLXArray(losses).mean(stream: .cpu).item(Float.self) 590 | let now = Date.timeIntervalSinceReferenceDate 591 | 592 | let iterationsPerSecond = Double(parameters.stepsPerReport) / (now - start) 593 | let tokensPerSecond = Double(tokenCount) / (now - start) 594 | 595 | if progress( 596 | .train( 597 | iteration: iteration, trainingLoss: trainingLoss, 598 | iterationsPerSecond: iterationsPerSecond, tokensPerSecond: tokensPerSecond)) 599 | == .stop 600 | { 601 | break 602 | } 603 | 604 | losses.removeAll() 605 | tokenCount = 0 606 | start = Date.timeIntervalSinceReferenceDate 607 | } 608 | 609 | // report validation loss 610 | if iteration == 0 || (iteration + 1) % parameters.stepsPerEval == 0 { 611 | let validationStart = Date.timeIntervalSinceReferenceDate 612 | let validationLoss = evaluate( 613 | model: model, dataset: validate, loss: loss, tokenizer: tokenizer, 614 | batchSize: parameters.batchSize, batchCount: parameters.validationBatches) 615 | let now = Date.timeIntervalSinceReferenceDate 616 | 617 | if progress( 618 | .validation( 619 | iteration: iteration, validationLoss: validationLoss, 620 | validationTime: now - validationStart)) == .stop 621 | { 622 | break 623 | } 624 | 625 | start = Date.timeIntervalSinceReferenceDate 626 | } 627 | 628 | // save adapter weights if needed 629 | if let adapterURL = parameters.adapterURL, (iteration + 1) % parameters.saveEvery == 0 { 630 | try saveLoRAWeights(model: model, url: adapterURL) 631 | 632 | if progress(.save(iteration: iteration, url: adapterURL)) == .stop { 633 | break 634 | } 635 | 636 | start = Date.timeIntervalSinceReferenceDate 637 | } 638 | 639 | if iteration + 1 >= parameters.iterations { 640 | break 641 | } 642 | } 643 | } 644 | } 645 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/ModelConfiguration.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ModelConfiguration.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | import Hub 10 | 11 | /// Registry of models and any overrides that go with them, e.g. prompt augmentation. 12 | /// If asked for an unknown configuration this will use the model/tokenizer as-is. 13 | /// 14 | /// The python tokenizers have a very rich set of implementations and configuration. The 15 | /// swift-tokenizers code handles a good chunk of that and this is a place to augment that 16 | /// implementation, if needed. 17 | public struct ModelConfiguration: Sendable { 18 | 19 | public enum Identifier: Sendable { 20 | case id(String) 21 | case directory(URL) 22 | } 23 | 24 | public var id: Identifier 25 | 26 | public var name: String { 27 | switch id { 28 | case .id(let string): 29 | string 30 | case .directory(let url): 31 | url.deletingLastPathComponent().lastPathComponent + "/" + url.lastPathComponent 32 | } 33 | } 34 | 35 | /// pull the tokenizer from an alternate id 36 | public let tokenizerId: String? 37 | 38 | /// overrides for TokenizerModel/knownTokenizers -- useful before swift-transformers is updated 39 | public let overrideTokenizer: String? 40 | 41 | /// A reasonable default prompt for the model 42 | public let defaultPrompt: String 43 | 44 | /// Additional tokens to use for end of string 45 | public let extraEOSTokens: Set 46 | 47 | public init( 48 | id: String, tokenizerId: String? = nil, overrideTokenizer: String? = nil, 49 | defaultPrompt: String = "hello", 50 | extraEOSTokens: Set = [], 51 | preparePrompt: (@Sendable (String) -> String)? = nil 52 | ) { 53 | self.id = .id(id) 54 | self.tokenizerId = tokenizerId 55 | self.overrideTokenizer = overrideTokenizer 56 | self.defaultPrompt = defaultPrompt 57 | self.extraEOSTokens = extraEOSTokens 58 | } 59 | 60 | public init( 61 | directory: URL, tokenizerId: String? = nil, overrideTokenizer: String? = nil, 62 | defaultPrompt: String = "hello", 63 | extraEOSTokens: Set = [] 64 | ) { 65 | self.id = .directory(directory) 66 | self.tokenizerId = tokenizerId 67 | self.overrideTokenizer = overrideTokenizer 68 | self.defaultPrompt = defaultPrompt 69 | self.extraEOSTokens = extraEOSTokens 70 | } 71 | 72 | public func modelDirectory(hub: HubApi = HubApi()) -> URL { 73 | switch id { 74 | case .id(let id): 75 | // download the model weights and config 76 | let repo = Hub.Repo(id: id) 77 | return hub.localRepoLocation(repo) 78 | 79 | case .directory(let directory): 80 | return directory 81 | } 82 | } 83 | 84 | @MainActor 85 | public static var registry = [String: ModelConfiguration]() 86 | 87 | @MainActor 88 | public static func register(configurations: [ModelConfiguration]) { 89 | bootstrap() 90 | 91 | for c in configurations { 92 | registry[c.name] = c 93 | } 94 | } 95 | 96 | @MainActor 97 | public static func configuration(id: String) -> ModelConfiguration { 98 | bootstrap() 99 | 100 | if let c = registry[id] { 101 | return c 102 | } else { 103 | return ModelConfiguration(id: id) 104 | } 105 | } 106 | } 107 | 108 | extension ModelConfiguration { 109 | public static let smolLM_135M_4bit = ModelConfiguration( 110 | id: "mlx-community/SmolLM-135M-Instruct-4bit", 111 | defaultPrompt: "Tell me about the history of Spain." 112 | ) 113 | 114 | public static let mistralNeMo4bit = ModelConfiguration( 115 | id: "mlx-community/Mistral-Nemo-Instruct-2407-4bit", 116 | defaultPrompt: "Explain quaternions." 117 | ) 118 | 119 | public static let mistral7B4bit = ModelConfiguration( 120 | id: "mlx-community/Mistral-7B-Instruct-v0.3-4bit", 121 | defaultPrompt: "Describe the Swift language." 122 | ) 123 | 124 | public static let codeLlama13b4bit = ModelConfiguration( 125 | id: "mlx-community/CodeLlama-13b-Instruct-hf-4bit-MLX", 126 | overrideTokenizer: "PreTrainedTokenizer", 127 | defaultPrompt: "func sortArray(_ array: [Int]) -> String { }" 128 | ) 129 | 130 | public static let phi4bit = ModelConfiguration( 131 | id: "mlx-community/phi-2-hf-4bit-mlx", 132 | // https://www.promptingguide.ai/models/phi-2 133 | defaultPrompt: "Why is the sky blue?" 134 | ) 135 | 136 | public static let phi3_5_4bit = ModelConfiguration( 137 | id: "mlx-community/Phi-3.5-mini-instruct-4bit", 138 | defaultPrompt: "What is the gravity on Mars and the moon?", 139 | extraEOSTokens: ["<|end|>"] 140 | ) 141 | 142 | public static let phi3_5MoE = ModelConfiguration( 143 | id: "mlx-community/Phi-3.5-MoE-instruct-4bit", 144 | defaultPrompt: "What is the gravity on Mars and the moon?", 145 | extraEOSTokens: ["<|end|>"] 146 | ) { 147 | prompt in 148 | "<|user|>\n\(prompt)<|end|>\n<|assistant|>\n" 149 | } 150 | 151 | public static let gemma2bQuantized = ModelConfiguration( 152 | id: "mlx-community/quantized-gemma-2b-it", 153 | overrideTokenizer: "PreTrainedTokenizer", 154 | // https://www.promptingguide.ai/models/gemma 155 | defaultPrompt: "what is the difference between lettuce and cabbage?" 156 | ) 157 | 158 | public static let gemma_2_9b_it_4bit = ModelConfiguration( 159 | id: "mlx-community/gemma-2-9b-it-4bit", 160 | overrideTokenizer: "PreTrainedTokenizer", 161 | // https://www.promptingguide.ai/models/gemma 162 | defaultPrompt: "What is the difference between lettuce and cabbage?" 163 | ) 164 | 165 | public static let gemma_2_2b_it_4bit = ModelConfiguration( 166 | id: "mlx-community/gemma-2-2b-it-4bit", 167 | overrideTokenizer: "PreTrainedTokenizer", 168 | // https://www.promptingguide.ai/models/gemma 169 | defaultPrompt: "What is the difference between lettuce and cabbage?" 170 | ) 171 | 172 | public static let qwen205b4bit = ModelConfiguration( 173 | id: "mlx-community/Qwen1.5-0.5B-Chat-4bit", 174 | overrideTokenizer: "PreTrainedTokenizer", 175 | defaultPrompt: "why is the sky blue?" 176 | ) 177 | 178 | public static let openelm270m4bit = ModelConfiguration( 179 | id: "mlx-community/OpenELM-270M-Instruct", 180 | // https://huggingface.co/apple/OpenELM 181 | defaultPrompt: "Once upon a time there was" 182 | ) 183 | 184 | public static let llama3_1_8B_4bit = ModelConfiguration( 185 | id: "mlx-community/Meta-Llama-3.1-8B-Instruct-4bit", 186 | defaultPrompt: "What is the difference between a fruit and a vegetable?" 187 | ) 188 | 189 | public static let llama3_8B_4bit = ModelConfiguration( 190 | id: "mlx-community/Meta-Llama-3-8B-Instruct-4bit", 191 | defaultPrompt: "What is the difference between a fruit and a vegetable?" 192 | ) 193 | 194 | public static let llama3_2_1B_4bit = ModelConfiguration( 195 | id: "mlx-community/Llama-3.2-1B-Instruct-4bit", 196 | defaultPrompt: "What is the difference between a fruit and a vegetable?" 197 | ) 198 | 199 | public static let llama3_2_3B_4bit = ModelConfiguration( 200 | id: "mlx-community/Llama-3.2-3B-Instruct-4bit", 201 | defaultPrompt: "What is the difference between a fruit and a vegetable?" 202 | ) 203 | 204 | private enum BootstrapState: Sendable { 205 | case idle 206 | case bootstrapping 207 | case bootstrapped 208 | } 209 | 210 | @MainActor 211 | static private var bootstrapState = BootstrapState.idle 212 | 213 | @MainActor 214 | static func bootstrap() { 215 | switch bootstrapState { 216 | case .idle: 217 | bootstrapState = .bootstrapping 218 | register(configurations: [ 219 | codeLlama13b4bit, 220 | gemma2bQuantized, 221 | gemma_2_2b_it_4bit, 222 | gemma_2_9b_it_4bit, 223 | llama3_1_8B_4bit, 224 | llama3_2_1B_4bit, 225 | llama3_2_3B_4bit, 226 | llama3_8B_4bit, 227 | mistral7B4bit, 228 | mistralNeMo4bit, 229 | openelm270m4bit, 230 | phi3_5MoE, 231 | phi3_5_4bit, 232 | phi4bit, 233 | qwen205b4bit, 234 | smolLM_135M_4bit, 235 | ]) 236 | bootstrapState = .bootstrapped 237 | 238 | case .bootstrapping: 239 | break 240 | 241 | case .bootstrapped: 242 | break 243 | } 244 | } 245 | } 246 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Libraries/Tokenizer.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Tokenizer.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 08.11.24. 6 | // 7 | 8 | import Foundation 9 | import Hub 10 | import Tokenizers 11 | 12 | public func loadTokenizer(configuration: ModelConfiguration, hub: HubApi) async throws -> Tokenizer 13 | { 14 | let (tokenizerConfig, tokenizerData) = try await loadTokenizerConfig( 15 | configuration: configuration, hub: hub) 16 | 17 | return try PreTrainedTokenizer( 18 | tokenizerConfig: tokenizerConfig, tokenizerData: tokenizerData) 19 | } 20 | 21 | func loadTokenizerConfig(configuration: ModelConfiguration, hub: HubApi) async throws -> ( 22 | Config, Config 23 | ) { 24 | // from AutoTokenizer.from() -- this lets us override parts of the configuration 25 | let config: LanguageModelConfigurationFromHub 26 | 27 | switch configuration.id { 28 | case .id(let id): 29 | do { 30 | // the load can fail (async when we try to use it) 31 | let loaded = LanguageModelConfigurationFromHub( 32 | modelName: configuration.tokenizerId ?? id, hubApi: hub) 33 | _ = try await loaded.tokenizerConfig 34 | config = loaded 35 | } catch { 36 | let nserror = error as NSError 37 | if nserror.domain == NSURLErrorDomain 38 | && nserror.code == NSURLErrorNotConnectedToInternet 39 | { 40 | // Internet connection appears to be offline -- fall back to loading from 41 | // the local directory 42 | config = LanguageModelConfigurationFromHub( 43 | modelFolder: configuration.modelDirectory(hub: hub), hubApi: hub) 44 | } else { 45 | throw error 46 | } 47 | } 48 | case .directory(let directory): 49 | config = LanguageModelConfigurationFromHub(modelFolder: directory, hubApi: hub) 50 | } 51 | 52 | guard var tokenizerConfig = try await config.tokenizerConfig else { 53 | throw LLMError(message: "missing config") 54 | } 55 | let tokenizerData = try await config.tokenizerData 56 | 57 | tokenizerConfig = updateTokenizerConfig(tokenizerConfig) 58 | 59 | return (tokenizerConfig, tokenizerData) 60 | } 61 | 62 | private func updateTokenizerConfig(_ tokenizerConfig: Config) -> Config { 63 | // workaround: replacement tokenizers for unhandled values in swift-transform 64 | if let tokenizerClass = tokenizerConfig.tokenizerClass?.stringValue, 65 | let replacement = replacementTokenizers[tokenizerClass] 66 | { 67 | var dictionary = tokenizerConfig.dictionary 68 | dictionary["tokenizer_class"] = replacement 69 | return Config(dictionary) 70 | } 71 | 72 | return tokenizerConfig 73 | } 74 | 75 | public class TokenizerReplacementRegistry: @unchecked Sendable { 76 | 77 | // Note: using NSLock as we have very small (just dictionary get/set) 78 | // critical sections and expect no contention. this allows the methods 79 | // to remain synchronous. 80 | private let lock = NSLock() 81 | 82 | /// overrides for TokenizerModel/knownTokenizers 83 | private var replacementTokenizers = [ 84 | "InternLM2Tokenizer": "PreTrainedTokenizer", 85 | "Qwen2Tokenizer": "PreTrainedTokenizer", 86 | "CohereTokenizer": "PreTrainedTokenizer", 87 | ] 88 | 89 | public subscript(key: String) -> String? { 90 | get { 91 | lock.withLock { 92 | replacementTokenizers[key] 93 | } 94 | } 95 | set { 96 | lock.withLock { 97 | replacementTokenizers[key] = newValue 98 | } 99 | } 100 | } 101 | } 102 | 103 | public let replacementTokenizers = TokenizerReplacementRegistry() 104 | 105 | public protocol StreamingDetokenizer: IteratorProtocol { 106 | 107 | mutating func append(token: Int) 108 | 109 | } 110 | 111 | public struct NaiveStreamingDetokenizer: StreamingDetokenizer { 112 | let tokenizer: Tokenizer 113 | 114 | var segmentTokens = [Int]() 115 | var segment = "" 116 | 117 | public init(tokenizer: Tokenizer) { 118 | self.tokenizer = tokenizer 119 | } 120 | 121 | mutating public func append(token: Int) { 122 | segmentTokens.append(token) 123 | } 124 | 125 | mutating func startNewSegment() { 126 | let lastToken = segmentTokens.last 127 | segmentTokens.removeAll() 128 | if let lastToken { 129 | segmentTokens.append(lastToken) 130 | segment = tokenizer.decode(tokens: segmentTokens) 131 | } else { 132 | segment = "" 133 | } 134 | } 135 | 136 | public mutating func next() -> String? { 137 | let newSegment = tokenizer.decode(tokens: segmentTokens) 138 | let new = newSegment.suffix(newSegment.count - segment.count) 139 | 140 | if new.hasSuffix("\n") { 141 | startNewSegment() 142 | } else { 143 | self.segment = newSegment 144 | } 145 | 146 | return String(new) 147 | } 148 | 149 | } 150 | -------------------------------------------------------------------------------- /MyAppleIntelligence/MyAppleIntelligence.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.app-sandbox 6 | 7 | com.apple.security.files.user-selected.read-only 8 | 9 | 10 | 11 | -------------------------------------------------------------------------------- /MyAppleIntelligence/MyAppleIntelligenceApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // MyAppleIntelligenceApp.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 05.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @main 11 | struct MyAppleIntelligenceApp: App { 12 | var body: some Scene { 13 | WindowGroup { 14 | ContentView() 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Navigation/Model/NavigationOption.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NavigationOption.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 12.11.24. 6 | // 7 | 8 | import Foundation 9 | 10 | enum NavigationOption: String, CaseIterable, Hashable { 11 | case WritingTools, ImagePlayground 12 | } 13 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Navigation/View/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ContentView.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 05.11.24. 6 | // 7 | 8 | import SwiftUI 9 | import MarkdownUI 10 | import MLX 11 | 12 | struct ContentView: View { 13 | 14 | @StateObject private var navigationModel = NavigationModel() 15 | 16 | var body: some View { 17 | NavigationStack(path: $navigationModel.navigationPath) { 18 | VStack { 19 | ForEach(NavigationOption.allCases, id: \.self) { navigationOption in 20 | NavigationLink(value: navigationOption) { 21 | Text(navigationOption.rawValue) 22 | .foregroundStyle(.primary) 23 | .font(.title) 24 | .frame(maxWidth: .infinity, maxHeight: 200) 25 | .padding(.horizontal, 16) 26 | .padding(.vertical, 8) 27 | .background(.thinMaterial, in: .rect(cornerRadius: 20)) 28 | .background(MeshGradient.custom, in: .rect(cornerRadius: 20)) 29 | .tint(.primary) 30 | } 31 | } 32 | 33 | Spacer() 34 | } 35 | .navigationTitle("My A.I.") 36 | .navigationDestination(for: NavigationOption.self) { navigationOption in 37 | switch navigationOption { 38 | case .WritingTools: 39 | WritingToolsInputView(navigationModel: navigationModel) 40 | case .ImagePlayground: 41 | ImagePlaygroundView(navigationModel: navigationModel) 42 | } 43 | } 44 | } 45 | } 46 | } 47 | 48 | #Preview { 49 | ContentView() 50 | } 51 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Navigation/ViewModel/NavigationModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // NavigationModel.swift 3 | // MyAppleIntelligence 4 | // 5 | // Created by Stefan Blos on 12.11.24. 6 | // 7 | 8 | import SwiftUI 9 | 10 | @MainActor 11 | class NavigationModel: ObservableObject { 12 | @Published var navigationPath = NavigationPath() 13 | } 14 | -------------------------------------------------------------------------------- /MyAppleIntelligence/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # My Apple Intelligence 2 | 3 | This project is used to discover the ability to build a copy of Apple Intelligence myself. Currently, it covers building the following two features: 4 | 5 | - Writing Tools 6 | - Image Playground 7 | 8 | The focus when building this project was all about: 9 | 10 | - UX 11 | - Open-Source 12 | - On-device 13 | 14 | This project is used as a testiment for my talk at the [Do iOS Conference](https://do-ios.com) in Amsterdam on November 13-14. 15 | 16 | It's not a perfectly running app, but in the end it's possible to use it to run the features mentioned fully on-device. 17 | 18 | If you enjoy the project, feel free to [follow me on X](https://x.com/stefanjblos) to learn more about it. 19 | 20 | ## Writing Tools demo 21 | 22 | https://github.com/user-attachments/assets/24b075d3-f645-45d2-b56e-cf1af25181fe 23 | 24 | ## Image Playground demo 25 | 26 | https://github.com/user-attachments/assets/7819eda8-5e20-44ae-9bde-3b166446a912 27 | 28 | ## Helpful Resources 29 | 30 | Here are some of the resources that were incredibly helpful when creating this project: 31 | 32 | - [mlx-swift](https://github.com/ml-explore/mlx-swift) 33 | - [swift-transformers](https://github.com/huggingface/swift-transformers) 34 | - [mlx-swift-examples](https://github.com/ml-explore/mlx-swift-examples) 35 | - [Llama 3.2 3B Instruct 4bit-quantized model](https://huggingface.co/mlx-community/Llama-3.2-3B-Instruct-4bit) 36 | - [SDXL 1.0-base model](https://huggingface.co/apple/coreml-stable-diffusion-xl-base-ios) 37 | --------------------------------------------------------------------------------