├── GenerateApp ├── GuernikaModelConverter.version ├── GuernikaModelConverter_AppIcon.png ├── exportOptions.plist └── GuernikaTools.entitlements ├── GuernikaTools ├── guernikatools │ ├── _version.py │ ├── __init__.py │ ├── models │ │ ├── layer_norm.py │ │ ├── attention.py │ │ └── controlnet.py │ ├── convert │ │ ├── convert_text_encoder.py │ │ ├── convert_t2i_adapter.py │ │ ├── convert_safety_checker.py │ │ ├── convert_controlnet.py │ │ ├── convert_vae.py │ │ └── convert_unet.py │ └── utils │ │ ├── merge_lora.py │ │ ├── utils.py │ │ └── chunk_mlprogram.py ├── requirements.txt ├── build.sh ├── GuernikaTools.entitlements ├── setup.py └── GuernikaTools.spec ├── GuernikaModelConverter ├── Assets.xcassets │ ├── Contents.json │ ├── AppIcon.appiconset │ │ ├── AppIcon-128.png │ │ ├── AppIcon-16.png │ │ ├── AppIcon-256.png │ │ ├── AppIcon-32.png │ │ ├── AppIcon-512.png │ │ ├── AppIcon-64.png │ │ ├── AppIcon-1024.png │ │ └── Contents.json │ └── AccentColor.colorset │ │ └── Contents.json ├── Preview Content │ └── Preview Assets.xcassets │ │ └── Contents.json ├── Model │ ├── LoRAInfo.swift │ ├── Compression.swift │ ├── ModelOrigin.swift │ ├── ComputeUnits.swift │ ├── Version.swift │ ├── Logger.swift │ └── ConverterProcess.swift ├── GuernikaModelConverter.entitlements ├── GuernikaModelConverterApp.swift ├── Navigation │ ├── ContentView.swift │ ├── DetailColumn.swift │ └── Sidebar.swift ├── Views │ ├── CircularProgress.swift │ ├── DestinationToolbarButton.swift │ ├── IntegerField.swift │ └── DecimalField.swift ├── Log │ └── LogView.swift ├── ConvertControlNet │ └── ConvertControlNetViewModel.swift └── ConvertModel │ └── ConvertModelViewModel.swift ├── GuernikaModelConverter.xcodeproj ├── project.xcworkspace │ ├── contents.xcworkspacedata │ └── xcshareddata │ │ └── IDEWorkspaceChecks.plist └── xcuserdata │ └── guiye.xcuserdatad │ └── xcschemes │ └── xcschememanagement.plist ├── .github └── FUNDING.yml ├── GuernikaModelConverterUITests ├── GuernikaModelConverterUITestsLaunchTests.swift └── GuernikaModelConverterUITests.swift ├── GuernikaModelConverterTests └── GuernikaModelConverterTests.swift ├── .gitignore └── README.md /GenerateApp/GuernikaModelConverter.version: -------------------------------------------------------------------------------- 1 | 6.5.0 2 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "7.0.0" 2 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /GenerateApp/GuernikaModelConverter_AppIcon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GenerateApp/GuernikaModelConverter_AppIcon.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Preview Content/Preview Assets.xcassets/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "info" : { 3 | "author" : "xcode", 4 | "version" : 1 5 | } 6 | } 7 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-128.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-128.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-16.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-256.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-32.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-512.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-512.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-64.png -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-1024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/GuernikaCore/GuernikaModelConverter/HEAD/GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/AppIcon-1024.png -------------------------------------------------------------------------------- /GuernikaModelConverter.xcodeproj/project.xcworkspace/contents.xcworkspacedata: -------------------------------------------------------------------------------- 1 | 2 | 4 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /GuernikaTools/requirements.txt: -------------------------------------------------------------------------------- 1 | coremltools>=7.0b2 2 | diffusers[torch] 3 | torch 4 | transformers>=4.30.0 5 | scipy 6 | numpy<1.24 7 | pytest 8 | scikit-learn==1.1.2 9 | pytorch_lightning 10 | omegaconf 11 | six 12 | safetensors 13 | pyinstaller 14 | -------------------------------------------------------------------------------- /GuernikaTools/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pyinstaller GuernikaTools.spec --clean -y --distpath "../GuernikaModelConverter" 4 | codesign -s - -i com.guiyec.GuernikaModelConverter.GuernikaTools -o runtime --entitlements GuernikaTools.entitlements -f "../GuernikaModelConverter/GuernikaTools" 5 | -------------------------------------------------------------------------------- /GuernikaModelConverter.xcodeproj/project.xcworkspace/xcshareddata/IDEWorkspaceChecks.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | IDEDidComputeMac32BitWarning 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /GenerateApp/exportOptions.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | method 6 | developer-id 7 | signingStyle 8 | automatic 9 | teamID 10 | A5ZC2LG374 11 | 12 | 13 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/LoRAInfo.swift: -------------------------------------------------------------------------------- 1 | // 2 | // LoRAInfo.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 11/8/23. 6 | // 7 | 8 | import Foundation 9 | 10 | struct LoRAInfo: Hashable, Identifiable { 11 | var id: URL { url } 12 | var url: URL 13 | var ratio: Double 14 | 15 | var argument: String { 16 | String(format: "%@:%0.2f", url.path(percentEncoded: false), ratio) 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /GuernikaModelConverter.xcodeproj/xcuserdata/guiye.xcuserdatad/xcschemes/xcschememanagement.plist: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | SchemeUserState 6 | 7 | GuernikaModelConverter.xcscheme_^#shared#^_ 8 | 9 | orderHint 10 | 0 11 | 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /GuernikaModelConverter/GuernikaModelConverter.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.cs.allow-dyld-environment-variables 6 | 7 | com.apple.security.cs.allow-jit 8 | 9 | com.apple.security.cs.disable-library-validation 10 | 11 | com.apple.security.network.client 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /GenerateApp/GuernikaTools.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.cs.allow-dyld-environment-variables 6 | 7 | com.apple.security.cs.allow-jit 8 | 9 | com.apple.security.cs.disable-library-validation 10 | 11 | com.apple.security.network.client 12 | 13 | com.apple.security.inherit 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /GuernikaTools/GuernikaTools.entitlements: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | com.apple.security.cs.allow-dyld-environment-variables 6 | 7 | com.apple.security.cs.allow-jit 8 | 9 | com.apple.security.cs.disable-library-validation 10 | 11 | com.apple.security.network.client 12 | 13 | com.apple.security.inherit 14 | 15 | 16 | 17 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/Compression.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Compression.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 19/6/23. 6 | // 7 | 8 | import Foundation 9 | 10 | enum Compression: String, CaseIterable, Identifiable, CustomStringConvertible { 11 | case quantizied6bit 12 | case quantizied8bit 13 | case fullSize 14 | 15 | var id: String { rawValue } 16 | 17 | var description: String { 18 | switch self { 19 | case .quantizied6bit: return "6 bit" 20 | case .quantizied8bit: return "8 bit" 21 | case .fullSize: return "Full size" 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/ModelOrigin.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ModelOrigin.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import Foundation 9 | 10 | enum ModelOrigin: String, CaseIterable, Identifiable, CustomStringConvertible { 11 | case huggingface 12 | case diffusers 13 | case checkpoint 14 | 15 | var id: String { rawValue } 16 | 17 | var description: String { 18 | switch self { 19 | case .huggingface: return "🤗 Hugging Face" 20 | case .diffusers: return "📂 Diffusers" 21 | case .checkpoint: return "💽 Checkpoint" 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: GuiyeC 4 | patreon: # Replace with a single Patreon username 5 | open_collective: # Replace with a single Open Collective username 6 | ko_fi: # Replace with a single Ko-fi username 7 | tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel 8 | community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry 9 | liberapay: # Replace with a single Liberapay username 10 | issuehunt: # Replace with a single IssueHunt username 11 | otechie: # Replace with a single Otechie username 12 | lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry 13 | custom: # Replace with up to 4 custom sponsorship URLs e.g., ['link1', 'link2'] 14 | -------------------------------------------------------------------------------- /GuernikaModelConverter/GuernikaModelConverterApp.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GuernikaModelConverterApp.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | final class AppDelegate: NSObject, NSApplicationDelegate { 11 | func applicationShouldTerminateAfterLastWindowClosed(_ sender: NSApplication) -> Bool { 12 | true 13 | } 14 | } 15 | 16 | @main 17 | struct GuernikaModelConverterApp: App { 18 | @NSApplicationDelegateAdaptor(AppDelegate.self) var delegate 19 | 20 | var body: some Scene { 21 | WindowGroup { 22 | ContentView() 23 | }.commands { 24 | SidebarCommands() 25 | 26 | CommandGroup(replacing: CommandGroupPlacement.newItem) {} 27 | } 28 | } 29 | } 30 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AccentColor.colorset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "colors" : [ 3 | { 4 | "color" : { 5 | "color-space" : "srgb", 6 | "components" : { 7 | "alpha" : "1.000", 8 | "blue" : "0.271", 9 | "green" : "0.067", 10 | "red" : "0.522" 11 | } 12 | }, 13 | "idiom" : "mac" 14 | }, 15 | { 16 | "appearances" : [ 17 | { 18 | "appearance" : "luminosity", 19 | "value" : "dark" 20 | } 21 | ], 22 | "color" : { 23 | "color-space" : "srgb", 24 | "components" : { 25 | "alpha" : "1.000", 26 | "blue" : "0.271", 27 | "green" : "0.067", 28 | "red" : "0.522" 29 | } 30 | }, 31 | "idiom" : "mac" 32 | } 33 | ], 34 | "info" : { 35 | "author" : "xcode", 36 | "version" : 1 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /GuernikaModelConverterUITests/GuernikaModelConverterUITestsLaunchTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GuernikaModelConverterUITestsLaunchTests.swift 3 | // GuernikaModelConverterUITests 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import XCTest 9 | 10 | final class GuernikaModelConverterUITestsLaunchTests: XCTestCase { 11 | 12 | override class var runsForEachTargetApplicationUIConfiguration: Bool { 13 | true 14 | } 15 | 16 | override func setUpWithError() throws { 17 | continueAfterFailure = false 18 | } 19 | 20 | func testLaunch() throws { 21 | let app = XCUIApplication() 22 | app.launch() 23 | 24 | // Insert steps here to perform after app launch but before taking a screenshot, 25 | // such as logging into a test account or navigating somewhere in the app 26 | 27 | let attachment = XCTAttachment(screenshot: app.screenshot()) 28 | attachment.name = "Launch Screen" 29 | attachment.lifetime = .keepAlways 30 | add(attachment) 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Navigation/ContentView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ContentView.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct ContentView: View { 11 | @State private var selection: Panel = Panel.model 12 | @State private var path = NavigationPath() 13 | 14 | var body: some View { 15 | NavigationSplitView { 16 | Sidebar(path: $path, selection: $selection) 17 | } detail: { 18 | NavigationStack(path: $path) { 19 | DetailColumn(path: $path, selection: $selection) 20 | } 21 | } 22 | .onChange(of: selection) { _ in 23 | path.removeLast(path.count) 24 | } 25 | .frame(minWidth: 800, minHeight: 500) 26 | } 27 | } 28 | 29 | struct ContentView_Previews: PreviewProvider { 30 | struct Preview: View { 31 | var body: some View { 32 | ContentView() 33 | } 34 | } 35 | static var previews: some View { 36 | Preview() 37 | } 38 | } 39 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Navigation/DetailColumn.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DetailColumn.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct DetailColumn: View { 11 | @Binding var path: NavigationPath 12 | @Binding var selection: Panel 13 | @StateObject var modelConverter = ConvertModelViewModel() 14 | @StateObject var controlNetConverter = ConvertControlNetViewModel() 15 | 16 | var body: some View { 17 | switch selection { 18 | case .model: 19 | ConvertModelView(model: modelConverter) 20 | case .controlNet: 21 | ConvertControlNetView(model: controlNetConverter) 22 | case .log: 23 | LogView() 24 | } 25 | } 26 | } 27 | 28 | struct DetailColumn_Previews: PreviewProvider { 29 | struct Preview: View { 30 | @State private var selection: Panel = .model 31 | 32 | var body: some View { 33 | DetailColumn(path: .constant(NavigationPath()), selection: $selection) 34 | } 35 | } 36 | static var previews: some View { 37 | Preview() 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/ComputeUnits.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ComputeUnits.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import Foundation 9 | 10 | enum ComputeUnits: String, CaseIterable, Identifiable, CustomStringConvertible { 11 | case cpuAndNeuralEngine, cpuAndGPU, cpuOnly, all 12 | 13 | var id: String { rawValue } 14 | 15 | var description: String { 16 | switch self { 17 | case .cpuAndNeuralEngine: return "CPU and Neural Engine" 18 | case .cpuAndGPU: return "CPU and GPU" 19 | case .cpuOnly: return "CPU only" 20 | case .all: return "All" 21 | } 22 | } 23 | 24 | var shortDescription: String { 25 | switch self { 26 | case .cpuAndNeuralEngine: return "CPU & NE" 27 | case .cpuAndGPU: return "CPU & GPU" 28 | case .cpuOnly: return "CPU" 29 | case .all: return "All" 30 | } 31 | } 32 | 33 | var asCTComputeUnits: String { 34 | switch self { 35 | case .cpuAndNeuralEngine: return "CPU_AND_NE" 36 | case .cpuAndGPU: return "CPU_AND_GPU" 37 | case .cpuOnly: return "CPU_ONLY" 38 | case .all: return "ALL" 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /GuernikaTools/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | from guernikatools._version import __version__ 4 | 5 | setup( 6 | name='guernikatools', 7 | version=__version__, 8 | url='https://github.com/GuernikaCore/GuernikaModelConverter', 9 | description="Run Stable Diffusion on Apple Silicon with Guernika", 10 | author='Guernika', 11 | install_requires=[ 12 | "coremltools>=7.0", 13 | "diffusers[torch]", 14 | "torch", 15 | "transformers>=4.30.0", 16 | "scipy", 17 | "numpy", 18 | "pytest", 19 | "scikit-learn==1.1.2", 20 | "pytorch_lightning", 21 | "OmegaConf", 22 | "six", 23 | "safetensors", 24 | "pyinstaller", 25 | ], 26 | packages=find_packages(), 27 | classifiers=[ 28 | "Development Status :: 4 - Beta", 29 | "Intended Audience :: Developers", 30 | "Operating System :: MacOS :: MacOS X", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.7", 33 | "Programming Language :: Python :: 3.8", 34 | "Programming Language :: Python :: 3.9", 35 | "Topic :: Artificial Intelligence", 36 | "Topic :: Scientific/Engineering", 37 | "Topic :: Software Development", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Views/CircularProgress.swift: -------------------------------------------------------------------------------- 1 | // 2 | // CircularProgress.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 15/12/22. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct CircularProgress: View { 11 | var progress: Float? 12 | 13 | var body: some View { 14 | ZStack { 15 | if let progress, progress != 0 && progress != 1 { 16 | Circle() 17 | .stroke(lineWidth: 4) 18 | .opacity(0.2) 19 | .foregroundColor(Color.primary) 20 | .frame(width: 20, height: 20) 21 | Circle() 22 | .trim(from: 0, to: CGFloat(min(progress, 1))) 23 | .stroke(style: StrokeStyle(lineWidth: 4, lineCap: .round, lineJoin: .round)) 24 | .foregroundColor(Color.accentColor) 25 | .rotationEffect(Angle(degrees: 270)) 26 | .animation(.linear, value: progress) 27 | .frame(width: 20, height: 20) 28 | } else { 29 | ProgressView() 30 | .progressViewStyle(.circular) 31 | .scaleEffect(0.7) 32 | } 33 | }.frame(width: 24, height: 24) 34 | } 35 | } 36 | 37 | struct CircularProgress_Previews: PreviewProvider { 38 | static var previews: some View { 39 | CircularProgress(progress: 0.5) 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /GuernikaModelConverterTests/GuernikaModelConverterTests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GuernikaModelConverterTests.swift 3 | // GuernikaModelConverterTests 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import XCTest 9 | @testable import GuernikaModelConverter 10 | 11 | final class GuernikaModelConverterTests: XCTestCase { 12 | 13 | override func setUpWithError() throws { 14 | // Put setup code here. This method is called before the invocation of each test method in the class. 15 | } 16 | 17 | override func tearDownWithError() throws { 18 | // Put teardown code here. This method is called after the invocation of each test method in the class. 19 | } 20 | 21 | func testExample() throws { 22 | // This is an example of a functional test case. 23 | // Use XCTAssert and related functions to verify your tests produce the correct results. 24 | // Any test you write for XCTest can be annotated as throws and async. 25 | // Mark your test throws to produce an unexpected failure when your test encounters an uncaught error. 26 | // Mark your test async to allow awaiting for asynchronous code to complete. Check the results with assertions afterwards. 27 | } 28 | 29 | func testPerformanceExample() throws { 30 | // This is an example of a performance test case. 31 | self.measure { 32 | // Put the code you want to measure the time of here. 33 | } 34 | } 35 | 36 | } 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Xcode 2 | # 3 | # gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore 4 | 5 | ## User settings 6 | xcuserdata/ 7 | 8 | ## Obj-C/Swift specific 9 | *.hmap 10 | 11 | ## App packaging 12 | *.ipa 13 | *.dSYM.zip 14 | *.dSYM 15 | 16 | ## Playgrounds 17 | timeline.xctimeline 18 | playground.xcworkspace 19 | 20 | # Swift Package Manager 21 | # 22 | # Add this line if you want to avoid checking in source code from Swift Package Manager dependencies. 23 | # Packages/ 24 | # Package.pins 25 | # Package.resolved 26 | # *.xcodeproj 27 | # 28 | # Xcode automatically generates this directory with a .xcworkspacedata file and xcuserdata 29 | # hence it is not needed unless you have added a package configuration file to your project 30 | .swiftpm 31 | 32 | .build/ 33 | 34 | # fastlane 35 | # 36 | # It is recommended to not store the screenshots in the git repo. 37 | # Instead, use fastlane to re-generate the screenshots whenever they are needed. 38 | # For more information about the recommended setup visit: 39 | # https://docs.fastlane.tools/best-practices/source-control/#source-control 40 | 41 | fastlane/report.xml 42 | fastlane/Preview.html 43 | fastlane/screenshots/**/*.png 44 | fastlane/test_output 45 | 46 | # Code Injection 47 | # 48 | # After new code Injection tools there's a generated folder /iOSInjectionProject 49 | # https://github.com/johnno1962/injectionforxcode 50 | 51 | iOSInjectionProject/ 52 | .DS_Store 53 | 54 | __pycache__ 55 | GenerateApp/build 56 | GuernikaTools/build 57 | GuernikaTools/guernikatools.egg-info 58 | GuernikaModelConverter/GuernikaTools 59 | *.dmg 60 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Assets.xcassets/AppIcon.appiconset/Contents.json: -------------------------------------------------------------------------------- 1 | { 2 | "images" : [ 3 | { 4 | "size" : "16x16", 5 | "idiom" : "mac", 6 | "filename" : "AppIcon-16.png", 7 | "scale" : "1x" 8 | }, 9 | { 10 | "size" : "16x16", 11 | "idiom" : "mac", 12 | "filename" : "AppIcon-32.png", 13 | "scale" : "2x" 14 | }, 15 | { 16 | "size" : "32x32", 17 | "idiom" : "mac", 18 | "filename" : "AppIcon-32.png", 19 | "scale" : "1x" 20 | }, 21 | { 22 | "size" : "32x32", 23 | "idiom" : "mac", 24 | "filename" : "AppIcon-64.png", 25 | "scale" : "2x" 26 | }, 27 | { 28 | "size" : "128x128", 29 | "idiom" : "mac", 30 | "filename" : "AppIcon-128.png", 31 | "scale" : "1x" 32 | }, 33 | { 34 | "size" : "128x128", 35 | "idiom" : "mac", 36 | "filename" : "AppIcon-256.png", 37 | "scale" : "2x" 38 | }, 39 | { 40 | "size" : "256x256", 41 | "idiom" : "mac", 42 | "filename" : "AppIcon-256.png", 43 | "scale" : "1x" 44 | }, 45 | { 46 | "size" : "256x256", 47 | "idiom" : "mac", 48 | "filename" : "AppIcon-512.png", 49 | "scale" : "2x" 50 | }, 51 | { 52 | "size" : "512x512", 53 | "idiom" : "mac", 54 | "filename" : "AppIcon-512.png", 55 | "scale" : "1x" 56 | }, 57 | { 58 | "size" : "512x512", 59 | "idiom" : "mac", 60 | "filename" : "AppIcon-1024.png", 61 | "scale" : "2x" 62 | } 63 | ], 64 | "info" : { 65 | "version" : 1, 66 | "author" : "xcode" 67 | } 68 | } -------------------------------------------------------------------------------- /GuernikaModelConverterUITests/GuernikaModelConverterUITests.swift: -------------------------------------------------------------------------------- 1 | // 2 | // GuernikaModelConverterUITests.swift 3 | // GuernikaModelConverterUITests 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import XCTest 9 | 10 | final class GuernikaModelConverterUITests: XCTestCase { 11 | 12 | override func setUpWithError() throws { 13 | // Put setup code here. This method is called before the invocation of each test method in the class. 14 | 15 | // In UI tests it is usually best to stop immediately when a failure occurs. 16 | continueAfterFailure = false 17 | 18 | // In UI tests it’s important to set the initial state - such as interface orientation - required for your tests before they run. The setUp method is a good place to do this. 19 | } 20 | 21 | override func tearDownWithError() throws { 22 | // Put teardown code here. This method is called after the invocation of each test method in the class. 23 | } 24 | 25 | func testExample() throws { 26 | // UI tests must launch the application that they test. 27 | let app = XCUIApplication() 28 | app.launch() 29 | 30 | // Use XCTAssert and related functions to verify your tests produce the correct results. 31 | } 32 | 33 | func testLaunchPerformance() throws { 34 | if #available(macOS 10.15, iOS 13.0, tvOS 13.0, watchOS 7.0, *) { 35 | // This measures how long it takes to launch your application. 36 | measure(metrics: [XCTApplicationLaunchMetric()]) { 37 | XCUIApplication().launch() 38 | } 39 | } 40 | } 41 | } 42 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Views/DestinationToolbarButton.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DestinationToolbarButton.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 31/3/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct DestinationToolbarButton: View { 11 | @Binding var showOutputPicker: Bool 12 | var outputLocation: URL? 13 | 14 | var body: some View { 15 | ZStack(alignment: .leading) { 16 | Image(systemName: "folder") 17 | .padding(.leading, 18) 18 | .frame(width: 16) 19 | .foregroundColor(.secondary) 20 | .onTapGesture { 21 | guard let outputLocation else { return } 22 | NSWorkspace.shared.open(outputLocation) 23 | } 24 | Button { 25 | showOutputPicker = true 26 | } label: { 27 | Text(outputLocation?.lastPathComponent ?? "Select destination") 28 | .frame(minWidth: 200) 29 | } 30 | .foregroundColor(.primary) 31 | .background(Color.primary.opacity(0.1)) 32 | .clipShape(RoundedRectangle(cornerRadius: 8, style: .continuous)) 33 | .padding(.leading, 34) 34 | }.background { 35 | RoundedRectangle(cornerRadius: 8, style: .continuous) 36 | .stroke(.secondary, lineWidth: 1) 37 | .opacity(0.4) 38 | } 39 | .help("Destination") 40 | .padding(.trailing, 8) 41 | } 42 | } 43 | 44 | struct DestinationToolbarButton_Previews: PreviewProvider { 45 | static var previews: some View { 46 | DestinationToolbarButton(showOutputPicker: .constant(false), outputLocation: nil) 47 | } 48 | } 49 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/Version.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Version.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 14/3/23. 6 | // 7 | 8 | import Foundation 9 | 10 | public struct Version: Comparable, Hashable, CustomStringConvertible { 11 | let components: [Int] 12 | public var major: Int { 13 | guard components.count > 0 else { return 0 } 14 | return components[0] 15 | } 16 | public var minor: Int { 17 | guard components.count > 1 else { return 0 } 18 | return components[1] 19 | } 20 | public var patch: Int { 21 | guard components.count > 2 else { return 0 } 22 | return components[2] 23 | } 24 | 25 | public var description: String { 26 | return components.map { $0.description }.joined(separator: ".") 27 | } 28 | 29 | public init(major: Int, minor: Int, patch: Int) { 30 | self.components = [major, minor, patch] 31 | } 32 | 33 | public static func == (lhs: Version, rhs: Version) -> Bool { 34 | return lhs.major == rhs.major && 35 | lhs.minor == rhs.minor && 36 | lhs.patch == rhs.patch 37 | } 38 | 39 | public static func < (lhs: Version, rhs: Version) -> Bool { 40 | return lhs.major < rhs.major || 41 | (lhs.major == rhs.major && lhs.minor < rhs.minor) || 42 | (lhs.major == rhs.major && lhs.minor == rhs.minor && lhs.patch < rhs.patch) 43 | } 44 | } 45 | 46 | extension Version: Codable { 47 | public init(from decoder: Decoder) throws { 48 | let container = try decoder.singleValueContainer() 49 | let stringValue = try container.decode(String.self) 50 | self.init(stringLiteral: stringValue) 51 | } 52 | 53 | public func encode(to encoder: Encoder) throws { 54 | var container = encoder.singleValueContainer() 55 | try container.encode(description) 56 | } 57 | } 58 | 59 | extension Version: ExpressibleByStringLiteral { 60 | public typealias StringLiteralType = String 61 | 62 | public init(stringLiteral value: StringLiteralType) { 63 | components = value.split(separator: ".").compactMap { Int($0) } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/Logger.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Logger.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | class Logger: ObservableObject { 11 | static var shared: Logger = Logger() 12 | 13 | enum LogLevel { 14 | case debug 15 | case info 16 | case warning 17 | case error 18 | case success 19 | 20 | var backgroundColor: Color { 21 | switch self { 22 | case .debug: 23 | return .secondary 24 | case .info: 25 | return .blue 26 | case .warning: 27 | return .orange 28 | case .error: 29 | return .red 30 | case .success: 31 | return .green 32 | } 33 | } 34 | } 35 | 36 | var isEmpty: Bool = true 37 | var previousContent: Text? 38 | @Published var content: Text = Text("") 39 | 40 | func append(_ line: String) { 41 | if line.starts(with: "INFO:") { 42 | append(String(line.replacing(try! Regex(#"INFO:.*:"#), with: "")), level: .info) 43 | } else if line.starts(with: "WARNING:") { 44 | append(String(line.replacing(try! Regex(#"WARNING:.*:"#), with: "")), level: .warning) 45 | } else if line.starts(with: "ERROR:") { 46 | append(String(line.replacing(try! Regex(#"ERROR:.*:"#), with: "")), level: .error) 47 | } else { 48 | append(line, level: .debug) 49 | } 50 | } 51 | 52 | func append(_ line: String, level: LogLevel) { 53 | if level == .success { 54 | if previousContent == nil { 55 | previousContent = content 56 | } 57 | content = previousContent! + Text(line + "\n").foregroundColor(level.backgroundColor) 58 | isEmpty = false 59 | print(line) 60 | } else { 61 | previousContent = nil 62 | if !line.isEmpty { 63 | content = content + Text(line + "\n").foregroundColor(level.backgroundColor) 64 | isEmpty = false 65 | print(line) 66 | } 67 | } 68 | } 69 | 70 | func clear() { 71 | content = Text("") 72 | previousContent = nil 73 | isEmpty = true 74 | } 75 | } 76 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Guernika Model Converter 2 | 3 | This repository contains a model converter compatible with [Guernika](https://apps.apple.com/app/id1660407508). 4 | 5 | ## Converting Models to Guernika 6 | 7 | **WARNING:** Xcode is required to convert models: 8 | 9 | - Make sure you have [Xcode](https://apps.apple.com/app/id497799835) installed. 10 | 11 | - Once installed run the following commands: 12 | 13 | ```shell 14 | sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer/ 15 | sudo xcodebuild -license accept 16 | ``` 17 | 18 | - You should now be ready to start converting models! 19 | 20 | **Step 1:** Download and install [`Guernika Model Converter`](https://huggingface.co/Guernika/CoreMLStableDiffusion/resolve/main/GuernikaModelConverter.dmg). 21 | 22 | [Guernika Model Converter icon](https://huggingface.co/Guernika/CoreMLStableDiffusion/resolve/main/GuernikaModelConverter.dmg) 23 | 24 | **Step 2:** Launch `Guernika Model Converter` from your `Applications` folder, this app may take a few seconds to load. 25 | 26 | **Step 3:** Once the app has loaded you will be able to select what model you want to convert: 27 | 28 | - You can input the model identifier (e.g. CompVis/stable-diffusion-v1-4) to download from Hugging Face. You may have to log in to or register for your [Hugging Face account](https://huggingface.co), generate a [User Access Token](https://huggingface.co/settings/tokens) and use this token to set up Hugging Face API access by running `huggingface-cli login` in a Terminal window. 29 | 30 | - You can select a local model from your machine: `Select local model` 31 | 32 | - You can select a local .CKPT model from your machine: `Select CKPT` 33 | 34 | Guernika Model Converter interface 35 | 36 | **Step 4:** Once you've chosen the model you want to convert you can choose what modules to convert and/or if you want to chunk the UNet module (recommended for iOS/iPadOS devices). 37 | 38 | **Step 5:** Once you're happy with your selection click `Convert to Guernika` and wait for the app to complete conversion. 39 | **WARNING:** This command may download several GB worth of PyTorch checkpoints from Hugging Face and may take a long time to complete (15-20 minutes on an M1 machine). 40 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Log/LogView.swift: -------------------------------------------------------------------------------- 1 | // 2 | // LogView.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct LogView: View { 11 | @ObservedObject var logger = Logger.shared 12 | @State var stickToBottom: Bool = true 13 | 14 | var body: some View { 15 | VStack { 16 | if logger.isEmpty { 17 | emptyView 18 | } else { 19 | contentView 20 | } 21 | }.navigationTitle("Log") 22 | .toolbar { 23 | ToolbarItemGroup { 24 | Button { 25 | logger.clear() 26 | } label: { 27 | Image(systemName: "trash") 28 | }.help("Clear log") 29 | .disabled(logger.isEmpty) 30 | Toggle(isOn: $stickToBottom) { 31 | Image(systemName: "dock.arrow.down.rectangle") 32 | }.help("Stick to bottom") 33 | } 34 | } 35 | } 36 | 37 | @ViewBuilder 38 | var emptyView: some View { 39 | Image(systemName: "moon.zzz.fill") 40 | .resizable() 41 | .aspectRatio(contentMode: .fit) 42 | .frame(width: 72) 43 | .opacity(0.3) 44 | .padding(8) 45 | Text("Log is empty") 46 | .font(.largeTitle) 47 | .opacity(0.3) 48 | } 49 | 50 | @ViewBuilder 51 | var contentView: some View { 52 | ScrollViewReader { proxy in 53 | ScrollView { 54 | logger.content 55 | .multilineTextAlignment(.leading) 56 | .font(.body.monospaced()) 57 | .textSelection(.enabled) 58 | .frame(maxWidth: .infinity, alignment: .leading) 59 | .padding() 60 | 61 | Divider().opacity(0) 62 | .id("bottom") 63 | }.onChange(of: logger.content) { _ in 64 | if stickToBottom { 65 | proxy.scrollTo("bottom", anchor: .bottom) 66 | } 67 | }.onChange(of: stickToBottom) { newValue in 68 | if newValue { 69 | proxy.scrollTo("bottom", anchor: .bottom) 70 | } 71 | }.onAppear { 72 | if stickToBottom { 73 | proxy.scrollTo("bottom", anchor: .bottom) 74 | } 75 | } 76 | } 77 | } 78 | } 79 | 80 | struct LogView_Previews: PreviewProvider { 81 | static var previews: some View { 82 | LogView() 83 | } 84 | } 85 | -------------------------------------------------------------------------------- /GuernikaTools/GuernikaTools.spec: -------------------------------------------------------------------------------- 1 | # -*- mode: python ; coding: utf-8 -*- 2 | from PyInstaller.utils.hooks import collect_all 3 | 4 | import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) 5 | 6 | datas = [] 7 | binaries = [] 8 | hiddenimports = [] 9 | tmp_ret = collect_all('regex') 10 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 11 | tmp_ret = collect_all('tqdm') 12 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 13 | tmp_ret = collect_all('requests') 14 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 15 | tmp_ret = collect_all('packaging') 16 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 17 | tmp_ret = collect_all('filelock') 18 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 19 | tmp_ret = collect_all('numpy') 20 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 21 | tmp_ret = collect_all('tokenizers') 22 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 23 | tmp_ret = collect_all('transformers') 24 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 25 | tmp_ret = collect_all('huggingface-hub') 26 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 27 | tmp_ret = collect_all('pyyaml') 28 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 29 | tmp_ret = collect_all('omegaconf') 30 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 31 | tmp_ret = collect_all('pytorch_lightning') 32 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 33 | tmp_ret = collect_all('pytorch-lightning') 34 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 35 | tmp_ret = collect_all('torch') 36 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 37 | tmp_ret = collect_all('safetensors') 38 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 39 | tmp_ret = collect_all('pillow') 40 | datas += tmp_ret[0]; binaries += tmp_ret[1]; hiddenimports += tmp_ret[2] 41 | 42 | block_cipher = None 43 | 44 | 45 | a = Analysis( 46 | ['./guernikatools/torch2coreml.py'], 47 | pathex=[], 48 | binaries=binaries, 49 | datas=datas, 50 | hiddenimports=hiddenimports, 51 | hookspath=[], 52 | hooksconfig={}, 53 | runtime_hooks=[], 54 | excludes=[], 55 | win_no_prefer_redirects=False, 56 | win_private_assemblies=False, 57 | cipher=block_cipher, 58 | noarchive=False, 59 | ) 60 | pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) 61 | 62 | exe = EXE( 63 | pyz, 64 | a.scripts, 65 | a.binaries, 66 | a.zipfiles, 67 | a.datas, 68 | [], 69 | name='GuernikaTools', 70 | debug=False, 71 | bootloader_ignore_signals=False, 72 | strip=False, 73 | upx=True, 74 | upx_exclude=[], 75 | runtime_tmpdir=None, 76 | console=False, 77 | disable_windowed_traceback=False, 78 | argv_emulation=False, 79 | target_arch=None, 80 | codesign_identity=None, 81 | entitlements_file=None, 82 | ) 83 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/models/layer_norm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | # Reference: https://github.com/apple/ml-ane-transformers/blob/main/ane_transformers/reference/layer_norm.py 11 | class LayerNormANE(nn.Module): 12 | """ LayerNorm optimized for Apple Neural Engine (ANE) execution 13 | 14 | Note: This layer only supports normalization over the final dim. It expects `num_channels` 15 | as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`. 16 | """ 17 | 18 | def __init__(self, 19 | num_channels, 20 | clip_mag=None, 21 | eps=1e-5, 22 | elementwise_affine=True): 23 | """ 24 | Args: 25 | num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length. 26 | clip_mag: Optional float value to use for clamping the input range before layer norm is applied. 27 | If specified, helps reduce risk of overflow. 28 | eps: Small value to avoid dividing by zero 29 | elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters 30 | """ 31 | super().__init__() 32 | # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine) 33 | self.expected_rank = len("BC1S") 34 | 35 | self.num_channels = num_channels 36 | self.eps = eps 37 | self.clip_mag = clip_mag 38 | self.elementwise_affine = elementwise_affine 39 | 40 | if self.elementwise_affine: 41 | self.weight = nn.Parameter(torch.Tensor(num_channels)) 42 | self.bias = nn.Parameter(torch.Tensor(num_channels)) 43 | 44 | self._reset_parameters() 45 | 46 | def _reset_parameters(self): 47 | if self.elementwise_affine: 48 | nn.init.ones_(self.weight) 49 | nn.init.zeros_(self.bias) 50 | 51 | def forward(self, inputs): 52 | input_rank = len(inputs.size()) 53 | 54 | # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine) 55 | # Migrate the data format from BSC to BC1S (most conducive to ANE) 56 | if input_rank == 3 and inputs.size(2) == self.num_channels: 57 | inputs = inputs.transpose(1, 2).unsqueeze(2) 58 | input_rank = len(inputs.size()) 59 | 60 | assert input_rank == self.expected_rank 61 | assert inputs.size(1) == self.num_channels 62 | 63 | if self.clip_mag is not None: 64 | inputs.clamp_(-self.clip_mag, self.clip_mag) 65 | 66 | channels_mean = inputs.mean(dim=1, keepdim=True) 67 | 68 | zero_mean = inputs - channels_mean 69 | 70 | zero_mean_sq = zero_mean * zero_mean 71 | 72 | denom = (zero_mean_sq.mean(dim=1, keepdim=True) + self.eps).rsqrt() 73 | 74 | out = zero_mean * denom 75 | 76 | if self.elementwise_affine: 77 | out = (out + self.bias.view(1, self.num_channels, 1, 1)) * self.weight.view(1, self.num_channels, 1, 1) 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Navigation/Sidebar.swift: -------------------------------------------------------------------------------- 1 | // 2 | // Sidebar.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 19/3/23. 6 | // 7 | 8 | import Cocoa 9 | import SwiftUI 10 | 11 | enum Panel: Hashable { 12 | case model 13 | case controlNet 14 | case log 15 | } 16 | 17 | struct Sidebar: View { 18 | @Binding var path: NavigationPath 19 | @Binding var selection: Panel 20 | @State var showUpdateButton: Bool = false 21 | 22 | var body: some View { 23 | List(selection: $selection) { 24 | NavigationLink(value: Panel.model) { 25 | Label("Model", systemImage: "shippingbox") 26 | } 27 | NavigationLink(value: Panel.controlNet) { 28 | Label("ControlNet", systemImage: "cube.transparent") 29 | } 30 | NavigationLink(value: Panel.log) { 31 | Label("Log", systemImage: "doc.text.below.ecg") 32 | } 33 | } 34 | .safeAreaInset(edge: .bottom, content: { 35 | VStack(spacing: 12) { 36 | if showUpdateButton { 37 | Button { 38 | NSWorkspace.shared.open(URL(string: "https://huggingface.co/Guernika/CoreMLStableDiffusion/blob/main/GuernikaModelConverter.dmg")!) 39 | } label: { 40 | Text("Update available") 41 | .frame(minWidth: 168) 42 | }.controlSize(.large) 43 | .buttonStyle(.borderedProminent) 44 | } 45 | if !FileManager.default.fileExists(atPath: "/Applications/Guernika.app") { 46 | Button { 47 | NSWorkspace.shared.open(URL(string: "macappstore://apps.apple.com/app/id1660407508")!) 48 | } label: { 49 | Text("Install Guernika") 50 | .frame(minWidth: 168) 51 | }.controlSize(.large) 52 | } 53 | } 54 | .padding(16) 55 | }) 56 | .navigationTitle("Guernika Model Converter") 57 | .navigationSplitViewColumnWidth(min: 200, ideal: 200) 58 | .onAppear { checkForUpdate() } 59 | } 60 | 61 | func checkForUpdate() { 62 | Task.detached { 63 | let versionUrl = URL(string: "https://huggingface.co/Guernika/CoreMLStableDiffusion/raw/main/GuernikaModelConverter.version")! 64 | guard let lastVersionString = try? String(contentsOf: versionUrl) else { return } 65 | let lastVersion = Version(stringLiteral: lastVersionString) 66 | let currentVersionString = Bundle.main.object(forInfoDictionaryKey: "CFBundleShortVersionString") as? String ?? "" 67 | let currentVersion = Version(stringLiteral: currentVersionString) 68 | await MainActor.run { 69 | withAnimation { 70 | showUpdateButton = currentVersion < lastVersion 71 | } 72 | } 73 | } 74 | } 75 | } 76 | 77 | struct Sidebar_Previews: PreviewProvider { 78 | struct Preview: View { 79 | @State private var selection: Panel = Panel.model 80 | var body: some View { 81 | Sidebar(path: .constant(NavigationPath()), selection: $selection) 82 | } 83 | } 84 | 85 | static var previews: some View { 86 | NavigationSplitView { 87 | Preview() 88 | } detail: { 89 | Text("Detail!") 90 | } 91 | } 92 | } 93 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Views/IntegerField.swift: -------------------------------------------------------------------------------- 1 | // 2 | // IntegerField.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 23/1/22. 6 | // 7 | 8 | import SwiftUI 9 | 10 | struct LabeledIntegerField: View { 11 | @Binding var value: Int 12 | var step: Int = 1 13 | var minValue: Int? 14 | var maxValue: Int? 15 | @ViewBuilder var label: () -> Content 16 | 17 | init( 18 | _ titleKey: LocalizedStringKey, 19 | value: Binding, 20 | step: Int = 1, 21 | minValue: Int? = nil, 22 | maxValue: Int? = nil 23 | ) where Content == Text { 24 | self.init(value: value, step: step, minValue: minValue, maxValue: maxValue, label: { 25 | Text(titleKey) 26 | }) 27 | } 28 | 29 | init( 30 | value: Binding, 31 | step: Int = 1, 32 | minValue: Int? = nil, 33 | maxValue: Int? = nil, 34 | @ViewBuilder label: @escaping () -> Content 35 | ) { 36 | self.label = label 37 | self.step = step 38 | self.minValue = minValue 39 | self.maxValue = maxValue 40 | self._value = value 41 | } 42 | 43 | var body: some View { 44 | LabeledContent(content: { 45 | IntegerField( 46 | value: $value, 47 | step: step, 48 | minValue: minValue, 49 | maxValue: maxValue 50 | ) 51 | .frame(maxWidth: 120) 52 | }, label: label) 53 | } 54 | } 55 | 56 | struct IntegerField: View { 57 | @Binding var value: Int 58 | var step: Int = 1 59 | var minValue: Int? 60 | var maxValue: Int? 61 | @State private var text: String 62 | @FocusState private var isFocused: Bool 63 | 64 | init( 65 | value: Binding, 66 | step: Int = 1, 67 | minValue: Int? = nil, 68 | maxValue: Int? = nil 69 | ) { 70 | self.step = step 71 | self.minValue = minValue 72 | self.maxValue = maxValue 73 | self._value = value 74 | let text = String(describing: value.wrappedValue) 75 | self._text = State(wrappedValue: text) 76 | } 77 | 78 | var body: some View { 79 | HStack(spacing: 0) { 80 | TextField("", text: $text, prompt: Text("Value")) 81 | .multilineTextAlignment(.trailing) 82 | .textFieldStyle(.plain) 83 | .padding(.horizontal, 10) 84 | .submitLabel(.done) 85 | .focused($isFocused) 86 | .frame(minWidth: 70) 87 | .labelsHidden() 88 | #if !os(macOS) 89 | .keyboardType(.numberPad) 90 | #endif 91 | Stepper(label: {}, onIncrement: { 92 | if let maxValue { 93 | value = min(value + step, maxValue) 94 | } else { 95 | value += step 96 | } 97 | }, onDecrement: { 98 | if let minValue { 99 | value = max(value - step, minValue) 100 | } else { 101 | value -= step 102 | } 103 | }).labelsHidden() 104 | } 105 | #if os(macOS) 106 | .padding(3) 107 | .background(Color.primary.opacity(0.06)) 108 | .clipShape(RoundedRectangle(cornerRadius: 8, style: .continuous)) 109 | #else 110 | .padding(2) 111 | .background(Color.primary.opacity(0.05)) 112 | .clipShape(RoundedRectangle(cornerRadius: 10, style: .continuous)) 113 | #endif 114 | .onSubmit { 115 | updateValue() 116 | isFocused = false 117 | } 118 | .onChange(of: isFocused) { focused in 119 | if !focused { 120 | updateValue() 121 | } 122 | } 123 | .onChange(of: value) { _ in updateText() } 124 | .onChange(of: text) { _ in 125 | if let newValue = Int(text), value != newValue { 126 | value = newValue 127 | } 128 | } 129 | #if !os(macOS) 130 | .toolbar { 131 | if isFocused { 132 | ToolbarItem(placement: .keyboard) { 133 | HStack { 134 | Spacer() 135 | Button("Done") { 136 | isFocused = false 137 | } 138 | } 139 | } 140 | } 141 | } 142 | #endif 143 | } 144 | 145 | private func updateValue() { 146 | if let newValue = Int(text) { 147 | if let maxValue, newValue > maxValue { 148 | value = maxValue 149 | } else if let minValue, newValue < minValue { 150 | value = minValue 151 | } else { 152 | value = newValue 153 | } 154 | } 155 | } 156 | 157 | private func updateText() { 158 | text = String(describing: value) 159 | } 160 | } 161 | 162 | struct ValueField_Previews: PreviewProvider { 163 | static var previews: some View { 164 | LabeledIntegerField("Text", value: .constant(20)) 165 | } 166 | } 167 | 168 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/convert/convert_text_encoder.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from guernikatools.utils import utils 8 | 9 | from collections import OrderedDict, defaultdict 10 | from copy import deepcopy 11 | import gc 12 | 13 | import logging 14 | 15 | logging.basicConfig() 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | import numpy as np 20 | import os 21 | 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | 26 | torch.set_grad_enabled(False) 27 | 28 | 29 | def main(tokenizer, text_encoder, args, model_name="text_encoder"): 30 | """ Converts the text encoder component of Stable Diffusion 31 | """ 32 | out_path = utils.get_out_path(args, model_name) 33 | if os.path.exists(out_path): 34 | logger.info( 35 | f"`text_encoder` already exists at {out_path}, skipping conversion." 36 | ) 37 | return 38 | 39 | # Create sample inputs for tracing, conversion and correctness verification 40 | text_encoder_sequence_length = tokenizer.model_max_length 41 | text_encoder_hidden_size = text_encoder.config.hidden_size 42 | 43 | sample_text_encoder_inputs = { 44 | "input_ids": 45 | torch.randint( 46 | text_encoder.config.vocab_size, 47 | (1, text_encoder_sequence_length), 48 | # https://github.com/apple/coremltools/issues/1423 49 | dtype=torch.float32, 50 | ) 51 | } 52 | sample_text_encoder_inputs_spec = { 53 | k: (v.shape, v.dtype) 54 | for k, v in sample_text_encoder_inputs.items() 55 | } 56 | logger.info(f"Sample inputs spec: {sample_text_encoder_inputs_spec}") 57 | 58 | class TextEncoder(nn.Module): 59 | 60 | def __init__(self): 61 | super().__init__() 62 | self.text_encoder = text_encoder 63 | 64 | def forward(self, input_ids): 65 | return text_encoder(input_ids, return_dict=False) 66 | 67 | class TextEncoderXL(nn.Module): 68 | 69 | def __init__(self): 70 | super().__init__() 71 | self.text_encoder = text_encoder 72 | 73 | def forward(self, input_ids): 74 | output = text_encoder(input_ids, output_hidden_states=True) 75 | return (output.hidden_states[-2], output[0]) 76 | 77 | reference_text_encoder = TextEncoderXL().eval() if args.model_is_sdxl else TextEncoder().eval() 78 | 79 | logger.info("JIT tracing {model_name}..") 80 | reference_text_encoder = torch.jit.trace( 81 | reference_text_encoder, 82 | (sample_text_encoder_inputs["input_ids"].to(torch.int32), ), 83 | ) 84 | logger.info("Done.") 85 | 86 | sample_coreml_inputs = utils.get_coreml_inputs(sample_text_encoder_inputs) 87 | coreml_text_encoder, out_path = utils.convert_to_coreml( 88 | model_name, reference_text_encoder, sample_coreml_inputs, 89 | ["last_hidden_state", "pooled_outputs"], args.precision_full, args 90 | ) 91 | 92 | # Set model metadata 93 | coreml_text_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 94 | coreml_text_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)" 95 | coreml_text_encoder.version = args.model_version 96 | coreml_text_encoder.short_description = \ 97 | "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \ 98 | "Please refer to https://arxiv.org/abs/2112.10752 for details." 99 | 100 | # Set the input descriptions 101 | coreml_text_encoder.input_description["input_ids"] = "The token ids that represent the input text" 102 | 103 | # Set the output descriptions 104 | coreml_text_encoder.output_description["last_hidden_state"] = "The token embeddings as encoded by the Transformer model" 105 | coreml_text_encoder.output_description["pooled_outputs"] = "The version of the `last_hidden_state` output after pooling" 106 | 107 | # Set package version metadata 108 | coreml_text_encoder.user_defined_metadata["identifier"] = args.model_version 109 | coreml_text_encoder.user_defined_metadata["converter_version"] = __version__ 110 | coreml_text_encoder.user_defined_metadata["attention_implementation"] = args.attention_implementation 111 | coreml_text_encoder.user_defined_metadata["compute_unit"] = args.compute_unit 112 | coreml_text_encoder.user_defined_metadata["hidden_size"] = str(text_encoder.config.hidden_size) 113 | 114 | coreml_text_encoder.save(out_path) 115 | 116 | logger.info(f"Saved {model_name} into {out_path}") 117 | 118 | # Parity check PyTorch vs CoreML 119 | if args.check_output_correctness: 120 | baseline_out = text_encoder( 121 | sample_text_encoder_inputs["input_ids"].to(torch.int32), 122 | return_dict=False, 123 | )[1].numpy() 124 | 125 | coreml_out = list(coreml_text_encoder.predict({ 126 | k: v.numpy() for k, v in sample_text_encoder_inputs.items() 127 | }).values())[0] 128 | utils.report_correctness(baseline_out, coreml_out, "{model_name} baseline PyTorch to reference CoreML") 129 | 130 | del reference_text_encoder, coreml_text_encoder, text_encoder 131 | gc.collect() 132 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Views/DecimalField.swift: -------------------------------------------------------------------------------- 1 | // 2 | // DecimalField.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 5/2/23. 6 | // 7 | 8 | import SwiftUI 9 | 10 | extension Formatter { 11 | static let decimal: NumberFormatter = { 12 | let formatter = NumberFormatter() 13 | formatter.numberStyle = .decimal 14 | formatter.usesGroupingSeparator = false 15 | formatter.maximumFractionDigits = 2 16 | return formatter 17 | }() 18 | } 19 | 20 | struct LabeledDecimalField: View { 21 | @Binding var value: Double 22 | var step: Double = 1 23 | var minValue: Double? 24 | var maxValue: Double? 25 | @ViewBuilder var label: () -> Content 26 | 27 | init( 28 | _ titleKey: LocalizedStringKey, 29 | value: Binding, 30 | step: Double = 1, 31 | minValue: Double? = nil, 32 | maxValue: Double? = nil 33 | ) where Content == Text { 34 | self.init(value: value, step: step, minValue: minValue, maxValue: maxValue, label: { 35 | Text(titleKey) 36 | }) 37 | } 38 | 39 | init( 40 | value: Binding, 41 | step: Double = 1, 42 | minValue: Double? = nil, 43 | maxValue: Double? = nil, 44 | @ViewBuilder label: @escaping () -> Content 45 | ) { 46 | self.label = label 47 | self.step = step 48 | self.minValue = minValue 49 | self.maxValue = maxValue 50 | self._value = value 51 | } 52 | 53 | var body: some View { 54 | LabeledContent(content: { 55 | DecimalField( 56 | value: $value, 57 | step: step, 58 | minValue: minValue, 59 | maxValue: maxValue 60 | ) 61 | }, label: label) 62 | } 63 | } 64 | 65 | struct DecimalField: View { 66 | @Binding var value: Double 67 | var step: Double = 1 68 | var minValue: Double? 69 | var maxValue: Double? 70 | @State private var text: String 71 | @FocusState private var isFocused: Bool 72 | 73 | init( 74 | value: Binding, 75 | step: Double = 1, 76 | minValue: Double? = nil, 77 | maxValue: Double? = nil 78 | ) { 79 | self.step = step 80 | self.minValue = minValue 81 | self.maxValue = maxValue 82 | self._value = value 83 | let text = Formatter.decimal.string(from: value.wrappedValue as NSNumber) ?? "" 84 | self._text = State(wrappedValue: text) 85 | } 86 | 87 | var body: some View { 88 | HStack(spacing: 0) { 89 | TextField("", text: $text, prompt: Text("Value")) 90 | .multilineTextAlignment(.trailing) 91 | .textFieldStyle(.plain) 92 | .padding(.horizontal, 10) 93 | .submitLabel(.done) 94 | .focused($isFocused) 95 | .frame(minWidth: 70) 96 | .labelsHidden() 97 | #if !os(macOS) 98 | .keyboardType(.decimalPad) 99 | #endif 100 | Stepper(label: {}, onIncrement: { 101 | if let maxValue { 102 | value = min(value + step, maxValue) 103 | } else { 104 | value += step 105 | } 106 | }, onDecrement: { 107 | if let minValue { 108 | value = max(value - step, minValue) 109 | } else { 110 | value -= step 111 | } 112 | }).labelsHidden() 113 | } 114 | #if os(macOS) 115 | .padding(3) 116 | .background(Color.primary.opacity(0.06)) 117 | .clipShape(RoundedRectangle(cornerRadius: 8, style: .continuous)) 118 | #else 119 | .padding(2) 120 | .background(Color.primary.opacity(0.05)) 121 | .clipShape(RoundedRectangle(cornerRadius: 10, style: .continuous)) 122 | #endif 123 | .onSubmit { 124 | updateValue() 125 | isFocused = false 126 | } 127 | .onChange(of: isFocused) { focused in 128 | if !focused { 129 | updateValue() 130 | } 131 | } 132 | .onChange(of: value) { _ in updateText() } 133 | .onChange(of: text) { _ in 134 | if let newValue = Formatter.decimal.number(from: text)?.doubleValue, value != newValue { 135 | value = newValue 136 | } 137 | } 138 | #if !os(macOS) 139 | .toolbar { 140 | if isFocused { 141 | ToolbarItem(placement: .keyboard) { 142 | HStack { 143 | Spacer() 144 | Button("Done") { 145 | isFocused = false 146 | } 147 | } 148 | } 149 | } 150 | } 151 | #endif 152 | } 153 | 154 | private func updateValue() { 155 | if let newValue = Formatter.decimal.number(from: text)?.doubleValue { 156 | if let maxValue, newValue > maxValue { 157 | value = maxValue 158 | } else if let minValue, newValue < minValue { 159 | value = minValue 160 | } else { 161 | value = newValue 162 | } 163 | } 164 | } 165 | 166 | private func updateText() { 167 | text = Formatter.decimal.string(from: value as NSNumber) ?? "" 168 | } 169 | } 170 | 171 | 172 | struct DecimalField_Previews: PreviewProvider { 173 | static var previews: some View { 174 | DecimalField(value: .constant(20)) 175 | .padding() 176 | } 177 | } 178 | -------------------------------------------------------------------------------- /GuernikaModelConverter/ConvertControlNet/ConvertControlNetViewModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ConvertControlNetViewModel.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import SwiftUI 9 | import Combine 10 | 11 | class ConvertControlNetViewModel: ObservableObject { 12 | @Published var showSuccess: Bool = false 13 | var showError: Bool { 14 | get { error != nil } 15 | set { 16 | if !newValue { 17 | error = nil 18 | isCoreMLError = false 19 | } 20 | } 21 | } 22 | var isCoreMLError: Bool = false 23 | @Published var error: String? 24 | 25 | var isReady: Bool { 26 | switch controlNetOrigin { 27 | case .huggingface: return !huggingfaceIdentifier.isEmpty 28 | case .diffusers: return diffusersLocation != nil 29 | case .checkpoint: return checkpointLocation != nil 30 | } 31 | } 32 | @Published var process: ConverterProcess? 33 | var isRunning: Bool { process != nil } 34 | 35 | @Published var showOutputPicker: Bool = false 36 | @AppStorage("output_location") var outputLocation: URL? 37 | 38 | var controlNetOrigin: ModelOrigin { 39 | get { ModelOrigin(rawValue: controlNetOriginString) ?? .huggingface } 40 | set { controlNetOriginString = newValue.rawValue } 41 | } 42 | @AppStorage("controlnet_origin") var controlNetOriginString: String = ModelOrigin.huggingface.rawValue 43 | @AppStorage("controlnet_huggingface_identifier") var huggingfaceIdentifier: String = "" 44 | @AppStorage("controlnet_diffusers_location") var diffusersLocation: URL? 45 | @AppStorage("controlnet_checkpoint_location") var checkpointLocation: URL? 46 | var selectedControlNet: String? { 47 | switch controlNetOrigin { 48 | case .huggingface: return nil 49 | case .diffusers: return diffusersLocation?.lastPathComponent 50 | case .checkpoint: return checkpointLocation?.lastPathComponent 51 | } 52 | } 53 | 54 | var computeUnits: ComputeUnits { 55 | get { ComputeUnits(rawValue: computeUnitsString) ?? .cpuAndNeuralEngine } 56 | set { computeUnitsString = newValue.rawValue } 57 | } 58 | @AppStorage("compute_units") var computeUnitsString: String = ComputeUnits.cpuAndNeuralEngine.rawValue 59 | 60 | var compression: Compression { 61 | get { Compression(rawValue: compressionString) ?? .fullSize } 62 | set { compressionString = newValue.rawValue } 63 | } 64 | @AppStorage("controlnet_compression") var compressionString: String = Compression.fullSize.rawValue 65 | 66 | @AppStorage("custom_size") var customSize: Bool = false 67 | @AppStorage("custom_size_width") var customWidth: Int = 512 68 | @AppStorage("custom_size_height") var customHeight: Int = 512 69 | @AppStorage("controlnet_multisize") var multisize: Bool = false 70 | 71 | private var cancellables: Set = [] 72 | 73 | func start() { 74 | guard checkCoreMLCompiler() else { 75 | isCoreMLError = true 76 | error = "CoreMLCompiler not available.\nMake sure you have Xcode installed and you run \"sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer/\" on a Terminal." 77 | return 78 | } 79 | do { 80 | let process = try ConverterProcess( 81 | outputLocation: outputLocation, 82 | controlNetOrigin: controlNetOrigin, 83 | huggingfaceIdentifier: huggingfaceIdentifier, 84 | diffusersLocation: diffusersLocation, 85 | checkpointLocation: checkpointLocation, 86 | computeUnits: computeUnits, 87 | customWidth: customSize && !multisize && computeUnits == .cpuAndGPU ? customWidth : nil, 88 | customHeight: customSize && !multisize && computeUnits == .cpuAndGPU ? customHeight : nil, 89 | multisize: multisize && computeUnits == .cpuAndGPU, 90 | compression: compression 91 | ) 92 | process.objectWillChange 93 | .receive(on: DispatchQueue.main) 94 | .sink { _ in 95 | self.objectWillChange.send() 96 | }.store(in: &cancellables) 97 | NotificationCenter.default.publisher(for: Process.didTerminateNotification, object: process.process) 98 | .receive(on: DispatchQueue.main) 99 | .sink { _ in 100 | withAnimation { 101 | if process.didComplete { 102 | self.showSuccess = true 103 | } 104 | self.cancel() 105 | } 106 | }.store(in: &cancellables) 107 | try process.start() 108 | withAnimation { 109 | self.process = process 110 | } 111 | } catch ConverterProcess.ArgumentError.noOutputLocation { 112 | showOutputPicker = true 113 | } catch ConverterProcess.ArgumentError.noHuggingfaceIdentifier { 114 | self.error = "Enter a valid identifier" 115 | } catch ConverterProcess.ArgumentError.noDiffusersLocation { 116 | self.error = "Enter a valid location" 117 | } catch ConverterProcess.ArgumentError.noCheckpointLocation { 118 | self.error = "Enter a valid location" 119 | } catch { 120 | print(error.localizedDescription) 121 | self.error = error.localizedDescription 122 | withAnimation { 123 | cancel() 124 | } 125 | } 126 | } 127 | 128 | func cancel() { 129 | withAnimation { 130 | process?.cancel() 131 | process = nil 132 | } 133 | } 134 | 135 | func checkCoreMLCompiler() -> Bool { 136 | let process = Process() 137 | process.executableURL = URL(string: "file:///usr/bin/xcrun") 138 | process.arguments = ["coremlcompiler", "version"] 139 | do { 140 | try process.run() 141 | } catch { 142 | Logger.shared.append("coremlcompiler not found") 143 | return false 144 | } 145 | return true 146 | } 147 | } 148 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/convert/convert_t2i_adapter.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from guernikatools.utils import utils 8 | from guernikatools.models import attention 9 | 10 | from diffusers import T2IAdapter 11 | 12 | from collections import OrderedDict, defaultdict 13 | from copy import deepcopy 14 | import coremltools as ct 15 | import gc 16 | 17 | import logging 18 | 19 | logging.basicConfig() 20 | logger = logging.getLogger(__name__) 21 | logger.setLevel(logging.INFO) 22 | 23 | import numpy as np 24 | import os 25 | 26 | import torch 27 | import torch.nn as nn 28 | import torch.nn.functional as F 29 | 30 | torch.set_grad_enabled(False) 31 | 32 | 33 | def main(base_adapter, args): 34 | """ Converts a T2IAdapter 35 | """ 36 | out_path = utils.get_out_path(args, "t2i_adapter") 37 | if os.path.exists(out_path): 38 | logger.info(f"`t2i_adapter` already exists at {out_path}, skipping conversion.") 39 | return 40 | 41 | # Register the selected attention implementation globally 42 | attention.ATTENTION_IMPLEMENTATION_IN_EFFECT = attention.AttentionImplementations[args.attention_implementation] 43 | logger.info(f"Attention implementation in effect: {attention.ATTENTION_IMPLEMENTATION_IN_EFFECT}") 44 | 45 | # Prepare sample input shapes and values 46 | batch_size = 2 # for classifier-free guidance 47 | adapter_type = base_adapter.config.adapter_type 48 | adapter_in_channels = base_adapter.config.in_channels 49 | 50 | input_shape = ( 51 | 1, # B 52 | adapter_in_channels, # C 53 | args.output_h, # H 54 | args.output_w, # W 55 | ) 56 | 57 | sample_adapter_inputs = { 58 | "input": torch.rand(*input_shape, dtype=torch.float16) 59 | } 60 | sample_adapter_inputs_spec = { 61 | k: (v.shape, v.dtype) 62 | for k, v in sample_adapter_inputs.items() 63 | } 64 | logger.info(f"Sample inputs spec: {sample_adapter_inputs_spec}") 65 | 66 | # Initialize reference adapter 67 | reference_adapter = T2IAdapter(**base_adapter.config).eval() 68 | load_state_dict_summary = reference_adapter.load_state_dict(base_adapter.state_dict()) 69 | 70 | # Prepare inputs 71 | baseline_sample_adapter_inputs = deepcopy(sample_adapter_inputs) 72 | 73 | # JIT trace 74 | logger.info("JIT tracing..") 75 | reference_adapter = torch.jit.trace(reference_adapter, (sample_adapter_inputs["input"].to(torch.float32), )) 76 | logger.info("Done.") 77 | 78 | if args.check_output_correctness: 79 | baseline_out = base_adapter(**baseline_sample_adapter_inputs, return_dict=False)[0].numpy() 80 | reference_out = reference_adapter(**sample_adapter_inputs)[0].numpy() 81 | utils.report_correctness(baseline_out, reference_out, "control baseline to reference PyTorch") 82 | 83 | del base_adapter 84 | gc.collect() 85 | 86 | coreml_sample_adapter_inputs = { 87 | k: v.numpy().astype(np.float16) 88 | for k, v in sample_adapter_inputs.items() 89 | } 90 | 91 | if args.multisize: 92 | input_size = args.output_h 93 | input_shape = ct.Shape(shape=( 94 | 1, 95 | adapter_in_channels, 96 | ct.RangeDim(int(input_size * 0.5), upper_bound=int(input_size * 2), default=input_size), 97 | ct.RangeDim(int(input_size * 0.5), upper_bound=int(input_size * 2), default=input_size) 98 | )) 99 | 100 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_adapter_inputs, {"input": input_shape}) 101 | else: 102 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_adapter_inputs) 103 | output_names = [ 104 | "adapter_res_samples_00", "adapter_res_samples_01", 105 | "adapter_res_samples_02", "adapter_res_samples_03" 106 | ] 107 | coreml_adapter, out_path = utils.convert_to_coreml( 108 | "t2i_adapter", 109 | reference_adapter, 110 | sample_coreml_inputs, 111 | output_names, 112 | args.precision_full, 113 | args 114 | ) 115 | del reference_adapter 116 | gc.collect() 117 | 118 | # Set model metadata 119 | coreml_adapter.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 120 | coreml_adapter.license = "T2IAdapter (https://github.com/TencentARC/T2I-Adapter)" 121 | coreml_adapter.version = args.t2i_adapter_version 122 | coreml_adapter.short_description = \ 123 | "T2IAdapter is a neural network structure to control diffusion models by adding extra conditions. " \ 124 | "Please refer to https://github.com/TencentARC/T2I-Adapter for details." 125 | 126 | # Set the input descriptions 127 | coreml_adapter.input_description["input"] = "Image used to condition adapter output" 128 | 129 | # Set the output descriptions 130 | coreml_adapter.output_description["adapter_res_samples_00"] = "Residual sample from T2IAdapter" 131 | coreml_adapter.output_description["adapter_res_samples_01"] = "Residual sample from T2IAdapter" 132 | coreml_adapter.output_description["adapter_res_samples_02"] = "Residual sample from T2IAdapter" 133 | coreml_adapter.output_description["adapter_res_samples_03"] = "Residual sample from T2IAdapter" 134 | 135 | # Set package version metadata 136 | coreml_adapter.user_defined_metadata["identifier"] = args.t2i_adapter_version 137 | coreml_adapter.user_defined_metadata["converter_version"] = __version__ 138 | coreml_adapter.user_defined_metadata["attention_implementation"] = args.attention_implementation 139 | coreml_adapter.user_defined_metadata["compute_unit"] = args.compute_unit 140 | coreml_adapter.user_defined_metadata["adapter_type"] = adapter_type 141 | adapter_method = conditioning_method_from(args.t2i_adapter_version) 142 | if adapter_method: 143 | coreml_adapter.user_defined_metadata["method"] = adapter_method 144 | 145 | coreml_adapter.save(out_path) 146 | logger.info(f"Saved adapter into {out_path}") 147 | 148 | # Parity check PyTorch vs CoreML 149 | if args.check_output_correctness: 150 | coreml_out = list(coreml_adapter.predict(coreml_sample_adapter_inputs).values())[0] 151 | utils.report_correctness(baseline_out, coreml_out, "control baseline PyTorch to reference CoreML") 152 | 153 | del coreml_adapter 154 | gc.collect() 155 | -------------------------------------------------------------------------------- /GuernikaModelConverter/ConvertModel/ConvertModelViewModel.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ConvertModelViewModel.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import SwiftUI 9 | import Combine 10 | 11 | class ConvertModelViewModel: ObservableObject { 12 | @Published var showSuccess: Bool = false 13 | var showError: Bool { 14 | get { error != nil } 15 | set { 16 | if !newValue { 17 | error = nil 18 | isCoreMLError = false 19 | } 20 | } 21 | } 22 | var isCoreMLError: Bool = false 23 | @Published var error: String? 24 | 25 | var isReady: Bool { 26 | guard convertUnet || convertTextEncoder || convertVaeEncoder || convertVaeDecoder || convertSafetyChecker else { 27 | return false 28 | } 29 | switch modelOrigin { 30 | case .huggingface: return !huggingfaceIdentifier.isEmpty 31 | case .diffusers: return diffusersLocation != nil 32 | case .checkpoint: return checkpointLocation != nil 33 | } 34 | } 35 | @Published var process: ConverterProcess? 36 | var isRunning: Bool { process != nil } 37 | 38 | @Published var showOutputPicker: Bool = false 39 | @AppStorage("output_location") var outputLocation: URL? 40 | var modelOrigin: ModelOrigin { 41 | get { ModelOrigin(rawValue: modelOriginString) ?? .huggingface } 42 | set { modelOriginString = newValue.rawValue } 43 | } 44 | @AppStorage("model_origin") var modelOriginString: String = ModelOrigin.huggingface.rawValue 45 | @AppStorage("huggingface_identifier") var huggingfaceIdentifier: String = "" 46 | @AppStorage("diffusers_location") var diffusersLocation: URL? 47 | @AppStorage("checkpoint_location") var checkpointLocation: URL? 48 | var selectedModel: String? { 49 | switch modelOrigin { 50 | case .huggingface: return nil 51 | case .diffusers: return diffusersLocation?.lastPathComponent 52 | case .checkpoint: return checkpointLocation?.lastPathComponent 53 | } 54 | } 55 | 56 | var computeUnits: ComputeUnits { 57 | get { ComputeUnits(rawValue: computeUnitsString) ?? .cpuAndNeuralEngine } 58 | set { computeUnitsString = newValue.rawValue } 59 | } 60 | @AppStorage("compute_units") var computeUnitsString: String = ComputeUnits.cpuAndNeuralEngine.rawValue 61 | 62 | var compression: Compression { 63 | get { Compression(rawValue: compressionString) ?? .fullSize } 64 | set { compressionString = newValue.rawValue } 65 | } 66 | @AppStorage("compression") var compressionString: String = Compression.fullSize.rawValue 67 | 68 | @AppStorage("custom_size") var customSize: Bool = false 69 | @AppStorage("custom_size_width") var customWidth: Int = 512 70 | @AppStorage("custom_size_height") var customHeight: Int = 512 71 | @AppStorage("multisize") var multisize: Bool = false 72 | 73 | @AppStorage("convert_unet") var convertUnet: Bool = true 74 | @AppStorage("chunk_unet") var chunkUnet: Bool = false 75 | @AppStorage("controlnet_support") var controlNetSupport: Bool = true 76 | @AppStorage("convert_text_encoder") var convertTextEncoder: Bool = true 77 | @AppStorage("load_embeddings") var loadEmbeddings: Bool = false 78 | @AppStorage("embeddings_location") var embeddingsLocation: URL? 79 | @AppStorage("convert_vae_encoder") var convertVaeEncoder: Bool = true 80 | @AppStorage("convert_vae_decoder") var convertVaeDecoder: Bool = true 81 | @AppStorage("convert_safety_checker") var convertSafetyChecker: Bool = false 82 | @AppStorage("precision_full") var precisionFull: Bool = false 83 | @Published var loRAsToMerge: [LoRAInfo] = [] 84 | 85 | private var cancellables: Set = [] 86 | 87 | func start() { 88 | guard checkCoreMLCompiler() else { 89 | isCoreMLError = true 90 | error = "CoreMLCompiler not available.\nMake sure you have Xcode installed and you run \"sudo xcode-select --switch /Applications/Xcode.app/Contents/Developer/\" on a Terminal." 91 | return 92 | } 93 | do { 94 | let process = try ConverterProcess( 95 | outputLocation: outputLocation, 96 | modelOrigin: modelOrigin, 97 | huggingfaceIdentifier: huggingfaceIdentifier, 98 | diffusersLocation: diffusersLocation, 99 | checkpointLocation: checkpointLocation, 100 | computeUnits: computeUnits, 101 | customWidth: customSize && !multisize && computeUnits == .cpuAndGPU ? customWidth : nil, 102 | customHeight: customSize && !multisize && computeUnits == .cpuAndGPU ? customHeight : nil, 103 | multisize: multisize && computeUnits == .cpuAndGPU, 104 | convertUnet: convertUnet, 105 | chunkUnet: chunkUnet, 106 | controlNetSupport: controlNetSupport && !multisize, // TODO: Currently not possible to have both 107 | convertTextEncoder: convertTextEncoder, 108 | embeddingsLocation: loadEmbeddings ? embeddingsLocation : nil, 109 | convertVaeEncoder: convertVaeEncoder, 110 | convertVaeDecoder: convertVaeDecoder, 111 | convertSafetyChecker: convertSafetyChecker, 112 | precisionFull: precisionFull, 113 | loRAsToMerge: loRAsToMerge, 114 | compression: compression 115 | ) 116 | process.objectWillChange 117 | .receive(on: DispatchQueue.main) 118 | .sink { _ in 119 | self.objectWillChange.send() 120 | }.store(in: &cancellables) 121 | NotificationCenter.default.publisher(for: Process.didTerminateNotification, object: process.process) 122 | .receive(on: DispatchQueue.main) 123 | .sink { _ in 124 | if process.didComplete { 125 | self.showSuccess = true 126 | } 127 | self.cancel() 128 | }.store(in: &cancellables) 129 | try process.start() 130 | withAnimation { 131 | self.process = process 132 | } 133 | } catch ConverterProcess.ArgumentError.noOutputLocation { 134 | showOutputPicker = true 135 | } catch ConverterProcess.ArgumentError.noHuggingfaceIdentifier { 136 | self.error = "Enter a valid identifier" 137 | } catch ConverterProcess.ArgumentError.noDiffusersLocation { 138 | self.error = "Enter a valid location" 139 | } catch ConverterProcess.ArgumentError.noCheckpointLocation { 140 | self.error = "Enter a valid location" 141 | } catch { 142 | print(error.localizedDescription) 143 | self.error = error.localizedDescription 144 | withAnimation { 145 | cancel() 146 | } 147 | } 148 | } 149 | 150 | func cancel() { 151 | withAnimation { 152 | process?.cancel() 153 | process = nil 154 | } 155 | } 156 | 157 | func selectAll() { 158 | convertUnet = true 159 | convertTextEncoder = true 160 | convertVaeEncoder = true 161 | convertVaeDecoder = true 162 | convertSafetyChecker = true 163 | } 164 | 165 | func selectNone() { 166 | convertUnet = false 167 | convertTextEncoder = false 168 | convertVaeEncoder = false 169 | convertVaeDecoder = false 170 | convertSafetyChecker = false 171 | } 172 | 173 | func checkCoreMLCompiler() -> Bool { 174 | let process = Process() 175 | process.executableURL = URL(string: "file:///usr/bin/xcrun") 176 | process.arguments = ["coremlcompiler", "version"] 177 | do { 178 | try process.run() 179 | } catch { 180 | Logger.shared.append("coremlcompiler not found") 181 | return false 182 | } 183 | return true 184 | } 185 | } 186 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/models/attention.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | logger.setLevel(logging.INFO) 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | from enum import Enum 10 | 11 | 12 | # Ensure minimum macOS version requirement is met for this particular model 13 | from coremltools.models.utils import _macos_version 14 | if not _macos_version() >= (13, 1): 15 | logger.warning( 16 | "!!! macOS 13.1 and newer or iOS/iPadOS 16.2 and newer is required for best performance !!!" 17 | ) 18 | 19 | 20 | class AttentionImplementations(Enum): 21 | ORIGINAL = "ORIGINAL" 22 | SPLIT_EINSUM = "SPLIT_EINSUM" 23 | SPLIT_EINSUM_V2 = "SPLIT_EINSUM_V2" 24 | 25 | 26 | ATTENTION_IMPLEMENTATION_IN_EFFECT = AttentionImplementations.SPLIT_EINSUM_V2 27 | 28 | 29 | WARN_MSG = \ 30 | "This `nn.Module` is intended for Apple Silicon deployment only. " \ 31 | "PyTorch-specific optimizations and training is disabled" 32 | 33 | class Attention(nn.Module): 34 | """ Apple Silicon friendly version of `diffusers.models.attention.Attention` 35 | """ 36 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, bias=False): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | context_dim = context_dim if context_dim is not None else query_dim 40 | 41 | self.scale = dim_head**-0.5 42 | self.heads = heads 43 | self.dim_head = dim_head 44 | 45 | self.to_q = nn.Conv2d(query_dim, inner_dim, kernel_size=1, bias=bias) 46 | self.to_k = nn.Conv2d(context_dim, inner_dim, kernel_size=1, bias=bias) 47 | self.to_v = nn.Conv2d(context_dim, inner_dim, kernel_size=1, bias=bias) 48 | if dropout > 0: 49 | self.to_out = nn.Sequential( 50 | nn.Conv2d(inner_dim, query_dim, kernel_size=1, bias=True), nn.Dropout(dropout) 51 | ) 52 | else: 53 | self.to_out = nn.Sequential( 54 | nn.Conv2d(inner_dim, query_dim, kernel_size=1, bias=True) 55 | ) 56 | 57 | 58 | def forward(self, hidden_states, encoder_hidden_states=None, mask=None): 59 | if self.training: 60 | raise NotImplementedError(WARN_MSG) 61 | 62 | batch_size, dim, _, sequence_length = hidden_states.shape 63 | 64 | q = self.to_q(hidden_states) 65 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states 66 | k = self.to_k(encoder_hidden_states) 67 | v = self.to_v(encoder_hidden_states) 68 | 69 | # Validate mask 70 | if mask is not None: 71 | expected_mask_shape = [batch_size, sequence_length, 1, 1] 72 | if mask.dtype == torch.bool: 73 | mask = mask.logical_not().float() * -1e4 74 | elif mask.dtype == torch.int64: 75 | mask = (1 - mask).float() * -1e4 76 | elif mask.dtype != torch.float32: 77 | raise TypeError(f"Unexpected dtype for mask: {mask.dtype}") 78 | 79 | if len(mask.size()) == 2: 80 | mask = mask.unsqueeze(2).unsqueeze(2) 81 | 82 | if list(mask.size()) != expected_mask_shape: 83 | raise RuntimeError( 84 | f"Invalid shape for `mask` (Expected {expected_mask_shape}, got {list(mask.size())}" 85 | ) 86 | 87 | if ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.ORIGINAL: 88 | attn = original(q, k, v, mask, self.heads, self.dim_head) 89 | elif ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.SPLIT_EINSUM: 90 | attn = split_einsum(q, k, v, mask, self.heads, self.dim_head) 91 | elif ATTENTION_IMPLEMENTATION_IN_EFFECT == AttentionImplementations.SPLIT_EINSUM_V2: 92 | attn = split_einsum_v2(q, k, v, mask, self.heads, self.dim_head) 93 | else: 94 | raise ValueError(ATTENTION_IMPLEMENTATION_IN_EFFECT) 95 | 96 | return self.to_out(attn) 97 | 98 | def split_einsum(q, k, v, mask, heads, dim_head): 99 | """ Attention Implementation backing AttentionImplementations.SPLIT_EINSUM 100 | 101 | - Implements https://machinelearning.apple.com/research/neural-engine-transformers 102 | - Recommended for ANE 103 | - Marginally slower on GPU 104 | """ 105 | mh_q = [ 106 | q[:, head_idx * dim_head:(head_idx + 1) * 107 | dim_head, :, :] for head_idx in range(heads) 108 | ] # (bs, dim_head, 1, max_seq_length) * heads 109 | 110 | k = k.transpose(1, 3) 111 | mh_k = [ 112 | k[:, :, :, 113 | head_idx * dim_head:(head_idx + 1) * dim_head] 114 | for head_idx in range(heads) 115 | ] # (bs, max_seq_length, 1, dim_head) * heads 116 | 117 | mh_v = [ 118 | v[:, head_idx * dim_head:(head_idx + 1) * 119 | dim_head, :, :] for head_idx in range(heads) 120 | ] # (bs, dim_head, 1, max_seq_length) * heads 121 | 122 | attn_weights = [ 123 | torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * (dim_head**-0.5) 124 | for qi, ki in zip(mh_q, mh_k) 125 | ] # (bs, max_seq_length, 1, max_seq_length) * heads 126 | 127 | if mask is not None: 128 | for head_idx in range(heads): 129 | attn_weights[head_idx] = attn_weights[head_idx] + mask 130 | 131 | attn_weights = [ 132 | aw.softmax(dim=1) for aw in attn_weights 133 | ] # (bs, max_seq_length, 1, max_seq_length) * heads 134 | attn = [ 135 | torch.einsum("bkhq,bchk->bchq", wi, vi) 136 | for wi, vi in zip(attn_weights, mh_v) 137 | ] # (bs, dim_head, 1, max_seq_length) * heads 138 | 139 | attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length) 140 | return attn 141 | 142 | 143 | CHUNK_SIZE = 512 144 | 145 | def split_einsum_v2(q, k, v, mask, heads, dim_head): 146 | """ Attention Implementation backing AttentionImplementations.SPLIT_EINSUM_V2 147 | 148 | - Implements https://machinelearning.apple.com/research/neural-engine-transformers 149 | - Recommended for ANE 150 | - Marginally slower on GPU 151 | - Chunks the query sequence to avoid large intermediate tensors and improves ANE performance 152 | """ 153 | query_seq_length = q.size(3) 154 | num_chunks = query_seq_length // CHUNK_SIZE 155 | 156 | if num_chunks == 0: 157 | logger.info( 158 | "AttentionImplementations.SPLIT_EINSUM_V2: query sequence too short to chunk " 159 | f"({query_seq_length}<{CHUNK_SIZE}), fall back to AttentionImplementations.SPLIT_EINSUM (safe to ignore)") 160 | return split_einsum(q, k, v, mask, heads, dim_head) 161 | 162 | logger.info( 163 | "AttentionImplementations.SPLIT_EINSUM_V2: Splitting query sequence length of " 164 | f"{query_seq_length} into {num_chunks} chunks") 165 | 166 | mh_q = [ 167 | q[:, head_idx * dim_head:(head_idx + 1) * 168 | dim_head, :, :] for head_idx in range(heads) 169 | ] # (bs, dim_head, 1, max_seq_length) * heads 170 | 171 | # Chunk the query sequence for each head 172 | mh_q_chunked = [ 173 | [h_q[..., chunk_idx * CHUNK_SIZE:(chunk_idx + 1) * CHUNK_SIZE] for chunk_idx in range(num_chunks)] 174 | for h_q in mh_q 175 | ] # ((bs, dim_head, 1, QUERY_SEQ_CHUNK_SIZE) * num_chunks) * heads 176 | 177 | k = k.transpose(1, 3) 178 | mh_k = [ 179 | k[:, :, :, 180 | head_idx * dim_head:(head_idx + 1) * dim_head] 181 | for head_idx in range(heads) 182 | ] # (bs, max_seq_length, 1, dim_head) * heads 183 | 184 | mh_v = [ 185 | v[:, head_idx * dim_head:(head_idx + 1) * 186 | dim_head, :, :] for head_idx in range(heads) 187 | ] # (bs, dim_head, 1, max_seq_length) * heads 188 | 189 | attn_weights = [ 190 | [ 191 | torch.einsum("bchq,bkhc->bkhq", [qi_chunk, ki]) * (dim_head**-0.5) 192 | for qi_chunk in h_q_chunked 193 | ] for h_q_chunked, ki in zip(mh_q_chunked, mh_k) 194 | ] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads 195 | 196 | attn_weights = [ 197 | [aw_chunk.softmax(dim=1) for aw_chunk in aw_chunked] 198 | for aw_chunked in attn_weights 199 | ] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads 200 | 201 | attn = [ 202 | [ 203 | torch.einsum("bkhq,bchk->bchq", wi_chunk, vi) 204 | for wi_chunk in wi_chunked 205 | ] for wi_chunked, vi in zip(attn_weights, mh_v) 206 | ] # ((bs, dim_head, 1, chunk_size) * num_chunks) * heads 207 | 208 | attn = torch.cat([ 209 | torch.cat(attn_chunked, dim=3) for attn_chunked in attn 210 | ], dim=1) # (bs, dim, 1, max_seq_length) 211 | 212 | return attn 213 | 214 | 215 | def original(q, k, v, mask, heads, dim_head): 216 | """ Attention Implementation backing AttentionImplementations.ORIGINAL 217 | 218 | - Not recommended for ANE 219 | - Recommended for GPU 220 | """ 221 | bs = q.size(0) 222 | mh_q = q.view(bs, heads, dim_head, -1) 223 | mh_k = k.view(bs, heads, dim_head, -1) 224 | mh_v = v.view(bs, heads, dim_head, -1) 225 | 226 | attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k]) 227 | attn_weights.mul_(dim_head**-0.5) 228 | 229 | if mask is not None: 230 | attn_weights = attn_weights + mask 231 | 232 | attn_weights = attn_weights.softmax(dim=3) 233 | 234 | attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v]) 235 | attn = attn.contiguous().view(bs, heads * dim_head, 1, -1) 236 | return attn 237 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/convert/convert_safety_checker.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from guernikatools.utils import utils 8 | 9 | from collections import OrderedDict, defaultdict 10 | from copy import deepcopy 11 | import coremltools as ct 12 | import gc 13 | 14 | import logging 15 | 16 | logging.basicConfig() 17 | logger = logging.getLogger(__name__) 18 | logger.setLevel(logging.INFO) 19 | 20 | import numpy as np 21 | import os 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | torch.set_grad_enabled(False) 28 | 29 | from types import MethodType 30 | 31 | 32 | def main(pipe, args): 33 | """ Converts the Safety Checker component of Stable Diffusion 34 | """ 35 | if pipe.safety_checker is None: 36 | logger.warning( 37 | f"diffusers pipeline for {args.model_version} does not have a `safety_checker` module! " \ 38 | "`--convert-safety-checker` will be ignored." 39 | ) 40 | return 41 | 42 | out_path = utils.get_out_path(args, "safety_checker") 43 | if os.path.exists(out_path): 44 | logger.info(f"`safety_checker` already exists at {out_path}, skipping conversion.") 45 | return 46 | 47 | sample_image = np.random.randn( 48 | 1, # B 49 | args.output_h, # H 50 | args.output_w, # w 51 | 3 # C 52 | ).astype(np.float32) 53 | 54 | # Note that pipe.feature_extractor is not an ML model. It simply 55 | # preprocesses data for the pipe.safety_checker module. 56 | safety_checker_input = pipe.feature_extractor( 57 | pipe.numpy_to_pil(sample_image), 58 | return_tensors="pt", 59 | ).pixel_values.to(torch.float32) 60 | 61 | sample_safety_checker_inputs = OrderedDict([ 62 | ("clip_input", safety_checker_input), 63 | ("images", torch.from_numpy(sample_image)), 64 | ("adjustment", torch.tensor([0]).to(torch.float32)), 65 | ]) 66 | 67 | sample_safety_checker_inputs_spec = { 68 | k: (v.shape, v.dtype) 69 | for k, v in sample_safety_checker_inputs.items() 70 | } 71 | logger.info(f"Sample inputs spec: {sample_safety_checker_inputs_spec}") 72 | 73 | # Patch safety_checker's forward pass to be vectorized and avoid conditional blocks 74 | # (similar to pipe.safety_checker.forward_onnx) 75 | from diffusers.pipelines.stable_diffusion import safety_checker 76 | 77 | def forward_coreml(self, clip_input, images, adjustment): 78 | """ Forward pass implementation for safety_checker 79 | """ 80 | 81 | def cosine_distance(image_embeds, text_embeds): 82 | return F.normalize(image_embeds) @ F.normalize(text_embeds).transpose(0, 1) 83 | 84 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 85 | image_embeds = self.visual_projection(pooled_output) 86 | 87 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 88 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 89 | 90 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment 91 | special_care = special_scores.gt(0).float().sum(dim=1).gt(0).float() 92 | special_adjustment = special_care * 0.01 93 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) 94 | 95 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment 96 | has_nsfw_concepts = concept_scores.gt(0).float().sum(dim=1).gt(0) 97 | 98 | # There is a problem when converting using multisize, for now the workaround is to not filter the images 99 | # The swift implementations already filters the images checking `has_nsfw_concepts` so this should not have any impact 100 | 101 | #has_nsfw_concepts = concept_scores.gt(0).float().sum(dim=1).gt(0)[:, None, None, None] 102 | 103 | #has_nsfw_concepts_inds, _ = torch.broadcast_tensors(has_nsfw_concepts, images) 104 | #images[has_nsfw_concepts_inds] = 0.0 # black image 105 | 106 | return images, has_nsfw_concepts.float(), concept_scores 107 | 108 | baseline_safety_checker = deepcopy(pipe.safety_checker.eval()) 109 | setattr(baseline_safety_checker, "forward", MethodType(forward_coreml, baseline_safety_checker)) 110 | 111 | # In order to parity check the actual signal, we need to override the forward pass to return `concept_scores` which is the 112 | # output before thresholding 113 | # Reference: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py#L100 114 | def forward_extended_return(self, clip_input, images, adjustment): 115 | 116 | def cosine_distance(image_embeds, text_embeds): 117 | normalized_image_embeds = F.normalize(image_embeds) 118 | normalized_text_embeds = F.normalize(text_embeds) 119 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) 120 | 121 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 122 | image_embeds = self.visual_projection(pooled_output) 123 | 124 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 125 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 126 | 127 | adjustment = 0.0 128 | 129 | special_scores = special_cos_dist - self.special_care_embeds_weights + adjustment 130 | special_care = torch.any(special_scores > 0, dim=1) 131 | special_adjustment = special_care * 0.01 132 | special_adjustment = special_adjustment.unsqueeze(1).expand(-1, cos_dist.shape[1]) 133 | 134 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment 135 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) 136 | 137 | # Don't make the images black as to align with the workaround in `forward_coreml` 138 | #images[has_nsfw_concepts] = 0.0 139 | 140 | return images, has_nsfw_concepts, concept_scores 141 | 142 | setattr(pipe.safety_checker, "forward", MethodType(forward_extended_return, pipe.safety_checker)) 143 | 144 | # Trace the safety_checker model 145 | logger.info("JIT tracing..") 146 | traced_safety_checker = torch.jit.trace(baseline_safety_checker, list(sample_safety_checker_inputs.values())) 147 | logger.info("Done.") 148 | del baseline_safety_checker 149 | gc.collect() 150 | 151 | # Cast all inputs to float16 152 | coreml_sample_safety_checker_inputs = { 153 | k: v.numpy().astype(np.float16) 154 | for k, v in sample_safety_checker_inputs.items() 155 | } 156 | 157 | # Convert safety_checker model to Core ML 158 | if args.multisize: 159 | clip_size = safety_checker_input.shape[2] 160 | clip_input_shape = ct.Shape(shape=( 161 | 1, 162 | 3, 163 | ct.RangeDim(int(clip_size * 0.5), upper_bound=int(clip_size * 2), default=clip_size), 164 | ct.RangeDim(int(clip_size * 0.5), upper_bound=int(clip_size * 2), default=clip_size) 165 | )) 166 | sample_size = args.output_h 167 | input_shape = ct.Shape(shape=( 168 | 1, 169 | ct.RangeDim(int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size), 170 | ct.RangeDim(int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size), 171 | 3 172 | )) 173 | 174 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_safety_checker_inputs, { 175 | "clip_input": clip_input_shape, "images": input_shape 176 | }) 177 | else: 178 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_safety_checker_inputs) 179 | coreml_safety_checker, out_path = utils.convert_to_coreml( 180 | "safety_checker", traced_safety_checker, 181 | sample_coreml_inputs, 182 | ["filtered_images", "has_nsfw_concepts", "concept_scores"], False, args 183 | ) 184 | 185 | # Set model metadata 186 | coreml_safety_checker.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 187 | coreml_safety_checker.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)" 188 | coreml_safety_checker.version = args.model_version 189 | coreml_safety_checker.short_description = \ 190 | "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \ 191 | "Please refer to https://arxiv.org/abs/2112.10752 for details." 192 | 193 | # Set the input descriptions 194 | coreml_safety_checker.input_description["clip_input"] = \ 195 | "The normalized image input tensor resized to (224x224) in channels-first (BCHW) format" 196 | coreml_safety_checker.input_description["images"] = \ 197 | f"Output of the vae_decoder ({pipe.vae.config.sample_size}x{pipe.vae.config.sample_size}) in channels-last (BHWC) format" 198 | coreml_safety_checker.input_description["adjustment"] = \ 199 | "Bias added to the concept scores to trade off increased recall for reduce precision in the safety checker classifier" 200 | 201 | # Set the output descriptions 202 | coreml_safety_checker.output_description["filtered_images"] = \ 203 | f"Identical to the input `images`. If safety checker detected any sensitive content, " \ 204 | "the corresponding image is replaced with a blank image (zeros)" 205 | coreml_safety_checker.output_description["has_nsfw_concepts"] = \ 206 | "Indicates whether the safety checker model found any sensitive content in the given image" 207 | coreml_safety_checker.output_description["concept_scores"] = \ 208 | "Concept scores are the scores before thresholding at zero yields the `has_nsfw_concepts` output. " \ 209 | "These scores can be used to tune the `adjustment` input" 210 | 211 | # Set package version metadata 212 | coreml_safety_checker.user_defined_metadata["identifier"] = args.model_version 213 | coreml_safety_checker.user_defined_metadata["converter_version"] = __version__ 214 | coreml_safety_checker.user_defined_metadata["attention_implementation"] = args.attention_implementation 215 | coreml_safety_checker.user_defined_metadata["compute_unit"] = args.compute_unit 216 | 217 | coreml_safety_checker.save(out_path) 218 | 219 | if args.check_output_correctness: 220 | baseline_out = pipe.safety_checker(**sample_safety_checker_inputs)[2].numpy() 221 | coreml_out = coreml_safety_checker.predict(coreml_sample_safety_checker_inputs)["concept_scores"] 222 | utils.report_correctness(baseline_out, coreml_out, "safety_checker baseline PyTorch to reference CoreML") 223 | 224 | del traced_safety_checker, coreml_safety_checker, pipe.safety_checker 225 | gc.collect() 226 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/utils/merge_lora.py: -------------------------------------------------------------------------------- 1 | # Taken from: https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py 2 | 3 | import math 4 | import argparse 5 | import os 6 | import torch 7 | from safetensors.torch import load_file, save_file 8 | 9 | # is it possible to apply conv_in and conv_out? -> yes, newer LoCon supports it (^^;) 10 | UNET_TARGET_REPLACE_MODULE = ["Transformer2DModel", "Attention"] 11 | UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] 12 | TEXT_ENCODER_TARGET_REPLACE_MODULE = ["CLIPAttention", "CLIPMLP"] 13 | LORA_PREFIX_UNET = "lora_unet" 14 | LORA_PREFIX_TEXT_ENCODER = "lora_te" 15 | 16 | # SDXL: must starts with LORA_PREFIX_TEXT_ENCODER 17 | LORA_PREFIX_TEXT_ENCODER_1 = "lora_te1" 18 | LORA_PREFIX_TEXT_ENCODER_2 = "lora_te2" 19 | 20 | def load_state_dict(file_name, dtype): 21 | if os.path.splitext(file_name)[1] == ".safetensors": 22 | sd = load_file(file_name) 23 | else: 24 | sd = torch.load(file_name, map_location="cpu") 25 | for key in list(sd.keys()): 26 | if type(sd[key]) == torch.Tensor: 27 | sd[key] = sd[key].to(dtype) 28 | return sd 29 | 30 | 31 | def save_to_file(file_name, model, state_dict, dtype): 32 | if dtype is not None: 33 | for key in list(state_dict.keys()): 34 | if type(state_dict[key]) == torch.Tensor: 35 | state_dict[key] = state_dict[key].to(dtype) 36 | 37 | if os.path.splitext(file_name)[1] == ".safetensors": 38 | save_file(model, file_name) 39 | else: 40 | torch.save(model, file_name) 41 | 42 | def merge_to_sd_model(unet, text_encoder, text_encoder_2, models, ratios, merge_dtype=torch.float32): 43 | unet.to(merge_dtype) 44 | text_encoder.to(merge_dtype) 45 | if text_encoder_2: 46 | text_encoder_2.to(merge_dtype) 47 | 48 | layers_per_block = unet.config.layers_per_block 49 | 50 | # create module map 51 | name_to_module = {} 52 | for i, root_module in enumerate([unet, text_encoder, text_encoder_2]): 53 | if not root_module: 54 | continue 55 | if i == 0: 56 | prefix = LORA_PREFIX_UNET 57 | target_replace_modules = ( 58 | UNET_TARGET_REPLACE_MODULE + UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 59 | ) 60 | elif text_encoder_2: 61 | target_replace_modules = TEXT_ENCODER_TARGET_REPLACE_MODULE 62 | if i == 1: 63 | prefix = LORA_PREFIX_TEXT_ENCODER_1 64 | else: 65 | prefix = LORA_PREFIX_TEXT_ENCODER_2 66 | else: 67 | prefix = LORA_PREFIX_TEXT_ENCODER 68 | target_replace_modules = TEXT_ENCODER_TARGET_REPLACE_MODULE 69 | 70 | for name, module in root_module.named_modules(): 71 | if module.__class__.__name__ == "LoRACompatibleLinear" or module.__class__.__name__ == "LoRACompatibleConv": 72 | lora_name = prefix + "." + name 73 | lora_name = lora_name.replace(".", "_") 74 | name_to_module[lora_name] = module 75 | elif module.__class__.__name__ in target_replace_modules: 76 | for child_name, child_module in module.named_modules(): 77 | if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": 78 | lora_name = prefix + "." + name + "." + child_name 79 | lora_name = lora_name.replace(".", "_") 80 | name_to_module[lora_name] = child_module 81 | 82 | for model, ratio in zip(models, ratios): 83 | print(f"Merging: {model}") 84 | lora_sd = load_state_dict(model, merge_dtype) 85 | 86 | for key in lora_sd.keys(): 87 | if "lora_down" in key: 88 | up_key = key.replace("lora_down", "lora_up") 89 | alpha_key = key[: key.index("lora_down")] + "alpha" 90 | 91 | # find original module for this lora 92 | module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" 93 | if "input_blocks" in module_name: 94 | i = int(module_name.split("input_blocks_", 1)[1].split("_", 1)[0]) 95 | block_id = (i - 1) // (layers_per_block + 1) 96 | layer_in_block_id = (i - 1) % (layers_per_block + 1) 97 | module_name = module_name.replace(f"input_blocks_{i}_0", f"down_blocks_{block_id}_resnets_{layer_in_block_id}") 98 | module_name = module_name.replace(f"input_blocks_{i}_1", f"down_blocks_{block_id}_attentions_{layer_in_block_id}") 99 | module_name = module_name.replace(f"input_blocks_{i}_2", f"down_blocks_{block_id}_resnets_{layer_in_block_id}") 100 | if "middle_block" in module_name: 101 | module_name = module_name.replace("middle_block_0", "mid_block_resnets_0") 102 | module_name = module_name.replace("middle_block_1", "mid_block_attentions_0") 103 | module_name = module_name.replace("middle_block_2", "mid_block_resnets_1") 104 | if "output_blocks" in module_name: 105 | i = int(module_name.split("output_blocks_", 1)[1].split("_", 1)[0]) 106 | block_id = i // (layers_per_block + 1) 107 | layer_in_block_id = i % (layers_per_block + 1) 108 | module_name = module_name.replace(f"output_blocks_{i}_0", f"up_blocks_{block_id}_resnets_{layer_in_block_id}") 109 | module_name = module_name.replace(f"output_blocks_{i}_1", f"up_blocks_{block_id}_attentions_{layer_in_block_id}") 110 | module_name = module_name.replace(f"output_blocks_{i}_2", f"up_blocks_{block_id}_resnets_{layer_in_block_id}") 111 | 112 | module_name = module_name.replace("in_layers_0", "norm1") 113 | module_name = module_name.replace("in_layers_2", "conv1") 114 | 115 | module_name = module_name.replace("out_layers_0", "norm2") 116 | module_name = module_name.replace("out_layers_3", "conv2") 117 | 118 | module_name = module_name.replace("emb_layers_1", "time_emb_proj") 119 | module_name = module_name.replace("skip_connection", "conv_shortcut") 120 | 121 | if module_name not in name_to_module: 122 | print(f"no module found for LoRA weight: {key}") 123 | continue 124 | module = name_to_module[module_name] 125 | # print(f"apply {key} to {module}") 126 | 127 | down_weight = lora_sd[key] 128 | up_weight = lora_sd[up_key] 129 | 130 | dim = down_weight.size()[0] 131 | alpha = lora_sd.get(alpha_key, dim) 132 | scale = alpha / dim 133 | 134 | # W <- W + U * D 135 | weight = module.weight 136 | # print(module_name, down_weight.size(), up_weight.size()) 137 | if len(weight.size()) == 2: 138 | # linear 139 | if len(up_weight.size()) == 4: # use linear projection mismatch 140 | up_weight = up_weight.squeeze(3).squeeze(2) 141 | down_weight = down_weight.squeeze(3).squeeze(2) 142 | weight = weight + ratio * (up_weight @ down_weight) * scale 143 | elif down_weight.size()[2:4] == (1, 1): 144 | # conv2d 1x1 145 | weight = ( 146 | weight 147 | + ratio 148 | * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) 149 | * scale 150 | ) 151 | else: 152 | # conv2d 3x3 153 | conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) 154 | # print(conved.size(), weight.size(), module.stride, module.padding) 155 | weight = weight + ratio * conved * scale 156 | 157 | module.weight = torch.nn.Parameter(weight) 158 | 159 | 160 | def merge_lora_models(models, ratios, merge_dtype): 161 | base_alphas = {} # alpha for merged model 162 | base_dims = {} 163 | 164 | merged_sd = {} 165 | for model, ratio in zip(models, ratios): 166 | print(f"loading: {model}") 167 | lora_sd = load_state_dict(model, merge_dtype) 168 | 169 | # get alpha and dim 170 | alphas = {} # alpha for current model 171 | dims = {} # dims for current model 172 | for key in lora_sd.keys(): 173 | if "alpha" in key: 174 | lora_module_name = key[: key.rfind(".alpha")] 175 | alpha = float(lora_sd[key].detach().numpy()) 176 | alphas[lora_module_name] = alpha 177 | if lora_module_name not in base_alphas: 178 | base_alphas[lora_module_name] = alpha 179 | elif "lora_down" in key: 180 | lora_module_name = key[: key.rfind(".lora_down")] 181 | dim = lora_sd[key].size()[0] 182 | dims[lora_module_name] = dim 183 | if lora_module_name not in base_dims: 184 | base_dims[lora_module_name] = dim 185 | 186 | for lora_module_name in dims.keys(): 187 | if lora_module_name not in alphas: 188 | alpha = dims[lora_module_name] 189 | alphas[lora_module_name] = alpha 190 | if lora_module_name not in base_alphas: 191 | base_alphas[lora_module_name] = alpha 192 | 193 | print(f"dim: {list(set(dims.values()))}, alpha: {list(set(alphas.values()))}") 194 | 195 | # merge 196 | print(f"merging...") 197 | for key in lora_sd.keys(): 198 | if "alpha" in key: 199 | continue 200 | 201 | lora_module_name = key[: key.rfind(".lora_")] 202 | 203 | base_alpha = base_alphas[lora_module_name] 204 | alpha = alphas[lora_module_name] 205 | 206 | scale = math.sqrt(alpha / base_alpha) * ratio 207 | 208 | if key in merged_sd: 209 | assert ( 210 | merged_sd[key].size() == lora_sd[key].size() 211 | ), f"weights shape mismatch merging v1 and v2, different dims? / 重みのサイズが合いません。v1とv2、または次元数の異なるモデルはマージできません" 212 | merged_sd[key] = merged_sd[key] + lora_sd[key] * scale 213 | else: 214 | merged_sd[key] = lora_sd[key] * scale 215 | 216 | # set alpha to sd 217 | for lora_module_name, alpha in base_alphas.items(): 218 | key = lora_module_name + ".alpha" 219 | merged_sd[key] = torch.tensor(alpha) 220 | 221 | print("merged model") 222 | print(f"dim: {list(set(base_dims.values()))}, alpha: {list(set(base_alphas.values()))}") 223 | 224 | return merged_sd 225 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/utils/utils.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from .merge_lora import merge_to_sd_model 8 | 9 | import json 10 | import argparse 11 | from collections import OrderedDict, defaultdict 12 | from copy import deepcopy 13 | import coremltools as ct 14 | from diffusers import DiffusionPipeline, StableDiffusionPipeline, StableDiffusionXLPipeline, AutoPipelineForInpainting 15 | from diffusers import T2IAdapter, StableDiffusionAdapterPipeline 16 | from diffusers import ControlNetModel, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline, AutoencoderKL 17 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt 18 | import gc 19 | 20 | import logging 21 | 22 | logging.basicConfig() 23 | logger = logging.getLogger(__name__) 24 | logger.setLevel(logging.INFO) 25 | 26 | import numpy as np 27 | import os 28 | from os import listdir 29 | from os.path import isfile, join 30 | import tempfile 31 | import requests 32 | import shutil 33 | import time 34 | import traceback 35 | 36 | import torch 37 | import torch.nn as nn 38 | import torch.nn.functional as F 39 | 40 | torch.set_grad_enabled(False) 41 | 42 | 43 | def conditioning_method_from(identifier): 44 | if "canny" in identifier: 45 | return "canny" 46 | if "depth" in identifier: 47 | return "depth" 48 | if "pose" in identifier: 49 | return "pose" 50 | if "mlsd" in identifier: 51 | return "mlsd" 52 | if "normal" in identifier: 53 | return "normal" 54 | if "scribble" in identifier: 55 | return "scribble" 56 | if "hed" in identifier: 57 | return "hed" 58 | if "seg" in identifier: 59 | return "segmentation" 60 | return None 61 | 62 | 63 | def get_out_path(args, submodule_name): 64 | fname = f"{args.model_version}_{submodule_name}.mlpackage" 65 | fname = fname.replace("/", "_") 66 | if args.clean_up_mlpackages: 67 | temp_dir = tempfile.gettempdir() 68 | return os.path.join(temp_dir, fname) 69 | return os.path.join(args.o, fname) 70 | 71 | 72 | def get_coreml_inputs(sample_inputs, samples_shapes=None): 73 | return [ 74 | ct.TensorType( 75 | name=k, 76 | shape=samples_shapes[k] if samples_shapes and k in samples_shapes else v.shape, 77 | dtype=v.numpy().dtype if isinstance(v, torch.Tensor) else v.dtype, 78 | ) for k, v in sample_inputs.items() 79 | ] 80 | 81 | 82 | ABSOLUTE_MIN_PSNR = 35 83 | 84 | def report_correctness(original_outputs, final_outputs, log_prefix): 85 | """ Report PSNR values across two compatible tensors 86 | """ 87 | original_psnr = compute_psnr(original_outputs, original_outputs) 88 | final_psnr = compute_psnr(original_outputs, final_outputs) 89 | 90 | dB_change = final_psnr - original_psnr 91 | logger.info( 92 | f"{log_prefix}: PSNR changed by {dB_change:.1f} dB ({original_psnr:.1f} -> {final_psnr:.1f})" 93 | ) 94 | 95 | if final_psnr < ABSOLUTE_MIN_PSNR: 96 | # raise ValueError(f"{final_psnr:.1f} dB is too low!") 97 | logger.info(f"{final_psnr:.1f} dB is too low!") 98 | else: 99 | logger.info( 100 | f"{final_psnr:.1f} dB > {ABSOLUTE_MIN_PSNR} dB (minimum allowed) parity check passed" 101 | ) 102 | return final_psnr 103 | 104 | 105 | def compute_psnr(a, b): 106 | """ Compute Peak-Signal-to-Noise-Ratio across two numpy.ndarray objects 107 | """ 108 | max_b = np.abs(b).max() 109 | sumdeltasq = 0.0 110 | 111 | sumdeltasq = ((a - b) * (a - b)).sum() 112 | 113 | sumdeltasq /= b.size 114 | sumdeltasq = np.sqrt(sumdeltasq) 115 | 116 | eps = 1e-5 117 | eps2 = 1e-10 118 | psnr = 20 * np.log10((max_b + eps) / (sumdeltasq + eps2)) 119 | 120 | return psnr 121 | 122 | 123 | def convert_to_coreml(submodule_name, torchscript_module, coreml_inputs, output_names, precision_full, args): 124 | out_path = get_out_path(args, submodule_name) 125 | 126 | if os.path.exists(out_path): 127 | logger.info(f"Skipping export because {out_path} already exists") 128 | logger.info(f"Loading model from {out_path}") 129 | 130 | start = time.time() 131 | # Note: Note that each model load will trigger a model compilation which takes up to a few minutes. 132 | # The Swifty CLI we provide uses precompiled Core ML models (.mlmodelc) which incurs compilation only 133 | # upon first load and mitigates the load time in subsequent runs. 134 | coreml_model = ct.models.MLModel(out_path, compute_units=ct.ComputeUnit[args.compute_unit]) 135 | logger.info(f"Loading {out_path} took {time.time() - start:.1f} seconds") 136 | 137 | coreml_model.compute_unit = ct.ComputeUnit[args.compute_unit] 138 | else: 139 | logger.info(f"Converting {submodule_name} to CoreML...") 140 | coreml_model = ct.convert( 141 | torchscript_module, 142 | convert_to="mlprogram", 143 | minimum_deployment_target=ct.target.macOS13, 144 | inputs=coreml_inputs, 145 | outputs=[ct.TensorType(name=name) for name in output_names], 146 | compute_units=ct.ComputeUnit[args.compute_unit], 147 | compute_precision=ct.precision.FLOAT32 if precision_full else ct.precision.FLOAT16, 148 | # skip_model_load=True, 149 | ) 150 | 151 | del torchscript_module 152 | gc.collect() 153 | 154 | return coreml_model, out_path 155 | 156 | 157 | def quantize_weights(models, args): 158 | """ Quantize weights to args.quantize_nbits using a palette (look-up table) 159 | """ 160 | for model_name in models: 161 | logger.info(f"Quantizing {model_name} to {args.quantize_nbits}-bit precision") 162 | out_path = get_out_path(args, model_name) 163 | _quantize_weights( 164 | out_path, 165 | model_name, 166 | args.quantize_nbits 167 | ) 168 | 169 | def _quantize_weights(out_path, model_name, nbits): 170 | if os.path.exists(out_path): 171 | logger.info(f"Quantizing {model_name}") 172 | mlmodel = ct.models.MLModel(out_path, compute_units=ct.ComputeUnit.CPU_ONLY) 173 | 174 | op_config = ct.optimize.coreml.OpPalettizerConfig( 175 | mode="kmeans", 176 | nbits=nbits, 177 | ) 178 | 179 | config = ct.optimize.coreml.OptimizationConfig( 180 | global_config=op_config, 181 | op_type_configs={ 182 | "gather": None # avoid quantizing the embedding table 183 | } 184 | ) 185 | 186 | model = ct.optimize.coreml.palettize_weights(mlmodel, config=config).save(out_path) 187 | logger.info("Done") 188 | else: 189 | logger.info(f"Skipped quantizing {model_name} (Not found at {out_path})") 190 | 191 | 192 | def bundle_resources_for_guernika(pipe, args): 193 | """ 194 | - Compiles Core ML models from mlpackage into mlmodelc format 195 | - Download tokenizer resources for the text encoder 196 | """ 197 | resources_dir = os.path.join(args.o, args.resources_dir_name.replace("/", "_")) 198 | if not os.path.exists(resources_dir): 199 | os.makedirs(resources_dir, exist_ok=True) 200 | logger.info(f"Created {resources_dir} for Guernika assets") 201 | 202 | # Compile model using coremlcompiler (Significantly reduces the load time for unet) 203 | for source_name, target_name in [ 204 | ("text_encoder", "TextEncoder"), 205 | ("text_encoder_2", "TextEncoder2"), 206 | ("text_encoder_prior", "TextEncoderPrior"), 207 | ("vae_encoder", "VAEEncoder"), 208 | ("vae_decoder", "VAEDecoder"), 209 | ("t2i_adapter", "T2IAdapter"), 210 | ("controlnet", "ControlNet"), 211 | ("unet", "Unet"), 212 | ("unet_chunk1", "UnetChunk1"), 213 | ("unet_chunk2", "UnetChunk2"), 214 | ("safety_checker", "SafetyChecker"), 215 | ("wuerstchen_prior", "WuerstchenPrior"), 216 | ("wuerstchen_decoder", "WuerstchenDecoder"), 217 | ("wuerstchen_vqgan", "WuerstchenVQGAN") 218 | 219 | ]: 220 | source_path = get_out_path(args, source_name) 221 | if os.path.exists(source_path): 222 | target_path = _compile_coreml_model(source_path, resources_dir, target_name) 223 | logger.info(f"Compiled {source_path} to {target_path}") 224 | if source_name.startswith("text_encoder"): 225 | # Fetch and save vocabulary JSON file for text tokenizer 226 | logger.info("Downloading and saving tokenizer vocab.json") 227 | with open(os.path.join(target_path, "vocab.json"), "wb") as f: 228 | f.write(requests.get(args.text_encoder_vocabulary_url).content) 229 | logger.info("Done") 230 | 231 | # Fetch and save merged pairs JSON file for text tokenizer 232 | logger.info("Downloading and saving tokenizer merges.txt") 233 | with open(os.path.join(target_path, "merges.txt"), "wb") as f: 234 | f.write(requests.get(args.text_encoder_merges_url).content) 235 | logger.info("Done") 236 | 237 | if hasattr(args, "added_vocab") and args.added_vocab: 238 | logger.info("Saving added vocab") 239 | with open(os.path.join(target_path, "added_vocab.json"), 'w', encoding='utf-8') as f: 240 | json.dump(args.added_vocab, f, ensure_ascii=False, indent=4) 241 | else: 242 | logger.warning( 243 | f"{source_path} not found, skipping compilation to {target_name}.mlmodelc" 244 | ) 245 | 246 | return resources_dir 247 | 248 | 249 | def _compile_coreml_model(source_model_path, output_dir, final_name): 250 | """ Compiles Core ML models using the coremlcompiler utility from Xcode toolchain 251 | """ 252 | target_path = os.path.join(output_dir, f"{final_name}.mlmodelc") 253 | if os.path.exists(target_path): 254 | logger.warning(f"Found existing compiled model at {target_path}! Skipping..") 255 | return target_path 256 | 257 | logger.info(f"Compiling {source_model_path}") 258 | source_model_name = os.path.basename(os.path.splitext(source_model_path)[0]) 259 | 260 | os.system(f"xcrun coremlcompiler compile '{source_model_path}' '{output_dir}'") 261 | compiled_output = os.path.join(output_dir, f"{source_model_name}.mlmodelc") 262 | shutil.move(compiled_output, target_path) 263 | 264 | return target_path 265 | 266 | 267 | def remove_mlpackages(args): 268 | for package_name in [ 269 | "text_encoder", 270 | "text_encoder_2", 271 | "text_encoder_prior", 272 | "vae_encoder", 273 | "vae_decoder", 274 | "t2i_adapter", 275 | "controlnet", 276 | "unet", 277 | "unet_chunk1", 278 | "unet_chunk2", 279 | "safety_checker", 280 | "wuerstchen_prior", 281 | "wuerstchen_prior_chunk1", 282 | "wuerstchen_prior_chunk2", 283 | "wuerstchen_decoder", 284 | "wuerstchen_decoder_chunk1", 285 | "wuerstchen_decoder_chunk2", 286 | "wuerstchen_vqgan" 287 | ]: 288 | package_path = get_out_path(args, package_name) 289 | try: 290 | if os.path.exists(package_path): 291 | shutil.rmtree(package_path) 292 | except: 293 | traceback.print_exc() 294 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/convert/convert_controlnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from guernikatools.utils import utils 8 | from guernikatools.models import attention, controlnet 9 | 10 | from collections import OrderedDict, defaultdict 11 | from copy import deepcopy 12 | import coremltools as ct 13 | import gc 14 | 15 | import logging 16 | 17 | logging.basicConfig() 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.INFO) 20 | 21 | import numpy as np 22 | import os 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | torch.set_grad_enabled(False) 29 | 30 | 31 | def main(pipe, args): 32 | """ Converts the ControlNet component of Stable Diffusion 33 | """ 34 | if not pipe.controlnet: 35 | logger.info(f"`controlnet` not available in this pipline.") 36 | return 37 | 38 | out_path = utils.get_out_path(args, "controlnet") 39 | if os.path.exists(out_path): 40 | logger.info(f"`controlnet` already exists at {out_path}, skipping conversion.") 41 | return 42 | 43 | # Register the selected attention implementation globally 44 | attention.ATTENTION_IMPLEMENTATION_IN_EFFECT = attention.AttentionImplementations[args.attention_implementation] 45 | logger.info(f"Attention implementation in effect: {attention.ATTENTION_IMPLEMENTATION_IN_EFFECT}") 46 | 47 | # Prepare sample input shapes and values 48 | batch_size = 2 # for classifier-free guidance 49 | controlnet_in_channels = pipe.controlnet.config.in_channels 50 | vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1) 51 | height = int(args.output_h / vae_scale_factor) 52 | width = int(args.output_w / vae_scale_factor) 53 | 54 | sample_shape = ( 55 | batch_size, # B 56 | controlnet_in_channels, # C 57 | height, # H 58 | width, # W 59 | ) 60 | 61 | cond_shape = ( 62 | batch_size, # B 63 | 3, # C 64 | args.output_h, # H 65 | args.output_w, # W 66 | ) 67 | 68 | if not hasattr(pipe, "text_encoder"): 69 | raise RuntimeError( 70 | "convert_text_encoder() deletes pipe.text_encoder to save RAM. " 71 | "Please use convert_controlnet() before convert_text_encoder()") 72 | 73 | hidden_size = pipe.controlnet.config.cross_attention_dim 74 | encoder_hidden_states_shape = ( 75 | batch_size, 76 | hidden_size, 77 | 1, 78 | pipe.text_encoder.config.max_position_embeddings, 79 | ) 80 | 81 | # Create the scheduled timesteps for downstream use 82 | DEFAULT_NUM_INFERENCE_STEPS = 50 83 | pipe.scheduler.set_timesteps(DEFAULT_NUM_INFERENCE_STEPS) 84 | 85 | output_names = [ 86 | "down_block_res_samples_00", "down_block_res_samples_01", "down_block_res_samples_02", 87 | "down_block_res_samples_03", "down_block_res_samples_04", "down_block_res_samples_05", 88 | "down_block_res_samples_06", "down_block_res_samples_07", "down_block_res_samples_08" 89 | ] 90 | sample_controlnet_inputs = [ 91 | ("sample", torch.rand(*sample_shape)), 92 | ("timestep", 93 | torch.tensor([pipe.scheduler.timesteps[0].item()] * 94 | (batch_size)).to(torch.float32)), 95 | ("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)), 96 | ("controlnet_cond", torch.rand(*cond_shape)) 97 | ] 98 | if hasattr(pipe.controlnet.config, "addition_embed_type") and pipe.controlnet.config.addition_embed_type == "text_time": 99 | text_embeds_shape = ( 100 | batch_size, 101 | pipe.text_encoder_2.config.hidden_size, 102 | ) 103 | time_ids_input = [ 104 | [args.output_h, args.output_w, 0, 0, args.output_h, args.output_w], 105 | [args.output_h, args.output_w, 0, 0, args.output_h, args.output_w] 106 | ] 107 | sample_controlnet_inputs = sample_controlnet_inputs + [ 108 | ("text_embeds", torch.rand(*text_embeds_shape)), 109 | ("time_ids", torch.tensor(time_ids_input).to(torch.float32)), 110 | ] 111 | else: 112 | # SDXL ControlNet does not generate these outputs 113 | output_names = output_names + ["down_block_res_samples_09", "down_block_res_samples_10", "down_block_res_samples_11"] 114 | output_names = output_names + ["mid_block_res_sample"] 115 | 116 | sample_controlnet_inputs = OrderedDict(sample_controlnet_inputs) 117 | sample_controlnet_inputs_spec = { 118 | k: (v.shape, v.dtype) 119 | for k, v in sample_controlnet_inputs.items() 120 | } 121 | logger.info(f"Sample inputs spec: {sample_controlnet_inputs_spec}") 122 | 123 | # Initialize reference controlnet 124 | reference_controlnet = controlnet.ControlNetModel(**pipe.controlnet.config).eval() 125 | load_state_dict_summary = reference_controlnet.load_state_dict(pipe.controlnet.state_dict()) 126 | 127 | # Prepare inputs 128 | baseline_sample_controlnet_inputs = deepcopy(sample_controlnet_inputs) 129 | baseline_sample_controlnet_inputs[ 130 | "encoder_hidden_states"] = baseline_sample_controlnet_inputs[ 131 | "encoder_hidden_states"].squeeze(2).transpose(1, 2) 132 | 133 | # JIT trace 134 | logger.info("JIT tracing..") 135 | reference_controlnet = torch.jit.trace(reference_controlnet, example_kwarg_inputs=sample_controlnet_inputs) 136 | logger.info("Done.") 137 | 138 | if args.check_output_correctness: 139 | baseline_out = pipe.controlnet(**baseline_sample_controlnet_inputs, return_dict=False)[0].numpy() 140 | reference_out = reference_controlnet(**sample_controlnet_inputs)[0].numpy() 141 | utils.report_correctness(baseline_out, reference_out, "control baseline to reference PyTorch") 142 | 143 | del pipe.controlnet 144 | gc.collect() 145 | 146 | coreml_sample_controlnet_inputs = { 147 | k: v.numpy().astype(np.float16) 148 | for k, v in sample_controlnet_inputs.items() 149 | } 150 | 151 | if args.multisize: 152 | sample_size = height 153 | sample_input_shape = ct.Shape(shape=( 154 | batch_size, 155 | controlnet_in_channels, 156 | ct.RangeDim(int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size), 157 | ct.RangeDim(int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size) 158 | )) 159 | 160 | cond_size = args.output_h 161 | cond_input_shape = ct.Shape(shape=( 162 | batch_size, 163 | 3, 164 | ct.RangeDim(int(cond_size * 0.5), upper_bound=int(cond_size * 2), default=cond_size), 165 | ct.RangeDim(int(cond_size * 0.5), upper_bound=int(cond_size * 2), default=cond_size) 166 | )) 167 | 168 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_controlnet_inputs, { 169 | "sample": sample_input_shape, 170 | "controlnet_cond": cond_input_shape, 171 | }) 172 | else: 173 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_controlnet_inputs) 174 | coreml_controlnet, out_path = utils.convert_to_coreml( 175 | "controlnet", 176 | reference_controlnet, 177 | sample_coreml_inputs, 178 | output_names, 179 | args.precision_full, 180 | args 181 | ) 182 | del reference_controlnet 183 | gc.collect() 184 | 185 | # Set model metadata 186 | coreml_controlnet.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 187 | coreml_controlnet.license = "ControlNet (https://github.com/lllyasviel/ControlNet)" 188 | coreml_controlnet.version = args.controlnet_version 189 | coreml_controlnet.short_description = \ 190 | "ControlNet is a neural network structure to control diffusion models by adding extra conditions. " \ 191 | "Please refer to https://github.com/lllyasviel/ControlNet for details." 192 | 193 | # Set the input descriptions 194 | coreml_controlnet.input_description["sample"] = \ 195 | "The low resolution latent feature maps being denoised through reverse diffusion" 196 | coreml_controlnet.input_description["timestep"] = \ 197 | "A value emitted by the associated scheduler object to condition the model on a given noise schedule" 198 | coreml_controlnet.input_description["encoder_hidden_states"] = \ 199 | "Output embeddings from the associated text_encoder model to condition to generated image on text. " \ 200 | "A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " \ 201 | "Shorter text does not reduce computation." 202 | coreml_controlnet.input_description["controlnet_cond"] = \ 203 | "Image used to condition ControlNet output" 204 | 205 | # Set the output descriptions 206 | coreml_controlnet.output_description["down_block_res_samples_00"] = "Residual down sample from ControlNet" 207 | coreml_controlnet.output_description["down_block_res_samples_01"] = "Residual down sample from ControlNet" 208 | coreml_controlnet.output_description["down_block_res_samples_02"] = "Residual down sample from ControlNet" 209 | coreml_controlnet.output_description["down_block_res_samples_03"] = "Residual down sample from ControlNet" 210 | coreml_controlnet.output_description["down_block_res_samples_04"] = "Residual down sample from ControlNet" 211 | coreml_controlnet.output_description["down_block_res_samples_05"] = "Residual down sample from ControlNet" 212 | coreml_controlnet.output_description["down_block_res_samples_06"] = "Residual down sample from ControlNet" 213 | coreml_controlnet.output_description["down_block_res_samples_07"] = "Residual down sample from ControlNet" 214 | coreml_controlnet.output_description["down_block_res_samples_08"] = "Residual down sample from ControlNet" 215 | if "down_block_res_samples_09" in output_names: 216 | coreml_controlnet.output_description["down_block_res_samples_09"] = "Residual down sample from ControlNet" 217 | coreml_controlnet.output_description["down_block_res_samples_10"] = "Residual down sample from ControlNet" 218 | coreml_controlnet.output_description["down_block_res_samples_11"] = "Residual down sample from ControlNet" 219 | coreml_controlnet.output_description["mid_block_res_sample"] = "Residual mid sample from ControlNet" 220 | 221 | # Set package version metadata 222 | coreml_controlnet.user_defined_metadata["identifier"] = args.controlnet_version 223 | coreml_controlnet.user_defined_metadata["converter_version"] = __version__ 224 | coreml_controlnet.user_defined_metadata["attention_implementation"] = args.attention_implementation 225 | coreml_controlnet.user_defined_metadata["compute_unit"] = args.compute_unit 226 | coreml_controlnet.user_defined_metadata["hidden_size"] = str(hidden_size) 227 | controlnet_method = utils.conditioning_method_from(args.controlnet_version) 228 | if controlnet_method: 229 | coreml_controlnet.user_defined_metadata["method"] = controlnet_method 230 | 231 | coreml_controlnet.save(out_path) 232 | logger.info(f"Saved controlnet into {out_path}") 233 | 234 | # Parity check PyTorch vs CoreML 235 | if args.check_output_correctness: 236 | coreml_out = list(coreml_controlnet.predict(coreml_sample_controlnet_inputs).values())[0] 237 | utils.report_correctness(baseline_out, coreml_out, "control baseline PyTorch to reference CoreML") 238 | 239 | del coreml_controlnet 240 | gc.collect() 241 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/models/controlnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools.models import attention 7 | from .layer_norm import LayerNormANE 8 | from .unet import Timesteps, TimestepEmbedding, UNetMidBlock2DCrossAttn 9 | from .unet import linear_to_conv2d_map, get_down_block, get_up_block 10 | 11 | from diffusers.configuration_utils import ConfigMixin, register_to_config 12 | from diffusers import ModelMixin 13 | 14 | from enum import Enum 15 | 16 | import logging 17 | 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.INFO) 20 | 21 | import math 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | # Ensure minimum macOS version requirement is met for this particular model 28 | from coremltools.models.utils import _macos_version 29 | if not _macos_version() >= (13, 1): 30 | logger.warning( 31 | "!!! macOS 13.1 and newer or iOS/iPadOS 16.2 and newer is required for best performance !!!" 32 | ) 33 | 34 | class ControlNetConditioningDefaultEmbedding(nn.Module): 35 | """ 36 | "Stable Diffusion uses a pre-processing method similar to VQ-GAN [11] to convert the entire dataset of 512 × 512 37 | images into smaller 64 × 64 “latent images” for stabilized training. This requires ControlNets to convert 38 | image-based conditions to 64 × 64 feature space to match the convolution size. We use a tiny network E(·) of four 39 | convolution layers with 4 × 4 kernels and 2 × 2 strides (activated by ReLU, channels are 16, 32, 64, 128, 40 | initialized with Gaussian weights, trained jointly with the full model) to encode image-space conditions ... into 41 | feature maps ..." 42 | """ 43 | 44 | def __init__( 45 | self, 46 | conditioning_embedding_channels: int, 47 | conditioning_channels: int = 3, 48 | block_out_channels = (16, 32, 96, 256), 49 | ): 50 | super().__init__() 51 | 52 | self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 53 | 54 | self.blocks = nn.ModuleList([]) 55 | 56 | for i in range(len(block_out_channels) - 1): 57 | channel_in = block_out_channels[i] 58 | channel_out = block_out_channels[i + 1] 59 | self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) 60 | self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 61 | 62 | self.conv_out = zero_module( 63 | nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 64 | ) 65 | 66 | def forward(self, conditioning): 67 | embedding = self.conv_in(conditioning) 68 | embedding = F.silu(embedding) 69 | 70 | for block in self.blocks: 71 | embedding = block(embedding) 72 | embedding = F.silu(embedding) 73 | 74 | embedding = self.conv_out(embedding) 75 | return embedding 76 | 77 | class ControlNetModel(ModelMixin, ConfigMixin): 78 | 79 | @register_to_config 80 | def __init__( 81 | self, 82 | sample_size=None, 83 | in_channels=4, 84 | out_channels=4, 85 | center_input_sample=False, 86 | flip_sin_to_cos=True, 87 | freq_shift=0, 88 | down_block_types=( 89 | "CrossAttnDownBlock2D", 90 | "CrossAttnDownBlock2D", 91 | "CrossAttnDownBlock2D", 92 | "DownBlock2D", 93 | ), 94 | only_cross_attention=False, 95 | block_out_channels=(320, 640, 1280, 1280), 96 | layers_per_block=2, 97 | downsample_padding=1, 98 | mid_block_scale_factor=1, 99 | act_fn="silu", 100 | norm_num_groups=32, 101 | norm_eps=1e-5, 102 | cross_attention_dim=768, 103 | attention_head_dim=8, 104 | transformer_layers_per_block=1, 105 | conv_in_kernel=3, 106 | conv_out_kernel=3, 107 | addition_embed_type=None, 108 | addition_time_embed_dim=None, 109 | projection_class_embeddings_input_dim=None, 110 | conditioning_embedding_out_channels=(16, 32, 96, 256), 111 | **kwargs, 112 | ): 113 | if kwargs.get("dual_cross_attention", None): 114 | raise NotImplementedError 115 | if kwargs.get("num_classs_embeds", None): 116 | raise NotImplementedError 117 | if only_cross_attention: 118 | raise NotImplementedError 119 | if kwargs.get("use_linear_projection", None): 120 | logger.warning("`use_linear_projection=True` is ignored!") 121 | 122 | super().__init__() 123 | self._register_load_state_dict_pre_hook(linear_to_conv2d_map) 124 | 125 | self.sample_size = sample_size 126 | time_embed_dim = block_out_channels[0] * 4 127 | 128 | # input 129 | conv_in_padding = (conv_in_kernel - 1) // 2 130 | self.conv_in = nn.Conv2d( 131 | in_channels, 132 | block_out_channels[0], 133 | kernel_size=conv_in_kernel, 134 | padding=conv_in_padding 135 | ) 136 | 137 | # time 138 | time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 139 | timestep_input_dim = block_out_channels[0] 140 | time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) 141 | 142 | if addition_embed_type == "text_time": 143 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) 144 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 145 | elif addition_embed_type is not None: 146 | raise NotImplementedError 147 | 148 | self.time_proj = time_proj 149 | self.time_embedding = time_embedding 150 | 151 | self.controlnet_cond_embedding = ControlNetConditioningDefaultEmbedding( 152 | conditioning_embedding_channels=block_out_channels[0], 153 | block_out_channels=conditioning_embedding_out_channels, 154 | ) 155 | 156 | self.down_blocks = nn.ModuleList([]) 157 | self.mid_block = None 158 | self.controlnet_down_blocks = nn.ModuleList([]) 159 | 160 | if isinstance(only_cross_attention, bool): 161 | only_cross_attention = [only_cross_attention] * len(down_block_types) 162 | 163 | if isinstance(attention_head_dim, int): 164 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 165 | 166 | if isinstance(transformer_layers_per_block, int): 167 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 168 | 169 | # down 170 | output_channel = block_out_channels[0] 171 | 172 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 173 | controlnet_block = zero_module(controlnet_block) 174 | self.controlnet_down_blocks.append(controlnet_block) 175 | 176 | for i, down_block_type in enumerate(down_block_types): 177 | input_channel = output_channel 178 | output_channel = block_out_channels[i] 179 | is_final_block = i == len(block_out_channels) - 1 180 | 181 | down_block = get_down_block( 182 | down_block_type, 183 | num_layers=layers_per_block, 184 | transformer_layers_per_block=transformer_layers_per_block[i], 185 | in_channels=input_channel, 186 | out_channels=output_channel, 187 | temb_channels=time_embed_dim, 188 | add_downsample=not is_final_block, 189 | resnet_eps=norm_eps, 190 | resnet_act_fn=act_fn, 191 | cross_attention_dim=cross_attention_dim, 192 | attn_num_head_channels=attention_head_dim[i], 193 | downsample_padding=downsample_padding, 194 | ) 195 | self.down_blocks.append(down_block) 196 | 197 | for _ in range(layers_per_block): 198 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 199 | controlnet_block = zero_module(controlnet_block) 200 | self.controlnet_down_blocks.append(controlnet_block) 201 | 202 | if not is_final_block: 203 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 204 | controlnet_block = zero_module(controlnet_block) 205 | self.controlnet_down_blocks.append(controlnet_block) 206 | 207 | # mid 208 | mid_block_channel = block_out_channels[-1] 209 | 210 | controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 211 | controlnet_block = zero_module(controlnet_block) 212 | self.controlnet_mid_block = controlnet_block 213 | 214 | self.mid_block = UNetMidBlock2DCrossAttn( 215 | transformer_layers_per_block=transformer_layers_per_block[-1], 216 | in_channels=mid_block_channel, 217 | temb_channels=time_embed_dim, 218 | resnet_eps=norm_eps, 219 | resnet_act_fn=act_fn, 220 | output_scale_factor=mid_block_scale_factor, 221 | resnet_time_scale_shift="default", 222 | cross_attention_dim=cross_attention_dim, 223 | attn_num_head_channels=attention_head_dim[i], 224 | resnet_groups=norm_num_groups, 225 | ) 226 | 227 | def forward( 228 | self, 229 | sample, 230 | timestep, 231 | encoder_hidden_states, 232 | controlnet_cond, 233 | text_embeds=None, 234 | time_ids=None 235 | ): 236 | # 0. Project (or look-up) time embeddings 237 | t_emb = self.time_proj(timestep) 238 | emb = self.time_embedding(t_emb) 239 | 240 | if hasattr(self.config, "addition_embed_type") and self.config.addition_embed_type == "text_time": 241 | time_embeds = self.add_time_proj(time_ids.flatten()) 242 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 243 | 244 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 245 | add_embeds = add_embeds.to(emb.dtype) 246 | aug_emb = self.add_embedding(add_embeds) 247 | emb = emb + aug_emb 248 | 249 | # 1. center input if necessary 250 | if self.config.center_input_sample: 251 | sample = 2 * sample - 1.0 252 | 253 | # 2. pre-process 254 | sample = self.conv_in(sample) 255 | 256 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 257 | 258 | sample += controlnet_cond 259 | 260 | # 3. down 261 | down_block_res_samples = (sample, ) 262 | for downsample_block in self.down_blocks: 263 | if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: 264 | sample, res_samples = downsample_block( 265 | hidden_states=sample, 266 | temb=emb, 267 | encoder_hidden_states=encoder_hidden_states 268 | ) 269 | else: 270 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 271 | 272 | down_block_res_samples += res_samples 273 | 274 | # 4. mid 275 | sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states) 276 | 277 | # 5. Control net blocks 278 | 279 | controlnet_down_block_res_samples = () 280 | 281 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 282 | down_block_res_sample = controlnet_block(down_block_res_sample) 283 | controlnet_down_block_res_samples += (down_block_res_sample, ) 284 | 285 | down_block_res_samples = controlnet_down_block_res_samples 286 | 287 | mid_block_res_sample = self.controlnet_mid_block(sample) 288 | 289 | return (down_block_res_samples, mid_block_res_sample) 290 | 291 | 292 | def zero_module(module): 293 | for p in module.parameters(): 294 | nn.init.zeros_(p) 295 | return module 296 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/convert/convert_vae.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from guernikatools.utils import utils 8 | 9 | from collections import OrderedDict, defaultdict 10 | from copy import deepcopy 11 | import coremltools as ct 12 | import gc 13 | 14 | import logging 15 | 16 | logging.basicConfig() 17 | logger = logging.getLogger(__name__) 18 | logger.setLevel(logging.INFO) 19 | 20 | import numpy as np 21 | import os 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | 27 | torch.set_grad_enabled(False) 28 | 29 | 30 | def modify_coremltools_torch_frontend_badbmm(): 31 | """ 32 | Modifies coremltools torch frontend for baddbmm to be robust to the `beta` argument being of non-float dtype: 33 | e.g. https://github.com/huggingface/diffusers/blob/v0.8.1/src/diffusers/models/attention.py#L315 34 | """ 35 | from coremltools.converters.mil import register_torch_op 36 | from coremltools.converters.mil.mil import Builder as mb 37 | from coremltools.converters.mil.frontend.torch.ops import _get_inputs 38 | from coremltools.converters.mil.frontend.torch.torch_op_registry import _TORCH_OPS_REGISTRY 39 | if "baddbmm" in _TORCH_OPS_REGISTRY: 40 | del _TORCH_OPS_REGISTRY["baddbmm"] 41 | 42 | @register_torch_op 43 | def baddbmm(context, node): 44 | """ 45 | baddbmm(Tensor input, Tensor batch1, Tensor batch2, Scalar beta=1, Scalar alpha=1) 46 | output = beta * input + alpha * batch1 * batch2 47 | Notice that batch1 and batch2 must be 3-D tensors each containing the same number of matrices. 48 | If batch1 is a (b×n×m) tensor, batch2 is a (b×m×p) tensor, then input must be broadcastable with a (b×n×p) tensor 49 | and out will be a (b×n×p) tensor. 50 | """ 51 | assert len(node.outputs) == 1 52 | inputs = _get_inputs(context, node, expected=5) 53 | bias, batch1, batch2, beta, alpha = inputs 54 | 55 | if beta.val != 1.0: 56 | # Apply scaling factor beta to the bias. 57 | if beta.val.dtype == np.int32: 58 | beta = mb.cast(x=beta, dtype="fp32") 59 | logger.warning( 60 | f"Casted the `beta`(value={beta.val}) argument of `baddbmm` op " 61 | "from int32 to float32 dtype for conversion!") 62 | bias = mb.mul(x=beta, y=bias, name=bias.name + "_scaled") 63 | 64 | context.add(bias) 65 | 66 | if alpha.val != 1.0: 67 | # Apply scaling factor alpha to the input. 68 | batch1 = mb.mul(x=alpha, y=batch1, name=batch1.name + "_scaled") 69 | context.add(batch1) 70 | 71 | bmm_node = mb.matmul(x=batch1, y=batch2, name=node.name + "_bmm") 72 | context.add(bmm_node) 73 | 74 | baddbmm_node = mb.add(x=bias, y=bmm_node, name=node.name) 75 | context.add(baddbmm_node) 76 | 77 | 78 | def encoder(pipe, args): 79 | """ Converts the VAE Encoder component of Stable Diffusion 80 | """ 81 | out_path = utils.get_out_path(args, "vae_encoder") 82 | if os.path.exists(out_path): 83 | logger.info(f"`vae_encoder` already exists at {out_path}, skipping conversion.") 84 | return 85 | 86 | if not hasattr(pipe, "unet"): 87 | raise RuntimeError( 88 | "convert_unet() deletes pipe.unet to save RAM. " 89 | "Please use convert_vae_encoder() before convert_unet()") 90 | 91 | z_shape = ( 92 | 1, # B 93 | 3, # C 94 | args.output_h, # H 95 | args.output_w, # w 96 | ) 97 | 98 | sample_vae_encoder_inputs = { 99 | "z": torch.rand(*z_shape, dtype=torch.float16) 100 | } 101 | 102 | class VAEEncoder(nn.Module): 103 | """ Wrapper nn.Module wrapper for pipe.encode() method 104 | """ 105 | 106 | def __init__(self): 107 | super().__init__() 108 | self.quant_conv = pipe.vae.quant_conv 109 | self.encoder = pipe.vae.encoder 110 | 111 | def forward(self, z): 112 | return self.quant_conv(self.encoder(z)) 113 | 114 | baseline_encoder = VAEEncoder().eval() 115 | 116 | # No optimization needed for the VAE Encoder as it is a pure ConvNet 117 | traced_vae_encoder = torch.jit.trace(baseline_encoder, (sample_vae_encoder_inputs["z"].to(torch.float32), )) 118 | 119 | modify_coremltools_torch_frontend_badbmm() 120 | 121 | # TODO: For now using variable size takes too much memory and time 122 | # if args.multisize: 123 | # sample_size = args.output_h 124 | # input_shape = ct.Shape(shape=( 125 | # 1, 126 | # 3, 127 | # ct.RangeDim(lower_bound=int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size), 128 | # ct.RangeDim(lower_bound=int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size) 129 | # )) 130 | # sample_coreml_inputs = utils.get_coreml_inputs(sample_vae_encoder_inputs, {"z": input_shape}) 131 | # else: 132 | # sample_coreml_inputs = utils.get_coreml_inputs(sample_vae_encoder_inputs) 133 | sample_coreml_inputs = utils.get_coreml_inputs(sample_vae_encoder_inputs) 134 | 135 | # SDXL seems to require full precision 136 | precision_full = args.model_is_sdxl and args.model_version != "stabilityai/stable-diffusion-xl-base-1.0" 137 | coreml_vae_encoder, out_path = utils.convert_to_coreml( 138 | "vae_encoder", traced_vae_encoder, sample_coreml_inputs, 139 | ["latent"], precision_full, args 140 | ) 141 | 142 | # Set model metadata 143 | coreml_vae_encoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 144 | coreml_vae_encoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)" 145 | coreml_vae_encoder.version = args.model_version 146 | coreml_vae_encoder.short_description = \ 147 | "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \ 148 | "Please refer to https://arxiv.org/abs/2112.10752 for details." 149 | 150 | # Set the input descriptions 151 | coreml_vae_encoder.input_description["z"] = \ 152 | "The input image to base the initial latents on normalized to range [-1, 1]" 153 | 154 | # Set the output descriptions 155 | coreml_vae_encoder.output_description["latent"] = "The latent embeddings from the unet model from the input image." 156 | 157 | # Set package version metadata 158 | coreml_vae_encoder.user_defined_metadata["identifier"] = args.model_version 159 | coreml_vae_encoder.user_defined_metadata["converter_version"] = __version__ 160 | coreml_vae_encoder.user_defined_metadata["attention_implementation"] = args.attention_implementation 161 | coreml_vae_encoder.user_defined_metadata["compute_unit"] = args.compute_unit 162 | coreml_vae_encoder.user_defined_metadata["scaling_factor"] = str(pipe.vae.config.scaling_factor) 163 | # TODO: Add this key to stop using the hack when variable size works correctly 164 | coreml_vae_encoder.user_defined_metadata["supports_model_size_hack"] = "true" 165 | 166 | coreml_vae_encoder.save(out_path) 167 | 168 | logger.info(f"Saved vae_encoder into {out_path}") 169 | 170 | # Parity check PyTorch vs CoreML 171 | if args.check_output_correctness: 172 | baseline_out = baseline_encoder(z=sample_vae_encoder_inputs["z"].to(torch.float32)).detach().numpy() 173 | coreml_out = list(coreml_vae_encoder.predict({ 174 | k: v.numpy() for k, v in sample_vae_encoder_inputs.items() 175 | }).values())[0] 176 | utils.report_correctness(baseline_out, coreml_out,"vae_encoder baseline PyTorch to baseline CoreML") 177 | 178 | del traced_vae_encoder, pipe.vae.encoder, coreml_vae_encoder 179 | gc.collect() 180 | 181 | 182 | def decoder(pipe, args): 183 | """ Converts the VAE Decoder component of Stable Diffusion 184 | """ 185 | out_path = utils.get_out_path(args, "vae_decoder") 186 | if os.path.exists(out_path): 187 | logger.info(f"`vae_decoder` already exists at {out_path}, skipping conversion.") 188 | return 189 | 190 | if not hasattr(pipe, "unet"): 191 | raise RuntimeError( 192 | "convert_unet() deletes pipe.unet to save RAM. " 193 | "Please use convert_vae_decoder() before convert_unet()") 194 | 195 | vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1) 196 | height = int(args.output_h / vae_scale_factor) 197 | width = int(args.output_w / vae_scale_factor) 198 | vae_latent_channels = pipe.vae.config.latent_channels 199 | z_shape = ( 200 | 1, # B 201 | vae_latent_channels, # C 202 | height, # H 203 | width, # w 204 | ) 205 | 206 | sample_vae_decoder_inputs = { 207 | "z": torch.rand(*z_shape, dtype=torch.float16) 208 | } 209 | 210 | class VAEDecoder(nn.Module): 211 | """ Wrapper nn.Module wrapper for pipe.decode() method 212 | """ 213 | 214 | def __init__(self): 215 | super().__init__() 216 | self.post_quant_conv = pipe.vae.post_quant_conv 217 | self.decoder = pipe.vae.decoder 218 | 219 | def forward(self, z): 220 | return self.decoder(self.post_quant_conv(z)) 221 | 222 | baseline_decoder = VAEDecoder().eval() 223 | 224 | # No optimization needed for the VAE Decoder as it is a pure ConvNet 225 | traced_vae_decoder = torch.jit.trace(baseline_decoder, (sample_vae_decoder_inputs["z"].to(torch.float32), )) 226 | 227 | modify_coremltools_torch_frontend_badbmm() 228 | 229 | # TODO: For now using variable size takes too much memory and time 230 | # if args.multisize: 231 | # sample_size = height 232 | # input_shape = ct.Shape(shape=( 233 | # 1, 234 | # vae_latent_channels, 235 | # ct.RangeDim(int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size), 236 | # ct.RangeDim(int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size) 237 | # )) 238 | # sample_coreml_inputs = utils.get_coreml_inputs(sample_vae_decoder_inputs, {"z": input_shape}) 239 | # else: 240 | # sample_coreml_inputs = utils.get_coreml_inputs(sample_vae_decoder_inputs) 241 | sample_coreml_inputs = utils.get_coreml_inputs(sample_vae_decoder_inputs) 242 | 243 | # SDXL seems to require full precision 244 | precision_full = args.model_is_sdxl and args.model_version != "stabilityai/stable-diffusion-xl-base-1.0" 245 | coreml_vae_decoder, out_path = utils.convert_to_coreml( 246 | "vae_decoder", traced_vae_decoder, sample_coreml_inputs, 247 | ["image"], precision_full, args 248 | ) 249 | 250 | # Set model metadata 251 | coreml_vae_decoder.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 252 | coreml_vae_decoder.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)" 253 | coreml_vae_decoder.version = args.model_version 254 | coreml_vae_decoder.short_description = \ 255 | "Stable Diffusion generates images conditioned on text and/or other images as input through the diffusion process. " \ 256 | "Please refer to https://arxiv.org/abs/2112.10752 for details." 257 | 258 | # Set the input descriptions 259 | coreml_vae_decoder.input_description["z"] = \ 260 | "The denoised latent embeddings from the unet model after the last step of reverse diffusion" 261 | 262 | # Set the output descriptions 263 | coreml_vae_decoder.output_description["image"] = "Generated image normalized to range [-1, 1]" 264 | 265 | # Set package version metadata 266 | coreml_vae_decoder.user_defined_metadata["identifier"] = args.model_version 267 | coreml_vae_decoder.user_defined_metadata["converter_version"] = __version__ 268 | coreml_vae_decoder.user_defined_metadata["attention_implementation"] = args.attention_implementation 269 | coreml_vae_decoder.user_defined_metadata["compute_unit"] = args.compute_unit 270 | coreml_vae_decoder.user_defined_metadata["scaling_factor"] = str(pipe.vae.config.scaling_factor) 271 | # TODO: Add this key to stop using the hack when variable size works correctly 272 | coreml_vae_decoder.user_defined_metadata["supports_model_size_hack"] = "true" 273 | 274 | coreml_vae_decoder.save(out_path) 275 | 276 | logger.info(f"Saved vae_decoder into {out_path}") 277 | 278 | # Parity check PyTorch vs CoreML 279 | if args.check_output_correctness: 280 | baseline_out = baseline_decoder(z=sample_vae_decoder_inputs["z"].to(torch.float32)).detach().numpy() 281 | coreml_out = list(coreml_vae_decoder.predict({ 282 | k: v.numpy() for k, v in sample_vae_decoder_inputs.items() 283 | }).values())[0] 284 | utils.report_correctness(baseline_out, coreml_out, "vae_decoder baseline PyTorch to baseline CoreML") 285 | 286 | del traced_vae_decoder, pipe.vae.decoder, coreml_vae_decoder 287 | gc.collect() 288 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/utils/chunk_mlprogram.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools.utils import utils 7 | 8 | import argparse 9 | from collections import OrderedDict 10 | 11 | import coremltools as ct 12 | from coremltools.converters.mil import Block, Program, Var 13 | from coremltools.converters.mil.frontend.milproto.load import load as _milproto_to_pymil 14 | from coremltools.converters.mil.mil import Builder as mb 15 | from coremltools.converters.mil.mil import Placeholder 16 | from coremltools.converters.mil.mil import types as types 17 | from coremltools.converters.mil.mil.passes.helper import block_context_manager 18 | from coremltools.converters.mil.mil.passes.pass_registry import PASS_REGISTRY 19 | from coremltools.converters.mil.testing_utils import random_gen_input_feature_type 20 | 21 | import gc 22 | 23 | import logging 24 | 25 | logging.basicConfig() 26 | logger = logging.getLogger(__name__) 27 | logger.setLevel(logging.INFO) 28 | 29 | import numpy as np 30 | import os 31 | import tempfile 32 | from guernikatools import torch2coreml 33 | import shutil 34 | import time 35 | 36 | 37 | def _verify_output_correctness_of_chunks(full_model, first_chunk_model, 38 | second_chunk_model): 39 | """ Verifies the end-to-end output correctness of full (original) model versus chunked models 40 | """ 41 | # Generate inputs for first chunk and full model 42 | input_dict = {} 43 | for input_desc in full_model._spec.description.input: 44 | input_dict[input_desc.name] = random_gen_input_feature_type(input_desc) 45 | 46 | # Generate outputs for first chunk and full model 47 | outputs_from_full_model = full_model.predict(input_dict) 48 | outputs_from_first_chunk_model = first_chunk_model.predict(input_dict) 49 | 50 | # Prepare inputs for second chunk model from first chunk's outputs and regular inputs 51 | second_chunk_input_dict = {} 52 | for input_desc in second_chunk_model._spec.description.input: 53 | if input_desc.name in outputs_from_first_chunk_model: 54 | second_chunk_input_dict[ 55 | input_desc.name] = outputs_from_first_chunk_model[ 56 | input_desc.name] 57 | else: 58 | second_chunk_input_dict[input_desc.name] = input_dict[ 59 | input_desc.name] 60 | 61 | # Generate output for second chunk model 62 | outputs_from_second_chunk_model = second_chunk_model.predict( 63 | second_chunk_input_dict) 64 | 65 | # Verify correctness across all outputs from second chunk and full model 66 | for out_name in outputs_from_full_model.keys(): 67 | utils.report_correctness( 68 | original_outputs=outputs_from_full_model[out_name], 69 | final_outputs=outputs_from_second_chunk_model[out_name], 70 | log_prefix=f"{out_name}") 71 | 72 | 73 | def _load_prog_from_mlmodel(model): 74 | """ Load MIL Program from an MLModel 75 | """ 76 | model_spec = model.get_spec() 77 | start_ = time.time() 78 | logger.info( 79 | "Loading MLModel object into a MIL Program object (including the weights).." 80 | ) 81 | prog = _milproto_to_pymil( 82 | model_spec=model_spec, 83 | specification_version=model_spec.specificationVersion, 84 | file_weights_dir=model.weights_dir, 85 | ) 86 | logger.info(f"Program loaded in {time.time() - start_:.1f} seconds") 87 | 88 | return prog 89 | 90 | 91 | def _get_op_idx_split_location(prog: Program): 92 | """ Find the op that approximately bisects the graph as measure by weights size on each side 93 | """ 94 | main_block = prog.functions["main"] 95 | total_size_in_mb = 0 96 | 97 | for op in main_block.operations: 98 | if op.op_type == "const" and isinstance(op.val.val, np.ndarray): 99 | size_in_mb = op.val.val.size * op.val.val.itemsize / (1024 * 1024) 100 | total_size_in_mb += size_in_mb 101 | half_size = total_size_in_mb / 2 102 | 103 | # Find the first non const op (single child), where the total cumulative size exceeds 104 | # the half size for the first time 105 | cumulative_size_in_mb = 0 106 | for op in main_block.operations: 107 | if op.op_type == "const" and isinstance(op.val.val, np.ndarray): 108 | size_in_mb = op.val.val.size * op.val.val.itemsize / (1024 * 1024) 109 | cumulative_size_in_mb += size_in_mb 110 | 111 | # Note: The condition "not op.op_type.startswith("const")" is to make sure that the 112 | # incision op is neither of type "const" nor "constexpr_*" ops that 113 | # are used to store compressed weights 114 | if (cumulative_size_in_mb > half_size and not op.op_type.startswith("const") 115 | and len(op.outputs) == 1 116 | and len(op.outputs[0].child_ops) == 1): 117 | op_idx = main_block.operations.index(op) 118 | return op_idx, cumulative_size_in_mb, total_size_in_mb 119 | 120 | 121 | def _get_first_chunk_outputs(block, op_idx): 122 | # Get the list of all vars that go across from first program (all ops from 0 to op_idx (inclusive)) 123 | # to the second program (all ops from op_idx+1 till the end). These all vars need to be made the output 124 | # of the first program and the input of the second program 125 | boundary_vars = set() 126 | for i in range(op_idx + 1): 127 | op = block.operations[i] 128 | if not op.op_type.startswith("const"): 129 | for var in op.outputs: 130 | if var.val is None: # only consider non const vars 131 | for child_op in var.child_ops: 132 | child_op_idx = block.operations.index(child_op) 133 | if child_op_idx > op_idx: 134 | boundary_vars.add(var) 135 | return list(boundary_vars) 136 | 137 | 138 | @block_context_manager 139 | def _add_fp32_casts(block, boundary_vars): 140 | new_boundary_vars = [] 141 | for var in boundary_vars: 142 | if var.dtype != types.fp16: 143 | new_boundary_vars.append(var) 144 | else: 145 | fp32_var = mb.cast(x=var, dtype="fp32", name=var.name) 146 | new_boundary_vars.append(fp32_var) 147 | return new_boundary_vars 148 | 149 | 150 | def _make_first_chunk_prog(prog, op_idx): 151 | """ Build first chunk by declaring early outputs and removing unused subgraph 152 | """ 153 | block = prog.functions["main"] 154 | boundary_vars = _get_first_chunk_outputs(block, op_idx) 155 | 156 | # Due to possible numerical issues, cast any fp16 var to fp32 157 | new_boundary_vars = _add_fp32_casts(block, boundary_vars) 158 | 159 | block.outputs.clear() 160 | block.set_outputs(new_boundary_vars) 161 | PASS_REGISTRY["common::dead_code_elimination"](prog) 162 | return prog 163 | 164 | 165 | def _make_second_chunk_prog(prog, op_idx): 166 | """ Build second chunk by rebuilding a pristine MIL Program from MLModel 167 | """ 168 | block = prog.functions["main"] 169 | block.opset_version = ct.target.iOS16 170 | 171 | # First chunk outputs are second chunk inputs (e.g. skip connections) 172 | boundary_vars = _get_first_chunk_outputs(block, op_idx) 173 | 174 | # This op will not be included in this program. Its output var will be made into an input 175 | boundary_op = block.operations[op_idx] 176 | 177 | # Add all boundary ops as inputs 178 | with block: 179 | for var in boundary_vars: 180 | new_placeholder = Placeholder( 181 | sym_shape=var.shape, 182 | dtype=var.dtype if var.dtype != types.fp16 else types.fp32, 183 | name=var.name, 184 | ) 185 | 186 | block._input_dict[ 187 | new_placeholder.outputs[0].name] = new_placeholder.outputs[0] 188 | 189 | block.function_inputs = tuple(block._input_dict.values()) 190 | new_var = None 191 | if var.dtype == types.fp16: 192 | new_var = mb.cast(x=new_placeholder.outputs[0], 193 | dtype="fp16", 194 | before_op=var.op) 195 | else: 196 | new_var = new_placeholder.outputs[0] 197 | 198 | block.replace_uses_of_var_after_op( 199 | anchor_op=boundary_op, 200 | old_var=var, 201 | new_var=new_var, 202 | # This is needed if the program contains "constexpr_*" ops. In normal cases, there are stricter 203 | # rules for removing them, and their presence may prevent replacing this var. 204 | # However in this case, since we want to remove all the ops in chunk 1, we can safely 205 | # set this to True. 206 | force_replace=True, 207 | ) 208 | 209 | PASS_REGISTRY["common::dead_code_elimination"](prog) 210 | 211 | # Remove any unused inputs 212 | new_input_dict = OrderedDict() 213 | for k, v in block._input_dict.items(): 214 | if len(v.child_ops) > 0: 215 | new_input_dict[k] = v 216 | block._input_dict = new_input_dict 217 | block.function_inputs = tuple(block._input_dict.values()) 218 | 219 | return prog 220 | 221 | 222 | def main(args): 223 | os.makedirs(args.o, exist_ok=True) 224 | 225 | # Check filename extension 226 | mlpackage_name = os.path.basename(args.mlpackage_path) 227 | name, ext = os.path.splitext(mlpackage_name) 228 | assert ext == ".mlpackage", f"`--mlpackage-path` (args.mlpackage_path) is not an .mlpackage file" 229 | 230 | # Load CoreML model 231 | logger.info("Loading model from {}".format(args.mlpackage_path)) 232 | start_ = time.time() 233 | model = ct.models.MLModel( 234 | args.mlpackage_path, 235 | compute_units=ct.ComputeUnit.CPU_ONLY, 236 | ) 237 | logger.info( 238 | f"Loading {args.mlpackage_path} took {time.time() - start_:.1f} seconds" 239 | ) 240 | 241 | # Load the MIL Program from MLModel 242 | prog = _load_prog_from_mlmodel(model) 243 | 244 | # Compute the incision point by bisecting the program based on weights size 245 | op_idx, first_chunk_weights_size, total_weights_size = _get_op_idx_split_location( 246 | prog) 247 | main_block = prog.functions["main"] 248 | incision_op = main_block.operations[op_idx] 249 | logger.info(f"{args.mlpackage_path} will chunked into two pieces.") 250 | logger.info( 251 | f"The incision op: name={incision_op.name}, type={incision_op.op_type}, index={op_idx}/{len(main_block.operations)}" 252 | ) 253 | logger.info(f"First chunk size = {first_chunk_weights_size:.2f} MB") 254 | logger.info( 255 | f"Second chunk size = {total_weights_size - first_chunk_weights_size:.2f} MB" 256 | ) 257 | 258 | # Build first chunk (in-place modifies prog by declaring early exits and removing unused subgraph) 259 | prog_chunk1 = _make_first_chunk_prog(prog, op_idx) 260 | 261 | # Build the second chunk 262 | prog_chunk2 = _make_second_chunk_prog(_load_prog_from_mlmodel(model), op_idx) 263 | 264 | user_defined_metadata = model.user_defined_metadata 265 | 266 | if not args.check_output_correctness: 267 | # Original model no longer needed in memory 268 | del model 269 | gc.collect() 270 | 271 | # Convert the MIL Program objects into MLModels 272 | logger.info("Converting the two programs") 273 | model_chunk1 = ct.convert( 274 | prog_chunk1, 275 | convert_to="mlprogram", 276 | compute_units=ct.ComputeUnit.CPU_ONLY, 277 | minimum_deployment_target=ct.target.iOS16, 278 | ) 279 | del prog_chunk1 280 | gc.collect() 281 | logger.info("Conversion of first chunk done.") 282 | 283 | model_chunk2 = ct.convert( 284 | prog_chunk2, 285 | convert_to="mlprogram", 286 | compute_units=ct.ComputeUnit.CPU_ONLY, 287 | minimum_deployment_target=ct.target.iOS16, 288 | ) 289 | del prog_chunk2 290 | gc.collect() 291 | logger.info("Conversion of second chunk done.") 292 | 293 | for key in user_defined_metadata: 294 | if "com.github.apple" not in key: 295 | model_chunk1.user_defined_metadata[key] = user_defined_metadata[key] 296 | model_chunk2.user_defined_metadata[key] = user_defined_metadata[key] 297 | 298 | # Verify output correctness 299 | if args.check_output_correctness: 300 | logger.info("Verifying output correctness of chunks") 301 | _verify_output_correctness_of_chunks( 302 | full_model=model, 303 | first_chunk_model=model_chunk1, 304 | second_chunk_model=model_chunk2, 305 | ) 306 | 307 | # Remove original (non-chunked) model if requested 308 | if args.remove_original: 309 | logger.info( 310 | "Removing original (non-chunked) model at {args.mlpackage_path}") 311 | shutil.rmtree(args.mlpackage_path) 312 | logger.info("Done.") 313 | 314 | # Save the chunked models to disk 315 | out_path_chunk1 = os.path.join(args.o, name + "_chunk1.mlpackage") 316 | out_path_chunk2 = os.path.join(args.o, name + "_chunk2.mlpackage") 317 | if args.clean_up_mlpackages: 318 | temp_dir = tempfile.gettempdir() 319 | out_path_chunk1 = os.path.join(temp_dir, name + "_chunk1.mlpackage") 320 | out_path_chunk2 = os.path.join(temp_dir, name + "_chunk2.mlpackage") 321 | 322 | logger.info( 323 | f"Saved chunks in {args.o} with the suffix _chunk1.mlpackage and _chunk2.mlpackage" 324 | ) 325 | model_chunk1.save(out_path_chunk1) 326 | model_chunk2.save(out_path_chunk2) 327 | logger.info("Done.") 328 | 329 | 330 | if __name__ == "__main__": 331 | parser = argparse.ArgumentParser() 332 | parser.add_argument( 333 | "--mlpackage-path", 334 | required=True, 335 | help= 336 | "Path to the mlpackage file to be split into two mlpackages of approximately same file size.", 337 | ) 338 | parser.add_argument( 339 | "-o", 340 | required=True, 341 | help= 342 | "Path to output directory where the two model chunks should be saved.", 343 | ) 344 | parser.add_argument( 345 | "--clean-up-mlpackages", 346 | action="store_true", 347 | help="Removes mlpackages after a successful convesion.") 348 | parser.add_argument( 349 | "--remove-original", 350 | action="store_true", 351 | help= 352 | "If specified, removes the original (non-chunked) model to avoid duplicating storage." 353 | ) 354 | parser.add_argument( 355 | "--check-output-correctness", 356 | action="store_true", 357 | help= 358 | ("If specified, compares the outputs of original Core ML model with that of pipelined CoreML model chunks and reports PSNR in dB. ", 359 | "Enabling this feature uses more memory. Disable it if your machine runs out of memory." 360 | )) 361 | 362 | args = parser.parse_args() 363 | main(args) 364 | -------------------------------------------------------------------------------- /GuernikaTools/guernikatools/convert/convert_unet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from guernikatools._version import __version__ 7 | from guernikatools.utils import utils, chunk_mlprogram 8 | from guernikatools.models import attention, unet 9 | 10 | from collections import OrderedDict, defaultdict 11 | from copy import deepcopy 12 | import coremltools as ct 13 | import gc 14 | 15 | import logging 16 | 17 | logging.basicConfig() 18 | logger = logging.getLogger(__name__) 19 | logger.setLevel(logging.INFO) 20 | 21 | import numpy as np 22 | import os 23 | 24 | import torch 25 | import torch.nn as nn 26 | import torch.nn.functional as F 27 | 28 | torch.set_grad_enabled(False) 29 | 30 | 31 | def main(pipe, args): 32 | """ Converts the UNet component of Stable Diffusion 33 | """ 34 | out_path = utils.get_out_path(args, "unet") 35 | 36 | # Check if Unet was previously exported and then chunked 37 | unet_chunks_exist = all( 38 | os.path.exists(out_path.replace(".mlpackage", f"_chunk{idx+1}.mlpackage")) for idx in range(2) 39 | ) 40 | 41 | if args.chunk_unet and unet_chunks_exist: 42 | logger.info("`unet` chunks already exist, skipping conversion.") 43 | del pipe.unet 44 | gc.collect() 45 | return 46 | 47 | # If original Unet does not exist, export it from PyTorch+diffusers 48 | elif not os.path.exists(out_path): 49 | # Register the selected attention implementation globally 50 | attention.ATTENTION_IMPLEMENTATION_IN_EFFECT = attention.AttentionImplementations[args.attention_implementation] 51 | logger.info(f"Attention implementation in effect: {attention.ATTENTION_IMPLEMENTATION_IN_EFFECT}") 52 | 53 | # Prepare sample input shapes and values 54 | batch_size = 2 # for classifier-free guidance 55 | unet_in_channels = pipe.unet.config.in_channels 56 | # allow converting instruct pix2pix 57 | if unet_in_channels == 8: 58 | batch_size = 3 59 | vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1) 60 | height = int(args.output_h / vae_scale_factor) 61 | width = int(args.output_w / vae_scale_factor) 62 | 63 | sample_shape = ( 64 | batch_size, # B 65 | unet_in_channels, # C 66 | height, # H 67 | width, # W 68 | ) 69 | 70 | max_position_embeddings = 77 71 | if hasattr(pipe, "text_encoder") and pipe.text_encoder and pipe.text_encoder.config: 72 | max_position_embeddings = pipe.text_encoder.config.max_position_embeddings 73 | elif hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2 and pipe.text_encoder_2.config: 74 | max_position_embeddings = pipe.text_encoder_2.config.max_position_embeddings 75 | else: 76 | raise RuntimeError( 77 | "convert_text_encoder() deletes pipe.text_encoder to save RAM. " 78 | "Please use convert_unet() before convert_text_encoder()") 79 | 80 | args.hidden_size = pipe.unet.config.cross_attention_dim 81 | encoder_hidden_states_shape = ( 82 | batch_size, 83 | args.hidden_size, 84 | 1, 85 | max_position_embeddings, 86 | ) 87 | 88 | # Create the scheduled timesteps for downstream use 89 | DEFAULT_NUM_INFERENCE_STEPS = 50 90 | pipe.scheduler.set_timesteps(DEFAULT_NUM_INFERENCE_STEPS) 91 | 92 | sample_unet_inputs = [ 93 | ("sample", torch.rand(*sample_shape)), 94 | ("timestep", torch.tensor([pipe.scheduler.timesteps[0].item()] * (batch_size)).to(torch.float32)), 95 | ("encoder_hidden_states", torch.rand(*encoder_hidden_states_shape)), 96 | ] 97 | 98 | args.requires_aesthetics_score = False 99 | if hasattr(pipe.unet.config, "addition_embed_type") and pipe.unet.config.addition_embed_type == "text_time": 100 | text_embeds_shape = ( 101 | batch_size, 102 | pipe.text_encoder_2.config.hidden_size, 103 | ) 104 | time_ids_input = None 105 | if hasattr(pipe.config, "requires_aesthetics_score") and pipe.config.requires_aesthetics_score: 106 | args.requires_aesthetics_score = True 107 | time_ids_input = [ 108 | [args.output_h, args.output_w, 0, 0, 2.5], 109 | [args.output_h, args.output_w, 0, 0, 6] 110 | ] 111 | else: 112 | time_ids_input = [ 113 | [args.output_h, args.output_w, 0, 0, args.output_h, args.output_w], 114 | [args.output_h, args.output_w, 0, 0, args.output_h, args.output_w] 115 | ] 116 | sample_unet_inputs = sample_unet_inputs + [ 117 | ("text_embeds", torch.rand(*text_embeds_shape)), 118 | ("time_ids", torch.tensor(time_ids_input).to(torch.float32)), 119 | ] 120 | 121 | if args.controlnet_support: 122 | block_out_channels = pipe.unet.config.block_out_channels 123 | 124 | cn_output = 0 125 | cn_height = height 126 | cn_width = width 127 | 128 | # down 129 | output_channel = block_out_channels[0] 130 | sample_unet_inputs = sample_unet_inputs + [ 131 | (f"down_block_res_samples_{cn_output:02}", torch.rand(2, output_channel, cn_height, cn_width)) 132 | ] 133 | cn_output += 1 134 | 135 | for i, output_channel in enumerate(block_out_channels): 136 | is_final_block = i == len(block_out_channels) - 1 137 | sample_unet_inputs = sample_unet_inputs + [ 138 | (f"down_block_res_samples_{cn_output:02}", torch.rand(2, output_channel, cn_height, cn_width)), 139 | (f"down_block_res_samples_{cn_output+1:02}", torch.rand(2, output_channel, cn_height, cn_width)), 140 | ] 141 | cn_output += 2 142 | if not is_final_block: 143 | cn_height = int(cn_height / 2) 144 | cn_width = int(cn_width / 2) 145 | sample_unet_inputs = sample_unet_inputs + [ 146 | (f"down_block_res_samples_{cn_output:02}", torch.rand(2, output_channel, cn_height, cn_width)), 147 | ] 148 | cn_output += 1 149 | 150 | # mid 151 | output_channel = block_out_channels[-1] 152 | sample_unet_inputs = sample_unet_inputs + [ 153 | ("mid_block_res_sample", torch.rand(2, output_channel, cn_height, cn_width)) 154 | ] 155 | 156 | if args.t2i_adapter_support: 157 | block_out_channels = pipe.unet.config.block_out_channels 158 | 159 | if args.model_is_sdxl: 160 | t2ia_height = int(height / 2) 161 | t2ia_width = int(width / 2) 162 | else: 163 | t2ia_height = height 164 | t2ia_width = width 165 | 166 | for i, output_channel in enumerate(block_out_channels): 167 | sample_unet_inputs = sample_unet_inputs + [ 168 | (f"adapter_res_samples_{i:02}", torch.rand(2, output_channel, t2ia_height, t2ia_width)), 169 | ] 170 | if not args.model_is_sdxl or i == 1: 171 | t2ia_height = int(t2ia_height / 2) 172 | t2ia_width = int(t2ia_width / 2) 173 | 174 | multisize_inputs = None 175 | if args.multisize: 176 | sample_size = height 177 | input_shape = ct.Shape(shape=( 178 | batch_size, 179 | unet_in_channels, 180 | ct.RangeDim(lower_bound=int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size), 181 | ct.RangeDim(lower_bound=int(sample_size * 0.5), upper_bound=int(sample_size * 2), default=sample_size) 182 | )) 183 | multisize_inputs = {"sample": input_shape} 184 | for k, v in sample_unet_inputs: 185 | if "res_sample" in k: 186 | v_height = v.shape[2] 187 | v_width = v.shape[3] 188 | multisize_inputs[k] = ct.Shape(shape=( 189 | 2, 190 | output_channel, 191 | ct.RangeDim(lower_bound=int(v_height * 0.5), upper_bound=int(v_height * 2), default=v_height), 192 | ct.RangeDim(lower_bound=int(v_width * 0.5), upper_bound=int(v_width * 2), default=v_width) 193 | )) 194 | 195 | sample_unet_inputs = OrderedDict(sample_unet_inputs) 196 | sample_unet_inputs_spec = { 197 | k: (v.shape, v.dtype) 198 | for k, v in sample_unet_inputs.items() 199 | } 200 | logger.info(f"Sample inputs spec: {sample_unet_inputs_spec}") 201 | 202 | # Initialize reference unet 203 | reference_unet = unet.UNet2DConditionModel(**pipe.unet.config).eval() 204 | load_state_dict_summary = reference_unet.load_state_dict(pipe.unet.state_dict()) 205 | 206 | # Prepare inputs 207 | baseline_sample_unet_inputs = deepcopy(sample_unet_inputs) 208 | baseline_sample_unet_inputs[ 209 | "encoder_hidden_states"] = baseline_sample_unet_inputs[ 210 | "encoder_hidden_states"].squeeze(2).transpose(1, 2) 211 | 212 | if not args.check_output_correctness: 213 | del pipe.unet 214 | gc.collect() 215 | 216 | # JIT trace 217 | logger.info("JIT tracing..") 218 | reference_unet = torch.jit.trace(reference_unet, example_kwarg_inputs=sample_unet_inputs) 219 | logger.info("Done.") 220 | 221 | if args.check_output_correctness: 222 | baseline_out = pipe.unet( 223 | sample=baseline_sample_unet_inputs["sample"], 224 | timestep=baseline_sample_unet_inputs["timestep"], 225 | encoder_hidden_states=baseline_sample_unet_inputs["encoder_hidden_states"], 226 | return_dict=False 227 | )[0].detach().numpy() 228 | reference_out = reference_unet(**sample_unet_inputs)[0].detach().numpy() 229 | utils.report_correctness(baseline_out, reference_out, "unet baseline to reference PyTorch") 230 | 231 | del pipe.unet 232 | gc.collect() 233 | 234 | coreml_sample_unet_inputs = { 235 | k: v.numpy().astype(np.float16) 236 | for k, v in sample_unet_inputs.items() 237 | } 238 | 239 | sample_coreml_inputs = utils.get_coreml_inputs(coreml_sample_unet_inputs, multisize_inputs) 240 | precision_full = args.precision_full 241 | if not precision_full and pipe.scheduler.config.prediction_type == "v_prediction": 242 | precision_full = True 243 | logger.info(f"Full precision required: prediction_type == v_prediction") 244 | coreml_unet, out_path = utils.convert_to_coreml( 245 | "unet", 246 | reference_unet, 247 | sample_coreml_inputs, 248 | ["noise_pred"], 249 | precision_full, 250 | args 251 | ) 252 | del reference_unet 253 | gc.collect() 254 | 255 | update_coreml_unet(pipe, coreml_unet, out_path, args) 256 | logger.info(f"Saved unet into {out_path}") 257 | 258 | # Parity check PyTorch vs CoreML 259 | if args.check_output_correctness: 260 | coreml_out = list(coreml_unet.predict(coreml_sample_unet_inputs).values())[0] 261 | utils.report_correctness(baseline_out, coreml_out, "unet baseline PyTorch to reference CoreML") 262 | 263 | del coreml_unet 264 | gc.collect() 265 | else: 266 | del pipe.unet 267 | gc.collect() 268 | logger.info(f"`unet` already exists at {out_path}, skipping conversion.") 269 | 270 | if args.chunk_unet and not unet_chunks_exist: 271 | logger.info("Chunking unet in two approximately equal MLModels") 272 | args.mlpackage_path = out_path 273 | args.remove_original = False 274 | chunk_mlprogram.main(args) 275 | 276 | 277 | def update_coreml_unet(pipe, coreml_unet, out_path, args): 278 | # make ControlNet/T2IAdapter inputs optional 279 | coreml_spec = coreml_unet.get_spec() 280 | for index, input_spec in enumerate(coreml_spec.description.input): 281 | if "res_sample" in input_spec.name: 282 | coreml_spec.description.input[index].type.isOptional = True 283 | coreml_unet = ct.models.MLModel(coreml_spec, skip_model_load=True, weights_dir=coreml_unet.weights_dir) 284 | 285 | # Set model metadata 286 | coreml_unet.author = f"Please refer to the Model Card available at huggingface.co/{args.model_version}" 287 | coreml_unet.license = "OpenRAIL (https://huggingface.co/spaces/CompVis/stable-diffusion-license)" 288 | coreml_unet.version = args.model_version 289 | coreml_unet.short_description = \ 290 | "Stable Diffusion generates images conditioned on text or other images as input through the diffusion process. " \ 291 | "Please refer to https://arxiv.org/abs/2112.10752 for details." 292 | 293 | # Set the input descriptions 294 | coreml_unet.input_description["sample"] = \ 295 | "The low resolution latent feature maps being denoised through reverse diffusion" 296 | coreml_unet.input_description["timestep"] = \ 297 | "A value emitted by the associated scheduler object to condition the model on a given noise schedule" 298 | coreml_unet.input_description["encoder_hidden_states"] = \ 299 | "Output embeddings from the associated text_encoder model to condition to generated image on text. " \ 300 | "A maximum of 77 tokens (~40 words) are allowed. Longer text is truncated. " \ 301 | "Shorter text does not reduce computation." 302 | if hasattr(pipe, "text_encoder_2") and pipe.text_encoder_2: 303 | coreml_unet.input_description["text_embeds"] = "" 304 | coreml_unet.input_description["time_ids"] = "" 305 | if args.t2i_adapter_support: 306 | coreml_unet.input_description["adapter_res_samples_00"] = "Optional: Residual down sample from T2IAdapter" 307 | if args.controlnet_support: 308 | coreml_unet.input_description["down_block_res_samples_00"] = "Optional: Residual down sample from ControlNet" 309 | coreml_unet.input_description["mid_block_res_sample"] = "Optional: Residual mid sample from ControlNet" 310 | 311 | # Set the output descriptions 312 | coreml_unet.output_description["noise_pred"] = \ 313 | "Same shape and dtype as the `sample` input. " \ 314 | "The predicted noise to facilitate the reverse diffusion (denoising) process" 315 | 316 | # Set package version metadata 317 | coreml_unet.user_defined_metadata["identifier"] = args.model_version 318 | coreml_unet.user_defined_metadata["converter_version"] = __version__ 319 | coreml_unet.user_defined_metadata["attention_implementation"] = args.attention_implementation 320 | coreml_unet.user_defined_metadata["compute_unit"] = args.compute_unit 321 | coreml_unet.user_defined_metadata["prediction_type"] = pipe.scheduler.config.prediction_type 322 | coreml_unet.user_defined_metadata["hidden_size"] = str(args.hidden_size) 323 | if args.requires_aesthetics_score: 324 | coreml_unet.user_defined_metadata["requires_aesthetics_score"] = "true" 325 | 326 | coreml_unet.save(out_path) 327 | -------------------------------------------------------------------------------- /GuernikaModelConverter/Model/ConverterProcess.swift: -------------------------------------------------------------------------------- 1 | // 2 | // ConverterProcess.swift 3 | // GuernikaModelConverter 4 | // 5 | // Created by Guillermo Cique Fernández on 30/3/23. 6 | // 7 | 8 | import Combine 9 | import Foundation 10 | 11 | extension Process: Cancellable { 12 | public func cancel() { 13 | terminate() 14 | } 15 | } 16 | 17 | class ConverterProcess: ObservableObject { 18 | enum ArgumentError: Error { 19 | case noOutputLocation 20 | case noHuggingfaceIdentifier 21 | case noDiffusersLocation 22 | case noCheckpointLocation 23 | } 24 | enum Step: Int, Identifiable, CustomStringConvertible { 25 | case initialize = 0 26 | case loadEmbeddings 27 | case convertVaeEncoder 28 | case convertVaeDecoder 29 | case convertUnet 30 | case convertTextEncoder 31 | case convertSafetyChecker 32 | case convertControlNet 33 | case compressOutput 34 | case cleanUp 35 | case done 36 | 37 | var id: Int { rawValue } 38 | 39 | var description: String { 40 | switch self { 41 | case .initialize: 42 | return "Initialize" 43 | case .loadEmbeddings: 44 | return "Load embeddings" 45 | case .convertVaeEncoder: 46 | return "Convert Encoder" 47 | case .convertVaeDecoder: 48 | return "Convert Decoder" 49 | case .convertUnet: 50 | return "Convert Unet" 51 | case .convertTextEncoder: 52 | return "Convert Text encoder" 53 | case .convertSafetyChecker: 54 | return "Convert Safety checker" 55 | case .convertControlNet: 56 | return "Convert ControlNet" 57 | case .compressOutput: 58 | return "Compress" 59 | case .cleanUp: 60 | return "Clean up" 61 | case .done: 62 | return "Done" 63 | } 64 | } 65 | } 66 | struct StepProgress: CustomStringConvertible { 67 | let step: String 68 | let etaString: String 69 | var percentage: Float 70 | 71 | var description: String { "\(step).\(etaString)" } 72 | } 73 | 74 | let process: Process 75 | let steps: [Step] 76 | @Published var currentStep: Step = .initialize 77 | @Published var currentProgress: StepProgress? 78 | var didComplete: Bool { 79 | currentStep == .done && !process.isRunning 80 | } 81 | 82 | init( 83 | outputLocation: URL?, 84 | modelOrigin: ModelOrigin, 85 | huggingfaceIdentifier: String, 86 | diffusersLocation: URL?, 87 | checkpointLocation: URL?, 88 | computeUnits: ComputeUnits, 89 | customWidth: Int?, 90 | customHeight: Int?, 91 | multisize: Bool, 92 | convertUnet: Bool, 93 | chunkUnet: Bool, 94 | controlNetSupport: Bool, 95 | convertTextEncoder: Bool, 96 | embeddingsLocation: URL?, 97 | convertVaeEncoder: Bool, 98 | convertVaeDecoder: Bool, 99 | convertSafetyChecker: Bool, 100 | precisionFull: Bool, 101 | loRAsToMerge: [LoRAInfo] = [], 102 | compression: Compression 103 | ) throws { 104 | guard let outputLocation else { 105 | throw ArgumentError.noOutputLocation 106 | } 107 | var steps: [Step] = [.initialize] 108 | var arguments: [String] = [ 109 | "-o", outputLocation.path(percentEncoded: false), 110 | "--bundle-resources-for-guernika", 111 | "--clean-up-mlpackages", 112 | "--compute-unit", "\(computeUnits.asCTComputeUnits)" 113 | ] 114 | 115 | switch modelOrigin { 116 | case .huggingface: 117 | guard !huggingfaceIdentifier.isEmpty else { 118 | throw ArgumentError.noHuggingfaceIdentifier 119 | } 120 | arguments.append("--model-version") 121 | arguments.append(huggingfaceIdentifier) 122 | arguments.append("--resources-dir-name") 123 | arguments.append(huggingfaceIdentifier) 124 | case .diffusers: 125 | guard let diffusersLocation else { 126 | throw ArgumentError.noDiffusersLocation 127 | } 128 | arguments.append("--model-location") 129 | arguments.append(diffusersLocation.path(percentEncoded: false)) 130 | arguments.append("--resources-dir-name") 131 | arguments.append(diffusersLocation.lastPathComponent) 132 | arguments.append("--model-version") 133 | arguments.append(diffusersLocation.lastPathComponent) 134 | case .checkpoint: 135 | guard let checkpointLocation else { 136 | throw ArgumentError.noCheckpointLocation 137 | } 138 | arguments.append("--model-checkpoint-location") 139 | arguments.append(checkpointLocation.path(percentEncoded: false)) 140 | arguments.append("--resources-dir-name") 141 | arguments.append(checkpointLocation.deletingPathExtension().lastPathComponent) 142 | arguments.append("--model-version") 143 | arguments.append(checkpointLocation.deletingPathExtension().lastPathComponent) 144 | } 145 | 146 | if computeUnits == .cpuAndGPU { 147 | arguments.append("--attention-implementation") 148 | arguments.append("ORIGINAL") 149 | } else { 150 | arguments.append("--attention-implementation") 151 | arguments.append("SPLIT_EINSUM_V2") 152 | } 153 | 154 | if let embeddingsLocation { 155 | steps.append(.loadEmbeddings) 156 | arguments.append("--embeddings-location") 157 | arguments.append(embeddingsLocation.path(percentEncoded: false)) 158 | } 159 | 160 | if convertVaeEncoder { 161 | steps.append(.convertVaeEncoder) 162 | arguments.append("--convert-vae-encoder") 163 | } 164 | if convertVaeDecoder { 165 | steps.append(.convertVaeDecoder) 166 | arguments.append("--convert-vae-decoder") 167 | } 168 | if convertUnet { 169 | steps.append(.convertUnet) 170 | arguments.append("--convert-unet") 171 | if chunkUnet { 172 | arguments.append("--chunk-unet") 173 | } 174 | if controlNetSupport { 175 | arguments.append("--controlnet-support") 176 | } 177 | } 178 | if convertTextEncoder { 179 | steps.append(.convertTextEncoder) 180 | arguments.append("--convert-text-encoder") 181 | } 182 | if convertSafetyChecker { 183 | steps.append(.convertSafetyChecker) 184 | arguments.append("--convert-safety-checker") 185 | } 186 | 187 | if !loRAsToMerge.isEmpty { 188 | arguments.append("--loras-to-merge") 189 | for loRA in loRAsToMerge { 190 | arguments.append(loRA.argument) 191 | } 192 | } 193 | 194 | if #available(macOS 14.0, *) { 195 | switch compression { 196 | case .quantizied6bit: 197 | arguments.append("--quantize-nbits") 198 | arguments.append("6") 199 | if convertUnet || convertTextEncoder { 200 | steps.append(.compressOutput) 201 | } 202 | case .quantizied8bit: 203 | arguments.append("--quantize-nbits") 204 | arguments.append("8") 205 | if convertUnet || convertTextEncoder { 206 | steps.append(.compressOutput) 207 | } 208 | case .fullSize: 209 | break 210 | } 211 | } 212 | 213 | if precisionFull { 214 | arguments.append("--precision-full") 215 | } 216 | 217 | if let customWidth { 218 | arguments.append("--output-w") 219 | arguments.append(String(describing: customWidth)) 220 | } 221 | if let customHeight { 222 | arguments.append("--output-h") 223 | arguments.append(String(describing: customHeight)) 224 | } 225 | if multisize { 226 | arguments.append("--multisize") 227 | } 228 | 229 | let process = Process() 230 | let pipe = Pipe() 231 | process.standardOutput = pipe 232 | process.standardError = pipe 233 | process.standardInput = nil 234 | 235 | process.executableURL = Bundle.main.url(forAuxiliaryExecutable: "GuernikaTools")! 236 | print("Arguments", arguments) 237 | process.arguments = arguments 238 | self.process = process 239 | steps.append(.cleanUp) 240 | self.steps = steps 241 | 242 | pipe.fileHandleForReading.readabilityHandler = { handle in 243 | let data = handle.availableData 244 | if data.count > 0 { 245 | if let newLine = String(data: data, encoding: .utf8) { 246 | self.handleOutput(newLine) 247 | } 248 | } 249 | } 250 | } 251 | 252 | init( 253 | outputLocation: URL?, 254 | controlNetOrigin: ModelOrigin, 255 | huggingfaceIdentifier: String, 256 | diffusersLocation: URL?, 257 | checkpointLocation: URL?, 258 | computeUnits: ComputeUnits, 259 | customWidth: Int?, 260 | customHeight: Int?, 261 | multisize: Bool, 262 | compression: Compression 263 | ) throws { 264 | guard let outputLocation else { 265 | throw ArgumentError.noOutputLocation 266 | } 267 | var steps: [Step] = [.initialize, .convertControlNet] 268 | var arguments: [String] = [ 269 | "-o", outputLocation.path(percentEncoded: false), 270 | "--bundle-resources-for-guernika", 271 | "--clean-up-mlpackages", 272 | "--compute-unit", "\(computeUnits.asCTComputeUnits)" 273 | ] 274 | 275 | switch controlNetOrigin { 276 | case .huggingface: 277 | guard !huggingfaceIdentifier.isEmpty else { 278 | throw ArgumentError.noHuggingfaceIdentifier 279 | } 280 | arguments.append("--controlnet-version") 281 | arguments.append(huggingfaceIdentifier) 282 | arguments.append("--resources-dir-name") 283 | arguments.append(huggingfaceIdentifier) 284 | case .diffusers: 285 | guard let diffusersLocation else { 286 | throw ArgumentError.noDiffusersLocation 287 | } 288 | arguments.append("--controlnet-location") 289 | arguments.append(diffusersLocation.path(percentEncoded: false)) 290 | arguments.append("--resources-dir-name") 291 | arguments.append(diffusersLocation.lastPathComponent) 292 | arguments.append("--controlnet-version") 293 | arguments.append(diffusersLocation.lastPathComponent) 294 | case .checkpoint: 295 | guard let checkpointLocation else { 296 | throw ArgumentError.noCheckpointLocation 297 | } 298 | arguments.append("--controlnet-checkpoint-location") 299 | arguments.append(checkpointLocation.path(percentEncoded: false)) 300 | arguments.append("--resources-dir-name") 301 | arguments.append(checkpointLocation.deletingPathExtension().lastPathComponent) 302 | arguments.append("--controlnet-version") 303 | arguments.append(checkpointLocation.deletingPathExtension().lastPathComponent) 304 | } 305 | 306 | if computeUnits == .cpuAndGPU { 307 | arguments.append("--attention-implementation") 308 | arguments.append("ORIGINAL") 309 | } else { 310 | arguments.append("--attention-implementation") 311 | arguments.append("SPLIT_EINSUM_V2") 312 | } 313 | 314 | if let customWidth { 315 | arguments.append("--output-w") 316 | arguments.append(String(describing: customWidth)) 317 | } 318 | if let customHeight { 319 | arguments.append("--output-h") 320 | arguments.append(String(describing: customHeight)) 321 | } 322 | if multisize { 323 | arguments.append("--multisize") 324 | } 325 | 326 | if #available(macOS 14.0, *) { 327 | switch compression { 328 | case .quantizied6bit: 329 | arguments.append("--quantize-nbits") 330 | arguments.append("6") 331 | steps.append(.compressOutput) 332 | case .quantizied8bit: 333 | arguments.append("--quantize-nbits") 334 | arguments.append("8") 335 | steps.append(.compressOutput) 336 | case .fullSize: 337 | break 338 | } 339 | } 340 | 341 | let process = Process() 342 | let pipe = Pipe() 343 | process.standardOutput = pipe 344 | process.standardError = pipe 345 | process.standardInput = nil 346 | 347 | process.executableURL = Bundle.main.url(forAuxiliaryExecutable: "GuernikaTools")! 348 | print("Arguments", arguments) 349 | process.arguments = arguments 350 | self.process = process 351 | steps.append(.cleanUp) 352 | self.steps = steps 353 | 354 | pipe.fileHandleForReading.readabilityHandler = { handle in 355 | let data = handle.availableData 356 | if data.count > 0 { 357 | if let newLine = String(data: data, encoding: .utf8) { 358 | self.handleOutput(newLine) 359 | } 360 | } 361 | } 362 | } 363 | 364 | private func handleOutput(_ newLine: String) { 365 | let isProgressStep = newLine.starts(with: "\r") 366 | let newLine = newLine.trimmingCharacters(in: .whitespacesAndNewlines) 367 | if isProgressStep { 368 | if 369 | let match = newLine.firstMatch(#": *?(\d*)%.*\|(.*)"#), 370 | let percentageRange = Range(match.range(at: 1), in: newLine), 371 | let percentage = Int(newLine[percentageRange]), 372 | let etaRange = Range(match.range(at: 2), in: newLine) 373 | { 374 | let etaString = String(newLine[etaRange]) 375 | if newLine.starts(with: "Running MIL") || newLine.starts(with: "Running compression pass") { 376 | currentProgress = StepProgress( 377 | step: String(newLine.split(separator: ":")[0]), 378 | etaString: etaString, 379 | percentage: Float(percentage) / 100 380 | ) 381 | } else { 382 | currentProgress = nil 383 | } 384 | } 385 | DispatchQueue.main.async { 386 | Logger.shared.append(newLine, level: .success) 387 | } 388 | return 389 | } 390 | 391 | if 392 | let match = newLine.firstMatch(#"INFO:.*Converting ([^\s]+)$"#), 393 | let moduleRange = Range(match.range(at: 1), in: newLine) 394 | { 395 | let module = String(newLine[moduleRange]) 396 | currentProgress = nil 397 | switch module { 398 | case "vae_encoder": 399 | currentStep = .convertVaeEncoder 400 | case "vae_decoder": 401 | currentStep = .convertVaeDecoder 402 | case "unet": 403 | currentStep = .convertUnet 404 | case "text_encoder": 405 | currentStep = .convertTextEncoder 406 | case "safety_checker": 407 | currentStep = .convertSafetyChecker 408 | case "controlnet": 409 | currentStep = .convertControlNet 410 | default: 411 | break 412 | } 413 | } else if let _ = newLine.firstMatch(#"INFO:.*Loading embeddings"#) { 414 | currentProgress = nil 415 | currentStep = .loadEmbeddings 416 | } else if let _ = newLine.firstMatch(#"INFO:.*Quantizing weights"#) { 417 | currentProgress = nil 418 | currentStep = .compressOutput 419 | } else if let _ = newLine.firstMatch(#"INFO:.*Bundling resources for Guernika$"#) { 420 | currentProgress = nil 421 | currentStep = .cleanUp 422 | } else if let _ = newLine.firstMatch(#"INFO:.*MLPackages removed$"#) { 423 | currentProgress = nil 424 | currentStep = .done 425 | } else if newLine.hasPrefix("usage: GuernikaTools") || newLine.hasPrefix("GuernikaTools: error: unrecognized arguments") { 426 | print(newLine) 427 | return 428 | } else { 429 | currentProgress = nil 430 | } 431 | DispatchQueue.main.async { 432 | Logger.shared.append(newLine) 433 | } 434 | } 435 | 436 | func start() throws { 437 | guard !process.isRunning else { return } 438 | Logger.shared.append("Starting python converter", level: .info) 439 | try process.run() 440 | } 441 | 442 | func cancel() { 443 | process.cancel() 444 | } 445 | } 446 | 447 | extension String { 448 | func firstMatch(_ pattern: String) -> NSTextCheckingResult? { 449 | guard let regex = try? NSRegularExpression(pattern: pattern) else { return nil } 450 | return regex.firstMatch(in: self, options: [], range: NSRange(location: 0, length: utf16.count)) 451 | } 452 | } 453 | --------------------------------------------------------------------------------