├── .github └── pull_request_template.md ├── .gitignore ├── ACKNOWLEDGEMENTS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── Package.swift ├── README.md ├── assets ├── a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space │ ├── randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png │ ├── randomSeed_11_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png │ ├── randomSeed_123456789_computeUnit_ALL_modelVersion_CompVis_stable-diffusion-v1-4.png │ ├── randomSeed_123456789_computeUnit_CPU_AND_GPU_modelVersion_CompVis_stable-diffusion-v1-4.png │ ├── randomSeed_123456789_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png │ ├── randomSeed_93_computeUnit_ALL_modelVersion_stabilityai_stable-diffusion-2-base.png │ ├── randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png │ └── randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_stabilityai_stable-diffusion-2-base.png ├── a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space │ ├── randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png │ ├── randomSeed_11_computeUnit_CPU_AND_NE_modelVersion_stabilityai_stable-diffusion-2-base.png │ ├── randomSeed_13_computeUnit_ALL_modelVersion_CompVis_stable-diffusion-v1-4.png │ ├── randomSeed_13_computeUnit_CPU_AND_GPU_modelVersion_CompVis_stable-diffusion-v1-4.png │ ├── randomSeed_13_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png │ ├── randomSeed_93_computeUnit_ALL_modelVersion_runwayml_stable-diffusion-v1-5.png │ ├── randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png │ └── randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png ├── controlnet_readme_reel.png ├── float16_cpuandne_readmereel.png ├── float16_gpu_readmereel.png ├── mbp │ ├── a_high_quality_photo_of_a_surfing_dog.7667.final_3.41-bits.png │ ├── a_high_quality_photo_of_a_surfing_dog.7667.final_4.50-bits.png │ ├── a_high_quality_photo_of_a_surfing_dog.7667.final_6.55-bits.png │ ├── a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png │ ├── runwayml_stable-diffusion-v1-5_psnr_vs_size.png │ ├── stabilityai_stable-diffusion-2-1-base_psnr_vs_size.png │ └── stabilityai_stable-diffusion-xl-base-1.0_psnr_vs_size.png ├── palette6_cpuandne_readmereel.png └── readme_reel.png ├── python_coreml_stable_diffusion ├── __init__.py ├── _version.py ├── activation_quantization.py ├── attention.py ├── chunk_mlprogram.py ├── controlnet.py ├── coreml_model.py ├── layer_norm.py ├── mixed_bit_compression_apply.py ├── mixed_bit_compression_pre_analysis.py ├── multilingual_projection.py ├── pipeline.py ├── torch2coreml.py └── unet.py ├── requirements.txt ├── setup.py ├── swift ├── StableDiffusion │ ├── pipeline │ │ ├── CGImage+vImage.swift │ │ ├── ControlNet.swift │ │ ├── DPMSolverMultistepScheduler.swift │ │ ├── Decoder.swift │ │ ├── DiscreteFlowScheduler.swift │ │ ├── Encoder.swift │ │ ├── ManagedMLModel.swift │ │ ├── MultiModalDiffusionTransformer.swift │ │ ├── MultilingualTextEncoder.swift │ │ ├── NumPyRandomSource.swift │ │ ├── NvRandomSource.swift │ │ ├── RandomSource.swift │ │ ├── ResourceManaging.swift │ │ ├── SafetyChecker.swift │ │ ├── SampleTimer.swift │ │ ├── Scheduler.swift │ │ ├── StableDiffusion3Pipeline+Resources.swift │ │ ├── StableDiffusion3Pipeline.swift │ │ ├── StableDiffusionPipeline+Resources.swift │ │ ├── StableDiffusionPipeline.Configuration.swift │ │ ├── StableDiffusionPipeline.swift │ │ ├── StableDiffusionXL+Resources.swift │ │ ├── StableDiffusionXLPipeline.swift │ │ ├── TextEncoder.swift │ │ ├── TextEncoderT5.swift │ │ ├── TextEncoderXL.swift │ │ ├── TorchRandomSource.swift │ │ └── Unet.swift │ └── tokenizer │ │ ├── BPETokenizer+Reading.swift │ │ ├── BPETokenizer.swift │ │ └── T5Tokenizer.swift ├── StableDiffusionCLI │ └── main.swift └── StableDiffusionTests │ ├── Resources │ ├── merges.txt │ └── vocab.json │ └── StableDiffusionTests.swift └── tests ├── __init__.py └── test_stable_diffusion.py /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | Thank you for your interest in contributing to Core ML Stable Diffusion! Please review [CONTRIBUTING.md](../CONTRIBUTING.md) first. We appreciate your interest in the project! 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | # Swift Package 4 | .DS_Store 5 | /.build 6 | /Packages 7 | /*.xcodeproj 8 | .swiftpm 9 | .vscode 10 | .*.sw? 11 | *.docc-build 12 | *.vs 13 | Package.resolved 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | # macOS filesystem 146 | *.DS_Store 147 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thank you for your interest in contributing to Core ML Stable Diffusion! This project was released for system demonstration purposes and there are limited plans for future development of the repository. While we welcome new pull requests and issues please note that our response may be limited. 4 | 5 | 6 | ## Submitting a Pull Request 7 | 8 | The project is licensed under the MIT license. By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the MIT license. 9 | 10 | ## Code of Conduct 11 | 12 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 13 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Apple Inc. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 6 | 7 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 10 | -------------------------------------------------------------------------------- /Package.swift: -------------------------------------------------------------------------------- 1 | // swift-tools-version: 5.8 2 | // The swift-tools-version declares the minimum version of Swift required to build this package. 3 | 4 | import PackageDescription 5 | 6 | let package = Package( 7 | name: "stable-diffusion", 8 | platforms: [ 9 | .macOS(.v13), 10 | .iOS(.v16), 11 | ], 12 | products: [ 13 | .library( 14 | name: "StableDiffusion", 15 | targets: ["StableDiffusion"]), 16 | .executable( 17 | name: "StableDiffusionSample", 18 | targets: ["StableDiffusionCLI"]) 19 | ], 20 | dependencies: [ 21 | .package(url: "https://github.com/apple/swift-argument-parser.git", from: "1.2.3"), 22 | .package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.8"), 23 | ], 24 | targets: [ 25 | .target( 26 | name: "StableDiffusion", 27 | dependencies: [ 28 | .product(name: "Transformers", package: "swift-transformers"), 29 | ], 30 | path: "swift/StableDiffusion"), 31 | .executableTarget( 32 | name: "StableDiffusionCLI", 33 | dependencies: [ 34 | "StableDiffusion", 35 | .product(name: "ArgumentParser", package: "swift-argument-parser")], 36 | path: "swift/StableDiffusionCLI"), 37 | .testTarget( 38 | name: "StableDiffusionTests", 39 | dependencies: ["StableDiffusion"], 40 | path: "swift/StableDiffusionTests", 41 | resources: [ 42 | .copy("Resources/vocab.json"), 43 | .copy("Resources/merges.txt") 44 | ]), 45 | ] 46 | ) 47 | -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_11_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_11_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_123456789_computeUnit_ALL_modelVersion_CompVis_stable-diffusion-v1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_123456789_computeUnit_ALL_modelVersion_CompVis_stable-diffusion-v1-4.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_123456789_computeUnit_CPU_AND_GPU_modelVersion_CompVis_stable-diffusion-v1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_123456789_computeUnit_CPU_AND_GPU_modelVersion_CompVis_stable-diffusion-v1-4.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_123456789_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_123456789_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_93_computeUnit_ALL_modelVersion_stabilityai_stable-diffusion-2-base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_93_computeUnit_ALL_modelVersion_stabilityai_stable-diffusion-2-base.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_stabilityai_stable-diffusion-2-base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_dragon_in_space/randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_stabilityai_stable-diffusion-2-base.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_11_computeUnit_CPU_AND_GPU_modelVersion_stabilityai_stable-diffusion-2-base.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_11_computeUnit_CPU_AND_NE_modelVersion_stabilityai_stable-diffusion-2-base.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_11_computeUnit_CPU_AND_NE_modelVersion_stabilityai_stable-diffusion-2-base.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_ALL_modelVersion_CompVis_stable-diffusion-v1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_ALL_modelVersion_CompVis_stable-diffusion-v1-4.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_CPU_AND_GPU_modelVersion_CompVis_stable-diffusion-v1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_CPU_AND_GPU_modelVersion_CompVis_stable-diffusion-v1-4.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_13_computeUnit_CPU_AND_NE_modelVersion_CompVis_stable-diffusion-v1-4.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_ALL_modelVersion_runwayml_stable-diffusion-v1-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_ALL_modelVersion_runwayml_stable-diffusion-v1-5.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_GPU_modelVersion_runwayml_stable-diffusion-v1-5.png -------------------------------------------------------------------------------- /assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/a_high_quality_photo_of_an_astronaut_riding_a_horse_in_space/randomSeed_93_computeUnit_CPU_AND_NE_modelVersion_runwayml_stable-diffusion-v1-5.png -------------------------------------------------------------------------------- /assets/controlnet_readme_reel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/controlnet_readme_reel.png -------------------------------------------------------------------------------- /assets/float16_cpuandne_readmereel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/float16_cpuandne_readmereel.png -------------------------------------------------------------------------------- /assets/float16_gpu_readmereel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/float16_gpu_readmereel.png -------------------------------------------------------------------------------- /assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_3.41-bits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_3.41-bits.png -------------------------------------------------------------------------------- /assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_4.50-bits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_4.50-bits.png -------------------------------------------------------------------------------- /assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_6.55-bits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_6.55-bits.png -------------------------------------------------------------------------------- /assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png -------------------------------------------------------------------------------- /assets/mbp/runwayml_stable-diffusion-v1-5_psnr_vs_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/runwayml_stable-diffusion-v1-5_psnr_vs_size.png -------------------------------------------------------------------------------- /assets/mbp/stabilityai_stable-diffusion-2-1-base_psnr_vs_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/stabilityai_stable-diffusion-2-1-base_psnr_vs_size.png -------------------------------------------------------------------------------- /assets/mbp/stabilityai_stable-diffusion-xl-base-1.0_psnr_vs_size.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/mbp/stabilityai_stable-diffusion-xl-base-1.0_psnr_vs_size.png -------------------------------------------------------------------------------- /assets/palette6_cpuandne_readmereel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/palette6_cpuandne_readmereel.png -------------------------------------------------------------------------------- /assets/readme_reel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-stable-diffusion/e5d960c41a6a4ab200b8db379194127607b1c590/assets/readme_reel.png -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from ._version import __version__ 2 | -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/attention.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logger = logging.getLogger(__name__) 4 | logger.setLevel(logging.INFO) 5 | 6 | import torch 7 | import math 8 | 9 | SPLIT_SOFTMAX = False 10 | 11 | def softmax(x, dim): 12 | # Reduction max 13 | max_x = x.max(dim=dim, keepdim=True).values 14 | # EW sub 15 | x -= max_x 16 | # Scale for EXP to EXP2, Activation EXP2 17 | scaled_x = x * (1 / math.log(2)) 18 | exp_act = torch.exp2(scaled_x) 19 | # Reduction Sum + Inv 20 | exp_sum_inv = 1 / exp_act.sum(dim=dim, keepdims=True) 21 | # EW Mult 22 | return exp_act * exp_sum_inv 23 | 24 | def split_einsum(q, k, v, mask, heads, dim_head): 25 | """ Attention Implementation backing AttentionImplementations.SPLIT_EINSUM 26 | 27 | - Implements https://machinelearning.apple.com/research/neural-engine-transformers 28 | - Recommended for ANE 29 | - Marginally slower on GPU 30 | """ 31 | mh_q = [ 32 | q[:, head_idx * dim_head:(head_idx + 1) * 33 | dim_head, :, :] for head_idx in range(heads) 34 | ] # (bs, dim_head, 1, max_seq_length) * heads 35 | 36 | k = k.transpose(1, 3) 37 | mh_k = [ 38 | k[:, :, :, 39 | head_idx * dim_head:(head_idx + 1) * dim_head] 40 | for head_idx in range(heads) 41 | ] # (bs, max_seq_length, 1, dim_head) * heads 42 | 43 | mh_v = [ 44 | v[:, head_idx * dim_head:(head_idx + 1) * 45 | dim_head, :, :] for head_idx in range(heads) 46 | ] # (bs, dim_head, 1, max_seq_length) * heads 47 | 48 | attn_weights = [ 49 | torch.einsum("bchq,bkhc->bkhq", [qi, ki]) * (dim_head**-0.5) 50 | for qi, ki in zip(mh_q, mh_k) 51 | ] # (bs, max_seq_length, 1, max_seq_length) * heads 52 | 53 | if mask is not None: 54 | for head_idx in range(heads): 55 | attn_weights[head_idx] = attn_weights[head_idx] + mask 56 | 57 | if SPLIT_SOFTMAX: 58 | attn_weights = [ 59 | softmax(aw, dim=1) for aw in attn_weights 60 | ] # (bs, max_seq_length, 1, max_seq_length) * heads 61 | else: 62 | attn_weights = [ 63 | aw.softmax(dim=1) for aw in attn_weights 64 | ] # (bs, max_seq_length, 1, max_seq_length) * heads 65 | 66 | attn = [ 67 | torch.einsum("bkhq,bchk->bchq", wi, vi) 68 | for wi, vi in zip(attn_weights, mh_v) 69 | ] # (bs, dim_head, 1, max_seq_length) * heads 70 | 71 | attn = torch.cat(attn, dim=1) # (bs, dim, 1, max_seq_length) 72 | return attn 73 | 74 | 75 | CHUNK_SIZE = 512 76 | 77 | def split_einsum_v2(q, k, v, mask, heads, dim_head): 78 | """ Attention Implementation backing AttentionImplementations.SPLIT_EINSUM_V2 79 | 80 | - Implements https://machinelearning.apple.com/research/neural-engine-transformers 81 | - Recommended for ANE 82 | - Marginally slower on GPU 83 | - Chunks the query sequence to avoid large intermediate tensors and improves ANE performance 84 | """ 85 | query_seq_length = q.size(3) 86 | num_chunks = query_seq_length // CHUNK_SIZE 87 | 88 | if num_chunks == 0: 89 | logger.info( 90 | "AttentionImplementations.SPLIT_EINSUM_V2: query sequence too short to chunk " 91 | f"({query_seq_length}<{CHUNK_SIZE}), fall back to AttentionImplementations.SPLIT_EINSUM (safe to ignore)") 92 | return split_einsum(q, k, v, mask, heads, dim_head) 93 | 94 | logger.info( 95 | "AttentionImplementations.SPLIT_EINSUM_V2: Splitting query sequence length of " 96 | f"{query_seq_length} into {num_chunks} chunks") 97 | 98 | mh_q = [ 99 | q[:, head_idx * dim_head:(head_idx + 1) * 100 | dim_head, :, :] for head_idx in range(heads) 101 | ] # (bs, dim_head, 1, max_seq_length) * heads 102 | 103 | # Chunk the query sequence for each head 104 | mh_q_chunked = [ 105 | [h_q[..., chunk_idx * CHUNK_SIZE:(chunk_idx + 1) * CHUNK_SIZE] for chunk_idx in range(num_chunks)] 106 | for h_q in mh_q 107 | ] # ((bs, dim_head, 1, QUERY_SEQ_CHUNK_SIZE) * num_chunks) * heads 108 | 109 | k = k.transpose(1, 3) 110 | mh_k = [ 111 | k[:, :, :, 112 | head_idx * dim_head:(head_idx + 1) * dim_head] 113 | for head_idx in range(heads) 114 | ] # (bs, max_seq_length, 1, dim_head) * heads 115 | 116 | mh_v = [ 117 | v[:, head_idx * dim_head:(head_idx + 1) * 118 | dim_head, :, :] for head_idx in range(heads) 119 | ] # (bs, dim_head, 1, max_seq_length) * heads 120 | 121 | attn_weights = [ 122 | [ 123 | torch.einsum("bchq,bkhc->bkhq", [qi_chunk, ki]) * (dim_head**-0.5) 124 | for qi_chunk in h_q_chunked 125 | ] for h_q_chunked, ki in zip(mh_q_chunked, mh_k) 126 | ] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads 127 | 128 | attn_weights = [ 129 | [aw_chunk.softmax(dim=1) for aw_chunk in aw_chunked] 130 | for aw_chunked in attn_weights 131 | ] # ((bs, max_seq_length, 1, chunk_size) * num_chunks) * heads 132 | 133 | attn = [ 134 | [ 135 | torch.einsum("bkhq,bchk->bchq", wi_chunk, vi) 136 | for wi_chunk in wi_chunked 137 | ] for wi_chunked, vi in zip(attn_weights, mh_v) 138 | ] # ((bs, dim_head, 1, chunk_size) * num_chunks) * heads 139 | 140 | attn = torch.cat([ 141 | torch.cat(attn_chunked, dim=3) for attn_chunked in attn 142 | ], dim=1) # (bs, dim, 1, max_seq_length) 143 | 144 | return attn 145 | 146 | 147 | def original(q, k, v, mask, heads, dim_head): 148 | """ Attention Implementation backing AttentionImplementations.ORIGINAL 149 | 150 | - Not recommended for ANE 151 | - Recommended for GPU 152 | """ 153 | bs = q.size(0) 154 | mh_q = q.view(bs, heads, dim_head, -1) 155 | mh_k = k.view(bs, heads, dim_head, -1) 156 | mh_v = v.view(bs, heads, dim_head, -1) 157 | 158 | attn_weights = torch.einsum("bhcq,bhck->bhqk", [mh_q, mh_k]) 159 | attn_weights.mul_(dim_head**-0.5) 160 | 161 | if mask is not None: 162 | attn_weights = attn_weights + mask 163 | 164 | attn_weights = attn_weights.softmax(dim=3) 165 | 166 | attn = torch.einsum("bhqk,bhck->bhcq", [attn_weights, mh_v]) 167 | attn = attn.contiguous().view(bs, heads * dim_head, 1, -1) 168 | return attn 169 | -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/controlnet.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | from diffusers.configuration_utils import ConfigMixin, register_to_config 7 | from diffusers import ModelMixin 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .unet import Timesteps, TimestepEmbedding, get_down_block, UNetMidBlock2DCrossAttn, linear_to_conv2d_map 14 | 15 | class ControlNetConditioningEmbedding(nn.Module): 16 | 17 | def __init__( 18 | self, 19 | conditioning_embedding_channels, 20 | conditioning_channels=3, 21 | block_out_channels=(16, 32, 96, 256), 22 | ): 23 | super().__init__() 24 | 25 | self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) 26 | 27 | self.blocks = nn.ModuleList([]) 28 | 29 | for i in range(len(block_out_channels) - 1): 30 | channel_in = block_out_channels[i] 31 | channel_out = block_out_channels[i + 1] 32 | self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) 33 | self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) 34 | 35 | self.conv_out = nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) 36 | 37 | def forward(self, conditioning): 38 | embedding = self.conv_in(conditioning) 39 | embedding = F.silu(embedding) 40 | 41 | for block in self.blocks: 42 | embedding = block(embedding) 43 | embedding = F.silu(embedding) 44 | 45 | embedding = self.conv_out(embedding) 46 | 47 | return embedding 48 | 49 | class ControlNetModel(ModelMixin, ConfigMixin): 50 | 51 | @register_to_config 52 | def __init__( 53 | self, 54 | in_channels=4, 55 | flip_sin_to_cos=True, 56 | freq_shift=0, 57 | down_block_types=( 58 | "CrossAttnDownBlock2D", 59 | "CrossAttnDownBlock2D", 60 | "CrossAttnDownBlock2D", 61 | "DownBlock2D", 62 | ), 63 | only_cross_attention=False, 64 | block_out_channels=(320, 640, 1280, 1280), 65 | layers_per_block=2, 66 | downsample_padding=1, 67 | mid_block_scale_factor=1, 68 | act_fn="silu", 69 | norm_num_groups=32, 70 | norm_eps=1e-5, 71 | cross_attention_dim=1280, 72 | transformer_layers_per_block=1, 73 | attention_head_dim=8, 74 | use_linear_projection=False, 75 | upcast_attention=False, 76 | resnet_time_scale_shift="default", 77 | conditioning_embedding_out_channels=(16, 32, 96, 256), 78 | **kwargs, 79 | ): 80 | super().__init__() 81 | 82 | # Check inputs 83 | if len(block_out_channels) != len(down_block_types): 84 | raise ValueError( 85 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 86 | ) 87 | 88 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 89 | raise ValueError( 90 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 91 | ) 92 | 93 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): 94 | raise ValueError( 95 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 96 | ) 97 | 98 | self._register_load_state_dict_pre_hook(linear_to_conv2d_map) 99 | 100 | # input 101 | conv_in_kernel = 3 102 | conv_in_padding = (conv_in_kernel - 1) // 2 103 | self.conv_in = nn.Conv2d( 104 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 105 | ) 106 | 107 | # time 108 | time_embed_dim = block_out_channels[0] * 4 109 | 110 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 111 | timestep_input_dim = block_out_channels[0] 112 | 113 | self.time_embedding = TimestepEmbedding( 114 | timestep_input_dim, 115 | time_embed_dim, 116 | ) 117 | 118 | # control net conditioning embedding 119 | self.controlnet_cond_embedding = ControlNetConditioningEmbedding( 120 | conditioning_embedding_channels=block_out_channels[0], 121 | block_out_channels=conditioning_embedding_out_channels, 122 | ) 123 | 124 | self.down_blocks = nn.ModuleList([]) 125 | self.controlnet_down_blocks = nn.ModuleList([]) 126 | 127 | if isinstance(only_cross_attention, bool): 128 | only_cross_attention = [only_cross_attention] * len(down_block_types) 129 | 130 | if isinstance(attention_head_dim, int): 131 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 132 | 133 | if isinstance(transformer_layers_per_block, int): 134 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 135 | 136 | # down 137 | output_channel = block_out_channels[0] 138 | 139 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 140 | self.controlnet_down_blocks.append(controlnet_block) 141 | 142 | for i, down_block_type in enumerate(down_block_types): 143 | input_channel = output_channel 144 | output_channel = block_out_channels[i] 145 | is_final_block = i == len(block_out_channels) - 1 146 | 147 | down_block = get_down_block( 148 | down_block_type, 149 | transformer_layers_per_block=transformer_layers_per_block[i], 150 | num_layers=layers_per_block, 151 | in_channels=input_channel, 152 | out_channels=output_channel, 153 | temb_channels=time_embed_dim, 154 | resnet_eps=norm_eps, 155 | resnet_act_fn=act_fn, 156 | cross_attention_dim=cross_attention_dim, 157 | attn_num_head_channels=attention_head_dim[i], 158 | downsample_padding=downsample_padding, 159 | add_downsample=not is_final_block, 160 | ) 161 | self.down_blocks.append(down_block) 162 | 163 | for _ in range(layers_per_block): 164 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 165 | self.controlnet_down_blocks.append(controlnet_block) 166 | 167 | if not is_final_block: 168 | controlnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) 169 | self.controlnet_down_blocks.append(controlnet_block) 170 | 171 | # mid 172 | mid_block_channel = block_out_channels[-1] 173 | 174 | controlnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) 175 | self.controlnet_mid_block = controlnet_block 176 | 177 | self.mid_block = UNetMidBlock2DCrossAttn( 178 | in_channels=mid_block_channel, 179 | temb_channels=time_embed_dim, 180 | resnet_eps=norm_eps, 181 | resnet_act_fn=act_fn, 182 | output_scale_factor=mid_block_scale_factor, 183 | resnet_time_scale_shift=resnet_time_scale_shift, 184 | cross_attention_dim=cross_attention_dim, 185 | attn_num_head_channels=attention_head_dim[-1], 186 | resnet_groups=norm_num_groups, 187 | use_linear_projection=use_linear_projection, 188 | upcast_attention=upcast_attention, 189 | ) 190 | 191 | def get_num_residuals(self): 192 | num_res = 2 # initial sample + mid block 193 | for down_block in self.down_blocks: 194 | num_res += len(down_block.resnets) 195 | if hasattr(down_block, "downsamplers") and down_block.downsamplers is not None: 196 | num_res += len(down_block.downsamplers) 197 | return num_res 198 | 199 | def forward( 200 | self, 201 | sample, 202 | timestep, 203 | encoder_hidden_states, 204 | controlnet_cond, 205 | ): 206 | # 1. time 207 | t_emb = self.time_proj(timestep) 208 | emb = self.time_embedding(t_emb) 209 | 210 | # 2. pre-process 211 | sample = self.conv_in(sample) 212 | 213 | controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) 214 | 215 | sample += controlnet_cond 216 | 217 | # 3. down 218 | down_block_res_samples = (sample,) 219 | for downsample_block in self.down_blocks: 220 | if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: 221 | sample, res_samples = downsample_block( 222 | hidden_states=sample, 223 | temb=emb, 224 | encoder_hidden_states=encoder_hidden_states, 225 | ) 226 | else: 227 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 228 | 229 | down_block_res_samples += res_samples 230 | 231 | # 4. mid 232 | if self.mid_block is not None: 233 | sample = self.mid_block( 234 | sample, 235 | emb, 236 | encoder_hidden_states=encoder_hidden_states, 237 | ) 238 | 239 | # 5. Control net blocks 240 | controlnet_down_block_res_samples = () 241 | 242 | for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks): 243 | down_block_res_sample = controlnet_block(down_block_res_sample) 244 | controlnet_down_block_res_samples += (down_block_res_sample,) 245 | 246 | down_block_res_samples = controlnet_down_block_res_samples 247 | 248 | mid_block_res_sample = self.controlnet_mid_block(sample) 249 | 250 | return down_block_res_samples, mid_block_res_sample -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/coreml_model.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import coremltools as ct 7 | 8 | import logging 9 | import json 10 | 11 | logging.basicConfig() 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | import numpy as np 16 | 17 | import os 18 | import time 19 | import subprocess 20 | import sys 21 | 22 | 23 | def _macos_version(): 24 | """ 25 | Returns macOS version as a tuple of integers. On non-Macs, returns an empty tuple. 26 | """ 27 | if sys.platform == "darwin": 28 | try: 29 | ver_str = subprocess.run(["sw_vers", "-productVersion"], stdout=subprocess.PIPE).stdout.decode('utf-8').strip('\n') 30 | return tuple([int(v) for v in ver_str.split(".")]) 31 | except: 32 | raise Exception("Unable to determine the macOS version") 33 | return () 34 | 35 | 36 | class CoreMLModel: 37 | """ Wrapper for running CoreML models using coremltools 38 | """ 39 | 40 | def __init__(self, model_path, compute_unit, sources='packages', optimization_hints=None): 41 | 42 | logger.info(f"Loading {model_path}") 43 | 44 | start = time.time() 45 | if sources == 'packages': 46 | assert os.path.exists(model_path) and model_path.endswith(".mlpackage") 47 | 48 | self.model = ct.models.MLModel( 49 | model_path, 50 | compute_units=ct.ComputeUnit[compute_unit], 51 | optimization_hints=optimization_hints, 52 | ) 53 | DTYPE_MAP = { 54 | 65552: np.float16, 55 | 65568: np.float32, 56 | 131104: np.int32, 57 | } 58 | self.expected_inputs = { 59 | input_tensor.name: { 60 | "shape": tuple(input_tensor.type.multiArrayType.shape), 61 | "dtype": DTYPE_MAP[input_tensor.type.multiArrayType.dataType], 62 | } 63 | for input_tensor in self.model._spec.description.input 64 | } 65 | elif sources == 'compiled': 66 | assert os.path.exists(model_path) and model_path.endswith(".mlmodelc") 67 | 68 | self.model = ct.models.CompiledMLModel( 69 | model_path, 70 | compute_units=ct.ComputeUnit[compute_unit], 71 | optimization_hints=optimization_hints, 72 | ) 73 | 74 | # Grab expected inputs from metadata.json 75 | with open(os.path.join(model_path, 'metadata.json'), 'r') as f: 76 | config = json.load(f)[0] 77 | 78 | self.expected_inputs = { 79 | input_tensor['name']: { 80 | "shape": tuple(eval(input_tensor['shape'])), 81 | "dtype": np.dtype(input_tensor['dataType'].lower()), 82 | } 83 | for input_tensor in config['inputSchema'] 84 | } 85 | else: 86 | raise ValueError(f'Expected `packages` or `compiled` for sources, received {sources}') 87 | 88 | load_time = time.time() - start 89 | logger.info(f"Done. Took {load_time:.1f} seconds.") 90 | 91 | if load_time > LOAD_TIME_INFO_MSG_TRIGGER: 92 | logger.info( 93 | "Loading a CoreML model through coremltools triggers compilation every time. " 94 | "The Swift package we provide uses precompiled Core ML models (.mlmodelc) to avoid compile-on-load." 95 | ) 96 | 97 | def _verify_inputs(self, **kwargs): 98 | for k, v in kwargs.items(): 99 | if k in self.expected_inputs: 100 | if not isinstance(v, np.ndarray): 101 | raise TypeError( 102 | f"Expected numpy.ndarray, got {v} for input: {k}") 103 | 104 | expected_dtype = self.expected_inputs[k]["dtype"] 105 | if not v.dtype == expected_dtype: 106 | raise TypeError( 107 | f"Expected dtype {expected_dtype}, got {v.dtype} for input: {k}" 108 | ) 109 | 110 | expected_shape = self.expected_inputs[k]["shape"] 111 | if not v.shape == expected_shape: 112 | raise TypeError( 113 | f"Expected shape {expected_shape}, got {v.shape} for input: {k}" 114 | ) 115 | else: 116 | raise ValueError(f"Received unexpected input kwarg: {k}") 117 | 118 | def __call__(self, **kwargs): 119 | self._verify_inputs(**kwargs) 120 | return self.model.predict(kwargs) 121 | 122 | 123 | LOAD_TIME_INFO_MSG_TRIGGER = 10 # seconds 124 | 125 | 126 | def get_resource_type(resources_dir: str) -> str: 127 | """ 128 | Detect resource type based on filepath extensions. 129 | returns: 130 | `packages`: for .mlpackage resources 131 | 'compiled`: for .mlmodelc resources 132 | """ 133 | directories = [f for f in os.listdir(resources_dir) if os.path.isdir(os.path.join(resources_dir, f))] 134 | 135 | # consider directories ending with extension 136 | extensions = set([os.path.splitext(e)[1] for e in directories if os.path.splitext(e)[1]]) 137 | 138 | # if one extension present we may be able to infer sources type 139 | if len(set(extensions)) == 1: 140 | extension = extensions.pop() 141 | else: 142 | raise ValueError(f'Multiple file extensions found at {resources_dir}.' 143 | f'Cannot infer resource type from contents.') 144 | 145 | if extension == '.mlpackage': 146 | sources = 'packages' 147 | elif extension == '.mlmodelc': 148 | sources = 'compiled' 149 | else: 150 | raise ValueError(f'Did not find .mlpackage or .mlmodelc at {resources_dir}') 151 | 152 | return sources 153 | 154 | 155 | def _load_mlpackage(submodule_name, 156 | mlpackages_dir, 157 | model_version, 158 | compute_unit, 159 | sources=None): 160 | """ 161 | Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py) 162 | 163 | """ 164 | 165 | # if sources not provided, attempt to infer `packages` or `compiled` from the 166 | # resources directory 167 | if sources is None: 168 | sources = get_resource_type(mlpackages_dir) 169 | 170 | if sources == 'packages': 171 | logger.info(f"Loading {submodule_name} mlpackage") 172 | fname = f"Stable_Diffusion_version_{model_version}_{submodule_name}.mlpackage".replace( 173 | "/", "_") 174 | mlpackage_path = os.path.join(mlpackages_dir, fname) 175 | 176 | if not os.path.exists(mlpackage_path): 177 | raise FileNotFoundError( 178 | f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}") 179 | 180 | elif sources == 'compiled': 181 | logger.info(f"Loading {submodule_name} mlmodelc") 182 | 183 | # FixMe: Submodule names and compiled resources names differ. Can change if names match in the future. 184 | submodule_names = ["text_encoder", "text_encoder_2", "unet", "vae_decoder", "vae_encoder", "safety_checker"] 185 | compiled_names = ['TextEncoder', 'TextEncoder2', 'Unet', 'VAEDecoder', 'VAEEncoder', 'SafetyChecker'] 186 | name_map = dict(zip(submodule_names, compiled_names)) 187 | 188 | cname = name_map[submodule_name] + '.mlmodelc' 189 | mlpackage_path = os.path.join(mlpackages_dir, cname) 190 | 191 | if not os.path.exists(mlpackage_path): 192 | raise FileNotFoundError( 193 | f"{submodule_name} CoreML model doesn't exist at {mlpackage_path}") 194 | 195 | # On macOS 15+, set fast prediction optimization hint for the unet. 196 | optimization_hints = None 197 | if submodule_name == "unet" and _macos_version() >= (15, 0): 198 | optimization_hints = {"specializationStrategy": ct.SpecializationStrategy.FastPrediction} 199 | 200 | return CoreMLModel(mlpackage_path, 201 | compute_unit, 202 | sources=sources, 203 | optimization_hints=optimization_hints) 204 | 205 | 206 | def _load_mlpackage_controlnet(mlpackages_dir, model_version, compute_unit): 207 | """ Load Core ML (mlpackage) models from disk (As exported by torch2coreml.py) 208 | """ 209 | model_name = model_version.replace("/", "_") 210 | 211 | logger.info(f"Loading controlnet_{model_name} mlpackage") 212 | 213 | fname = f"ControlNet_{model_name}.mlpackage" 214 | 215 | mlpackage_path = os.path.join(mlpackages_dir, fname) 216 | 217 | if not os.path.exists(mlpackage_path): 218 | raise FileNotFoundError( 219 | f"controlnet_{model_name} CoreML model doesn't exist at {mlpackage_path}") 220 | 221 | return CoreMLModel(mlpackage_path, compute_unit) 222 | 223 | 224 | def get_available_compute_units(): 225 | return tuple(cu for cu in ct.ComputeUnit._member_names_) 226 | -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/layer_norm.py: -------------------------------------------------------------------------------- 1 | # 2 | # For licensing see accompanying LICENSE.md file. 3 | # Copyright (C) 2022 Apple Inc. All Rights Reserved. 4 | # 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | # Reference: https://github.com/apple/ml-ane-transformers/blob/main/ane_transformers/reference/layer_norm.py 11 | class LayerNormANE(nn.Module): 12 | """ LayerNorm optimized for Apple Neural Engine (ANE) execution 13 | 14 | Note: This layer only supports normalization over the final dim. It expects `num_channels` 15 | as an argument and not `normalized_shape` which is used by `torch.nn.LayerNorm`. 16 | """ 17 | 18 | def __init__(self, 19 | num_channels, 20 | clip_mag=None, 21 | eps=1e-5, 22 | elementwise_affine=True): 23 | """ 24 | Args: 25 | num_channels: Number of channels (C) where the expected input data format is BC1S. S stands for sequence length. 26 | clip_mag: Optional float value to use for clamping the input range before layer norm is applied. 27 | If specified, helps reduce risk of overflow. 28 | eps: Small value to avoid dividing by zero 29 | elementwise_affine: If true, adds learnable channel-wise shift (bias) and scale (weight) parameters 30 | """ 31 | super().__init__() 32 | # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine) 33 | self.expected_rank = len("BC1S") 34 | 35 | self.num_channels = num_channels 36 | self.eps = eps 37 | self.clip_mag = clip_mag 38 | self.elementwise_affine = elementwise_affine 39 | 40 | if self.elementwise_affine: 41 | self.weight = nn.Parameter(torch.Tensor(num_channels)) 42 | self.bias = nn.Parameter(torch.Tensor(num_channels)) 43 | 44 | self._reset_parameters() 45 | 46 | def _reset_parameters(self): 47 | if self.elementwise_affine: 48 | nn.init.ones_(self.weight) 49 | nn.init.zeros_(self.bias) 50 | 51 | def forward(self, inputs): 52 | input_rank = len(inputs.size()) 53 | 54 | # Principle 1: Picking the Right Data Format (machinelearning.apple.com/research/apple-neural-engine) 55 | # Migrate the data format from BSC to BC1S (most conducive to ANE) 56 | if input_rank == 3 and inputs.size(2) == self.num_channels: 57 | inputs = inputs.transpose(1, 2).unsqueeze(2) 58 | input_rank = len(inputs.size()) 59 | 60 | assert input_rank == self.expected_rank 61 | assert inputs.size(1) == self.num_channels 62 | 63 | if self.clip_mag is not None: 64 | inputs.clamp_(-self.clip_mag, self.clip_mag) 65 | 66 | channels_mean = inputs.mean(dim=1, keepdims=True) 67 | 68 | zero_mean = inputs - channels_mean 69 | 70 | zero_mean_sq = zero_mean * zero_mean 71 | 72 | denom = (zero_mean_sq.mean(dim=1, keepdims=True) + self.eps).rsqrt() 73 | 74 | out = zero_mean * denom 75 | 76 | if self.elementwise_affine: 77 | out = (out + self.bias.view(1, self.num_channels, 1, 1) 78 | ) * self.weight.view(1, self.num_channels, 1, 1) 79 | 80 | return out 81 | -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/mixed_bit_compression_apply.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import gc 3 | import json 4 | import logging 5 | import os 6 | 7 | import coremltools as ct 8 | import coremltools.optimize.coreml as cto 9 | import numpy as np 10 | 11 | from python_coreml_stable_diffusion.torch2coreml import get_pipeline 12 | from python_coreml_stable_diffusion.mixed_bit_compression_pre_analysis import ( 13 | NBITS, 14 | PALETTIZE_MIN_SIZE as MIN_SIZE 15 | ) 16 | 17 | 18 | logging.basicConfig() 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.INFO) 21 | 22 | 23 | def main(args): 24 | # Load Core ML model 25 | coreml_model = ct.models.MLModel(args.mlpackage_path, compute_units=ct.ComputeUnit.CPU_ONLY) 26 | logger.info(f"Loaded {args.mlpackage_path}") 27 | 28 | # Load palettization recipe 29 | with open(args.pre_analysis_json_path, 'r') as f: 30 | pre_analysis = json.load(f) 31 | 32 | if args.selected_recipe not in list(pre_analysis["recipes"]): 33 | raise KeyError( 34 | f"--selected-recipe ({args.selected_recipe}) not found in " 35 | f"--pre-analysis-json-path ({args.pre_analysis_json_path}). " 36 | f" Available recipes: {list(pre_analysis['recipes'])}" 37 | ) 38 | 39 | 40 | recipe = pre_analysis["recipes"][args.selected_recipe] 41 | assert all(nbits in NBITS + [16] for nbits in recipe.values()), \ 42 | f"Some nbits values in the recipe are illegal. Allowed values: {NBITS}" 43 | 44 | # Hash tensors to be able to match torch tensor names to mil tensors 45 | def get_tensor_hash(tensor): 46 | assert tensor.dtype == np.float16 47 | return tensor.ravel()[0] + np.prod(tensor.shape) 48 | 49 | args.model_version = pre_analysis["model_version"] 50 | pipe = get_pipeline(args) 51 | torch_model = pipe.unet 52 | 53 | hashed_recipe = {} 54 | for torch_module_name, nbits in recipe.items(): 55 | tensor = [ 56 | tensor.cpu().numpy().astype(np.float16) for name,tensor in torch_model.named_parameters() 57 | if name == torch_module_name + '.weight' 58 | ][0] 59 | hashed_recipe[get_tensor_hash(tensor)] = nbits 60 | 61 | del pipe 62 | gc.collect() 63 | 64 | op_name_configs = {} 65 | weight_metadata = cto.get_weights_metadata(coreml_model, weight_threshold=MIN_SIZE) 66 | hashes = np.array(list(hashed_recipe)) 67 | for name, metadata in weight_metadata.items(): 68 | # Look up target bits for this weight 69 | tensor_hash = get_tensor_hash(metadata.val) 70 | pdist = np.abs(hashes - tensor_hash) 71 | assert(pdist.min() < 0.01) 72 | matched = pdist.argmin() 73 | target_nbits = hashed_recipe[hashes[matched]] 74 | 75 | if target_nbits == 16: 76 | continue 77 | 78 | op_name_configs[name] = cto.OpPalettizerConfig( 79 | mode="kmeans", 80 | nbits=target_nbits, 81 | weight_threshold=int(MIN_SIZE) 82 | ) 83 | 84 | config = ct.optimize.coreml.OptimizationConfig(op_name_configs=op_name_configs) 85 | coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config) 86 | 87 | coreml_model.save(args.o) 88 | 89 | 90 | if __name__ == "__main__": 91 | parser = argparse.ArgumentParser() 92 | parser.add_argument( 93 | "-o", 94 | required=True, 95 | help="Output directory to save the custom palettized model" 96 | ) 97 | parser.add_argument( 98 | "--mlpackage-path", 99 | required=True, 100 | help="Path to .mlpackage model to be palettized" 101 | ) 102 | parser.add_argument( 103 | "--pre-analysis-json-path", 104 | required=True, 105 | type=str, 106 | help=("The JSON file generated by mixed_bit_compression_pre_analysis.py" 107 | )) 108 | parser.add_argument( 109 | "--selected-recipe", 110 | required=True, 111 | type=str, 112 | help=("The string key into --pre-analysis-json-path's baselines dict" 113 | )) 114 | parser.add_argument( 115 | "--custom-vae-version", 116 | type=str, 117 | default=None, 118 | help= 119 | ("Custom VAE checkpoint to override the pipeline's built-in VAE. " 120 | "If specified, the specified VAE will be converted instead of the one associated to the `--model-version` checkpoint. " 121 | "No precision override is applied when using a custom VAE." 122 | )) 123 | 124 | args = parser.parse_args() 125 | 126 | if not os.path.exists(args.mlpackage_path): 127 | raise FileNotFoundError 128 | if not os.path.exists(args.pre_analysis_json_path): 129 | raise FileNotFoundError 130 | if not args.pre_analysis_json_path.endswith('.json'): 131 | raise ValueError("--recipe-json-path should end with '.json'") 132 | 133 | main(args) 134 | -------------------------------------------------------------------------------- /python_coreml_stable_diffusion/multilingual_projection.py: -------------------------------------------------------------------------------- 1 | from python_coreml_stable_diffusion.torch2coreml import _compile_coreml_model 2 | 3 | import argparse 4 | import coremltools as ct 5 | import numpy as np 6 | import os 7 | import torch 8 | import torch.nn as nn 9 | 10 | # TODO: Read these values off of the NLContextualEmbedding API to enforce dimensions and track API versioning 11 | MAX_SEQUENCE_LENGTH = 256 12 | EMBED_DIM = 512 13 | BATCH_SIZE = 1 14 | 15 | def main(args): 16 | # Layer that was trained to map NLContextualEmbedding to your text_encoder.hidden_size dimensionality 17 | text_encoder_projection = torch.jit.load(args.input_path) 18 | 19 | # Prepare random inputs for tracing the network before conversion 20 | random_input = torch.randn(BATCH_SIZE, MAX_SEQUENCE_LENGTH, EMBED_DIM) 21 | 22 | # Create a class to bake in the reshape operations required to fit the existing model interface 23 | class TextEncoderProjection(nn.Module): 24 | def __init__(self, proj): 25 | super().__init__() 26 | self.proj = proj 27 | 28 | def forward(self, x): 29 | return self.proj(x).transpose(1, 2).unsqueeze(2) # BSC, BC1S 30 | 31 | # Trace the torch model 32 | text_encoder_projection = torch.jit.trace(TextEncoderProjection(text_encoder_projection), (random_input,)) 33 | 34 | # Convert the model to Core ML 35 | mlpackage_path = os.path.join(args.output_dir, "MultilingualTextEncoderProjection.mlpackage") 36 | ct.convert( 37 | text_encoder_projection, 38 | inputs=[ct.TensorType('nlcontextualembeddings_output', shape=(1, MAX_SEQUENCE_LENGTH, EMBED_DIM), dtype=np.float32)], 39 | outputs=[ct.TensorType('encoder_hidden_states', dtype=np.float32)], 40 | minimum_deployment_target=ct.target.macOS14, # NLContextualEmbedding minimum availability build 41 | convert_to='mlprogram', 42 | ).save() 43 | 44 | # Compile the model and save it under the specified directory 45 | _compile_coreml_model(mlpackage_path, args.output_dir, final_name="MultilingualTextEncoderProjection") 46 | 47 | 48 | if __name__ == "__main__": 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument( 51 | "--input-path", 52 | help="Path to the torchscript file that contains the projection layer" 53 | ) 54 | parser.add_argument( 55 | "--output-dir", 56 | help="Output directory in which the Core ML model should be saved", 57 | ) 58 | args = parser.parse_args() 59 | 60 | main(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | coremltools>=8.0 2 | diffusers[torch]==0.30.2 3 | diffusionkit==0.4.0 4 | torch 5 | transformers==4.44.2 6 | scipy 7 | scikit-learn 8 | pytest 9 | invisible-watermark 10 | safetensors 11 | matplotlib 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | from python_coreml_stable_diffusion._version import __version__ 4 | 5 | with open('README.md') as f: 6 | readme = f.read() 7 | 8 | setup( 9 | name='python_coreml_stable_diffusion', 10 | version=__version__, 11 | url='https://github.com/apple/ml-stable-diffusion', 12 | description="Run Stable Diffusion on Apple Silicon with Core ML (Python and Swift)", 13 | long_description=readme, 14 | long_description_content_type='text/markdown', 15 | author='Apple Inc.', 16 | install_requires=[ 17 | "coremltools>=8.0", 18 | "diffusers[torch]==0.30.2", 19 | "torch", 20 | "transformers==4.44.2", 21 | "huggingface-hub==0.24.6", 22 | "scipy", 23 | "numpy<1.24", 24 | "pytest", 25 | "scikit-learn", 26 | "invisible-watermark", 27 | "safetensors", 28 | "matplotlib", 29 | "diffusionkit==0.4.0", 30 | ], 31 | packages=find_packages(), 32 | classifiers=[ 33 | "Development Status :: 4 - Beta", 34 | "Intended Audience :: Developers", 35 | "Operating System :: MacOS :: MacOS X", 36 | "Programming Language :: Python :: 3", 37 | "Programming Language :: Python :: 3.7", 38 | "Programming Language :: Python :: 3.8", 39 | "Programming Language :: Python :: 3.9", 40 | "Topic :: Artificial Intelligence", 41 | "Topic :: Scientific/Engineering", 42 | "Topic :: Software Development", 43 | ], 44 | ) 45 | -------------------------------------------------------------------------------- /swift/StableDiffusion/pipeline/CGImage+vImage.swift: -------------------------------------------------------------------------------- 1 | // For licensing see accompanying LICENSE.md file. 2 | // Copyright (C) 2022 Apple Inc. All Rights Reserved. 3 | 4 | import Foundation 5 | import Accelerate 6 | import CoreML 7 | import CoreGraphics 8 | 9 | @available(iOS 16.0, macOS 13.0, *) 10 | extension CGImage { 11 | 12 | typealias PixelBufferPFx1 = vImage.PixelBuffer 13 | typealias PixelBufferP8x3 = vImage.PixelBuffer 14 | typealias PixelBufferIFx3 = vImage.PixelBuffer 15 | typealias PixelBufferI8x3 = vImage.PixelBuffer 16 | 17 | public enum ShapedArrayError: String, Swift.Error { 18 | case wrongNumberOfChannels 19 | case incorrectFormatsConvertingToShapedArray 20 | case vImageConverterNotInitialized 21 | } 22 | 23 | public static func fromShapedArray(_ array: MLShapedArray) throws -> CGImage { 24 | 25 | // array is [N,C,H,W], where C==3 26 | let channelCount = array.shape[1] 27 | guard channelCount == 3 else { 28 | throw ShapedArrayError.wrongNumberOfChannels 29 | } 30 | 31 | let height = array.shape[2] 32 | let width = array.shape[3] 33 | 34 | // Normalize each channel into a float between 0 and 1.0 35 | let floatChannels = (0.. [0.0 1.0] 46 | cIn.multiply(by: 0.5, preBias: 1.0, postBias: 0.0, destination: cOut) 47 | } 48 | return cOut 49 | } 50 | 51 | // Convert to interleaved and then to UInt8 52 | let floatImage = PixelBufferIFx3(planarBuffers: floatChannels) 53 | let uint8Image = PixelBufferI8x3(width: width, height: height) 54 | floatImage.convert(to:uint8Image) // maps [0.0 1.0] -> [0 255] and clips 55 | 56 | // Convert to uint8x3 to RGB CGImage (no alpha) 57 | let bitmapInfo = CGBitmapInfo(rawValue: CGImageAlphaInfo.none.rawValue) 58 | let cgImage = uint8Image.makeCGImage(cgImageFormat: 59 | .init(bitsPerComponent: 8, 60 | bitsPerPixel: 3*8, 61 | colorSpace: CGColorSpace(name: CGColorSpace.sRGB) ?? CGColorSpaceCreateDeviceRGB(), 62 | bitmapInfo: bitmapInfo)!)! 63 | 64 | return cgImage 65 | } 66 | 67 | public func planarRGBShapedArray(minValue: Float, maxValue: Float) 68 | throws -> MLShapedArray { 69 | guard 70 | var sourceFormat = vImage_CGImageFormat(cgImage: self), 71 | var mediumFormat = vImage_CGImageFormat( 72 | bitsPerComponent: 8 * MemoryLayout.size, 73 | bitsPerPixel: 8 * MemoryLayout.size * 4, 74 | colorSpace: CGColorSpaceCreateDeviceRGB(), 75 | bitmapInfo: CGBitmapInfo(rawValue: CGImageAlphaInfo.first.rawValue)), 76 | let width = vImagePixelCount(exactly: self.width), 77 | let height = vImagePixelCount(exactly: self.height) 78 | else { 79 | throw ShapedArrayError.incorrectFormatsConvertingToShapedArray 80 | } 81 | 82 | var sourceImageBuffer = try vImage_Buffer(cgImage: self) 83 | 84 | var mediumDestination = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: mediumFormat.bitsPerPixel) 85 | 86 | let converter = vImageConverter_CreateWithCGImageFormat( 87 | &sourceFormat, 88 | &mediumFormat, 89 | nil, 90 | vImage_Flags(kvImagePrintDiagnosticsToConsole), 91 | nil) 92 | 93 | guard let converter = converter?.takeRetainedValue() else { 94 | throw ShapedArrayError.vImageConverterNotInitialized 95 | } 96 | 97 | vImageConvert_AnyToAny(converter, &sourceImageBuffer, &mediumDestination, nil, vImage_Flags(kvImagePrintDiagnosticsToConsole)) 98 | 99 | var destinationA = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) 100 | var destinationR = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) 101 | var destinationG = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) 102 | var destinationB = try vImage_Buffer(width: Int(width), height: Int(height), bitsPerPixel: 8 * UInt32(MemoryLayout.size)) 103 | 104 | var minFloat: [Float] = Array(repeating: minValue, count: 4) 105 | var maxFloat: [Float] = Array(repeating: maxValue, count: 4) 106 | 107 | vImageConvert_ARGB8888toPlanarF(&mediumDestination, &destinationA, &destinationR, &destinationG, &destinationB, &maxFloat, &minFloat, .zero) 108 | 109 | let destAPtr = destinationA.data.assumingMemoryBound(to: Float.self) 110 | let destRPtr = destinationR.data.assumingMemoryBound(to: Float.self) 111 | let destGPtr = destinationG.data.assumingMemoryBound(to: Float.self) 112 | let destBPtr = destinationB.data.assumingMemoryBound(to: Float.self) 113 | 114 | for i in 0..(data: imageData, shape: [1, 3, self.height, self.width]) 129 | 130 | return shapedArray 131 | } 132 | 133 | private func normalizePixelValues(pixel: UInt8) -> Float { 134 | return (Float(pixel) / 127.5) - 1.0 135 | } 136 | 137 | public func toRGBShapedArray(minValue: Float, maxValue: Float) 138 | throws -> MLShapedArray { 139 | let image = self 140 | let width = image.width 141 | let height = image.height 142 | let alphaMaskValue: Float = minValue 143 | 144 | guard let colorSpace = CGColorSpace(name: CGColorSpace.sRGB), 145 | let context = CGContext(data: nil, width: width, height: height, bitsPerComponent: 8, bytesPerRow: 4 * width, space: colorSpace, bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue), 146 | let ptr = context.data?.bindMemory(to: UInt8.self, capacity: width * height * 4) else { 147 | return [] 148 | } 149 | 150 | context.draw(image, in: CGRect(x: 0, y: 0, width: width, height: height)) 151 | 152 | var redChannel = [Float](repeating: 0, count: width * height) 153 | var greenChannel = [Float](repeating: 0, count: width * height) 154 | var blueChannel = [Float](repeating: 0, count: width * height) 155 | 156 | for y in 0..(scalars: redChannel, shape: colorShape) 174 | let greenShapedArray = MLShapedArray(scalars: greenChannel, shape: colorShape) 175 | let blueShapedArray = MLShapedArray(scalars: blueChannel, shape: colorShape) 176 | 177 | let shapedArray = MLShapedArray(concatenating: [redShapedArray, greenShapedArray, blueShapedArray], alongAxis: 1) 178 | 179 | return shapedArray 180 | } 181 | } 182 | 183 | extension vImage_Buffer { 184 | func unpaddedData() -> Data { 185 | let bytesPerPixel = self.rowBytes / Int(self.width) 186 | let bytesPerRow = Int(self.width) * bytesPerPixel 187 | 188 | var contiguousPixelData = Data(capacity: bytesPerRow * Int(self.height)) 189 | for row in 0..